diff --git a/.gitattributes b/.gitattributes
new file mode 100644
index 0000000000000000000000000000000000000000..0cd58331b2a989b68be4ec5676383437fca8687b
--- /dev/null
+++ b/.gitattributes
@@ -0,0 +1,36 @@
+*.7z filter=lfs diff=lfs merge=lfs -text
+*.arrow filter=lfs diff=lfs merge=lfs -text
+*.bin filter=lfs diff=lfs merge=lfs -text
+*.bz2 filter=lfs diff=lfs merge=lfs -text
+*.ckpt filter=lfs diff=lfs merge=lfs -text
+*.ftz filter=lfs diff=lfs merge=lfs -text
+*.gz filter=lfs diff=lfs merge=lfs -text
+*.h5 filter=lfs diff=lfs merge=lfs -text
+*.joblib filter=lfs diff=lfs merge=lfs -text
+*.lfs.* filter=lfs diff=lfs merge=lfs -text
+*.mlmodel filter=lfs diff=lfs merge=lfs -text
+*.model filter=lfs diff=lfs merge=lfs -text
+*.msgpack filter=lfs diff=lfs merge=lfs -text
+*.npy filter=lfs diff=lfs merge=lfs -text
+*.npz filter=lfs diff=lfs merge=lfs -text
+*.onnx filter=lfs diff=lfs merge=lfs -text
+*.ot filter=lfs diff=lfs merge=lfs -text
+*.parquet filter=lfs diff=lfs merge=lfs -text
+*.pb filter=lfs diff=lfs merge=lfs -text
+*.pickle filter=lfs diff=lfs merge=lfs -text
+*.pkl filter=lfs diff=lfs merge=lfs -text
+*.pt filter=lfs diff=lfs merge=lfs -text
+*.pth filter=lfs diff=lfs merge=lfs -text
+*.rar filter=lfs diff=lfs merge=lfs -text
+*.safetensors filter=lfs diff=lfs merge=lfs -text
+saved_model/**/* filter=lfs diff=lfs merge=lfs -text
+*.tar.* filter=lfs diff=lfs merge=lfs -text
+*.tar filter=lfs diff=lfs merge=lfs -text
+*.tflite filter=lfs diff=lfs merge=lfs -text
+*.tgz filter=lfs diff=lfs merge=lfs -text
+*.wasm filter=lfs diff=lfs merge=lfs -text
+*.xz filter=lfs diff=lfs merge=lfs -text
+*.zip filter=lfs diff=lfs merge=lfs -text
+*.zst filter=lfs diff=lfs merge=lfs -text
+*tfevents* filter=lfs diff=lfs merge=lfs -text
+*.so filter=lfs diff=lfs merge=lfs -text
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..8a0162bfc076e35c9b4d87579f05f86ff2639a43
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,5 @@
+.venv
+__pycache__
+.bak
+megablocks-moe/.bak
+.pytest_cache
\ No newline at end of file
diff --git a/README.md b/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..2bed67f0bc24fc62546154442dad44b08d71d39c
--- /dev/null
+++ b/README.md
@@ -0,0 +1,76 @@
+---
+license: apache-2.0
+tags:
+ - kernels
+---
+
+## Quickstart
+
+```bash
+uv run https://huggingface.co/kernels-community/megablocks/raw/main/readme_example.py
+```
+
+```python
+# /// script
+# requires-python = "==3.10"
+# dependencies = [
+# "numpy",
+# "kernels",
+# "torch"
+# ]
+# ///
+
+import torch
+from collections import namedtuple
+
+from kernels import get_kernel
+
+# Make reproducible
+torch.manual_seed(42)
+torch.cuda.manual_seed(42)
+
+# Download optimized kernels from the Hugging Face hub
+megablocks = get_kernel("kernels-community/megablocks")
+print("MegaBlocks kernel downloaded successfully.")
+
+model = megablocks.layers.MegaBlocksMoeMLP()
+model.experts = namedtuple("Experts", ["gate_up_proj", "gate_down_proj", "down_proj", "hidden_size"])
+print("MegaBlocksMoeMLP instance created successfully.")
+
+# Config
+ne, hs, isz = 128, 1152, 3072
+
+# Router with proper initialization
+model.router = torch.nn.Linear(hs, ne, device="cuda")
+torch.nn.init.kaiming_uniform_(model.router.weight)
+
+# Expert layers with realistic weights
+e = model.experts
+e.gate_up_proj = torch.nn.Parameter(torch.randn(ne, hs, isz, device="cuda") * 0.02)
+e.gate_up_proj_bias = torch.nn.Parameter(torch.zeros(ne, isz, device="cuda"))
+e.down_proj = torch.nn.Parameter(torch.randn(ne, 1536, hs, device="cuda") * 0.02)
+e.down_proj_bias = torch.nn.Parameter(torch.zeros(ne, hs, device="cuda"))
+e.hidden_size = hs
+print("Expert layers initialized successfully.")
+
+# Test with normalized input
+x = torch.randn(1, 1, hs, device="cuda") * 0.1
+output, expert_weights = model(x)
+print("Model forward pass completed successfully.")
+
+print(f"Output shape: {output.shape}")
+print(f"Output range: [{output.min():.3f}, {output.max():.3f}]")
+print(f"Output: {output.flatten()[:10]}")
+print(f"Expert weights sum: {expert_weights.sum():.3f}")
+```
+
+### Performance
+
+
+
+
+
+
+
+
+
diff --git a/benchmarks/benchmark.py b/benchmarks/benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..780d585a4068ff525a8cd071e64cdc1a4d39c988
--- /dev/null
+++ b/benchmarks/benchmark.py
@@ -0,0 +1,233 @@
+import torch
+import torch.nn.functional as F
+from collections import namedtuple
+
+from kernels.benchmark import Benchmark
+
+
+def moe_mlp_reference(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: torch.Tensor,
+ gate_up_proj: torch.Tensor,
+ gate_up_proj_bias: torch.Tensor,
+ down_proj: torch.Tensor,
+ down_proj_bias: torch.Tensor,
+ top_k: int = 4,
+ alpha: float = 1.702,
+ limit: float = 7.0,
+) -> tuple[torch.Tensor, torch.Tensor]:
+ in_shape = x.shape
+ num_experts = router_weight.shape[0]
+ hidden_size = x.shape[-1]
+
+ # Flatten to (num_tokens, hidden_size)
+ hidden_states = x.view(-1, hidden_size)
+ num_tokens = hidden_states.shape[0]
+
+ # Router: compute logits and get top-k experts
+ logits = F.linear(hidden_states, router_weight, router_bias)
+ expert_weights, router_indices = torch.topk(logits, top_k, dim=-1)
+ routing_weights = F.softmax(expert_weights, dim=-1)
+
+ # Initialize output
+ next_states = torch.zeros_like(hidden_states)
+
+ # Create expert mask using one_hot
+ with torch.no_grad():
+ expert_mask = F.one_hot(router_indices, num_classes=num_experts)
+ expert_mask = expert_mask.permute(2, 1, 0) # (num_experts, top_k, num_tokens)
+ # Find which experts are hit
+ expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
+
+ # Process each expert that has tokens
+ for expert_idx in expert_hit:
+ expert_idx = expert_idx[0]
+ with torch.no_grad():
+ top_k_idx, token_idx = torch.where(expert_mask[expert_idx])
+
+ current_state = hidden_states[token_idx]
+
+ # Up projection
+ gate_up = (
+ current_state @ gate_up_proj[expert_idx] + gate_up_proj_bias[expert_idx]
+ )
+
+ # Split into gate and up
+ gate, up = gate_up[..., ::2], gate_up[..., 1::2]
+
+ # Clamp
+ gate = gate.clamp(min=None, max=limit)
+ up = up.clamp(min=-limit, max=limit)
+
+ # SwiGLU-like activation
+ glu = gate * torch.sigmoid(gate * alpha)
+ gated_output = (up + 1) * glu
+
+ # Down projection
+ out = gated_output @ down_proj[expert_idx] + down_proj_bias[expert_idx]
+
+ # Get the routing weight for this expert at the correct top_k position
+ weights_for_expert = routing_weights[token_idx, top_k_idx]
+ weighted_output = out * weights_for_expert[:, None]
+ next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype))
+
+ return next_states.view(in_shape), routing_weights
+
+
+class MegaBlocksMoeBenchmark(Benchmark):
+ seed: int = 42
+
+ def setup(self):
+ # Config matching readme_example.py
+ ne, hs, isz = 128, 1152, 3072
+ batch, seq = 8, 1
+
+ # Router
+ self.router_weight = torch.randn(
+ ne, hs, device=self.device, dtype=torch.float32
+ )
+ torch.nn.init.kaiming_uniform_(self.router_weight)
+ self.router_bias = torch.zeros(ne, device=self.device, dtype=torch.float32)
+
+ # Expert weights
+ self.gate_up_proj = (
+ torch.randn(ne, hs, isz, device=self.device, dtype=torch.float32) * 0.02
+ )
+ self.gate_up_proj_bias = torch.zeros(
+ ne, isz, device=self.device, dtype=torch.float32
+ )
+ self.down_proj = (
+ torch.randn(ne, isz // 2, hs, device=self.device, dtype=torch.float32)
+ * 0.02
+ )
+ self.down_proj_bias = torch.zeros(
+ ne, hs, device=self.device, dtype=torch.float32
+ )
+
+ # Input
+ self.x = (
+ torch.randn(seq, batch, hs, device=self.device, dtype=torch.float32) * 0.1
+ )
+
+ # Setup the model
+ self.model = self.kernel.layers.MegaBlocksMoeMLP()
+ self.model.router = torch.nn.Linear(hs, ne, device=self.device)
+ self.model.router.weight.data = self.router_weight.clone()
+ self.model.router.bias.data = self.router_bias.clone()
+
+ Experts = namedtuple(
+ "Experts",
+ [
+ "gate_up_proj",
+ "gate_up_proj_bias",
+ "down_proj",
+ "down_proj_bias",
+ "hidden_size",
+ "num_experts",
+ ],
+ )
+ self.model.experts = Experts(
+ gate_up_proj=torch.nn.Parameter(self.gate_up_proj.clone()),
+ gate_up_proj_bias=torch.nn.Parameter(self.gate_up_proj_bias.clone()),
+ down_proj=torch.nn.Parameter(self.down_proj.clone()),
+ down_proj_bias=torch.nn.Parameter(self.down_proj_bias.clone()),
+ hidden_size=hs,
+ num_experts=ne,
+ )
+
+ self.out = torch.empty(seq, batch, hs, device=self.device, dtype=torch.float32)
+
+ def benchmark_base(self):
+ self.out, self.expert_weights = self.model(self.x)
+
+ def verify_base(self) -> torch.Tensor:
+ ref_out, _ = moe_mlp_reference(
+ self.x,
+ self.router_weight,
+ self.router_bias,
+ self.gate_up_proj,
+ self.gate_up_proj_bias,
+ self.down_proj,
+ self.down_proj_bias,
+ top_k=4,
+ )
+ return ref_out
+
+ def setup_large(self):
+ # Larger config with more tokens
+ ne, hs, isz = 128, 1152, 3072
+ batch, seq = 32, 16
+
+ # Router
+ self.router_weight = torch.randn(
+ ne, hs, device=self.device, dtype=torch.float32
+ )
+ torch.nn.init.kaiming_uniform_(self.router_weight)
+ self.router_bias = torch.zeros(ne, device=self.device, dtype=torch.float32)
+
+ # Expert weights
+ self.gate_up_proj = (
+ torch.randn(ne, hs, isz, device=self.device, dtype=torch.float32) * 0.02
+ )
+ self.gate_up_proj_bias = torch.zeros(
+ ne, isz, device=self.device, dtype=torch.float32
+ )
+ self.down_proj = (
+ torch.randn(ne, isz // 2, hs, device=self.device, dtype=torch.float32)
+ * 0.02
+ )
+ self.down_proj_bias = torch.zeros(
+ ne, hs, device=self.device, dtype=torch.float32
+ )
+
+ # Input
+ self.x = (
+ torch.randn(seq, batch, hs, device=self.device, dtype=torch.float32) * 0.1
+ )
+
+ # Setup the model
+ self.model = self.kernel.layers.MegaBlocksMoeMLP()
+ self.model.router = torch.nn.Linear(hs, ne, device=self.device)
+ self.model.router.weight.data = self.router_weight.clone()
+ self.model.router.bias.data = self.router_bias.clone()
+
+ Experts = namedtuple(
+ "Experts",
+ [
+ "gate_up_proj",
+ "gate_up_proj_bias",
+ "down_proj",
+ "down_proj_bias",
+ "hidden_size",
+ "num_experts",
+ "capacity_factor",
+ ],
+ )
+ self.model.experts = Experts(
+ gate_up_proj=torch.nn.Parameter(self.gate_up_proj.clone()),
+ gate_up_proj_bias=torch.nn.Parameter(self.gate_up_proj_bias.clone()),
+ down_proj=torch.nn.Parameter(self.down_proj.clone()),
+ down_proj_bias=torch.nn.Parameter(self.down_proj_bias.clone()),
+ hidden_size=hs,
+ num_experts=ne,
+ capacity_factor=4.0, # Higher capacity to avoid token dropping
+ )
+
+ self.out = torch.empty(seq, batch, hs, device=self.device, dtype=torch.float32)
+
+ def benchmark_large(self):
+ self.out, self.expert_weights = self.model(self.x)
+
+ def verify_large(self) -> torch.Tensor:
+ ref_out, _ = moe_mlp_reference(
+ self.x,
+ self.router_weight,
+ self.router_bias,
+ self.gate_up_proj,
+ self.gate_up_proj_bias,
+ self.down_proj,
+ self.down_proj_bias,
+ top_k=4,
+ )
+ return ref_out
diff --git a/build.toml b/build.toml
new file mode 100644
index 0000000000000000000000000000000000000000..863ecb6568219c7bb42b06cc0a81445a7639fc8a
--- /dev/null
+++ b/build.toml
@@ -0,0 +1,43 @@
+[general]
+name = "megablocks"
+universal = false
+
+[torch]
+src = [
+ "torch-ext/torch_binding.cpp",
+ "torch-ext/torch_binding.h"
+]
+
+[kernel.megablocks]
+backend = "cuda"
+cuda-capabilities = [
+ "7.0",
+ "7.2",
+ "7.5",
+ "8.0",
+ "8.6",
+ "8.7",
+ "8.9",
+ "9.0",
+ "10.0",
+ "10.1",
+ "11.8",
+ "12.0"
+]
+depends = ["torch", "cutlass_3_8"]
+src = [
+ "csrc/new_cumsum.h",
+ "csrc/new_cumsum.cu",
+ "csrc/new_histogram.h",
+ "csrc/new_histogram.cu",
+ "csrc/new_indices.h",
+ "csrc/new_indices.cu",
+ "csrc/new_replicate.cu",
+ "csrc/new_replicate.h",
+ "csrc/new_sort.h",
+ "csrc/new_sort.cu",
+ # vendored grouped gemm
+ "csrc/grouped_gemm/fill_arguments.cuh",
+ "csrc/grouped_gemm/grouped_gemm.cu",
+ "csrc/grouped_gemm/grouped_gemm.h",
+]
\ No newline at end of file
diff --git a/build/torch210-cxx11-cpu-x86_64-linux/__init__.py b/build/torch210-cxx11-cpu-x86_64-linux/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..38075732c6d8fa0e1e6ef493145e1aca3851ae6b
--- /dev/null
+++ b/build/torch210-cxx11-cpu-x86_64-linux/__init__.py
@@ -0,0 +1,202 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+from ._ops import ops
+
+from .grouped_gemm import backend as gg_backend
+from .grouped_gemm import ops as gg_ops
+
+
+from ._layers.arguments import Arguments
+from ._layers.dmoe import ParallelDroplessMLP, dMoE
+from ._layers.glu import SparseGLU
+from ._layers.mlp import MLP, SparseMLP
+from ._layers.moe import MoE, ParallelMLP, get_load_balancing_loss
+
+from . import layers
+
+# This section contains the direct kernel exports (not inlcuded in the original code)
+def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
+ """
+ Compute exclusive cumulative sum along the specified dimension.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ result = ops.exclusive_cumsum(x, dim)
+ out.copy_(result)
+ return out
+
+
+def inclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
+ """
+ Compute inclusive cumulative sum along the specified dimension.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ result = ops.inclusive_cumsum(x, dim)
+ out.copy_(result)
+ return out
+
+
+def histogram(x: torch.Tensor, num_bins: int) -> torch.Tensor:
+ """
+ Compute histogram of input tensor values.
+
+ Args:
+ x: Input tensor
+ num_bins: Number of histogram bins
+
+ Returns:
+ Histogram tensor with counts for each bin
+ """
+ return ops.histogram(x, num_bins)
+
+
+def indices(
+ padded_bins: torch.Tensor,
+ block_size: int,
+ output_block_rows: int,
+ output_block_columns: int,
+) -> torch.Tensor:
+ """
+ Construct indices from padded bins for sparse operations.
+
+ Args:
+ padded_bins: Tensor containing bin boundaries
+ block_size: Size of each block
+ output_block_rows: Number of rows in output blocks
+ output_block_columns: Number of columns in output blocks
+
+ Returns:
+ Tensor containing constructed indices
+ """
+ return ops.indices(padded_bins, block_size, output_block_rows, output_block_columns)
+
+
+def replicate_forward(
+ x: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Forward pass of replicate operation - replicate values according to bin sizes.
+
+ Args:
+ x: Input tensor with values to replicate
+ bins: Tensor containing bin sizes
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ return ops.replicate_forward(x, bins, out)
+
+
+def replicate_backward(
+ grad: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Backward pass of replicate operation - reduce gradients back to bins.
+
+ Args:
+ grad: Gradient tensor to reduce
+ bins: Tensor containing bin sizes
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ return ops.replicate_backward(grad, bins, out)
+
+
+def sort(
+ x: torch.Tensor, end_bit: int, x_out: torch.Tensor, iota_out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Radix sort with index tracking.
+
+ Args:
+ x: Input tensor to sort
+ end_bit: Number of bits to consider in sorting
+ x_out: Output tensor for sorted values
+ iota_out: Output tensor for sorted indices
+
+ Returns:
+ The sorted values tensor
+ """
+ return ops.sort(x, end_bit, x_out, iota_out)
+
+
+# Convenience functions for common use cases
+def cumsum(x: torch.Tensor, dim: int = -1, exclusive: bool = False) -> torch.Tensor:
+ """
+ Compute cumulative sum with automatic output allocation.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum (default: last dimension)
+ exclusive: Whether to compute exclusive (True) or inclusive (False) cumsum
+
+ Returns:
+ New tensor containing the cumulative sum
+ """
+ out = torch.empty_like(x)
+ if exclusive:
+ return exclusive_cumsum(x, dim, out)
+ else:
+ return inclusive_cumsum(x, dim, out)
+
+
+def argsort(x: torch.Tensor, end_bit: int = 32) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Sort tensor and return both sorted values and indices.
+
+ Args:
+ x: Input tensor to sort
+ end_bit: Number of bits to consider in sorting
+
+ Returns:
+ Tuple of (sorted_values, sorted_indices)
+ """
+ x_out = torch.empty_like(x)
+ iota_out = torch.empty_like(x)
+ sort(x, end_bit, x_out, iota_out)
+ return x_out, iota_out
+
+
+# Export public API
+__all__ = [
+ "MyReplacementLayer",
+ # Direct kernel exports
+ "exclusive_cumsum",
+ "inclusive_cumsum",
+ "histogram",
+ "indices",
+ "replicate_forward",
+ "replicate_backward",
+ "sort",
+ "cumsum",
+ "argsort",
+ # Original exports
+ "Arguments",
+ "ParallelDroplessMLP",
+ "dMoE",
+ "SparseGLU",
+ "MLP",
+ "SparseMLP",
+ "MoE",
+ "ParallelMLP",
+ "get_load_balancing_loss",
+]
diff --git a/build/torch210-cxx11-cpu-x86_64-linux/_layers/__init__.py b/build/torch210-cxx11-cpu-x86_64-linux/_layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a720e7a2cc4e44636f6e433a2750e945dc38e8b2
--- /dev/null
+++ b/build/torch210-cxx11-cpu-x86_64-linux/_layers/__init__.py
@@ -0,0 +1,10 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# from megablocks.layers.dmoe import dMoE
+from .moe import MoE
+
+__all__ = [
+ 'MoE',
+ # 'dMoE',
+]
diff --git a/build/torch210-cxx11-cpu-x86_64-linux/_layers/activation_fn.py b/build/torch210-cxx11-cpu-x86_64-linux/_layers/activation_fn.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e1d956704840aa4daf7d1d71d24e051567feab9
--- /dev/null
+++ b/build/torch210-cxx11-cpu-x86_64-linux/_layers/activation_fn.py
@@ -0,0 +1,33 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Callable, Union
+
+import torch
+from ..stk import Matrix
+
+
+def act_fn(
+ x: Matrix,
+ function: Callable,
+ return_grad_fn: bool = False,
+ **kwargs,
+) -> Union[tuple[Matrix, Any] | Matrix]:
+ assert isinstance(x, Matrix)
+ with torch.set_grad_enabled(torch.is_grad_enabled() or return_grad_fn):
+ if return_grad_fn:
+ x.data.requires_grad = True
+ out = function(x.data, **kwargs)
+ y = Matrix(
+ x.size(),
+ out,
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
+ if return_grad_fn:
+ return y, out.backward
+ return y
diff --git a/build/torch210-cxx11-cpu-x86_64-linux/_layers/all_to_all.py b/build/torch210-cxx11-cpu-x86_64-linux/_layers/all_to_all.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ac7067bcaa34db1d82b340c43550fe3577aa7a3
--- /dev/null
+++ b/build/torch210-cxx11-cpu-x86_64-linux/_layers/all_to_all.py
@@ -0,0 +1,54 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import torch.distributed as dist
+
+
+class AllToAllOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op):
+ out = torch.empty((sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype)
+
+ ctx.input_shape = x.shape
+ ctx.output_split_sizes = output_split_sizes
+ ctx.input_split_sizes = input_split_sizes
+ ctx.group = group
+ handle = dist.all_to_all_single(
+ out,
+ x,
+ output_split_sizes=output_split_sizes,
+ input_split_sizes=input_split_sizes,
+ group=group,
+ async_op=async_op,
+ )
+ return out, handle
+
+ @staticmethod
+ def backward(ctx, grad, _):
+ if ctx.needs_input_grad[0]:
+ out = torch.empty(
+ ctx.input_shape,
+ device=grad.device,
+ dtype=grad.dtype,
+ )
+ dist.all_to_all_single(
+ out,
+ grad,
+ output_split_sizes=ctx.input_split_sizes,
+ input_split_sizes=ctx.output_split_sizes,
+ group=ctx.group,
+ )
+ return out, None, None, None, None
+ return None, None, None, None, None
+
+
+def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False):
+ return AllToAllOp.apply(
+ x,
+ output_split_sizes,
+ input_split_sizes,
+ group,
+ async_op,
+ )
diff --git a/build/torch210-cxx11-cpu-x86_64-linux/_layers/arguments.py b/build/torch210-cxx11-cpu-x86_64-linux/_layers/arguments.py
new file mode 100644
index 0000000000000000000000000000000000000000..4db9b1bd38bc2e2f421625c124f86b85f45c5ae0
--- /dev/null
+++ b/build/torch210-cxx11-cpu-x86_64-linux/_layers/arguments.py
@@ -0,0 +1,101 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import dataclasses
+from functools import partial
+from typing import Any, Callable, Optional, Union
+
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+
+# import megablocks.grouped_gemm_util as grouped_gemm
+from .. import grouped_gemm_util as grouped_gemm
+
+# Type annotation for in-place Tensor initialization function.
+InitFn = Union[Callable[[torch.Tensor], None], partial[torch.Tensor]]
+
+_ALLOWED_BITWIDTHS = (-1, 4, 8)
+
+DEFAULT_ACTIVATION_FN = partial(F.gelu, approximate='tanh')
+
+
+@dataclasses.dataclass
+class Arguments:
+ # Model arguments.
+ hidden_size: int = 1024
+ ffn_hidden_size: int = 4096
+ num_layers: int = 1
+ bias: bool = True
+ return_bias: bool = True
+ activation_fn: Optional[Callable] = DEFAULT_ACTIVATION_FN
+
+ # MoE arguments.
+ moe_num_experts: int = 1
+ moe_top_k: int = 1
+ moe_capacity_factor: int = 1
+ moe_normalize_expert_weights: Optional[Union[int, float]] = None
+ moe_loss_weight: float = 0.1
+ moe_jitter_eps: Optional[float] = None
+ moe_lbl_in_fp32: bool = False
+
+ # Parallelism arguments.
+ moe_expert_model_parallelism: bool = False
+ expert_parallel_group: Optional[dist.ProcessGroup] = None
+ pipeline_model_parallel_size: int = 1
+ num_layers_per_virtual_pipeline_stage: Optional[int] = None
+
+ # Compute arguments.
+ memory_optimized_mlp: bool = False
+ mlp_type: str = 'mlp'
+ mlp_impl: str = 'sparse'
+
+ # Initialization arguments.
+ fp16: bool = True
+ bf16: bool = False
+ device: Union[int, torch.device] = dataclasses.field(default_factory=torch.cuda.current_device)
+ init_method: InitFn = partial(torch.nn.init.normal_, mean=0.0, std=0.02)
+ output_layer_init_method: InitFn = init_method
+
+ # Benchmarking arguments.
+ uniform_expert_assignment: bool = False
+
+ # shared expert arguments
+ shared_expert: bool = False # enable using shared expert
+ fc_cls: Any = torch.nn.Linear # class of the fully connected layer in shared expert (purpose: to allow using custom FC layer eg te.Linear (for FP8))
+ fc_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict,) # kwargs for custom fc layers
+ remat_act_fn: bool = True # enable act fn to be rematerialized instead of stored
+ shared_expert_hidden_size: Optional[
+ int] = None # hidden size of the shared expert IF we want to set it to something different from hidden_size
+ shared_expert_weighted_sum: bool = False # enable using weighted sum for shared expert output (wieghted by number of experts used)
+
+ # Router Z-loss arguments
+ moe_zloss_weight: float = 0 # 1e-3 is a reasonable value
+ moe_zloss_in_fp32: bool = False
+
+ def __post_init__(self):
+ # Sparse MLP is not supported with triton >=3.2.0
+ # TODO: Remove this once sparse is supported with triton >=3.2.0
+ if self.__getattribute__('mlp_impl') == 'sparse':
+ try:
+ import triton
+ if triton.__version__ >= '3.2.0':
+ raise ValueError(
+ 'Sparse MLP is not supported with triton >=3.2.0. Please use mlp_impl="grouped" instead.',
+ )
+ except ImportError:
+ raise ImportError('Triton is required for sparse MLP implementation')
+
+ if self.__getattribute__('mlp_impl') == 'grouped':
+ grouped_gemm.assert_grouped_gemm_is_available()
+
+ if self.shared_expert_hidden_size is None:
+ self.shared_expert_hidden_size = self.ffn_hidden_size
+
+
+def from_megatron(megatron_args: Any):
+ args = Arguments()
+ for field in dataclasses.fields(args):
+ if hasattr(megatron_args, field.name):
+ setattr(args, field.name, getattr(megatron_args, field.name))
+ return args
diff --git a/build/torch210-cxx11-cpu-x86_64-linux/_layers/common.py b/build/torch210-cxx11-cpu-x86_64-linux/_layers/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d07109702963ba48a3b94ab860807954dfd79c1
--- /dev/null
+++ b/build/torch210-cxx11-cpu-x86_64-linux/_layers/common.py
@@ -0,0 +1,26 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+from .arguments import Arguments
+
+
+def dtype(args: Arguments):
+ if args.fp16:
+ return torch.float16
+ elif args.bf16:
+ return torch.bfloat16
+ return None
+
+
+def cast_if_autocast_enabled(tensor):
+ if torch.is_autocast_enabled():
+ if tensor.device.type == 'cuda':
+ dtype = torch.get_autocast_gpu_dtype()
+ elif tensor.device.type == 'cpu':
+ dtype = torch.get_autocast_cpu_dtype()
+ else:
+ raise NotImplementedError()
+ return tensor.to(dtype=dtype)
+ return tensor
diff --git a/build/torch210-cxx11-cpu-x86_64-linux/_layers/dmlp_registry.py b/build/torch210-cxx11-cpu-x86_64-linux/_layers/dmlp_registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..de2ed047042e438c7190ebb139b6f7f30009734c
--- /dev/null
+++ b/build/torch210-cxx11-cpu-x86_64-linux/_layers/dmlp_registry.py
@@ -0,0 +1,42 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Union
+
+from . import glu, mlp
+from .arguments import Arguments
+
+MlpType = Union[mlp.SparseMLP, glu.SparseGLU]
+
+_REGISTRY = {
+ 'mlp': {
+ 'grouped': mlp.GroupedMLP,
+ 'sparse': mlp.SparseMLP,
+ },
+ 'glu': {
+ 'grouped': glu.GroupedGLU,
+ 'sparse': glu.SparseGLU,
+ },
+}
+
+
+def get(args: Arguments) -> MlpType:
+ """Returns an MLP for use in a dMoE instance.
+
+ Uses the provided arguments to instantiate the appropriate
+ MLP instance. This only contains MLPs for use in dMoEs
+ (ie. only for the dropless versions of MoEs).
+
+ Args:
+ args: propagated Arguments dataclass.
+
+ Returns:
+ An instantiated MLP constructed using the input args.
+ """
+ if args.mlp_type not in _REGISTRY:
+ raise ValueError(f'Unsupported mlp type: {args.mlp_type}')
+
+ if args.mlp_impl not in _REGISTRY[args.mlp_type]:
+ raise ValueError(f'{args.mlp_type} does not support {args.mlp_impl} backend.',)
+
+ return _REGISTRY[args.mlp_type][args.mlp_impl](args)
diff --git a/build/torch210-cxx11-cpu-x86_64-linux/_layers/dmoe.py b/build/torch210-cxx11-cpu-x86_64-linux/_layers/dmoe.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d0375a4df2f27134c4127e60be04f3b45693050
--- /dev/null
+++ b/build/torch210-cxx11-cpu-x86_64-linux/_layers/dmoe.py
@@ -0,0 +1,337 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import numpy as np
+import torch
+
+# try:
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/dmoe.py is needed.',
+# )
+
+# import megablocks.ops as ops
+# # from megablocks.ops import ops
+# from megablocks.layers import common, dmlp_registry, moe, mpu
+# from megablocks.layers.arguments import Arguments
+
+from .. import stk
+from .. import ops
+from . import common, dmlp_registry, moe, mpu
+from .arguments import Arguments
+
+def promote_scalar(x):
+ return x.view(1) if not len(x.size()) else x
+
+
+class ParallelDroplessMLP(moe.ParallelMLP):
+
+ def __init__(self, args: Arguments):
+ super(ParallelDroplessMLP, self).__init__(args)
+ self.hidden_size = args.hidden_size
+ self.ffn_hidden_size = mpu.features_per_rank(args)
+ self.blocking = 128
+ self.mlp = dmlp_registry.get(args)
+
+ # Calculate the number of bits needed to represent the column indices
+ # in the intermediate sparse matrix.
+ max_column_index = ((self.ffn_hidden_size * self.num_experts) // self.blocking)
+ self.transpose_sort_end_bit = max(
+ int(np.ceil(np.log2(max_column_index))),
+ 1,
+ )
+
+ def sparse_transpose(self, size, row_indices, column_indices, offsets):
+ block_columns = size[1] // self.blocking
+
+ # Sort row indices by column indices to get the transposed matrix's
+ # column indices.
+ #
+ # NOTE: Our sort operation uses the same width indices as the input values.
+ # To avoid overflow when we have large activation matrices we cast to
+ # 32-bit before sorting.
+ _, gather_indices = ops.sort(
+ column_indices.int(),
+ self.transpose_sort_end_bit,
+ )
+
+ # There are a constant number of blocks in every row of the sparse matrix.
+ # A blocks offset is:
+ #
+ # row_index * blocks_per_row + column_index % blocks_per_row
+ #
+ # Once we have the block offsets ordered for transposition we can divide
+ # by blocks_per_row to get the transposed column indices.
+ column_indices_t = row_indices.gather(0, gather_indices.long())
+ block_offsets_t = gather_indices.int()
+
+ zero = torch.zeros((1,), dtype=torch.int32, device=row_indices.device)
+ nnz_per_column = ops.histogram(column_indices, block_columns)
+ nnz_per_column = ops.inclusive_cumsum(nnz_per_column, 0)
+ if nnz_per_column.dim() == 0:
+ # This addresses an edge case when ffn_hidden_size is equal to self.blocking.
+ nnz_per_column = nnz_per_column.unsqueeze(0)
+ offsets_t = torch.cat([zero, nnz_per_column])
+ return column_indices_t, offsets_t, block_offsets_t
+
+ def topology(self, x, padded_bins):
+ padded_tokens, _ = x.size()
+ assert padded_tokens % self.blocking == 0
+ if self.ffn_hidden_size % self.blocking != 0:
+ raise ValueError(
+ f'The ffn_hidden_size {self.ffn_hidden_size} must be divisible by ' +
+ f'the block size {self.blocking}. Please update your configuration.',
+ )
+
+ # Offsets for the sparse matrix. All rows have the
+ # same number of nonzero blocks dictated by the
+ # dimensionality of a single expert.
+ block_rows = padded_tokens // self.blocking
+ blocks_per_row = self.ffn_hidden_size // self.blocking
+ offsets = torch.arange(
+ 0,
+ block_rows * blocks_per_row + 1,
+ blocks_per_row,
+ dtype=torch.int32,
+ device=x.device,
+ )
+
+ # Indices for the sparse matrix. The indices for
+ # the intermediate matrix are dynamic depending
+ # on the mapping of tokens to experts.
+ column_indices = ops.topology(
+ padded_bins,
+ self.blocking,
+ block_rows,
+ blocks_per_row,
+ )
+
+ # TODO(tgale): This is unused. Remove the need for this in stk.
+ # For now, use meta init to save the device memory.
+ data = torch.empty(
+ column_indices.numel(),
+ self.blocking,
+ self.blocking,
+ dtype=common.dtype(self.args),
+ device='meta',
+ )
+ shape = (
+ padded_tokens,
+ self.ffn_hidden_size * mpu.experts_per_rank(self.args),
+ )
+ row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
+ column_indices_t, offsets_t, block_offsets_t = self.sparse_transpose(
+ shape,
+ row_indices,
+ column_indices,
+ offsets,
+ )
+ return stk.Matrix(
+ shape,
+ data,
+ row_indices,
+ column_indices,
+ offsets,
+ column_indices_t,
+ offsets_t,
+ block_offsets_t,
+ )
+
+ def indices_and_padded_bins(self, top_experts):
+ # Sort the expert ids to produce the scatter/gather
+ # indices for the permutation.
+ top_experts = top_experts.int()
+ bin_ids, indices = ops.sort(top_experts, self.sort_end_bit)
+
+ # Histogram the expert ids to identify the number of
+ # tokens routed to each expert.
+ tokens_per_expert = ops.histogram(top_experts, self.num_experts)
+
+ # Round the token counts up to the block size used in
+ # the matrix muliplications. Caculate the starting
+ # position of each bin.
+ padded_tokens_per_expert = ops.round_up(
+ tokens_per_expert,
+ self.blocking,
+ )
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ padded_bins = promote_scalar(padded_bins)
+
+ # Calculate the bin bounds for the sorted tokens.
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ bins = promote_scalar(bins)
+ return indices, bin_ids, bins, padded_bins, tokens_per_expert
+
+ def sparse_forward_once(self, x, expert_weights, top_experts):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, padded_bins, tokens_per_expert = (self.indices_and_padded_bins(top_experts))
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.padded_gather(
+ x,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ self.top_k,
+ )
+
+ # Create the sparse matrix topology.
+ with torch.no_grad():
+ topo = self.topology(x, padded_bins)
+
+ # Perform the expert computation.
+ x = self.mlp(x, topo)
+
+ # Un-route the data for the MoE output.
+ x = ops.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ padded_bins,
+ self.top_k,
+ )
+ return x, tokens_per_expert
+
+ # For use in the base-class parallel_forward_once.
+ def sparse_permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy, # unused
+ top_k,
+ ):
+
+ # Round the token counts up to the block size used in the matrix
+ # multiplication. Calculate the starting position of each bin.
+ padded_tokens_per_expert = ops.round_up(
+ tokens_per_expert,
+ self.blocking,
+ )
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ padded_bins = promote_scalar(padded_bins)
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
+
+ # Create the sparse matrix topology.
+ with torch.no_grad():
+ topo = self.topology(x, padded_bins)
+
+ # Perform the expert computation.
+ x = self.mlp(x, topo)
+
+ # Un-route the data for the MoE output.
+ return ops.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ def grouped_forward_once(self, x, expert_weights, top_experts):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ out = self.grouped_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ -1, # unused
+ self.args.moe_top_k,
+ )
+ return out, tokens_per_expert
+
+ def grouped_permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy, # unused
+ top_k,
+ ):
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.gather(x, indices, bin_ids, bins, top_k)
+
+ # Perform the expert computation.
+ x = self.mlp(x, tokens_per_expert)
+
+ # Un-route the data for the MoE output.
+ return ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
+
+ def forward_once(self, x, expert_weights, top_experts):
+ if self.args.mlp_impl == 'sparse':
+ return self.sparse_forward_once(x, expert_weights, top_experts)
+ else:
+ return self.grouped_forward_once(x, expert_weights, top_experts)
+
+ def permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ ):
+ if self.args.mlp_impl == 'sparse':
+ return self.sparse_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ )
+ else:
+ return self.grouped_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ )
+
+
+class dMoE(moe.MoE):
+
+ def _init_experts_mlp(self, args: Arguments):
+ return ParallelDroplessMLP(args)
diff --git a/build/torch210-cxx11-cpu-x86_64-linux/_layers/gelu.py b/build/torch210-cxx11-cpu-x86_64-linux/_layers/gelu.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4c9e6532798615b5c12c96694241a4c18ee8f7b
--- /dev/null
+++ b/build/torch210-cxx11-cpu-x86_64-linux/_layers/gelu.py
@@ -0,0 +1,52 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# try:
+# import stk
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/gelu.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+import torch.nn.functional as F
+
+
+@torch.jit.script
+def _gelu_backward_inplace(g, x):
+ tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
+ ff = (0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out))
+ return g.mul_(ff)
+
+
+def gelu_backward_(grad: stk.Matrix, x: stk.Matrix):
+ # NOTE: The two sparse matrices must have the same topology.
+ if isinstance(grad, stk.Matrix) and isinstance(x, stk.Matrix):
+ return stk.Matrix(
+ x.size(),
+ _gelu_backward_inplace(grad.data, x.data),
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
+ return _gelu_backward_inplace(grad, x)
+
+
+def gelu(x: stk.Matrix):
+ assert isinstance(x, stk.Matrix)
+ return stk.Matrix(
+ x.size(),
+ F.gelu(x.data, approximate='tanh'),
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
diff --git a/build/torch210-cxx11-cpu-x86_64-linux/_layers/glu.py b/build/torch210-cxx11-cpu-x86_64-linux/_layers/glu.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f297a41ff6a1a2a285f5b461951672364b898da
--- /dev/null
+++ b/build/torch210-cxx11-cpu-x86_64-linux/_layers/glu.py
@@ -0,0 +1,244 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# import stk.ops
+# try:
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/glu.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+
+# from megablocks import grouped_gemm_util as gg
+# from megablocks.layers import common, mpu
+# from megablocks.layers.activation_fn import act_fn
+# from megablocks.layers.arguments import Arguments
+# from megablocks.layers.mlp import (
+# SharedMLP,
+# SparseMLP,
+# create_dmoe_expert_weights,
+# resolve_dtensor,
+# )
+
+from .. import grouped_gemm_util as gg
+from . import common, mpu
+from .activation_fn import act_fn
+from .arguments import Arguments
+from .mlp import (
+ SharedMLP,
+ SparseMLP,
+ create_dmoe_expert_weights,
+ resolve_dtensor,
+)
+
+
+class SparseGLU(SparseMLP):
+
+ def __init__(self, args: Arguments):
+ super().__init__(args)
+ self.v1 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ with torch.no_grad():
+ self.v1.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ ),
+ )
+
+ mpu.set_expert_model_parallel_attributes(
+ self.v1,
+ self._should_set_parallelism_attribute,
+ )
+
+ def forward(self, x, topo):
+ if self.args.memory_optimized_mlp:
+ raise NotImplementedError(
+ 'Memory optimized implementation not yet supported with GLU with sparse kernels.',
+ )
+
+ w1, v1, w2 = self.scale_grad(self.w1), self.scale_grad(self.v1,), self.scale_grad(self.w2)
+ w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2)
+
+ # Compute the GLU.
+ x1 = stk.ops.sdd(x, w1.t(), topo)
+ x2 = stk.ops.sdd(x, v1.t(), topo)
+
+ activation_fn_out = act_fn(x1, self.args.activation_fn)
+ x1 = stk.ops.mul(activation_fn_out, x2)
+
+ return stk.ops.dsd(x1, w2)
+
+
+class MemoryOptimizedGroupedGLU(torch.autograd.Function):
+ """GroupedMLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, v1, w2, batch_sizes, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ v1 = v1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], v1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not v1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1', 'v1' and 'w2'.")
+
+ # Layer 0: x @ w1.t().
+ assert gg.backend is not None
+ sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True)
+ v1_out = gg.backend.gmm(x, v1, batch_sizes, trans_b=True)
+
+ # GeLU.
+ activation_fn_out = activation_fn(sdd_out) * v1_out
+
+ # Layer 1: x @ w2.
+ dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, v1, w2, batch_sizes, x, sdd_out, v1_out)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx, ddsd_out):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # Unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, v1, w2 = saved_tensors[:3]
+ batch_sizes = saved_tensors[3]
+ x = saved_tensors[4]
+ sdd_out, v1_out = saved_tensors[5:7]
+
+ # Rematerialize activation_fn output.
+ activation_fn = ctx.activation_fn
+ with torch.set_grad_enabled(True):
+ sdd_out.requires_grad = True
+ v1_out.requires_grad = True
+ activation_fn_out = activation_fn(sdd_out) * v1_out
+ activation_grad_fn = activation_fn_out.backward
+
+ # Compute dw2 with recomputed activation_fn output.
+ assert gg.backend is not None
+ dw2 = gg.backend.gmm(
+ activation_fn_out,
+ ddsd_out,
+ batch_sizes,
+ trans_a=True,
+ )
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ gg.backend.gmm(
+ ddsd_out,
+ w2,
+ batch_sizes,
+ trans_b=True,
+ c=dactivation_fn_out,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out)
+ dsdd_out = sdd_out.grad
+ dv1_out = v1_out.grad
+
+ # Compute dw1.
+ dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True)
+
+ # Compute dv1.
+ dv1 = gg.backend.gmm(dv1_out, x, batch_sizes, trans_a=True)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ dx = ddsd_out
+ gg.backend.gmm(dsdd_out, w1, batch_sizes, c=dx)
+ dx += gg.backend.gmm(dv1_out, v1, batch_sizes)
+ return dx, dw1, dv1, dw2, None, None
+
+
+memory_optimized_grouped_glu = MemoryOptimizedGroupedGLU.apply
+
+
+class GroupedGLU(SparseGLU):
+
+ def forward(self, x, tokens_per_expert):
+ batch_sizes = tokens_per_expert.cpu().to(torch.long)
+ w1, v1, w2 = (
+ self.scale_grad(self.w1),
+ self.scale_grad(self.v1),
+ self.scale_grad(self.w2),
+ )
+ w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2)
+
+ # Re-shape the weights for the grouped GEMMs.
+ ne = mpu.experts_per_rank(self.args)
+ w1 = w1.view(ne, -1, self.args.hidden_size)
+ v1 = v1.view(ne, -1, self.args.hidden_size)
+ w2 = w2.view(ne, -1, self.args.hidden_size)
+
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_grouped_glu(
+ x,
+ w1,
+ v1,
+ w2,
+ batch_sizes,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ assert gg.ops is not None
+ x1 = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
+ x2 = gg.ops.gmm(x, v1, batch_sizes, trans_b=True)
+ x1 = self.args.activation_fn(x1) * x2
+ return gg.ops.gmm(x1, w2, batch_sizes)
+
+
+class SharedGLU(SharedMLP):
+ """GPU for shared expert.
+
+ Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTGLU class
+ """
+
+ def __init__(self, args: Arguments):
+ super().__init__(args)
+ self.gate_proj = args.fc_cls(
+ args.hidden_size,
+ self.args.shared_expert_hidden_size,
+ **self.fc_kwargs,
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))
diff --git a/build/torch210-cxx11-cpu-x86_64-linux/_layers/memory_test.py b/build/torch210-cxx11-cpu-x86_64-linux/_layers/memory_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..74d1166931b712635131985b25a89f4ca23e576d
--- /dev/null
+++ b/build/torch210-cxx11-cpu-x86_64-linux/_layers/memory_test.py
@@ -0,0 +1,103 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import gc
+
+import torch
+import torch.distributed as dist
+
+# from megablocks.layers import arguments, dmoe
+from . import arguments, dmoe
+
+_TESTS = ((8, 2048, 4096, 4096, 32, 4),)
+
+
+def get_tensors():
+ ptrs = set()
+ out = []
+ for obj in gc.get_objects():
+ if torch.is_tensor(obj):
+ if not obj.is_contiguous() or obj.data_ptr() in ptrs:
+ continue
+ out.append(obj)
+ ptrs.add(obj.data_ptr())
+ return out
+
+
+def test_memory(
+ group,
+ batch_size,
+ sequence_length,
+ hidden_size,
+ ffn_hidden_size,
+ num_experts,
+ top_k,
+):
+ args = arguments.Arguments(
+ hidden_size=hidden_size,
+ ffn_hidden_size=ffn_hidden_size,
+ moe_num_experts=num_experts,
+ moe_top_k=top_k,
+ moe_expert_model_parallelism=True,
+ expert_parallel_group=group,
+ fp16=False,
+ bf16=True,
+ device=torch.cuda.current_device(),
+ )
+ layer = dmoe.dMoE(args).cuda()
+
+ x = torch.randn((batch_size, sequence_length, hidden_size),
+ device=torch.cuda.current_device(),
+ dtype=torch.bfloat16).requires_grad_(True)
+ torch.cuda.empty_cache()
+
+ # Run forward + backward.
+ # with torch.autograd.detect_anomaly():
+ out, _ = layer(x)
+ out.mean().backward()
+
+ # Report peak memory.
+ mem = torch.cuda.max_memory_allocated()
+ print('Max Memory Allocated = {:0.0f}MiB'.format(mem / 1e6))
+ print('Max Memory Reserved = {:0.0f}MiB'.format(torch.cuda.max_memory_reserved() / 1e6,),)
+
+ # Calculate weight and gradient memory usage.
+ weight_memory = 2 * (
+ layer.router.layer.weight.numel() + layer.experts.mlp.w1.numel() + layer.experts.mlp.w2.numel()
+ )
+
+ def grad_numel(x):
+ if x.grad is not None:
+ return x.grad.numel()
+ return 0
+
+ grad_memory = 2 * (
+ grad_numel(layer.router.layer.weight) + grad_numel(layer.experts.mlp.w1) + grad_numel(layer.experts.mlp.w2)
+ )
+ weight_memory += grad_memory
+
+ print('Weight Memory Allocated = {:0.0f}MiB'.format(weight_memory / 1e6))
+ print('Activation Memory Allocated = {:0.0f}MiB'.format((mem - weight_memory) / 1e6,),)
+
+ # Manually calculate GPU memory usage from the garbage
+ # collector.
+ gc.collect()
+ total = 0
+ tensors = get_tensors()
+ tensors = sorted(tensors, key=lambda x: -x.numel())
+ for i, t in enumerate(tensors):
+ total += t.numel()
+ print(f'{i}: {t.shape}, {t.numel() * 2}')
+ del tensors
+
+ print('Total Bytes Found = {:0.0f}MiB'.format(total * 2 / 1e6))
+
+
+if __name__ == '__main__':
+ assert dist.is_available()
+ group = dist.init_process_group(backend='nccl')
+ local_rank = dist.get_rank(group)
+ torch.cuda.set_device(local_rank)
+
+ for args in _TESTS:
+ test_memory(group, *args)
diff --git a/build/torch210-cxx11-cpu-x86_64-linux/_layers/mlp.py b/build/torch210-cxx11-cpu-x86_64-linux/_layers/mlp.py
new file mode 100644
index 0000000000000000000000000000000000000000..c99afb9904c24a8b6a83e79059cd1251dbbfd99e
--- /dev/null
+++ b/build/torch210-cxx11-cpu-x86_64-linux/_layers/mlp.py
@@ -0,0 +1,587 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# try:
+# import stk
+# import stk.backend.triton_kernels
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/mlp.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+from packaging import version
+
+# from megablocks import grouped_gemm_util as gg
+# from megablocks.layers import common, gelu, mpu
+# from megablocks.layers.activation_fn import act_fn
+# from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
+
+from .. import grouped_gemm_util as gg
+from . import common, gelu, mpu
+from .activation_fn import act_fn
+from .arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
+
+class ScaleGradient(torch.autograd.Function):
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx: Any, x: torch.Tensor, scale: float):
+ ctx.scale = scale
+ return x
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx: torch.Tensor, grad: torch.Tensor):
+ return grad * ctx.scale, None
+
+
+scale_gradient = ScaleGradient.apply
+
+
+def resolve_dtensor(weight: torch.Tensor):
+ if version.parse(torch.__version__) >= version.parse('2.0.0'):
+ from torch.distributed._tensor import DTensor
+ if isinstance(weight, DTensor):
+ return weight.to_local()
+ return weight
+
+
+def create_moe_expert_weights(
+ args: Arguments,
+ num_experts: int,
+ ffn_hidden_size: int,
+ hidden_size: int,
+ init_method: InitFn,
+):
+ # Create the entire weight matrix such that the sampled weights will
+ # not vary between data parallelism and expert model parallelism for
+ # the same random seed.
+ master_weights = torch.empty(
+ num_experts,
+ ffn_hidden_size,
+ hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ )
+ init_method(master_weights)
+
+ if not args.moe_expert_model_parallelism:
+ return master_weights
+
+ # Calculate the amount of sharding in each dimension.
+ expert_sharding_degree = mpu.expert_sharding_degree(args)
+ hidden_sharding_degree = mpu.hidden_sharding_degree(args)
+
+ # Calculate the experts per rank.
+ #
+ # NOTE: We assign ranks to be expert parallel before going
+ # tensor parallel.
+ rank = mpu.get_expert_parallel_rank(args)
+ expert_rank = rank % expert_sharding_degree
+ num_experts_per_rank = num_experts // expert_sharding_degree
+ start_expert = expert_rank * num_experts_per_rank
+ end_expert = (expert_rank + 1) * num_experts_per_rank
+
+ # Calculate the rows per rank.
+ row_rank = rank // expert_sharding_degree
+ num_rows_per_rank = ffn_hidden_size // hidden_sharding_degree
+ start_row = row_rank * num_rows_per_rank
+ end_row = (row_rank + 1) * num_rows_per_rank
+
+ # Slice the weight matrix to get the chunk for this rank.
+ with torch.no_grad():
+ weights = master_weights[start_expert:end_expert, start_row:end_row]
+ return weights
+
+
+class MLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ # expert_parallel_world_size = mpu.get_expert_parallel_world_size(args)
+ experts_per_rank = mpu.experts_per_rank(args)
+
+ self.w1 = torch.nn.Parameter(
+ torch.empty(
+ experts_per_rank,
+ args.hidden_size,
+ mpu.features_per_rank(args),
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ self.w2 = torch.nn.Parameter(
+ torch.empty(
+ experts_per_rank,
+ mpu.features_per_rank(args),
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w1,
+ args.moe_expert_model_parallelism,
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w2,
+ args.moe_expert_model_parallelism,
+ )
+
+ # Initialize the parameters for the MLP.
+ #
+ # NOTE: It is important that we create the weight tensors prior
+ # to creating the master weights and slicing our the piece for
+ # this rank. If the master weights are created first the PyTorch
+ # caching allocator appears to use the same memory block for these
+ # and the slice which causes large increases in our peak memory
+ # usage.
+ with torch.no_grad():
+ w1 = create_moe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ )
+ self.w1.copy_(w1.transpose(1, 2).contiguous())
+ self.w2.copy_(
+ create_moe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.output_layer_init_method,
+ ),
+ )
+
+ self.gradient_scale = None
+ if self.args.moe_expert_model_parallelism:
+ self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,)
+
+ def scale_grad(self, w):
+ if self.gradient_scale is None:
+ return w
+ return scale_gradient(w, self.gradient_scale)
+
+ def forward(self, x):
+ w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2)
+ w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2)
+ x = torch.bmm(x, w1)
+ x = self.args.activation_fn(x)
+ return torch.bmm(x, w2)
+
+
+def create_dmoe_expert_weights(
+ args: Arguments,
+ num_experts: int,
+ rows: int,
+ columns: int,
+ init_method: InitFn,
+):
+ weights = create_moe_expert_weights(
+ args,
+ num_experts,
+ rows,
+ columns,
+ init_method,
+ )
+ return weights.view([-1, columns])
+
+
+class MemoryOptimizedMLP(torch.autograd.Function):
+ """Sparse MLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, w2, topo, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.")
+
+ topo_tensors = (
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets,
+ topo.column_indices_t,
+ topo.offsets_t,
+ topo.block_offsets_t,
+ )
+
+ # Layer 0: x @ w1.t().
+ sdd_out = stk.ops.sdd(x, w1.t(), topo)
+
+ # GeLU.
+ activation_fn_out = act_fn(sdd_out, activation_fn)
+
+ # Layer 1: x @ w2.
+ dsd_out = stk.ops.dsd(activation_fn_out, w2)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.shape = topo.shape
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.data.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, w2, *topo_tensors, x, sdd_out.data)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx, ddsd_out):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, w2 = saved_tensors[:2]
+ topo_tensors = saved_tensors[2:8]
+ x = saved_tensors[8]
+ sdd_out_data = saved_tensors[9]
+
+ # rematerialize activation function output
+ activation_fn = ctx.activation_fn
+ sdd_out = stk.Matrix(ctx.shape, sdd_out_data, *topo_tensors)
+ activation_fn_out, activation_grad_fn = act_fn(
+ sdd_out,
+ activation_fn,
+ return_grad_fn=True,
+ )
+
+ # Compute dw2 with recomputed activation_fn output.
+ dw2 = stk.ops.dsd(activation_fn_out.t(), ddsd_out)
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ stk.backend.triton_kernels.sdd(
+ ddsd_out,
+ w2.t(),
+ dactivation_fn_out.shape,
+ dactivation_fn_out.data,
+ dactivation_fn_out.offsets,
+ dactivation_fn_out.row_indices,
+ dactivation_fn_out.column_indices,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ if activation_fn is DEFAULT_ACTIVATION_FN:
+ dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out)
+ else:
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out.data)
+ dsdd_out = stk.Matrix(ctx.shape, sdd_out.data.grad, *topo_tensors)
+
+ # Compute dw1.
+ dw1 = stk.ops.dsd(dsdd_out.t(), x)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ stk.backend.triton_kernels.dsd(
+ dsdd_out.shape,
+ dsdd_out.data,
+ dsdd_out.offsets,
+ dsdd_out.row_indices,
+ dsdd_out.column_indices,
+ dsdd_out.offsets_t,
+ dsdd_out.column_indices_t,
+ dsdd_out.block_offsets_t,
+ False,
+ w1,
+ ddsd_out,
+ )
+ dx = ddsd_out
+ return dx, dw1, dw2, None, None
+
+
+memory_optimized_mlp = MemoryOptimizedMLP.apply
+
+
+class SparseMLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ self._num_rows_per_rank = mpu.experts_per_rank(args) * mpu.features_per_rank(args)
+
+ self.w1 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ self.w2 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+
+ # Initialize the parameters for the MLP.
+ #
+ # NOTE: It is important that we create the weight tensors prior
+ # to creating the master weights and slicing our the piece for
+ # this rank. If the master weights are created first the PyTorch
+ # caching allocator appears to use the same memory block for these
+ # and the slice which causes large increases in our peak memory
+ # usage.
+ with torch.no_grad():
+ self.w1.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ ),
+ )
+ self.w2.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.output_layer_init_method,
+ ),
+ )
+
+ self._should_set_parallelism_attribute = args.moe_expert_model_parallelism
+ mpu.set_expert_model_parallel_attributes(
+ self.w1,
+ self._should_set_parallelism_attribute,
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w2,
+ self._should_set_parallelism_attribute,
+ )
+
+ self.gradient_scale = None
+ if self.args.moe_expert_model_parallelism:
+ self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,)
+
+ def scale_grad(self, w):
+ if self.gradient_scale is None:
+ return w
+ return scale_gradient(w, self.gradient_scale)
+
+ def forward(self, x, topo):
+ w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2)
+ w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2)
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_mlp(
+ x,
+ w1,
+ w2,
+ topo,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ x = stk.ops.sdd(x, w1.t(), topo)
+ activation_fn_out = act_fn(x, self.args.activation_fn)
+ return stk.ops.dsd(activation_fn_out, w2)
+
+
+class MemoryOptimizedGroupedMLP(torch.autograd.Function):
+ """GroupedMLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, w2, batch_sizes, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.")
+
+ # Layer 0: x @ w1.t().
+ assert gg.backend is not None
+ sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True)
+
+ # activation_fn
+ activation_fn_out = activation_fn(sdd_out)
+
+ # Layer 1: x @ w2.
+ dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, w2, batch_sizes, x, sdd_out)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx: Any, ddsd_out: torch.Tensor):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # Unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, w2 = saved_tensors[:2]
+ batch_sizes = saved_tensors[2]
+ x = saved_tensors[3]
+ sdd_out = saved_tensors[4]
+
+ # Rematerialize activation_fn output.
+ activation_fn = ctx.activation_fn
+ with torch.set_grad_enabled(True):
+ sdd_out.requires_grad = True
+ activation_fn_out = activation_fn(sdd_out)
+ activation_grad_fn = activation_fn_out.backward
+
+ # Compute dw2 with recomputed activation_fn output.
+ assert gg.backend is not None
+ dw2 = gg.backend.gmm(
+ activation_fn_out,
+ ddsd_out,
+ batch_sizes,
+ trans_a=True,
+ )
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ gg.backend.gmm(
+ ddsd_out,
+ w2,
+ batch_sizes,
+ trans_b=True,
+ c=dactivation_fn_out,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ if activation_fn is DEFAULT_ACTIVATION_FN:
+ dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out)
+ else:
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out)
+ dsdd_out = sdd_out.grad
+
+ # Compute dw1.
+ dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ gg.backend.gmm(dsdd_out, w1, batch_sizes, c=ddsd_out)
+ dx = ddsd_out
+ return dx, dw1, dw2, None, None
+
+
+memory_optimized_grouped_mlp = MemoryOptimizedGroupedMLP.apply
+
+
+class GroupedMLP(SparseMLP):
+
+ def forward(self, x, tokens_per_expert):
+ batch_sizes = tokens_per_expert.cpu().to(torch.long)
+ w1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.w2))
+
+ # Re-shape the weights for the grouped GEMMs.
+ ne = mpu.experts_per_rank(self.args)
+ w1 = resolve_dtensor(w1).view(ne, -1, self.args.hidden_size)
+ w2 = resolve_dtensor(w2).view(ne, -1, self.args.hidden_size)
+
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_grouped_mlp(
+ x,
+ w1,
+ w2,
+ batch_sizes,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ assert gg.ops is not None
+ x = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
+ x = self.args.activation_fn(x)
+ return gg.ops.gmm(x, w2, batch_sizes)
+
+
+class SharedMLP(torch.nn.Module):
+ """MLP for shared expert.
+
+ Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTMLP class
+ """
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ self.fc_kwargs: dict[str, Any] = {
+ 'bias': args.bias,
+ 'device': args.device,
+ }
+ self.fc_kwargs.update(args.fc_kwargs)
+
+ self.up_proj = args.fc_cls(
+ args.hidden_size,
+ args.shared_expert_hidden_size,
+ **self.fc_kwargs,
+ )
+ self.act = args.activation_fn
+ self.down_proj = args.fc_cls(
+ args.shared_expert_hidden_size,
+ args.hidden_size,
+ **self.fc_kwargs,
+ )
+ self.down_proj._is_residual = True # a flag for llm-foundry init
+
+ def add_experts_sharedexpert(
+ self,
+ shared_expert_out: torch.Tensor,
+ expert_out: torch.Tensor,
+ ) -> torch.Tensor:
+ # Helper function to add expert output to shared expert output
+ # with optional weighted sum.
+ if self.args.shared_expert_weighted_sum:
+ # enable using weighted sum for shared expert output
+ # wieghted by number of experts used
+ t_experts = self.args.moe_top_k + 1
+ sh_mlp_out = shared_expert_out / t_experts
+ return sh_mlp_out.add(
+ expert_out,
+ alpha=(self.args.moe_top_k / t_experts),
+ )
+
+ return shared_expert_out + expert_out
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.down_proj(self.act(self.up_proj(x)))
diff --git a/build/torch210-cxx11-cpu-x86_64-linux/_layers/moe.py b/build/torch210-cxx11-cpu-x86_64-linux/_layers/moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0a4aeaacc9c86fc70944e730c53f7a55644e05e
--- /dev/null
+++ b/build/torch210-cxx11-cpu-x86_64-linux/_layers/moe.py
@@ -0,0 +1,507 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Optional, Tuple
+
+import numpy as np
+import torch
+import torch.distributed as dist
+
+# import megablocks.ops as ops
+# from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry
+# from megablocks.layers.all_to_all import all_to_all
+# from megablocks.layers.arguments import Arguments
+
+from ..ops import (
+ sort,
+ histogram,
+ inclusive_cumsum,
+ exclusive_cumsum,
+ binned_gather,
+ binned_scatter,
+ gather,
+ scatter,
+ repeat,
+ replicate,
+)
+
+from . import common, mlp, mpu, router, sharedexpert_registry
+from .arguments import Arguments
+from .all_to_all import all_to_all
+
+_LOAD_BALANCING_LOSS = []
+
+
+def save_load_balancing_loss(loss):
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.append(loss)
+
+
+def get_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ return _LOAD_BALANCING_LOSS
+
+
+def clear_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.clear()
+
+
+def batched_load_balancing_loss(args: Arguments):
+ if args.moe_loss_weight == 0:
+ return 0.0
+
+ # tokens_per_expert[i].shape = (num_experts)
+ # expert_scores[i].shape = (tokens, num_experts)
+ tokens_per_expert, expert_scores = zip(*get_load_balancing_loss())
+ num_layers_per_pipeline_stage = (args.num_layers // args.pipeline_model_parallel_size)
+ if args.num_layers_per_virtual_pipeline_stage is not None:
+ num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage
+
+ if len(tokens_per_expert) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f'Expected {num_layers_per_pipeline_stage} token_per_experts '
+ f'but found {len(tokens_per_expert)}.\nnum_layers = '
+ f'{args.num_layers}\npipeline_model_parallel_size = '
+ f'{args.pipeline_model_parallel_size}\n'
+ 'num_layers_per_virtual_pipeline_stage'
+ f' = {args.num_layers_per_virtual_pipeline_stage}',
+ )
+ if len(expert_scores) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f'Expected {num_layers_per_pipeline_stage} expert_scores '
+ f'but found {len(tokens_per_expert)}.\nnum_layers = '
+ f'{args.num_layers}\npipeline_model_parallel_size = '
+ f'{args.pipeline_model_parallel_size}\n'
+ 'num_layers_per_virtual_pipeline_stage'
+ f' = {args.num_layers_per_virtual_pipeline_stage}',
+ )
+
+ # Verify the shape of the tokens_per_expert and expert_scores tensors.
+ assert all((x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert))
+
+ tokens = expert_scores[0].shape[0]
+ assert all(((x.ndim == 2 and x.shape[1] == args.moe_num_experts and x.shape[0] == tokens) for x in expert_scores))
+
+ # Concatenate the contributions of each layer and convert to
+ # the correct types and formats for the dot product.
+ expert_scores = torch.cat(expert_scores, dim=1)
+ if args.moe_lbl_in_fp32:
+ expert_scores = expert_scores.float()
+ if tokens != 0:
+ expert_scores = expert_scores.mean(dim=0)
+ else:
+ expert_scores = expert_scores.sum(dim=0)
+ tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype)
+
+ expected_values = num_layers_per_pipeline_stage * args.moe_num_experts
+ assert tokens_per_expert.numel() == expected_values
+ assert expert_scores.numel() == expected_values
+
+ # Calculate the total scale across all factors.
+ #
+ # loss_weight * num_experts / (num_layers * tokens * top_k)
+ scale_numerator = (args.moe_num_experts * args.moe_loss_weight)
+ scale_denominator = (args.num_layers * tokens * args.moe_top_k)
+ scale = scale_numerator / scale_denominator
+ return scale * torch.dot(tokens_per_expert, expert_scores)
+
+
+# NOTE: This class defines MoE expert computation, including expert model parallel
+# communication. When using FSDP on top of MegaBlocks this is the module that should
+# be wrapped s.t. the weight all-gathers can be scheduled *before* the expert model
+# parallel all2all.
+class ParallelMLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super(ParallelMLP, self).__init__()
+ self.args = args
+
+ # Calculate the number of experts in total and the number of experts
+ # owned by this rank.
+ # world_size = mpu.get_expert_parallel_world_size(args)
+ self.num_experts = args.moe_num_experts
+ self.top_k = self.args.moe_top_k
+
+ # Calculate the number of bits needed to represent the expert indices
+ # so that we can pass it to radix sort.
+ self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1)
+
+ # Expert MLP.
+ self.mlp = mlp.MLP(args)
+
+ self.bias: Optional[torch.Tensor]
+ if self.args.bias:
+ # Note that the output bias is not parallelized with expert
+ # model parallelism.
+ self.bias = torch.nn.Parameter(
+ torch.empty(
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ torch.nn.init.zeros_(self.bias)
+ else:
+ self.register_parameter('bias', None)
+
+ # Select the forward function for the operating mode.
+ self.forward_fn = (self.parallel_forward_once if args.moe_expert_model_parallelism else self.forward_once)
+
+ def expert_capacity(self, tokens: int) -> int:
+ world_size = mpu.get_expert_parallel_world_size(self.args)
+ tokens_per_expert = (self.top_k * tokens * world_size / self.num_experts)
+ return int(self.args.moe_capacity_factor * tokens_per_expert)
+
+ def load_balancing_loss(self, tokens_per_expert: torch.Tensor, expert_scores: torch.Tensor):
+ """Calculate the load balancing loss contribution."""
+ assert len(expert_scores.size()) == 2
+ tokens, num_experts = expert_scores.size()
+ assert num_experts == self.num_experts
+ assert len(tokens_per_expert.size()) == 1
+ num_experts, = tokens_per_expert.size()
+ assert num_experts == self.num_experts
+ scale = self.num_experts / (tokens * self.top_k)
+ return scale * torch.dot(
+ tokens_per_expert.to(expert_scores.dtype),
+ expert_scores.mean(dim=0),
+ )
+
+ def indices_and_bins(self,
+ top_expert: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ # Sort the expert ids to produce the scatter/gather
+ # indices for the permutation.
+ #
+ # TODO(tgale): Is it worth doing this conversion to 32-bit
+ # prior? Could we place the `torch.max` operation to return
+ # 32-bit expert indices?
+ top_expert = top_expert.int()
+ # output = ops.sort(top_expert, self.sort_end_bit)
+ output = sort(top_expert, self.sort_end_bit)
+ assert output is not None
+ bin_ids, indices = output
+
+ # Histogram the expert ids to identify the number of
+ # tokens routed to each expert.
+ #
+ # TODO(tgale): Does the sorted data produce a more favorable
+ # data distribution for histogram? Or is the op parallelism
+ # worth more?
+ # tokens_per_expert = ops.histogram(top_expert, self.num_experts)
+ tokens_per_expert = histogram(top_expert, self.num_experts)
+
+ # Calculate the bin bounds for the sorted tokens.
+ # bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ bins = inclusive_cumsum(tokens_per_expert, 0)
+ assert bins is not None
+ bins = bins.view(1) if not len(bins.size()) else bins
+
+ assert isinstance(indices, torch.Tensor)
+ assert isinstance(bin_ids, torch.Tensor)
+ assert isinstance(bins, torch.Tensor)
+ assert isinstance(tokens_per_expert, torch.Tensor)
+
+ return indices, bin_ids, bins, tokens_per_expert
+
+ def permute_and_compute(
+ self,
+ x: torch.Tensor,
+ tokens_per_expert: int, # unused
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor, # unused
+ expert_weights: torch.Tensor,
+ bins: torch.Tensor,
+ expert_capacity: int,
+ top_k: int,
+ ):
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ # output = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
+ output = binned_gather(x, indices, bins, expert_capacity, top_k)
+ assert output is not None
+ x = output
+
+ # Perform the expert computation. Note that we don't
+ # use biases for these linear operations.
+ x = self.mlp(x)
+
+ # Un-route the data for the MoE output.
+ # return ops.binned_scatter(x, indices, expert_weights, bins, top_k)
+ return binned_scatter(x, indices, expert_weights, bins, top_k)
+
+
+ def forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ # If expert_capacity is set to zero, set the number of tokens
+ # per expert to the maximum we need to avoid dropping tokens.
+ sl, bs, _ = x.size()
+ expert_capacity = self.expert_capacity(sl * bs)
+ if expert_capacity == 0:
+ expert_capacity = torch.max(tokens_per_expert).item()
+
+ x = self.permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ self.top_k,
+ )
+ return x, tokens_per_expert
+
+ def parallel_forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ # NOTE: This function implements the same computation as forward_once
+ # but with expert model parallelism.
+ #
+ # 1. Permute the tokens locally so that they are grouped by their
+ # expert assignments. This allows us to transfer all of the tokens
+ # for a remote device in one communication primitive.
+ #
+ # 2. Permute the tokens across the expert parallel devices. After
+ # this is completed each device has all of the tokens assigned to
+ # its set of experts in its local HBM.
+ #
+ # 3. Permute the tokens locally so that they are grouped by their
+ # expert assignement. After the distributed permutation the tokens
+ # are grouped by which device they came from. We re-order them
+ # locally to allow for efficient computation.
+ #
+ # After this series of permutations we compute the linear layers
+ # and then repeat these three steps in reverse to produce the final
+ # output.
+ #
+ # Compute the mapping of local tokens to experts.
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ # If we're sharding the experts along the hidden dimension
+ # multiple devices own parts of the same sets of experts.
+ # Replicate the token counts so every device gets the counts.
+ # repeated_tokens_per_expert = ops.repeat(
+ repeated_tokens_per_expert = repeat(
+ tokens_per_expert,
+ (mpu.hidden_sharding_degree(self.args),),
+ )
+
+ # Pass token count information to the device on which the
+ # target expert resides.
+ parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert,)
+ tpe_handle = dist.all_to_all_single(
+ parallel_tokens_per_expert,
+ repeated_tokens_per_expert,
+ group=self.args.expert_parallel_group,
+ async_op=True,
+ )
+
+ # Permute locally and without any padding so that tokens for each
+ # parallel device are stored contiguously.
+ #
+ # This view updates the shape of the tensor from [sl, bs, hs] to
+ # [sl * bs, hs] prior to the permutation.
+ x = x.view(-1, x.shape[-1])
+ # output = ops.gather(x, indices, bin_ids, bins, self.top_k)
+ output = gather(x, indices, bin_ids, bins, self.top_k)
+ assert output is not None
+ x = output
+
+ # Compute the number of tokens that will be received from each
+ # device and permute the input data across the devices.
+ with torch.no_grad():
+ tpe_handle.wait()
+ experts_per_rank = mpu.experts_per_rank(self.args)
+
+ # Reshape to [world_size, num_experts_per_rank].
+ world_size = mpu.get_expert_parallel_world_size(self.args)
+ repeated_tokens_per_expert = (repeated_tokens_per_expert.view(world_size, experts_per_rank))
+ parallel_tokens_per_expert = (parallel_tokens_per_expert.view(world_size, experts_per_rank))
+
+ # TODO(tgale): It might be faster to do this on the GPU and
+ # then communicate the results back to the host.
+ send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1)
+ parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
+ recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1)
+
+ # Convert the send/recv counts to lists.
+ send_counts = send_counts.tolist()
+ recv_counts = recv_counts.tolist()
+ tokens_received = sum(recv_counts)
+
+ # If we're sharding the experts along the hidden dimension
+ # multiple devices own parts of the same sets of experts.
+ # Replicate the token counts so devices that share experts
+ # get all of the tokens assigned to them.
+ #
+ # TODO(tgale): Fuse this into the prior, local permutation.
+ # x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
+ x = repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
+
+ # Start the cross-device permutation asynchronously so we can
+ # overlap communication with computation.
+ parallel_x, parallel_x_handle = all_to_all(
+ x,
+ recv_counts,
+ send_counts,
+ self.args.expert_parallel_group,
+ async_op=True,
+ )
+
+ with torch.no_grad():
+ # After we do the cross-device permutation we have the tokens on the
+ # correct device but not yet grouped by expert because we received
+ # tokens from each device as contiguous chunks. To group the tokens
+ # for expert computation we'll do one more local permutation. The
+ # rest of this torch.no_grad() scope sets up the indices and bins
+ # for this permutation.
+ # replicate_bins = ops.inclusive_cumsum(
+ replicate_bins = inclusive_cumsum(
+ parallel_tokens_per_expert.flatten(),
+ 0,
+ )
+ replicate_bins = (replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins)
+
+ # Construct the expert indices for the permuted tokens.
+ parallel_top_expert = torch.remainder(
+ torch.arange(
+ self.num_experts * mpu.hidden_sharding_degree(self.args),
+ dtype=torch.int32,
+ device=indices.device,
+ ),
+ mpu.experts_per_rank(self.args),
+ )
+ # parallel_top_expert = ops.replicate(
+ parallel_top_expert = replicate(
+ parallel_top_expert.unsqueeze(dim=0),
+ replicate_bins,
+ tokens_received,
+ ).flatten()
+
+ # TODO(tgale): The sort_end_bit here can be reduced.
+ # parallel_bin_ids, parallel_indices = ops.sort(
+ parallel_bin_ids, parallel_indices = sort(
+ parallel_top_expert,
+ self.sort_end_bit,
+ )
+
+ # Calculate the bins boundaries from the token counts.
+ parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+ # parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = (parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins)
+
+ # If expert_capacity is set to zero, set the number of tokens
+ # per expert to the maximum we need to avoid dropping tokens.
+ tokens, _ = x.size()
+ expert_capacity = self.expert_capacity(tokens)
+ if expert_capacity == 0:
+ expert_capacity = torch.max(parallel_tokens_per_expert).item()
+
+ # Locally permute the tokens and perform the expert computation.
+ # Block to make sure that the cross-device permutation is complete.
+ if self.args.mlp_impl == 'grouped':
+ # GroupedMLP requires counts on CPU. We can use the tensor already
+ # moved to CPU for the prior all_to_all, which avoids an extra
+ # device synchronization.
+ parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+ parallel_x_handle.wait()
+ parallel_x = self.permute_and_compute(
+ parallel_x,
+ parallel_tokens_per_expert,
+ parallel_indices,
+ parallel_bin_ids,
+ None, # expert_weights
+ parallel_bins,
+ expert_capacity,
+ top_k=1,
+ )
+
+ # Un-permute the tokens across the devices.
+ x, _ = all_to_all(
+ parallel_x,
+ send_counts,
+ recv_counts,
+ self.args.expert_parallel_group,
+ )
+
+ # Reduce along the hidden sharding to get the final outputs.
+ #
+ # TODO(tgale): Fuse this into the following local permutation.
+ shape = (
+ mpu.hidden_sharding_degree(self.args),
+ -1,
+ self.args.hidden_size,
+ )
+ # x = ops.sum(x.view(shape), dim=0)
+ x = x.view(shape).sum(dim=0)
+
+ # Un-permute locally to setup for the next series of operations.
+ # x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
+ x = scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
+ return x, tokens_per_expert.flatten()
+
+ def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ in_shape = x.size()
+
+ # Compute the experts.
+ x, tokens_per_expert = self.forward_fn(x, expert_weights, top_experts)
+ if self.training and self.args.moe_loss_weight > 0:
+ save_load_balancing_loss((tokens_per_expert, scores))
+ x = x.view(in_shape)
+ if self.bias is not None:
+ if self.args.return_bias:
+ return x, self.bias
+ return x + self.bias
+ return x
+
+
+class MoE(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super(MoE, self).__init__()
+
+ # Token router.
+ self.router = router.LearnedRouter(args)
+
+ # Expert computation helper.
+ self.experts = self._init_experts_mlp(args)
+
+ self.shared_expert = None
+ if args.shared_expert:
+ # SharedExpert computation helper.
+ self.shared_expert = sharedexpert_registry.get(args)
+
+ def _init_experts_mlp(self, args: Arguments):
+ return ParallelMLP(args)
+
+ def forward(self, x: torch.Tensor):
+ # NOTE: If we're going to cast the activations to lower precision
+ # do it before we permute the tokens to save bandwidth.
+ x = common.cast_if_autocast_enabled(x)
+
+ # Compute the expert scores and assignments.
+ scores, expert_weights, top_experts = self.router(x)
+
+ # Compute the experts.
+ out = self.experts(x, scores, expert_weights, top_experts)
+ if self.shared_expert is not None:
+ shared_expert_out = self.shared_expert(x)
+ out = self.shared_expert.add_experts_sharedexpert(
+ shared_expert_out,
+ out,
+ )
+ return out
diff --git a/build/torch210-cxx11-cpu-x86_64-linux/_layers/mpu.py b/build/torch210-cxx11-cpu-x86_64-linux/_layers/mpu.py
new file mode 100644
index 0000000000000000000000000000000000000000..434e143ab42bf3f83406d69e9dd1f72777716e22
--- /dev/null
+++ b/build/torch210-cxx11-cpu-x86_64-linux/_layers/mpu.py
@@ -0,0 +1,94 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Optional
+
+import torch
+import torch.distributed as dist
+
+# from megablocks.layers.arguments import Arguments
+from .arguments import Arguments
+
+
+class MoeParam(torch.Tensor):
+
+ def __init__(self):
+ super().__init__(self)
+ self.expert_model_parallel: bool
+
+
+def is_moe_param(tensor: torch.Tensor) -> bool:
+ return hasattr(tensor, 'expert_model_parallel')
+
+
+def get_expert_parallel_world_size(args: Arguments) -> int:
+ return (dist.get_world_size(args.expert_parallel_group) if args.moe_expert_model_parallelism else 1)
+
+
+def get_expert_parallel_rank(args: Arguments) -> int:
+ return (dist.get_rank(args.expert_parallel_group) if args.moe_expert_model_parallelism else 0)
+
+
+def set_expert_model_parallel_attributes(
+ tensor: torch.Tensor,
+ is_parallel: bool,
+):
+ assert not hasattr(tensor, 'expert_model_parallel')
+ setattr(tensor, 'expert_model_parallel', is_parallel)
+
+
+def param_is_expert_model_parallel(param: MoeParam) -> bool:
+ return (hasattr(param, 'expert_model_parallel') and param.expert_model_parallel)
+
+
+def copy_expert_model_parallel_attributes(
+ destination_tensor: torch.Tensor,
+ source_tensor: torch.Tensor,
+):
+ if hasattr(source_tensor, 'expert_model_parallel'):
+ setattr(
+ destination_tensor,
+ 'expert_model_parallel',
+ getattr(source_tensor, 'expert_model_parallel'),
+ )
+
+
+def synchronized_print(group: Optional[dist.ProcessGroup], *x: torch.Tensor):
+ world_size = dist.get_world_size(group)
+ rank = dist.get_rank(group)
+ for i in range(world_size):
+ dist.barrier(group)
+ if i == rank:
+ print(f'rank = {rank}', *x)
+
+
+# Helpers for expert/tensor sharding.
+def expert_sharding_degree(args: Arguments) -> int:
+ world_size = get_expert_parallel_world_size(args)
+ esd = min(world_size, args.moe_num_experts)
+
+ if (args.moe_num_experts % esd) != 0:
+ raise ValueError(f'Cannot shard {args.moe_num_experts} experts {esd} ways.',)
+ return esd
+
+
+def hidden_sharding_degree(args: Arguments) -> int:
+ world_size = get_expert_parallel_world_size(args)
+ esd = expert_sharding_degree(args)
+ hsd = world_size // esd
+
+ if (args.ffn_hidden_size % hsd) != 0:
+ raise ValueError(f'Cannot shard {args.ffn_hidden_size} features {hsd} ways.',)
+ if (esd * hsd) != world_size:
+ raise ValueError(
+ f"Invalid sharding. 'expert_sharding_degree' ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size}).",
+ )
+ return hsd
+
+
+def experts_per_rank(args: Arguments) -> int:
+ return args.moe_num_experts // expert_sharding_degree(args)
+
+
+def features_per_rank(args: Arguments) -> int:
+ return args.ffn_hidden_size // hidden_sharding_degree(args)
diff --git a/build/torch210-cxx11-cpu-x86_64-linux/_layers/router.py b/build/torch210-cxx11-cpu-x86_64-linux/_layers/router.py
new file mode 100644
index 0000000000000000000000000000000000000000..37cb2782348d62583376f1a183c7ede83601216d
--- /dev/null
+++ b/build/torch210-cxx11-cpu-x86_64-linux/_layers/router.py
@@ -0,0 +1,116 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+
+# from megablocks.layers import common
+# from megablocks.layers.arguments import Arguments
+from . import common
+from .arguments import Arguments
+
+_ROUTER_LOGITS = []
+
+
+def _save_router_logits(logits: torch.Tensor, args: Arguments):
+ if args.moe_zloss_weight == 0:
+ return
+ global _ROUTER_LOGITS
+ _ROUTER_LOGITS.append(logits)
+
+
+def clear_router_zloss():
+ global _ROUTER_LOGITS
+ _ROUTER_LOGITS.clear()
+
+
+def batched_router_zloss(args: Arguments):
+ global _ROUTER_LOGITS
+
+ if args.moe_zloss_weight == 0:
+ import warnings
+ warnings.warn('Call to batched_router_zloss, but moe_zloss_weight=0')
+ return 0
+
+ logits_per_router = _ROUTER_LOGITS
+
+ if args.moe_zloss_in_fp32:
+ logits_per_router = [logits.float() for logits in logits_per_router]
+
+ unscaled_zloss_per_router = torch.stack([
+ torch.logsumexp(logits, dim=1).square().mean() for logits in logits_per_router
+ ])
+
+ return args.moe_zloss_weight * unscaled_zloss_per_router
+
+
+# NOTE: To enable end-to-end benchmarking without convergence we
+# support a flag to force the router to assign tokens uniformly
+# across the experts. We do this with a custom autograd operation
+# so that PyTorch still executes the full set of router operation.
+class _UniformExpertAssignment(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, num_experts: int):
+ out = torch.arange(x.numel(), dtype=x.dtype, device=x.device)
+ out = torch.remainder(out, num_experts)
+ return out.view(x.shape)
+
+
+_uniform_expert_assignment = _UniformExpertAssignment.apply
+
+
+class LearnedRouter(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+
+ # Learned router parameters.
+ #
+ # NOTE: This weight matrix is not parallelized with expert model
+ # parallelism. Each device needs the entire router weight matrix
+ # so that it can route its batch of data correctly.
+ self.layer = torch.nn.Linear(
+ args.hidden_size,
+ args.moe_num_experts,
+ bias=False,
+ dtype=common.dtype(args),
+ device=args.device,
+ )
+ args.init_method(self.layer.weight)
+
+ def jitter(self, x: torch.Tensor):
+ low: float = 1.0 - self.args.moe_jitter_eps
+ high: float = 1.0 + self.args.moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return low + noise * (high - low)
+
+ def _top_k(self, scores: torch.Tensor):
+ if self.args.moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, self.args.moe_top_k, dim=-1)
+
+ def forward(self, x: torch.Tensor):
+ if self.training and self.args.moe_jitter_eps is not None:
+ x = x * self.jitter(x)
+
+ logits = self.layer(x.view(-1, x.shape[-1]))
+ _save_router_logits(logits, self.args)
+ scores = logits.softmax(dim=-1)
+ expert_weights, expert_indices = self._top_k(scores)
+ if self.args.moe_normalize_expert_weights:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=self.args.moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+
+ expert_indices = (
+ _uniform_expert_assignment(
+ expert_indices,
+ self.args.moe_num_experts,
+ ) if self.args.uniform_expert_assignment else expert_indices
+ )
+ return scores, expert_weights, expert_indices
diff --git a/build/torch210-cxx11-cpu-x86_64-linux/_layers/sharedexpert_registry.py b/build/torch210-cxx11-cpu-x86_64-linux/_layers/sharedexpert_registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..5840862f88f370ace5fd49bd0612fc98d186cc49
--- /dev/null
+++ b/build/torch210-cxx11-cpu-x86_64-linux/_layers/sharedexpert_registry.py
@@ -0,0 +1,32 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Union
+
+# from megablocks.layers import glu, mlp
+# from megablocks.layers.arguments import Arguments
+from . import glu, mlp
+from .arguments import Arguments
+
+_REGISTRY = {
+ 'mlp': mlp.SharedMLP,
+ 'glu': glu.SharedGLU,
+}
+
+
+def get(args: Arguments) -> Union[mlp.SharedMLP, glu.SharedGLU]:
+ """Returns an SharedMLP for use in a dMoE instance.
+
+ Uses the provided arguments to instantiate the appropriate
+ SharedMLP instance.
+
+ Args:
+ args: propagated Arguments dataclass.
+
+ Returns:
+ An instantiated SharedMLP constructed using the input args.
+ """
+ if args.mlp_type not in _REGISTRY:
+ raise ValueError(f'Unsupported mlp type: {args.mlp_type}')
+
+ return _REGISTRY[args.mlp_type](args)
diff --git a/build/torch210-cxx11-cpu-x86_64-linux/_megablocks_cpu_7a6bcf4.abi3.so b/build/torch210-cxx11-cpu-x86_64-linux/_megablocks_cpu_7a6bcf4.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..5ba8ac355e5d280728ad5f5585983fc4627eb4ad
--- /dev/null
+++ b/build/torch210-cxx11-cpu-x86_64-linux/_megablocks_cpu_7a6bcf4.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:10f9bb557d5036d3b215b4dabcb0c44a51d276a6a3ab67c37c37dfca3259f824
+size 2219080
diff --git a/build/torch210-cxx11-cpu-x86_64-linux/_ops.py b/build/torch210-cxx11-cpu-x86_64-linux/_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..102573e975ffc7897e0e9c4edca028ed1dc67419
--- /dev/null
+++ b/build/torch210-cxx11-cpu-x86_64-linux/_ops.py
@@ -0,0 +1,9 @@
+import torch
+from . import _megablocks_cpu_7a6bcf4
+ops = torch.ops._megablocks_cpu_7a6bcf4
+
+def add_op_namespace_prefix(op_name: str):
+ """
+ Prefix op by namespace.
+ """
+ return f"_megablocks_cpu_7a6bcf4::{op_name}"
diff --git a/build/torch210-cxx11-cpu-x86_64-linux/_version.py b/build/torch210-cxx11-cpu-x86_64-linux/_version.py
new file mode 100644
index 0000000000000000000000000000000000000000..c55783177af19bc03654c730c4892df8f8532279
--- /dev/null
+++ b/build/torch210-cxx11-cpu-x86_64-linux/_version.py
@@ -0,0 +1,6 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+"""The MegaBlocks Version."""
+
+__version__ = '0.11.0.dev0'
diff --git a/build/torch210-cxx11-cpu-x86_64-linux/backend/__init__.py b/build/torch210-cxx11-cpu-x86_64-linux/backend/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d4e43e9b7e6a9a1c3bde2df34914643ca5d8332
--- /dev/null
+++ b/build/torch210-cxx11-cpu-x86_64-linux/backend/__init__.py
@@ -0,0 +1,2 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
diff --git a/build/torch210-cxx11-cpu-x86_64-linux/backend/kernels.py b/build/torch210-cxx11-cpu-x86_64-linux/backend/kernels.py
new file mode 100644
index 0000000000000000000000000000000000000000..d20139324883338ddb312e4b05a72056d74491ac
--- /dev/null
+++ b/build/torch210-cxx11-cpu-x86_64-linux/backend/kernels.py
@@ -0,0 +1,557 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import triton
+import triton.language as tl
+
+# Stub triton autotune when testing in a env that does not have CUDA
+# this approach preserves the original code but enables testing without a GPU
+if torch.cuda.is_available() is False:
+ import warnings
+
+ warnings.warn("CUDA is not available. Triton autotuning is disabled.")
+
+ def _no_autotune(*args, **kwargs):
+ def deco(fn):
+ return fn
+ return deco
+
+ triton.autotune = _no_autotune
+
+
+def assert_is_tensor(x, ndim):
+ if x.ndim != ndim:
+ raise ValueError(f'Expected {ndim}-tensor but got {x.ndim}-tensor')
+
+
+def assert_is_matrix(x):
+ assert_is_tensor(x, 2)
+
+
+def assert_is_vector(x):
+ if x.ndim != 1:
+ raise ValueError(f'Expected 1-tensor but got {x.ndim}-tensor')
+
+
+def assert_equal(a, b):
+ if a != b:
+ raise ValueError(f'Expected dimensions to be equal but got {a} and {b}.',)
+
+
+# a: (tokens, hidden_size), real.
+# indices: (tokens * top_k), integer.
+# bin_ids: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+# padded_bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _padded_copy(
+ a,
+ b,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+ A_TO_B: tl.constexpr,
+ SCALE: tl.constexpr,
+):
+ # Our index into array 'a'.
+ index_a = tl.load(indices + tl.program_id(0))
+
+ # One threadblock per row in 'a'. Array 'b' has greater or equal
+ # number of rows since they could be padded.
+ bin_idx = tl.load(bin_ids + tl.program_id(0))
+
+ # Now we know what bin we're assigned to, but we need to know how
+ # many threadblocks were assigned to earlier bins so we can offset
+ # in our bin properly.
+ offset_in_bin = tl.program_id(0)
+ if bin_idx > 0:
+ offset_in_bin -= tl.load(bins + bin_idx - 1)
+
+ # Load the starting index of our bin in array 'b'.
+ index_b = offset_in_bin
+ if bin_idx > 0:
+ index_b += tl.load(padded_bins + bin_idx - 1)
+
+ # Offset the input and output pointers.
+ #
+ # If we're going from A to B, divide the input index to copy
+ # the same input repeatedly. If we're going from B to A we
+ # need to reduce the result. Using atomics is slow, so we
+ # do the reduce step in a second kernel.
+ offset = index_a // TOP_K if A_TO_B else index_a
+ a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS)
+ b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ # Load the scale, if requested.
+ scale = tl.load(weights + index_a) if SCALE else 1
+
+ # Swap the pointers depending on the direction.
+ iptr = a if A_TO_B else b
+ optr = b if A_TO_B else a
+
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ x = tl.load(iptr + offsets, mask=mask)
+ x = x.to(tl.float32) * scale.to(tl.float32)
+
+ tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask)
+
+ offsets += BLOCK_X
+
+
+def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+ assert_equal(bin_ids.shape[0], x.shape[0] * top_k)
+ assert_equal(bins.size(), padded_bins.size())
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ # NOTE: Because of the padding, the output size is dynamic.
+ # We load the final padded bin bound to get the output rows.
+ output_rows = padded_bins[-1].cpu().item()
+ out = torch.zeros((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ x,
+ out,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def gather(x, indices, bin_ids, weights, bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+ assert_equal(bin_ids.shape[0], x.shape[0] * top_k)
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ # NOTE: There is no padding so the output rows equals the
+ # input rows multiplied by top_k.
+ output_rows = x.shape[0] * top_k
+ out = torch.empty((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ x,
+ out,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], bin_ids.shape[0])
+ assert_equal(bins.size(), padded_bins.size())
+
+ if weights is not None:
+ assert_equal(indices.shape[0], weights.shape[0])
+
+ tokens = indices.shape[0] // top_k
+ out = torch.empty((tokens, top_k, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ out,
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=False,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+
+ # Reduce along the top-k dimension, if needed.
+ return out.sum(dim=1) if top_k > 1 else out.view(tokens, x.shape[1])
+
+
+def scatter(x, indices, bin_ids, weights, bins, top_k):
+ return padded_scatter(x, indices, bin_ids, weights, bins, bins, top_k)
+
+
+# x: (tokens, top_k, hidden_size), real
+# grad: (tokens, hidden_size), real.
+# wgrad: (tokens, top_k), real.
+# indices: (tokens * top_k), integer.
+# bin_ids: (tokens * top_k), integer.
+# bins: (num_experts), integer.
+# padded_bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _padded_copy_wgrad(
+ x,
+ grad,
+ wgrad,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+):
+ # Our index into 'tokens * top_k'.
+ index_out = tl.load(indices + tl.program_id(0))
+
+ # One threadblock per row in 'a'. Array 'b' has greater or equal
+ # number of rows since they could be padded.
+ bin_idx = tl.load(bin_ids + tl.program_id(0))
+
+ # Now we know what bin we're assigned to, but we need to know how
+ # many threadblocks were assigned to earlier bins so we can offset
+ # in our bin properly.
+ offset_in_bin = tl.program_id(0)
+ if bin_idx > 0:
+ offset_in_bin -= tl.load(bins + bin_idx - 1)
+
+ # Load the starting index of our bin in array 'x'.
+ index_x = offset_in_bin
+ if bin_idx > 0:
+ index_x += tl.load(padded_bins + bin_idx - 1)
+
+ # Offset the input and output pointers.
+ wgrad += index_out
+ grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS)
+ x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ data = tl.load(x + offsets, mask=mask).to(tl.float32)
+ scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
+ acc += data * scale
+ offsets += BLOCK_X
+
+ # Reduce to get the final result and store.
+ out = tl.sum(acc).to(wgrad.dtype.element_ty)
+ tl.store(wgrad, out)
+
+
+def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_matrix(grad)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], bin_ids.shape[0])
+ assert_equal(bins.size(), padded_bins.size())
+
+ tokens = indices.shape[0] // top_k
+ out = torch.empty((tokens * top_k), dtype=x.dtype, device=x.device)
+ _padded_copy_wgrad[(indices.shape[0],)](
+ x,
+ grad,
+ out,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ TOP_K=top_k,
+ )
+ return out
+
+
+def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k):
+ return padded_scatter_wgrad(x, grad, indices, bin_ids, bins, bins, top_k)
+
+
+# a: (tokens, hidden_size), real.
+# b: (num_experts, expert_capacity, num_columns), real.
+# indices: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _binned_copy(
+ a,
+ b,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+ A_TO_B: tl.constexpr,
+ SCALE: tl.constexpr,
+):
+ # Load our indices into the output.
+ expert_idx = tl.program_id(0)
+ entry_idx = tl.program_id(1)
+
+ # Calculate our offset into the output.
+ index_b = expert_idx * expert_capacity + entry_idx
+
+ # Load the index bounds for our bin and calculate
+ # the number of tokens assigned to our expert.
+ start = 0
+ if expert_idx > 0:
+ start = tl.load(bins + expert_idx - 1)
+ end = tl.load(bins + expert_idx)
+ num_tokens = end - start
+
+ # Calculate our offset into the input. If we don't
+ # have an input exit early.
+ if entry_idx >= num_tokens:
+ return
+ index_a = tl.load(indices + start + entry_idx)
+
+ # Offset the input and output pointers.
+ #
+ # If we're going from A to B, divide the input index to copy
+ # the same input repeatedly. If we're going from B to A we
+ # need to reduce the result. Using atomics is slow, so we
+ # do the reduce step in a second kernel.
+ offset = index_a // TOP_K if A_TO_B else index_a
+ a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS)
+ b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ # Load the scale, if requested.
+ scale = tl.load(weights + index_a) if SCALE else 1
+
+ # Swap the pointers depending on the direction.
+ #
+ # NOTE: We need to zero the output in both directions.
+ iptr = a if A_TO_B else b
+ optr = b if A_TO_B else a
+
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ x = tl.load(iptr + offsets, mask=mask)
+ x = x.to(tl.float32) * scale.to(tl.float32)
+
+ tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask)
+
+ offsets += BLOCK_X
+
+
+def binned_gather(x, indices, weights, bins, expert_capacity, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ num_experts = bins.shape[0]
+ out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
+
+ _binned_copy[(num_experts, expert_capacity)](
+ x,
+ out,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def binned_scatter(x, indices, weights, bins, top_k):
+ # Validate the input shapes.
+ assert_is_tensor(x, 3)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(bins.shape[0], x.shape[0])
+
+ if weights is not None:
+ assert_equal(indices.shape[0], weights.shape[0])
+
+ num_experts, expert_capacity, hidden_size = x.shape
+ tokens = indices.shape[0] // top_k
+ out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
+ _binned_copy[(num_experts, expert_capacity)](
+ out,
+ x,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS=hidden_size,
+ A_TO_B=False,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+
+ # Reduce along the top-k dimension, if needed.
+ return out.sum(dim=1) if top_k > 1 else out.view(tokens, hidden_size)
+
+
+# a: (tokens, hidden_size), real.
+# b: (num_experts, expert_capacity, num_columns), real.
+# indices: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _binned_copy_wgrad(
+ x,
+ grad,
+ wgrad,
+ num_experts,
+ expert_capacity,
+ indices,
+ bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+):
+ # Load our indices into the output.
+ expert_idx = tl.program_id(0)
+ entry_idx = tl.program_id(1)
+
+ # Calculate our offset into the output.
+ index_x = expert_idx * expert_capacity + entry_idx
+
+ # Load the index bounds for our bin and calculate
+ # the number of tokens assigned to our expert.
+ start = 0
+ if expert_idx > 0:
+ start = tl.load(bins + expert_idx - 1)
+ end = tl.load(bins + expert_idx)
+ num_tokens = end - start
+
+ # Calculate our offset into the input. If we don't
+ # have an input exit early.
+ if entry_idx >= num_tokens:
+ return
+ index_out = tl.load(indices + start + entry_idx)
+
+ # Offset the input and output pointers.
+ wgrad += index_out
+ grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS)
+ x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ data = tl.load(x + offsets, mask=mask).to(tl.float32)
+ scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
+ acc += data * scale
+ offsets += BLOCK_X
+
+ # Reduce to get the final result and store.
+ out = tl.sum(acc).to(wgrad.dtype.element_ty)
+ tl.store(wgrad, out)
+
+
+def binned_scatter_wgrad(x, grad, indices, bins, top_k):
+ # Validate the input shapes.
+ assert_is_tensor(x, 3)
+ assert_is_matrix(grad)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(bins.shape[0], x.shape[0])
+
+ num_experts, expert_capacity, hidden_size = x.shape
+ tokens = indices.shape[0] // top_k
+ out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
+ _binned_copy_wgrad[(num_experts, expert_capacity)](
+ x,
+ grad,
+ out,
+ num_experts,
+ expert_capacity,
+ indices,
+ bins,
+ NUM_COLUMNS=hidden_size,
+ TOP_K=top_k,
+ )
+ return out
diff --git a/build/torch210-cxx11-cpu-x86_64-linux/benchmark_util.py b/build/torch210-cxx11-cpu-x86_64-linux/benchmark_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..02612d95e3ead1175a596e2878fa34b5bf85ad6f
--- /dev/null
+++ b/build/torch210-cxx11-cpu-x86_64-linux/benchmark_util.py
@@ -0,0 +1,35 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import numpy as np
+import torch
+
+
+def log_benchmark(name, arguments, time, std):
+ print('=' * 60)
+ print(f'{name} Benchmark')
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean time = {:.3f}ms, std time = {:.3f}ms'.format(time, std))
+ print('=' * 60)
+
+
+def benchmark_function(fn, iterations=100, warmup=10):
+ # Warmup iterations.
+ for _ in range(warmup):
+ fn()
+
+ times = []
+ for i in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+
+ start.record()
+ fn()
+ end.record()
+
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ return np.mean(times), np.std(times)
diff --git a/build/torch210-cxx11-cpu-x86_64-linux/cpu_fused_moe.py b/build/torch210-cxx11-cpu-x86_64-linux/cpu_fused_moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a29e81d6e9c57f64b0a78ac1c0828e45fd9d855
--- /dev/null
+++ b/build/torch210-cxx11-cpu-x86_64-linux/cpu_fused_moe.py
@@ -0,0 +1,311 @@
+# SPDX-License-Identifier: Apache-2.0
+# MegaBlocks CPU Fused MoE Implementation
+#
+# This is a pure Python/PyTorch implementation for CPU.
+# For better performance, consider using the C++ kernel implementation.
+#
+import torch
+import torch.nn.functional as F
+
+
+def swigluoai_activation(gate: torch.Tensor, up: torch.Tensor,
+ alpha: float = 1.702, limit: float = 7.0) -> torch.Tensor:
+ """
+ SwigluOAI activation function used in GptOss models.
+
+ Formula:
+ gate = clamp(gate, max=limit)
+ up = clamp(up, -limit, limit)
+ glu = gate * sigmoid(gate * alpha)
+ output = (up + 1) * glu
+
+ Args:
+ gate: Gate tensor from gate projection
+ up: Up tensor from up projection
+ alpha: Scaling factor for sigmoid (default: 1.702)
+ limit: Clamp limit (default: 7.0)
+
+ Returns:
+ Activated tensor
+ """
+ gate = gate.clamp(max=limit)
+ up = up.clamp(min=-limit, max=limit)
+ glu = gate * torch.sigmoid(gate * alpha)
+ return (up + 1) * glu
+
+
+def silu_and_mul_activation(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor:
+ """
+ SiLU (Swish) activation with element-wise multiplication.
+
+ Formula:
+ output = silu(gate) * up
+
+ Args:
+ gate: Gate tensor
+ up: Up tensor
+
+ Returns:
+ Activated tensor
+ """
+ return F.silu(gate) * up
+
+
+def route_tokens_cpu(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: torch.Tensor | None,
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_normalize_expert_weights: int | None = None,
+) -> tuple:
+ """
+ Route tokens to experts and compute expert weights and indices (CPU version).
+
+ Args:
+ x: Input tensor [batch, seq, hidden] or [tokens, hidden]
+ router_weight: Router weight [num_experts, hidden]
+ router_bias: Router bias [num_experts] or None
+ moe_top_k: Number of experts per token
+ moe_num_experts: Total number of experts
+ moe_normalize_expert_weights: Normalization order or None
+
+ Returns:
+ Tuple of (logits, expert_weights, expert_indices)
+ """
+ x_flat = x.view(-1, x.shape[-1])
+ logits = F.linear(x_flat, router_weight, router_bias)
+
+ if moe_top_k == 1:
+ expert_weights, expert_indices = logits.max(dim=-1, keepdim=True)
+ else:
+ expert_weights, expert_indices = torch.topk(logits, moe_top_k, dim=-1)
+
+ expert_weights = expert_weights.softmax(dim=-1)
+
+ if moe_normalize_expert_weights is not None:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+
+ return logits, expert_weights, expert_indices
+
+
+def cpu_fused_moe(
+ hidden_states: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ w1_bias: torch.Tensor | None = None,
+ w2_bias: torch.Tensor | None = None,
+ activation: str = "silu",
+ alpha: float = 1.702,
+ limit: float = 7.0,
+ is_interleaved: bool = True,
+) -> torch.Tensor:
+ """
+ CPU Fused MoE using PyTorch operations.
+
+ This implementation processes all experts in parallel using batched operations
+ instead of sequential for loops, which is more efficient on CPU.
+
+ Args:
+ hidden_states: [num_tokens, hidden_size]
+ w1: [num_experts, hidden_size, 2*inter_size] - gate_up_proj weights
+ w2: [num_experts, inter_size, hidden_size] - down_proj weights
+ topk_weights: [num_tokens, topk] - routing weights
+ topk_ids: [num_tokens, topk] - expert indices
+ w1_bias: [num_experts, 2*inter_size] or None
+ w2_bias: [num_experts, hidden_size] or None
+ activation: "silu" or "swigluoai"
+ alpha: swigluoai alpha parameter
+ limit: swigluoai limit parameter
+ is_interleaved: whether gate_up is interleaved [g0,u0,g1,u1,...] (True for GptOss)
+
+ Returns:
+ output: [num_tokens, hidden_size]
+ """
+ num_tokens, hidden_size = hidden_states.shape
+ num_experts = w1.shape[0]
+ inter_size = w2.shape[1]
+ topk = topk_weights.shape[1]
+
+ # Initialize output
+ output = torch.zeros_like(hidden_states)
+
+ # Build expert mask: which tokens go to which expert
+ # expert_mask[expert_id] contains indices of (token_idx, topk_pos) pairs
+ for expert_idx in range(num_experts):
+ # Find tokens assigned to this expert
+ # mask shape: [num_tokens, topk], True where topk_ids == expert_idx
+ mask = (topk_ids == expert_idx)
+
+ if not mask.any():
+ continue
+
+ # Get token indices and topk positions
+ token_indices, topk_positions = torch.where(mask)
+
+ if len(token_indices) == 0:
+ continue
+
+ # Gather input tokens for this expert
+ # current_hidden: [num_selected_tokens, hidden_size]
+ current_hidden = hidden_states[token_indices]
+
+ # Get weights for this expert
+ # w1[expert_idx]: [hidden_size, 2*inter_size]
+ # w2[expert_idx]: [inter_size, hidden_size]
+ expert_w1 = w1[expert_idx] # [hidden_size, 2*inter_size]
+ expert_w2 = w2[expert_idx] # [inter_size, hidden_size]
+
+ # First projection: hidden @ w1 -> [num_selected, 2*inter_size]
+ gate_up = current_hidden @ expert_w1
+
+ # Add bias if present
+ if w1_bias is not None:
+ gate_up = gate_up + w1_bias[expert_idx]
+
+ # Split gate and up projections
+ if is_interleaved:
+ # GptOss uses interleaved layout: [g0, u0, g1, u1, ...]
+ gate = gate_up[..., ::2] # [num_selected, inter_size]
+ up = gate_up[..., 1::2] # [num_selected, inter_size]
+ else:
+ # Standard layout: [gate_all, up_all]
+ gate = gate_up[..., :inter_size]
+ up = gate_up[..., inter_size:]
+
+ # Apply activation
+ if activation == "swigluoai":
+ activated = swigluoai_activation(gate, up, alpha, limit)
+ else: # silu
+ activated = silu_and_mul_activation(gate, up)
+
+ # Second projection: activated @ w2 -> [num_selected, hidden_size]
+ expert_out = activated @ expert_w2
+
+ # Add bias if present
+ if w2_bias is not None:
+ expert_out = expert_out + w2_bias[expert_idx]
+
+ # Apply routing weights and accumulate
+ # weights shape: [num_selected]
+ weights = topk_weights[token_indices, topk_positions].unsqueeze(-1)
+ weighted_out = expert_out * weights
+
+ # Accumulate to output
+ output.index_add_(0, token_indices, weighted_out.to(output.dtype))
+
+ return output
+
+
+class MegaBlocksMoeMLP(torch.nn.Module):
+ """
+ CPU MoE MLP module that can be used as a drop-in replacement for
+ the transformers GptOssMLP when using @use_kernel_forward_from_hub.
+ """
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> tuple:
+ """
+ Forward pass through the MoE layer.
+
+ Args:
+ x: Input tensor of shape [batch_size, seq_len, hidden_size] or [tokens, hidden_size]
+
+ Returns:
+ Tuple of (output, expert_weights) where:
+ - output: Tensor of same shape as input
+ - expert_weights: Expert weights for each token [tokens, top_k]
+ """
+ # Get MoE parameters from the wrapped modules
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+
+ # Detect activation type
+ if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
+ activation = "swigluoai"
+ alpha = self.experts.alpha
+ limit = self.experts.limit
+ else:
+ activation = getattr(self.experts, "activation", "silu")
+ alpha = 1.702
+ limit = 7.0
+
+ # Get weight tensors
+ if hasattr(self.experts, "gate_up_proj"):
+ w1 = self.experts.gate_up_proj
+ is_interleaved = True # GptOss uses interleaved layout
+ elif hasattr(self.experts, "w1"):
+ w1 = self.experts.w1
+ w3 = getattr(self.experts, "w3", None)
+ if w3 is not None:
+ w1 = torch.cat([w1, w3], dim=-1)
+ is_interleaved = False
+ else:
+ raise AttributeError("experts module must have 'gate_up_proj' or 'w1' attribute")
+
+ if hasattr(self.experts, "down_proj"):
+ w2 = self.experts.down_proj
+ elif hasattr(self.experts, "w2"):
+ w2 = self.experts.w2
+ else:
+ raise AttributeError("experts module must have 'down_proj' or 'w2' attribute")
+
+ # Get optional bias tensors
+ w1_bias = getattr(self.experts, "gate_up_proj_bias", None)
+ w2_bias = getattr(self.experts, "down_proj_bias", None)
+
+ # Store original shape
+ in_shape = x.size()
+
+ # Route tokens to experts
+ logits, expert_weights, expert_indices = route_tokens_cpu(
+ x,
+ self.router.weight,
+ getattr(self.router, "bias", None),
+ moe_top_k,
+ moe_num_experts,
+ moe_normalize_expert_weights,
+ )
+
+ # Reshape input for fused MoE
+ x_flat = x.view(-1, x.shape[-1])
+
+ # Call CPU fused MoE
+ output = cpu_fused_moe(
+ hidden_states=x_flat,
+ w1=w1,
+ w2=w2,
+ topk_weights=expert_weights,
+ topk_ids=expert_indices,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ activation=activation,
+ alpha=alpha,
+ limit=limit,
+ is_interleaved=is_interleaved,
+ )
+
+ # Restore original shape
+ output = output.view(in_shape)
+
+ return output, expert_weights
+
+
+# Export classes and functions
+__all__ = [
+ "MegaBlocksMoeMLP",
+ "cpu_fused_moe",
+ "route_tokens_cpu",
+ "swigluoai_activation",
+ "silu_and_mul_activation",
+]
diff --git a/build/torch210-cxx11-cpu-x86_64-linux/cpu_moe_cpp.py b/build/torch210-cxx11-cpu-x86_64-linux/cpu_moe_cpp.py
new file mode 100644
index 0000000000000000000000000000000000000000..073ff66d24ce348fbb5ed19c9027fadd3f7a9c61
--- /dev/null
+++ b/build/torch210-cxx11-cpu-x86_64-linux/cpu_moe_cpp.py
@@ -0,0 +1,265 @@
+# SPDX-License-Identifier: Apache-2.0
+# MegaBlocks C++ Optimized CPU MoE
+
+"""
+C++ accelerated MoE with brgemm optimization for Intel AMX.
+Direct replacement for cpu_fused_moe.MegaBlocksMoeMLP with better performance.
+"""
+
+import torch
+from typing import Optional
+from .cpu_fused_moe import route_tokens_cpu
+from ._ops import ops
+
+
+def _to_local_tensor(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
+ """Convert DTensor to local torch.Tensor if needed for custom ops compatibility."""
+ if tensor is None:
+ return None
+ # Check if it's a DTensor by looking for the to_local() method
+ if hasattr(tensor, "to_local"):
+ return tensor.to_local()
+ return tensor
+
+
+def fused_moe_cpp(
+ hidden_states: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ inplace: bool = False,
+ use_int8_w8a8: bool = False,
+ use_fp8_w8a16: bool = False,
+ use_mxfp4: bool = False,
+ w1_scale: Optional[torch.Tensor] = None,
+ w2_scale: Optional[torch.Tensor] = None,
+ block_size: Optional[list] = None,
+ a1_scale: Optional[torch.Tensor] = None,
+ a2_scale: Optional[torch.Tensor] = None,
+ w1_bias: Optional[torch.Tensor] = None,
+ w2_bias: Optional[torch.Tensor] = None,
+ alpha: Optional[float] = None,
+ limit: Optional[float] = None,
+ is_vnni: bool = False,
+) -> torch.Tensor:
+ """
+ C++ Fused MoE with brgemm optimization (sglang compatible interface).
+
+ Uses at::native::cpublas::brgemm for efficient batch GEMM on Intel CPUs.
+ Supports both silu_and_mul (standard SwiGLU) and swigluoai (GptOss) activations.
+
+ Args:
+ hidden_states: Input tensor [M, K]
+ w1: Gate and up projections [E, 2N, K]
+ w2: Down projection [E, K, N]
+ topk_weights: Expert weights [M, topk]
+ topk_ids: Expert indices [M, topk]
+ inplace: Whether to use hidden_states as output
+ use_int8_w8a8: Use int8 quantization
+ use_fp8_w8a16: Use fp8 quantization
+ use_mxfp4: Use mxfp4 quantization
+ w1_scale, w2_scale: Quantization scales
+ block_size: Block size for fp8
+ a1_scale, a2_scale: Activation scales
+ w1_bias, w2_bias: Optional biases
+ alpha: swigluoai alpha parameter (set to enable swiglu)
+ limit: swigluoai limit parameter (set to enable swiglu)
+ is_vnni: Whether w1/w2 are already in VNNI packed format
+ """
+ # MXFP4/FP8 kernels only support bf16, convert if needed
+ orig_dtype = hidden_states.dtype
+ need_convert = ((use_mxfp4 or use_fp8_w8a16) and orig_dtype != torch.bfloat16) or orig_dtype == torch.float32
+ if need_convert:
+ hidden_states = hidden_states.to(torch.bfloat16)
+
+ # bias must match hidden_states dtype
+ if w1_bias is not None:
+ w1_bias = w1_bias.to(hidden_states.dtype)
+ if w2_bias is not None:
+ w2_bias = w2_bias.to(hidden_states.dtype)
+
+ # Convert DTensor to local tensor for custom ops compatibility (TP mode)
+ hidden_states = _to_local_tensor(hidden_states)
+ w1 = _to_local_tensor(w1)
+ w2 = _to_local_tensor(w2)
+ topk_weights = _to_local_tensor(topk_weights)
+ topk_ids = _to_local_tensor(topk_ids)
+ w1_scale = _to_local_tensor(w1_scale)
+ w2_scale = _to_local_tensor(w2_scale)
+ a1_scale = _to_local_tensor(a1_scale)
+ a2_scale = _to_local_tensor(a2_scale)
+ w1_bias = _to_local_tensor(w1_bias)
+ w2_bias = _to_local_tensor(w2_bias)
+
+ output = ops.fused_experts(
+ hidden_states, w1, w2, topk_weights, topk_ids,
+ inplace, use_int8_w8a8, use_fp8_w8a16, use_mxfp4,
+ w1_scale, w2_scale, block_size, a1_scale, a2_scale,
+ w1_bias, w2_bias, alpha, limit, is_vnni
+ )
+
+ # Convert back to original dtype if needed
+ if need_convert:
+ output = output.to(orig_dtype)
+ return output
+
+
+class CPUMegaBlocksMoeMLP(torch.nn.Module):
+ """
+ C++ optimized MoE MLP using brgemm.
+ Drop-in replacement for cpu_fused_moe.MegaBlocksMoeMLP with better performance.
+
+ Usage in transformers:
+ # Will be used via @use_kernel_forward_from_hub decorator
+ """
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> tuple:
+ """
+ Forward pass through the MoE layer using C++ kernel.
+
+ Args:
+ x: Input tensor [batch_size, seq_len, hidden_size]
+
+ Returns:
+ Tuple of (output, expert_weights)
+ """
+ # Optimization for GPT-OSS model
+ if getattr(self, "use_mxfp4", None) is None:
+ self.use_mxfp4 = False
+
+ w1_scale = None
+ w2_scale = None
+
+ if (
+ not getattr(self, "packed_scales", False)
+ and hasattr(self.experts, "gate_up_proj")
+ and getattr(self.experts, "gate_up_proj_precision_config", None) is not None
+ ):
+ # convert scales
+ data_1 = ops.convert_scale_packed(self.experts.gate_up_proj_precision_config.weight_scale.data.transpose(-1, -2).contiguous())
+ data_2 = ops.convert_scale_packed(self.experts.down_proj_precision_config.weight_scale.data.transpose(-1, -2).contiguous())
+ self.experts.gate_up_proj_precision_config.weight_scale.storage.data = data_1
+ self.experts.down_proj_precision_config.weight_scale.storage.data = data_2
+ self.packed_scales = True
+ self.use_mxfp4 = True
+
+ if not getattr(self, "packed_weight", False) and hasattr(
+ self.experts, "gate_up_proj"
+ ):
+ # convert weights
+ data_1 = self.experts.gate_up_proj.data.transpose(-1, -2).contiguous()
+ data_2 = self.experts.down_proj.data.transpose(-1, -2).contiguous()
+ if self.use_mxfp4:
+ self.experts.gate_up_proj.storage.data = ops.convert_weight_packed(data_1)
+ self.experts.down_proj.storage.data = ops.convert_weight_packed(data_2)
+ else:
+ # convert_weight_packed only supports bfloat16, float16, int8, fp8_e4m3 or uint8(mxfp4 or int4).
+ data_1 = data_1.to(torch.bfloat16) if data_1.dtype == torch.float32 else data_1
+ data_2 = data_2.to(torch.bfloat16) if data_2.dtype == torch.float32 else data_2
+ self.experts.gate_up_proj.data = ops.convert_weight_packed(data_1)
+ self.experts.down_proj.data = ops.convert_weight_packed(data_2)
+
+ # C++ kernel does not support float32.
+ dtype = torch.bfloat16 if x.dtype == torch.float32 else x.dtype
+ if getattr(self.experts, "gate_up_proj_bias", None) is not None:
+ self.experts.gate_up_proj_bias.data = self.experts.gate_up_proj_bias.data.to(dtype)
+ if getattr(self.experts, "down_proj_bias", None) is not None:
+ self.experts.down_proj_bias.data = self.experts.down_proj_bias.data.to(dtype)
+
+ self.packed_weight = True
+
+ # Get MoE parameters
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
+
+ # Detect activation type
+ if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
+ activation = "swigluoai"
+ alpha = self.experts.alpha
+ limit = self.experts.limit
+ else:
+ activation = getattr(self.experts, "activation", "silu")
+ alpha = 1.702
+ limit = 7.0
+
+ # Get weight tensors
+ if hasattr(self.experts, "gate_up_proj"):
+ w1 = self.experts.gate_up_proj
+ elif hasattr(self.experts, "w1"):
+ w1 = self.experts.w1
+ w3 = getattr(self.experts, "w3", None)
+ if w3 is not None:
+ w1 = torch.cat([w1, w3], dim=-1)
+ else:
+ raise AttributeError("experts module must have 'gate_up_proj' or 'w1' attribute")
+
+ if hasattr(self.experts, "down_proj"):
+ w2 = self.experts.down_proj
+ elif hasattr(self.experts, "w2"):
+ w2 = self.experts.w2
+ else:
+ raise AttributeError("experts module must have 'down_proj' or 'w2' attribute")
+
+ # Get optional bias tensors
+ w1_bias = getattr(self.experts, "gate_up_proj_bias", None)
+ w2_bias = getattr(self.experts, "down_proj_bias", None)
+ w1_bias = w1_bias if w1_bias is None else w1_bias.data
+ w2_bias = w2_bias if w2_bias is None else w2_bias.data
+
+ if self.use_mxfp4:
+ w1_scale = self.experts.gate_up_proj_precision_config.weight_scale.data
+ w2_scale = self.experts.down_proj_precision_config.weight_scale.data
+
+ # Store original shape
+ in_shape = x.size()
+
+ # Route tokens to experts (Python implementation is fast enough)
+ logits, expert_weights, expert_indices = route_tokens_cpu(
+ x,
+ self.router.weight,
+ getattr(self.router, "bias", None),
+ moe_top_k,
+ moe_num_experts,
+ moe_normalize_expert_weights,
+ )
+
+ # Flatten input
+ x_flat = x.view(-1, x.shape[-1])
+
+ # Determine alpha/limit for swiglu activation
+ use_alpha = alpha if activation == "swigluoai" else None
+ use_limit = limit if activation == "swigluoai" else None
+
+ # Call C++ optimized kernel
+ output = fused_moe_cpp(
+ hidden_states=x_flat,
+ w1=w1.data,
+ w2=w2.data,
+ topk_weights=expert_weights,
+ topk_ids=expert_indices.to(torch.int32),
+ inplace=False,
+ use_int8_w8a8=False,
+ use_fp8_w8a16=False,
+ use_mxfp4=self.use_mxfp4,
+ w1_scale=w1_scale,
+ w2_scale=w2_scale,
+ block_size=None,
+ a1_scale=None,
+ a2_scale=None,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ alpha=use_alpha,
+ limit=use_limit,
+ is_vnni=getattr(self, "packed_weight", False),
+ )
+
+ # Restore original shape
+ output = output.view(in_shape)
+
+ return output, expert_weights
+
+
+__all__ = ["fused_moe_cpp", "MegaBlocksMoeMLP"]
diff --git a/build/torch210-cxx11-cpu-x86_64-linux/grouped_gemm/__init__.py b/build/torch210-cxx11-cpu-x86_64-linux/grouped_gemm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b91c8308f0c24f4c4171b6e4f15b6f76dabf295a
--- /dev/null
+++ b/build/torch210-cxx11-cpu-x86_64-linux/grouped_gemm/__init__.py
@@ -0,0 +1,2 @@
+from . import ops
+from . import backend
diff --git a/build/torch210-cxx11-cpu-x86_64-linux/grouped_gemm/backend.py b/build/torch210-cxx11-cpu-x86_64-linux/grouped_gemm/backend.py
new file mode 100644
index 0000000000000000000000000000000000000000..76037d8039cbfc2f0577275c78e4bc0be762592a
--- /dev/null
+++ b/build/torch210-cxx11-cpu-x86_64-linux/grouped_gemm/backend.py
@@ -0,0 +1,33 @@
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# # TODO(tgale): Wrap this in a try-block with better
+# # error message and instructions for building the
+# # c++ operations.
+# import grouped_gemm_backend as backend
+
+# We import the backend operations from the megablocks package as
+# grouped_gemm is vendored in megablocks in this repository.
+# from ... import _ops as backend
+# from megablocks._ops import ops as backend # type: ignore
+from .._ops import ops as backend # type: ignore
+
+def _allocate_output(a, b, batch_sizes, trans_a, trans_b):
+ assert not (trans_a and trans_b)
+ assert batch_sizes.ndim == 1, "Expected 1d tensor for batch_sizes"
+ assert a.ndim == 2, "Expected 2d tensor for 'a'"
+ assert b.ndim == (2 if trans_a else 3)
+
+ shape = (
+ (batch_sizes.shape[0], a.shape[1], b.shape[1])
+ if trans_a else
+ (a.shape[0], (b.shape[1] if trans_b else b.shape[2]))
+ )
+ return torch.empty(*shape, device=a.device, dtype=a.dtype)
+
+def gmm(a, b, batch_sizes, trans_a=False, trans_b=False, c=None):
+ if c is None:
+ c = _allocate_output(a, b, batch_sizes, trans_a, trans_b)
+ backend.gmm(a, b, c, batch_sizes, trans_a, trans_b)
+ return c
diff --git a/build/torch210-cxx11-cpu-x86_64-linux/grouped_gemm/ops.py b/build/torch210-cxx11-cpu-x86_64-linux/grouped_gemm/ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b30dd14e23837ea3b12334f4e31337ed9ad2b69
--- /dev/null
+++ b/build/torch210-cxx11-cpu-x86_64-linux/grouped_gemm/ops.py
@@ -0,0 +1,33 @@
+from . import backend
+import torch
+
+
+class GroupedGemm(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, a, b, batch_sizes, trans_b):
+ ctx.save_for_backward(a, b, batch_sizes)
+ ctx.trans_b = trans_b
+ return backend.gmm(a, b, batch_sizes, trans_a=False, trans_b=trans_b)
+
+ @staticmethod
+ def backward(ctx, grad):
+ grad = grad.contiguous()
+ a, b, batch_sizes = ctx.saved_tensors
+ trans_b = ctx.trans_b
+
+ agrad = None
+ if ctx.needs_input_grad[0]:
+ agrad = backend.gmm(
+ grad, b, batch_sizes, trans_a=False, trans_b=not trans_b)
+
+ bgrad = None
+ if ctx.needs_input_grad[1]:
+ lhs, rhs = (grad, a) if trans_b else (a, grad)
+ bgrad = backend.gmm(
+ lhs, rhs, batch_sizes, trans_a=True, trans_b=False)
+ return agrad, bgrad, None, None
+
+
+def gmm(a, b, batch_sizes, trans_b=False):
+ return GroupedGemm.apply(a, b, batch_sizes, trans_b)
diff --git a/build/torch210-cxx11-cpu-x86_64-linux/grouped_gemm_util.py b/build/torch210-cxx11-cpu-x86_64-linux/grouped_gemm_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d6d49fc46b0a57ad46e4179df3cc1ac2a24f7ae
--- /dev/null
+++ b/build/torch210-cxx11-cpu-x86_64-linux/grouped_gemm_util.py
@@ -0,0 +1,31 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+import warnings
+
+_grouped_gemm_is_available: bool = False
+try:
+ # import grouped_gemm
+ pass
+ _grouped_gemm_is_available = True
+except ImportError as error:
+ warnings.warn('Grouped GEMM not available.')
+
+
+def grouped_gemm_is_available():
+ return _grouped_gemm_is_available
+
+
+def assert_grouped_gemm_is_available():
+ msg = (
+ 'Grouped GEMM not available. Please run '
+ '`pip install git+https://github.com/tgale96/grouped_gemm@main`.',
+ )
+ assert _grouped_gemm_is_available, msg
+
+
+# backend = grouped_gemm.backend if grouped_gemm_is_available() else None
+# ops = grouped_gemm.ops if grouped_gemm_is_available() else None
+
+
+from .grouped_gemm import backend as ops
+from .grouped_gemm import ops as backend
diff --git a/build/torch210-cxx11-cpu-x86_64-linux/layers.py b/build/torch210-cxx11-cpu-x86_64-linux/layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c91edcd9e2d1a4ef9eac90217ff481f08ab1886
--- /dev/null
+++ b/build/torch210-cxx11-cpu-x86_64-linux/layers.py
@@ -0,0 +1,1232 @@
+import torch
+import torch.distributed as dist
+
+from typing import Optional, Any, TYPE_CHECKING
+
+from . import _layers
+from . import ops
+
+# Conditional import for meta kernel registration
+if TYPE_CHECKING:
+
+ def register_fake(fn):
+ return lambda name: fn
+
+else:
+ try:
+ from torch.library import register_fake
+ except ImportError:
+ try:
+ from torch.library import impl_abstract as register_fake
+ except ImportError:
+ # Fallback for older PyTorch versions
+ def register_fake(op_name):
+ def decorator(fn):
+ return fn
+
+ return decorator
+
+
+# Meta kernel implementations for torch.compile compatibility
+def _install_meta_kernels():
+ """Install meta kernels for existing MegaBlocks operations"""
+
+ # Create wrapper functions that check for compilation and return meta tensors
+
+ # Patch ops.sort
+ if hasattr(ops, "sort"):
+ original_sort = ops.sort
+
+ def sort_with_meta(x, end_bit=None):
+ if torch.compiler.is_compiling():
+ # print("Using meta kernel for sort")
+ # Meta implementation - return tensors with correct shape/dtype/device
+ return torch.empty_like(x), torch.empty_like(x)
+ # print("Using original sort kernel")
+ return original_sort(x, end_bit)
+
+ ops.sort = sort_with_meta
+
+ # Patch ops.histogram
+ if hasattr(ops, "histogram"):
+ original_histogram = ops.histogram
+
+ def histogram_with_meta(x, max_val):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty((max_val,), dtype=torch.int32, device=x.device)
+ return original_histogram(x, max_val)
+
+ ops.histogram = histogram_with_meta
+
+ # Patch ops.inclusive_cumsum
+ if hasattr(ops, "inclusive_cumsum"):
+ original_inclusive_cumsum = ops.inclusive_cumsum
+
+ def inclusive_cumsum_with_meta(x, dim):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty_like(x)
+ return original_inclusive_cumsum(x, dim)
+
+ ops.inclusive_cumsum = inclusive_cumsum_with_meta
+
+ # Patch ops.binned_gather
+ if hasattr(ops, "binned_gather"):
+ original_binned_gather = ops.binned_gather
+
+ def binned_gather_with_meta(x, indices, bins, bin_size, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - output shape based on bin_size
+ if x.dim() >= 2:
+ hidden_size = x.size(-1)
+ return torch.empty(
+ (bin_size, x.size(1), hidden_size),
+ dtype=x.dtype,
+ device=x.device,
+ )
+ else:
+ return torch.empty((bin_size,), dtype=x.dtype, device=x.device)
+ return original_binned_gather(x, indices, bins, bin_size, top_k)
+
+ ops.binned_gather = binned_gather_with_meta
+
+ # Patch ops.binned_scatter
+ if hasattr(ops, "binned_scatter"):
+ original_binned_scatter = ops.binned_scatter
+
+ def binned_scatter_with_meta(x, indices, weights, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - typically reduces to 2D
+ if x.dim() >= 3:
+ return torch.empty(
+ (x.size(1), x.size(2)), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty_like(x)
+ return original_binned_scatter(x, indices, weights, bins, top_k)
+
+ ops.binned_scatter = binned_scatter_with_meta
+
+ # Patch ops.gather
+ if hasattr(ops, "gather"):
+ original_gather = ops.gather
+
+ def gather_with_meta(x, indices, bin_ids, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ if x.dim() >= 2:
+ hidden_size = x.size(-1)
+ return torch.empty(
+ (indices.numel(), hidden_size), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty(indices.shape, dtype=x.dtype, device=x.device)
+ return original_gather(x, indices, bin_ids, bins, top_k)
+
+ ops.gather = gather_with_meta
+
+ # Patch ops.scatter
+ if hasattr(ops, "scatter"):
+ original_scatter = ops.scatter
+
+ def scatter_with_meta(x, indices, bin_ids, weights, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - restore sequence shape
+ seq_len = (
+ indices.size(0) // top_k
+ if indices.numel() > 0 and top_k > 0
+ else x.size(0)
+ )
+ if x.dim() >= 2:
+ return torch.empty(
+ (seq_len, x.size(-1)), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty((seq_len,), dtype=x.dtype, device=x.device)
+ return original_scatter(x, indices, bin_ids, weights, bins, top_k)
+
+ ops.scatter = scatter_with_meta
+
+ # Patch ops.replicate
+ if hasattr(ops, "replicate"):
+ original_replicate = ops.replicate
+
+ def replicate_with_meta(x, bins, num_outputs):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty(
+ (x.shape[0], num_outputs), dtype=x.dtype, device=x.device
+ )
+ return original_replicate(x, bins, num_outputs)
+
+ ops.replicate = replicate_with_meta
+
+ # Patch ops.repeat (if it's a regular function)
+ if hasattr(ops, "repeat"):
+ original_repeat = ops.repeat
+
+ def repeat_with_meta(x, repeats):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ if isinstance(repeats, (tuple, list)):
+ new_shape = list(x.shape)
+ for i, rep in enumerate(repeats):
+ if i < len(new_shape):
+ new_shape[i] *= rep
+ return torch.empty(new_shape, dtype=x.dtype, device=x.device)
+ else:
+ new_shape = [x.size(0) * repeats] + list(x.shape[1:])
+ return torch.empty(new_shape, dtype=x.dtype, device=x.device)
+ return original_repeat(x, repeats)
+
+ ops.repeat = repeat_with_meta
+
+
+# Install meta kernels on import
+try:
+ _install_meta_kernels()
+except Exception as e:
+ # If meta kernel installation fails, continue without them
+ # torch.compile may not work but the library will still function
+ import warnings
+
+ warnings.warn(
+ f"Failed to install meta kernels for torch.compile support: {e}", UserWarning
+ )
+
+
+# Set the expert model parallel attributes on a tensor
+def set_expert_model_parallel_attributes(
+ tensor: torch.Tensor,
+ is_parallel: bool,
+):
+ assert not hasattr(tensor, "expert_model_parallel")
+ setattr(tensor, "expert_model_parallel", is_parallel)
+
+
+# Get the expert model parallel attributes from a tensor
+def expert_sharding_degree(
+ world_size: int,
+ moe_num_experts: int,
+) -> int:
+ esd = min(world_size, moe_num_experts)
+ if (moe_num_experts % esd) != 0:
+ raise ValueError(f"Cannot shard {moe_num_experts} experts {esd} ways.")
+ return esd
+
+
+# Calculate the hidden sharding degree based on world size and expert sharding degree
+def hidden_sharding_degree(
+ world_size: int,
+ moe_num_experts: int,
+ ffn_hidden_size: int,
+) -> int:
+ esd = expert_sharding_degree(world_size, moe_num_experts)
+ hsd = world_size // esd
+ if (ffn_hidden_size % hsd) != 0:
+ raise ValueError(f"Cannot shard {ffn_hidden_size} features {hsd} ways.")
+ if (esd * hsd) != world_size:
+ raise ValueError(
+ f"Invalid sharding. expert_sharding_degree ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size})."
+ )
+ return hsd
+
+
+# Calculate the number of experts per rank based on world size and expert sharding degree
+def experts_per_rank(
+ moe_num_experts: int,
+ world_size: int,
+) -> int:
+ return moe_num_experts // expert_sharding_degree(world_size, moe_num_experts)
+
+
+# Calculate the number of features per rank based on ffn hidden size and hidden sharding degree
+def features_per_rank(
+ ffn_hidden_size: int, world_size: int, moe_num_experts: int
+) -> int:
+ return ffn_hidden_size // hidden_sharding_degree(
+ world_size, moe_num_experts, ffn_hidden_size
+ )
+
+
+# Apply jitter to the input tensor
+def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor:
+ low = 1.0 - moe_jitter_eps
+ high = 1.0 + moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return x * (low + noise * (high - low))
+
+
+# Compute the top-k scores from the logits
+def compute_top_k(scores: torch.Tensor, moe_top_k: int):
+ if moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, moe_top_k, dim=-1)
+
+
+# Route tokens to experts and compute expert weights and indices
+def route_tokens(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: torch.Tensor,
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ if training and moe_jitter_eps is not None:
+ x = apply_jitter(x, moe_jitter_eps)
+
+ x_flat = x.view(-1, x.shape[-1])
+ logits = torch.nn.functional.linear(x_flat, router_weight, router_bias)
+ expert_weights, expert_indices = compute_top_k(logits, moe_top_k)
+ expert_weights = expert_weights.softmax(dim=-1)
+ if moe_normalize_expert_weights is not None:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+ if uniform_expert_assignment:
+ expert_indices = _layers.router._uniform_expert_assignment(
+ expert_indices,
+ moe_num_experts,
+ )
+
+ return logits, expert_weights, expert_indices
+
+
+# Scale the gradient of the weights
+def scale_grad(
+ w: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+) -> torch.Tensor:
+ if gradient_scale is None:
+ return w
+ return _layers.mlp.scale_gradient(w, gradient_scale)
+
+
+# Forward pass for the MLP layer
+def mlp_forward(
+ x: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ limit: float = 7.0,
+):
+ # Scale weights
+ w1 = scale_grad(w1, gradient_scale)
+ w2 = scale_grad(w2, gradient_scale)
+ w1_bias = scale_grad(w1_bias, gradient_scale)
+ w2_bias = scale_grad(w2_bias, gradient_scale)
+
+ # Resolve dtensors
+ w1 = _layers.mlp.resolve_dtensor(w1)
+ w2 = _layers.mlp.resolve_dtensor(w2)
+ w1_bias = _layers.mlp.resolve_dtensor(w1_bias)
+ w2_bias = _layers.mlp.resolve_dtensor(w2_bias)
+
+ # Forward pass
+ gate_up = torch.bmm(x, w1) + w1_bias[..., None, :]
+ gate, up = gate_up[..., ::2], gate_up[..., 1::2]
+ gate = gate.clamp(min=None, max=limit)
+ up = up.clamp(min=-limit, max=limit)
+ glu = gate * torch.sigmoid(gate * alpha)
+ next_states = torch.bmm(((up + 1) * glu), w2)
+ next_states += w2_bias[..., None, :]
+ return next_states
+
+# Shared expert MLP forward pass
+def shared_mlp_forward(
+ x: torch.Tensor,
+ up_proj_weight: torch.Tensor,
+ down_proj_weight: torch.Tensor,
+ up_proj_bias: Optional[torch.Tensor] = None,
+ down_proj_bias: Optional[torch.Tensor] = None,
+ activation_fn: Optional[Any] = None,
+ gradient_scale: Optional[float] = None,
+) -> torch.Tensor:
+ # Default activation function
+ if activation_fn is None:
+ activation_fn = torch.nn.functional.gelu
+
+ # Scale weights
+ up_proj_weight = scale_grad(up_proj_weight, gradient_scale)
+ down_proj_weight = scale_grad(down_proj_weight, gradient_scale)
+ if up_proj_bias is not None:
+ up_proj_bias = scale_grad(up_proj_bias, gradient_scale)
+ if down_proj_bias is not None:
+ down_proj_bias = scale_grad(down_proj_bias, gradient_scale)
+
+ # Resolve dtensors
+ up_proj_weight = _layers.mlp.resolve_dtensor(up_proj_weight)
+ down_proj_weight = _layers.mlp.resolve_dtensor(down_proj_weight)
+ if up_proj_bias is not None:
+ up_proj_bias = _layers.mlp.resolve_dtensor(up_proj_bias)
+ if down_proj_bias is not None:
+ down_proj_bias = _layers.mlp.resolve_dtensor(down_proj_bias)
+
+ # Up projection
+ x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias)
+
+ # Activation
+ x = activation_fn(x)
+
+ # Down projection
+ x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias)
+
+ return x
+
+
+# Combine outputs from shared expert and regular experts
+def combine_expert_shared_outputs(
+ shared_expert_out: torch.Tensor,
+ expert_out: torch.Tensor,
+ shared_expert_weighted_sum: bool = False,
+ moe_top_k: int = 1,
+) -> torch.Tensor:
+ if shared_expert_weighted_sum:
+ # Weighted sum based on number of experts used
+ total_experts = moe_top_k + 1
+ shared_weight = 1.0 / total_experts
+ expert_weight = moe_top_k / total_experts
+ return shared_expert_out * shared_weight + expert_out * expert_weight
+ else:
+ # Simple addition
+ return shared_expert_out + expert_out
+
+
+# Global variable to store load balancing loss
+_LOAD_BALANCING_LOSS = []
+
+
+def save_load_balancing_loss(loss):
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.append(loss)
+
+
+def get_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ return _LOAD_BALANCING_LOSS
+
+
+def clear_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.clear()
+
+
+def batched_load_balancing_loss(args):
+ if args.moe_loss_weight == 0:
+ return 0.0
+
+ tokens_per_expert, expert_scores = zip(*get_load_balancing_loss())
+ num_layers_per_pipeline_stage = args.num_layers // args.pipeline_model_parallel_size
+ if args.num_layers_per_virtual_pipeline_stage is not None:
+ num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage
+
+ if len(tokens_per_expert) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f"Expected {num_layers_per_pipeline_stage} token_per_experts "
+ f"but found {len(tokens_per_expert)}.\nnum_layers = "
+ f"{args.num_layers}\npipeline_model_parallel_size = "
+ f"{args.pipeline_model_parallel_size}\n"
+ "num_layers_per_virtual_pipeline_stage"
+ f" = {args.num_layers_per_virtual_pipeline_stage}",
+ )
+ if len(expert_scores) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f"Expected {num_layers_per_pipeline_stage} expert_scores "
+ f"but found {len(tokens_per_expert)}.\nnum_layers = "
+ f"{args.num_layers}\npipeline_model_parallel_size = "
+ f"{args.pipeline_model_parallel_size}\n"
+ "num_layers_per_virtual_pipeline_stage"
+ f" = {args.num_layers_per_virtual_pipeline_stage}",
+ )
+
+ # Verify the shape of the tokens_per_expert and expert_scores tensors.
+ assert all(
+ (x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert)
+ )
+
+ tokens = expert_scores[0].shape[0]
+ assert all(
+ (
+ (
+ x.ndim == 2
+ and x.shape[1] == args.moe_num_experts
+ and x.shape[0] == tokens
+ )
+ for x in expert_scores
+ )
+ )
+
+ # Concatenate the contributions of each layer and convert to
+ # the correct types and formats for the dot product.
+ expert_scores = torch.cat(expert_scores, dim=1)
+ if args.moe_lbl_in_fp32:
+ expert_scores = expert_scores.float()
+ if tokens != 0:
+ expert_scores = expert_scores.mean(dim=0)
+ else:
+ expert_scores = expert_scores.sum(dim=0)
+ tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype)
+
+ expected_values = num_layers_per_pipeline_stage * args.moe_num_experts
+ assert tokens_per_expert.numel() == expected_values
+ assert expert_scores.numel() == expected_values
+
+ # Calculate the total scale across all factors.
+ #
+ # loss_weight * num_experts / (num_layers * tokens * top_k)
+ scale_numerator = args.moe_num_experts * args.moe_loss_weight
+ scale_denominator = args.num_layers * tokens * args.moe_top_k
+ scale = scale_numerator / scale_denominator
+ return scale * torch.dot(tokens_per_expert, expert_scores)
+
+
+# Calculate the expert capacity based on tokens, top_k, number of experts,
+# expert parallel group, capacity factor, and whether expert model parallelism is used.
+def expert_capacity(
+ tokens: int,
+ top_k: int,
+ num_experts: int,
+ expert_parallel_group: int,
+ moe_capacity_factor: float,
+ moe_expert_model_parallelism: bool,
+) -> int:
+ world_size = (
+ dist.get_world_size(expert_parallel_group)
+ if moe_expert_model_parallelism
+ else 1
+ )
+
+ tokens_per_expert = top_k * tokens * world_size / num_experts
+ return int(moe_capacity_factor * tokens_per_expert)
+
+
+def load_balancing_loss(
+ tokens_per_expert: torch.Tensor,
+ expert_scores: torch.Tensor,
+ top_k: int,
+ num_experts: int,
+):
+ assert len(expert_scores.size()) == 2
+ tokens, num_experts = expert_scores.size()
+ assert num_experts == num_experts
+ assert len(tokens_per_expert.size()) == 1
+ (num_experts,) = tokens_per_expert.size()
+ assert num_experts == num_experts
+ scale = num_experts / (tokens * top_k)
+ return scale * torch.dot(
+ tokens_per_expert.to(expert_scores.dtype),
+ expert_scores.mean(dim=0),
+ )
+
+
+def indices_and_bins(
+ top_expert: torch.Tensor,
+ sort_end_bit: int,
+ num_experts: int,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ top_expert = top_expert.int()
+
+ # Ensure contiguous memory layout
+ top_expert = top_expert.contiguous()
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(top_expert.device):
+ output = ops.sort(top_expert, sort_end_bit)
+ bin_ids, indices = output
+ tokens_per_expert = ops.histogram(top_expert, num_experts)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+
+ bins = bins.view(1) if not len(bins.size()) else bins
+ return indices, bin_ids, bins, tokens_per_expert
+
+
+def expert_capacity_fn(
+ tokens: int,
+ top_k: int,
+ num_experts: int,
+ expert_parallel_group: torch.distributed.ProcessGroup,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+) -> int:
+ world_size = (
+ dist.get_world_size(expert_parallel_group)
+ if moe_expert_model_parallelism
+ else 1
+ )
+ tokens_per_expert = top_k * tokens * world_size / num_experts
+ return int(moe_capacity_factor * tokens_per_expert)
+
+
+def permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ top_k,
+ w1,
+ w2,
+ w1_bias,
+ w2_bias,
+ gradient_scale,
+ alpha,
+):
+ # Route tokens to experts
+ x = x.view(-1, x.shape[-1])
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(x.device):
+ x = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
+
+ # Expert computation
+ x = mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale, alpha)
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(x.device):
+ # Route tokens back
+ out = ops.binned_scatter(x, indices, expert_weights, bins, top_k)
+ return out
+
+
+def forward_once(
+ x: torch.Tensor,
+ expert_weights: torch.Tensor,
+ top_experts: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ top_k: int = 4,
+ num_experts: int = 128,
+ expert_parallel_group: int = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ mlp_impl: Optional[str] = None,
+):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
+ top_experts, sort_end_bit, num_experts
+ )
+
+ # Calculate expert capacity
+ sl, bs, _ = x.size()
+
+ expert_capacity = expert_capacity_fn(
+ sl * bs,
+ top_k,
+ num_experts,
+ expert_parallel_group,
+ moe_capacity_factor,
+ moe_expert_model_parallelism,
+ )
+
+ if expert_capacity == 0:
+ expert_capacity = torch.max(tokens_per_expert).item()
+
+ x = permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ top_k,
+ w1,
+ w2,
+ w1_bias,
+ w2_bias,
+ gradient_scale,
+ alpha,
+ )
+ return x, tokens_per_expert
+
+
+def parallel_forward_once(
+ x: torch.Tensor,
+ expert_weights: torch.Tensor,
+ top_experts: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ top_k: int = 4,
+ num_experts: int = 128,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = True,
+ hidden_size: int = 1152,
+ mlp_impl: Optional[str] = "grouped",
+):
+ # Flatten inputs
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+
+ # TODO: remove debugging var
+ # my_rank = dist.get_rank(expert_parallel_group) if expert_parallel_group else 0
+
+ with torch.no_grad():
+ # Step 1: Local permutation setup
+ indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
+ top_experts, sort_end_bit, num_experts
+ )
+
+ # Calculate sharding parameters
+ world_size = dist.get_world_size(expert_parallel_group)
+ hidden_sharding_deg = hidden_sharding_degree(
+ world_size, num_experts, hidden_size
+ )
+ experts_per_rank_val = experts_per_rank(num_experts, world_size)
+
+ # Replicate token counts for hidden sharding
+ repeated_tokens_per_expert = ops.repeat(
+ tokens_per_expert, (hidden_sharding_deg,)
+ )
+
+ # Exchange token counts across devices
+ parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert)
+
+ # Ensure CUB knows which device to use
+ tpe_handle = dist.all_to_all_single(
+ parallel_tokens_per_expert,
+ repeated_tokens_per_expert,
+ group=expert_parallel_group,
+ async_op=True,
+ )
+
+ # Step 2: Local permutation - group tokens by target device
+ x = x.view(-1, x.shape[-1]) # [sl * bs, hs]
+ x = ops.gather(x, indices, bin_ids, bins, top_k)
+
+ # Step 3: Compute communication counts and exchange tokens
+ with torch.no_grad():
+ tpe_handle.wait()
+
+ # Reshape for per-device calculations
+ repeated_tokens_per_expert = repeated_tokens_per_expert.view(
+ world_size, experts_per_rank_val
+ )
+ parallel_tokens_per_expert = parallel_tokens_per_expert.view(
+ world_size, experts_per_rank_val
+ )
+
+ # Calculate send/recv counts
+ send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1).tolist()
+ # recv_counts = parallel_tokens_per_expert.cpu().sum(dim=-1).tolist()
+ parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
+ recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1).tolist()
+ tokens_received = sum(recv_counts)
+
+ # Replicate for hidden sharding
+ x = ops.repeat(x, (hidden_sharding_deg, 1))
+
+ # Cross-device token exchange
+ parallel_x, parallel_x_handle = _layers.all_to_all.all_to_all(
+ x, recv_counts, send_counts, expert_parallel_group, async_op=True
+ )
+
+ with torch.no_grad():
+ # Step 4: Setup for local expert computation
+ replicate_bins = ops.inclusive_cumsum(parallel_tokens_per_expert.flatten(), 0)
+ replicate_bins = (
+ replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins
+ )
+
+ # Create expert indices for received tokens
+ parallel_top_expert = torch.remainder(
+ torch.arange(
+ num_experts * hidden_sharding_deg,
+ dtype=torch.int32,
+ device=indices.device,
+ ),
+ experts_per_rank_val,
+ )
+ parallel_top_expert = ops.replicate(
+ parallel_top_expert.unsqueeze(dim=0),
+ replicate_bins,
+ tokens_received,
+ ).flatten()
+
+ # Sort tokens by expert assignment
+ parallel_bin_ids, parallel_indices = ops.sort(
+ parallel_top_expert,
+ sort_end_bit,
+ )
+
+ # Calculate bins for local experts
+ parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
+ dim=0, dtype=torch.int
+ )
+ parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = (
+ parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins
+ )
+
+ # Calculate expert capacity
+ expert_capacity = expert_capacity_fn(
+ tokens_received,
+ top_k,
+ experts_per_rank_val,
+ expert_parallel_group,
+ moe_capacity_factor,
+ moe_expert_model_parallelism,
+ )
+ if expert_capacity == 0:
+ expert_capacity = torch.max(parallel_tokens_per_expert).item()
+
+ # Locally permute the tokens and perform the expert computation.
+ # Block to make sure that the cross-device permutation is complete.
+ if mlp_impl == "grouped":
+ # GroupedMLP requires counts on CPU. We can use the tensor already
+ # moved to CPU for the prior all_to_all, which avoids an extra
+ # device synchronization.
+ parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+
+ # Step 5: Expert computation
+ parallel_x_handle.wait()
+
+ parallel_x = permute_and_compute(
+ parallel_x,
+ parallel_tokens_per_expert,
+ parallel_indices,
+ parallel_bin_ids,
+ None, # expert_weights
+ parallel_bins,
+ expert_capacity,
+ top_k=1,
+ w1=w1,
+ w2=w2,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ )
+
+ # Step 6: Reverse communication - send results back
+ x, _ = _layers.all_to_all.all_to_all(
+ parallel_x, send_counts, recv_counts, expert_parallel_group
+ )
+
+ # Step 7: Reduce across hidden sharding dimension
+ shape = (hidden_sharding_deg, -1, hidden_size)
+ x = x.view(shape).sum(dim=0)
+
+ # Step 8: Final local unpermutation
+ x = ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
+
+ return x, tokens_per_expert.flatten()
+
+
+def moe_forward(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: Optional[torch.Tensor],
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+ w1: torch.Tensor = None,
+ w2: torch.Tensor = None,
+ w1_bias: torch.Tensor = None,
+ w2_bias: torch.Tensor = None,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ forward_fn: Any = None,
+ hidden_size: int = None,
+ mlp_impl: str = "grouped",
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+
+ # Route tokens to experts
+ logits, expert_weights, expert_indices = route_tokens(
+ x,
+ router_weight,
+ router_bias,
+ moe_top_k,
+ moe_num_experts,
+ moe_jitter_eps,
+ moe_normalize_expert_weights,
+ uniform_expert_assignment,
+ training,
+ )
+
+ # Create router scores for output
+ router_scores = (
+ torch.zeros_like(logits)
+ .scatter_(1, expert_indices, expert_weights)
+ .transpose(0, 1)
+ )
+
+ in_shape = x.size()
+
+ # Prepare forward function arguments
+ forward_args = {
+ "x": x,
+ "expert_weights": expert_weights,
+ "top_experts": expert_indices,
+ "w1": w1,
+ "w2": w2,
+ "w1_bias": w1_bias,
+ "w2_bias": w2_bias,
+ "gradient_scale": gradient_scale,
+ "alpha": alpha,
+ "sort_end_bit": sort_end_bit,
+ "top_k": moe_top_k,
+ "num_experts": moe_num_experts,
+ "expert_parallel_group": expert_parallel_group,
+ "moe_capacity_factor": moe_capacity_factor,
+ "moe_expert_model_parallelism": moe_expert_model_parallelism,
+ "mlp_impl": mlp_impl,
+ }
+
+ # Add hidden_size for parallel forward
+ if moe_expert_model_parallelism and hidden_size is not None:
+ forward_args["hidden_size"] = hidden_size
+ elif moe_expert_model_parallelism and hidden_size is None:
+ # Infer hidden_size from input shape
+ forward_args["hidden_size"] = x.shape[-1]
+
+ # Compute expert outputs
+ x, tokens_per_expert = forward_fn(**forward_args)
+
+ # Save load balancing loss if needed
+ moe_loss_weight = 0.0 # Can be made configurable
+ if training and moe_loss_weight > 0:
+ save_load_balancing_loss((tokens_per_expert, logits))
+
+ # Restore original shape
+ x = x.view(in_shape)
+
+ return x, expert_weights, router_scores
+
+
+def moe_forward_with_shared_expert(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: Optional[torch.Tensor],
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+ w1: torch.Tensor = None,
+ w2: torch.Tensor = None,
+ w1_bias: torch.Tensor = None,
+ w2_bias: torch.Tensor = None,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ forward_fn: Any = None,
+ hidden_size: int = None,
+ mlp_impl: str = "grouped",
+ # Shared expert parameters
+ shared_up_proj_weight: Optional[torch.Tensor] = None,
+ shared_down_proj_weight: Optional[torch.Tensor] = None,
+ shared_up_proj_bias: Optional[torch.Tensor] = None,
+ shared_down_proj_bias: Optional[torch.Tensor] = None,
+ shared_expert_weighted_sum: bool = False,
+ shared_activation_fn: Optional[Any] = None,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+
+ # First, compute regular MoE forward pass
+ expert_out, expert_weights, router_scores = moe_forward(
+ x=x,
+ router_weight=router_weight,
+ router_bias=router_bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=training,
+ w1=w1,
+ w2=w2,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=moe_expert_model_parallelism,
+ forward_fn=forward_fn,
+ hidden_size=hidden_size,
+ mlp_impl=mlp_impl,
+ )
+
+ # If shared expert weights provided, compute shared expert output
+ if shared_up_proj_weight is not None and shared_down_proj_weight is not None:
+ shared_expert_out = shared_mlp_forward(
+ x=x,
+ up_proj_weight=shared_up_proj_weight,
+ down_proj_weight=shared_down_proj_weight,
+ up_proj_bias=shared_up_proj_bias,
+ down_proj_bias=shared_down_proj_bias,
+ activation_fn=shared_activation_fn,
+ gradient_scale=gradient_scale,
+ )
+
+ # Combine expert outputs
+ combined_out = combine_expert_shared_outputs(
+ shared_expert_out=shared_expert_out,
+ expert_out=expert_out,
+ shared_expert_weighted_sum=shared_expert_weighted_sum,
+ moe_top_k=moe_top_k,
+ )
+
+ return combined_out, expert_weights, router_scores
+
+ # Return regular MoE output if no shared expert
+ return expert_out, expert_weights, router_scores
+
+
+def create_shared_expert_weights(
+ hidden_size: int,
+ shared_expert_hidden_size: int,
+ device: torch.device,
+ dtype: torch.dtype,
+ init_method: Any,
+ output_layer_init_method: Any = None,
+) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
+
+ if output_layer_init_method is None:
+ output_layer_init_method = init_method
+
+ # Create weight tensors
+ up_proj_weight = torch.empty(
+ shared_expert_hidden_size,
+ hidden_size,
+ device=device,
+ dtype=dtype,
+ )
+ down_proj_weight = torch.empty(
+ hidden_size,
+ shared_expert_hidden_size,
+ device=device,
+ dtype=dtype,
+ )
+
+ # Initialize weights
+ init_method(up_proj_weight)
+ output_layer_init_method(down_proj_weight)
+
+ # No bias by default
+ return up_proj_weight, down_proj_weight, None, None
+
+
+# HACK: Extract device_mesh from pre-hook closure - required for transformers integration
+# This exists because device_mesh is trapped in hook closures with no model attribute
+# Fragile - breaks if hook structure changes or Python internals change
+# TODO: Replace with a more robust solution when available
+def get_device_mesh(model):
+ # Extract device_mesh from child's unused pre_hook closure
+ try:
+ # Find the pre-hook that contains 'device_mesh' in its closure
+ hook = next(
+ h
+ for h in model.experts._forward_pre_hooks.values()
+ if "device_mesh" in h.__code__.co_freevars
+ )
+ # Extract the device_mesh from the closure
+ return hook.__closure__[
+ hook.__code__.co_freevars.index("device_mesh")
+ ].cell_contents
+ except Exception:
+ return None
+
+
+class MegaBlocksMoeMLP(torch.nn.Module):
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ gradient_scale = getattr(self.experts, "gradient_scale", None)
+ alpha = getattr(self.experts, "alpha", 1.0)
+ moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+ uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
+
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = get_device_mesh(self)
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
+
+ has_parallel = (
+ expert_parallel_group is not None
+ and dist.is_initialized()
+ and dist.get_world_size(expert_parallel_group) > 1
+ )
+ forward_fn = parallel_forward_once if has_parallel else forward_once
+
+ sort_end_bit = max(
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
+ )
+ mlp_impl = getattr(self, "mlp_impl", "grouped")
+ output, expert_weights_out, *_ = moe_forward(
+ x=x,
+ router_weight=self.router.weight,
+ router_bias=self.router.bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=self.training,
+ w1=self.experts.gate_up_proj,
+ w2=self.experts.down_proj,
+ w1_bias=self.experts.gate_up_proj_bias,
+ w2_bias=self.experts.down_proj_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=has_parallel,
+ forward_fn=forward_fn,
+ hidden_size=self.experts.hidden_size,
+ mlp_impl=mlp_impl,
+ )
+ return output, expert_weights_out
+
+
+# Export main classes
+__all__ = ["MegaBlocksMoeMLP", "MegaBlocksMoeMLPWithSharedExpert"]
+
+
+class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
+
+ def __init__(self):
+ super().__init__()
+ # Shared expert weights will be set by the user
+ self.shared_up_proj_weight = None
+ self.shared_down_proj_weight = None
+ self.shared_up_proj_bias = None
+ self.shared_down_proj_bias = None
+ self.shared_expert_weighted_sum = False
+ self.shared_activation_fn = None
+
+ def set_shared_expert_weights(
+ self,
+ up_proj_weight: torch.Tensor,
+ down_proj_weight: torch.Tensor,
+ up_proj_bias: Optional[torch.Tensor] = None,
+ down_proj_bias: Optional[torch.Tensor] = None,
+ weighted_sum: bool = False,
+ activation_fn: Optional[Any] = None,
+ ):
+ self.shared_up_proj_weight = up_proj_weight
+ self.shared_down_proj_weight = down_proj_weight
+ self.shared_up_proj_bias = up_proj_bias
+ self.shared_down_proj_bias = down_proj_bias
+ self.shared_expert_weighted_sum = weighted_sum
+ self.shared_activation_fn = activation_fn
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ gradient_scale = getattr(self.experts, "gradient_scale", None)
+ alpha = getattr(self.experts, "alpha", 1.0)
+ moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+ uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
+
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = get_device_mesh(self)
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
+
+ has_parallel = (
+ expert_parallel_group is not None
+ and dist.is_initialized()
+ and dist.get_world_size(expert_parallel_group) > 1
+ )
+ forward_fn = parallel_forward_once if has_parallel else forward_once
+
+ sort_end_bit = max(
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
+ )
+ mlp_impl = getattr(self, "mlp_impl", "grouped")
+
+ output, expert_weights_out, *_ = moe_forward_with_shared_expert(
+ x=x,
+ router_weight=self.router.weight,
+ router_bias=self.router.bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=self.training,
+ w1=self.experts.gate_up_proj,
+ w2=self.experts.down_proj,
+ w1_bias=self.experts.gate_up_proj_bias,
+ w2_bias=self.experts.down_proj_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=has_parallel,
+ forward_fn=forward_fn,
+ hidden_size=self.experts.hidden_size,
+ mlp_impl=mlp_impl,
+ # Shared expert parameters
+ shared_up_proj_weight=self.shared_up_proj_weight,
+ shared_down_proj_weight=self.shared_down_proj_weight,
+ shared_up_proj_bias=self.shared_up_proj_bias,
+ shared_down_proj_bias=self.shared_down_proj_bias,
+ shared_expert_weighted_sum=self.shared_expert_weighted_sum,
+ shared_activation_fn=self.shared_activation_fn,
+ )
+ return output, expert_weights_out
+
+
+# Patch for XPU or CPU support
+if hasattr(torch, "xpu") and torch.xpu.is_available():
+ from .xpu_fused_moe import MegaBlocksMoeMLP
+
+from .cpu_moe_cpp import CPUMegaBlocksMoeMLP
diff --git a/build/torch210-cxx11-cpu-x86_64-linux/megablocks/__init__.py b/build/torch210-cxx11-cpu-x86_64-linux/megablocks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23
--- /dev/null
+++ b/build/torch210-cxx11-cpu-x86_64-linux/megablocks/__init__.py
@@ -0,0 +1,26 @@
+import ctypes
+import importlib.util
+import sys
+from pathlib import Path
+from types import ModuleType
+
+
+def _import_from_path(file_path: Path) -> ModuleType:
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
+ # it would also be used for other imports. So, we make a module name that
+ # depends on the path for it to be unique using the hex-encoded hash of
+ # the path.
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
+ module_name = path_hash
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ if spec is None:
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
+ module = importlib.util.module_from_spec(spec)
+ if module is None:
+ raise ImportError(f"Cannot load module {module_name} from spec")
+ sys.modules[module_name] = module
+ spec.loader.exec_module(module) # type: ignore
+ return module
+
+
+globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
diff --git a/build/torch210-cxx11-cpu-x86_64-linux/metadata.json b/build/torch210-cxx11-cpu-x86_64-linux/metadata.json
new file mode 100644
index 0000000000000000000000000000000000000000..eb22148b3f551be150f7824a5684c19bbc40ae0e
--- /dev/null
+++ b/build/torch210-cxx11-cpu-x86_64-linux/metadata.json
@@ -0,0 +1,8 @@
+{
+ "version": 1,
+ "license": "Apache-2.0",
+ "python-depends": [],
+ "backend": {
+ "type": "cpu"
+ }
+}
\ No newline at end of file
diff --git a/build/torch210-cxx11-cpu-x86_64-linux/ops/__init__.py b/build/torch210-cxx11-cpu-x86_64-linux/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b944080df810d0b0cfc571f3009b0098a651f9b7
--- /dev/null
+++ b/build/torch210-cxx11-cpu-x86_64-linux/ops/__init__.py
@@ -0,0 +1,35 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from .binned_gather import binned_gather
+from .binned_scatter import binned_scatter
+from .cumsum import exclusive_cumsum, inclusive_cumsum
+from .gather import gather
+from .histogram import histogram
+from .padded_gather import padded_gather
+from .padded_scatter import padded_scatter
+from .repeat import repeat
+from .replicate import replicate
+from .round_up import round_up
+from .scatter import scatter
+from .sort import sort
+from .sum import sum
+from .topology import topology
+
+__all__ = [
+ 'binned_gather',
+ 'binned_scatter',
+ 'exclusive_cumsum',
+ 'inclusive_cumsum',
+ 'gather',
+ 'histogram',
+ 'padded_gather',
+ 'padded_scatter',
+ 'repeat',
+ 'replicate',
+ 'round_up',
+ 'scatter',
+ 'sort',
+ 'sum',
+ 'topology',
+]
diff --git a/build/torch210-cxx11-cpu-x86_64-linux/ops/all_to_all_benchmark.py b/build/torch210-cxx11-cpu-x86_64-linux/ops/all_to_all_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c939818edca3345f6344bbc7cef07ffe3cd0181
--- /dev/null
+++ b/build/torch210-cxx11-cpu-x86_64-linux/ops/all_to_all_benchmark.py
@@ -0,0 +1,63 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import torch.distributed as dist
+
+# from megablocks import benchmark_util
+# from megablocks.layers.all_to_all import all_to_all
+
+from .. import benchmark_util
+from .._layers.all_to_all import all_to_all
+
+_ALL_TO_ALL_BENCHMARK = (
+ (8, 1024),
+ (16, 1024),
+ (32, 1024),
+ (64, 1024),
+ (128, 1024),
+ (256, 1024),
+ (512, 1024),
+ (1024, 1024),
+ (2 * 1024, 1024),
+ (4 * 1024, 1024),
+ (8 * 1024, 1024),
+ (16 * 1024, 1024),
+ (32 * 1024, 1024),
+ (64 * 1024, 1024),
+ (128 * 1024, 1024),
+ (256 * 1024, 1024),
+ (512 * 1024, 1024),
+ (1024 * 1024, 1024),
+)
+
+
+def benchmark_all_to_all(group, sl, hs):
+ world_size = dist.get_world_size(group)
+ assert (sl % world_size) == 0
+ send_recv_sizes = [sl // world_size] * world_size
+
+ x = torch.randn((sl, hs)).cuda().half()
+
+ details = {
+ 'world_size': world_size,
+ 'message_size (B)': send_recv_sizes[0] * hs * 2, # 2B elements.
+ }
+
+ def benchmark():
+ return all_to_all(x, send_recv_sizes, send_recv_sizes, group)
+
+ time, std = benchmark_util.benchmark_function(benchmark)
+
+ if dist.get_rank(group) == 0:
+ benchmark_util.log_benchmark('All-To-All', details, time, std)
+
+
+if __name__ == '__main__':
+ assert dist.is_available()
+ group = dist.init_process_group(backend='nccl')
+ local_rank = dist.get_rank(group)
+ torch.cuda.set_device(local_rank)
+
+ for args in _ALL_TO_ALL_BENCHMARK:
+ benchmark_all_to_all(group, *args)
diff --git a/build/torch210-cxx11-cpu-x86_64-linux/ops/binned_gather.py b/build/torch210-cxx11-cpu-x86_64-linux/ops/binned_gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..189a7fa3518d660f29ea32e7a04827164af98d60
--- /dev/null
+++ b/build/torch210-cxx11-cpu-x86_64-linux/ops/binned_gather.py
@@ -0,0 +1,37 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for binned_gather kernel.
+class BinnedGatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bins: torch.Tensor,
+ bin_size: int,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bins)
+ ctx.top_k = top_k
+ return kernels.binned_gather(x, indices, None, bins, bin_size, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ indices, bins = ctx.saved_tensors
+ out = kernels.binned_scatter(grad, indices, None, bins, ctx.top_k)
+ return out, None, None, None, None
+
+
+binned_gather = BinnedGatherOp.apply
diff --git a/build/torch210-cxx11-cpu-x86_64-linux/ops/binned_scatter.py b/build/torch210-cxx11-cpu-x86_64-linux/ops/binned_scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb937c0c106662ce8108c1cb926f8f063b163d3d
--- /dev/null
+++ b/build/torch210-cxx11-cpu-x86_64-linux/ops/binned_scatter.py
@@ -0,0 +1,59 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for binned_scatter kernel.
+class BinnedScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ):
+ assert len(x.size()) == 3
+ ctx.bin_size = x.size(1)
+ ctx.top_k = top_k
+
+ # TODO(tgale): Don't save 'x' for backwards if we don't need to
+ # calculate the gradient w.r.t. 'weights'.
+ ctx.save_for_backward(x, indices, weights, bins)
+ return kernels.binned_scatter(x, indices, weights, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ x, indices, weights, bins = ctx.saved_tensors
+ out = kernels.binned_gather(
+ grad,
+ indices,
+ weights,
+ bins,
+ ctx.bin_size,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[2]:
+ wgrad = kernels.binned_scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bins,
+ ctx.top_k,
+ )
+ return out, None, wgrad, None, None
+
+
+binned_scatter = BinnedScatterOp.apply
diff --git a/build/torch210-cxx11-cpu-x86_64-linux/ops/cumsum.py b/build/torch210-cxx11-cpu-x86_64-linux/ops/cumsum.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2b7572391e20045d335cf7337246e8a9b9f57ef
--- /dev/null
+++ b/build/torch210-cxx11-cpu-x86_64-linux/ops/cumsum.py
@@ -0,0 +1,52 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ # import megablocks_ops as ops # type: ignore
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrappers for cumsum kernels.
+# NOTE: Does not support gradients.
+class ExclusiveCumsumOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, dim: int):
+ if len(x.size()) == 1:
+ x = x.view([1, -1])
+ out = torch.empty_like(x)
+ ops.exclusive_cumsum(x, 1, out)
+ return out.squeeze()
+ out = torch.empty_like(x)
+ ops.exclusive_cumsum(x, dim, out)
+ return out
+
+
+exclusive_cumsum = ExclusiveCumsumOp.apply
+
+
+class InclusiveCumsumOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, dim: int) -> torch.Tensor:
+ if len(x.size()) == 1:
+ x = x.view([1, -1])
+ out = torch.empty_like(x)
+ ops.inclusive_cumsum(x, 1, out)
+ return out.squeeze()
+ out = torch.empty_like(x)
+ ops.inclusive_cumsum(x, dim, out)
+ return out
+
+
+inclusive_cumsum = InclusiveCumsumOp.apply
diff --git a/build/torch210-cxx11-cpu-x86_64-linux/ops/gather.py b/build/torch210-cxx11-cpu-x86_64-linux/ops/gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1f87c1e7bed8d3589dd790805234976e0b05898
--- /dev/null
+++ b/build/torch210-cxx11-cpu-x86_64-linux/ops/gather.py
@@ -0,0 +1,38 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for gather kernel.
+class GatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bin_ids, bins)
+ ctx.top_k = top_k
+ return kernels.gather(x, indices, bin_ids, None, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+
+ indices, bin_ids, bins = ctx.saved_tensors
+ out = kernels.scatter(grad, indices, bin_ids, None, bins, ctx.top_k)
+ return out, None, None, None, None, None
+
+
+gather = GatherOp.apply
diff --git a/build/torch210-cxx11-cpu-x86_64-linux/ops/histogram.py b/build/torch210-cxx11-cpu-x86_64-linux/ops/histogram.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b3f058ec373cbba7555704fb5e4212c3cc75d9d
--- /dev/null
+++ b/build/torch210-cxx11-cpu-x86_64-linux/ops/histogram.py
@@ -0,0 +1,27 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for histogram kernel.
+# NOTE: Does not support gradients.
+class HistogramOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, max_val: float):
+ return ops.histogram(x, max_val)
+
+
+histogram = HistogramOp.apply
diff --git a/build/torch210-cxx11-cpu-x86_64-linux/ops/histogram_benchmark.py b/build/torch210-cxx11-cpu-x86_64-linux/ops/histogram_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0b9c6047567b87a295979498142230d1b0c9411
--- /dev/null
+++ b/build/torch210-cxx11-cpu-x86_64-linux/ops/histogram_benchmark.py
@@ -0,0 +1,78 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import numpy as np
+import torch
+# from absl.testing import parameterized
+
+from .. import ops
+
+_HISTOGRAM_TESTS = (
+ (16384, torch.int32, 2),
+ (16384, torch.int32, 4),
+ (16384, torch.int32, 8),
+ (16384, torch.int32, 16),
+ (16384, torch.int32, 32),
+ (16384, torch.int32, 64),
+ (16384, torch.int32, 128),
+ (16384, torch.int32, 256),
+)
+
+
+def benchmark_function(fn, iterations=10):
+ # Run once to get rid of startup overhead.
+ fn()
+ times = []
+ for _ in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+ start.record()
+ fn()
+ end.record()
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ times = np.array(times)
+ return times.mean(), times.std(), times.max(), times.min()
+
+
+def log_benchmark(arguments, mean_t, std_t):
+ print('=' * 60)
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t))
+ print('=' * 60)
+
+
+# class HistogramBenchmark(parameterized.TestCase):
+#
+# @parameterized.parameters(*_HISTOGRAM_TESTS)
+# def testHistogram(self, n, dtype, max_val):
+# x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
+#
+# mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.histogram(x, max_val),)
+# arguments = {
+# 'n': n,
+# 'dtype': dtype,
+# 'max_val': max_val,
+# }
+# log_benchmark(arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_HISTOGRAM_TESTS)
+# def testTorchHistogram(self, n, dtype, max_val):
+# x = torch.randint(0, 128, (n,)).cuda().to(dtype)
+#
+# mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.histc(x, max_val, 0, max_val - 1),)
+# arguments = {
+# 'n': n,
+# 'dtype': dtype,
+# 'max_val': max_val,
+# }
+# log_benchmark(arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch210-cxx11-cpu-x86_64-linux/ops/matmul_benchmark.py b/build/torch210-cxx11-cpu-x86_64-linux/ops/matmul_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c4b9b8866ffed2eb769b77f2320c82e5491ae0e
--- /dev/null
+++ b/build/torch210-cxx11-cpu-x86_64-linux/ops/matmul_benchmark.py
@@ -0,0 +1,415 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+
+# import stk
+
+# try:
+# import stk
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/ops/matmul_benchmark.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+# from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+
+# Calling tensor.t() calls tensor.transpose(0, 1) which calls
+# torch.as_strided(...). Circumvent this chain to avoid an overhead
+# this adds.
+def transpose_view(x):
+ return torch.as_strided(
+ x,
+ (x.shape[1], x.shape[0]),
+ (x.stride()[1], x.stride()[0]),
+ )
+
+
+_MATMUL_TESTS = (
+ (64 * 1024, 512, 2048, 64),
+ (32 * 1024, 768, 3072, 64),
+ (8 * 1024, 1024, 4096, 64),
+ (4 * 2048, 4096, 4 * 4096, 4),
+)
+
+
+def log_benchmark(name, arguments, time, std, flops):
+ benchmark_util.log_benchmark(name, arguments, time, std)
+ print('flops = {:.2f}B'.format(flops / 1e9))
+ print('throughput = {:.2f}T'.format(flops / 1e9 / time))
+ print('=' * 60)
+
+
+# class MatmulBenchmark(parameterized.TestCase):
+#
+# def build_sparse_matrix(self, x, padded_bins, fhs, ne):
+# blocking = 128
+# padded_tokens, _ = x.size()
+# assert padded_tokens % blocking == 0
+# assert fhs % blocking == 0
+#
+# # Offsets for the sparse matrix. All rows have the
+# # same number of nonzero blocks dictated by the
+# # dimensionality of a single expert.
+# block_rows = padded_tokens // blocking
+# blocks_per_row = fhs // blocking
+# offsets = torch.arange(
+# 0,
+# block_rows * blocks_per_row + 1,
+# blocks_per_row,
+# dtype=torch.int32,
+# device=x.device,
+# )
+#
+# # Indices for the sparse matrix. The indices for
+# # the intermediate matrix are dynamic depending
+# # on the mapping of tokens to experts.
+# column_indices = ops.topology(
+# padded_bins,
+# blocking,
+# block_rows,
+# blocks_per_row,
+# )
+# data = torch.empty(
+# column_indices.numel(),
+# blocking,
+# blocking,
+# dtype=torch.float16,
+# device=x.device,
+# )
+# shape = (padded_tokens, fhs * ne)
+# row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
+# return stk.Matrix(shape, data, row_indices, column_indices, offsets)
+#
+# def build_input_matrix(self, sl, hs, ne):
+# x = torch.randn((sl, hs)).cuda().half()
+#
+# # Assign tokens to experts uniformly.
+# top_expert = torch.arange(0, sl).cuda().int() % ne
+#
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(top_expert, ne)
+# padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+# padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+# out = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, 1)
+# return out, padded_bins
+#
+# def build_weight_matrix(self, ne, hs, fhs):
+# return torch.randn((hs, ne * fhs)).cuda().half()
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_Fwd_SDD_NT(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+# w = transpose_view(w)
+#
+# def benchmark():
+# return stk.ops.sdd(x, w, topo)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0::Fwd::SDD::NT',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_GradX_DSD_NN(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+#
+# def benchmark():
+# return stk.ops.dsd(topo, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0::GradX::DSD::NN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_GradW_DSD_TN(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+# topo = topo.t()
+#
+# def benchmark():
+# return stk.ops.dsd(topo, x)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0::GradW::DSD::TN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_Fwd_DSD_NN(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+#
+# def benchmark():
+# return stk.ops.dsd(x, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::Fwd::DSD::NN',
+# arguments,
+# mean_t,
+# std_t,
+# x.nnz * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_GradX_SDD_NT(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+# out = stk.ops.dsd(x, w)
+# w = transpose_view(w)
+#
+# def benchmark():
+# return stk.ops.sdd(out, w, x)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::GradX::SDD::NT',
+# arguments,
+# mean_t,
+# std_t,
+# x.nnz * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_GradW_DSD_TN(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+# out = stk.ops.dsd(x, w)
+# x = x.t()
+#
+# def benchmark():
+# return stk.ops.dsd(x, out)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::GradW::DSD::TN',
+# arguments,
+# mean_t,
+# std_t,
+# x.nnz * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_Fwd_DDD_NT(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, hs)).cuda().half()
+# w = torch.randn((ne, hs, fhs)).cuda().half()
+#
+# w = w.transpose(1, 2).contiguous()
+# w = w.transpose(1, 2)
+#
+# def benchmark():
+# return torch.bmm(x, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0::Fwd:DDD::NT',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_GradX_DDD_NN(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, hs)).cuda().half()
+# w = torch.randn((ne, hs, fhs)).cuda().half()
+# out = torch.bmm(x, w)
+# w = w.transpose(1, 2).contiguous()
+#
+# def benchmark():
+# return torch.bmm(out, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0:GradX:DDD::NN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_GradW_DDD_TN(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, hs)).cuda().half()
+# w = torch.randn((ne, hs, fhs)).cuda().half()
+# out = torch.bmm(x, w)
+# out = out.transpose(1, 2)
+#
+# def benchmark():
+# return torch.bmm(out, x)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0:GradW:DDD::TN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_Fwd_DDD_NN(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+# w = torch.randn((ne, fhs, hs)).cuda().half()
+#
+# def benchmark():
+# return torch.bmm(x, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::Fwd::DDD::NN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_GradX_DDD_NT(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+# w = torch.randn((ne, fhs, hs)).cuda().half()
+# out = torch.bmm(x, w)
+# w = torch.transpose(w, 1, 2)
+#
+# def benchmark():
+# return torch.bmm(out, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::GradX::DDD::NT',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_GradW_DDD_TN(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+# w = torch.randn((ne, fhs, hs)).cuda().half()
+# out = torch.bmm(x, w)
+# x = torch.transpose(x, 1, 2)
+#
+# def benchmark():
+# return torch.bmm(x, out)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::GradW::DDD::TN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * hs * 2,
+# )
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch210-cxx11-cpu-x86_64-linux/ops/padded_gather.py b/build/torch210-cxx11-cpu-x86_64-linux/ops/padded_gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1cf4047c9494394d2a3884ba8830179013db7ff
--- /dev/null
+++ b/build/torch210-cxx11-cpu-x86_64-linux/ops/padded_gather.py
@@ -0,0 +1,55 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for padded_gather kernel.
+class PaddedGatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bin_ids, bins, padded_bins)
+ ctx.top_k = top_k
+ return kernels.padded_gather(
+ x,
+ indices,
+ bin_ids,
+ None,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+
+ indices, bin_ids, bins, padded_bins = ctx.saved_tensors
+ out = kernels.padded_scatter(
+ grad,
+ indices,
+ bin_ids,
+ None,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+ return out, None, None, None, None, None
+
+
+padded_gather = PaddedGatherOp.apply
diff --git a/build/torch210-cxx11-cpu-x86_64-linux/ops/padded_scatter.py b/build/torch210-cxx11-cpu-x86_64-linux/ops/padded_scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..61e021b81497e472cda5d72bdac557a0ca92d262
--- /dev/null
+++ b/build/torch210-cxx11-cpu-x86_64-linux/ops/padded_scatter.py
@@ -0,0 +1,98 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for padded_scatter kernel.
+class PaddedScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+ ):
+ maybe_x = [x] if ctx.needs_input_grad[3] else []
+ ctx.save_for_backward(
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ *maybe_x,
+ )
+ ctx.top_k = top_k
+ ctx.x_shape = x.shape
+ return kernels.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ saved_tensors = ctx.saved_tensors
+
+ indices, bin_ids, weights, bins, padded_bins = saved_tensors[:5]
+ dgrad = None
+ if ctx.needs_input_grad[0]:
+ dgrad = kernels.padded_gather(
+ grad,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[3]: # need wgrad
+ x = saved_tensors[-1]
+ wgrad = kernels.padded_scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+ return dgrad, None, None, wgrad, None, None, None, None
+
+
+def padded_scatter(
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+):
+ return PaddedScatterOp.apply(
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
diff --git a/build/torch210-cxx11-cpu-x86_64-linux/ops/padded_scatter_benchmark.py b/build/torch210-cxx11-cpu-x86_64-linux/ops/padded_scatter_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..3cbe4735891446b46f93170c64c23fe63632bf93
--- /dev/null
+++ b/build/torch210-cxx11-cpu-x86_64-linux/ops/padded_scatter_benchmark.py
@@ -0,0 +1,66 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import torch
+# from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+_PADDED_SCATTER_BENCHMARK = (
+ # dMoE-Medium, 8-way EMP.
+ (1024 * 16, 1024, 8, 4),
+ # dMoE-Medium, post-all-to-all.
+ (1024 * 16 * 4, 1024, 8, 1),
+)
+
+
+# class PaddedScatterTest(parameterized.TestCase):
+#
+# @parameterized.parameters(*_PADDED_SCATTER_BENCHMARK)
+# def testPaddedScatter(self, sl, hs, ne, top_k):
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+#
+# # Randomly assign tokens to experts.
+# top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(top_expert, ne)
+# padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+# padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+#
+# # Sample weights for the scatter reduce.
+# weights = torch.rand((sl * top_k,)).cuda().half()
+#
+# # Gather the data to prepare for backwards.
+# x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
+#
+# def benchmark():
+# return ops.padded_scatter(
+# x,
+# indices,
+# bin_ids,
+# weights,
+# bins,
+# padded_bins,
+# top_k,
+# )
+#
+# time, std = benchmark_util.benchmark_function(benchmark)
+# benchmark_util.log_benchmark(
+# 'Padded Scatter',
+# {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# 'top_k': top_k,
+# },
+# time,
+# std,
+# )
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch210-cxx11-cpu-x86_64-linux/ops/permute_benchmark.py b/build/torch210-cxx11-cpu-x86_64-linux/ops/permute_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..697abddbb3a2082ec4ddd6d94f89f7faabb34b40
--- /dev/null
+++ b/build/torch210-cxx11-cpu-x86_64-linux/ops/permute_benchmark.py
@@ -0,0 +1,149 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import torch
+# from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+_PERMUTE_TESTS = (
+ (16384, 768, 2),
+ (16384, 768, 4),
+ (16384, 768, 8),
+ (16384, 768, 16),
+ (16384, 768, 32),
+ (16384, 768, 64),
+ (16384, 768, 128),
+ (16384 * 8, 768, 2),
+ (16384 * 8, 768, 4),
+ (16384 * 8, 768, 8),
+ (16384 * 8, 768, 16),
+ (16384 * 8, 768, 32),
+ (16384 * 8, 768, 64),
+ (16384 * 8, 768, 128),
+)
+
+
+# class PermuteBenchmark(parameterized.TestCase):
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testBinnedGather(self, sl, hs, ne):
+# # NOTE: Capacity factor == 1.
+# ec = sl // ne
+#
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+# top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(indices, ne)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+#
+# def benchmark():
+# return ops.binned_gather(x, indices, bins, ec)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('BinnedGather', arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testBinnedScatter(self, sl, hs, ne):
+# # NOTE: Capacity factor == 1.
+# ec = sl // ne
+#
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+# top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(indices, ne)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+# x = ops.binned_gather(x, indices, bins, ec)
+#
+# def benchmark():
+# return ops.binned_scatter(x, indices, bins)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('BinnedScatter', arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testPaddedGather(self, sl, hs, ne):
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+#
+# # Randomly assign tokens to experts.
+# top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(top_expert, ne)
+# padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+# padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+#
+# def benchmark():
+# return ops.padded_gather(x, indices, bin_ids, bins, padded_bins)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('PaddedGather', arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testPaddedScatter(self, sl, hs, ne):
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+#
+# # Randomly assign tokens to experts.
+# top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(top_expert, ne)
+# padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+# padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+# x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins)
+#
+# def benchmark():
+# return ops.padded_scatter(x, indices, bin_ids, bins, padded_bins)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('PaddedScatter', arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testCopy(self, sl, hs, ne):
+# # NOTE: Capacity factor == 1.
+# # ec = sl // ne
+#
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+# y = x.clone()
+#
+# def benchmark():
+# return y.copy_(x)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('Copy', arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch210-cxx11-cpu-x86_64-linux/ops/repeat.py b/build/torch210-cxx11-cpu-x86_64-linux/ops/repeat.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e9e09de5f857d51cd758ab30b2f3a846d6f9275
--- /dev/null
+++ b/build/torch210-cxx11-cpu-x86_64-linux/ops/repeat.py
@@ -0,0 +1,10 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+
+def repeat(x: torch.Tensor, tiling: torch.Size):
+ if all((t == 1 for t in tiling)):
+ return x
+ return x.repeat(*tiling)
diff --git a/build/torch210-cxx11-cpu-x86_64-linux/ops/replicate.py b/build/torch210-cxx11-cpu-x86_64-linux/ops/replicate.py
new file mode 100644
index 0000000000000000000000000000000000000000..26daf0eede330603a4b8ea7167faf1411d07ca93
--- /dev/null
+++ b/build/torch210-cxx11-cpu-x86_64-linux/ops/replicate.py
@@ -0,0 +1,36 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for replicate kernel.
+class ReplicateOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, bins: torch.Tensor, num_outputs: int):
+ ctx.save_for_backward(bins)
+ out = torch.empty((x.shape[0], num_outputs), dtype=x.dtype, device=x.device)
+ ops.replicate_forward(x, bins, out)
+ return out
+
+ @staticmethod
+ def backward(ctx: Any, grad: torch.Tensor):
+ bins, = ctx.saved_tensors
+ out = torch.empty((grad.shape[0], bins.shape[0]), dtype=grad.dtype, device=grad.device)
+ ops.replicate_backward(grad, bins, out)
+ return out, None, None
+
+
+replicate = ReplicateOp.apply
diff --git a/build/torch210-cxx11-cpu-x86_64-linux/ops/round_up.py b/build/torch210-cxx11-cpu-x86_64-linux/ops/round_up.py
new file mode 100644
index 0000000000000000000000000000000000000000..6cf6bc873c9f448c5fa9126ebcfd66e8688002af
--- /dev/null
+++ b/build/torch210-cxx11-cpu-x86_64-linux/ops/round_up.py
@@ -0,0 +1,14 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+
+def round_up(x: torch.Tensor, value: int):
+ assert isinstance(value, int)
+ assert x.dtype == torch.int32
+
+ # TODO(tgale): If this becomes and issue
+ # do this in a custom kernel. We only expect
+ # to use this on arrays of less than 1k elements.
+ return torch.div(x + (value - 1), value, rounding_mode='trunc') * value
diff --git a/build/torch210-cxx11-cpu-x86_64-linux/ops/scatter.py b/build/torch210-cxx11-cpu-x86_64-linux/ops/scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4605d9b46f387761b070352365f223dbfe69d47
--- /dev/null
+++ b/build/torch210-cxx11-cpu-x86_64-linux/ops/scatter.py
@@ -0,0 +1,72 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Optional
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for scatter kernel.
+class ScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ) -> torch.Tensor:
+ maybe_x = [x] if ctx.needs_input_grad[3] else []
+ ctx.save_for_backward(indices, bin_ids, weights, bins, *maybe_x)
+ ctx.top_k = top_k
+ ctx.x_shape = x.shape
+ return kernels.scatter(x, indices, bin_ids, weights, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ saved_tensors = ctx.saved_tensors
+
+ indices, bin_ids, weights, bins = saved_tensors[:4]
+ dgrad = None
+ if ctx.needs_input_grad[0]:
+ dgrad = kernels.gather(
+ grad,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[3]: # need wgrad
+ x = saved_tensors[-1]
+ wgrad = kernels.scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bin_ids,
+ bins,
+ ctx.top_k,
+ )
+ return dgrad, None, None, wgrad, None, None, None
+
+
+def scatter(
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+) -> Optional[torch.Tensor]:
+ return ScatterOp.apply(x, indices, bin_ids, weights, bins, top_k)
diff --git a/build/torch210-cxx11-cpu-x86_64-linux/ops/sort.py b/build/torch210-cxx11-cpu-x86_64-linux/ops/sort.py
new file mode 100644
index 0000000000000000000000000000000000000000..bda3bf64283e39533c2eae3627e76bb2d0262c9f
--- /dev/null
+++ b/build/torch210-cxx11-cpu-x86_64-linux/ops/sort.py
@@ -0,0 +1,38 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Optional, Tuple
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+_BITS_FOR_DTYPE = {
+ torch.int16: 16,
+ torch.int32: 32,
+ torch.int64: 64,
+}
+
+
+# Autograd wrapper for sort kernel.
+# NOTE: Does not support gradients.
+class SortOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, end_bit: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
+ if end_bit is None:
+ end_bit = _BITS_FOR_DTYPE[x.dtype]
+ x_out = torch.empty_like(x)
+ iota_out = torch.empty_like(x)
+ ops.sort(x, end_bit, x_out, iota_out)
+ return (x_out, iota_out)
+
+
+sort = SortOp.apply
diff --git a/build/torch210-cxx11-cpu-x86_64-linux/ops/sort_benchmark.py b/build/torch210-cxx11-cpu-x86_64-linux/ops/sort_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..11043c0824c36372585f1d9f48480c2a6ef32eb6
--- /dev/null
+++ b/build/torch210-cxx11-cpu-x86_64-linux/ops/sort_benchmark.py
@@ -0,0 +1,85 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import numpy as np
+import torch
+# from absl.testing import parameterized
+
+from .. import ops
+
+_SORT_TESTS = (
+ (16384, torch.int32, None),
+ (16384, torch.int32, 2),
+ (16384, torch.int32, 128),
+)
+
+_BASELINE_SORT_TESTS = ((16384,),)
+
+
+def numpy_dtype(dtype):
+ types = {
+ torch.int16: np.int16,
+ torch.int32: np.int32,
+ torch.int64: np.int64,
+ }
+ return types[dtype]
+
+
+def benchmark_function(fn, iterations=10):
+ # Run once to get rid of startup overhead.
+ fn()
+ times = []
+ for _ in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+ start.record()
+ fn()
+ end.record()
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ times = np.array(times)
+ return times.mean(), times.std(), times.max(), times.min()
+
+
+def log_benchmark(arguments, mean_t, std_t):
+ print('=' * 60)
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t))
+ print('=' * 60)
+
+
+# class SortBenchmark(parameterized.TestCase):
+#
+# @parameterized.parameters(*_SORT_TESTS)
+# def testSort(self, n, dtype, max_val):
+# if max_val is None:
+# max_val = np.iinfo(numpy_dtype(dtype)).max
+# end_bit = int(np.ceil(np.log2(max_val)))
+# x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
+#
+# mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.sort(x, end_bit),)
+# arguments = {
+# 'n': n,
+# 'dtype': dtype,
+# 'max_val': max_val,
+# }
+# log_benchmark(arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_BASELINE_SORT_TESTS)
+# def testTorchSort(self, n):
+# x = torch.randint(0, 128, (n,)).cuda().to(torch.int32)
+#
+# mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.sort(x))
+# arguments = {
+# 'n': n,
+# }
+# log_benchmark(arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch210-cxx11-cpu-x86_64-linux/ops/stk_autocast.py b/build/torch210-cxx11-cpu-x86_64-linux/ops/stk_autocast.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a3626e5e0eec51339c95a448bca84be14a2ca93
--- /dev/null
+++ b/build/torch210-cxx11-cpu-x86_64-linux/ops/stk_autocast.py
@@ -0,0 +1,39 @@
+# vendored from
+# https://github.com/stanford-futuredata/stk/blob/736313768ef697ce13a0594a41b2512a0fbc9884/stk/backend/autocast.py
+import functools
+import torch
+
+
+def _is_eligible(x):
+ return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64)
+
+
+def _cast(x, dtype):
+ if isinstance(x, torch.Tensor) and _is_eligible(x):
+ return x.to(dtype)
+ elif isinstance(x, map):
+ return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()}
+ elif isinstance(x, list) or isinstance(x, tuple):
+ return type(x)(map(lambda y: _cast(y, dtype), x))
+ return x
+
+
+def custom_fwd(fwd):
+ """Wrap a custom autograd function that always uses autocast dtype."""
+
+ @functools.wraps(fwd)
+ def decorate_fwd(*args, **kwargs):
+ if torch.is_autocast_enabled():
+ with torch.autocast(device_type="cuda", enabled=False):
+ dtype = torch.get_autocast_gpu_dtype()
+ return fwd(*_cast(args, dtype), **_cast(kwargs, dtype))
+ return fwd(*args, **kwargs)
+ return decorate_fwd
+
+
+def custom_bwd(bwd):
+ @functools.wraps(bwd)
+ def decorate_bwd(*args, **kwargs):
+ with torch.autocast(device_type="cuda", enabled=False):
+ return bwd(*args, **kwargs)
+ return decorate_bwd
\ No newline at end of file
diff --git a/build/torch210-cxx11-cpu-x86_64-linux/ops/sum.py b/build/torch210-cxx11-cpu-x86_64-linux/ops/sum.py
new file mode 100644
index 0000000000000000000000000000000000000000..e00c1aa68e1193f5b72f75a2edc37de8d505facc
--- /dev/null
+++ b/build/torch210-cxx11-cpu-x86_64-linux/ops/sum.py
@@ -0,0 +1,9 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+import torch
+
+
+def sum(x: torch.Tensor, dim: int = 0):
+ if x.shape[dim] == 1:
+ return x.squeeze(dim=dim)
+ return x.sum(dim=dim)
diff --git a/build/torch210-cxx11-cpu-x86_64-linux/ops/topology.py b/build/torch210-cxx11-cpu-x86_64-linux/ops/topology.py
new file mode 100644
index 0000000000000000000000000000000000000000..76a50d3164db20534b099dcb4d8487a7aef25d15
--- /dev/null
+++ b/build/torch210-cxx11-cpu-x86_64-linux/ops/topology.py
@@ -0,0 +1,45 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for topology kernel.
+# NOTE: Does not support gradients.
+class TopologyOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(
+ ctx: Any,
+ padded_bins: torch.Tensor,
+ block_size: int,
+ output_block_rows: int,
+ output_block_columns: int,
+ ):
+ out = torch.empty(
+ output_block_rows * output_block_columns,
+ dtype=torch.int16,
+ device=padded_bins.device,
+ )
+ ops.indices(
+ padded_bins,
+ block_size,
+ output_block_rows,
+ output_block_columns,
+ out,
+ )
+ return out
+
+
+topology = TopologyOp.apply
diff --git a/build/torch210-cxx11-cpu-x86_64-linux/stk/__init__.py b/build/torch210-cxx11-cpu-x86_64-linux/stk/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..73c40c267b1c3f4949e9c957a5d2c9f682dfc1a6
--- /dev/null
+++ b/build/torch210-cxx11-cpu-x86_64-linux/stk/__init__.py
@@ -0,0 +1,7 @@
+# import stk.random
+# import stk.ops
+# from stk.matrix import Matrix
+
+from . import random
+from . import ops
+from .matrix import Matrix
diff --git a/build/torch210-cxx11-cpu-x86_64-linux/stk/backend/__init__.py b/build/torch210-cxx11-cpu-x86_64-linux/stk/backend/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/build/torch210-cxx11-cpu-x86_64-linux/stk/backend/autocast.py b/build/torch210-cxx11-cpu-x86_64-linux/stk/backend/autocast.py
new file mode 100644
index 0000000000000000000000000000000000000000..97f6e919a60f3fd579ed0215031008d14111dc96
--- /dev/null
+++ b/build/torch210-cxx11-cpu-x86_64-linux/stk/backend/autocast.py
@@ -0,0 +1,37 @@
+import functools
+import torch
+
+
+def _is_eligible(x):
+ return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64)
+
+
+def _cast(x, dtype):
+ if isinstance(x, torch.Tensor) and _is_eligible(x):
+ return x.to(dtype)
+ elif isinstance(x, map):
+ return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()}
+ elif isinstance(x, list) or isinstance(x, tuple):
+ return type(x)(map(lambda y: _cast(y, dtype), x))
+ return x
+
+
+def custom_fwd(fwd):
+ """Wrap a custom autograd function that always uses autocast dtype."""
+
+ @functools.wraps(fwd)
+ def decorate_fwd(*args, **kwargs):
+ if torch.is_autocast_enabled():
+ with torch.autocast(device_type="cuda", enabled=False):
+ dtype = torch.get_autocast_gpu_dtype()
+ return fwd(*_cast(args, dtype), **_cast(kwargs, dtype))
+ return fwd(*args, **kwargs)
+ return decorate_fwd
+
+
+def custom_bwd(bwd):
+ @functools.wraps(bwd)
+ def decorate_bwd(*args, **kwargs):
+ with torch.autocast(device_type="cuda", enabled=False):
+ return bwd(*args, **kwargs)
+ return decorate_bwd
diff --git a/build/torch210-cxx11-cpu-x86_64-linux/stk/backend/sputnik.py b/build/torch210-cxx11-cpu-x86_64-linux/stk/backend/sputnik.py
new file mode 100644
index 0000000000000000000000000000000000000000..220c947bc1e932e8c77cc30f4069e9930f1aa962
--- /dev/null
+++ b/build/torch210-cxx11-cpu-x86_64-linux/stk/backend/sputnik.py
@@ -0,0 +1,316 @@
+import torch
+
+from ..backend import triton_kernels as backend
+from ..backend.autocast import custom_bwd, custom_fwd
+
+
+def _standardize_shape(x, transpose):
+ if transpose:
+ return torch.Size((x[1], x[0]))
+ return x
+
+
+def _sparse_transpose(x):
+ return (torch.Size((x[0][1], x[0][0])), ) + x[1:]
+
+
+def _transpose_helper(x, transpose):
+ if isinstance(x, torch.Tensor):
+ return x.t() if transpose else x
+ if transpose:
+ x = _sparse_transpose(x)
+ return x + (transpose,)
+
+
+def _wrap(x):
+ if isinstance(x, torch.Tensor):
+ return (x,)
+ return x
+
+
+def _is_transposed(x):
+ return (not x.is_contiguous() and
+ x.stride()[0] == 1 and
+ x.stride()[1] == x.size()[0])
+
+
+def _call_helper(op, out, a, b, trans_a, trans_b):
+ args = (_wrap(_transpose_helper(a, trans_a)) +
+ _wrap(_transpose_helper(b, trans_b)))
+ if isinstance(out, tuple):
+ args = args + out
+ return op(*args)
+
+
+def _preprocess_inputs(lhs, rhs, dy):
+ if isinstance(lhs, torch.Tensor) and _is_transposed(lhs):
+ lhs = lhs.t()
+ if isinstance(rhs, torch.Tensor) and _is_transposed(rhs):
+ rhs = rhs.t()
+ if (isinstance(dy, torch.Tensor) and
+ not dy.is_contiguous() and
+ not _is_transposed(dy)):
+ dy = dy.contiguous()
+ if isinstance(dy, tuple) and not dy[1].is_contiguous():
+ dy = (dy[0], dy[1].contiguous()) + dy[2:]
+ return lhs, rhs, dy
+
+
+def _postprocess_outputs(x, transpose, grad):
+ if isinstance(x, torch.Tensor) and transpose:
+ return grad.t()
+ return grad
+
+
+def _lhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs):
+ lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy)
+
+ a, b = (rhs, dy) if trans_lhs else (dy, rhs)
+ trans_a = trans_lhs and trans_rhs
+ trans_b = trans_lhs or not trans_rhs
+ out = _call_helper(op, lhs, a, b, trans_a, trans_b)
+ return _postprocess_outputs(lhs, trans_lhs, out)
+
+
+def _rhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs):
+ lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy)
+
+ a, b = (dy, lhs) if trans_rhs else (lhs, dy)
+ trans_a = not trans_lhs or trans_rhs
+ trans_b = trans_lhs and trans_rhs
+ out = _call_helper(op, rhs, a, b, trans_a, trans_b)
+ return _postprocess_outputs(rhs, trans_rhs, out)
+
+
+class DSD(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs):
+ ctx.save_for_backward(data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ rhs)
+ ctx.shape = _standardize_shape(shape, transpose_a)
+ ctx.transpose_a = transpose_a
+
+ out = torch.empty(
+ (shape[0], rhs.size()[1]),
+ dtype=rhs.dtype,
+ device=rhs.device)
+
+ backend.dsd(shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs,
+ out)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs = (ctx.shape,) + saved_tensors[:-1]
+ rhs = saved_tensors[-1]
+ trans_a = ctx.transpose_a
+ trans_b = _is_transposed(rhs)
+
+ ddata = None
+ if ctx.needs_input_grad[1]:
+ ddata = _lhs_gradient(sdd,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ drhs = None
+ if ctx.needs_input_grad[-1]:
+ op = dds if trans_b else dsd
+ drhs = _rhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return None, ddata, None, None, None, None, None, None, None, drhs
+
+
+dsd = DSD.apply
+
+
+class DDS(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b):
+ ctx.save_for_backward(lhs,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t)
+ ctx.shape = _standardize_shape(shape, transpose_b)
+ ctx.transpose_b = transpose_b
+ out = torch.empty((lhs.size()[0], shape[1]),
+ dtype=lhs.dtype,
+ device=lhs.device)
+ backend.dds(lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b,
+ out)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs = saved_tensors[0]
+ rhs = (ctx.shape,) + saved_tensors[1:]
+ trans_a = _is_transposed(lhs)
+ trans_b = ctx.transpose_b
+
+ dlhs = None
+ if ctx.needs_input_grad[0]:
+ op = dsd if trans_a else dds
+ dlhs = _lhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ ddata = None
+ if ctx.needs_input_grad[2]:
+ ddata = _rhs_gradient(sdd,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return dlhs, None, ddata, None, None, None, None, None, None, None
+
+
+dds = DDS.apply
+
+
+class SDD(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ lhs,
+ rhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t):
+ ctx.save_for_backward(
+ lhs,
+ rhs,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t)
+ ctx.shape = shape
+ out = torch.empty(
+ data.shape,
+ dtype=lhs.dtype,
+ device=lhs.device)
+ backend.sdd(lhs,
+ rhs,
+ shape,
+ out,
+ offsets,
+ row_indices,
+ column_indices)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs, rhs = saved_tensors[:2]
+ dy = (ctx.shape, dy) + saved_tensors[2:]
+ trans_a = _is_transposed(lhs)
+ trans_b = _is_transposed(rhs)
+
+ dlhs = None
+ if ctx.needs_input_grad[0]:
+ op = dds if trans_a else dsd
+ dlhs = _lhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ drhs = None
+ if ctx.needs_input_grad[1]:
+ op = dsd if trans_b else dds
+ drhs = _rhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return dlhs, drhs, None, None, None, None, None, None, None, None
+
+
+sdd = SDD.apply
+
+class RowIndices(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, shape, data, offsets, column_indices):
+ out = torch.empty(
+ column_indices.shape,
+ dtype=column_indices.dtype,
+ device=column_indices.device)
+ backend.row_indices(shape, data, offsets, column_indices, out)
+ return out
+
+
+row_indices = RowIndices.apply
diff --git a/build/torch210-cxx11-cpu-x86_64-linux/stk/backend/triton_kernels.py b/build/torch210-cxx11-cpu-x86_64-linux/stk/backend/triton_kernels.py
new file mode 100644
index 0000000000000000000000000000000000000000..c535309f3321249f475367164a558f94a4f8eb86
--- /dev/null
+++ b/build/torch210-cxx11-cpu-x86_64-linux/stk/backend/triton_kernels.py
@@ -0,0 +1,393 @@
+import torch
+import triton
+import triton.language as tl
+from dataclasses import dataclass
+
+@dataclass
+class TritonConfig:
+ BLOCK_M: int = 128
+ BLOCK_N: int = 128
+ BLOCK_K: int = 32
+ BLOCK_SIZE: int = 128
+ NUM_STAGES: int = 4
+ NUM_WARPS: int = 4
+
+def _validate_matmul_dims(M: int, K: int, N: int):
+ error_string = "incompatible dimensions: tensor has dim with length: {}, which must be divisible by {}"
+ assert M % TritonConfig.BLOCK_M == 0, error_string.format(M, TritonConfig.BLOCK_M)
+ assert K % TritonConfig.BLOCK_K == 0, error_string.format(K, TritonConfig.BLOCK_K)
+ assert N % TritonConfig.BLOCK_N == 0, error_string.format(N, TritonConfig.BLOCK_N)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _sdd_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+ # matrix multiplication
+ pid = tl.program_id(0)
+ pid_m = tl.load(row_indices + pid)
+ pid_n = tl.load(column_indices + pid)
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
+ rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
+ rk = tl.arange(0, BLOCK_K)
+ # pointers
+ A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
+ B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ for k in range(0, tl.cdiv(K, BLOCK_K)):
+ a = tl.load(A)
+ b = tl.load(B)
+ acc += tl.dot(a, b)
+ A += BLOCK_K * stride_ak
+ B += BLOCK_K * stride_bk
+ #Store to sparse matrix
+ acc = acc.to(C.dtype.element_ty)
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+ cm = tl.arange(0, BLOCK_M)
+ cn = tl.arange(0, BLOCK_N)
+ C = C + pid * BLOCK_ELEMENTS + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _dsd_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices, offsets,
+ block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+
+ # matrix multiplication
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+
+ num_pid_m = tl.num_programs(0)
+ num_pid_n = tl.num_programs(1)
+ pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M)
+
+ start_inx = tl.load(offsets + pid_m)
+ end_inx = tl.load(offsets + pid_m + 1)
+
+ # pointers to sparse matrix
+ rm = tl.arange(0, BLOCK_M)
+ rak = tl.arange(0, BLOCK_K)
+
+ A += (rm[:, None] * stride_am + rak[None, :] * stride_ak)
+
+ # pointers to dense matrix
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ rbk = tl.arange(0, BLOCK_K)
+ B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn)
+
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K)
+
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+ ak_sub_incr = BLOCK_K * stride_ak
+ bk_sub_incr = BLOCK_K * stride_bk
+ bk_block_incr = BLOCK_SIZE * stride_bk
+
+ for k in range(nsub_blocks * (end_inx - start_inx)):
+ sub_block_inx = k % nsub_blocks
+ block_inx = k // nsub_blocks
+
+ if trans_A:
+ ptr_A = A + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr
+ else:
+ ptr_A = A + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr
+
+ ptr_B = B + tl.load(column_indices + start_inx + block_inx) * bk_block_incr + sub_block_inx * bk_sub_incr
+
+ a = tl.load(ptr_A)
+ b = tl.load(ptr_B)
+ acc += tl.dot(a, b)
+
+ acc = acc.to(C.dtype.element_ty)
+
+ cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+
+ C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _dds_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices, offsets,
+ block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+
+ # matrix multiplication
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+
+ num_pid_m = tl.num_programs(0)
+ num_pid_n = tl.num_programs(1)
+ pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M)
+
+ start_inx = tl.load(offsets + pid_n)
+ end_inx = tl.load(offsets + pid_n + 1)
+
+ # pointers to dense matrix
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ rak = tl.arange(0, BLOCK_K)
+
+ A += (rm[:, None] * stride_am + rak[None, :] * stride_ak)
+
+ # pointers to sparse matrix
+ rn = tl.arange(0, BLOCK_N)
+ rbk = tl.arange(0, BLOCK_K)
+ B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn)
+
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K)
+
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+
+ ak_sub_incr = BLOCK_K * stride_ak
+ ak_block_incr = BLOCK_SIZE * stride_ak
+ bk_sub_incr = BLOCK_K * stride_bk
+
+ for k in range(nsub_blocks * (end_inx - start_inx)):
+ sub_block_inx = k % nsub_blocks
+ block_inx = k // nsub_blocks
+
+ if trans_B:
+ ptr_B = B + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr
+ else:
+ ptr_B = B + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr
+
+ ptr_A = A + tl.load(column_indices + start_inx + block_inx) * ak_block_incr + sub_block_inx * ak_sub_incr
+ a = tl.load(ptr_A)
+ b = tl.load(ptr_B)
+ acc += tl.dot(a, b)
+
+ acc = acc.to(C.dtype.element_ty)
+ cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+def dsd(shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs,
+ out
+ ):
+
+ device = rhs.device
+ trans_A = transpose_a
+ trans_B = False
+
+ if rhs.stride(0) > 1 and rhs.stride(1) > 1:
+ trans_B = True
+
+ # checks constraints
+ assert shape[1] == rhs.shape[0], "incompatible dimensions"
+ M, K = shape
+ _, N = rhs.shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if rhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ stride_am, stride_ak = data.stride(1), data.stride(2)
+ stride_bk, stride_bn = rhs.stride(0), rhs.stride(1)
+ a_column_indices = column_indices
+ a_offsets = offsets
+
+ # launch kernel
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N']))
+
+ if trans_A:
+ stride_am, stride_ak = data.stride(2), data.stride(1)
+ a_column_indices, a_offsets = column_indices_t, offsets_t
+
+ if trans_B:
+ stride_bk, stride_bn = rhs.stride(1), rhs.stride(0)
+
+ _dsd_kernel[grid](
+ data.data, rhs, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(0), out.stride(1),
+ row_indices, a_column_indices, a_offsets,
+ block_offsets_t, trans_A, trans_B,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+ # return out
+
+def dds(lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b,
+ out
+ ):
+
+ device = lhs.device
+ trans_B = transpose_b
+ trans_A = False
+
+ if lhs.stride(0) > 1 and lhs.stride(1) > 1:
+ trans_A = True
+
+ # checks constraints
+ assert lhs.shape[1] == shape[0], "incompatible dimensions"
+ M, K = lhs.shape
+ _, N = shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if lhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ stride_am, stride_ak = lhs.stride(0), lhs.stride(1)
+ stride_bk, stride_bn = data.stride(1), data.stride(2)
+ b_column_indices = column_indices_t
+ b_offsets = offsets_t
+
+ # launch kernel
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N']))
+
+ if trans_A:
+ stride_am, stride_ak = lhs.stride(1), lhs.stride(0)
+ if trans_B:
+ stride_bk, stride_bn = data.stride(2), data.stride(1)
+ b_column_indices, b_offsets = column_indices, offsets
+
+ _dds_kernel[grid](
+ lhs, data, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(0), out.stride(1),
+ row_indices, b_column_indices, b_offsets,
+ block_offsets_t, trans_A, trans_B,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+
+def sdd(lhs,
+ rhs,
+ shape,
+ out,
+ offsets,
+ row_indices,
+ column_indices
+ ):
+
+ device = out.device
+ trans_A = False
+ trans_B = False
+
+ if lhs.stride(0) > 1 and lhs.stride(1) > 1:
+ trans_A = True
+ if rhs.stride(0) > 1 and rhs.stride(1) > 1:
+ trans_B = True
+
+ # checks constraints
+ assert lhs.shape[1] == rhs.shape[0], "incompatible dimensions"
+ M, K = lhs.shape
+ _, N = rhs.shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if out.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ # launch kernel
+ nnz_blocks = len(row_indices)
+ grid = lambda META: (nnz_blocks,)
+
+ stride_am, stride_ak = lhs.stride(0), lhs.stride(1)
+ stride_bk, stride_bn = rhs.stride(0), rhs.stride(1)
+
+ if trans_A:
+ stride_am, stride_ak = lhs.stride(1), lhs.stride(0)
+ if trans_B:
+ stride_bk, stride_bn = rhs.stride(1), rhs.stride(0)
+
+ _sdd_kernel[grid](
+ lhs, rhs, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(1), out.stride(2),
+ row_indices, column_indices,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+
+@triton.jit
+def _row_indices_kernel(offsets, out):
+ pid = tl.program_id(0)
+ row_offset = tl.load(offsets + pid)
+ nnz_blocks = tl.load(offsets + pid + 1) - row_offset
+ for nnz_block in range(nnz_blocks):
+ tl.store(out + row_offset + nnz_block, pid)
+
+def row_indices(
+ shape, data, offsets, column_indices, out
+):
+ block_rows = len(offsets) - 1
+ _row_indices_kernel[(block_rows, )](offsets, out)
diff --git a/build/torch210-cxx11-cpu-x86_64-linux/stk/matrix.py b/build/torch210-cxx11-cpu-x86_64-linux/stk/matrix.py
new file mode 100644
index 0000000000000000000000000000000000000000..80f42263d6aada287adbfa52a61fe950162a9e28
--- /dev/null
+++ b/build/torch210-cxx11-cpu-x86_64-linux/stk/matrix.py
@@ -0,0 +1,329 @@
+import numpy as np
+import torch
+
+# 1. Add heavyweight (data) validation helper.
+# 2. Add construction helpers
+# 3. Make indentation consistent
+# 4. Replace asserts with descriptive errors.
+
+##
+### Validation helpers.
+##
+
+
+def _validate_matrix(shape, data, row_indices, column_indices, offsets):
+ # Data should be [nnz, block_size, block_size]
+ if data.dim() == 1:
+ data = torch.reshape(data, [data.numel(), 1, 1])
+
+ # Blocks should be square.
+ if data.shape[-2] != data.shape[-1]:
+ raise ValueError(
+ "Expected square blocking in data. "
+ f"Got block shape {[data.shape[-2], data.shape[-1]]}")
+
+ # Flatten batch dimensions on data - original shape preserved
+ # in shape argument.
+ block_size = data.shape[-1]
+ data = data.view([-1, block_size, block_size])
+
+ if data.dim() != 3:
+ raise ValueError(
+ "Expected 3D shape for data (nnz, block, block). "
+ f"Got shape {data.dim()}D shape.")
+
+ block_size = data.shape[1]
+ if shape[-2] % block_size != 0 or shape[-1] % block_size != 0:
+ raise ValueError(
+ "Matrix shape must be dividible by blocking. "
+ f"Got shape {shape} with "
+ f"{[block_size, block_size]} blocking.")
+
+ if np.prod(shape) < data.numel():
+ raise ValueError(
+ "Invalid matrix. Number of nonzeros exceeds matrix capacity "
+ f"({data.numel()} v. {np.prod(shape)})")
+
+ if row_indices.dim() != 1:
+ raise ValueError(
+ f"Expected 1D row_indices. Got {row_indices.dim()}D row_indices.")
+
+ if column_indices.dim() != 1:
+ raise ValueError(
+ f"Expected 1D column_indices. Got {column_indices.dim()}D column_indices.")
+
+ if offsets.dim() != 1:
+ raise ValueError(
+ f"Expected 1D offsets. Got {offsets.dim()}D offsets.")
+
+ if row_indices.numel() != data.shape[0]:
+ raise ValueError(
+ "Expected 1 index per nonzero block. "
+ f"Got {row_indices.numel()} row_indices for {data.shape[0]} blocks")
+
+ if column_indices.numel() != data.shape[0]:
+ raise ValueError(
+ "Expected 1 index per nonzero block. "
+ f"Got {column_indices.numel()} column_indices for {data.shape[0]} blocks")
+
+ block_rows = np.prod(shape[:-1]) / block_size
+ if offsets.numel() != block_rows + 1:
+ raise ValueError(
+ "Expected one offset per block row plus one. "
+ f"Got {offsets.numel()} offsets with {block_rows} block rows.")
+
+ is_cuda = (data.is_cuda and
+ row_indices.is_cuda and
+ column_indices.is_cuda and
+ offsets.is_cuda)
+ is_cpu = (not data.is_cuda and
+ not row_indices.is_cuda and
+ not column_indices.is_cuda and
+ not offsets.is_cuda)
+ if not (is_cuda or is_cpu):
+ raise ValueError(
+ "Expected data & meta-data on common device. "
+ f"Got data on {data.device}, row_indices on {row_indices.device} "
+ f"column_indices on {column_indices.device} and "
+ f"offsets on {offsets.device}.")
+
+ if data.dtype != torch.float16:
+ raise ValueError(
+ f"Expected float16 data. Got {data.dtype} data.")
+ if row_indices.dtype != torch.int16:
+ raise ValueError(
+ f"Expected int16 row_indices. Got {row_indices.dtype} row_indices.")
+ if column_indices.dtype != torch.int16:
+ raise ValueError(
+ f"Expected int16 column_indices. Got {column_indices.dtype} column_indices.")
+ if offsets.dtype != torch.int32:
+ raise ValueError(
+ f"Expected int32 offsets. Got {offsets.dtype} offsets.")
+ return data
+
+
+def _transpose(size, data, row_indices, column_indices, offsets):
+ block_columns = size[1] // data.shape[1]
+
+ # Sort row indices by column indices to get the transposed matrix's
+ # column indices.
+ gather_indices = column_indices.argsort()
+ column_indices_t = row_indices.gather(0, gather_indices)
+ block_offsets_t = gather_indices.int()
+
+ # NOTE: Histogram is not implemented for any integer type on CPU. Do
+ # the histogram in 32-bit float, which can exactly represent 16-bit
+ # integers.
+ column_indices_float = column_indices.float()
+
+ zero = torch.zeros((1,), dtype=torch.int32, device=data.device)
+ nnz_per_column = column_indices_float.histc(block_columns, 0, block_columns)
+ nnz_per_column = nnz_per_column.int()
+ offsets_t = torch.cat([zero, nnz_per_column.cumsum(0, dtype=torch.int32)])
+ return column_indices_t, offsets_t, block_offsets_t
+
+
+class Matrix(torch.nn.Module):
+ """A matrix stored in sparse format.
+
+ Underlying format is block compressed sparse row (BCSR).
+
+ TODO(tgale): Make this mirror torch.Tensor API as much as possible.
+ """
+
+ def __init__(self,
+ size,
+ data,
+ row_indices,
+ column_indices,
+ offsets,
+ column_indices_t=None,
+ offsets_t=None,
+ block_offsets_t=None):
+ super().__init__()
+ self._size = size
+ self._data = data
+ self._row_indices = row_indices
+ self._column_indices = column_indices
+ self._offsets = offsets
+
+ # Produce the transpose meta-data if it is not passed in.
+ if ((column_indices_t is None) or (offsets_t is None) or
+ (block_offsets_t is None)):
+ column_indices_t, offsets_t, block_offsets_t = _transpose(
+ size, data, row_indices, column_indices, offsets)
+ self._column_indices_t = column_indices_t
+ self._offsets_t = offsets_t
+ self._block_offsets_t = block_offsets_t
+
+ self._transposed = False
+
+ # Validate that our metadata will not overflow.
+ max_dim = np.iinfo(np.int16).max * self.blocking
+ if column_indices.dtype == torch.int16:
+ if size[0] > max_dim or size[1] > max_dim:
+ raise ValueError(
+ "Sparse matrix with shape {size} exceeds representable "
+ "size with 16-bit indices.")
+
+ def validate(self):
+ _validate_matrix(self._size,
+ self._data,
+ self._row_indices,
+ self._column_indices,
+ self._offsets)
+
+ # TODO(tgale): Add heavyweight data validation.
+
+ def to(self, device):
+ # TODO(tgale): Handle type conversions here. We
+ # need to set the appropriate meta-data type for
+ # the given floating-point type.
+ self._data = self._data.to(device)
+ self._row_indices = self._row_indices.to(device)
+ self._column_indices = self._column_indices.to(device)
+ self._offsets = self._offsets.to(device)
+ self._column_indices_t = self._column_indices_t.to(device)
+ self._offsets_t = self._offsets_t.to(device)
+ self._block_offsets_t = self._block_offsets_t.to(device)
+ return self
+
+ def cuda(self):
+ return self.to(torch.cuda.current_device())
+
+ def clone(self):
+ return Matrix(
+ self.size(),
+ self.data.clone(),
+ self.row_indices.clone(),
+ self.column_indices.clone(),
+ self.offsets.clone(),
+ self.column_indices_t.clone(),
+ self.offsets_t.clone(),
+ self.block_offsets_t.clone())
+
+ def t(self):
+ if self.dim() != 2:
+ raise ValueError(
+ "t() expects a tensor with <= 2 dimensions, "
+ f"but self is {self.dim()}D.")
+ out = Matrix(self.size(),
+ self.data,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+ out._transposed = not self._transposed
+ out._size = torch.Size((self._size[1], self._size[0]))
+ return out
+
+ def contiguous(self):
+ raise ValueError("Not yet implemented.")
+
+ def is_contiguous(self):
+ return not self._transposed
+
+ @property
+ def is_cuda(self):
+ return self._data.is_cuda
+
+ @property
+ def device(self):
+ return self._data.device
+
+ def size(self):
+ return self._size
+
+ @property
+ def shape(self):
+ return self.size()
+
+ def dim(self):
+ return len(self._size)
+
+ @property
+ def data(self):
+ return self._data
+
+ @property
+ def row_indices(self):
+ return self._row_indices
+
+ @property
+ def column_indices(self):
+ return self._column_indices
+
+ @property
+ def offsets(self):
+ return self._offsets
+
+ @property
+ def offsets_t(self):
+ return self._offsets_t
+
+ @property
+ def column_indices_t(self):
+ return self._column_indices_t
+
+ @property
+ def block_offsets_t(self):
+ return self._block_offsets_t
+
+ @property
+ def dtype(self):
+ return self.data.dtype
+
+ @property
+ def nnz(self):
+ return self.data.numel()
+
+ @property
+ def blocking(self):
+ return self.data.shape[1]
+
+ @property
+ def requires_grad(self):
+ return self.data.requires_grad
+
+ def requires_grad_(self, x):
+ self.data.requires_grad_(x)
+ return self
+
+ def view(self, *shape):
+ assert self.is_contiguous()
+ if shape[-1] != self.size()[-1]:
+ raise ValueError(
+ "Can't change view on compressed dimension. "
+ f"{self.size()[-1]} v. {shape[-1]}.")
+ if np.prod(shape) != np.prod(self.size()):
+ raise ValueError(
+ "Mismatch in numel of Matrix and new shape. "
+ f"{np.prod(self.size())} v. {np.prod(shape)}")
+ return Matrix(shape,
+ self.data,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+
+ @property
+ def grad(self):
+ # TODO(tgale): Make sure this mirrors torch.Tensor
+ # behavior in the case where we ask for the gradient
+ # of a non-contiguous tensor.
+ size = self.size()
+ if not self.is_contiguous():
+ size = torch.Size((size[1], size[0]))
+ out = Matrix(size,
+ self.data.grad,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+ return out if self.is_contiguous() else out.t()
diff --git a/build/torch210-cxx11-cpu-x86_64-linux/stk/ops/__init__.py b/build/torch210-cxx11-cpu-x86_64-linux/stk/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc873b236f4cd4036964c016a4036e3ce5ebf1ac
--- /dev/null
+++ b/build/torch210-cxx11-cpu-x86_64-linux/stk/ops/__init__.py
@@ -0,0 +1,3 @@
+from .linear_ops import dds, dsd, sdd
+from .matrix_ops import ones_like, row_indices, sum, to_dense, to_sparse
+from .eltwise_ops import mul
diff --git a/build/torch210-cxx11-cpu-x86_64-linux/stk/ops/eltwise_ops.py b/build/torch210-cxx11-cpu-x86_64-linux/stk/ops/eltwise_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba7d7332320250fd01fa60e60528f19de3e8ed03
--- /dev/null
+++ b/build/torch210-cxx11-cpu-x86_64-linux/stk/ops/eltwise_ops.py
@@ -0,0 +1,28 @@
+from ..matrix import Matrix
+
+def mul(a, b):
+ """Performs element-wise multiplication of matrices a and b.
+
+ It is the user's responsibility to make sure that a and b
+ follow the same matrix topology. This function assumes it is safe
+ to use the topoplogy of a.
+
+ Args:
+ a: stk.Matrix.
+ b: stk.Matrix with a's matrix topology.
+
+ Returns:
+ stk.Matrix where the entries correspond to torch.mul(a, b).
+ """
+ assert isinstance(a, Matrix)
+ assert isinstance(b, Matrix)
+ assert a.size() == b.size()
+
+ return Matrix(a.size(),
+ a.data * b.data,
+ a.row_indices,
+ a.column_indices,
+ a.offsets,
+ a.column_indices_t,
+ a.offsets_t,
+ a.block_offsets_t)
diff --git a/build/torch210-cxx11-cpu-x86_64-linux/stk/ops/eltwise_ops_test.py b/build/torch210-cxx11-cpu-x86_64-linux/stk/ops/eltwise_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..2939372a5c68ac92b47b11015db4f75f4fd60ffa
--- /dev/null
+++ b/build/torch210-cxx11-cpu-x86_64-linux/stk/ops/eltwise_ops_test.py
@@ -0,0 +1,86 @@
+import unittest
+import itertools
+import torch
+# from absl.testing import parameterized
+
+import stk
+from stk.ops.linear_ops_test import allclose, _dense_and_sparse
+
+_MATRIX_SIZES = (
+ (128, 128, 0.0),
+ (256, 256, 0.5),
+ (2048, 1024, 0.8),
+ (512, 128, 0.0),
+ (128, 512, 0.0),
+ (1024, 512, 0.0),
+ (1024, 512, 0.5),
+ (1024, 512, 0.75),
+ (512, 1024, 0.0),
+ (512, 1024, 0.5),
+ (512, 1024, 0.75),
+ (1024, 1024, 0.0),
+ (1024, 1024, 0.5),
+ (1024, 1024, 0.75),
+)
+
+_DTYPE = (
+ torch.float16, torch.bfloat16
+)
+
+def _generate_testcases():
+ testcases = itertools.product(_MATRIX_SIZES, _DTYPE)
+ testcases = [(*size, 128, dtype) for
+ (size, dtype) in testcases]
+ return testcases
+
+_ELTWISE_OP_TESTS = _generate_testcases()
+
+def _dense_and_sparse_like(x, std=0.1):
+ dense_data = torch.randn_like(x.data, device=x.device) * std
+ sparse = stk.Matrix(x.size(),
+ dense_data,
+ x.row_indices,
+ x.column_indices,
+ x.offsets)
+ dense = stk.ops.to_dense(sparse)
+
+ return (dense.requires_grad_(True),
+ sparse.requires_grad_(True))
+
+# @parameterized.parameters(_ELTWISE_OP_TESTS)
+# class EltwiseOpsTest(parameterized.TestCase):
+#
+# def testEltwiseMul(self, m, n, sparsity, blocking, dtype):
+#
+# a_dense, a = _dense_and_sparse(m, n, sparsity, blocking, dtype)
+# b_dense, b = _dense_and_sparse_like(a)
+#
+# out = stk.ops.mul(a, b)
+# expected_out = torch.mul(a_dense, b_dense)
+#
+# # Compute the gradients w.r.t. the inputs.
+# expected_out.sum().backward()
+# stk.ops.sum(out).backward()
+#
+# # Validate the results.
+# out = stk.ops.to_dense(out)
+# self.assertEqual(out.dim(), 2)
+# self.assertEqual(expected_out.size(), out.size())
+# self.assertTrue(allclose(out, expected_out))
+#
+# # LHS gradient.
+# grad = stk.ops.to_dense(a.grad)
+# expected_grad = a_dense.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size(), grad.size())
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# # RHS gradient.
+# grad = stk.ops.to_dense(b.grad)
+# expected_grad = b_dense.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size(), grad.size())
+# self.assertTrue(allclose(grad, expected_grad))
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch210-cxx11-cpu-x86_64-linux/stk/ops/linear_ops.py b/build/torch210-cxx11-cpu-x86_64-linux/stk/ops/linear_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d277c8c07f9e30addc31900a12175c8a1f4d7ad
--- /dev/null
+++ b/build/torch210-cxx11-cpu-x86_64-linux/stk/ops/linear_ops.py
@@ -0,0 +1,59 @@
+import torch
+
+from ..backend import sputnik
+from ..matrix import Matrix
+
+
+def dsd(a, b):
+ assert isinstance(a, Matrix)
+ assert isinstance(b, torch.Tensor)
+ return sputnik.dsd(
+ a.size(),
+ a.data, a.offsets,
+ a.row_indices,
+ a.column_indices,
+ a.offsets_t,
+ a.column_indices_t,
+ a.block_offsets_t,
+ not a.is_contiguous(),
+ b)
+
+
+def dds(a, b):
+ assert isinstance(a, torch.Tensor)
+ assert isinstance(b, Matrix)
+ return sputnik.dds(
+ a,
+ b.size(),
+ b.data, b.offsets,
+ b.row_indices,
+ b.column_indices,
+ b.offsets_t,
+ b.column_indices_t,
+ b.block_offsets_t,
+ not b.is_contiguous())
+
+
+def sdd(a, b, topo):
+ assert isinstance(a, torch.Tensor)
+ assert isinstance(b, torch.Tensor)
+ assert isinstance(topo, Matrix)
+ assert topo.is_contiguous()
+ out = sputnik.sdd(
+ a, b,
+ topo.size(),
+ topo.data,
+ topo.offsets,
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets_t,
+ topo.column_indices_t,
+ topo.block_offsets_t)
+ return Matrix(topo.size(),
+ out,
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets,
+ topo.column_indices_t,
+ topo.offsets_t,
+ topo.block_offsets_t)
diff --git a/build/torch210-cxx11-cpu-x86_64-linux/stk/ops/linear_ops_test.py b/build/torch210-cxx11-cpu-x86_64-linux/stk/ops/linear_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1c24d350df9c1b2346c7da885502cd696c88867
--- /dev/null
+++ b/build/torch210-cxx11-cpu-x86_64-linux/stk/ops/linear_ops_test.py
@@ -0,0 +1,216 @@
+import unittest
+import itertools
+import numpy as np
+import torch
+# from absl.testing import parameterized
+
+import stk
+
+
+def allclose(x, y, pct=0.25):
+ mask = torch.isclose(x, y, rtol=5e-2)
+ pct_diff = (mask.numel() - mask.sum()) / mask.numel() * 100
+ if pct_diff > pct:
+ print("{:.2f}% of values not close.".format(pct_diff))
+ return False
+ return True
+
+
+# An assortment of problems designed to make sure
+# the bindings are operating correctly.
+_MATRIX_SIZES = (
+ (128, 128, 128, 0.0),
+ (256, 256, 256, 0.5),
+ (2048, 1024, 512, 0.8),
+ (512, 128, 128, 0.0),
+ (128, 128, 512, 0.0),
+ (1024, 512, 512, 0.0),
+ (1024, 512, 512, 0.5),
+ (1024, 512, 512, 0.75),
+ (512, 512, 1024, 0.0),
+ (512, 512, 1024, 0.5),
+ (512, 512, 1024, 0.75),
+ (1024, 1024, 1024, 0.0),
+ (1024, 1024, 1024, 0.5),
+ (1024, 1024, 1024, 0.75),
+)
+
+_TRANSPOSE = (
+ (False, False),
+ (False, True),
+ (True, False),
+ (True, True),
+)
+
+_DTYPE = (
+ torch.float16, torch.bfloat16
+)
+
+def _generate_testcases():
+ testcases = itertools.product(_MATRIX_SIZES, _TRANSPOSE, _DTYPE)
+ testcases = [(*size, *trans, 128, dtype) for
+ (size, trans, dtype) in testcases]
+ return testcases
+
+_LINEAR_OP_TESTS = _generate_testcases()
+
+def _dense_and_sparse(rows, cols, sparsity, blocking, dtype, std=0.1):
+ mask = stk.random.dense_mask(rows, cols, sparsity, blocking)
+ dense = (torch.randn(rows, cols) * std * mask).type(dtype)
+ sparse = stk.ops.to_sparse(dense, blocking)
+ cuda_device = torch.device("cuda")
+ return (dense.to(cuda_device).requires_grad_(True),
+ sparse.to(cuda_device).requires_grad_(True))
+
+
+def _dense(rows, cols, dtype, std=0.1):
+ cuda_device = torch.device("cuda")
+ out = (torch.randn(rows, cols) * std).type(dtype)
+ return out.to(cuda_device).requires_grad_(True)
+
+
+def _dense_2x(rows, cols, dtype):
+ a = _dense(rows, cols, dtype)
+ return a, a.detach().requires_grad_(True)
+
+
+def _with_transpose(op, a, b, trans_a, trans_b):
+ a = a.t() if trans_a else a
+ b = b.t() if trans_b else b
+ return op(a, b)
+
+
+def _mmm(a, b, topo):
+ mask = stk.ops.to_dense(stk.ops.ones_like(topo))
+ return torch.mm(a, b) * mask
+
+
+def _sparse_out_with_transpose(op, a, b, topo, trans_a, trans_b):
+ a = a.t() if trans_a else a
+ b = b.t() if trans_b else b
+ return op(a, b, topo)
+
+
+def _mask(x, mask):
+ mask = stk.ops.to_dense(stk.ops.ones_like(mask))
+ return x * mask
+
+
+# @parameterized.parameters(*_LINEAR_OP_TESTS)
+# class LinearOpsTest(parameterized.TestCase):
+#
+# def testLinearOps_Dsd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+# # Construct the operands.
+# a_shape = (k, m) if trans_a else (m, k)
+# a_dense, a = _dense_and_sparse(*a_shape, sparsity, blocking, dtype)
+# b_shape = (n, k) if trans_b else (k, n)
+# b, bcp = _dense_2x(*b_shape, dtype)
+#
+# # Execute the matmul.
+# out = _with_transpose(stk.ops.dsd, a, b, trans_a, trans_b)
+# expected_out = _with_transpose(torch.mm, a_dense, bcp, trans_a, trans_b)
+#
+# # Compute the gradients w.r.t. the inputs.
+# expected_out.sum().backward()
+# out.sum().backward()
+#
+# # Validate the results.
+# self.assertEqual(out.dim(), 2)
+# self.assertEqual(expected_out.size()[0], out.size()[0])
+# self.assertEqual(expected_out.size()[1], out.size()[1])
+# self.assertTrue(allclose(out, expected_out))
+#
+# # LHS gradient.
+# grad = stk.ops.to_dense(a.grad)
+# expected_grad = _mask(a_dense.grad, a.grad)
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# # RHS gradient.
+# grad = b.grad
+# expected_grad = bcp.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# def testLinearOps_Dds(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+# # Construct the operands.
+# a_shape = (k, m) if trans_a else (m, k)
+# a, acp = _dense_2x(*a_shape, dtype)
+# b_shape = (n, k) if trans_b else (k, n)
+# b_dense, b = _dense_and_sparse(*b_shape, sparsity, blocking, dtype)
+#
+# # Execute the matmul.
+# out = _with_transpose(stk.ops.dds, a, b, trans_a, trans_b)
+# expected_out = _with_transpose(torch.mm, acp, b_dense, trans_a, trans_b)
+#
+# # Compute the gradients w.r.t. the inputs.
+# expected_out.sum().backward()
+# out.sum().backward()
+#
+# # Validate the results.
+# self.assertEqual(out.dim(), 2)
+# self.assertEqual(expected_out.size()[0], out.size()[0])
+# self.assertEqual(expected_out.size()[1], out.size()[1])
+# self.assertTrue(allclose(out, expected_out))
+#
+# # LHS gradient.
+# grad = a.grad
+# expected_grad = acp.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# # RHS gradient.
+# grad = stk.ops.to_dense(b.grad)
+# expected_grad = _mask(b_dense.grad, b.grad)
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# def testLinearOps_Sdd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+# # Construct the operands.
+# a_shape = (k, m) if trans_a else (m, k)
+# a, acp = _dense_2x(*a_shape, dtype)
+# b_shape = (n, k) if trans_b else (k, n)
+# b, bcp = _dense_2x(*b_shape, dtype)
+# _, topo = _dense_and_sparse(m, n, sparsity, blocking, dtype)
+#
+# # Execute the matmul.
+# out = _sparse_out_with_transpose(stk.ops.sdd, a, b, topo, trans_a, trans_b)
+# expected_out = _sparse_out_with_transpose(_mmm, acp, bcp, topo, trans_a, trans_b)
+#
+# # Compute the gradients w.r.t. the inputs.
+# expected_out.sum().backward()
+# stk.ops.sum(out).backward()
+#
+# # Validate the results.
+# out = stk.ops.to_dense(out)
+# self.assertEqual(out.dim(), 2)
+# self.assertEqual(expected_out.size()[0], out.size()[0])
+# self.assertEqual(expected_out.size()[1], out.size()[1])
+# self.assertTrue(allclose(out, expected_out))
+#
+# # LHS gradient.
+# grad = a.grad
+# expected_grad = acp.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# # RHS gradient.
+# grad = b.grad
+# expected_grad = bcp.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch210-cxx11-cpu-x86_64-linux/stk/ops/matrix_ops.py b/build/torch210-cxx11-cpu-x86_64-linux/stk/ops/matrix_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..447c72dc73439d84f58c917676cc04e64f13e97d
--- /dev/null
+++ b/build/torch210-cxx11-cpu-x86_64-linux/stk/ops/matrix_ops.py
@@ -0,0 +1,98 @@
+from ..backend import sputnik
+from ..matrix import Matrix
+import torch
+import numpy as np
+
+
+@torch.no_grad()
+def row_indices(shape, data, offsets, column_indices):
+ return sputnik.row_indices(shape, data, offsets, column_indices)
+
+
+# TODO(tgale): Replace this helper with a custom kernel. This operation
+# is much simpler to do than how it's currently implemented.
+@torch.no_grad()
+def _expand_for_blocking(idxs, blocking):
+ # Duplicate for block column dimension.
+ idxs = torch.reshape(idxs, [idxs.size()[0], 1, 2]).repeat(1, blocking, 1)
+
+ # Update the column indices.
+ idxs[:, :, 1] *= blocking
+ idxs[:, :, 1] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking])
+
+ # Duplicate for block row dimension.
+ idxs = torch.reshape(idxs, [idxs.size()[0], 1, blocking, 2])
+ idxs = idxs.repeat(1, blocking, 1, 1)
+
+ # Update the row indices.
+ idxs[:, :, :, 0] *= blocking
+ idxs[:, :, :, 0] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking, 1])
+ idxs = torch.reshape(idxs, [-1, 2])
+ return idxs
+
+
+# TODO(tgale): Add input type checking.
+@torch.no_grad()
+def to_dense(x):
+ assert isinstance(x, Matrix)
+
+ shape = (np.prod(x.shape[:-1]), x.shape[-1])
+ row_idxs = x.row_indices.type(torch.int32)
+ col_idxs = x.column_indices.type(torch.int32)
+ indices = _expand_for_blocking(torch.stack([row_idxs, col_idxs], dim=1), x.blocking)
+ indices = (indices[:, 0] * shape[1] + indices[:, 1]).type(torch.int64)
+
+ out = torch.zeros(shape[0] * shape[1], dtype=x.dtype, device=x.device)
+ out.scatter_(0, indices, x.data.flatten())
+ return out.reshape(x.size())
+
+
+@torch.no_grad()
+def _mask(x, blocking=1):
+ assert x.dim() == 2
+ assert x.size()[0] % blocking == 0
+ assert x.size()[1] % blocking == 0
+ block_rows = x.size()[0] // blocking
+ block_cols = x.size()[1] // blocking
+ x = torch.reshape(x, [block_rows, blocking, block_cols, blocking])
+ x = torch.sum(torch.abs(x), dim=(1, 3))
+ return x != 0
+
+
+# TODO(tgale): Add input type checking.
+@torch.no_grad()
+def to_sparse(x, blocking=1):
+ m = _mask(x, blocking)
+
+ # TODO(tgale): Set to appropriate type for input matrix.
+ row_nnzs = torch.sum(m, dim=1).type(torch.int32)
+ zeros = torch.zeros((1,), dtype=row_nnzs.dtype, device=row_nnzs.device)
+ offsets = torch.cat([zeros, torch.cumsum(row_nnzs, dim=0)])
+ offsets = offsets.type(torch.int32)
+
+ indices = torch.nonzero(m).type(torch.int16)
+ row_indices = indices[:, 0]
+ column_indices = indices[:, 1]
+
+ # Nonzero indices in the dense matrix.
+ nonzero_indices = torch.nonzero(m)
+ nonzero_indices = _expand_for_blocking(nonzero_indices, blocking)
+ nonzero_indices = nonzero_indices[:, 0] * x.size()[1] + nonzero_indices[:, 1]
+
+ # Gather the data and construct the sparse matrix.
+ data = torch.gather(x.flatten(), dim=0, index=nonzero_indices)
+ data = torch.reshape(data, [-1, blocking, blocking])
+ return Matrix(x.size(), data, row_indices, column_indices, offsets)
+
+
+@torch.no_grad()
+def ones_like(x):
+ return Matrix(x.size(),
+ torch.ones_like(x.data),
+ x.row_indices,
+ x.column_indices, x.offsets)
+
+
+def sum(x):
+ assert isinstance(x, Matrix)
+ return x.data.sum()
diff --git a/build/torch210-cxx11-cpu-x86_64-linux/stk/ops/matrix_ops_test.py b/build/torch210-cxx11-cpu-x86_64-linux/stk/ops/matrix_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..d172d921f6f08b0e4fb709207a458b0e1e071bd0
--- /dev/null
+++ b/build/torch210-cxx11-cpu-x86_64-linux/stk/ops/matrix_ops_test.py
@@ -0,0 +1,62 @@
+import unittest
+
+# from absl.testing import parameterized
+import stk
+import torch
+
+
+# @parameterized.parameters(
+# (8, 16, 0.0, 1),
+# (8, 16, 0.5, 1),
+# (8, 16, .95, 1),
+# (16, 8, 0.0, 1),
+# (16, 8, 0.5, 1),
+# (16, 8, .95, 1),
+# (8, 16, 0.0, 8),
+# (8, 16, 0.5, 8),
+# (8, 16, 1.0, 8),
+# (16, 8, 0.0, 8),
+# (16, 8, 0.5, 8),
+# (16, 8, 1.0, 8),
+# (128, 256, 0.5, 16),
+# (256, 128, 0.75, 32),
+# (512, 512, .875, 128))
+# class MatrixOpsTest(parameterized.TestCase):
+#
+# def testMatrixOps_FormatConversion(self, rows, cols, sparsity, blocking):
+# mask = stk.random.dense_mask(rows, cols, sparsity, blocking)
+# x = (torch.randn(rows, cols) * mask).type(torch.float16)
+#
+# # Convert the matrix to sparse format.
+# sparse_x = stk.ops.to_sparse(x, blocking)
+#
+# # Validate the matrix.
+# sparse_x.validate()
+#
+# # Validate the shape.
+# self.assertEqual(sparse_x.dim(), 2)
+# self.assertEqual(sparse_x.size()[0], rows)
+# self.assertEqual(sparse_x.size()[1], cols)
+#
+# # Validate the sparsity.
+# numblocks = rows // blocking * cols // blocking
+# nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+# self.assertEqual(sparse_x.nnz, nnz)
+#
+# # Convert back to dense format.
+# dense_x = stk.ops.to_dense(sparse_x)
+#
+# # Validate the shape.
+# self.assertEqual(dense_x.dim(), 2)
+# self.assertEqual(dense_x.size()[0], rows)
+# self.assertEqual(dense_x.size()[1], cols)
+#
+# # Validate the sparsity
+# self.assertEqual(torch.count_nonzero(dense_x).item(), nnz)
+#
+# # Validate the output.
+# self.assertTrue(torch.all(torch.eq(x, dense_x)))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch210-cxx11-cpu-x86_64-linux/stk/random/__init__.py b/build/torch210-cxx11-cpu-x86_64-linux/stk/random/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2576d1ca27283f77569a9a620c7c99fa68aaf30e
--- /dev/null
+++ b/build/torch210-cxx11-cpu-x86_64-linux/stk/random/__init__.py
@@ -0,0 +1,2 @@
+# from stk.random.random_ops import dense_mask, mask, randn
+from .random_ops import dense_mask, mask, randn
diff --git a/build/torch210-cxx11-cpu-x86_64-linux/stk/random/random_ops.py b/build/torch210-cxx11-cpu-x86_64-linux/stk/random/random_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1b36771e0eb8e7abf46bcb3b136b5fb1d29df93
--- /dev/null
+++ b/build/torch210-cxx11-cpu-x86_64-linux/stk/random/random_ops.py
@@ -0,0 +1,36 @@
+import numpy as np
+import torch
+from ..ops import matrix_ops
+
+
+@torch.no_grad()
+def dense_mask(rows, cols, sparsity, blocking=1):
+ assert sparsity >= 0.0 and sparsity <= 1.0
+ assert rows % blocking == 0 and cols % blocking == 0
+
+ block_rows, block_cols = (rows // blocking, cols // blocking)
+ nnz = round(block_rows * block_cols * (1 - sparsity))
+
+ out = np.ones(block_rows * block_cols)
+ mask = np.random.choice(out.size, out.size - nnz, replace=False)
+ out[mask] = 0.0
+
+ out = np.tile(
+ np.reshape(out, [block_rows, 1, block_cols, 1]),
+ (1, blocking, 1, blocking))
+ out = np.reshape(out, [rows, cols])
+ return torch.from_numpy(out.astype(np.float32))
+
+
+@torch.no_grad()
+def mask(m, n, sparsity, blocking=1):
+ out = dense_mask(m, n, sparsity, blocking).type(torch.float16)
+ return matrix_ops.to_sparse(out, blocking=blocking)
+
+
+@torch.no_grad()
+def randn(shape, sparsity, blocking=1):
+ shape_2d = (np.prod(shape[:-1]), shape[-1])
+ out = mask(*shape_2d, sparsity, blocking)
+ out.data.copy_(torch.randn(*out.data.shape))
+ return out.view(*shape)
diff --git a/build/torch210-cxx11-cpu-x86_64-linux/stk/random/random_ops_test.py b/build/torch210-cxx11-cpu-x86_64-linux/stk/random/random_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..71d716b78b5ec009cbf9ac2dfdf09162a0102e62
--- /dev/null
+++ b/build/torch210-cxx11-cpu-x86_64-linux/stk/random/random_ops_test.py
@@ -0,0 +1,73 @@
+import unittest
+
+# from absl.testing import parameterized
+from . import random
+import torch
+
+
+# @parameterized.parameters(
+# (8, 16, 0.0, 1),
+# (8, 16, 0.5, 1),
+# (8, 16, .95, 1),
+# (16, 8, 0.0, 1),
+# (16, 8, 0.5, 1),
+# (16, 8, .95, 1),
+# (8, 16, 0.0, 8),
+# (8, 16, 0.5, 8),
+# (8, 16, 1.0, 8),
+# (16, 8, 0.0, 8),
+# (16, 8, 0.5, 8),
+# (16, 8, 1.0, 8),
+# (128, 256, 0.5, 16),
+# (256, 128, 0.75, 32),
+# (512, 512, .875, 128))
+# class RandomOpsTest(parameterized.TestCase):
+#
+# def testRandomOps_DenseMask(self, rows, cols, sparsity, blocking):
+# mask = random.dense_mask(
+# rows, cols, sparsity, blocking)
+#
+# # Validate the shape.
+# self.assertEqual(mask.dim(), 2)
+# self.assertEqual(mask.size()[0], rows)
+# self.assertEqual(mask.size()[1], cols)
+#
+# # Validate the sparsity
+# numblocks = rows // blocking * cols // blocking
+# nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+# self.assertEqual(
+# torch.count_nonzero(mask).item(),
+# nnz)
+#
+# # Check values are zero or one.
+# self.assertTrue(
+# torch.all(torch.logical_or(
+# torch.eq(mask, 0),
+# torch.eq(mask, 1))))
+#
+# def testRandomOps_SparseMask(self, rows, cols, sparsity, blocking):
+# mask = random.mask(
+# rows, cols, sparsity, blocking)
+#
+# # Validate the matrix.
+# mask.validate()
+#
+# # Validate the shape.
+# self.assertEqual(mask.dim(), 2)
+# self.assertEqual(mask.size()[0], rows)
+# self.assertEqual(mask.size()[1], cols)
+#
+# # Validate the sparsity.
+# numblocks = rows // blocking * cols // blocking
+# nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+# self.assertEqual(mask.nnz, nnz)
+#
+# # Check values are zero or one.
+# self.assertTrue(
+# torch.all(torch.logical_or(
+# torch.eq(mask.data, 0),
+# torch.eq(mask.data, 1))))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch210-cxx11-cpu-x86_64-linux/xpu_fused_moe.py b/build/torch210-cxx11-cpu-x86_64-linux/xpu_fused_moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..0440b484e6f4073a384dcf3bc562174d2ef71d25
--- /dev/null
+++ b/build/torch210-cxx11-cpu-x86_64-linux/xpu_fused_moe.py
@@ -0,0 +1,672 @@
+# SPDX-License-Identifier: Apache-2.0
+# MegaBlocks XPU Fused MoE Implementation
+import os
+import torch
+
+from ._ops import ops, add_op_namespace_prefix
+
+from torch.library import register_fake
+
+
+def resolve_dtensor(weight: torch.Tensor):
+ """Convert DTensor to local tensor for use with custom ops."""
+ from torch.distributed._tensor import DTensor
+ if isinstance(weight, DTensor):
+ return weight.to_local()
+ return weight
+
+
+# Register fake/meta kernels for torch.compile compatibility
+def _register_xpu_fake_kernels():
+ """Register fake kernels for XPU MoE operations to support torch.compile."""
+
+ def _register_if_available(op_name, fn):
+ if hasattr(ops, op_name):
+ register_fake(add_op_namespace_prefix(op_name))(fn)
+
+ _register_if_available(
+ "cutlass_grouped_gemm_interface",
+ lambda ptr_A, ptr_B, ptr_scales, ptr_bias, ptr_D, expert_first_token_offset, N, K, num_experts, is_B_int4, is_B_mxfp4: ptr_D,
+ )
+
+ _register_if_available(
+ "fused_moe_prologue",
+ lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, ep_rank, ep_size, num_experts_on_rank: None,
+ )
+
+ _register_if_available(
+ "moe_gather",
+ lambda output, moe_output, topk_weights, permuted_row_to_unpermuted_row, unpermuted_row_to_permuted_row, expert_first_token_offset, num_experts: None,
+ )
+
+ _register_if_available(
+ "silu_and_mul",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "mul_and_silu",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_and_mul",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_tanh_and_mul",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_fast",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_new",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_quick",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "swigluoai_and_mul",
+ lambda out, input, alpha=1.702, limit=7.0: None,
+ )
+
+
+# Register fake kernels on module load
+_register_xpu_fake_kernels()
+
+
+# default
+def cutlass_grouped_gemm(input_A, input_B, bias, output, expert_token_count, n,
+ k, num_experts):
+ # expert_token_count_ = torch.tensor(expert_token_count,
+ # dtype=torch.int64,
+ # device=input_A.device)
+ # if bias is not None:
+ # bias = bias.repeat_interleave(expert_token_count_, dim=0).float()
+
+ def exclusive_prefix_sum(arr):
+ prefix = [0]
+ for i, x in enumerate(arr):
+ prefix.append(prefix[-1] + x)
+ return prefix
+
+ expert_offset = torch.tensor(exclusive_prefix_sum(expert_token_count),
+ dtype=torch.int64,
+ device="xpu")
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=None,
+ ptr_bias=bias,
+ ptr_D=output,
+ expert_first_token_offset=expert_offset,
+ N=n,
+ K=k,
+ num_experts=num_experts,
+ is_B_int4=False,
+ is_B_mxfp4=False)
+
+
+def cutlass_grouped_gemm_xe2(input_A, input_B, scales, bias, output,
+ num_rows_per_expert, n, k, num_experts, is_B_int4,
+ is_B_mxfp4):
+ expert_first_token_offset = torch.cat([
+ torch.tensor([0],
+ dtype=num_rows_per_expert.dtype,
+ device=num_rows_per_expert.device),
+ torch.cumsum(num_rows_per_expert, dim=0)
+ ]).to(torch.int64)
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=scales,
+ ptr_bias=bias,
+ ptr_D=output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=n,
+ K=k,
+ num_experts=num_experts,
+ is_B_int4=is_B_int4,
+ is_B_mxfp4=is_B_mxfp4)
+
+
+def ceilDiv(a, b):
+ return (a + b - 1) // b
+
+
+def compute_num_tokens_per_block(num_tokens, num_experts_per_node):
+ for num_tokens_per_block in [32, 64, 128, 256, 512, 1024]:
+ num_blocks_per_seq = ceilDiv(num_tokens, num_tokens_per_block)
+ if num_blocks_per_seq * num_experts_per_node <= num_tokens_per_block:
+ return num_tokens_per_block
+ return 1024
+
+
+def _bytes_to_typed_tensor(byte_tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
+ """Reinterpret a uint8 buffer as a typed tensor by copying bytes.
+
+ This avoids `Tensor.view(dtype)` which can fail under torch.compile
+ constant folding when shape divisibility is not proven.
+ """
+ if byte_tensor.dtype != torch.uint8:
+ raise ValueError("byte_tensor must be uint8")
+ itemsize = torch.empty((), dtype=dtype).element_size()
+ numel = byte_tensor.numel() // itemsize
+ out = torch.empty((numel,), dtype=dtype, device=byte_tensor.device)
+ out.view(torch.uint8).copy_(byte_tensor.contiguous())
+ return out
+
+
+def implement_zp(qweight):
+ # change u4 to s4 to avoid zero point in gemm kernel
+ # only support default zero point now
+ assert qweight.dtype == torch.uint8, "Input tensor must be uint8"
+
+ high_u4 = (qweight >> 4) & 0x0F
+ low_u4 = qweight & 0x0F
+
+ high_s8 = high_u4.to(torch.int8)
+ low_s8 = low_u4.to(torch.int8)
+
+ high_s8 = high_s8 - 8
+ low_s8 = low_s8 - 8
+
+ def pack_compact(a, b):
+
+ def process_number(x):
+ sign = (x < 0).to(torch.uint8)
+ abs_low3 = (x.view(torch.uint8) & 0x7).to(torch.uint8)
+ return (sign << 3) | abs_low3
+
+ packed_a = process_number(a)
+ packed_b = process_number(b)
+
+ return (packed_a << 4) | packed_b
+
+ result = pack_compact(high_s8, low_s8)
+
+ return result
+
+
+def xpu_fused_moe(hidden_states,
+ w13,
+ w13_scales,
+ w13_bias,
+ w2,
+ w2_scales,
+ w2_bias,
+ topk_weights,
+ topk_ids,
+ n_experts_per_token,
+ activation,
+ num_experts,
+ ep_rank=0,
+ ep_size=1,
+ is_fp8=False,
+ is_int4=False,
+ is_mxfp4=False):
+ '''
+ hidden_states: [num_rows, hidden_size]
+ w13: [num_experts, 2*inter_size, hidden_size]
+ w13_scales:
+ None for bf16/fp16
+ or [num_experts] for fp8
+ or [num_experts, 2*inter_size, hidden_size // group_size] for 4bits
+ w13_bias: [num_experts, 2*inter_size] or None
+ w2: [num_experts, hidden_size, inter_size]
+ w2_scales:
+ None for bf16/fp16
+ or [num_experts] for fp8
+ or [num_experts, hidden_size, inter_size // group_size] for 4bits
+ w2_bias: [num_experts, hidden_size] or None
+ topk_weights: [num_rows, topk]
+ topk_ids: [num_rows, topk]
+ n_experts_per_token: int
+ activation: str
+ num_experts: int
+ is_int4: bool
+ is_mxfp4: bool
+ '''
+
+ # Resolve DTensors to local tensors before passing to custom ops
+ hidden_states = resolve_dtensor(hidden_states)
+ w13 = resolve_dtensor(w13)
+ w2 = resolve_dtensor(w2)
+ if w13_scales is not None:
+ w13_scales = resolve_dtensor(w13_scales)
+ if w13_bias is not None:
+ w13_bias = resolve_dtensor(w13_bias)
+ if w2_scales is not None:
+ w2_scales = resolve_dtensor(w2_scales)
+ if w2_bias is not None:
+ w2_bias = resolve_dtensor(w2_bias)
+ topk_weights = resolve_dtensor(topk_weights)
+ topk_ids = resolve_dtensor(topk_ids)
+
+ output = torch.empty_like(hidden_states)
+ num_rows, hidden_size = list(hidden_states.shape)
+
+ dim_last = w13.shape[-1]
+ dim_second_last = w13.shape[-2]
+
+ # w13 is combined gate+up weights, so one dimension is 2*inter_size
+ # Determine which dimension is hidden_size and which is 2*inter_size
+ if dim_second_last == hidden_size:
+ # w13 is [E, hidden_size, 2*inter_size] - standard layout
+ inter_size = dim_last // 2
+ needs_transpose = False
+ else:
+ # w13 is [E, 2*inter_size, hidden_size] - needs transpose
+ inter_size = dim_second_last // 2
+ needs_transpose = True
+
+ assert w13.is_contiguous() and w2.is_contiguous()
+
+ # 4bits support [E, N, K]
+ # other types [E, K, N]
+ if not is_int4 and not is_mxfp4:
+ if not hasattr(w13, 'xpu_fused_moe'):
+ if needs_transpose:
+ w13.data = w13.transpose(-1, -2).contiguous()
+ w2.data = w2.transpose(-1, -2).contiguous()
+ w13.xpu_fused_moe = True
+ w13.inter_size = inter_size
+ else:
+ inter_size = w13.inter_size
+
+ if is_int4 and not hasattr(w13, 'xpu_fused_moe'):
+ w13_tmp = torch.empty_like(w13)
+ w2_tmp = torch.empty_like(w2)
+ for i in range(num_experts):
+ w13_tmp[i] = implement_zp(w13[i])
+ w2_tmp[i] = implement_zp(w2[i])
+ w13_tmp = w13_tmp.contiguous()
+ w2_tmp = w2_tmp.contiguous()
+ w13.data = w13_tmp
+ w2.data = w2_tmp
+ w13.xpu_fused_moe = True
+
+ # TODO: will all integrated in Cpp func. Temporary expose before gemm fusion
+ num_experts_per_node = num_experts
+ experts_per_token = n_experts_per_token
+ num_moe_inputs = n_experts_per_token * num_rows
+ permuted_elems = num_moe_inputs * hidden_size
+ # interbuf_elems = num_moe_inputs * inter_size
+ permuted_row_to_unpermuted_row_size = num_moe_inputs * 4
+ permuted_token_selected_experts_size = num_moe_inputs * 4
+ src_to_dest_map_size = experts_per_token * num_rows * 4
+ expert_first_token_offset_size = (num_experts_per_node + 1) * 8
+ num_tokens_per_block = compute_num_tokens_per_block(
+ num_rows, num_experts_per_node)
+ num_blocks_per_seq = ceilDiv(num_rows, num_tokens_per_block)
+ blocked_expert_counts_size = num_experts_per_node * num_blocks_per_seq * 4
+ blocked_expert_counts_cumsum_size = blocked_expert_counts_size
+ blocked_row_to_unpermuted_row_size = num_experts_per_node * num_rows * 4
+ permuted_data_size = permuted_elems * hidden_states.element_size()
+ permuted_token_final_scales_size = num_moe_inputs * 4
+
+ ws_map = {}
+ map_offset = 0
+
+ def config_ws(name, size):
+ nonlocal map_offset
+ if size % 256 != 0:
+ size += 256 - size % 256
+ ws_map[name] = (size, map_offset)
+ map_offset += size
+
+ config_ws("permuted_row_to_unpermuted_row",
+ permuted_row_to_unpermuted_row_size)
+ config_ws("permuted_token_selected_experts",
+ permuted_token_selected_experts_size)
+ config_ws("unpermuted_row_to_permuted_row", src_to_dest_map_size)
+ config_ws("blocked_expert_counts", blocked_expert_counts_size)
+ config_ws("blocked_expert_counts_cumsum",
+ blocked_expert_counts_cumsum_size)
+ config_ws("blocked_row_to_unpermuted_row",
+ blocked_row_to_unpermuted_row_size)
+ config_ws("expert_first_token_offset", expert_first_token_offset_size)
+ config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
+ config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
+
+ workspace = torch.zeros(map_offset,
+ dtype=torch.uint8,
+ device=hidden_states.device)
+ if topk_ids.dtype == torch.int32:
+ topk_ids = topk_ids.to(torch.int64)
+ ops.fused_moe_prologue(
+ input=hidden_states,
+ token_selected_experts=topk_ids,
+ token_final_scales=topk_weights,
+ workspace=workspace,
+ hidden_size=hidden_size,
+ inter_size=inter_size,
+ ep_rank=ep_rank,
+ ep_size=ep_size,
+ num_experts_on_rank=num_experts_per_node)
+
+ expert_first_token_offset_bytes = workspace[
+ ws_map["expert_first_token_offset"][1]:
+ ws_map["expert_first_token_offset"][1] +
+ expert_first_token_offset_size]
+ unpermuted_row_to_permuted_row_bytes = workspace[
+ ws_map["unpermuted_row_to_permuted_row"][1]:
+ ws_map["unpermuted_row_to_permuted_row"][1] +
+ src_to_dest_map_size]
+ permuted_row_to_unpermuted_row_bytes = workspace[
+ ws_map["permuted_row_to_unpermuted_row"][1]:
+ ws_map["permuted_row_to_unpermuted_row"][1] +
+ permuted_row_to_unpermuted_row_size]
+
+ if torch.compiler.is_compiling():
+ expert_first_token_offset = _bytes_to_typed_tensor(
+ expert_first_token_offset_bytes, torch.int64
+ )
+ unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
+ unpermuted_row_to_permuted_row_bytes, torch.int32
+ )
+ permuted_row_to_unpermuted_row = _bytes_to_typed_tensor(
+ permuted_row_to_unpermuted_row_bytes, torch.int32
+ )
+ else:
+ expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
+ unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
+ permuted_row_to_unpermuted_row = permuted_row_to_unpermuted_row_bytes.view(torch.int32)
+ gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
+ ws_map["overlapped_gemm1_gemm2_inputs"][1] +
+ permuted_data_size].view(hidden_states.dtype).view(
+ num_moe_inputs, hidden_size)
+ # permuted_token_final_scales = workspace[
+ # ws_map["permuted_token_final_scales"][1]:
+ # ws_map["permuted_token_final_scales"][1] +
+ # permuted_token_final_scales_size].view(torch.float)
+ gemm1_output = torch.empty((num_moe_inputs, 2 * inter_size),
+ dtype=hidden_states.dtype,
+ device=hidden_states.device)
+
+ ########### gemm1 ##################
+ input_B = w13
+
+ if not is_fp8 and not is_int4 and not is_mxfp4:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=gemm1_input,
+ ptr_B=input_B,
+ ptr_scales=None,
+ ptr_bias=w13_bias,
+ ptr_D=gemm1_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=2 * inter_size,
+ K=hidden_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+ else:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=gemm1_input,
+ ptr_B=input_B,
+ ptr_scales=w13_scales,
+ ptr_bias=w13_bias,
+ ptr_D=gemm1_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=2 * inter_size,
+ K=hidden_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+
+ # act
+ act_output = torch.empty((num_moe_inputs, inter_size),
+ dtype=gemm1_output.dtype,
+ device=gemm1_output.device)
+ if activation == "silu":
+ ops.silu_and_mul(act_output, gemm1_output)
+ elif activation == "gelu":
+ ops.gelu_and_mul(act_output, gemm1_output)
+ elif activation == "swigluoai":
+ ops.swigluoai_and_mul(act_output, gemm1_output, 1.702, 7.0)
+ else:
+ raise ValueError(f"Unsupported FusedMoe activation: {activation}.")
+
+ ########### gemm2 ##################
+ input_A = act_output.contiguous()
+ input_B = w2
+ gemm2_output = torch.empty((num_moe_inputs, hidden_size),
+ dtype=hidden_states.dtype,
+ device=hidden_states.device)
+ if not is_fp8 and not is_int4 and not is_mxfp4:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=None,
+ ptr_bias=w2_bias,
+ ptr_D=gemm2_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=hidden_size,
+ K=inter_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+ else:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=w2_scales,
+ ptr_bias=w2_bias,
+ ptr_D=gemm2_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=hidden_size,
+ K=inter_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+
+ ops.moe_gather(output, gemm2_output, topk_weights,
+ permuted_row_to_unpermuted_row,
+ unpermuted_row_to_permuted_row,
+ expert_first_token_offset,
+ num_experts_per_node)
+ return output
+
+
+def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor:
+ """Apply jitter to the input tensor for regularization."""
+ low = 1.0 - moe_jitter_eps
+ high = 1.0 + moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return x * (low + noise * (high - low))
+
+
+def compute_top_k(scores: torch.Tensor, moe_top_k: int):
+ """Compute the top-k scores from the logits."""
+ if moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, moe_top_k, dim=-1)
+
+
+def route_tokens_xpu(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: torch.Tensor,
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ training: bool = False,
+) -> tuple:
+ """Route tokens to experts and compute expert weights and indices (XPU version)."""
+ if training and moe_jitter_eps is not None:
+ x = apply_jitter(x, moe_jitter_eps)
+
+ x_flat = x.view(-1, x.shape[-1])
+ logits = torch.nn.functional.linear(x_flat, router_weight, router_bias)
+ expert_weights, expert_indices = compute_top_k(logits, moe_top_k)
+ expert_weights = expert_weights.softmax(dim=-1)
+ if moe_normalize_expert_weights is not None:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+
+ return logits, expert_weights, expert_indices
+
+
+def _get_device_mesh(model):
+ """Extract device_mesh from child's unused pre_hook closure for EP support."""
+ try:
+ hook = next(
+ h
+ for h in model.experts._forward_pre_hooks.values()
+ if "device_mesh" in h.__code__.co_freevars
+ )
+ return hook.__closure__[
+ hook.__code__.co_freevars.index("device_mesh")
+ ].cell_contents
+ except Exception:
+ return None
+
+
+class MegaBlocksMoeMLP(torch.nn.Module):
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> tuple:
+ """
+ Forward pass through the MoE layer.
+
+ Args:
+ x: Input tensor of shape [batch_size, seq_len, hidden_size] or [tokens, hidden_size]
+
+ Returns:
+ Tuple of (output, expert_weights) where:
+ - output: Tensor of same shape as input
+ - expert_weights: Expert weights for each token [tokens, top_k]
+ """
+
+ # Get MoE parameters from the wrapped modules
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+
+ # Get EP (Expert Parallelism) parameters
+ ep_size = 1
+ ep_rank = 0
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = _get_device_mesh(self)
+ if device_mesh is not None:
+ expert_parallel_group = device_mesh.get_group()
+ if expert_parallel_group is not None:
+ import torch.distributed as dist
+ if dist.is_initialized():
+ ep_size = dist.get_world_size(expert_parallel_group)
+ ep_rank = dist.get_rank(expert_parallel_group)
+
+ # Number of experts on this rank
+ num_experts_on_rank = moe_num_experts // ep_size
+
+ # Detect activation type - check for GptOss-style swigluoai activation
+ # GptOssExperts has alpha and limit attributes for swigluoai
+ if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
+ activation = "swigluoai"
+ else:
+ activation = getattr(self.experts, "activation", "silu")
+
+ # Get weight tensors - support different naming conventions
+ if hasattr(self.experts, "gate_up_proj"):
+ w13 = self.experts.gate_up_proj
+ # NOTE: swigluoai_and_mul kernel expects interleaved layout [g0,u0,g1,u1,...]
+ # which matches GptOss's gate_up_proj layout, so no conversion needed.
+
+ elif hasattr(self.experts, "w1"):
+ # Combine w1 and w3 if stored separately
+ w1 = self.experts.w1
+ w3 = getattr(self.experts, "w3", None)
+ if w3 is not None:
+ w13 = torch.cat([w1, w3], dim=-2)
+ else:
+ w13 = w1
+ else:
+ raise AttributeError("experts module must have 'gate_up_proj' or 'w1' attribute")
+
+ if hasattr(self.experts, "down_proj"):
+ w2 = self.experts.down_proj
+ elif hasattr(self.experts, "w2"):
+ w2 = self.experts.w2
+ else:
+ raise AttributeError("experts module must have 'down_proj' or 'w2' attribute")
+
+ # Get optional bias tensors
+ w13_bias = getattr(self.experts, "gate_up_proj_bias", None)
+ w2_bias = getattr(self.experts, "down_proj_bias", None)
+
+ # Get quantization info
+ is_fp8 = getattr(self.experts, "is_fp8", False)
+ is_int4 = getattr(self.experts, "is_int4", False)
+ is_mxfp4 = getattr(self.experts, "is_mxfp4", False)
+
+ w13_scales = getattr(self.experts, "gate_up_proj_scales", None)
+ w2_scales = getattr(self.experts, "down_proj_scales", None)
+
+ # Store original shape
+ in_shape = x.size()
+
+ # Route tokens to experts
+ logits, expert_weights, expert_indices = route_tokens_xpu(
+ x,
+ self.router.weight,
+ getattr(self.router, "bias", None),
+ moe_top_k,
+ moe_num_experts,
+ moe_jitter_eps,
+ moe_normalize_expert_weights,
+ self.training,
+ )
+
+ # Reshape input for fused MoE
+ x_flat = x.view(-1, x.shape[-1])
+
+ # Call XPU fused MoE kernel
+ output = xpu_fused_moe(
+ hidden_states=x_flat,
+ w13=w13,
+ w13_scales=w13_scales,
+ w13_bias=w13_bias,
+ w2=w2,
+ w2_scales=w2_scales,
+ w2_bias=w2_bias,
+ topk_weights=expert_weights.float(),
+ topk_ids=expert_indices,
+ n_experts_per_token=moe_top_k,
+ activation=activation,
+ num_experts=num_experts_on_rank,
+ ep_rank=ep_rank,
+ ep_size=ep_size,
+ is_fp8=is_fp8,
+ is_int4=is_int4,
+ is_mxfp4=is_mxfp4,
+ )
+
+ # All-reduce across EP group to combine partial expert outputs
+ if ep_size > 1 and expert_parallel_group is not None:
+ import torch.distributed as dist
+ dist.all_reduce(output, op=dist.ReduceOp.SUM, group=expert_parallel_group)
+
+ # Restore original shape
+ output = output.view(in_shape)
+
+ return output, expert_weights
+
+
+# Export classes and functions
+__all__ = [
+ "MegaBlocksMoeMLP",
+ "xpu_fused_moe",
+ "cutlass_grouped_gemm",
+ "cutlass_grouped_gemm_xe2",
+]
\ No newline at end of file
diff --git a/build/torch210-cxx11-cu126-aarch64-linux/__init__.py b/build/torch210-cxx11-cu126-aarch64-linux/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..38075732c6d8fa0e1e6ef493145e1aca3851ae6b
--- /dev/null
+++ b/build/torch210-cxx11-cu126-aarch64-linux/__init__.py
@@ -0,0 +1,202 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+from ._ops import ops
+
+from .grouped_gemm import backend as gg_backend
+from .grouped_gemm import ops as gg_ops
+
+
+from ._layers.arguments import Arguments
+from ._layers.dmoe import ParallelDroplessMLP, dMoE
+from ._layers.glu import SparseGLU
+from ._layers.mlp import MLP, SparseMLP
+from ._layers.moe import MoE, ParallelMLP, get_load_balancing_loss
+
+from . import layers
+
+# This section contains the direct kernel exports (not inlcuded in the original code)
+def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
+ """
+ Compute exclusive cumulative sum along the specified dimension.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ result = ops.exclusive_cumsum(x, dim)
+ out.copy_(result)
+ return out
+
+
+def inclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
+ """
+ Compute inclusive cumulative sum along the specified dimension.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ result = ops.inclusive_cumsum(x, dim)
+ out.copy_(result)
+ return out
+
+
+def histogram(x: torch.Tensor, num_bins: int) -> torch.Tensor:
+ """
+ Compute histogram of input tensor values.
+
+ Args:
+ x: Input tensor
+ num_bins: Number of histogram bins
+
+ Returns:
+ Histogram tensor with counts for each bin
+ """
+ return ops.histogram(x, num_bins)
+
+
+def indices(
+ padded_bins: torch.Tensor,
+ block_size: int,
+ output_block_rows: int,
+ output_block_columns: int,
+) -> torch.Tensor:
+ """
+ Construct indices from padded bins for sparse operations.
+
+ Args:
+ padded_bins: Tensor containing bin boundaries
+ block_size: Size of each block
+ output_block_rows: Number of rows in output blocks
+ output_block_columns: Number of columns in output blocks
+
+ Returns:
+ Tensor containing constructed indices
+ """
+ return ops.indices(padded_bins, block_size, output_block_rows, output_block_columns)
+
+
+def replicate_forward(
+ x: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Forward pass of replicate operation - replicate values according to bin sizes.
+
+ Args:
+ x: Input tensor with values to replicate
+ bins: Tensor containing bin sizes
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ return ops.replicate_forward(x, bins, out)
+
+
+def replicate_backward(
+ grad: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Backward pass of replicate operation - reduce gradients back to bins.
+
+ Args:
+ grad: Gradient tensor to reduce
+ bins: Tensor containing bin sizes
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ return ops.replicate_backward(grad, bins, out)
+
+
+def sort(
+ x: torch.Tensor, end_bit: int, x_out: torch.Tensor, iota_out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Radix sort with index tracking.
+
+ Args:
+ x: Input tensor to sort
+ end_bit: Number of bits to consider in sorting
+ x_out: Output tensor for sorted values
+ iota_out: Output tensor for sorted indices
+
+ Returns:
+ The sorted values tensor
+ """
+ return ops.sort(x, end_bit, x_out, iota_out)
+
+
+# Convenience functions for common use cases
+def cumsum(x: torch.Tensor, dim: int = -1, exclusive: bool = False) -> torch.Tensor:
+ """
+ Compute cumulative sum with automatic output allocation.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum (default: last dimension)
+ exclusive: Whether to compute exclusive (True) or inclusive (False) cumsum
+
+ Returns:
+ New tensor containing the cumulative sum
+ """
+ out = torch.empty_like(x)
+ if exclusive:
+ return exclusive_cumsum(x, dim, out)
+ else:
+ return inclusive_cumsum(x, dim, out)
+
+
+def argsort(x: torch.Tensor, end_bit: int = 32) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Sort tensor and return both sorted values and indices.
+
+ Args:
+ x: Input tensor to sort
+ end_bit: Number of bits to consider in sorting
+
+ Returns:
+ Tuple of (sorted_values, sorted_indices)
+ """
+ x_out = torch.empty_like(x)
+ iota_out = torch.empty_like(x)
+ sort(x, end_bit, x_out, iota_out)
+ return x_out, iota_out
+
+
+# Export public API
+__all__ = [
+ "MyReplacementLayer",
+ # Direct kernel exports
+ "exclusive_cumsum",
+ "inclusive_cumsum",
+ "histogram",
+ "indices",
+ "replicate_forward",
+ "replicate_backward",
+ "sort",
+ "cumsum",
+ "argsort",
+ # Original exports
+ "Arguments",
+ "ParallelDroplessMLP",
+ "dMoE",
+ "SparseGLU",
+ "MLP",
+ "SparseMLP",
+ "MoE",
+ "ParallelMLP",
+ "get_load_balancing_loss",
+]
diff --git a/build/torch210-cxx11-cu126-aarch64-linux/_layers/__init__.py b/build/torch210-cxx11-cu126-aarch64-linux/_layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a720e7a2cc4e44636f6e433a2750e945dc38e8b2
--- /dev/null
+++ b/build/torch210-cxx11-cu126-aarch64-linux/_layers/__init__.py
@@ -0,0 +1,10 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# from megablocks.layers.dmoe import dMoE
+from .moe import MoE
+
+__all__ = [
+ 'MoE',
+ # 'dMoE',
+]
diff --git a/build/torch210-cxx11-cu126-aarch64-linux/_layers/activation_fn.py b/build/torch210-cxx11-cu126-aarch64-linux/_layers/activation_fn.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e1d956704840aa4daf7d1d71d24e051567feab9
--- /dev/null
+++ b/build/torch210-cxx11-cu126-aarch64-linux/_layers/activation_fn.py
@@ -0,0 +1,33 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Callable, Union
+
+import torch
+from ..stk import Matrix
+
+
+def act_fn(
+ x: Matrix,
+ function: Callable,
+ return_grad_fn: bool = False,
+ **kwargs,
+) -> Union[tuple[Matrix, Any] | Matrix]:
+ assert isinstance(x, Matrix)
+ with torch.set_grad_enabled(torch.is_grad_enabled() or return_grad_fn):
+ if return_grad_fn:
+ x.data.requires_grad = True
+ out = function(x.data, **kwargs)
+ y = Matrix(
+ x.size(),
+ out,
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
+ if return_grad_fn:
+ return y, out.backward
+ return y
diff --git a/build/torch210-cxx11-cu126-aarch64-linux/_layers/all_to_all.py b/build/torch210-cxx11-cu126-aarch64-linux/_layers/all_to_all.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ac7067bcaa34db1d82b340c43550fe3577aa7a3
--- /dev/null
+++ b/build/torch210-cxx11-cu126-aarch64-linux/_layers/all_to_all.py
@@ -0,0 +1,54 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import torch.distributed as dist
+
+
+class AllToAllOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op):
+ out = torch.empty((sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype)
+
+ ctx.input_shape = x.shape
+ ctx.output_split_sizes = output_split_sizes
+ ctx.input_split_sizes = input_split_sizes
+ ctx.group = group
+ handle = dist.all_to_all_single(
+ out,
+ x,
+ output_split_sizes=output_split_sizes,
+ input_split_sizes=input_split_sizes,
+ group=group,
+ async_op=async_op,
+ )
+ return out, handle
+
+ @staticmethod
+ def backward(ctx, grad, _):
+ if ctx.needs_input_grad[0]:
+ out = torch.empty(
+ ctx.input_shape,
+ device=grad.device,
+ dtype=grad.dtype,
+ )
+ dist.all_to_all_single(
+ out,
+ grad,
+ output_split_sizes=ctx.input_split_sizes,
+ input_split_sizes=ctx.output_split_sizes,
+ group=ctx.group,
+ )
+ return out, None, None, None, None
+ return None, None, None, None, None
+
+
+def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False):
+ return AllToAllOp.apply(
+ x,
+ output_split_sizes,
+ input_split_sizes,
+ group,
+ async_op,
+ )
diff --git a/build/torch210-cxx11-cu126-aarch64-linux/_layers/arguments.py b/build/torch210-cxx11-cu126-aarch64-linux/_layers/arguments.py
new file mode 100644
index 0000000000000000000000000000000000000000..4db9b1bd38bc2e2f421625c124f86b85f45c5ae0
--- /dev/null
+++ b/build/torch210-cxx11-cu126-aarch64-linux/_layers/arguments.py
@@ -0,0 +1,101 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import dataclasses
+from functools import partial
+from typing import Any, Callable, Optional, Union
+
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+
+# import megablocks.grouped_gemm_util as grouped_gemm
+from .. import grouped_gemm_util as grouped_gemm
+
+# Type annotation for in-place Tensor initialization function.
+InitFn = Union[Callable[[torch.Tensor], None], partial[torch.Tensor]]
+
+_ALLOWED_BITWIDTHS = (-1, 4, 8)
+
+DEFAULT_ACTIVATION_FN = partial(F.gelu, approximate='tanh')
+
+
+@dataclasses.dataclass
+class Arguments:
+ # Model arguments.
+ hidden_size: int = 1024
+ ffn_hidden_size: int = 4096
+ num_layers: int = 1
+ bias: bool = True
+ return_bias: bool = True
+ activation_fn: Optional[Callable] = DEFAULT_ACTIVATION_FN
+
+ # MoE arguments.
+ moe_num_experts: int = 1
+ moe_top_k: int = 1
+ moe_capacity_factor: int = 1
+ moe_normalize_expert_weights: Optional[Union[int, float]] = None
+ moe_loss_weight: float = 0.1
+ moe_jitter_eps: Optional[float] = None
+ moe_lbl_in_fp32: bool = False
+
+ # Parallelism arguments.
+ moe_expert_model_parallelism: bool = False
+ expert_parallel_group: Optional[dist.ProcessGroup] = None
+ pipeline_model_parallel_size: int = 1
+ num_layers_per_virtual_pipeline_stage: Optional[int] = None
+
+ # Compute arguments.
+ memory_optimized_mlp: bool = False
+ mlp_type: str = 'mlp'
+ mlp_impl: str = 'sparse'
+
+ # Initialization arguments.
+ fp16: bool = True
+ bf16: bool = False
+ device: Union[int, torch.device] = dataclasses.field(default_factory=torch.cuda.current_device)
+ init_method: InitFn = partial(torch.nn.init.normal_, mean=0.0, std=0.02)
+ output_layer_init_method: InitFn = init_method
+
+ # Benchmarking arguments.
+ uniform_expert_assignment: bool = False
+
+ # shared expert arguments
+ shared_expert: bool = False # enable using shared expert
+ fc_cls: Any = torch.nn.Linear # class of the fully connected layer in shared expert (purpose: to allow using custom FC layer eg te.Linear (for FP8))
+ fc_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict,) # kwargs for custom fc layers
+ remat_act_fn: bool = True # enable act fn to be rematerialized instead of stored
+ shared_expert_hidden_size: Optional[
+ int] = None # hidden size of the shared expert IF we want to set it to something different from hidden_size
+ shared_expert_weighted_sum: bool = False # enable using weighted sum for shared expert output (wieghted by number of experts used)
+
+ # Router Z-loss arguments
+ moe_zloss_weight: float = 0 # 1e-3 is a reasonable value
+ moe_zloss_in_fp32: bool = False
+
+ def __post_init__(self):
+ # Sparse MLP is not supported with triton >=3.2.0
+ # TODO: Remove this once sparse is supported with triton >=3.2.0
+ if self.__getattribute__('mlp_impl') == 'sparse':
+ try:
+ import triton
+ if triton.__version__ >= '3.2.0':
+ raise ValueError(
+ 'Sparse MLP is not supported with triton >=3.2.0. Please use mlp_impl="grouped" instead.',
+ )
+ except ImportError:
+ raise ImportError('Triton is required for sparse MLP implementation')
+
+ if self.__getattribute__('mlp_impl') == 'grouped':
+ grouped_gemm.assert_grouped_gemm_is_available()
+
+ if self.shared_expert_hidden_size is None:
+ self.shared_expert_hidden_size = self.ffn_hidden_size
+
+
+def from_megatron(megatron_args: Any):
+ args = Arguments()
+ for field in dataclasses.fields(args):
+ if hasattr(megatron_args, field.name):
+ setattr(args, field.name, getattr(megatron_args, field.name))
+ return args
diff --git a/build/torch210-cxx11-cu126-aarch64-linux/_layers/common.py b/build/torch210-cxx11-cu126-aarch64-linux/_layers/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d07109702963ba48a3b94ab860807954dfd79c1
--- /dev/null
+++ b/build/torch210-cxx11-cu126-aarch64-linux/_layers/common.py
@@ -0,0 +1,26 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+from .arguments import Arguments
+
+
+def dtype(args: Arguments):
+ if args.fp16:
+ return torch.float16
+ elif args.bf16:
+ return torch.bfloat16
+ return None
+
+
+def cast_if_autocast_enabled(tensor):
+ if torch.is_autocast_enabled():
+ if tensor.device.type == 'cuda':
+ dtype = torch.get_autocast_gpu_dtype()
+ elif tensor.device.type == 'cpu':
+ dtype = torch.get_autocast_cpu_dtype()
+ else:
+ raise NotImplementedError()
+ return tensor.to(dtype=dtype)
+ return tensor
diff --git a/build/torch210-cxx11-cu126-aarch64-linux/_layers/dmlp_registry.py b/build/torch210-cxx11-cu126-aarch64-linux/_layers/dmlp_registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..de2ed047042e438c7190ebb139b6f7f30009734c
--- /dev/null
+++ b/build/torch210-cxx11-cu126-aarch64-linux/_layers/dmlp_registry.py
@@ -0,0 +1,42 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Union
+
+from . import glu, mlp
+from .arguments import Arguments
+
+MlpType = Union[mlp.SparseMLP, glu.SparseGLU]
+
+_REGISTRY = {
+ 'mlp': {
+ 'grouped': mlp.GroupedMLP,
+ 'sparse': mlp.SparseMLP,
+ },
+ 'glu': {
+ 'grouped': glu.GroupedGLU,
+ 'sparse': glu.SparseGLU,
+ },
+}
+
+
+def get(args: Arguments) -> MlpType:
+ """Returns an MLP for use in a dMoE instance.
+
+ Uses the provided arguments to instantiate the appropriate
+ MLP instance. This only contains MLPs for use in dMoEs
+ (ie. only for the dropless versions of MoEs).
+
+ Args:
+ args: propagated Arguments dataclass.
+
+ Returns:
+ An instantiated MLP constructed using the input args.
+ """
+ if args.mlp_type not in _REGISTRY:
+ raise ValueError(f'Unsupported mlp type: {args.mlp_type}')
+
+ if args.mlp_impl not in _REGISTRY[args.mlp_type]:
+ raise ValueError(f'{args.mlp_type} does not support {args.mlp_impl} backend.',)
+
+ return _REGISTRY[args.mlp_type][args.mlp_impl](args)
diff --git a/build/torch210-cxx11-cu126-aarch64-linux/_layers/dmoe.py b/build/torch210-cxx11-cu126-aarch64-linux/_layers/dmoe.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d0375a4df2f27134c4127e60be04f3b45693050
--- /dev/null
+++ b/build/torch210-cxx11-cu126-aarch64-linux/_layers/dmoe.py
@@ -0,0 +1,337 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import numpy as np
+import torch
+
+# try:
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/dmoe.py is needed.',
+# )
+
+# import megablocks.ops as ops
+# # from megablocks.ops import ops
+# from megablocks.layers import common, dmlp_registry, moe, mpu
+# from megablocks.layers.arguments import Arguments
+
+from .. import stk
+from .. import ops
+from . import common, dmlp_registry, moe, mpu
+from .arguments import Arguments
+
+def promote_scalar(x):
+ return x.view(1) if not len(x.size()) else x
+
+
+class ParallelDroplessMLP(moe.ParallelMLP):
+
+ def __init__(self, args: Arguments):
+ super(ParallelDroplessMLP, self).__init__(args)
+ self.hidden_size = args.hidden_size
+ self.ffn_hidden_size = mpu.features_per_rank(args)
+ self.blocking = 128
+ self.mlp = dmlp_registry.get(args)
+
+ # Calculate the number of bits needed to represent the column indices
+ # in the intermediate sparse matrix.
+ max_column_index = ((self.ffn_hidden_size * self.num_experts) // self.blocking)
+ self.transpose_sort_end_bit = max(
+ int(np.ceil(np.log2(max_column_index))),
+ 1,
+ )
+
+ def sparse_transpose(self, size, row_indices, column_indices, offsets):
+ block_columns = size[1] // self.blocking
+
+ # Sort row indices by column indices to get the transposed matrix's
+ # column indices.
+ #
+ # NOTE: Our sort operation uses the same width indices as the input values.
+ # To avoid overflow when we have large activation matrices we cast to
+ # 32-bit before sorting.
+ _, gather_indices = ops.sort(
+ column_indices.int(),
+ self.transpose_sort_end_bit,
+ )
+
+ # There are a constant number of blocks in every row of the sparse matrix.
+ # A blocks offset is:
+ #
+ # row_index * blocks_per_row + column_index % blocks_per_row
+ #
+ # Once we have the block offsets ordered for transposition we can divide
+ # by blocks_per_row to get the transposed column indices.
+ column_indices_t = row_indices.gather(0, gather_indices.long())
+ block_offsets_t = gather_indices.int()
+
+ zero = torch.zeros((1,), dtype=torch.int32, device=row_indices.device)
+ nnz_per_column = ops.histogram(column_indices, block_columns)
+ nnz_per_column = ops.inclusive_cumsum(nnz_per_column, 0)
+ if nnz_per_column.dim() == 0:
+ # This addresses an edge case when ffn_hidden_size is equal to self.blocking.
+ nnz_per_column = nnz_per_column.unsqueeze(0)
+ offsets_t = torch.cat([zero, nnz_per_column])
+ return column_indices_t, offsets_t, block_offsets_t
+
+ def topology(self, x, padded_bins):
+ padded_tokens, _ = x.size()
+ assert padded_tokens % self.blocking == 0
+ if self.ffn_hidden_size % self.blocking != 0:
+ raise ValueError(
+ f'The ffn_hidden_size {self.ffn_hidden_size} must be divisible by ' +
+ f'the block size {self.blocking}. Please update your configuration.',
+ )
+
+ # Offsets for the sparse matrix. All rows have the
+ # same number of nonzero blocks dictated by the
+ # dimensionality of a single expert.
+ block_rows = padded_tokens // self.blocking
+ blocks_per_row = self.ffn_hidden_size // self.blocking
+ offsets = torch.arange(
+ 0,
+ block_rows * blocks_per_row + 1,
+ blocks_per_row,
+ dtype=torch.int32,
+ device=x.device,
+ )
+
+ # Indices for the sparse matrix. The indices for
+ # the intermediate matrix are dynamic depending
+ # on the mapping of tokens to experts.
+ column_indices = ops.topology(
+ padded_bins,
+ self.blocking,
+ block_rows,
+ blocks_per_row,
+ )
+
+ # TODO(tgale): This is unused. Remove the need for this in stk.
+ # For now, use meta init to save the device memory.
+ data = torch.empty(
+ column_indices.numel(),
+ self.blocking,
+ self.blocking,
+ dtype=common.dtype(self.args),
+ device='meta',
+ )
+ shape = (
+ padded_tokens,
+ self.ffn_hidden_size * mpu.experts_per_rank(self.args),
+ )
+ row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
+ column_indices_t, offsets_t, block_offsets_t = self.sparse_transpose(
+ shape,
+ row_indices,
+ column_indices,
+ offsets,
+ )
+ return stk.Matrix(
+ shape,
+ data,
+ row_indices,
+ column_indices,
+ offsets,
+ column_indices_t,
+ offsets_t,
+ block_offsets_t,
+ )
+
+ def indices_and_padded_bins(self, top_experts):
+ # Sort the expert ids to produce the scatter/gather
+ # indices for the permutation.
+ top_experts = top_experts.int()
+ bin_ids, indices = ops.sort(top_experts, self.sort_end_bit)
+
+ # Histogram the expert ids to identify the number of
+ # tokens routed to each expert.
+ tokens_per_expert = ops.histogram(top_experts, self.num_experts)
+
+ # Round the token counts up to the block size used in
+ # the matrix muliplications. Caculate the starting
+ # position of each bin.
+ padded_tokens_per_expert = ops.round_up(
+ tokens_per_expert,
+ self.blocking,
+ )
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ padded_bins = promote_scalar(padded_bins)
+
+ # Calculate the bin bounds for the sorted tokens.
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ bins = promote_scalar(bins)
+ return indices, bin_ids, bins, padded_bins, tokens_per_expert
+
+ def sparse_forward_once(self, x, expert_weights, top_experts):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, padded_bins, tokens_per_expert = (self.indices_and_padded_bins(top_experts))
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.padded_gather(
+ x,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ self.top_k,
+ )
+
+ # Create the sparse matrix topology.
+ with torch.no_grad():
+ topo = self.topology(x, padded_bins)
+
+ # Perform the expert computation.
+ x = self.mlp(x, topo)
+
+ # Un-route the data for the MoE output.
+ x = ops.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ padded_bins,
+ self.top_k,
+ )
+ return x, tokens_per_expert
+
+ # For use in the base-class parallel_forward_once.
+ def sparse_permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy, # unused
+ top_k,
+ ):
+
+ # Round the token counts up to the block size used in the matrix
+ # multiplication. Calculate the starting position of each bin.
+ padded_tokens_per_expert = ops.round_up(
+ tokens_per_expert,
+ self.blocking,
+ )
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ padded_bins = promote_scalar(padded_bins)
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
+
+ # Create the sparse matrix topology.
+ with torch.no_grad():
+ topo = self.topology(x, padded_bins)
+
+ # Perform the expert computation.
+ x = self.mlp(x, topo)
+
+ # Un-route the data for the MoE output.
+ return ops.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ def grouped_forward_once(self, x, expert_weights, top_experts):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ out = self.grouped_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ -1, # unused
+ self.args.moe_top_k,
+ )
+ return out, tokens_per_expert
+
+ def grouped_permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy, # unused
+ top_k,
+ ):
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.gather(x, indices, bin_ids, bins, top_k)
+
+ # Perform the expert computation.
+ x = self.mlp(x, tokens_per_expert)
+
+ # Un-route the data for the MoE output.
+ return ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
+
+ def forward_once(self, x, expert_weights, top_experts):
+ if self.args.mlp_impl == 'sparse':
+ return self.sparse_forward_once(x, expert_weights, top_experts)
+ else:
+ return self.grouped_forward_once(x, expert_weights, top_experts)
+
+ def permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ ):
+ if self.args.mlp_impl == 'sparse':
+ return self.sparse_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ )
+ else:
+ return self.grouped_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ )
+
+
+class dMoE(moe.MoE):
+
+ def _init_experts_mlp(self, args: Arguments):
+ return ParallelDroplessMLP(args)
diff --git a/build/torch210-cxx11-cu126-aarch64-linux/_layers/gelu.py b/build/torch210-cxx11-cu126-aarch64-linux/_layers/gelu.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4c9e6532798615b5c12c96694241a4c18ee8f7b
--- /dev/null
+++ b/build/torch210-cxx11-cu126-aarch64-linux/_layers/gelu.py
@@ -0,0 +1,52 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# try:
+# import stk
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/gelu.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+import torch.nn.functional as F
+
+
+@torch.jit.script
+def _gelu_backward_inplace(g, x):
+ tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
+ ff = (0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out))
+ return g.mul_(ff)
+
+
+def gelu_backward_(grad: stk.Matrix, x: stk.Matrix):
+ # NOTE: The two sparse matrices must have the same topology.
+ if isinstance(grad, stk.Matrix) and isinstance(x, stk.Matrix):
+ return stk.Matrix(
+ x.size(),
+ _gelu_backward_inplace(grad.data, x.data),
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
+ return _gelu_backward_inplace(grad, x)
+
+
+def gelu(x: stk.Matrix):
+ assert isinstance(x, stk.Matrix)
+ return stk.Matrix(
+ x.size(),
+ F.gelu(x.data, approximate='tanh'),
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
diff --git a/build/torch210-cxx11-cu126-aarch64-linux/_layers/glu.py b/build/torch210-cxx11-cu126-aarch64-linux/_layers/glu.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f297a41ff6a1a2a285f5b461951672364b898da
--- /dev/null
+++ b/build/torch210-cxx11-cu126-aarch64-linux/_layers/glu.py
@@ -0,0 +1,244 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# import stk.ops
+# try:
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/glu.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+
+# from megablocks import grouped_gemm_util as gg
+# from megablocks.layers import common, mpu
+# from megablocks.layers.activation_fn import act_fn
+# from megablocks.layers.arguments import Arguments
+# from megablocks.layers.mlp import (
+# SharedMLP,
+# SparseMLP,
+# create_dmoe_expert_weights,
+# resolve_dtensor,
+# )
+
+from .. import grouped_gemm_util as gg
+from . import common, mpu
+from .activation_fn import act_fn
+from .arguments import Arguments
+from .mlp import (
+ SharedMLP,
+ SparseMLP,
+ create_dmoe_expert_weights,
+ resolve_dtensor,
+)
+
+
+class SparseGLU(SparseMLP):
+
+ def __init__(self, args: Arguments):
+ super().__init__(args)
+ self.v1 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ with torch.no_grad():
+ self.v1.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ ),
+ )
+
+ mpu.set_expert_model_parallel_attributes(
+ self.v1,
+ self._should_set_parallelism_attribute,
+ )
+
+ def forward(self, x, topo):
+ if self.args.memory_optimized_mlp:
+ raise NotImplementedError(
+ 'Memory optimized implementation not yet supported with GLU with sparse kernels.',
+ )
+
+ w1, v1, w2 = self.scale_grad(self.w1), self.scale_grad(self.v1,), self.scale_grad(self.w2)
+ w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2)
+
+ # Compute the GLU.
+ x1 = stk.ops.sdd(x, w1.t(), topo)
+ x2 = stk.ops.sdd(x, v1.t(), topo)
+
+ activation_fn_out = act_fn(x1, self.args.activation_fn)
+ x1 = stk.ops.mul(activation_fn_out, x2)
+
+ return stk.ops.dsd(x1, w2)
+
+
+class MemoryOptimizedGroupedGLU(torch.autograd.Function):
+ """GroupedMLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, v1, w2, batch_sizes, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ v1 = v1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], v1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not v1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1', 'v1' and 'w2'.")
+
+ # Layer 0: x @ w1.t().
+ assert gg.backend is not None
+ sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True)
+ v1_out = gg.backend.gmm(x, v1, batch_sizes, trans_b=True)
+
+ # GeLU.
+ activation_fn_out = activation_fn(sdd_out) * v1_out
+
+ # Layer 1: x @ w2.
+ dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, v1, w2, batch_sizes, x, sdd_out, v1_out)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx, ddsd_out):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # Unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, v1, w2 = saved_tensors[:3]
+ batch_sizes = saved_tensors[3]
+ x = saved_tensors[4]
+ sdd_out, v1_out = saved_tensors[5:7]
+
+ # Rematerialize activation_fn output.
+ activation_fn = ctx.activation_fn
+ with torch.set_grad_enabled(True):
+ sdd_out.requires_grad = True
+ v1_out.requires_grad = True
+ activation_fn_out = activation_fn(sdd_out) * v1_out
+ activation_grad_fn = activation_fn_out.backward
+
+ # Compute dw2 with recomputed activation_fn output.
+ assert gg.backend is not None
+ dw2 = gg.backend.gmm(
+ activation_fn_out,
+ ddsd_out,
+ batch_sizes,
+ trans_a=True,
+ )
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ gg.backend.gmm(
+ ddsd_out,
+ w2,
+ batch_sizes,
+ trans_b=True,
+ c=dactivation_fn_out,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out)
+ dsdd_out = sdd_out.grad
+ dv1_out = v1_out.grad
+
+ # Compute dw1.
+ dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True)
+
+ # Compute dv1.
+ dv1 = gg.backend.gmm(dv1_out, x, batch_sizes, trans_a=True)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ dx = ddsd_out
+ gg.backend.gmm(dsdd_out, w1, batch_sizes, c=dx)
+ dx += gg.backend.gmm(dv1_out, v1, batch_sizes)
+ return dx, dw1, dv1, dw2, None, None
+
+
+memory_optimized_grouped_glu = MemoryOptimizedGroupedGLU.apply
+
+
+class GroupedGLU(SparseGLU):
+
+ def forward(self, x, tokens_per_expert):
+ batch_sizes = tokens_per_expert.cpu().to(torch.long)
+ w1, v1, w2 = (
+ self.scale_grad(self.w1),
+ self.scale_grad(self.v1),
+ self.scale_grad(self.w2),
+ )
+ w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2)
+
+ # Re-shape the weights for the grouped GEMMs.
+ ne = mpu.experts_per_rank(self.args)
+ w1 = w1.view(ne, -1, self.args.hidden_size)
+ v1 = v1.view(ne, -1, self.args.hidden_size)
+ w2 = w2.view(ne, -1, self.args.hidden_size)
+
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_grouped_glu(
+ x,
+ w1,
+ v1,
+ w2,
+ batch_sizes,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ assert gg.ops is not None
+ x1 = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
+ x2 = gg.ops.gmm(x, v1, batch_sizes, trans_b=True)
+ x1 = self.args.activation_fn(x1) * x2
+ return gg.ops.gmm(x1, w2, batch_sizes)
+
+
+class SharedGLU(SharedMLP):
+ """GPU for shared expert.
+
+ Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTGLU class
+ """
+
+ def __init__(self, args: Arguments):
+ super().__init__(args)
+ self.gate_proj = args.fc_cls(
+ args.hidden_size,
+ self.args.shared_expert_hidden_size,
+ **self.fc_kwargs,
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))
diff --git a/build/torch210-cxx11-cu126-aarch64-linux/_layers/memory_test.py b/build/torch210-cxx11-cu126-aarch64-linux/_layers/memory_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..74d1166931b712635131985b25a89f4ca23e576d
--- /dev/null
+++ b/build/torch210-cxx11-cu126-aarch64-linux/_layers/memory_test.py
@@ -0,0 +1,103 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import gc
+
+import torch
+import torch.distributed as dist
+
+# from megablocks.layers import arguments, dmoe
+from . import arguments, dmoe
+
+_TESTS = ((8, 2048, 4096, 4096, 32, 4),)
+
+
+def get_tensors():
+ ptrs = set()
+ out = []
+ for obj in gc.get_objects():
+ if torch.is_tensor(obj):
+ if not obj.is_contiguous() or obj.data_ptr() in ptrs:
+ continue
+ out.append(obj)
+ ptrs.add(obj.data_ptr())
+ return out
+
+
+def test_memory(
+ group,
+ batch_size,
+ sequence_length,
+ hidden_size,
+ ffn_hidden_size,
+ num_experts,
+ top_k,
+):
+ args = arguments.Arguments(
+ hidden_size=hidden_size,
+ ffn_hidden_size=ffn_hidden_size,
+ moe_num_experts=num_experts,
+ moe_top_k=top_k,
+ moe_expert_model_parallelism=True,
+ expert_parallel_group=group,
+ fp16=False,
+ bf16=True,
+ device=torch.cuda.current_device(),
+ )
+ layer = dmoe.dMoE(args).cuda()
+
+ x = torch.randn((batch_size, sequence_length, hidden_size),
+ device=torch.cuda.current_device(),
+ dtype=torch.bfloat16).requires_grad_(True)
+ torch.cuda.empty_cache()
+
+ # Run forward + backward.
+ # with torch.autograd.detect_anomaly():
+ out, _ = layer(x)
+ out.mean().backward()
+
+ # Report peak memory.
+ mem = torch.cuda.max_memory_allocated()
+ print('Max Memory Allocated = {:0.0f}MiB'.format(mem / 1e6))
+ print('Max Memory Reserved = {:0.0f}MiB'.format(torch.cuda.max_memory_reserved() / 1e6,),)
+
+ # Calculate weight and gradient memory usage.
+ weight_memory = 2 * (
+ layer.router.layer.weight.numel() + layer.experts.mlp.w1.numel() + layer.experts.mlp.w2.numel()
+ )
+
+ def grad_numel(x):
+ if x.grad is not None:
+ return x.grad.numel()
+ return 0
+
+ grad_memory = 2 * (
+ grad_numel(layer.router.layer.weight) + grad_numel(layer.experts.mlp.w1) + grad_numel(layer.experts.mlp.w2)
+ )
+ weight_memory += grad_memory
+
+ print('Weight Memory Allocated = {:0.0f}MiB'.format(weight_memory / 1e6))
+ print('Activation Memory Allocated = {:0.0f}MiB'.format((mem - weight_memory) / 1e6,),)
+
+ # Manually calculate GPU memory usage from the garbage
+ # collector.
+ gc.collect()
+ total = 0
+ tensors = get_tensors()
+ tensors = sorted(tensors, key=lambda x: -x.numel())
+ for i, t in enumerate(tensors):
+ total += t.numel()
+ print(f'{i}: {t.shape}, {t.numel() * 2}')
+ del tensors
+
+ print('Total Bytes Found = {:0.0f}MiB'.format(total * 2 / 1e6))
+
+
+if __name__ == '__main__':
+ assert dist.is_available()
+ group = dist.init_process_group(backend='nccl')
+ local_rank = dist.get_rank(group)
+ torch.cuda.set_device(local_rank)
+
+ for args in _TESTS:
+ test_memory(group, *args)
diff --git a/build/torch210-cxx11-cu126-aarch64-linux/_layers/mlp.py b/build/torch210-cxx11-cu126-aarch64-linux/_layers/mlp.py
new file mode 100644
index 0000000000000000000000000000000000000000..c99afb9904c24a8b6a83e79059cd1251dbbfd99e
--- /dev/null
+++ b/build/torch210-cxx11-cu126-aarch64-linux/_layers/mlp.py
@@ -0,0 +1,587 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# try:
+# import stk
+# import stk.backend.triton_kernels
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/mlp.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+from packaging import version
+
+# from megablocks import grouped_gemm_util as gg
+# from megablocks.layers import common, gelu, mpu
+# from megablocks.layers.activation_fn import act_fn
+# from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
+
+from .. import grouped_gemm_util as gg
+from . import common, gelu, mpu
+from .activation_fn import act_fn
+from .arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
+
+class ScaleGradient(torch.autograd.Function):
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx: Any, x: torch.Tensor, scale: float):
+ ctx.scale = scale
+ return x
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx: torch.Tensor, grad: torch.Tensor):
+ return grad * ctx.scale, None
+
+
+scale_gradient = ScaleGradient.apply
+
+
+def resolve_dtensor(weight: torch.Tensor):
+ if version.parse(torch.__version__) >= version.parse('2.0.0'):
+ from torch.distributed._tensor import DTensor
+ if isinstance(weight, DTensor):
+ return weight.to_local()
+ return weight
+
+
+def create_moe_expert_weights(
+ args: Arguments,
+ num_experts: int,
+ ffn_hidden_size: int,
+ hidden_size: int,
+ init_method: InitFn,
+):
+ # Create the entire weight matrix such that the sampled weights will
+ # not vary between data parallelism and expert model parallelism for
+ # the same random seed.
+ master_weights = torch.empty(
+ num_experts,
+ ffn_hidden_size,
+ hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ )
+ init_method(master_weights)
+
+ if not args.moe_expert_model_parallelism:
+ return master_weights
+
+ # Calculate the amount of sharding in each dimension.
+ expert_sharding_degree = mpu.expert_sharding_degree(args)
+ hidden_sharding_degree = mpu.hidden_sharding_degree(args)
+
+ # Calculate the experts per rank.
+ #
+ # NOTE: We assign ranks to be expert parallel before going
+ # tensor parallel.
+ rank = mpu.get_expert_parallel_rank(args)
+ expert_rank = rank % expert_sharding_degree
+ num_experts_per_rank = num_experts // expert_sharding_degree
+ start_expert = expert_rank * num_experts_per_rank
+ end_expert = (expert_rank + 1) * num_experts_per_rank
+
+ # Calculate the rows per rank.
+ row_rank = rank // expert_sharding_degree
+ num_rows_per_rank = ffn_hidden_size // hidden_sharding_degree
+ start_row = row_rank * num_rows_per_rank
+ end_row = (row_rank + 1) * num_rows_per_rank
+
+ # Slice the weight matrix to get the chunk for this rank.
+ with torch.no_grad():
+ weights = master_weights[start_expert:end_expert, start_row:end_row]
+ return weights
+
+
+class MLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ # expert_parallel_world_size = mpu.get_expert_parallel_world_size(args)
+ experts_per_rank = mpu.experts_per_rank(args)
+
+ self.w1 = torch.nn.Parameter(
+ torch.empty(
+ experts_per_rank,
+ args.hidden_size,
+ mpu.features_per_rank(args),
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ self.w2 = torch.nn.Parameter(
+ torch.empty(
+ experts_per_rank,
+ mpu.features_per_rank(args),
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w1,
+ args.moe_expert_model_parallelism,
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w2,
+ args.moe_expert_model_parallelism,
+ )
+
+ # Initialize the parameters for the MLP.
+ #
+ # NOTE: It is important that we create the weight tensors prior
+ # to creating the master weights and slicing our the piece for
+ # this rank. If the master weights are created first the PyTorch
+ # caching allocator appears to use the same memory block for these
+ # and the slice which causes large increases in our peak memory
+ # usage.
+ with torch.no_grad():
+ w1 = create_moe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ )
+ self.w1.copy_(w1.transpose(1, 2).contiguous())
+ self.w2.copy_(
+ create_moe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.output_layer_init_method,
+ ),
+ )
+
+ self.gradient_scale = None
+ if self.args.moe_expert_model_parallelism:
+ self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,)
+
+ def scale_grad(self, w):
+ if self.gradient_scale is None:
+ return w
+ return scale_gradient(w, self.gradient_scale)
+
+ def forward(self, x):
+ w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2)
+ w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2)
+ x = torch.bmm(x, w1)
+ x = self.args.activation_fn(x)
+ return torch.bmm(x, w2)
+
+
+def create_dmoe_expert_weights(
+ args: Arguments,
+ num_experts: int,
+ rows: int,
+ columns: int,
+ init_method: InitFn,
+):
+ weights = create_moe_expert_weights(
+ args,
+ num_experts,
+ rows,
+ columns,
+ init_method,
+ )
+ return weights.view([-1, columns])
+
+
+class MemoryOptimizedMLP(torch.autograd.Function):
+ """Sparse MLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, w2, topo, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.")
+
+ topo_tensors = (
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets,
+ topo.column_indices_t,
+ topo.offsets_t,
+ topo.block_offsets_t,
+ )
+
+ # Layer 0: x @ w1.t().
+ sdd_out = stk.ops.sdd(x, w1.t(), topo)
+
+ # GeLU.
+ activation_fn_out = act_fn(sdd_out, activation_fn)
+
+ # Layer 1: x @ w2.
+ dsd_out = stk.ops.dsd(activation_fn_out, w2)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.shape = topo.shape
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.data.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, w2, *topo_tensors, x, sdd_out.data)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx, ddsd_out):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, w2 = saved_tensors[:2]
+ topo_tensors = saved_tensors[2:8]
+ x = saved_tensors[8]
+ sdd_out_data = saved_tensors[9]
+
+ # rematerialize activation function output
+ activation_fn = ctx.activation_fn
+ sdd_out = stk.Matrix(ctx.shape, sdd_out_data, *topo_tensors)
+ activation_fn_out, activation_grad_fn = act_fn(
+ sdd_out,
+ activation_fn,
+ return_grad_fn=True,
+ )
+
+ # Compute dw2 with recomputed activation_fn output.
+ dw2 = stk.ops.dsd(activation_fn_out.t(), ddsd_out)
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ stk.backend.triton_kernels.sdd(
+ ddsd_out,
+ w2.t(),
+ dactivation_fn_out.shape,
+ dactivation_fn_out.data,
+ dactivation_fn_out.offsets,
+ dactivation_fn_out.row_indices,
+ dactivation_fn_out.column_indices,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ if activation_fn is DEFAULT_ACTIVATION_FN:
+ dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out)
+ else:
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out.data)
+ dsdd_out = stk.Matrix(ctx.shape, sdd_out.data.grad, *topo_tensors)
+
+ # Compute dw1.
+ dw1 = stk.ops.dsd(dsdd_out.t(), x)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ stk.backend.triton_kernels.dsd(
+ dsdd_out.shape,
+ dsdd_out.data,
+ dsdd_out.offsets,
+ dsdd_out.row_indices,
+ dsdd_out.column_indices,
+ dsdd_out.offsets_t,
+ dsdd_out.column_indices_t,
+ dsdd_out.block_offsets_t,
+ False,
+ w1,
+ ddsd_out,
+ )
+ dx = ddsd_out
+ return dx, dw1, dw2, None, None
+
+
+memory_optimized_mlp = MemoryOptimizedMLP.apply
+
+
+class SparseMLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ self._num_rows_per_rank = mpu.experts_per_rank(args) * mpu.features_per_rank(args)
+
+ self.w1 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ self.w2 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+
+ # Initialize the parameters for the MLP.
+ #
+ # NOTE: It is important that we create the weight tensors prior
+ # to creating the master weights and slicing our the piece for
+ # this rank. If the master weights are created first the PyTorch
+ # caching allocator appears to use the same memory block for these
+ # and the slice which causes large increases in our peak memory
+ # usage.
+ with torch.no_grad():
+ self.w1.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ ),
+ )
+ self.w2.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.output_layer_init_method,
+ ),
+ )
+
+ self._should_set_parallelism_attribute = args.moe_expert_model_parallelism
+ mpu.set_expert_model_parallel_attributes(
+ self.w1,
+ self._should_set_parallelism_attribute,
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w2,
+ self._should_set_parallelism_attribute,
+ )
+
+ self.gradient_scale = None
+ if self.args.moe_expert_model_parallelism:
+ self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,)
+
+ def scale_grad(self, w):
+ if self.gradient_scale is None:
+ return w
+ return scale_gradient(w, self.gradient_scale)
+
+ def forward(self, x, topo):
+ w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2)
+ w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2)
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_mlp(
+ x,
+ w1,
+ w2,
+ topo,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ x = stk.ops.sdd(x, w1.t(), topo)
+ activation_fn_out = act_fn(x, self.args.activation_fn)
+ return stk.ops.dsd(activation_fn_out, w2)
+
+
+class MemoryOptimizedGroupedMLP(torch.autograd.Function):
+ """GroupedMLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, w2, batch_sizes, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.")
+
+ # Layer 0: x @ w1.t().
+ assert gg.backend is not None
+ sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True)
+
+ # activation_fn
+ activation_fn_out = activation_fn(sdd_out)
+
+ # Layer 1: x @ w2.
+ dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, w2, batch_sizes, x, sdd_out)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx: Any, ddsd_out: torch.Tensor):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # Unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, w2 = saved_tensors[:2]
+ batch_sizes = saved_tensors[2]
+ x = saved_tensors[3]
+ sdd_out = saved_tensors[4]
+
+ # Rematerialize activation_fn output.
+ activation_fn = ctx.activation_fn
+ with torch.set_grad_enabled(True):
+ sdd_out.requires_grad = True
+ activation_fn_out = activation_fn(sdd_out)
+ activation_grad_fn = activation_fn_out.backward
+
+ # Compute dw2 with recomputed activation_fn output.
+ assert gg.backend is not None
+ dw2 = gg.backend.gmm(
+ activation_fn_out,
+ ddsd_out,
+ batch_sizes,
+ trans_a=True,
+ )
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ gg.backend.gmm(
+ ddsd_out,
+ w2,
+ batch_sizes,
+ trans_b=True,
+ c=dactivation_fn_out,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ if activation_fn is DEFAULT_ACTIVATION_FN:
+ dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out)
+ else:
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out)
+ dsdd_out = sdd_out.grad
+
+ # Compute dw1.
+ dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ gg.backend.gmm(dsdd_out, w1, batch_sizes, c=ddsd_out)
+ dx = ddsd_out
+ return dx, dw1, dw2, None, None
+
+
+memory_optimized_grouped_mlp = MemoryOptimizedGroupedMLP.apply
+
+
+class GroupedMLP(SparseMLP):
+
+ def forward(self, x, tokens_per_expert):
+ batch_sizes = tokens_per_expert.cpu().to(torch.long)
+ w1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.w2))
+
+ # Re-shape the weights for the grouped GEMMs.
+ ne = mpu.experts_per_rank(self.args)
+ w1 = resolve_dtensor(w1).view(ne, -1, self.args.hidden_size)
+ w2 = resolve_dtensor(w2).view(ne, -1, self.args.hidden_size)
+
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_grouped_mlp(
+ x,
+ w1,
+ w2,
+ batch_sizes,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ assert gg.ops is not None
+ x = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
+ x = self.args.activation_fn(x)
+ return gg.ops.gmm(x, w2, batch_sizes)
+
+
+class SharedMLP(torch.nn.Module):
+ """MLP for shared expert.
+
+ Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTMLP class
+ """
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ self.fc_kwargs: dict[str, Any] = {
+ 'bias': args.bias,
+ 'device': args.device,
+ }
+ self.fc_kwargs.update(args.fc_kwargs)
+
+ self.up_proj = args.fc_cls(
+ args.hidden_size,
+ args.shared_expert_hidden_size,
+ **self.fc_kwargs,
+ )
+ self.act = args.activation_fn
+ self.down_proj = args.fc_cls(
+ args.shared_expert_hidden_size,
+ args.hidden_size,
+ **self.fc_kwargs,
+ )
+ self.down_proj._is_residual = True # a flag for llm-foundry init
+
+ def add_experts_sharedexpert(
+ self,
+ shared_expert_out: torch.Tensor,
+ expert_out: torch.Tensor,
+ ) -> torch.Tensor:
+ # Helper function to add expert output to shared expert output
+ # with optional weighted sum.
+ if self.args.shared_expert_weighted_sum:
+ # enable using weighted sum for shared expert output
+ # wieghted by number of experts used
+ t_experts = self.args.moe_top_k + 1
+ sh_mlp_out = shared_expert_out / t_experts
+ return sh_mlp_out.add(
+ expert_out,
+ alpha=(self.args.moe_top_k / t_experts),
+ )
+
+ return shared_expert_out + expert_out
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.down_proj(self.act(self.up_proj(x)))
diff --git a/build/torch210-cxx11-cu126-aarch64-linux/_layers/moe.py b/build/torch210-cxx11-cu126-aarch64-linux/_layers/moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0a4aeaacc9c86fc70944e730c53f7a55644e05e
--- /dev/null
+++ b/build/torch210-cxx11-cu126-aarch64-linux/_layers/moe.py
@@ -0,0 +1,507 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Optional, Tuple
+
+import numpy as np
+import torch
+import torch.distributed as dist
+
+# import megablocks.ops as ops
+# from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry
+# from megablocks.layers.all_to_all import all_to_all
+# from megablocks.layers.arguments import Arguments
+
+from ..ops import (
+ sort,
+ histogram,
+ inclusive_cumsum,
+ exclusive_cumsum,
+ binned_gather,
+ binned_scatter,
+ gather,
+ scatter,
+ repeat,
+ replicate,
+)
+
+from . import common, mlp, mpu, router, sharedexpert_registry
+from .arguments import Arguments
+from .all_to_all import all_to_all
+
+_LOAD_BALANCING_LOSS = []
+
+
+def save_load_balancing_loss(loss):
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.append(loss)
+
+
+def get_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ return _LOAD_BALANCING_LOSS
+
+
+def clear_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.clear()
+
+
+def batched_load_balancing_loss(args: Arguments):
+ if args.moe_loss_weight == 0:
+ return 0.0
+
+ # tokens_per_expert[i].shape = (num_experts)
+ # expert_scores[i].shape = (tokens, num_experts)
+ tokens_per_expert, expert_scores = zip(*get_load_balancing_loss())
+ num_layers_per_pipeline_stage = (args.num_layers // args.pipeline_model_parallel_size)
+ if args.num_layers_per_virtual_pipeline_stage is not None:
+ num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage
+
+ if len(tokens_per_expert) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f'Expected {num_layers_per_pipeline_stage} token_per_experts '
+ f'but found {len(tokens_per_expert)}.\nnum_layers = '
+ f'{args.num_layers}\npipeline_model_parallel_size = '
+ f'{args.pipeline_model_parallel_size}\n'
+ 'num_layers_per_virtual_pipeline_stage'
+ f' = {args.num_layers_per_virtual_pipeline_stage}',
+ )
+ if len(expert_scores) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f'Expected {num_layers_per_pipeline_stage} expert_scores '
+ f'but found {len(tokens_per_expert)}.\nnum_layers = '
+ f'{args.num_layers}\npipeline_model_parallel_size = '
+ f'{args.pipeline_model_parallel_size}\n'
+ 'num_layers_per_virtual_pipeline_stage'
+ f' = {args.num_layers_per_virtual_pipeline_stage}',
+ )
+
+ # Verify the shape of the tokens_per_expert and expert_scores tensors.
+ assert all((x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert))
+
+ tokens = expert_scores[0].shape[0]
+ assert all(((x.ndim == 2 and x.shape[1] == args.moe_num_experts and x.shape[0] == tokens) for x in expert_scores))
+
+ # Concatenate the contributions of each layer and convert to
+ # the correct types and formats for the dot product.
+ expert_scores = torch.cat(expert_scores, dim=1)
+ if args.moe_lbl_in_fp32:
+ expert_scores = expert_scores.float()
+ if tokens != 0:
+ expert_scores = expert_scores.mean(dim=0)
+ else:
+ expert_scores = expert_scores.sum(dim=0)
+ tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype)
+
+ expected_values = num_layers_per_pipeline_stage * args.moe_num_experts
+ assert tokens_per_expert.numel() == expected_values
+ assert expert_scores.numel() == expected_values
+
+ # Calculate the total scale across all factors.
+ #
+ # loss_weight * num_experts / (num_layers * tokens * top_k)
+ scale_numerator = (args.moe_num_experts * args.moe_loss_weight)
+ scale_denominator = (args.num_layers * tokens * args.moe_top_k)
+ scale = scale_numerator / scale_denominator
+ return scale * torch.dot(tokens_per_expert, expert_scores)
+
+
+# NOTE: This class defines MoE expert computation, including expert model parallel
+# communication. When using FSDP on top of MegaBlocks this is the module that should
+# be wrapped s.t. the weight all-gathers can be scheduled *before* the expert model
+# parallel all2all.
+class ParallelMLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super(ParallelMLP, self).__init__()
+ self.args = args
+
+ # Calculate the number of experts in total and the number of experts
+ # owned by this rank.
+ # world_size = mpu.get_expert_parallel_world_size(args)
+ self.num_experts = args.moe_num_experts
+ self.top_k = self.args.moe_top_k
+
+ # Calculate the number of bits needed to represent the expert indices
+ # so that we can pass it to radix sort.
+ self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1)
+
+ # Expert MLP.
+ self.mlp = mlp.MLP(args)
+
+ self.bias: Optional[torch.Tensor]
+ if self.args.bias:
+ # Note that the output bias is not parallelized with expert
+ # model parallelism.
+ self.bias = torch.nn.Parameter(
+ torch.empty(
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ torch.nn.init.zeros_(self.bias)
+ else:
+ self.register_parameter('bias', None)
+
+ # Select the forward function for the operating mode.
+ self.forward_fn = (self.parallel_forward_once if args.moe_expert_model_parallelism else self.forward_once)
+
+ def expert_capacity(self, tokens: int) -> int:
+ world_size = mpu.get_expert_parallel_world_size(self.args)
+ tokens_per_expert = (self.top_k * tokens * world_size / self.num_experts)
+ return int(self.args.moe_capacity_factor * tokens_per_expert)
+
+ def load_balancing_loss(self, tokens_per_expert: torch.Tensor, expert_scores: torch.Tensor):
+ """Calculate the load balancing loss contribution."""
+ assert len(expert_scores.size()) == 2
+ tokens, num_experts = expert_scores.size()
+ assert num_experts == self.num_experts
+ assert len(tokens_per_expert.size()) == 1
+ num_experts, = tokens_per_expert.size()
+ assert num_experts == self.num_experts
+ scale = self.num_experts / (tokens * self.top_k)
+ return scale * torch.dot(
+ tokens_per_expert.to(expert_scores.dtype),
+ expert_scores.mean(dim=0),
+ )
+
+ def indices_and_bins(self,
+ top_expert: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ # Sort the expert ids to produce the scatter/gather
+ # indices for the permutation.
+ #
+ # TODO(tgale): Is it worth doing this conversion to 32-bit
+ # prior? Could we place the `torch.max` operation to return
+ # 32-bit expert indices?
+ top_expert = top_expert.int()
+ # output = ops.sort(top_expert, self.sort_end_bit)
+ output = sort(top_expert, self.sort_end_bit)
+ assert output is not None
+ bin_ids, indices = output
+
+ # Histogram the expert ids to identify the number of
+ # tokens routed to each expert.
+ #
+ # TODO(tgale): Does the sorted data produce a more favorable
+ # data distribution for histogram? Or is the op parallelism
+ # worth more?
+ # tokens_per_expert = ops.histogram(top_expert, self.num_experts)
+ tokens_per_expert = histogram(top_expert, self.num_experts)
+
+ # Calculate the bin bounds for the sorted tokens.
+ # bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ bins = inclusive_cumsum(tokens_per_expert, 0)
+ assert bins is not None
+ bins = bins.view(1) if not len(bins.size()) else bins
+
+ assert isinstance(indices, torch.Tensor)
+ assert isinstance(bin_ids, torch.Tensor)
+ assert isinstance(bins, torch.Tensor)
+ assert isinstance(tokens_per_expert, torch.Tensor)
+
+ return indices, bin_ids, bins, tokens_per_expert
+
+ def permute_and_compute(
+ self,
+ x: torch.Tensor,
+ tokens_per_expert: int, # unused
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor, # unused
+ expert_weights: torch.Tensor,
+ bins: torch.Tensor,
+ expert_capacity: int,
+ top_k: int,
+ ):
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ # output = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
+ output = binned_gather(x, indices, bins, expert_capacity, top_k)
+ assert output is not None
+ x = output
+
+ # Perform the expert computation. Note that we don't
+ # use biases for these linear operations.
+ x = self.mlp(x)
+
+ # Un-route the data for the MoE output.
+ # return ops.binned_scatter(x, indices, expert_weights, bins, top_k)
+ return binned_scatter(x, indices, expert_weights, bins, top_k)
+
+
+ def forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ # If expert_capacity is set to zero, set the number of tokens
+ # per expert to the maximum we need to avoid dropping tokens.
+ sl, bs, _ = x.size()
+ expert_capacity = self.expert_capacity(sl * bs)
+ if expert_capacity == 0:
+ expert_capacity = torch.max(tokens_per_expert).item()
+
+ x = self.permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ self.top_k,
+ )
+ return x, tokens_per_expert
+
+ def parallel_forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ # NOTE: This function implements the same computation as forward_once
+ # but with expert model parallelism.
+ #
+ # 1. Permute the tokens locally so that they are grouped by their
+ # expert assignments. This allows us to transfer all of the tokens
+ # for a remote device in one communication primitive.
+ #
+ # 2. Permute the tokens across the expert parallel devices. After
+ # this is completed each device has all of the tokens assigned to
+ # its set of experts in its local HBM.
+ #
+ # 3. Permute the tokens locally so that they are grouped by their
+ # expert assignement. After the distributed permutation the tokens
+ # are grouped by which device they came from. We re-order them
+ # locally to allow for efficient computation.
+ #
+ # After this series of permutations we compute the linear layers
+ # and then repeat these three steps in reverse to produce the final
+ # output.
+ #
+ # Compute the mapping of local tokens to experts.
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ # If we're sharding the experts along the hidden dimension
+ # multiple devices own parts of the same sets of experts.
+ # Replicate the token counts so every device gets the counts.
+ # repeated_tokens_per_expert = ops.repeat(
+ repeated_tokens_per_expert = repeat(
+ tokens_per_expert,
+ (mpu.hidden_sharding_degree(self.args),),
+ )
+
+ # Pass token count information to the device on which the
+ # target expert resides.
+ parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert,)
+ tpe_handle = dist.all_to_all_single(
+ parallel_tokens_per_expert,
+ repeated_tokens_per_expert,
+ group=self.args.expert_parallel_group,
+ async_op=True,
+ )
+
+ # Permute locally and without any padding so that tokens for each
+ # parallel device are stored contiguously.
+ #
+ # This view updates the shape of the tensor from [sl, bs, hs] to
+ # [sl * bs, hs] prior to the permutation.
+ x = x.view(-1, x.shape[-1])
+ # output = ops.gather(x, indices, bin_ids, bins, self.top_k)
+ output = gather(x, indices, bin_ids, bins, self.top_k)
+ assert output is not None
+ x = output
+
+ # Compute the number of tokens that will be received from each
+ # device and permute the input data across the devices.
+ with torch.no_grad():
+ tpe_handle.wait()
+ experts_per_rank = mpu.experts_per_rank(self.args)
+
+ # Reshape to [world_size, num_experts_per_rank].
+ world_size = mpu.get_expert_parallel_world_size(self.args)
+ repeated_tokens_per_expert = (repeated_tokens_per_expert.view(world_size, experts_per_rank))
+ parallel_tokens_per_expert = (parallel_tokens_per_expert.view(world_size, experts_per_rank))
+
+ # TODO(tgale): It might be faster to do this on the GPU and
+ # then communicate the results back to the host.
+ send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1)
+ parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
+ recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1)
+
+ # Convert the send/recv counts to lists.
+ send_counts = send_counts.tolist()
+ recv_counts = recv_counts.tolist()
+ tokens_received = sum(recv_counts)
+
+ # If we're sharding the experts along the hidden dimension
+ # multiple devices own parts of the same sets of experts.
+ # Replicate the token counts so devices that share experts
+ # get all of the tokens assigned to them.
+ #
+ # TODO(tgale): Fuse this into the prior, local permutation.
+ # x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
+ x = repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
+
+ # Start the cross-device permutation asynchronously so we can
+ # overlap communication with computation.
+ parallel_x, parallel_x_handle = all_to_all(
+ x,
+ recv_counts,
+ send_counts,
+ self.args.expert_parallel_group,
+ async_op=True,
+ )
+
+ with torch.no_grad():
+ # After we do the cross-device permutation we have the tokens on the
+ # correct device but not yet grouped by expert because we received
+ # tokens from each device as contiguous chunks. To group the tokens
+ # for expert computation we'll do one more local permutation. The
+ # rest of this torch.no_grad() scope sets up the indices and bins
+ # for this permutation.
+ # replicate_bins = ops.inclusive_cumsum(
+ replicate_bins = inclusive_cumsum(
+ parallel_tokens_per_expert.flatten(),
+ 0,
+ )
+ replicate_bins = (replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins)
+
+ # Construct the expert indices for the permuted tokens.
+ parallel_top_expert = torch.remainder(
+ torch.arange(
+ self.num_experts * mpu.hidden_sharding_degree(self.args),
+ dtype=torch.int32,
+ device=indices.device,
+ ),
+ mpu.experts_per_rank(self.args),
+ )
+ # parallel_top_expert = ops.replicate(
+ parallel_top_expert = replicate(
+ parallel_top_expert.unsqueeze(dim=0),
+ replicate_bins,
+ tokens_received,
+ ).flatten()
+
+ # TODO(tgale): The sort_end_bit here can be reduced.
+ # parallel_bin_ids, parallel_indices = ops.sort(
+ parallel_bin_ids, parallel_indices = sort(
+ parallel_top_expert,
+ self.sort_end_bit,
+ )
+
+ # Calculate the bins boundaries from the token counts.
+ parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+ # parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = (parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins)
+
+ # If expert_capacity is set to zero, set the number of tokens
+ # per expert to the maximum we need to avoid dropping tokens.
+ tokens, _ = x.size()
+ expert_capacity = self.expert_capacity(tokens)
+ if expert_capacity == 0:
+ expert_capacity = torch.max(parallel_tokens_per_expert).item()
+
+ # Locally permute the tokens and perform the expert computation.
+ # Block to make sure that the cross-device permutation is complete.
+ if self.args.mlp_impl == 'grouped':
+ # GroupedMLP requires counts on CPU. We can use the tensor already
+ # moved to CPU for the prior all_to_all, which avoids an extra
+ # device synchronization.
+ parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+ parallel_x_handle.wait()
+ parallel_x = self.permute_and_compute(
+ parallel_x,
+ parallel_tokens_per_expert,
+ parallel_indices,
+ parallel_bin_ids,
+ None, # expert_weights
+ parallel_bins,
+ expert_capacity,
+ top_k=1,
+ )
+
+ # Un-permute the tokens across the devices.
+ x, _ = all_to_all(
+ parallel_x,
+ send_counts,
+ recv_counts,
+ self.args.expert_parallel_group,
+ )
+
+ # Reduce along the hidden sharding to get the final outputs.
+ #
+ # TODO(tgale): Fuse this into the following local permutation.
+ shape = (
+ mpu.hidden_sharding_degree(self.args),
+ -1,
+ self.args.hidden_size,
+ )
+ # x = ops.sum(x.view(shape), dim=0)
+ x = x.view(shape).sum(dim=0)
+
+ # Un-permute locally to setup for the next series of operations.
+ # x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
+ x = scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
+ return x, tokens_per_expert.flatten()
+
+ def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ in_shape = x.size()
+
+ # Compute the experts.
+ x, tokens_per_expert = self.forward_fn(x, expert_weights, top_experts)
+ if self.training and self.args.moe_loss_weight > 0:
+ save_load_balancing_loss((tokens_per_expert, scores))
+ x = x.view(in_shape)
+ if self.bias is not None:
+ if self.args.return_bias:
+ return x, self.bias
+ return x + self.bias
+ return x
+
+
+class MoE(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super(MoE, self).__init__()
+
+ # Token router.
+ self.router = router.LearnedRouter(args)
+
+ # Expert computation helper.
+ self.experts = self._init_experts_mlp(args)
+
+ self.shared_expert = None
+ if args.shared_expert:
+ # SharedExpert computation helper.
+ self.shared_expert = sharedexpert_registry.get(args)
+
+ def _init_experts_mlp(self, args: Arguments):
+ return ParallelMLP(args)
+
+ def forward(self, x: torch.Tensor):
+ # NOTE: If we're going to cast the activations to lower precision
+ # do it before we permute the tokens to save bandwidth.
+ x = common.cast_if_autocast_enabled(x)
+
+ # Compute the expert scores and assignments.
+ scores, expert_weights, top_experts = self.router(x)
+
+ # Compute the experts.
+ out = self.experts(x, scores, expert_weights, top_experts)
+ if self.shared_expert is not None:
+ shared_expert_out = self.shared_expert(x)
+ out = self.shared_expert.add_experts_sharedexpert(
+ shared_expert_out,
+ out,
+ )
+ return out
diff --git a/build/torch210-cxx11-cu126-aarch64-linux/_layers/mpu.py b/build/torch210-cxx11-cu126-aarch64-linux/_layers/mpu.py
new file mode 100644
index 0000000000000000000000000000000000000000..434e143ab42bf3f83406d69e9dd1f72777716e22
--- /dev/null
+++ b/build/torch210-cxx11-cu126-aarch64-linux/_layers/mpu.py
@@ -0,0 +1,94 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Optional
+
+import torch
+import torch.distributed as dist
+
+# from megablocks.layers.arguments import Arguments
+from .arguments import Arguments
+
+
+class MoeParam(torch.Tensor):
+
+ def __init__(self):
+ super().__init__(self)
+ self.expert_model_parallel: bool
+
+
+def is_moe_param(tensor: torch.Tensor) -> bool:
+ return hasattr(tensor, 'expert_model_parallel')
+
+
+def get_expert_parallel_world_size(args: Arguments) -> int:
+ return (dist.get_world_size(args.expert_parallel_group) if args.moe_expert_model_parallelism else 1)
+
+
+def get_expert_parallel_rank(args: Arguments) -> int:
+ return (dist.get_rank(args.expert_parallel_group) if args.moe_expert_model_parallelism else 0)
+
+
+def set_expert_model_parallel_attributes(
+ tensor: torch.Tensor,
+ is_parallel: bool,
+):
+ assert not hasattr(tensor, 'expert_model_parallel')
+ setattr(tensor, 'expert_model_parallel', is_parallel)
+
+
+def param_is_expert_model_parallel(param: MoeParam) -> bool:
+ return (hasattr(param, 'expert_model_parallel') and param.expert_model_parallel)
+
+
+def copy_expert_model_parallel_attributes(
+ destination_tensor: torch.Tensor,
+ source_tensor: torch.Tensor,
+):
+ if hasattr(source_tensor, 'expert_model_parallel'):
+ setattr(
+ destination_tensor,
+ 'expert_model_parallel',
+ getattr(source_tensor, 'expert_model_parallel'),
+ )
+
+
+def synchronized_print(group: Optional[dist.ProcessGroup], *x: torch.Tensor):
+ world_size = dist.get_world_size(group)
+ rank = dist.get_rank(group)
+ for i in range(world_size):
+ dist.barrier(group)
+ if i == rank:
+ print(f'rank = {rank}', *x)
+
+
+# Helpers for expert/tensor sharding.
+def expert_sharding_degree(args: Arguments) -> int:
+ world_size = get_expert_parallel_world_size(args)
+ esd = min(world_size, args.moe_num_experts)
+
+ if (args.moe_num_experts % esd) != 0:
+ raise ValueError(f'Cannot shard {args.moe_num_experts} experts {esd} ways.',)
+ return esd
+
+
+def hidden_sharding_degree(args: Arguments) -> int:
+ world_size = get_expert_parallel_world_size(args)
+ esd = expert_sharding_degree(args)
+ hsd = world_size // esd
+
+ if (args.ffn_hidden_size % hsd) != 0:
+ raise ValueError(f'Cannot shard {args.ffn_hidden_size} features {hsd} ways.',)
+ if (esd * hsd) != world_size:
+ raise ValueError(
+ f"Invalid sharding. 'expert_sharding_degree' ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size}).",
+ )
+ return hsd
+
+
+def experts_per_rank(args: Arguments) -> int:
+ return args.moe_num_experts // expert_sharding_degree(args)
+
+
+def features_per_rank(args: Arguments) -> int:
+ return args.ffn_hidden_size // hidden_sharding_degree(args)
diff --git a/build/torch210-cxx11-cu126-aarch64-linux/_layers/router.py b/build/torch210-cxx11-cu126-aarch64-linux/_layers/router.py
new file mode 100644
index 0000000000000000000000000000000000000000..37cb2782348d62583376f1a183c7ede83601216d
--- /dev/null
+++ b/build/torch210-cxx11-cu126-aarch64-linux/_layers/router.py
@@ -0,0 +1,116 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+
+# from megablocks.layers import common
+# from megablocks.layers.arguments import Arguments
+from . import common
+from .arguments import Arguments
+
+_ROUTER_LOGITS = []
+
+
+def _save_router_logits(logits: torch.Tensor, args: Arguments):
+ if args.moe_zloss_weight == 0:
+ return
+ global _ROUTER_LOGITS
+ _ROUTER_LOGITS.append(logits)
+
+
+def clear_router_zloss():
+ global _ROUTER_LOGITS
+ _ROUTER_LOGITS.clear()
+
+
+def batched_router_zloss(args: Arguments):
+ global _ROUTER_LOGITS
+
+ if args.moe_zloss_weight == 0:
+ import warnings
+ warnings.warn('Call to batched_router_zloss, but moe_zloss_weight=0')
+ return 0
+
+ logits_per_router = _ROUTER_LOGITS
+
+ if args.moe_zloss_in_fp32:
+ logits_per_router = [logits.float() for logits in logits_per_router]
+
+ unscaled_zloss_per_router = torch.stack([
+ torch.logsumexp(logits, dim=1).square().mean() for logits in logits_per_router
+ ])
+
+ return args.moe_zloss_weight * unscaled_zloss_per_router
+
+
+# NOTE: To enable end-to-end benchmarking without convergence we
+# support a flag to force the router to assign tokens uniformly
+# across the experts. We do this with a custom autograd operation
+# so that PyTorch still executes the full set of router operation.
+class _UniformExpertAssignment(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, num_experts: int):
+ out = torch.arange(x.numel(), dtype=x.dtype, device=x.device)
+ out = torch.remainder(out, num_experts)
+ return out.view(x.shape)
+
+
+_uniform_expert_assignment = _UniformExpertAssignment.apply
+
+
+class LearnedRouter(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+
+ # Learned router parameters.
+ #
+ # NOTE: This weight matrix is not parallelized with expert model
+ # parallelism. Each device needs the entire router weight matrix
+ # so that it can route its batch of data correctly.
+ self.layer = torch.nn.Linear(
+ args.hidden_size,
+ args.moe_num_experts,
+ bias=False,
+ dtype=common.dtype(args),
+ device=args.device,
+ )
+ args.init_method(self.layer.weight)
+
+ def jitter(self, x: torch.Tensor):
+ low: float = 1.0 - self.args.moe_jitter_eps
+ high: float = 1.0 + self.args.moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return low + noise * (high - low)
+
+ def _top_k(self, scores: torch.Tensor):
+ if self.args.moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, self.args.moe_top_k, dim=-1)
+
+ def forward(self, x: torch.Tensor):
+ if self.training and self.args.moe_jitter_eps is not None:
+ x = x * self.jitter(x)
+
+ logits = self.layer(x.view(-1, x.shape[-1]))
+ _save_router_logits(logits, self.args)
+ scores = logits.softmax(dim=-1)
+ expert_weights, expert_indices = self._top_k(scores)
+ if self.args.moe_normalize_expert_weights:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=self.args.moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+
+ expert_indices = (
+ _uniform_expert_assignment(
+ expert_indices,
+ self.args.moe_num_experts,
+ ) if self.args.uniform_expert_assignment else expert_indices
+ )
+ return scores, expert_weights, expert_indices
diff --git a/build/torch210-cxx11-cu126-aarch64-linux/_layers/sharedexpert_registry.py b/build/torch210-cxx11-cu126-aarch64-linux/_layers/sharedexpert_registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..5840862f88f370ace5fd49bd0612fc98d186cc49
--- /dev/null
+++ b/build/torch210-cxx11-cu126-aarch64-linux/_layers/sharedexpert_registry.py
@@ -0,0 +1,32 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Union
+
+# from megablocks.layers import glu, mlp
+# from megablocks.layers.arguments import Arguments
+from . import glu, mlp
+from .arguments import Arguments
+
+_REGISTRY = {
+ 'mlp': mlp.SharedMLP,
+ 'glu': glu.SharedGLU,
+}
+
+
+def get(args: Arguments) -> Union[mlp.SharedMLP, glu.SharedGLU]:
+ """Returns an SharedMLP for use in a dMoE instance.
+
+ Uses the provided arguments to instantiate the appropriate
+ SharedMLP instance.
+
+ Args:
+ args: propagated Arguments dataclass.
+
+ Returns:
+ An instantiated SharedMLP constructed using the input args.
+ """
+ if args.mlp_type not in _REGISTRY:
+ raise ValueError(f'Unsupported mlp type: {args.mlp_type}')
+
+ return _REGISTRY[args.mlp_type](args)
diff --git a/build/torch210-cxx11-cu126-aarch64-linux/_megablocks_cuda_7a6bcf4.abi3.so b/build/torch210-cxx11-cu126-aarch64-linux/_megablocks_cuda_7a6bcf4.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..aaae1f716dedb9f0e6616ac3fe9ad730ed68f7f6
--- /dev/null
+++ b/build/torch210-cxx11-cu126-aarch64-linux/_megablocks_cuda_7a6bcf4.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:39609d4551be5cbaf91e53da23cf4826040984b94aa8c5a574e69de104b484bd
+size 15124328
diff --git a/build/torch210-cxx11-cu126-aarch64-linux/_ops.py b/build/torch210-cxx11-cu126-aarch64-linux/_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..abde001f8cf5f78a02794d6e9a81fd8195e65d77
--- /dev/null
+++ b/build/torch210-cxx11-cu126-aarch64-linux/_ops.py
@@ -0,0 +1,9 @@
+import torch
+from . import _megablocks_cuda_7a6bcf4
+ops = torch.ops._megablocks_cuda_7a6bcf4
+
+def add_op_namespace_prefix(op_name: str):
+ """
+ Prefix op by namespace.
+ """
+ return f"_megablocks_cuda_7a6bcf4::{op_name}"
diff --git a/build/torch210-cxx11-cu126-aarch64-linux/_version.py b/build/torch210-cxx11-cu126-aarch64-linux/_version.py
new file mode 100644
index 0000000000000000000000000000000000000000..c55783177af19bc03654c730c4892df8f8532279
--- /dev/null
+++ b/build/torch210-cxx11-cu126-aarch64-linux/_version.py
@@ -0,0 +1,6 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+"""The MegaBlocks Version."""
+
+__version__ = '0.11.0.dev0'
diff --git a/build/torch210-cxx11-cu126-aarch64-linux/backend/__init__.py b/build/torch210-cxx11-cu126-aarch64-linux/backend/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d4e43e9b7e6a9a1c3bde2df34914643ca5d8332
--- /dev/null
+++ b/build/torch210-cxx11-cu126-aarch64-linux/backend/__init__.py
@@ -0,0 +1,2 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
diff --git a/build/torch210-cxx11-cu126-aarch64-linux/backend/kernels.py b/build/torch210-cxx11-cu126-aarch64-linux/backend/kernels.py
new file mode 100644
index 0000000000000000000000000000000000000000..d20139324883338ddb312e4b05a72056d74491ac
--- /dev/null
+++ b/build/torch210-cxx11-cu126-aarch64-linux/backend/kernels.py
@@ -0,0 +1,557 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import triton
+import triton.language as tl
+
+# Stub triton autotune when testing in a env that does not have CUDA
+# this approach preserves the original code but enables testing without a GPU
+if torch.cuda.is_available() is False:
+ import warnings
+
+ warnings.warn("CUDA is not available. Triton autotuning is disabled.")
+
+ def _no_autotune(*args, **kwargs):
+ def deco(fn):
+ return fn
+ return deco
+
+ triton.autotune = _no_autotune
+
+
+def assert_is_tensor(x, ndim):
+ if x.ndim != ndim:
+ raise ValueError(f'Expected {ndim}-tensor but got {x.ndim}-tensor')
+
+
+def assert_is_matrix(x):
+ assert_is_tensor(x, 2)
+
+
+def assert_is_vector(x):
+ if x.ndim != 1:
+ raise ValueError(f'Expected 1-tensor but got {x.ndim}-tensor')
+
+
+def assert_equal(a, b):
+ if a != b:
+ raise ValueError(f'Expected dimensions to be equal but got {a} and {b}.',)
+
+
+# a: (tokens, hidden_size), real.
+# indices: (tokens * top_k), integer.
+# bin_ids: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+# padded_bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _padded_copy(
+ a,
+ b,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+ A_TO_B: tl.constexpr,
+ SCALE: tl.constexpr,
+):
+ # Our index into array 'a'.
+ index_a = tl.load(indices + tl.program_id(0))
+
+ # One threadblock per row in 'a'. Array 'b' has greater or equal
+ # number of rows since they could be padded.
+ bin_idx = tl.load(bin_ids + tl.program_id(0))
+
+ # Now we know what bin we're assigned to, but we need to know how
+ # many threadblocks were assigned to earlier bins so we can offset
+ # in our bin properly.
+ offset_in_bin = tl.program_id(0)
+ if bin_idx > 0:
+ offset_in_bin -= tl.load(bins + bin_idx - 1)
+
+ # Load the starting index of our bin in array 'b'.
+ index_b = offset_in_bin
+ if bin_idx > 0:
+ index_b += tl.load(padded_bins + bin_idx - 1)
+
+ # Offset the input and output pointers.
+ #
+ # If we're going from A to B, divide the input index to copy
+ # the same input repeatedly. If we're going from B to A we
+ # need to reduce the result. Using atomics is slow, so we
+ # do the reduce step in a second kernel.
+ offset = index_a // TOP_K if A_TO_B else index_a
+ a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS)
+ b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ # Load the scale, if requested.
+ scale = tl.load(weights + index_a) if SCALE else 1
+
+ # Swap the pointers depending on the direction.
+ iptr = a if A_TO_B else b
+ optr = b if A_TO_B else a
+
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ x = tl.load(iptr + offsets, mask=mask)
+ x = x.to(tl.float32) * scale.to(tl.float32)
+
+ tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask)
+
+ offsets += BLOCK_X
+
+
+def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+ assert_equal(bin_ids.shape[0], x.shape[0] * top_k)
+ assert_equal(bins.size(), padded_bins.size())
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ # NOTE: Because of the padding, the output size is dynamic.
+ # We load the final padded bin bound to get the output rows.
+ output_rows = padded_bins[-1].cpu().item()
+ out = torch.zeros((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ x,
+ out,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def gather(x, indices, bin_ids, weights, bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+ assert_equal(bin_ids.shape[0], x.shape[0] * top_k)
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ # NOTE: There is no padding so the output rows equals the
+ # input rows multiplied by top_k.
+ output_rows = x.shape[0] * top_k
+ out = torch.empty((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ x,
+ out,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], bin_ids.shape[0])
+ assert_equal(bins.size(), padded_bins.size())
+
+ if weights is not None:
+ assert_equal(indices.shape[0], weights.shape[0])
+
+ tokens = indices.shape[0] // top_k
+ out = torch.empty((tokens, top_k, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ out,
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=False,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+
+ # Reduce along the top-k dimension, if needed.
+ return out.sum(dim=1) if top_k > 1 else out.view(tokens, x.shape[1])
+
+
+def scatter(x, indices, bin_ids, weights, bins, top_k):
+ return padded_scatter(x, indices, bin_ids, weights, bins, bins, top_k)
+
+
+# x: (tokens, top_k, hidden_size), real
+# grad: (tokens, hidden_size), real.
+# wgrad: (tokens, top_k), real.
+# indices: (tokens * top_k), integer.
+# bin_ids: (tokens * top_k), integer.
+# bins: (num_experts), integer.
+# padded_bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _padded_copy_wgrad(
+ x,
+ grad,
+ wgrad,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+):
+ # Our index into 'tokens * top_k'.
+ index_out = tl.load(indices + tl.program_id(0))
+
+ # One threadblock per row in 'a'. Array 'b' has greater or equal
+ # number of rows since they could be padded.
+ bin_idx = tl.load(bin_ids + tl.program_id(0))
+
+ # Now we know what bin we're assigned to, but we need to know how
+ # many threadblocks were assigned to earlier bins so we can offset
+ # in our bin properly.
+ offset_in_bin = tl.program_id(0)
+ if bin_idx > 0:
+ offset_in_bin -= tl.load(bins + bin_idx - 1)
+
+ # Load the starting index of our bin in array 'x'.
+ index_x = offset_in_bin
+ if bin_idx > 0:
+ index_x += tl.load(padded_bins + bin_idx - 1)
+
+ # Offset the input and output pointers.
+ wgrad += index_out
+ grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS)
+ x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ data = tl.load(x + offsets, mask=mask).to(tl.float32)
+ scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
+ acc += data * scale
+ offsets += BLOCK_X
+
+ # Reduce to get the final result and store.
+ out = tl.sum(acc).to(wgrad.dtype.element_ty)
+ tl.store(wgrad, out)
+
+
+def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_matrix(grad)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], bin_ids.shape[0])
+ assert_equal(bins.size(), padded_bins.size())
+
+ tokens = indices.shape[0] // top_k
+ out = torch.empty((tokens * top_k), dtype=x.dtype, device=x.device)
+ _padded_copy_wgrad[(indices.shape[0],)](
+ x,
+ grad,
+ out,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ TOP_K=top_k,
+ )
+ return out
+
+
+def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k):
+ return padded_scatter_wgrad(x, grad, indices, bin_ids, bins, bins, top_k)
+
+
+# a: (tokens, hidden_size), real.
+# b: (num_experts, expert_capacity, num_columns), real.
+# indices: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _binned_copy(
+ a,
+ b,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+ A_TO_B: tl.constexpr,
+ SCALE: tl.constexpr,
+):
+ # Load our indices into the output.
+ expert_idx = tl.program_id(0)
+ entry_idx = tl.program_id(1)
+
+ # Calculate our offset into the output.
+ index_b = expert_idx * expert_capacity + entry_idx
+
+ # Load the index bounds for our bin and calculate
+ # the number of tokens assigned to our expert.
+ start = 0
+ if expert_idx > 0:
+ start = tl.load(bins + expert_idx - 1)
+ end = tl.load(bins + expert_idx)
+ num_tokens = end - start
+
+ # Calculate our offset into the input. If we don't
+ # have an input exit early.
+ if entry_idx >= num_tokens:
+ return
+ index_a = tl.load(indices + start + entry_idx)
+
+ # Offset the input and output pointers.
+ #
+ # If we're going from A to B, divide the input index to copy
+ # the same input repeatedly. If we're going from B to A we
+ # need to reduce the result. Using atomics is slow, so we
+ # do the reduce step in a second kernel.
+ offset = index_a // TOP_K if A_TO_B else index_a
+ a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS)
+ b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ # Load the scale, if requested.
+ scale = tl.load(weights + index_a) if SCALE else 1
+
+ # Swap the pointers depending on the direction.
+ #
+ # NOTE: We need to zero the output in both directions.
+ iptr = a if A_TO_B else b
+ optr = b if A_TO_B else a
+
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ x = tl.load(iptr + offsets, mask=mask)
+ x = x.to(tl.float32) * scale.to(tl.float32)
+
+ tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask)
+
+ offsets += BLOCK_X
+
+
+def binned_gather(x, indices, weights, bins, expert_capacity, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ num_experts = bins.shape[0]
+ out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
+
+ _binned_copy[(num_experts, expert_capacity)](
+ x,
+ out,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def binned_scatter(x, indices, weights, bins, top_k):
+ # Validate the input shapes.
+ assert_is_tensor(x, 3)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(bins.shape[0], x.shape[0])
+
+ if weights is not None:
+ assert_equal(indices.shape[0], weights.shape[0])
+
+ num_experts, expert_capacity, hidden_size = x.shape
+ tokens = indices.shape[0] // top_k
+ out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
+ _binned_copy[(num_experts, expert_capacity)](
+ out,
+ x,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS=hidden_size,
+ A_TO_B=False,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+
+ # Reduce along the top-k dimension, if needed.
+ return out.sum(dim=1) if top_k > 1 else out.view(tokens, hidden_size)
+
+
+# a: (tokens, hidden_size), real.
+# b: (num_experts, expert_capacity, num_columns), real.
+# indices: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _binned_copy_wgrad(
+ x,
+ grad,
+ wgrad,
+ num_experts,
+ expert_capacity,
+ indices,
+ bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+):
+ # Load our indices into the output.
+ expert_idx = tl.program_id(0)
+ entry_idx = tl.program_id(1)
+
+ # Calculate our offset into the output.
+ index_x = expert_idx * expert_capacity + entry_idx
+
+ # Load the index bounds for our bin and calculate
+ # the number of tokens assigned to our expert.
+ start = 0
+ if expert_idx > 0:
+ start = tl.load(bins + expert_idx - 1)
+ end = tl.load(bins + expert_idx)
+ num_tokens = end - start
+
+ # Calculate our offset into the input. If we don't
+ # have an input exit early.
+ if entry_idx >= num_tokens:
+ return
+ index_out = tl.load(indices + start + entry_idx)
+
+ # Offset the input and output pointers.
+ wgrad += index_out
+ grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS)
+ x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ data = tl.load(x + offsets, mask=mask).to(tl.float32)
+ scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
+ acc += data * scale
+ offsets += BLOCK_X
+
+ # Reduce to get the final result and store.
+ out = tl.sum(acc).to(wgrad.dtype.element_ty)
+ tl.store(wgrad, out)
+
+
+def binned_scatter_wgrad(x, grad, indices, bins, top_k):
+ # Validate the input shapes.
+ assert_is_tensor(x, 3)
+ assert_is_matrix(grad)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(bins.shape[0], x.shape[0])
+
+ num_experts, expert_capacity, hidden_size = x.shape
+ tokens = indices.shape[0] // top_k
+ out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
+ _binned_copy_wgrad[(num_experts, expert_capacity)](
+ x,
+ grad,
+ out,
+ num_experts,
+ expert_capacity,
+ indices,
+ bins,
+ NUM_COLUMNS=hidden_size,
+ TOP_K=top_k,
+ )
+ return out
diff --git a/build/torch210-cxx11-cu126-aarch64-linux/benchmark_util.py b/build/torch210-cxx11-cu126-aarch64-linux/benchmark_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..02612d95e3ead1175a596e2878fa34b5bf85ad6f
--- /dev/null
+++ b/build/torch210-cxx11-cu126-aarch64-linux/benchmark_util.py
@@ -0,0 +1,35 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import numpy as np
+import torch
+
+
+def log_benchmark(name, arguments, time, std):
+ print('=' * 60)
+ print(f'{name} Benchmark')
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean time = {:.3f}ms, std time = {:.3f}ms'.format(time, std))
+ print('=' * 60)
+
+
+def benchmark_function(fn, iterations=100, warmup=10):
+ # Warmup iterations.
+ for _ in range(warmup):
+ fn()
+
+ times = []
+ for i in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+
+ start.record()
+ fn()
+ end.record()
+
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ return np.mean(times), np.std(times)
diff --git a/build/torch210-cxx11-cu126-aarch64-linux/cpu_fused_moe.py b/build/torch210-cxx11-cu126-aarch64-linux/cpu_fused_moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a29e81d6e9c57f64b0a78ac1c0828e45fd9d855
--- /dev/null
+++ b/build/torch210-cxx11-cu126-aarch64-linux/cpu_fused_moe.py
@@ -0,0 +1,311 @@
+# SPDX-License-Identifier: Apache-2.0
+# MegaBlocks CPU Fused MoE Implementation
+#
+# This is a pure Python/PyTorch implementation for CPU.
+# For better performance, consider using the C++ kernel implementation.
+#
+import torch
+import torch.nn.functional as F
+
+
+def swigluoai_activation(gate: torch.Tensor, up: torch.Tensor,
+ alpha: float = 1.702, limit: float = 7.0) -> torch.Tensor:
+ """
+ SwigluOAI activation function used in GptOss models.
+
+ Formula:
+ gate = clamp(gate, max=limit)
+ up = clamp(up, -limit, limit)
+ glu = gate * sigmoid(gate * alpha)
+ output = (up + 1) * glu
+
+ Args:
+ gate: Gate tensor from gate projection
+ up: Up tensor from up projection
+ alpha: Scaling factor for sigmoid (default: 1.702)
+ limit: Clamp limit (default: 7.0)
+
+ Returns:
+ Activated tensor
+ """
+ gate = gate.clamp(max=limit)
+ up = up.clamp(min=-limit, max=limit)
+ glu = gate * torch.sigmoid(gate * alpha)
+ return (up + 1) * glu
+
+
+def silu_and_mul_activation(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor:
+ """
+ SiLU (Swish) activation with element-wise multiplication.
+
+ Formula:
+ output = silu(gate) * up
+
+ Args:
+ gate: Gate tensor
+ up: Up tensor
+
+ Returns:
+ Activated tensor
+ """
+ return F.silu(gate) * up
+
+
+def route_tokens_cpu(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: torch.Tensor | None,
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_normalize_expert_weights: int | None = None,
+) -> tuple:
+ """
+ Route tokens to experts and compute expert weights and indices (CPU version).
+
+ Args:
+ x: Input tensor [batch, seq, hidden] or [tokens, hidden]
+ router_weight: Router weight [num_experts, hidden]
+ router_bias: Router bias [num_experts] or None
+ moe_top_k: Number of experts per token
+ moe_num_experts: Total number of experts
+ moe_normalize_expert_weights: Normalization order or None
+
+ Returns:
+ Tuple of (logits, expert_weights, expert_indices)
+ """
+ x_flat = x.view(-1, x.shape[-1])
+ logits = F.linear(x_flat, router_weight, router_bias)
+
+ if moe_top_k == 1:
+ expert_weights, expert_indices = logits.max(dim=-1, keepdim=True)
+ else:
+ expert_weights, expert_indices = torch.topk(logits, moe_top_k, dim=-1)
+
+ expert_weights = expert_weights.softmax(dim=-1)
+
+ if moe_normalize_expert_weights is not None:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+
+ return logits, expert_weights, expert_indices
+
+
+def cpu_fused_moe(
+ hidden_states: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ w1_bias: torch.Tensor | None = None,
+ w2_bias: torch.Tensor | None = None,
+ activation: str = "silu",
+ alpha: float = 1.702,
+ limit: float = 7.0,
+ is_interleaved: bool = True,
+) -> torch.Tensor:
+ """
+ CPU Fused MoE using PyTorch operations.
+
+ This implementation processes all experts in parallel using batched operations
+ instead of sequential for loops, which is more efficient on CPU.
+
+ Args:
+ hidden_states: [num_tokens, hidden_size]
+ w1: [num_experts, hidden_size, 2*inter_size] - gate_up_proj weights
+ w2: [num_experts, inter_size, hidden_size] - down_proj weights
+ topk_weights: [num_tokens, topk] - routing weights
+ topk_ids: [num_tokens, topk] - expert indices
+ w1_bias: [num_experts, 2*inter_size] or None
+ w2_bias: [num_experts, hidden_size] or None
+ activation: "silu" or "swigluoai"
+ alpha: swigluoai alpha parameter
+ limit: swigluoai limit parameter
+ is_interleaved: whether gate_up is interleaved [g0,u0,g1,u1,...] (True for GptOss)
+
+ Returns:
+ output: [num_tokens, hidden_size]
+ """
+ num_tokens, hidden_size = hidden_states.shape
+ num_experts = w1.shape[0]
+ inter_size = w2.shape[1]
+ topk = topk_weights.shape[1]
+
+ # Initialize output
+ output = torch.zeros_like(hidden_states)
+
+ # Build expert mask: which tokens go to which expert
+ # expert_mask[expert_id] contains indices of (token_idx, topk_pos) pairs
+ for expert_idx in range(num_experts):
+ # Find tokens assigned to this expert
+ # mask shape: [num_tokens, topk], True where topk_ids == expert_idx
+ mask = (topk_ids == expert_idx)
+
+ if not mask.any():
+ continue
+
+ # Get token indices and topk positions
+ token_indices, topk_positions = torch.where(mask)
+
+ if len(token_indices) == 0:
+ continue
+
+ # Gather input tokens for this expert
+ # current_hidden: [num_selected_tokens, hidden_size]
+ current_hidden = hidden_states[token_indices]
+
+ # Get weights for this expert
+ # w1[expert_idx]: [hidden_size, 2*inter_size]
+ # w2[expert_idx]: [inter_size, hidden_size]
+ expert_w1 = w1[expert_idx] # [hidden_size, 2*inter_size]
+ expert_w2 = w2[expert_idx] # [inter_size, hidden_size]
+
+ # First projection: hidden @ w1 -> [num_selected, 2*inter_size]
+ gate_up = current_hidden @ expert_w1
+
+ # Add bias if present
+ if w1_bias is not None:
+ gate_up = gate_up + w1_bias[expert_idx]
+
+ # Split gate and up projections
+ if is_interleaved:
+ # GptOss uses interleaved layout: [g0, u0, g1, u1, ...]
+ gate = gate_up[..., ::2] # [num_selected, inter_size]
+ up = gate_up[..., 1::2] # [num_selected, inter_size]
+ else:
+ # Standard layout: [gate_all, up_all]
+ gate = gate_up[..., :inter_size]
+ up = gate_up[..., inter_size:]
+
+ # Apply activation
+ if activation == "swigluoai":
+ activated = swigluoai_activation(gate, up, alpha, limit)
+ else: # silu
+ activated = silu_and_mul_activation(gate, up)
+
+ # Second projection: activated @ w2 -> [num_selected, hidden_size]
+ expert_out = activated @ expert_w2
+
+ # Add bias if present
+ if w2_bias is not None:
+ expert_out = expert_out + w2_bias[expert_idx]
+
+ # Apply routing weights and accumulate
+ # weights shape: [num_selected]
+ weights = topk_weights[token_indices, topk_positions].unsqueeze(-1)
+ weighted_out = expert_out * weights
+
+ # Accumulate to output
+ output.index_add_(0, token_indices, weighted_out.to(output.dtype))
+
+ return output
+
+
+class MegaBlocksMoeMLP(torch.nn.Module):
+ """
+ CPU MoE MLP module that can be used as a drop-in replacement for
+ the transformers GptOssMLP when using @use_kernel_forward_from_hub.
+ """
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> tuple:
+ """
+ Forward pass through the MoE layer.
+
+ Args:
+ x: Input tensor of shape [batch_size, seq_len, hidden_size] or [tokens, hidden_size]
+
+ Returns:
+ Tuple of (output, expert_weights) where:
+ - output: Tensor of same shape as input
+ - expert_weights: Expert weights for each token [tokens, top_k]
+ """
+ # Get MoE parameters from the wrapped modules
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+
+ # Detect activation type
+ if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
+ activation = "swigluoai"
+ alpha = self.experts.alpha
+ limit = self.experts.limit
+ else:
+ activation = getattr(self.experts, "activation", "silu")
+ alpha = 1.702
+ limit = 7.0
+
+ # Get weight tensors
+ if hasattr(self.experts, "gate_up_proj"):
+ w1 = self.experts.gate_up_proj
+ is_interleaved = True # GptOss uses interleaved layout
+ elif hasattr(self.experts, "w1"):
+ w1 = self.experts.w1
+ w3 = getattr(self.experts, "w3", None)
+ if w3 is not None:
+ w1 = torch.cat([w1, w3], dim=-1)
+ is_interleaved = False
+ else:
+ raise AttributeError("experts module must have 'gate_up_proj' or 'w1' attribute")
+
+ if hasattr(self.experts, "down_proj"):
+ w2 = self.experts.down_proj
+ elif hasattr(self.experts, "w2"):
+ w2 = self.experts.w2
+ else:
+ raise AttributeError("experts module must have 'down_proj' or 'w2' attribute")
+
+ # Get optional bias tensors
+ w1_bias = getattr(self.experts, "gate_up_proj_bias", None)
+ w2_bias = getattr(self.experts, "down_proj_bias", None)
+
+ # Store original shape
+ in_shape = x.size()
+
+ # Route tokens to experts
+ logits, expert_weights, expert_indices = route_tokens_cpu(
+ x,
+ self.router.weight,
+ getattr(self.router, "bias", None),
+ moe_top_k,
+ moe_num_experts,
+ moe_normalize_expert_weights,
+ )
+
+ # Reshape input for fused MoE
+ x_flat = x.view(-1, x.shape[-1])
+
+ # Call CPU fused MoE
+ output = cpu_fused_moe(
+ hidden_states=x_flat,
+ w1=w1,
+ w2=w2,
+ topk_weights=expert_weights,
+ topk_ids=expert_indices,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ activation=activation,
+ alpha=alpha,
+ limit=limit,
+ is_interleaved=is_interleaved,
+ )
+
+ # Restore original shape
+ output = output.view(in_shape)
+
+ return output, expert_weights
+
+
+# Export classes and functions
+__all__ = [
+ "MegaBlocksMoeMLP",
+ "cpu_fused_moe",
+ "route_tokens_cpu",
+ "swigluoai_activation",
+ "silu_and_mul_activation",
+]
diff --git a/build/torch210-cxx11-cu126-aarch64-linux/cpu_moe_cpp.py b/build/torch210-cxx11-cu126-aarch64-linux/cpu_moe_cpp.py
new file mode 100644
index 0000000000000000000000000000000000000000..073ff66d24ce348fbb5ed19c9027fadd3f7a9c61
--- /dev/null
+++ b/build/torch210-cxx11-cu126-aarch64-linux/cpu_moe_cpp.py
@@ -0,0 +1,265 @@
+# SPDX-License-Identifier: Apache-2.0
+# MegaBlocks C++ Optimized CPU MoE
+
+"""
+C++ accelerated MoE with brgemm optimization for Intel AMX.
+Direct replacement for cpu_fused_moe.MegaBlocksMoeMLP with better performance.
+"""
+
+import torch
+from typing import Optional
+from .cpu_fused_moe import route_tokens_cpu
+from ._ops import ops
+
+
+def _to_local_tensor(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
+ """Convert DTensor to local torch.Tensor if needed for custom ops compatibility."""
+ if tensor is None:
+ return None
+ # Check if it's a DTensor by looking for the to_local() method
+ if hasattr(tensor, "to_local"):
+ return tensor.to_local()
+ return tensor
+
+
+def fused_moe_cpp(
+ hidden_states: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ inplace: bool = False,
+ use_int8_w8a8: bool = False,
+ use_fp8_w8a16: bool = False,
+ use_mxfp4: bool = False,
+ w1_scale: Optional[torch.Tensor] = None,
+ w2_scale: Optional[torch.Tensor] = None,
+ block_size: Optional[list] = None,
+ a1_scale: Optional[torch.Tensor] = None,
+ a2_scale: Optional[torch.Tensor] = None,
+ w1_bias: Optional[torch.Tensor] = None,
+ w2_bias: Optional[torch.Tensor] = None,
+ alpha: Optional[float] = None,
+ limit: Optional[float] = None,
+ is_vnni: bool = False,
+) -> torch.Tensor:
+ """
+ C++ Fused MoE with brgemm optimization (sglang compatible interface).
+
+ Uses at::native::cpublas::brgemm for efficient batch GEMM on Intel CPUs.
+ Supports both silu_and_mul (standard SwiGLU) and swigluoai (GptOss) activations.
+
+ Args:
+ hidden_states: Input tensor [M, K]
+ w1: Gate and up projections [E, 2N, K]
+ w2: Down projection [E, K, N]
+ topk_weights: Expert weights [M, topk]
+ topk_ids: Expert indices [M, topk]
+ inplace: Whether to use hidden_states as output
+ use_int8_w8a8: Use int8 quantization
+ use_fp8_w8a16: Use fp8 quantization
+ use_mxfp4: Use mxfp4 quantization
+ w1_scale, w2_scale: Quantization scales
+ block_size: Block size for fp8
+ a1_scale, a2_scale: Activation scales
+ w1_bias, w2_bias: Optional biases
+ alpha: swigluoai alpha parameter (set to enable swiglu)
+ limit: swigluoai limit parameter (set to enable swiglu)
+ is_vnni: Whether w1/w2 are already in VNNI packed format
+ """
+ # MXFP4/FP8 kernels only support bf16, convert if needed
+ orig_dtype = hidden_states.dtype
+ need_convert = ((use_mxfp4 or use_fp8_w8a16) and orig_dtype != torch.bfloat16) or orig_dtype == torch.float32
+ if need_convert:
+ hidden_states = hidden_states.to(torch.bfloat16)
+
+ # bias must match hidden_states dtype
+ if w1_bias is not None:
+ w1_bias = w1_bias.to(hidden_states.dtype)
+ if w2_bias is not None:
+ w2_bias = w2_bias.to(hidden_states.dtype)
+
+ # Convert DTensor to local tensor for custom ops compatibility (TP mode)
+ hidden_states = _to_local_tensor(hidden_states)
+ w1 = _to_local_tensor(w1)
+ w2 = _to_local_tensor(w2)
+ topk_weights = _to_local_tensor(topk_weights)
+ topk_ids = _to_local_tensor(topk_ids)
+ w1_scale = _to_local_tensor(w1_scale)
+ w2_scale = _to_local_tensor(w2_scale)
+ a1_scale = _to_local_tensor(a1_scale)
+ a2_scale = _to_local_tensor(a2_scale)
+ w1_bias = _to_local_tensor(w1_bias)
+ w2_bias = _to_local_tensor(w2_bias)
+
+ output = ops.fused_experts(
+ hidden_states, w1, w2, topk_weights, topk_ids,
+ inplace, use_int8_w8a8, use_fp8_w8a16, use_mxfp4,
+ w1_scale, w2_scale, block_size, a1_scale, a2_scale,
+ w1_bias, w2_bias, alpha, limit, is_vnni
+ )
+
+ # Convert back to original dtype if needed
+ if need_convert:
+ output = output.to(orig_dtype)
+ return output
+
+
+class CPUMegaBlocksMoeMLP(torch.nn.Module):
+ """
+ C++ optimized MoE MLP using brgemm.
+ Drop-in replacement for cpu_fused_moe.MegaBlocksMoeMLP with better performance.
+
+ Usage in transformers:
+ # Will be used via @use_kernel_forward_from_hub decorator
+ """
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> tuple:
+ """
+ Forward pass through the MoE layer using C++ kernel.
+
+ Args:
+ x: Input tensor [batch_size, seq_len, hidden_size]
+
+ Returns:
+ Tuple of (output, expert_weights)
+ """
+ # Optimization for GPT-OSS model
+ if getattr(self, "use_mxfp4", None) is None:
+ self.use_mxfp4 = False
+
+ w1_scale = None
+ w2_scale = None
+
+ if (
+ not getattr(self, "packed_scales", False)
+ and hasattr(self.experts, "gate_up_proj")
+ and getattr(self.experts, "gate_up_proj_precision_config", None) is not None
+ ):
+ # convert scales
+ data_1 = ops.convert_scale_packed(self.experts.gate_up_proj_precision_config.weight_scale.data.transpose(-1, -2).contiguous())
+ data_2 = ops.convert_scale_packed(self.experts.down_proj_precision_config.weight_scale.data.transpose(-1, -2).contiguous())
+ self.experts.gate_up_proj_precision_config.weight_scale.storage.data = data_1
+ self.experts.down_proj_precision_config.weight_scale.storage.data = data_2
+ self.packed_scales = True
+ self.use_mxfp4 = True
+
+ if not getattr(self, "packed_weight", False) and hasattr(
+ self.experts, "gate_up_proj"
+ ):
+ # convert weights
+ data_1 = self.experts.gate_up_proj.data.transpose(-1, -2).contiguous()
+ data_2 = self.experts.down_proj.data.transpose(-1, -2).contiguous()
+ if self.use_mxfp4:
+ self.experts.gate_up_proj.storage.data = ops.convert_weight_packed(data_1)
+ self.experts.down_proj.storage.data = ops.convert_weight_packed(data_2)
+ else:
+ # convert_weight_packed only supports bfloat16, float16, int8, fp8_e4m3 or uint8(mxfp4 or int4).
+ data_1 = data_1.to(torch.bfloat16) if data_1.dtype == torch.float32 else data_1
+ data_2 = data_2.to(torch.bfloat16) if data_2.dtype == torch.float32 else data_2
+ self.experts.gate_up_proj.data = ops.convert_weight_packed(data_1)
+ self.experts.down_proj.data = ops.convert_weight_packed(data_2)
+
+ # C++ kernel does not support float32.
+ dtype = torch.bfloat16 if x.dtype == torch.float32 else x.dtype
+ if getattr(self.experts, "gate_up_proj_bias", None) is not None:
+ self.experts.gate_up_proj_bias.data = self.experts.gate_up_proj_bias.data.to(dtype)
+ if getattr(self.experts, "down_proj_bias", None) is not None:
+ self.experts.down_proj_bias.data = self.experts.down_proj_bias.data.to(dtype)
+
+ self.packed_weight = True
+
+ # Get MoE parameters
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
+
+ # Detect activation type
+ if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
+ activation = "swigluoai"
+ alpha = self.experts.alpha
+ limit = self.experts.limit
+ else:
+ activation = getattr(self.experts, "activation", "silu")
+ alpha = 1.702
+ limit = 7.0
+
+ # Get weight tensors
+ if hasattr(self.experts, "gate_up_proj"):
+ w1 = self.experts.gate_up_proj
+ elif hasattr(self.experts, "w1"):
+ w1 = self.experts.w1
+ w3 = getattr(self.experts, "w3", None)
+ if w3 is not None:
+ w1 = torch.cat([w1, w3], dim=-1)
+ else:
+ raise AttributeError("experts module must have 'gate_up_proj' or 'w1' attribute")
+
+ if hasattr(self.experts, "down_proj"):
+ w2 = self.experts.down_proj
+ elif hasattr(self.experts, "w2"):
+ w2 = self.experts.w2
+ else:
+ raise AttributeError("experts module must have 'down_proj' or 'w2' attribute")
+
+ # Get optional bias tensors
+ w1_bias = getattr(self.experts, "gate_up_proj_bias", None)
+ w2_bias = getattr(self.experts, "down_proj_bias", None)
+ w1_bias = w1_bias if w1_bias is None else w1_bias.data
+ w2_bias = w2_bias if w2_bias is None else w2_bias.data
+
+ if self.use_mxfp4:
+ w1_scale = self.experts.gate_up_proj_precision_config.weight_scale.data
+ w2_scale = self.experts.down_proj_precision_config.weight_scale.data
+
+ # Store original shape
+ in_shape = x.size()
+
+ # Route tokens to experts (Python implementation is fast enough)
+ logits, expert_weights, expert_indices = route_tokens_cpu(
+ x,
+ self.router.weight,
+ getattr(self.router, "bias", None),
+ moe_top_k,
+ moe_num_experts,
+ moe_normalize_expert_weights,
+ )
+
+ # Flatten input
+ x_flat = x.view(-1, x.shape[-1])
+
+ # Determine alpha/limit for swiglu activation
+ use_alpha = alpha if activation == "swigluoai" else None
+ use_limit = limit if activation == "swigluoai" else None
+
+ # Call C++ optimized kernel
+ output = fused_moe_cpp(
+ hidden_states=x_flat,
+ w1=w1.data,
+ w2=w2.data,
+ topk_weights=expert_weights,
+ topk_ids=expert_indices.to(torch.int32),
+ inplace=False,
+ use_int8_w8a8=False,
+ use_fp8_w8a16=False,
+ use_mxfp4=self.use_mxfp4,
+ w1_scale=w1_scale,
+ w2_scale=w2_scale,
+ block_size=None,
+ a1_scale=None,
+ a2_scale=None,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ alpha=use_alpha,
+ limit=use_limit,
+ is_vnni=getattr(self, "packed_weight", False),
+ )
+
+ # Restore original shape
+ output = output.view(in_shape)
+
+ return output, expert_weights
+
+
+__all__ = ["fused_moe_cpp", "MegaBlocksMoeMLP"]
diff --git a/build/torch210-cxx11-cu126-aarch64-linux/grouped_gemm/__init__.py b/build/torch210-cxx11-cu126-aarch64-linux/grouped_gemm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b91c8308f0c24f4c4171b6e4f15b6f76dabf295a
--- /dev/null
+++ b/build/torch210-cxx11-cu126-aarch64-linux/grouped_gemm/__init__.py
@@ -0,0 +1,2 @@
+from . import ops
+from . import backend
diff --git a/build/torch210-cxx11-cu126-aarch64-linux/grouped_gemm/backend.py b/build/torch210-cxx11-cu126-aarch64-linux/grouped_gemm/backend.py
new file mode 100644
index 0000000000000000000000000000000000000000..76037d8039cbfc2f0577275c78e4bc0be762592a
--- /dev/null
+++ b/build/torch210-cxx11-cu126-aarch64-linux/grouped_gemm/backend.py
@@ -0,0 +1,33 @@
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# # TODO(tgale): Wrap this in a try-block with better
+# # error message and instructions for building the
+# # c++ operations.
+# import grouped_gemm_backend as backend
+
+# We import the backend operations from the megablocks package as
+# grouped_gemm is vendored in megablocks in this repository.
+# from ... import _ops as backend
+# from megablocks._ops import ops as backend # type: ignore
+from .._ops import ops as backend # type: ignore
+
+def _allocate_output(a, b, batch_sizes, trans_a, trans_b):
+ assert not (trans_a and trans_b)
+ assert batch_sizes.ndim == 1, "Expected 1d tensor for batch_sizes"
+ assert a.ndim == 2, "Expected 2d tensor for 'a'"
+ assert b.ndim == (2 if trans_a else 3)
+
+ shape = (
+ (batch_sizes.shape[0], a.shape[1], b.shape[1])
+ if trans_a else
+ (a.shape[0], (b.shape[1] if trans_b else b.shape[2]))
+ )
+ return torch.empty(*shape, device=a.device, dtype=a.dtype)
+
+def gmm(a, b, batch_sizes, trans_a=False, trans_b=False, c=None):
+ if c is None:
+ c = _allocate_output(a, b, batch_sizes, trans_a, trans_b)
+ backend.gmm(a, b, c, batch_sizes, trans_a, trans_b)
+ return c
diff --git a/build/torch210-cxx11-cu126-aarch64-linux/grouped_gemm/ops.py b/build/torch210-cxx11-cu126-aarch64-linux/grouped_gemm/ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b30dd14e23837ea3b12334f4e31337ed9ad2b69
--- /dev/null
+++ b/build/torch210-cxx11-cu126-aarch64-linux/grouped_gemm/ops.py
@@ -0,0 +1,33 @@
+from . import backend
+import torch
+
+
+class GroupedGemm(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, a, b, batch_sizes, trans_b):
+ ctx.save_for_backward(a, b, batch_sizes)
+ ctx.trans_b = trans_b
+ return backend.gmm(a, b, batch_sizes, trans_a=False, trans_b=trans_b)
+
+ @staticmethod
+ def backward(ctx, grad):
+ grad = grad.contiguous()
+ a, b, batch_sizes = ctx.saved_tensors
+ trans_b = ctx.trans_b
+
+ agrad = None
+ if ctx.needs_input_grad[0]:
+ agrad = backend.gmm(
+ grad, b, batch_sizes, trans_a=False, trans_b=not trans_b)
+
+ bgrad = None
+ if ctx.needs_input_grad[1]:
+ lhs, rhs = (grad, a) if trans_b else (a, grad)
+ bgrad = backend.gmm(
+ lhs, rhs, batch_sizes, trans_a=True, trans_b=False)
+ return agrad, bgrad, None, None
+
+
+def gmm(a, b, batch_sizes, trans_b=False):
+ return GroupedGemm.apply(a, b, batch_sizes, trans_b)
diff --git a/build/torch210-cxx11-cu126-aarch64-linux/grouped_gemm_util.py b/build/torch210-cxx11-cu126-aarch64-linux/grouped_gemm_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d6d49fc46b0a57ad46e4179df3cc1ac2a24f7ae
--- /dev/null
+++ b/build/torch210-cxx11-cu126-aarch64-linux/grouped_gemm_util.py
@@ -0,0 +1,31 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+import warnings
+
+_grouped_gemm_is_available: bool = False
+try:
+ # import grouped_gemm
+ pass
+ _grouped_gemm_is_available = True
+except ImportError as error:
+ warnings.warn('Grouped GEMM not available.')
+
+
+def grouped_gemm_is_available():
+ return _grouped_gemm_is_available
+
+
+def assert_grouped_gemm_is_available():
+ msg = (
+ 'Grouped GEMM not available. Please run '
+ '`pip install git+https://github.com/tgale96/grouped_gemm@main`.',
+ )
+ assert _grouped_gemm_is_available, msg
+
+
+# backend = grouped_gemm.backend if grouped_gemm_is_available() else None
+# ops = grouped_gemm.ops if grouped_gemm_is_available() else None
+
+
+from .grouped_gemm import backend as ops
+from .grouped_gemm import ops as backend
diff --git a/build/torch210-cxx11-cu126-aarch64-linux/layers.py b/build/torch210-cxx11-cu126-aarch64-linux/layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c91edcd9e2d1a4ef9eac90217ff481f08ab1886
--- /dev/null
+++ b/build/torch210-cxx11-cu126-aarch64-linux/layers.py
@@ -0,0 +1,1232 @@
+import torch
+import torch.distributed as dist
+
+from typing import Optional, Any, TYPE_CHECKING
+
+from . import _layers
+from . import ops
+
+# Conditional import for meta kernel registration
+if TYPE_CHECKING:
+
+ def register_fake(fn):
+ return lambda name: fn
+
+else:
+ try:
+ from torch.library import register_fake
+ except ImportError:
+ try:
+ from torch.library import impl_abstract as register_fake
+ except ImportError:
+ # Fallback for older PyTorch versions
+ def register_fake(op_name):
+ def decorator(fn):
+ return fn
+
+ return decorator
+
+
+# Meta kernel implementations for torch.compile compatibility
+def _install_meta_kernels():
+ """Install meta kernels for existing MegaBlocks operations"""
+
+ # Create wrapper functions that check for compilation and return meta tensors
+
+ # Patch ops.sort
+ if hasattr(ops, "sort"):
+ original_sort = ops.sort
+
+ def sort_with_meta(x, end_bit=None):
+ if torch.compiler.is_compiling():
+ # print("Using meta kernel for sort")
+ # Meta implementation - return tensors with correct shape/dtype/device
+ return torch.empty_like(x), torch.empty_like(x)
+ # print("Using original sort kernel")
+ return original_sort(x, end_bit)
+
+ ops.sort = sort_with_meta
+
+ # Patch ops.histogram
+ if hasattr(ops, "histogram"):
+ original_histogram = ops.histogram
+
+ def histogram_with_meta(x, max_val):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty((max_val,), dtype=torch.int32, device=x.device)
+ return original_histogram(x, max_val)
+
+ ops.histogram = histogram_with_meta
+
+ # Patch ops.inclusive_cumsum
+ if hasattr(ops, "inclusive_cumsum"):
+ original_inclusive_cumsum = ops.inclusive_cumsum
+
+ def inclusive_cumsum_with_meta(x, dim):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty_like(x)
+ return original_inclusive_cumsum(x, dim)
+
+ ops.inclusive_cumsum = inclusive_cumsum_with_meta
+
+ # Patch ops.binned_gather
+ if hasattr(ops, "binned_gather"):
+ original_binned_gather = ops.binned_gather
+
+ def binned_gather_with_meta(x, indices, bins, bin_size, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - output shape based on bin_size
+ if x.dim() >= 2:
+ hidden_size = x.size(-1)
+ return torch.empty(
+ (bin_size, x.size(1), hidden_size),
+ dtype=x.dtype,
+ device=x.device,
+ )
+ else:
+ return torch.empty((bin_size,), dtype=x.dtype, device=x.device)
+ return original_binned_gather(x, indices, bins, bin_size, top_k)
+
+ ops.binned_gather = binned_gather_with_meta
+
+ # Patch ops.binned_scatter
+ if hasattr(ops, "binned_scatter"):
+ original_binned_scatter = ops.binned_scatter
+
+ def binned_scatter_with_meta(x, indices, weights, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - typically reduces to 2D
+ if x.dim() >= 3:
+ return torch.empty(
+ (x.size(1), x.size(2)), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty_like(x)
+ return original_binned_scatter(x, indices, weights, bins, top_k)
+
+ ops.binned_scatter = binned_scatter_with_meta
+
+ # Patch ops.gather
+ if hasattr(ops, "gather"):
+ original_gather = ops.gather
+
+ def gather_with_meta(x, indices, bin_ids, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ if x.dim() >= 2:
+ hidden_size = x.size(-1)
+ return torch.empty(
+ (indices.numel(), hidden_size), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty(indices.shape, dtype=x.dtype, device=x.device)
+ return original_gather(x, indices, bin_ids, bins, top_k)
+
+ ops.gather = gather_with_meta
+
+ # Patch ops.scatter
+ if hasattr(ops, "scatter"):
+ original_scatter = ops.scatter
+
+ def scatter_with_meta(x, indices, bin_ids, weights, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - restore sequence shape
+ seq_len = (
+ indices.size(0) // top_k
+ if indices.numel() > 0 and top_k > 0
+ else x.size(0)
+ )
+ if x.dim() >= 2:
+ return torch.empty(
+ (seq_len, x.size(-1)), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty((seq_len,), dtype=x.dtype, device=x.device)
+ return original_scatter(x, indices, bin_ids, weights, bins, top_k)
+
+ ops.scatter = scatter_with_meta
+
+ # Patch ops.replicate
+ if hasattr(ops, "replicate"):
+ original_replicate = ops.replicate
+
+ def replicate_with_meta(x, bins, num_outputs):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty(
+ (x.shape[0], num_outputs), dtype=x.dtype, device=x.device
+ )
+ return original_replicate(x, bins, num_outputs)
+
+ ops.replicate = replicate_with_meta
+
+ # Patch ops.repeat (if it's a regular function)
+ if hasattr(ops, "repeat"):
+ original_repeat = ops.repeat
+
+ def repeat_with_meta(x, repeats):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ if isinstance(repeats, (tuple, list)):
+ new_shape = list(x.shape)
+ for i, rep in enumerate(repeats):
+ if i < len(new_shape):
+ new_shape[i] *= rep
+ return torch.empty(new_shape, dtype=x.dtype, device=x.device)
+ else:
+ new_shape = [x.size(0) * repeats] + list(x.shape[1:])
+ return torch.empty(new_shape, dtype=x.dtype, device=x.device)
+ return original_repeat(x, repeats)
+
+ ops.repeat = repeat_with_meta
+
+
+# Install meta kernels on import
+try:
+ _install_meta_kernels()
+except Exception as e:
+ # If meta kernel installation fails, continue without them
+ # torch.compile may not work but the library will still function
+ import warnings
+
+ warnings.warn(
+ f"Failed to install meta kernels for torch.compile support: {e}", UserWarning
+ )
+
+
+# Set the expert model parallel attributes on a tensor
+def set_expert_model_parallel_attributes(
+ tensor: torch.Tensor,
+ is_parallel: bool,
+):
+ assert not hasattr(tensor, "expert_model_parallel")
+ setattr(tensor, "expert_model_parallel", is_parallel)
+
+
+# Get the expert model parallel attributes from a tensor
+def expert_sharding_degree(
+ world_size: int,
+ moe_num_experts: int,
+) -> int:
+ esd = min(world_size, moe_num_experts)
+ if (moe_num_experts % esd) != 0:
+ raise ValueError(f"Cannot shard {moe_num_experts} experts {esd} ways.")
+ return esd
+
+
+# Calculate the hidden sharding degree based on world size and expert sharding degree
+def hidden_sharding_degree(
+ world_size: int,
+ moe_num_experts: int,
+ ffn_hidden_size: int,
+) -> int:
+ esd = expert_sharding_degree(world_size, moe_num_experts)
+ hsd = world_size // esd
+ if (ffn_hidden_size % hsd) != 0:
+ raise ValueError(f"Cannot shard {ffn_hidden_size} features {hsd} ways.")
+ if (esd * hsd) != world_size:
+ raise ValueError(
+ f"Invalid sharding. expert_sharding_degree ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size})."
+ )
+ return hsd
+
+
+# Calculate the number of experts per rank based on world size and expert sharding degree
+def experts_per_rank(
+ moe_num_experts: int,
+ world_size: int,
+) -> int:
+ return moe_num_experts // expert_sharding_degree(world_size, moe_num_experts)
+
+
+# Calculate the number of features per rank based on ffn hidden size and hidden sharding degree
+def features_per_rank(
+ ffn_hidden_size: int, world_size: int, moe_num_experts: int
+) -> int:
+ return ffn_hidden_size // hidden_sharding_degree(
+ world_size, moe_num_experts, ffn_hidden_size
+ )
+
+
+# Apply jitter to the input tensor
+def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor:
+ low = 1.0 - moe_jitter_eps
+ high = 1.0 + moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return x * (low + noise * (high - low))
+
+
+# Compute the top-k scores from the logits
+def compute_top_k(scores: torch.Tensor, moe_top_k: int):
+ if moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, moe_top_k, dim=-1)
+
+
+# Route tokens to experts and compute expert weights and indices
+def route_tokens(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: torch.Tensor,
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ if training and moe_jitter_eps is not None:
+ x = apply_jitter(x, moe_jitter_eps)
+
+ x_flat = x.view(-1, x.shape[-1])
+ logits = torch.nn.functional.linear(x_flat, router_weight, router_bias)
+ expert_weights, expert_indices = compute_top_k(logits, moe_top_k)
+ expert_weights = expert_weights.softmax(dim=-1)
+ if moe_normalize_expert_weights is not None:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+ if uniform_expert_assignment:
+ expert_indices = _layers.router._uniform_expert_assignment(
+ expert_indices,
+ moe_num_experts,
+ )
+
+ return logits, expert_weights, expert_indices
+
+
+# Scale the gradient of the weights
+def scale_grad(
+ w: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+) -> torch.Tensor:
+ if gradient_scale is None:
+ return w
+ return _layers.mlp.scale_gradient(w, gradient_scale)
+
+
+# Forward pass for the MLP layer
+def mlp_forward(
+ x: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ limit: float = 7.0,
+):
+ # Scale weights
+ w1 = scale_grad(w1, gradient_scale)
+ w2 = scale_grad(w2, gradient_scale)
+ w1_bias = scale_grad(w1_bias, gradient_scale)
+ w2_bias = scale_grad(w2_bias, gradient_scale)
+
+ # Resolve dtensors
+ w1 = _layers.mlp.resolve_dtensor(w1)
+ w2 = _layers.mlp.resolve_dtensor(w2)
+ w1_bias = _layers.mlp.resolve_dtensor(w1_bias)
+ w2_bias = _layers.mlp.resolve_dtensor(w2_bias)
+
+ # Forward pass
+ gate_up = torch.bmm(x, w1) + w1_bias[..., None, :]
+ gate, up = gate_up[..., ::2], gate_up[..., 1::2]
+ gate = gate.clamp(min=None, max=limit)
+ up = up.clamp(min=-limit, max=limit)
+ glu = gate * torch.sigmoid(gate * alpha)
+ next_states = torch.bmm(((up + 1) * glu), w2)
+ next_states += w2_bias[..., None, :]
+ return next_states
+
+# Shared expert MLP forward pass
+def shared_mlp_forward(
+ x: torch.Tensor,
+ up_proj_weight: torch.Tensor,
+ down_proj_weight: torch.Tensor,
+ up_proj_bias: Optional[torch.Tensor] = None,
+ down_proj_bias: Optional[torch.Tensor] = None,
+ activation_fn: Optional[Any] = None,
+ gradient_scale: Optional[float] = None,
+) -> torch.Tensor:
+ # Default activation function
+ if activation_fn is None:
+ activation_fn = torch.nn.functional.gelu
+
+ # Scale weights
+ up_proj_weight = scale_grad(up_proj_weight, gradient_scale)
+ down_proj_weight = scale_grad(down_proj_weight, gradient_scale)
+ if up_proj_bias is not None:
+ up_proj_bias = scale_grad(up_proj_bias, gradient_scale)
+ if down_proj_bias is not None:
+ down_proj_bias = scale_grad(down_proj_bias, gradient_scale)
+
+ # Resolve dtensors
+ up_proj_weight = _layers.mlp.resolve_dtensor(up_proj_weight)
+ down_proj_weight = _layers.mlp.resolve_dtensor(down_proj_weight)
+ if up_proj_bias is not None:
+ up_proj_bias = _layers.mlp.resolve_dtensor(up_proj_bias)
+ if down_proj_bias is not None:
+ down_proj_bias = _layers.mlp.resolve_dtensor(down_proj_bias)
+
+ # Up projection
+ x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias)
+
+ # Activation
+ x = activation_fn(x)
+
+ # Down projection
+ x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias)
+
+ return x
+
+
+# Combine outputs from shared expert and regular experts
+def combine_expert_shared_outputs(
+ shared_expert_out: torch.Tensor,
+ expert_out: torch.Tensor,
+ shared_expert_weighted_sum: bool = False,
+ moe_top_k: int = 1,
+) -> torch.Tensor:
+ if shared_expert_weighted_sum:
+ # Weighted sum based on number of experts used
+ total_experts = moe_top_k + 1
+ shared_weight = 1.0 / total_experts
+ expert_weight = moe_top_k / total_experts
+ return shared_expert_out * shared_weight + expert_out * expert_weight
+ else:
+ # Simple addition
+ return shared_expert_out + expert_out
+
+
+# Global variable to store load balancing loss
+_LOAD_BALANCING_LOSS = []
+
+
+def save_load_balancing_loss(loss):
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.append(loss)
+
+
+def get_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ return _LOAD_BALANCING_LOSS
+
+
+def clear_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.clear()
+
+
+def batched_load_balancing_loss(args):
+ if args.moe_loss_weight == 0:
+ return 0.0
+
+ tokens_per_expert, expert_scores = zip(*get_load_balancing_loss())
+ num_layers_per_pipeline_stage = args.num_layers // args.pipeline_model_parallel_size
+ if args.num_layers_per_virtual_pipeline_stage is not None:
+ num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage
+
+ if len(tokens_per_expert) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f"Expected {num_layers_per_pipeline_stage} token_per_experts "
+ f"but found {len(tokens_per_expert)}.\nnum_layers = "
+ f"{args.num_layers}\npipeline_model_parallel_size = "
+ f"{args.pipeline_model_parallel_size}\n"
+ "num_layers_per_virtual_pipeline_stage"
+ f" = {args.num_layers_per_virtual_pipeline_stage}",
+ )
+ if len(expert_scores) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f"Expected {num_layers_per_pipeline_stage} expert_scores "
+ f"but found {len(tokens_per_expert)}.\nnum_layers = "
+ f"{args.num_layers}\npipeline_model_parallel_size = "
+ f"{args.pipeline_model_parallel_size}\n"
+ "num_layers_per_virtual_pipeline_stage"
+ f" = {args.num_layers_per_virtual_pipeline_stage}",
+ )
+
+ # Verify the shape of the tokens_per_expert and expert_scores tensors.
+ assert all(
+ (x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert)
+ )
+
+ tokens = expert_scores[0].shape[0]
+ assert all(
+ (
+ (
+ x.ndim == 2
+ and x.shape[1] == args.moe_num_experts
+ and x.shape[0] == tokens
+ )
+ for x in expert_scores
+ )
+ )
+
+ # Concatenate the contributions of each layer and convert to
+ # the correct types and formats for the dot product.
+ expert_scores = torch.cat(expert_scores, dim=1)
+ if args.moe_lbl_in_fp32:
+ expert_scores = expert_scores.float()
+ if tokens != 0:
+ expert_scores = expert_scores.mean(dim=0)
+ else:
+ expert_scores = expert_scores.sum(dim=0)
+ tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype)
+
+ expected_values = num_layers_per_pipeline_stage * args.moe_num_experts
+ assert tokens_per_expert.numel() == expected_values
+ assert expert_scores.numel() == expected_values
+
+ # Calculate the total scale across all factors.
+ #
+ # loss_weight * num_experts / (num_layers * tokens * top_k)
+ scale_numerator = args.moe_num_experts * args.moe_loss_weight
+ scale_denominator = args.num_layers * tokens * args.moe_top_k
+ scale = scale_numerator / scale_denominator
+ return scale * torch.dot(tokens_per_expert, expert_scores)
+
+
+# Calculate the expert capacity based on tokens, top_k, number of experts,
+# expert parallel group, capacity factor, and whether expert model parallelism is used.
+def expert_capacity(
+ tokens: int,
+ top_k: int,
+ num_experts: int,
+ expert_parallel_group: int,
+ moe_capacity_factor: float,
+ moe_expert_model_parallelism: bool,
+) -> int:
+ world_size = (
+ dist.get_world_size(expert_parallel_group)
+ if moe_expert_model_parallelism
+ else 1
+ )
+
+ tokens_per_expert = top_k * tokens * world_size / num_experts
+ return int(moe_capacity_factor * tokens_per_expert)
+
+
+def load_balancing_loss(
+ tokens_per_expert: torch.Tensor,
+ expert_scores: torch.Tensor,
+ top_k: int,
+ num_experts: int,
+):
+ assert len(expert_scores.size()) == 2
+ tokens, num_experts = expert_scores.size()
+ assert num_experts == num_experts
+ assert len(tokens_per_expert.size()) == 1
+ (num_experts,) = tokens_per_expert.size()
+ assert num_experts == num_experts
+ scale = num_experts / (tokens * top_k)
+ return scale * torch.dot(
+ tokens_per_expert.to(expert_scores.dtype),
+ expert_scores.mean(dim=0),
+ )
+
+
+def indices_and_bins(
+ top_expert: torch.Tensor,
+ sort_end_bit: int,
+ num_experts: int,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ top_expert = top_expert.int()
+
+ # Ensure contiguous memory layout
+ top_expert = top_expert.contiguous()
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(top_expert.device):
+ output = ops.sort(top_expert, sort_end_bit)
+ bin_ids, indices = output
+ tokens_per_expert = ops.histogram(top_expert, num_experts)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+
+ bins = bins.view(1) if not len(bins.size()) else bins
+ return indices, bin_ids, bins, tokens_per_expert
+
+
+def expert_capacity_fn(
+ tokens: int,
+ top_k: int,
+ num_experts: int,
+ expert_parallel_group: torch.distributed.ProcessGroup,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+) -> int:
+ world_size = (
+ dist.get_world_size(expert_parallel_group)
+ if moe_expert_model_parallelism
+ else 1
+ )
+ tokens_per_expert = top_k * tokens * world_size / num_experts
+ return int(moe_capacity_factor * tokens_per_expert)
+
+
+def permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ top_k,
+ w1,
+ w2,
+ w1_bias,
+ w2_bias,
+ gradient_scale,
+ alpha,
+):
+ # Route tokens to experts
+ x = x.view(-1, x.shape[-1])
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(x.device):
+ x = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
+
+ # Expert computation
+ x = mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale, alpha)
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(x.device):
+ # Route tokens back
+ out = ops.binned_scatter(x, indices, expert_weights, bins, top_k)
+ return out
+
+
+def forward_once(
+ x: torch.Tensor,
+ expert_weights: torch.Tensor,
+ top_experts: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ top_k: int = 4,
+ num_experts: int = 128,
+ expert_parallel_group: int = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ mlp_impl: Optional[str] = None,
+):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
+ top_experts, sort_end_bit, num_experts
+ )
+
+ # Calculate expert capacity
+ sl, bs, _ = x.size()
+
+ expert_capacity = expert_capacity_fn(
+ sl * bs,
+ top_k,
+ num_experts,
+ expert_parallel_group,
+ moe_capacity_factor,
+ moe_expert_model_parallelism,
+ )
+
+ if expert_capacity == 0:
+ expert_capacity = torch.max(tokens_per_expert).item()
+
+ x = permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ top_k,
+ w1,
+ w2,
+ w1_bias,
+ w2_bias,
+ gradient_scale,
+ alpha,
+ )
+ return x, tokens_per_expert
+
+
+def parallel_forward_once(
+ x: torch.Tensor,
+ expert_weights: torch.Tensor,
+ top_experts: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ top_k: int = 4,
+ num_experts: int = 128,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = True,
+ hidden_size: int = 1152,
+ mlp_impl: Optional[str] = "grouped",
+):
+ # Flatten inputs
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+
+ # TODO: remove debugging var
+ # my_rank = dist.get_rank(expert_parallel_group) if expert_parallel_group else 0
+
+ with torch.no_grad():
+ # Step 1: Local permutation setup
+ indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
+ top_experts, sort_end_bit, num_experts
+ )
+
+ # Calculate sharding parameters
+ world_size = dist.get_world_size(expert_parallel_group)
+ hidden_sharding_deg = hidden_sharding_degree(
+ world_size, num_experts, hidden_size
+ )
+ experts_per_rank_val = experts_per_rank(num_experts, world_size)
+
+ # Replicate token counts for hidden sharding
+ repeated_tokens_per_expert = ops.repeat(
+ tokens_per_expert, (hidden_sharding_deg,)
+ )
+
+ # Exchange token counts across devices
+ parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert)
+
+ # Ensure CUB knows which device to use
+ tpe_handle = dist.all_to_all_single(
+ parallel_tokens_per_expert,
+ repeated_tokens_per_expert,
+ group=expert_parallel_group,
+ async_op=True,
+ )
+
+ # Step 2: Local permutation - group tokens by target device
+ x = x.view(-1, x.shape[-1]) # [sl * bs, hs]
+ x = ops.gather(x, indices, bin_ids, bins, top_k)
+
+ # Step 3: Compute communication counts and exchange tokens
+ with torch.no_grad():
+ tpe_handle.wait()
+
+ # Reshape for per-device calculations
+ repeated_tokens_per_expert = repeated_tokens_per_expert.view(
+ world_size, experts_per_rank_val
+ )
+ parallel_tokens_per_expert = parallel_tokens_per_expert.view(
+ world_size, experts_per_rank_val
+ )
+
+ # Calculate send/recv counts
+ send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1).tolist()
+ # recv_counts = parallel_tokens_per_expert.cpu().sum(dim=-1).tolist()
+ parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
+ recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1).tolist()
+ tokens_received = sum(recv_counts)
+
+ # Replicate for hidden sharding
+ x = ops.repeat(x, (hidden_sharding_deg, 1))
+
+ # Cross-device token exchange
+ parallel_x, parallel_x_handle = _layers.all_to_all.all_to_all(
+ x, recv_counts, send_counts, expert_parallel_group, async_op=True
+ )
+
+ with torch.no_grad():
+ # Step 4: Setup for local expert computation
+ replicate_bins = ops.inclusive_cumsum(parallel_tokens_per_expert.flatten(), 0)
+ replicate_bins = (
+ replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins
+ )
+
+ # Create expert indices for received tokens
+ parallel_top_expert = torch.remainder(
+ torch.arange(
+ num_experts * hidden_sharding_deg,
+ dtype=torch.int32,
+ device=indices.device,
+ ),
+ experts_per_rank_val,
+ )
+ parallel_top_expert = ops.replicate(
+ parallel_top_expert.unsqueeze(dim=0),
+ replicate_bins,
+ tokens_received,
+ ).flatten()
+
+ # Sort tokens by expert assignment
+ parallel_bin_ids, parallel_indices = ops.sort(
+ parallel_top_expert,
+ sort_end_bit,
+ )
+
+ # Calculate bins for local experts
+ parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
+ dim=0, dtype=torch.int
+ )
+ parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = (
+ parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins
+ )
+
+ # Calculate expert capacity
+ expert_capacity = expert_capacity_fn(
+ tokens_received,
+ top_k,
+ experts_per_rank_val,
+ expert_parallel_group,
+ moe_capacity_factor,
+ moe_expert_model_parallelism,
+ )
+ if expert_capacity == 0:
+ expert_capacity = torch.max(parallel_tokens_per_expert).item()
+
+ # Locally permute the tokens and perform the expert computation.
+ # Block to make sure that the cross-device permutation is complete.
+ if mlp_impl == "grouped":
+ # GroupedMLP requires counts on CPU. We can use the tensor already
+ # moved to CPU for the prior all_to_all, which avoids an extra
+ # device synchronization.
+ parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+
+ # Step 5: Expert computation
+ parallel_x_handle.wait()
+
+ parallel_x = permute_and_compute(
+ parallel_x,
+ parallel_tokens_per_expert,
+ parallel_indices,
+ parallel_bin_ids,
+ None, # expert_weights
+ parallel_bins,
+ expert_capacity,
+ top_k=1,
+ w1=w1,
+ w2=w2,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ )
+
+ # Step 6: Reverse communication - send results back
+ x, _ = _layers.all_to_all.all_to_all(
+ parallel_x, send_counts, recv_counts, expert_parallel_group
+ )
+
+ # Step 7: Reduce across hidden sharding dimension
+ shape = (hidden_sharding_deg, -1, hidden_size)
+ x = x.view(shape).sum(dim=0)
+
+ # Step 8: Final local unpermutation
+ x = ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
+
+ return x, tokens_per_expert.flatten()
+
+
+def moe_forward(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: Optional[torch.Tensor],
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+ w1: torch.Tensor = None,
+ w2: torch.Tensor = None,
+ w1_bias: torch.Tensor = None,
+ w2_bias: torch.Tensor = None,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ forward_fn: Any = None,
+ hidden_size: int = None,
+ mlp_impl: str = "grouped",
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+
+ # Route tokens to experts
+ logits, expert_weights, expert_indices = route_tokens(
+ x,
+ router_weight,
+ router_bias,
+ moe_top_k,
+ moe_num_experts,
+ moe_jitter_eps,
+ moe_normalize_expert_weights,
+ uniform_expert_assignment,
+ training,
+ )
+
+ # Create router scores for output
+ router_scores = (
+ torch.zeros_like(logits)
+ .scatter_(1, expert_indices, expert_weights)
+ .transpose(0, 1)
+ )
+
+ in_shape = x.size()
+
+ # Prepare forward function arguments
+ forward_args = {
+ "x": x,
+ "expert_weights": expert_weights,
+ "top_experts": expert_indices,
+ "w1": w1,
+ "w2": w2,
+ "w1_bias": w1_bias,
+ "w2_bias": w2_bias,
+ "gradient_scale": gradient_scale,
+ "alpha": alpha,
+ "sort_end_bit": sort_end_bit,
+ "top_k": moe_top_k,
+ "num_experts": moe_num_experts,
+ "expert_parallel_group": expert_parallel_group,
+ "moe_capacity_factor": moe_capacity_factor,
+ "moe_expert_model_parallelism": moe_expert_model_parallelism,
+ "mlp_impl": mlp_impl,
+ }
+
+ # Add hidden_size for parallel forward
+ if moe_expert_model_parallelism and hidden_size is not None:
+ forward_args["hidden_size"] = hidden_size
+ elif moe_expert_model_parallelism and hidden_size is None:
+ # Infer hidden_size from input shape
+ forward_args["hidden_size"] = x.shape[-1]
+
+ # Compute expert outputs
+ x, tokens_per_expert = forward_fn(**forward_args)
+
+ # Save load balancing loss if needed
+ moe_loss_weight = 0.0 # Can be made configurable
+ if training and moe_loss_weight > 0:
+ save_load_balancing_loss((tokens_per_expert, logits))
+
+ # Restore original shape
+ x = x.view(in_shape)
+
+ return x, expert_weights, router_scores
+
+
+def moe_forward_with_shared_expert(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: Optional[torch.Tensor],
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+ w1: torch.Tensor = None,
+ w2: torch.Tensor = None,
+ w1_bias: torch.Tensor = None,
+ w2_bias: torch.Tensor = None,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ forward_fn: Any = None,
+ hidden_size: int = None,
+ mlp_impl: str = "grouped",
+ # Shared expert parameters
+ shared_up_proj_weight: Optional[torch.Tensor] = None,
+ shared_down_proj_weight: Optional[torch.Tensor] = None,
+ shared_up_proj_bias: Optional[torch.Tensor] = None,
+ shared_down_proj_bias: Optional[torch.Tensor] = None,
+ shared_expert_weighted_sum: bool = False,
+ shared_activation_fn: Optional[Any] = None,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+
+ # First, compute regular MoE forward pass
+ expert_out, expert_weights, router_scores = moe_forward(
+ x=x,
+ router_weight=router_weight,
+ router_bias=router_bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=training,
+ w1=w1,
+ w2=w2,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=moe_expert_model_parallelism,
+ forward_fn=forward_fn,
+ hidden_size=hidden_size,
+ mlp_impl=mlp_impl,
+ )
+
+ # If shared expert weights provided, compute shared expert output
+ if shared_up_proj_weight is not None and shared_down_proj_weight is not None:
+ shared_expert_out = shared_mlp_forward(
+ x=x,
+ up_proj_weight=shared_up_proj_weight,
+ down_proj_weight=shared_down_proj_weight,
+ up_proj_bias=shared_up_proj_bias,
+ down_proj_bias=shared_down_proj_bias,
+ activation_fn=shared_activation_fn,
+ gradient_scale=gradient_scale,
+ )
+
+ # Combine expert outputs
+ combined_out = combine_expert_shared_outputs(
+ shared_expert_out=shared_expert_out,
+ expert_out=expert_out,
+ shared_expert_weighted_sum=shared_expert_weighted_sum,
+ moe_top_k=moe_top_k,
+ )
+
+ return combined_out, expert_weights, router_scores
+
+ # Return regular MoE output if no shared expert
+ return expert_out, expert_weights, router_scores
+
+
+def create_shared_expert_weights(
+ hidden_size: int,
+ shared_expert_hidden_size: int,
+ device: torch.device,
+ dtype: torch.dtype,
+ init_method: Any,
+ output_layer_init_method: Any = None,
+) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
+
+ if output_layer_init_method is None:
+ output_layer_init_method = init_method
+
+ # Create weight tensors
+ up_proj_weight = torch.empty(
+ shared_expert_hidden_size,
+ hidden_size,
+ device=device,
+ dtype=dtype,
+ )
+ down_proj_weight = torch.empty(
+ hidden_size,
+ shared_expert_hidden_size,
+ device=device,
+ dtype=dtype,
+ )
+
+ # Initialize weights
+ init_method(up_proj_weight)
+ output_layer_init_method(down_proj_weight)
+
+ # No bias by default
+ return up_proj_weight, down_proj_weight, None, None
+
+
+# HACK: Extract device_mesh from pre-hook closure - required for transformers integration
+# This exists because device_mesh is trapped in hook closures with no model attribute
+# Fragile - breaks if hook structure changes or Python internals change
+# TODO: Replace with a more robust solution when available
+def get_device_mesh(model):
+ # Extract device_mesh from child's unused pre_hook closure
+ try:
+ # Find the pre-hook that contains 'device_mesh' in its closure
+ hook = next(
+ h
+ for h in model.experts._forward_pre_hooks.values()
+ if "device_mesh" in h.__code__.co_freevars
+ )
+ # Extract the device_mesh from the closure
+ return hook.__closure__[
+ hook.__code__.co_freevars.index("device_mesh")
+ ].cell_contents
+ except Exception:
+ return None
+
+
+class MegaBlocksMoeMLP(torch.nn.Module):
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ gradient_scale = getattr(self.experts, "gradient_scale", None)
+ alpha = getattr(self.experts, "alpha", 1.0)
+ moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+ uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
+
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = get_device_mesh(self)
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
+
+ has_parallel = (
+ expert_parallel_group is not None
+ and dist.is_initialized()
+ and dist.get_world_size(expert_parallel_group) > 1
+ )
+ forward_fn = parallel_forward_once if has_parallel else forward_once
+
+ sort_end_bit = max(
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
+ )
+ mlp_impl = getattr(self, "mlp_impl", "grouped")
+ output, expert_weights_out, *_ = moe_forward(
+ x=x,
+ router_weight=self.router.weight,
+ router_bias=self.router.bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=self.training,
+ w1=self.experts.gate_up_proj,
+ w2=self.experts.down_proj,
+ w1_bias=self.experts.gate_up_proj_bias,
+ w2_bias=self.experts.down_proj_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=has_parallel,
+ forward_fn=forward_fn,
+ hidden_size=self.experts.hidden_size,
+ mlp_impl=mlp_impl,
+ )
+ return output, expert_weights_out
+
+
+# Export main classes
+__all__ = ["MegaBlocksMoeMLP", "MegaBlocksMoeMLPWithSharedExpert"]
+
+
+class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
+
+ def __init__(self):
+ super().__init__()
+ # Shared expert weights will be set by the user
+ self.shared_up_proj_weight = None
+ self.shared_down_proj_weight = None
+ self.shared_up_proj_bias = None
+ self.shared_down_proj_bias = None
+ self.shared_expert_weighted_sum = False
+ self.shared_activation_fn = None
+
+ def set_shared_expert_weights(
+ self,
+ up_proj_weight: torch.Tensor,
+ down_proj_weight: torch.Tensor,
+ up_proj_bias: Optional[torch.Tensor] = None,
+ down_proj_bias: Optional[torch.Tensor] = None,
+ weighted_sum: bool = False,
+ activation_fn: Optional[Any] = None,
+ ):
+ self.shared_up_proj_weight = up_proj_weight
+ self.shared_down_proj_weight = down_proj_weight
+ self.shared_up_proj_bias = up_proj_bias
+ self.shared_down_proj_bias = down_proj_bias
+ self.shared_expert_weighted_sum = weighted_sum
+ self.shared_activation_fn = activation_fn
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ gradient_scale = getattr(self.experts, "gradient_scale", None)
+ alpha = getattr(self.experts, "alpha", 1.0)
+ moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+ uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
+
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = get_device_mesh(self)
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
+
+ has_parallel = (
+ expert_parallel_group is not None
+ and dist.is_initialized()
+ and dist.get_world_size(expert_parallel_group) > 1
+ )
+ forward_fn = parallel_forward_once if has_parallel else forward_once
+
+ sort_end_bit = max(
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
+ )
+ mlp_impl = getattr(self, "mlp_impl", "grouped")
+
+ output, expert_weights_out, *_ = moe_forward_with_shared_expert(
+ x=x,
+ router_weight=self.router.weight,
+ router_bias=self.router.bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=self.training,
+ w1=self.experts.gate_up_proj,
+ w2=self.experts.down_proj,
+ w1_bias=self.experts.gate_up_proj_bias,
+ w2_bias=self.experts.down_proj_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=has_parallel,
+ forward_fn=forward_fn,
+ hidden_size=self.experts.hidden_size,
+ mlp_impl=mlp_impl,
+ # Shared expert parameters
+ shared_up_proj_weight=self.shared_up_proj_weight,
+ shared_down_proj_weight=self.shared_down_proj_weight,
+ shared_up_proj_bias=self.shared_up_proj_bias,
+ shared_down_proj_bias=self.shared_down_proj_bias,
+ shared_expert_weighted_sum=self.shared_expert_weighted_sum,
+ shared_activation_fn=self.shared_activation_fn,
+ )
+ return output, expert_weights_out
+
+
+# Patch for XPU or CPU support
+if hasattr(torch, "xpu") and torch.xpu.is_available():
+ from .xpu_fused_moe import MegaBlocksMoeMLP
+
+from .cpu_moe_cpp import CPUMegaBlocksMoeMLP
diff --git a/build/torch210-cxx11-cu126-aarch64-linux/megablocks/__init__.py b/build/torch210-cxx11-cu126-aarch64-linux/megablocks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23
--- /dev/null
+++ b/build/torch210-cxx11-cu126-aarch64-linux/megablocks/__init__.py
@@ -0,0 +1,26 @@
+import ctypes
+import importlib.util
+import sys
+from pathlib import Path
+from types import ModuleType
+
+
+def _import_from_path(file_path: Path) -> ModuleType:
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
+ # it would also be used for other imports. So, we make a module name that
+ # depends on the path for it to be unique using the hex-encoded hash of
+ # the path.
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
+ module_name = path_hash
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ if spec is None:
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
+ module = importlib.util.module_from_spec(spec)
+ if module is None:
+ raise ImportError(f"Cannot load module {module_name} from spec")
+ sys.modules[module_name] = module
+ spec.loader.exec_module(module) # type: ignore
+ return module
+
+
+globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
diff --git a/build/torch210-cxx11-cu126-aarch64-linux/metadata.json b/build/torch210-cxx11-cu126-aarch64-linux/metadata.json
new file mode 100644
index 0000000000000000000000000000000000000000..155112c59509d3b4d07f4d090cbf57071e3f5217
--- /dev/null
+++ b/build/torch210-cxx11-cu126-aarch64-linux/metadata.json
@@ -0,0 +1,18 @@
+{
+ "version": 1,
+ "license": "Apache-2.0",
+ "python-depends": [],
+ "backend": {
+ "type": "cuda",
+ "archs": [
+ "7.0",
+ "7.2",
+ "7.5",
+ "8.0",
+ "8.6",
+ "8.7",
+ "8.9",
+ "9.0"
+ ]
+ }
+}
diff --git a/build/torch210-cxx11-cu126-aarch64-linux/ops/__init__.py b/build/torch210-cxx11-cu126-aarch64-linux/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b944080df810d0b0cfc571f3009b0098a651f9b7
--- /dev/null
+++ b/build/torch210-cxx11-cu126-aarch64-linux/ops/__init__.py
@@ -0,0 +1,35 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from .binned_gather import binned_gather
+from .binned_scatter import binned_scatter
+from .cumsum import exclusive_cumsum, inclusive_cumsum
+from .gather import gather
+from .histogram import histogram
+from .padded_gather import padded_gather
+from .padded_scatter import padded_scatter
+from .repeat import repeat
+from .replicate import replicate
+from .round_up import round_up
+from .scatter import scatter
+from .sort import sort
+from .sum import sum
+from .topology import topology
+
+__all__ = [
+ 'binned_gather',
+ 'binned_scatter',
+ 'exclusive_cumsum',
+ 'inclusive_cumsum',
+ 'gather',
+ 'histogram',
+ 'padded_gather',
+ 'padded_scatter',
+ 'repeat',
+ 'replicate',
+ 'round_up',
+ 'scatter',
+ 'sort',
+ 'sum',
+ 'topology',
+]
diff --git a/build/torch210-cxx11-cu126-aarch64-linux/ops/all_to_all_benchmark.py b/build/torch210-cxx11-cu126-aarch64-linux/ops/all_to_all_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c939818edca3345f6344bbc7cef07ffe3cd0181
--- /dev/null
+++ b/build/torch210-cxx11-cu126-aarch64-linux/ops/all_to_all_benchmark.py
@@ -0,0 +1,63 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import torch.distributed as dist
+
+# from megablocks import benchmark_util
+# from megablocks.layers.all_to_all import all_to_all
+
+from .. import benchmark_util
+from .._layers.all_to_all import all_to_all
+
+_ALL_TO_ALL_BENCHMARK = (
+ (8, 1024),
+ (16, 1024),
+ (32, 1024),
+ (64, 1024),
+ (128, 1024),
+ (256, 1024),
+ (512, 1024),
+ (1024, 1024),
+ (2 * 1024, 1024),
+ (4 * 1024, 1024),
+ (8 * 1024, 1024),
+ (16 * 1024, 1024),
+ (32 * 1024, 1024),
+ (64 * 1024, 1024),
+ (128 * 1024, 1024),
+ (256 * 1024, 1024),
+ (512 * 1024, 1024),
+ (1024 * 1024, 1024),
+)
+
+
+def benchmark_all_to_all(group, sl, hs):
+ world_size = dist.get_world_size(group)
+ assert (sl % world_size) == 0
+ send_recv_sizes = [sl // world_size] * world_size
+
+ x = torch.randn((sl, hs)).cuda().half()
+
+ details = {
+ 'world_size': world_size,
+ 'message_size (B)': send_recv_sizes[0] * hs * 2, # 2B elements.
+ }
+
+ def benchmark():
+ return all_to_all(x, send_recv_sizes, send_recv_sizes, group)
+
+ time, std = benchmark_util.benchmark_function(benchmark)
+
+ if dist.get_rank(group) == 0:
+ benchmark_util.log_benchmark('All-To-All', details, time, std)
+
+
+if __name__ == '__main__':
+ assert dist.is_available()
+ group = dist.init_process_group(backend='nccl')
+ local_rank = dist.get_rank(group)
+ torch.cuda.set_device(local_rank)
+
+ for args in _ALL_TO_ALL_BENCHMARK:
+ benchmark_all_to_all(group, *args)
diff --git a/build/torch210-cxx11-cu126-aarch64-linux/ops/binned_gather.py b/build/torch210-cxx11-cu126-aarch64-linux/ops/binned_gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..189a7fa3518d660f29ea32e7a04827164af98d60
--- /dev/null
+++ b/build/torch210-cxx11-cu126-aarch64-linux/ops/binned_gather.py
@@ -0,0 +1,37 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for binned_gather kernel.
+class BinnedGatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bins: torch.Tensor,
+ bin_size: int,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bins)
+ ctx.top_k = top_k
+ return kernels.binned_gather(x, indices, None, bins, bin_size, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ indices, bins = ctx.saved_tensors
+ out = kernels.binned_scatter(grad, indices, None, bins, ctx.top_k)
+ return out, None, None, None, None
+
+
+binned_gather = BinnedGatherOp.apply
diff --git a/build/torch210-cxx11-cu126-aarch64-linux/ops/binned_scatter.py b/build/torch210-cxx11-cu126-aarch64-linux/ops/binned_scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb937c0c106662ce8108c1cb926f8f063b163d3d
--- /dev/null
+++ b/build/torch210-cxx11-cu126-aarch64-linux/ops/binned_scatter.py
@@ -0,0 +1,59 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for binned_scatter kernel.
+class BinnedScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ):
+ assert len(x.size()) == 3
+ ctx.bin_size = x.size(1)
+ ctx.top_k = top_k
+
+ # TODO(tgale): Don't save 'x' for backwards if we don't need to
+ # calculate the gradient w.r.t. 'weights'.
+ ctx.save_for_backward(x, indices, weights, bins)
+ return kernels.binned_scatter(x, indices, weights, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ x, indices, weights, bins = ctx.saved_tensors
+ out = kernels.binned_gather(
+ grad,
+ indices,
+ weights,
+ bins,
+ ctx.bin_size,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[2]:
+ wgrad = kernels.binned_scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bins,
+ ctx.top_k,
+ )
+ return out, None, wgrad, None, None
+
+
+binned_scatter = BinnedScatterOp.apply
diff --git a/build/torch210-cxx11-cu126-aarch64-linux/ops/cumsum.py b/build/torch210-cxx11-cu126-aarch64-linux/ops/cumsum.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2b7572391e20045d335cf7337246e8a9b9f57ef
--- /dev/null
+++ b/build/torch210-cxx11-cu126-aarch64-linux/ops/cumsum.py
@@ -0,0 +1,52 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ # import megablocks_ops as ops # type: ignore
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrappers for cumsum kernels.
+# NOTE: Does not support gradients.
+class ExclusiveCumsumOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, dim: int):
+ if len(x.size()) == 1:
+ x = x.view([1, -1])
+ out = torch.empty_like(x)
+ ops.exclusive_cumsum(x, 1, out)
+ return out.squeeze()
+ out = torch.empty_like(x)
+ ops.exclusive_cumsum(x, dim, out)
+ return out
+
+
+exclusive_cumsum = ExclusiveCumsumOp.apply
+
+
+class InclusiveCumsumOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, dim: int) -> torch.Tensor:
+ if len(x.size()) == 1:
+ x = x.view([1, -1])
+ out = torch.empty_like(x)
+ ops.inclusive_cumsum(x, 1, out)
+ return out.squeeze()
+ out = torch.empty_like(x)
+ ops.inclusive_cumsum(x, dim, out)
+ return out
+
+
+inclusive_cumsum = InclusiveCumsumOp.apply
diff --git a/build/torch210-cxx11-cu126-aarch64-linux/ops/gather.py b/build/torch210-cxx11-cu126-aarch64-linux/ops/gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1f87c1e7bed8d3589dd790805234976e0b05898
--- /dev/null
+++ b/build/torch210-cxx11-cu126-aarch64-linux/ops/gather.py
@@ -0,0 +1,38 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for gather kernel.
+class GatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bin_ids, bins)
+ ctx.top_k = top_k
+ return kernels.gather(x, indices, bin_ids, None, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+
+ indices, bin_ids, bins = ctx.saved_tensors
+ out = kernels.scatter(grad, indices, bin_ids, None, bins, ctx.top_k)
+ return out, None, None, None, None, None
+
+
+gather = GatherOp.apply
diff --git a/build/torch210-cxx11-cu126-aarch64-linux/ops/histogram.py b/build/torch210-cxx11-cu126-aarch64-linux/ops/histogram.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b3f058ec373cbba7555704fb5e4212c3cc75d9d
--- /dev/null
+++ b/build/torch210-cxx11-cu126-aarch64-linux/ops/histogram.py
@@ -0,0 +1,27 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for histogram kernel.
+# NOTE: Does not support gradients.
+class HistogramOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, max_val: float):
+ return ops.histogram(x, max_val)
+
+
+histogram = HistogramOp.apply
diff --git a/build/torch210-cxx11-cu126-aarch64-linux/ops/histogram_benchmark.py b/build/torch210-cxx11-cu126-aarch64-linux/ops/histogram_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0b9c6047567b87a295979498142230d1b0c9411
--- /dev/null
+++ b/build/torch210-cxx11-cu126-aarch64-linux/ops/histogram_benchmark.py
@@ -0,0 +1,78 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import numpy as np
+import torch
+# from absl.testing import parameterized
+
+from .. import ops
+
+_HISTOGRAM_TESTS = (
+ (16384, torch.int32, 2),
+ (16384, torch.int32, 4),
+ (16384, torch.int32, 8),
+ (16384, torch.int32, 16),
+ (16384, torch.int32, 32),
+ (16384, torch.int32, 64),
+ (16384, torch.int32, 128),
+ (16384, torch.int32, 256),
+)
+
+
+def benchmark_function(fn, iterations=10):
+ # Run once to get rid of startup overhead.
+ fn()
+ times = []
+ for _ in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+ start.record()
+ fn()
+ end.record()
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ times = np.array(times)
+ return times.mean(), times.std(), times.max(), times.min()
+
+
+def log_benchmark(arguments, mean_t, std_t):
+ print('=' * 60)
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t))
+ print('=' * 60)
+
+
+# class HistogramBenchmark(parameterized.TestCase):
+#
+# @parameterized.parameters(*_HISTOGRAM_TESTS)
+# def testHistogram(self, n, dtype, max_val):
+# x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
+#
+# mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.histogram(x, max_val),)
+# arguments = {
+# 'n': n,
+# 'dtype': dtype,
+# 'max_val': max_val,
+# }
+# log_benchmark(arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_HISTOGRAM_TESTS)
+# def testTorchHistogram(self, n, dtype, max_val):
+# x = torch.randint(0, 128, (n,)).cuda().to(dtype)
+#
+# mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.histc(x, max_val, 0, max_val - 1),)
+# arguments = {
+# 'n': n,
+# 'dtype': dtype,
+# 'max_val': max_val,
+# }
+# log_benchmark(arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch210-cxx11-cu126-aarch64-linux/ops/matmul_benchmark.py b/build/torch210-cxx11-cu126-aarch64-linux/ops/matmul_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c4b9b8866ffed2eb769b77f2320c82e5491ae0e
--- /dev/null
+++ b/build/torch210-cxx11-cu126-aarch64-linux/ops/matmul_benchmark.py
@@ -0,0 +1,415 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+
+# import stk
+
+# try:
+# import stk
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/ops/matmul_benchmark.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+# from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+
+# Calling tensor.t() calls tensor.transpose(0, 1) which calls
+# torch.as_strided(...). Circumvent this chain to avoid an overhead
+# this adds.
+def transpose_view(x):
+ return torch.as_strided(
+ x,
+ (x.shape[1], x.shape[0]),
+ (x.stride()[1], x.stride()[0]),
+ )
+
+
+_MATMUL_TESTS = (
+ (64 * 1024, 512, 2048, 64),
+ (32 * 1024, 768, 3072, 64),
+ (8 * 1024, 1024, 4096, 64),
+ (4 * 2048, 4096, 4 * 4096, 4),
+)
+
+
+def log_benchmark(name, arguments, time, std, flops):
+ benchmark_util.log_benchmark(name, arguments, time, std)
+ print('flops = {:.2f}B'.format(flops / 1e9))
+ print('throughput = {:.2f}T'.format(flops / 1e9 / time))
+ print('=' * 60)
+
+
+# class MatmulBenchmark(parameterized.TestCase):
+#
+# def build_sparse_matrix(self, x, padded_bins, fhs, ne):
+# blocking = 128
+# padded_tokens, _ = x.size()
+# assert padded_tokens % blocking == 0
+# assert fhs % blocking == 0
+#
+# # Offsets for the sparse matrix. All rows have the
+# # same number of nonzero blocks dictated by the
+# # dimensionality of a single expert.
+# block_rows = padded_tokens // blocking
+# blocks_per_row = fhs // blocking
+# offsets = torch.arange(
+# 0,
+# block_rows * blocks_per_row + 1,
+# blocks_per_row,
+# dtype=torch.int32,
+# device=x.device,
+# )
+#
+# # Indices for the sparse matrix. The indices for
+# # the intermediate matrix are dynamic depending
+# # on the mapping of tokens to experts.
+# column_indices = ops.topology(
+# padded_bins,
+# blocking,
+# block_rows,
+# blocks_per_row,
+# )
+# data = torch.empty(
+# column_indices.numel(),
+# blocking,
+# blocking,
+# dtype=torch.float16,
+# device=x.device,
+# )
+# shape = (padded_tokens, fhs * ne)
+# row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
+# return stk.Matrix(shape, data, row_indices, column_indices, offsets)
+#
+# def build_input_matrix(self, sl, hs, ne):
+# x = torch.randn((sl, hs)).cuda().half()
+#
+# # Assign tokens to experts uniformly.
+# top_expert = torch.arange(0, sl).cuda().int() % ne
+#
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(top_expert, ne)
+# padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+# padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+# out = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, 1)
+# return out, padded_bins
+#
+# def build_weight_matrix(self, ne, hs, fhs):
+# return torch.randn((hs, ne * fhs)).cuda().half()
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_Fwd_SDD_NT(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+# w = transpose_view(w)
+#
+# def benchmark():
+# return stk.ops.sdd(x, w, topo)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0::Fwd::SDD::NT',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_GradX_DSD_NN(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+#
+# def benchmark():
+# return stk.ops.dsd(topo, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0::GradX::DSD::NN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_GradW_DSD_TN(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+# topo = topo.t()
+#
+# def benchmark():
+# return stk.ops.dsd(topo, x)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0::GradW::DSD::TN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_Fwd_DSD_NN(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+#
+# def benchmark():
+# return stk.ops.dsd(x, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::Fwd::DSD::NN',
+# arguments,
+# mean_t,
+# std_t,
+# x.nnz * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_GradX_SDD_NT(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+# out = stk.ops.dsd(x, w)
+# w = transpose_view(w)
+#
+# def benchmark():
+# return stk.ops.sdd(out, w, x)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::GradX::SDD::NT',
+# arguments,
+# mean_t,
+# std_t,
+# x.nnz * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_GradW_DSD_TN(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+# out = stk.ops.dsd(x, w)
+# x = x.t()
+#
+# def benchmark():
+# return stk.ops.dsd(x, out)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::GradW::DSD::TN',
+# arguments,
+# mean_t,
+# std_t,
+# x.nnz * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_Fwd_DDD_NT(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, hs)).cuda().half()
+# w = torch.randn((ne, hs, fhs)).cuda().half()
+#
+# w = w.transpose(1, 2).contiguous()
+# w = w.transpose(1, 2)
+#
+# def benchmark():
+# return torch.bmm(x, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0::Fwd:DDD::NT',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_GradX_DDD_NN(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, hs)).cuda().half()
+# w = torch.randn((ne, hs, fhs)).cuda().half()
+# out = torch.bmm(x, w)
+# w = w.transpose(1, 2).contiguous()
+#
+# def benchmark():
+# return torch.bmm(out, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0:GradX:DDD::NN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_GradW_DDD_TN(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, hs)).cuda().half()
+# w = torch.randn((ne, hs, fhs)).cuda().half()
+# out = torch.bmm(x, w)
+# out = out.transpose(1, 2)
+#
+# def benchmark():
+# return torch.bmm(out, x)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0:GradW:DDD::TN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_Fwd_DDD_NN(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+# w = torch.randn((ne, fhs, hs)).cuda().half()
+#
+# def benchmark():
+# return torch.bmm(x, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::Fwd::DDD::NN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_GradX_DDD_NT(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+# w = torch.randn((ne, fhs, hs)).cuda().half()
+# out = torch.bmm(x, w)
+# w = torch.transpose(w, 1, 2)
+#
+# def benchmark():
+# return torch.bmm(out, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::GradX::DDD::NT',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_GradW_DDD_TN(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+# w = torch.randn((ne, fhs, hs)).cuda().half()
+# out = torch.bmm(x, w)
+# x = torch.transpose(x, 1, 2)
+#
+# def benchmark():
+# return torch.bmm(x, out)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::GradW::DDD::TN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * hs * 2,
+# )
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch210-cxx11-cu126-aarch64-linux/ops/padded_gather.py b/build/torch210-cxx11-cu126-aarch64-linux/ops/padded_gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1cf4047c9494394d2a3884ba8830179013db7ff
--- /dev/null
+++ b/build/torch210-cxx11-cu126-aarch64-linux/ops/padded_gather.py
@@ -0,0 +1,55 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for padded_gather kernel.
+class PaddedGatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bin_ids, bins, padded_bins)
+ ctx.top_k = top_k
+ return kernels.padded_gather(
+ x,
+ indices,
+ bin_ids,
+ None,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+
+ indices, bin_ids, bins, padded_bins = ctx.saved_tensors
+ out = kernels.padded_scatter(
+ grad,
+ indices,
+ bin_ids,
+ None,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+ return out, None, None, None, None, None
+
+
+padded_gather = PaddedGatherOp.apply
diff --git a/build/torch210-cxx11-cu126-aarch64-linux/ops/padded_scatter.py b/build/torch210-cxx11-cu126-aarch64-linux/ops/padded_scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..61e021b81497e472cda5d72bdac557a0ca92d262
--- /dev/null
+++ b/build/torch210-cxx11-cu126-aarch64-linux/ops/padded_scatter.py
@@ -0,0 +1,98 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for padded_scatter kernel.
+class PaddedScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+ ):
+ maybe_x = [x] if ctx.needs_input_grad[3] else []
+ ctx.save_for_backward(
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ *maybe_x,
+ )
+ ctx.top_k = top_k
+ ctx.x_shape = x.shape
+ return kernels.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ saved_tensors = ctx.saved_tensors
+
+ indices, bin_ids, weights, bins, padded_bins = saved_tensors[:5]
+ dgrad = None
+ if ctx.needs_input_grad[0]:
+ dgrad = kernels.padded_gather(
+ grad,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[3]: # need wgrad
+ x = saved_tensors[-1]
+ wgrad = kernels.padded_scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+ return dgrad, None, None, wgrad, None, None, None, None
+
+
+def padded_scatter(
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+):
+ return PaddedScatterOp.apply(
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
diff --git a/build/torch210-cxx11-cu126-aarch64-linux/ops/padded_scatter_benchmark.py b/build/torch210-cxx11-cu126-aarch64-linux/ops/padded_scatter_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..3cbe4735891446b46f93170c64c23fe63632bf93
--- /dev/null
+++ b/build/torch210-cxx11-cu126-aarch64-linux/ops/padded_scatter_benchmark.py
@@ -0,0 +1,66 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import torch
+# from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+_PADDED_SCATTER_BENCHMARK = (
+ # dMoE-Medium, 8-way EMP.
+ (1024 * 16, 1024, 8, 4),
+ # dMoE-Medium, post-all-to-all.
+ (1024 * 16 * 4, 1024, 8, 1),
+)
+
+
+# class PaddedScatterTest(parameterized.TestCase):
+#
+# @parameterized.parameters(*_PADDED_SCATTER_BENCHMARK)
+# def testPaddedScatter(self, sl, hs, ne, top_k):
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+#
+# # Randomly assign tokens to experts.
+# top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(top_expert, ne)
+# padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+# padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+#
+# # Sample weights for the scatter reduce.
+# weights = torch.rand((sl * top_k,)).cuda().half()
+#
+# # Gather the data to prepare for backwards.
+# x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
+#
+# def benchmark():
+# return ops.padded_scatter(
+# x,
+# indices,
+# bin_ids,
+# weights,
+# bins,
+# padded_bins,
+# top_k,
+# )
+#
+# time, std = benchmark_util.benchmark_function(benchmark)
+# benchmark_util.log_benchmark(
+# 'Padded Scatter',
+# {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# 'top_k': top_k,
+# },
+# time,
+# std,
+# )
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch210-cxx11-cu126-aarch64-linux/ops/permute_benchmark.py b/build/torch210-cxx11-cu126-aarch64-linux/ops/permute_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..697abddbb3a2082ec4ddd6d94f89f7faabb34b40
--- /dev/null
+++ b/build/torch210-cxx11-cu126-aarch64-linux/ops/permute_benchmark.py
@@ -0,0 +1,149 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import torch
+# from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+_PERMUTE_TESTS = (
+ (16384, 768, 2),
+ (16384, 768, 4),
+ (16384, 768, 8),
+ (16384, 768, 16),
+ (16384, 768, 32),
+ (16384, 768, 64),
+ (16384, 768, 128),
+ (16384 * 8, 768, 2),
+ (16384 * 8, 768, 4),
+ (16384 * 8, 768, 8),
+ (16384 * 8, 768, 16),
+ (16384 * 8, 768, 32),
+ (16384 * 8, 768, 64),
+ (16384 * 8, 768, 128),
+)
+
+
+# class PermuteBenchmark(parameterized.TestCase):
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testBinnedGather(self, sl, hs, ne):
+# # NOTE: Capacity factor == 1.
+# ec = sl // ne
+#
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+# top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(indices, ne)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+#
+# def benchmark():
+# return ops.binned_gather(x, indices, bins, ec)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('BinnedGather', arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testBinnedScatter(self, sl, hs, ne):
+# # NOTE: Capacity factor == 1.
+# ec = sl // ne
+#
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+# top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(indices, ne)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+# x = ops.binned_gather(x, indices, bins, ec)
+#
+# def benchmark():
+# return ops.binned_scatter(x, indices, bins)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('BinnedScatter', arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testPaddedGather(self, sl, hs, ne):
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+#
+# # Randomly assign tokens to experts.
+# top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(top_expert, ne)
+# padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+# padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+#
+# def benchmark():
+# return ops.padded_gather(x, indices, bin_ids, bins, padded_bins)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('PaddedGather', arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testPaddedScatter(self, sl, hs, ne):
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+#
+# # Randomly assign tokens to experts.
+# top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(top_expert, ne)
+# padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+# padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+# x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins)
+#
+# def benchmark():
+# return ops.padded_scatter(x, indices, bin_ids, bins, padded_bins)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('PaddedScatter', arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testCopy(self, sl, hs, ne):
+# # NOTE: Capacity factor == 1.
+# # ec = sl // ne
+#
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+# y = x.clone()
+#
+# def benchmark():
+# return y.copy_(x)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('Copy', arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch210-cxx11-cu126-aarch64-linux/ops/repeat.py b/build/torch210-cxx11-cu126-aarch64-linux/ops/repeat.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e9e09de5f857d51cd758ab30b2f3a846d6f9275
--- /dev/null
+++ b/build/torch210-cxx11-cu126-aarch64-linux/ops/repeat.py
@@ -0,0 +1,10 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+
+def repeat(x: torch.Tensor, tiling: torch.Size):
+ if all((t == 1 for t in tiling)):
+ return x
+ return x.repeat(*tiling)
diff --git a/build/torch210-cxx11-cu126-aarch64-linux/ops/replicate.py b/build/torch210-cxx11-cu126-aarch64-linux/ops/replicate.py
new file mode 100644
index 0000000000000000000000000000000000000000..26daf0eede330603a4b8ea7167faf1411d07ca93
--- /dev/null
+++ b/build/torch210-cxx11-cu126-aarch64-linux/ops/replicate.py
@@ -0,0 +1,36 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for replicate kernel.
+class ReplicateOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, bins: torch.Tensor, num_outputs: int):
+ ctx.save_for_backward(bins)
+ out = torch.empty((x.shape[0], num_outputs), dtype=x.dtype, device=x.device)
+ ops.replicate_forward(x, bins, out)
+ return out
+
+ @staticmethod
+ def backward(ctx: Any, grad: torch.Tensor):
+ bins, = ctx.saved_tensors
+ out = torch.empty((grad.shape[0], bins.shape[0]), dtype=grad.dtype, device=grad.device)
+ ops.replicate_backward(grad, bins, out)
+ return out, None, None
+
+
+replicate = ReplicateOp.apply
diff --git a/build/torch210-cxx11-cu126-aarch64-linux/ops/round_up.py b/build/torch210-cxx11-cu126-aarch64-linux/ops/round_up.py
new file mode 100644
index 0000000000000000000000000000000000000000..6cf6bc873c9f448c5fa9126ebcfd66e8688002af
--- /dev/null
+++ b/build/torch210-cxx11-cu126-aarch64-linux/ops/round_up.py
@@ -0,0 +1,14 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+
+def round_up(x: torch.Tensor, value: int):
+ assert isinstance(value, int)
+ assert x.dtype == torch.int32
+
+ # TODO(tgale): If this becomes and issue
+ # do this in a custom kernel. We only expect
+ # to use this on arrays of less than 1k elements.
+ return torch.div(x + (value - 1), value, rounding_mode='trunc') * value
diff --git a/build/torch210-cxx11-cu126-aarch64-linux/ops/scatter.py b/build/torch210-cxx11-cu126-aarch64-linux/ops/scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4605d9b46f387761b070352365f223dbfe69d47
--- /dev/null
+++ b/build/torch210-cxx11-cu126-aarch64-linux/ops/scatter.py
@@ -0,0 +1,72 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Optional
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for scatter kernel.
+class ScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ) -> torch.Tensor:
+ maybe_x = [x] if ctx.needs_input_grad[3] else []
+ ctx.save_for_backward(indices, bin_ids, weights, bins, *maybe_x)
+ ctx.top_k = top_k
+ ctx.x_shape = x.shape
+ return kernels.scatter(x, indices, bin_ids, weights, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ saved_tensors = ctx.saved_tensors
+
+ indices, bin_ids, weights, bins = saved_tensors[:4]
+ dgrad = None
+ if ctx.needs_input_grad[0]:
+ dgrad = kernels.gather(
+ grad,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[3]: # need wgrad
+ x = saved_tensors[-1]
+ wgrad = kernels.scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bin_ids,
+ bins,
+ ctx.top_k,
+ )
+ return dgrad, None, None, wgrad, None, None, None
+
+
+def scatter(
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+) -> Optional[torch.Tensor]:
+ return ScatterOp.apply(x, indices, bin_ids, weights, bins, top_k)
diff --git a/build/torch210-cxx11-cu126-aarch64-linux/ops/sort.py b/build/torch210-cxx11-cu126-aarch64-linux/ops/sort.py
new file mode 100644
index 0000000000000000000000000000000000000000..bda3bf64283e39533c2eae3627e76bb2d0262c9f
--- /dev/null
+++ b/build/torch210-cxx11-cu126-aarch64-linux/ops/sort.py
@@ -0,0 +1,38 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Optional, Tuple
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+_BITS_FOR_DTYPE = {
+ torch.int16: 16,
+ torch.int32: 32,
+ torch.int64: 64,
+}
+
+
+# Autograd wrapper for sort kernel.
+# NOTE: Does not support gradients.
+class SortOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, end_bit: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
+ if end_bit is None:
+ end_bit = _BITS_FOR_DTYPE[x.dtype]
+ x_out = torch.empty_like(x)
+ iota_out = torch.empty_like(x)
+ ops.sort(x, end_bit, x_out, iota_out)
+ return (x_out, iota_out)
+
+
+sort = SortOp.apply
diff --git a/build/torch210-cxx11-cu126-aarch64-linux/ops/sort_benchmark.py b/build/torch210-cxx11-cu126-aarch64-linux/ops/sort_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..11043c0824c36372585f1d9f48480c2a6ef32eb6
--- /dev/null
+++ b/build/torch210-cxx11-cu126-aarch64-linux/ops/sort_benchmark.py
@@ -0,0 +1,85 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import numpy as np
+import torch
+# from absl.testing import parameterized
+
+from .. import ops
+
+_SORT_TESTS = (
+ (16384, torch.int32, None),
+ (16384, torch.int32, 2),
+ (16384, torch.int32, 128),
+)
+
+_BASELINE_SORT_TESTS = ((16384,),)
+
+
+def numpy_dtype(dtype):
+ types = {
+ torch.int16: np.int16,
+ torch.int32: np.int32,
+ torch.int64: np.int64,
+ }
+ return types[dtype]
+
+
+def benchmark_function(fn, iterations=10):
+ # Run once to get rid of startup overhead.
+ fn()
+ times = []
+ for _ in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+ start.record()
+ fn()
+ end.record()
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ times = np.array(times)
+ return times.mean(), times.std(), times.max(), times.min()
+
+
+def log_benchmark(arguments, mean_t, std_t):
+ print('=' * 60)
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t))
+ print('=' * 60)
+
+
+# class SortBenchmark(parameterized.TestCase):
+#
+# @parameterized.parameters(*_SORT_TESTS)
+# def testSort(self, n, dtype, max_val):
+# if max_val is None:
+# max_val = np.iinfo(numpy_dtype(dtype)).max
+# end_bit = int(np.ceil(np.log2(max_val)))
+# x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
+#
+# mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.sort(x, end_bit),)
+# arguments = {
+# 'n': n,
+# 'dtype': dtype,
+# 'max_val': max_val,
+# }
+# log_benchmark(arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_BASELINE_SORT_TESTS)
+# def testTorchSort(self, n):
+# x = torch.randint(0, 128, (n,)).cuda().to(torch.int32)
+#
+# mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.sort(x))
+# arguments = {
+# 'n': n,
+# }
+# log_benchmark(arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch210-cxx11-cu126-aarch64-linux/ops/stk_autocast.py b/build/torch210-cxx11-cu126-aarch64-linux/ops/stk_autocast.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a3626e5e0eec51339c95a448bca84be14a2ca93
--- /dev/null
+++ b/build/torch210-cxx11-cu126-aarch64-linux/ops/stk_autocast.py
@@ -0,0 +1,39 @@
+# vendored from
+# https://github.com/stanford-futuredata/stk/blob/736313768ef697ce13a0594a41b2512a0fbc9884/stk/backend/autocast.py
+import functools
+import torch
+
+
+def _is_eligible(x):
+ return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64)
+
+
+def _cast(x, dtype):
+ if isinstance(x, torch.Tensor) and _is_eligible(x):
+ return x.to(dtype)
+ elif isinstance(x, map):
+ return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()}
+ elif isinstance(x, list) or isinstance(x, tuple):
+ return type(x)(map(lambda y: _cast(y, dtype), x))
+ return x
+
+
+def custom_fwd(fwd):
+ """Wrap a custom autograd function that always uses autocast dtype."""
+
+ @functools.wraps(fwd)
+ def decorate_fwd(*args, **kwargs):
+ if torch.is_autocast_enabled():
+ with torch.autocast(device_type="cuda", enabled=False):
+ dtype = torch.get_autocast_gpu_dtype()
+ return fwd(*_cast(args, dtype), **_cast(kwargs, dtype))
+ return fwd(*args, **kwargs)
+ return decorate_fwd
+
+
+def custom_bwd(bwd):
+ @functools.wraps(bwd)
+ def decorate_bwd(*args, **kwargs):
+ with torch.autocast(device_type="cuda", enabled=False):
+ return bwd(*args, **kwargs)
+ return decorate_bwd
\ No newline at end of file
diff --git a/build/torch210-cxx11-cu126-aarch64-linux/ops/sum.py b/build/torch210-cxx11-cu126-aarch64-linux/ops/sum.py
new file mode 100644
index 0000000000000000000000000000000000000000..e00c1aa68e1193f5b72f75a2edc37de8d505facc
--- /dev/null
+++ b/build/torch210-cxx11-cu126-aarch64-linux/ops/sum.py
@@ -0,0 +1,9 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+import torch
+
+
+def sum(x: torch.Tensor, dim: int = 0):
+ if x.shape[dim] == 1:
+ return x.squeeze(dim=dim)
+ return x.sum(dim=dim)
diff --git a/build/torch210-cxx11-cu126-aarch64-linux/ops/topology.py b/build/torch210-cxx11-cu126-aarch64-linux/ops/topology.py
new file mode 100644
index 0000000000000000000000000000000000000000..76a50d3164db20534b099dcb4d8487a7aef25d15
--- /dev/null
+++ b/build/torch210-cxx11-cu126-aarch64-linux/ops/topology.py
@@ -0,0 +1,45 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for topology kernel.
+# NOTE: Does not support gradients.
+class TopologyOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(
+ ctx: Any,
+ padded_bins: torch.Tensor,
+ block_size: int,
+ output_block_rows: int,
+ output_block_columns: int,
+ ):
+ out = torch.empty(
+ output_block_rows * output_block_columns,
+ dtype=torch.int16,
+ device=padded_bins.device,
+ )
+ ops.indices(
+ padded_bins,
+ block_size,
+ output_block_rows,
+ output_block_columns,
+ out,
+ )
+ return out
+
+
+topology = TopologyOp.apply
diff --git a/build/torch210-cxx11-cu126-aarch64-linux/stk/__init__.py b/build/torch210-cxx11-cu126-aarch64-linux/stk/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..73c40c267b1c3f4949e9c957a5d2c9f682dfc1a6
--- /dev/null
+++ b/build/torch210-cxx11-cu126-aarch64-linux/stk/__init__.py
@@ -0,0 +1,7 @@
+# import stk.random
+# import stk.ops
+# from stk.matrix import Matrix
+
+from . import random
+from . import ops
+from .matrix import Matrix
diff --git a/build/torch210-cxx11-cu126-aarch64-linux/stk/backend/__init__.py b/build/torch210-cxx11-cu126-aarch64-linux/stk/backend/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/build/torch210-cxx11-cu126-aarch64-linux/stk/backend/autocast.py b/build/torch210-cxx11-cu126-aarch64-linux/stk/backend/autocast.py
new file mode 100644
index 0000000000000000000000000000000000000000..97f6e919a60f3fd579ed0215031008d14111dc96
--- /dev/null
+++ b/build/torch210-cxx11-cu126-aarch64-linux/stk/backend/autocast.py
@@ -0,0 +1,37 @@
+import functools
+import torch
+
+
+def _is_eligible(x):
+ return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64)
+
+
+def _cast(x, dtype):
+ if isinstance(x, torch.Tensor) and _is_eligible(x):
+ return x.to(dtype)
+ elif isinstance(x, map):
+ return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()}
+ elif isinstance(x, list) or isinstance(x, tuple):
+ return type(x)(map(lambda y: _cast(y, dtype), x))
+ return x
+
+
+def custom_fwd(fwd):
+ """Wrap a custom autograd function that always uses autocast dtype."""
+
+ @functools.wraps(fwd)
+ def decorate_fwd(*args, **kwargs):
+ if torch.is_autocast_enabled():
+ with torch.autocast(device_type="cuda", enabled=False):
+ dtype = torch.get_autocast_gpu_dtype()
+ return fwd(*_cast(args, dtype), **_cast(kwargs, dtype))
+ return fwd(*args, **kwargs)
+ return decorate_fwd
+
+
+def custom_bwd(bwd):
+ @functools.wraps(bwd)
+ def decorate_bwd(*args, **kwargs):
+ with torch.autocast(device_type="cuda", enabled=False):
+ return bwd(*args, **kwargs)
+ return decorate_bwd
diff --git a/build/torch210-cxx11-cu126-aarch64-linux/stk/backend/sputnik.py b/build/torch210-cxx11-cu126-aarch64-linux/stk/backend/sputnik.py
new file mode 100644
index 0000000000000000000000000000000000000000..220c947bc1e932e8c77cc30f4069e9930f1aa962
--- /dev/null
+++ b/build/torch210-cxx11-cu126-aarch64-linux/stk/backend/sputnik.py
@@ -0,0 +1,316 @@
+import torch
+
+from ..backend import triton_kernels as backend
+from ..backend.autocast import custom_bwd, custom_fwd
+
+
+def _standardize_shape(x, transpose):
+ if transpose:
+ return torch.Size((x[1], x[0]))
+ return x
+
+
+def _sparse_transpose(x):
+ return (torch.Size((x[0][1], x[0][0])), ) + x[1:]
+
+
+def _transpose_helper(x, transpose):
+ if isinstance(x, torch.Tensor):
+ return x.t() if transpose else x
+ if transpose:
+ x = _sparse_transpose(x)
+ return x + (transpose,)
+
+
+def _wrap(x):
+ if isinstance(x, torch.Tensor):
+ return (x,)
+ return x
+
+
+def _is_transposed(x):
+ return (not x.is_contiguous() and
+ x.stride()[0] == 1 and
+ x.stride()[1] == x.size()[0])
+
+
+def _call_helper(op, out, a, b, trans_a, trans_b):
+ args = (_wrap(_transpose_helper(a, trans_a)) +
+ _wrap(_transpose_helper(b, trans_b)))
+ if isinstance(out, tuple):
+ args = args + out
+ return op(*args)
+
+
+def _preprocess_inputs(lhs, rhs, dy):
+ if isinstance(lhs, torch.Tensor) and _is_transposed(lhs):
+ lhs = lhs.t()
+ if isinstance(rhs, torch.Tensor) and _is_transposed(rhs):
+ rhs = rhs.t()
+ if (isinstance(dy, torch.Tensor) and
+ not dy.is_contiguous() and
+ not _is_transposed(dy)):
+ dy = dy.contiguous()
+ if isinstance(dy, tuple) and not dy[1].is_contiguous():
+ dy = (dy[0], dy[1].contiguous()) + dy[2:]
+ return lhs, rhs, dy
+
+
+def _postprocess_outputs(x, transpose, grad):
+ if isinstance(x, torch.Tensor) and transpose:
+ return grad.t()
+ return grad
+
+
+def _lhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs):
+ lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy)
+
+ a, b = (rhs, dy) if trans_lhs else (dy, rhs)
+ trans_a = trans_lhs and trans_rhs
+ trans_b = trans_lhs or not trans_rhs
+ out = _call_helper(op, lhs, a, b, trans_a, trans_b)
+ return _postprocess_outputs(lhs, trans_lhs, out)
+
+
+def _rhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs):
+ lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy)
+
+ a, b = (dy, lhs) if trans_rhs else (lhs, dy)
+ trans_a = not trans_lhs or trans_rhs
+ trans_b = trans_lhs and trans_rhs
+ out = _call_helper(op, rhs, a, b, trans_a, trans_b)
+ return _postprocess_outputs(rhs, trans_rhs, out)
+
+
+class DSD(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs):
+ ctx.save_for_backward(data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ rhs)
+ ctx.shape = _standardize_shape(shape, transpose_a)
+ ctx.transpose_a = transpose_a
+
+ out = torch.empty(
+ (shape[0], rhs.size()[1]),
+ dtype=rhs.dtype,
+ device=rhs.device)
+
+ backend.dsd(shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs,
+ out)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs = (ctx.shape,) + saved_tensors[:-1]
+ rhs = saved_tensors[-1]
+ trans_a = ctx.transpose_a
+ trans_b = _is_transposed(rhs)
+
+ ddata = None
+ if ctx.needs_input_grad[1]:
+ ddata = _lhs_gradient(sdd,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ drhs = None
+ if ctx.needs_input_grad[-1]:
+ op = dds if trans_b else dsd
+ drhs = _rhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return None, ddata, None, None, None, None, None, None, None, drhs
+
+
+dsd = DSD.apply
+
+
+class DDS(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b):
+ ctx.save_for_backward(lhs,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t)
+ ctx.shape = _standardize_shape(shape, transpose_b)
+ ctx.transpose_b = transpose_b
+ out = torch.empty((lhs.size()[0], shape[1]),
+ dtype=lhs.dtype,
+ device=lhs.device)
+ backend.dds(lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b,
+ out)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs = saved_tensors[0]
+ rhs = (ctx.shape,) + saved_tensors[1:]
+ trans_a = _is_transposed(lhs)
+ trans_b = ctx.transpose_b
+
+ dlhs = None
+ if ctx.needs_input_grad[0]:
+ op = dsd if trans_a else dds
+ dlhs = _lhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ ddata = None
+ if ctx.needs_input_grad[2]:
+ ddata = _rhs_gradient(sdd,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return dlhs, None, ddata, None, None, None, None, None, None, None
+
+
+dds = DDS.apply
+
+
+class SDD(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ lhs,
+ rhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t):
+ ctx.save_for_backward(
+ lhs,
+ rhs,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t)
+ ctx.shape = shape
+ out = torch.empty(
+ data.shape,
+ dtype=lhs.dtype,
+ device=lhs.device)
+ backend.sdd(lhs,
+ rhs,
+ shape,
+ out,
+ offsets,
+ row_indices,
+ column_indices)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs, rhs = saved_tensors[:2]
+ dy = (ctx.shape, dy) + saved_tensors[2:]
+ trans_a = _is_transposed(lhs)
+ trans_b = _is_transposed(rhs)
+
+ dlhs = None
+ if ctx.needs_input_grad[0]:
+ op = dds if trans_a else dsd
+ dlhs = _lhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ drhs = None
+ if ctx.needs_input_grad[1]:
+ op = dsd if trans_b else dds
+ drhs = _rhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return dlhs, drhs, None, None, None, None, None, None, None, None
+
+
+sdd = SDD.apply
+
+class RowIndices(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, shape, data, offsets, column_indices):
+ out = torch.empty(
+ column_indices.shape,
+ dtype=column_indices.dtype,
+ device=column_indices.device)
+ backend.row_indices(shape, data, offsets, column_indices, out)
+ return out
+
+
+row_indices = RowIndices.apply
diff --git a/build/torch210-cxx11-cu126-aarch64-linux/stk/backend/triton_kernels.py b/build/torch210-cxx11-cu126-aarch64-linux/stk/backend/triton_kernels.py
new file mode 100644
index 0000000000000000000000000000000000000000..c535309f3321249f475367164a558f94a4f8eb86
--- /dev/null
+++ b/build/torch210-cxx11-cu126-aarch64-linux/stk/backend/triton_kernels.py
@@ -0,0 +1,393 @@
+import torch
+import triton
+import triton.language as tl
+from dataclasses import dataclass
+
+@dataclass
+class TritonConfig:
+ BLOCK_M: int = 128
+ BLOCK_N: int = 128
+ BLOCK_K: int = 32
+ BLOCK_SIZE: int = 128
+ NUM_STAGES: int = 4
+ NUM_WARPS: int = 4
+
+def _validate_matmul_dims(M: int, K: int, N: int):
+ error_string = "incompatible dimensions: tensor has dim with length: {}, which must be divisible by {}"
+ assert M % TritonConfig.BLOCK_M == 0, error_string.format(M, TritonConfig.BLOCK_M)
+ assert K % TritonConfig.BLOCK_K == 0, error_string.format(K, TritonConfig.BLOCK_K)
+ assert N % TritonConfig.BLOCK_N == 0, error_string.format(N, TritonConfig.BLOCK_N)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _sdd_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+ # matrix multiplication
+ pid = tl.program_id(0)
+ pid_m = tl.load(row_indices + pid)
+ pid_n = tl.load(column_indices + pid)
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
+ rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
+ rk = tl.arange(0, BLOCK_K)
+ # pointers
+ A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
+ B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ for k in range(0, tl.cdiv(K, BLOCK_K)):
+ a = tl.load(A)
+ b = tl.load(B)
+ acc += tl.dot(a, b)
+ A += BLOCK_K * stride_ak
+ B += BLOCK_K * stride_bk
+ #Store to sparse matrix
+ acc = acc.to(C.dtype.element_ty)
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+ cm = tl.arange(0, BLOCK_M)
+ cn = tl.arange(0, BLOCK_N)
+ C = C + pid * BLOCK_ELEMENTS + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _dsd_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices, offsets,
+ block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+
+ # matrix multiplication
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+
+ num_pid_m = tl.num_programs(0)
+ num_pid_n = tl.num_programs(1)
+ pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M)
+
+ start_inx = tl.load(offsets + pid_m)
+ end_inx = tl.load(offsets + pid_m + 1)
+
+ # pointers to sparse matrix
+ rm = tl.arange(0, BLOCK_M)
+ rak = tl.arange(0, BLOCK_K)
+
+ A += (rm[:, None] * stride_am + rak[None, :] * stride_ak)
+
+ # pointers to dense matrix
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ rbk = tl.arange(0, BLOCK_K)
+ B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn)
+
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K)
+
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+ ak_sub_incr = BLOCK_K * stride_ak
+ bk_sub_incr = BLOCK_K * stride_bk
+ bk_block_incr = BLOCK_SIZE * stride_bk
+
+ for k in range(nsub_blocks * (end_inx - start_inx)):
+ sub_block_inx = k % nsub_blocks
+ block_inx = k // nsub_blocks
+
+ if trans_A:
+ ptr_A = A + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr
+ else:
+ ptr_A = A + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr
+
+ ptr_B = B + tl.load(column_indices + start_inx + block_inx) * bk_block_incr + sub_block_inx * bk_sub_incr
+
+ a = tl.load(ptr_A)
+ b = tl.load(ptr_B)
+ acc += tl.dot(a, b)
+
+ acc = acc.to(C.dtype.element_ty)
+
+ cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+
+ C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _dds_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices, offsets,
+ block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+
+ # matrix multiplication
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+
+ num_pid_m = tl.num_programs(0)
+ num_pid_n = tl.num_programs(1)
+ pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M)
+
+ start_inx = tl.load(offsets + pid_n)
+ end_inx = tl.load(offsets + pid_n + 1)
+
+ # pointers to dense matrix
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ rak = tl.arange(0, BLOCK_K)
+
+ A += (rm[:, None] * stride_am + rak[None, :] * stride_ak)
+
+ # pointers to sparse matrix
+ rn = tl.arange(0, BLOCK_N)
+ rbk = tl.arange(0, BLOCK_K)
+ B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn)
+
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K)
+
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+
+ ak_sub_incr = BLOCK_K * stride_ak
+ ak_block_incr = BLOCK_SIZE * stride_ak
+ bk_sub_incr = BLOCK_K * stride_bk
+
+ for k in range(nsub_blocks * (end_inx - start_inx)):
+ sub_block_inx = k % nsub_blocks
+ block_inx = k // nsub_blocks
+
+ if trans_B:
+ ptr_B = B + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr
+ else:
+ ptr_B = B + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr
+
+ ptr_A = A + tl.load(column_indices + start_inx + block_inx) * ak_block_incr + sub_block_inx * ak_sub_incr
+ a = tl.load(ptr_A)
+ b = tl.load(ptr_B)
+ acc += tl.dot(a, b)
+
+ acc = acc.to(C.dtype.element_ty)
+ cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+def dsd(shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs,
+ out
+ ):
+
+ device = rhs.device
+ trans_A = transpose_a
+ trans_B = False
+
+ if rhs.stride(0) > 1 and rhs.stride(1) > 1:
+ trans_B = True
+
+ # checks constraints
+ assert shape[1] == rhs.shape[0], "incompatible dimensions"
+ M, K = shape
+ _, N = rhs.shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if rhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ stride_am, stride_ak = data.stride(1), data.stride(2)
+ stride_bk, stride_bn = rhs.stride(0), rhs.stride(1)
+ a_column_indices = column_indices
+ a_offsets = offsets
+
+ # launch kernel
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N']))
+
+ if trans_A:
+ stride_am, stride_ak = data.stride(2), data.stride(1)
+ a_column_indices, a_offsets = column_indices_t, offsets_t
+
+ if trans_B:
+ stride_bk, stride_bn = rhs.stride(1), rhs.stride(0)
+
+ _dsd_kernel[grid](
+ data.data, rhs, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(0), out.stride(1),
+ row_indices, a_column_indices, a_offsets,
+ block_offsets_t, trans_A, trans_B,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+ # return out
+
+def dds(lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b,
+ out
+ ):
+
+ device = lhs.device
+ trans_B = transpose_b
+ trans_A = False
+
+ if lhs.stride(0) > 1 and lhs.stride(1) > 1:
+ trans_A = True
+
+ # checks constraints
+ assert lhs.shape[1] == shape[0], "incompatible dimensions"
+ M, K = lhs.shape
+ _, N = shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if lhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ stride_am, stride_ak = lhs.stride(0), lhs.stride(1)
+ stride_bk, stride_bn = data.stride(1), data.stride(2)
+ b_column_indices = column_indices_t
+ b_offsets = offsets_t
+
+ # launch kernel
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N']))
+
+ if trans_A:
+ stride_am, stride_ak = lhs.stride(1), lhs.stride(0)
+ if trans_B:
+ stride_bk, stride_bn = data.stride(2), data.stride(1)
+ b_column_indices, b_offsets = column_indices, offsets
+
+ _dds_kernel[grid](
+ lhs, data, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(0), out.stride(1),
+ row_indices, b_column_indices, b_offsets,
+ block_offsets_t, trans_A, trans_B,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+
+def sdd(lhs,
+ rhs,
+ shape,
+ out,
+ offsets,
+ row_indices,
+ column_indices
+ ):
+
+ device = out.device
+ trans_A = False
+ trans_B = False
+
+ if lhs.stride(0) > 1 and lhs.stride(1) > 1:
+ trans_A = True
+ if rhs.stride(0) > 1 and rhs.stride(1) > 1:
+ trans_B = True
+
+ # checks constraints
+ assert lhs.shape[1] == rhs.shape[0], "incompatible dimensions"
+ M, K = lhs.shape
+ _, N = rhs.shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if out.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ # launch kernel
+ nnz_blocks = len(row_indices)
+ grid = lambda META: (nnz_blocks,)
+
+ stride_am, stride_ak = lhs.stride(0), lhs.stride(1)
+ stride_bk, stride_bn = rhs.stride(0), rhs.stride(1)
+
+ if trans_A:
+ stride_am, stride_ak = lhs.stride(1), lhs.stride(0)
+ if trans_B:
+ stride_bk, stride_bn = rhs.stride(1), rhs.stride(0)
+
+ _sdd_kernel[grid](
+ lhs, rhs, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(1), out.stride(2),
+ row_indices, column_indices,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+
+@triton.jit
+def _row_indices_kernel(offsets, out):
+ pid = tl.program_id(0)
+ row_offset = tl.load(offsets + pid)
+ nnz_blocks = tl.load(offsets + pid + 1) - row_offset
+ for nnz_block in range(nnz_blocks):
+ tl.store(out + row_offset + nnz_block, pid)
+
+def row_indices(
+ shape, data, offsets, column_indices, out
+):
+ block_rows = len(offsets) - 1
+ _row_indices_kernel[(block_rows, )](offsets, out)
diff --git a/build/torch210-cxx11-cu126-aarch64-linux/stk/matrix.py b/build/torch210-cxx11-cu126-aarch64-linux/stk/matrix.py
new file mode 100644
index 0000000000000000000000000000000000000000..80f42263d6aada287adbfa52a61fe950162a9e28
--- /dev/null
+++ b/build/torch210-cxx11-cu126-aarch64-linux/stk/matrix.py
@@ -0,0 +1,329 @@
+import numpy as np
+import torch
+
+# 1. Add heavyweight (data) validation helper.
+# 2. Add construction helpers
+# 3. Make indentation consistent
+# 4. Replace asserts with descriptive errors.
+
+##
+### Validation helpers.
+##
+
+
+def _validate_matrix(shape, data, row_indices, column_indices, offsets):
+ # Data should be [nnz, block_size, block_size]
+ if data.dim() == 1:
+ data = torch.reshape(data, [data.numel(), 1, 1])
+
+ # Blocks should be square.
+ if data.shape[-2] != data.shape[-1]:
+ raise ValueError(
+ "Expected square blocking in data. "
+ f"Got block shape {[data.shape[-2], data.shape[-1]]}")
+
+ # Flatten batch dimensions on data - original shape preserved
+ # in shape argument.
+ block_size = data.shape[-1]
+ data = data.view([-1, block_size, block_size])
+
+ if data.dim() != 3:
+ raise ValueError(
+ "Expected 3D shape for data (nnz, block, block). "
+ f"Got shape {data.dim()}D shape.")
+
+ block_size = data.shape[1]
+ if shape[-2] % block_size != 0 or shape[-1] % block_size != 0:
+ raise ValueError(
+ "Matrix shape must be dividible by blocking. "
+ f"Got shape {shape} with "
+ f"{[block_size, block_size]} blocking.")
+
+ if np.prod(shape) < data.numel():
+ raise ValueError(
+ "Invalid matrix. Number of nonzeros exceeds matrix capacity "
+ f"({data.numel()} v. {np.prod(shape)})")
+
+ if row_indices.dim() != 1:
+ raise ValueError(
+ f"Expected 1D row_indices. Got {row_indices.dim()}D row_indices.")
+
+ if column_indices.dim() != 1:
+ raise ValueError(
+ f"Expected 1D column_indices. Got {column_indices.dim()}D column_indices.")
+
+ if offsets.dim() != 1:
+ raise ValueError(
+ f"Expected 1D offsets. Got {offsets.dim()}D offsets.")
+
+ if row_indices.numel() != data.shape[0]:
+ raise ValueError(
+ "Expected 1 index per nonzero block. "
+ f"Got {row_indices.numel()} row_indices for {data.shape[0]} blocks")
+
+ if column_indices.numel() != data.shape[0]:
+ raise ValueError(
+ "Expected 1 index per nonzero block. "
+ f"Got {column_indices.numel()} column_indices for {data.shape[0]} blocks")
+
+ block_rows = np.prod(shape[:-1]) / block_size
+ if offsets.numel() != block_rows + 1:
+ raise ValueError(
+ "Expected one offset per block row plus one. "
+ f"Got {offsets.numel()} offsets with {block_rows} block rows.")
+
+ is_cuda = (data.is_cuda and
+ row_indices.is_cuda and
+ column_indices.is_cuda and
+ offsets.is_cuda)
+ is_cpu = (not data.is_cuda and
+ not row_indices.is_cuda and
+ not column_indices.is_cuda and
+ not offsets.is_cuda)
+ if not (is_cuda or is_cpu):
+ raise ValueError(
+ "Expected data & meta-data on common device. "
+ f"Got data on {data.device}, row_indices on {row_indices.device} "
+ f"column_indices on {column_indices.device} and "
+ f"offsets on {offsets.device}.")
+
+ if data.dtype != torch.float16:
+ raise ValueError(
+ f"Expected float16 data. Got {data.dtype} data.")
+ if row_indices.dtype != torch.int16:
+ raise ValueError(
+ f"Expected int16 row_indices. Got {row_indices.dtype} row_indices.")
+ if column_indices.dtype != torch.int16:
+ raise ValueError(
+ f"Expected int16 column_indices. Got {column_indices.dtype} column_indices.")
+ if offsets.dtype != torch.int32:
+ raise ValueError(
+ f"Expected int32 offsets. Got {offsets.dtype} offsets.")
+ return data
+
+
+def _transpose(size, data, row_indices, column_indices, offsets):
+ block_columns = size[1] // data.shape[1]
+
+ # Sort row indices by column indices to get the transposed matrix's
+ # column indices.
+ gather_indices = column_indices.argsort()
+ column_indices_t = row_indices.gather(0, gather_indices)
+ block_offsets_t = gather_indices.int()
+
+ # NOTE: Histogram is not implemented for any integer type on CPU. Do
+ # the histogram in 32-bit float, which can exactly represent 16-bit
+ # integers.
+ column_indices_float = column_indices.float()
+
+ zero = torch.zeros((1,), dtype=torch.int32, device=data.device)
+ nnz_per_column = column_indices_float.histc(block_columns, 0, block_columns)
+ nnz_per_column = nnz_per_column.int()
+ offsets_t = torch.cat([zero, nnz_per_column.cumsum(0, dtype=torch.int32)])
+ return column_indices_t, offsets_t, block_offsets_t
+
+
+class Matrix(torch.nn.Module):
+ """A matrix stored in sparse format.
+
+ Underlying format is block compressed sparse row (BCSR).
+
+ TODO(tgale): Make this mirror torch.Tensor API as much as possible.
+ """
+
+ def __init__(self,
+ size,
+ data,
+ row_indices,
+ column_indices,
+ offsets,
+ column_indices_t=None,
+ offsets_t=None,
+ block_offsets_t=None):
+ super().__init__()
+ self._size = size
+ self._data = data
+ self._row_indices = row_indices
+ self._column_indices = column_indices
+ self._offsets = offsets
+
+ # Produce the transpose meta-data if it is not passed in.
+ if ((column_indices_t is None) or (offsets_t is None) or
+ (block_offsets_t is None)):
+ column_indices_t, offsets_t, block_offsets_t = _transpose(
+ size, data, row_indices, column_indices, offsets)
+ self._column_indices_t = column_indices_t
+ self._offsets_t = offsets_t
+ self._block_offsets_t = block_offsets_t
+
+ self._transposed = False
+
+ # Validate that our metadata will not overflow.
+ max_dim = np.iinfo(np.int16).max * self.blocking
+ if column_indices.dtype == torch.int16:
+ if size[0] > max_dim or size[1] > max_dim:
+ raise ValueError(
+ "Sparse matrix with shape {size} exceeds representable "
+ "size with 16-bit indices.")
+
+ def validate(self):
+ _validate_matrix(self._size,
+ self._data,
+ self._row_indices,
+ self._column_indices,
+ self._offsets)
+
+ # TODO(tgale): Add heavyweight data validation.
+
+ def to(self, device):
+ # TODO(tgale): Handle type conversions here. We
+ # need to set the appropriate meta-data type for
+ # the given floating-point type.
+ self._data = self._data.to(device)
+ self._row_indices = self._row_indices.to(device)
+ self._column_indices = self._column_indices.to(device)
+ self._offsets = self._offsets.to(device)
+ self._column_indices_t = self._column_indices_t.to(device)
+ self._offsets_t = self._offsets_t.to(device)
+ self._block_offsets_t = self._block_offsets_t.to(device)
+ return self
+
+ def cuda(self):
+ return self.to(torch.cuda.current_device())
+
+ def clone(self):
+ return Matrix(
+ self.size(),
+ self.data.clone(),
+ self.row_indices.clone(),
+ self.column_indices.clone(),
+ self.offsets.clone(),
+ self.column_indices_t.clone(),
+ self.offsets_t.clone(),
+ self.block_offsets_t.clone())
+
+ def t(self):
+ if self.dim() != 2:
+ raise ValueError(
+ "t() expects a tensor with <= 2 dimensions, "
+ f"but self is {self.dim()}D.")
+ out = Matrix(self.size(),
+ self.data,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+ out._transposed = not self._transposed
+ out._size = torch.Size((self._size[1], self._size[0]))
+ return out
+
+ def contiguous(self):
+ raise ValueError("Not yet implemented.")
+
+ def is_contiguous(self):
+ return not self._transposed
+
+ @property
+ def is_cuda(self):
+ return self._data.is_cuda
+
+ @property
+ def device(self):
+ return self._data.device
+
+ def size(self):
+ return self._size
+
+ @property
+ def shape(self):
+ return self.size()
+
+ def dim(self):
+ return len(self._size)
+
+ @property
+ def data(self):
+ return self._data
+
+ @property
+ def row_indices(self):
+ return self._row_indices
+
+ @property
+ def column_indices(self):
+ return self._column_indices
+
+ @property
+ def offsets(self):
+ return self._offsets
+
+ @property
+ def offsets_t(self):
+ return self._offsets_t
+
+ @property
+ def column_indices_t(self):
+ return self._column_indices_t
+
+ @property
+ def block_offsets_t(self):
+ return self._block_offsets_t
+
+ @property
+ def dtype(self):
+ return self.data.dtype
+
+ @property
+ def nnz(self):
+ return self.data.numel()
+
+ @property
+ def blocking(self):
+ return self.data.shape[1]
+
+ @property
+ def requires_grad(self):
+ return self.data.requires_grad
+
+ def requires_grad_(self, x):
+ self.data.requires_grad_(x)
+ return self
+
+ def view(self, *shape):
+ assert self.is_contiguous()
+ if shape[-1] != self.size()[-1]:
+ raise ValueError(
+ "Can't change view on compressed dimension. "
+ f"{self.size()[-1]} v. {shape[-1]}.")
+ if np.prod(shape) != np.prod(self.size()):
+ raise ValueError(
+ "Mismatch in numel of Matrix and new shape. "
+ f"{np.prod(self.size())} v. {np.prod(shape)}")
+ return Matrix(shape,
+ self.data,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+
+ @property
+ def grad(self):
+ # TODO(tgale): Make sure this mirrors torch.Tensor
+ # behavior in the case where we ask for the gradient
+ # of a non-contiguous tensor.
+ size = self.size()
+ if not self.is_contiguous():
+ size = torch.Size((size[1], size[0]))
+ out = Matrix(size,
+ self.data.grad,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+ return out if self.is_contiguous() else out.t()
diff --git a/build/torch210-cxx11-cu126-aarch64-linux/stk/ops/__init__.py b/build/torch210-cxx11-cu126-aarch64-linux/stk/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc873b236f4cd4036964c016a4036e3ce5ebf1ac
--- /dev/null
+++ b/build/torch210-cxx11-cu126-aarch64-linux/stk/ops/__init__.py
@@ -0,0 +1,3 @@
+from .linear_ops import dds, dsd, sdd
+from .matrix_ops import ones_like, row_indices, sum, to_dense, to_sparse
+from .eltwise_ops import mul
diff --git a/build/torch210-cxx11-cu126-aarch64-linux/stk/ops/eltwise_ops.py b/build/torch210-cxx11-cu126-aarch64-linux/stk/ops/eltwise_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba7d7332320250fd01fa60e60528f19de3e8ed03
--- /dev/null
+++ b/build/torch210-cxx11-cu126-aarch64-linux/stk/ops/eltwise_ops.py
@@ -0,0 +1,28 @@
+from ..matrix import Matrix
+
+def mul(a, b):
+ """Performs element-wise multiplication of matrices a and b.
+
+ It is the user's responsibility to make sure that a and b
+ follow the same matrix topology. This function assumes it is safe
+ to use the topoplogy of a.
+
+ Args:
+ a: stk.Matrix.
+ b: stk.Matrix with a's matrix topology.
+
+ Returns:
+ stk.Matrix where the entries correspond to torch.mul(a, b).
+ """
+ assert isinstance(a, Matrix)
+ assert isinstance(b, Matrix)
+ assert a.size() == b.size()
+
+ return Matrix(a.size(),
+ a.data * b.data,
+ a.row_indices,
+ a.column_indices,
+ a.offsets,
+ a.column_indices_t,
+ a.offsets_t,
+ a.block_offsets_t)
diff --git a/build/torch210-cxx11-cu126-aarch64-linux/stk/ops/eltwise_ops_test.py b/build/torch210-cxx11-cu126-aarch64-linux/stk/ops/eltwise_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..2939372a5c68ac92b47b11015db4f75f4fd60ffa
--- /dev/null
+++ b/build/torch210-cxx11-cu126-aarch64-linux/stk/ops/eltwise_ops_test.py
@@ -0,0 +1,86 @@
+import unittest
+import itertools
+import torch
+# from absl.testing import parameterized
+
+import stk
+from stk.ops.linear_ops_test import allclose, _dense_and_sparse
+
+_MATRIX_SIZES = (
+ (128, 128, 0.0),
+ (256, 256, 0.5),
+ (2048, 1024, 0.8),
+ (512, 128, 0.0),
+ (128, 512, 0.0),
+ (1024, 512, 0.0),
+ (1024, 512, 0.5),
+ (1024, 512, 0.75),
+ (512, 1024, 0.0),
+ (512, 1024, 0.5),
+ (512, 1024, 0.75),
+ (1024, 1024, 0.0),
+ (1024, 1024, 0.5),
+ (1024, 1024, 0.75),
+)
+
+_DTYPE = (
+ torch.float16, torch.bfloat16
+)
+
+def _generate_testcases():
+ testcases = itertools.product(_MATRIX_SIZES, _DTYPE)
+ testcases = [(*size, 128, dtype) for
+ (size, dtype) in testcases]
+ return testcases
+
+_ELTWISE_OP_TESTS = _generate_testcases()
+
+def _dense_and_sparse_like(x, std=0.1):
+ dense_data = torch.randn_like(x.data, device=x.device) * std
+ sparse = stk.Matrix(x.size(),
+ dense_data,
+ x.row_indices,
+ x.column_indices,
+ x.offsets)
+ dense = stk.ops.to_dense(sparse)
+
+ return (dense.requires_grad_(True),
+ sparse.requires_grad_(True))
+
+# @parameterized.parameters(_ELTWISE_OP_TESTS)
+# class EltwiseOpsTest(parameterized.TestCase):
+#
+# def testEltwiseMul(self, m, n, sparsity, blocking, dtype):
+#
+# a_dense, a = _dense_and_sparse(m, n, sparsity, blocking, dtype)
+# b_dense, b = _dense_and_sparse_like(a)
+#
+# out = stk.ops.mul(a, b)
+# expected_out = torch.mul(a_dense, b_dense)
+#
+# # Compute the gradients w.r.t. the inputs.
+# expected_out.sum().backward()
+# stk.ops.sum(out).backward()
+#
+# # Validate the results.
+# out = stk.ops.to_dense(out)
+# self.assertEqual(out.dim(), 2)
+# self.assertEqual(expected_out.size(), out.size())
+# self.assertTrue(allclose(out, expected_out))
+#
+# # LHS gradient.
+# grad = stk.ops.to_dense(a.grad)
+# expected_grad = a_dense.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size(), grad.size())
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# # RHS gradient.
+# grad = stk.ops.to_dense(b.grad)
+# expected_grad = b_dense.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size(), grad.size())
+# self.assertTrue(allclose(grad, expected_grad))
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch210-cxx11-cu126-aarch64-linux/stk/ops/linear_ops.py b/build/torch210-cxx11-cu126-aarch64-linux/stk/ops/linear_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d277c8c07f9e30addc31900a12175c8a1f4d7ad
--- /dev/null
+++ b/build/torch210-cxx11-cu126-aarch64-linux/stk/ops/linear_ops.py
@@ -0,0 +1,59 @@
+import torch
+
+from ..backend import sputnik
+from ..matrix import Matrix
+
+
+def dsd(a, b):
+ assert isinstance(a, Matrix)
+ assert isinstance(b, torch.Tensor)
+ return sputnik.dsd(
+ a.size(),
+ a.data, a.offsets,
+ a.row_indices,
+ a.column_indices,
+ a.offsets_t,
+ a.column_indices_t,
+ a.block_offsets_t,
+ not a.is_contiguous(),
+ b)
+
+
+def dds(a, b):
+ assert isinstance(a, torch.Tensor)
+ assert isinstance(b, Matrix)
+ return sputnik.dds(
+ a,
+ b.size(),
+ b.data, b.offsets,
+ b.row_indices,
+ b.column_indices,
+ b.offsets_t,
+ b.column_indices_t,
+ b.block_offsets_t,
+ not b.is_contiguous())
+
+
+def sdd(a, b, topo):
+ assert isinstance(a, torch.Tensor)
+ assert isinstance(b, torch.Tensor)
+ assert isinstance(topo, Matrix)
+ assert topo.is_contiguous()
+ out = sputnik.sdd(
+ a, b,
+ topo.size(),
+ topo.data,
+ topo.offsets,
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets_t,
+ topo.column_indices_t,
+ topo.block_offsets_t)
+ return Matrix(topo.size(),
+ out,
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets,
+ topo.column_indices_t,
+ topo.offsets_t,
+ topo.block_offsets_t)
diff --git a/build/torch210-cxx11-cu126-aarch64-linux/stk/ops/linear_ops_test.py b/build/torch210-cxx11-cu126-aarch64-linux/stk/ops/linear_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1c24d350df9c1b2346c7da885502cd696c88867
--- /dev/null
+++ b/build/torch210-cxx11-cu126-aarch64-linux/stk/ops/linear_ops_test.py
@@ -0,0 +1,216 @@
+import unittest
+import itertools
+import numpy as np
+import torch
+# from absl.testing import parameterized
+
+import stk
+
+
+def allclose(x, y, pct=0.25):
+ mask = torch.isclose(x, y, rtol=5e-2)
+ pct_diff = (mask.numel() - mask.sum()) / mask.numel() * 100
+ if pct_diff > pct:
+ print("{:.2f}% of values not close.".format(pct_diff))
+ return False
+ return True
+
+
+# An assortment of problems designed to make sure
+# the bindings are operating correctly.
+_MATRIX_SIZES = (
+ (128, 128, 128, 0.0),
+ (256, 256, 256, 0.5),
+ (2048, 1024, 512, 0.8),
+ (512, 128, 128, 0.0),
+ (128, 128, 512, 0.0),
+ (1024, 512, 512, 0.0),
+ (1024, 512, 512, 0.5),
+ (1024, 512, 512, 0.75),
+ (512, 512, 1024, 0.0),
+ (512, 512, 1024, 0.5),
+ (512, 512, 1024, 0.75),
+ (1024, 1024, 1024, 0.0),
+ (1024, 1024, 1024, 0.5),
+ (1024, 1024, 1024, 0.75),
+)
+
+_TRANSPOSE = (
+ (False, False),
+ (False, True),
+ (True, False),
+ (True, True),
+)
+
+_DTYPE = (
+ torch.float16, torch.bfloat16
+)
+
+def _generate_testcases():
+ testcases = itertools.product(_MATRIX_SIZES, _TRANSPOSE, _DTYPE)
+ testcases = [(*size, *trans, 128, dtype) for
+ (size, trans, dtype) in testcases]
+ return testcases
+
+_LINEAR_OP_TESTS = _generate_testcases()
+
+def _dense_and_sparse(rows, cols, sparsity, blocking, dtype, std=0.1):
+ mask = stk.random.dense_mask(rows, cols, sparsity, blocking)
+ dense = (torch.randn(rows, cols) * std * mask).type(dtype)
+ sparse = stk.ops.to_sparse(dense, blocking)
+ cuda_device = torch.device("cuda")
+ return (dense.to(cuda_device).requires_grad_(True),
+ sparse.to(cuda_device).requires_grad_(True))
+
+
+def _dense(rows, cols, dtype, std=0.1):
+ cuda_device = torch.device("cuda")
+ out = (torch.randn(rows, cols) * std).type(dtype)
+ return out.to(cuda_device).requires_grad_(True)
+
+
+def _dense_2x(rows, cols, dtype):
+ a = _dense(rows, cols, dtype)
+ return a, a.detach().requires_grad_(True)
+
+
+def _with_transpose(op, a, b, trans_a, trans_b):
+ a = a.t() if trans_a else a
+ b = b.t() if trans_b else b
+ return op(a, b)
+
+
+def _mmm(a, b, topo):
+ mask = stk.ops.to_dense(stk.ops.ones_like(topo))
+ return torch.mm(a, b) * mask
+
+
+def _sparse_out_with_transpose(op, a, b, topo, trans_a, trans_b):
+ a = a.t() if trans_a else a
+ b = b.t() if trans_b else b
+ return op(a, b, topo)
+
+
+def _mask(x, mask):
+ mask = stk.ops.to_dense(stk.ops.ones_like(mask))
+ return x * mask
+
+
+# @parameterized.parameters(*_LINEAR_OP_TESTS)
+# class LinearOpsTest(parameterized.TestCase):
+#
+# def testLinearOps_Dsd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+# # Construct the operands.
+# a_shape = (k, m) if trans_a else (m, k)
+# a_dense, a = _dense_and_sparse(*a_shape, sparsity, blocking, dtype)
+# b_shape = (n, k) if trans_b else (k, n)
+# b, bcp = _dense_2x(*b_shape, dtype)
+#
+# # Execute the matmul.
+# out = _with_transpose(stk.ops.dsd, a, b, trans_a, trans_b)
+# expected_out = _with_transpose(torch.mm, a_dense, bcp, trans_a, trans_b)
+#
+# # Compute the gradients w.r.t. the inputs.
+# expected_out.sum().backward()
+# out.sum().backward()
+#
+# # Validate the results.
+# self.assertEqual(out.dim(), 2)
+# self.assertEqual(expected_out.size()[0], out.size()[0])
+# self.assertEqual(expected_out.size()[1], out.size()[1])
+# self.assertTrue(allclose(out, expected_out))
+#
+# # LHS gradient.
+# grad = stk.ops.to_dense(a.grad)
+# expected_grad = _mask(a_dense.grad, a.grad)
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# # RHS gradient.
+# grad = b.grad
+# expected_grad = bcp.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# def testLinearOps_Dds(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+# # Construct the operands.
+# a_shape = (k, m) if trans_a else (m, k)
+# a, acp = _dense_2x(*a_shape, dtype)
+# b_shape = (n, k) if trans_b else (k, n)
+# b_dense, b = _dense_and_sparse(*b_shape, sparsity, blocking, dtype)
+#
+# # Execute the matmul.
+# out = _with_transpose(stk.ops.dds, a, b, trans_a, trans_b)
+# expected_out = _with_transpose(torch.mm, acp, b_dense, trans_a, trans_b)
+#
+# # Compute the gradients w.r.t. the inputs.
+# expected_out.sum().backward()
+# out.sum().backward()
+#
+# # Validate the results.
+# self.assertEqual(out.dim(), 2)
+# self.assertEqual(expected_out.size()[0], out.size()[0])
+# self.assertEqual(expected_out.size()[1], out.size()[1])
+# self.assertTrue(allclose(out, expected_out))
+#
+# # LHS gradient.
+# grad = a.grad
+# expected_grad = acp.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# # RHS gradient.
+# grad = stk.ops.to_dense(b.grad)
+# expected_grad = _mask(b_dense.grad, b.grad)
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# def testLinearOps_Sdd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+# # Construct the operands.
+# a_shape = (k, m) if trans_a else (m, k)
+# a, acp = _dense_2x(*a_shape, dtype)
+# b_shape = (n, k) if trans_b else (k, n)
+# b, bcp = _dense_2x(*b_shape, dtype)
+# _, topo = _dense_and_sparse(m, n, sparsity, blocking, dtype)
+#
+# # Execute the matmul.
+# out = _sparse_out_with_transpose(stk.ops.sdd, a, b, topo, trans_a, trans_b)
+# expected_out = _sparse_out_with_transpose(_mmm, acp, bcp, topo, trans_a, trans_b)
+#
+# # Compute the gradients w.r.t. the inputs.
+# expected_out.sum().backward()
+# stk.ops.sum(out).backward()
+#
+# # Validate the results.
+# out = stk.ops.to_dense(out)
+# self.assertEqual(out.dim(), 2)
+# self.assertEqual(expected_out.size()[0], out.size()[0])
+# self.assertEqual(expected_out.size()[1], out.size()[1])
+# self.assertTrue(allclose(out, expected_out))
+#
+# # LHS gradient.
+# grad = a.grad
+# expected_grad = acp.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# # RHS gradient.
+# grad = b.grad
+# expected_grad = bcp.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch210-cxx11-cu126-aarch64-linux/stk/ops/matrix_ops.py b/build/torch210-cxx11-cu126-aarch64-linux/stk/ops/matrix_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..447c72dc73439d84f58c917676cc04e64f13e97d
--- /dev/null
+++ b/build/torch210-cxx11-cu126-aarch64-linux/stk/ops/matrix_ops.py
@@ -0,0 +1,98 @@
+from ..backend import sputnik
+from ..matrix import Matrix
+import torch
+import numpy as np
+
+
+@torch.no_grad()
+def row_indices(shape, data, offsets, column_indices):
+ return sputnik.row_indices(shape, data, offsets, column_indices)
+
+
+# TODO(tgale): Replace this helper with a custom kernel. This operation
+# is much simpler to do than how it's currently implemented.
+@torch.no_grad()
+def _expand_for_blocking(idxs, blocking):
+ # Duplicate for block column dimension.
+ idxs = torch.reshape(idxs, [idxs.size()[0], 1, 2]).repeat(1, blocking, 1)
+
+ # Update the column indices.
+ idxs[:, :, 1] *= blocking
+ idxs[:, :, 1] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking])
+
+ # Duplicate for block row dimension.
+ idxs = torch.reshape(idxs, [idxs.size()[0], 1, blocking, 2])
+ idxs = idxs.repeat(1, blocking, 1, 1)
+
+ # Update the row indices.
+ idxs[:, :, :, 0] *= blocking
+ idxs[:, :, :, 0] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking, 1])
+ idxs = torch.reshape(idxs, [-1, 2])
+ return idxs
+
+
+# TODO(tgale): Add input type checking.
+@torch.no_grad()
+def to_dense(x):
+ assert isinstance(x, Matrix)
+
+ shape = (np.prod(x.shape[:-1]), x.shape[-1])
+ row_idxs = x.row_indices.type(torch.int32)
+ col_idxs = x.column_indices.type(torch.int32)
+ indices = _expand_for_blocking(torch.stack([row_idxs, col_idxs], dim=1), x.blocking)
+ indices = (indices[:, 0] * shape[1] + indices[:, 1]).type(torch.int64)
+
+ out = torch.zeros(shape[0] * shape[1], dtype=x.dtype, device=x.device)
+ out.scatter_(0, indices, x.data.flatten())
+ return out.reshape(x.size())
+
+
+@torch.no_grad()
+def _mask(x, blocking=1):
+ assert x.dim() == 2
+ assert x.size()[0] % blocking == 0
+ assert x.size()[1] % blocking == 0
+ block_rows = x.size()[0] // blocking
+ block_cols = x.size()[1] // blocking
+ x = torch.reshape(x, [block_rows, blocking, block_cols, blocking])
+ x = torch.sum(torch.abs(x), dim=(1, 3))
+ return x != 0
+
+
+# TODO(tgale): Add input type checking.
+@torch.no_grad()
+def to_sparse(x, blocking=1):
+ m = _mask(x, blocking)
+
+ # TODO(tgale): Set to appropriate type for input matrix.
+ row_nnzs = torch.sum(m, dim=1).type(torch.int32)
+ zeros = torch.zeros((1,), dtype=row_nnzs.dtype, device=row_nnzs.device)
+ offsets = torch.cat([zeros, torch.cumsum(row_nnzs, dim=0)])
+ offsets = offsets.type(torch.int32)
+
+ indices = torch.nonzero(m).type(torch.int16)
+ row_indices = indices[:, 0]
+ column_indices = indices[:, 1]
+
+ # Nonzero indices in the dense matrix.
+ nonzero_indices = torch.nonzero(m)
+ nonzero_indices = _expand_for_blocking(nonzero_indices, blocking)
+ nonzero_indices = nonzero_indices[:, 0] * x.size()[1] + nonzero_indices[:, 1]
+
+ # Gather the data and construct the sparse matrix.
+ data = torch.gather(x.flatten(), dim=0, index=nonzero_indices)
+ data = torch.reshape(data, [-1, blocking, blocking])
+ return Matrix(x.size(), data, row_indices, column_indices, offsets)
+
+
+@torch.no_grad()
+def ones_like(x):
+ return Matrix(x.size(),
+ torch.ones_like(x.data),
+ x.row_indices,
+ x.column_indices, x.offsets)
+
+
+def sum(x):
+ assert isinstance(x, Matrix)
+ return x.data.sum()
diff --git a/build/torch210-cxx11-cu126-aarch64-linux/stk/ops/matrix_ops_test.py b/build/torch210-cxx11-cu126-aarch64-linux/stk/ops/matrix_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..d172d921f6f08b0e4fb709207a458b0e1e071bd0
--- /dev/null
+++ b/build/torch210-cxx11-cu126-aarch64-linux/stk/ops/matrix_ops_test.py
@@ -0,0 +1,62 @@
+import unittest
+
+# from absl.testing import parameterized
+import stk
+import torch
+
+
+# @parameterized.parameters(
+# (8, 16, 0.0, 1),
+# (8, 16, 0.5, 1),
+# (8, 16, .95, 1),
+# (16, 8, 0.0, 1),
+# (16, 8, 0.5, 1),
+# (16, 8, .95, 1),
+# (8, 16, 0.0, 8),
+# (8, 16, 0.5, 8),
+# (8, 16, 1.0, 8),
+# (16, 8, 0.0, 8),
+# (16, 8, 0.5, 8),
+# (16, 8, 1.0, 8),
+# (128, 256, 0.5, 16),
+# (256, 128, 0.75, 32),
+# (512, 512, .875, 128))
+# class MatrixOpsTest(parameterized.TestCase):
+#
+# def testMatrixOps_FormatConversion(self, rows, cols, sparsity, blocking):
+# mask = stk.random.dense_mask(rows, cols, sparsity, blocking)
+# x = (torch.randn(rows, cols) * mask).type(torch.float16)
+#
+# # Convert the matrix to sparse format.
+# sparse_x = stk.ops.to_sparse(x, blocking)
+#
+# # Validate the matrix.
+# sparse_x.validate()
+#
+# # Validate the shape.
+# self.assertEqual(sparse_x.dim(), 2)
+# self.assertEqual(sparse_x.size()[0], rows)
+# self.assertEqual(sparse_x.size()[1], cols)
+#
+# # Validate the sparsity.
+# numblocks = rows // blocking * cols // blocking
+# nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+# self.assertEqual(sparse_x.nnz, nnz)
+#
+# # Convert back to dense format.
+# dense_x = stk.ops.to_dense(sparse_x)
+#
+# # Validate the shape.
+# self.assertEqual(dense_x.dim(), 2)
+# self.assertEqual(dense_x.size()[0], rows)
+# self.assertEqual(dense_x.size()[1], cols)
+#
+# # Validate the sparsity
+# self.assertEqual(torch.count_nonzero(dense_x).item(), nnz)
+#
+# # Validate the output.
+# self.assertTrue(torch.all(torch.eq(x, dense_x)))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch210-cxx11-cu126-aarch64-linux/stk/random/__init__.py b/build/torch210-cxx11-cu126-aarch64-linux/stk/random/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2576d1ca27283f77569a9a620c7c99fa68aaf30e
--- /dev/null
+++ b/build/torch210-cxx11-cu126-aarch64-linux/stk/random/__init__.py
@@ -0,0 +1,2 @@
+# from stk.random.random_ops import dense_mask, mask, randn
+from .random_ops import dense_mask, mask, randn
diff --git a/build/torch210-cxx11-cu126-aarch64-linux/stk/random/random_ops.py b/build/torch210-cxx11-cu126-aarch64-linux/stk/random/random_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1b36771e0eb8e7abf46bcb3b136b5fb1d29df93
--- /dev/null
+++ b/build/torch210-cxx11-cu126-aarch64-linux/stk/random/random_ops.py
@@ -0,0 +1,36 @@
+import numpy as np
+import torch
+from ..ops import matrix_ops
+
+
+@torch.no_grad()
+def dense_mask(rows, cols, sparsity, blocking=1):
+ assert sparsity >= 0.0 and sparsity <= 1.0
+ assert rows % blocking == 0 and cols % blocking == 0
+
+ block_rows, block_cols = (rows // blocking, cols // blocking)
+ nnz = round(block_rows * block_cols * (1 - sparsity))
+
+ out = np.ones(block_rows * block_cols)
+ mask = np.random.choice(out.size, out.size - nnz, replace=False)
+ out[mask] = 0.0
+
+ out = np.tile(
+ np.reshape(out, [block_rows, 1, block_cols, 1]),
+ (1, blocking, 1, blocking))
+ out = np.reshape(out, [rows, cols])
+ return torch.from_numpy(out.astype(np.float32))
+
+
+@torch.no_grad()
+def mask(m, n, sparsity, blocking=1):
+ out = dense_mask(m, n, sparsity, blocking).type(torch.float16)
+ return matrix_ops.to_sparse(out, blocking=blocking)
+
+
+@torch.no_grad()
+def randn(shape, sparsity, blocking=1):
+ shape_2d = (np.prod(shape[:-1]), shape[-1])
+ out = mask(*shape_2d, sparsity, blocking)
+ out.data.copy_(torch.randn(*out.data.shape))
+ return out.view(*shape)
diff --git a/build/torch210-cxx11-cu126-aarch64-linux/stk/random/random_ops_test.py b/build/torch210-cxx11-cu126-aarch64-linux/stk/random/random_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..71d716b78b5ec009cbf9ac2dfdf09162a0102e62
--- /dev/null
+++ b/build/torch210-cxx11-cu126-aarch64-linux/stk/random/random_ops_test.py
@@ -0,0 +1,73 @@
+import unittest
+
+# from absl.testing import parameterized
+from . import random
+import torch
+
+
+# @parameterized.parameters(
+# (8, 16, 0.0, 1),
+# (8, 16, 0.5, 1),
+# (8, 16, .95, 1),
+# (16, 8, 0.0, 1),
+# (16, 8, 0.5, 1),
+# (16, 8, .95, 1),
+# (8, 16, 0.0, 8),
+# (8, 16, 0.5, 8),
+# (8, 16, 1.0, 8),
+# (16, 8, 0.0, 8),
+# (16, 8, 0.5, 8),
+# (16, 8, 1.0, 8),
+# (128, 256, 0.5, 16),
+# (256, 128, 0.75, 32),
+# (512, 512, .875, 128))
+# class RandomOpsTest(parameterized.TestCase):
+#
+# def testRandomOps_DenseMask(self, rows, cols, sparsity, blocking):
+# mask = random.dense_mask(
+# rows, cols, sparsity, blocking)
+#
+# # Validate the shape.
+# self.assertEqual(mask.dim(), 2)
+# self.assertEqual(mask.size()[0], rows)
+# self.assertEqual(mask.size()[1], cols)
+#
+# # Validate the sparsity
+# numblocks = rows // blocking * cols // blocking
+# nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+# self.assertEqual(
+# torch.count_nonzero(mask).item(),
+# nnz)
+#
+# # Check values are zero or one.
+# self.assertTrue(
+# torch.all(torch.logical_or(
+# torch.eq(mask, 0),
+# torch.eq(mask, 1))))
+#
+# def testRandomOps_SparseMask(self, rows, cols, sparsity, blocking):
+# mask = random.mask(
+# rows, cols, sparsity, blocking)
+#
+# # Validate the matrix.
+# mask.validate()
+#
+# # Validate the shape.
+# self.assertEqual(mask.dim(), 2)
+# self.assertEqual(mask.size()[0], rows)
+# self.assertEqual(mask.size()[1], cols)
+#
+# # Validate the sparsity.
+# numblocks = rows // blocking * cols // blocking
+# nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+# self.assertEqual(mask.nnz, nnz)
+#
+# # Check values are zero or one.
+# self.assertTrue(
+# torch.all(torch.logical_or(
+# torch.eq(mask.data, 0),
+# torch.eq(mask.data, 1))))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch210-cxx11-cu126-aarch64-linux/xpu_fused_moe.py b/build/torch210-cxx11-cu126-aarch64-linux/xpu_fused_moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..0440b484e6f4073a384dcf3bc562174d2ef71d25
--- /dev/null
+++ b/build/torch210-cxx11-cu126-aarch64-linux/xpu_fused_moe.py
@@ -0,0 +1,672 @@
+# SPDX-License-Identifier: Apache-2.0
+# MegaBlocks XPU Fused MoE Implementation
+import os
+import torch
+
+from ._ops import ops, add_op_namespace_prefix
+
+from torch.library import register_fake
+
+
+def resolve_dtensor(weight: torch.Tensor):
+ """Convert DTensor to local tensor for use with custom ops."""
+ from torch.distributed._tensor import DTensor
+ if isinstance(weight, DTensor):
+ return weight.to_local()
+ return weight
+
+
+# Register fake/meta kernels for torch.compile compatibility
+def _register_xpu_fake_kernels():
+ """Register fake kernels for XPU MoE operations to support torch.compile."""
+
+ def _register_if_available(op_name, fn):
+ if hasattr(ops, op_name):
+ register_fake(add_op_namespace_prefix(op_name))(fn)
+
+ _register_if_available(
+ "cutlass_grouped_gemm_interface",
+ lambda ptr_A, ptr_B, ptr_scales, ptr_bias, ptr_D, expert_first_token_offset, N, K, num_experts, is_B_int4, is_B_mxfp4: ptr_D,
+ )
+
+ _register_if_available(
+ "fused_moe_prologue",
+ lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, ep_rank, ep_size, num_experts_on_rank: None,
+ )
+
+ _register_if_available(
+ "moe_gather",
+ lambda output, moe_output, topk_weights, permuted_row_to_unpermuted_row, unpermuted_row_to_permuted_row, expert_first_token_offset, num_experts: None,
+ )
+
+ _register_if_available(
+ "silu_and_mul",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "mul_and_silu",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_and_mul",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_tanh_and_mul",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_fast",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_new",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_quick",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "swigluoai_and_mul",
+ lambda out, input, alpha=1.702, limit=7.0: None,
+ )
+
+
+# Register fake kernels on module load
+_register_xpu_fake_kernels()
+
+
+# default
+def cutlass_grouped_gemm(input_A, input_B, bias, output, expert_token_count, n,
+ k, num_experts):
+ # expert_token_count_ = torch.tensor(expert_token_count,
+ # dtype=torch.int64,
+ # device=input_A.device)
+ # if bias is not None:
+ # bias = bias.repeat_interleave(expert_token_count_, dim=0).float()
+
+ def exclusive_prefix_sum(arr):
+ prefix = [0]
+ for i, x in enumerate(arr):
+ prefix.append(prefix[-1] + x)
+ return prefix
+
+ expert_offset = torch.tensor(exclusive_prefix_sum(expert_token_count),
+ dtype=torch.int64,
+ device="xpu")
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=None,
+ ptr_bias=bias,
+ ptr_D=output,
+ expert_first_token_offset=expert_offset,
+ N=n,
+ K=k,
+ num_experts=num_experts,
+ is_B_int4=False,
+ is_B_mxfp4=False)
+
+
+def cutlass_grouped_gemm_xe2(input_A, input_B, scales, bias, output,
+ num_rows_per_expert, n, k, num_experts, is_B_int4,
+ is_B_mxfp4):
+ expert_first_token_offset = torch.cat([
+ torch.tensor([0],
+ dtype=num_rows_per_expert.dtype,
+ device=num_rows_per_expert.device),
+ torch.cumsum(num_rows_per_expert, dim=0)
+ ]).to(torch.int64)
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=scales,
+ ptr_bias=bias,
+ ptr_D=output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=n,
+ K=k,
+ num_experts=num_experts,
+ is_B_int4=is_B_int4,
+ is_B_mxfp4=is_B_mxfp4)
+
+
+def ceilDiv(a, b):
+ return (a + b - 1) // b
+
+
+def compute_num_tokens_per_block(num_tokens, num_experts_per_node):
+ for num_tokens_per_block in [32, 64, 128, 256, 512, 1024]:
+ num_blocks_per_seq = ceilDiv(num_tokens, num_tokens_per_block)
+ if num_blocks_per_seq * num_experts_per_node <= num_tokens_per_block:
+ return num_tokens_per_block
+ return 1024
+
+
+def _bytes_to_typed_tensor(byte_tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
+ """Reinterpret a uint8 buffer as a typed tensor by copying bytes.
+
+ This avoids `Tensor.view(dtype)` which can fail under torch.compile
+ constant folding when shape divisibility is not proven.
+ """
+ if byte_tensor.dtype != torch.uint8:
+ raise ValueError("byte_tensor must be uint8")
+ itemsize = torch.empty((), dtype=dtype).element_size()
+ numel = byte_tensor.numel() // itemsize
+ out = torch.empty((numel,), dtype=dtype, device=byte_tensor.device)
+ out.view(torch.uint8).copy_(byte_tensor.contiguous())
+ return out
+
+
+def implement_zp(qweight):
+ # change u4 to s4 to avoid zero point in gemm kernel
+ # only support default zero point now
+ assert qweight.dtype == torch.uint8, "Input tensor must be uint8"
+
+ high_u4 = (qweight >> 4) & 0x0F
+ low_u4 = qweight & 0x0F
+
+ high_s8 = high_u4.to(torch.int8)
+ low_s8 = low_u4.to(torch.int8)
+
+ high_s8 = high_s8 - 8
+ low_s8 = low_s8 - 8
+
+ def pack_compact(a, b):
+
+ def process_number(x):
+ sign = (x < 0).to(torch.uint8)
+ abs_low3 = (x.view(torch.uint8) & 0x7).to(torch.uint8)
+ return (sign << 3) | abs_low3
+
+ packed_a = process_number(a)
+ packed_b = process_number(b)
+
+ return (packed_a << 4) | packed_b
+
+ result = pack_compact(high_s8, low_s8)
+
+ return result
+
+
+def xpu_fused_moe(hidden_states,
+ w13,
+ w13_scales,
+ w13_bias,
+ w2,
+ w2_scales,
+ w2_bias,
+ topk_weights,
+ topk_ids,
+ n_experts_per_token,
+ activation,
+ num_experts,
+ ep_rank=0,
+ ep_size=1,
+ is_fp8=False,
+ is_int4=False,
+ is_mxfp4=False):
+ '''
+ hidden_states: [num_rows, hidden_size]
+ w13: [num_experts, 2*inter_size, hidden_size]
+ w13_scales:
+ None for bf16/fp16
+ or [num_experts] for fp8
+ or [num_experts, 2*inter_size, hidden_size // group_size] for 4bits
+ w13_bias: [num_experts, 2*inter_size] or None
+ w2: [num_experts, hidden_size, inter_size]
+ w2_scales:
+ None for bf16/fp16
+ or [num_experts] for fp8
+ or [num_experts, hidden_size, inter_size // group_size] for 4bits
+ w2_bias: [num_experts, hidden_size] or None
+ topk_weights: [num_rows, topk]
+ topk_ids: [num_rows, topk]
+ n_experts_per_token: int
+ activation: str
+ num_experts: int
+ is_int4: bool
+ is_mxfp4: bool
+ '''
+
+ # Resolve DTensors to local tensors before passing to custom ops
+ hidden_states = resolve_dtensor(hidden_states)
+ w13 = resolve_dtensor(w13)
+ w2 = resolve_dtensor(w2)
+ if w13_scales is not None:
+ w13_scales = resolve_dtensor(w13_scales)
+ if w13_bias is not None:
+ w13_bias = resolve_dtensor(w13_bias)
+ if w2_scales is not None:
+ w2_scales = resolve_dtensor(w2_scales)
+ if w2_bias is not None:
+ w2_bias = resolve_dtensor(w2_bias)
+ topk_weights = resolve_dtensor(topk_weights)
+ topk_ids = resolve_dtensor(topk_ids)
+
+ output = torch.empty_like(hidden_states)
+ num_rows, hidden_size = list(hidden_states.shape)
+
+ dim_last = w13.shape[-1]
+ dim_second_last = w13.shape[-2]
+
+ # w13 is combined gate+up weights, so one dimension is 2*inter_size
+ # Determine which dimension is hidden_size and which is 2*inter_size
+ if dim_second_last == hidden_size:
+ # w13 is [E, hidden_size, 2*inter_size] - standard layout
+ inter_size = dim_last // 2
+ needs_transpose = False
+ else:
+ # w13 is [E, 2*inter_size, hidden_size] - needs transpose
+ inter_size = dim_second_last // 2
+ needs_transpose = True
+
+ assert w13.is_contiguous() and w2.is_contiguous()
+
+ # 4bits support [E, N, K]
+ # other types [E, K, N]
+ if not is_int4 and not is_mxfp4:
+ if not hasattr(w13, 'xpu_fused_moe'):
+ if needs_transpose:
+ w13.data = w13.transpose(-1, -2).contiguous()
+ w2.data = w2.transpose(-1, -2).contiguous()
+ w13.xpu_fused_moe = True
+ w13.inter_size = inter_size
+ else:
+ inter_size = w13.inter_size
+
+ if is_int4 and not hasattr(w13, 'xpu_fused_moe'):
+ w13_tmp = torch.empty_like(w13)
+ w2_tmp = torch.empty_like(w2)
+ for i in range(num_experts):
+ w13_tmp[i] = implement_zp(w13[i])
+ w2_tmp[i] = implement_zp(w2[i])
+ w13_tmp = w13_tmp.contiguous()
+ w2_tmp = w2_tmp.contiguous()
+ w13.data = w13_tmp
+ w2.data = w2_tmp
+ w13.xpu_fused_moe = True
+
+ # TODO: will all integrated in Cpp func. Temporary expose before gemm fusion
+ num_experts_per_node = num_experts
+ experts_per_token = n_experts_per_token
+ num_moe_inputs = n_experts_per_token * num_rows
+ permuted_elems = num_moe_inputs * hidden_size
+ # interbuf_elems = num_moe_inputs * inter_size
+ permuted_row_to_unpermuted_row_size = num_moe_inputs * 4
+ permuted_token_selected_experts_size = num_moe_inputs * 4
+ src_to_dest_map_size = experts_per_token * num_rows * 4
+ expert_first_token_offset_size = (num_experts_per_node + 1) * 8
+ num_tokens_per_block = compute_num_tokens_per_block(
+ num_rows, num_experts_per_node)
+ num_blocks_per_seq = ceilDiv(num_rows, num_tokens_per_block)
+ blocked_expert_counts_size = num_experts_per_node * num_blocks_per_seq * 4
+ blocked_expert_counts_cumsum_size = blocked_expert_counts_size
+ blocked_row_to_unpermuted_row_size = num_experts_per_node * num_rows * 4
+ permuted_data_size = permuted_elems * hidden_states.element_size()
+ permuted_token_final_scales_size = num_moe_inputs * 4
+
+ ws_map = {}
+ map_offset = 0
+
+ def config_ws(name, size):
+ nonlocal map_offset
+ if size % 256 != 0:
+ size += 256 - size % 256
+ ws_map[name] = (size, map_offset)
+ map_offset += size
+
+ config_ws("permuted_row_to_unpermuted_row",
+ permuted_row_to_unpermuted_row_size)
+ config_ws("permuted_token_selected_experts",
+ permuted_token_selected_experts_size)
+ config_ws("unpermuted_row_to_permuted_row", src_to_dest_map_size)
+ config_ws("blocked_expert_counts", blocked_expert_counts_size)
+ config_ws("blocked_expert_counts_cumsum",
+ blocked_expert_counts_cumsum_size)
+ config_ws("blocked_row_to_unpermuted_row",
+ blocked_row_to_unpermuted_row_size)
+ config_ws("expert_first_token_offset", expert_first_token_offset_size)
+ config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
+ config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
+
+ workspace = torch.zeros(map_offset,
+ dtype=torch.uint8,
+ device=hidden_states.device)
+ if topk_ids.dtype == torch.int32:
+ topk_ids = topk_ids.to(torch.int64)
+ ops.fused_moe_prologue(
+ input=hidden_states,
+ token_selected_experts=topk_ids,
+ token_final_scales=topk_weights,
+ workspace=workspace,
+ hidden_size=hidden_size,
+ inter_size=inter_size,
+ ep_rank=ep_rank,
+ ep_size=ep_size,
+ num_experts_on_rank=num_experts_per_node)
+
+ expert_first_token_offset_bytes = workspace[
+ ws_map["expert_first_token_offset"][1]:
+ ws_map["expert_first_token_offset"][1] +
+ expert_first_token_offset_size]
+ unpermuted_row_to_permuted_row_bytes = workspace[
+ ws_map["unpermuted_row_to_permuted_row"][1]:
+ ws_map["unpermuted_row_to_permuted_row"][1] +
+ src_to_dest_map_size]
+ permuted_row_to_unpermuted_row_bytes = workspace[
+ ws_map["permuted_row_to_unpermuted_row"][1]:
+ ws_map["permuted_row_to_unpermuted_row"][1] +
+ permuted_row_to_unpermuted_row_size]
+
+ if torch.compiler.is_compiling():
+ expert_first_token_offset = _bytes_to_typed_tensor(
+ expert_first_token_offset_bytes, torch.int64
+ )
+ unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
+ unpermuted_row_to_permuted_row_bytes, torch.int32
+ )
+ permuted_row_to_unpermuted_row = _bytes_to_typed_tensor(
+ permuted_row_to_unpermuted_row_bytes, torch.int32
+ )
+ else:
+ expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
+ unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
+ permuted_row_to_unpermuted_row = permuted_row_to_unpermuted_row_bytes.view(torch.int32)
+ gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
+ ws_map["overlapped_gemm1_gemm2_inputs"][1] +
+ permuted_data_size].view(hidden_states.dtype).view(
+ num_moe_inputs, hidden_size)
+ # permuted_token_final_scales = workspace[
+ # ws_map["permuted_token_final_scales"][1]:
+ # ws_map["permuted_token_final_scales"][1] +
+ # permuted_token_final_scales_size].view(torch.float)
+ gemm1_output = torch.empty((num_moe_inputs, 2 * inter_size),
+ dtype=hidden_states.dtype,
+ device=hidden_states.device)
+
+ ########### gemm1 ##################
+ input_B = w13
+
+ if not is_fp8 and not is_int4 and not is_mxfp4:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=gemm1_input,
+ ptr_B=input_B,
+ ptr_scales=None,
+ ptr_bias=w13_bias,
+ ptr_D=gemm1_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=2 * inter_size,
+ K=hidden_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+ else:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=gemm1_input,
+ ptr_B=input_B,
+ ptr_scales=w13_scales,
+ ptr_bias=w13_bias,
+ ptr_D=gemm1_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=2 * inter_size,
+ K=hidden_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+
+ # act
+ act_output = torch.empty((num_moe_inputs, inter_size),
+ dtype=gemm1_output.dtype,
+ device=gemm1_output.device)
+ if activation == "silu":
+ ops.silu_and_mul(act_output, gemm1_output)
+ elif activation == "gelu":
+ ops.gelu_and_mul(act_output, gemm1_output)
+ elif activation == "swigluoai":
+ ops.swigluoai_and_mul(act_output, gemm1_output, 1.702, 7.0)
+ else:
+ raise ValueError(f"Unsupported FusedMoe activation: {activation}.")
+
+ ########### gemm2 ##################
+ input_A = act_output.contiguous()
+ input_B = w2
+ gemm2_output = torch.empty((num_moe_inputs, hidden_size),
+ dtype=hidden_states.dtype,
+ device=hidden_states.device)
+ if not is_fp8 and not is_int4 and not is_mxfp4:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=None,
+ ptr_bias=w2_bias,
+ ptr_D=gemm2_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=hidden_size,
+ K=inter_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+ else:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=w2_scales,
+ ptr_bias=w2_bias,
+ ptr_D=gemm2_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=hidden_size,
+ K=inter_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+
+ ops.moe_gather(output, gemm2_output, topk_weights,
+ permuted_row_to_unpermuted_row,
+ unpermuted_row_to_permuted_row,
+ expert_first_token_offset,
+ num_experts_per_node)
+ return output
+
+
+def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor:
+ """Apply jitter to the input tensor for regularization."""
+ low = 1.0 - moe_jitter_eps
+ high = 1.0 + moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return x * (low + noise * (high - low))
+
+
+def compute_top_k(scores: torch.Tensor, moe_top_k: int):
+ """Compute the top-k scores from the logits."""
+ if moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, moe_top_k, dim=-1)
+
+
+def route_tokens_xpu(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: torch.Tensor,
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ training: bool = False,
+) -> tuple:
+ """Route tokens to experts and compute expert weights and indices (XPU version)."""
+ if training and moe_jitter_eps is not None:
+ x = apply_jitter(x, moe_jitter_eps)
+
+ x_flat = x.view(-1, x.shape[-1])
+ logits = torch.nn.functional.linear(x_flat, router_weight, router_bias)
+ expert_weights, expert_indices = compute_top_k(logits, moe_top_k)
+ expert_weights = expert_weights.softmax(dim=-1)
+ if moe_normalize_expert_weights is not None:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+
+ return logits, expert_weights, expert_indices
+
+
+def _get_device_mesh(model):
+ """Extract device_mesh from child's unused pre_hook closure for EP support."""
+ try:
+ hook = next(
+ h
+ for h in model.experts._forward_pre_hooks.values()
+ if "device_mesh" in h.__code__.co_freevars
+ )
+ return hook.__closure__[
+ hook.__code__.co_freevars.index("device_mesh")
+ ].cell_contents
+ except Exception:
+ return None
+
+
+class MegaBlocksMoeMLP(torch.nn.Module):
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> tuple:
+ """
+ Forward pass through the MoE layer.
+
+ Args:
+ x: Input tensor of shape [batch_size, seq_len, hidden_size] or [tokens, hidden_size]
+
+ Returns:
+ Tuple of (output, expert_weights) where:
+ - output: Tensor of same shape as input
+ - expert_weights: Expert weights for each token [tokens, top_k]
+ """
+
+ # Get MoE parameters from the wrapped modules
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+
+ # Get EP (Expert Parallelism) parameters
+ ep_size = 1
+ ep_rank = 0
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = _get_device_mesh(self)
+ if device_mesh is not None:
+ expert_parallel_group = device_mesh.get_group()
+ if expert_parallel_group is not None:
+ import torch.distributed as dist
+ if dist.is_initialized():
+ ep_size = dist.get_world_size(expert_parallel_group)
+ ep_rank = dist.get_rank(expert_parallel_group)
+
+ # Number of experts on this rank
+ num_experts_on_rank = moe_num_experts // ep_size
+
+ # Detect activation type - check for GptOss-style swigluoai activation
+ # GptOssExperts has alpha and limit attributes for swigluoai
+ if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
+ activation = "swigluoai"
+ else:
+ activation = getattr(self.experts, "activation", "silu")
+
+ # Get weight tensors - support different naming conventions
+ if hasattr(self.experts, "gate_up_proj"):
+ w13 = self.experts.gate_up_proj
+ # NOTE: swigluoai_and_mul kernel expects interleaved layout [g0,u0,g1,u1,...]
+ # which matches GptOss's gate_up_proj layout, so no conversion needed.
+
+ elif hasattr(self.experts, "w1"):
+ # Combine w1 and w3 if stored separately
+ w1 = self.experts.w1
+ w3 = getattr(self.experts, "w3", None)
+ if w3 is not None:
+ w13 = torch.cat([w1, w3], dim=-2)
+ else:
+ w13 = w1
+ else:
+ raise AttributeError("experts module must have 'gate_up_proj' or 'w1' attribute")
+
+ if hasattr(self.experts, "down_proj"):
+ w2 = self.experts.down_proj
+ elif hasattr(self.experts, "w2"):
+ w2 = self.experts.w2
+ else:
+ raise AttributeError("experts module must have 'down_proj' or 'w2' attribute")
+
+ # Get optional bias tensors
+ w13_bias = getattr(self.experts, "gate_up_proj_bias", None)
+ w2_bias = getattr(self.experts, "down_proj_bias", None)
+
+ # Get quantization info
+ is_fp8 = getattr(self.experts, "is_fp8", False)
+ is_int4 = getattr(self.experts, "is_int4", False)
+ is_mxfp4 = getattr(self.experts, "is_mxfp4", False)
+
+ w13_scales = getattr(self.experts, "gate_up_proj_scales", None)
+ w2_scales = getattr(self.experts, "down_proj_scales", None)
+
+ # Store original shape
+ in_shape = x.size()
+
+ # Route tokens to experts
+ logits, expert_weights, expert_indices = route_tokens_xpu(
+ x,
+ self.router.weight,
+ getattr(self.router, "bias", None),
+ moe_top_k,
+ moe_num_experts,
+ moe_jitter_eps,
+ moe_normalize_expert_weights,
+ self.training,
+ )
+
+ # Reshape input for fused MoE
+ x_flat = x.view(-1, x.shape[-1])
+
+ # Call XPU fused MoE kernel
+ output = xpu_fused_moe(
+ hidden_states=x_flat,
+ w13=w13,
+ w13_scales=w13_scales,
+ w13_bias=w13_bias,
+ w2=w2,
+ w2_scales=w2_scales,
+ w2_bias=w2_bias,
+ topk_weights=expert_weights.float(),
+ topk_ids=expert_indices,
+ n_experts_per_token=moe_top_k,
+ activation=activation,
+ num_experts=num_experts_on_rank,
+ ep_rank=ep_rank,
+ ep_size=ep_size,
+ is_fp8=is_fp8,
+ is_int4=is_int4,
+ is_mxfp4=is_mxfp4,
+ )
+
+ # All-reduce across EP group to combine partial expert outputs
+ if ep_size > 1 and expert_parallel_group is not None:
+ import torch.distributed as dist
+ dist.all_reduce(output, op=dist.ReduceOp.SUM, group=expert_parallel_group)
+
+ # Restore original shape
+ output = output.view(in_shape)
+
+ return output, expert_weights
+
+
+# Export classes and functions
+__all__ = [
+ "MegaBlocksMoeMLP",
+ "xpu_fused_moe",
+ "cutlass_grouped_gemm",
+ "cutlass_grouped_gemm_xe2",
+]
\ No newline at end of file
diff --git a/build/torch210-cxx11-cu126-x86_64-linux/__init__.py b/build/torch210-cxx11-cu126-x86_64-linux/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..38075732c6d8fa0e1e6ef493145e1aca3851ae6b
--- /dev/null
+++ b/build/torch210-cxx11-cu126-x86_64-linux/__init__.py
@@ -0,0 +1,202 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+from ._ops import ops
+
+from .grouped_gemm import backend as gg_backend
+from .grouped_gemm import ops as gg_ops
+
+
+from ._layers.arguments import Arguments
+from ._layers.dmoe import ParallelDroplessMLP, dMoE
+from ._layers.glu import SparseGLU
+from ._layers.mlp import MLP, SparseMLP
+from ._layers.moe import MoE, ParallelMLP, get_load_balancing_loss
+
+from . import layers
+
+# This section contains the direct kernel exports (not inlcuded in the original code)
+def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
+ """
+ Compute exclusive cumulative sum along the specified dimension.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ result = ops.exclusive_cumsum(x, dim)
+ out.copy_(result)
+ return out
+
+
+def inclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
+ """
+ Compute inclusive cumulative sum along the specified dimension.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ result = ops.inclusive_cumsum(x, dim)
+ out.copy_(result)
+ return out
+
+
+def histogram(x: torch.Tensor, num_bins: int) -> torch.Tensor:
+ """
+ Compute histogram of input tensor values.
+
+ Args:
+ x: Input tensor
+ num_bins: Number of histogram bins
+
+ Returns:
+ Histogram tensor with counts for each bin
+ """
+ return ops.histogram(x, num_bins)
+
+
+def indices(
+ padded_bins: torch.Tensor,
+ block_size: int,
+ output_block_rows: int,
+ output_block_columns: int,
+) -> torch.Tensor:
+ """
+ Construct indices from padded bins for sparse operations.
+
+ Args:
+ padded_bins: Tensor containing bin boundaries
+ block_size: Size of each block
+ output_block_rows: Number of rows in output blocks
+ output_block_columns: Number of columns in output blocks
+
+ Returns:
+ Tensor containing constructed indices
+ """
+ return ops.indices(padded_bins, block_size, output_block_rows, output_block_columns)
+
+
+def replicate_forward(
+ x: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Forward pass of replicate operation - replicate values according to bin sizes.
+
+ Args:
+ x: Input tensor with values to replicate
+ bins: Tensor containing bin sizes
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ return ops.replicate_forward(x, bins, out)
+
+
+def replicate_backward(
+ grad: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Backward pass of replicate operation - reduce gradients back to bins.
+
+ Args:
+ grad: Gradient tensor to reduce
+ bins: Tensor containing bin sizes
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ return ops.replicate_backward(grad, bins, out)
+
+
+def sort(
+ x: torch.Tensor, end_bit: int, x_out: torch.Tensor, iota_out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Radix sort with index tracking.
+
+ Args:
+ x: Input tensor to sort
+ end_bit: Number of bits to consider in sorting
+ x_out: Output tensor for sorted values
+ iota_out: Output tensor for sorted indices
+
+ Returns:
+ The sorted values tensor
+ """
+ return ops.sort(x, end_bit, x_out, iota_out)
+
+
+# Convenience functions for common use cases
+def cumsum(x: torch.Tensor, dim: int = -1, exclusive: bool = False) -> torch.Tensor:
+ """
+ Compute cumulative sum with automatic output allocation.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum (default: last dimension)
+ exclusive: Whether to compute exclusive (True) or inclusive (False) cumsum
+
+ Returns:
+ New tensor containing the cumulative sum
+ """
+ out = torch.empty_like(x)
+ if exclusive:
+ return exclusive_cumsum(x, dim, out)
+ else:
+ return inclusive_cumsum(x, dim, out)
+
+
+def argsort(x: torch.Tensor, end_bit: int = 32) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Sort tensor and return both sorted values and indices.
+
+ Args:
+ x: Input tensor to sort
+ end_bit: Number of bits to consider in sorting
+
+ Returns:
+ Tuple of (sorted_values, sorted_indices)
+ """
+ x_out = torch.empty_like(x)
+ iota_out = torch.empty_like(x)
+ sort(x, end_bit, x_out, iota_out)
+ return x_out, iota_out
+
+
+# Export public API
+__all__ = [
+ "MyReplacementLayer",
+ # Direct kernel exports
+ "exclusive_cumsum",
+ "inclusive_cumsum",
+ "histogram",
+ "indices",
+ "replicate_forward",
+ "replicate_backward",
+ "sort",
+ "cumsum",
+ "argsort",
+ # Original exports
+ "Arguments",
+ "ParallelDroplessMLP",
+ "dMoE",
+ "SparseGLU",
+ "MLP",
+ "SparseMLP",
+ "MoE",
+ "ParallelMLP",
+ "get_load_balancing_loss",
+]
diff --git a/build/torch210-cxx11-cu126-x86_64-linux/_layers/__init__.py b/build/torch210-cxx11-cu126-x86_64-linux/_layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a720e7a2cc4e44636f6e433a2750e945dc38e8b2
--- /dev/null
+++ b/build/torch210-cxx11-cu126-x86_64-linux/_layers/__init__.py
@@ -0,0 +1,10 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# from megablocks.layers.dmoe import dMoE
+from .moe import MoE
+
+__all__ = [
+ 'MoE',
+ # 'dMoE',
+]
diff --git a/build/torch210-cxx11-cu126-x86_64-linux/_layers/activation_fn.py b/build/torch210-cxx11-cu126-x86_64-linux/_layers/activation_fn.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e1d956704840aa4daf7d1d71d24e051567feab9
--- /dev/null
+++ b/build/torch210-cxx11-cu126-x86_64-linux/_layers/activation_fn.py
@@ -0,0 +1,33 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Callable, Union
+
+import torch
+from ..stk import Matrix
+
+
+def act_fn(
+ x: Matrix,
+ function: Callable,
+ return_grad_fn: bool = False,
+ **kwargs,
+) -> Union[tuple[Matrix, Any] | Matrix]:
+ assert isinstance(x, Matrix)
+ with torch.set_grad_enabled(torch.is_grad_enabled() or return_grad_fn):
+ if return_grad_fn:
+ x.data.requires_grad = True
+ out = function(x.data, **kwargs)
+ y = Matrix(
+ x.size(),
+ out,
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
+ if return_grad_fn:
+ return y, out.backward
+ return y
diff --git a/build/torch210-cxx11-cu126-x86_64-linux/_layers/all_to_all.py b/build/torch210-cxx11-cu126-x86_64-linux/_layers/all_to_all.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ac7067bcaa34db1d82b340c43550fe3577aa7a3
--- /dev/null
+++ b/build/torch210-cxx11-cu126-x86_64-linux/_layers/all_to_all.py
@@ -0,0 +1,54 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import torch.distributed as dist
+
+
+class AllToAllOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op):
+ out = torch.empty((sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype)
+
+ ctx.input_shape = x.shape
+ ctx.output_split_sizes = output_split_sizes
+ ctx.input_split_sizes = input_split_sizes
+ ctx.group = group
+ handle = dist.all_to_all_single(
+ out,
+ x,
+ output_split_sizes=output_split_sizes,
+ input_split_sizes=input_split_sizes,
+ group=group,
+ async_op=async_op,
+ )
+ return out, handle
+
+ @staticmethod
+ def backward(ctx, grad, _):
+ if ctx.needs_input_grad[0]:
+ out = torch.empty(
+ ctx.input_shape,
+ device=grad.device,
+ dtype=grad.dtype,
+ )
+ dist.all_to_all_single(
+ out,
+ grad,
+ output_split_sizes=ctx.input_split_sizes,
+ input_split_sizes=ctx.output_split_sizes,
+ group=ctx.group,
+ )
+ return out, None, None, None, None
+ return None, None, None, None, None
+
+
+def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False):
+ return AllToAllOp.apply(
+ x,
+ output_split_sizes,
+ input_split_sizes,
+ group,
+ async_op,
+ )
diff --git a/build/torch210-cxx11-cu126-x86_64-linux/_layers/arguments.py b/build/torch210-cxx11-cu126-x86_64-linux/_layers/arguments.py
new file mode 100644
index 0000000000000000000000000000000000000000..4db9b1bd38bc2e2f421625c124f86b85f45c5ae0
--- /dev/null
+++ b/build/torch210-cxx11-cu126-x86_64-linux/_layers/arguments.py
@@ -0,0 +1,101 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import dataclasses
+from functools import partial
+from typing import Any, Callable, Optional, Union
+
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+
+# import megablocks.grouped_gemm_util as grouped_gemm
+from .. import grouped_gemm_util as grouped_gemm
+
+# Type annotation for in-place Tensor initialization function.
+InitFn = Union[Callable[[torch.Tensor], None], partial[torch.Tensor]]
+
+_ALLOWED_BITWIDTHS = (-1, 4, 8)
+
+DEFAULT_ACTIVATION_FN = partial(F.gelu, approximate='tanh')
+
+
+@dataclasses.dataclass
+class Arguments:
+ # Model arguments.
+ hidden_size: int = 1024
+ ffn_hidden_size: int = 4096
+ num_layers: int = 1
+ bias: bool = True
+ return_bias: bool = True
+ activation_fn: Optional[Callable] = DEFAULT_ACTIVATION_FN
+
+ # MoE arguments.
+ moe_num_experts: int = 1
+ moe_top_k: int = 1
+ moe_capacity_factor: int = 1
+ moe_normalize_expert_weights: Optional[Union[int, float]] = None
+ moe_loss_weight: float = 0.1
+ moe_jitter_eps: Optional[float] = None
+ moe_lbl_in_fp32: bool = False
+
+ # Parallelism arguments.
+ moe_expert_model_parallelism: bool = False
+ expert_parallel_group: Optional[dist.ProcessGroup] = None
+ pipeline_model_parallel_size: int = 1
+ num_layers_per_virtual_pipeline_stage: Optional[int] = None
+
+ # Compute arguments.
+ memory_optimized_mlp: bool = False
+ mlp_type: str = 'mlp'
+ mlp_impl: str = 'sparse'
+
+ # Initialization arguments.
+ fp16: bool = True
+ bf16: bool = False
+ device: Union[int, torch.device] = dataclasses.field(default_factory=torch.cuda.current_device)
+ init_method: InitFn = partial(torch.nn.init.normal_, mean=0.0, std=0.02)
+ output_layer_init_method: InitFn = init_method
+
+ # Benchmarking arguments.
+ uniform_expert_assignment: bool = False
+
+ # shared expert arguments
+ shared_expert: bool = False # enable using shared expert
+ fc_cls: Any = torch.nn.Linear # class of the fully connected layer in shared expert (purpose: to allow using custom FC layer eg te.Linear (for FP8))
+ fc_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict,) # kwargs for custom fc layers
+ remat_act_fn: bool = True # enable act fn to be rematerialized instead of stored
+ shared_expert_hidden_size: Optional[
+ int] = None # hidden size of the shared expert IF we want to set it to something different from hidden_size
+ shared_expert_weighted_sum: bool = False # enable using weighted sum for shared expert output (wieghted by number of experts used)
+
+ # Router Z-loss arguments
+ moe_zloss_weight: float = 0 # 1e-3 is a reasonable value
+ moe_zloss_in_fp32: bool = False
+
+ def __post_init__(self):
+ # Sparse MLP is not supported with triton >=3.2.0
+ # TODO: Remove this once sparse is supported with triton >=3.2.0
+ if self.__getattribute__('mlp_impl') == 'sparse':
+ try:
+ import triton
+ if triton.__version__ >= '3.2.0':
+ raise ValueError(
+ 'Sparse MLP is not supported with triton >=3.2.0. Please use mlp_impl="grouped" instead.',
+ )
+ except ImportError:
+ raise ImportError('Triton is required for sparse MLP implementation')
+
+ if self.__getattribute__('mlp_impl') == 'grouped':
+ grouped_gemm.assert_grouped_gemm_is_available()
+
+ if self.shared_expert_hidden_size is None:
+ self.shared_expert_hidden_size = self.ffn_hidden_size
+
+
+def from_megatron(megatron_args: Any):
+ args = Arguments()
+ for field in dataclasses.fields(args):
+ if hasattr(megatron_args, field.name):
+ setattr(args, field.name, getattr(megatron_args, field.name))
+ return args
diff --git a/build/torch210-cxx11-cu126-x86_64-linux/_layers/common.py b/build/torch210-cxx11-cu126-x86_64-linux/_layers/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d07109702963ba48a3b94ab860807954dfd79c1
--- /dev/null
+++ b/build/torch210-cxx11-cu126-x86_64-linux/_layers/common.py
@@ -0,0 +1,26 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+from .arguments import Arguments
+
+
+def dtype(args: Arguments):
+ if args.fp16:
+ return torch.float16
+ elif args.bf16:
+ return torch.bfloat16
+ return None
+
+
+def cast_if_autocast_enabled(tensor):
+ if torch.is_autocast_enabled():
+ if tensor.device.type == 'cuda':
+ dtype = torch.get_autocast_gpu_dtype()
+ elif tensor.device.type == 'cpu':
+ dtype = torch.get_autocast_cpu_dtype()
+ else:
+ raise NotImplementedError()
+ return tensor.to(dtype=dtype)
+ return tensor
diff --git a/build/torch210-cxx11-cu126-x86_64-linux/_layers/dmlp_registry.py b/build/torch210-cxx11-cu126-x86_64-linux/_layers/dmlp_registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..de2ed047042e438c7190ebb139b6f7f30009734c
--- /dev/null
+++ b/build/torch210-cxx11-cu126-x86_64-linux/_layers/dmlp_registry.py
@@ -0,0 +1,42 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Union
+
+from . import glu, mlp
+from .arguments import Arguments
+
+MlpType = Union[mlp.SparseMLP, glu.SparseGLU]
+
+_REGISTRY = {
+ 'mlp': {
+ 'grouped': mlp.GroupedMLP,
+ 'sparse': mlp.SparseMLP,
+ },
+ 'glu': {
+ 'grouped': glu.GroupedGLU,
+ 'sparse': glu.SparseGLU,
+ },
+}
+
+
+def get(args: Arguments) -> MlpType:
+ """Returns an MLP for use in a dMoE instance.
+
+ Uses the provided arguments to instantiate the appropriate
+ MLP instance. This only contains MLPs for use in dMoEs
+ (ie. only for the dropless versions of MoEs).
+
+ Args:
+ args: propagated Arguments dataclass.
+
+ Returns:
+ An instantiated MLP constructed using the input args.
+ """
+ if args.mlp_type not in _REGISTRY:
+ raise ValueError(f'Unsupported mlp type: {args.mlp_type}')
+
+ if args.mlp_impl not in _REGISTRY[args.mlp_type]:
+ raise ValueError(f'{args.mlp_type} does not support {args.mlp_impl} backend.',)
+
+ return _REGISTRY[args.mlp_type][args.mlp_impl](args)
diff --git a/build/torch210-cxx11-cu126-x86_64-linux/_layers/dmoe.py b/build/torch210-cxx11-cu126-x86_64-linux/_layers/dmoe.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d0375a4df2f27134c4127e60be04f3b45693050
--- /dev/null
+++ b/build/torch210-cxx11-cu126-x86_64-linux/_layers/dmoe.py
@@ -0,0 +1,337 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import numpy as np
+import torch
+
+# try:
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/dmoe.py is needed.',
+# )
+
+# import megablocks.ops as ops
+# # from megablocks.ops import ops
+# from megablocks.layers import common, dmlp_registry, moe, mpu
+# from megablocks.layers.arguments import Arguments
+
+from .. import stk
+from .. import ops
+from . import common, dmlp_registry, moe, mpu
+from .arguments import Arguments
+
+def promote_scalar(x):
+ return x.view(1) if not len(x.size()) else x
+
+
+class ParallelDroplessMLP(moe.ParallelMLP):
+
+ def __init__(self, args: Arguments):
+ super(ParallelDroplessMLP, self).__init__(args)
+ self.hidden_size = args.hidden_size
+ self.ffn_hidden_size = mpu.features_per_rank(args)
+ self.blocking = 128
+ self.mlp = dmlp_registry.get(args)
+
+ # Calculate the number of bits needed to represent the column indices
+ # in the intermediate sparse matrix.
+ max_column_index = ((self.ffn_hidden_size * self.num_experts) // self.blocking)
+ self.transpose_sort_end_bit = max(
+ int(np.ceil(np.log2(max_column_index))),
+ 1,
+ )
+
+ def sparse_transpose(self, size, row_indices, column_indices, offsets):
+ block_columns = size[1] // self.blocking
+
+ # Sort row indices by column indices to get the transposed matrix's
+ # column indices.
+ #
+ # NOTE: Our sort operation uses the same width indices as the input values.
+ # To avoid overflow when we have large activation matrices we cast to
+ # 32-bit before sorting.
+ _, gather_indices = ops.sort(
+ column_indices.int(),
+ self.transpose_sort_end_bit,
+ )
+
+ # There are a constant number of blocks in every row of the sparse matrix.
+ # A blocks offset is:
+ #
+ # row_index * blocks_per_row + column_index % blocks_per_row
+ #
+ # Once we have the block offsets ordered for transposition we can divide
+ # by blocks_per_row to get the transposed column indices.
+ column_indices_t = row_indices.gather(0, gather_indices.long())
+ block_offsets_t = gather_indices.int()
+
+ zero = torch.zeros((1,), dtype=torch.int32, device=row_indices.device)
+ nnz_per_column = ops.histogram(column_indices, block_columns)
+ nnz_per_column = ops.inclusive_cumsum(nnz_per_column, 0)
+ if nnz_per_column.dim() == 0:
+ # This addresses an edge case when ffn_hidden_size is equal to self.blocking.
+ nnz_per_column = nnz_per_column.unsqueeze(0)
+ offsets_t = torch.cat([zero, nnz_per_column])
+ return column_indices_t, offsets_t, block_offsets_t
+
+ def topology(self, x, padded_bins):
+ padded_tokens, _ = x.size()
+ assert padded_tokens % self.blocking == 0
+ if self.ffn_hidden_size % self.blocking != 0:
+ raise ValueError(
+ f'The ffn_hidden_size {self.ffn_hidden_size} must be divisible by ' +
+ f'the block size {self.blocking}. Please update your configuration.',
+ )
+
+ # Offsets for the sparse matrix. All rows have the
+ # same number of nonzero blocks dictated by the
+ # dimensionality of a single expert.
+ block_rows = padded_tokens // self.blocking
+ blocks_per_row = self.ffn_hidden_size // self.blocking
+ offsets = torch.arange(
+ 0,
+ block_rows * blocks_per_row + 1,
+ blocks_per_row,
+ dtype=torch.int32,
+ device=x.device,
+ )
+
+ # Indices for the sparse matrix. The indices for
+ # the intermediate matrix are dynamic depending
+ # on the mapping of tokens to experts.
+ column_indices = ops.topology(
+ padded_bins,
+ self.blocking,
+ block_rows,
+ blocks_per_row,
+ )
+
+ # TODO(tgale): This is unused. Remove the need for this in stk.
+ # For now, use meta init to save the device memory.
+ data = torch.empty(
+ column_indices.numel(),
+ self.blocking,
+ self.blocking,
+ dtype=common.dtype(self.args),
+ device='meta',
+ )
+ shape = (
+ padded_tokens,
+ self.ffn_hidden_size * mpu.experts_per_rank(self.args),
+ )
+ row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
+ column_indices_t, offsets_t, block_offsets_t = self.sparse_transpose(
+ shape,
+ row_indices,
+ column_indices,
+ offsets,
+ )
+ return stk.Matrix(
+ shape,
+ data,
+ row_indices,
+ column_indices,
+ offsets,
+ column_indices_t,
+ offsets_t,
+ block_offsets_t,
+ )
+
+ def indices_and_padded_bins(self, top_experts):
+ # Sort the expert ids to produce the scatter/gather
+ # indices for the permutation.
+ top_experts = top_experts.int()
+ bin_ids, indices = ops.sort(top_experts, self.sort_end_bit)
+
+ # Histogram the expert ids to identify the number of
+ # tokens routed to each expert.
+ tokens_per_expert = ops.histogram(top_experts, self.num_experts)
+
+ # Round the token counts up to the block size used in
+ # the matrix muliplications. Caculate the starting
+ # position of each bin.
+ padded_tokens_per_expert = ops.round_up(
+ tokens_per_expert,
+ self.blocking,
+ )
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ padded_bins = promote_scalar(padded_bins)
+
+ # Calculate the bin bounds for the sorted tokens.
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ bins = promote_scalar(bins)
+ return indices, bin_ids, bins, padded_bins, tokens_per_expert
+
+ def sparse_forward_once(self, x, expert_weights, top_experts):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, padded_bins, tokens_per_expert = (self.indices_and_padded_bins(top_experts))
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.padded_gather(
+ x,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ self.top_k,
+ )
+
+ # Create the sparse matrix topology.
+ with torch.no_grad():
+ topo = self.topology(x, padded_bins)
+
+ # Perform the expert computation.
+ x = self.mlp(x, topo)
+
+ # Un-route the data for the MoE output.
+ x = ops.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ padded_bins,
+ self.top_k,
+ )
+ return x, tokens_per_expert
+
+ # For use in the base-class parallel_forward_once.
+ def sparse_permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy, # unused
+ top_k,
+ ):
+
+ # Round the token counts up to the block size used in the matrix
+ # multiplication. Calculate the starting position of each bin.
+ padded_tokens_per_expert = ops.round_up(
+ tokens_per_expert,
+ self.blocking,
+ )
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ padded_bins = promote_scalar(padded_bins)
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
+
+ # Create the sparse matrix topology.
+ with torch.no_grad():
+ topo = self.topology(x, padded_bins)
+
+ # Perform the expert computation.
+ x = self.mlp(x, topo)
+
+ # Un-route the data for the MoE output.
+ return ops.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ def grouped_forward_once(self, x, expert_weights, top_experts):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ out = self.grouped_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ -1, # unused
+ self.args.moe_top_k,
+ )
+ return out, tokens_per_expert
+
+ def grouped_permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy, # unused
+ top_k,
+ ):
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.gather(x, indices, bin_ids, bins, top_k)
+
+ # Perform the expert computation.
+ x = self.mlp(x, tokens_per_expert)
+
+ # Un-route the data for the MoE output.
+ return ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
+
+ def forward_once(self, x, expert_weights, top_experts):
+ if self.args.mlp_impl == 'sparse':
+ return self.sparse_forward_once(x, expert_weights, top_experts)
+ else:
+ return self.grouped_forward_once(x, expert_weights, top_experts)
+
+ def permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ ):
+ if self.args.mlp_impl == 'sparse':
+ return self.sparse_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ )
+ else:
+ return self.grouped_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ )
+
+
+class dMoE(moe.MoE):
+
+ def _init_experts_mlp(self, args: Arguments):
+ return ParallelDroplessMLP(args)
diff --git a/build/torch210-cxx11-cu126-x86_64-linux/_layers/gelu.py b/build/torch210-cxx11-cu126-x86_64-linux/_layers/gelu.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4c9e6532798615b5c12c96694241a4c18ee8f7b
--- /dev/null
+++ b/build/torch210-cxx11-cu126-x86_64-linux/_layers/gelu.py
@@ -0,0 +1,52 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# try:
+# import stk
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/gelu.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+import torch.nn.functional as F
+
+
+@torch.jit.script
+def _gelu_backward_inplace(g, x):
+ tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
+ ff = (0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out))
+ return g.mul_(ff)
+
+
+def gelu_backward_(grad: stk.Matrix, x: stk.Matrix):
+ # NOTE: The two sparse matrices must have the same topology.
+ if isinstance(grad, stk.Matrix) and isinstance(x, stk.Matrix):
+ return stk.Matrix(
+ x.size(),
+ _gelu_backward_inplace(grad.data, x.data),
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
+ return _gelu_backward_inplace(grad, x)
+
+
+def gelu(x: stk.Matrix):
+ assert isinstance(x, stk.Matrix)
+ return stk.Matrix(
+ x.size(),
+ F.gelu(x.data, approximate='tanh'),
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
diff --git a/build/torch210-cxx11-cu126-x86_64-linux/_layers/glu.py b/build/torch210-cxx11-cu126-x86_64-linux/_layers/glu.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f297a41ff6a1a2a285f5b461951672364b898da
--- /dev/null
+++ b/build/torch210-cxx11-cu126-x86_64-linux/_layers/glu.py
@@ -0,0 +1,244 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# import stk.ops
+# try:
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/glu.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+
+# from megablocks import grouped_gemm_util as gg
+# from megablocks.layers import common, mpu
+# from megablocks.layers.activation_fn import act_fn
+# from megablocks.layers.arguments import Arguments
+# from megablocks.layers.mlp import (
+# SharedMLP,
+# SparseMLP,
+# create_dmoe_expert_weights,
+# resolve_dtensor,
+# )
+
+from .. import grouped_gemm_util as gg
+from . import common, mpu
+from .activation_fn import act_fn
+from .arguments import Arguments
+from .mlp import (
+ SharedMLP,
+ SparseMLP,
+ create_dmoe_expert_weights,
+ resolve_dtensor,
+)
+
+
+class SparseGLU(SparseMLP):
+
+ def __init__(self, args: Arguments):
+ super().__init__(args)
+ self.v1 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ with torch.no_grad():
+ self.v1.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ ),
+ )
+
+ mpu.set_expert_model_parallel_attributes(
+ self.v1,
+ self._should_set_parallelism_attribute,
+ )
+
+ def forward(self, x, topo):
+ if self.args.memory_optimized_mlp:
+ raise NotImplementedError(
+ 'Memory optimized implementation not yet supported with GLU with sparse kernels.',
+ )
+
+ w1, v1, w2 = self.scale_grad(self.w1), self.scale_grad(self.v1,), self.scale_grad(self.w2)
+ w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2)
+
+ # Compute the GLU.
+ x1 = stk.ops.sdd(x, w1.t(), topo)
+ x2 = stk.ops.sdd(x, v1.t(), topo)
+
+ activation_fn_out = act_fn(x1, self.args.activation_fn)
+ x1 = stk.ops.mul(activation_fn_out, x2)
+
+ return stk.ops.dsd(x1, w2)
+
+
+class MemoryOptimizedGroupedGLU(torch.autograd.Function):
+ """GroupedMLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, v1, w2, batch_sizes, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ v1 = v1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], v1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not v1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1', 'v1' and 'w2'.")
+
+ # Layer 0: x @ w1.t().
+ assert gg.backend is not None
+ sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True)
+ v1_out = gg.backend.gmm(x, v1, batch_sizes, trans_b=True)
+
+ # GeLU.
+ activation_fn_out = activation_fn(sdd_out) * v1_out
+
+ # Layer 1: x @ w2.
+ dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, v1, w2, batch_sizes, x, sdd_out, v1_out)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx, ddsd_out):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # Unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, v1, w2 = saved_tensors[:3]
+ batch_sizes = saved_tensors[3]
+ x = saved_tensors[4]
+ sdd_out, v1_out = saved_tensors[5:7]
+
+ # Rematerialize activation_fn output.
+ activation_fn = ctx.activation_fn
+ with torch.set_grad_enabled(True):
+ sdd_out.requires_grad = True
+ v1_out.requires_grad = True
+ activation_fn_out = activation_fn(sdd_out) * v1_out
+ activation_grad_fn = activation_fn_out.backward
+
+ # Compute dw2 with recomputed activation_fn output.
+ assert gg.backend is not None
+ dw2 = gg.backend.gmm(
+ activation_fn_out,
+ ddsd_out,
+ batch_sizes,
+ trans_a=True,
+ )
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ gg.backend.gmm(
+ ddsd_out,
+ w2,
+ batch_sizes,
+ trans_b=True,
+ c=dactivation_fn_out,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out)
+ dsdd_out = sdd_out.grad
+ dv1_out = v1_out.grad
+
+ # Compute dw1.
+ dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True)
+
+ # Compute dv1.
+ dv1 = gg.backend.gmm(dv1_out, x, batch_sizes, trans_a=True)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ dx = ddsd_out
+ gg.backend.gmm(dsdd_out, w1, batch_sizes, c=dx)
+ dx += gg.backend.gmm(dv1_out, v1, batch_sizes)
+ return dx, dw1, dv1, dw2, None, None
+
+
+memory_optimized_grouped_glu = MemoryOptimizedGroupedGLU.apply
+
+
+class GroupedGLU(SparseGLU):
+
+ def forward(self, x, tokens_per_expert):
+ batch_sizes = tokens_per_expert.cpu().to(torch.long)
+ w1, v1, w2 = (
+ self.scale_grad(self.w1),
+ self.scale_grad(self.v1),
+ self.scale_grad(self.w2),
+ )
+ w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2)
+
+ # Re-shape the weights for the grouped GEMMs.
+ ne = mpu.experts_per_rank(self.args)
+ w1 = w1.view(ne, -1, self.args.hidden_size)
+ v1 = v1.view(ne, -1, self.args.hidden_size)
+ w2 = w2.view(ne, -1, self.args.hidden_size)
+
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_grouped_glu(
+ x,
+ w1,
+ v1,
+ w2,
+ batch_sizes,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ assert gg.ops is not None
+ x1 = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
+ x2 = gg.ops.gmm(x, v1, batch_sizes, trans_b=True)
+ x1 = self.args.activation_fn(x1) * x2
+ return gg.ops.gmm(x1, w2, batch_sizes)
+
+
+class SharedGLU(SharedMLP):
+ """GPU for shared expert.
+
+ Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTGLU class
+ """
+
+ def __init__(self, args: Arguments):
+ super().__init__(args)
+ self.gate_proj = args.fc_cls(
+ args.hidden_size,
+ self.args.shared_expert_hidden_size,
+ **self.fc_kwargs,
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))
diff --git a/build/torch210-cxx11-cu126-x86_64-linux/_layers/memory_test.py b/build/torch210-cxx11-cu126-x86_64-linux/_layers/memory_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..74d1166931b712635131985b25a89f4ca23e576d
--- /dev/null
+++ b/build/torch210-cxx11-cu126-x86_64-linux/_layers/memory_test.py
@@ -0,0 +1,103 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import gc
+
+import torch
+import torch.distributed as dist
+
+# from megablocks.layers import arguments, dmoe
+from . import arguments, dmoe
+
+_TESTS = ((8, 2048, 4096, 4096, 32, 4),)
+
+
+def get_tensors():
+ ptrs = set()
+ out = []
+ for obj in gc.get_objects():
+ if torch.is_tensor(obj):
+ if not obj.is_contiguous() or obj.data_ptr() in ptrs:
+ continue
+ out.append(obj)
+ ptrs.add(obj.data_ptr())
+ return out
+
+
+def test_memory(
+ group,
+ batch_size,
+ sequence_length,
+ hidden_size,
+ ffn_hidden_size,
+ num_experts,
+ top_k,
+):
+ args = arguments.Arguments(
+ hidden_size=hidden_size,
+ ffn_hidden_size=ffn_hidden_size,
+ moe_num_experts=num_experts,
+ moe_top_k=top_k,
+ moe_expert_model_parallelism=True,
+ expert_parallel_group=group,
+ fp16=False,
+ bf16=True,
+ device=torch.cuda.current_device(),
+ )
+ layer = dmoe.dMoE(args).cuda()
+
+ x = torch.randn((batch_size, sequence_length, hidden_size),
+ device=torch.cuda.current_device(),
+ dtype=torch.bfloat16).requires_grad_(True)
+ torch.cuda.empty_cache()
+
+ # Run forward + backward.
+ # with torch.autograd.detect_anomaly():
+ out, _ = layer(x)
+ out.mean().backward()
+
+ # Report peak memory.
+ mem = torch.cuda.max_memory_allocated()
+ print('Max Memory Allocated = {:0.0f}MiB'.format(mem / 1e6))
+ print('Max Memory Reserved = {:0.0f}MiB'.format(torch.cuda.max_memory_reserved() / 1e6,),)
+
+ # Calculate weight and gradient memory usage.
+ weight_memory = 2 * (
+ layer.router.layer.weight.numel() + layer.experts.mlp.w1.numel() + layer.experts.mlp.w2.numel()
+ )
+
+ def grad_numel(x):
+ if x.grad is not None:
+ return x.grad.numel()
+ return 0
+
+ grad_memory = 2 * (
+ grad_numel(layer.router.layer.weight) + grad_numel(layer.experts.mlp.w1) + grad_numel(layer.experts.mlp.w2)
+ )
+ weight_memory += grad_memory
+
+ print('Weight Memory Allocated = {:0.0f}MiB'.format(weight_memory / 1e6))
+ print('Activation Memory Allocated = {:0.0f}MiB'.format((mem - weight_memory) / 1e6,),)
+
+ # Manually calculate GPU memory usage from the garbage
+ # collector.
+ gc.collect()
+ total = 0
+ tensors = get_tensors()
+ tensors = sorted(tensors, key=lambda x: -x.numel())
+ for i, t in enumerate(tensors):
+ total += t.numel()
+ print(f'{i}: {t.shape}, {t.numel() * 2}')
+ del tensors
+
+ print('Total Bytes Found = {:0.0f}MiB'.format(total * 2 / 1e6))
+
+
+if __name__ == '__main__':
+ assert dist.is_available()
+ group = dist.init_process_group(backend='nccl')
+ local_rank = dist.get_rank(group)
+ torch.cuda.set_device(local_rank)
+
+ for args in _TESTS:
+ test_memory(group, *args)
diff --git a/build/torch210-cxx11-cu126-x86_64-linux/_layers/mlp.py b/build/torch210-cxx11-cu126-x86_64-linux/_layers/mlp.py
new file mode 100644
index 0000000000000000000000000000000000000000..c99afb9904c24a8b6a83e79059cd1251dbbfd99e
--- /dev/null
+++ b/build/torch210-cxx11-cu126-x86_64-linux/_layers/mlp.py
@@ -0,0 +1,587 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# try:
+# import stk
+# import stk.backend.triton_kernels
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/mlp.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+from packaging import version
+
+# from megablocks import grouped_gemm_util as gg
+# from megablocks.layers import common, gelu, mpu
+# from megablocks.layers.activation_fn import act_fn
+# from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
+
+from .. import grouped_gemm_util as gg
+from . import common, gelu, mpu
+from .activation_fn import act_fn
+from .arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
+
+class ScaleGradient(torch.autograd.Function):
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx: Any, x: torch.Tensor, scale: float):
+ ctx.scale = scale
+ return x
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx: torch.Tensor, grad: torch.Tensor):
+ return grad * ctx.scale, None
+
+
+scale_gradient = ScaleGradient.apply
+
+
+def resolve_dtensor(weight: torch.Tensor):
+ if version.parse(torch.__version__) >= version.parse('2.0.0'):
+ from torch.distributed._tensor import DTensor
+ if isinstance(weight, DTensor):
+ return weight.to_local()
+ return weight
+
+
+def create_moe_expert_weights(
+ args: Arguments,
+ num_experts: int,
+ ffn_hidden_size: int,
+ hidden_size: int,
+ init_method: InitFn,
+):
+ # Create the entire weight matrix such that the sampled weights will
+ # not vary between data parallelism and expert model parallelism for
+ # the same random seed.
+ master_weights = torch.empty(
+ num_experts,
+ ffn_hidden_size,
+ hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ )
+ init_method(master_weights)
+
+ if not args.moe_expert_model_parallelism:
+ return master_weights
+
+ # Calculate the amount of sharding in each dimension.
+ expert_sharding_degree = mpu.expert_sharding_degree(args)
+ hidden_sharding_degree = mpu.hidden_sharding_degree(args)
+
+ # Calculate the experts per rank.
+ #
+ # NOTE: We assign ranks to be expert parallel before going
+ # tensor parallel.
+ rank = mpu.get_expert_parallel_rank(args)
+ expert_rank = rank % expert_sharding_degree
+ num_experts_per_rank = num_experts // expert_sharding_degree
+ start_expert = expert_rank * num_experts_per_rank
+ end_expert = (expert_rank + 1) * num_experts_per_rank
+
+ # Calculate the rows per rank.
+ row_rank = rank // expert_sharding_degree
+ num_rows_per_rank = ffn_hidden_size // hidden_sharding_degree
+ start_row = row_rank * num_rows_per_rank
+ end_row = (row_rank + 1) * num_rows_per_rank
+
+ # Slice the weight matrix to get the chunk for this rank.
+ with torch.no_grad():
+ weights = master_weights[start_expert:end_expert, start_row:end_row]
+ return weights
+
+
+class MLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ # expert_parallel_world_size = mpu.get_expert_parallel_world_size(args)
+ experts_per_rank = mpu.experts_per_rank(args)
+
+ self.w1 = torch.nn.Parameter(
+ torch.empty(
+ experts_per_rank,
+ args.hidden_size,
+ mpu.features_per_rank(args),
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ self.w2 = torch.nn.Parameter(
+ torch.empty(
+ experts_per_rank,
+ mpu.features_per_rank(args),
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w1,
+ args.moe_expert_model_parallelism,
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w2,
+ args.moe_expert_model_parallelism,
+ )
+
+ # Initialize the parameters for the MLP.
+ #
+ # NOTE: It is important that we create the weight tensors prior
+ # to creating the master weights and slicing our the piece for
+ # this rank. If the master weights are created first the PyTorch
+ # caching allocator appears to use the same memory block for these
+ # and the slice which causes large increases in our peak memory
+ # usage.
+ with torch.no_grad():
+ w1 = create_moe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ )
+ self.w1.copy_(w1.transpose(1, 2).contiguous())
+ self.w2.copy_(
+ create_moe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.output_layer_init_method,
+ ),
+ )
+
+ self.gradient_scale = None
+ if self.args.moe_expert_model_parallelism:
+ self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,)
+
+ def scale_grad(self, w):
+ if self.gradient_scale is None:
+ return w
+ return scale_gradient(w, self.gradient_scale)
+
+ def forward(self, x):
+ w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2)
+ w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2)
+ x = torch.bmm(x, w1)
+ x = self.args.activation_fn(x)
+ return torch.bmm(x, w2)
+
+
+def create_dmoe_expert_weights(
+ args: Arguments,
+ num_experts: int,
+ rows: int,
+ columns: int,
+ init_method: InitFn,
+):
+ weights = create_moe_expert_weights(
+ args,
+ num_experts,
+ rows,
+ columns,
+ init_method,
+ )
+ return weights.view([-1, columns])
+
+
+class MemoryOptimizedMLP(torch.autograd.Function):
+ """Sparse MLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, w2, topo, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.")
+
+ topo_tensors = (
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets,
+ topo.column_indices_t,
+ topo.offsets_t,
+ topo.block_offsets_t,
+ )
+
+ # Layer 0: x @ w1.t().
+ sdd_out = stk.ops.sdd(x, w1.t(), topo)
+
+ # GeLU.
+ activation_fn_out = act_fn(sdd_out, activation_fn)
+
+ # Layer 1: x @ w2.
+ dsd_out = stk.ops.dsd(activation_fn_out, w2)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.shape = topo.shape
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.data.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, w2, *topo_tensors, x, sdd_out.data)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx, ddsd_out):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, w2 = saved_tensors[:2]
+ topo_tensors = saved_tensors[2:8]
+ x = saved_tensors[8]
+ sdd_out_data = saved_tensors[9]
+
+ # rematerialize activation function output
+ activation_fn = ctx.activation_fn
+ sdd_out = stk.Matrix(ctx.shape, sdd_out_data, *topo_tensors)
+ activation_fn_out, activation_grad_fn = act_fn(
+ sdd_out,
+ activation_fn,
+ return_grad_fn=True,
+ )
+
+ # Compute dw2 with recomputed activation_fn output.
+ dw2 = stk.ops.dsd(activation_fn_out.t(), ddsd_out)
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ stk.backend.triton_kernels.sdd(
+ ddsd_out,
+ w2.t(),
+ dactivation_fn_out.shape,
+ dactivation_fn_out.data,
+ dactivation_fn_out.offsets,
+ dactivation_fn_out.row_indices,
+ dactivation_fn_out.column_indices,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ if activation_fn is DEFAULT_ACTIVATION_FN:
+ dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out)
+ else:
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out.data)
+ dsdd_out = stk.Matrix(ctx.shape, sdd_out.data.grad, *topo_tensors)
+
+ # Compute dw1.
+ dw1 = stk.ops.dsd(dsdd_out.t(), x)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ stk.backend.triton_kernels.dsd(
+ dsdd_out.shape,
+ dsdd_out.data,
+ dsdd_out.offsets,
+ dsdd_out.row_indices,
+ dsdd_out.column_indices,
+ dsdd_out.offsets_t,
+ dsdd_out.column_indices_t,
+ dsdd_out.block_offsets_t,
+ False,
+ w1,
+ ddsd_out,
+ )
+ dx = ddsd_out
+ return dx, dw1, dw2, None, None
+
+
+memory_optimized_mlp = MemoryOptimizedMLP.apply
+
+
+class SparseMLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ self._num_rows_per_rank = mpu.experts_per_rank(args) * mpu.features_per_rank(args)
+
+ self.w1 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ self.w2 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+
+ # Initialize the parameters for the MLP.
+ #
+ # NOTE: It is important that we create the weight tensors prior
+ # to creating the master weights and slicing our the piece for
+ # this rank. If the master weights are created first the PyTorch
+ # caching allocator appears to use the same memory block for these
+ # and the slice which causes large increases in our peak memory
+ # usage.
+ with torch.no_grad():
+ self.w1.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ ),
+ )
+ self.w2.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.output_layer_init_method,
+ ),
+ )
+
+ self._should_set_parallelism_attribute = args.moe_expert_model_parallelism
+ mpu.set_expert_model_parallel_attributes(
+ self.w1,
+ self._should_set_parallelism_attribute,
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w2,
+ self._should_set_parallelism_attribute,
+ )
+
+ self.gradient_scale = None
+ if self.args.moe_expert_model_parallelism:
+ self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,)
+
+ def scale_grad(self, w):
+ if self.gradient_scale is None:
+ return w
+ return scale_gradient(w, self.gradient_scale)
+
+ def forward(self, x, topo):
+ w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2)
+ w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2)
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_mlp(
+ x,
+ w1,
+ w2,
+ topo,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ x = stk.ops.sdd(x, w1.t(), topo)
+ activation_fn_out = act_fn(x, self.args.activation_fn)
+ return stk.ops.dsd(activation_fn_out, w2)
+
+
+class MemoryOptimizedGroupedMLP(torch.autograd.Function):
+ """GroupedMLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, w2, batch_sizes, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.")
+
+ # Layer 0: x @ w1.t().
+ assert gg.backend is not None
+ sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True)
+
+ # activation_fn
+ activation_fn_out = activation_fn(sdd_out)
+
+ # Layer 1: x @ w2.
+ dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, w2, batch_sizes, x, sdd_out)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx: Any, ddsd_out: torch.Tensor):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # Unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, w2 = saved_tensors[:2]
+ batch_sizes = saved_tensors[2]
+ x = saved_tensors[3]
+ sdd_out = saved_tensors[4]
+
+ # Rematerialize activation_fn output.
+ activation_fn = ctx.activation_fn
+ with torch.set_grad_enabled(True):
+ sdd_out.requires_grad = True
+ activation_fn_out = activation_fn(sdd_out)
+ activation_grad_fn = activation_fn_out.backward
+
+ # Compute dw2 with recomputed activation_fn output.
+ assert gg.backend is not None
+ dw2 = gg.backend.gmm(
+ activation_fn_out,
+ ddsd_out,
+ batch_sizes,
+ trans_a=True,
+ )
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ gg.backend.gmm(
+ ddsd_out,
+ w2,
+ batch_sizes,
+ trans_b=True,
+ c=dactivation_fn_out,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ if activation_fn is DEFAULT_ACTIVATION_FN:
+ dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out)
+ else:
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out)
+ dsdd_out = sdd_out.grad
+
+ # Compute dw1.
+ dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ gg.backend.gmm(dsdd_out, w1, batch_sizes, c=ddsd_out)
+ dx = ddsd_out
+ return dx, dw1, dw2, None, None
+
+
+memory_optimized_grouped_mlp = MemoryOptimizedGroupedMLP.apply
+
+
+class GroupedMLP(SparseMLP):
+
+ def forward(self, x, tokens_per_expert):
+ batch_sizes = tokens_per_expert.cpu().to(torch.long)
+ w1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.w2))
+
+ # Re-shape the weights for the grouped GEMMs.
+ ne = mpu.experts_per_rank(self.args)
+ w1 = resolve_dtensor(w1).view(ne, -1, self.args.hidden_size)
+ w2 = resolve_dtensor(w2).view(ne, -1, self.args.hidden_size)
+
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_grouped_mlp(
+ x,
+ w1,
+ w2,
+ batch_sizes,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ assert gg.ops is not None
+ x = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
+ x = self.args.activation_fn(x)
+ return gg.ops.gmm(x, w2, batch_sizes)
+
+
+class SharedMLP(torch.nn.Module):
+ """MLP for shared expert.
+
+ Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTMLP class
+ """
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ self.fc_kwargs: dict[str, Any] = {
+ 'bias': args.bias,
+ 'device': args.device,
+ }
+ self.fc_kwargs.update(args.fc_kwargs)
+
+ self.up_proj = args.fc_cls(
+ args.hidden_size,
+ args.shared_expert_hidden_size,
+ **self.fc_kwargs,
+ )
+ self.act = args.activation_fn
+ self.down_proj = args.fc_cls(
+ args.shared_expert_hidden_size,
+ args.hidden_size,
+ **self.fc_kwargs,
+ )
+ self.down_proj._is_residual = True # a flag for llm-foundry init
+
+ def add_experts_sharedexpert(
+ self,
+ shared_expert_out: torch.Tensor,
+ expert_out: torch.Tensor,
+ ) -> torch.Tensor:
+ # Helper function to add expert output to shared expert output
+ # with optional weighted sum.
+ if self.args.shared_expert_weighted_sum:
+ # enable using weighted sum for shared expert output
+ # wieghted by number of experts used
+ t_experts = self.args.moe_top_k + 1
+ sh_mlp_out = shared_expert_out / t_experts
+ return sh_mlp_out.add(
+ expert_out,
+ alpha=(self.args.moe_top_k / t_experts),
+ )
+
+ return shared_expert_out + expert_out
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.down_proj(self.act(self.up_proj(x)))
diff --git a/build/torch210-cxx11-cu126-x86_64-linux/_layers/moe.py b/build/torch210-cxx11-cu126-x86_64-linux/_layers/moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0a4aeaacc9c86fc70944e730c53f7a55644e05e
--- /dev/null
+++ b/build/torch210-cxx11-cu126-x86_64-linux/_layers/moe.py
@@ -0,0 +1,507 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Optional, Tuple
+
+import numpy as np
+import torch
+import torch.distributed as dist
+
+# import megablocks.ops as ops
+# from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry
+# from megablocks.layers.all_to_all import all_to_all
+# from megablocks.layers.arguments import Arguments
+
+from ..ops import (
+ sort,
+ histogram,
+ inclusive_cumsum,
+ exclusive_cumsum,
+ binned_gather,
+ binned_scatter,
+ gather,
+ scatter,
+ repeat,
+ replicate,
+)
+
+from . import common, mlp, mpu, router, sharedexpert_registry
+from .arguments import Arguments
+from .all_to_all import all_to_all
+
+_LOAD_BALANCING_LOSS = []
+
+
+def save_load_balancing_loss(loss):
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.append(loss)
+
+
+def get_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ return _LOAD_BALANCING_LOSS
+
+
+def clear_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.clear()
+
+
+def batched_load_balancing_loss(args: Arguments):
+ if args.moe_loss_weight == 0:
+ return 0.0
+
+ # tokens_per_expert[i].shape = (num_experts)
+ # expert_scores[i].shape = (tokens, num_experts)
+ tokens_per_expert, expert_scores = zip(*get_load_balancing_loss())
+ num_layers_per_pipeline_stage = (args.num_layers // args.pipeline_model_parallel_size)
+ if args.num_layers_per_virtual_pipeline_stage is not None:
+ num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage
+
+ if len(tokens_per_expert) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f'Expected {num_layers_per_pipeline_stage} token_per_experts '
+ f'but found {len(tokens_per_expert)}.\nnum_layers = '
+ f'{args.num_layers}\npipeline_model_parallel_size = '
+ f'{args.pipeline_model_parallel_size}\n'
+ 'num_layers_per_virtual_pipeline_stage'
+ f' = {args.num_layers_per_virtual_pipeline_stage}',
+ )
+ if len(expert_scores) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f'Expected {num_layers_per_pipeline_stage} expert_scores '
+ f'but found {len(tokens_per_expert)}.\nnum_layers = '
+ f'{args.num_layers}\npipeline_model_parallel_size = '
+ f'{args.pipeline_model_parallel_size}\n'
+ 'num_layers_per_virtual_pipeline_stage'
+ f' = {args.num_layers_per_virtual_pipeline_stage}',
+ )
+
+ # Verify the shape of the tokens_per_expert and expert_scores tensors.
+ assert all((x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert))
+
+ tokens = expert_scores[0].shape[0]
+ assert all(((x.ndim == 2 and x.shape[1] == args.moe_num_experts and x.shape[0] == tokens) for x in expert_scores))
+
+ # Concatenate the contributions of each layer and convert to
+ # the correct types and formats for the dot product.
+ expert_scores = torch.cat(expert_scores, dim=1)
+ if args.moe_lbl_in_fp32:
+ expert_scores = expert_scores.float()
+ if tokens != 0:
+ expert_scores = expert_scores.mean(dim=0)
+ else:
+ expert_scores = expert_scores.sum(dim=0)
+ tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype)
+
+ expected_values = num_layers_per_pipeline_stage * args.moe_num_experts
+ assert tokens_per_expert.numel() == expected_values
+ assert expert_scores.numel() == expected_values
+
+ # Calculate the total scale across all factors.
+ #
+ # loss_weight * num_experts / (num_layers * tokens * top_k)
+ scale_numerator = (args.moe_num_experts * args.moe_loss_weight)
+ scale_denominator = (args.num_layers * tokens * args.moe_top_k)
+ scale = scale_numerator / scale_denominator
+ return scale * torch.dot(tokens_per_expert, expert_scores)
+
+
+# NOTE: This class defines MoE expert computation, including expert model parallel
+# communication. When using FSDP on top of MegaBlocks this is the module that should
+# be wrapped s.t. the weight all-gathers can be scheduled *before* the expert model
+# parallel all2all.
+class ParallelMLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super(ParallelMLP, self).__init__()
+ self.args = args
+
+ # Calculate the number of experts in total and the number of experts
+ # owned by this rank.
+ # world_size = mpu.get_expert_parallel_world_size(args)
+ self.num_experts = args.moe_num_experts
+ self.top_k = self.args.moe_top_k
+
+ # Calculate the number of bits needed to represent the expert indices
+ # so that we can pass it to radix sort.
+ self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1)
+
+ # Expert MLP.
+ self.mlp = mlp.MLP(args)
+
+ self.bias: Optional[torch.Tensor]
+ if self.args.bias:
+ # Note that the output bias is not parallelized with expert
+ # model parallelism.
+ self.bias = torch.nn.Parameter(
+ torch.empty(
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ torch.nn.init.zeros_(self.bias)
+ else:
+ self.register_parameter('bias', None)
+
+ # Select the forward function for the operating mode.
+ self.forward_fn = (self.parallel_forward_once if args.moe_expert_model_parallelism else self.forward_once)
+
+ def expert_capacity(self, tokens: int) -> int:
+ world_size = mpu.get_expert_parallel_world_size(self.args)
+ tokens_per_expert = (self.top_k * tokens * world_size / self.num_experts)
+ return int(self.args.moe_capacity_factor * tokens_per_expert)
+
+ def load_balancing_loss(self, tokens_per_expert: torch.Tensor, expert_scores: torch.Tensor):
+ """Calculate the load balancing loss contribution."""
+ assert len(expert_scores.size()) == 2
+ tokens, num_experts = expert_scores.size()
+ assert num_experts == self.num_experts
+ assert len(tokens_per_expert.size()) == 1
+ num_experts, = tokens_per_expert.size()
+ assert num_experts == self.num_experts
+ scale = self.num_experts / (tokens * self.top_k)
+ return scale * torch.dot(
+ tokens_per_expert.to(expert_scores.dtype),
+ expert_scores.mean(dim=0),
+ )
+
+ def indices_and_bins(self,
+ top_expert: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ # Sort the expert ids to produce the scatter/gather
+ # indices for the permutation.
+ #
+ # TODO(tgale): Is it worth doing this conversion to 32-bit
+ # prior? Could we place the `torch.max` operation to return
+ # 32-bit expert indices?
+ top_expert = top_expert.int()
+ # output = ops.sort(top_expert, self.sort_end_bit)
+ output = sort(top_expert, self.sort_end_bit)
+ assert output is not None
+ bin_ids, indices = output
+
+ # Histogram the expert ids to identify the number of
+ # tokens routed to each expert.
+ #
+ # TODO(tgale): Does the sorted data produce a more favorable
+ # data distribution for histogram? Or is the op parallelism
+ # worth more?
+ # tokens_per_expert = ops.histogram(top_expert, self.num_experts)
+ tokens_per_expert = histogram(top_expert, self.num_experts)
+
+ # Calculate the bin bounds for the sorted tokens.
+ # bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ bins = inclusive_cumsum(tokens_per_expert, 0)
+ assert bins is not None
+ bins = bins.view(1) if not len(bins.size()) else bins
+
+ assert isinstance(indices, torch.Tensor)
+ assert isinstance(bin_ids, torch.Tensor)
+ assert isinstance(bins, torch.Tensor)
+ assert isinstance(tokens_per_expert, torch.Tensor)
+
+ return indices, bin_ids, bins, tokens_per_expert
+
+ def permute_and_compute(
+ self,
+ x: torch.Tensor,
+ tokens_per_expert: int, # unused
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor, # unused
+ expert_weights: torch.Tensor,
+ bins: torch.Tensor,
+ expert_capacity: int,
+ top_k: int,
+ ):
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ # output = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
+ output = binned_gather(x, indices, bins, expert_capacity, top_k)
+ assert output is not None
+ x = output
+
+ # Perform the expert computation. Note that we don't
+ # use biases for these linear operations.
+ x = self.mlp(x)
+
+ # Un-route the data for the MoE output.
+ # return ops.binned_scatter(x, indices, expert_weights, bins, top_k)
+ return binned_scatter(x, indices, expert_weights, bins, top_k)
+
+
+ def forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ # If expert_capacity is set to zero, set the number of tokens
+ # per expert to the maximum we need to avoid dropping tokens.
+ sl, bs, _ = x.size()
+ expert_capacity = self.expert_capacity(sl * bs)
+ if expert_capacity == 0:
+ expert_capacity = torch.max(tokens_per_expert).item()
+
+ x = self.permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ self.top_k,
+ )
+ return x, tokens_per_expert
+
+ def parallel_forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ # NOTE: This function implements the same computation as forward_once
+ # but with expert model parallelism.
+ #
+ # 1. Permute the tokens locally so that they are grouped by their
+ # expert assignments. This allows us to transfer all of the tokens
+ # for a remote device in one communication primitive.
+ #
+ # 2. Permute the tokens across the expert parallel devices. After
+ # this is completed each device has all of the tokens assigned to
+ # its set of experts in its local HBM.
+ #
+ # 3. Permute the tokens locally so that they are grouped by their
+ # expert assignement. After the distributed permutation the tokens
+ # are grouped by which device they came from. We re-order them
+ # locally to allow for efficient computation.
+ #
+ # After this series of permutations we compute the linear layers
+ # and then repeat these three steps in reverse to produce the final
+ # output.
+ #
+ # Compute the mapping of local tokens to experts.
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ # If we're sharding the experts along the hidden dimension
+ # multiple devices own parts of the same sets of experts.
+ # Replicate the token counts so every device gets the counts.
+ # repeated_tokens_per_expert = ops.repeat(
+ repeated_tokens_per_expert = repeat(
+ tokens_per_expert,
+ (mpu.hidden_sharding_degree(self.args),),
+ )
+
+ # Pass token count information to the device on which the
+ # target expert resides.
+ parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert,)
+ tpe_handle = dist.all_to_all_single(
+ parallel_tokens_per_expert,
+ repeated_tokens_per_expert,
+ group=self.args.expert_parallel_group,
+ async_op=True,
+ )
+
+ # Permute locally and without any padding so that tokens for each
+ # parallel device are stored contiguously.
+ #
+ # This view updates the shape of the tensor from [sl, bs, hs] to
+ # [sl * bs, hs] prior to the permutation.
+ x = x.view(-1, x.shape[-1])
+ # output = ops.gather(x, indices, bin_ids, bins, self.top_k)
+ output = gather(x, indices, bin_ids, bins, self.top_k)
+ assert output is not None
+ x = output
+
+ # Compute the number of tokens that will be received from each
+ # device and permute the input data across the devices.
+ with torch.no_grad():
+ tpe_handle.wait()
+ experts_per_rank = mpu.experts_per_rank(self.args)
+
+ # Reshape to [world_size, num_experts_per_rank].
+ world_size = mpu.get_expert_parallel_world_size(self.args)
+ repeated_tokens_per_expert = (repeated_tokens_per_expert.view(world_size, experts_per_rank))
+ parallel_tokens_per_expert = (parallel_tokens_per_expert.view(world_size, experts_per_rank))
+
+ # TODO(tgale): It might be faster to do this on the GPU and
+ # then communicate the results back to the host.
+ send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1)
+ parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
+ recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1)
+
+ # Convert the send/recv counts to lists.
+ send_counts = send_counts.tolist()
+ recv_counts = recv_counts.tolist()
+ tokens_received = sum(recv_counts)
+
+ # If we're sharding the experts along the hidden dimension
+ # multiple devices own parts of the same sets of experts.
+ # Replicate the token counts so devices that share experts
+ # get all of the tokens assigned to them.
+ #
+ # TODO(tgale): Fuse this into the prior, local permutation.
+ # x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
+ x = repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
+
+ # Start the cross-device permutation asynchronously so we can
+ # overlap communication with computation.
+ parallel_x, parallel_x_handle = all_to_all(
+ x,
+ recv_counts,
+ send_counts,
+ self.args.expert_parallel_group,
+ async_op=True,
+ )
+
+ with torch.no_grad():
+ # After we do the cross-device permutation we have the tokens on the
+ # correct device but not yet grouped by expert because we received
+ # tokens from each device as contiguous chunks. To group the tokens
+ # for expert computation we'll do one more local permutation. The
+ # rest of this torch.no_grad() scope sets up the indices and bins
+ # for this permutation.
+ # replicate_bins = ops.inclusive_cumsum(
+ replicate_bins = inclusive_cumsum(
+ parallel_tokens_per_expert.flatten(),
+ 0,
+ )
+ replicate_bins = (replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins)
+
+ # Construct the expert indices for the permuted tokens.
+ parallel_top_expert = torch.remainder(
+ torch.arange(
+ self.num_experts * mpu.hidden_sharding_degree(self.args),
+ dtype=torch.int32,
+ device=indices.device,
+ ),
+ mpu.experts_per_rank(self.args),
+ )
+ # parallel_top_expert = ops.replicate(
+ parallel_top_expert = replicate(
+ parallel_top_expert.unsqueeze(dim=0),
+ replicate_bins,
+ tokens_received,
+ ).flatten()
+
+ # TODO(tgale): The sort_end_bit here can be reduced.
+ # parallel_bin_ids, parallel_indices = ops.sort(
+ parallel_bin_ids, parallel_indices = sort(
+ parallel_top_expert,
+ self.sort_end_bit,
+ )
+
+ # Calculate the bins boundaries from the token counts.
+ parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+ # parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = (parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins)
+
+ # If expert_capacity is set to zero, set the number of tokens
+ # per expert to the maximum we need to avoid dropping tokens.
+ tokens, _ = x.size()
+ expert_capacity = self.expert_capacity(tokens)
+ if expert_capacity == 0:
+ expert_capacity = torch.max(parallel_tokens_per_expert).item()
+
+ # Locally permute the tokens and perform the expert computation.
+ # Block to make sure that the cross-device permutation is complete.
+ if self.args.mlp_impl == 'grouped':
+ # GroupedMLP requires counts on CPU. We can use the tensor already
+ # moved to CPU for the prior all_to_all, which avoids an extra
+ # device synchronization.
+ parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+ parallel_x_handle.wait()
+ parallel_x = self.permute_and_compute(
+ parallel_x,
+ parallel_tokens_per_expert,
+ parallel_indices,
+ parallel_bin_ids,
+ None, # expert_weights
+ parallel_bins,
+ expert_capacity,
+ top_k=1,
+ )
+
+ # Un-permute the tokens across the devices.
+ x, _ = all_to_all(
+ parallel_x,
+ send_counts,
+ recv_counts,
+ self.args.expert_parallel_group,
+ )
+
+ # Reduce along the hidden sharding to get the final outputs.
+ #
+ # TODO(tgale): Fuse this into the following local permutation.
+ shape = (
+ mpu.hidden_sharding_degree(self.args),
+ -1,
+ self.args.hidden_size,
+ )
+ # x = ops.sum(x.view(shape), dim=0)
+ x = x.view(shape).sum(dim=0)
+
+ # Un-permute locally to setup for the next series of operations.
+ # x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
+ x = scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
+ return x, tokens_per_expert.flatten()
+
+ def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ in_shape = x.size()
+
+ # Compute the experts.
+ x, tokens_per_expert = self.forward_fn(x, expert_weights, top_experts)
+ if self.training and self.args.moe_loss_weight > 0:
+ save_load_balancing_loss((tokens_per_expert, scores))
+ x = x.view(in_shape)
+ if self.bias is not None:
+ if self.args.return_bias:
+ return x, self.bias
+ return x + self.bias
+ return x
+
+
+class MoE(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super(MoE, self).__init__()
+
+ # Token router.
+ self.router = router.LearnedRouter(args)
+
+ # Expert computation helper.
+ self.experts = self._init_experts_mlp(args)
+
+ self.shared_expert = None
+ if args.shared_expert:
+ # SharedExpert computation helper.
+ self.shared_expert = sharedexpert_registry.get(args)
+
+ def _init_experts_mlp(self, args: Arguments):
+ return ParallelMLP(args)
+
+ def forward(self, x: torch.Tensor):
+ # NOTE: If we're going to cast the activations to lower precision
+ # do it before we permute the tokens to save bandwidth.
+ x = common.cast_if_autocast_enabled(x)
+
+ # Compute the expert scores and assignments.
+ scores, expert_weights, top_experts = self.router(x)
+
+ # Compute the experts.
+ out = self.experts(x, scores, expert_weights, top_experts)
+ if self.shared_expert is not None:
+ shared_expert_out = self.shared_expert(x)
+ out = self.shared_expert.add_experts_sharedexpert(
+ shared_expert_out,
+ out,
+ )
+ return out
diff --git a/build/torch210-cxx11-cu126-x86_64-linux/_layers/mpu.py b/build/torch210-cxx11-cu126-x86_64-linux/_layers/mpu.py
new file mode 100644
index 0000000000000000000000000000000000000000..434e143ab42bf3f83406d69e9dd1f72777716e22
--- /dev/null
+++ b/build/torch210-cxx11-cu126-x86_64-linux/_layers/mpu.py
@@ -0,0 +1,94 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Optional
+
+import torch
+import torch.distributed as dist
+
+# from megablocks.layers.arguments import Arguments
+from .arguments import Arguments
+
+
+class MoeParam(torch.Tensor):
+
+ def __init__(self):
+ super().__init__(self)
+ self.expert_model_parallel: bool
+
+
+def is_moe_param(tensor: torch.Tensor) -> bool:
+ return hasattr(tensor, 'expert_model_parallel')
+
+
+def get_expert_parallel_world_size(args: Arguments) -> int:
+ return (dist.get_world_size(args.expert_parallel_group) if args.moe_expert_model_parallelism else 1)
+
+
+def get_expert_parallel_rank(args: Arguments) -> int:
+ return (dist.get_rank(args.expert_parallel_group) if args.moe_expert_model_parallelism else 0)
+
+
+def set_expert_model_parallel_attributes(
+ tensor: torch.Tensor,
+ is_parallel: bool,
+):
+ assert not hasattr(tensor, 'expert_model_parallel')
+ setattr(tensor, 'expert_model_parallel', is_parallel)
+
+
+def param_is_expert_model_parallel(param: MoeParam) -> bool:
+ return (hasattr(param, 'expert_model_parallel') and param.expert_model_parallel)
+
+
+def copy_expert_model_parallel_attributes(
+ destination_tensor: torch.Tensor,
+ source_tensor: torch.Tensor,
+):
+ if hasattr(source_tensor, 'expert_model_parallel'):
+ setattr(
+ destination_tensor,
+ 'expert_model_parallel',
+ getattr(source_tensor, 'expert_model_parallel'),
+ )
+
+
+def synchronized_print(group: Optional[dist.ProcessGroup], *x: torch.Tensor):
+ world_size = dist.get_world_size(group)
+ rank = dist.get_rank(group)
+ for i in range(world_size):
+ dist.barrier(group)
+ if i == rank:
+ print(f'rank = {rank}', *x)
+
+
+# Helpers for expert/tensor sharding.
+def expert_sharding_degree(args: Arguments) -> int:
+ world_size = get_expert_parallel_world_size(args)
+ esd = min(world_size, args.moe_num_experts)
+
+ if (args.moe_num_experts % esd) != 0:
+ raise ValueError(f'Cannot shard {args.moe_num_experts} experts {esd} ways.',)
+ return esd
+
+
+def hidden_sharding_degree(args: Arguments) -> int:
+ world_size = get_expert_parallel_world_size(args)
+ esd = expert_sharding_degree(args)
+ hsd = world_size // esd
+
+ if (args.ffn_hidden_size % hsd) != 0:
+ raise ValueError(f'Cannot shard {args.ffn_hidden_size} features {hsd} ways.',)
+ if (esd * hsd) != world_size:
+ raise ValueError(
+ f"Invalid sharding. 'expert_sharding_degree' ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size}).",
+ )
+ return hsd
+
+
+def experts_per_rank(args: Arguments) -> int:
+ return args.moe_num_experts // expert_sharding_degree(args)
+
+
+def features_per_rank(args: Arguments) -> int:
+ return args.ffn_hidden_size // hidden_sharding_degree(args)
diff --git a/build/torch210-cxx11-cu126-x86_64-linux/_layers/router.py b/build/torch210-cxx11-cu126-x86_64-linux/_layers/router.py
new file mode 100644
index 0000000000000000000000000000000000000000..37cb2782348d62583376f1a183c7ede83601216d
--- /dev/null
+++ b/build/torch210-cxx11-cu126-x86_64-linux/_layers/router.py
@@ -0,0 +1,116 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+
+# from megablocks.layers import common
+# from megablocks.layers.arguments import Arguments
+from . import common
+from .arguments import Arguments
+
+_ROUTER_LOGITS = []
+
+
+def _save_router_logits(logits: torch.Tensor, args: Arguments):
+ if args.moe_zloss_weight == 0:
+ return
+ global _ROUTER_LOGITS
+ _ROUTER_LOGITS.append(logits)
+
+
+def clear_router_zloss():
+ global _ROUTER_LOGITS
+ _ROUTER_LOGITS.clear()
+
+
+def batched_router_zloss(args: Arguments):
+ global _ROUTER_LOGITS
+
+ if args.moe_zloss_weight == 0:
+ import warnings
+ warnings.warn('Call to batched_router_zloss, but moe_zloss_weight=0')
+ return 0
+
+ logits_per_router = _ROUTER_LOGITS
+
+ if args.moe_zloss_in_fp32:
+ logits_per_router = [logits.float() for logits in logits_per_router]
+
+ unscaled_zloss_per_router = torch.stack([
+ torch.logsumexp(logits, dim=1).square().mean() for logits in logits_per_router
+ ])
+
+ return args.moe_zloss_weight * unscaled_zloss_per_router
+
+
+# NOTE: To enable end-to-end benchmarking without convergence we
+# support a flag to force the router to assign tokens uniformly
+# across the experts. We do this with a custom autograd operation
+# so that PyTorch still executes the full set of router operation.
+class _UniformExpertAssignment(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, num_experts: int):
+ out = torch.arange(x.numel(), dtype=x.dtype, device=x.device)
+ out = torch.remainder(out, num_experts)
+ return out.view(x.shape)
+
+
+_uniform_expert_assignment = _UniformExpertAssignment.apply
+
+
+class LearnedRouter(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+
+ # Learned router parameters.
+ #
+ # NOTE: This weight matrix is not parallelized with expert model
+ # parallelism. Each device needs the entire router weight matrix
+ # so that it can route its batch of data correctly.
+ self.layer = torch.nn.Linear(
+ args.hidden_size,
+ args.moe_num_experts,
+ bias=False,
+ dtype=common.dtype(args),
+ device=args.device,
+ )
+ args.init_method(self.layer.weight)
+
+ def jitter(self, x: torch.Tensor):
+ low: float = 1.0 - self.args.moe_jitter_eps
+ high: float = 1.0 + self.args.moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return low + noise * (high - low)
+
+ def _top_k(self, scores: torch.Tensor):
+ if self.args.moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, self.args.moe_top_k, dim=-1)
+
+ def forward(self, x: torch.Tensor):
+ if self.training and self.args.moe_jitter_eps is not None:
+ x = x * self.jitter(x)
+
+ logits = self.layer(x.view(-1, x.shape[-1]))
+ _save_router_logits(logits, self.args)
+ scores = logits.softmax(dim=-1)
+ expert_weights, expert_indices = self._top_k(scores)
+ if self.args.moe_normalize_expert_weights:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=self.args.moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+
+ expert_indices = (
+ _uniform_expert_assignment(
+ expert_indices,
+ self.args.moe_num_experts,
+ ) if self.args.uniform_expert_assignment else expert_indices
+ )
+ return scores, expert_weights, expert_indices
diff --git a/build/torch210-cxx11-cu126-x86_64-linux/_layers/sharedexpert_registry.py b/build/torch210-cxx11-cu126-x86_64-linux/_layers/sharedexpert_registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..5840862f88f370ace5fd49bd0612fc98d186cc49
--- /dev/null
+++ b/build/torch210-cxx11-cu126-x86_64-linux/_layers/sharedexpert_registry.py
@@ -0,0 +1,32 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Union
+
+# from megablocks.layers import glu, mlp
+# from megablocks.layers.arguments import Arguments
+from . import glu, mlp
+from .arguments import Arguments
+
+_REGISTRY = {
+ 'mlp': mlp.SharedMLP,
+ 'glu': glu.SharedGLU,
+}
+
+
+def get(args: Arguments) -> Union[mlp.SharedMLP, glu.SharedGLU]:
+ """Returns an SharedMLP for use in a dMoE instance.
+
+ Uses the provided arguments to instantiate the appropriate
+ SharedMLP instance.
+
+ Args:
+ args: propagated Arguments dataclass.
+
+ Returns:
+ An instantiated SharedMLP constructed using the input args.
+ """
+ if args.mlp_type not in _REGISTRY:
+ raise ValueError(f'Unsupported mlp type: {args.mlp_type}')
+
+ return _REGISTRY[args.mlp_type](args)
diff --git a/build/torch210-cxx11-cu126-x86_64-linux/_megablocks_cuda_7a6bcf4.abi3.so b/build/torch210-cxx11-cu126-x86_64-linux/_megablocks_cuda_7a6bcf4.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..4e4162acffa15b572b47e28662e3ea8dc8259eee
--- /dev/null
+++ b/build/torch210-cxx11-cu126-x86_64-linux/_megablocks_cuda_7a6bcf4.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7949dd996d24b131bc32dd98e15b6bf00c5f4cad2f17cd96edec5f5ae90544de
+size 15061056
diff --git a/build/torch210-cxx11-cu126-x86_64-linux/_ops.py b/build/torch210-cxx11-cu126-x86_64-linux/_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..abde001f8cf5f78a02794d6e9a81fd8195e65d77
--- /dev/null
+++ b/build/torch210-cxx11-cu126-x86_64-linux/_ops.py
@@ -0,0 +1,9 @@
+import torch
+from . import _megablocks_cuda_7a6bcf4
+ops = torch.ops._megablocks_cuda_7a6bcf4
+
+def add_op_namespace_prefix(op_name: str):
+ """
+ Prefix op by namespace.
+ """
+ return f"_megablocks_cuda_7a6bcf4::{op_name}"
diff --git a/build/torch210-cxx11-cu126-x86_64-linux/_version.py b/build/torch210-cxx11-cu126-x86_64-linux/_version.py
new file mode 100644
index 0000000000000000000000000000000000000000..c55783177af19bc03654c730c4892df8f8532279
--- /dev/null
+++ b/build/torch210-cxx11-cu126-x86_64-linux/_version.py
@@ -0,0 +1,6 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+"""The MegaBlocks Version."""
+
+__version__ = '0.11.0.dev0'
diff --git a/build/torch210-cxx11-cu126-x86_64-linux/backend/__init__.py b/build/torch210-cxx11-cu126-x86_64-linux/backend/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d4e43e9b7e6a9a1c3bde2df34914643ca5d8332
--- /dev/null
+++ b/build/torch210-cxx11-cu126-x86_64-linux/backend/__init__.py
@@ -0,0 +1,2 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
diff --git a/build/torch210-cxx11-cu126-x86_64-linux/backend/kernels.py b/build/torch210-cxx11-cu126-x86_64-linux/backend/kernels.py
new file mode 100644
index 0000000000000000000000000000000000000000..d20139324883338ddb312e4b05a72056d74491ac
--- /dev/null
+++ b/build/torch210-cxx11-cu126-x86_64-linux/backend/kernels.py
@@ -0,0 +1,557 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import triton
+import triton.language as tl
+
+# Stub triton autotune when testing in a env that does not have CUDA
+# this approach preserves the original code but enables testing without a GPU
+if torch.cuda.is_available() is False:
+ import warnings
+
+ warnings.warn("CUDA is not available. Triton autotuning is disabled.")
+
+ def _no_autotune(*args, **kwargs):
+ def deco(fn):
+ return fn
+ return deco
+
+ triton.autotune = _no_autotune
+
+
+def assert_is_tensor(x, ndim):
+ if x.ndim != ndim:
+ raise ValueError(f'Expected {ndim}-tensor but got {x.ndim}-tensor')
+
+
+def assert_is_matrix(x):
+ assert_is_tensor(x, 2)
+
+
+def assert_is_vector(x):
+ if x.ndim != 1:
+ raise ValueError(f'Expected 1-tensor but got {x.ndim}-tensor')
+
+
+def assert_equal(a, b):
+ if a != b:
+ raise ValueError(f'Expected dimensions to be equal but got {a} and {b}.',)
+
+
+# a: (tokens, hidden_size), real.
+# indices: (tokens * top_k), integer.
+# bin_ids: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+# padded_bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _padded_copy(
+ a,
+ b,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+ A_TO_B: tl.constexpr,
+ SCALE: tl.constexpr,
+):
+ # Our index into array 'a'.
+ index_a = tl.load(indices + tl.program_id(0))
+
+ # One threadblock per row in 'a'. Array 'b' has greater or equal
+ # number of rows since they could be padded.
+ bin_idx = tl.load(bin_ids + tl.program_id(0))
+
+ # Now we know what bin we're assigned to, but we need to know how
+ # many threadblocks were assigned to earlier bins so we can offset
+ # in our bin properly.
+ offset_in_bin = tl.program_id(0)
+ if bin_idx > 0:
+ offset_in_bin -= tl.load(bins + bin_idx - 1)
+
+ # Load the starting index of our bin in array 'b'.
+ index_b = offset_in_bin
+ if bin_idx > 0:
+ index_b += tl.load(padded_bins + bin_idx - 1)
+
+ # Offset the input and output pointers.
+ #
+ # If we're going from A to B, divide the input index to copy
+ # the same input repeatedly. If we're going from B to A we
+ # need to reduce the result. Using atomics is slow, so we
+ # do the reduce step in a second kernel.
+ offset = index_a // TOP_K if A_TO_B else index_a
+ a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS)
+ b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ # Load the scale, if requested.
+ scale = tl.load(weights + index_a) if SCALE else 1
+
+ # Swap the pointers depending on the direction.
+ iptr = a if A_TO_B else b
+ optr = b if A_TO_B else a
+
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ x = tl.load(iptr + offsets, mask=mask)
+ x = x.to(tl.float32) * scale.to(tl.float32)
+
+ tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask)
+
+ offsets += BLOCK_X
+
+
+def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+ assert_equal(bin_ids.shape[0], x.shape[0] * top_k)
+ assert_equal(bins.size(), padded_bins.size())
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ # NOTE: Because of the padding, the output size is dynamic.
+ # We load the final padded bin bound to get the output rows.
+ output_rows = padded_bins[-1].cpu().item()
+ out = torch.zeros((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ x,
+ out,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def gather(x, indices, bin_ids, weights, bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+ assert_equal(bin_ids.shape[0], x.shape[0] * top_k)
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ # NOTE: There is no padding so the output rows equals the
+ # input rows multiplied by top_k.
+ output_rows = x.shape[0] * top_k
+ out = torch.empty((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ x,
+ out,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], bin_ids.shape[0])
+ assert_equal(bins.size(), padded_bins.size())
+
+ if weights is not None:
+ assert_equal(indices.shape[0], weights.shape[0])
+
+ tokens = indices.shape[0] // top_k
+ out = torch.empty((tokens, top_k, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ out,
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=False,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+
+ # Reduce along the top-k dimension, if needed.
+ return out.sum(dim=1) if top_k > 1 else out.view(tokens, x.shape[1])
+
+
+def scatter(x, indices, bin_ids, weights, bins, top_k):
+ return padded_scatter(x, indices, bin_ids, weights, bins, bins, top_k)
+
+
+# x: (tokens, top_k, hidden_size), real
+# grad: (tokens, hidden_size), real.
+# wgrad: (tokens, top_k), real.
+# indices: (tokens * top_k), integer.
+# bin_ids: (tokens * top_k), integer.
+# bins: (num_experts), integer.
+# padded_bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _padded_copy_wgrad(
+ x,
+ grad,
+ wgrad,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+):
+ # Our index into 'tokens * top_k'.
+ index_out = tl.load(indices + tl.program_id(0))
+
+ # One threadblock per row in 'a'. Array 'b' has greater or equal
+ # number of rows since they could be padded.
+ bin_idx = tl.load(bin_ids + tl.program_id(0))
+
+ # Now we know what bin we're assigned to, but we need to know how
+ # many threadblocks were assigned to earlier bins so we can offset
+ # in our bin properly.
+ offset_in_bin = tl.program_id(0)
+ if bin_idx > 0:
+ offset_in_bin -= tl.load(bins + bin_idx - 1)
+
+ # Load the starting index of our bin in array 'x'.
+ index_x = offset_in_bin
+ if bin_idx > 0:
+ index_x += tl.load(padded_bins + bin_idx - 1)
+
+ # Offset the input and output pointers.
+ wgrad += index_out
+ grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS)
+ x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ data = tl.load(x + offsets, mask=mask).to(tl.float32)
+ scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
+ acc += data * scale
+ offsets += BLOCK_X
+
+ # Reduce to get the final result and store.
+ out = tl.sum(acc).to(wgrad.dtype.element_ty)
+ tl.store(wgrad, out)
+
+
+def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_matrix(grad)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], bin_ids.shape[0])
+ assert_equal(bins.size(), padded_bins.size())
+
+ tokens = indices.shape[0] // top_k
+ out = torch.empty((tokens * top_k), dtype=x.dtype, device=x.device)
+ _padded_copy_wgrad[(indices.shape[0],)](
+ x,
+ grad,
+ out,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ TOP_K=top_k,
+ )
+ return out
+
+
+def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k):
+ return padded_scatter_wgrad(x, grad, indices, bin_ids, bins, bins, top_k)
+
+
+# a: (tokens, hidden_size), real.
+# b: (num_experts, expert_capacity, num_columns), real.
+# indices: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _binned_copy(
+ a,
+ b,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+ A_TO_B: tl.constexpr,
+ SCALE: tl.constexpr,
+):
+ # Load our indices into the output.
+ expert_idx = tl.program_id(0)
+ entry_idx = tl.program_id(1)
+
+ # Calculate our offset into the output.
+ index_b = expert_idx * expert_capacity + entry_idx
+
+ # Load the index bounds for our bin and calculate
+ # the number of tokens assigned to our expert.
+ start = 0
+ if expert_idx > 0:
+ start = tl.load(bins + expert_idx - 1)
+ end = tl.load(bins + expert_idx)
+ num_tokens = end - start
+
+ # Calculate our offset into the input. If we don't
+ # have an input exit early.
+ if entry_idx >= num_tokens:
+ return
+ index_a = tl.load(indices + start + entry_idx)
+
+ # Offset the input and output pointers.
+ #
+ # If we're going from A to B, divide the input index to copy
+ # the same input repeatedly. If we're going from B to A we
+ # need to reduce the result. Using atomics is slow, so we
+ # do the reduce step in a second kernel.
+ offset = index_a // TOP_K if A_TO_B else index_a
+ a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS)
+ b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ # Load the scale, if requested.
+ scale = tl.load(weights + index_a) if SCALE else 1
+
+ # Swap the pointers depending on the direction.
+ #
+ # NOTE: We need to zero the output in both directions.
+ iptr = a if A_TO_B else b
+ optr = b if A_TO_B else a
+
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ x = tl.load(iptr + offsets, mask=mask)
+ x = x.to(tl.float32) * scale.to(tl.float32)
+
+ tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask)
+
+ offsets += BLOCK_X
+
+
+def binned_gather(x, indices, weights, bins, expert_capacity, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ num_experts = bins.shape[0]
+ out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
+
+ _binned_copy[(num_experts, expert_capacity)](
+ x,
+ out,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def binned_scatter(x, indices, weights, bins, top_k):
+ # Validate the input shapes.
+ assert_is_tensor(x, 3)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(bins.shape[0], x.shape[0])
+
+ if weights is not None:
+ assert_equal(indices.shape[0], weights.shape[0])
+
+ num_experts, expert_capacity, hidden_size = x.shape
+ tokens = indices.shape[0] // top_k
+ out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
+ _binned_copy[(num_experts, expert_capacity)](
+ out,
+ x,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS=hidden_size,
+ A_TO_B=False,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+
+ # Reduce along the top-k dimension, if needed.
+ return out.sum(dim=1) if top_k > 1 else out.view(tokens, hidden_size)
+
+
+# a: (tokens, hidden_size), real.
+# b: (num_experts, expert_capacity, num_columns), real.
+# indices: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _binned_copy_wgrad(
+ x,
+ grad,
+ wgrad,
+ num_experts,
+ expert_capacity,
+ indices,
+ bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+):
+ # Load our indices into the output.
+ expert_idx = tl.program_id(0)
+ entry_idx = tl.program_id(1)
+
+ # Calculate our offset into the output.
+ index_x = expert_idx * expert_capacity + entry_idx
+
+ # Load the index bounds for our bin and calculate
+ # the number of tokens assigned to our expert.
+ start = 0
+ if expert_idx > 0:
+ start = tl.load(bins + expert_idx - 1)
+ end = tl.load(bins + expert_idx)
+ num_tokens = end - start
+
+ # Calculate our offset into the input. If we don't
+ # have an input exit early.
+ if entry_idx >= num_tokens:
+ return
+ index_out = tl.load(indices + start + entry_idx)
+
+ # Offset the input and output pointers.
+ wgrad += index_out
+ grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS)
+ x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ data = tl.load(x + offsets, mask=mask).to(tl.float32)
+ scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
+ acc += data * scale
+ offsets += BLOCK_X
+
+ # Reduce to get the final result and store.
+ out = tl.sum(acc).to(wgrad.dtype.element_ty)
+ tl.store(wgrad, out)
+
+
+def binned_scatter_wgrad(x, grad, indices, bins, top_k):
+ # Validate the input shapes.
+ assert_is_tensor(x, 3)
+ assert_is_matrix(grad)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(bins.shape[0], x.shape[0])
+
+ num_experts, expert_capacity, hidden_size = x.shape
+ tokens = indices.shape[0] // top_k
+ out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
+ _binned_copy_wgrad[(num_experts, expert_capacity)](
+ x,
+ grad,
+ out,
+ num_experts,
+ expert_capacity,
+ indices,
+ bins,
+ NUM_COLUMNS=hidden_size,
+ TOP_K=top_k,
+ )
+ return out
diff --git a/build/torch210-cxx11-cu126-x86_64-linux/benchmark_util.py b/build/torch210-cxx11-cu126-x86_64-linux/benchmark_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..02612d95e3ead1175a596e2878fa34b5bf85ad6f
--- /dev/null
+++ b/build/torch210-cxx11-cu126-x86_64-linux/benchmark_util.py
@@ -0,0 +1,35 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import numpy as np
+import torch
+
+
+def log_benchmark(name, arguments, time, std):
+ print('=' * 60)
+ print(f'{name} Benchmark')
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean time = {:.3f}ms, std time = {:.3f}ms'.format(time, std))
+ print('=' * 60)
+
+
+def benchmark_function(fn, iterations=100, warmup=10):
+ # Warmup iterations.
+ for _ in range(warmup):
+ fn()
+
+ times = []
+ for i in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+
+ start.record()
+ fn()
+ end.record()
+
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ return np.mean(times), np.std(times)
diff --git a/build/torch210-cxx11-cu126-x86_64-linux/cpu_fused_moe.py b/build/torch210-cxx11-cu126-x86_64-linux/cpu_fused_moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a29e81d6e9c57f64b0a78ac1c0828e45fd9d855
--- /dev/null
+++ b/build/torch210-cxx11-cu126-x86_64-linux/cpu_fused_moe.py
@@ -0,0 +1,311 @@
+# SPDX-License-Identifier: Apache-2.0
+# MegaBlocks CPU Fused MoE Implementation
+#
+# This is a pure Python/PyTorch implementation for CPU.
+# For better performance, consider using the C++ kernel implementation.
+#
+import torch
+import torch.nn.functional as F
+
+
+def swigluoai_activation(gate: torch.Tensor, up: torch.Tensor,
+ alpha: float = 1.702, limit: float = 7.0) -> torch.Tensor:
+ """
+ SwigluOAI activation function used in GptOss models.
+
+ Formula:
+ gate = clamp(gate, max=limit)
+ up = clamp(up, -limit, limit)
+ glu = gate * sigmoid(gate * alpha)
+ output = (up + 1) * glu
+
+ Args:
+ gate: Gate tensor from gate projection
+ up: Up tensor from up projection
+ alpha: Scaling factor for sigmoid (default: 1.702)
+ limit: Clamp limit (default: 7.0)
+
+ Returns:
+ Activated tensor
+ """
+ gate = gate.clamp(max=limit)
+ up = up.clamp(min=-limit, max=limit)
+ glu = gate * torch.sigmoid(gate * alpha)
+ return (up + 1) * glu
+
+
+def silu_and_mul_activation(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor:
+ """
+ SiLU (Swish) activation with element-wise multiplication.
+
+ Formula:
+ output = silu(gate) * up
+
+ Args:
+ gate: Gate tensor
+ up: Up tensor
+
+ Returns:
+ Activated tensor
+ """
+ return F.silu(gate) * up
+
+
+def route_tokens_cpu(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: torch.Tensor | None,
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_normalize_expert_weights: int | None = None,
+) -> tuple:
+ """
+ Route tokens to experts and compute expert weights and indices (CPU version).
+
+ Args:
+ x: Input tensor [batch, seq, hidden] or [tokens, hidden]
+ router_weight: Router weight [num_experts, hidden]
+ router_bias: Router bias [num_experts] or None
+ moe_top_k: Number of experts per token
+ moe_num_experts: Total number of experts
+ moe_normalize_expert_weights: Normalization order or None
+
+ Returns:
+ Tuple of (logits, expert_weights, expert_indices)
+ """
+ x_flat = x.view(-1, x.shape[-1])
+ logits = F.linear(x_flat, router_weight, router_bias)
+
+ if moe_top_k == 1:
+ expert_weights, expert_indices = logits.max(dim=-1, keepdim=True)
+ else:
+ expert_weights, expert_indices = torch.topk(logits, moe_top_k, dim=-1)
+
+ expert_weights = expert_weights.softmax(dim=-1)
+
+ if moe_normalize_expert_weights is not None:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+
+ return logits, expert_weights, expert_indices
+
+
+def cpu_fused_moe(
+ hidden_states: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ w1_bias: torch.Tensor | None = None,
+ w2_bias: torch.Tensor | None = None,
+ activation: str = "silu",
+ alpha: float = 1.702,
+ limit: float = 7.0,
+ is_interleaved: bool = True,
+) -> torch.Tensor:
+ """
+ CPU Fused MoE using PyTorch operations.
+
+ This implementation processes all experts in parallel using batched operations
+ instead of sequential for loops, which is more efficient on CPU.
+
+ Args:
+ hidden_states: [num_tokens, hidden_size]
+ w1: [num_experts, hidden_size, 2*inter_size] - gate_up_proj weights
+ w2: [num_experts, inter_size, hidden_size] - down_proj weights
+ topk_weights: [num_tokens, topk] - routing weights
+ topk_ids: [num_tokens, topk] - expert indices
+ w1_bias: [num_experts, 2*inter_size] or None
+ w2_bias: [num_experts, hidden_size] or None
+ activation: "silu" or "swigluoai"
+ alpha: swigluoai alpha parameter
+ limit: swigluoai limit parameter
+ is_interleaved: whether gate_up is interleaved [g0,u0,g1,u1,...] (True for GptOss)
+
+ Returns:
+ output: [num_tokens, hidden_size]
+ """
+ num_tokens, hidden_size = hidden_states.shape
+ num_experts = w1.shape[0]
+ inter_size = w2.shape[1]
+ topk = topk_weights.shape[1]
+
+ # Initialize output
+ output = torch.zeros_like(hidden_states)
+
+ # Build expert mask: which tokens go to which expert
+ # expert_mask[expert_id] contains indices of (token_idx, topk_pos) pairs
+ for expert_idx in range(num_experts):
+ # Find tokens assigned to this expert
+ # mask shape: [num_tokens, topk], True where topk_ids == expert_idx
+ mask = (topk_ids == expert_idx)
+
+ if not mask.any():
+ continue
+
+ # Get token indices and topk positions
+ token_indices, topk_positions = torch.where(mask)
+
+ if len(token_indices) == 0:
+ continue
+
+ # Gather input tokens for this expert
+ # current_hidden: [num_selected_tokens, hidden_size]
+ current_hidden = hidden_states[token_indices]
+
+ # Get weights for this expert
+ # w1[expert_idx]: [hidden_size, 2*inter_size]
+ # w2[expert_idx]: [inter_size, hidden_size]
+ expert_w1 = w1[expert_idx] # [hidden_size, 2*inter_size]
+ expert_w2 = w2[expert_idx] # [inter_size, hidden_size]
+
+ # First projection: hidden @ w1 -> [num_selected, 2*inter_size]
+ gate_up = current_hidden @ expert_w1
+
+ # Add bias if present
+ if w1_bias is not None:
+ gate_up = gate_up + w1_bias[expert_idx]
+
+ # Split gate and up projections
+ if is_interleaved:
+ # GptOss uses interleaved layout: [g0, u0, g1, u1, ...]
+ gate = gate_up[..., ::2] # [num_selected, inter_size]
+ up = gate_up[..., 1::2] # [num_selected, inter_size]
+ else:
+ # Standard layout: [gate_all, up_all]
+ gate = gate_up[..., :inter_size]
+ up = gate_up[..., inter_size:]
+
+ # Apply activation
+ if activation == "swigluoai":
+ activated = swigluoai_activation(gate, up, alpha, limit)
+ else: # silu
+ activated = silu_and_mul_activation(gate, up)
+
+ # Second projection: activated @ w2 -> [num_selected, hidden_size]
+ expert_out = activated @ expert_w2
+
+ # Add bias if present
+ if w2_bias is not None:
+ expert_out = expert_out + w2_bias[expert_idx]
+
+ # Apply routing weights and accumulate
+ # weights shape: [num_selected]
+ weights = topk_weights[token_indices, topk_positions].unsqueeze(-1)
+ weighted_out = expert_out * weights
+
+ # Accumulate to output
+ output.index_add_(0, token_indices, weighted_out.to(output.dtype))
+
+ return output
+
+
+class MegaBlocksMoeMLP(torch.nn.Module):
+ """
+ CPU MoE MLP module that can be used as a drop-in replacement for
+ the transformers GptOssMLP when using @use_kernel_forward_from_hub.
+ """
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> tuple:
+ """
+ Forward pass through the MoE layer.
+
+ Args:
+ x: Input tensor of shape [batch_size, seq_len, hidden_size] or [tokens, hidden_size]
+
+ Returns:
+ Tuple of (output, expert_weights) where:
+ - output: Tensor of same shape as input
+ - expert_weights: Expert weights for each token [tokens, top_k]
+ """
+ # Get MoE parameters from the wrapped modules
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+
+ # Detect activation type
+ if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
+ activation = "swigluoai"
+ alpha = self.experts.alpha
+ limit = self.experts.limit
+ else:
+ activation = getattr(self.experts, "activation", "silu")
+ alpha = 1.702
+ limit = 7.0
+
+ # Get weight tensors
+ if hasattr(self.experts, "gate_up_proj"):
+ w1 = self.experts.gate_up_proj
+ is_interleaved = True # GptOss uses interleaved layout
+ elif hasattr(self.experts, "w1"):
+ w1 = self.experts.w1
+ w3 = getattr(self.experts, "w3", None)
+ if w3 is not None:
+ w1 = torch.cat([w1, w3], dim=-1)
+ is_interleaved = False
+ else:
+ raise AttributeError("experts module must have 'gate_up_proj' or 'w1' attribute")
+
+ if hasattr(self.experts, "down_proj"):
+ w2 = self.experts.down_proj
+ elif hasattr(self.experts, "w2"):
+ w2 = self.experts.w2
+ else:
+ raise AttributeError("experts module must have 'down_proj' or 'w2' attribute")
+
+ # Get optional bias tensors
+ w1_bias = getattr(self.experts, "gate_up_proj_bias", None)
+ w2_bias = getattr(self.experts, "down_proj_bias", None)
+
+ # Store original shape
+ in_shape = x.size()
+
+ # Route tokens to experts
+ logits, expert_weights, expert_indices = route_tokens_cpu(
+ x,
+ self.router.weight,
+ getattr(self.router, "bias", None),
+ moe_top_k,
+ moe_num_experts,
+ moe_normalize_expert_weights,
+ )
+
+ # Reshape input for fused MoE
+ x_flat = x.view(-1, x.shape[-1])
+
+ # Call CPU fused MoE
+ output = cpu_fused_moe(
+ hidden_states=x_flat,
+ w1=w1,
+ w2=w2,
+ topk_weights=expert_weights,
+ topk_ids=expert_indices,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ activation=activation,
+ alpha=alpha,
+ limit=limit,
+ is_interleaved=is_interleaved,
+ )
+
+ # Restore original shape
+ output = output.view(in_shape)
+
+ return output, expert_weights
+
+
+# Export classes and functions
+__all__ = [
+ "MegaBlocksMoeMLP",
+ "cpu_fused_moe",
+ "route_tokens_cpu",
+ "swigluoai_activation",
+ "silu_and_mul_activation",
+]
diff --git a/build/torch210-cxx11-cu126-x86_64-linux/cpu_moe_cpp.py b/build/torch210-cxx11-cu126-x86_64-linux/cpu_moe_cpp.py
new file mode 100644
index 0000000000000000000000000000000000000000..073ff66d24ce348fbb5ed19c9027fadd3f7a9c61
--- /dev/null
+++ b/build/torch210-cxx11-cu126-x86_64-linux/cpu_moe_cpp.py
@@ -0,0 +1,265 @@
+# SPDX-License-Identifier: Apache-2.0
+# MegaBlocks C++ Optimized CPU MoE
+
+"""
+C++ accelerated MoE with brgemm optimization for Intel AMX.
+Direct replacement for cpu_fused_moe.MegaBlocksMoeMLP with better performance.
+"""
+
+import torch
+from typing import Optional
+from .cpu_fused_moe import route_tokens_cpu
+from ._ops import ops
+
+
+def _to_local_tensor(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
+ """Convert DTensor to local torch.Tensor if needed for custom ops compatibility."""
+ if tensor is None:
+ return None
+ # Check if it's a DTensor by looking for the to_local() method
+ if hasattr(tensor, "to_local"):
+ return tensor.to_local()
+ return tensor
+
+
+def fused_moe_cpp(
+ hidden_states: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ inplace: bool = False,
+ use_int8_w8a8: bool = False,
+ use_fp8_w8a16: bool = False,
+ use_mxfp4: bool = False,
+ w1_scale: Optional[torch.Tensor] = None,
+ w2_scale: Optional[torch.Tensor] = None,
+ block_size: Optional[list] = None,
+ a1_scale: Optional[torch.Tensor] = None,
+ a2_scale: Optional[torch.Tensor] = None,
+ w1_bias: Optional[torch.Tensor] = None,
+ w2_bias: Optional[torch.Tensor] = None,
+ alpha: Optional[float] = None,
+ limit: Optional[float] = None,
+ is_vnni: bool = False,
+) -> torch.Tensor:
+ """
+ C++ Fused MoE with brgemm optimization (sglang compatible interface).
+
+ Uses at::native::cpublas::brgemm for efficient batch GEMM on Intel CPUs.
+ Supports both silu_and_mul (standard SwiGLU) and swigluoai (GptOss) activations.
+
+ Args:
+ hidden_states: Input tensor [M, K]
+ w1: Gate and up projections [E, 2N, K]
+ w2: Down projection [E, K, N]
+ topk_weights: Expert weights [M, topk]
+ topk_ids: Expert indices [M, topk]
+ inplace: Whether to use hidden_states as output
+ use_int8_w8a8: Use int8 quantization
+ use_fp8_w8a16: Use fp8 quantization
+ use_mxfp4: Use mxfp4 quantization
+ w1_scale, w2_scale: Quantization scales
+ block_size: Block size for fp8
+ a1_scale, a2_scale: Activation scales
+ w1_bias, w2_bias: Optional biases
+ alpha: swigluoai alpha parameter (set to enable swiglu)
+ limit: swigluoai limit parameter (set to enable swiglu)
+ is_vnni: Whether w1/w2 are already in VNNI packed format
+ """
+ # MXFP4/FP8 kernels only support bf16, convert if needed
+ orig_dtype = hidden_states.dtype
+ need_convert = ((use_mxfp4 or use_fp8_w8a16) and orig_dtype != torch.bfloat16) or orig_dtype == torch.float32
+ if need_convert:
+ hidden_states = hidden_states.to(torch.bfloat16)
+
+ # bias must match hidden_states dtype
+ if w1_bias is not None:
+ w1_bias = w1_bias.to(hidden_states.dtype)
+ if w2_bias is not None:
+ w2_bias = w2_bias.to(hidden_states.dtype)
+
+ # Convert DTensor to local tensor for custom ops compatibility (TP mode)
+ hidden_states = _to_local_tensor(hidden_states)
+ w1 = _to_local_tensor(w1)
+ w2 = _to_local_tensor(w2)
+ topk_weights = _to_local_tensor(topk_weights)
+ topk_ids = _to_local_tensor(topk_ids)
+ w1_scale = _to_local_tensor(w1_scale)
+ w2_scale = _to_local_tensor(w2_scale)
+ a1_scale = _to_local_tensor(a1_scale)
+ a2_scale = _to_local_tensor(a2_scale)
+ w1_bias = _to_local_tensor(w1_bias)
+ w2_bias = _to_local_tensor(w2_bias)
+
+ output = ops.fused_experts(
+ hidden_states, w1, w2, topk_weights, topk_ids,
+ inplace, use_int8_w8a8, use_fp8_w8a16, use_mxfp4,
+ w1_scale, w2_scale, block_size, a1_scale, a2_scale,
+ w1_bias, w2_bias, alpha, limit, is_vnni
+ )
+
+ # Convert back to original dtype if needed
+ if need_convert:
+ output = output.to(orig_dtype)
+ return output
+
+
+class CPUMegaBlocksMoeMLP(torch.nn.Module):
+ """
+ C++ optimized MoE MLP using brgemm.
+ Drop-in replacement for cpu_fused_moe.MegaBlocksMoeMLP with better performance.
+
+ Usage in transformers:
+ # Will be used via @use_kernel_forward_from_hub decorator
+ """
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> tuple:
+ """
+ Forward pass through the MoE layer using C++ kernel.
+
+ Args:
+ x: Input tensor [batch_size, seq_len, hidden_size]
+
+ Returns:
+ Tuple of (output, expert_weights)
+ """
+ # Optimization for GPT-OSS model
+ if getattr(self, "use_mxfp4", None) is None:
+ self.use_mxfp4 = False
+
+ w1_scale = None
+ w2_scale = None
+
+ if (
+ not getattr(self, "packed_scales", False)
+ and hasattr(self.experts, "gate_up_proj")
+ and getattr(self.experts, "gate_up_proj_precision_config", None) is not None
+ ):
+ # convert scales
+ data_1 = ops.convert_scale_packed(self.experts.gate_up_proj_precision_config.weight_scale.data.transpose(-1, -2).contiguous())
+ data_2 = ops.convert_scale_packed(self.experts.down_proj_precision_config.weight_scale.data.transpose(-1, -2).contiguous())
+ self.experts.gate_up_proj_precision_config.weight_scale.storage.data = data_1
+ self.experts.down_proj_precision_config.weight_scale.storage.data = data_2
+ self.packed_scales = True
+ self.use_mxfp4 = True
+
+ if not getattr(self, "packed_weight", False) and hasattr(
+ self.experts, "gate_up_proj"
+ ):
+ # convert weights
+ data_1 = self.experts.gate_up_proj.data.transpose(-1, -2).contiguous()
+ data_2 = self.experts.down_proj.data.transpose(-1, -2).contiguous()
+ if self.use_mxfp4:
+ self.experts.gate_up_proj.storage.data = ops.convert_weight_packed(data_1)
+ self.experts.down_proj.storage.data = ops.convert_weight_packed(data_2)
+ else:
+ # convert_weight_packed only supports bfloat16, float16, int8, fp8_e4m3 or uint8(mxfp4 or int4).
+ data_1 = data_1.to(torch.bfloat16) if data_1.dtype == torch.float32 else data_1
+ data_2 = data_2.to(torch.bfloat16) if data_2.dtype == torch.float32 else data_2
+ self.experts.gate_up_proj.data = ops.convert_weight_packed(data_1)
+ self.experts.down_proj.data = ops.convert_weight_packed(data_2)
+
+ # C++ kernel does not support float32.
+ dtype = torch.bfloat16 if x.dtype == torch.float32 else x.dtype
+ if getattr(self.experts, "gate_up_proj_bias", None) is not None:
+ self.experts.gate_up_proj_bias.data = self.experts.gate_up_proj_bias.data.to(dtype)
+ if getattr(self.experts, "down_proj_bias", None) is not None:
+ self.experts.down_proj_bias.data = self.experts.down_proj_bias.data.to(dtype)
+
+ self.packed_weight = True
+
+ # Get MoE parameters
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
+
+ # Detect activation type
+ if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
+ activation = "swigluoai"
+ alpha = self.experts.alpha
+ limit = self.experts.limit
+ else:
+ activation = getattr(self.experts, "activation", "silu")
+ alpha = 1.702
+ limit = 7.0
+
+ # Get weight tensors
+ if hasattr(self.experts, "gate_up_proj"):
+ w1 = self.experts.gate_up_proj
+ elif hasattr(self.experts, "w1"):
+ w1 = self.experts.w1
+ w3 = getattr(self.experts, "w3", None)
+ if w3 is not None:
+ w1 = torch.cat([w1, w3], dim=-1)
+ else:
+ raise AttributeError("experts module must have 'gate_up_proj' or 'w1' attribute")
+
+ if hasattr(self.experts, "down_proj"):
+ w2 = self.experts.down_proj
+ elif hasattr(self.experts, "w2"):
+ w2 = self.experts.w2
+ else:
+ raise AttributeError("experts module must have 'down_proj' or 'w2' attribute")
+
+ # Get optional bias tensors
+ w1_bias = getattr(self.experts, "gate_up_proj_bias", None)
+ w2_bias = getattr(self.experts, "down_proj_bias", None)
+ w1_bias = w1_bias if w1_bias is None else w1_bias.data
+ w2_bias = w2_bias if w2_bias is None else w2_bias.data
+
+ if self.use_mxfp4:
+ w1_scale = self.experts.gate_up_proj_precision_config.weight_scale.data
+ w2_scale = self.experts.down_proj_precision_config.weight_scale.data
+
+ # Store original shape
+ in_shape = x.size()
+
+ # Route tokens to experts (Python implementation is fast enough)
+ logits, expert_weights, expert_indices = route_tokens_cpu(
+ x,
+ self.router.weight,
+ getattr(self.router, "bias", None),
+ moe_top_k,
+ moe_num_experts,
+ moe_normalize_expert_weights,
+ )
+
+ # Flatten input
+ x_flat = x.view(-1, x.shape[-1])
+
+ # Determine alpha/limit for swiglu activation
+ use_alpha = alpha if activation == "swigluoai" else None
+ use_limit = limit if activation == "swigluoai" else None
+
+ # Call C++ optimized kernel
+ output = fused_moe_cpp(
+ hidden_states=x_flat,
+ w1=w1.data,
+ w2=w2.data,
+ topk_weights=expert_weights,
+ topk_ids=expert_indices.to(torch.int32),
+ inplace=False,
+ use_int8_w8a8=False,
+ use_fp8_w8a16=False,
+ use_mxfp4=self.use_mxfp4,
+ w1_scale=w1_scale,
+ w2_scale=w2_scale,
+ block_size=None,
+ a1_scale=None,
+ a2_scale=None,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ alpha=use_alpha,
+ limit=use_limit,
+ is_vnni=getattr(self, "packed_weight", False),
+ )
+
+ # Restore original shape
+ output = output.view(in_shape)
+
+ return output, expert_weights
+
+
+__all__ = ["fused_moe_cpp", "MegaBlocksMoeMLP"]
diff --git a/build/torch210-cxx11-cu126-x86_64-linux/grouped_gemm/__init__.py b/build/torch210-cxx11-cu126-x86_64-linux/grouped_gemm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b91c8308f0c24f4c4171b6e4f15b6f76dabf295a
--- /dev/null
+++ b/build/torch210-cxx11-cu126-x86_64-linux/grouped_gemm/__init__.py
@@ -0,0 +1,2 @@
+from . import ops
+from . import backend
diff --git a/build/torch210-cxx11-cu126-x86_64-linux/grouped_gemm/backend.py b/build/torch210-cxx11-cu126-x86_64-linux/grouped_gemm/backend.py
new file mode 100644
index 0000000000000000000000000000000000000000..76037d8039cbfc2f0577275c78e4bc0be762592a
--- /dev/null
+++ b/build/torch210-cxx11-cu126-x86_64-linux/grouped_gemm/backend.py
@@ -0,0 +1,33 @@
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# # TODO(tgale): Wrap this in a try-block with better
+# # error message and instructions for building the
+# # c++ operations.
+# import grouped_gemm_backend as backend
+
+# We import the backend operations from the megablocks package as
+# grouped_gemm is vendored in megablocks in this repository.
+# from ... import _ops as backend
+# from megablocks._ops import ops as backend # type: ignore
+from .._ops import ops as backend # type: ignore
+
+def _allocate_output(a, b, batch_sizes, trans_a, trans_b):
+ assert not (trans_a and trans_b)
+ assert batch_sizes.ndim == 1, "Expected 1d tensor for batch_sizes"
+ assert a.ndim == 2, "Expected 2d tensor for 'a'"
+ assert b.ndim == (2 if trans_a else 3)
+
+ shape = (
+ (batch_sizes.shape[0], a.shape[1], b.shape[1])
+ if trans_a else
+ (a.shape[0], (b.shape[1] if trans_b else b.shape[2]))
+ )
+ return torch.empty(*shape, device=a.device, dtype=a.dtype)
+
+def gmm(a, b, batch_sizes, trans_a=False, trans_b=False, c=None):
+ if c is None:
+ c = _allocate_output(a, b, batch_sizes, trans_a, trans_b)
+ backend.gmm(a, b, c, batch_sizes, trans_a, trans_b)
+ return c
diff --git a/build/torch210-cxx11-cu126-x86_64-linux/grouped_gemm/ops.py b/build/torch210-cxx11-cu126-x86_64-linux/grouped_gemm/ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b30dd14e23837ea3b12334f4e31337ed9ad2b69
--- /dev/null
+++ b/build/torch210-cxx11-cu126-x86_64-linux/grouped_gemm/ops.py
@@ -0,0 +1,33 @@
+from . import backend
+import torch
+
+
+class GroupedGemm(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, a, b, batch_sizes, trans_b):
+ ctx.save_for_backward(a, b, batch_sizes)
+ ctx.trans_b = trans_b
+ return backend.gmm(a, b, batch_sizes, trans_a=False, trans_b=trans_b)
+
+ @staticmethod
+ def backward(ctx, grad):
+ grad = grad.contiguous()
+ a, b, batch_sizes = ctx.saved_tensors
+ trans_b = ctx.trans_b
+
+ agrad = None
+ if ctx.needs_input_grad[0]:
+ agrad = backend.gmm(
+ grad, b, batch_sizes, trans_a=False, trans_b=not trans_b)
+
+ bgrad = None
+ if ctx.needs_input_grad[1]:
+ lhs, rhs = (grad, a) if trans_b else (a, grad)
+ bgrad = backend.gmm(
+ lhs, rhs, batch_sizes, trans_a=True, trans_b=False)
+ return agrad, bgrad, None, None
+
+
+def gmm(a, b, batch_sizes, trans_b=False):
+ return GroupedGemm.apply(a, b, batch_sizes, trans_b)
diff --git a/build/torch210-cxx11-cu126-x86_64-linux/grouped_gemm_util.py b/build/torch210-cxx11-cu126-x86_64-linux/grouped_gemm_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d6d49fc46b0a57ad46e4179df3cc1ac2a24f7ae
--- /dev/null
+++ b/build/torch210-cxx11-cu126-x86_64-linux/grouped_gemm_util.py
@@ -0,0 +1,31 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+import warnings
+
+_grouped_gemm_is_available: bool = False
+try:
+ # import grouped_gemm
+ pass
+ _grouped_gemm_is_available = True
+except ImportError as error:
+ warnings.warn('Grouped GEMM not available.')
+
+
+def grouped_gemm_is_available():
+ return _grouped_gemm_is_available
+
+
+def assert_grouped_gemm_is_available():
+ msg = (
+ 'Grouped GEMM not available. Please run '
+ '`pip install git+https://github.com/tgale96/grouped_gemm@main`.',
+ )
+ assert _grouped_gemm_is_available, msg
+
+
+# backend = grouped_gemm.backend if grouped_gemm_is_available() else None
+# ops = grouped_gemm.ops if grouped_gemm_is_available() else None
+
+
+from .grouped_gemm import backend as ops
+from .grouped_gemm import ops as backend
diff --git a/build/torch210-cxx11-cu126-x86_64-linux/layers.py b/build/torch210-cxx11-cu126-x86_64-linux/layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c91edcd9e2d1a4ef9eac90217ff481f08ab1886
--- /dev/null
+++ b/build/torch210-cxx11-cu126-x86_64-linux/layers.py
@@ -0,0 +1,1232 @@
+import torch
+import torch.distributed as dist
+
+from typing import Optional, Any, TYPE_CHECKING
+
+from . import _layers
+from . import ops
+
+# Conditional import for meta kernel registration
+if TYPE_CHECKING:
+
+ def register_fake(fn):
+ return lambda name: fn
+
+else:
+ try:
+ from torch.library import register_fake
+ except ImportError:
+ try:
+ from torch.library import impl_abstract as register_fake
+ except ImportError:
+ # Fallback for older PyTorch versions
+ def register_fake(op_name):
+ def decorator(fn):
+ return fn
+
+ return decorator
+
+
+# Meta kernel implementations for torch.compile compatibility
+def _install_meta_kernels():
+ """Install meta kernels for existing MegaBlocks operations"""
+
+ # Create wrapper functions that check for compilation and return meta tensors
+
+ # Patch ops.sort
+ if hasattr(ops, "sort"):
+ original_sort = ops.sort
+
+ def sort_with_meta(x, end_bit=None):
+ if torch.compiler.is_compiling():
+ # print("Using meta kernel for sort")
+ # Meta implementation - return tensors with correct shape/dtype/device
+ return torch.empty_like(x), torch.empty_like(x)
+ # print("Using original sort kernel")
+ return original_sort(x, end_bit)
+
+ ops.sort = sort_with_meta
+
+ # Patch ops.histogram
+ if hasattr(ops, "histogram"):
+ original_histogram = ops.histogram
+
+ def histogram_with_meta(x, max_val):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty((max_val,), dtype=torch.int32, device=x.device)
+ return original_histogram(x, max_val)
+
+ ops.histogram = histogram_with_meta
+
+ # Patch ops.inclusive_cumsum
+ if hasattr(ops, "inclusive_cumsum"):
+ original_inclusive_cumsum = ops.inclusive_cumsum
+
+ def inclusive_cumsum_with_meta(x, dim):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty_like(x)
+ return original_inclusive_cumsum(x, dim)
+
+ ops.inclusive_cumsum = inclusive_cumsum_with_meta
+
+ # Patch ops.binned_gather
+ if hasattr(ops, "binned_gather"):
+ original_binned_gather = ops.binned_gather
+
+ def binned_gather_with_meta(x, indices, bins, bin_size, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - output shape based on bin_size
+ if x.dim() >= 2:
+ hidden_size = x.size(-1)
+ return torch.empty(
+ (bin_size, x.size(1), hidden_size),
+ dtype=x.dtype,
+ device=x.device,
+ )
+ else:
+ return torch.empty((bin_size,), dtype=x.dtype, device=x.device)
+ return original_binned_gather(x, indices, bins, bin_size, top_k)
+
+ ops.binned_gather = binned_gather_with_meta
+
+ # Patch ops.binned_scatter
+ if hasattr(ops, "binned_scatter"):
+ original_binned_scatter = ops.binned_scatter
+
+ def binned_scatter_with_meta(x, indices, weights, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - typically reduces to 2D
+ if x.dim() >= 3:
+ return torch.empty(
+ (x.size(1), x.size(2)), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty_like(x)
+ return original_binned_scatter(x, indices, weights, bins, top_k)
+
+ ops.binned_scatter = binned_scatter_with_meta
+
+ # Patch ops.gather
+ if hasattr(ops, "gather"):
+ original_gather = ops.gather
+
+ def gather_with_meta(x, indices, bin_ids, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ if x.dim() >= 2:
+ hidden_size = x.size(-1)
+ return torch.empty(
+ (indices.numel(), hidden_size), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty(indices.shape, dtype=x.dtype, device=x.device)
+ return original_gather(x, indices, bin_ids, bins, top_k)
+
+ ops.gather = gather_with_meta
+
+ # Patch ops.scatter
+ if hasattr(ops, "scatter"):
+ original_scatter = ops.scatter
+
+ def scatter_with_meta(x, indices, bin_ids, weights, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - restore sequence shape
+ seq_len = (
+ indices.size(0) // top_k
+ if indices.numel() > 0 and top_k > 0
+ else x.size(0)
+ )
+ if x.dim() >= 2:
+ return torch.empty(
+ (seq_len, x.size(-1)), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty((seq_len,), dtype=x.dtype, device=x.device)
+ return original_scatter(x, indices, bin_ids, weights, bins, top_k)
+
+ ops.scatter = scatter_with_meta
+
+ # Patch ops.replicate
+ if hasattr(ops, "replicate"):
+ original_replicate = ops.replicate
+
+ def replicate_with_meta(x, bins, num_outputs):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty(
+ (x.shape[0], num_outputs), dtype=x.dtype, device=x.device
+ )
+ return original_replicate(x, bins, num_outputs)
+
+ ops.replicate = replicate_with_meta
+
+ # Patch ops.repeat (if it's a regular function)
+ if hasattr(ops, "repeat"):
+ original_repeat = ops.repeat
+
+ def repeat_with_meta(x, repeats):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ if isinstance(repeats, (tuple, list)):
+ new_shape = list(x.shape)
+ for i, rep in enumerate(repeats):
+ if i < len(new_shape):
+ new_shape[i] *= rep
+ return torch.empty(new_shape, dtype=x.dtype, device=x.device)
+ else:
+ new_shape = [x.size(0) * repeats] + list(x.shape[1:])
+ return torch.empty(new_shape, dtype=x.dtype, device=x.device)
+ return original_repeat(x, repeats)
+
+ ops.repeat = repeat_with_meta
+
+
+# Install meta kernels on import
+try:
+ _install_meta_kernels()
+except Exception as e:
+ # If meta kernel installation fails, continue without them
+ # torch.compile may not work but the library will still function
+ import warnings
+
+ warnings.warn(
+ f"Failed to install meta kernels for torch.compile support: {e}", UserWarning
+ )
+
+
+# Set the expert model parallel attributes on a tensor
+def set_expert_model_parallel_attributes(
+ tensor: torch.Tensor,
+ is_parallel: bool,
+):
+ assert not hasattr(tensor, "expert_model_parallel")
+ setattr(tensor, "expert_model_parallel", is_parallel)
+
+
+# Get the expert model parallel attributes from a tensor
+def expert_sharding_degree(
+ world_size: int,
+ moe_num_experts: int,
+) -> int:
+ esd = min(world_size, moe_num_experts)
+ if (moe_num_experts % esd) != 0:
+ raise ValueError(f"Cannot shard {moe_num_experts} experts {esd} ways.")
+ return esd
+
+
+# Calculate the hidden sharding degree based on world size and expert sharding degree
+def hidden_sharding_degree(
+ world_size: int,
+ moe_num_experts: int,
+ ffn_hidden_size: int,
+) -> int:
+ esd = expert_sharding_degree(world_size, moe_num_experts)
+ hsd = world_size // esd
+ if (ffn_hidden_size % hsd) != 0:
+ raise ValueError(f"Cannot shard {ffn_hidden_size} features {hsd} ways.")
+ if (esd * hsd) != world_size:
+ raise ValueError(
+ f"Invalid sharding. expert_sharding_degree ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size})."
+ )
+ return hsd
+
+
+# Calculate the number of experts per rank based on world size and expert sharding degree
+def experts_per_rank(
+ moe_num_experts: int,
+ world_size: int,
+) -> int:
+ return moe_num_experts // expert_sharding_degree(world_size, moe_num_experts)
+
+
+# Calculate the number of features per rank based on ffn hidden size and hidden sharding degree
+def features_per_rank(
+ ffn_hidden_size: int, world_size: int, moe_num_experts: int
+) -> int:
+ return ffn_hidden_size // hidden_sharding_degree(
+ world_size, moe_num_experts, ffn_hidden_size
+ )
+
+
+# Apply jitter to the input tensor
+def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor:
+ low = 1.0 - moe_jitter_eps
+ high = 1.0 + moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return x * (low + noise * (high - low))
+
+
+# Compute the top-k scores from the logits
+def compute_top_k(scores: torch.Tensor, moe_top_k: int):
+ if moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, moe_top_k, dim=-1)
+
+
+# Route tokens to experts and compute expert weights and indices
+def route_tokens(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: torch.Tensor,
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ if training and moe_jitter_eps is not None:
+ x = apply_jitter(x, moe_jitter_eps)
+
+ x_flat = x.view(-1, x.shape[-1])
+ logits = torch.nn.functional.linear(x_flat, router_weight, router_bias)
+ expert_weights, expert_indices = compute_top_k(logits, moe_top_k)
+ expert_weights = expert_weights.softmax(dim=-1)
+ if moe_normalize_expert_weights is not None:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+ if uniform_expert_assignment:
+ expert_indices = _layers.router._uniform_expert_assignment(
+ expert_indices,
+ moe_num_experts,
+ )
+
+ return logits, expert_weights, expert_indices
+
+
+# Scale the gradient of the weights
+def scale_grad(
+ w: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+) -> torch.Tensor:
+ if gradient_scale is None:
+ return w
+ return _layers.mlp.scale_gradient(w, gradient_scale)
+
+
+# Forward pass for the MLP layer
+def mlp_forward(
+ x: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ limit: float = 7.0,
+):
+ # Scale weights
+ w1 = scale_grad(w1, gradient_scale)
+ w2 = scale_grad(w2, gradient_scale)
+ w1_bias = scale_grad(w1_bias, gradient_scale)
+ w2_bias = scale_grad(w2_bias, gradient_scale)
+
+ # Resolve dtensors
+ w1 = _layers.mlp.resolve_dtensor(w1)
+ w2 = _layers.mlp.resolve_dtensor(w2)
+ w1_bias = _layers.mlp.resolve_dtensor(w1_bias)
+ w2_bias = _layers.mlp.resolve_dtensor(w2_bias)
+
+ # Forward pass
+ gate_up = torch.bmm(x, w1) + w1_bias[..., None, :]
+ gate, up = gate_up[..., ::2], gate_up[..., 1::2]
+ gate = gate.clamp(min=None, max=limit)
+ up = up.clamp(min=-limit, max=limit)
+ glu = gate * torch.sigmoid(gate * alpha)
+ next_states = torch.bmm(((up + 1) * glu), w2)
+ next_states += w2_bias[..., None, :]
+ return next_states
+
+# Shared expert MLP forward pass
+def shared_mlp_forward(
+ x: torch.Tensor,
+ up_proj_weight: torch.Tensor,
+ down_proj_weight: torch.Tensor,
+ up_proj_bias: Optional[torch.Tensor] = None,
+ down_proj_bias: Optional[torch.Tensor] = None,
+ activation_fn: Optional[Any] = None,
+ gradient_scale: Optional[float] = None,
+) -> torch.Tensor:
+ # Default activation function
+ if activation_fn is None:
+ activation_fn = torch.nn.functional.gelu
+
+ # Scale weights
+ up_proj_weight = scale_grad(up_proj_weight, gradient_scale)
+ down_proj_weight = scale_grad(down_proj_weight, gradient_scale)
+ if up_proj_bias is not None:
+ up_proj_bias = scale_grad(up_proj_bias, gradient_scale)
+ if down_proj_bias is not None:
+ down_proj_bias = scale_grad(down_proj_bias, gradient_scale)
+
+ # Resolve dtensors
+ up_proj_weight = _layers.mlp.resolve_dtensor(up_proj_weight)
+ down_proj_weight = _layers.mlp.resolve_dtensor(down_proj_weight)
+ if up_proj_bias is not None:
+ up_proj_bias = _layers.mlp.resolve_dtensor(up_proj_bias)
+ if down_proj_bias is not None:
+ down_proj_bias = _layers.mlp.resolve_dtensor(down_proj_bias)
+
+ # Up projection
+ x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias)
+
+ # Activation
+ x = activation_fn(x)
+
+ # Down projection
+ x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias)
+
+ return x
+
+
+# Combine outputs from shared expert and regular experts
+def combine_expert_shared_outputs(
+ shared_expert_out: torch.Tensor,
+ expert_out: torch.Tensor,
+ shared_expert_weighted_sum: bool = False,
+ moe_top_k: int = 1,
+) -> torch.Tensor:
+ if shared_expert_weighted_sum:
+ # Weighted sum based on number of experts used
+ total_experts = moe_top_k + 1
+ shared_weight = 1.0 / total_experts
+ expert_weight = moe_top_k / total_experts
+ return shared_expert_out * shared_weight + expert_out * expert_weight
+ else:
+ # Simple addition
+ return shared_expert_out + expert_out
+
+
+# Global variable to store load balancing loss
+_LOAD_BALANCING_LOSS = []
+
+
+def save_load_balancing_loss(loss):
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.append(loss)
+
+
+def get_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ return _LOAD_BALANCING_LOSS
+
+
+def clear_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.clear()
+
+
+def batched_load_balancing_loss(args):
+ if args.moe_loss_weight == 0:
+ return 0.0
+
+ tokens_per_expert, expert_scores = zip(*get_load_balancing_loss())
+ num_layers_per_pipeline_stage = args.num_layers // args.pipeline_model_parallel_size
+ if args.num_layers_per_virtual_pipeline_stage is not None:
+ num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage
+
+ if len(tokens_per_expert) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f"Expected {num_layers_per_pipeline_stage} token_per_experts "
+ f"but found {len(tokens_per_expert)}.\nnum_layers = "
+ f"{args.num_layers}\npipeline_model_parallel_size = "
+ f"{args.pipeline_model_parallel_size}\n"
+ "num_layers_per_virtual_pipeline_stage"
+ f" = {args.num_layers_per_virtual_pipeline_stage}",
+ )
+ if len(expert_scores) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f"Expected {num_layers_per_pipeline_stage} expert_scores "
+ f"but found {len(tokens_per_expert)}.\nnum_layers = "
+ f"{args.num_layers}\npipeline_model_parallel_size = "
+ f"{args.pipeline_model_parallel_size}\n"
+ "num_layers_per_virtual_pipeline_stage"
+ f" = {args.num_layers_per_virtual_pipeline_stage}",
+ )
+
+ # Verify the shape of the tokens_per_expert and expert_scores tensors.
+ assert all(
+ (x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert)
+ )
+
+ tokens = expert_scores[0].shape[0]
+ assert all(
+ (
+ (
+ x.ndim == 2
+ and x.shape[1] == args.moe_num_experts
+ and x.shape[0] == tokens
+ )
+ for x in expert_scores
+ )
+ )
+
+ # Concatenate the contributions of each layer and convert to
+ # the correct types and formats for the dot product.
+ expert_scores = torch.cat(expert_scores, dim=1)
+ if args.moe_lbl_in_fp32:
+ expert_scores = expert_scores.float()
+ if tokens != 0:
+ expert_scores = expert_scores.mean(dim=0)
+ else:
+ expert_scores = expert_scores.sum(dim=0)
+ tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype)
+
+ expected_values = num_layers_per_pipeline_stage * args.moe_num_experts
+ assert tokens_per_expert.numel() == expected_values
+ assert expert_scores.numel() == expected_values
+
+ # Calculate the total scale across all factors.
+ #
+ # loss_weight * num_experts / (num_layers * tokens * top_k)
+ scale_numerator = args.moe_num_experts * args.moe_loss_weight
+ scale_denominator = args.num_layers * tokens * args.moe_top_k
+ scale = scale_numerator / scale_denominator
+ return scale * torch.dot(tokens_per_expert, expert_scores)
+
+
+# Calculate the expert capacity based on tokens, top_k, number of experts,
+# expert parallel group, capacity factor, and whether expert model parallelism is used.
+def expert_capacity(
+ tokens: int,
+ top_k: int,
+ num_experts: int,
+ expert_parallel_group: int,
+ moe_capacity_factor: float,
+ moe_expert_model_parallelism: bool,
+) -> int:
+ world_size = (
+ dist.get_world_size(expert_parallel_group)
+ if moe_expert_model_parallelism
+ else 1
+ )
+
+ tokens_per_expert = top_k * tokens * world_size / num_experts
+ return int(moe_capacity_factor * tokens_per_expert)
+
+
+def load_balancing_loss(
+ tokens_per_expert: torch.Tensor,
+ expert_scores: torch.Tensor,
+ top_k: int,
+ num_experts: int,
+):
+ assert len(expert_scores.size()) == 2
+ tokens, num_experts = expert_scores.size()
+ assert num_experts == num_experts
+ assert len(tokens_per_expert.size()) == 1
+ (num_experts,) = tokens_per_expert.size()
+ assert num_experts == num_experts
+ scale = num_experts / (tokens * top_k)
+ return scale * torch.dot(
+ tokens_per_expert.to(expert_scores.dtype),
+ expert_scores.mean(dim=0),
+ )
+
+
+def indices_and_bins(
+ top_expert: torch.Tensor,
+ sort_end_bit: int,
+ num_experts: int,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ top_expert = top_expert.int()
+
+ # Ensure contiguous memory layout
+ top_expert = top_expert.contiguous()
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(top_expert.device):
+ output = ops.sort(top_expert, sort_end_bit)
+ bin_ids, indices = output
+ tokens_per_expert = ops.histogram(top_expert, num_experts)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+
+ bins = bins.view(1) if not len(bins.size()) else bins
+ return indices, bin_ids, bins, tokens_per_expert
+
+
+def expert_capacity_fn(
+ tokens: int,
+ top_k: int,
+ num_experts: int,
+ expert_parallel_group: torch.distributed.ProcessGroup,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+) -> int:
+ world_size = (
+ dist.get_world_size(expert_parallel_group)
+ if moe_expert_model_parallelism
+ else 1
+ )
+ tokens_per_expert = top_k * tokens * world_size / num_experts
+ return int(moe_capacity_factor * tokens_per_expert)
+
+
+def permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ top_k,
+ w1,
+ w2,
+ w1_bias,
+ w2_bias,
+ gradient_scale,
+ alpha,
+):
+ # Route tokens to experts
+ x = x.view(-1, x.shape[-1])
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(x.device):
+ x = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
+
+ # Expert computation
+ x = mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale, alpha)
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(x.device):
+ # Route tokens back
+ out = ops.binned_scatter(x, indices, expert_weights, bins, top_k)
+ return out
+
+
+def forward_once(
+ x: torch.Tensor,
+ expert_weights: torch.Tensor,
+ top_experts: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ top_k: int = 4,
+ num_experts: int = 128,
+ expert_parallel_group: int = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ mlp_impl: Optional[str] = None,
+):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
+ top_experts, sort_end_bit, num_experts
+ )
+
+ # Calculate expert capacity
+ sl, bs, _ = x.size()
+
+ expert_capacity = expert_capacity_fn(
+ sl * bs,
+ top_k,
+ num_experts,
+ expert_parallel_group,
+ moe_capacity_factor,
+ moe_expert_model_parallelism,
+ )
+
+ if expert_capacity == 0:
+ expert_capacity = torch.max(tokens_per_expert).item()
+
+ x = permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ top_k,
+ w1,
+ w2,
+ w1_bias,
+ w2_bias,
+ gradient_scale,
+ alpha,
+ )
+ return x, tokens_per_expert
+
+
+def parallel_forward_once(
+ x: torch.Tensor,
+ expert_weights: torch.Tensor,
+ top_experts: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ top_k: int = 4,
+ num_experts: int = 128,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = True,
+ hidden_size: int = 1152,
+ mlp_impl: Optional[str] = "grouped",
+):
+ # Flatten inputs
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+
+ # TODO: remove debugging var
+ # my_rank = dist.get_rank(expert_parallel_group) if expert_parallel_group else 0
+
+ with torch.no_grad():
+ # Step 1: Local permutation setup
+ indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
+ top_experts, sort_end_bit, num_experts
+ )
+
+ # Calculate sharding parameters
+ world_size = dist.get_world_size(expert_parallel_group)
+ hidden_sharding_deg = hidden_sharding_degree(
+ world_size, num_experts, hidden_size
+ )
+ experts_per_rank_val = experts_per_rank(num_experts, world_size)
+
+ # Replicate token counts for hidden sharding
+ repeated_tokens_per_expert = ops.repeat(
+ tokens_per_expert, (hidden_sharding_deg,)
+ )
+
+ # Exchange token counts across devices
+ parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert)
+
+ # Ensure CUB knows which device to use
+ tpe_handle = dist.all_to_all_single(
+ parallel_tokens_per_expert,
+ repeated_tokens_per_expert,
+ group=expert_parallel_group,
+ async_op=True,
+ )
+
+ # Step 2: Local permutation - group tokens by target device
+ x = x.view(-1, x.shape[-1]) # [sl * bs, hs]
+ x = ops.gather(x, indices, bin_ids, bins, top_k)
+
+ # Step 3: Compute communication counts and exchange tokens
+ with torch.no_grad():
+ tpe_handle.wait()
+
+ # Reshape for per-device calculations
+ repeated_tokens_per_expert = repeated_tokens_per_expert.view(
+ world_size, experts_per_rank_val
+ )
+ parallel_tokens_per_expert = parallel_tokens_per_expert.view(
+ world_size, experts_per_rank_val
+ )
+
+ # Calculate send/recv counts
+ send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1).tolist()
+ # recv_counts = parallel_tokens_per_expert.cpu().sum(dim=-1).tolist()
+ parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
+ recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1).tolist()
+ tokens_received = sum(recv_counts)
+
+ # Replicate for hidden sharding
+ x = ops.repeat(x, (hidden_sharding_deg, 1))
+
+ # Cross-device token exchange
+ parallel_x, parallel_x_handle = _layers.all_to_all.all_to_all(
+ x, recv_counts, send_counts, expert_parallel_group, async_op=True
+ )
+
+ with torch.no_grad():
+ # Step 4: Setup for local expert computation
+ replicate_bins = ops.inclusive_cumsum(parallel_tokens_per_expert.flatten(), 0)
+ replicate_bins = (
+ replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins
+ )
+
+ # Create expert indices for received tokens
+ parallel_top_expert = torch.remainder(
+ torch.arange(
+ num_experts * hidden_sharding_deg,
+ dtype=torch.int32,
+ device=indices.device,
+ ),
+ experts_per_rank_val,
+ )
+ parallel_top_expert = ops.replicate(
+ parallel_top_expert.unsqueeze(dim=0),
+ replicate_bins,
+ tokens_received,
+ ).flatten()
+
+ # Sort tokens by expert assignment
+ parallel_bin_ids, parallel_indices = ops.sort(
+ parallel_top_expert,
+ sort_end_bit,
+ )
+
+ # Calculate bins for local experts
+ parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
+ dim=0, dtype=torch.int
+ )
+ parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = (
+ parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins
+ )
+
+ # Calculate expert capacity
+ expert_capacity = expert_capacity_fn(
+ tokens_received,
+ top_k,
+ experts_per_rank_val,
+ expert_parallel_group,
+ moe_capacity_factor,
+ moe_expert_model_parallelism,
+ )
+ if expert_capacity == 0:
+ expert_capacity = torch.max(parallel_tokens_per_expert).item()
+
+ # Locally permute the tokens and perform the expert computation.
+ # Block to make sure that the cross-device permutation is complete.
+ if mlp_impl == "grouped":
+ # GroupedMLP requires counts on CPU. We can use the tensor already
+ # moved to CPU for the prior all_to_all, which avoids an extra
+ # device synchronization.
+ parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+
+ # Step 5: Expert computation
+ parallel_x_handle.wait()
+
+ parallel_x = permute_and_compute(
+ parallel_x,
+ parallel_tokens_per_expert,
+ parallel_indices,
+ parallel_bin_ids,
+ None, # expert_weights
+ parallel_bins,
+ expert_capacity,
+ top_k=1,
+ w1=w1,
+ w2=w2,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ )
+
+ # Step 6: Reverse communication - send results back
+ x, _ = _layers.all_to_all.all_to_all(
+ parallel_x, send_counts, recv_counts, expert_parallel_group
+ )
+
+ # Step 7: Reduce across hidden sharding dimension
+ shape = (hidden_sharding_deg, -1, hidden_size)
+ x = x.view(shape).sum(dim=0)
+
+ # Step 8: Final local unpermutation
+ x = ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
+
+ return x, tokens_per_expert.flatten()
+
+
+def moe_forward(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: Optional[torch.Tensor],
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+ w1: torch.Tensor = None,
+ w2: torch.Tensor = None,
+ w1_bias: torch.Tensor = None,
+ w2_bias: torch.Tensor = None,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ forward_fn: Any = None,
+ hidden_size: int = None,
+ mlp_impl: str = "grouped",
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+
+ # Route tokens to experts
+ logits, expert_weights, expert_indices = route_tokens(
+ x,
+ router_weight,
+ router_bias,
+ moe_top_k,
+ moe_num_experts,
+ moe_jitter_eps,
+ moe_normalize_expert_weights,
+ uniform_expert_assignment,
+ training,
+ )
+
+ # Create router scores for output
+ router_scores = (
+ torch.zeros_like(logits)
+ .scatter_(1, expert_indices, expert_weights)
+ .transpose(0, 1)
+ )
+
+ in_shape = x.size()
+
+ # Prepare forward function arguments
+ forward_args = {
+ "x": x,
+ "expert_weights": expert_weights,
+ "top_experts": expert_indices,
+ "w1": w1,
+ "w2": w2,
+ "w1_bias": w1_bias,
+ "w2_bias": w2_bias,
+ "gradient_scale": gradient_scale,
+ "alpha": alpha,
+ "sort_end_bit": sort_end_bit,
+ "top_k": moe_top_k,
+ "num_experts": moe_num_experts,
+ "expert_parallel_group": expert_parallel_group,
+ "moe_capacity_factor": moe_capacity_factor,
+ "moe_expert_model_parallelism": moe_expert_model_parallelism,
+ "mlp_impl": mlp_impl,
+ }
+
+ # Add hidden_size for parallel forward
+ if moe_expert_model_parallelism and hidden_size is not None:
+ forward_args["hidden_size"] = hidden_size
+ elif moe_expert_model_parallelism and hidden_size is None:
+ # Infer hidden_size from input shape
+ forward_args["hidden_size"] = x.shape[-1]
+
+ # Compute expert outputs
+ x, tokens_per_expert = forward_fn(**forward_args)
+
+ # Save load balancing loss if needed
+ moe_loss_weight = 0.0 # Can be made configurable
+ if training and moe_loss_weight > 0:
+ save_load_balancing_loss((tokens_per_expert, logits))
+
+ # Restore original shape
+ x = x.view(in_shape)
+
+ return x, expert_weights, router_scores
+
+
+def moe_forward_with_shared_expert(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: Optional[torch.Tensor],
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+ w1: torch.Tensor = None,
+ w2: torch.Tensor = None,
+ w1_bias: torch.Tensor = None,
+ w2_bias: torch.Tensor = None,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ forward_fn: Any = None,
+ hidden_size: int = None,
+ mlp_impl: str = "grouped",
+ # Shared expert parameters
+ shared_up_proj_weight: Optional[torch.Tensor] = None,
+ shared_down_proj_weight: Optional[torch.Tensor] = None,
+ shared_up_proj_bias: Optional[torch.Tensor] = None,
+ shared_down_proj_bias: Optional[torch.Tensor] = None,
+ shared_expert_weighted_sum: bool = False,
+ shared_activation_fn: Optional[Any] = None,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+
+ # First, compute regular MoE forward pass
+ expert_out, expert_weights, router_scores = moe_forward(
+ x=x,
+ router_weight=router_weight,
+ router_bias=router_bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=training,
+ w1=w1,
+ w2=w2,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=moe_expert_model_parallelism,
+ forward_fn=forward_fn,
+ hidden_size=hidden_size,
+ mlp_impl=mlp_impl,
+ )
+
+ # If shared expert weights provided, compute shared expert output
+ if shared_up_proj_weight is not None and shared_down_proj_weight is not None:
+ shared_expert_out = shared_mlp_forward(
+ x=x,
+ up_proj_weight=shared_up_proj_weight,
+ down_proj_weight=shared_down_proj_weight,
+ up_proj_bias=shared_up_proj_bias,
+ down_proj_bias=shared_down_proj_bias,
+ activation_fn=shared_activation_fn,
+ gradient_scale=gradient_scale,
+ )
+
+ # Combine expert outputs
+ combined_out = combine_expert_shared_outputs(
+ shared_expert_out=shared_expert_out,
+ expert_out=expert_out,
+ shared_expert_weighted_sum=shared_expert_weighted_sum,
+ moe_top_k=moe_top_k,
+ )
+
+ return combined_out, expert_weights, router_scores
+
+ # Return regular MoE output if no shared expert
+ return expert_out, expert_weights, router_scores
+
+
+def create_shared_expert_weights(
+ hidden_size: int,
+ shared_expert_hidden_size: int,
+ device: torch.device,
+ dtype: torch.dtype,
+ init_method: Any,
+ output_layer_init_method: Any = None,
+) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
+
+ if output_layer_init_method is None:
+ output_layer_init_method = init_method
+
+ # Create weight tensors
+ up_proj_weight = torch.empty(
+ shared_expert_hidden_size,
+ hidden_size,
+ device=device,
+ dtype=dtype,
+ )
+ down_proj_weight = torch.empty(
+ hidden_size,
+ shared_expert_hidden_size,
+ device=device,
+ dtype=dtype,
+ )
+
+ # Initialize weights
+ init_method(up_proj_weight)
+ output_layer_init_method(down_proj_weight)
+
+ # No bias by default
+ return up_proj_weight, down_proj_weight, None, None
+
+
+# HACK: Extract device_mesh from pre-hook closure - required for transformers integration
+# This exists because device_mesh is trapped in hook closures with no model attribute
+# Fragile - breaks if hook structure changes or Python internals change
+# TODO: Replace with a more robust solution when available
+def get_device_mesh(model):
+ # Extract device_mesh from child's unused pre_hook closure
+ try:
+ # Find the pre-hook that contains 'device_mesh' in its closure
+ hook = next(
+ h
+ for h in model.experts._forward_pre_hooks.values()
+ if "device_mesh" in h.__code__.co_freevars
+ )
+ # Extract the device_mesh from the closure
+ return hook.__closure__[
+ hook.__code__.co_freevars.index("device_mesh")
+ ].cell_contents
+ except Exception:
+ return None
+
+
+class MegaBlocksMoeMLP(torch.nn.Module):
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ gradient_scale = getattr(self.experts, "gradient_scale", None)
+ alpha = getattr(self.experts, "alpha", 1.0)
+ moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+ uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
+
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = get_device_mesh(self)
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
+
+ has_parallel = (
+ expert_parallel_group is not None
+ and dist.is_initialized()
+ and dist.get_world_size(expert_parallel_group) > 1
+ )
+ forward_fn = parallel_forward_once if has_parallel else forward_once
+
+ sort_end_bit = max(
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
+ )
+ mlp_impl = getattr(self, "mlp_impl", "grouped")
+ output, expert_weights_out, *_ = moe_forward(
+ x=x,
+ router_weight=self.router.weight,
+ router_bias=self.router.bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=self.training,
+ w1=self.experts.gate_up_proj,
+ w2=self.experts.down_proj,
+ w1_bias=self.experts.gate_up_proj_bias,
+ w2_bias=self.experts.down_proj_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=has_parallel,
+ forward_fn=forward_fn,
+ hidden_size=self.experts.hidden_size,
+ mlp_impl=mlp_impl,
+ )
+ return output, expert_weights_out
+
+
+# Export main classes
+__all__ = ["MegaBlocksMoeMLP", "MegaBlocksMoeMLPWithSharedExpert"]
+
+
+class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
+
+ def __init__(self):
+ super().__init__()
+ # Shared expert weights will be set by the user
+ self.shared_up_proj_weight = None
+ self.shared_down_proj_weight = None
+ self.shared_up_proj_bias = None
+ self.shared_down_proj_bias = None
+ self.shared_expert_weighted_sum = False
+ self.shared_activation_fn = None
+
+ def set_shared_expert_weights(
+ self,
+ up_proj_weight: torch.Tensor,
+ down_proj_weight: torch.Tensor,
+ up_proj_bias: Optional[torch.Tensor] = None,
+ down_proj_bias: Optional[torch.Tensor] = None,
+ weighted_sum: bool = False,
+ activation_fn: Optional[Any] = None,
+ ):
+ self.shared_up_proj_weight = up_proj_weight
+ self.shared_down_proj_weight = down_proj_weight
+ self.shared_up_proj_bias = up_proj_bias
+ self.shared_down_proj_bias = down_proj_bias
+ self.shared_expert_weighted_sum = weighted_sum
+ self.shared_activation_fn = activation_fn
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ gradient_scale = getattr(self.experts, "gradient_scale", None)
+ alpha = getattr(self.experts, "alpha", 1.0)
+ moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+ uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
+
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = get_device_mesh(self)
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
+
+ has_parallel = (
+ expert_parallel_group is not None
+ and dist.is_initialized()
+ and dist.get_world_size(expert_parallel_group) > 1
+ )
+ forward_fn = parallel_forward_once if has_parallel else forward_once
+
+ sort_end_bit = max(
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
+ )
+ mlp_impl = getattr(self, "mlp_impl", "grouped")
+
+ output, expert_weights_out, *_ = moe_forward_with_shared_expert(
+ x=x,
+ router_weight=self.router.weight,
+ router_bias=self.router.bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=self.training,
+ w1=self.experts.gate_up_proj,
+ w2=self.experts.down_proj,
+ w1_bias=self.experts.gate_up_proj_bias,
+ w2_bias=self.experts.down_proj_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=has_parallel,
+ forward_fn=forward_fn,
+ hidden_size=self.experts.hidden_size,
+ mlp_impl=mlp_impl,
+ # Shared expert parameters
+ shared_up_proj_weight=self.shared_up_proj_weight,
+ shared_down_proj_weight=self.shared_down_proj_weight,
+ shared_up_proj_bias=self.shared_up_proj_bias,
+ shared_down_proj_bias=self.shared_down_proj_bias,
+ shared_expert_weighted_sum=self.shared_expert_weighted_sum,
+ shared_activation_fn=self.shared_activation_fn,
+ )
+ return output, expert_weights_out
+
+
+# Patch for XPU or CPU support
+if hasattr(torch, "xpu") and torch.xpu.is_available():
+ from .xpu_fused_moe import MegaBlocksMoeMLP
+
+from .cpu_moe_cpp import CPUMegaBlocksMoeMLP
diff --git a/build/torch210-cxx11-cu126-x86_64-linux/megablocks/__init__.py b/build/torch210-cxx11-cu126-x86_64-linux/megablocks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23
--- /dev/null
+++ b/build/torch210-cxx11-cu126-x86_64-linux/megablocks/__init__.py
@@ -0,0 +1,26 @@
+import ctypes
+import importlib.util
+import sys
+from pathlib import Path
+from types import ModuleType
+
+
+def _import_from_path(file_path: Path) -> ModuleType:
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
+ # it would also be used for other imports. So, we make a module name that
+ # depends on the path for it to be unique using the hex-encoded hash of
+ # the path.
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
+ module_name = path_hash
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ if spec is None:
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
+ module = importlib.util.module_from_spec(spec)
+ if module is None:
+ raise ImportError(f"Cannot load module {module_name} from spec")
+ sys.modules[module_name] = module
+ spec.loader.exec_module(module) # type: ignore
+ return module
+
+
+globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
diff --git a/build/torch210-cxx11-cu126-x86_64-linux/metadata.json b/build/torch210-cxx11-cu126-x86_64-linux/metadata.json
new file mode 100644
index 0000000000000000000000000000000000000000..155112c59509d3b4d07f4d090cbf57071e3f5217
--- /dev/null
+++ b/build/torch210-cxx11-cu126-x86_64-linux/metadata.json
@@ -0,0 +1,18 @@
+{
+ "version": 1,
+ "license": "Apache-2.0",
+ "python-depends": [],
+ "backend": {
+ "type": "cuda",
+ "archs": [
+ "7.0",
+ "7.2",
+ "7.5",
+ "8.0",
+ "8.6",
+ "8.7",
+ "8.9",
+ "9.0"
+ ]
+ }
+}
diff --git a/build/torch210-cxx11-cu126-x86_64-linux/ops/__init__.py b/build/torch210-cxx11-cu126-x86_64-linux/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b944080df810d0b0cfc571f3009b0098a651f9b7
--- /dev/null
+++ b/build/torch210-cxx11-cu126-x86_64-linux/ops/__init__.py
@@ -0,0 +1,35 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from .binned_gather import binned_gather
+from .binned_scatter import binned_scatter
+from .cumsum import exclusive_cumsum, inclusive_cumsum
+from .gather import gather
+from .histogram import histogram
+from .padded_gather import padded_gather
+from .padded_scatter import padded_scatter
+from .repeat import repeat
+from .replicate import replicate
+from .round_up import round_up
+from .scatter import scatter
+from .sort import sort
+from .sum import sum
+from .topology import topology
+
+__all__ = [
+ 'binned_gather',
+ 'binned_scatter',
+ 'exclusive_cumsum',
+ 'inclusive_cumsum',
+ 'gather',
+ 'histogram',
+ 'padded_gather',
+ 'padded_scatter',
+ 'repeat',
+ 'replicate',
+ 'round_up',
+ 'scatter',
+ 'sort',
+ 'sum',
+ 'topology',
+]
diff --git a/build/torch210-cxx11-cu126-x86_64-linux/ops/all_to_all_benchmark.py b/build/torch210-cxx11-cu126-x86_64-linux/ops/all_to_all_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c939818edca3345f6344bbc7cef07ffe3cd0181
--- /dev/null
+++ b/build/torch210-cxx11-cu126-x86_64-linux/ops/all_to_all_benchmark.py
@@ -0,0 +1,63 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import torch.distributed as dist
+
+# from megablocks import benchmark_util
+# from megablocks.layers.all_to_all import all_to_all
+
+from .. import benchmark_util
+from .._layers.all_to_all import all_to_all
+
+_ALL_TO_ALL_BENCHMARK = (
+ (8, 1024),
+ (16, 1024),
+ (32, 1024),
+ (64, 1024),
+ (128, 1024),
+ (256, 1024),
+ (512, 1024),
+ (1024, 1024),
+ (2 * 1024, 1024),
+ (4 * 1024, 1024),
+ (8 * 1024, 1024),
+ (16 * 1024, 1024),
+ (32 * 1024, 1024),
+ (64 * 1024, 1024),
+ (128 * 1024, 1024),
+ (256 * 1024, 1024),
+ (512 * 1024, 1024),
+ (1024 * 1024, 1024),
+)
+
+
+def benchmark_all_to_all(group, sl, hs):
+ world_size = dist.get_world_size(group)
+ assert (sl % world_size) == 0
+ send_recv_sizes = [sl // world_size] * world_size
+
+ x = torch.randn((sl, hs)).cuda().half()
+
+ details = {
+ 'world_size': world_size,
+ 'message_size (B)': send_recv_sizes[0] * hs * 2, # 2B elements.
+ }
+
+ def benchmark():
+ return all_to_all(x, send_recv_sizes, send_recv_sizes, group)
+
+ time, std = benchmark_util.benchmark_function(benchmark)
+
+ if dist.get_rank(group) == 0:
+ benchmark_util.log_benchmark('All-To-All', details, time, std)
+
+
+if __name__ == '__main__':
+ assert dist.is_available()
+ group = dist.init_process_group(backend='nccl')
+ local_rank = dist.get_rank(group)
+ torch.cuda.set_device(local_rank)
+
+ for args in _ALL_TO_ALL_BENCHMARK:
+ benchmark_all_to_all(group, *args)
diff --git a/build/torch210-cxx11-cu126-x86_64-linux/ops/binned_gather.py b/build/torch210-cxx11-cu126-x86_64-linux/ops/binned_gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..189a7fa3518d660f29ea32e7a04827164af98d60
--- /dev/null
+++ b/build/torch210-cxx11-cu126-x86_64-linux/ops/binned_gather.py
@@ -0,0 +1,37 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for binned_gather kernel.
+class BinnedGatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bins: torch.Tensor,
+ bin_size: int,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bins)
+ ctx.top_k = top_k
+ return kernels.binned_gather(x, indices, None, bins, bin_size, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ indices, bins = ctx.saved_tensors
+ out = kernels.binned_scatter(grad, indices, None, bins, ctx.top_k)
+ return out, None, None, None, None
+
+
+binned_gather = BinnedGatherOp.apply
diff --git a/build/torch210-cxx11-cu126-x86_64-linux/ops/binned_scatter.py b/build/torch210-cxx11-cu126-x86_64-linux/ops/binned_scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb937c0c106662ce8108c1cb926f8f063b163d3d
--- /dev/null
+++ b/build/torch210-cxx11-cu126-x86_64-linux/ops/binned_scatter.py
@@ -0,0 +1,59 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for binned_scatter kernel.
+class BinnedScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ):
+ assert len(x.size()) == 3
+ ctx.bin_size = x.size(1)
+ ctx.top_k = top_k
+
+ # TODO(tgale): Don't save 'x' for backwards if we don't need to
+ # calculate the gradient w.r.t. 'weights'.
+ ctx.save_for_backward(x, indices, weights, bins)
+ return kernels.binned_scatter(x, indices, weights, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ x, indices, weights, bins = ctx.saved_tensors
+ out = kernels.binned_gather(
+ grad,
+ indices,
+ weights,
+ bins,
+ ctx.bin_size,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[2]:
+ wgrad = kernels.binned_scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bins,
+ ctx.top_k,
+ )
+ return out, None, wgrad, None, None
+
+
+binned_scatter = BinnedScatterOp.apply
diff --git a/build/torch210-cxx11-cu126-x86_64-linux/ops/cumsum.py b/build/torch210-cxx11-cu126-x86_64-linux/ops/cumsum.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2b7572391e20045d335cf7337246e8a9b9f57ef
--- /dev/null
+++ b/build/torch210-cxx11-cu126-x86_64-linux/ops/cumsum.py
@@ -0,0 +1,52 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ # import megablocks_ops as ops # type: ignore
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrappers for cumsum kernels.
+# NOTE: Does not support gradients.
+class ExclusiveCumsumOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, dim: int):
+ if len(x.size()) == 1:
+ x = x.view([1, -1])
+ out = torch.empty_like(x)
+ ops.exclusive_cumsum(x, 1, out)
+ return out.squeeze()
+ out = torch.empty_like(x)
+ ops.exclusive_cumsum(x, dim, out)
+ return out
+
+
+exclusive_cumsum = ExclusiveCumsumOp.apply
+
+
+class InclusiveCumsumOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, dim: int) -> torch.Tensor:
+ if len(x.size()) == 1:
+ x = x.view([1, -1])
+ out = torch.empty_like(x)
+ ops.inclusive_cumsum(x, 1, out)
+ return out.squeeze()
+ out = torch.empty_like(x)
+ ops.inclusive_cumsum(x, dim, out)
+ return out
+
+
+inclusive_cumsum = InclusiveCumsumOp.apply
diff --git a/build/torch210-cxx11-cu126-x86_64-linux/ops/gather.py b/build/torch210-cxx11-cu126-x86_64-linux/ops/gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1f87c1e7bed8d3589dd790805234976e0b05898
--- /dev/null
+++ b/build/torch210-cxx11-cu126-x86_64-linux/ops/gather.py
@@ -0,0 +1,38 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for gather kernel.
+class GatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bin_ids, bins)
+ ctx.top_k = top_k
+ return kernels.gather(x, indices, bin_ids, None, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+
+ indices, bin_ids, bins = ctx.saved_tensors
+ out = kernels.scatter(grad, indices, bin_ids, None, bins, ctx.top_k)
+ return out, None, None, None, None, None
+
+
+gather = GatherOp.apply
diff --git a/build/torch210-cxx11-cu126-x86_64-linux/ops/histogram.py b/build/torch210-cxx11-cu126-x86_64-linux/ops/histogram.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b3f058ec373cbba7555704fb5e4212c3cc75d9d
--- /dev/null
+++ b/build/torch210-cxx11-cu126-x86_64-linux/ops/histogram.py
@@ -0,0 +1,27 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for histogram kernel.
+# NOTE: Does not support gradients.
+class HistogramOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, max_val: float):
+ return ops.histogram(x, max_val)
+
+
+histogram = HistogramOp.apply
diff --git a/build/torch210-cxx11-cu126-x86_64-linux/ops/histogram_benchmark.py b/build/torch210-cxx11-cu126-x86_64-linux/ops/histogram_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0b9c6047567b87a295979498142230d1b0c9411
--- /dev/null
+++ b/build/torch210-cxx11-cu126-x86_64-linux/ops/histogram_benchmark.py
@@ -0,0 +1,78 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import numpy as np
+import torch
+# from absl.testing import parameterized
+
+from .. import ops
+
+_HISTOGRAM_TESTS = (
+ (16384, torch.int32, 2),
+ (16384, torch.int32, 4),
+ (16384, torch.int32, 8),
+ (16384, torch.int32, 16),
+ (16384, torch.int32, 32),
+ (16384, torch.int32, 64),
+ (16384, torch.int32, 128),
+ (16384, torch.int32, 256),
+)
+
+
+def benchmark_function(fn, iterations=10):
+ # Run once to get rid of startup overhead.
+ fn()
+ times = []
+ for _ in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+ start.record()
+ fn()
+ end.record()
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ times = np.array(times)
+ return times.mean(), times.std(), times.max(), times.min()
+
+
+def log_benchmark(arguments, mean_t, std_t):
+ print('=' * 60)
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t))
+ print('=' * 60)
+
+
+# class HistogramBenchmark(parameterized.TestCase):
+#
+# @parameterized.parameters(*_HISTOGRAM_TESTS)
+# def testHistogram(self, n, dtype, max_val):
+# x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
+#
+# mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.histogram(x, max_val),)
+# arguments = {
+# 'n': n,
+# 'dtype': dtype,
+# 'max_val': max_val,
+# }
+# log_benchmark(arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_HISTOGRAM_TESTS)
+# def testTorchHistogram(self, n, dtype, max_val):
+# x = torch.randint(0, 128, (n,)).cuda().to(dtype)
+#
+# mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.histc(x, max_val, 0, max_val - 1),)
+# arguments = {
+# 'n': n,
+# 'dtype': dtype,
+# 'max_val': max_val,
+# }
+# log_benchmark(arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch210-cxx11-cu126-x86_64-linux/ops/matmul_benchmark.py b/build/torch210-cxx11-cu126-x86_64-linux/ops/matmul_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c4b9b8866ffed2eb769b77f2320c82e5491ae0e
--- /dev/null
+++ b/build/torch210-cxx11-cu126-x86_64-linux/ops/matmul_benchmark.py
@@ -0,0 +1,415 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+
+# import stk
+
+# try:
+# import stk
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/ops/matmul_benchmark.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+# from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+
+# Calling tensor.t() calls tensor.transpose(0, 1) which calls
+# torch.as_strided(...). Circumvent this chain to avoid an overhead
+# this adds.
+def transpose_view(x):
+ return torch.as_strided(
+ x,
+ (x.shape[1], x.shape[0]),
+ (x.stride()[1], x.stride()[0]),
+ )
+
+
+_MATMUL_TESTS = (
+ (64 * 1024, 512, 2048, 64),
+ (32 * 1024, 768, 3072, 64),
+ (8 * 1024, 1024, 4096, 64),
+ (4 * 2048, 4096, 4 * 4096, 4),
+)
+
+
+def log_benchmark(name, arguments, time, std, flops):
+ benchmark_util.log_benchmark(name, arguments, time, std)
+ print('flops = {:.2f}B'.format(flops / 1e9))
+ print('throughput = {:.2f}T'.format(flops / 1e9 / time))
+ print('=' * 60)
+
+
+# class MatmulBenchmark(parameterized.TestCase):
+#
+# def build_sparse_matrix(self, x, padded_bins, fhs, ne):
+# blocking = 128
+# padded_tokens, _ = x.size()
+# assert padded_tokens % blocking == 0
+# assert fhs % blocking == 0
+#
+# # Offsets for the sparse matrix. All rows have the
+# # same number of nonzero blocks dictated by the
+# # dimensionality of a single expert.
+# block_rows = padded_tokens // blocking
+# blocks_per_row = fhs // blocking
+# offsets = torch.arange(
+# 0,
+# block_rows * blocks_per_row + 1,
+# blocks_per_row,
+# dtype=torch.int32,
+# device=x.device,
+# )
+#
+# # Indices for the sparse matrix. The indices for
+# # the intermediate matrix are dynamic depending
+# # on the mapping of tokens to experts.
+# column_indices = ops.topology(
+# padded_bins,
+# blocking,
+# block_rows,
+# blocks_per_row,
+# )
+# data = torch.empty(
+# column_indices.numel(),
+# blocking,
+# blocking,
+# dtype=torch.float16,
+# device=x.device,
+# )
+# shape = (padded_tokens, fhs * ne)
+# row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
+# return stk.Matrix(shape, data, row_indices, column_indices, offsets)
+#
+# def build_input_matrix(self, sl, hs, ne):
+# x = torch.randn((sl, hs)).cuda().half()
+#
+# # Assign tokens to experts uniformly.
+# top_expert = torch.arange(0, sl).cuda().int() % ne
+#
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(top_expert, ne)
+# padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+# padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+# out = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, 1)
+# return out, padded_bins
+#
+# def build_weight_matrix(self, ne, hs, fhs):
+# return torch.randn((hs, ne * fhs)).cuda().half()
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_Fwd_SDD_NT(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+# w = transpose_view(w)
+#
+# def benchmark():
+# return stk.ops.sdd(x, w, topo)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0::Fwd::SDD::NT',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_GradX_DSD_NN(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+#
+# def benchmark():
+# return stk.ops.dsd(topo, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0::GradX::DSD::NN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_GradW_DSD_TN(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+# topo = topo.t()
+#
+# def benchmark():
+# return stk.ops.dsd(topo, x)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0::GradW::DSD::TN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_Fwd_DSD_NN(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+#
+# def benchmark():
+# return stk.ops.dsd(x, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::Fwd::DSD::NN',
+# arguments,
+# mean_t,
+# std_t,
+# x.nnz * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_GradX_SDD_NT(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+# out = stk.ops.dsd(x, w)
+# w = transpose_view(w)
+#
+# def benchmark():
+# return stk.ops.sdd(out, w, x)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::GradX::SDD::NT',
+# arguments,
+# mean_t,
+# std_t,
+# x.nnz * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_GradW_DSD_TN(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+# out = stk.ops.dsd(x, w)
+# x = x.t()
+#
+# def benchmark():
+# return stk.ops.dsd(x, out)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::GradW::DSD::TN',
+# arguments,
+# mean_t,
+# std_t,
+# x.nnz * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_Fwd_DDD_NT(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, hs)).cuda().half()
+# w = torch.randn((ne, hs, fhs)).cuda().half()
+#
+# w = w.transpose(1, 2).contiguous()
+# w = w.transpose(1, 2)
+#
+# def benchmark():
+# return torch.bmm(x, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0::Fwd:DDD::NT',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_GradX_DDD_NN(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, hs)).cuda().half()
+# w = torch.randn((ne, hs, fhs)).cuda().half()
+# out = torch.bmm(x, w)
+# w = w.transpose(1, 2).contiguous()
+#
+# def benchmark():
+# return torch.bmm(out, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0:GradX:DDD::NN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_GradW_DDD_TN(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, hs)).cuda().half()
+# w = torch.randn((ne, hs, fhs)).cuda().half()
+# out = torch.bmm(x, w)
+# out = out.transpose(1, 2)
+#
+# def benchmark():
+# return torch.bmm(out, x)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0:GradW:DDD::TN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_Fwd_DDD_NN(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+# w = torch.randn((ne, fhs, hs)).cuda().half()
+#
+# def benchmark():
+# return torch.bmm(x, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::Fwd::DDD::NN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_GradX_DDD_NT(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+# w = torch.randn((ne, fhs, hs)).cuda().half()
+# out = torch.bmm(x, w)
+# w = torch.transpose(w, 1, 2)
+#
+# def benchmark():
+# return torch.bmm(out, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::GradX::DDD::NT',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_GradW_DDD_TN(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+# w = torch.randn((ne, fhs, hs)).cuda().half()
+# out = torch.bmm(x, w)
+# x = torch.transpose(x, 1, 2)
+#
+# def benchmark():
+# return torch.bmm(x, out)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::GradW::DDD::TN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * hs * 2,
+# )
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch210-cxx11-cu126-x86_64-linux/ops/padded_gather.py b/build/torch210-cxx11-cu126-x86_64-linux/ops/padded_gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1cf4047c9494394d2a3884ba8830179013db7ff
--- /dev/null
+++ b/build/torch210-cxx11-cu126-x86_64-linux/ops/padded_gather.py
@@ -0,0 +1,55 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for padded_gather kernel.
+class PaddedGatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bin_ids, bins, padded_bins)
+ ctx.top_k = top_k
+ return kernels.padded_gather(
+ x,
+ indices,
+ bin_ids,
+ None,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+
+ indices, bin_ids, bins, padded_bins = ctx.saved_tensors
+ out = kernels.padded_scatter(
+ grad,
+ indices,
+ bin_ids,
+ None,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+ return out, None, None, None, None, None
+
+
+padded_gather = PaddedGatherOp.apply
diff --git a/build/torch210-cxx11-cu126-x86_64-linux/ops/padded_scatter.py b/build/torch210-cxx11-cu126-x86_64-linux/ops/padded_scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..61e021b81497e472cda5d72bdac557a0ca92d262
--- /dev/null
+++ b/build/torch210-cxx11-cu126-x86_64-linux/ops/padded_scatter.py
@@ -0,0 +1,98 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for padded_scatter kernel.
+class PaddedScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+ ):
+ maybe_x = [x] if ctx.needs_input_grad[3] else []
+ ctx.save_for_backward(
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ *maybe_x,
+ )
+ ctx.top_k = top_k
+ ctx.x_shape = x.shape
+ return kernels.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ saved_tensors = ctx.saved_tensors
+
+ indices, bin_ids, weights, bins, padded_bins = saved_tensors[:5]
+ dgrad = None
+ if ctx.needs_input_grad[0]:
+ dgrad = kernels.padded_gather(
+ grad,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[3]: # need wgrad
+ x = saved_tensors[-1]
+ wgrad = kernels.padded_scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+ return dgrad, None, None, wgrad, None, None, None, None
+
+
+def padded_scatter(
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+):
+ return PaddedScatterOp.apply(
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
diff --git a/build/torch210-cxx11-cu126-x86_64-linux/ops/padded_scatter_benchmark.py b/build/torch210-cxx11-cu126-x86_64-linux/ops/padded_scatter_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..3cbe4735891446b46f93170c64c23fe63632bf93
--- /dev/null
+++ b/build/torch210-cxx11-cu126-x86_64-linux/ops/padded_scatter_benchmark.py
@@ -0,0 +1,66 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import torch
+# from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+_PADDED_SCATTER_BENCHMARK = (
+ # dMoE-Medium, 8-way EMP.
+ (1024 * 16, 1024, 8, 4),
+ # dMoE-Medium, post-all-to-all.
+ (1024 * 16 * 4, 1024, 8, 1),
+)
+
+
+# class PaddedScatterTest(parameterized.TestCase):
+#
+# @parameterized.parameters(*_PADDED_SCATTER_BENCHMARK)
+# def testPaddedScatter(self, sl, hs, ne, top_k):
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+#
+# # Randomly assign tokens to experts.
+# top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(top_expert, ne)
+# padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+# padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+#
+# # Sample weights for the scatter reduce.
+# weights = torch.rand((sl * top_k,)).cuda().half()
+#
+# # Gather the data to prepare for backwards.
+# x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
+#
+# def benchmark():
+# return ops.padded_scatter(
+# x,
+# indices,
+# bin_ids,
+# weights,
+# bins,
+# padded_bins,
+# top_k,
+# )
+#
+# time, std = benchmark_util.benchmark_function(benchmark)
+# benchmark_util.log_benchmark(
+# 'Padded Scatter',
+# {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# 'top_k': top_k,
+# },
+# time,
+# std,
+# )
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch210-cxx11-cu126-x86_64-linux/ops/permute_benchmark.py b/build/torch210-cxx11-cu126-x86_64-linux/ops/permute_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..697abddbb3a2082ec4ddd6d94f89f7faabb34b40
--- /dev/null
+++ b/build/torch210-cxx11-cu126-x86_64-linux/ops/permute_benchmark.py
@@ -0,0 +1,149 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import torch
+# from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+_PERMUTE_TESTS = (
+ (16384, 768, 2),
+ (16384, 768, 4),
+ (16384, 768, 8),
+ (16384, 768, 16),
+ (16384, 768, 32),
+ (16384, 768, 64),
+ (16384, 768, 128),
+ (16384 * 8, 768, 2),
+ (16384 * 8, 768, 4),
+ (16384 * 8, 768, 8),
+ (16384 * 8, 768, 16),
+ (16384 * 8, 768, 32),
+ (16384 * 8, 768, 64),
+ (16384 * 8, 768, 128),
+)
+
+
+# class PermuteBenchmark(parameterized.TestCase):
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testBinnedGather(self, sl, hs, ne):
+# # NOTE: Capacity factor == 1.
+# ec = sl // ne
+#
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+# top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(indices, ne)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+#
+# def benchmark():
+# return ops.binned_gather(x, indices, bins, ec)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('BinnedGather', arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testBinnedScatter(self, sl, hs, ne):
+# # NOTE: Capacity factor == 1.
+# ec = sl // ne
+#
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+# top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(indices, ne)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+# x = ops.binned_gather(x, indices, bins, ec)
+#
+# def benchmark():
+# return ops.binned_scatter(x, indices, bins)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('BinnedScatter', arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testPaddedGather(self, sl, hs, ne):
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+#
+# # Randomly assign tokens to experts.
+# top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(top_expert, ne)
+# padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+# padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+#
+# def benchmark():
+# return ops.padded_gather(x, indices, bin_ids, bins, padded_bins)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('PaddedGather', arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testPaddedScatter(self, sl, hs, ne):
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+#
+# # Randomly assign tokens to experts.
+# top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(top_expert, ne)
+# padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+# padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+# x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins)
+#
+# def benchmark():
+# return ops.padded_scatter(x, indices, bin_ids, bins, padded_bins)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('PaddedScatter', arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testCopy(self, sl, hs, ne):
+# # NOTE: Capacity factor == 1.
+# # ec = sl // ne
+#
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+# y = x.clone()
+#
+# def benchmark():
+# return y.copy_(x)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('Copy', arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch210-cxx11-cu126-x86_64-linux/ops/repeat.py b/build/torch210-cxx11-cu126-x86_64-linux/ops/repeat.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e9e09de5f857d51cd758ab30b2f3a846d6f9275
--- /dev/null
+++ b/build/torch210-cxx11-cu126-x86_64-linux/ops/repeat.py
@@ -0,0 +1,10 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+
+def repeat(x: torch.Tensor, tiling: torch.Size):
+ if all((t == 1 for t in tiling)):
+ return x
+ return x.repeat(*tiling)
diff --git a/build/torch210-cxx11-cu126-x86_64-linux/ops/replicate.py b/build/torch210-cxx11-cu126-x86_64-linux/ops/replicate.py
new file mode 100644
index 0000000000000000000000000000000000000000..26daf0eede330603a4b8ea7167faf1411d07ca93
--- /dev/null
+++ b/build/torch210-cxx11-cu126-x86_64-linux/ops/replicate.py
@@ -0,0 +1,36 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for replicate kernel.
+class ReplicateOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, bins: torch.Tensor, num_outputs: int):
+ ctx.save_for_backward(bins)
+ out = torch.empty((x.shape[0], num_outputs), dtype=x.dtype, device=x.device)
+ ops.replicate_forward(x, bins, out)
+ return out
+
+ @staticmethod
+ def backward(ctx: Any, grad: torch.Tensor):
+ bins, = ctx.saved_tensors
+ out = torch.empty((grad.shape[0], bins.shape[0]), dtype=grad.dtype, device=grad.device)
+ ops.replicate_backward(grad, bins, out)
+ return out, None, None
+
+
+replicate = ReplicateOp.apply
diff --git a/build/torch210-cxx11-cu126-x86_64-linux/ops/round_up.py b/build/torch210-cxx11-cu126-x86_64-linux/ops/round_up.py
new file mode 100644
index 0000000000000000000000000000000000000000..6cf6bc873c9f448c5fa9126ebcfd66e8688002af
--- /dev/null
+++ b/build/torch210-cxx11-cu126-x86_64-linux/ops/round_up.py
@@ -0,0 +1,14 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+
+def round_up(x: torch.Tensor, value: int):
+ assert isinstance(value, int)
+ assert x.dtype == torch.int32
+
+ # TODO(tgale): If this becomes and issue
+ # do this in a custom kernel. We only expect
+ # to use this on arrays of less than 1k elements.
+ return torch.div(x + (value - 1), value, rounding_mode='trunc') * value
diff --git a/build/torch210-cxx11-cu126-x86_64-linux/ops/scatter.py b/build/torch210-cxx11-cu126-x86_64-linux/ops/scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4605d9b46f387761b070352365f223dbfe69d47
--- /dev/null
+++ b/build/torch210-cxx11-cu126-x86_64-linux/ops/scatter.py
@@ -0,0 +1,72 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Optional
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for scatter kernel.
+class ScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ) -> torch.Tensor:
+ maybe_x = [x] if ctx.needs_input_grad[3] else []
+ ctx.save_for_backward(indices, bin_ids, weights, bins, *maybe_x)
+ ctx.top_k = top_k
+ ctx.x_shape = x.shape
+ return kernels.scatter(x, indices, bin_ids, weights, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ saved_tensors = ctx.saved_tensors
+
+ indices, bin_ids, weights, bins = saved_tensors[:4]
+ dgrad = None
+ if ctx.needs_input_grad[0]:
+ dgrad = kernels.gather(
+ grad,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[3]: # need wgrad
+ x = saved_tensors[-1]
+ wgrad = kernels.scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bin_ids,
+ bins,
+ ctx.top_k,
+ )
+ return dgrad, None, None, wgrad, None, None, None
+
+
+def scatter(
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+) -> Optional[torch.Tensor]:
+ return ScatterOp.apply(x, indices, bin_ids, weights, bins, top_k)
diff --git a/build/torch210-cxx11-cu126-x86_64-linux/ops/sort.py b/build/torch210-cxx11-cu126-x86_64-linux/ops/sort.py
new file mode 100644
index 0000000000000000000000000000000000000000..bda3bf64283e39533c2eae3627e76bb2d0262c9f
--- /dev/null
+++ b/build/torch210-cxx11-cu126-x86_64-linux/ops/sort.py
@@ -0,0 +1,38 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Optional, Tuple
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+_BITS_FOR_DTYPE = {
+ torch.int16: 16,
+ torch.int32: 32,
+ torch.int64: 64,
+}
+
+
+# Autograd wrapper for sort kernel.
+# NOTE: Does not support gradients.
+class SortOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, end_bit: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
+ if end_bit is None:
+ end_bit = _BITS_FOR_DTYPE[x.dtype]
+ x_out = torch.empty_like(x)
+ iota_out = torch.empty_like(x)
+ ops.sort(x, end_bit, x_out, iota_out)
+ return (x_out, iota_out)
+
+
+sort = SortOp.apply
diff --git a/build/torch210-cxx11-cu126-x86_64-linux/ops/sort_benchmark.py b/build/torch210-cxx11-cu126-x86_64-linux/ops/sort_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..11043c0824c36372585f1d9f48480c2a6ef32eb6
--- /dev/null
+++ b/build/torch210-cxx11-cu126-x86_64-linux/ops/sort_benchmark.py
@@ -0,0 +1,85 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import numpy as np
+import torch
+# from absl.testing import parameterized
+
+from .. import ops
+
+_SORT_TESTS = (
+ (16384, torch.int32, None),
+ (16384, torch.int32, 2),
+ (16384, torch.int32, 128),
+)
+
+_BASELINE_SORT_TESTS = ((16384,),)
+
+
+def numpy_dtype(dtype):
+ types = {
+ torch.int16: np.int16,
+ torch.int32: np.int32,
+ torch.int64: np.int64,
+ }
+ return types[dtype]
+
+
+def benchmark_function(fn, iterations=10):
+ # Run once to get rid of startup overhead.
+ fn()
+ times = []
+ for _ in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+ start.record()
+ fn()
+ end.record()
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ times = np.array(times)
+ return times.mean(), times.std(), times.max(), times.min()
+
+
+def log_benchmark(arguments, mean_t, std_t):
+ print('=' * 60)
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t))
+ print('=' * 60)
+
+
+# class SortBenchmark(parameterized.TestCase):
+#
+# @parameterized.parameters(*_SORT_TESTS)
+# def testSort(self, n, dtype, max_val):
+# if max_val is None:
+# max_val = np.iinfo(numpy_dtype(dtype)).max
+# end_bit = int(np.ceil(np.log2(max_val)))
+# x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
+#
+# mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.sort(x, end_bit),)
+# arguments = {
+# 'n': n,
+# 'dtype': dtype,
+# 'max_val': max_val,
+# }
+# log_benchmark(arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_BASELINE_SORT_TESTS)
+# def testTorchSort(self, n):
+# x = torch.randint(0, 128, (n,)).cuda().to(torch.int32)
+#
+# mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.sort(x))
+# arguments = {
+# 'n': n,
+# }
+# log_benchmark(arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch210-cxx11-cu126-x86_64-linux/ops/stk_autocast.py b/build/torch210-cxx11-cu126-x86_64-linux/ops/stk_autocast.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a3626e5e0eec51339c95a448bca84be14a2ca93
--- /dev/null
+++ b/build/torch210-cxx11-cu126-x86_64-linux/ops/stk_autocast.py
@@ -0,0 +1,39 @@
+# vendored from
+# https://github.com/stanford-futuredata/stk/blob/736313768ef697ce13a0594a41b2512a0fbc9884/stk/backend/autocast.py
+import functools
+import torch
+
+
+def _is_eligible(x):
+ return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64)
+
+
+def _cast(x, dtype):
+ if isinstance(x, torch.Tensor) and _is_eligible(x):
+ return x.to(dtype)
+ elif isinstance(x, map):
+ return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()}
+ elif isinstance(x, list) or isinstance(x, tuple):
+ return type(x)(map(lambda y: _cast(y, dtype), x))
+ return x
+
+
+def custom_fwd(fwd):
+ """Wrap a custom autograd function that always uses autocast dtype."""
+
+ @functools.wraps(fwd)
+ def decorate_fwd(*args, **kwargs):
+ if torch.is_autocast_enabled():
+ with torch.autocast(device_type="cuda", enabled=False):
+ dtype = torch.get_autocast_gpu_dtype()
+ return fwd(*_cast(args, dtype), **_cast(kwargs, dtype))
+ return fwd(*args, **kwargs)
+ return decorate_fwd
+
+
+def custom_bwd(bwd):
+ @functools.wraps(bwd)
+ def decorate_bwd(*args, **kwargs):
+ with torch.autocast(device_type="cuda", enabled=False):
+ return bwd(*args, **kwargs)
+ return decorate_bwd
\ No newline at end of file
diff --git a/build/torch210-cxx11-cu126-x86_64-linux/ops/sum.py b/build/torch210-cxx11-cu126-x86_64-linux/ops/sum.py
new file mode 100644
index 0000000000000000000000000000000000000000..e00c1aa68e1193f5b72f75a2edc37de8d505facc
--- /dev/null
+++ b/build/torch210-cxx11-cu126-x86_64-linux/ops/sum.py
@@ -0,0 +1,9 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+import torch
+
+
+def sum(x: torch.Tensor, dim: int = 0):
+ if x.shape[dim] == 1:
+ return x.squeeze(dim=dim)
+ return x.sum(dim=dim)
diff --git a/build/torch210-cxx11-cu126-x86_64-linux/ops/topology.py b/build/torch210-cxx11-cu126-x86_64-linux/ops/topology.py
new file mode 100644
index 0000000000000000000000000000000000000000..76a50d3164db20534b099dcb4d8487a7aef25d15
--- /dev/null
+++ b/build/torch210-cxx11-cu126-x86_64-linux/ops/topology.py
@@ -0,0 +1,45 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for topology kernel.
+# NOTE: Does not support gradients.
+class TopologyOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(
+ ctx: Any,
+ padded_bins: torch.Tensor,
+ block_size: int,
+ output_block_rows: int,
+ output_block_columns: int,
+ ):
+ out = torch.empty(
+ output_block_rows * output_block_columns,
+ dtype=torch.int16,
+ device=padded_bins.device,
+ )
+ ops.indices(
+ padded_bins,
+ block_size,
+ output_block_rows,
+ output_block_columns,
+ out,
+ )
+ return out
+
+
+topology = TopologyOp.apply
diff --git a/build/torch210-cxx11-cu126-x86_64-linux/stk/__init__.py b/build/torch210-cxx11-cu126-x86_64-linux/stk/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..73c40c267b1c3f4949e9c957a5d2c9f682dfc1a6
--- /dev/null
+++ b/build/torch210-cxx11-cu126-x86_64-linux/stk/__init__.py
@@ -0,0 +1,7 @@
+# import stk.random
+# import stk.ops
+# from stk.matrix import Matrix
+
+from . import random
+from . import ops
+from .matrix import Matrix
diff --git a/build/torch210-cxx11-cu126-x86_64-linux/stk/backend/__init__.py b/build/torch210-cxx11-cu126-x86_64-linux/stk/backend/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/build/torch210-cxx11-cu126-x86_64-linux/stk/backend/autocast.py b/build/torch210-cxx11-cu126-x86_64-linux/stk/backend/autocast.py
new file mode 100644
index 0000000000000000000000000000000000000000..97f6e919a60f3fd579ed0215031008d14111dc96
--- /dev/null
+++ b/build/torch210-cxx11-cu126-x86_64-linux/stk/backend/autocast.py
@@ -0,0 +1,37 @@
+import functools
+import torch
+
+
+def _is_eligible(x):
+ return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64)
+
+
+def _cast(x, dtype):
+ if isinstance(x, torch.Tensor) and _is_eligible(x):
+ return x.to(dtype)
+ elif isinstance(x, map):
+ return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()}
+ elif isinstance(x, list) or isinstance(x, tuple):
+ return type(x)(map(lambda y: _cast(y, dtype), x))
+ return x
+
+
+def custom_fwd(fwd):
+ """Wrap a custom autograd function that always uses autocast dtype."""
+
+ @functools.wraps(fwd)
+ def decorate_fwd(*args, **kwargs):
+ if torch.is_autocast_enabled():
+ with torch.autocast(device_type="cuda", enabled=False):
+ dtype = torch.get_autocast_gpu_dtype()
+ return fwd(*_cast(args, dtype), **_cast(kwargs, dtype))
+ return fwd(*args, **kwargs)
+ return decorate_fwd
+
+
+def custom_bwd(bwd):
+ @functools.wraps(bwd)
+ def decorate_bwd(*args, **kwargs):
+ with torch.autocast(device_type="cuda", enabled=False):
+ return bwd(*args, **kwargs)
+ return decorate_bwd
diff --git a/build/torch210-cxx11-cu126-x86_64-linux/stk/backend/sputnik.py b/build/torch210-cxx11-cu126-x86_64-linux/stk/backend/sputnik.py
new file mode 100644
index 0000000000000000000000000000000000000000..220c947bc1e932e8c77cc30f4069e9930f1aa962
--- /dev/null
+++ b/build/torch210-cxx11-cu126-x86_64-linux/stk/backend/sputnik.py
@@ -0,0 +1,316 @@
+import torch
+
+from ..backend import triton_kernels as backend
+from ..backend.autocast import custom_bwd, custom_fwd
+
+
+def _standardize_shape(x, transpose):
+ if transpose:
+ return torch.Size((x[1], x[0]))
+ return x
+
+
+def _sparse_transpose(x):
+ return (torch.Size((x[0][1], x[0][0])), ) + x[1:]
+
+
+def _transpose_helper(x, transpose):
+ if isinstance(x, torch.Tensor):
+ return x.t() if transpose else x
+ if transpose:
+ x = _sparse_transpose(x)
+ return x + (transpose,)
+
+
+def _wrap(x):
+ if isinstance(x, torch.Tensor):
+ return (x,)
+ return x
+
+
+def _is_transposed(x):
+ return (not x.is_contiguous() and
+ x.stride()[0] == 1 and
+ x.stride()[1] == x.size()[0])
+
+
+def _call_helper(op, out, a, b, trans_a, trans_b):
+ args = (_wrap(_transpose_helper(a, trans_a)) +
+ _wrap(_transpose_helper(b, trans_b)))
+ if isinstance(out, tuple):
+ args = args + out
+ return op(*args)
+
+
+def _preprocess_inputs(lhs, rhs, dy):
+ if isinstance(lhs, torch.Tensor) and _is_transposed(lhs):
+ lhs = lhs.t()
+ if isinstance(rhs, torch.Tensor) and _is_transposed(rhs):
+ rhs = rhs.t()
+ if (isinstance(dy, torch.Tensor) and
+ not dy.is_contiguous() and
+ not _is_transposed(dy)):
+ dy = dy.contiguous()
+ if isinstance(dy, tuple) and not dy[1].is_contiguous():
+ dy = (dy[0], dy[1].contiguous()) + dy[2:]
+ return lhs, rhs, dy
+
+
+def _postprocess_outputs(x, transpose, grad):
+ if isinstance(x, torch.Tensor) and transpose:
+ return grad.t()
+ return grad
+
+
+def _lhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs):
+ lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy)
+
+ a, b = (rhs, dy) if trans_lhs else (dy, rhs)
+ trans_a = trans_lhs and trans_rhs
+ trans_b = trans_lhs or not trans_rhs
+ out = _call_helper(op, lhs, a, b, trans_a, trans_b)
+ return _postprocess_outputs(lhs, trans_lhs, out)
+
+
+def _rhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs):
+ lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy)
+
+ a, b = (dy, lhs) if trans_rhs else (lhs, dy)
+ trans_a = not trans_lhs or trans_rhs
+ trans_b = trans_lhs and trans_rhs
+ out = _call_helper(op, rhs, a, b, trans_a, trans_b)
+ return _postprocess_outputs(rhs, trans_rhs, out)
+
+
+class DSD(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs):
+ ctx.save_for_backward(data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ rhs)
+ ctx.shape = _standardize_shape(shape, transpose_a)
+ ctx.transpose_a = transpose_a
+
+ out = torch.empty(
+ (shape[0], rhs.size()[1]),
+ dtype=rhs.dtype,
+ device=rhs.device)
+
+ backend.dsd(shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs,
+ out)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs = (ctx.shape,) + saved_tensors[:-1]
+ rhs = saved_tensors[-1]
+ trans_a = ctx.transpose_a
+ trans_b = _is_transposed(rhs)
+
+ ddata = None
+ if ctx.needs_input_grad[1]:
+ ddata = _lhs_gradient(sdd,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ drhs = None
+ if ctx.needs_input_grad[-1]:
+ op = dds if trans_b else dsd
+ drhs = _rhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return None, ddata, None, None, None, None, None, None, None, drhs
+
+
+dsd = DSD.apply
+
+
+class DDS(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b):
+ ctx.save_for_backward(lhs,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t)
+ ctx.shape = _standardize_shape(shape, transpose_b)
+ ctx.transpose_b = transpose_b
+ out = torch.empty((lhs.size()[0], shape[1]),
+ dtype=lhs.dtype,
+ device=lhs.device)
+ backend.dds(lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b,
+ out)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs = saved_tensors[0]
+ rhs = (ctx.shape,) + saved_tensors[1:]
+ trans_a = _is_transposed(lhs)
+ trans_b = ctx.transpose_b
+
+ dlhs = None
+ if ctx.needs_input_grad[0]:
+ op = dsd if trans_a else dds
+ dlhs = _lhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ ddata = None
+ if ctx.needs_input_grad[2]:
+ ddata = _rhs_gradient(sdd,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return dlhs, None, ddata, None, None, None, None, None, None, None
+
+
+dds = DDS.apply
+
+
+class SDD(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ lhs,
+ rhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t):
+ ctx.save_for_backward(
+ lhs,
+ rhs,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t)
+ ctx.shape = shape
+ out = torch.empty(
+ data.shape,
+ dtype=lhs.dtype,
+ device=lhs.device)
+ backend.sdd(lhs,
+ rhs,
+ shape,
+ out,
+ offsets,
+ row_indices,
+ column_indices)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs, rhs = saved_tensors[:2]
+ dy = (ctx.shape, dy) + saved_tensors[2:]
+ trans_a = _is_transposed(lhs)
+ trans_b = _is_transposed(rhs)
+
+ dlhs = None
+ if ctx.needs_input_grad[0]:
+ op = dds if trans_a else dsd
+ dlhs = _lhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ drhs = None
+ if ctx.needs_input_grad[1]:
+ op = dsd if trans_b else dds
+ drhs = _rhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return dlhs, drhs, None, None, None, None, None, None, None, None
+
+
+sdd = SDD.apply
+
+class RowIndices(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, shape, data, offsets, column_indices):
+ out = torch.empty(
+ column_indices.shape,
+ dtype=column_indices.dtype,
+ device=column_indices.device)
+ backend.row_indices(shape, data, offsets, column_indices, out)
+ return out
+
+
+row_indices = RowIndices.apply
diff --git a/build/torch210-cxx11-cu126-x86_64-linux/stk/backend/triton_kernels.py b/build/torch210-cxx11-cu126-x86_64-linux/stk/backend/triton_kernels.py
new file mode 100644
index 0000000000000000000000000000000000000000..c535309f3321249f475367164a558f94a4f8eb86
--- /dev/null
+++ b/build/torch210-cxx11-cu126-x86_64-linux/stk/backend/triton_kernels.py
@@ -0,0 +1,393 @@
+import torch
+import triton
+import triton.language as tl
+from dataclasses import dataclass
+
+@dataclass
+class TritonConfig:
+ BLOCK_M: int = 128
+ BLOCK_N: int = 128
+ BLOCK_K: int = 32
+ BLOCK_SIZE: int = 128
+ NUM_STAGES: int = 4
+ NUM_WARPS: int = 4
+
+def _validate_matmul_dims(M: int, K: int, N: int):
+ error_string = "incompatible dimensions: tensor has dim with length: {}, which must be divisible by {}"
+ assert M % TritonConfig.BLOCK_M == 0, error_string.format(M, TritonConfig.BLOCK_M)
+ assert K % TritonConfig.BLOCK_K == 0, error_string.format(K, TritonConfig.BLOCK_K)
+ assert N % TritonConfig.BLOCK_N == 0, error_string.format(N, TritonConfig.BLOCK_N)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _sdd_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+ # matrix multiplication
+ pid = tl.program_id(0)
+ pid_m = tl.load(row_indices + pid)
+ pid_n = tl.load(column_indices + pid)
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
+ rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
+ rk = tl.arange(0, BLOCK_K)
+ # pointers
+ A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
+ B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ for k in range(0, tl.cdiv(K, BLOCK_K)):
+ a = tl.load(A)
+ b = tl.load(B)
+ acc += tl.dot(a, b)
+ A += BLOCK_K * stride_ak
+ B += BLOCK_K * stride_bk
+ #Store to sparse matrix
+ acc = acc.to(C.dtype.element_ty)
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+ cm = tl.arange(0, BLOCK_M)
+ cn = tl.arange(0, BLOCK_N)
+ C = C + pid * BLOCK_ELEMENTS + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _dsd_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices, offsets,
+ block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+
+ # matrix multiplication
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+
+ num_pid_m = tl.num_programs(0)
+ num_pid_n = tl.num_programs(1)
+ pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M)
+
+ start_inx = tl.load(offsets + pid_m)
+ end_inx = tl.load(offsets + pid_m + 1)
+
+ # pointers to sparse matrix
+ rm = tl.arange(0, BLOCK_M)
+ rak = tl.arange(0, BLOCK_K)
+
+ A += (rm[:, None] * stride_am + rak[None, :] * stride_ak)
+
+ # pointers to dense matrix
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ rbk = tl.arange(0, BLOCK_K)
+ B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn)
+
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K)
+
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+ ak_sub_incr = BLOCK_K * stride_ak
+ bk_sub_incr = BLOCK_K * stride_bk
+ bk_block_incr = BLOCK_SIZE * stride_bk
+
+ for k in range(nsub_blocks * (end_inx - start_inx)):
+ sub_block_inx = k % nsub_blocks
+ block_inx = k // nsub_blocks
+
+ if trans_A:
+ ptr_A = A + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr
+ else:
+ ptr_A = A + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr
+
+ ptr_B = B + tl.load(column_indices + start_inx + block_inx) * bk_block_incr + sub_block_inx * bk_sub_incr
+
+ a = tl.load(ptr_A)
+ b = tl.load(ptr_B)
+ acc += tl.dot(a, b)
+
+ acc = acc.to(C.dtype.element_ty)
+
+ cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+
+ C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _dds_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices, offsets,
+ block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+
+ # matrix multiplication
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+
+ num_pid_m = tl.num_programs(0)
+ num_pid_n = tl.num_programs(1)
+ pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M)
+
+ start_inx = tl.load(offsets + pid_n)
+ end_inx = tl.load(offsets + pid_n + 1)
+
+ # pointers to dense matrix
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ rak = tl.arange(0, BLOCK_K)
+
+ A += (rm[:, None] * stride_am + rak[None, :] * stride_ak)
+
+ # pointers to sparse matrix
+ rn = tl.arange(0, BLOCK_N)
+ rbk = tl.arange(0, BLOCK_K)
+ B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn)
+
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K)
+
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+
+ ak_sub_incr = BLOCK_K * stride_ak
+ ak_block_incr = BLOCK_SIZE * stride_ak
+ bk_sub_incr = BLOCK_K * stride_bk
+
+ for k in range(nsub_blocks * (end_inx - start_inx)):
+ sub_block_inx = k % nsub_blocks
+ block_inx = k // nsub_blocks
+
+ if trans_B:
+ ptr_B = B + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr
+ else:
+ ptr_B = B + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr
+
+ ptr_A = A + tl.load(column_indices + start_inx + block_inx) * ak_block_incr + sub_block_inx * ak_sub_incr
+ a = tl.load(ptr_A)
+ b = tl.load(ptr_B)
+ acc += tl.dot(a, b)
+
+ acc = acc.to(C.dtype.element_ty)
+ cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+def dsd(shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs,
+ out
+ ):
+
+ device = rhs.device
+ trans_A = transpose_a
+ trans_B = False
+
+ if rhs.stride(0) > 1 and rhs.stride(1) > 1:
+ trans_B = True
+
+ # checks constraints
+ assert shape[1] == rhs.shape[0], "incompatible dimensions"
+ M, K = shape
+ _, N = rhs.shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if rhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ stride_am, stride_ak = data.stride(1), data.stride(2)
+ stride_bk, stride_bn = rhs.stride(0), rhs.stride(1)
+ a_column_indices = column_indices
+ a_offsets = offsets
+
+ # launch kernel
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N']))
+
+ if trans_A:
+ stride_am, stride_ak = data.stride(2), data.stride(1)
+ a_column_indices, a_offsets = column_indices_t, offsets_t
+
+ if trans_B:
+ stride_bk, stride_bn = rhs.stride(1), rhs.stride(0)
+
+ _dsd_kernel[grid](
+ data.data, rhs, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(0), out.stride(1),
+ row_indices, a_column_indices, a_offsets,
+ block_offsets_t, trans_A, trans_B,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+ # return out
+
+def dds(lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b,
+ out
+ ):
+
+ device = lhs.device
+ trans_B = transpose_b
+ trans_A = False
+
+ if lhs.stride(0) > 1 and lhs.stride(1) > 1:
+ trans_A = True
+
+ # checks constraints
+ assert lhs.shape[1] == shape[0], "incompatible dimensions"
+ M, K = lhs.shape
+ _, N = shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if lhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ stride_am, stride_ak = lhs.stride(0), lhs.stride(1)
+ stride_bk, stride_bn = data.stride(1), data.stride(2)
+ b_column_indices = column_indices_t
+ b_offsets = offsets_t
+
+ # launch kernel
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N']))
+
+ if trans_A:
+ stride_am, stride_ak = lhs.stride(1), lhs.stride(0)
+ if trans_B:
+ stride_bk, stride_bn = data.stride(2), data.stride(1)
+ b_column_indices, b_offsets = column_indices, offsets
+
+ _dds_kernel[grid](
+ lhs, data, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(0), out.stride(1),
+ row_indices, b_column_indices, b_offsets,
+ block_offsets_t, trans_A, trans_B,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+
+def sdd(lhs,
+ rhs,
+ shape,
+ out,
+ offsets,
+ row_indices,
+ column_indices
+ ):
+
+ device = out.device
+ trans_A = False
+ trans_B = False
+
+ if lhs.stride(0) > 1 and lhs.stride(1) > 1:
+ trans_A = True
+ if rhs.stride(0) > 1 and rhs.stride(1) > 1:
+ trans_B = True
+
+ # checks constraints
+ assert lhs.shape[1] == rhs.shape[0], "incompatible dimensions"
+ M, K = lhs.shape
+ _, N = rhs.shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if out.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ # launch kernel
+ nnz_blocks = len(row_indices)
+ grid = lambda META: (nnz_blocks,)
+
+ stride_am, stride_ak = lhs.stride(0), lhs.stride(1)
+ stride_bk, stride_bn = rhs.stride(0), rhs.stride(1)
+
+ if trans_A:
+ stride_am, stride_ak = lhs.stride(1), lhs.stride(0)
+ if trans_B:
+ stride_bk, stride_bn = rhs.stride(1), rhs.stride(0)
+
+ _sdd_kernel[grid](
+ lhs, rhs, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(1), out.stride(2),
+ row_indices, column_indices,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+
+@triton.jit
+def _row_indices_kernel(offsets, out):
+ pid = tl.program_id(0)
+ row_offset = tl.load(offsets + pid)
+ nnz_blocks = tl.load(offsets + pid + 1) - row_offset
+ for nnz_block in range(nnz_blocks):
+ tl.store(out + row_offset + nnz_block, pid)
+
+def row_indices(
+ shape, data, offsets, column_indices, out
+):
+ block_rows = len(offsets) - 1
+ _row_indices_kernel[(block_rows, )](offsets, out)
diff --git a/build/torch210-cxx11-cu126-x86_64-linux/stk/matrix.py b/build/torch210-cxx11-cu126-x86_64-linux/stk/matrix.py
new file mode 100644
index 0000000000000000000000000000000000000000..80f42263d6aada287adbfa52a61fe950162a9e28
--- /dev/null
+++ b/build/torch210-cxx11-cu126-x86_64-linux/stk/matrix.py
@@ -0,0 +1,329 @@
+import numpy as np
+import torch
+
+# 1. Add heavyweight (data) validation helper.
+# 2. Add construction helpers
+# 3. Make indentation consistent
+# 4. Replace asserts with descriptive errors.
+
+##
+### Validation helpers.
+##
+
+
+def _validate_matrix(shape, data, row_indices, column_indices, offsets):
+ # Data should be [nnz, block_size, block_size]
+ if data.dim() == 1:
+ data = torch.reshape(data, [data.numel(), 1, 1])
+
+ # Blocks should be square.
+ if data.shape[-2] != data.shape[-1]:
+ raise ValueError(
+ "Expected square blocking in data. "
+ f"Got block shape {[data.shape[-2], data.shape[-1]]}")
+
+ # Flatten batch dimensions on data - original shape preserved
+ # in shape argument.
+ block_size = data.shape[-1]
+ data = data.view([-1, block_size, block_size])
+
+ if data.dim() != 3:
+ raise ValueError(
+ "Expected 3D shape for data (nnz, block, block). "
+ f"Got shape {data.dim()}D shape.")
+
+ block_size = data.shape[1]
+ if shape[-2] % block_size != 0 or shape[-1] % block_size != 0:
+ raise ValueError(
+ "Matrix shape must be dividible by blocking. "
+ f"Got shape {shape} with "
+ f"{[block_size, block_size]} blocking.")
+
+ if np.prod(shape) < data.numel():
+ raise ValueError(
+ "Invalid matrix. Number of nonzeros exceeds matrix capacity "
+ f"({data.numel()} v. {np.prod(shape)})")
+
+ if row_indices.dim() != 1:
+ raise ValueError(
+ f"Expected 1D row_indices. Got {row_indices.dim()}D row_indices.")
+
+ if column_indices.dim() != 1:
+ raise ValueError(
+ f"Expected 1D column_indices. Got {column_indices.dim()}D column_indices.")
+
+ if offsets.dim() != 1:
+ raise ValueError(
+ f"Expected 1D offsets. Got {offsets.dim()}D offsets.")
+
+ if row_indices.numel() != data.shape[0]:
+ raise ValueError(
+ "Expected 1 index per nonzero block. "
+ f"Got {row_indices.numel()} row_indices for {data.shape[0]} blocks")
+
+ if column_indices.numel() != data.shape[0]:
+ raise ValueError(
+ "Expected 1 index per nonzero block. "
+ f"Got {column_indices.numel()} column_indices for {data.shape[0]} blocks")
+
+ block_rows = np.prod(shape[:-1]) / block_size
+ if offsets.numel() != block_rows + 1:
+ raise ValueError(
+ "Expected one offset per block row plus one. "
+ f"Got {offsets.numel()} offsets with {block_rows} block rows.")
+
+ is_cuda = (data.is_cuda and
+ row_indices.is_cuda and
+ column_indices.is_cuda and
+ offsets.is_cuda)
+ is_cpu = (not data.is_cuda and
+ not row_indices.is_cuda and
+ not column_indices.is_cuda and
+ not offsets.is_cuda)
+ if not (is_cuda or is_cpu):
+ raise ValueError(
+ "Expected data & meta-data on common device. "
+ f"Got data on {data.device}, row_indices on {row_indices.device} "
+ f"column_indices on {column_indices.device} and "
+ f"offsets on {offsets.device}.")
+
+ if data.dtype != torch.float16:
+ raise ValueError(
+ f"Expected float16 data. Got {data.dtype} data.")
+ if row_indices.dtype != torch.int16:
+ raise ValueError(
+ f"Expected int16 row_indices. Got {row_indices.dtype} row_indices.")
+ if column_indices.dtype != torch.int16:
+ raise ValueError(
+ f"Expected int16 column_indices. Got {column_indices.dtype} column_indices.")
+ if offsets.dtype != torch.int32:
+ raise ValueError(
+ f"Expected int32 offsets. Got {offsets.dtype} offsets.")
+ return data
+
+
+def _transpose(size, data, row_indices, column_indices, offsets):
+ block_columns = size[1] // data.shape[1]
+
+ # Sort row indices by column indices to get the transposed matrix's
+ # column indices.
+ gather_indices = column_indices.argsort()
+ column_indices_t = row_indices.gather(0, gather_indices)
+ block_offsets_t = gather_indices.int()
+
+ # NOTE: Histogram is not implemented for any integer type on CPU. Do
+ # the histogram in 32-bit float, which can exactly represent 16-bit
+ # integers.
+ column_indices_float = column_indices.float()
+
+ zero = torch.zeros((1,), dtype=torch.int32, device=data.device)
+ nnz_per_column = column_indices_float.histc(block_columns, 0, block_columns)
+ nnz_per_column = nnz_per_column.int()
+ offsets_t = torch.cat([zero, nnz_per_column.cumsum(0, dtype=torch.int32)])
+ return column_indices_t, offsets_t, block_offsets_t
+
+
+class Matrix(torch.nn.Module):
+ """A matrix stored in sparse format.
+
+ Underlying format is block compressed sparse row (BCSR).
+
+ TODO(tgale): Make this mirror torch.Tensor API as much as possible.
+ """
+
+ def __init__(self,
+ size,
+ data,
+ row_indices,
+ column_indices,
+ offsets,
+ column_indices_t=None,
+ offsets_t=None,
+ block_offsets_t=None):
+ super().__init__()
+ self._size = size
+ self._data = data
+ self._row_indices = row_indices
+ self._column_indices = column_indices
+ self._offsets = offsets
+
+ # Produce the transpose meta-data if it is not passed in.
+ if ((column_indices_t is None) or (offsets_t is None) or
+ (block_offsets_t is None)):
+ column_indices_t, offsets_t, block_offsets_t = _transpose(
+ size, data, row_indices, column_indices, offsets)
+ self._column_indices_t = column_indices_t
+ self._offsets_t = offsets_t
+ self._block_offsets_t = block_offsets_t
+
+ self._transposed = False
+
+ # Validate that our metadata will not overflow.
+ max_dim = np.iinfo(np.int16).max * self.blocking
+ if column_indices.dtype == torch.int16:
+ if size[0] > max_dim or size[1] > max_dim:
+ raise ValueError(
+ "Sparse matrix with shape {size} exceeds representable "
+ "size with 16-bit indices.")
+
+ def validate(self):
+ _validate_matrix(self._size,
+ self._data,
+ self._row_indices,
+ self._column_indices,
+ self._offsets)
+
+ # TODO(tgale): Add heavyweight data validation.
+
+ def to(self, device):
+ # TODO(tgale): Handle type conversions here. We
+ # need to set the appropriate meta-data type for
+ # the given floating-point type.
+ self._data = self._data.to(device)
+ self._row_indices = self._row_indices.to(device)
+ self._column_indices = self._column_indices.to(device)
+ self._offsets = self._offsets.to(device)
+ self._column_indices_t = self._column_indices_t.to(device)
+ self._offsets_t = self._offsets_t.to(device)
+ self._block_offsets_t = self._block_offsets_t.to(device)
+ return self
+
+ def cuda(self):
+ return self.to(torch.cuda.current_device())
+
+ def clone(self):
+ return Matrix(
+ self.size(),
+ self.data.clone(),
+ self.row_indices.clone(),
+ self.column_indices.clone(),
+ self.offsets.clone(),
+ self.column_indices_t.clone(),
+ self.offsets_t.clone(),
+ self.block_offsets_t.clone())
+
+ def t(self):
+ if self.dim() != 2:
+ raise ValueError(
+ "t() expects a tensor with <= 2 dimensions, "
+ f"but self is {self.dim()}D.")
+ out = Matrix(self.size(),
+ self.data,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+ out._transposed = not self._transposed
+ out._size = torch.Size((self._size[1], self._size[0]))
+ return out
+
+ def contiguous(self):
+ raise ValueError("Not yet implemented.")
+
+ def is_contiguous(self):
+ return not self._transposed
+
+ @property
+ def is_cuda(self):
+ return self._data.is_cuda
+
+ @property
+ def device(self):
+ return self._data.device
+
+ def size(self):
+ return self._size
+
+ @property
+ def shape(self):
+ return self.size()
+
+ def dim(self):
+ return len(self._size)
+
+ @property
+ def data(self):
+ return self._data
+
+ @property
+ def row_indices(self):
+ return self._row_indices
+
+ @property
+ def column_indices(self):
+ return self._column_indices
+
+ @property
+ def offsets(self):
+ return self._offsets
+
+ @property
+ def offsets_t(self):
+ return self._offsets_t
+
+ @property
+ def column_indices_t(self):
+ return self._column_indices_t
+
+ @property
+ def block_offsets_t(self):
+ return self._block_offsets_t
+
+ @property
+ def dtype(self):
+ return self.data.dtype
+
+ @property
+ def nnz(self):
+ return self.data.numel()
+
+ @property
+ def blocking(self):
+ return self.data.shape[1]
+
+ @property
+ def requires_grad(self):
+ return self.data.requires_grad
+
+ def requires_grad_(self, x):
+ self.data.requires_grad_(x)
+ return self
+
+ def view(self, *shape):
+ assert self.is_contiguous()
+ if shape[-1] != self.size()[-1]:
+ raise ValueError(
+ "Can't change view on compressed dimension. "
+ f"{self.size()[-1]} v. {shape[-1]}.")
+ if np.prod(shape) != np.prod(self.size()):
+ raise ValueError(
+ "Mismatch in numel of Matrix and new shape. "
+ f"{np.prod(self.size())} v. {np.prod(shape)}")
+ return Matrix(shape,
+ self.data,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+
+ @property
+ def grad(self):
+ # TODO(tgale): Make sure this mirrors torch.Tensor
+ # behavior in the case where we ask for the gradient
+ # of a non-contiguous tensor.
+ size = self.size()
+ if not self.is_contiguous():
+ size = torch.Size((size[1], size[0]))
+ out = Matrix(size,
+ self.data.grad,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+ return out if self.is_contiguous() else out.t()
diff --git a/build/torch210-cxx11-cu126-x86_64-linux/stk/ops/__init__.py b/build/torch210-cxx11-cu126-x86_64-linux/stk/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc873b236f4cd4036964c016a4036e3ce5ebf1ac
--- /dev/null
+++ b/build/torch210-cxx11-cu126-x86_64-linux/stk/ops/__init__.py
@@ -0,0 +1,3 @@
+from .linear_ops import dds, dsd, sdd
+from .matrix_ops import ones_like, row_indices, sum, to_dense, to_sparse
+from .eltwise_ops import mul
diff --git a/build/torch210-cxx11-cu126-x86_64-linux/stk/ops/eltwise_ops.py b/build/torch210-cxx11-cu126-x86_64-linux/stk/ops/eltwise_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba7d7332320250fd01fa60e60528f19de3e8ed03
--- /dev/null
+++ b/build/torch210-cxx11-cu126-x86_64-linux/stk/ops/eltwise_ops.py
@@ -0,0 +1,28 @@
+from ..matrix import Matrix
+
+def mul(a, b):
+ """Performs element-wise multiplication of matrices a and b.
+
+ It is the user's responsibility to make sure that a and b
+ follow the same matrix topology. This function assumes it is safe
+ to use the topoplogy of a.
+
+ Args:
+ a: stk.Matrix.
+ b: stk.Matrix with a's matrix topology.
+
+ Returns:
+ stk.Matrix where the entries correspond to torch.mul(a, b).
+ """
+ assert isinstance(a, Matrix)
+ assert isinstance(b, Matrix)
+ assert a.size() == b.size()
+
+ return Matrix(a.size(),
+ a.data * b.data,
+ a.row_indices,
+ a.column_indices,
+ a.offsets,
+ a.column_indices_t,
+ a.offsets_t,
+ a.block_offsets_t)
diff --git a/build/torch210-cxx11-cu126-x86_64-linux/stk/ops/eltwise_ops_test.py b/build/torch210-cxx11-cu126-x86_64-linux/stk/ops/eltwise_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..2939372a5c68ac92b47b11015db4f75f4fd60ffa
--- /dev/null
+++ b/build/torch210-cxx11-cu126-x86_64-linux/stk/ops/eltwise_ops_test.py
@@ -0,0 +1,86 @@
+import unittest
+import itertools
+import torch
+# from absl.testing import parameterized
+
+import stk
+from stk.ops.linear_ops_test import allclose, _dense_and_sparse
+
+_MATRIX_SIZES = (
+ (128, 128, 0.0),
+ (256, 256, 0.5),
+ (2048, 1024, 0.8),
+ (512, 128, 0.0),
+ (128, 512, 0.0),
+ (1024, 512, 0.0),
+ (1024, 512, 0.5),
+ (1024, 512, 0.75),
+ (512, 1024, 0.0),
+ (512, 1024, 0.5),
+ (512, 1024, 0.75),
+ (1024, 1024, 0.0),
+ (1024, 1024, 0.5),
+ (1024, 1024, 0.75),
+)
+
+_DTYPE = (
+ torch.float16, torch.bfloat16
+)
+
+def _generate_testcases():
+ testcases = itertools.product(_MATRIX_SIZES, _DTYPE)
+ testcases = [(*size, 128, dtype) for
+ (size, dtype) in testcases]
+ return testcases
+
+_ELTWISE_OP_TESTS = _generate_testcases()
+
+def _dense_and_sparse_like(x, std=0.1):
+ dense_data = torch.randn_like(x.data, device=x.device) * std
+ sparse = stk.Matrix(x.size(),
+ dense_data,
+ x.row_indices,
+ x.column_indices,
+ x.offsets)
+ dense = stk.ops.to_dense(sparse)
+
+ return (dense.requires_grad_(True),
+ sparse.requires_grad_(True))
+
+# @parameterized.parameters(_ELTWISE_OP_TESTS)
+# class EltwiseOpsTest(parameterized.TestCase):
+#
+# def testEltwiseMul(self, m, n, sparsity, blocking, dtype):
+#
+# a_dense, a = _dense_and_sparse(m, n, sparsity, blocking, dtype)
+# b_dense, b = _dense_and_sparse_like(a)
+#
+# out = stk.ops.mul(a, b)
+# expected_out = torch.mul(a_dense, b_dense)
+#
+# # Compute the gradients w.r.t. the inputs.
+# expected_out.sum().backward()
+# stk.ops.sum(out).backward()
+#
+# # Validate the results.
+# out = stk.ops.to_dense(out)
+# self.assertEqual(out.dim(), 2)
+# self.assertEqual(expected_out.size(), out.size())
+# self.assertTrue(allclose(out, expected_out))
+#
+# # LHS gradient.
+# grad = stk.ops.to_dense(a.grad)
+# expected_grad = a_dense.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size(), grad.size())
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# # RHS gradient.
+# grad = stk.ops.to_dense(b.grad)
+# expected_grad = b_dense.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size(), grad.size())
+# self.assertTrue(allclose(grad, expected_grad))
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch210-cxx11-cu126-x86_64-linux/stk/ops/linear_ops.py b/build/torch210-cxx11-cu126-x86_64-linux/stk/ops/linear_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d277c8c07f9e30addc31900a12175c8a1f4d7ad
--- /dev/null
+++ b/build/torch210-cxx11-cu126-x86_64-linux/stk/ops/linear_ops.py
@@ -0,0 +1,59 @@
+import torch
+
+from ..backend import sputnik
+from ..matrix import Matrix
+
+
+def dsd(a, b):
+ assert isinstance(a, Matrix)
+ assert isinstance(b, torch.Tensor)
+ return sputnik.dsd(
+ a.size(),
+ a.data, a.offsets,
+ a.row_indices,
+ a.column_indices,
+ a.offsets_t,
+ a.column_indices_t,
+ a.block_offsets_t,
+ not a.is_contiguous(),
+ b)
+
+
+def dds(a, b):
+ assert isinstance(a, torch.Tensor)
+ assert isinstance(b, Matrix)
+ return sputnik.dds(
+ a,
+ b.size(),
+ b.data, b.offsets,
+ b.row_indices,
+ b.column_indices,
+ b.offsets_t,
+ b.column_indices_t,
+ b.block_offsets_t,
+ not b.is_contiguous())
+
+
+def sdd(a, b, topo):
+ assert isinstance(a, torch.Tensor)
+ assert isinstance(b, torch.Tensor)
+ assert isinstance(topo, Matrix)
+ assert topo.is_contiguous()
+ out = sputnik.sdd(
+ a, b,
+ topo.size(),
+ topo.data,
+ topo.offsets,
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets_t,
+ topo.column_indices_t,
+ topo.block_offsets_t)
+ return Matrix(topo.size(),
+ out,
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets,
+ topo.column_indices_t,
+ topo.offsets_t,
+ topo.block_offsets_t)
diff --git a/build/torch210-cxx11-cu126-x86_64-linux/stk/ops/linear_ops_test.py b/build/torch210-cxx11-cu126-x86_64-linux/stk/ops/linear_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1c24d350df9c1b2346c7da885502cd696c88867
--- /dev/null
+++ b/build/torch210-cxx11-cu126-x86_64-linux/stk/ops/linear_ops_test.py
@@ -0,0 +1,216 @@
+import unittest
+import itertools
+import numpy as np
+import torch
+# from absl.testing import parameterized
+
+import stk
+
+
+def allclose(x, y, pct=0.25):
+ mask = torch.isclose(x, y, rtol=5e-2)
+ pct_diff = (mask.numel() - mask.sum()) / mask.numel() * 100
+ if pct_diff > pct:
+ print("{:.2f}% of values not close.".format(pct_diff))
+ return False
+ return True
+
+
+# An assortment of problems designed to make sure
+# the bindings are operating correctly.
+_MATRIX_SIZES = (
+ (128, 128, 128, 0.0),
+ (256, 256, 256, 0.5),
+ (2048, 1024, 512, 0.8),
+ (512, 128, 128, 0.0),
+ (128, 128, 512, 0.0),
+ (1024, 512, 512, 0.0),
+ (1024, 512, 512, 0.5),
+ (1024, 512, 512, 0.75),
+ (512, 512, 1024, 0.0),
+ (512, 512, 1024, 0.5),
+ (512, 512, 1024, 0.75),
+ (1024, 1024, 1024, 0.0),
+ (1024, 1024, 1024, 0.5),
+ (1024, 1024, 1024, 0.75),
+)
+
+_TRANSPOSE = (
+ (False, False),
+ (False, True),
+ (True, False),
+ (True, True),
+)
+
+_DTYPE = (
+ torch.float16, torch.bfloat16
+)
+
+def _generate_testcases():
+ testcases = itertools.product(_MATRIX_SIZES, _TRANSPOSE, _DTYPE)
+ testcases = [(*size, *trans, 128, dtype) for
+ (size, trans, dtype) in testcases]
+ return testcases
+
+_LINEAR_OP_TESTS = _generate_testcases()
+
+def _dense_and_sparse(rows, cols, sparsity, blocking, dtype, std=0.1):
+ mask = stk.random.dense_mask(rows, cols, sparsity, blocking)
+ dense = (torch.randn(rows, cols) * std * mask).type(dtype)
+ sparse = stk.ops.to_sparse(dense, blocking)
+ cuda_device = torch.device("cuda")
+ return (dense.to(cuda_device).requires_grad_(True),
+ sparse.to(cuda_device).requires_grad_(True))
+
+
+def _dense(rows, cols, dtype, std=0.1):
+ cuda_device = torch.device("cuda")
+ out = (torch.randn(rows, cols) * std).type(dtype)
+ return out.to(cuda_device).requires_grad_(True)
+
+
+def _dense_2x(rows, cols, dtype):
+ a = _dense(rows, cols, dtype)
+ return a, a.detach().requires_grad_(True)
+
+
+def _with_transpose(op, a, b, trans_a, trans_b):
+ a = a.t() if trans_a else a
+ b = b.t() if trans_b else b
+ return op(a, b)
+
+
+def _mmm(a, b, topo):
+ mask = stk.ops.to_dense(stk.ops.ones_like(topo))
+ return torch.mm(a, b) * mask
+
+
+def _sparse_out_with_transpose(op, a, b, topo, trans_a, trans_b):
+ a = a.t() if trans_a else a
+ b = b.t() if trans_b else b
+ return op(a, b, topo)
+
+
+def _mask(x, mask):
+ mask = stk.ops.to_dense(stk.ops.ones_like(mask))
+ return x * mask
+
+
+# @parameterized.parameters(*_LINEAR_OP_TESTS)
+# class LinearOpsTest(parameterized.TestCase):
+#
+# def testLinearOps_Dsd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+# # Construct the operands.
+# a_shape = (k, m) if trans_a else (m, k)
+# a_dense, a = _dense_and_sparse(*a_shape, sparsity, blocking, dtype)
+# b_shape = (n, k) if trans_b else (k, n)
+# b, bcp = _dense_2x(*b_shape, dtype)
+#
+# # Execute the matmul.
+# out = _with_transpose(stk.ops.dsd, a, b, trans_a, trans_b)
+# expected_out = _with_transpose(torch.mm, a_dense, bcp, trans_a, trans_b)
+#
+# # Compute the gradients w.r.t. the inputs.
+# expected_out.sum().backward()
+# out.sum().backward()
+#
+# # Validate the results.
+# self.assertEqual(out.dim(), 2)
+# self.assertEqual(expected_out.size()[0], out.size()[0])
+# self.assertEqual(expected_out.size()[1], out.size()[1])
+# self.assertTrue(allclose(out, expected_out))
+#
+# # LHS gradient.
+# grad = stk.ops.to_dense(a.grad)
+# expected_grad = _mask(a_dense.grad, a.grad)
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# # RHS gradient.
+# grad = b.grad
+# expected_grad = bcp.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# def testLinearOps_Dds(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+# # Construct the operands.
+# a_shape = (k, m) if trans_a else (m, k)
+# a, acp = _dense_2x(*a_shape, dtype)
+# b_shape = (n, k) if trans_b else (k, n)
+# b_dense, b = _dense_and_sparse(*b_shape, sparsity, blocking, dtype)
+#
+# # Execute the matmul.
+# out = _with_transpose(stk.ops.dds, a, b, trans_a, trans_b)
+# expected_out = _with_transpose(torch.mm, acp, b_dense, trans_a, trans_b)
+#
+# # Compute the gradients w.r.t. the inputs.
+# expected_out.sum().backward()
+# out.sum().backward()
+#
+# # Validate the results.
+# self.assertEqual(out.dim(), 2)
+# self.assertEqual(expected_out.size()[0], out.size()[0])
+# self.assertEqual(expected_out.size()[1], out.size()[1])
+# self.assertTrue(allclose(out, expected_out))
+#
+# # LHS gradient.
+# grad = a.grad
+# expected_grad = acp.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# # RHS gradient.
+# grad = stk.ops.to_dense(b.grad)
+# expected_grad = _mask(b_dense.grad, b.grad)
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# def testLinearOps_Sdd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+# # Construct the operands.
+# a_shape = (k, m) if trans_a else (m, k)
+# a, acp = _dense_2x(*a_shape, dtype)
+# b_shape = (n, k) if trans_b else (k, n)
+# b, bcp = _dense_2x(*b_shape, dtype)
+# _, topo = _dense_and_sparse(m, n, sparsity, blocking, dtype)
+#
+# # Execute the matmul.
+# out = _sparse_out_with_transpose(stk.ops.sdd, a, b, topo, trans_a, trans_b)
+# expected_out = _sparse_out_with_transpose(_mmm, acp, bcp, topo, trans_a, trans_b)
+#
+# # Compute the gradients w.r.t. the inputs.
+# expected_out.sum().backward()
+# stk.ops.sum(out).backward()
+#
+# # Validate the results.
+# out = stk.ops.to_dense(out)
+# self.assertEqual(out.dim(), 2)
+# self.assertEqual(expected_out.size()[0], out.size()[0])
+# self.assertEqual(expected_out.size()[1], out.size()[1])
+# self.assertTrue(allclose(out, expected_out))
+#
+# # LHS gradient.
+# grad = a.grad
+# expected_grad = acp.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# # RHS gradient.
+# grad = b.grad
+# expected_grad = bcp.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch210-cxx11-cu126-x86_64-linux/stk/ops/matrix_ops.py b/build/torch210-cxx11-cu126-x86_64-linux/stk/ops/matrix_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..447c72dc73439d84f58c917676cc04e64f13e97d
--- /dev/null
+++ b/build/torch210-cxx11-cu126-x86_64-linux/stk/ops/matrix_ops.py
@@ -0,0 +1,98 @@
+from ..backend import sputnik
+from ..matrix import Matrix
+import torch
+import numpy as np
+
+
+@torch.no_grad()
+def row_indices(shape, data, offsets, column_indices):
+ return sputnik.row_indices(shape, data, offsets, column_indices)
+
+
+# TODO(tgale): Replace this helper with a custom kernel. This operation
+# is much simpler to do than how it's currently implemented.
+@torch.no_grad()
+def _expand_for_blocking(idxs, blocking):
+ # Duplicate for block column dimension.
+ idxs = torch.reshape(idxs, [idxs.size()[0], 1, 2]).repeat(1, blocking, 1)
+
+ # Update the column indices.
+ idxs[:, :, 1] *= blocking
+ idxs[:, :, 1] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking])
+
+ # Duplicate for block row dimension.
+ idxs = torch.reshape(idxs, [idxs.size()[0], 1, blocking, 2])
+ idxs = idxs.repeat(1, blocking, 1, 1)
+
+ # Update the row indices.
+ idxs[:, :, :, 0] *= blocking
+ idxs[:, :, :, 0] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking, 1])
+ idxs = torch.reshape(idxs, [-1, 2])
+ return idxs
+
+
+# TODO(tgale): Add input type checking.
+@torch.no_grad()
+def to_dense(x):
+ assert isinstance(x, Matrix)
+
+ shape = (np.prod(x.shape[:-1]), x.shape[-1])
+ row_idxs = x.row_indices.type(torch.int32)
+ col_idxs = x.column_indices.type(torch.int32)
+ indices = _expand_for_blocking(torch.stack([row_idxs, col_idxs], dim=1), x.blocking)
+ indices = (indices[:, 0] * shape[1] + indices[:, 1]).type(torch.int64)
+
+ out = torch.zeros(shape[0] * shape[1], dtype=x.dtype, device=x.device)
+ out.scatter_(0, indices, x.data.flatten())
+ return out.reshape(x.size())
+
+
+@torch.no_grad()
+def _mask(x, blocking=1):
+ assert x.dim() == 2
+ assert x.size()[0] % blocking == 0
+ assert x.size()[1] % blocking == 0
+ block_rows = x.size()[0] // blocking
+ block_cols = x.size()[1] // blocking
+ x = torch.reshape(x, [block_rows, blocking, block_cols, blocking])
+ x = torch.sum(torch.abs(x), dim=(1, 3))
+ return x != 0
+
+
+# TODO(tgale): Add input type checking.
+@torch.no_grad()
+def to_sparse(x, blocking=1):
+ m = _mask(x, blocking)
+
+ # TODO(tgale): Set to appropriate type for input matrix.
+ row_nnzs = torch.sum(m, dim=1).type(torch.int32)
+ zeros = torch.zeros((1,), dtype=row_nnzs.dtype, device=row_nnzs.device)
+ offsets = torch.cat([zeros, torch.cumsum(row_nnzs, dim=0)])
+ offsets = offsets.type(torch.int32)
+
+ indices = torch.nonzero(m).type(torch.int16)
+ row_indices = indices[:, 0]
+ column_indices = indices[:, 1]
+
+ # Nonzero indices in the dense matrix.
+ nonzero_indices = torch.nonzero(m)
+ nonzero_indices = _expand_for_blocking(nonzero_indices, blocking)
+ nonzero_indices = nonzero_indices[:, 0] * x.size()[1] + nonzero_indices[:, 1]
+
+ # Gather the data and construct the sparse matrix.
+ data = torch.gather(x.flatten(), dim=0, index=nonzero_indices)
+ data = torch.reshape(data, [-1, blocking, blocking])
+ return Matrix(x.size(), data, row_indices, column_indices, offsets)
+
+
+@torch.no_grad()
+def ones_like(x):
+ return Matrix(x.size(),
+ torch.ones_like(x.data),
+ x.row_indices,
+ x.column_indices, x.offsets)
+
+
+def sum(x):
+ assert isinstance(x, Matrix)
+ return x.data.sum()
diff --git a/build/torch210-cxx11-cu126-x86_64-linux/stk/ops/matrix_ops_test.py b/build/torch210-cxx11-cu126-x86_64-linux/stk/ops/matrix_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..d172d921f6f08b0e4fb709207a458b0e1e071bd0
--- /dev/null
+++ b/build/torch210-cxx11-cu126-x86_64-linux/stk/ops/matrix_ops_test.py
@@ -0,0 +1,62 @@
+import unittest
+
+# from absl.testing import parameterized
+import stk
+import torch
+
+
+# @parameterized.parameters(
+# (8, 16, 0.0, 1),
+# (8, 16, 0.5, 1),
+# (8, 16, .95, 1),
+# (16, 8, 0.0, 1),
+# (16, 8, 0.5, 1),
+# (16, 8, .95, 1),
+# (8, 16, 0.0, 8),
+# (8, 16, 0.5, 8),
+# (8, 16, 1.0, 8),
+# (16, 8, 0.0, 8),
+# (16, 8, 0.5, 8),
+# (16, 8, 1.0, 8),
+# (128, 256, 0.5, 16),
+# (256, 128, 0.75, 32),
+# (512, 512, .875, 128))
+# class MatrixOpsTest(parameterized.TestCase):
+#
+# def testMatrixOps_FormatConversion(self, rows, cols, sparsity, blocking):
+# mask = stk.random.dense_mask(rows, cols, sparsity, blocking)
+# x = (torch.randn(rows, cols) * mask).type(torch.float16)
+#
+# # Convert the matrix to sparse format.
+# sparse_x = stk.ops.to_sparse(x, blocking)
+#
+# # Validate the matrix.
+# sparse_x.validate()
+#
+# # Validate the shape.
+# self.assertEqual(sparse_x.dim(), 2)
+# self.assertEqual(sparse_x.size()[0], rows)
+# self.assertEqual(sparse_x.size()[1], cols)
+#
+# # Validate the sparsity.
+# numblocks = rows // blocking * cols // blocking
+# nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+# self.assertEqual(sparse_x.nnz, nnz)
+#
+# # Convert back to dense format.
+# dense_x = stk.ops.to_dense(sparse_x)
+#
+# # Validate the shape.
+# self.assertEqual(dense_x.dim(), 2)
+# self.assertEqual(dense_x.size()[0], rows)
+# self.assertEqual(dense_x.size()[1], cols)
+#
+# # Validate the sparsity
+# self.assertEqual(torch.count_nonzero(dense_x).item(), nnz)
+#
+# # Validate the output.
+# self.assertTrue(torch.all(torch.eq(x, dense_x)))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch210-cxx11-cu126-x86_64-linux/stk/random/__init__.py b/build/torch210-cxx11-cu126-x86_64-linux/stk/random/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2576d1ca27283f77569a9a620c7c99fa68aaf30e
--- /dev/null
+++ b/build/torch210-cxx11-cu126-x86_64-linux/stk/random/__init__.py
@@ -0,0 +1,2 @@
+# from stk.random.random_ops import dense_mask, mask, randn
+from .random_ops import dense_mask, mask, randn
diff --git a/build/torch210-cxx11-cu126-x86_64-linux/stk/random/random_ops.py b/build/torch210-cxx11-cu126-x86_64-linux/stk/random/random_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1b36771e0eb8e7abf46bcb3b136b5fb1d29df93
--- /dev/null
+++ b/build/torch210-cxx11-cu126-x86_64-linux/stk/random/random_ops.py
@@ -0,0 +1,36 @@
+import numpy as np
+import torch
+from ..ops import matrix_ops
+
+
+@torch.no_grad()
+def dense_mask(rows, cols, sparsity, blocking=1):
+ assert sparsity >= 0.0 and sparsity <= 1.0
+ assert rows % blocking == 0 and cols % blocking == 0
+
+ block_rows, block_cols = (rows // blocking, cols // blocking)
+ nnz = round(block_rows * block_cols * (1 - sparsity))
+
+ out = np.ones(block_rows * block_cols)
+ mask = np.random.choice(out.size, out.size - nnz, replace=False)
+ out[mask] = 0.0
+
+ out = np.tile(
+ np.reshape(out, [block_rows, 1, block_cols, 1]),
+ (1, blocking, 1, blocking))
+ out = np.reshape(out, [rows, cols])
+ return torch.from_numpy(out.astype(np.float32))
+
+
+@torch.no_grad()
+def mask(m, n, sparsity, blocking=1):
+ out = dense_mask(m, n, sparsity, blocking).type(torch.float16)
+ return matrix_ops.to_sparse(out, blocking=blocking)
+
+
+@torch.no_grad()
+def randn(shape, sparsity, blocking=1):
+ shape_2d = (np.prod(shape[:-1]), shape[-1])
+ out = mask(*shape_2d, sparsity, blocking)
+ out.data.copy_(torch.randn(*out.data.shape))
+ return out.view(*shape)
diff --git a/build/torch210-cxx11-cu126-x86_64-linux/stk/random/random_ops_test.py b/build/torch210-cxx11-cu126-x86_64-linux/stk/random/random_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..71d716b78b5ec009cbf9ac2dfdf09162a0102e62
--- /dev/null
+++ b/build/torch210-cxx11-cu126-x86_64-linux/stk/random/random_ops_test.py
@@ -0,0 +1,73 @@
+import unittest
+
+# from absl.testing import parameterized
+from . import random
+import torch
+
+
+# @parameterized.parameters(
+# (8, 16, 0.0, 1),
+# (8, 16, 0.5, 1),
+# (8, 16, .95, 1),
+# (16, 8, 0.0, 1),
+# (16, 8, 0.5, 1),
+# (16, 8, .95, 1),
+# (8, 16, 0.0, 8),
+# (8, 16, 0.5, 8),
+# (8, 16, 1.0, 8),
+# (16, 8, 0.0, 8),
+# (16, 8, 0.5, 8),
+# (16, 8, 1.0, 8),
+# (128, 256, 0.5, 16),
+# (256, 128, 0.75, 32),
+# (512, 512, .875, 128))
+# class RandomOpsTest(parameterized.TestCase):
+#
+# def testRandomOps_DenseMask(self, rows, cols, sparsity, blocking):
+# mask = random.dense_mask(
+# rows, cols, sparsity, blocking)
+#
+# # Validate the shape.
+# self.assertEqual(mask.dim(), 2)
+# self.assertEqual(mask.size()[0], rows)
+# self.assertEqual(mask.size()[1], cols)
+#
+# # Validate the sparsity
+# numblocks = rows // blocking * cols // blocking
+# nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+# self.assertEqual(
+# torch.count_nonzero(mask).item(),
+# nnz)
+#
+# # Check values are zero or one.
+# self.assertTrue(
+# torch.all(torch.logical_or(
+# torch.eq(mask, 0),
+# torch.eq(mask, 1))))
+#
+# def testRandomOps_SparseMask(self, rows, cols, sparsity, blocking):
+# mask = random.mask(
+# rows, cols, sparsity, blocking)
+#
+# # Validate the matrix.
+# mask.validate()
+#
+# # Validate the shape.
+# self.assertEqual(mask.dim(), 2)
+# self.assertEqual(mask.size()[0], rows)
+# self.assertEqual(mask.size()[1], cols)
+#
+# # Validate the sparsity.
+# numblocks = rows // blocking * cols // blocking
+# nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+# self.assertEqual(mask.nnz, nnz)
+#
+# # Check values are zero or one.
+# self.assertTrue(
+# torch.all(torch.logical_or(
+# torch.eq(mask.data, 0),
+# torch.eq(mask.data, 1))))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch210-cxx11-cu126-x86_64-linux/xpu_fused_moe.py b/build/torch210-cxx11-cu126-x86_64-linux/xpu_fused_moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..0440b484e6f4073a384dcf3bc562174d2ef71d25
--- /dev/null
+++ b/build/torch210-cxx11-cu126-x86_64-linux/xpu_fused_moe.py
@@ -0,0 +1,672 @@
+# SPDX-License-Identifier: Apache-2.0
+# MegaBlocks XPU Fused MoE Implementation
+import os
+import torch
+
+from ._ops import ops, add_op_namespace_prefix
+
+from torch.library import register_fake
+
+
+def resolve_dtensor(weight: torch.Tensor):
+ """Convert DTensor to local tensor for use with custom ops."""
+ from torch.distributed._tensor import DTensor
+ if isinstance(weight, DTensor):
+ return weight.to_local()
+ return weight
+
+
+# Register fake/meta kernels for torch.compile compatibility
+def _register_xpu_fake_kernels():
+ """Register fake kernels for XPU MoE operations to support torch.compile."""
+
+ def _register_if_available(op_name, fn):
+ if hasattr(ops, op_name):
+ register_fake(add_op_namespace_prefix(op_name))(fn)
+
+ _register_if_available(
+ "cutlass_grouped_gemm_interface",
+ lambda ptr_A, ptr_B, ptr_scales, ptr_bias, ptr_D, expert_first_token_offset, N, K, num_experts, is_B_int4, is_B_mxfp4: ptr_D,
+ )
+
+ _register_if_available(
+ "fused_moe_prologue",
+ lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, ep_rank, ep_size, num_experts_on_rank: None,
+ )
+
+ _register_if_available(
+ "moe_gather",
+ lambda output, moe_output, topk_weights, permuted_row_to_unpermuted_row, unpermuted_row_to_permuted_row, expert_first_token_offset, num_experts: None,
+ )
+
+ _register_if_available(
+ "silu_and_mul",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "mul_and_silu",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_and_mul",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_tanh_and_mul",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_fast",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_new",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_quick",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "swigluoai_and_mul",
+ lambda out, input, alpha=1.702, limit=7.0: None,
+ )
+
+
+# Register fake kernels on module load
+_register_xpu_fake_kernels()
+
+
+# default
+def cutlass_grouped_gemm(input_A, input_B, bias, output, expert_token_count, n,
+ k, num_experts):
+ # expert_token_count_ = torch.tensor(expert_token_count,
+ # dtype=torch.int64,
+ # device=input_A.device)
+ # if bias is not None:
+ # bias = bias.repeat_interleave(expert_token_count_, dim=0).float()
+
+ def exclusive_prefix_sum(arr):
+ prefix = [0]
+ for i, x in enumerate(arr):
+ prefix.append(prefix[-1] + x)
+ return prefix
+
+ expert_offset = torch.tensor(exclusive_prefix_sum(expert_token_count),
+ dtype=torch.int64,
+ device="xpu")
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=None,
+ ptr_bias=bias,
+ ptr_D=output,
+ expert_first_token_offset=expert_offset,
+ N=n,
+ K=k,
+ num_experts=num_experts,
+ is_B_int4=False,
+ is_B_mxfp4=False)
+
+
+def cutlass_grouped_gemm_xe2(input_A, input_B, scales, bias, output,
+ num_rows_per_expert, n, k, num_experts, is_B_int4,
+ is_B_mxfp4):
+ expert_first_token_offset = torch.cat([
+ torch.tensor([0],
+ dtype=num_rows_per_expert.dtype,
+ device=num_rows_per_expert.device),
+ torch.cumsum(num_rows_per_expert, dim=0)
+ ]).to(torch.int64)
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=scales,
+ ptr_bias=bias,
+ ptr_D=output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=n,
+ K=k,
+ num_experts=num_experts,
+ is_B_int4=is_B_int4,
+ is_B_mxfp4=is_B_mxfp4)
+
+
+def ceilDiv(a, b):
+ return (a + b - 1) // b
+
+
+def compute_num_tokens_per_block(num_tokens, num_experts_per_node):
+ for num_tokens_per_block in [32, 64, 128, 256, 512, 1024]:
+ num_blocks_per_seq = ceilDiv(num_tokens, num_tokens_per_block)
+ if num_blocks_per_seq * num_experts_per_node <= num_tokens_per_block:
+ return num_tokens_per_block
+ return 1024
+
+
+def _bytes_to_typed_tensor(byte_tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
+ """Reinterpret a uint8 buffer as a typed tensor by copying bytes.
+
+ This avoids `Tensor.view(dtype)` which can fail under torch.compile
+ constant folding when shape divisibility is not proven.
+ """
+ if byte_tensor.dtype != torch.uint8:
+ raise ValueError("byte_tensor must be uint8")
+ itemsize = torch.empty((), dtype=dtype).element_size()
+ numel = byte_tensor.numel() // itemsize
+ out = torch.empty((numel,), dtype=dtype, device=byte_tensor.device)
+ out.view(torch.uint8).copy_(byte_tensor.contiguous())
+ return out
+
+
+def implement_zp(qweight):
+ # change u4 to s4 to avoid zero point in gemm kernel
+ # only support default zero point now
+ assert qweight.dtype == torch.uint8, "Input tensor must be uint8"
+
+ high_u4 = (qweight >> 4) & 0x0F
+ low_u4 = qweight & 0x0F
+
+ high_s8 = high_u4.to(torch.int8)
+ low_s8 = low_u4.to(torch.int8)
+
+ high_s8 = high_s8 - 8
+ low_s8 = low_s8 - 8
+
+ def pack_compact(a, b):
+
+ def process_number(x):
+ sign = (x < 0).to(torch.uint8)
+ abs_low3 = (x.view(torch.uint8) & 0x7).to(torch.uint8)
+ return (sign << 3) | abs_low3
+
+ packed_a = process_number(a)
+ packed_b = process_number(b)
+
+ return (packed_a << 4) | packed_b
+
+ result = pack_compact(high_s8, low_s8)
+
+ return result
+
+
+def xpu_fused_moe(hidden_states,
+ w13,
+ w13_scales,
+ w13_bias,
+ w2,
+ w2_scales,
+ w2_bias,
+ topk_weights,
+ topk_ids,
+ n_experts_per_token,
+ activation,
+ num_experts,
+ ep_rank=0,
+ ep_size=1,
+ is_fp8=False,
+ is_int4=False,
+ is_mxfp4=False):
+ '''
+ hidden_states: [num_rows, hidden_size]
+ w13: [num_experts, 2*inter_size, hidden_size]
+ w13_scales:
+ None for bf16/fp16
+ or [num_experts] for fp8
+ or [num_experts, 2*inter_size, hidden_size // group_size] for 4bits
+ w13_bias: [num_experts, 2*inter_size] or None
+ w2: [num_experts, hidden_size, inter_size]
+ w2_scales:
+ None for bf16/fp16
+ or [num_experts] for fp8
+ or [num_experts, hidden_size, inter_size // group_size] for 4bits
+ w2_bias: [num_experts, hidden_size] or None
+ topk_weights: [num_rows, topk]
+ topk_ids: [num_rows, topk]
+ n_experts_per_token: int
+ activation: str
+ num_experts: int
+ is_int4: bool
+ is_mxfp4: bool
+ '''
+
+ # Resolve DTensors to local tensors before passing to custom ops
+ hidden_states = resolve_dtensor(hidden_states)
+ w13 = resolve_dtensor(w13)
+ w2 = resolve_dtensor(w2)
+ if w13_scales is not None:
+ w13_scales = resolve_dtensor(w13_scales)
+ if w13_bias is not None:
+ w13_bias = resolve_dtensor(w13_bias)
+ if w2_scales is not None:
+ w2_scales = resolve_dtensor(w2_scales)
+ if w2_bias is not None:
+ w2_bias = resolve_dtensor(w2_bias)
+ topk_weights = resolve_dtensor(topk_weights)
+ topk_ids = resolve_dtensor(topk_ids)
+
+ output = torch.empty_like(hidden_states)
+ num_rows, hidden_size = list(hidden_states.shape)
+
+ dim_last = w13.shape[-1]
+ dim_second_last = w13.shape[-2]
+
+ # w13 is combined gate+up weights, so one dimension is 2*inter_size
+ # Determine which dimension is hidden_size and which is 2*inter_size
+ if dim_second_last == hidden_size:
+ # w13 is [E, hidden_size, 2*inter_size] - standard layout
+ inter_size = dim_last // 2
+ needs_transpose = False
+ else:
+ # w13 is [E, 2*inter_size, hidden_size] - needs transpose
+ inter_size = dim_second_last // 2
+ needs_transpose = True
+
+ assert w13.is_contiguous() and w2.is_contiguous()
+
+ # 4bits support [E, N, K]
+ # other types [E, K, N]
+ if not is_int4 and not is_mxfp4:
+ if not hasattr(w13, 'xpu_fused_moe'):
+ if needs_transpose:
+ w13.data = w13.transpose(-1, -2).contiguous()
+ w2.data = w2.transpose(-1, -2).contiguous()
+ w13.xpu_fused_moe = True
+ w13.inter_size = inter_size
+ else:
+ inter_size = w13.inter_size
+
+ if is_int4 and not hasattr(w13, 'xpu_fused_moe'):
+ w13_tmp = torch.empty_like(w13)
+ w2_tmp = torch.empty_like(w2)
+ for i in range(num_experts):
+ w13_tmp[i] = implement_zp(w13[i])
+ w2_tmp[i] = implement_zp(w2[i])
+ w13_tmp = w13_tmp.contiguous()
+ w2_tmp = w2_tmp.contiguous()
+ w13.data = w13_tmp
+ w2.data = w2_tmp
+ w13.xpu_fused_moe = True
+
+ # TODO: will all integrated in Cpp func. Temporary expose before gemm fusion
+ num_experts_per_node = num_experts
+ experts_per_token = n_experts_per_token
+ num_moe_inputs = n_experts_per_token * num_rows
+ permuted_elems = num_moe_inputs * hidden_size
+ # interbuf_elems = num_moe_inputs * inter_size
+ permuted_row_to_unpermuted_row_size = num_moe_inputs * 4
+ permuted_token_selected_experts_size = num_moe_inputs * 4
+ src_to_dest_map_size = experts_per_token * num_rows * 4
+ expert_first_token_offset_size = (num_experts_per_node + 1) * 8
+ num_tokens_per_block = compute_num_tokens_per_block(
+ num_rows, num_experts_per_node)
+ num_blocks_per_seq = ceilDiv(num_rows, num_tokens_per_block)
+ blocked_expert_counts_size = num_experts_per_node * num_blocks_per_seq * 4
+ blocked_expert_counts_cumsum_size = blocked_expert_counts_size
+ blocked_row_to_unpermuted_row_size = num_experts_per_node * num_rows * 4
+ permuted_data_size = permuted_elems * hidden_states.element_size()
+ permuted_token_final_scales_size = num_moe_inputs * 4
+
+ ws_map = {}
+ map_offset = 0
+
+ def config_ws(name, size):
+ nonlocal map_offset
+ if size % 256 != 0:
+ size += 256 - size % 256
+ ws_map[name] = (size, map_offset)
+ map_offset += size
+
+ config_ws("permuted_row_to_unpermuted_row",
+ permuted_row_to_unpermuted_row_size)
+ config_ws("permuted_token_selected_experts",
+ permuted_token_selected_experts_size)
+ config_ws("unpermuted_row_to_permuted_row", src_to_dest_map_size)
+ config_ws("blocked_expert_counts", blocked_expert_counts_size)
+ config_ws("blocked_expert_counts_cumsum",
+ blocked_expert_counts_cumsum_size)
+ config_ws("blocked_row_to_unpermuted_row",
+ blocked_row_to_unpermuted_row_size)
+ config_ws("expert_first_token_offset", expert_first_token_offset_size)
+ config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
+ config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
+
+ workspace = torch.zeros(map_offset,
+ dtype=torch.uint8,
+ device=hidden_states.device)
+ if topk_ids.dtype == torch.int32:
+ topk_ids = topk_ids.to(torch.int64)
+ ops.fused_moe_prologue(
+ input=hidden_states,
+ token_selected_experts=topk_ids,
+ token_final_scales=topk_weights,
+ workspace=workspace,
+ hidden_size=hidden_size,
+ inter_size=inter_size,
+ ep_rank=ep_rank,
+ ep_size=ep_size,
+ num_experts_on_rank=num_experts_per_node)
+
+ expert_first_token_offset_bytes = workspace[
+ ws_map["expert_first_token_offset"][1]:
+ ws_map["expert_first_token_offset"][1] +
+ expert_first_token_offset_size]
+ unpermuted_row_to_permuted_row_bytes = workspace[
+ ws_map["unpermuted_row_to_permuted_row"][1]:
+ ws_map["unpermuted_row_to_permuted_row"][1] +
+ src_to_dest_map_size]
+ permuted_row_to_unpermuted_row_bytes = workspace[
+ ws_map["permuted_row_to_unpermuted_row"][1]:
+ ws_map["permuted_row_to_unpermuted_row"][1] +
+ permuted_row_to_unpermuted_row_size]
+
+ if torch.compiler.is_compiling():
+ expert_first_token_offset = _bytes_to_typed_tensor(
+ expert_first_token_offset_bytes, torch.int64
+ )
+ unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
+ unpermuted_row_to_permuted_row_bytes, torch.int32
+ )
+ permuted_row_to_unpermuted_row = _bytes_to_typed_tensor(
+ permuted_row_to_unpermuted_row_bytes, torch.int32
+ )
+ else:
+ expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
+ unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
+ permuted_row_to_unpermuted_row = permuted_row_to_unpermuted_row_bytes.view(torch.int32)
+ gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
+ ws_map["overlapped_gemm1_gemm2_inputs"][1] +
+ permuted_data_size].view(hidden_states.dtype).view(
+ num_moe_inputs, hidden_size)
+ # permuted_token_final_scales = workspace[
+ # ws_map["permuted_token_final_scales"][1]:
+ # ws_map["permuted_token_final_scales"][1] +
+ # permuted_token_final_scales_size].view(torch.float)
+ gemm1_output = torch.empty((num_moe_inputs, 2 * inter_size),
+ dtype=hidden_states.dtype,
+ device=hidden_states.device)
+
+ ########### gemm1 ##################
+ input_B = w13
+
+ if not is_fp8 and not is_int4 and not is_mxfp4:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=gemm1_input,
+ ptr_B=input_B,
+ ptr_scales=None,
+ ptr_bias=w13_bias,
+ ptr_D=gemm1_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=2 * inter_size,
+ K=hidden_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+ else:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=gemm1_input,
+ ptr_B=input_B,
+ ptr_scales=w13_scales,
+ ptr_bias=w13_bias,
+ ptr_D=gemm1_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=2 * inter_size,
+ K=hidden_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+
+ # act
+ act_output = torch.empty((num_moe_inputs, inter_size),
+ dtype=gemm1_output.dtype,
+ device=gemm1_output.device)
+ if activation == "silu":
+ ops.silu_and_mul(act_output, gemm1_output)
+ elif activation == "gelu":
+ ops.gelu_and_mul(act_output, gemm1_output)
+ elif activation == "swigluoai":
+ ops.swigluoai_and_mul(act_output, gemm1_output, 1.702, 7.0)
+ else:
+ raise ValueError(f"Unsupported FusedMoe activation: {activation}.")
+
+ ########### gemm2 ##################
+ input_A = act_output.contiguous()
+ input_B = w2
+ gemm2_output = torch.empty((num_moe_inputs, hidden_size),
+ dtype=hidden_states.dtype,
+ device=hidden_states.device)
+ if not is_fp8 and not is_int4 and not is_mxfp4:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=None,
+ ptr_bias=w2_bias,
+ ptr_D=gemm2_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=hidden_size,
+ K=inter_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+ else:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=w2_scales,
+ ptr_bias=w2_bias,
+ ptr_D=gemm2_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=hidden_size,
+ K=inter_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+
+ ops.moe_gather(output, gemm2_output, topk_weights,
+ permuted_row_to_unpermuted_row,
+ unpermuted_row_to_permuted_row,
+ expert_first_token_offset,
+ num_experts_per_node)
+ return output
+
+
+def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor:
+ """Apply jitter to the input tensor for regularization."""
+ low = 1.0 - moe_jitter_eps
+ high = 1.0 + moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return x * (low + noise * (high - low))
+
+
+def compute_top_k(scores: torch.Tensor, moe_top_k: int):
+ """Compute the top-k scores from the logits."""
+ if moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, moe_top_k, dim=-1)
+
+
+def route_tokens_xpu(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: torch.Tensor,
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ training: bool = False,
+) -> tuple:
+ """Route tokens to experts and compute expert weights and indices (XPU version)."""
+ if training and moe_jitter_eps is not None:
+ x = apply_jitter(x, moe_jitter_eps)
+
+ x_flat = x.view(-1, x.shape[-1])
+ logits = torch.nn.functional.linear(x_flat, router_weight, router_bias)
+ expert_weights, expert_indices = compute_top_k(logits, moe_top_k)
+ expert_weights = expert_weights.softmax(dim=-1)
+ if moe_normalize_expert_weights is not None:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+
+ return logits, expert_weights, expert_indices
+
+
+def _get_device_mesh(model):
+ """Extract device_mesh from child's unused pre_hook closure for EP support."""
+ try:
+ hook = next(
+ h
+ for h in model.experts._forward_pre_hooks.values()
+ if "device_mesh" in h.__code__.co_freevars
+ )
+ return hook.__closure__[
+ hook.__code__.co_freevars.index("device_mesh")
+ ].cell_contents
+ except Exception:
+ return None
+
+
+class MegaBlocksMoeMLP(torch.nn.Module):
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> tuple:
+ """
+ Forward pass through the MoE layer.
+
+ Args:
+ x: Input tensor of shape [batch_size, seq_len, hidden_size] or [tokens, hidden_size]
+
+ Returns:
+ Tuple of (output, expert_weights) where:
+ - output: Tensor of same shape as input
+ - expert_weights: Expert weights for each token [tokens, top_k]
+ """
+
+ # Get MoE parameters from the wrapped modules
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+
+ # Get EP (Expert Parallelism) parameters
+ ep_size = 1
+ ep_rank = 0
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = _get_device_mesh(self)
+ if device_mesh is not None:
+ expert_parallel_group = device_mesh.get_group()
+ if expert_parallel_group is not None:
+ import torch.distributed as dist
+ if dist.is_initialized():
+ ep_size = dist.get_world_size(expert_parallel_group)
+ ep_rank = dist.get_rank(expert_parallel_group)
+
+ # Number of experts on this rank
+ num_experts_on_rank = moe_num_experts // ep_size
+
+ # Detect activation type - check for GptOss-style swigluoai activation
+ # GptOssExperts has alpha and limit attributes for swigluoai
+ if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
+ activation = "swigluoai"
+ else:
+ activation = getattr(self.experts, "activation", "silu")
+
+ # Get weight tensors - support different naming conventions
+ if hasattr(self.experts, "gate_up_proj"):
+ w13 = self.experts.gate_up_proj
+ # NOTE: swigluoai_and_mul kernel expects interleaved layout [g0,u0,g1,u1,...]
+ # which matches GptOss's gate_up_proj layout, so no conversion needed.
+
+ elif hasattr(self.experts, "w1"):
+ # Combine w1 and w3 if stored separately
+ w1 = self.experts.w1
+ w3 = getattr(self.experts, "w3", None)
+ if w3 is not None:
+ w13 = torch.cat([w1, w3], dim=-2)
+ else:
+ w13 = w1
+ else:
+ raise AttributeError("experts module must have 'gate_up_proj' or 'w1' attribute")
+
+ if hasattr(self.experts, "down_proj"):
+ w2 = self.experts.down_proj
+ elif hasattr(self.experts, "w2"):
+ w2 = self.experts.w2
+ else:
+ raise AttributeError("experts module must have 'down_proj' or 'w2' attribute")
+
+ # Get optional bias tensors
+ w13_bias = getattr(self.experts, "gate_up_proj_bias", None)
+ w2_bias = getattr(self.experts, "down_proj_bias", None)
+
+ # Get quantization info
+ is_fp8 = getattr(self.experts, "is_fp8", False)
+ is_int4 = getattr(self.experts, "is_int4", False)
+ is_mxfp4 = getattr(self.experts, "is_mxfp4", False)
+
+ w13_scales = getattr(self.experts, "gate_up_proj_scales", None)
+ w2_scales = getattr(self.experts, "down_proj_scales", None)
+
+ # Store original shape
+ in_shape = x.size()
+
+ # Route tokens to experts
+ logits, expert_weights, expert_indices = route_tokens_xpu(
+ x,
+ self.router.weight,
+ getattr(self.router, "bias", None),
+ moe_top_k,
+ moe_num_experts,
+ moe_jitter_eps,
+ moe_normalize_expert_weights,
+ self.training,
+ )
+
+ # Reshape input for fused MoE
+ x_flat = x.view(-1, x.shape[-1])
+
+ # Call XPU fused MoE kernel
+ output = xpu_fused_moe(
+ hidden_states=x_flat,
+ w13=w13,
+ w13_scales=w13_scales,
+ w13_bias=w13_bias,
+ w2=w2,
+ w2_scales=w2_scales,
+ w2_bias=w2_bias,
+ topk_weights=expert_weights.float(),
+ topk_ids=expert_indices,
+ n_experts_per_token=moe_top_k,
+ activation=activation,
+ num_experts=num_experts_on_rank,
+ ep_rank=ep_rank,
+ ep_size=ep_size,
+ is_fp8=is_fp8,
+ is_int4=is_int4,
+ is_mxfp4=is_mxfp4,
+ )
+
+ # All-reduce across EP group to combine partial expert outputs
+ if ep_size > 1 and expert_parallel_group is not None:
+ import torch.distributed as dist
+ dist.all_reduce(output, op=dist.ReduceOp.SUM, group=expert_parallel_group)
+
+ # Restore original shape
+ output = output.view(in_shape)
+
+ return output, expert_weights
+
+
+# Export classes and functions
+__all__ = [
+ "MegaBlocksMoeMLP",
+ "xpu_fused_moe",
+ "cutlass_grouped_gemm",
+ "cutlass_grouped_gemm_xe2",
+]
\ No newline at end of file
diff --git a/build/torch210-cxx11-cu128-aarch64-linux/__init__.py b/build/torch210-cxx11-cu128-aarch64-linux/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..38075732c6d8fa0e1e6ef493145e1aca3851ae6b
--- /dev/null
+++ b/build/torch210-cxx11-cu128-aarch64-linux/__init__.py
@@ -0,0 +1,202 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+from ._ops import ops
+
+from .grouped_gemm import backend as gg_backend
+from .grouped_gemm import ops as gg_ops
+
+
+from ._layers.arguments import Arguments
+from ._layers.dmoe import ParallelDroplessMLP, dMoE
+from ._layers.glu import SparseGLU
+from ._layers.mlp import MLP, SparseMLP
+from ._layers.moe import MoE, ParallelMLP, get_load_balancing_loss
+
+from . import layers
+
+# This section contains the direct kernel exports (not inlcuded in the original code)
+def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
+ """
+ Compute exclusive cumulative sum along the specified dimension.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ result = ops.exclusive_cumsum(x, dim)
+ out.copy_(result)
+ return out
+
+
+def inclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
+ """
+ Compute inclusive cumulative sum along the specified dimension.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ result = ops.inclusive_cumsum(x, dim)
+ out.copy_(result)
+ return out
+
+
+def histogram(x: torch.Tensor, num_bins: int) -> torch.Tensor:
+ """
+ Compute histogram of input tensor values.
+
+ Args:
+ x: Input tensor
+ num_bins: Number of histogram bins
+
+ Returns:
+ Histogram tensor with counts for each bin
+ """
+ return ops.histogram(x, num_bins)
+
+
+def indices(
+ padded_bins: torch.Tensor,
+ block_size: int,
+ output_block_rows: int,
+ output_block_columns: int,
+) -> torch.Tensor:
+ """
+ Construct indices from padded bins for sparse operations.
+
+ Args:
+ padded_bins: Tensor containing bin boundaries
+ block_size: Size of each block
+ output_block_rows: Number of rows in output blocks
+ output_block_columns: Number of columns in output blocks
+
+ Returns:
+ Tensor containing constructed indices
+ """
+ return ops.indices(padded_bins, block_size, output_block_rows, output_block_columns)
+
+
+def replicate_forward(
+ x: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Forward pass of replicate operation - replicate values according to bin sizes.
+
+ Args:
+ x: Input tensor with values to replicate
+ bins: Tensor containing bin sizes
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ return ops.replicate_forward(x, bins, out)
+
+
+def replicate_backward(
+ grad: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Backward pass of replicate operation - reduce gradients back to bins.
+
+ Args:
+ grad: Gradient tensor to reduce
+ bins: Tensor containing bin sizes
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ return ops.replicate_backward(grad, bins, out)
+
+
+def sort(
+ x: torch.Tensor, end_bit: int, x_out: torch.Tensor, iota_out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Radix sort with index tracking.
+
+ Args:
+ x: Input tensor to sort
+ end_bit: Number of bits to consider in sorting
+ x_out: Output tensor for sorted values
+ iota_out: Output tensor for sorted indices
+
+ Returns:
+ The sorted values tensor
+ """
+ return ops.sort(x, end_bit, x_out, iota_out)
+
+
+# Convenience functions for common use cases
+def cumsum(x: torch.Tensor, dim: int = -1, exclusive: bool = False) -> torch.Tensor:
+ """
+ Compute cumulative sum with automatic output allocation.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum (default: last dimension)
+ exclusive: Whether to compute exclusive (True) or inclusive (False) cumsum
+
+ Returns:
+ New tensor containing the cumulative sum
+ """
+ out = torch.empty_like(x)
+ if exclusive:
+ return exclusive_cumsum(x, dim, out)
+ else:
+ return inclusive_cumsum(x, dim, out)
+
+
+def argsort(x: torch.Tensor, end_bit: int = 32) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Sort tensor and return both sorted values and indices.
+
+ Args:
+ x: Input tensor to sort
+ end_bit: Number of bits to consider in sorting
+
+ Returns:
+ Tuple of (sorted_values, sorted_indices)
+ """
+ x_out = torch.empty_like(x)
+ iota_out = torch.empty_like(x)
+ sort(x, end_bit, x_out, iota_out)
+ return x_out, iota_out
+
+
+# Export public API
+__all__ = [
+ "MyReplacementLayer",
+ # Direct kernel exports
+ "exclusive_cumsum",
+ "inclusive_cumsum",
+ "histogram",
+ "indices",
+ "replicate_forward",
+ "replicate_backward",
+ "sort",
+ "cumsum",
+ "argsort",
+ # Original exports
+ "Arguments",
+ "ParallelDroplessMLP",
+ "dMoE",
+ "SparseGLU",
+ "MLP",
+ "SparseMLP",
+ "MoE",
+ "ParallelMLP",
+ "get_load_balancing_loss",
+]
diff --git a/build/torch210-cxx11-cu128-aarch64-linux/_layers/__init__.py b/build/torch210-cxx11-cu128-aarch64-linux/_layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a720e7a2cc4e44636f6e433a2750e945dc38e8b2
--- /dev/null
+++ b/build/torch210-cxx11-cu128-aarch64-linux/_layers/__init__.py
@@ -0,0 +1,10 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# from megablocks.layers.dmoe import dMoE
+from .moe import MoE
+
+__all__ = [
+ 'MoE',
+ # 'dMoE',
+]
diff --git a/build/torch210-cxx11-cu128-aarch64-linux/_layers/activation_fn.py b/build/torch210-cxx11-cu128-aarch64-linux/_layers/activation_fn.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e1d956704840aa4daf7d1d71d24e051567feab9
--- /dev/null
+++ b/build/torch210-cxx11-cu128-aarch64-linux/_layers/activation_fn.py
@@ -0,0 +1,33 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Callable, Union
+
+import torch
+from ..stk import Matrix
+
+
+def act_fn(
+ x: Matrix,
+ function: Callable,
+ return_grad_fn: bool = False,
+ **kwargs,
+) -> Union[tuple[Matrix, Any] | Matrix]:
+ assert isinstance(x, Matrix)
+ with torch.set_grad_enabled(torch.is_grad_enabled() or return_grad_fn):
+ if return_grad_fn:
+ x.data.requires_grad = True
+ out = function(x.data, **kwargs)
+ y = Matrix(
+ x.size(),
+ out,
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
+ if return_grad_fn:
+ return y, out.backward
+ return y
diff --git a/build/torch210-cxx11-cu128-aarch64-linux/_layers/all_to_all.py b/build/torch210-cxx11-cu128-aarch64-linux/_layers/all_to_all.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ac7067bcaa34db1d82b340c43550fe3577aa7a3
--- /dev/null
+++ b/build/torch210-cxx11-cu128-aarch64-linux/_layers/all_to_all.py
@@ -0,0 +1,54 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import torch.distributed as dist
+
+
+class AllToAllOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op):
+ out = torch.empty((sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype)
+
+ ctx.input_shape = x.shape
+ ctx.output_split_sizes = output_split_sizes
+ ctx.input_split_sizes = input_split_sizes
+ ctx.group = group
+ handle = dist.all_to_all_single(
+ out,
+ x,
+ output_split_sizes=output_split_sizes,
+ input_split_sizes=input_split_sizes,
+ group=group,
+ async_op=async_op,
+ )
+ return out, handle
+
+ @staticmethod
+ def backward(ctx, grad, _):
+ if ctx.needs_input_grad[0]:
+ out = torch.empty(
+ ctx.input_shape,
+ device=grad.device,
+ dtype=grad.dtype,
+ )
+ dist.all_to_all_single(
+ out,
+ grad,
+ output_split_sizes=ctx.input_split_sizes,
+ input_split_sizes=ctx.output_split_sizes,
+ group=ctx.group,
+ )
+ return out, None, None, None, None
+ return None, None, None, None, None
+
+
+def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False):
+ return AllToAllOp.apply(
+ x,
+ output_split_sizes,
+ input_split_sizes,
+ group,
+ async_op,
+ )
diff --git a/build/torch210-cxx11-cu128-aarch64-linux/_layers/arguments.py b/build/torch210-cxx11-cu128-aarch64-linux/_layers/arguments.py
new file mode 100644
index 0000000000000000000000000000000000000000..4db9b1bd38bc2e2f421625c124f86b85f45c5ae0
--- /dev/null
+++ b/build/torch210-cxx11-cu128-aarch64-linux/_layers/arguments.py
@@ -0,0 +1,101 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import dataclasses
+from functools import partial
+from typing import Any, Callable, Optional, Union
+
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+
+# import megablocks.grouped_gemm_util as grouped_gemm
+from .. import grouped_gemm_util as grouped_gemm
+
+# Type annotation for in-place Tensor initialization function.
+InitFn = Union[Callable[[torch.Tensor], None], partial[torch.Tensor]]
+
+_ALLOWED_BITWIDTHS = (-1, 4, 8)
+
+DEFAULT_ACTIVATION_FN = partial(F.gelu, approximate='tanh')
+
+
+@dataclasses.dataclass
+class Arguments:
+ # Model arguments.
+ hidden_size: int = 1024
+ ffn_hidden_size: int = 4096
+ num_layers: int = 1
+ bias: bool = True
+ return_bias: bool = True
+ activation_fn: Optional[Callable] = DEFAULT_ACTIVATION_FN
+
+ # MoE arguments.
+ moe_num_experts: int = 1
+ moe_top_k: int = 1
+ moe_capacity_factor: int = 1
+ moe_normalize_expert_weights: Optional[Union[int, float]] = None
+ moe_loss_weight: float = 0.1
+ moe_jitter_eps: Optional[float] = None
+ moe_lbl_in_fp32: bool = False
+
+ # Parallelism arguments.
+ moe_expert_model_parallelism: bool = False
+ expert_parallel_group: Optional[dist.ProcessGroup] = None
+ pipeline_model_parallel_size: int = 1
+ num_layers_per_virtual_pipeline_stage: Optional[int] = None
+
+ # Compute arguments.
+ memory_optimized_mlp: bool = False
+ mlp_type: str = 'mlp'
+ mlp_impl: str = 'sparse'
+
+ # Initialization arguments.
+ fp16: bool = True
+ bf16: bool = False
+ device: Union[int, torch.device] = dataclasses.field(default_factory=torch.cuda.current_device)
+ init_method: InitFn = partial(torch.nn.init.normal_, mean=0.0, std=0.02)
+ output_layer_init_method: InitFn = init_method
+
+ # Benchmarking arguments.
+ uniform_expert_assignment: bool = False
+
+ # shared expert arguments
+ shared_expert: bool = False # enable using shared expert
+ fc_cls: Any = torch.nn.Linear # class of the fully connected layer in shared expert (purpose: to allow using custom FC layer eg te.Linear (for FP8))
+ fc_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict,) # kwargs for custom fc layers
+ remat_act_fn: bool = True # enable act fn to be rematerialized instead of stored
+ shared_expert_hidden_size: Optional[
+ int] = None # hidden size of the shared expert IF we want to set it to something different from hidden_size
+ shared_expert_weighted_sum: bool = False # enable using weighted sum for shared expert output (wieghted by number of experts used)
+
+ # Router Z-loss arguments
+ moe_zloss_weight: float = 0 # 1e-3 is a reasonable value
+ moe_zloss_in_fp32: bool = False
+
+ def __post_init__(self):
+ # Sparse MLP is not supported with triton >=3.2.0
+ # TODO: Remove this once sparse is supported with triton >=3.2.0
+ if self.__getattribute__('mlp_impl') == 'sparse':
+ try:
+ import triton
+ if triton.__version__ >= '3.2.0':
+ raise ValueError(
+ 'Sparse MLP is not supported with triton >=3.2.0. Please use mlp_impl="grouped" instead.',
+ )
+ except ImportError:
+ raise ImportError('Triton is required for sparse MLP implementation')
+
+ if self.__getattribute__('mlp_impl') == 'grouped':
+ grouped_gemm.assert_grouped_gemm_is_available()
+
+ if self.shared_expert_hidden_size is None:
+ self.shared_expert_hidden_size = self.ffn_hidden_size
+
+
+def from_megatron(megatron_args: Any):
+ args = Arguments()
+ for field in dataclasses.fields(args):
+ if hasattr(megatron_args, field.name):
+ setattr(args, field.name, getattr(megatron_args, field.name))
+ return args
diff --git a/build/torch210-cxx11-cu128-aarch64-linux/_layers/common.py b/build/torch210-cxx11-cu128-aarch64-linux/_layers/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d07109702963ba48a3b94ab860807954dfd79c1
--- /dev/null
+++ b/build/torch210-cxx11-cu128-aarch64-linux/_layers/common.py
@@ -0,0 +1,26 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+from .arguments import Arguments
+
+
+def dtype(args: Arguments):
+ if args.fp16:
+ return torch.float16
+ elif args.bf16:
+ return torch.bfloat16
+ return None
+
+
+def cast_if_autocast_enabled(tensor):
+ if torch.is_autocast_enabled():
+ if tensor.device.type == 'cuda':
+ dtype = torch.get_autocast_gpu_dtype()
+ elif tensor.device.type == 'cpu':
+ dtype = torch.get_autocast_cpu_dtype()
+ else:
+ raise NotImplementedError()
+ return tensor.to(dtype=dtype)
+ return tensor
diff --git a/build/torch210-cxx11-cu128-aarch64-linux/_layers/dmlp_registry.py b/build/torch210-cxx11-cu128-aarch64-linux/_layers/dmlp_registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..de2ed047042e438c7190ebb139b6f7f30009734c
--- /dev/null
+++ b/build/torch210-cxx11-cu128-aarch64-linux/_layers/dmlp_registry.py
@@ -0,0 +1,42 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Union
+
+from . import glu, mlp
+from .arguments import Arguments
+
+MlpType = Union[mlp.SparseMLP, glu.SparseGLU]
+
+_REGISTRY = {
+ 'mlp': {
+ 'grouped': mlp.GroupedMLP,
+ 'sparse': mlp.SparseMLP,
+ },
+ 'glu': {
+ 'grouped': glu.GroupedGLU,
+ 'sparse': glu.SparseGLU,
+ },
+}
+
+
+def get(args: Arguments) -> MlpType:
+ """Returns an MLP for use in a dMoE instance.
+
+ Uses the provided arguments to instantiate the appropriate
+ MLP instance. This only contains MLPs for use in dMoEs
+ (ie. only for the dropless versions of MoEs).
+
+ Args:
+ args: propagated Arguments dataclass.
+
+ Returns:
+ An instantiated MLP constructed using the input args.
+ """
+ if args.mlp_type not in _REGISTRY:
+ raise ValueError(f'Unsupported mlp type: {args.mlp_type}')
+
+ if args.mlp_impl not in _REGISTRY[args.mlp_type]:
+ raise ValueError(f'{args.mlp_type} does not support {args.mlp_impl} backend.',)
+
+ return _REGISTRY[args.mlp_type][args.mlp_impl](args)
diff --git a/build/torch210-cxx11-cu128-aarch64-linux/_layers/dmoe.py b/build/torch210-cxx11-cu128-aarch64-linux/_layers/dmoe.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d0375a4df2f27134c4127e60be04f3b45693050
--- /dev/null
+++ b/build/torch210-cxx11-cu128-aarch64-linux/_layers/dmoe.py
@@ -0,0 +1,337 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import numpy as np
+import torch
+
+# try:
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/dmoe.py is needed.',
+# )
+
+# import megablocks.ops as ops
+# # from megablocks.ops import ops
+# from megablocks.layers import common, dmlp_registry, moe, mpu
+# from megablocks.layers.arguments import Arguments
+
+from .. import stk
+from .. import ops
+from . import common, dmlp_registry, moe, mpu
+from .arguments import Arguments
+
+def promote_scalar(x):
+ return x.view(1) if not len(x.size()) else x
+
+
+class ParallelDroplessMLP(moe.ParallelMLP):
+
+ def __init__(self, args: Arguments):
+ super(ParallelDroplessMLP, self).__init__(args)
+ self.hidden_size = args.hidden_size
+ self.ffn_hidden_size = mpu.features_per_rank(args)
+ self.blocking = 128
+ self.mlp = dmlp_registry.get(args)
+
+ # Calculate the number of bits needed to represent the column indices
+ # in the intermediate sparse matrix.
+ max_column_index = ((self.ffn_hidden_size * self.num_experts) // self.blocking)
+ self.transpose_sort_end_bit = max(
+ int(np.ceil(np.log2(max_column_index))),
+ 1,
+ )
+
+ def sparse_transpose(self, size, row_indices, column_indices, offsets):
+ block_columns = size[1] // self.blocking
+
+ # Sort row indices by column indices to get the transposed matrix's
+ # column indices.
+ #
+ # NOTE: Our sort operation uses the same width indices as the input values.
+ # To avoid overflow when we have large activation matrices we cast to
+ # 32-bit before sorting.
+ _, gather_indices = ops.sort(
+ column_indices.int(),
+ self.transpose_sort_end_bit,
+ )
+
+ # There are a constant number of blocks in every row of the sparse matrix.
+ # A blocks offset is:
+ #
+ # row_index * blocks_per_row + column_index % blocks_per_row
+ #
+ # Once we have the block offsets ordered for transposition we can divide
+ # by blocks_per_row to get the transposed column indices.
+ column_indices_t = row_indices.gather(0, gather_indices.long())
+ block_offsets_t = gather_indices.int()
+
+ zero = torch.zeros((1,), dtype=torch.int32, device=row_indices.device)
+ nnz_per_column = ops.histogram(column_indices, block_columns)
+ nnz_per_column = ops.inclusive_cumsum(nnz_per_column, 0)
+ if nnz_per_column.dim() == 0:
+ # This addresses an edge case when ffn_hidden_size is equal to self.blocking.
+ nnz_per_column = nnz_per_column.unsqueeze(0)
+ offsets_t = torch.cat([zero, nnz_per_column])
+ return column_indices_t, offsets_t, block_offsets_t
+
+ def topology(self, x, padded_bins):
+ padded_tokens, _ = x.size()
+ assert padded_tokens % self.blocking == 0
+ if self.ffn_hidden_size % self.blocking != 0:
+ raise ValueError(
+ f'The ffn_hidden_size {self.ffn_hidden_size} must be divisible by ' +
+ f'the block size {self.blocking}. Please update your configuration.',
+ )
+
+ # Offsets for the sparse matrix. All rows have the
+ # same number of nonzero blocks dictated by the
+ # dimensionality of a single expert.
+ block_rows = padded_tokens // self.blocking
+ blocks_per_row = self.ffn_hidden_size // self.blocking
+ offsets = torch.arange(
+ 0,
+ block_rows * blocks_per_row + 1,
+ blocks_per_row,
+ dtype=torch.int32,
+ device=x.device,
+ )
+
+ # Indices for the sparse matrix. The indices for
+ # the intermediate matrix are dynamic depending
+ # on the mapping of tokens to experts.
+ column_indices = ops.topology(
+ padded_bins,
+ self.blocking,
+ block_rows,
+ blocks_per_row,
+ )
+
+ # TODO(tgale): This is unused. Remove the need for this in stk.
+ # For now, use meta init to save the device memory.
+ data = torch.empty(
+ column_indices.numel(),
+ self.blocking,
+ self.blocking,
+ dtype=common.dtype(self.args),
+ device='meta',
+ )
+ shape = (
+ padded_tokens,
+ self.ffn_hidden_size * mpu.experts_per_rank(self.args),
+ )
+ row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
+ column_indices_t, offsets_t, block_offsets_t = self.sparse_transpose(
+ shape,
+ row_indices,
+ column_indices,
+ offsets,
+ )
+ return stk.Matrix(
+ shape,
+ data,
+ row_indices,
+ column_indices,
+ offsets,
+ column_indices_t,
+ offsets_t,
+ block_offsets_t,
+ )
+
+ def indices_and_padded_bins(self, top_experts):
+ # Sort the expert ids to produce the scatter/gather
+ # indices for the permutation.
+ top_experts = top_experts.int()
+ bin_ids, indices = ops.sort(top_experts, self.sort_end_bit)
+
+ # Histogram the expert ids to identify the number of
+ # tokens routed to each expert.
+ tokens_per_expert = ops.histogram(top_experts, self.num_experts)
+
+ # Round the token counts up to the block size used in
+ # the matrix muliplications. Caculate the starting
+ # position of each bin.
+ padded_tokens_per_expert = ops.round_up(
+ tokens_per_expert,
+ self.blocking,
+ )
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ padded_bins = promote_scalar(padded_bins)
+
+ # Calculate the bin bounds for the sorted tokens.
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ bins = promote_scalar(bins)
+ return indices, bin_ids, bins, padded_bins, tokens_per_expert
+
+ def sparse_forward_once(self, x, expert_weights, top_experts):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, padded_bins, tokens_per_expert = (self.indices_and_padded_bins(top_experts))
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.padded_gather(
+ x,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ self.top_k,
+ )
+
+ # Create the sparse matrix topology.
+ with torch.no_grad():
+ topo = self.topology(x, padded_bins)
+
+ # Perform the expert computation.
+ x = self.mlp(x, topo)
+
+ # Un-route the data for the MoE output.
+ x = ops.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ padded_bins,
+ self.top_k,
+ )
+ return x, tokens_per_expert
+
+ # For use in the base-class parallel_forward_once.
+ def sparse_permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy, # unused
+ top_k,
+ ):
+
+ # Round the token counts up to the block size used in the matrix
+ # multiplication. Calculate the starting position of each bin.
+ padded_tokens_per_expert = ops.round_up(
+ tokens_per_expert,
+ self.blocking,
+ )
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ padded_bins = promote_scalar(padded_bins)
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
+
+ # Create the sparse matrix topology.
+ with torch.no_grad():
+ topo = self.topology(x, padded_bins)
+
+ # Perform the expert computation.
+ x = self.mlp(x, topo)
+
+ # Un-route the data for the MoE output.
+ return ops.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ def grouped_forward_once(self, x, expert_weights, top_experts):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ out = self.grouped_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ -1, # unused
+ self.args.moe_top_k,
+ )
+ return out, tokens_per_expert
+
+ def grouped_permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy, # unused
+ top_k,
+ ):
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.gather(x, indices, bin_ids, bins, top_k)
+
+ # Perform the expert computation.
+ x = self.mlp(x, tokens_per_expert)
+
+ # Un-route the data for the MoE output.
+ return ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
+
+ def forward_once(self, x, expert_weights, top_experts):
+ if self.args.mlp_impl == 'sparse':
+ return self.sparse_forward_once(x, expert_weights, top_experts)
+ else:
+ return self.grouped_forward_once(x, expert_weights, top_experts)
+
+ def permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ ):
+ if self.args.mlp_impl == 'sparse':
+ return self.sparse_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ )
+ else:
+ return self.grouped_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ )
+
+
+class dMoE(moe.MoE):
+
+ def _init_experts_mlp(self, args: Arguments):
+ return ParallelDroplessMLP(args)
diff --git a/build/torch210-cxx11-cu128-aarch64-linux/_layers/gelu.py b/build/torch210-cxx11-cu128-aarch64-linux/_layers/gelu.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4c9e6532798615b5c12c96694241a4c18ee8f7b
--- /dev/null
+++ b/build/torch210-cxx11-cu128-aarch64-linux/_layers/gelu.py
@@ -0,0 +1,52 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# try:
+# import stk
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/gelu.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+import torch.nn.functional as F
+
+
+@torch.jit.script
+def _gelu_backward_inplace(g, x):
+ tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
+ ff = (0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out))
+ return g.mul_(ff)
+
+
+def gelu_backward_(grad: stk.Matrix, x: stk.Matrix):
+ # NOTE: The two sparse matrices must have the same topology.
+ if isinstance(grad, stk.Matrix) and isinstance(x, stk.Matrix):
+ return stk.Matrix(
+ x.size(),
+ _gelu_backward_inplace(grad.data, x.data),
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
+ return _gelu_backward_inplace(grad, x)
+
+
+def gelu(x: stk.Matrix):
+ assert isinstance(x, stk.Matrix)
+ return stk.Matrix(
+ x.size(),
+ F.gelu(x.data, approximate='tanh'),
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
diff --git a/build/torch210-cxx11-cu128-aarch64-linux/_layers/glu.py b/build/torch210-cxx11-cu128-aarch64-linux/_layers/glu.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f297a41ff6a1a2a285f5b461951672364b898da
--- /dev/null
+++ b/build/torch210-cxx11-cu128-aarch64-linux/_layers/glu.py
@@ -0,0 +1,244 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# import stk.ops
+# try:
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/glu.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+
+# from megablocks import grouped_gemm_util as gg
+# from megablocks.layers import common, mpu
+# from megablocks.layers.activation_fn import act_fn
+# from megablocks.layers.arguments import Arguments
+# from megablocks.layers.mlp import (
+# SharedMLP,
+# SparseMLP,
+# create_dmoe_expert_weights,
+# resolve_dtensor,
+# )
+
+from .. import grouped_gemm_util as gg
+from . import common, mpu
+from .activation_fn import act_fn
+from .arguments import Arguments
+from .mlp import (
+ SharedMLP,
+ SparseMLP,
+ create_dmoe_expert_weights,
+ resolve_dtensor,
+)
+
+
+class SparseGLU(SparseMLP):
+
+ def __init__(self, args: Arguments):
+ super().__init__(args)
+ self.v1 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ with torch.no_grad():
+ self.v1.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ ),
+ )
+
+ mpu.set_expert_model_parallel_attributes(
+ self.v1,
+ self._should_set_parallelism_attribute,
+ )
+
+ def forward(self, x, topo):
+ if self.args.memory_optimized_mlp:
+ raise NotImplementedError(
+ 'Memory optimized implementation not yet supported with GLU with sparse kernels.',
+ )
+
+ w1, v1, w2 = self.scale_grad(self.w1), self.scale_grad(self.v1,), self.scale_grad(self.w2)
+ w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2)
+
+ # Compute the GLU.
+ x1 = stk.ops.sdd(x, w1.t(), topo)
+ x2 = stk.ops.sdd(x, v1.t(), topo)
+
+ activation_fn_out = act_fn(x1, self.args.activation_fn)
+ x1 = stk.ops.mul(activation_fn_out, x2)
+
+ return stk.ops.dsd(x1, w2)
+
+
+class MemoryOptimizedGroupedGLU(torch.autograd.Function):
+ """GroupedMLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, v1, w2, batch_sizes, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ v1 = v1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], v1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not v1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1', 'v1' and 'w2'.")
+
+ # Layer 0: x @ w1.t().
+ assert gg.backend is not None
+ sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True)
+ v1_out = gg.backend.gmm(x, v1, batch_sizes, trans_b=True)
+
+ # GeLU.
+ activation_fn_out = activation_fn(sdd_out) * v1_out
+
+ # Layer 1: x @ w2.
+ dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, v1, w2, batch_sizes, x, sdd_out, v1_out)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx, ddsd_out):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # Unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, v1, w2 = saved_tensors[:3]
+ batch_sizes = saved_tensors[3]
+ x = saved_tensors[4]
+ sdd_out, v1_out = saved_tensors[5:7]
+
+ # Rematerialize activation_fn output.
+ activation_fn = ctx.activation_fn
+ with torch.set_grad_enabled(True):
+ sdd_out.requires_grad = True
+ v1_out.requires_grad = True
+ activation_fn_out = activation_fn(sdd_out) * v1_out
+ activation_grad_fn = activation_fn_out.backward
+
+ # Compute dw2 with recomputed activation_fn output.
+ assert gg.backend is not None
+ dw2 = gg.backend.gmm(
+ activation_fn_out,
+ ddsd_out,
+ batch_sizes,
+ trans_a=True,
+ )
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ gg.backend.gmm(
+ ddsd_out,
+ w2,
+ batch_sizes,
+ trans_b=True,
+ c=dactivation_fn_out,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out)
+ dsdd_out = sdd_out.grad
+ dv1_out = v1_out.grad
+
+ # Compute dw1.
+ dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True)
+
+ # Compute dv1.
+ dv1 = gg.backend.gmm(dv1_out, x, batch_sizes, trans_a=True)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ dx = ddsd_out
+ gg.backend.gmm(dsdd_out, w1, batch_sizes, c=dx)
+ dx += gg.backend.gmm(dv1_out, v1, batch_sizes)
+ return dx, dw1, dv1, dw2, None, None
+
+
+memory_optimized_grouped_glu = MemoryOptimizedGroupedGLU.apply
+
+
+class GroupedGLU(SparseGLU):
+
+ def forward(self, x, tokens_per_expert):
+ batch_sizes = tokens_per_expert.cpu().to(torch.long)
+ w1, v1, w2 = (
+ self.scale_grad(self.w1),
+ self.scale_grad(self.v1),
+ self.scale_grad(self.w2),
+ )
+ w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2)
+
+ # Re-shape the weights for the grouped GEMMs.
+ ne = mpu.experts_per_rank(self.args)
+ w1 = w1.view(ne, -1, self.args.hidden_size)
+ v1 = v1.view(ne, -1, self.args.hidden_size)
+ w2 = w2.view(ne, -1, self.args.hidden_size)
+
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_grouped_glu(
+ x,
+ w1,
+ v1,
+ w2,
+ batch_sizes,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ assert gg.ops is not None
+ x1 = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
+ x2 = gg.ops.gmm(x, v1, batch_sizes, trans_b=True)
+ x1 = self.args.activation_fn(x1) * x2
+ return gg.ops.gmm(x1, w2, batch_sizes)
+
+
+class SharedGLU(SharedMLP):
+ """GPU for shared expert.
+
+ Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTGLU class
+ """
+
+ def __init__(self, args: Arguments):
+ super().__init__(args)
+ self.gate_proj = args.fc_cls(
+ args.hidden_size,
+ self.args.shared_expert_hidden_size,
+ **self.fc_kwargs,
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))
diff --git a/build/torch210-cxx11-cu128-aarch64-linux/_layers/memory_test.py b/build/torch210-cxx11-cu128-aarch64-linux/_layers/memory_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..74d1166931b712635131985b25a89f4ca23e576d
--- /dev/null
+++ b/build/torch210-cxx11-cu128-aarch64-linux/_layers/memory_test.py
@@ -0,0 +1,103 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import gc
+
+import torch
+import torch.distributed as dist
+
+# from megablocks.layers import arguments, dmoe
+from . import arguments, dmoe
+
+_TESTS = ((8, 2048, 4096, 4096, 32, 4),)
+
+
+def get_tensors():
+ ptrs = set()
+ out = []
+ for obj in gc.get_objects():
+ if torch.is_tensor(obj):
+ if not obj.is_contiguous() or obj.data_ptr() in ptrs:
+ continue
+ out.append(obj)
+ ptrs.add(obj.data_ptr())
+ return out
+
+
+def test_memory(
+ group,
+ batch_size,
+ sequence_length,
+ hidden_size,
+ ffn_hidden_size,
+ num_experts,
+ top_k,
+):
+ args = arguments.Arguments(
+ hidden_size=hidden_size,
+ ffn_hidden_size=ffn_hidden_size,
+ moe_num_experts=num_experts,
+ moe_top_k=top_k,
+ moe_expert_model_parallelism=True,
+ expert_parallel_group=group,
+ fp16=False,
+ bf16=True,
+ device=torch.cuda.current_device(),
+ )
+ layer = dmoe.dMoE(args).cuda()
+
+ x = torch.randn((batch_size, sequence_length, hidden_size),
+ device=torch.cuda.current_device(),
+ dtype=torch.bfloat16).requires_grad_(True)
+ torch.cuda.empty_cache()
+
+ # Run forward + backward.
+ # with torch.autograd.detect_anomaly():
+ out, _ = layer(x)
+ out.mean().backward()
+
+ # Report peak memory.
+ mem = torch.cuda.max_memory_allocated()
+ print('Max Memory Allocated = {:0.0f}MiB'.format(mem / 1e6))
+ print('Max Memory Reserved = {:0.0f}MiB'.format(torch.cuda.max_memory_reserved() / 1e6,),)
+
+ # Calculate weight and gradient memory usage.
+ weight_memory = 2 * (
+ layer.router.layer.weight.numel() + layer.experts.mlp.w1.numel() + layer.experts.mlp.w2.numel()
+ )
+
+ def grad_numel(x):
+ if x.grad is not None:
+ return x.grad.numel()
+ return 0
+
+ grad_memory = 2 * (
+ grad_numel(layer.router.layer.weight) + grad_numel(layer.experts.mlp.w1) + grad_numel(layer.experts.mlp.w2)
+ )
+ weight_memory += grad_memory
+
+ print('Weight Memory Allocated = {:0.0f}MiB'.format(weight_memory / 1e6))
+ print('Activation Memory Allocated = {:0.0f}MiB'.format((mem - weight_memory) / 1e6,),)
+
+ # Manually calculate GPU memory usage from the garbage
+ # collector.
+ gc.collect()
+ total = 0
+ tensors = get_tensors()
+ tensors = sorted(tensors, key=lambda x: -x.numel())
+ for i, t in enumerate(tensors):
+ total += t.numel()
+ print(f'{i}: {t.shape}, {t.numel() * 2}')
+ del tensors
+
+ print('Total Bytes Found = {:0.0f}MiB'.format(total * 2 / 1e6))
+
+
+if __name__ == '__main__':
+ assert dist.is_available()
+ group = dist.init_process_group(backend='nccl')
+ local_rank = dist.get_rank(group)
+ torch.cuda.set_device(local_rank)
+
+ for args in _TESTS:
+ test_memory(group, *args)
diff --git a/build/torch210-cxx11-cu128-aarch64-linux/_layers/mlp.py b/build/torch210-cxx11-cu128-aarch64-linux/_layers/mlp.py
new file mode 100644
index 0000000000000000000000000000000000000000..c99afb9904c24a8b6a83e79059cd1251dbbfd99e
--- /dev/null
+++ b/build/torch210-cxx11-cu128-aarch64-linux/_layers/mlp.py
@@ -0,0 +1,587 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# try:
+# import stk
+# import stk.backend.triton_kernels
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/mlp.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+from packaging import version
+
+# from megablocks import grouped_gemm_util as gg
+# from megablocks.layers import common, gelu, mpu
+# from megablocks.layers.activation_fn import act_fn
+# from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
+
+from .. import grouped_gemm_util as gg
+from . import common, gelu, mpu
+from .activation_fn import act_fn
+from .arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
+
+class ScaleGradient(torch.autograd.Function):
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx: Any, x: torch.Tensor, scale: float):
+ ctx.scale = scale
+ return x
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx: torch.Tensor, grad: torch.Tensor):
+ return grad * ctx.scale, None
+
+
+scale_gradient = ScaleGradient.apply
+
+
+def resolve_dtensor(weight: torch.Tensor):
+ if version.parse(torch.__version__) >= version.parse('2.0.0'):
+ from torch.distributed._tensor import DTensor
+ if isinstance(weight, DTensor):
+ return weight.to_local()
+ return weight
+
+
+def create_moe_expert_weights(
+ args: Arguments,
+ num_experts: int,
+ ffn_hidden_size: int,
+ hidden_size: int,
+ init_method: InitFn,
+):
+ # Create the entire weight matrix such that the sampled weights will
+ # not vary between data parallelism and expert model parallelism for
+ # the same random seed.
+ master_weights = torch.empty(
+ num_experts,
+ ffn_hidden_size,
+ hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ )
+ init_method(master_weights)
+
+ if not args.moe_expert_model_parallelism:
+ return master_weights
+
+ # Calculate the amount of sharding in each dimension.
+ expert_sharding_degree = mpu.expert_sharding_degree(args)
+ hidden_sharding_degree = mpu.hidden_sharding_degree(args)
+
+ # Calculate the experts per rank.
+ #
+ # NOTE: We assign ranks to be expert parallel before going
+ # tensor parallel.
+ rank = mpu.get_expert_parallel_rank(args)
+ expert_rank = rank % expert_sharding_degree
+ num_experts_per_rank = num_experts // expert_sharding_degree
+ start_expert = expert_rank * num_experts_per_rank
+ end_expert = (expert_rank + 1) * num_experts_per_rank
+
+ # Calculate the rows per rank.
+ row_rank = rank // expert_sharding_degree
+ num_rows_per_rank = ffn_hidden_size // hidden_sharding_degree
+ start_row = row_rank * num_rows_per_rank
+ end_row = (row_rank + 1) * num_rows_per_rank
+
+ # Slice the weight matrix to get the chunk for this rank.
+ with torch.no_grad():
+ weights = master_weights[start_expert:end_expert, start_row:end_row]
+ return weights
+
+
+class MLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ # expert_parallel_world_size = mpu.get_expert_parallel_world_size(args)
+ experts_per_rank = mpu.experts_per_rank(args)
+
+ self.w1 = torch.nn.Parameter(
+ torch.empty(
+ experts_per_rank,
+ args.hidden_size,
+ mpu.features_per_rank(args),
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ self.w2 = torch.nn.Parameter(
+ torch.empty(
+ experts_per_rank,
+ mpu.features_per_rank(args),
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w1,
+ args.moe_expert_model_parallelism,
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w2,
+ args.moe_expert_model_parallelism,
+ )
+
+ # Initialize the parameters for the MLP.
+ #
+ # NOTE: It is important that we create the weight tensors prior
+ # to creating the master weights and slicing our the piece for
+ # this rank. If the master weights are created first the PyTorch
+ # caching allocator appears to use the same memory block for these
+ # and the slice which causes large increases in our peak memory
+ # usage.
+ with torch.no_grad():
+ w1 = create_moe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ )
+ self.w1.copy_(w1.transpose(1, 2).contiguous())
+ self.w2.copy_(
+ create_moe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.output_layer_init_method,
+ ),
+ )
+
+ self.gradient_scale = None
+ if self.args.moe_expert_model_parallelism:
+ self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,)
+
+ def scale_grad(self, w):
+ if self.gradient_scale is None:
+ return w
+ return scale_gradient(w, self.gradient_scale)
+
+ def forward(self, x):
+ w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2)
+ w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2)
+ x = torch.bmm(x, w1)
+ x = self.args.activation_fn(x)
+ return torch.bmm(x, w2)
+
+
+def create_dmoe_expert_weights(
+ args: Arguments,
+ num_experts: int,
+ rows: int,
+ columns: int,
+ init_method: InitFn,
+):
+ weights = create_moe_expert_weights(
+ args,
+ num_experts,
+ rows,
+ columns,
+ init_method,
+ )
+ return weights.view([-1, columns])
+
+
+class MemoryOptimizedMLP(torch.autograd.Function):
+ """Sparse MLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, w2, topo, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.")
+
+ topo_tensors = (
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets,
+ topo.column_indices_t,
+ topo.offsets_t,
+ topo.block_offsets_t,
+ )
+
+ # Layer 0: x @ w1.t().
+ sdd_out = stk.ops.sdd(x, w1.t(), topo)
+
+ # GeLU.
+ activation_fn_out = act_fn(sdd_out, activation_fn)
+
+ # Layer 1: x @ w2.
+ dsd_out = stk.ops.dsd(activation_fn_out, w2)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.shape = topo.shape
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.data.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, w2, *topo_tensors, x, sdd_out.data)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx, ddsd_out):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, w2 = saved_tensors[:2]
+ topo_tensors = saved_tensors[2:8]
+ x = saved_tensors[8]
+ sdd_out_data = saved_tensors[9]
+
+ # rematerialize activation function output
+ activation_fn = ctx.activation_fn
+ sdd_out = stk.Matrix(ctx.shape, sdd_out_data, *topo_tensors)
+ activation_fn_out, activation_grad_fn = act_fn(
+ sdd_out,
+ activation_fn,
+ return_grad_fn=True,
+ )
+
+ # Compute dw2 with recomputed activation_fn output.
+ dw2 = stk.ops.dsd(activation_fn_out.t(), ddsd_out)
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ stk.backend.triton_kernels.sdd(
+ ddsd_out,
+ w2.t(),
+ dactivation_fn_out.shape,
+ dactivation_fn_out.data,
+ dactivation_fn_out.offsets,
+ dactivation_fn_out.row_indices,
+ dactivation_fn_out.column_indices,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ if activation_fn is DEFAULT_ACTIVATION_FN:
+ dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out)
+ else:
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out.data)
+ dsdd_out = stk.Matrix(ctx.shape, sdd_out.data.grad, *topo_tensors)
+
+ # Compute dw1.
+ dw1 = stk.ops.dsd(dsdd_out.t(), x)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ stk.backend.triton_kernels.dsd(
+ dsdd_out.shape,
+ dsdd_out.data,
+ dsdd_out.offsets,
+ dsdd_out.row_indices,
+ dsdd_out.column_indices,
+ dsdd_out.offsets_t,
+ dsdd_out.column_indices_t,
+ dsdd_out.block_offsets_t,
+ False,
+ w1,
+ ddsd_out,
+ )
+ dx = ddsd_out
+ return dx, dw1, dw2, None, None
+
+
+memory_optimized_mlp = MemoryOptimizedMLP.apply
+
+
+class SparseMLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ self._num_rows_per_rank = mpu.experts_per_rank(args) * mpu.features_per_rank(args)
+
+ self.w1 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ self.w2 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+
+ # Initialize the parameters for the MLP.
+ #
+ # NOTE: It is important that we create the weight tensors prior
+ # to creating the master weights and slicing our the piece for
+ # this rank. If the master weights are created first the PyTorch
+ # caching allocator appears to use the same memory block for these
+ # and the slice which causes large increases in our peak memory
+ # usage.
+ with torch.no_grad():
+ self.w1.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ ),
+ )
+ self.w2.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.output_layer_init_method,
+ ),
+ )
+
+ self._should_set_parallelism_attribute = args.moe_expert_model_parallelism
+ mpu.set_expert_model_parallel_attributes(
+ self.w1,
+ self._should_set_parallelism_attribute,
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w2,
+ self._should_set_parallelism_attribute,
+ )
+
+ self.gradient_scale = None
+ if self.args.moe_expert_model_parallelism:
+ self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,)
+
+ def scale_grad(self, w):
+ if self.gradient_scale is None:
+ return w
+ return scale_gradient(w, self.gradient_scale)
+
+ def forward(self, x, topo):
+ w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2)
+ w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2)
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_mlp(
+ x,
+ w1,
+ w2,
+ topo,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ x = stk.ops.sdd(x, w1.t(), topo)
+ activation_fn_out = act_fn(x, self.args.activation_fn)
+ return stk.ops.dsd(activation_fn_out, w2)
+
+
+class MemoryOptimizedGroupedMLP(torch.autograd.Function):
+ """GroupedMLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, w2, batch_sizes, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.")
+
+ # Layer 0: x @ w1.t().
+ assert gg.backend is not None
+ sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True)
+
+ # activation_fn
+ activation_fn_out = activation_fn(sdd_out)
+
+ # Layer 1: x @ w2.
+ dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, w2, batch_sizes, x, sdd_out)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx: Any, ddsd_out: torch.Tensor):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # Unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, w2 = saved_tensors[:2]
+ batch_sizes = saved_tensors[2]
+ x = saved_tensors[3]
+ sdd_out = saved_tensors[4]
+
+ # Rematerialize activation_fn output.
+ activation_fn = ctx.activation_fn
+ with torch.set_grad_enabled(True):
+ sdd_out.requires_grad = True
+ activation_fn_out = activation_fn(sdd_out)
+ activation_grad_fn = activation_fn_out.backward
+
+ # Compute dw2 with recomputed activation_fn output.
+ assert gg.backend is not None
+ dw2 = gg.backend.gmm(
+ activation_fn_out,
+ ddsd_out,
+ batch_sizes,
+ trans_a=True,
+ )
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ gg.backend.gmm(
+ ddsd_out,
+ w2,
+ batch_sizes,
+ trans_b=True,
+ c=dactivation_fn_out,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ if activation_fn is DEFAULT_ACTIVATION_FN:
+ dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out)
+ else:
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out)
+ dsdd_out = sdd_out.grad
+
+ # Compute dw1.
+ dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ gg.backend.gmm(dsdd_out, w1, batch_sizes, c=ddsd_out)
+ dx = ddsd_out
+ return dx, dw1, dw2, None, None
+
+
+memory_optimized_grouped_mlp = MemoryOptimizedGroupedMLP.apply
+
+
+class GroupedMLP(SparseMLP):
+
+ def forward(self, x, tokens_per_expert):
+ batch_sizes = tokens_per_expert.cpu().to(torch.long)
+ w1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.w2))
+
+ # Re-shape the weights for the grouped GEMMs.
+ ne = mpu.experts_per_rank(self.args)
+ w1 = resolve_dtensor(w1).view(ne, -1, self.args.hidden_size)
+ w2 = resolve_dtensor(w2).view(ne, -1, self.args.hidden_size)
+
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_grouped_mlp(
+ x,
+ w1,
+ w2,
+ batch_sizes,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ assert gg.ops is not None
+ x = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
+ x = self.args.activation_fn(x)
+ return gg.ops.gmm(x, w2, batch_sizes)
+
+
+class SharedMLP(torch.nn.Module):
+ """MLP for shared expert.
+
+ Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTMLP class
+ """
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ self.fc_kwargs: dict[str, Any] = {
+ 'bias': args.bias,
+ 'device': args.device,
+ }
+ self.fc_kwargs.update(args.fc_kwargs)
+
+ self.up_proj = args.fc_cls(
+ args.hidden_size,
+ args.shared_expert_hidden_size,
+ **self.fc_kwargs,
+ )
+ self.act = args.activation_fn
+ self.down_proj = args.fc_cls(
+ args.shared_expert_hidden_size,
+ args.hidden_size,
+ **self.fc_kwargs,
+ )
+ self.down_proj._is_residual = True # a flag for llm-foundry init
+
+ def add_experts_sharedexpert(
+ self,
+ shared_expert_out: torch.Tensor,
+ expert_out: torch.Tensor,
+ ) -> torch.Tensor:
+ # Helper function to add expert output to shared expert output
+ # with optional weighted sum.
+ if self.args.shared_expert_weighted_sum:
+ # enable using weighted sum for shared expert output
+ # wieghted by number of experts used
+ t_experts = self.args.moe_top_k + 1
+ sh_mlp_out = shared_expert_out / t_experts
+ return sh_mlp_out.add(
+ expert_out,
+ alpha=(self.args.moe_top_k / t_experts),
+ )
+
+ return shared_expert_out + expert_out
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.down_proj(self.act(self.up_proj(x)))
diff --git a/build/torch210-cxx11-cu128-aarch64-linux/_layers/moe.py b/build/torch210-cxx11-cu128-aarch64-linux/_layers/moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0a4aeaacc9c86fc70944e730c53f7a55644e05e
--- /dev/null
+++ b/build/torch210-cxx11-cu128-aarch64-linux/_layers/moe.py
@@ -0,0 +1,507 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Optional, Tuple
+
+import numpy as np
+import torch
+import torch.distributed as dist
+
+# import megablocks.ops as ops
+# from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry
+# from megablocks.layers.all_to_all import all_to_all
+# from megablocks.layers.arguments import Arguments
+
+from ..ops import (
+ sort,
+ histogram,
+ inclusive_cumsum,
+ exclusive_cumsum,
+ binned_gather,
+ binned_scatter,
+ gather,
+ scatter,
+ repeat,
+ replicate,
+)
+
+from . import common, mlp, mpu, router, sharedexpert_registry
+from .arguments import Arguments
+from .all_to_all import all_to_all
+
+_LOAD_BALANCING_LOSS = []
+
+
+def save_load_balancing_loss(loss):
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.append(loss)
+
+
+def get_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ return _LOAD_BALANCING_LOSS
+
+
+def clear_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.clear()
+
+
+def batched_load_balancing_loss(args: Arguments):
+ if args.moe_loss_weight == 0:
+ return 0.0
+
+ # tokens_per_expert[i].shape = (num_experts)
+ # expert_scores[i].shape = (tokens, num_experts)
+ tokens_per_expert, expert_scores = zip(*get_load_balancing_loss())
+ num_layers_per_pipeline_stage = (args.num_layers // args.pipeline_model_parallel_size)
+ if args.num_layers_per_virtual_pipeline_stage is not None:
+ num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage
+
+ if len(tokens_per_expert) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f'Expected {num_layers_per_pipeline_stage} token_per_experts '
+ f'but found {len(tokens_per_expert)}.\nnum_layers = '
+ f'{args.num_layers}\npipeline_model_parallel_size = '
+ f'{args.pipeline_model_parallel_size}\n'
+ 'num_layers_per_virtual_pipeline_stage'
+ f' = {args.num_layers_per_virtual_pipeline_stage}',
+ )
+ if len(expert_scores) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f'Expected {num_layers_per_pipeline_stage} expert_scores '
+ f'but found {len(tokens_per_expert)}.\nnum_layers = '
+ f'{args.num_layers}\npipeline_model_parallel_size = '
+ f'{args.pipeline_model_parallel_size}\n'
+ 'num_layers_per_virtual_pipeline_stage'
+ f' = {args.num_layers_per_virtual_pipeline_stage}',
+ )
+
+ # Verify the shape of the tokens_per_expert and expert_scores tensors.
+ assert all((x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert))
+
+ tokens = expert_scores[0].shape[0]
+ assert all(((x.ndim == 2 and x.shape[1] == args.moe_num_experts and x.shape[0] == tokens) for x in expert_scores))
+
+ # Concatenate the contributions of each layer and convert to
+ # the correct types and formats for the dot product.
+ expert_scores = torch.cat(expert_scores, dim=1)
+ if args.moe_lbl_in_fp32:
+ expert_scores = expert_scores.float()
+ if tokens != 0:
+ expert_scores = expert_scores.mean(dim=0)
+ else:
+ expert_scores = expert_scores.sum(dim=0)
+ tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype)
+
+ expected_values = num_layers_per_pipeline_stage * args.moe_num_experts
+ assert tokens_per_expert.numel() == expected_values
+ assert expert_scores.numel() == expected_values
+
+ # Calculate the total scale across all factors.
+ #
+ # loss_weight * num_experts / (num_layers * tokens * top_k)
+ scale_numerator = (args.moe_num_experts * args.moe_loss_weight)
+ scale_denominator = (args.num_layers * tokens * args.moe_top_k)
+ scale = scale_numerator / scale_denominator
+ return scale * torch.dot(tokens_per_expert, expert_scores)
+
+
+# NOTE: This class defines MoE expert computation, including expert model parallel
+# communication. When using FSDP on top of MegaBlocks this is the module that should
+# be wrapped s.t. the weight all-gathers can be scheduled *before* the expert model
+# parallel all2all.
+class ParallelMLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super(ParallelMLP, self).__init__()
+ self.args = args
+
+ # Calculate the number of experts in total and the number of experts
+ # owned by this rank.
+ # world_size = mpu.get_expert_parallel_world_size(args)
+ self.num_experts = args.moe_num_experts
+ self.top_k = self.args.moe_top_k
+
+ # Calculate the number of bits needed to represent the expert indices
+ # so that we can pass it to radix sort.
+ self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1)
+
+ # Expert MLP.
+ self.mlp = mlp.MLP(args)
+
+ self.bias: Optional[torch.Tensor]
+ if self.args.bias:
+ # Note that the output bias is not parallelized with expert
+ # model parallelism.
+ self.bias = torch.nn.Parameter(
+ torch.empty(
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ torch.nn.init.zeros_(self.bias)
+ else:
+ self.register_parameter('bias', None)
+
+ # Select the forward function for the operating mode.
+ self.forward_fn = (self.parallel_forward_once if args.moe_expert_model_parallelism else self.forward_once)
+
+ def expert_capacity(self, tokens: int) -> int:
+ world_size = mpu.get_expert_parallel_world_size(self.args)
+ tokens_per_expert = (self.top_k * tokens * world_size / self.num_experts)
+ return int(self.args.moe_capacity_factor * tokens_per_expert)
+
+ def load_balancing_loss(self, tokens_per_expert: torch.Tensor, expert_scores: torch.Tensor):
+ """Calculate the load balancing loss contribution."""
+ assert len(expert_scores.size()) == 2
+ tokens, num_experts = expert_scores.size()
+ assert num_experts == self.num_experts
+ assert len(tokens_per_expert.size()) == 1
+ num_experts, = tokens_per_expert.size()
+ assert num_experts == self.num_experts
+ scale = self.num_experts / (tokens * self.top_k)
+ return scale * torch.dot(
+ tokens_per_expert.to(expert_scores.dtype),
+ expert_scores.mean(dim=0),
+ )
+
+ def indices_and_bins(self,
+ top_expert: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ # Sort the expert ids to produce the scatter/gather
+ # indices for the permutation.
+ #
+ # TODO(tgale): Is it worth doing this conversion to 32-bit
+ # prior? Could we place the `torch.max` operation to return
+ # 32-bit expert indices?
+ top_expert = top_expert.int()
+ # output = ops.sort(top_expert, self.sort_end_bit)
+ output = sort(top_expert, self.sort_end_bit)
+ assert output is not None
+ bin_ids, indices = output
+
+ # Histogram the expert ids to identify the number of
+ # tokens routed to each expert.
+ #
+ # TODO(tgale): Does the sorted data produce a more favorable
+ # data distribution for histogram? Or is the op parallelism
+ # worth more?
+ # tokens_per_expert = ops.histogram(top_expert, self.num_experts)
+ tokens_per_expert = histogram(top_expert, self.num_experts)
+
+ # Calculate the bin bounds for the sorted tokens.
+ # bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ bins = inclusive_cumsum(tokens_per_expert, 0)
+ assert bins is not None
+ bins = bins.view(1) if not len(bins.size()) else bins
+
+ assert isinstance(indices, torch.Tensor)
+ assert isinstance(bin_ids, torch.Tensor)
+ assert isinstance(bins, torch.Tensor)
+ assert isinstance(tokens_per_expert, torch.Tensor)
+
+ return indices, bin_ids, bins, tokens_per_expert
+
+ def permute_and_compute(
+ self,
+ x: torch.Tensor,
+ tokens_per_expert: int, # unused
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor, # unused
+ expert_weights: torch.Tensor,
+ bins: torch.Tensor,
+ expert_capacity: int,
+ top_k: int,
+ ):
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ # output = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
+ output = binned_gather(x, indices, bins, expert_capacity, top_k)
+ assert output is not None
+ x = output
+
+ # Perform the expert computation. Note that we don't
+ # use biases for these linear operations.
+ x = self.mlp(x)
+
+ # Un-route the data for the MoE output.
+ # return ops.binned_scatter(x, indices, expert_weights, bins, top_k)
+ return binned_scatter(x, indices, expert_weights, bins, top_k)
+
+
+ def forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ # If expert_capacity is set to zero, set the number of tokens
+ # per expert to the maximum we need to avoid dropping tokens.
+ sl, bs, _ = x.size()
+ expert_capacity = self.expert_capacity(sl * bs)
+ if expert_capacity == 0:
+ expert_capacity = torch.max(tokens_per_expert).item()
+
+ x = self.permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ self.top_k,
+ )
+ return x, tokens_per_expert
+
+ def parallel_forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ # NOTE: This function implements the same computation as forward_once
+ # but with expert model parallelism.
+ #
+ # 1. Permute the tokens locally so that they are grouped by their
+ # expert assignments. This allows us to transfer all of the tokens
+ # for a remote device in one communication primitive.
+ #
+ # 2. Permute the tokens across the expert parallel devices. After
+ # this is completed each device has all of the tokens assigned to
+ # its set of experts in its local HBM.
+ #
+ # 3. Permute the tokens locally so that they are grouped by their
+ # expert assignement. After the distributed permutation the tokens
+ # are grouped by which device they came from. We re-order them
+ # locally to allow for efficient computation.
+ #
+ # After this series of permutations we compute the linear layers
+ # and then repeat these three steps in reverse to produce the final
+ # output.
+ #
+ # Compute the mapping of local tokens to experts.
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ # If we're sharding the experts along the hidden dimension
+ # multiple devices own parts of the same sets of experts.
+ # Replicate the token counts so every device gets the counts.
+ # repeated_tokens_per_expert = ops.repeat(
+ repeated_tokens_per_expert = repeat(
+ tokens_per_expert,
+ (mpu.hidden_sharding_degree(self.args),),
+ )
+
+ # Pass token count information to the device on which the
+ # target expert resides.
+ parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert,)
+ tpe_handle = dist.all_to_all_single(
+ parallel_tokens_per_expert,
+ repeated_tokens_per_expert,
+ group=self.args.expert_parallel_group,
+ async_op=True,
+ )
+
+ # Permute locally and without any padding so that tokens for each
+ # parallel device are stored contiguously.
+ #
+ # This view updates the shape of the tensor from [sl, bs, hs] to
+ # [sl * bs, hs] prior to the permutation.
+ x = x.view(-1, x.shape[-1])
+ # output = ops.gather(x, indices, bin_ids, bins, self.top_k)
+ output = gather(x, indices, bin_ids, bins, self.top_k)
+ assert output is not None
+ x = output
+
+ # Compute the number of tokens that will be received from each
+ # device and permute the input data across the devices.
+ with torch.no_grad():
+ tpe_handle.wait()
+ experts_per_rank = mpu.experts_per_rank(self.args)
+
+ # Reshape to [world_size, num_experts_per_rank].
+ world_size = mpu.get_expert_parallel_world_size(self.args)
+ repeated_tokens_per_expert = (repeated_tokens_per_expert.view(world_size, experts_per_rank))
+ parallel_tokens_per_expert = (parallel_tokens_per_expert.view(world_size, experts_per_rank))
+
+ # TODO(tgale): It might be faster to do this on the GPU and
+ # then communicate the results back to the host.
+ send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1)
+ parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
+ recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1)
+
+ # Convert the send/recv counts to lists.
+ send_counts = send_counts.tolist()
+ recv_counts = recv_counts.tolist()
+ tokens_received = sum(recv_counts)
+
+ # If we're sharding the experts along the hidden dimension
+ # multiple devices own parts of the same sets of experts.
+ # Replicate the token counts so devices that share experts
+ # get all of the tokens assigned to them.
+ #
+ # TODO(tgale): Fuse this into the prior, local permutation.
+ # x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
+ x = repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
+
+ # Start the cross-device permutation asynchronously so we can
+ # overlap communication with computation.
+ parallel_x, parallel_x_handle = all_to_all(
+ x,
+ recv_counts,
+ send_counts,
+ self.args.expert_parallel_group,
+ async_op=True,
+ )
+
+ with torch.no_grad():
+ # After we do the cross-device permutation we have the tokens on the
+ # correct device but not yet grouped by expert because we received
+ # tokens from each device as contiguous chunks. To group the tokens
+ # for expert computation we'll do one more local permutation. The
+ # rest of this torch.no_grad() scope sets up the indices and bins
+ # for this permutation.
+ # replicate_bins = ops.inclusive_cumsum(
+ replicate_bins = inclusive_cumsum(
+ parallel_tokens_per_expert.flatten(),
+ 0,
+ )
+ replicate_bins = (replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins)
+
+ # Construct the expert indices for the permuted tokens.
+ parallel_top_expert = torch.remainder(
+ torch.arange(
+ self.num_experts * mpu.hidden_sharding_degree(self.args),
+ dtype=torch.int32,
+ device=indices.device,
+ ),
+ mpu.experts_per_rank(self.args),
+ )
+ # parallel_top_expert = ops.replicate(
+ parallel_top_expert = replicate(
+ parallel_top_expert.unsqueeze(dim=0),
+ replicate_bins,
+ tokens_received,
+ ).flatten()
+
+ # TODO(tgale): The sort_end_bit here can be reduced.
+ # parallel_bin_ids, parallel_indices = ops.sort(
+ parallel_bin_ids, parallel_indices = sort(
+ parallel_top_expert,
+ self.sort_end_bit,
+ )
+
+ # Calculate the bins boundaries from the token counts.
+ parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+ # parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = (parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins)
+
+ # If expert_capacity is set to zero, set the number of tokens
+ # per expert to the maximum we need to avoid dropping tokens.
+ tokens, _ = x.size()
+ expert_capacity = self.expert_capacity(tokens)
+ if expert_capacity == 0:
+ expert_capacity = torch.max(parallel_tokens_per_expert).item()
+
+ # Locally permute the tokens and perform the expert computation.
+ # Block to make sure that the cross-device permutation is complete.
+ if self.args.mlp_impl == 'grouped':
+ # GroupedMLP requires counts on CPU. We can use the tensor already
+ # moved to CPU for the prior all_to_all, which avoids an extra
+ # device synchronization.
+ parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+ parallel_x_handle.wait()
+ parallel_x = self.permute_and_compute(
+ parallel_x,
+ parallel_tokens_per_expert,
+ parallel_indices,
+ parallel_bin_ids,
+ None, # expert_weights
+ parallel_bins,
+ expert_capacity,
+ top_k=1,
+ )
+
+ # Un-permute the tokens across the devices.
+ x, _ = all_to_all(
+ parallel_x,
+ send_counts,
+ recv_counts,
+ self.args.expert_parallel_group,
+ )
+
+ # Reduce along the hidden sharding to get the final outputs.
+ #
+ # TODO(tgale): Fuse this into the following local permutation.
+ shape = (
+ mpu.hidden_sharding_degree(self.args),
+ -1,
+ self.args.hidden_size,
+ )
+ # x = ops.sum(x.view(shape), dim=0)
+ x = x.view(shape).sum(dim=0)
+
+ # Un-permute locally to setup for the next series of operations.
+ # x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
+ x = scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
+ return x, tokens_per_expert.flatten()
+
+ def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ in_shape = x.size()
+
+ # Compute the experts.
+ x, tokens_per_expert = self.forward_fn(x, expert_weights, top_experts)
+ if self.training and self.args.moe_loss_weight > 0:
+ save_load_balancing_loss((tokens_per_expert, scores))
+ x = x.view(in_shape)
+ if self.bias is not None:
+ if self.args.return_bias:
+ return x, self.bias
+ return x + self.bias
+ return x
+
+
+class MoE(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super(MoE, self).__init__()
+
+ # Token router.
+ self.router = router.LearnedRouter(args)
+
+ # Expert computation helper.
+ self.experts = self._init_experts_mlp(args)
+
+ self.shared_expert = None
+ if args.shared_expert:
+ # SharedExpert computation helper.
+ self.shared_expert = sharedexpert_registry.get(args)
+
+ def _init_experts_mlp(self, args: Arguments):
+ return ParallelMLP(args)
+
+ def forward(self, x: torch.Tensor):
+ # NOTE: If we're going to cast the activations to lower precision
+ # do it before we permute the tokens to save bandwidth.
+ x = common.cast_if_autocast_enabled(x)
+
+ # Compute the expert scores and assignments.
+ scores, expert_weights, top_experts = self.router(x)
+
+ # Compute the experts.
+ out = self.experts(x, scores, expert_weights, top_experts)
+ if self.shared_expert is not None:
+ shared_expert_out = self.shared_expert(x)
+ out = self.shared_expert.add_experts_sharedexpert(
+ shared_expert_out,
+ out,
+ )
+ return out
diff --git a/build/torch210-cxx11-cu128-aarch64-linux/_layers/mpu.py b/build/torch210-cxx11-cu128-aarch64-linux/_layers/mpu.py
new file mode 100644
index 0000000000000000000000000000000000000000..434e143ab42bf3f83406d69e9dd1f72777716e22
--- /dev/null
+++ b/build/torch210-cxx11-cu128-aarch64-linux/_layers/mpu.py
@@ -0,0 +1,94 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Optional
+
+import torch
+import torch.distributed as dist
+
+# from megablocks.layers.arguments import Arguments
+from .arguments import Arguments
+
+
+class MoeParam(torch.Tensor):
+
+ def __init__(self):
+ super().__init__(self)
+ self.expert_model_parallel: bool
+
+
+def is_moe_param(tensor: torch.Tensor) -> bool:
+ return hasattr(tensor, 'expert_model_parallel')
+
+
+def get_expert_parallel_world_size(args: Arguments) -> int:
+ return (dist.get_world_size(args.expert_parallel_group) if args.moe_expert_model_parallelism else 1)
+
+
+def get_expert_parallel_rank(args: Arguments) -> int:
+ return (dist.get_rank(args.expert_parallel_group) if args.moe_expert_model_parallelism else 0)
+
+
+def set_expert_model_parallel_attributes(
+ tensor: torch.Tensor,
+ is_parallel: bool,
+):
+ assert not hasattr(tensor, 'expert_model_parallel')
+ setattr(tensor, 'expert_model_parallel', is_parallel)
+
+
+def param_is_expert_model_parallel(param: MoeParam) -> bool:
+ return (hasattr(param, 'expert_model_parallel') and param.expert_model_parallel)
+
+
+def copy_expert_model_parallel_attributes(
+ destination_tensor: torch.Tensor,
+ source_tensor: torch.Tensor,
+):
+ if hasattr(source_tensor, 'expert_model_parallel'):
+ setattr(
+ destination_tensor,
+ 'expert_model_parallel',
+ getattr(source_tensor, 'expert_model_parallel'),
+ )
+
+
+def synchronized_print(group: Optional[dist.ProcessGroup], *x: torch.Tensor):
+ world_size = dist.get_world_size(group)
+ rank = dist.get_rank(group)
+ for i in range(world_size):
+ dist.barrier(group)
+ if i == rank:
+ print(f'rank = {rank}', *x)
+
+
+# Helpers for expert/tensor sharding.
+def expert_sharding_degree(args: Arguments) -> int:
+ world_size = get_expert_parallel_world_size(args)
+ esd = min(world_size, args.moe_num_experts)
+
+ if (args.moe_num_experts % esd) != 0:
+ raise ValueError(f'Cannot shard {args.moe_num_experts} experts {esd} ways.',)
+ return esd
+
+
+def hidden_sharding_degree(args: Arguments) -> int:
+ world_size = get_expert_parallel_world_size(args)
+ esd = expert_sharding_degree(args)
+ hsd = world_size // esd
+
+ if (args.ffn_hidden_size % hsd) != 0:
+ raise ValueError(f'Cannot shard {args.ffn_hidden_size} features {hsd} ways.',)
+ if (esd * hsd) != world_size:
+ raise ValueError(
+ f"Invalid sharding. 'expert_sharding_degree' ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size}).",
+ )
+ return hsd
+
+
+def experts_per_rank(args: Arguments) -> int:
+ return args.moe_num_experts // expert_sharding_degree(args)
+
+
+def features_per_rank(args: Arguments) -> int:
+ return args.ffn_hidden_size // hidden_sharding_degree(args)
diff --git a/build/torch210-cxx11-cu128-aarch64-linux/_layers/router.py b/build/torch210-cxx11-cu128-aarch64-linux/_layers/router.py
new file mode 100644
index 0000000000000000000000000000000000000000..37cb2782348d62583376f1a183c7ede83601216d
--- /dev/null
+++ b/build/torch210-cxx11-cu128-aarch64-linux/_layers/router.py
@@ -0,0 +1,116 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+
+# from megablocks.layers import common
+# from megablocks.layers.arguments import Arguments
+from . import common
+from .arguments import Arguments
+
+_ROUTER_LOGITS = []
+
+
+def _save_router_logits(logits: torch.Tensor, args: Arguments):
+ if args.moe_zloss_weight == 0:
+ return
+ global _ROUTER_LOGITS
+ _ROUTER_LOGITS.append(logits)
+
+
+def clear_router_zloss():
+ global _ROUTER_LOGITS
+ _ROUTER_LOGITS.clear()
+
+
+def batched_router_zloss(args: Arguments):
+ global _ROUTER_LOGITS
+
+ if args.moe_zloss_weight == 0:
+ import warnings
+ warnings.warn('Call to batched_router_zloss, but moe_zloss_weight=0')
+ return 0
+
+ logits_per_router = _ROUTER_LOGITS
+
+ if args.moe_zloss_in_fp32:
+ logits_per_router = [logits.float() for logits in logits_per_router]
+
+ unscaled_zloss_per_router = torch.stack([
+ torch.logsumexp(logits, dim=1).square().mean() for logits in logits_per_router
+ ])
+
+ return args.moe_zloss_weight * unscaled_zloss_per_router
+
+
+# NOTE: To enable end-to-end benchmarking without convergence we
+# support a flag to force the router to assign tokens uniformly
+# across the experts. We do this with a custom autograd operation
+# so that PyTorch still executes the full set of router operation.
+class _UniformExpertAssignment(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, num_experts: int):
+ out = torch.arange(x.numel(), dtype=x.dtype, device=x.device)
+ out = torch.remainder(out, num_experts)
+ return out.view(x.shape)
+
+
+_uniform_expert_assignment = _UniformExpertAssignment.apply
+
+
+class LearnedRouter(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+
+ # Learned router parameters.
+ #
+ # NOTE: This weight matrix is not parallelized with expert model
+ # parallelism. Each device needs the entire router weight matrix
+ # so that it can route its batch of data correctly.
+ self.layer = torch.nn.Linear(
+ args.hidden_size,
+ args.moe_num_experts,
+ bias=False,
+ dtype=common.dtype(args),
+ device=args.device,
+ )
+ args.init_method(self.layer.weight)
+
+ def jitter(self, x: torch.Tensor):
+ low: float = 1.0 - self.args.moe_jitter_eps
+ high: float = 1.0 + self.args.moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return low + noise * (high - low)
+
+ def _top_k(self, scores: torch.Tensor):
+ if self.args.moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, self.args.moe_top_k, dim=-1)
+
+ def forward(self, x: torch.Tensor):
+ if self.training and self.args.moe_jitter_eps is not None:
+ x = x * self.jitter(x)
+
+ logits = self.layer(x.view(-1, x.shape[-1]))
+ _save_router_logits(logits, self.args)
+ scores = logits.softmax(dim=-1)
+ expert_weights, expert_indices = self._top_k(scores)
+ if self.args.moe_normalize_expert_weights:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=self.args.moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+
+ expert_indices = (
+ _uniform_expert_assignment(
+ expert_indices,
+ self.args.moe_num_experts,
+ ) if self.args.uniform_expert_assignment else expert_indices
+ )
+ return scores, expert_weights, expert_indices
diff --git a/build/torch210-cxx11-cu128-aarch64-linux/_layers/sharedexpert_registry.py b/build/torch210-cxx11-cu128-aarch64-linux/_layers/sharedexpert_registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..5840862f88f370ace5fd49bd0612fc98d186cc49
--- /dev/null
+++ b/build/torch210-cxx11-cu128-aarch64-linux/_layers/sharedexpert_registry.py
@@ -0,0 +1,32 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Union
+
+# from megablocks.layers import glu, mlp
+# from megablocks.layers.arguments import Arguments
+from . import glu, mlp
+from .arguments import Arguments
+
+_REGISTRY = {
+ 'mlp': mlp.SharedMLP,
+ 'glu': glu.SharedGLU,
+}
+
+
+def get(args: Arguments) -> Union[mlp.SharedMLP, glu.SharedGLU]:
+ """Returns an SharedMLP for use in a dMoE instance.
+
+ Uses the provided arguments to instantiate the appropriate
+ SharedMLP instance.
+
+ Args:
+ args: propagated Arguments dataclass.
+
+ Returns:
+ An instantiated SharedMLP constructed using the input args.
+ """
+ if args.mlp_type not in _REGISTRY:
+ raise ValueError(f'Unsupported mlp type: {args.mlp_type}')
+
+ return _REGISTRY[args.mlp_type](args)
diff --git a/build/torch210-cxx11-cu128-aarch64-linux/_megablocks_cuda_7a6bcf4.abi3.so b/build/torch210-cxx11-cu128-aarch64-linux/_megablocks_cuda_7a6bcf4.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..7d9953915c47274e401eeab2f0332618b25769ed
--- /dev/null
+++ b/build/torch210-cxx11-cu128-aarch64-linux/_megablocks_cuda_7a6bcf4.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a3f60e88338e68c1def0050d88229517cf28b83852da6b4a3fff41e73331eca0
+size 21088232
diff --git a/build/torch210-cxx11-cu128-aarch64-linux/_ops.py b/build/torch210-cxx11-cu128-aarch64-linux/_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..abde001f8cf5f78a02794d6e9a81fd8195e65d77
--- /dev/null
+++ b/build/torch210-cxx11-cu128-aarch64-linux/_ops.py
@@ -0,0 +1,9 @@
+import torch
+from . import _megablocks_cuda_7a6bcf4
+ops = torch.ops._megablocks_cuda_7a6bcf4
+
+def add_op_namespace_prefix(op_name: str):
+ """
+ Prefix op by namespace.
+ """
+ return f"_megablocks_cuda_7a6bcf4::{op_name}"
diff --git a/build/torch210-cxx11-cu128-aarch64-linux/_version.py b/build/torch210-cxx11-cu128-aarch64-linux/_version.py
new file mode 100644
index 0000000000000000000000000000000000000000..c55783177af19bc03654c730c4892df8f8532279
--- /dev/null
+++ b/build/torch210-cxx11-cu128-aarch64-linux/_version.py
@@ -0,0 +1,6 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+"""The MegaBlocks Version."""
+
+__version__ = '0.11.0.dev0'
diff --git a/build/torch210-cxx11-cu128-aarch64-linux/backend/__init__.py b/build/torch210-cxx11-cu128-aarch64-linux/backend/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d4e43e9b7e6a9a1c3bde2df34914643ca5d8332
--- /dev/null
+++ b/build/torch210-cxx11-cu128-aarch64-linux/backend/__init__.py
@@ -0,0 +1,2 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
diff --git a/build/torch210-cxx11-cu128-aarch64-linux/backend/kernels.py b/build/torch210-cxx11-cu128-aarch64-linux/backend/kernels.py
new file mode 100644
index 0000000000000000000000000000000000000000..d20139324883338ddb312e4b05a72056d74491ac
--- /dev/null
+++ b/build/torch210-cxx11-cu128-aarch64-linux/backend/kernels.py
@@ -0,0 +1,557 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import triton
+import triton.language as tl
+
+# Stub triton autotune when testing in a env that does not have CUDA
+# this approach preserves the original code but enables testing without a GPU
+if torch.cuda.is_available() is False:
+ import warnings
+
+ warnings.warn("CUDA is not available. Triton autotuning is disabled.")
+
+ def _no_autotune(*args, **kwargs):
+ def deco(fn):
+ return fn
+ return deco
+
+ triton.autotune = _no_autotune
+
+
+def assert_is_tensor(x, ndim):
+ if x.ndim != ndim:
+ raise ValueError(f'Expected {ndim}-tensor but got {x.ndim}-tensor')
+
+
+def assert_is_matrix(x):
+ assert_is_tensor(x, 2)
+
+
+def assert_is_vector(x):
+ if x.ndim != 1:
+ raise ValueError(f'Expected 1-tensor but got {x.ndim}-tensor')
+
+
+def assert_equal(a, b):
+ if a != b:
+ raise ValueError(f'Expected dimensions to be equal but got {a} and {b}.',)
+
+
+# a: (tokens, hidden_size), real.
+# indices: (tokens * top_k), integer.
+# bin_ids: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+# padded_bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _padded_copy(
+ a,
+ b,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+ A_TO_B: tl.constexpr,
+ SCALE: tl.constexpr,
+):
+ # Our index into array 'a'.
+ index_a = tl.load(indices + tl.program_id(0))
+
+ # One threadblock per row in 'a'. Array 'b' has greater or equal
+ # number of rows since they could be padded.
+ bin_idx = tl.load(bin_ids + tl.program_id(0))
+
+ # Now we know what bin we're assigned to, but we need to know how
+ # many threadblocks were assigned to earlier bins so we can offset
+ # in our bin properly.
+ offset_in_bin = tl.program_id(0)
+ if bin_idx > 0:
+ offset_in_bin -= tl.load(bins + bin_idx - 1)
+
+ # Load the starting index of our bin in array 'b'.
+ index_b = offset_in_bin
+ if bin_idx > 0:
+ index_b += tl.load(padded_bins + bin_idx - 1)
+
+ # Offset the input and output pointers.
+ #
+ # If we're going from A to B, divide the input index to copy
+ # the same input repeatedly. If we're going from B to A we
+ # need to reduce the result. Using atomics is slow, so we
+ # do the reduce step in a second kernel.
+ offset = index_a // TOP_K if A_TO_B else index_a
+ a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS)
+ b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ # Load the scale, if requested.
+ scale = tl.load(weights + index_a) if SCALE else 1
+
+ # Swap the pointers depending on the direction.
+ iptr = a if A_TO_B else b
+ optr = b if A_TO_B else a
+
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ x = tl.load(iptr + offsets, mask=mask)
+ x = x.to(tl.float32) * scale.to(tl.float32)
+
+ tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask)
+
+ offsets += BLOCK_X
+
+
+def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+ assert_equal(bin_ids.shape[0], x.shape[0] * top_k)
+ assert_equal(bins.size(), padded_bins.size())
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ # NOTE: Because of the padding, the output size is dynamic.
+ # We load the final padded bin bound to get the output rows.
+ output_rows = padded_bins[-1].cpu().item()
+ out = torch.zeros((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ x,
+ out,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def gather(x, indices, bin_ids, weights, bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+ assert_equal(bin_ids.shape[0], x.shape[0] * top_k)
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ # NOTE: There is no padding so the output rows equals the
+ # input rows multiplied by top_k.
+ output_rows = x.shape[0] * top_k
+ out = torch.empty((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ x,
+ out,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], bin_ids.shape[0])
+ assert_equal(bins.size(), padded_bins.size())
+
+ if weights is not None:
+ assert_equal(indices.shape[0], weights.shape[0])
+
+ tokens = indices.shape[0] // top_k
+ out = torch.empty((tokens, top_k, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ out,
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=False,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+
+ # Reduce along the top-k dimension, if needed.
+ return out.sum(dim=1) if top_k > 1 else out.view(tokens, x.shape[1])
+
+
+def scatter(x, indices, bin_ids, weights, bins, top_k):
+ return padded_scatter(x, indices, bin_ids, weights, bins, bins, top_k)
+
+
+# x: (tokens, top_k, hidden_size), real
+# grad: (tokens, hidden_size), real.
+# wgrad: (tokens, top_k), real.
+# indices: (tokens * top_k), integer.
+# bin_ids: (tokens * top_k), integer.
+# bins: (num_experts), integer.
+# padded_bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _padded_copy_wgrad(
+ x,
+ grad,
+ wgrad,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+):
+ # Our index into 'tokens * top_k'.
+ index_out = tl.load(indices + tl.program_id(0))
+
+ # One threadblock per row in 'a'. Array 'b' has greater or equal
+ # number of rows since they could be padded.
+ bin_idx = tl.load(bin_ids + tl.program_id(0))
+
+ # Now we know what bin we're assigned to, but we need to know how
+ # many threadblocks were assigned to earlier bins so we can offset
+ # in our bin properly.
+ offset_in_bin = tl.program_id(0)
+ if bin_idx > 0:
+ offset_in_bin -= tl.load(bins + bin_idx - 1)
+
+ # Load the starting index of our bin in array 'x'.
+ index_x = offset_in_bin
+ if bin_idx > 0:
+ index_x += tl.load(padded_bins + bin_idx - 1)
+
+ # Offset the input and output pointers.
+ wgrad += index_out
+ grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS)
+ x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ data = tl.load(x + offsets, mask=mask).to(tl.float32)
+ scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
+ acc += data * scale
+ offsets += BLOCK_X
+
+ # Reduce to get the final result and store.
+ out = tl.sum(acc).to(wgrad.dtype.element_ty)
+ tl.store(wgrad, out)
+
+
+def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_matrix(grad)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], bin_ids.shape[0])
+ assert_equal(bins.size(), padded_bins.size())
+
+ tokens = indices.shape[0] // top_k
+ out = torch.empty((tokens * top_k), dtype=x.dtype, device=x.device)
+ _padded_copy_wgrad[(indices.shape[0],)](
+ x,
+ grad,
+ out,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ TOP_K=top_k,
+ )
+ return out
+
+
+def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k):
+ return padded_scatter_wgrad(x, grad, indices, bin_ids, bins, bins, top_k)
+
+
+# a: (tokens, hidden_size), real.
+# b: (num_experts, expert_capacity, num_columns), real.
+# indices: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _binned_copy(
+ a,
+ b,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+ A_TO_B: tl.constexpr,
+ SCALE: tl.constexpr,
+):
+ # Load our indices into the output.
+ expert_idx = tl.program_id(0)
+ entry_idx = tl.program_id(1)
+
+ # Calculate our offset into the output.
+ index_b = expert_idx * expert_capacity + entry_idx
+
+ # Load the index bounds for our bin and calculate
+ # the number of tokens assigned to our expert.
+ start = 0
+ if expert_idx > 0:
+ start = tl.load(bins + expert_idx - 1)
+ end = tl.load(bins + expert_idx)
+ num_tokens = end - start
+
+ # Calculate our offset into the input. If we don't
+ # have an input exit early.
+ if entry_idx >= num_tokens:
+ return
+ index_a = tl.load(indices + start + entry_idx)
+
+ # Offset the input and output pointers.
+ #
+ # If we're going from A to B, divide the input index to copy
+ # the same input repeatedly. If we're going from B to A we
+ # need to reduce the result. Using atomics is slow, so we
+ # do the reduce step in a second kernel.
+ offset = index_a // TOP_K if A_TO_B else index_a
+ a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS)
+ b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ # Load the scale, if requested.
+ scale = tl.load(weights + index_a) if SCALE else 1
+
+ # Swap the pointers depending on the direction.
+ #
+ # NOTE: We need to zero the output in both directions.
+ iptr = a if A_TO_B else b
+ optr = b if A_TO_B else a
+
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ x = tl.load(iptr + offsets, mask=mask)
+ x = x.to(tl.float32) * scale.to(tl.float32)
+
+ tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask)
+
+ offsets += BLOCK_X
+
+
+def binned_gather(x, indices, weights, bins, expert_capacity, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ num_experts = bins.shape[0]
+ out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
+
+ _binned_copy[(num_experts, expert_capacity)](
+ x,
+ out,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def binned_scatter(x, indices, weights, bins, top_k):
+ # Validate the input shapes.
+ assert_is_tensor(x, 3)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(bins.shape[0], x.shape[0])
+
+ if weights is not None:
+ assert_equal(indices.shape[0], weights.shape[0])
+
+ num_experts, expert_capacity, hidden_size = x.shape
+ tokens = indices.shape[0] // top_k
+ out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
+ _binned_copy[(num_experts, expert_capacity)](
+ out,
+ x,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS=hidden_size,
+ A_TO_B=False,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+
+ # Reduce along the top-k dimension, if needed.
+ return out.sum(dim=1) if top_k > 1 else out.view(tokens, hidden_size)
+
+
+# a: (tokens, hidden_size), real.
+# b: (num_experts, expert_capacity, num_columns), real.
+# indices: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _binned_copy_wgrad(
+ x,
+ grad,
+ wgrad,
+ num_experts,
+ expert_capacity,
+ indices,
+ bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+):
+ # Load our indices into the output.
+ expert_idx = tl.program_id(0)
+ entry_idx = tl.program_id(1)
+
+ # Calculate our offset into the output.
+ index_x = expert_idx * expert_capacity + entry_idx
+
+ # Load the index bounds for our bin and calculate
+ # the number of tokens assigned to our expert.
+ start = 0
+ if expert_idx > 0:
+ start = tl.load(bins + expert_idx - 1)
+ end = tl.load(bins + expert_idx)
+ num_tokens = end - start
+
+ # Calculate our offset into the input. If we don't
+ # have an input exit early.
+ if entry_idx >= num_tokens:
+ return
+ index_out = tl.load(indices + start + entry_idx)
+
+ # Offset the input and output pointers.
+ wgrad += index_out
+ grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS)
+ x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ data = tl.load(x + offsets, mask=mask).to(tl.float32)
+ scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
+ acc += data * scale
+ offsets += BLOCK_X
+
+ # Reduce to get the final result and store.
+ out = tl.sum(acc).to(wgrad.dtype.element_ty)
+ tl.store(wgrad, out)
+
+
+def binned_scatter_wgrad(x, grad, indices, bins, top_k):
+ # Validate the input shapes.
+ assert_is_tensor(x, 3)
+ assert_is_matrix(grad)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(bins.shape[0], x.shape[0])
+
+ num_experts, expert_capacity, hidden_size = x.shape
+ tokens = indices.shape[0] // top_k
+ out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
+ _binned_copy_wgrad[(num_experts, expert_capacity)](
+ x,
+ grad,
+ out,
+ num_experts,
+ expert_capacity,
+ indices,
+ bins,
+ NUM_COLUMNS=hidden_size,
+ TOP_K=top_k,
+ )
+ return out
diff --git a/build/torch210-cxx11-cu128-aarch64-linux/benchmark_util.py b/build/torch210-cxx11-cu128-aarch64-linux/benchmark_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..02612d95e3ead1175a596e2878fa34b5bf85ad6f
--- /dev/null
+++ b/build/torch210-cxx11-cu128-aarch64-linux/benchmark_util.py
@@ -0,0 +1,35 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import numpy as np
+import torch
+
+
+def log_benchmark(name, arguments, time, std):
+ print('=' * 60)
+ print(f'{name} Benchmark')
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean time = {:.3f}ms, std time = {:.3f}ms'.format(time, std))
+ print('=' * 60)
+
+
+def benchmark_function(fn, iterations=100, warmup=10):
+ # Warmup iterations.
+ for _ in range(warmup):
+ fn()
+
+ times = []
+ for i in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+
+ start.record()
+ fn()
+ end.record()
+
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ return np.mean(times), np.std(times)
diff --git a/build/torch210-cxx11-cu128-aarch64-linux/cpu_fused_moe.py b/build/torch210-cxx11-cu128-aarch64-linux/cpu_fused_moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a29e81d6e9c57f64b0a78ac1c0828e45fd9d855
--- /dev/null
+++ b/build/torch210-cxx11-cu128-aarch64-linux/cpu_fused_moe.py
@@ -0,0 +1,311 @@
+# SPDX-License-Identifier: Apache-2.0
+# MegaBlocks CPU Fused MoE Implementation
+#
+# This is a pure Python/PyTorch implementation for CPU.
+# For better performance, consider using the C++ kernel implementation.
+#
+import torch
+import torch.nn.functional as F
+
+
+def swigluoai_activation(gate: torch.Tensor, up: torch.Tensor,
+ alpha: float = 1.702, limit: float = 7.0) -> torch.Tensor:
+ """
+ SwigluOAI activation function used in GptOss models.
+
+ Formula:
+ gate = clamp(gate, max=limit)
+ up = clamp(up, -limit, limit)
+ glu = gate * sigmoid(gate * alpha)
+ output = (up + 1) * glu
+
+ Args:
+ gate: Gate tensor from gate projection
+ up: Up tensor from up projection
+ alpha: Scaling factor for sigmoid (default: 1.702)
+ limit: Clamp limit (default: 7.0)
+
+ Returns:
+ Activated tensor
+ """
+ gate = gate.clamp(max=limit)
+ up = up.clamp(min=-limit, max=limit)
+ glu = gate * torch.sigmoid(gate * alpha)
+ return (up + 1) * glu
+
+
+def silu_and_mul_activation(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor:
+ """
+ SiLU (Swish) activation with element-wise multiplication.
+
+ Formula:
+ output = silu(gate) * up
+
+ Args:
+ gate: Gate tensor
+ up: Up tensor
+
+ Returns:
+ Activated tensor
+ """
+ return F.silu(gate) * up
+
+
+def route_tokens_cpu(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: torch.Tensor | None,
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_normalize_expert_weights: int | None = None,
+) -> tuple:
+ """
+ Route tokens to experts and compute expert weights and indices (CPU version).
+
+ Args:
+ x: Input tensor [batch, seq, hidden] or [tokens, hidden]
+ router_weight: Router weight [num_experts, hidden]
+ router_bias: Router bias [num_experts] or None
+ moe_top_k: Number of experts per token
+ moe_num_experts: Total number of experts
+ moe_normalize_expert_weights: Normalization order or None
+
+ Returns:
+ Tuple of (logits, expert_weights, expert_indices)
+ """
+ x_flat = x.view(-1, x.shape[-1])
+ logits = F.linear(x_flat, router_weight, router_bias)
+
+ if moe_top_k == 1:
+ expert_weights, expert_indices = logits.max(dim=-1, keepdim=True)
+ else:
+ expert_weights, expert_indices = torch.topk(logits, moe_top_k, dim=-1)
+
+ expert_weights = expert_weights.softmax(dim=-1)
+
+ if moe_normalize_expert_weights is not None:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+
+ return logits, expert_weights, expert_indices
+
+
+def cpu_fused_moe(
+ hidden_states: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ w1_bias: torch.Tensor | None = None,
+ w2_bias: torch.Tensor | None = None,
+ activation: str = "silu",
+ alpha: float = 1.702,
+ limit: float = 7.0,
+ is_interleaved: bool = True,
+) -> torch.Tensor:
+ """
+ CPU Fused MoE using PyTorch operations.
+
+ This implementation processes all experts in parallel using batched operations
+ instead of sequential for loops, which is more efficient on CPU.
+
+ Args:
+ hidden_states: [num_tokens, hidden_size]
+ w1: [num_experts, hidden_size, 2*inter_size] - gate_up_proj weights
+ w2: [num_experts, inter_size, hidden_size] - down_proj weights
+ topk_weights: [num_tokens, topk] - routing weights
+ topk_ids: [num_tokens, topk] - expert indices
+ w1_bias: [num_experts, 2*inter_size] or None
+ w2_bias: [num_experts, hidden_size] or None
+ activation: "silu" or "swigluoai"
+ alpha: swigluoai alpha parameter
+ limit: swigluoai limit parameter
+ is_interleaved: whether gate_up is interleaved [g0,u0,g1,u1,...] (True for GptOss)
+
+ Returns:
+ output: [num_tokens, hidden_size]
+ """
+ num_tokens, hidden_size = hidden_states.shape
+ num_experts = w1.shape[0]
+ inter_size = w2.shape[1]
+ topk = topk_weights.shape[1]
+
+ # Initialize output
+ output = torch.zeros_like(hidden_states)
+
+ # Build expert mask: which tokens go to which expert
+ # expert_mask[expert_id] contains indices of (token_idx, topk_pos) pairs
+ for expert_idx in range(num_experts):
+ # Find tokens assigned to this expert
+ # mask shape: [num_tokens, topk], True where topk_ids == expert_idx
+ mask = (topk_ids == expert_idx)
+
+ if not mask.any():
+ continue
+
+ # Get token indices and topk positions
+ token_indices, topk_positions = torch.where(mask)
+
+ if len(token_indices) == 0:
+ continue
+
+ # Gather input tokens for this expert
+ # current_hidden: [num_selected_tokens, hidden_size]
+ current_hidden = hidden_states[token_indices]
+
+ # Get weights for this expert
+ # w1[expert_idx]: [hidden_size, 2*inter_size]
+ # w2[expert_idx]: [inter_size, hidden_size]
+ expert_w1 = w1[expert_idx] # [hidden_size, 2*inter_size]
+ expert_w2 = w2[expert_idx] # [inter_size, hidden_size]
+
+ # First projection: hidden @ w1 -> [num_selected, 2*inter_size]
+ gate_up = current_hidden @ expert_w1
+
+ # Add bias if present
+ if w1_bias is not None:
+ gate_up = gate_up + w1_bias[expert_idx]
+
+ # Split gate and up projections
+ if is_interleaved:
+ # GptOss uses interleaved layout: [g0, u0, g1, u1, ...]
+ gate = gate_up[..., ::2] # [num_selected, inter_size]
+ up = gate_up[..., 1::2] # [num_selected, inter_size]
+ else:
+ # Standard layout: [gate_all, up_all]
+ gate = gate_up[..., :inter_size]
+ up = gate_up[..., inter_size:]
+
+ # Apply activation
+ if activation == "swigluoai":
+ activated = swigluoai_activation(gate, up, alpha, limit)
+ else: # silu
+ activated = silu_and_mul_activation(gate, up)
+
+ # Second projection: activated @ w2 -> [num_selected, hidden_size]
+ expert_out = activated @ expert_w2
+
+ # Add bias if present
+ if w2_bias is not None:
+ expert_out = expert_out + w2_bias[expert_idx]
+
+ # Apply routing weights and accumulate
+ # weights shape: [num_selected]
+ weights = topk_weights[token_indices, topk_positions].unsqueeze(-1)
+ weighted_out = expert_out * weights
+
+ # Accumulate to output
+ output.index_add_(0, token_indices, weighted_out.to(output.dtype))
+
+ return output
+
+
+class MegaBlocksMoeMLP(torch.nn.Module):
+ """
+ CPU MoE MLP module that can be used as a drop-in replacement for
+ the transformers GptOssMLP when using @use_kernel_forward_from_hub.
+ """
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> tuple:
+ """
+ Forward pass through the MoE layer.
+
+ Args:
+ x: Input tensor of shape [batch_size, seq_len, hidden_size] or [tokens, hidden_size]
+
+ Returns:
+ Tuple of (output, expert_weights) where:
+ - output: Tensor of same shape as input
+ - expert_weights: Expert weights for each token [tokens, top_k]
+ """
+ # Get MoE parameters from the wrapped modules
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+
+ # Detect activation type
+ if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
+ activation = "swigluoai"
+ alpha = self.experts.alpha
+ limit = self.experts.limit
+ else:
+ activation = getattr(self.experts, "activation", "silu")
+ alpha = 1.702
+ limit = 7.0
+
+ # Get weight tensors
+ if hasattr(self.experts, "gate_up_proj"):
+ w1 = self.experts.gate_up_proj
+ is_interleaved = True # GptOss uses interleaved layout
+ elif hasattr(self.experts, "w1"):
+ w1 = self.experts.w1
+ w3 = getattr(self.experts, "w3", None)
+ if w3 is not None:
+ w1 = torch.cat([w1, w3], dim=-1)
+ is_interleaved = False
+ else:
+ raise AttributeError("experts module must have 'gate_up_proj' or 'w1' attribute")
+
+ if hasattr(self.experts, "down_proj"):
+ w2 = self.experts.down_proj
+ elif hasattr(self.experts, "w2"):
+ w2 = self.experts.w2
+ else:
+ raise AttributeError("experts module must have 'down_proj' or 'w2' attribute")
+
+ # Get optional bias tensors
+ w1_bias = getattr(self.experts, "gate_up_proj_bias", None)
+ w2_bias = getattr(self.experts, "down_proj_bias", None)
+
+ # Store original shape
+ in_shape = x.size()
+
+ # Route tokens to experts
+ logits, expert_weights, expert_indices = route_tokens_cpu(
+ x,
+ self.router.weight,
+ getattr(self.router, "bias", None),
+ moe_top_k,
+ moe_num_experts,
+ moe_normalize_expert_weights,
+ )
+
+ # Reshape input for fused MoE
+ x_flat = x.view(-1, x.shape[-1])
+
+ # Call CPU fused MoE
+ output = cpu_fused_moe(
+ hidden_states=x_flat,
+ w1=w1,
+ w2=w2,
+ topk_weights=expert_weights,
+ topk_ids=expert_indices,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ activation=activation,
+ alpha=alpha,
+ limit=limit,
+ is_interleaved=is_interleaved,
+ )
+
+ # Restore original shape
+ output = output.view(in_shape)
+
+ return output, expert_weights
+
+
+# Export classes and functions
+__all__ = [
+ "MegaBlocksMoeMLP",
+ "cpu_fused_moe",
+ "route_tokens_cpu",
+ "swigluoai_activation",
+ "silu_and_mul_activation",
+]
diff --git a/build/torch210-cxx11-cu128-aarch64-linux/cpu_moe_cpp.py b/build/torch210-cxx11-cu128-aarch64-linux/cpu_moe_cpp.py
new file mode 100644
index 0000000000000000000000000000000000000000..073ff66d24ce348fbb5ed19c9027fadd3f7a9c61
--- /dev/null
+++ b/build/torch210-cxx11-cu128-aarch64-linux/cpu_moe_cpp.py
@@ -0,0 +1,265 @@
+# SPDX-License-Identifier: Apache-2.0
+# MegaBlocks C++ Optimized CPU MoE
+
+"""
+C++ accelerated MoE with brgemm optimization for Intel AMX.
+Direct replacement for cpu_fused_moe.MegaBlocksMoeMLP with better performance.
+"""
+
+import torch
+from typing import Optional
+from .cpu_fused_moe import route_tokens_cpu
+from ._ops import ops
+
+
+def _to_local_tensor(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
+ """Convert DTensor to local torch.Tensor if needed for custom ops compatibility."""
+ if tensor is None:
+ return None
+ # Check if it's a DTensor by looking for the to_local() method
+ if hasattr(tensor, "to_local"):
+ return tensor.to_local()
+ return tensor
+
+
+def fused_moe_cpp(
+ hidden_states: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ inplace: bool = False,
+ use_int8_w8a8: bool = False,
+ use_fp8_w8a16: bool = False,
+ use_mxfp4: bool = False,
+ w1_scale: Optional[torch.Tensor] = None,
+ w2_scale: Optional[torch.Tensor] = None,
+ block_size: Optional[list] = None,
+ a1_scale: Optional[torch.Tensor] = None,
+ a2_scale: Optional[torch.Tensor] = None,
+ w1_bias: Optional[torch.Tensor] = None,
+ w2_bias: Optional[torch.Tensor] = None,
+ alpha: Optional[float] = None,
+ limit: Optional[float] = None,
+ is_vnni: bool = False,
+) -> torch.Tensor:
+ """
+ C++ Fused MoE with brgemm optimization (sglang compatible interface).
+
+ Uses at::native::cpublas::brgemm for efficient batch GEMM on Intel CPUs.
+ Supports both silu_and_mul (standard SwiGLU) and swigluoai (GptOss) activations.
+
+ Args:
+ hidden_states: Input tensor [M, K]
+ w1: Gate and up projections [E, 2N, K]
+ w2: Down projection [E, K, N]
+ topk_weights: Expert weights [M, topk]
+ topk_ids: Expert indices [M, topk]
+ inplace: Whether to use hidden_states as output
+ use_int8_w8a8: Use int8 quantization
+ use_fp8_w8a16: Use fp8 quantization
+ use_mxfp4: Use mxfp4 quantization
+ w1_scale, w2_scale: Quantization scales
+ block_size: Block size for fp8
+ a1_scale, a2_scale: Activation scales
+ w1_bias, w2_bias: Optional biases
+ alpha: swigluoai alpha parameter (set to enable swiglu)
+ limit: swigluoai limit parameter (set to enable swiglu)
+ is_vnni: Whether w1/w2 are already in VNNI packed format
+ """
+ # MXFP4/FP8 kernels only support bf16, convert if needed
+ orig_dtype = hidden_states.dtype
+ need_convert = ((use_mxfp4 or use_fp8_w8a16) and orig_dtype != torch.bfloat16) or orig_dtype == torch.float32
+ if need_convert:
+ hidden_states = hidden_states.to(torch.bfloat16)
+
+ # bias must match hidden_states dtype
+ if w1_bias is not None:
+ w1_bias = w1_bias.to(hidden_states.dtype)
+ if w2_bias is not None:
+ w2_bias = w2_bias.to(hidden_states.dtype)
+
+ # Convert DTensor to local tensor for custom ops compatibility (TP mode)
+ hidden_states = _to_local_tensor(hidden_states)
+ w1 = _to_local_tensor(w1)
+ w2 = _to_local_tensor(w2)
+ topk_weights = _to_local_tensor(topk_weights)
+ topk_ids = _to_local_tensor(topk_ids)
+ w1_scale = _to_local_tensor(w1_scale)
+ w2_scale = _to_local_tensor(w2_scale)
+ a1_scale = _to_local_tensor(a1_scale)
+ a2_scale = _to_local_tensor(a2_scale)
+ w1_bias = _to_local_tensor(w1_bias)
+ w2_bias = _to_local_tensor(w2_bias)
+
+ output = ops.fused_experts(
+ hidden_states, w1, w2, topk_weights, topk_ids,
+ inplace, use_int8_w8a8, use_fp8_w8a16, use_mxfp4,
+ w1_scale, w2_scale, block_size, a1_scale, a2_scale,
+ w1_bias, w2_bias, alpha, limit, is_vnni
+ )
+
+ # Convert back to original dtype if needed
+ if need_convert:
+ output = output.to(orig_dtype)
+ return output
+
+
+class CPUMegaBlocksMoeMLP(torch.nn.Module):
+ """
+ C++ optimized MoE MLP using brgemm.
+ Drop-in replacement for cpu_fused_moe.MegaBlocksMoeMLP with better performance.
+
+ Usage in transformers:
+ # Will be used via @use_kernel_forward_from_hub decorator
+ """
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> tuple:
+ """
+ Forward pass through the MoE layer using C++ kernel.
+
+ Args:
+ x: Input tensor [batch_size, seq_len, hidden_size]
+
+ Returns:
+ Tuple of (output, expert_weights)
+ """
+ # Optimization for GPT-OSS model
+ if getattr(self, "use_mxfp4", None) is None:
+ self.use_mxfp4 = False
+
+ w1_scale = None
+ w2_scale = None
+
+ if (
+ not getattr(self, "packed_scales", False)
+ and hasattr(self.experts, "gate_up_proj")
+ and getattr(self.experts, "gate_up_proj_precision_config", None) is not None
+ ):
+ # convert scales
+ data_1 = ops.convert_scale_packed(self.experts.gate_up_proj_precision_config.weight_scale.data.transpose(-1, -2).contiguous())
+ data_2 = ops.convert_scale_packed(self.experts.down_proj_precision_config.weight_scale.data.transpose(-1, -2).contiguous())
+ self.experts.gate_up_proj_precision_config.weight_scale.storage.data = data_1
+ self.experts.down_proj_precision_config.weight_scale.storage.data = data_2
+ self.packed_scales = True
+ self.use_mxfp4 = True
+
+ if not getattr(self, "packed_weight", False) and hasattr(
+ self.experts, "gate_up_proj"
+ ):
+ # convert weights
+ data_1 = self.experts.gate_up_proj.data.transpose(-1, -2).contiguous()
+ data_2 = self.experts.down_proj.data.transpose(-1, -2).contiguous()
+ if self.use_mxfp4:
+ self.experts.gate_up_proj.storage.data = ops.convert_weight_packed(data_1)
+ self.experts.down_proj.storage.data = ops.convert_weight_packed(data_2)
+ else:
+ # convert_weight_packed only supports bfloat16, float16, int8, fp8_e4m3 or uint8(mxfp4 or int4).
+ data_1 = data_1.to(torch.bfloat16) if data_1.dtype == torch.float32 else data_1
+ data_2 = data_2.to(torch.bfloat16) if data_2.dtype == torch.float32 else data_2
+ self.experts.gate_up_proj.data = ops.convert_weight_packed(data_1)
+ self.experts.down_proj.data = ops.convert_weight_packed(data_2)
+
+ # C++ kernel does not support float32.
+ dtype = torch.bfloat16 if x.dtype == torch.float32 else x.dtype
+ if getattr(self.experts, "gate_up_proj_bias", None) is not None:
+ self.experts.gate_up_proj_bias.data = self.experts.gate_up_proj_bias.data.to(dtype)
+ if getattr(self.experts, "down_proj_bias", None) is not None:
+ self.experts.down_proj_bias.data = self.experts.down_proj_bias.data.to(dtype)
+
+ self.packed_weight = True
+
+ # Get MoE parameters
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
+
+ # Detect activation type
+ if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
+ activation = "swigluoai"
+ alpha = self.experts.alpha
+ limit = self.experts.limit
+ else:
+ activation = getattr(self.experts, "activation", "silu")
+ alpha = 1.702
+ limit = 7.0
+
+ # Get weight tensors
+ if hasattr(self.experts, "gate_up_proj"):
+ w1 = self.experts.gate_up_proj
+ elif hasattr(self.experts, "w1"):
+ w1 = self.experts.w1
+ w3 = getattr(self.experts, "w3", None)
+ if w3 is not None:
+ w1 = torch.cat([w1, w3], dim=-1)
+ else:
+ raise AttributeError("experts module must have 'gate_up_proj' or 'w1' attribute")
+
+ if hasattr(self.experts, "down_proj"):
+ w2 = self.experts.down_proj
+ elif hasattr(self.experts, "w2"):
+ w2 = self.experts.w2
+ else:
+ raise AttributeError("experts module must have 'down_proj' or 'w2' attribute")
+
+ # Get optional bias tensors
+ w1_bias = getattr(self.experts, "gate_up_proj_bias", None)
+ w2_bias = getattr(self.experts, "down_proj_bias", None)
+ w1_bias = w1_bias if w1_bias is None else w1_bias.data
+ w2_bias = w2_bias if w2_bias is None else w2_bias.data
+
+ if self.use_mxfp4:
+ w1_scale = self.experts.gate_up_proj_precision_config.weight_scale.data
+ w2_scale = self.experts.down_proj_precision_config.weight_scale.data
+
+ # Store original shape
+ in_shape = x.size()
+
+ # Route tokens to experts (Python implementation is fast enough)
+ logits, expert_weights, expert_indices = route_tokens_cpu(
+ x,
+ self.router.weight,
+ getattr(self.router, "bias", None),
+ moe_top_k,
+ moe_num_experts,
+ moe_normalize_expert_weights,
+ )
+
+ # Flatten input
+ x_flat = x.view(-1, x.shape[-1])
+
+ # Determine alpha/limit for swiglu activation
+ use_alpha = alpha if activation == "swigluoai" else None
+ use_limit = limit if activation == "swigluoai" else None
+
+ # Call C++ optimized kernel
+ output = fused_moe_cpp(
+ hidden_states=x_flat,
+ w1=w1.data,
+ w2=w2.data,
+ topk_weights=expert_weights,
+ topk_ids=expert_indices.to(torch.int32),
+ inplace=False,
+ use_int8_w8a8=False,
+ use_fp8_w8a16=False,
+ use_mxfp4=self.use_mxfp4,
+ w1_scale=w1_scale,
+ w2_scale=w2_scale,
+ block_size=None,
+ a1_scale=None,
+ a2_scale=None,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ alpha=use_alpha,
+ limit=use_limit,
+ is_vnni=getattr(self, "packed_weight", False),
+ )
+
+ # Restore original shape
+ output = output.view(in_shape)
+
+ return output, expert_weights
+
+
+__all__ = ["fused_moe_cpp", "MegaBlocksMoeMLP"]
diff --git a/build/torch210-cxx11-cu128-aarch64-linux/grouped_gemm/__init__.py b/build/torch210-cxx11-cu128-aarch64-linux/grouped_gemm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b91c8308f0c24f4c4171b6e4f15b6f76dabf295a
--- /dev/null
+++ b/build/torch210-cxx11-cu128-aarch64-linux/grouped_gemm/__init__.py
@@ -0,0 +1,2 @@
+from . import ops
+from . import backend
diff --git a/build/torch210-cxx11-cu128-aarch64-linux/grouped_gemm/backend.py b/build/torch210-cxx11-cu128-aarch64-linux/grouped_gemm/backend.py
new file mode 100644
index 0000000000000000000000000000000000000000..76037d8039cbfc2f0577275c78e4bc0be762592a
--- /dev/null
+++ b/build/torch210-cxx11-cu128-aarch64-linux/grouped_gemm/backend.py
@@ -0,0 +1,33 @@
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# # TODO(tgale): Wrap this in a try-block with better
+# # error message and instructions for building the
+# # c++ operations.
+# import grouped_gemm_backend as backend
+
+# We import the backend operations from the megablocks package as
+# grouped_gemm is vendored in megablocks in this repository.
+# from ... import _ops as backend
+# from megablocks._ops import ops as backend # type: ignore
+from .._ops import ops as backend # type: ignore
+
+def _allocate_output(a, b, batch_sizes, trans_a, trans_b):
+ assert not (trans_a and trans_b)
+ assert batch_sizes.ndim == 1, "Expected 1d tensor for batch_sizes"
+ assert a.ndim == 2, "Expected 2d tensor for 'a'"
+ assert b.ndim == (2 if trans_a else 3)
+
+ shape = (
+ (batch_sizes.shape[0], a.shape[1], b.shape[1])
+ if trans_a else
+ (a.shape[0], (b.shape[1] if trans_b else b.shape[2]))
+ )
+ return torch.empty(*shape, device=a.device, dtype=a.dtype)
+
+def gmm(a, b, batch_sizes, trans_a=False, trans_b=False, c=None):
+ if c is None:
+ c = _allocate_output(a, b, batch_sizes, trans_a, trans_b)
+ backend.gmm(a, b, c, batch_sizes, trans_a, trans_b)
+ return c
diff --git a/build/torch210-cxx11-cu128-aarch64-linux/grouped_gemm/ops.py b/build/torch210-cxx11-cu128-aarch64-linux/grouped_gemm/ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b30dd14e23837ea3b12334f4e31337ed9ad2b69
--- /dev/null
+++ b/build/torch210-cxx11-cu128-aarch64-linux/grouped_gemm/ops.py
@@ -0,0 +1,33 @@
+from . import backend
+import torch
+
+
+class GroupedGemm(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, a, b, batch_sizes, trans_b):
+ ctx.save_for_backward(a, b, batch_sizes)
+ ctx.trans_b = trans_b
+ return backend.gmm(a, b, batch_sizes, trans_a=False, trans_b=trans_b)
+
+ @staticmethod
+ def backward(ctx, grad):
+ grad = grad.contiguous()
+ a, b, batch_sizes = ctx.saved_tensors
+ trans_b = ctx.trans_b
+
+ agrad = None
+ if ctx.needs_input_grad[0]:
+ agrad = backend.gmm(
+ grad, b, batch_sizes, trans_a=False, trans_b=not trans_b)
+
+ bgrad = None
+ if ctx.needs_input_grad[1]:
+ lhs, rhs = (grad, a) if trans_b else (a, grad)
+ bgrad = backend.gmm(
+ lhs, rhs, batch_sizes, trans_a=True, trans_b=False)
+ return agrad, bgrad, None, None
+
+
+def gmm(a, b, batch_sizes, trans_b=False):
+ return GroupedGemm.apply(a, b, batch_sizes, trans_b)
diff --git a/build/torch210-cxx11-cu128-aarch64-linux/grouped_gemm_util.py b/build/torch210-cxx11-cu128-aarch64-linux/grouped_gemm_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d6d49fc46b0a57ad46e4179df3cc1ac2a24f7ae
--- /dev/null
+++ b/build/torch210-cxx11-cu128-aarch64-linux/grouped_gemm_util.py
@@ -0,0 +1,31 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+import warnings
+
+_grouped_gemm_is_available: bool = False
+try:
+ # import grouped_gemm
+ pass
+ _grouped_gemm_is_available = True
+except ImportError as error:
+ warnings.warn('Grouped GEMM not available.')
+
+
+def grouped_gemm_is_available():
+ return _grouped_gemm_is_available
+
+
+def assert_grouped_gemm_is_available():
+ msg = (
+ 'Grouped GEMM not available. Please run '
+ '`pip install git+https://github.com/tgale96/grouped_gemm@main`.',
+ )
+ assert _grouped_gemm_is_available, msg
+
+
+# backend = grouped_gemm.backend if grouped_gemm_is_available() else None
+# ops = grouped_gemm.ops if grouped_gemm_is_available() else None
+
+
+from .grouped_gemm import backend as ops
+from .grouped_gemm import ops as backend
diff --git a/build/torch210-cxx11-cu128-aarch64-linux/layers.py b/build/torch210-cxx11-cu128-aarch64-linux/layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c91edcd9e2d1a4ef9eac90217ff481f08ab1886
--- /dev/null
+++ b/build/torch210-cxx11-cu128-aarch64-linux/layers.py
@@ -0,0 +1,1232 @@
+import torch
+import torch.distributed as dist
+
+from typing import Optional, Any, TYPE_CHECKING
+
+from . import _layers
+from . import ops
+
+# Conditional import for meta kernel registration
+if TYPE_CHECKING:
+
+ def register_fake(fn):
+ return lambda name: fn
+
+else:
+ try:
+ from torch.library import register_fake
+ except ImportError:
+ try:
+ from torch.library import impl_abstract as register_fake
+ except ImportError:
+ # Fallback for older PyTorch versions
+ def register_fake(op_name):
+ def decorator(fn):
+ return fn
+
+ return decorator
+
+
+# Meta kernel implementations for torch.compile compatibility
+def _install_meta_kernels():
+ """Install meta kernels for existing MegaBlocks operations"""
+
+ # Create wrapper functions that check for compilation and return meta tensors
+
+ # Patch ops.sort
+ if hasattr(ops, "sort"):
+ original_sort = ops.sort
+
+ def sort_with_meta(x, end_bit=None):
+ if torch.compiler.is_compiling():
+ # print("Using meta kernel for sort")
+ # Meta implementation - return tensors with correct shape/dtype/device
+ return torch.empty_like(x), torch.empty_like(x)
+ # print("Using original sort kernel")
+ return original_sort(x, end_bit)
+
+ ops.sort = sort_with_meta
+
+ # Patch ops.histogram
+ if hasattr(ops, "histogram"):
+ original_histogram = ops.histogram
+
+ def histogram_with_meta(x, max_val):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty((max_val,), dtype=torch.int32, device=x.device)
+ return original_histogram(x, max_val)
+
+ ops.histogram = histogram_with_meta
+
+ # Patch ops.inclusive_cumsum
+ if hasattr(ops, "inclusive_cumsum"):
+ original_inclusive_cumsum = ops.inclusive_cumsum
+
+ def inclusive_cumsum_with_meta(x, dim):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty_like(x)
+ return original_inclusive_cumsum(x, dim)
+
+ ops.inclusive_cumsum = inclusive_cumsum_with_meta
+
+ # Patch ops.binned_gather
+ if hasattr(ops, "binned_gather"):
+ original_binned_gather = ops.binned_gather
+
+ def binned_gather_with_meta(x, indices, bins, bin_size, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - output shape based on bin_size
+ if x.dim() >= 2:
+ hidden_size = x.size(-1)
+ return torch.empty(
+ (bin_size, x.size(1), hidden_size),
+ dtype=x.dtype,
+ device=x.device,
+ )
+ else:
+ return torch.empty((bin_size,), dtype=x.dtype, device=x.device)
+ return original_binned_gather(x, indices, bins, bin_size, top_k)
+
+ ops.binned_gather = binned_gather_with_meta
+
+ # Patch ops.binned_scatter
+ if hasattr(ops, "binned_scatter"):
+ original_binned_scatter = ops.binned_scatter
+
+ def binned_scatter_with_meta(x, indices, weights, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - typically reduces to 2D
+ if x.dim() >= 3:
+ return torch.empty(
+ (x.size(1), x.size(2)), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty_like(x)
+ return original_binned_scatter(x, indices, weights, bins, top_k)
+
+ ops.binned_scatter = binned_scatter_with_meta
+
+ # Patch ops.gather
+ if hasattr(ops, "gather"):
+ original_gather = ops.gather
+
+ def gather_with_meta(x, indices, bin_ids, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ if x.dim() >= 2:
+ hidden_size = x.size(-1)
+ return torch.empty(
+ (indices.numel(), hidden_size), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty(indices.shape, dtype=x.dtype, device=x.device)
+ return original_gather(x, indices, bin_ids, bins, top_k)
+
+ ops.gather = gather_with_meta
+
+ # Patch ops.scatter
+ if hasattr(ops, "scatter"):
+ original_scatter = ops.scatter
+
+ def scatter_with_meta(x, indices, bin_ids, weights, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - restore sequence shape
+ seq_len = (
+ indices.size(0) // top_k
+ if indices.numel() > 0 and top_k > 0
+ else x.size(0)
+ )
+ if x.dim() >= 2:
+ return torch.empty(
+ (seq_len, x.size(-1)), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty((seq_len,), dtype=x.dtype, device=x.device)
+ return original_scatter(x, indices, bin_ids, weights, bins, top_k)
+
+ ops.scatter = scatter_with_meta
+
+ # Patch ops.replicate
+ if hasattr(ops, "replicate"):
+ original_replicate = ops.replicate
+
+ def replicate_with_meta(x, bins, num_outputs):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty(
+ (x.shape[0], num_outputs), dtype=x.dtype, device=x.device
+ )
+ return original_replicate(x, bins, num_outputs)
+
+ ops.replicate = replicate_with_meta
+
+ # Patch ops.repeat (if it's a regular function)
+ if hasattr(ops, "repeat"):
+ original_repeat = ops.repeat
+
+ def repeat_with_meta(x, repeats):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ if isinstance(repeats, (tuple, list)):
+ new_shape = list(x.shape)
+ for i, rep in enumerate(repeats):
+ if i < len(new_shape):
+ new_shape[i] *= rep
+ return torch.empty(new_shape, dtype=x.dtype, device=x.device)
+ else:
+ new_shape = [x.size(0) * repeats] + list(x.shape[1:])
+ return torch.empty(new_shape, dtype=x.dtype, device=x.device)
+ return original_repeat(x, repeats)
+
+ ops.repeat = repeat_with_meta
+
+
+# Install meta kernels on import
+try:
+ _install_meta_kernels()
+except Exception as e:
+ # If meta kernel installation fails, continue without them
+ # torch.compile may not work but the library will still function
+ import warnings
+
+ warnings.warn(
+ f"Failed to install meta kernels for torch.compile support: {e}", UserWarning
+ )
+
+
+# Set the expert model parallel attributes on a tensor
+def set_expert_model_parallel_attributes(
+ tensor: torch.Tensor,
+ is_parallel: bool,
+):
+ assert not hasattr(tensor, "expert_model_parallel")
+ setattr(tensor, "expert_model_parallel", is_parallel)
+
+
+# Get the expert model parallel attributes from a tensor
+def expert_sharding_degree(
+ world_size: int,
+ moe_num_experts: int,
+) -> int:
+ esd = min(world_size, moe_num_experts)
+ if (moe_num_experts % esd) != 0:
+ raise ValueError(f"Cannot shard {moe_num_experts} experts {esd} ways.")
+ return esd
+
+
+# Calculate the hidden sharding degree based on world size and expert sharding degree
+def hidden_sharding_degree(
+ world_size: int,
+ moe_num_experts: int,
+ ffn_hidden_size: int,
+) -> int:
+ esd = expert_sharding_degree(world_size, moe_num_experts)
+ hsd = world_size // esd
+ if (ffn_hidden_size % hsd) != 0:
+ raise ValueError(f"Cannot shard {ffn_hidden_size} features {hsd} ways.")
+ if (esd * hsd) != world_size:
+ raise ValueError(
+ f"Invalid sharding. expert_sharding_degree ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size})."
+ )
+ return hsd
+
+
+# Calculate the number of experts per rank based on world size and expert sharding degree
+def experts_per_rank(
+ moe_num_experts: int,
+ world_size: int,
+) -> int:
+ return moe_num_experts // expert_sharding_degree(world_size, moe_num_experts)
+
+
+# Calculate the number of features per rank based on ffn hidden size and hidden sharding degree
+def features_per_rank(
+ ffn_hidden_size: int, world_size: int, moe_num_experts: int
+) -> int:
+ return ffn_hidden_size // hidden_sharding_degree(
+ world_size, moe_num_experts, ffn_hidden_size
+ )
+
+
+# Apply jitter to the input tensor
+def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor:
+ low = 1.0 - moe_jitter_eps
+ high = 1.0 + moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return x * (low + noise * (high - low))
+
+
+# Compute the top-k scores from the logits
+def compute_top_k(scores: torch.Tensor, moe_top_k: int):
+ if moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, moe_top_k, dim=-1)
+
+
+# Route tokens to experts and compute expert weights and indices
+def route_tokens(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: torch.Tensor,
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ if training and moe_jitter_eps is not None:
+ x = apply_jitter(x, moe_jitter_eps)
+
+ x_flat = x.view(-1, x.shape[-1])
+ logits = torch.nn.functional.linear(x_flat, router_weight, router_bias)
+ expert_weights, expert_indices = compute_top_k(logits, moe_top_k)
+ expert_weights = expert_weights.softmax(dim=-1)
+ if moe_normalize_expert_weights is not None:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+ if uniform_expert_assignment:
+ expert_indices = _layers.router._uniform_expert_assignment(
+ expert_indices,
+ moe_num_experts,
+ )
+
+ return logits, expert_weights, expert_indices
+
+
+# Scale the gradient of the weights
+def scale_grad(
+ w: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+) -> torch.Tensor:
+ if gradient_scale is None:
+ return w
+ return _layers.mlp.scale_gradient(w, gradient_scale)
+
+
+# Forward pass for the MLP layer
+def mlp_forward(
+ x: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ limit: float = 7.0,
+):
+ # Scale weights
+ w1 = scale_grad(w1, gradient_scale)
+ w2 = scale_grad(w2, gradient_scale)
+ w1_bias = scale_grad(w1_bias, gradient_scale)
+ w2_bias = scale_grad(w2_bias, gradient_scale)
+
+ # Resolve dtensors
+ w1 = _layers.mlp.resolve_dtensor(w1)
+ w2 = _layers.mlp.resolve_dtensor(w2)
+ w1_bias = _layers.mlp.resolve_dtensor(w1_bias)
+ w2_bias = _layers.mlp.resolve_dtensor(w2_bias)
+
+ # Forward pass
+ gate_up = torch.bmm(x, w1) + w1_bias[..., None, :]
+ gate, up = gate_up[..., ::2], gate_up[..., 1::2]
+ gate = gate.clamp(min=None, max=limit)
+ up = up.clamp(min=-limit, max=limit)
+ glu = gate * torch.sigmoid(gate * alpha)
+ next_states = torch.bmm(((up + 1) * glu), w2)
+ next_states += w2_bias[..., None, :]
+ return next_states
+
+# Shared expert MLP forward pass
+def shared_mlp_forward(
+ x: torch.Tensor,
+ up_proj_weight: torch.Tensor,
+ down_proj_weight: torch.Tensor,
+ up_proj_bias: Optional[torch.Tensor] = None,
+ down_proj_bias: Optional[torch.Tensor] = None,
+ activation_fn: Optional[Any] = None,
+ gradient_scale: Optional[float] = None,
+) -> torch.Tensor:
+ # Default activation function
+ if activation_fn is None:
+ activation_fn = torch.nn.functional.gelu
+
+ # Scale weights
+ up_proj_weight = scale_grad(up_proj_weight, gradient_scale)
+ down_proj_weight = scale_grad(down_proj_weight, gradient_scale)
+ if up_proj_bias is not None:
+ up_proj_bias = scale_grad(up_proj_bias, gradient_scale)
+ if down_proj_bias is not None:
+ down_proj_bias = scale_grad(down_proj_bias, gradient_scale)
+
+ # Resolve dtensors
+ up_proj_weight = _layers.mlp.resolve_dtensor(up_proj_weight)
+ down_proj_weight = _layers.mlp.resolve_dtensor(down_proj_weight)
+ if up_proj_bias is not None:
+ up_proj_bias = _layers.mlp.resolve_dtensor(up_proj_bias)
+ if down_proj_bias is not None:
+ down_proj_bias = _layers.mlp.resolve_dtensor(down_proj_bias)
+
+ # Up projection
+ x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias)
+
+ # Activation
+ x = activation_fn(x)
+
+ # Down projection
+ x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias)
+
+ return x
+
+
+# Combine outputs from shared expert and regular experts
+def combine_expert_shared_outputs(
+ shared_expert_out: torch.Tensor,
+ expert_out: torch.Tensor,
+ shared_expert_weighted_sum: bool = False,
+ moe_top_k: int = 1,
+) -> torch.Tensor:
+ if shared_expert_weighted_sum:
+ # Weighted sum based on number of experts used
+ total_experts = moe_top_k + 1
+ shared_weight = 1.0 / total_experts
+ expert_weight = moe_top_k / total_experts
+ return shared_expert_out * shared_weight + expert_out * expert_weight
+ else:
+ # Simple addition
+ return shared_expert_out + expert_out
+
+
+# Global variable to store load balancing loss
+_LOAD_BALANCING_LOSS = []
+
+
+def save_load_balancing_loss(loss):
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.append(loss)
+
+
+def get_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ return _LOAD_BALANCING_LOSS
+
+
+def clear_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.clear()
+
+
+def batched_load_balancing_loss(args):
+ if args.moe_loss_weight == 0:
+ return 0.0
+
+ tokens_per_expert, expert_scores = zip(*get_load_balancing_loss())
+ num_layers_per_pipeline_stage = args.num_layers // args.pipeline_model_parallel_size
+ if args.num_layers_per_virtual_pipeline_stage is not None:
+ num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage
+
+ if len(tokens_per_expert) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f"Expected {num_layers_per_pipeline_stage} token_per_experts "
+ f"but found {len(tokens_per_expert)}.\nnum_layers = "
+ f"{args.num_layers}\npipeline_model_parallel_size = "
+ f"{args.pipeline_model_parallel_size}\n"
+ "num_layers_per_virtual_pipeline_stage"
+ f" = {args.num_layers_per_virtual_pipeline_stage}",
+ )
+ if len(expert_scores) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f"Expected {num_layers_per_pipeline_stage} expert_scores "
+ f"but found {len(tokens_per_expert)}.\nnum_layers = "
+ f"{args.num_layers}\npipeline_model_parallel_size = "
+ f"{args.pipeline_model_parallel_size}\n"
+ "num_layers_per_virtual_pipeline_stage"
+ f" = {args.num_layers_per_virtual_pipeline_stage}",
+ )
+
+ # Verify the shape of the tokens_per_expert and expert_scores tensors.
+ assert all(
+ (x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert)
+ )
+
+ tokens = expert_scores[0].shape[0]
+ assert all(
+ (
+ (
+ x.ndim == 2
+ and x.shape[1] == args.moe_num_experts
+ and x.shape[0] == tokens
+ )
+ for x in expert_scores
+ )
+ )
+
+ # Concatenate the contributions of each layer and convert to
+ # the correct types and formats for the dot product.
+ expert_scores = torch.cat(expert_scores, dim=1)
+ if args.moe_lbl_in_fp32:
+ expert_scores = expert_scores.float()
+ if tokens != 0:
+ expert_scores = expert_scores.mean(dim=0)
+ else:
+ expert_scores = expert_scores.sum(dim=0)
+ tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype)
+
+ expected_values = num_layers_per_pipeline_stage * args.moe_num_experts
+ assert tokens_per_expert.numel() == expected_values
+ assert expert_scores.numel() == expected_values
+
+ # Calculate the total scale across all factors.
+ #
+ # loss_weight * num_experts / (num_layers * tokens * top_k)
+ scale_numerator = args.moe_num_experts * args.moe_loss_weight
+ scale_denominator = args.num_layers * tokens * args.moe_top_k
+ scale = scale_numerator / scale_denominator
+ return scale * torch.dot(tokens_per_expert, expert_scores)
+
+
+# Calculate the expert capacity based on tokens, top_k, number of experts,
+# expert parallel group, capacity factor, and whether expert model parallelism is used.
+def expert_capacity(
+ tokens: int,
+ top_k: int,
+ num_experts: int,
+ expert_parallel_group: int,
+ moe_capacity_factor: float,
+ moe_expert_model_parallelism: bool,
+) -> int:
+ world_size = (
+ dist.get_world_size(expert_parallel_group)
+ if moe_expert_model_parallelism
+ else 1
+ )
+
+ tokens_per_expert = top_k * tokens * world_size / num_experts
+ return int(moe_capacity_factor * tokens_per_expert)
+
+
+def load_balancing_loss(
+ tokens_per_expert: torch.Tensor,
+ expert_scores: torch.Tensor,
+ top_k: int,
+ num_experts: int,
+):
+ assert len(expert_scores.size()) == 2
+ tokens, num_experts = expert_scores.size()
+ assert num_experts == num_experts
+ assert len(tokens_per_expert.size()) == 1
+ (num_experts,) = tokens_per_expert.size()
+ assert num_experts == num_experts
+ scale = num_experts / (tokens * top_k)
+ return scale * torch.dot(
+ tokens_per_expert.to(expert_scores.dtype),
+ expert_scores.mean(dim=0),
+ )
+
+
+def indices_and_bins(
+ top_expert: torch.Tensor,
+ sort_end_bit: int,
+ num_experts: int,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ top_expert = top_expert.int()
+
+ # Ensure contiguous memory layout
+ top_expert = top_expert.contiguous()
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(top_expert.device):
+ output = ops.sort(top_expert, sort_end_bit)
+ bin_ids, indices = output
+ tokens_per_expert = ops.histogram(top_expert, num_experts)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+
+ bins = bins.view(1) if not len(bins.size()) else bins
+ return indices, bin_ids, bins, tokens_per_expert
+
+
+def expert_capacity_fn(
+ tokens: int,
+ top_k: int,
+ num_experts: int,
+ expert_parallel_group: torch.distributed.ProcessGroup,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+) -> int:
+ world_size = (
+ dist.get_world_size(expert_parallel_group)
+ if moe_expert_model_parallelism
+ else 1
+ )
+ tokens_per_expert = top_k * tokens * world_size / num_experts
+ return int(moe_capacity_factor * tokens_per_expert)
+
+
+def permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ top_k,
+ w1,
+ w2,
+ w1_bias,
+ w2_bias,
+ gradient_scale,
+ alpha,
+):
+ # Route tokens to experts
+ x = x.view(-1, x.shape[-1])
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(x.device):
+ x = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
+
+ # Expert computation
+ x = mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale, alpha)
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(x.device):
+ # Route tokens back
+ out = ops.binned_scatter(x, indices, expert_weights, bins, top_k)
+ return out
+
+
+def forward_once(
+ x: torch.Tensor,
+ expert_weights: torch.Tensor,
+ top_experts: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ top_k: int = 4,
+ num_experts: int = 128,
+ expert_parallel_group: int = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ mlp_impl: Optional[str] = None,
+):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
+ top_experts, sort_end_bit, num_experts
+ )
+
+ # Calculate expert capacity
+ sl, bs, _ = x.size()
+
+ expert_capacity = expert_capacity_fn(
+ sl * bs,
+ top_k,
+ num_experts,
+ expert_parallel_group,
+ moe_capacity_factor,
+ moe_expert_model_parallelism,
+ )
+
+ if expert_capacity == 0:
+ expert_capacity = torch.max(tokens_per_expert).item()
+
+ x = permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ top_k,
+ w1,
+ w2,
+ w1_bias,
+ w2_bias,
+ gradient_scale,
+ alpha,
+ )
+ return x, tokens_per_expert
+
+
+def parallel_forward_once(
+ x: torch.Tensor,
+ expert_weights: torch.Tensor,
+ top_experts: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ top_k: int = 4,
+ num_experts: int = 128,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = True,
+ hidden_size: int = 1152,
+ mlp_impl: Optional[str] = "grouped",
+):
+ # Flatten inputs
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+
+ # TODO: remove debugging var
+ # my_rank = dist.get_rank(expert_parallel_group) if expert_parallel_group else 0
+
+ with torch.no_grad():
+ # Step 1: Local permutation setup
+ indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
+ top_experts, sort_end_bit, num_experts
+ )
+
+ # Calculate sharding parameters
+ world_size = dist.get_world_size(expert_parallel_group)
+ hidden_sharding_deg = hidden_sharding_degree(
+ world_size, num_experts, hidden_size
+ )
+ experts_per_rank_val = experts_per_rank(num_experts, world_size)
+
+ # Replicate token counts for hidden sharding
+ repeated_tokens_per_expert = ops.repeat(
+ tokens_per_expert, (hidden_sharding_deg,)
+ )
+
+ # Exchange token counts across devices
+ parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert)
+
+ # Ensure CUB knows which device to use
+ tpe_handle = dist.all_to_all_single(
+ parallel_tokens_per_expert,
+ repeated_tokens_per_expert,
+ group=expert_parallel_group,
+ async_op=True,
+ )
+
+ # Step 2: Local permutation - group tokens by target device
+ x = x.view(-1, x.shape[-1]) # [sl * bs, hs]
+ x = ops.gather(x, indices, bin_ids, bins, top_k)
+
+ # Step 3: Compute communication counts and exchange tokens
+ with torch.no_grad():
+ tpe_handle.wait()
+
+ # Reshape for per-device calculations
+ repeated_tokens_per_expert = repeated_tokens_per_expert.view(
+ world_size, experts_per_rank_val
+ )
+ parallel_tokens_per_expert = parallel_tokens_per_expert.view(
+ world_size, experts_per_rank_val
+ )
+
+ # Calculate send/recv counts
+ send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1).tolist()
+ # recv_counts = parallel_tokens_per_expert.cpu().sum(dim=-1).tolist()
+ parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
+ recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1).tolist()
+ tokens_received = sum(recv_counts)
+
+ # Replicate for hidden sharding
+ x = ops.repeat(x, (hidden_sharding_deg, 1))
+
+ # Cross-device token exchange
+ parallel_x, parallel_x_handle = _layers.all_to_all.all_to_all(
+ x, recv_counts, send_counts, expert_parallel_group, async_op=True
+ )
+
+ with torch.no_grad():
+ # Step 4: Setup for local expert computation
+ replicate_bins = ops.inclusive_cumsum(parallel_tokens_per_expert.flatten(), 0)
+ replicate_bins = (
+ replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins
+ )
+
+ # Create expert indices for received tokens
+ parallel_top_expert = torch.remainder(
+ torch.arange(
+ num_experts * hidden_sharding_deg,
+ dtype=torch.int32,
+ device=indices.device,
+ ),
+ experts_per_rank_val,
+ )
+ parallel_top_expert = ops.replicate(
+ parallel_top_expert.unsqueeze(dim=0),
+ replicate_bins,
+ tokens_received,
+ ).flatten()
+
+ # Sort tokens by expert assignment
+ parallel_bin_ids, parallel_indices = ops.sort(
+ parallel_top_expert,
+ sort_end_bit,
+ )
+
+ # Calculate bins for local experts
+ parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
+ dim=0, dtype=torch.int
+ )
+ parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = (
+ parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins
+ )
+
+ # Calculate expert capacity
+ expert_capacity = expert_capacity_fn(
+ tokens_received,
+ top_k,
+ experts_per_rank_val,
+ expert_parallel_group,
+ moe_capacity_factor,
+ moe_expert_model_parallelism,
+ )
+ if expert_capacity == 0:
+ expert_capacity = torch.max(parallel_tokens_per_expert).item()
+
+ # Locally permute the tokens and perform the expert computation.
+ # Block to make sure that the cross-device permutation is complete.
+ if mlp_impl == "grouped":
+ # GroupedMLP requires counts on CPU. We can use the tensor already
+ # moved to CPU for the prior all_to_all, which avoids an extra
+ # device synchronization.
+ parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+
+ # Step 5: Expert computation
+ parallel_x_handle.wait()
+
+ parallel_x = permute_and_compute(
+ parallel_x,
+ parallel_tokens_per_expert,
+ parallel_indices,
+ parallel_bin_ids,
+ None, # expert_weights
+ parallel_bins,
+ expert_capacity,
+ top_k=1,
+ w1=w1,
+ w2=w2,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ )
+
+ # Step 6: Reverse communication - send results back
+ x, _ = _layers.all_to_all.all_to_all(
+ parallel_x, send_counts, recv_counts, expert_parallel_group
+ )
+
+ # Step 7: Reduce across hidden sharding dimension
+ shape = (hidden_sharding_deg, -1, hidden_size)
+ x = x.view(shape).sum(dim=0)
+
+ # Step 8: Final local unpermutation
+ x = ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
+
+ return x, tokens_per_expert.flatten()
+
+
+def moe_forward(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: Optional[torch.Tensor],
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+ w1: torch.Tensor = None,
+ w2: torch.Tensor = None,
+ w1_bias: torch.Tensor = None,
+ w2_bias: torch.Tensor = None,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ forward_fn: Any = None,
+ hidden_size: int = None,
+ mlp_impl: str = "grouped",
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+
+ # Route tokens to experts
+ logits, expert_weights, expert_indices = route_tokens(
+ x,
+ router_weight,
+ router_bias,
+ moe_top_k,
+ moe_num_experts,
+ moe_jitter_eps,
+ moe_normalize_expert_weights,
+ uniform_expert_assignment,
+ training,
+ )
+
+ # Create router scores for output
+ router_scores = (
+ torch.zeros_like(logits)
+ .scatter_(1, expert_indices, expert_weights)
+ .transpose(0, 1)
+ )
+
+ in_shape = x.size()
+
+ # Prepare forward function arguments
+ forward_args = {
+ "x": x,
+ "expert_weights": expert_weights,
+ "top_experts": expert_indices,
+ "w1": w1,
+ "w2": w2,
+ "w1_bias": w1_bias,
+ "w2_bias": w2_bias,
+ "gradient_scale": gradient_scale,
+ "alpha": alpha,
+ "sort_end_bit": sort_end_bit,
+ "top_k": moe_top_k,
+ "num_experts": moe_num_experts,
+ "expert_parallel_group": expert_parallel_group,
+ "moe_capacity_factor": moe_capacity_factor,
+ "moe_expert_model_parallelism": moe_expert_model_parallelism,
+ "mlp_impl": mlp_impl,
+ }
+
+ # Add hidden_size for parallel forward
+ if moe_expert_model_parallelism and hidden_size is not None:
+ forward_args["hidden_size"] = hidden_size
+ elif moe_expert_model_parallelism and hidden_size is None:
+ # Infer hidden_size from input shape
+ forward_args["hidden_size"] = x.shape[-1]
+
+ # Compute expert outputs
+ x, tokens_per_expert = forward_fn(**forward_args)
+
+ # Save load balancing loss if needed
+ moe_loss_weight = 0.0 # Can be made configurable
+ if training and moe_loss_weight > 0:
+ save_load_balancing_loss((tokens_per_expert, logits))
+
+ # Restore original shape
+ x = x.view(in_shape)
+
+ return x, expert_weights, router_scores
+
+
+def moe_forward_with_shared_expert(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: Optional[torch.Tensor],
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+ w1: torch.Tensor = None,
+ w2: torch.Tensor = None,
+ w1_bias: torch.Tensor = None,
+ w2_bias: torch.Tensor = None,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ forward_fn: Any = None,
+ hidden_size: int = None,
+ mlp_impl: str = "grouped",
+ # Shared expert parameters
+ shared_up_proj_weight: Optional[torch.Tensor] = None,
+ shared_down_proj_weight: Optional[torch.Tensor] = None,
+ shared_up_proj_bias: Optional[torch.Tensor] = None,
+ shared_down_proj_bias: Optional[torch.Tensor] = None,
+ shared_expert_weighted_sum: bool = False,
+ shared_activation_fn: Optional[Any] = None,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+
+ # First, compute regular MoE forward pass
+ expert_out, expert_weights, router_scores = moe_forward(
+ x=x,
+ router_weight=router_weight,
+ router_bias=router_bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=training,
+ w1=w1,
+ w2=w2,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=moe_expert_model_parallelism,
+ forward_fn=forward_fn,
+ hidden_size=hidden_size,
+ mlp_impl=mlp_impl,
+ )
+
+ # If shared expert weights provided, compute shared expert output
+ if shared_up_proj_weight is not None and shared_down_proj_weight is not None:
+ shared_expert_out = shared_mlp_forward(
+ x=x,
+ up_proj_weight=shared_up_proj_weight,
+ down_proj_weight=shared_down_proj_weight,
+ up_proj_bias=shared_up_proj_bias,
+ down_proj_bias=shared_down_proj_bias,
+ activation_fn=shared_activation_fn,
+ gradient_scale=gradient_scale,
+ )
+
+ # Combine expert outputs
+ combined_out = combine_expert_shared_outputs(
+ shared_expert_out=shared_expert_out,
+ expert_out=expert_out,
+ shared_expert_weighted_sum=shared_expert_weighted_sum,
+ moe_top_k=moe_top_k,
+ )
+
+ return combined_out, expert_weights, router_scores
+
+ # Return regular MoE output if no shared expert
+ return expert_out, expert_weights, router_scores
+
+
+def create_shared_expert_weights(
+ hidden_size: int,
+ shared_expert_hidden_size: int,
+ device: torch.device,
+ dtype: torch.dtype,
+ init_method: Any,
+ output_layer_init_method: Any = None,
+) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
+
+ if output_layer_init_method is None:
+ output_layer_init_method = init_method
+
+ # Create weight tensors
+ up_proj_weight = torch.empty(
+ shared_expert_hidden_size,
+ hidden_size,
+ device=device,
+ dtype=dtype,
+ )
+ down_proj_weight = torch.empty(
+ hidden_size,
+ shared_expert_hidden_size,
+ device=device,
+ dtype=dtype,
+ )
+
+ # Initialize weights
+ init_method(up_proj_weight)
+ output_layer_init_method(down_proj_weight)
+
+ # No bias by default
+ return up_proj_weight, down_proj_weight, None, None
+
+
+# HACK: Extract device_mesh from pre-hook closure - required for transformers integration
+# This exists because device_mesh is trapped in hook closures with no model attribute
+# Fragile - breaks if hook structure changes or Python internals change
+# TODO: Replace with a more robust solution when available
+def get_device_mesh(model):
+ # Extract device_mesh from child's unused pre_hook closure
+ try:
+ # Find the pre-hook that contains 'device_mesh' in its closure
+ hook = next(
+ h
+ for h in model.experts._forward_pre_hooks.values()
+ if "device_mesh" in h.__code__.co_freevars
+ )
+ # Extract the device_mesh from the closure
+ return hook.__closure__[
+ hook.__code__.co_freevars.index("device_mesh")
+ ].cell_contents
+ except Exception:
+ return None
+
+
+class MegaBlocksMoeMLP(torch.nn.Module):
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ gradient_scale = getattr(self.experts, "gradient_scale", None)
+ alpha = getattr(self.experts, "alpha", 1.0)
+ moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+ uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
+
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = get_device_mesh(self)
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
+
+ has_parallel = (
+ expert_parallel_group is not None
+ and dist.is_initialized()
+ and dist.get_world_size(expert_parallel_group) > 1
+ )
+ forward_fn = parallel_forward_once if has_parallel else forward_once
+
+ sort_end_bit = max(
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
+ )
+ mlp_impl = getattr(self, "mlp_impl", "grouped")
+ output, expert_weights_out, *_ = moe_forward(
+ x=x,
+ router_weight=self.router.weight,
+ router_bias=self.router.bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=self.training,
+ w1=self.experts.gate_up_proj,
+ w2=self.experts.down_proj,
+ w1_bias=self.experts.gate_up_proj_bias,
+ w2_bias=self.experts.down_proj_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=has_parallel,
+ forward_fn=forward_fn,
+ hidden_size=self.experts.hidden_size,
+ mlp_impl=mlp_impl,
+ )
+ return output, expert_weights_out
+
+
+# Export main classes
+__all__ = ["MegaBlocksMoeMLP", "MegaBlocksMoeMLPWithSharedExpert"]
+
+
+class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
+
+ def __init__(self):
+ super().__init__()
+ # Shared expert weights will be set by the user
+ self.shared_up_proj_weight = None
+ self.shared_down_proj_weight = None
+ self.shared_up_proj_bias = None
+ self.shared_down_proj_bias = None
+ self.shared_expert_weighted_sum = False
+ self.shared_activation_fn = None
+
+ def set_shared_expert_weights(
+ self,
+ up_proj_weight: torch.Tensor,
+ down_proj_weight: torch.Tensor,
+ up_proj_bias: Optional[torch.Tensor] = None,
+ down_proj_bias: Optional[torch.Tensor] = None,
+ weighted_sum: bool = False,
+ activation_fn: Optional[Any] = None,
+ ):
+ self.shared_up_proj_weight = up_proj_weight
+ self.shared_down_proj_weight = down_proj_weight
+ self.shared_up_proj_bias = up_proj_bias
+ self.shared_down_proj_bias = down_proj_bias
+ self.shared_expert_weighted_sum = weighted_sum
+ self.shared_activation_fn = activation_fn
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ gradient_scale = getattr(self.experts, "gradient_scale", None)
+ alpha = getattr(self.experts, "alpha", 1.0)
+ moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+ uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
+
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = get_device_mesh(self)
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
+
+ has_parallel = (
+ expert_parallel_group is not None
+ and dist.is_initialized()
+ and dist.get_world_size(expert_parallel_group) > 1
+ )
+ forward_fn = parallel_forward_once if has_parallel else forward_once
+
+ sort_end_bit = max(
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
+ )
+ mlp_impl = getattr(self, "mlp_impl", "grouped")
+
+ output, expert_weights_out, *_ = moe_forward_with_shared_expert(
+ x=x,
+ router_weight=self.router.weight,
+ router_bias=self.router.bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=self.training,
+ w1=self.experts.gate_up_proj,
+ w2=self.experts.down_proj,
+ w1_bias=self.experts.gate_up_proj_bias,
+ w2_bias=self.experts.down_proj_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=has_parallel,
+ forward_fn=forward_fn,
+ hidden_size=self.experts.hidden_size,
+ mlp_impl=mlp_impl,
+ # Shared expert parameters
+ shared_up_proj_weight=self.shared_up_proj_weight,
+ shared_down_proj_weight=self.shared_down_proj_weight,
+ shared_up_proj_bias=self.shared_up_proj_bias,
+ shared_down_proj_bias=self.shared_down_proj_bias,
+ shared_expert_weighted_sum=self.shared_expert_weighted_sum,
+ shared_activation_fn=self.shared_activation_fn,
+ )
+ return output, expert_weights_out
+
+
+# Patch for XPU or CPU support
+if hasattr(torch, "xpu") and torch.xpu.is_available():
+ from .xpu_fused_moe import MegaBlocksMoeMLP
+
+from .cpu_moe_cpp import CPUMegaBlocksMoeMLP
diff --git a/build/torch210-cxx11-cu128-aarch64-linux/megablocks/__init__.py b/build/torch210-cxx11-cu128-aarch64-linux/megablocks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23
--- /dev/null
+++ b/build/torch210-cxx11-cu128-aarch64-linux/megablocks/__init__.py
@@ -0,0 +1,26 @@
+import ctypes
+import importlib.util
+import sys
+from pathlib import Path
+from types import ModuleType
+
+
+def _import_from_path(file_path: Path) -> ModuleType:
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
+ # it would also be used for other imports. So, we make a module name that
+ # depends on the path for it to be unique using the hex-encoded hash of
+ # the path.
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
+ module_name = path_hash
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ if spec is None:
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
+ module = importlib.util.module_from_spec(spec)
+ if module is None:
+ raise ImportError(f"Cannot load module {module_name} from spec")
+ sys.modules[module_name] = module
+ spec.loader.exec_module(module) # type: ignore
+ return module
+
+
+globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
diff --git a/build/torch210-cxx11-cu128-aarch64-linux/metadata.json b/build/torch210-cxx11-cu128-aarch64-linux/metadata.json
new file mode 100644
index 0000000000000000000000000000000000000000..e3e4edf582b7ffb515d0ed32e9fc9c89f125c441
--- /dev/null
+++ b/build/torch210-cxx11-cu128-aarch64-linux/metadata.json
@@ -0,0 +1,21 @@
+{
+ "version": 1,
+ "license": "Apache-2.0",
+ "python-depends": [],
+ "backend": {
+ "type": "cuda",
+ "archs": [
+ "10.0",
+ "10.1",
+ "12.0",
+ "7.0",
+ "7.2",
+ "7.5",
+ "8.0",
+ "8.6",
+ "8.7",
+ "8.9",
+ "9.0"
+ ]
+ }
+}
diff --git a/build/torch210-cxx11-cu128-aarch64-linux/ops/__init__.py b/build/torch210-cxx11-cu128-aarch64-linux/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b944080df810d0b0cfc571f3009b0098a651f9b7
--- /dev/null
+++ b/build/torch210-cxx11-cu128-aarch64-linux/ops/__init__.py
@@ -0,0 +1,35 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from .binned_gather import binned_gather
+from .binned_scatter import binned_scatter
+from .cumsum import exclusive_cumsum, inclusive_cumsum
+from .gather import gather
+from .histogram import histogram
+from .padded_gather import padded_gather
+from .padded_scatter import padded_scatter
+from .repeat import repeat
+from .replicate import replicate
+from .round_up import round_up
+from .scatter import scatter
+from .sort import sort
+from .sum import sum
+from .topology import topology
+
+__all__ = [
+ 'binned_gather',
+ 'binned_scatter',
+ 'exclusive_cumsum',
+ 'inclusive_cumsum',
+ 'gather',
+ 'histogram',
+ 'padded_gather',
+ 'padded_scatter',
+ 'repeat',
+ 'replicate',
+ 'round_up',
+ 'scatter',
+ 'sort',
+ 'sum',
+ 'topology',
+]
diff --git a/build/torch210-cxx11-cu128-aarch64-linux/ops/all_to_all_benchmark.py b/build/torch210-cxx11-cu128-aarch64-linux/ops/all_to_all_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c939818edca3345f6344bbc7cef07ffe3cd0181
--- /dev/null
+++ b/build/torch210-cxx11-cu128-aarch64-linux/ops/all_to_all_benchmark.py
@@ -0,0 +1,63 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import torch.distributed as dist
+
+# from megablocks import benchmark_util
+# from megablocks.layers.all_to_all import all_to_all
+
+from .. import benchmark_util
+from .._layers.all_to_all import all_to_all
+
+_ALL_TO_ALL_BENCHMARK = (
+ (8, 1024),
+ (16, 1024),
+ (32, 1024),
+ (64, 1024),
+ (128, 1024),
+ (256, 1024),
+ (512, 1024),
+ (1024, 1024),
+ (2 * 1024, 1024),
+ (4 * 1024, 1024),
+ (8 * 1024, 1024),
+ (16 * 1024, 1024),
+ (32 * 1024, 1024),
+ (64 * 1024, 1024),
+ (128 * 1024, 1024),
+ (256 * 1024, 1024),
+ (512 * 1024, 1024),
+ (1024 * 1024, 1024),
+)
+
+
+def benchmark_all_to_all(group, sl, hs):
+ world_size = dist.get_world_size(group)
+ assert (sl % world_size) == 0
+ send_recv_sizes = [sl // world_size] * world_size
+
+ x = torch.randn((sl, hs)).cuda().half()
+
+ details = {
+ 'world_size': world_size,
+ 'message_size (B)': send_recv_sizes[0] * hs * 2, # 2B elements.
+ }
+
+ def benchmark():
+ return all_to_all(x, send_recv_sizes, send_recv_sizes, group)
+
+ time, std = benchmark_util.benchmark_function(benchmark)
+
+ if dist.get_rank(group) == 0:
+ benchmark_util.log_benchmark('All-To-All', details, time, std)
+
+
+if __name__ == '__main__':
+ assert dist.is_available()
+ group = dist.init_process_group(backend='nccl')
+ local_rank = dist.get_rank(group)
+ torch.cuda.set_device(local_rank)
+
+ for args in _ALL_TO_ALL_BENCHMARK:
+ benchmark_all_to_all(group, *args)
diff --git a/build/torch210-cxx11-cu128-aarch64-linux/ops/binned_gather.py b/build/torch210-cxx11-cu128-aarch64-linux/ops/binned_gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..189a7fa3518d660f29ea32e7a04827164af98d60
--- /dev/null
+++ b/build/torch210-cxx11-cu128-aarch64-linux/ops/binned_gather.py
@@ -0,0 +1,37 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for binned_gather kernel.
+class BinnedGatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bins: torch.Tensor,
+ bin_size: int,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bins)
+ ctx.top_k = top_k
+ return kernels.binned_gather(x, indices, None, bins, bin_size, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ indices, bins = ctx.saved_tensors
+ out = kernels.binned_scatter(grad, indices, None, bins, ctx.top_k)
+ return out, None, None, None, None
+
+
+binned_gather = BinnedGatherOp.apply
diff --git a/build/torch210-cxx11-cu128-aarch64-linux/ops/binned_scatter.py b/build/torch210-cxx11-cu128-aarch64-linux/ops/binned_scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb937c0c106662ce8108c1cb926f8f063b163d3d
--- /dev/null
+++ b/build/torch210-cxx11-cu128-aarch64-linux/ops/binned_scatter.py
@@ -0,0 +1,59 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for binned_scatter kernel.
+class BinnedScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ):
+ assert len(x.size()) == 3
+ ctx.bin_size = x.size(1)
+ ctx.top_k = top_k
+
+ # TODO(tgale): Don't save 'x' for backwards if we don't need to
+ # calculate the gradient w.r.t. 'weights'.
+ ctx.save_for_backward(x, indices, weights, bins)
+ return kernels.binned_scatter(x, indices, weights, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ x, indices, weights, bins = ctx.saved_tensors
+ out = kernels.binned_gather(
+ grad,
+ indices,
+ weights,
+ bins,
+ ctx.bin_size,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[2]:
+ wgrad = kernels.binned_scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bins,
+ ctx.top_k,
+ )
+ return out, None, wgrad, None, None
+
+
+binned_scatter = BinnedScatterOp.apply
diff --git a/build/torch210-cxx11-cu128-aarch64-linux/ops/cumsum.py b/build/torch210-cxx11-cu128-aarch64-linux/ops/cumsum.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2b7572391e20045d335cf7337246e8a9b9f57ef
--- /dev/null
+++ b/build/torch210-cxx11-cu128-aarch64-linux/ops/cumsum.py
@@ -0,0 +1,52 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ # import megablocks_ops as ops # type: ignore
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrappers for cumsum kernels.
+# NOTE: Does not support gradients.
+class ExclusiveCumsumOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, dim: int):
+ if len(x.size()) == 1:
+ x = x.view([1, -1])
+ out = torch.empty_like(x)
+ ops.exclusive_cumsum(x, 1, out)
+ return out.squeeze()
+ out = torch.empty_like(x)
+ ops.exclusive_cumsum(x, dim, out)
+ return out
+
+
+exclusive_cumsum = ExclusiveCumsumOp.apply
+
+
+class InclusiveCumsumOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, dim: int) -> torch.Tensor:
+ if len(x.size()) == 1:
+ x = x.view([1, -1])
+ out = torch.empty_like(x)
+ ops.inclusive_cumsum(x, 1, out)
+ return out.squeeze()
+ out = torch.empty_like(x)
+ ops.inclusive_cumsum(x, dim, out)
+ return out
+
+
+inclusive_cumsum = InclusiveCumsumOp.apply
diff --git a/build/torch210-cxx11-cu128-aarch64-linux/ops/gather.py b/build/torch210-cxx11-cu128-aarch64-linux/ops/gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1f87c1e7bed8d3589dd790805234976e0b05898
--- /dev/null
+++ b/build/torch210-cxx11-cu128-aarch64-linux/ops/gather.py
@@ -0,0 +1,38 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for gather kernel.
+class GatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bin_ids, bins)
+ ctx.top_k = top_k
+ return kernels.gather(x, indices, bin_ids, None, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+
+ indices, bin_ids, bins = ctx.saved_tensors
+ out = kernels.scatter(grad, indices, bin_ids, None, bins, ctx.top_k)
+ return out, None, None, None, None, None
+
+
+gather = GatherOp.apply
diff --git a/build/torch210-cxx11-cu128-aarch64-linux/ops/histogram.py b/build/torch210-cxx11-cu128-aarch64-linux/ops/histogram.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b3f058ec373cbba7555704fb5e4212c3cc75d9d
--- /dev/null
+++ b/build/torch210-cxx11-cu128-aarch64-linux/ops/histogram.py
@@ -0,0 +1,27 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for histogram kernel.
+# NOTE: Does not support gradients.
+class HistogramOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, max_val: float):
+ return ops.histogram(x, max_val)
+
+
+histogram = HistogramOp.apply
diff --git a/build/torch210-cxx11-cu128-aarch64-linux/ops/histogram_benchmark.py b/build/torch210-cxx11-cu128-aarch64-linux/ops/histogram_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0b9c6047567b87a295979498142230d1b0c9411
--- /dev/null
+++ b/build/torch210-cxx11-cu128-aarch64-linux/ops/histogram_benchmark.py
@@ -0,0 +1,78 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import numpy as np
+import torch
+# from absl.testing import parameterized
+
+from .. import ops
+
+_HISTOGRAM_TESTS = (
+ (16384, torch.int32, 2),
+ (16384, torch.int32, 4),
+ (16384, torch.int32, 8),
+ (16384, torch.int32, 16),
+ (16384, torch.int32, 32),
+ (16384, torch.int32, 64),
+ (16384, torch.int32, 128),
+ (16384, torch.int32, 256),
+)
+
+
+def benchmark_function(fn, iterations=10):
+ # Run once to get rid of startup overhead.
+ fn()
+ times = []
+ for _ in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+ start.record()
+ fn()
+ end.record()
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ times = np.array(times)
+ return times.mean(), times.std(), times.max(), times.min()
+
+
+def log_benchmark(arguments, mean_t, std_t):
+ print('=' * 60)
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t))
+ print('=' * 60)
+
+
+# class HistogramBenchmark(parameterized.TestCase):
+#
+# @parameterized.parameters(*_HISTOGRAM_TESTS)
+# def testHistogram(self, n, dtype, max_val):
+# x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
+#
+# mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.histogram(x, max_val),)
+# arguments = {
+# 'n': n,
+# 'dtype': dtype,
+# 'max_val': max_val,
+# }
+# log_benchmark(arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_HISTOGRAM_TESTS)
+# def testTorchHistogram(self, n, dtype, max_val):
+# x = torch.randint(0, 128, (n,)).cuda().to(dtype)
+#
+# mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.histc(x, max_val, 0, max_val - 1),)
+# arguments = {
+# 'n': n,
+# 'dtype': dtype,
+# 'max_val': max_val,
+# }
+# log_benchmark(arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch210-cxx11-cu128-aarch64-linux/ops/matmul_benchmark.py b/build/torch210-cxx11-cu128-aarch64-linux/ops/matmul_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c4b9b8866ffed2eb769b77f2320c82e5491ae0e
--- /dev/null
+++ b/build/torch210-cxx11-cu128-aarch64-linux/ops/matmul_benchmark.py
@@ -0,0 +1,415 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+
+# import stk
+
+# try:
+# import stk
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/ops/matmul_benchmark.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+# from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+
+# Calling tensor.t() calls tensor.transpose(0, 1) which calls
+# torch.as_strided(...). Circumvent this chain to avoid an overhead
+# this adds.
+def transpose_view(x):
+ return torch.as_strided(
+ x,
+ (x.shape[1], x.shape[0]),
+ (x.stride()[1], x.stride()[0]),
+ )
+
+
+_MATMUL_TESTS = (
+ (64 * 1024, 512, 2048, 64),
+ (32 * 1024, 768, 3072, 64),
+ (8 * 1024, 1024, 4096, 64),
+ (4 * 2048, 4096, 4 * 4096, 4),
+)
+
+
+def log_benchmark(name, arguments, time, std, flops):
+ benchmark_util.log_benchmark(name, arguments, time, std)
+ print('flops = {:.2f}B'.format(flops / 1e9))
+ print('throughput = {:.2f}T'.format(flops / 1e9 / time))
+ print('=' * 60)
+
+
+# class MatmulBenchmark(parameterized.TestCase):
+#
+# def build_sparse_matrix(self, x, padded_bins, fhs, ne):
+# blocking = 128
+# padded_tokens, _ = x.size()
+# assert padded_tokens % blocking == 0
+# assert fhs % blocking == 0
+#
+# # Offsets for the sparse matrix. All rows have the
+# # same number of nonzero blocks dictated by the
+# # dimensionality of a single expert.
+# block_rows = padded_tokens // blocking
+# blocks_per_row = fhs // blocking
+# offsets = torch.arange(
+# 0,
+# block_rows * blocks_per_row + 1,
+# blocks_per_row,
+# dtype=torch.int32,
+# device=x.device,
+# )
+#
+# # Indices for the sparse matrix. The indices for
+# # the intermediate matrix are dynamic depending
+# # on the mapping of tokens to experts.
+# column_indices = ops.topology(
+# padded_bins,
+# blocking,
+# block_rows,
+# blocks_per_row,
+# )
+# data = torch.empty(
+# column_indices.numel(),
+# blocking,
+# blocking,
+# dtype=torch.float16,
+# device=x.device,
+# )
+# shape = (padded_tokens, fhs * ne)
+# row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
+# return stk.Matrix(shape, data, row_indices, column_indices, offsets)
+#
+# def build_input_matrix(self, sl, hs, ne):
+# x = torch.randn((sl, hs)).cuda().half()
+#
+# # Assign tokens to experts uniformly.
+# top_expert = torch.arange(0, sl).cuda().int() % ne
+#
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(top_expert, ne)
+# padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+# padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+# out = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, 1)
+# return out, padded_bins
+#
+# def build_weight_matrix(self, ne, hs, fhs):
+# return torch.randn((hs, ne * fhs)).cuda().half()
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_Fwd_SDD_NT(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+# w = transpose_view(w)
+#
+# def benchmark():
+# return stk.ops.sdd(x, w, topo)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0::Fwd::SDD::NT',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_GradX_DSD_NN(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+#
+# def benchmark():
+# return stk.ops.dsd(topo, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0::GradX::DSD::NN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_GradW_DSD_TN(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+# topo = topo.t()
+#
+# def benchmark():
+# return stk.ops.dsd(topo, x)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0::GradW::DSD::TN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_Fwd_DSD_NN(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+#
+# def benchmark():
+# return stk.ops.dsd(x, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::Fwd::DSD::NN',
+# arguments,
+# mean_t,
+# std_t,
+# x.nnz * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_GradX_SDD_NT(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+# out = stk.ops.dsd(x, w)
+# w = transpose_view(w)
+#
+# def benchmark():
+# return stk.ops.sdd(out, w, x)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::GradX::SDD::NT',
+# arguments,
+# mean_t,
+# std_t,
+# x.nnz * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_GradW_DSD_TN(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+# out = stk.ops.dsd(x, w)
+# x = x.t()
+#
+# def benchmark():
+# return stk.ops.dsd(x, out)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::GradW::DSD::TN',
+# arguments,
+# mean_t,
+# std_t,
+# x.nnz * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_Fwd_DDD_NT(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, hs)).cuda().half()
+# w = torch.randn((ne, hs, fhs)).cuda().half()
+#
+# w = w.transpose(1, 2).contiguous()
+# w = w.transpose(1, 2)
+#
+# def benchmark():
+# return torch.bmm(x, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0::Fwd:DDD::NT',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_GradX_DDD_NN(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, hs)).cuda().half()
+# w = torch.randn((ne, hs, fhs)).cuda().half()
+# out = torch.bmm(x, w)
+# w = w.transpose(1, 2).contiguous()
+#
+# def benchmark():
+# return torch.bmm(out, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0:GradX:DDD::NN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_GradW_DDD_TN(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, hs)).cuda().half()
+# w = torch.randn((ne, hs, fhs)).cuda().half()
+# out = torch.bmm(x, w)
+# out = out.transpose(1, 2)
+#
+# def benchmark():
+# return torch.bmm(out, x)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0:GradW:DDD::TN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_Fwd_DDD_NN(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+# w = torch.randn((ne, fhs, hs)).cuda().half()
+#
+# def benchmark():
+# return torch.bmm(x, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::Fwd::DDD::NN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_GradX_DDD_NT(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+# w = torch.randn((ne, fhs, hs)).cuda().half()
+# out = torch.bmm(x, w)
+# w = torch.transpose(w, 1, 2)
+#
+# def benchmark():
+# return torch.bmm(out, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::GradX::DDD::NT',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_GradW_DDD_TN(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+# w = torch.randn((ne, fhs, hs)).cuda().half()
+# out = torch.bmm(x, w)
+# x = torch.transpose(x, 1, 2)
+#
+# def benchmark():
+# return torch.bmm(x, out)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::GradW::DDD::TN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * hs * 2,
+# )
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch210-cxx11-cu128-aarch64-linux/ops/padded_gather.py b/build/torch210-cxx11-cu128-aarch64-linux/ops/padded_gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1cf4047c9494394d2a3884ba8830179013db7ff
--- /dev/null
+++ b/build/torch210-cxx11-cu128-aarch64-linux/ops/padded_gather.py
@@ -0,0 +1,55 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for padded_gather kernel.
+class PaddedGatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bin_ids, bins, padded_bins)
+ ctx.top_k = top_k
+ return kernels.padded_gather(
+ x,
+ indices,
+ bin_ids,
+ None,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+
+ indices, bin_ids, bins, padded_bins = ctx.saved_tensors
+ out = kernels.padded_scatter(
+ grad,
+ indices,
+ bin_ids,
+ None,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+ return out, None, None, None, None, None
+
+
+padded_gather = PaddedGatherOp.apply
diff --git a/build/torch210-cxx11-cu128-aarch64-linux/ops/padded_scatter.py b/build/torch210-cxx11-cu128-aarch64-linux/ops/padded_scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..61e021b81497e472cda5d72bdac557a0ca92d262
--- /dev/null
+++ b/build/torch210-cxx11-cu128-aarch64-linux/ops/padded_scatter.py
@@ -0,0 +1,98 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for padded_scatter kernel.
+class PaddedScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+ ):
+ maybe_x = [x] if ctx.needs_input_grad[3] else []
+ ctx.save_for_backward(
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ *maybe_x,
+ )
+ ctx.top_k = top_k
+ ctx.x_shape = x.shape
+ return kernels.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ saved_tensors = ctx.saved_tensors
+
+ indices, bin_ids, weights, bins, padded_bins = saved_tensors[:5]
+ dgrad = None
+ if ctx.needs_input_grad[0]:
+ dgrad = kernels.padded_gather(
+ grad,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[3]: # need wgrad
+ x = saved_tensors[-1]
+ wgrad = kernels.padded_scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+ return dgrad, None, None, wgrad, None, None, None, None
+
+
+def padded_scatter(
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+):
+ return PaddedScatterOp.apply(
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
diff --git a/build/torch210-cxx11-cu128-aarch64-linux/ops/padded_scatter_benchmark.py b/build/torch210-cxx11-cu128-aarch64-linux/ops/padded_scatter_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..3cbe4735891446b46f93170c64c23fe63632bf93
--- /dev/null
+++ b/build/torch210-cxx11-cu128-aarch64-linux/ops/padded_scatter_benchmark.py
@@ -0,0 +1,66 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import torch
+# from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+_PADDED_SCATTER_BENCHMARK = (
+ # dMoE-Medium, 8-way EMP.
+ (1024 * 16, 1024, 8, 4),
+ # dMoE-Medium, post-all-to-all.
+ (1024 * 16 * 4, 1024, 8, 1),
+)
+
+
+# class PaddedScatterTest(parameterized.TestCase):
+#
+# @parameterized.parameters(*_PADDED_SCATTER_BENCHMARK)
+# def testPaddedScatter(self, sl, hs, ne, top_k):
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+#
+# # Randomly assign tokens to experts.
+# top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(top_expert, ne)
+# padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+# padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+#
+# # Sample weights for the scatter reduce.
+# weights = torch.rand((sl * top_k,)).cuda().half()
+#
+# # Gather the data to prepare for backwards.
+# x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
+#
+# def benchmark():
+# return ops.padded_scatter(
+# x,
+# indices,
+# bin_ids,
+# weights,
+# bins,
+# padded_bins,
+# top_k,
+# )
+#
+# time, std = benchmark_util.benchmark_function(benchmark)
+# benchmark_util.log_benchmark(
+# 'Padded Scatter',
+# {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# 'top_k': top_k,
+# },
+# time,
+# std,
+# )
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch210-cxx11-cu128-aarch64-linux/ops/permute_benchmark.py b/build/torch210-cxx11-cu128-aarch64-linux/ops/permute_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..697abddbb3a2082ec4ddd6d94f89f7faabb34b40
--- /dev/null
+++ b/build/torch210-cxx11-cu128-aarch64-linux/ops/permute_benchmark.py
@@ -0,0 +1,149 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import torch
+# from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+_PERMUTE_TESTS = (
+ (16384, 768, 2),
+ (16384, 768, 4),
+ (16384, 768, 8),
+ (16384, 768, 16),
+ (16384, 768, 32),
+ (16384, 768, 64),
+ (16384, 768, 128),
+ (16384 * 8, 768, 2),
+ (16384 * 8, 768, 4),
+ (16384 * 8, 768, 8),
+ (16384 * 8, 768, 16),
+ (16384 * 8, 768, 32),
+ (16384 * 8, 768, 64),
+ (16384 * 8, 768, 128),
+)
+
+
+# class PermuteBenchmark(parameterized.TestCase):
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testBinnedGather(self, sl, hs, ne):
+# # NOTE: Capacity factor == 1.
+# ec = sl // ne
+#
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+# top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(indices, ne)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+#
+# def benchmark():
+# return ops.binned_gather(x, indices, bins, ec)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('BinnedGather', arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testBinnedScatter(self, sl, hs, ne):
+# # NOTE: Capacity factor == 1.
+# ec = sl // ne
+#
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+# top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(indices, ne)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+# x = ops.binned_gather(x, indices, bins, ec)
+#
+# def benchmark():
+# return ops.binned_scatter(x, indices, bins)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('BinnedScatter', arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testPaddedGather(self, sl, hs, ne):
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+#
+# # Randomly assign tokens to experts.
+# top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(top_expert, ne)
+# padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+# padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+#
+# def benchmark():
+# return ops.padded_gather(x, indices, bin_ids, bins, padded_bins)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('PaddedGather', arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testPaddedScatter(self, sl, hs, ne):
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+#
+# # Randomly assign tokens to experts.
+# top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(top_expert, ne)
+# padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+# padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+# x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins)
+#
+# def benchmark():
+# return ops.padded_scatter(x, indices, bin_ids, bins, padded_bins)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('PaddedScatter', arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testCopy(self, sl, hs, ne):
+# # NOTE: Capacity factor == 1.
+# # ec = sl // ne
+#
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+# y = x.clone()
+#
+# def benchmark():
+# return y.copy_(x)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('Copy', arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch210-cxx11-cu128-aarch64-linux/ops/repeat.py b/build/torch210-cxx11-cu128-aarch64-linux/ops/repeat.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e9e09de5f857d51cd758ab30b2f3a846d6f9275
--- /dev/null
+++ b/build/torch210-cxx11-cu128-aarch64-linux/ops/repeat.py
@@ -0,0 +1,10 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+
+def repeat(x: torch.Tensor, tiling: torch.Size):
+ if all((t == 1 for t in tiling)):
+ return x
+ return x.repeat(*tiling)
diff --git a/build/torch210-cxx11-cu128-aarch64-linux/ops/replicate.py b/build/torch210-cxx11-cu128-aarch64-linux/ops/replicate.py
new file mode 100644
index 0000000000000000000000000000000000000000..26daf0eede330603a4b8ea7167faf1411d07ca93
--- /dev/null
+++ b/build/torch210-cxx11-cu128-aarch64-linux/ops/replicate.py
@@ -0,0 +1,36 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for replicate kernel.
+class ReplicateOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, bins: torch.Tensor, num_outputs: int):
+ ctx.save_for_backward(bins)
+ out = torch.empty((x.shape[0], num_outputs), dtype=x.dtype, device=x.device)
+ ops.replicate_forward(x, bins, out)
+ return out
+
+ @staticmethod
+ def backward(ctx: Any, grad: torch.Tensor):
+ bins, = ctx.saved_tensors
+ out = torch.empty((grad.shape[0], bins.shape[0]), dtype=grad.dtype, device=grad.device)
+ ops.replicate_backward(grad, bins, out)
+ return out, None, None
+
+
+replicate = ReplicateOp.apply
diff --git a/build/torch210-cxx11-cu128-aarch64-linux/ops/round_up.py b/build/torch210-cxx11-cu128-aarch64-linux/ops/round_up.py
new file mode 100644
index 0000000000000000000000000000000000000000..6cf6bc873c9f448c5fa9126ebcfd66e8688002af
--- /dev/null
+++ b/build/torch210-cxx11-cu128-aarch64-linux/ops/round_up.py
@@ -0,0 +1,14 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+
+def round_up(x: torch.Tensor, value: int):
+ assert isinstance(value, int)
+ assert x.dtype == torch.int32
+
+ # TODO(tgale): If this becomes and issue
+ # do this in a custom kernel. We only expect
+ # to use this on arrays of less than 1k elements.
+ return torch.div(x + (value - 1), value, rounding_mode='trunc') * value
diff --git a/build/torch210-cxx11-cu128-aarch64-linux/ops/scatter.py b/build/torch210-cxx11-cu128-aarch64-linux/ops/scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4605d9b46f387761b070352365f223dbfe69d47
--- /dev/null
+++ b/build/torch210-cxx11-cu128-aarch64-linux/ops/scatter.py
@@ -0,0 +1,72 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Optional
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for scatter kernel.
+class ScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ) -> torch.Tensor:
+ maybe_x = [x] if ctx.needs_input_grad[3] else []
+ ctx.save_for_backward(indices, bin_ids, weights, bins, *maybe_x)
+ ctx.top_k = top_k
+ ctx.x_shape = x.shape
+ return kernels.scatter(x, indices, bin_ids, weights, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ saved_tensors = ctx.saved_tensors
+
+ indices, bin_ids, weights, bins = saved_tensors[:4]
+ dgrad = None
+ if ctx.needs_input_grad[0]:
+ dgrad = kernels.gather(
+ grad,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[3]: # need wgrad
+ x = saved_tensors[-1]
+ wgrad = kernels.scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bin_ids,
+ bins,
+ ctx.top_k,
+ )
+ return dgrad, None, None, wgrad, None, None, None
+
+
+def scatter(
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+) -> Optional[torch.Tensor]:
+ return ScatterOp.apply(x, indices, bin_ids, weights, bins, top_k)
diff --git a/build/torch210-cxx11-cu128-aarch64-linux/ops/sort.py b/build/torch210-cxx11-cu128-aarch64-linux/ops/sort.py
new file mode 100644
index 0000000000000000000000000000000000000000..bda3bf64283e39533c2eae3627e76bb2d0262c9f
--- /dev/null
+++ b/build/torch210-cxx11-cu128-aarch64-linux/ops/sort.py
@@ -0,0 +1,38 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Optional, Tuple
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+_BITS_FOR_DTYPE = {
+ torch.int16: 16,
+ torch.int32: 32,
+ torch.int64: 64,
+}
+
+
+# Autograd wrapper for sort kernel.
+# NOTE: Does not support gradients.
+class SortOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, end_bit: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
+ if end_bit is None:
+ end_bit = _BITS_FOR_DTYPE[x.dtype]
+ x_out = torch.empty_like(x)
+ iota_out = torch.empty_like(x)
+ ops.sort(x, end_bit, x_out, iota_out)
+ return (x_out, iota_out)
+
+
+sort = SortOp.apply
diff --git a/build/torch210-cxx11-cu128-aarch64-linux/ops/sort_benchmark.py b/build/torch210-cxx11-cu128-aarch64-linux/ops/sort_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..11043c0824c36372585f1d9f48480c2a6ef32eb6
--- /dev/null
+++ b/build/torch210-cxx11-cu128-aarch64-linux/ops/sort_benchmark.py
@@ -0,0 +1,85 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import numpy as np
+import torch
+# from absl.testing import parameterized
+
+from .. import ops
+
+_SORT_TESTS = (
+ (16384, torch.int32, None),
+ (16384, torch.int32, 2),
+ (16384, torch.int32, 128),
+)
+
+_BASELINE_SORT_TESTS = ((16384,),)
+
+
+def numpy_dtype(dtype):
+ types = {
+ torch.int16: np.int16,
+ torch.int32: np.int32,
+ torch.int64: np.int64,
+ }
+ return types[dtype]
+
+
+def benchmark_function(fn, iterations=10):
+ # Run once to get rid of startup overhead.
+ fn()
+ times = []
+ for _ in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+ start.record()
+ fn()
+ end.record()
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ times = np.array(times)
+ return times.mean(), times.std(), times.max(), times.min()
+
+
+def log_benchmark(arguments, mean_t, std_t):
+ print('=' * 60)
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t))
+ print('=' * 60)
+
+
+# class SortBenchmark(parameterized.TestCase):
+#
+# @parameterized.parameters(*_SORT_TESTS)
+# def testSort(self, n, dtype, max_val):
+# if max_val is None:
+# max_val = np.iinfo(numpy_dtype(dtype)).max
+# end_bit = int(np.ceil(np.log2(max_val)))
+# x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
+#
+# mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.sort(x, end_bit),)
+# arguments = {
+# 'n': n,
+# 'dtype': dtype,
+# 'max_val': max_val,
+# }
+# log_benchmark(arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_BASELINE_SORT_TESTS)
+# def testTorchSort(self, n):
+# x = torch.randint(0, 128, (n,)).cuda().to(torch.int32)
+#
+# mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.sort(x))
+# arguments = {
+# 'n': n,
+# }
+# log_benchmark(arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch210-cxx11-cu128-aarch64-linux/ops/stk_autocast.py b/build/torch210-cxx11-cu128-aarch64-linux/ops/stk_autocast.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a3626e5e0eec51339c95a448bca84be14a2ca93
--- /dev/null
+++ b/build/torch210-cxx11-cu128-aarch64-linux/ops/stk_autocast.py
@@ -0,0 +1,39 @@
+# vendored from
+# https://github.com/stanford-futuredata/stk/blob/736313768ef697ce13a0594a41b2512a0fbc9884/stk/backend/autocast.py
+import functools
+import torch
+
+
+def _is_eligible(x):
+ return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64)
+
+
+def _cast(x, dtype):
+ if isinstance(x, torch.Tensor) and _is_eligible(x):
+ return x.to(dtype)
+ elif isinstance(x, map):
+ return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()}
+ elif isinstance(x, list) or isinstance(x, tuple):
+ return type(x)(map(lambda y: _cast(y, dtype), x))
+ return x
+
+
+def custom_fwd(fwd):
+ """Wrap a custom autograd function that always uses autocast dtype."""
+
+ @functools.wraps(fwd)
+ def decorate_fwd(*args, **kwargs):
+ if torch.is_autocast_enabled():
+ with torch.autocast(device_type="cuda", enabled=False):
+ dtype = torch.get_autocast_gpu_dtype()
+ return fwd(*_cast(args, dtype), **_cast(kwargs, dtype))
+ return fwd(*args, **kwargs)
+ return decorate_fwd
+
+
+def custom_bwd(bwd):
+ @functools.wraps(bwd)
+ def decorate_bwd(*args, **kwargs):
+ with torch.autocast(device_type="cuda", enabled=False):
+ return bwd(*args, **kwargs)
+ return decorate_bwd
\ No newline at end of file
diff --git a/build/torch210-cxx11-cu128-aarch64-linux/ops/sum.py b/build/torch210-cxx11-cu128-aarch64-linux/ops/sum.py
new file mode 100644
index 0000000000000000000000000000000000000000..e00c1aa68e1193f5b72f75a2edc37de8d505facc
--- /dev/null
+++ b/build/torch210-cxx11-cu128-aarch64-linux/ops/sum.py
@@ -0,0 +1,9 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+import torch
+
+
+def sum(x: torch.Tensor, dim: int = 0):
+ if x.shape[dim] == 1:
+ return x.squeeze(dim=dim)
+ return x.sum(dim=dim)
diff --git a/build/torch210-cxx11-cu128-aarch64-linux/ops/topology.py b/build/torch210-cxx11-cu128-aarch64-linux/ops/topology.py
new file mode 100644
index 0000000000000000000000000000000000000000..76a50d3164db20534b099dcb4d8487a7aef25d15
--- /dev/null
+++ b/build/torch210-cxx11-cu128-aarch64-linux/ops/topology.py
@@ -0,0 +1,45 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for topology kernel.
+# NOTE: Does not support gradients.
+class TopologyOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(
+ ctx: Any,
+ padded_bins: torch.Tensor,
+ block_size: int,
+ output_block_rows: int,
+ output_block_columns: int,
+ ):
+ out = torch.empty(
+ output_block_rows * output_block_columns,
+ dtype=torch.int16,
+ device=padded_bins.device,
+ )
+ ops.indices(
+ padded_bins,
+ block_size,
+ output_block_rows,
+ output_block_columns,
+ out,
+ )
+ return out
+
+
+topology = TopologyOp.apply
diff --git a/build/torch210-cxx11-cu128-aarch64-linux/stk/__init__.py b/build/torch210-cxx11-cu128-aarch64-linux/stk/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..73c40c267b1c3f4949e9c957a5d2c9f682dfc1a6
--- /dev/null
+++ b/build/torch210-cxx11-cu128-aarch64-linux/stk/__init__.py
@@ -0,0 +1,7 @@
+# import stk.random
+# import stk.ops
+# from stk.matrix import Matrix
+
+from . import random
+from . import ops
+from .matrix import Matrix
diff --git a/build/torch210-cxx11-cu128-aarch64-linux/stk/backend/__init__.py b/build/torch210-cxx11-cu128-aarch64-linux/stk/backend/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/build/torch210-cxx11-cu128-aarch64-linux/stk/backend/autocast.py b/build/torch210-cxx11-cu128-aarch64-linux/stk/backend/autocast.py
new file mode 100644
index 0000000000000000000000000000000000000000..97f6e919a60f3fd579ed0215031008d14111dc96
--- /dev/null
+++ b/build/torch210-cxx11-cu128-aarch64-linux/stk/backend/autocast.py
@@ -0,0 +1,37 @@
+import functools
+import torch
+
+
+def _is_eligible(x):
+ return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64)
+
+
+def _cast(x, dtype):
+ if isinstance(x, torch.Tensor) and _is_eligible(x):
+ return x.to(dtype)
+ elif isinstance(x, map):
+ return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()}
+ elif isinstance(x, list) or isinstance(x, tuple):
+ return type(x)(map(lambda y: _cast(y, dtype), x))
+ return x
+
+
+def custom_fwd(fwd):
+ """Wrap a custom autograd function that always uses autocast dtype."""
+
+ @functools.wraps(fwd)
+ def decorate_fwd(*args, **kwargs):
+ if torch.is_autocast_enabled():
+ with torch.autocast(device_type="cuda", enabled=False):
+ dtype = torch.get_autocast_gpu_dtype()
+ return fwd(*_cast(args, dtype), **_cast(kwargs, dtype))
+ return fwd(*args, **kwargs)
+ return decorate_fwd
+
+
+def custom_bwd(bwd):
+ @functools.wraps(bwd)
+ def decorate_bwd(*args, **kwargs):
+ with torch.autocast(device_type="cuda", enabled=False):
+ return bwd(*args, **kwargs)
+ return decorate_bwd
diff --git a/build/torch210-cxx11-cu128-aarch64-linux/stk/backend/sputnik.py b/build/torch210-cxx11-cu128-aarch64-linux/stk/backend/sputnik.py
new file mode 100644
index 0000000000000000000000000000000000000000..220c947bc1e932e8c77cc30f4069e9930f1aa962
--- /dev/null
+++ b/build/torch210-cxx11-cu128-aarch64-linux/stk/backend/sputnik.py
@@ -0,0 +1,316 @@
+import torch
+
+from ..backend import triton_kernels as backend
+from ..backend.autocast import custom_bwd, custom_fwd
+
+
+def _standardize_shape(x, transpose):
+ if transpose:
+ return torch.Size((x[1], x[0]))
+ return x
+
+
+def _sparse_transpose(x):
+ return (torch.Size((x[0][1], x[0][0])), ) + x[1:]
+
+
+def _transpose_helper(x, transpose):
+ if isinstance(x, torch.Tensor):
+ return x.t() if transpose else x
+ if transpose:
+ x = _sparse_transpose(x)
+ return x + (transpose,)
+
+
+def _wrap(x):
+ if isinstance(x, torch.Tensor):
+ return (x,)
+ return x
+
+
+def _is_transposed(x):
+ return (not x.is_contiguous() and
+ x.stride()[0] == 1 and
+ x.stride()[1] == x.size()[0])
+
+
+def _call_helper(op, out, a, b, trans_a, trans_b):
+ args = (_wrap(_transpose_helper(a, trans_a)) +
+ _wrap(_transpose_helper(b, trans_b)))
+ if isinstance(out, tuple):
+ args = args + out
+ return op(*args)
+
+
+def _preprocess_inputs(lhs, rhs, dy):
+ if isinstance(lhs, torch.Tensor) and _is_transposed(lhs):
+ lhs = lhs.t()
+ if isinstance(rhs, torch.Tensor) and _is_transposed(rhs):
+ rhs = rhs.t()
+ if (isinstance(dy, torch.Tensor) and
+ not dy.is_contiguous() and
+ not _is_transposed(dy)):
+ dy = dy.contiguous()
+ if isinstance(dy, tuple) and not dy[1].is_contiguous():
+ dy = (dy[0], dy[1].contiguous()) + dy[2:]
+ return lhs, rhs, dy
+
+
+def _postprocess_outputs(x, transpose, grad):
+ if isinstance(x, torch.Tensor) and transpose:
+ return grad.t()
+ return grad
+
+
+def _lhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs):
+ lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy)
+
+ a, b = (rhs, dy) if trans_lhs else (dy, rhs)
+ trans_a = trans_lhs and trans_rhs
+ trans_b = trans_lhs or not trans_rhs
+ out = _call_helper(op, lhs, a, b, trans_a, trans_b)
+ return _postprocess_outputs(lhs, trans_lhs, out)
+
+
+def _rhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs):
+ lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy)
+
+ a, b = (dy, lhs) if trans_rhs else (lhs, dy)
+ trans_a = not trans_lhs or trans_rhs
+ trans_b = trans_lhs and trans_rhs
+ out = _call_helper(op, rhs, a, b, trans_a, trans_b)
+ return _postprocess_outputs(rhs, trans_rhs, out)
+
+
+class DSD(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs):
+ ctx.save_for_backward(data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ rhs)
+ ctx.shape = _standardize_shape(shape, transpose_a)
+ ctx.transpose_a = transpose_a
+
+ out = torch.empty(
+ (shape[0], rhs.size()[1]),
+ dtype=rhs.dtype,
+ device=rhs.device)
+
+ backend.dsd(shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs,
+ out)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs = (ctx.shape,) + saved_tensors[:-1]
+ rhs = saved_tensors[-1]
+ trans_a = ctx.transpose_a
+ trans_b = _is_transposed(rhs)
+
+ ddata = None
+ if ctx.needs_input_grad[1]:
+ ddata = _lhs_gradient(sdd,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ drhs = None
+ if ctx.needs_input_grad[-1]:
+ op = dds if trans_b else dsd
+ drhs = _rhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return None, ddata, None, None, None, None, None, None, None, drhs
+
+
+dsd = DSD.apply
+
+
+class DDS(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b):
+ ctx.save_for_backward(lhs,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t)
+ ctx.shape = _standardize_shape(shape, transpose_b)
+ ctx.transpose_b = transpose_b
+ out = torch.empty((lhs.size()[0], shape[1]),
+ dtype=lhs.dtype,
+ device=lhs.device)
+ backend.dds(lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b,
+ out)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs = saved_tensors[0]
+ rhs = (ctx.shape,) + saved_tensors[1:]
+ trans_a = _is_transposed(lhs)
+ trans_b = ctx.transpose_b
+
+ dlhs = None
+ if ctx.needs_input_grad[0]:
+ op = dsd if trans_a else dds
+ dlhs = _lhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ ddata = None
+ if ctx.needs_input_grad[2]:
+ ddata = _rhs_gradient(sdd,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return dlhs, None, ddata, None, None, None, None, None, None, None
+
+
+dds = DDS.apply
+
+
+class SDD(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ lhs,
+ rhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t):
+ ctx.save_for_backward(
+ lhs,
+ rhs,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t)
+ ctx.shape = shape
+ out = torch.empty(
+ data.shape,
+ dtype=lhs.dtype,
+ device=lhs.device)
+ backend.sdd(lhs,
+ rhs,
+ shape,
+ out,
+ offsets,
+ row_indices,
+ column_indices)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs, rhs = saved_tensors[:2]
+ dy = (ctx.shape, dy) + saved_tensors[2:]
+ trans_a = _is_transposed(lhs)
+ trans_b = _is_transposed(rhs)
+
+ dlhs = None
+ if ctx.needs_input_grad[0]:
+ op = dds if trans_a else dsd
+ dlhs = _lhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ drhs = None
+ if ctx.needs_input_grad[1]:
+ op = dsd if trans_b else dds
+ drhs = _rhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return dlhs, drhs, None, None, None, None, None, None, None, None
+
+
+sdd = SDD.apply
+
+class RowIndices(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, shape, data, offsets, column_indices):
+ out = torch.empty(
+ column_indices.shape,
+ dtype=column_indices.dtype,
+ device=column_indices.device)
+ backend.row_indices(shape, data, offsets, column_indices, out)
+ return out
+
+
+row_indices = RowIndices.apply
diff --git a/build/torch210-cxx11-cu128-aarch64-linux/stk/backend/triton_kernels.py b/build/torch210-cxx11-cu128-aarch64-linux/stk/backend/triton_kernels.py
new file mode 100644
index 0000000000000000000000000000000000000000..c535309f3321249f475367164a558f94a4f8eb86
--- /dev/null
+++ b/build/torch210-cxx11-cu128-aarch64-linux/stk/backend/triton_kernels.py
@@ -0,0 +1,393 @@
+import torch
+import triton
+import triton.language as tl
+from dataclasses import dataclass
+
+@dataclass
+class TritonConfig:
+ BLOCK_M: int = 128
+ BLOCK_N: int = 128
+ BLOCK_K: int = 32
+ BLOCK_SIZE: int = 128
+ NUM_STAGES: int = 4
+ NUM_WARPS: int = 4
+
+def _validate_matmul_dims(M: int, K: int, N: int):
+ error_string = "incompatible dimensions: tensor has dim with length: {}, which must be divisible by {}"
+ assert M % TritonConfig.BLOCK_M == 0, error_string.format(M, TritonConfig.BLOCK_M)
+ assert K % TritonConfig.BLOCK_K == 0, error_string.format(K, TritonConfig.BLOCK_K)
+ assert N % TritonConfig.BLOCK_N == 0, error_string.format(N, TritonConfig.BLOCK_N)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _sdd_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+ # matrix multiplication
+ pid = tl.program_id(0)
+ pid_m = tl.load(row_indices + pid)
+ pid_n = tl.load(column_indices + pid)
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
+ rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
+ rk = tl.arange(0, BLOCK_K)
+ # pointers
+ A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
+ B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ for k in range(0, tl.cdiv(K, BLOCK_K)):
+ a = tl.load(A)
+ b = tl.load(B)
+ acc += tl.dot(a, b)
+ A += BLOCK_K * stride_ak
+ B += BLOCK_K * stride_bk
+ #Store to sparse matrix
+ acc = acc.to(C.dtype.element_ty)
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+ cm = tl.arange(0, BLOCK_M)
+ cn = tl.arange(0, BLOCK_N)
+ C = C + pid * BLOCK_ELEMENTS + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _dsd_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices, offsets,
+ block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+
+ # matrix multiplication
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+
+ num_pid_m = tl.num_programs(0)
+ num_pid_n = tl.num_programs(1)
+ pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M)
+
+ start_inx = tl.load(offsets + pid_m)
+ end_inx = tl.load(offsets + pid_m + 1)
+
+ # pointers to sparse matrix
+ rm = tl.arange(0, BLOCK_M)
+ rak = tl.arange(0, BLOCK_K)
+
+ A += (rm[:, None] * stride_am + rak[None, :] * stride_ak)
+
+ # pointers to dense matrix
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ rbk = tl.arange(0, BLOCK_K)
+ B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn)
+
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K)
+
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+ ak_sub_incr = BLOCK_K * stride_ak
+ bk_sub_incr = BLOCK_K * stride_bk
+ bk_block_incr = BLOCK_SIZE * stride_bk
+
+ for k in range(nsub_blocks * (end_inx - start_inx)):
+ sub_block_inx = k % nsub_blocks
+ block_inx = k // nsub_blocks
+
+ if trans_A:
+ ptr_A = A + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr
+ else:
+ ptr_A = A + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr
+
+ ptr_B = B + tl.load(column_indices + start_inx + block_inx) * bk_block_incr + sub_block_inx * bk_sub_incr
+
+ a = tl.load(ptr_A)
+ b = tl.load(ptr_B)
+ acc += tl.dot(a, b)
+
+ acc = acc.to(C.dtype.element_ty)
+
+ cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+
+ C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _dds_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices, offsets,
+ block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+
+ # matrix multiplication
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+
+ num_pid_m = tl.num_programs(0)
+ num_pid_n = tl.num_programs(1)
+ pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M)
+
+ start_inx = tl.load(offsets + pid_n)
+ end_inx = tl.load(offsets + pid_n + 1)
+
+ # pointers to dense matrix
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ rak = tl.arange(0, BLOCK_K)
+
+ A += (rm[:, None] * stride_am + rak[None, :] * stride_ak)
+
+ # pointers to sparse matrix
+ rn = tl.arange(0, BLOCK_N)
+ rbk = tl.arange(0, BLOCK_K)
+ B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn)
+
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K)
+
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+
+ ak_sub_incr = BLOCK_K * stride_ak
+ ak_block_incr = BLOCK_SIZE * stride_ak
+ bk_sub_incr = BLOCK_K * stride_bk
+
+ for k in range(nsub_blocks * (end_inx - start_inx)):
+ sub_block_inx = k % nsub_blocks
+ block_inx = k // nsub_blocks
+
+ if trans_B:
+ ptr_B = B + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr
+ else:
+ ptr_B = B + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr
+
+ ptr_A = A + tl.load(column_indices + start_inx + block_inx) * ak_block_incr + sub_block_inx * ak_sub_incr
+ a = tl.load(ptr_A)
+ b = tl.load(ptr_B)
+ acc += tl.dot(a, b)
+
+ acc = acc.to(C.dtype.element_ty)
+ cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+def dsd(shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs,
+ out
+ ):
+
+ device = rhs.device
+ trans_A = transpose_a
+ trans_B = False
+
+ if rhs.stride(0) > 1 and rhs.stride(1) > 1:
+ trans_B = True
+
+ # checks constraints
+ assert shape[1] == rhs.shape[0], "incompatible dimensions"
+ M, K = shape
+ _, N = rhs.shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if rhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ stride_am, stride_ak = data.stride(1), data.stride(2)
+ stride_bk, stride_bn = rhs.stride(0), rhs.stride(1)
+ a_column_indices = column_indices
+ a_offsets = offsets
+
+ # launch kernel
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N']))
+
+ if trans_A:
+ stride_am, stride_ak = data.stride(2), data.stride(1)
+ a_column_indices, a_offsets = column_indices_t, offsets_t
+
+ if trans_B:
+ stride_bk, stride_bn = rhs.stride(1), rhs.stride(0)
+
+ _dsd_kernel[grid](
+ data.data, rhs, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(0), out.stride(1),
+ row_indices, a_column_indices, a_offsets,
+ block_offsets_t, trans_A, trans_B,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+ # return out
+
+def dds(lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b,
+ out
+ ):
+
+ device = lhs.device
+ trans_B = transpose_b
+ trans_A = False
+
+ if lhs.stride(0) > 1 and lhs.stride(1) > 1:
+ trans_A = True
+
+ # checks constraints
+ assert lhs.shape[1] == shape[0], "incompatible dimensions"
+ M, K = lhs.shape
+ _, N = shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if lhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ stride_am, stride_ak = lhs.stride(0), lhs.stride(1)
+ stride_bk, stride_bn = data.stride(1), data.stride(2)
+ b_column_indices = column_indices_t
+ b_offsets = offsets_t
+
+ # launch kernel
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N']))
+
+ if trans_A:
+ stride_am, stride_ak = lhs.stride(1), lhs.stride(0)
+ if trans_B:
+ stride_bk, stride_bn = data.stride(2), data.stride(1)
+ b_column_indices, b_offsets = column_indices, offsets
+
+ _dds_kernel[grid](
+ lhs, data, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(0), out.stride(1),
+ row_indices, b_column_indices, b_offsets,
+ block_offsets_t, trans_A, trans_B,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+
+def sdd(lhs,
+ rhs,
+ shape,
+ out,
+ offsets,
+ row_indices,
+ column_indices
+ ):
+
+ device = out.device
+ trans_A = False
+ trans_B = False
+
+ if lhs.stride(0) > 1 and lhs.stride(1) > 1:
+ trans_A = True
+ if rhs.stride(0) > 1 and rhs.stride(1) > 1:
+ trans_B = True
+
+ # checks constraints
+ assert lhs.shape[1] == rhs.shape[0], "incompatible dimensions"
+ M, K = lhs.shape
+ _, N = rhs.shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if out.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ # launch kernel
+ nnz_blocks = len(row_indices)
+ grid = lambda META: (nnz_blocks,)
+
+ stride_am, stride_ak = lhs.stride(0), lhs.stride(1)
+ stride_bk, stride_bn = rhs.stride(0), rhs.stride(1)
+
+ if trans_A:
+ stride_am, stride_ak = lhs.stride(1), lhs.stride(0)
+ if trans_B:
+ stride_bk, stride_bn = rhs.stride(1), rhs.stride(0)
+
+ _sdd_kernel[grid](
+ lhs, rhs, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(1), out.stride(2),
+ row_indices, column_indices,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+
+@triton.jit
+def _row_indices_kernel(offsets, out):
+ pid = tl.program_id(0)
+ row_offset = tl.load(offsets + pid)
+ nnz_blocks = tl.load(offsets + pid + 1) - row_offset
+ for nnz_block in range(nnz_blocks):
+ tl.store(out + row_offset + nnz_block, pid)
+
+def row_indices(
+ shape, data, offsets, column_indices, out
+):
+ block_rows = len(offsets) - 1
+ _row_indices_kernel[(block_rows, )](offsets, out)
diff --git a/build/torch210-cxx11-cu128-aarch64-linux/stk/matrix.py b/build/torch210-cxx11-cu128-aarch64-linux/stk/matrix.py
new file mode 100644
index 0000000000000000000000000000000000000000..80f42263d6aada287adbfa52a61fe950162a9e28
--- /dev/null
+++ b/build/torch210-cxx11-cu128-aarch64-linux/stk/matrix.py
@@ -0,0 +1,329 @@
+import numpy as np
+import torch
+
+# 1. Add heavyweight (data) validation helper.
+# 2. Add construction helpers
+# 3. Make indentation consistent
+# 4. Replace asserts with descriptive errors.
+
+##
+### Validation helpers.
+##
+
+
+def _validate_matrix(shape, data, row_indices, column_indices, offsets):
+ # Data should be [nnz, block_size, block_size]
+ if data.dim() == 1:
+ data = torch.reshape(data, [data.numel(), 1, 1])
+
+ # Blocks should be square.
+ if data.shape[-2] != data.shape[-1]:
+ raise ValueError(
+ "Expected square blocking in data. "
+ f"Got block shape {[data.shape[-2], data.shape[-1]]}")
+
+ # Flatten batch dimensions on data - original shape preserved
+ # in shape argument.
+ block_size = data.shape[-1]
+ data = data.view([-1, block_size, block_size])
+
+ if data.dim() != 3:
+ raise ValueError(
+ "Expected 3D shape for data (nnz, block, block). "
+ f"Got shape {data.dim()}D shape.")
+
+ block_size = data.shape[1]
+ if shape[-2] % block_size != 0 or shape[-1] % block_size != 0:
+ raise ValueError(
+ "Matrix shape must be dividible by blocking. "
+ f"Got shape {shape} with "
+ f"{[block_size, block_size]} blocking.")
+
+ if np.prod(shape) < data.numel():
+ raise ValueError(
+ "Invalid matrix. Number of nonzeros exceeds matrix capacity "
+ f"({data.numel()} v. {np.prod(shape)})")
+
+ if row_indices.dim() != 1:
+ raise ValueError(
+ f"Expected 1D row_indices. Got {row_indices.dim()}D row_indices.")
+
+ if column_indices.dim() != 1:
+ raise ValueError(
+ f"Expected 1D column_indices. Got {column_indices.dim()}D column_indices.")
+
+ if offsets.dim() != 1:
+ raise ValueError(
+ f"Expected 1D offsets. Got {offsets.dim()}D offsets.")
+
+ if row_indices.numel() != data.shape[0]:
+ raise ValueError(
+ "Expected 1 index per nonzero block. "
+ f"Got {row_indices.numel()} row_indices for {data.shape[0]} blocks")
+
+ if column_indices.numel() != data.shape[0]:
+ raise ValueError(
+ "Expected 1 index per nonzero block. "
+ f"Got {column_indices.numel()} column_indices for {data.shape[0]} blocks")
+
+ block_rows = np.prod(shape[:-1]) / block_size
+ if offsets.numel() != block_rows + 1:
+ raise ValueError(
+ "Expected one offset per block row plus one. "
+ f"Got {offsets.numel()} offsets with {block_rows} block rows.")
+
+ is_cuda = (data.is_cuda and
+ row_indices.is_cuda and
+ column_indices.is_cuda and
+ offsets.is_cuda)
+ is_cpu = (not data.is_cuda and
+ not row_indices.is_cuda and
+ not column_indices.is_cuda and
+ not offsets.is_cuda)
+ if not (is_cuda or is_cpu):
+ raise ValueError(
+ "Expected data & meta-data on common device. "
+ f"Got data on {data.device}, row_indices on {row_indices.device} "
+ f"column_indices on {column_indices.device} and "
+ f"offsets on {offsets.device}.")
+
+ if data.dtype != torch.float16:
+ raise ValueError(
+ f"Expected float16 data. Got {data.dtype} data.")
+ if row_indices.dtype != torch.int16:
+ raise ValueError(
+ f"Expected int16 row_indices. Got {row_indices.dtype} row_indices.")
+ if column_indices.dtype != torch.int16:
+ raise ValueError(
+ f"Expected int16 column_indices. Got {column_indices.dtype} column_indices.")
+ if offsets.dtype != torch.int32:
+ raise ValueError(
+ f"Expected int32 offsets. Got {offsets.dtype} offsets.")
+ return data
+
+
+def _transpose(size, data, row_indices, column_indices, offsets):
+ block_columns = size[1] // data.shape[1]
+
+ # Sort row indices by column indices to get the transposed matrix's
+ # column indices.
+ gather_indices = column_indices.argsort()
+ column_indices_t = row_indices.gather(0, gather_indices)
+ block_offsets_t = gather_indices.int()
+
+ # NOTE: Histogram is not implemented for any integer type on CPU. Do
+ # the histogram in 32-bit float, which can exactly represent 16-bit
+ # integers.
+ column_indices_float = column_indices.float()
+
+ zero = torch.zeros((1,), dtype=torch.int32, device=data.device)
+ nnz_per_column = column_indices_float.histc(block_columns, 0, block_columns)
+ nnz_per_column = nnz_per_column.int()
+ offsets_t = torch.cat([zero, nnz_per_column.cumsum(0, dtype=torch.int32)])
+ return column_indices_t, offsets_t, block_offsets_t
+
+
+class Matrix(torch.nn.Module):
+ """A matrix stored in sparse format.
+
+ Underlying format is block compressed sparse row (BCSR).
+
+ TODO(tgale): Make this mirror torch.Tensor API as much as possible.
+ """
+
+ def __init__(self,
+ size,
+ data,
+ row_indices,
+ column_indices,
+ offsets,
+ column_indices_t=None,
+ offsets_t=None,
+ block_offsets_t=None):
+ super().__init__()
+ self._size = size
+ self._data = data
+ self._row_indices = row_indices
+ self._column_indices = column_indices
+ self._offsets = offsets
+
+ # Produce the transpose meta-data if it is not passed in.
+ if ((column_indices_t is None) or (offsets_t is None) or
+ (block_offsets_t is None)):
+ column_indices_t, offsets_t, block_offsets_t = _transpose(
+ size, data, row_indices, column_indices, offsets)
+ self._column_indices_t = column_indices_t
+ self._offsets_t = offsets_t
+ self._block_offsets_t = block_offsets_t
+
+ self._transposed = False
+
+ # Validate that our metadata will not overflow.
+ max_dim = np.iinfo(np.int16).max * self.blocking
+ if column_indices.dtype == torch.int16:
+ if size[0] > max_dim or size[1] > max_dim:
+ raise ValueError(
+ "Sparse matrix with shape {size} exceeds representable "
+ "size with 16-bit indices.")
+
+ def validate(self):
+ _validate_matrix(self._size,
+ self._data,
+ self._row_indices,
+ self._column_indices,
+ self._offsets)
+
+ # TODO(tgale): Add heavyweight data validation.
+
+ def to(self, device):
+ # TODO(tgale): Handle type conversions here. We
+ # need to set the appropriate meta-data type for
+ # the given floating-point type.
+ self._data = self._data.to(device)
+ self._row_indices = self._row_indices.to(device)
+ self._column_indices = self._column_indices.to(device)
+ self._offsets = self._offsets.to(device)
+ self._column_indices_t = self._column_indices_t.to(device)
+ self._offsets_t = self._offsets_t.to(device)
+ self._block_offsets_t = self._block_offsets_t.to(device)
+ return self
+
+ def cuda(self):
+ return self.to(torch.cuda.current_device())
+
+ def clone(self):
+ return Matrix(
+ self.size(),
+ self.data.clone(),
+ self.row_indices.clone(),
+ self.column_indices.clone(),
+ self.offsets.clone(),
+ self.column_indices_t.clone(),
+ self.offsets_t.clone(),
+ self.block_offsets_t.clone())
+
+ def t(self):
+ if self.dim() != 2:
+ raise ValueError(
+ "t() expects a tensor with <= 2 dimensions, "
+ f"but self is {self.dim()}D.")
+ out = Matrix(self.size(),
+ self.data,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+ out._transposed = not self._transposed
+ out._size = torch.Size((self._size[1], self._size[0]))
+ return out
+
+ def contiguous(self):
+ raise ValueError("Not yet implemented.")
+
+ def is_contiguous(self):
+ return not self._transposed
+
+ @property
+ def is_cuda(self):
+ return self._data.is_cuda
+
+ @property
+ def device(self):
+ return self._data.device
+
+ def size(self):
+ return self._size
+
+ @property
+ def shape(self):
+ return self.size()
+
+ def dim(self):
+ return len(self._size)
+
+ @property
+ def data(self):
+ return self._data
+
+ @property
+ def row_indices(self):
+ return self._row_indices
+
+ @property
+ def column_indices(self):
+ return self._column_indices
+
+ @property
+ def offsets(self):
+ return self._offsets
+
+ @property
+ def offsets_t(self):
+ return self._offsets_t
+
+ @property
+ def column_indices_t(self):
+ return self._column_indices_t
+
+ @property
+ def block_offsets_t(self):
+ return self._block_offsets_t
+
+ @property
+ def dtype(self):
+ return self.data.dtype
+
+ @property
+ def nnz(self):
+ return self.data.numel()
+
+ @property
+ def blocking(self):
+ return self.data.shape[1]
+
+ @property
+ def requires_grad(self):
+ return self.data.requires_grad
+
+ def requires_grad_(self, x):
+ self.data.requires_grad_(x)
+ return self
+
+ def view(self, *shape):
+ assert self.is_contiguous()
+ if shape[-1] != self.size()[-1]:
+ raise ValueError(
+ "Can't change view on compressed dimension. "
+ f"{self.size()[-1]} v. {shape[-1]}.")
+ if np.prod(shape) != np.prod(self.size()):
+ raise ValueError(
+ "Mismatch in numel of Matrix and new shape. "
+ f"{np.prod(self.size())} v. {np.prod(shape)}")
+ return Matrix(shape,
+ self.data,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+
+ @property
+ def grad(self):
+ # TODO(tgale): Make sure this mirrors torch.Tensor
+ # behavior in the case where we ask for the gradient
+ # of a non-contiguous tensor.
+ size = self.size()
+ if not self.is_contiguous():
+ size = torch.Size((size[1], size[0]))
+ out = Matrix(size,
+ self.data.grad,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+ return out if self.is_contiguous() else out.t()
diff --git a/build/torch210-cxx11-cu128-aarch64-linux/stk/ops/__init__.py b/build/torch210-cxx11-cu128-aarch64-linux/stk/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc873b236f4cd4036964c016a4036e3ce5ebf1ac
--- /dev/null
+++ b/build/torch210-cxx11-cu128-aarch64-linux/stk/ops/__init__.py
@@ -0,0 +1,3 @@
+from .linear_ops import dds, dsd, sdd
+from .matrix_ops import ones_like, row_indices, sum, to_dense, to_sparse
+from .eltwise_ops import mul
diff --git a/build/torch210-cxx11-cu128-aarch64-linux/stk/ops/eltwise_ops.py b/build/torch210-cxx11-cu128-aarch64-linux/stk/ops/eltwise_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba7d7332320250fd01fa60e60528f19de3e8ed03
--- /dev/null
+++ b/build/torch210-cxx11-cu128-aarch64-linux/stk/ops/eltwise_ops.py
@@ -0,0 +1,28 @@
+from ..matrix import Matrix
+
+def mul(a, b):
+ """Performs element-wise multiplication of matrices a and b.
+
+ It is the user's responsibility to make sure that a and b
+ follow the same matrix topology. This function assumes it is safe
+ to use the topoplogy of a.
+
+ Args:
+ a: stk.Matrix.
+ b: stk.Matrix with a's matrix topology.
+
+ Returns:
+ stk.Matrix where the entries correspond to torch.mul(a, b).
+ """
+ assert isinstance(a, Matrix)
+ assert isinstance(b, Matrix)
+ assert a.size() == b.size()
+
+ return Matrix(a.size(),
+ a.data * b.data,
+ a.row_indices,
+ a.column_indices,
+ a.offsets,
+ a.column_indices_t,
+ a.offsets_t,
+ a.block_offsets_t)
diff --git a/build/torch210-cxx11-cu128-aarch64-linux/stk/ops/eltwise_ops_test.py b/build/torch210-cxx11-cu128-aarch64-linux/stk/ops/eltwise_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..2939372a5c68ac92b47b11015db4f75f4fd60ffa
--- /dev/null
+++ b/build/torch210-cxx11-cu128-aarch64-linux/stk/ops/eltwise_ops_test.py
@@ -0,0 +1,86 @@
+import unittest
+import itertools
+import torch
+# from absl.testing import parameterized
+
+import stk
+from stk.ops.linear_ops_test import allclose, _dense_and_sparse
+
+_MATRIX_SIZES = (
+ (128, 128, 0.0),
+ (256, 256, 0.5),
+ (2048, 1024, 0.8),
+ (512, 128, 0.0),
+ (128, 512, 0.0),
+ (1024, 512, 0.0),
+ (1024, 512, 0.5),
+ (1024, 512, 0.75),
+ (512, 1024, 0.0),
+ (512, 1024, 0.5),
+ (512, 1024, 0.75),
+ (1024, 1024, 0.0),
+ (1024, 1024, 0.5),
+ (1024, 1024, 0.75),
+)
+
+_DTYPE = (
+ torch.float16, torch.bfloat16
+)
+
+def _generate_testcases():
+ testcases = itertools.product(_MATRIX_SIZES, _DTYPE)
+ testcases = [(*size, 128, dtype) for
+ (size, dtype) in testcases]
+ return testcases
+
+_ELTWISE_OP_TESTS = _generate_testcases()
+
+def _dense_and_sparse_like(x, std=0.1):
+ dense_data = torch.randn_like(x.data, device=x.device) * std
+ sparse = stk.Matrix(x.size(),
+ dense_data,
+ x.row_indices,
+ x.column_indices,
+ x.offsets)
+ dense = stk.ops.to_dense(sparse)
+
+ return (dense.requires_grad_(True),
+ sparse.requires_grad_(True))
+
+# @parameterized.parameters(_ELTWISE_OP_TESTS)
+# class EltwiseOpsTest(parameterized.TestCase):
+#
+# def testEltwiseMul(self, m, n, sparsity, blocking, dtype):
+#
+# a_dense, a = _dense_and_sparse(m, n, sparsity, blocking, dtype)
+# b_dense, b = _dense_and_sparse_like(a)
+#
+# out = stk.ops.mul(a, b)
+# expected_out = torch.mul(a_dense, b_dense)
+#
+# # Compute the gradients w.r.t. the inputs.
+# expected_out.sum().backward()
+# stk.ops.sum(out).backward()
+#
+# # Validate the results.
+# out = stk.ops.to_dense(out)
+# self.assertEqual(out.dim(), 2)
+# self.assertEqual(expected_out.size(), out.size())
+# self.assertTrue(allclose(out, expected_out))
+#
+# # LHS gradient.
+# grad = stk.ops.to_dense(a.grad)
+# expected_grad = a_dense.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size(), grad.size())
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# # RHS gradient.
+# grad = stk.ops.to_dense(b.grad)
+# expected_grad = b_dense.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size(), grad.size())
+# self.assertTrue(allclose(grad, expected_grad))
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch210-cxx11-cu128-aarch64-linux/stk/ops/linear_ops.py b/build/torch210-cxx11-cu128-aarch64-linux/stk/ops/linear_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d277c8c07f9e30addc31900a12175c8a1f4d7ad
--- /dev/null
+++ b/build/torch210-cxx11-cu128-aarch64-linux/stk/ops/linear_ops.py
@@ -0,0 +1,59 @@
+import torch
+
+from ..backend import sputnik
+from ..matrix import Matrix
+
+
+def dsd(a, b):
+ assert isinstance(a, Matrix)
+ assert isinstance(b, torch.Tensor)
+ return sputnik.dsd(
+ a.size(),
+ a.data, a.offsets,
+ a.row_indices,
+ a.column_indices,
+ a.offsets_t,
+ a.column_indices_t,
+ a.block_offsets_t,
+ not a.is_contiguous(),
+ b)
+
+
+def dds(a, b):
+ assert isinstance(a, torch.Tensor)
+ assert isinstance(b, Matrix)
+ return sputnik.dds(
+ a,
+ b.size(),
+ b.data, b.offsets,
+ b.row_indices,
+ b.column_indices,
+ b.offsets_t,
+ b.column_indices_t,
+ b.block_offsets_t,
+ not b.is_contiguous())
+
+
+def sdd(a, b, topo):
+ assert isinstance(a, torch.Tensor)
+ assert isinstance(b, torch.Tensor)
+ assert isinstance(topo, Matrix)
+ assert topo.is_contiguous()
+ out = sputnik.sdd(
+ a, b,
+ topo.size(),
+ topo.data,
+ topo.offsets,
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets_t,
+ topo.column_indices_t,
+ topo.block_offsets_t)
+ return Matrix(topo.size(),
+ out,
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets,
+ topo.column_indices_t,
+ topo.offsets_t,
+ topo.block_offsets_t)
diff --git a/build/torch210-cxx11-cu128-aarch64-linux/stk/ops/linear_ops_test.py b/build/torch210-cxx11-cu128-aarch64-linux/stk/ops/linear_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1c24d350df9c1b2346c7da885502cd696c88867
--- /dev/null
+++ b/build/torch210-cxx11-cu128-aarch64-linux/stk/ops/linear_ops_test.py
@@ -0,0 +1,216 @@
+import unittest
+import itertools
+import numpy as np
+import torch
+# from absl.testing import parameterized
+
+import stk
+
+
+def allclose(x, y, pct=0.25):
+ mask = torch.isclose(x, y, rtol=5e-2)
+ pct_diff = (mask.numel() - mask.sum()) / mask.numel() * 100
+ if pct_diff > pct:
+ print("{:.2f}% of values not close.".format(pct_diff))
+ return False
+ return True
+
+
+# An assortment of problems designed to make sure
+# the bindings are operating correctly.
+_MATRIX_SIZES = (
+ (128, 128, 128, 0.0),
+ (256, 256, 256, 0.5),
+ (2048, 1024, 512, 0.8),
+ (512, 128, 128, 0.0),
+ (128, 128, 512, 0.0),
+ (1024, 512, 512, 0.0),
+ (1024, 512, 512, 0.5),
+ (1024, 512, 512, 0.75),
+ (512, 512, 1024, 0.0),
+ (512, 512, 1024, 0.5),
+ (512, 512, 1024, 0.75),
+ (1024, 1024, 1024, 0.0),
+ (1024, 1024, 1024, 0.5),
+ (1024, 1024, 1024, 0.75),
+)
+
+_TRANSPOSE = (
+ (False, False),
+ (False, True),
+ (True, False),
+ (True, True),
+)
+
+_DTYPE = (
+ torch.float16, torch.bfloat16
+)
+
+def _generate_testcases():
+ testcases = itertools.product(_MATRIX_SIZES, _TRANSPOSE, _DTYPE)
+ testcases = [(*size, *trans, 128, dtype) for
+ (size, trans, dtype) in testcases]
+ return testcases
+
+_LINEAR_OP_TESTS = _generate_testcases()
+
+def _dense_and_sparse(rows, cols, sparsity, blocking, dtype, std=0.1):
+ mask = stk.random.dense_mask(rows, cols, sparsity, blocking)
+ dense = (torch.randn(rows, cols) * std * mask).type(dtype)
+ sparse = stk.ops.to_sparse(dense, blocking)
+ cuda_device = torch.device("cuda")
+ return (dense.to(cuda_device).requires_grad_(True),
+ sparse.to(cuda_device).requires_grad_(True))
+
+
+def _dense(rows, cols, dtype, std=0.1):
+ cuda_device = torch.device("cuda")
+ out = (torch.randn(rows, cols) * std).type(dtype)
+ return out.to(cuda_device).requires_grad_(True)
+
+
+def _dense_2x(rows, cols, dtype):
+ a = _dense(rows, cols, dtype)
+ return a, a.detach().requires_grad_(True)
+
+
+def _with_transpose(op, a, b, trans_a, trans_b):
+ a = a.t() if trans_a else a
+ b = b.t() if trans_b else b
+ return op(a, b)
+
+
+def _mmm(a, b, topo):
+ mask = stk.ops.to_dense(stk.ops.ones_like(topo))
+ return torch.mm(a, b) * mask
+
+
+def _sparse_out_with_transpose(op, a, b, topo, trans_a, trans_b):
+ a = a.t() if trans_a else a
+ b = b.t() if trans_b else b
+ return op(a, b, topo)
+
+
+def _mask(x, mask):
+ mask = stk.ops.to_dense(stk.ops.ones_like(mask))
+ return x * mask
+
+
+# @parameterized.parameters(*_LINEAR_OP_TESTS)
+# class LinearOpsTest(parameterized.TestCase):
+#
+# def testLinearOps_Dsd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+# # Construct the operands.
+# a_shape = (k, m) if trans_a else (m, k)
+# a_dense, a = _dense_and_sparse(*a_shape, sparsity, blocking, dtype)
+# b_shape = (n, k) if trans_b else (k, n)
+# b, bcp = _dense_2x(*b_shape, dtype)
+#
+# # Execute the matmul.
+# out = _with_transpose(stk.ops.dsd, a, b, trans_a, trans_b)
+# expected_out = _with_transpose(torch.mm, a_dense, bcp, trans_a, trans_b)
+#
+# # Compute the gradients w.r.t. the inputs.
+# expected_out.sum().backward()
+# out.sum().backward()
+#
+# # Validate the results.
+# self.assertEqual(out.dim(), 2)
+# self.assertEqual(expected_out.size()[0], out.size()[0])
+# self.assertEqual(expected_out.size()[1], out.size()[1])
+# self.assertTrue(allclose(out, expected_out))
+#
+# # LHS gradient.
+# grad = stk.ops.to_dense(a.grad)
+# expected_grad = _mask(a_dense.grad, a.grad)
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# # RHS gradient.
+# grad = b.grad
+# expected_grad = bcp.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# def testLinearOps_Dds(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+# # Construct the operands.
+# a_shape = (k, m) if trans_a else (m, k)
+# a, acp = _dense_2x(*a_shape, dtype)
+# b_shape = (n, k) if trans_b else (k, n)
+# b_dense, b = _dense_and_sparse(*b_shape, sparsity, blocking, dtype)
+#
+# # Execute the matmul.
+# out = _with_transpose(stk.ops.dds, a, b, trans_a, trans_b)
+# expected_out = _with_transpose(torch.mm, acp, b_dense, trans_a, trans_b)
+#
+# # Compute the gradients w.r.t. the inputs.
+# expected_out.sum().backward()
+# out.sum().backward()
+#
+# # Validate the results.
+# self.assertEqual(out.dim(), 2)
+# self.assertEqual(expected_out.size()[0], out.size()[0])
+# self.assertEqual(expected_out.size()[1], out.size()[1])
+# self.assertTrue(allclose(out, expected_out))
+#
+# # LHS gradient.
+# grad = a.grad
+# expected_grad = acp.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# # RHS gradient.
+# grad = stk.ops.to_dense(b.grad)
+# expected_grad = _mask(b_dense.grad, b.grad)
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# def testLinearOps_Sdd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+# # Construct the operands.
+# a_shape = (k, m) if trans_a else (m, k)
+# a, acp = _dense_2x(*a_shape, dtype)
+# b_shape = (n, k) if trans_b else (k, n)
+# b, bcp = _dense_2x(*b_shape, dtype)
+# _, topo = _dense_and_sparse(m, n, sparsity, blocking, dtype)
+#
+# # Execute the matmul.
+# out = _sparse_out_with_transpose(stk.ops.sdd, a, b, topo, trans_a, trans_b)
+# expected_out = _sparse_out_with_transpose(_mmm, acp, bcp, topo, trans_a, trans_b)
+#
+# # Compute the gradients w.r.t. the inputs.
+# expected_out.sum().backward()
+# stk.ops.sum(out).backward()
+#
+# # Validate the results.
+# out = stk.ops.to_dense(out)
+# self.assertEqual(out.dim(), 2)
+# self.assertEqual(expected_out.size()[0], out.size()[0])
+# self.assertEqual(expected_out.size()[1], out.size()[1])
+# self.assertTrue(allclose(out, expected_out))
+#
+# # LHS gradient.
+# grad = a.grad
+# expected_grad = acp.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# # RHS gradient.
+# grad = b.grad
+# expected_grad = bcp.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch210-cxx11-cu128-aarch64-linux/stk/ops/matrix_ops.py b/build/torch210-cxx11-cu128-aarch64-linux/stk/ops/matrix_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..447c72dc73439d84f58c917676cc04e64f13e97d
--- /dev/null
+++ b/build/torch210-cxx11-cu128-aarch64-linux/stk/ops/matrix_ops.py
@@ -0,0 +1,98 @@
+from ..backend import sputnik
+from ..matrix import Matrix
+import torch
+import numpy as np
+
+
+@torch.no_grad()
+def row_indices(shape, data, offsets, column_indices):
+ return sputnik.row_indices(shape, data, offsets, column_indices)
+
+
+# TODO(tgale): Replace this helper with a custom kernel. This operation
+# is much simpler to do than how it's currently implemented.
+@torch.no_grad()
+def _expand_for_blocking(idxs, blocking):
+ # Duplicate for block column dimension.
+ idxs = torch.reshape(idxs, [idxs.size()[0], 1, 2]).repeat(1, blocking, 1)
+
+ # Update the column indices.
+ idxs[:, :, 1] *= blocking
+ idxs[:, :, 1] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking])
+
+ # Duplicate for block row dimension.
+ idxs = torch.reshape(idxs, [idxs.size()[0], 1, blocking, 2])
+ idxs = idxs.repeat(1, blocking, 1, 1)
+
+ # Update the row indices.
+ idxs[:, :, :, 0] *= blocking
+ idxs[:, :, :, 0] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking, 1])
+ idxs = torch.reshape(idxs, [-1, 2])
+ return idxs
+
+
+# TODO(tgale): Add input type checking.
+@torch.no_grad()
+def to_dense(x):
+ assert isinstance(x, Matrix)
+
+ shape = (np.prod(x.shape[:-1]), x.shape[-1])
+ row_idxs = x.row_indices.type(torch.int32)
+ col_idxs = x.column_indices.type(torch.int32)
+ indices = _expand_for_blocking(torch.stack([row_idxs, col_idxs], dim=1), x.blocking)
+ indices = (indices[:, 0] * shape[1] + indices[:, 1]).type(torch.int64)
+
+ out = torch.zeros(shape[0] * shape[1], dtype=x.dtype, device=x.device)
+ out.scatter_(0, indices, x.data.flatten())
+ return out.reshape(x.size())
+
+
+@torch.no_grad()
+def _mask(x, blocking=1):
+ assert x.dim() == 2
+ assert x.size()[0] % blocking == 0
+ assert x.size()[1] % blocking == 0
+ block_rows = x.size()[0] // blocking
+ block_cols = x.size()[1] // blocking
+ x = torch.reshape(x, [block_rows, blocking, block_cols, blocking])
+ x = torch.sum(torch.abs(x), dim=(1, 3))
+ return x != 0
+
+
+# TODO(tgale): Add input type checking.
+@torch.no_grad()
+def to_sparse(x, blocking=1):
+ m = _mask(x, blocking)
+
+ # TODO(tgale): Set to appropriate type for input matrix.
+ row_nnzs = torch.sum(m, dim=1).type(torch.int32)
+ zeros = torch.zeros((1,), dtype=row_nnzs.dtype, device=row_nnzs.device)
+ offsets = torch.cat([zeros, torch.cumsum(row_nnzs, dim=0)])
+ offsets = offsets.type(torch.int32)
+
+ indices = torch.nonzero(m).type(torch.int16)
+ row_indices = indices[:, 0]
+ column_indices = indices[:, 1]
+
+ # Nonzero indices in the dense matrix.
+ nonzero_indices = torch.nonzero(m)
+ nonzero_indices = _expand_for_blocking(nonzero_indices, blocking)
+ nonzero_indices = nonzero_indices[:, 0] * x.size()[1] + nonzero_indices[:, 1]
+
+ # Gather the data and construct the sparse matrix.
+ data = torch.gather(x.flatten(), dim=0, index=nonzero_indices)
+ data = torch.reshape(data, [-1, blocking, blocking])
+ return Matrix(x.size(), data, row_indices, column_indices, offsets)
+
+
+@torch.no_grad()
+def ones_like(x):
+ return Matrix(x.size(),
+ torch.ones_like(x.data),
+ x.row_indices,
+ x.column_indices, x.offsets)
+
+
+def sum(x):
+ assert isinstance(x, Matrix)
+ return x.data.sum()
diff --git a/build/torch210-cxx11-cu128-aarch64-linux/stk/ops/matrix_ops_test.py b/build/torch210-cxx11-cu128-aarch64-linux/stk/ops/matrix_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..d172d921f6f08b0e4fb709207a458b0e1e071bd0
--- /dev/null
+++ b/build/torch210-cxx11-cu128-aarch64-linux/stk/ops/matrix_ops_test.py
@@ -0,0 +1,62 @@
+import unittest
+
+# from absl.testing import parameterized
+import stk
+import torch
+
+
+# @parameterized.parameters(
+# (8, 16, 0.0, 1),
+# (8, 16, 0.5, 1),
+# (8, 16, .95, 1),
+# (16, 8, 0.0, 1),
+# (16, 8, 0.5, 1),
+# (16, 8, .95, 1),
+# (8, 16, 0.0, 8),
+# (8, 16, 0.5, 8),
+# (8, 16, 1.0, 8),
+# (16, 8, 0.0, 8),
+# (16, 8, 0.5, 8),
+# (16, 8, 1.0, 8),
+# (128, 256, 0.5, 16),
+# (256, 128, 0.75, 32),
+# (512, 512, .875, 128))
+# class MatrixOpsTest(parameterized.TestCase):
+#
+# def testMatrixOps_FormatConversion(self, rows, cols, sparsity, blocking):
+# mask = stk.random.dense_mask(rows, cols, sparsity, blocking)
+# x = (torch.randn(rows, cols) * mask).type(torch.float16)
+#
+# # Convert the matrix to sparse format.
+# sparse_x = stk.ops.to_sparse(x, blocking)
+#
+# # Validate the matrix.
+# sparse_x.validate()
+#
+# # Validate the shape.
+# self.assertEqual(sparse_x.dim(), 2)
+# self.assertEqual(sparse_x.size()[0], rows)
+# self.assertEqual(sparse_x.size()[1], cols)
+#
+# # Validate the sparsity.
+# numblocks = rows // blocking * cols // blocking
+# nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+# self.assertEqual(sparse_x.nnz, nnz)
+#
+# # Convert back to dense format.
+# dense_x = stk.ops.to_dense(sparse_x)
+#
+# # Validate the shape.
+# self.assertEqual(dense_x.dim(), 2)
+# self.assertEqual(dense_x.size()[0], rows)
+# self.assertEqual(dense_x.size()[1], cols)
+#
+# # Validate the sparsity
+# self.assertEqual(torch.count_nonzero(dense_x).item(), nnz)
+#
+# # Validate the output.
+# self.assertTrue(torch.all(torch.eq(x, dense_x)))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch210-cxx11-cu128-aarch64-linux/stk/random/__init__.py b/build/torch210-cxx11-cu128-aarch64-linux/stk/random/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2576d1ca27283f77569a9a620c7c99fa68aaf30e
--- /dev/null
+++ b/build/torch210-cxx11-cu128-aarch64-linux/stk/random/__init__.py
@@ -0,0 +1,2 @@
+# from stk.random.random_ops import dense_mask, mask, randn
+from .random_ops import dense_mask, mask, randn
diff --git a/build/torch210-cxx11-cu128-aarch64-linux/stk/random/random_ops.py b/build/torch210-cxx11-cu128-aarch64-linux/stk/random/random_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1b36771e0eb8e7abf46bcb3b136b5fb1d29df93
--- /dev/null
+++ b/build/torch210-cxx11-cu128-aarch64-linux/stk/random/random_ops.py
@@ -0,0 +1,36 @@
+import numpy as np
+import torch
+from ..ops import matrix_ops
+
+
+@torch.no_grad()
+def dense_mask(rows, cols, sparsity, blocking=1):
+ assert sparsity >= 0.0 and sparsity <= 1.0
+ assert rows % blocking == 0 and cols % blocking == 0
+
+ block_rows, block_cols = (rows // blocking, cols // blocking)
+ nnz = round(block_rows * block_cols * (1 - sparsity))
+
+ out = np.ones(block_rows * block_cols)
+ mask = np.random.choice(out.size, out.size - nnz, replace=False)
+ out[mask] = 0.0
+
+ out = np.tile(
+ np.reshape(out, [block_rows, 1, block_cols, 1]),
+ (1, blocking, 1, blocking))
+ out = np.reshape(out, [rows, cols])
+ return torch.from_numpy(out.astype(np.float32))
+
+
+@torch.no_grad()
+def mask(m, n, sparsity, blocking=1):
+ out = dense_mask(m, n, sparsity, blocking).type(torch.float16)
+ return matrix_ops.to_sparse(out, blocking=blocking)
+
+
+@torch.no_grad()
+def randn(shape, sparsity, blocking=1):
+ shape_2d = (np.prod(shape[:-1]), shape[-1])
+ out = mask(*shape_2d, sparsity, blocking)
+ out.data.copy_(torch.randn(*out.data.shape))
+ return out.view(*shape)
diff --git a/build/torch210-cxx11-cu128-aarch64-linux/stk/random/random_ops_test.py b/build/torch210-cxx11-cu128-aarch64-linux/stk/random/random_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..71d716b78b5ec009cbf9ac2dfdf09162a0102e62
--- /dev/null
+++ b/build/torch210-cxx11-cu128-aarch64-linux/stk/random/random_ops_test.py
@@ -0,0 +1,73 @@
+import unittest
+
+# from absl.testing import parameterized
+from . import random
+import torch
+
+
+# @parameterized.parameters(
+# (8, 16, 0.0, 1),
+# (8, 16, 0.5, 1),
+# (8, 16, .95, 1),
+# (16, 8, 0.0, 1),
+# (16, 8, 0.5, 1),
+# (16, 8, .95, 1),
+# (8, 16, 0.0, 8),
+# (8, 16, 0.5, 8),
+# (8, 16, 1.0, 8),
+# (16, 8, 0.0, 8),
+# (16, 8, 0.5, 8),
+# (16, 8, 1.0, 8),
+# (128, 256, 0.5, 16),
+# (256, 128, 0.75, 32),
+# (512, 512, .875, 128))
+# class RandomOpsTest(parameterized.TestCase):
+#
+# def testRandomOps_DenseMask(self, rows, cols, sparsity, blocking):
+# mask = random.dense_mask(
+# rows, cols, sparsity, blocking)
+#
+# # Validate the shape.
+# self.assertEqual(mask.dim(), 2)
+# self.assertEqual(mask.size()[0], rows)
+# self.assertEqual(mask.size()[1], cols)
+#
+# # Validate the sparsity
+# numblocks = rows // blocking * cols // blocking
+# nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+# self.assertEqual(
+# torch.count_nonzero(mask).item(),
+# nnz)
+#
+# # Check values are zero or one.
+# self.assertTrue(
+# torch.all(torch.logical_or(
+# torch.eq(mask, 0),
+# torch.eq(mask, 1))))
+#
+# def testRandomOps_SparseMask(self, rows, cols, sparsity, blocking):
+# mask = random.mask(
+# rows, cols, sparsity, blocking)
+#
+# # Validate the matrix.
+# mask.validate()
+#
+# # Validate the shape.
+# self.assertEqual(mask.dim(), 2)
+# self.assertEqual(mask.size()[0], rows)
+# self.assertEqual(mask.size()[1], cols)
+#
+# # Validate the sparsity.
+# numblocks = rows // blocking * cols // blocking
+# nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+# self.assertEqual(mask.nnz, nnz)
+#
+# # Check values are zero or one.
+# self.assertTrue(
+# torch.all(torch.logical_or(
+# torch.eq(mask.data, 0),
+# torch.eq(mask.data, 1))))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch210-cxx11-cu128-aarch64-linux/xpu_fused_moe.py b/build/torch210-cxx11-cu128-aarch64-linux/xpu_fused_moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..0440b484e6f4073a384dcf3bc562174d2ef71d25
--- /dev/null
+++ b/build/torch210-cxx11-cu128-aarch64-linux/xpu_fused_moe.py
@@ -0,0 +1,672 @@
+# SPDX-License-Identifier: Apache-2.0
+# MegaBlocks XPU Fused MoE Implementation
+import os
+import torch
+
+from ._ops import ops, add_op_namespace_prefix
+
+from torch.library import register_fake
+
+
+def resolve_dtensor(weight: torch.Tensor):
+ """Convert DTensor to local tensor for use with custom ops."""
+ from torch.distributed._tensor import DTensor
+ if isinstance(weight, DTensor):
+ return weight.to_local()
+ return weight
+
+
+# Register fake/meta kernels for torch.compile compatibility
+def _register_xpu_fake_kernels():
+ """Register fake kernels for XPU MoE operations to support torch.compile."""
+
+ def _register_if_available(op_name, fn):
+ if hasattr(ops, op_name):
+ register_fake(add_op_namespace_prefix(op_name))(fn)
+
+ _register_if_available(
+ "cutlass_grouped_gemm_interface",
+ lambda ptr_A, ptr_B, ptr_scales, ptr_bias, ptr_D, expert_first_token_offset, N, K, num_experts, is_B_int4, is_B_mxfp4: ptr_D,
+ )
+
+ _register_if_available(
+ "fused_moe_prologue",
+ lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, ep_rank, ep_size, num_experts_on_rank: None,
+ )
+
+ _register_if_available(
+ "moe_gather",
+ lambda output, moe_output, topk_weights, permuted_row_to_unpermuted_row, unpermuted_row_to_permuted_row, expert_first_token_offset, num_experts: None,
+ )
+
+ _register_if_available(
+ "silu_and_mul",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "mul_and_silu",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_and_mul",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_tanh_and_mul",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_fast",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_new",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_quick",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "swigluoai_and_mul",
+ lambda out, input, alpha=1.702, limit=7.0: None,
+ )
+
+
+# Register fake kernels on module load
+_register_xpu_fake_kernels()
+
+
+# default
+def cutlass_grouped_gemm(input_A, input_B, bias, output, expert_token_count, n,
+ k, num_experts):
+ # expert_token_count_ = torch.tensor(expert_token_count,
+ # dtype=torch.int64,
+ # device=input_A.device)
+ # if bias is not None:
+ # bias = bias.repeat_interleave(expert_token_count_, dim=0).float()
+
+ def exclusive_prefix_sum(arr):
+ prefix = [0]
+ for i, x in enumerate(arr):
+ prefix.append(prefix[-1] + x)
+ return prefix
+
+ expert_offset = torch.tensor(exclusive_prefix_sum(expert_token_count),
+ dtype=torch.int64,
+ device="xpu")
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=None,
+ ptr_bias=bias,
+ ptr_D=output,
+ expert_first_token_offset=expert_offset,
+ N=n,
+ K=k,
+ num_experts=num_experts,
+ is_B_int4=False,
+ is_B_mxfp4=False)
+
+
+def cutlass_grouped_gemm_xe2(input_A, input_B, scales, bias, output,
+ num_rows_per_expert, n, k, num_experts, is_B_int4,
+ is_B_mxfp4):
+ expert_first_token_offset = torch.cat([
+ torch.tensor([0],
+ dtype=num_rows_per_expert.dtype,
+ device=num_rows_per_expert.device),
+ torch.cumsum(num_rows_per_expert, dim=0)
+ ]).to(torch.int64)
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=scales,
+ ptr_bias=bias,
+ ptr_D=output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=n,
+ K=k,
+ num_experts=num_experts,
+ is_B_int4=is_B_int4,
+ is_B_mxfp4=is_B_mxfp4)
+
+
+def ceilDiv(a, b):
+ return (a + b - 1) // b
+
+
+def compute_num_tokens_per_block(num_tokens, num_experts_per_node):
+ for num_tokens_per_block in [32, 64, 128, 256, 512, 1024]:
+ num_blocks_per_seq = ceilDiv(num_tokens, num_tokens_per_block)
+ if num_blocks_per_seq * num_experts_per_node <= num_tokens_per_block:
+ return num_tokens_per_block
+ return 1024
+
+
+def _bytes_to_typed_tensor(byte_tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
+ """Reinterpret a uint8 buffer as a typed tensor by copying bytes.
+
+ This avoids `Tensor.view(dtype)` which can fail under torch.compile
+ constant folding when shape divisibility is not proven.
+ """
+ if byte_tensor.dtype != torch.uint8:
+ raise ValueError("byte_tensor must be uint8")
+ itemsize = torch.empty((), dtype=dtype).element_size()
+ numel = byte_tensor.numel() // itemsize
+ out = torch.empty((numel,), dtype=dtype, device=byte_tensor.device)
+ out.view(torch.uint8).copy_(byte_tensor.contiguous())
+ return out
+
+
+def implement_zp(qweight):
+ # change u4 to s4 to avoid zero point in gemm kernel
+ # only support default zero point now
+ assert qweight.dtype == torch.uint8, "Input tensor must be uint8"
+
+ high_u4 = (qweight >> 4) & 0x0F
+ low_u4 = qweight & 0x0F
+
+ high_s8 = high_u4.to(torch.int8)
+ low_s8 = low_u4.to(torch.int8)
+
+ high_s8 = high_s8 - 8
+ low_s8 = low_s8 - 8
+
+ def pack_compact(a, b):
+
+ def process_number(x):
+ sign = (x < 0).to(torch.uint8)
+ abs_low3 = (x.view(torch.uint8) & 0x7).to(torch.uint8)
+ return (sign << 3) | abs_low3
+
+ packed_a = process_number(a)
+ packed_b = process_number(b)
+
+ return (packed_a << 4) | packed_b
+
+ result = pack_compact(high_s8, low_s8)
+
+ return result
+
+
+def xpu_fused_moe(hidden_states,
+ w13,
+ w13_scales,
+ w13_bias,
+ w2,
+ w2_scales,
+ w2_bias,
+ topk_weights,
+ topk_ids,
+ n_experts_per_token,
+ activation,
+ num_experts,
+ ep_rank=0,
+ ep_size=1,
+ is_fp8=False,
+ is_int4=False,
+ is_mxfp4=False):
+ '''
+ hidden_states: [num_rows, hidden_size]
+ w13: [num_experts, 2*inter_size, hidden_size]
+ w13_scales:
+ None for bf16/fp16
+ or [num_experts] for fp8
+ or [num_experts, 2*inter_size, hidden_size // group_size] for 4bits
+ w13_bias: [num_experts, 2*inter_size] or None
+ w2: [num_experts, hidden_size, inter_size]
+ w2_scales:
+ None for bf16/fp16
+ or [num_experts] for fp8
+ or [num_experts, hidden_size, inter_size // group_size] for 4bits
+ w2_bias: [num_experts, hidden_size] or None
+ topk_weights: [num_rows, topk]
+ topk_ids: [num_rows, topk]
+ n_experts_per_token: int
+ activation: str
+ num_experts: int
+ is_int4: bool
+ is_mxfp4: bool
+ '''
+
+ # Resolve DTensors to local tensors before passing to custom ops
+ hidden_states = resolve_dtensor(hidden_states)
+ w13 = resolve_dtensor(w13)
+ w2 = resolve_dtensor(w2)
+ if w13_scales is not None:
+ w13_scales = resolve_dtensor(w13_scales)
+ if w13_bias is not None:
+ w13_bias = resolve_dtensor(w13_bias)
+ if w2_scales is not None:
+ w2_scales = resolve_dtensor(w2_scales)
+ if w2_bias is not None:
+ w2_bias = resolve_dtensor(w2_bias)
+ topk_weights = resolve_dtensor(topk_weights)
+ topk_ids = resolve_dtensor(topk_ids)
+
+ output = torch.empty_like(hidden_states)
+ num_rows, hidden_size = list(hidden_states.shape)
+
+ dim_last = w13.shape[-1]
+ dim_second_last = w13.shape[-2]
+
+ # w13 is combined gate+up weights, so one dimension is 2*inter_size
+ # Determine which dimension is hidden_size and which is 2*inter_size
+ if dim_second_last == hidden_size:
+ # w13 is [E, hidden_size, 2*inter_size] - standard layout
+ inter_size = dim_last // 2
+ needs_transpose = False
+ else:
+ # w13 is [E, 2*inter_size, hidden_size] - needs transpose
+ inter_size = dim_second_last // 2
+ needs_transpose = True
+
+ assert w13.is_contiguous() and w2.is_contiguous()
+
+ # 4bits support [E, N, K]
+ # other types [E, K, N]
+ if not is_int4 and not is_mxfp4:
+ if not hasattr(w13, 'xpu_fused_moe'):
+ if needs_transpose:
+ w13.data = w13.transpose(-1, -2).contiguous()
+ w2.data = w2.transpose(-1, -2).contiguous()
+ w13.xpu_fused_moe = True
+ w13.inter_size = inter_size
+ else:
+ inter_size = w13.inter_size
+
+ if is_int4 and not hasattr(w13, 'xpu_fused_moe'):
+ w13_tmp = torch.empty_like(w13)
+ w2_tmp = torch.empty_like(w2)
+ for i in range(num_experts):
+ w13_tmp[i] = implement_zp(w13[i])
+ w2_tmp[i] = implement_zp(w2[i])
+ w13_tmp = w13_tmp.contiguous()
+ w2_tmp = w2_tmp.contiguous()
+ w13.data = w13_tmp
+ w2.data = w2_tmp
+ w13.xpu_fused_moe = True
+
+ # TODO: will all integrated in Cpp func. Temporary expose before gemm fusion
+ num_experts_per_node = num_experts
+ experts_per_token = n_experts_per_token
+ num_moe_inputs = n_experts_per_token * num_rows
+ permuted_elems = num_moe_inputs * hidden_size
+ # interbuf_elems = num_moe_inputs * inter_size
+ permuted_row_to_unpermuted_row_size = num_moe_inputs * 4
+ permuted_token_selected_experts_size = num_moe_inputs * 4
+ src_to_dest_map_size = experts_per_token * num_rows * 4
+ expert_first_token_offset_size = (num_experts_per_node + 1) * 8
+ num_tokens_per_block = compute_num_tokens_per_block(
+ num_rows, num_experts_per_node)
+ num_blocks_per_seq = ceilDiv(num_rows, num_tokens_per_block)
+ blocked_expert_counts_size = num_experts_per_node * num_blocks_per_seq * 4
+ blocked_expert_counts_cumsum_size = blocked_expert_counts_size
+ blocked_row_to_unpermuted_row_size = num_experts_per_node * num_rows * 4
+ permuted_data_size = permuted_elems * hidden_states.element_size()
+ permuted_token_final_scales_size = num_moe_inputs * 4
+
+ ws_map = {}
+ map_offset = 0
+
+ def config_ws(name, size):
+ nonlocal map_offset
+ if size % 256 != 0:
+ size += 256 - size % 256
+ ws_map[name] = (size, map_offset)
+ map_offset += size
+
+ config_ws("permuted_row_to_unpermuted_row",
+ permuted_row_to_unpermuted_row_size)
+ config_ws("permuted_token_selected_experts",
+ permuted_token_selected_experts_size)
+ config_ws("unpermuted_row_to_permuted_row", src_to_dest_map_size)
+ config_ws("blocked_expert_counts", blocked_expert_counts_size)
+ config_ws("blocked_expert_counts_cumsum",
+ blocked_expert_counts_cumsum_size)
+ config_ws("blocked_row_to_unpermuted_row",
+ blocked_row_to_unpermuted_row_size)
+ config_ws("expert_first_token_offset", expert_first_token_offset_size)
+ config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
+ config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
+
+ workspace = torch.zeros(map_offset,
+ dtype=torch.uint8,
+ device=hidden_states.device)
+ if topk_ids.dtype == torch.int32:
+ topk_ids = topk_ids.to(torch.int64)
+ ops.fused_moe_prologue(
+ input=hidden_states,
+ token_selected_experts=topk_ids,
+ token_final_scales=topk_weights,
+ workspace=workspace,
+ hidden_size=hidden_size,
+ inter_size=inter_size,
+ ep_rank=ep_rank,
+ ep_size=ep_size,
+ num_experts_on_rank=num_experts_per_node)
+
+ expert_first_token_offset_bytes = workspace[
+ ws_map["expert_first_token_offset"][1]:
+ ws_map["expert_first_token_offset"][1] +
+ expert_first_token_offset_size]
+ unpermuted_row_to_permuted_row_bytes = workspace[
+ ws_map["unpermuted_row_to_permuted_row"][1]:
+ ws_map["unpermuted_row_to_permuted_row"][1] +
+ src_to_dest_map_size]
+ permuted_row_to_unpermuted_row_bytes = workspace[
+ ws_map["permuted_row_to_unpermuted_row"][1]:
+ ws_map["permuted_row_to_unpermuted_row"][1] +
+ permuted_row_to_unpermuted_row_size]
+
+ if torch.compiler.is_compiling():
+ expert_first_token_offset = _bytes_to_typed_tensor(
+ expert_first_token_offset_bytes, torch.int64
+ )
+ unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
+ unpermuted_row_to_permuted_row_bytes, torch.int32
+ )
+ permuted_row_to_unpermuted_row = _bytes_to_typed_tensor(
+ permuted_row_to_unpermuted_row_bytes, torch.int32
+ )
+ else:
+ expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
+ unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
+ permuted_row_to_unpermuted_row = permuted_row_to_unpermuted_row_bytes.view(torch.int32)
+ gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
+ ws_map["overlapped_gemm1_gemm2_inputs"][1] +
+ permuted_data_size].view(hidden_states.dtype).view(
+ num_moe_inputs, hidden_size)
+ # permuted_token_final_scales = workspace[
+ # ws_map["permuted_token_final_scales"][1]:
+ # ws_map["permuted_token_final_scales"][1] +
+ # permuted_token_final_scales_size].view(torch.float)
+ gemm1_output = torch.empty((num_moe_inputs, 2 * inter_size),
+ dtype=hidden_states.dtype,
+ device=hidden_states.device)
+
+ ########### gemm1 ##################
+ input_B = w13
+
+ if not is_fp8 and not is_int4 and not is_mxfp4:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=gemm1_input,
+ ptr_B=input_B,
+ ptr_scales=None,
+ ptr_bias=w13_bias,
+ ptr_D=gemm1_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=2 * inter_size,
+ K=hidden_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+ else:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=gemm1_input,
+ ptr_B=input_B,
+ ptr_scales=w13_scales,
+ ptr_bias=w13_bias,
+ ptr_D=gemm1_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=2 * inter_size,
+ K=hidden_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+
+ # act
+ act_output = torch.empty((num_moe_inputs, inter_size),
+ dtype=gemm1_output.dtype,
+ device=gemm1_output.device)
+ if activation == "silu":
+ ops.silu_and_mul(act_output, gemm1_output)
+ elif activation == "gelu":
+ ops.gelu_and_mul(act_output, gemm1_output)
+ elif activation == "swigluoai":
+ ops.swigluoai_and_mul(act_output, gemm1_output, 1.702, 7.0)
+ else:
+ raise ValueError(f"Unsupported FusedMoe activation: {activation}.")
+
+ ########### gemm2 ##################
+ input_A = act_output.contiguous()
+ input_B = w2
+ gemm2_output = torch.empty((num_moe_inputs, hidden_size),
+ dtype=hidden_states.dtype,
+ device=hidden_states.device)
+ if not is_fp8 and not is_int4 and not is_mxfp4:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=None,
+ ptr_bias=w2_bias,
+ ptr_D=gemm2_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=hidden_size,
+ K=inter_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+ else:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=w2_scales,
+ ptr_bias=w2_bias,
+ ptr_D=gemm2_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=hidden_size,
+ K=inter_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+
+ ops.moe_gather(output, gemm2_output, topk_weights,
+ permuted_row_to_unpermuted_row,
+ unpermuted_row_to_permuted_row,
+ expert_first_token_offset,
+ num_experts_per_node)
+ return output
+
+
+def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor:
+ """Apply jitter to the input tensor for regularization."""
+ low = 1.0 - moe_jitter_eps
+ high = 1.0 + moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return x * (low + noise * (high - low))
+
+
+def compute_top_k(scores: torch.Tensor, moe_top_k: int):
+ """Compute the top-k scores from the logits."""
+ if moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, moe_top_k, dim=-1)
+
+
+def route_tokens_xpu(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: torch.Tensor,
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ training: bool = False,
+) -> tuple:
+ """Route tokens to experts and compute expert weights and indices (XPU version)."""
+ if training and moe_jitter_eps is not None:
+ x = apply_jitter(x, moe_jitter_eps)
+
+ x_flat = x.view(-1, x.shape[-1])
+ logits = torch.nn.functional.linear(x_flat, router_weight, router_bias)
+ expert_weights, expert_indices = compute_top_k(logits, moe_top_k)
+ expert_weights = expert_weights.softmax(dim=-1)
+ if moe_normalize_expert_weights is not None:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+
+ return logits, expert_weights, expert_indices
+
+
+def _get_device_mesh(model):
+ """Extract device_mesh from child's unused pre_hook closure for EP support."""
+ try:
+ hook = next(
+ h
+ for h in model.experts._forward_pre_hooks.values()
+ if "device_mesh" in h.__code__.co_freevars
+ )
+ return hook.__closure__[
+ hook.__code__.co_freevars.index("device_mesh")
+ ].cell_contents
+ except Exception:
+ return None
+
+
+class MegaBlocksMoeMLP(torch.nn.Module):
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> tuple:
+ """
+ Forward pass through the MoE layer.
+
+ Args:
+ x: Input tensor of shape [batch_size, seq_len, hidden_size] or [tokens, hidden_size]
+
+ Returns:
+ Tuple of (output, expert_weights) where:
+ - output: Tensor of same shape as input
+ - expert_weights: Expert weights for each token [tokens, top_k]
+ """
+
+ # Get MoE parameters from the wrapped modules
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+
+ # Get EP (Expert Parallelism) parameters
+ ep_size = 1
+ ep_rank = 0
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = _get_device_mesh(self)
+ if device_mesh is not None:
+ expert_parallel_group = device_mesh.get_group()
+ if expert_parallel_group is not None:
+ import torch.distributed as dist
+ if dist.is_initialized():
+ ep_size = dist.get_world_size(expert_parallel_group)
+ ep_rank = dist.get_rank(expert_parallel_group)
+
+ # Number of experts on this rank
+ num_experts_on_rank = moe_num_experts // ep_size
+
+ # Detect activation type - check for GptOss-style swigluoai activation
+ # GptOssExperts has alpha and limit attributes for swigluoai
+ if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
+ activation = "swigluoai"
+ else:
+ activation = getattr(self.experts, "activation", "silu")
+
+ # Get weight tensors - support different naming conventions
+ if hasattr(self.experts, "gate_up_proj"):
+ w13 = self.experts.gate_up_proj
+ # NOTE: swigluoai_and_mul kernel expects interleaved layout [g0,u0,g1,u1,...]
+ # which matches GptOss's gate_up_proj layout, so no conversion needed.
+
+ elif hasattr(self.experts, "w1"):
+ # Combine w1 and w3 if stored separately
+ w1 = self.experts.w1
+ w3 = getattr(self.experts, "w3", None)
+ if w3 is not None:
+ w13 = torch.cat([w1, w3], dim=-2)
+ else:
+ w13 = w1
+ else:
+ raise AttributeError("experts module must have 'gate_up_proj' or 'w1' attribute")
+
+ if hasattr(self.experts, "down_proj"):
+ w2 = self.experts.down_proj
+ elif hasattr(self.experts, "w2"):
+ w2 = self.experts.w2
+ else:
+ raise AttributeError("experts module must have 'down_proj' or 'w2' attribute")
+
+ # Get optional bias tensors
+ w13_bias = getattr(self.experts, "gate_up_proj_bias", None)
+ w2_bias = getattr(self.experts, "down_proj_bias", None)
+
+ # Get quantization info
+ is_fp8 = getattr(self.experts, "is_fp8", False)
+ is_int4 = getattr(self.experts, "is_int4", False)
+ is_mxfp4 = getattr(self.experts, "is_mxfp4", False)
+
+ w13_scales = getattr(self.experts, "gate_up_proj_scales", None)
+ w2_scales = getattr(self.experts, "down_proj_scales", None)
+
+ # Store original shape
+ in_shape = x.size()
+
+ # Route tokens to experts
+ logits, expert_weights, expert_indices = route_tokens_xpu(
+ x,
+ self.router.weight,
+ getattr(self.router, "bias", None),
+ moe_top_k,
+ moe_num_experts,
+ moe_jitter_eps,
+ moe_normalize_expert_weights,
+ self.training,
+ )
+
+ # Reshape input for fused MoE
+ x_flat = x.view(-1, x.shape[-1])
+
+ # Call XPU fused MoE kernel
+ output = xpu_fused_moe(
+ hidden_states=x_flat,
+ w13=w13,
+ w13_scales=w13_scales,
+ w13_bias=w13_bias,
+ w2=w2,
+ w2_scales=w2_scales,
+ w2_bias=w2_bias,
+ topk_weights=expert_weights.float(),
+ topk_ids=expert_indices,
+ n_experts_per_token=moe_top_k,
+ activation=activation,
+ num_experts=num_experts_on_rank,
+ ep_rank=ep_rank,
+ ep_size=ep_size,
+ is_fp8=is_fp8,
+ is_int4=is_int4,
+ is_mxfp4=is_mxfp4,
+ )
+
+ # All-reduce across EP group to combine partial expert outputs
+ if ep_size > 1 and expert_parallel_group is not None:
+ import torch.distributed as dist
+ dist.all_reduce(output, op=dist.ReduceOp.SUM, group=expert_parallel_group)
+
+ # Restore original shape
+ output = output.view(in_shape)
+
+ return output, expert_weights
+
+
+# Export classes and functions
+__all__ = [
+ "MegaBlocksMoeMLP",
+ "xpu_fused_moe",
+ "cutlass_grouped_gemm",
+ "cutlass_grouped_gemm_xe2",
+]
\ No newline at end of file
diff --git a/build/torch210-cxx11-cu128-x86_64-linux/__init__.py b/build/torch210-cxx11-cu128-x86_64-linux/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..38075732c6d8fa0e1e6ef493145e1aca3851ae6b
--- /dev/null
+++ b/build/torch210-cxx11-cu128-x86_64-linux/__init__.py
@@ -0,0 +1,202 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+from ._ops import ops
+
+from .grouped_gemm import backend as gg_backend
+from .grouped_gemm import ops as gg_ops
+
+
+from ._layers.arguments import Arguments
+from ._layers.dmoe import ParallelDroplessMLP, dMoE
+from ._layers.glu import SparseGLU
+from ._layers.mlp import MLP, SparseMLP
+from ._layers.moe import MoE, ParallelMLP, get_load_balancing_loss
+
+from . import layers
+
+# This section contains the direct kernel exports (not inlcuded in the original code)
+def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
+ """
+ Compute exclusive cumulative sum along the specified dimension.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ result = ops.exclusive_cumsum(x, dim)
+ out.copy_(result)
+ return out
+
+
+def inclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
+ """
+ Compute inclusive cumulative sum along the specified dimension.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ result = ops.inclusive_cumsum(x, dim)
+ out.copy_(result)
+ return out
+
+
+def histogram(x: torch.Tensor, num_bins: int) -> torch.Tensor:
+ """
+ Compute histogram of input tensor values.
+
+ Args:
+ x: Input tensor
+ num_bins: Number of histogram bins
+
+ Returns:
+ Histogram tensor with counts for each bin
+ """
+ return ops.histogram(x, num_bins)
+
+
+def indices(
+ padded_bins: torch.Tensor,
+ block_size: int,
+ output_block_rows: int,
+ output_block_columns: int,
+) -> torch.Tensor:
+ """
+ Construct indices from padded bins for sparse operations.
+
+ Args:
+ padded_bins: Tensor containing bin boundaries
+ block_size: Size of each block
+ output_block_rows: Number of rows in output blocks
+ output_block_columns: Number of columns in output blocks
+
+ Returns:
+ Tensor containing constructed indices
+ """
+ return ops.indices(padded_bins, block_size, output_block_rows, output_block_columns)
+
+
+def replicate_forward(
+ x: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Forward pass of replicate operation - replicate values according to bin sizes.
+
+ Args:
+ x: Input tensor with values to replicate
+ bins: Tensor containing bin sizes
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ return ops.replicate_forward(x, bins, out)
+
+
+def replicate_backward(
+ grad: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Backward pass of replicate operation - reduce gradients back to bins.
+
+ Args:
+ grad: Gradient tensor to reduce
+ bins: Tensor containing bin sizes
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ return ops.replicate_backward(grad, bins, out)
+
+
+def sort(
+ x: torch.Tensor, end_bit: int, x_out: torch.Tensor, iota_out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Radix sort with index tracking.
+
+ Args:
+ x: Input tensor to sort
+ end_bit: Number of bits to consider in sorting
+ x_out: Output tensor for sorted values
+ iota_out: Output tensor for sorted indices
+
+ Returns:
+ The sorted values tensor
+ """
+ return ops.sort(x, end_bit, x_out, iota_out)
+
+
+# Convenience functions for common use cases
+def cumsum(x: torch.Tensor, dim: int = -1, exclusive: bool = False) -> torch.Tensor:
+ """
+ Compute cumulative sum with automatic output allocation.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum (default: last dimension)
+ exclusive: Whether to compute exclusive (True) or inclusive (False) cumsum
+
+ Returns:
+ New tensor containing the cumulative sum
+ """
+ out = torch.empty_like(x)
+ if exclusive:
+ return exclusive_cumsum(x, dim, out)
+ else:
+ return inclusive_cumsum(x, dim, out)
+
+
+def argsort(x: torch.Tensor, end_bit: int = 32) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Sort tensor and return both sorted values and indices.
+
+ Args:
+ x: Input tensor to sort
+ end_bit: Number of bits to consider in sorting
+
+ Returns:
+ Tuple of (sorted_values, sorted_indices)
+ """
+ x_out = torch.empty_like(x)
+ iota_out = torch.empty_like(x)
+ sort(x, end_bit, x_out, iota_out)
+ return x_out, iota_out
+
+
+# Export public API
+__all__ = [
+ "MyReplacementLayer",
+ # Direct kernel exports
+ "exclusive_cumsum",
+ "inclusive_cumsum",
+ "histogram",
+ "indices",
+ "replicate_forward",
+ "replicate_backward",
+ "sort",
+ "cumsum",
+ "argsort",
+ # Original exports
+ "Arguments",
+ "ParallelDroplessMLP",
+ "dMoE",
+ "SparseGLU",
+ "MLP",
+ "SparseMLP",
+ "MoE",
+ "ParallelMLP",
+ "get_load_balancing_loss",
+]
diff --git a/build/torch210-cxx11-cu128-x86_64-linux/_layers/__init__.py b/build/torch210-cxx11-cu128-x86_64-linux/_layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a720e7a2cc4e44636f6e433a2750e945dc38e8b2
--- /dev/null
+++ b/build/torch210-cxx11-cu128-x86_64-linux/_layers/__init__.py
@@ -0,0 +1,10 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# from megablocks.layers.dmoe import dMoE
+from .moe import MoE
+
+__all__ = [
+ 'MoE',
+ # 'dMoE',
+]
diff --git a/build/torch210-cxx11-cu128-x86_64-linux/_layers/activation_fn.py b/build/torch210-cxx11-cu128-x86_64-linux/_layers/activation_fn.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e1d956704840aa4daf7d1d71d24e051567feab9
--- /dev/null
+++ b/build/torch210-cxx11-cu128-x86_64-linux/_layers/activation_fn.py
@@ -0,0 +1,33 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Callable, Union
+
+import torch
+from ..stk import Matrix
+
+
+def act_fn(
+ x: Matrix,
+ function: Callable,
+ return_grad_fn: bool = False,
+ **kwargs,
+) -> Union[tuple[Matrix, Any] | Matrix]:
+ assert isinstance(x, Matrix)
+ with torch.set_grad_enabled(torch.is_grad_enabled() or return_grad_fn):
+ if return_grad_fn:
+ x.data.requires_grad = True
+ out = function(x.data, **kwargs)
+ y = Matrix(
+ x.size(),
+ out,
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
+ if return_grad_fn:
+ return y, out.backward
+ return y
diff --git a/build/torch210-cxx11-cu128-x86_64-linux/_layers/all_to_all.py b/build/torch210-cxx11-cu128-x86_64-linux/_layers/all_to_all.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ac7067bcaa34db1d82b340c43550fe3577aa7a3
--- /dev/null
+++ b/build/torch210-cxx11-cu128-x86_64-linux/_layers/all_to_all.py
@@ -0,0 +1,54 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import torch.distributed as dist
+
+
+class AllToAllOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op):
+ out = torch.empty((sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype)
+
+ ctx.input_shape = x.shape
+ ctx.output_split_sizes = output_split_sizes
+ ctx.input_split_sizes = input_split_sizes
+ ctx.group = group
+ handle = dist.all_to_all_single(
+ out,
+ x,
+ output_split_sizes=output_split_sizes,
+ input_split_sizes=input_split_sizes,
+ group=group,
+ async_op=async_op,
+ )
+ return out, handle
+
+ @staticmethod
+ def backward(ctx, grad, _):
+ if ctx.needs_input_grad[0]:
+ out = torch.empty(
+ ctx.input_shape,
+ device=grad.device,
+ dtype=grad.dtype,
+ )
+ dist.all_to_all_single(
+ out,
+ grad,
+ output_split_sizes=ctx.input_split_sizes,
+ input_split_sizes=ctx.output_split_sizes,
+ group=ctx.group,
+ )
+ return out, None, None, None, None
+ return None, None, None, None, None
+
+
+def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False):
+ return AllToAllOp.apply(
+ x,
+ output_split_sizes,
+ input_split_sizes,
+ group,
+ async_op,
+ )
diff --git a/build/torch210-cxx11-cu128-x86_64-linux/_layers/arguments.py b/build/torch210-cxx11-cu128-x86_64-linux/_layers/arguments.py
new file mode 100644
index 0000000000000000000000000000000000000000..4db9b1bd38bc2e2f421625c124f86b85f45c5ae0
--- /dev/null
+++ b/build/torch210-cxx11-cu128-x86_64-linux/_layers/arguments.py
@@ -0,0 +1,101 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import dataclasses
+from functools import partial
+from typing import Any, Callable, Optional, Union
+
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+
+# import megablocks.grouped_gemm_util as grouped_gemm
+from .. import grouped_gemm_util as grouped_gemm
+
+# Type annotation for in-place Tensor initialization function.
+InitFn = Union[Callable[[torch.Tensor], None], partial[torch.Tensor]]
+
+_ALLOWED_BITWIDTHS = (-1, 4, 8)
+
+DEFAULT_ACTIVATION_FN = partial(F.gelu, approximate='tanh')
+
+
+@dataclasses.dataclass
+class Arguments:
+ # Model arguments.
+ hidden_size: int = 1024
+ ffn_hidden_size: int = 4096
+ num_layers: int = 1
+ bias: bool = True
+ return_bias: bool = True
+ activation_fn: Optional[Callable] = DEFAULT_ACTIVATION_FN
+
+ # MoE arguments.
+ moe_num_experts: int = 1
+ moe_top_k: int = 1
+ moe_capacity_factor: int = 1
+ moe_normalize_expert_weights: Optional[Union[int, float]] = None
+ moe_loss_weight: float = 0.1
+ moe_jitter_eps: Optional[float] = None
+ moe_lbl_in_fp32: bool = False
+
+ # Parallelism arguments.
+ moe_expert_model_parallelism: bool = False
+ expert_parallel_group: Optional[dist.ProcessGroup] = None
+ pipeline_model_parallel_size: int = 1
+ num_layers_per_virtual_pipeline_stage: Optional[int] = None
+
+ # Compute arguments.
+ memory_optimized_mlp: bool = False
+ mlp_type: str = 'mlp'
+ mlp_impl: str = 'sparse'
+
+ # Initialization arguments.
+ fp16: bool = True
+ bf16: bool = False
+ device: Union[int, torch.device] = dataclasses.field(default_factory=torch.cuda.current_device)
+ init_method: InitFn = partial(torch.nn.init.normal_, mean=0.0, std=0.02)
+ output_layer_init_method: InitFn = init_method
+
+ # Benchmarking arguments.
+ uniform_expert_assignment: bool = False
+
+ # shared expert arguments
+ shared_expert: bool = False # enable using shared expert
+ fc_cls: Any = torch.nn.Linear # class of the fully connected layer in shared expert (purpose: to allow using custom FC layer eg te.Linear (for FP8))
+ fc_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict,) # kwargs for custom fc layers
+ remat_act_fn: bool = True # enable act fn to be rematerialized instead of stored
+ shared_expert_hidden_size: Optional[
+ int] = None # hidden size of the shared expert IF we want to set it to something different from hidden_size
+ shared_expert_weighted_sum: bool = False # enable using weighted sum for shared expert output (wieghted by number of experts used)
+
+ # Router Z-loss arguments
+ moe_zloss_weight: float = 0 # 1e-3 is a reasonable value
+ moe_zloss_in_fp32: bool = False
+
+ def __post_init__(self):
+ # Sparse MLP is not supported with triton >=3.2.0
+ # TODO: Remove this once sparse is supported with triton >=3.2.0
+ if self.__getattribute__('mlp_impl') == 'sparse':
+ try:
+ import triton
+ if triton.__version__ >= '3.2.0':
+ raise ValueError(
+ 'Sparse MLP is not supported with triton >=3.2.0. Please use mlp_impl="grouped" instead.',
+ )
+ except ImportError:
+ raise ImportError('Triton is required for sparse MLP implementation')
+
+ if self.__getattribute__('mlp_impl') == 'grouped':
+ grouped_gemm.assert_grouped_gemm_is_available()
+
+ if self.shared_expert_hidden_size is None:
+ self.shared_expert_hidden_size = self.ffn_hidden_size
+
+
+def from_megatron(megatron_args: Any):
+ args = Arguments()
+ for field in dataclasses.fields(args):
+ if hasattr(megatron_args, field.name):
+ setattr(args, field.name, getattr(megatron_args, field.name))
+ return args
diff --git a/build/torch210-cxx11-cu128-x86_64-linux/_layers/common.py b/build/torch210-cxx11-cu128-x86_64-linux/_layers/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d07109702963ba48a3b94ab860807954dfd79c1
--- /dev/null
+++ b/build/torch210-cxx11-cu128-x86_64-linux/_layers/common.py
@@ -0,0 +1,26 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+from .arguments import Arguments
+
+
+def dtype(args: Arguments):
+ if args.fp16:
+ return torch.float16
+ elif args.bf16:
+ return torch.bfloat16
+ return None
+
+
+def cast_if_autocast_enabled(tensor):
+ if torch.is_autocast_enabled():
+ if tensor.device.type == 'cuda':
+ dtype = torch.get_autocast_gpu_dtype()
+ elif tensor.device.type == 'cpu':
+ dtype = torch.get_autocast_cpu_dtype()
+ else:
+ raise NotImplementedError()
+ return tensor.to(dtype=dtype)
+ return tensor
diff --git a/build/torch210-cxx11-cu128-x86_64-linux/_layers/dmlp_registry.py b/build/torch210-cxx11-cu128-x86_64-linux/_layers/dmlp_registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..de2ed047042e438c7190ebb139b6f7f30009734c
--- /dev/null
+++ b/build/torch210-cxx11-cu128-x86_64-linux/_layers/dmlp_registry.py
@@ -0,0 +1,42 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Union
+
+from . import glu, mlp
+from .arguments import Arguments
+
+MlpType = Union[mlp.SparseMLP, glu.SparseGLU]
+
+_REGISTRY = {
+ 'mlp': {
+ 'grouped': mlp.GroupedMLP,
+ 'sparse': mlp.SparseMLP,
+ },
+ 'glu': {
+ 'grouped': glu.GroupedGLU,
+ 'sparse': glu.SparseGLU,
+ },
+}
+
+
+def get(args: Arguments) -> MlpType:
+ """Returns an MLP for use in a dMoE instance.
+
+ Uses the provided arguments to instantiate the appropriate
+ MLP instance. This only contains MLPs for use in dMoEs
+ (ie. only for the dropless versions of MoEs).
+
+ Args:
+ args: propagated Arguments dataclass.
+
+ Returns:
+ An instantiated MLP constructed using the input args.
+ """
+ if args.mlp_type not in _REGISTRY:
+ raise ValueError(f'Unsupported mlp type: {args.mlp_type}')
+
+ if args.mlp_impl not in _REGISTRY[args.mlp_type]:
+ raise ValueError(f'{args.mlp_type} does not support {args.mlp_impl} backend.',)
+
+ return _REGISTRY[args.mlp_type][args.mlp_impl](args)
diff --git a/build/torch210-cxx11-cu128-x86_64-linux/_layers/dmoe.py b/build/torch210-cxx11-cu128-x86_64-linux/_layers/dmoe.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d0375a4df2f27134c4127e60be04f3b45693050
--- /dev/null
+++ b/build/torch210-cxx11-cu128-x86_64-linux/_layers/dmoe.py
@@ -0,0 +1,337 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import numpy as np
+import torch
+
+# try:
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/dmoe.py is needed.',
+# )
+
+# import megablocks.ops as ops
+# # from megablocks.ops import ops
+# from megablocks.layers import common, dmlp_registry, moe, mpu
+# from megablocks.layers.arguments import Arguments
+
+from .. import stk
+from .. import ops
+from . import common, dmlp_registry, moe, mpu
+from .arguments import Arguments
+
+def promote_scalar(x):
+ return x.view(1) if not len(x.size()) else x
+
+
+class ParallelDroplessMLP(moe.ParallelMLP):
+
+ def __init__(self, args: Arguments):
+ super(ParallelDroplessMLP, self).__init__(args)
+ self.hidden_size = args.hidden_size
+ self.ffn_hidden_size = mpu.features_per_rank(args)
+ self.blocking = 128
+ self.mlp = dmlp_registry.get(args)
+
+ # Calculate the number of bits needed to represent the column indices
+ # in the intermediate sparse matrix.
+ max_column_index = ((self.ffn_hidden_size * self.num_experts) // self.blocking)
+ self.transpose_sort_end_bit = max(
+ int(np.ceil(np.log2(max_column_index))),
+ 1,
+ )
+
+ def sparse_transpose(self, size, row_indices, column_indices, offsets):
+ block_columns = size[1] // self.blocking
+
+ # Sort row indices by column indices to get the transposed matrix's
+ # column indices.
+ #
+ # NOTE: Our sort operation uses the same width indices as the input values.
+ # To avoid overflow when we have large activation matrices we cast to
+ # 32-bit before sorting.
+ _, gather_indices = ops.sort(
+ column_indices.int(),
+ self.transpose_sort_end_bit,
+ )
+
+ # There are a constant number of blocks in every row of the sparse matrix.
+ # A blocks offset is:
+ #
+ # row_index * blocks_per_row + column_index % blocks_per_row
+ #
+ # Once we have the block offsets ordered for transposition we can divide
+ # by blocks_per_row to get the transposed column indices.
+ column_indices_t = row_indices.gather(0, gather_indices.long())
+ block_offsets_t = gather_indices.int()
+
+ zero = torch.zeros((1,), dtype=torch.int32, device=row_indices.device)
+ nnz_per_column = ops.histogram(column_indices, block_columns)
+ nnz_per_column = ops.inclusive_cumsum(nnz_per_column, 0)
+ if nnz_per_column.dim() == 0:
+ # This addresses an edge case when ffn_hidden_size is equal to self.blocking.
+ nnz_per_column = nnz_per_column.unsqueeze(0)
+ offsets_t = torch.cat([zero, nnz_per_column])
+ return column_indices_t, offsets_t, block_offsets_t
+
+ def topology(self, x, padded_bins):
+ padded_tokens, _ = x.size()
+ assert padded_tokens % self.blocking == 0
+ if self.ffn_hidden_size % self.blocking != 0:
+ raise ValueError(
+ f'The ffn_hidden_size {self.ffn_hidden_size} must be divisible by ' +
+ f'the block size {self.blocking}. Please update your configuration.',
+ )
+
+ # Offsets for the sparse matrix. All rows have the
+ # same number of nonzero blocks dictated by the
+ # dimensionality of a single expert.
+ block_rows = padded_tokens // self.blocking
+ blocks_per_row = self.ffn_hidden_size // self.blocking
+ offsets = torch.arange(
+ 0,
+ block_rows * blocks_per_row + 1,
+ blocks_per_row,
+ dtype=torch.int32,
+ device=x.device,
+ )
+
+ # Indices for the sparse matrix. The indices for
+ # the intermediate matrix are dynamic depending
+ # on the mapping of tokens to experts.
+ column_indices = ops.topology(
+ padded_bins,
+ self.blocking,
+ block_rows,
+ blocks_per_row,
+ )
+
+ # TODO(tgale): This is unused. Remove the need for this in stk.
+ # For now, use meta init to save the device memory.
+ data = torch.empty(
+ column_indices.numel(),
+ self.blocking,
+ self.blocking,
+ dtype=common.dtype(self.args),
+ device='meta',
+ )
+ shape = (
+ padded_tokens,
+ self.ffn_hidden_size * mpu.experts_per_rank(self.args),
+ )
+ row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
+ column_indices_t, offsets_t, block_offsets_t = self.sparse_transpose(
+ shape,
+ row_indices,
+ column_indices,
+ offsets,
+ )
+ return stk.Matrix(
+ shape,
+ data,
+ row_indices,
+ column_indices,
+ offsets,
+ column_indices_t,
+ offsets_t,
+ block_offsets_t,
+ )
+
+ def indices_and_padded_bins(self, top_experts):
+ # Sort the expert ids to produce the scatter/gather
+ # indices for the permutation.
+ top_experts = top_experts.int()
+ bin_ids, indices = ops.sort(top_experts, self.sort_end_bit)
+
+ # Histogram the expert ids to identify the number of
+ # tokens routed to each expert.
+ tokens_per_expert = ops.histogram(top_experts, self.num_experts)
+
+ # Round the token counts up to the block size used in
+ # the matrix muliplications. Caculate the starting
+ # position of each bin.
+ padded_tokens_per_expert = ops.round_up(
+ tokens_per_expert,
+ self.blocking,
+ )
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ padded_bins = promote_scalar(padded_bins)
+
+ # Calculate the bin bounds for the sorted tokens.
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ bins = promote_scalar(bins)
+ return indices, bin_ids, bins, padded_bins, tokens_per_expert
+
+ def sparse_forward_once(self, x, expert_weights, top_experts):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, padded_bins, tokens_per_expert = (self.indices_and_padded_bins(top_experts))
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.padded_gather(
+ x,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ self.top_k,
+ )
+
+ # Create the sparse matrix topology.
+ with torch.no_grad():
+ topo = self.topology(x, padded_bins)
+
+ # Perform the expert computation.
+ x = self.mlp(x, topo)
+
+ # Un-route the data for the MoE output.
+ x = ops.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ padded_bins,
+ self.top_k,
+ )
+ return x, tokens_per_expert
+
+ # For use in the base-class parallel_forward_once.
+ def sparse_permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy, # unused
+ top_k,
+ ):
+
+ # Round the token counts up to the block size used in the matrix
+ # multiplication. Calculate the starting position of each bin.
+ padded_tokens_per_expert = ops.round_up(
+ tokens_per_expert,
+ self.blocking,
+ )
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ padded_bins = promote_scalar(padded_bins)
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
+
+ # Create the sparse matrix topology.
+ with torch.no_grad():
+ topo = self.topology(x, padded_bins)
+
+ # Perform the expert computation.
+ x = self.mlp(x, topo)
+
+ # Un-route the data for the MoE output.
+ return ops.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ def grouped_forward_once(self, x, expert_weights, top_experts):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ out = self.grouped_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ -1, # unused
+ self.args.moe_top_k,
+ )
+ return out, tokens_per_expert
+
+ def grouped_permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy, # unused
+ top_k,
+ ):
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.gather(x, indices, bin_ids, bins, top_k)
+
+ # Perform the expert computation.
+ x = self.mlp(x, tokens_per_expert)
+
+ # Un-route the data for the MoE output.
+ return ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
+
+ def forward_once(self, x, expert_weights, top_experts):
+ if self.args.mlp_impl == 'sparse':
+ return self.sparse_forward_once(x, expert_weights, top_experts)
+ else:
+ return self.grouped_forward_once(x, expert_weights, top_experts)
+
+ def permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ ):
+ if self.args.mlp_impl == 'sparse':
+ return self.sparse_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ )
+ else:
+ return self.grouped_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ )
+
+
+class dMoE(moe.MoE):
+
+ def _init_experts_mlp(self, args: Arguments):
+ return ParallelDroplessMLP(args)
diff --git a/build/torch210-cxx11-cu128-x86_64-linux/_layers/gelu.py b/build/torch210-cxx11-cu128-x86_64-linux/_layers/gelu.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4c9e6532798615b5c12c96694241a4c18ee8f7b
--- /dev/null
+++ b/build/torch210-cxx11-cu128-x86_64-linux/_layers/gelu.py
@@ -0,0 +1,52 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# try:
+# import stk
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/gelu.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+import torch.nn.functional as F
+
+
+@torch.jit.script
+def _gelu_backward_inplace(g, x):
+ tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
+ ff = (0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out))
+ return g.mul_(ff)
+
+
+def gelu_backward_(grad: stk.Matrix, x: stk.Matrix):
+ # NOTE: The two sparse matrices must have the same topology.
+ if isinstance(grad, stk.Matrix) and isinstance(x, stk.Matrix):
+ return stk.Matrix(
+ x.size(),
+ _gelu_backward_inplace(grad.data, x.data),
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
+ return _gelu_backward_inplace(grad, x)
+
+
+def gelu(x: stk.Matrix):
+ assert isinstance(x, stk.Matrix)
+ return stk.Matrix(
+ x.size(),
+ F.gelu(x.data, approximate='tanh'),
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
diff --git a/build/torch210-cxx11-cu128-x86_64-linux/_layers/glu.py b/build/torch210-cxx11-cu128-x86_64-linux/_layers/glu.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f297a41ff6a1a2a285f5b461951672364b898da
--- /dev/null
+++ b/build/torch210-cxx11-cu128-x86_64-linux/_layers/glu.py
@@ -0,0 +1,244 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# import stk.ops
+# try:
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/glu.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+
+# from megablocks import grouped_gemm_util as gg
+# from megablocks.layers import common, mpu
+# from megablocks.layers.activation_fn import act_fn
+# from megablocks.layers.arguments import Arguments
+# from megablocks.layers.mlp import (
+# SharedMLP,
+# SparseMLP,
+# create_dmoe_expert_weights,
+# resolve_dtensor,
+# )
+
+from .. import grouped_gemm_util as gg
+from . import common, mpu
+from .activation_fn import act_fn
+from .arguments import Arguments
+from .mlp import (
+ SharedMLP,
+ SparseMLP,
+ create_dmoe_expert_weights,
+ resolve_dtensor,
+)
+
+
+class SparseGLU(SparseMLP):
+
+ def __init__(self, args: Arguments):
+ super().__init__(args)
+ self.v1 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ with torch.no_grad():
+ self.v1.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ ),
+ )
+
+ mpu.set_expert_model_parallel_attributes(
+ self.v1,
+ self._should_set_parallelism_attribute,
+ )
+
+ def forward(self, x, topo):
+ if self.args.memory_optimized_mlp:
+ raise NotImplementedError(
+ 'Memory optimized implementation not yet supported with GLU with sparse kernels.',
+ )
+
+ w1, v1, w2 = self.scale_grad(self.w1), self.scale_grad(self.v1,), self.scale_grad(self.w2)
+ w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2)
+
+ # Compute the GLU.
+ x1 = stk.ops.sdd(x, w1.t(), topo)
+ x2 = stk.ops.sdd(x, v1.t(), topo)
+
+ activation_fn_out = act_fn(x1, self.args.activation_fn)
+ x1 = stk.ops.mul(activation_fn_out, x2)
+
+ return stk.ops.dsd(x1, w2)
+
+
+class MemoryOptimizedGroupedGLU(torch.autograd.Function):
+ """GroupedMLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, v1, w2, batch_sizes, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ v1 = v1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], v1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not v1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1', 'v1' and 'w2'.")
+
+ # Layer 0: x @ w1.t().
+ assert gg.backend is not None
+ sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True)
+ v1_out = gg.backend.gmm(x, v1, batch_sizes, trans_b=True)
+
+ # GeLU.
+ activation_fn_out = activation_fn(sdd_out) * v1_out
+
+ # Layer 1: x @ w2.
+ dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, v1, w2, batch_sizes, x, sdd_out, v1_out)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx, ddsd_out):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # Unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, v1, w2 = saved_tensors[:3]
+ batch_sizes = saved_tensors[3]
+ x = saved_tensors[4]
+ sdd_out, v1_out = saved_tensors[5:7]
+
+ # Rematerialize activation_fn output.
+ activation_fn = ctx.activation_fn
+ with torch.set_grad_enabled(True):
+ sdd_out.requires_grad = True
+ v1_out.requires_grad = True
+ activation_fn_out = activation_fn(sdd_out) * v1_out
+ activation_grad_fn = activation_fn_out.backward
+
+ # Compute dw2 with recomputed activation_fn output.
+ assert gg.backend is not None
+ dw2 = gg.backend.gmm(
+ activation_fn_out,
+ ddsd_out,
+ batch_sizes,
+ trans_a=True,
+ )
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ gg.backend.gmm(
+ ddsd_out,
+ w2,
+ batch_sizes,
+ trans_b=True,
+ c=dactivation_fn_out,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out)
+ dsdd_out = sdd_out.grad
+ dv1_out = v1_out.grad
+
+ # Compute dw1.
+ dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True)
+
+ # Compute dv1.
+ dv1 = gg.backend.gmm(dv1_out, x, batch_sizes, trans_a=True)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ dx = ddsd_out
+ gg.backend.gmm(dsdd_out, w1, batch_sizes, c=dx)
+ dx += gg.backend.gmm(dv1_out, v1, batch_sizes)
+ return dx, dw1, dv1, dw2, None, None
+
+
+memory_optimized_grouped_glu = MemoryOptimizedGroupedGLU.apply
+
+
+class GroupedGLU(SparseGLU):
+
+ def forward(self, x, tokens_per_expert):
+ batch_sizes = tokens_per_expert.cpu().to(torch.long)
+ w1, v1, w2 = (
+ self.scale_grad(self.w1),
+ self.scale_grad(self.v1),
+ self.scale_grad(self.w2),
+ )
+ w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2)
+
+ # Re-shape the weights for the grouped GEMMs.
+ ne = mpu.experts_per_rank(self.args)
+ w1 = w1.view(ne, -1, self.args.hidden_size)
+ v1 = v1.view(ne, -1, self.args.hidden_size)
+ w2 = w2.view(ne, -1, self.args.hidden_size)
+
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_grouped_glu(
+ x,
+ w1,
+ v1,
+ w2,
+ batch_sizes,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ assert gg.ops is not None
+ x1 = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
+ x2 = gg.ops.gmm(x, v1, batch_sizes, trans_b=True)
+ x1 = self.args.activation_fn(x1) * x2
+ return gg.ops.gmm(x1, w2, batch_sizes)
+
+
+class SharedGLU(SharedMLP):
+ """GPU for shared expert.
+
+ Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTGLU class
+ """
+
+ def __init__(self, args: Arguments):
+ super().__init__(args)
+ self.gate_proj = args.fc_cls(
+ args.hidden_size,
+ self.args.shared_expert_hidden_size,
+ **self.fc_kwargs,
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))
diff --git a/build/torch210-cxx11-cu128-x86_64-linux/_layers/memory_test.py b/build/torch210-cxx11-cu128-x86_64-linux/_layers/memory_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..74d1166931b712635131985b25a89f4ca23e576d
--- /dev/null
+++ b/build/torch210-cxx11-cu128-x86_64-linux/_layers/memory_test.py
@@ -0,0 +1,103 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import gc
+
+import torch
+import torch.distributed as dist
+
+# from megablocks.layers import arguments, dmoe
+from . import arguments, dmoe
+
+_TESTS = ((8, 2048, 4096, 4096, 32, 4),)
+
+
+def get_tensors():
+ ptrs = set()
+ out = []
+ for obj in gc.get_objects():
+ if torch.is_tensor(obj):
+ if not obj.is_contiguous() or obj.data_ptr() in ptrs:
+ continue
+ out.append(obj)
+ ptrs.add(obj.data_ptr())
+ return out
+
+
+def test_memory(
+ group,
+ batch_size,
+ sequence_length,
+ hidden_size,
+ ffn_hidden_size,
+ num_experts,
+ top_k,
+):
+ args = arguments.Arguments(
+ hidden_size=hidden_size,
+ ffn_hidden_size=ffn_hidden_size,
+ moe_num_experts=num_experts,
+ moe_top_k=top_k,
+ moe_expert_model_parallelism=True,
+ expert_parallel_group=group,
+ fp16=False,
+ bf16=True,
+ device=torch.cuda.current_device(),
+ )
+ layer = dmoe.dMoE(args).cuda()
+
+ x = torch.randn((batch_size, sequence_length, hidden_size),
+ device=torch.cuda.current_device(),
+ dtype=torch.bfloat16).requires_grad_(True)
+ torch.cuda.empty_cache()
+
+ # Run forward + backward.
+ # with torch.autograd.detect_anomaly():
+ out, _ = layer(x)
+ out.mean().backward()
+
+ # Report peak memory.
+ mem = torch.cuda.max_memory_allocated()
+ print('Max Memory Allocated = {:0.0f}MiB'.format(mem / 1e6))
+ print('Max Memory Reserved = {:0.0f}MiB'.format(torch.cuda.max_memory_reserved() / 1e6,),)
+
+ # Calculate weight and gradient memory usage.
+ weight_memory = 2 * (
+ layer.router.layer.weight.numel() + layer.experts.mlp.w1.numel() + layer.experts.mlp.w2.numel()
+ )
+
+ def grad_numel(x):
+ if x.grad is not None:
+ return x.grad.numel()
+ return 0
+
+ grad_memory = 2 * (
+ grad_numel(layer.router.layer.weight) + grad_numel(layer.experts.mlp.w1) + grad_numel(layer.experts.mlp.w2)
+ )
+ weight_memory += grad_memory
+
+ print('Weight Memory Allocated = {:0.0f}MiB'.format(weight_memory / 1e6))
+ print('Activation Memory Allocated = {:0.0f}MiB'.format((mem - weight_memory) / 1e6,),)
+
+ # Manually calculate GPU memory usage from the garbage
+ # collector.
+ gc.collect()
+ total = 0
+ tensors = get_tensors()
+ tensors = sorted(tensors, key=lambda x: -x.numel())
+ for i, t in enumerate(tensors):
+ total += t.numel()
+ print(f'{i}: {t.shape}, {t.numel() * 2}')
+ del tensors
+
+ print('Total Bytes Found = {:0.0f}MiB'.format(total * 2 / 1e6))
+
+
+if __name__ == '__main__':
+ assert dist.is_available()
+ group = dist.init_process_group(backend='nccl')
+ local_rank = dist.get_rank(group)
+ torch.cuda.set_device(local_rank)
+
+ for args in _TESTS:
+ test_memory(group, *args)
diff --git a/build/torch210-cxx11-cu128-x86_64-linux/_layers/mlp.py b/build/torch210-cxx11-cu128-x86_64-linux/_layers/mlp.py
new file mode 100644
index 0000000000000000000000000000000000000000..c99afb9904c24a8b6a83e79059cd1251dbbfd99e
--- /dev/null
+++ b/build/torch210-cxx11-cu128-x86_64-linux/_layers/mlp.py
@@ -0,0 +1,587 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# try:
+# import stk
+# import stk.backend.triton_kernels
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/mlp.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+from packaging import version
+
+# from megablocks import grouped_gemm_util as gg
+# from megablocks.layers import common, gelu, mpu
+# from megablocks.layers.activation_fn import act_fn
+# from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
+
+from .. import grouped_gemm_util as gg
+from . import common, gelu, mpu
+from .activation_fn import act_fn
+from .arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
+
+class ScaleGradient(torch.autograd.Function):
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx: Any, x: torch.Tensor, scale: float):
+ ctx.scale = scale
+ return x
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx: torch.Tensor, grad: torch.Tensor):
+ return grad * ctx.scale, None
+
+
+scale_gradient = ScaleGradient.apply
+
+
+def resolve_dtensor(weight: torch.Tensor):
+ if version.parse(torch.__version__) >= version.parse('2.0.0'):
+ from torch.distributed._tensor import DTensor
+ if isinstance(weight, DTensor):
+ return weight.to_local()
+ return weight
+
+
+def create_moe_expert_weights(
+ args: Arguments,
+ num_experts: int,
+ ffn_hidden_size: int,
+ hidden_size: int,
+ init_method: InitFn,
+):
+ # Create the entire weight matrix such that the sampled weights will
+ # not vary between data parallelism and expert model parallelism for
+ # the same random seed.
+ master_weights = torch.empty(
+ num_experts,
+ ffn_hidden_size,
+ hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ )
+ init_method(master_weights)
+
+ if not args.moe_expert_model_parallelism:
+ return master_weights
+
+ # Calculate the amount of sharding in each dimension.
+ expert_sharding_degree = mpu.expert_sharding_degree(args)
+ hidden_sharding_degree = mpu.hidden_sharding_degree(args)
+
+ # Calculate the experts per rank.
+ #
+ # NOTE: We assign ranks to be expert parallel before going
+ # tensor parallel.
+ rank = mpu.get_expert_parallel_rank(args)
+ expert_rank = rank % expert_sharding_degree
+ num_experts_per_rank = num_experts // expert_sharding_degree
+ start_expert = expert_rank * num_experts_per_rank
+ end_expert = (expert_rank + 1) * num_experts_per_rank
+
+ # Calculate the rows per rank.
+ row_rank = rank // expert_sharding_degree
+ num_rows_per_rank = ffn_hidden_size // hidden_sharding_degree
+ start_row = row_rank * num_rows_per_rank
+ end_row = (row_rank + 1) * num_rows_per_rank
+
+ # Slice the weight matrix to get the chunk for this rank.
+ with torch.no_grad():
+ weights = master_weights[start_expert:end_expert, start_row:end_row]
+ return weights
+
+
+class MLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ # expert_parallel_world_size = mpu.get_expert_parallel_world_size(args)
+ experts_per_rank = mpu.experts_per_rank(args)
+
+ self.w1 = torch.nn.Parameter(
+ torch.empty(
+ experts_per_rank,
+ args.hidden_size,
+ mpu.features_per_rank(args),
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ self.w2 = torch.nn.Parameter(
+ torch.empty(
+ experts_per_rank,
+ mpu.features_per_rank(args),
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w1,
+ args.moe_expert_model_parallelism,
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w2,
+ args.moe_expert_model_parallelism,
+ )
+
+ # Initialize the parameters for the MLP.
+ #
+ # NOTE: It is important that we create the weight tensors prior
+ # to creating the master weights and slicing our the piece for
+ # this rank. If the master weights are created first the PyTorch
+ # caching allocator appears to use the same memory block for these
+ # and the slice which causes large increases in our peak memory
+ # usage.
+ with torch.no_grad():
+ w1 = create_moe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ )
+ self.w1.copy_(w1.transpose(1, 2).contiguous())
+ self.w2.copy_(
+ create_moe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.output_layer_init_method,
+ ),
+ )
+
+ self.gradient_scale = None
+ if self.args.moe_expert_model_parallelism:
+ self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,)
+
+ def scale_grad(self, w):
+ if self.gradient_scale is None:
+ return w
+ return scale_gradient(w, self.gradient_scale)
+
+ def forward(self, x):
+ w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2)
+ w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2)
+ x = torch.bmm(x, w1)
+ x = self.args.activation_fn(x)
+ return torch.bmm(x, w2)
+
+
+def create_dmoe_expert_weights(
+ args: Arguments,
+ num_experts: int,
+ rows: int,
+ columns: int,
+ init_method: InitFn,
+):
+ weights = create_moe_expert_weights(
+ args,
+ num_experts,
+ rows,
+ columns,
+ init_method,
+ )
+ return weights.view([-1, columns])
+
+
+class MemoryOptimizedMLP(torch.autograd.Function):
+ """Sparse MLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, w2, topo, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.")
+
+ topo_tensors = (
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets,
+ topo.column_indices_t,
+ topo.offsets_t,
+ topo.block_offsets_t,
+ )
+
+ # Layer 0: x @ w1.t().
+ sdd_out = stk.ops.sdd(x, w1.t(), topo)
+
+ # GeLU.
+ activation_fn_out = act_fn(sdd_out, activation_fn)
+
+ # Layer 1: x @ w2.
+ dsd_out = stk.ops.dsd(activation_fn_out, w2)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.shape = topo.shape
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.data.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, w2, *topo_tensors, x, sdd_out.data)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx, ddsd_out):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, w2 = saved_tensors[:2]
+ topo_tensors = saved_tensors[2:8]
+ x = saved_tensors[8]
+ sdd_out_data = saved_tensors[9]
+
+ # rematerialize activation function output
+ activation_fn = ctx.activation_fn
+ sdd_out = stk.Matrix(ctx.shape, sdd_out_data, *topo_tensors)
+ activation_fn_out, activation_grad_fn = act_fn(
+ sdd_out,
+ activation_fn,
+ return_grad_fn=True,
+ )
+
+ # Compute dw2 with recomputed activation_fn output.
+ dw2 = stk.ops.dsd(activation_fn_out.t(), ddsd_out)
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ stk.backend.triton_kernels.sdd(
+ ddsd_out,
+ w2.t(),
+ dactivation_fn_out.shape,
+ dactivation_fn_out.data,
+ dactivation_fn_out.offsets,
+ dactivation_fn_out.row_indices,
+ dactivation_fn_out.column_indices,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ if activation_fn is DEFAULT_ACTIVATION_FN:
+ dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out)
+ else:
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out.data)
+ dsdd_out = stk.Matrix(ctx.shape, sdd_out.data.grad, *topo_tensors)
+
+ # Compute dw1.
+ dw1 = stk.ops.dsd(dsdd_out.t(), x)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ stk.backend.triton_kernels.dsd(
+ dsdd_out.shape,
+ dsdd_out.data,
+ dsdd_out.offsets,
+ dsdd_out.row_indices,
+ dsdd_out.column_indices,
+ dsdd_out.offsets_t,
+ dsdd_out.column_indices_t,
+ dsdd_out.block_offsets_t,
+ False,
+ w1,
+ ddsd_out,
+ )
+ dx = ddsd_out
+ return dx, dw1, dw2, None, None
+
+
+memory_optimized_mlp = MemoryOptimizedMLP.apply
+
+
+class SparseMLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ self._num_rows_per_rank = mpu.experts_per_rank(args) * mpu.features_per_rank(args)
+
+ self.w1 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ self.w2 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+
+ # Initialize the parameters for the MLP.
+ #
+ # NOTE: It is important that we create the weight tensors prior
+ # to creating the master weights and slicing our the piece for
+ # this rank. If the master weights are created first the PyTorch
+ # caching allocator appears to use the same memory block for these
+ # and the slice which causes large increases in our peak memory
+ # usage.
+ with torch.no_grad():
+ self.w1.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ ),
+ )
+ self.w2.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.output_layer_init_method,
+ ),
+ )
+
+ self._should_set_parallelism_attribute = args.moe_expert_model_parallelism
+ mpu.set_expert_model_parallel_attributes(
+ self.w1,
+ self._should_set_parallelism_attribute,
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w2,
+ self._should_set_parallelism_attribute,
+ )
+
+ self.gradient_scale = None
+ if self.args.moe_expert_model_parallelism:
+ self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,)
+
+ def scale_grad(self, w):
+ if self.gradient_scale is None:
+ return w
+ return scale_gradient(w, self.gradient_scale)
+
+ def forward(self, x, topo):
+ w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2)
+ w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2)
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_mlp(
+ x,
+ w1,
+ w2,
+ topo,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ x = stk.ops.sdd(x, w1.t(), topo)
+ activation_fn_out = act_fn(x, self.args.activation_fn)
+ return stk.ops.dsd(activation_fn_out, w2)
+
+
+class MemoryOptimizedGroupedMLP(torch.autograd.Function):
+ """GroupedMLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, w2, batch_sizes, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.")
+
+ # Layer 0: x @ w1.t().
+ assert gg.backend is not None
+ sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True)
+
+ # activation_fn
+ activation_fn_out = activation_fn(sdd_out)
+
+ # Layer 1: x @ w2.
+ dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, w2, batch_sizes, x, sdd_out)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx: Any, ddsd_out: torch.Tensor):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # Unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, w2 = saved_tensors[:2]
+ batch_sizes = saved_tensors[2]
+ x = saved_tensors[3]
+ sdd_out = saved_tensors[4]
+
+ # Rematerialize activation_fn output.
+ activation_fn = ctx.activation_fn
+ with torch.set_grad_enabled(True):
+ sdd_out.requires_grad = True
+ activation_fn_out = activation_fn(sdd_out)
+ activation_grad_fn = activation_fn_out.backward
+
+ # Compute dw2 with recomputed activation_fn output.
+ assert gg.backend is not None
+ dw2 = gg.backend.gmm(
+ activation_fn_out,
+ ddsd_out,
+ batch_sizes,
+ trans_a=True,
+ )
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ gg.backend.gmm(
+ ddsd_out,
+ w2,
+ batch_sizes,
+ trans_b=True,
+ c=dactivation_fn_out,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ if activation_fn is DEFAULT_ACTIVATION_FN:
+ dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out)
+ else:
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out)
+ dsdd_out = sdd_out.grad
+
+ # Compute dw1.
+ dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ gg.backend.gmm(dsdd_out, w1, batch_sizes, c=ddsd_out)
+ dx = ddsd_out
+ return dx, dw1, dw2, None, None
+
+
+memory_optimized_grouped_mlp = MemoryOptimizedGroupedMLP.apply
+
+
+class GroupedMLP(SparseMLP):
+
+ def forward(self, x, tokens_per_expert):
+ batch_sizes = tokens_per_expert.cpu().to(torch.long)
+ w1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.w2))
+
+ # Re-shape the weights for the grouped GEMMs.
+ ne = mpu.experts_per_rank(self.args)
+ w1 = resolve_dtensor(w1).view(ne, -1, self.args.hidden_size)
+ w2 = resolve_dtensor(w2).view(ne, -1, self.args.hidden_size)
+
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_grouped_mlp(
+ x,
+ w1,
+ w2,
+ batch_sizes,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ assert gg.ops is not None
+ x = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
+ x = self.args.activation_fn(x)
+ return gg.ops.gmm(x, w2, batch_sizes)
+
+
+class SharedMLP(torch.nn.Module):
+ """MLP for shared expert.
+
+ Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTMLP class
+ """
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ self.fc_kwargs: dict[str, Any] = {
+ 'bias': args.bias,
+ 'device': args.device,
+ }
+ self.fc_kwargs.update(args.fc_kwargs)
+
+ self.up_proj = args.fc_cls(
+ args.hidden_size,
+ args.shared_expert_hidden_size,
+ **self.fc_kwargs,
+ )
+ self.act = args.activation_fn
+ self.down_proj = args.fc_cls(
+ args.shared_expert_hidden_size,
+ args.hidden_size,
+ **self.fc_kwargs,
+ )
+ self.down_proj._is_residual = True # a flag for llm-foundry init
+
+ def add_experts_sharedexpert(
+ self,
+ shared_expert_out: torch.Tensor,
+ expert_out: torch.Tensor,
+ ) -> torch.Tensor:
+ # Helper function to add expert output to shared expert output
+ # with optional weighted sum.
+ if self.args.shared_expert_weighted_sum:
+ # enable using weighted sum for shared expert output
+ # wieghted by number of experts used
+ t_experts = self.args.moe_top_k + 1
+ sh_mlp_out = shared_expert_out / t_experts
+ return sh_mlp_out.add(
+ expert_out,
+ alpha=(self.args.moe_top_k / t_experts),
+ )
+
+ return shared_expert_out + expert_out
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.down_proj(self.act(self.up_proj(x)))
diff --git a/build/torch210-cxx11-cu128-x86_64-linux/_layers/moe.py b/build/torch210-cxx11-cu128-x86_64-linux/_layers/moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0a4aeaacc9c86fc70944e730c53f7a55644e05e
--- /dev/null
+++ b/build/torch210-cxx11-cu128-x86_64-linux/_layers/moe.py
@@ -0,0 +1,507 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Optional, Tuple
+
+import numpy as np
+import torch
+import torch.distributed as dist
+
+# import megablocks.ops as ops
+# from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry
+# from megablocks.layers.all_to_all import all_to_all
+# from megablocks.layers.arguments import Arguments
+
+from ..ops import (
+ sort,
+ histogram,
+ inclusive_cumsum,
+ exclusive_cumsum,
+ binned_gather,
+ binned_scatter,
+ gather,
+ scatter,
+ repeat,
+ replicate,
+)
+
+from . import common, mlp, mpu, router, sharedexpert_registry
+from .arguments import Arguments
+from .all_to_all import all_to_all
+
+_LOAD_BALANCING_LOSS = []
+
+
+def save_load_balancing_loss(loss):
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.append(loss)
+
+
+def get_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ return _LOAD_BALANCING_LOSS
+
+
+def clear_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.clear()
+
+
+def batched_load_balancing_loss(args: Arguments):
+ if args.moe_loss_weight == 0:
+ return 0.0
+
+ # tokens_per_expert[i].shape = (num_experts)
+ # expert_scores[i].shape = (tokens, num_experts)
+ tokens_per_expert, expert_scores = zip(*get_load_balancing_loss())
+ num_layers_per_pipeline_stage = (args.num_layers // args.pipeline_model_parallel_size)
+ if args.num_layers_per_virtual_pipeline_stage is not None:
+ num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage
+
+ if len(tokens_per_expert) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f'Expected {num_layers_per_pipeline_stage} token_per_experts '
+ f'but found {len(tokens_per_expert)}.\nnum_layers = '
+ f'{args.num_layers}\npipeline_model_parallel_size = '
+ f'{args.pipeline_model_parallel_size}\n'
+ 'num_layers_per_virtual_pipeline_stage'
+ f' = {args.num_layers_per_virtual_pipeline_stage}',
+ )
+ if len(expert_scores) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f'Expected {num_layers_per_pipeline_stage} expert_scores '
+ f'but found {len(tokens_per_expert)}.\nnum_layers = '
+ f'{args.num_layers}\npipeline_model_parallel_size = '
+ f'{args.pipeline_model_parallel_size}\n'
+ 'num_layers_per_virtual_pipeline_stage'
+ f' = {args.num_layers_per_virtual_pipeline_stage}',
+ )
+
+ # Verify the shape of the tokens_per_expert and expert_scores tensors.
+ assert all((x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert))
+
+ tokens = expert_scores[0].shape[0]
+ assert all(((x.ndim == 2 and x.shape[1] == args.moe_num_experts and x.shape[0] == tokens) for x in expert_scores))
+
+ # Concatenate the contributions of each layer and convert to
+ # the correct types and formats for the dot product.
+ expert_scores = torch.cat(expert_scores, dim=1)
+ if args.moe_lbl_in_fp32:
+ expert_scores = expert_scores.float()
+ if tokens != 0:
+ expert_scores = expert_scores.mean(dim=0)
+ else:
+ expert_scores = expert_scores.sum(dim=0)
+ tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype)
+
+ expected_values = num_layers_per_pipeline_stage * args.moe_num_experts
+ assert tokens_per_expert.numel() == expected_values
+ assert expert_scores.numel() == expected_values
+
+ # Calculate the total scale across all factors.
+ #
+ # loss_weight * num_experts / (num_layers * tokens * top_k)
+ scale_numerator = (args.moe_num_experts * args.moe_loss_weight)
+ scale_denominator = (args.num_layers * tokens * args.moe_top_k)
+ scale = scale_numerator / scale_denominator
+ return scale * torch.dot(tokens_per_expert, expert_scores)
+
+
+# NOTE: This class defines MoE expert computation, including expert model parallel
+# communication. When using FSDP on top of MegaBlocks this is the module that should
+# be wrapped s.t. the weight all-gathers can be scheduled *before* the expert model
+# parallel all2all.
+class ParallelMLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super(ParallelMLP, self).__init__()
+ self.args = args
+
+ # Calculate the number of experts in total and the number of experts
+ # owned by this rank.
+ # world_size = mpu.get_expert_parallel_world_size(args)
+ self.num_experts = args.moe_num_experts
+ self.top_k = self.args.moe_top_k
+
+ # Calculate the number of bits needed to represent the expert indices
+ # so that we can pass it to radix sort.
+ self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1)
+
+ # Expert MLP.
+ self.mlp = mlp.MLP(args)
+
+ self.bias: Optional[torch.Tensor]
+ if self.args.bias:
+ # Note that the output bias is not parallelized with expert
+ # model parallelism.
+ self.bias = torch.nn.Parameter(
+ torch.empty(
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ torch.nn.init.zeros_(self.bias)
+ else:
+ self.register_parameter('bias', None)
+
+ # Select the forward function for the operating mode.
+ self.forward_fn = (self.parallel_forward_once if args.moe_expert_model_parallelism else self.forward_once)
+
+ def expert_capacity(self, tokens: int) -> int:
+ world_size = mpu.get_expert_parallel_world_size(self.args)
+ tokens_per_expert = (self.top_k * tokens * world_size / self.num_experts)
+ return int(self.args.moe_capacity_factor * tokens_per_expert)
+
+ def load_balancing_loss(self, tokens_per_expert: torch.Tensor, expert_scores: torch.Tensor):
+ """Calculate the load balancing loss contribution."""
+ assert len(expert_scores.size()) == 2
+ tokens, num_experts = expert_scores.size()
+ assert num_experts == self.num_experts
+ assert len(tokens_per_expert.size()) == 1
+ num_experts, = tokens_per_expert.size()
+ assert num_experts == self.num_experts
+ scale = self.num_experts / (tokens * self.top_k)
+ return scale * torch.dot(
+ tokens_per_expert.to(expert_scores.dtype),
+ expert_scores.mean(dim=0),
+ )
+
+ def indices_and_bins(self,
+ top_expert: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ # Sort the expert ids to produce the scatter/gather
+ # indices for the permutation.
+ #
+ # TODO(tgale): Is it worth doing this conversion to 32-bit
+ # prior? Could we place the `torch.max` operation to return
+ # 32-bit expert indices?
+ top_expert = top_expert.int()
+ # output = ops.sort(top_expert, self.sort_end_bit)
+ output = sort(top_expert, self.sort_end_bit)
+ assert output is not None
+ bin_ids, indices = output
+
+ # Histogram the expert ids to identify the number of
+ # tokens routed to each expert.
+ #
+ # TODO(tgale): Does the sorted data produce a more favorable
+ # data distribution for histogram? Or is the op parallelism
+ # worth more?
+ # tokens_per_expert = ops.histogram(top_expert, self.num_experts)
+ tokens_per_expert = histogram(top_expert, self.num_experts)
+
+ # Calculate the bin bounds for the sorted tokens.
+ # bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ bins = inclusive_cumsum(tokens_per_expert, 0)
+ assert bins is not None
+ bins = bins.view(1) if not len(bins.size()) else bins
+
+ assert isinstance(indices, torch.Tensor)
+ assert isinstance(bin_ids, torch.Tensor)
+ assert isinstance(bins, torch.Tensor)
+ assert isinstance(tokens_per_expert, torch.Tensor)
+
+ return indices, bin_ids, bins, tokens_per_expert
+
+ def permute_and_compute(
+ self,
+ x: torch.Tensor,
+ tokens_per_expert: int, # unused
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor, # unused
+ expert_weights: torch.Tensor,
+ bins: torch.Tensor,
+ expert_capacity: int,
+ top_k: int,
+ ):
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ # output = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
+ output = binned_gather(x, indices, bins, expert_capacity, top_k)
+ assert output is not None
+ x = output
+
+ # Perform the expert computation. Note that we don't
+ # use biases for these linear operations.
+ x = self.mlp(x)
+
+ # Un-route the data for the MoE output.
+ # return ops.binned_scatter(x, indices, expert_weights, bins, top_k)
+ return binned_scatter(x, indices, expert_weights, bins, top_k)
+
+
+ def forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ # If expert_capacity is set to zero, set the number of tokens
+ # per expert to the maximum we need to avoid dropping tokens.
+ sl, bs, _ = x.size()
+ expert_capacity = self.expert_capacity(sl * bs)
+ if expert_capacity == 0:
+ expert_capacity = torch.max(tokens_per_expert).item()
+
+ x = self.permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ self.top_k,
+ )
+ return x, tokens_per_expert
+
+ def parallel_forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ # NOTE: This function implements the same computation as forward_once
+ # but with expert model parallelism.
+ #
+ # 1. Permute the tokens locally so that they are grouped by their
+ # expert assignments. This allows us to transfer all of the tokens
+ # for a remote device in one communication primitive.
+ #
+ # 2. Permute the tokens across the expert parallel devices. After
+ # this is completed each device has all of the tokens assigned to
+ # its set of experts in its local HBM.
+ #
+ # 3. Permute the tokens locally so that they are grouped by their
+ # expert assignement. After the distributed permutation the tokens
+ # are grouped by which device they came from. We re-order them
+ # locally to allow for efficient computation.
+ #
+ # After this series of permutations we compute the linear layers
+ # and then repeat these three steps in reverse to produce the final
+ # output.
+ #
+ # Compute the mapping of local tokens to experts.
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ # If we're sharding the experts along the hidden dimension
+ # multiple devices own parts of the same sets of experts.
+ # Replicate the token counts so every device gets the counts.
+ # repeated_tokens_per_expert = ops.repeat(
+ repeated_tokens_per_expert = repeat(
+ tokens_per_expert,
+ (mpu.hidden_sharding_degree(self.args),),
+ )
+
+ # Pass token count information to the device on which the
+ # target expert resides.
+ parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert,)
+ tpe_handle = dist.all_to_all_single(
+ parallel_tokens_per_expert,
+ repeated_tokens_per_expert,
+ group=self.args.expert_parallel_group,
+ async_op=True,
+ )
+
+ # Permute locally and without any padding so that tokens for each
+ # parallel device are stored contiguously.
+ #
+ # This view updates the shape of the tensor from [sl, bs, hs] to
+ # [sl * bs, hs] prior to the permutation.
+ x = x.view(-1, x.shape[-1])
+ # output = ops.gather(x, indices, bin_ids, bins, self.top_k)
+ output = gather(x, indices, bin_ids, bins, self.top_k)
+ assert output is not None
+ x = output
+
+ # Compute the number of tokens that will be received from each
+ # device and permute the input data across the devices.
+ with torch.no_grad():
+ tpe_handle.wait()
+ experts_per_rank = mpu.experts_per_rank(self.args)
+
+ # Reshape to [world_size, num_experts_per_rank].
+ world_size = mpu.get_expert_parallel_world_size(self.args)
+ repeated_tokens_per_expert = (repeated_tokens_per_expert.view(world_size, experts_per_rank))
+ parallel_tokens_per_expert = (parallel_tokens_per_expert.view(world_size, experts_per_rank))
+
+ # TODO(tgale): It might be faster to do this on the GPU and
+ # then communicate the results back to the host.
+ send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1)
+ parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
+ recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1)
+
+ # Convert the send/recv counts to lists.
+ send_counts = send_counts.tolist()
+ recv_counts = recv_counts.tolist()
+ tokens_received = sum(recv_counts)
+
+ # If we're sharding the experts along the hidden dimension
+ # multiple devices own parts of the same sets of experts.
+ # Replicate the token counts so devices that share experts
+ # get all of the tokens assigned to them.
+ #
+ # TODO(tgale): Fuse this into the prior, local permutation.
+ # x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
+ x = repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
+
+ # Start the cross-device permutation asynchronously so we can
+ # overlap communication with computation.
+ parallel_x, parallel_x_handle = all_to_all(
+ x,
+ recv_counts,
+ send_counts,
+ self.args.expert_parallel_group,
+ async_op=True,
+ )
+
+ with torch.no_grad():
+ # After we do the cross-device permutation we have the tokens on the
+ # correct device but not yet grouped by expert because we received
+ # tokens from each device as contiguous chunks. To group the tokens
+ # for expert computation we'll do one more local permutation. The
+ # rest of this torch.no_grad() scope sets up the indices and bins
+ # for this permutation.
+ # replicate_bins = ops.inclusive_cumsum(
+ replicate_bins = inclusive_cumsum(
+ parallel_tokens_per_expert.flatten(),
+ 0,
+ )
+ replicate_bins = (replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins)
+
+ # Construct the expert indices for the permuted tokens.
+ parallel_top_expert = torch.remainder(
+ torch.arange(
+ self.num_experts * mpu.hidden_sharding_degree(self.args),
+ dtype=torch.int32,
+ device=indices.device,
+ ),
+ mpu.experts_per_rank(self.args),
+ )
+ # parallel_top_expert = ops.replicate(
+ parallel_top_expert = replicate(
+ parallel_top_expert.unsqueeze(dim=0),
+ replicate_bins,
+ tokens_received,
+ ).flatten()
+
+ # TODO(tgale): The sort_end_bit here can be reduced.
+ # parallel_bin_ids, parallel_indices = ops.sort(
+ parallel_bin_ids, parallel_indices = sort(
+ parallel_top_expert,
+ self.sort_end_bit,
+ )
+
+ # Calculate the bins boundaries from the token counts.
+ parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+ # parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = (parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins)
+
+ # If expert_capacity is set to zero, set the number of tokens
+ # per expert to the maximum we need to avoid dropping tokens.
+ tokens, _ = x.size()
+ expert_capacity = self.expert_capacity(tokens)
+ if expert_capacity == 0:
+ expert_capacity = torch.max(parallel_tokens_per_expert).item()
+
+ # Locally permute the tokens and perform the expert computation.
+ # Block to make sure that the cross-device permutation is complete.
+ if self.args.mlp_impl == 'grouped':
+ # GroupedMLP requires counts on CPU. We can use the tensor already
+ # moved to CPU for the prior all_to_all, which avoids an extra
+ # device synchronization.
+ parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+ parallel_x_handle.wait()
+ parallel_x = self.permute_and_compute(
+ parallel_x,
+ parallel_tokens_per_expert,
+ parallel_indices,
+ parallel_bin_ids,
+ None, # expert_weights
+ parallel_bins,
+ expert_capacity,
+ top_k=1,
+ )
+
+ # Un-permute the tokens across the devices.
+ x, _ = all_to_all(
+ parallel_x,
+ send_counts,
+ recv_counts,
+ self.args.expert_parallel_group,
+ )
+
+ # Reduce along the hidden sharding to get the final outputs.
+ #
+ # TODO(tgale): Fuse this into the following local permutation.
+ shape = (
+ mpu.hidden_sharding_degree(self.args),
+ -1,
+ self.args.hidden_size,
+ )
+ # x = ops.sum(x.view(shape), dim=0)
+ x = x.view(shape).sum(dim=0)
+
+ # Un-permute locally to setup for the next series of operations.
+ # x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
+ x = scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
+ return x, tokens_per_expert.flatten()
+
+ def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ in_shape = x.size()
+
+ # Compute the experts.
+ x, tokens_per_expert = self.forward_fn(x, expert_weights, top_experts)
+ if self.training and self.args.moe_loss_weight > 0:
+ save_load_balancing_loss((tokens_per_expert, scores))
+ x = x.view(in_shape)
+ if self.bias is not None:
+ if self.args.return_bias:
+ return x, self.bias
+ return x + self.bias
+ return x
+
+
+class MoE(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super(MoE, self).__init__()
+
+ # Token router.
+ self.router = router.LearnedRouter(args)
+
+ # Expert computation helper.
+ self.experts = self._init_experts_mlp(args)
+
+ self.shared_expert = None
+ if args.shared_expert:
+ # SharedExpert computation helper.
+ self.shared_expert = sharedexpert_registry.get(args)
+
+ def _init_experts_mlp(self, args: Arguments):
+ return ParallelMLP(args)
+
+ def forward(self, x: torch.Tensor):
+ # NOTE: If we're going to cast the activations to lower precision
+ # do it before we permute the tokens to save bandwidth.
+ x = common.cast_if_autocast_enabled(x)
+
+ # Compute the expert scores and assignments.
+ scores, expert_weights, top_experts = self.router(x)
+
+ # Compute the experts.
+ out = self.experts(x, scores, expert_weights, top_experts)
+ if self.shared_expert is not None:
+ shared_expert_out = self.shared_expert(x)
+ out = self.shared_expert.add_experts_sharedexpert(
+ shared_expert_out,
+ out,
+ )
+ return out
diff --git a/build/torch210-cxx11-cu128-x86_64-linux/_layers/mpu.py b/build/torch210-cxx11-cu128-x86_64-linux/_layers/mpu.py
new file mode 100644
index 0000000000000000000000000000000000000000..434e143ab42bf3f83406d69e9dd1f72777716e22
--- /dev/null
+++ b/build/torch210-cxx11-cu128-x86_64-linux/_layers/mpu.py
@@ -0,0 +1,94 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Optional
+
+import torch
+import torch.distributed as dist
+
+# from megablocks.layers.arguments import Arguments
+from .arguments import Arguments
+
+
+class MoeParam(torch.Tensor):
+
+ def __init__(self):
+ super().__init__(self)
+ self.expert_model_parallel: bool
+
+
+def is_moe_param(tensor: torch.Tensor) -> bool:
+ return hasattr(tensor, 'expert_model_parallel')
+
+
+def get_expert_parallel_world_size(args: Arguments) -> int:
+ return (dist.get_world_size(args.expert_parallel_group) if args.moe_expert_model_parallelism else 1)
+
+
+def get_expert_parallel_rank(args: Arguments) -> int:
+ return (dist.get_rank(args.expert_parallel_group) if args.moe_expert_model_parallelism else 0)
+
+
+def set_expert_model_parallel_attributes(
+ tensor: torch.Tensor,
+ is_parallel: bool,
+):
+ assert not hasattr(tensor, 'expert_model_parallel')
+ setattr(tensor, 'expert_model_parallel', is_parallel)
+
+
+def param_is_expert_model_parallel(param: MoeParam) -> bool:
+ return (hasattr(param, 'expert_model_parallel') and param.expert_model_parallel)
+
+
+def copy_expert_model_parallel_attributes(
+ destination_tensor: torch.Tensor,
+ source_tensor: torch.Tensor,
+):
+ if hasattr(source_tensor, 'expert_model_parallel'):
+ setattr(
+ destination_tensor,
+ 'expert_model_parallel',
+ getattr(source_tensor, 'expert_model_parallel'),
+ )
+
+
+def synchronized_print(group: Optional[dist.ProcessGroup], *x: torch.Tensor):
+ world_size = dist.get_world_size(group)
+ rank = dist.get_rank(group)
+ for i in range(world_size):
+ dist.barrier(group)
+ if i == rank:
+ print(f'rank = {rank}', *x)
+
+
+# Helpers for expert/tensor sharding.
+def expert_sharding_degree(args: Arguments) -> int:
+ world_size = get_expert_parallel_world_size(args)
+ esd = min(world_size, args.moe_num_experts)
+
+ if (args.moe_num_experts % esd) != 0:
+ raise ValueError(f'Cannot shard {args.moe_num_experts} experts {esd} ways.',)
+ return esd
+
+
+def hidden_sharding_degree(args: Arguments) -> int:
+ world_size = get_expert_parallel_world_size(args)
+ esd = expert_sharding_degree(args)
+ hsd = world_size // esd
+
+ if (args.ffn_hidden_size % hsd) != 0:
+ raise ValueError(f'Cannot shard {args.ffn_hidden_size} features {hsd} ways.',)
+ if (esd * hsd) != world_size:
+ raise ValueError(
+ f"Invalid sharding. 'expert_sharding_degree' ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size}).",
+ )
+ return hsd
+
+
+def experts_per_rank(args: Arguments) -> int:
+ return args.moe_num_experts // expert_sharding_degree(args)
+
+
+def features_per_rank(args: Arguments) -> int:
+ return args.ffn_hidden_size // hidden_sharding_degree(args)
diff --git a/build/torch210-cxx11-cu128-x86_64-linux/_layers/router.py b/build/torch210-cxx11-cu128-x86_64-linux/_layers/router.py
new file mode 100644
index 0000000000000000000000000000000000000000..37cb2782348d62583376f1a183c7ede83601216d
--- /dev/null
+++ b/build/torch210-cxx11-cu128-x86_64-linux/_layers/router.py
@@ -0,0 +1,116 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+
+# from megablocks.layers import common
+# from megablocks.layers.arguments import Arguments
+from . import common
+from .arguments import Arguments
+
+_ROUTER_LOGITS = []
+
+
+def _save_router_logits(logits: torch.Tensor, args: Arguments):
+ if args.moe_zloss_weight == 0:
+ return
+ global _ROUTER_LOGITS
+ _ROUTER_LOGITS.append(logits)
+
+
+def clear_router_zloss():
+ global _ROUTER_LOGITS
+ _ROUTER_LOGITS.clear()
+
+
+def batched_router_zloss(args: Arguments):
+ global _ROUTER_LOGITS
+
+ if args.moe_zloss_weight == 0:
+ import warnings
+ warnings.warn('Call to batched_router_zloss, but moe_zloss_weight=0')
+ return 0
+
+ logits_per_router = _ROUTER_LOGITS
+
+ if args.moe_zloss_in_fp32:
+ logits_per_router = [logits.float() for logits in logits_per_router]
+
+ unscaled_zloss_per_router = torch.stack([
+ torch.logsumexp(logits, dim=1).square().mean() for logits in logits_per_router
+ ])
+
+ return args.moe_zloss_weight * unscaled_zloss_per_router
+
+
+# NOTE: To enable end-to-end benchmarking without convergence we
+# support a flag to force the router to assign tokens uniformly
+# across the experts. We do this with a custom autograd operation
+# so that PyTorch still executes the full set of router operation.
+class _UniformExpertAssignment(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, num_experts: int):
+ out = torch.arange(x.numel(), dtype=x.dtype, device=x.device)
+ out = torch.remainder(out, num_experts)
+ return out.view(x.shape)
+
+
+_uniform_expert_assignment = _UniformExpertAssignment.apply
+
+
+class LearnedRouter(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+
+ # Learned router parameters.
+ #
+ # NOTE: This weight matrix is not parallelized with expert model
+ # parallelism. Each device needs the entire router weight matrix
+ # so that it can route its batch of data correctly.
+ self.layer = torch.nn.Linear(
+ args.hidden_size,
+ args.moe_num_experts,
+ bias=False,
+ dtype=common.dtype(args),
+ device=args.device,
+ )
+ args.init_method(self.layer.weight)
+
+ def jitter(self, x: torch.Tensor):
+ low: float = 1.0 - self.args.moe_jitter_eps
+ high: float = 1.0 + self.args.moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return low + noise * (high - low)
+
+ def _top_k(self, scores: torch.Tensor):
+ if self.args.moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, self.args.moe_top_k, dim=-1)
+
+ def forward(self, x: torch.Tensor):
+ if self.training and self.args.moe_jitter_eps is not None:
+ x = x * self.jitter(x)
+
+ logits = self.layer(x.view(-1, x.shape[-1]))
+ _save_router_logits(logits, self.args)
+ scores = logits.softmax(dim=-1)
+ expert_weights, expert_indices = self._top_k(scores)
+ if self.args.moe_normalize_expert_weights:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=self.args.moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+
+ expert_indices = (
+ _uniform_expert_assignment(
+ expert_indices,
+ self.args.moe_num_experts,
+ ) if self.args.uniform_expert_assignment else expert_indices
+ )
+ return scores, expert_weights, expert_indices
diff --git a/build/torch210-cxx11-cu128-x86_64-linux/_layers/sharedexpert_registry.py b/build/torch210-cxx11-cu128-x86_64-linux/_layers/sharedexpert_registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..5840862f88f370ace5fd49bd0612fc98d186cc49
--- /dev/null
+++ b/build/torch210-cxx11-cu128-x86_64-linux/_layers/sharedexpert_registry.py
@@ -0,0 +1,32 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Union
+
+# from megablocks.layers import glu, mlp
+# from megablocks.layers.arguments import Arguments
+from . import glu, mlp
+from .arguments import Arguments
+
+_REGISTRY = {
+ 'mlp': mlp.SharedMLP,
+ 'glu': glu.SharedGLU,
+}
+
+
+def get(args: Arguments) -> Union[mlp.SharedMLP, glu.SharedGLU]:
+ """Returns an SharedMLP for use in a dMoE instance.
+
+ Uses the provided arguments to instantiate the appropriate
+ SharedMLP instance.
+
+ Args:
+ args: propagated Arguments dataclass.
+
+ Returns:
+ An instantiated SharedMLP constructed using the input args.
+ """
+ if args.mlp_type not in _REGISTRY:
+ raise ValueError(f'Unsupported mlp type: {args.mlp_type}')
+
+ return _REGISTRY[args.mlp_type](args)
diff --git a/build/torch210-cxx11-cu128-x86_64-linux/_megablocks_cuda_7a6bcf4.abi3.so b/build/torch210-cxx11-cu128-x86_64-linux/_megablocks_cuda_7a6bcf4.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..b730d142f5ea59b45f3e5a9f0e347dbd8b7b589f
--- /dev/null
+++ b/build/torch210-cxx11-cu128-x86_64-linux/_megablocks_cuda_7a6bcf4.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d4e977b806a1a10968921e0bae84919664195cdc7baf05c08bf9ee63e4daa752
+size 21009984
diff --git a/build/torch210-cxx11-cu128-x86_64-linux/_ops.py b/build/torch210-cxx11-cu128-x86_64-linux/_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..abde001f8cf5f78a02794d6e9a81fd8195e65d77
--- /dev/null
+++ b/build/torch210-cxx11-cu128-x86_64-linux/_ops.py
@@ -0,0 +1,9 @@
+import torch
+from . import _megablocks_cuda_7a6bcf4
+ops = torch.ops._megablocks_cuda_7a6bcf4
+
+def add_op_namespace_prefix(op_name: str):
+ """
+ Prefix op by namespace.
+ """
+ return f"_megablocks_cuda_7a6bcf4::{op_name}"
diff --git a/build/torch210-cxx11-cu128-x86_64-linux/_version.py b/build/torch210-cxx11-cu128-x86_64-linux/_version.py
new file mode 100644
index 0000000000000000000000000000000000000000..c55783177af19bc03654c730c4892df8f8532279
--- /dev/null
+++ b/build/torch210-cxx11-cu128-x86_64-linux/_version.py
@@ -0,0 +1,6 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+"""The MegaBlocks Version."""
+
+__version__ = '0.11.0.dev0'
diff --git a/build/torch210-cxx11-cu128-x86_64-linux/backend/__init__.py b/build/torch210-cxx11-cu128-x86_64-linux/backend/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d4e43e9b7e6a9a1c3bde2df34914643ca5d8332
--- /dev/null
+++ b/build/torch210-cxx11-cu128-x86_64-linux/backend/__init__.py
@@ -0,0 +1,2 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
diff --git a/build/torch210-cxx11-cu128-x86_64-linux/backend/kernels.py b/build/torch210-cxx11-cu128-x86_64-linux/backend/kernels.py
new file mode 100644
index 0000000000000000000000000000000000000000..d20139324883338ddb312e4b05a72056d74491ac
--- /dev/null
+++ b/build/torch210-cxx11-cu128-x86_64-linux/backend/kernels.py
@@ -0,0 +1,557 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import triton
+import triton.language as tl
+
+# Stub triton autotune when testing in a env that does not have CUDA
+# this approach preserves the original code but enables testing without a GPU
+if torch.cuda.is_available() is False:
+ import warnings
+
+ warnings.warn("CUDA is not available. Triton autotuning is disabled.")
+
+ def _no_autotune(*args, **kwargs):
+ def deco(fn):
+ return fn
+ return deco
+
+ triton.autotune = _no_autotune
+
+
+def assert_is_tensor(x, ndim):
+ if x.ndim != ndim:
+ raise ValueError(f'Expected {ndim}-tensor but got {x.ndim}-tensor')
+
+
+def assert_is_matrix(x):
+ assert_is_tensor(x, 2)
+
+
+def assert_is_vector(x):
+ if x.ndim != 1:
+ raise ValueError(f'Expected 1-tensor but got {x.ndim}-tensor')
+
+
+def assert_equal(a, b):
+ if a != b:
+ raise ValueError(f'Expected dimensions to be equal but got {a} and {b}.',)
+
+
+# a: (tokens, hidden_size), real.
+# indices: (tokens * top_k), integer.
+# bin_ids: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+# padded_bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _padded_copy(
+ a,
+ b,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+ A_TO_B: tl.constexpr,
+ SCALE: tl.constexpr,
+):
+ # Our index into array 'a'.
+ index_a = tl.load(indices + tl.program_id(0))
+
+ # One threadblock per row in 'a'. Array 'b' has greater or equal
+ # number of rows since they could be padded.
+ bin_idx = tl.load(bin_ids + tl.program_id(0))
+
+ # Now we know what bin we're assigned to, but we need to know how
+ # many threadblocks were assigned to earlier bins so we can offset
+ # in our bin properly.
+ offset_in_bin = tl.program_id(0)
+ if bin_idx > 0:
+ offset_in_bin -= tl.load(bins + bin_idx - 1)
+
+ # Load the starting index of our bin in array 'b'.
+ index_b = offset_in_bin
+ if bin_idx > 0:
+ index_b += tl.load(padded_bins + bin_idx - 1)
+
+ # Offset the input and output pointers.
+ #
+ # If we're going from A to B, divide the input index to copy
+ # the same input repeatedly. If we're going from B to A we
+ # need to reduce the result. Using atomics is slow, so we
+ # do the reduce step in a second kernel.
+ offset = index_a // TOP_K if A_TO_B else index_a
+ a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS)
+ b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ # Load the scale, if requested.
+ scale = tl.load(weights + index_a) if SCALE else 1
+
+ # Swap the pointers depending on the direction.
+ iptr = a if A_TO_B else b
+ optr = b if A_TO_B else a
+
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ x = tl.load(iptr + offsets, mask=mask)
+ x = x.to(tl.float32) * scale.to(tl.float32)
+
+ tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask)
+
+ offsets += BLOCK_X
+
+
+def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+ assert_equal(bin_ids.shape[0], x.shape[0] * top_k)
+ assert_equal(bins.size(), padded_bins.size())
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ # NOTE: Because of the padding, the output size is dynamic.
+ # We load the final padded bin bound to get the output rows.
+ output_rows = padded_bins[-1].cpu().item()
+ out = torch.zeros((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ x,
+ out,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def gather(x, indices, bin_ids, weights, bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+ assert_equal(bin_ids.shape[0], x.shape[0] * top_k)
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ # NOTE: There is no padding so the output rows equals the
+ # input rows multiplied by top_k.
+ output_rows = x.shape[0] * top_k
+ out = torch.empty((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ x,
+ out,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], bin_ids.shape[0])
+ assert_equal(bins.size(), padded_bins.size())
+
+ if weights is not None:
+ assert_equal(indices.shape[0], weights.shape[0])
+
+ tokens = indices.shape[0] // top_k
+ out = torch.empty((tokens, top_k, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ out,
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=False,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+
+ # Reduce along the top-k dimension, if needed.
+ return out.sum(dim=1) if top_k > 1 else out.view(tokens, x.shape[1])
+
+
+def scatter(x, indices, bin_ids, weights, bins, top_k):
+ return padded_scatter(x, indices, bin_ids, weights, bins, bins, top_k)
+
+
+# x: (tokens, top_k, hidden_size), real
+# grad: (tokens, hidden_size), real.
+# wgrad: (tokens, top_k), real.
+# indices: (tokens * top_k), integer.
+# bin_ids: (tokens * top_k), integer.
+# bins: (num_experts), integer.
+# padded_bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _padded_copy_wgrad(
+ x,
+ grad,
+ wgrad,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+):
+ # Our index into 'tokens * top_k'.
+ index_out = tl.load(indices + tl.program_id(0))
+
+ # One threadblock per row in 'a'. Array 'b' has greater or equal
+ # number of rows since they could be padded.
+ bin_idx = tl.load(bin_ids + tl.program_id(0))
+
+ # Now we know what bin we're assigned to, but we need to know how
+ # many threadblocks were assigned to earlier bins so we can offset
+ # in our bin properly.
+ offset_in_bin = tl.program_id(0)
+ if bin_idx > 0:
+ offset_in_bin -= tl.load(bins + bin_idx - 1)
+
+ # Load the starting index of our bin in array 'x'.
+ index_x = offset_in_bin
+ if bin_idx > 0:
+ index_x += tl.load(padded_bins + bin_idx - 1)
+
+ # Offset the input and output pointers.
+ wgrad += index_out
+ grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS)
+ x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ data = tl.load(x + offsets, mask=mask).to(tl.float32)
+ scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
+ acc += data * scale
+ offsets += BLOCK_X
+
+ # Reduce to get the final result and store.
+ out = tl.sum(acc).to(wgrad.dtype.element_ty)
+ tl.store(wgrad, out)
+
+
+def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_matrix(grad)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], bin_ids.shape[0])
+ assert_equal(bins.size(), padded_bins.size())
+
+ tokens = indices.shape[0] // top_k
+ out = torch.empty((tokens * top_k), dtype=x.dtype, device=x.device)
+ _padded_copy_wgrad[(indices.shape[0],)](
+ x,
+ grad,
+ out,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ TOP_K=top_k,
+ )
+ return out
+
+
+def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k):
+ return padded_scatter_wgrad(x, grad, indices, bin_ids, bins, bins, top_k)
+
+
+# a: (tokens, hidden_size), real.
+# b: (num_experts, expert_capacity, num_columns), real.
+# indices: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _binned_copy(
+ a,
+ b,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+ A_TO_B: tl.constexpr,
+ SCALE: tl.constexpr,
+):
+ # Load our indices into the output.
+ expert_idx = tl.program_id(0)
+ entry_idx = tl.program_id(1)
+
+ # Calculate our offset into the output.
+ index_b = expert_idx * expert_capacity + entry_idx
+
+ # Load the index bounds for our bin and calculate
+ # the number of tokens assigned to our expert.
+ start = 0
+ if expert_idx > 0:
+ start = tl.load(bins + expert_idx - 1)
+ end = tl.load(bins + expert_idx)
+ num_tokens = end - start
+
+ # Calculate our offset into the input. If we don't
+ # have an input exit early.
+ if entry_idx >= num_tokens:
+ return
+ index_a = tl.load(indices + start + entry_idx)
+
+ # Offset the input and output pointers.
+ #
+ # If we're going from A to B, divide the input index to copy
+ # the same input repeatedly. If we're going from B to A we
+ # need to reduce the result. Using atomics is slow, so we
+ # do the reduce step in a second kernel.
+ offset = index_a // TOP_K if A_TO_B else index_a
+ a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS)
+ b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ # Load the scale, if requested.
+ scale = tl.load(weights + index_a) if SCALE else 1
+
+ # Swap the pointers depending on the direction.
+ #
+ # NOTE: We need to zero the output in both directions.
+ iptr = a if A_TO_B else b
+ optr = b if A_TO_B else a
+
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ x = tl.load(iptr + offsets, mask=mask)
+ x = x.to(tl.float32) * scale.to(tl.float32)
+
+ tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask)
+
+ offsets += BLOCK_X
+
+
+def binned_gather(x, indices, weights, bins, expert_capacity, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ num_experts = bins.shape[0]
+ out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
+
+ _binned_copy[(num_experts, expert_capacity)](
+ x,
+ out,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def binned_scatter(x, indices, weights, bins, top_k):
+ # Validate the input shapes.
+ assert_is_tensor(x, 3)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(bins.shape[0], x.shape[0])
+
+ if weights is not None:
+ assert_equal(indices.shape[0], weights.shape[0])
+
+ num_experts, expert_capacity, hidden_size = x.shape
+ tokens = indices.shape[0] // top_k
+ out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
+ _binned_copy[(num_experts, expert_capacity)](
+ out,
+ x,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS=hidden_size,
+ A_TO_B=False,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+
+ # Reduce along the top-k dimension, if needed.
+ return out.sum(dim=1) if top_k > 1 else out.view(tokens, hidden_size)
+
+
+# a: (tokens, hidden_size), real.
+# b: (num_experts, expert_capacity, num_columns), real.
+# indices: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _binned_copy_wgrad(
+ x,
+ grad,
+ wgrad,
+ num_experts,
+ expert_capacity,
+ indices,
+ bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+):
+ # Load our indices into the output.
+ expert_idx = tl.program_id(0)
+ entry_idx = tl.program_id(1)
+
+ # Calculate our offset into the output.
+ index_x = expert_idx * expert_capacity + entry_idx
+
+ # Load the index bounds for our bin and calculate
+ # the number of tokens assigned to our expert.
+ start = 0
+ if expert_idx > 0:
+ start = tl.load(bins + expert_idx - 1)
+ end = tl.load(bins + expert_idx)
+ num_tokens = end - start
+
+ # Calculate our offset into the input. If we don't
+ # have an input exit early.
+ if entry_idx >= num_tokens:
+ return
+ index_out = tl.load(indices + start + entry_idx)
+
+ # Offset the input and output pointers.
+ wgrad += index_out
+ grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS)
+ x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ data = tl.load(x + offsets, mask=mask).to(tl.float32)
+ scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
+ acc += data * scale
+ offsets += BLOCK_X
+
+ # Reduce to get the final result and store.
+ out = tl.sum(acc).to(wgrad.dtype.element_ty)
+ tl.store(wgrad, out)
+
+
+def binned_scatter_wgrad(x, grad, indices, bins, top_k):
+ # Validate the input shapes.
+ assert_is_tensor(x, 3)
+ assert_is_matrix(grad)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(bins.shape[0], x.shape[0])
+
+ num_experts, expert_capacity, hidden_size = x.shape
+ tokens = indices.shape[0] // top_k
+ out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
+ _binned_copy_wgrad[(num_experts, expert_capacity)](
+ x,
+ grad,
+ out,
+ num_experts,
+ expert_capacity,
+ indices,
+ bins,
+ NUM_COLUMNS=hidden_size,
+ TOP_K=top_k,
+ )
+ return out
diff --git a/build/torch210-cxx11-cu128-x86_64-linux/benchmark_util.py b/build/torch210-cxx11-cu128-x86_64-linux/benchmark_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..02612d95e3ead1175a596e2878fa34b5bf85ad6f
--- /dev/null
+++ b/build/torch210-cxx11-cu128-x86_64-linux/benchmark_util.py
@@ -0,0 +1,35 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import numpy as np
+import torch
+
+
+def log_benchmark(name, arguments, time, std):
+ print('=' * 60)
+ print(f'{name} Benchmark')
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean time = {:.3f}ms, std time = {:.3f}ms'.format(time, std))
+ print('=' * 60)
+
+
+def benchmark_function(fn, iterations=100, warmup=10):
+ # Warmup iterations.
+ for _ in range(warmup):
+ fn()
+
+ times = []
+ for i in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+
+ start.record()
+ fn()
+ end.record()
+
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ return np.mean(times), np.std(times)
diff --git a/build/torch210-cxx11-cu128-x86_64-linux/cpu_fused_moe.py b/build/torch210-cxx11-cu128-x86_64-linux/cpu_fused_moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a29e81d6e9c57f64b0a78ac1c0828e45fd9d855
--- /dev/null
+++ b/build/torch210-cxx11-cu128-x86_64-linux/cpu_fused_moe.py
@@ -0,0 +1,311 @@
+# SPDX-License-Identifier: Apache-2.0
+# MegaBlocks CPU Fused MoE Implementation
+#
+# This is a pure Python/PyTorch implementation for CPU.
+# For better performance, consider using the C++ kernel implementation.
+#
+import torch
+import torch.nn.functional as F
+
+
+def swigluoai_activation(gate: torch.Tensor, up: torch.Tensor,
+ alpha: float = 1.702, limit: float = 7.0) -> torch.Tensor:
+ """
+ SwigluOAI activation function used in GptOss models.
+
+ Formula:
+ gate = clamp(gate, max=limit)
+ up = clamp(up, -limit, limit)
+ glu = gate * sigmoid(gate * alpha)
+ output = (up + 1) * glu
+
+ Args:
+ gate: Gate tensor from gate projection
+ up: Up tensor from up projection
+ alpha: Scaling factor for sigmoid (default: 1.702)
+ limit: Clamp limit (default: 7.0)
+
+ Returns:
+ Activated tensor
+ """
+ gate = gate.clamp(max=limit)
+ up = up.clamp(min=-limit, max=limit)
+ glu = gate * torch.sigmoid(gate * alpha)
+ return (up + 1) * glu
+
+
+def silu_and_mul_activation(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor:
+ """
+ SiLU (Swish) activation with element-wise multiplication.
+
+ Formula:
+ output = silu(gate) * up
+
+ Args:
+ gate: Gate tensor
+ up: Up tensor
+
+ Returns:
+ Activated tensor
+ """
+ return F.silu(gate) * up
+
+
+def route_tokens_cpu(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: torch.Tensor | None,
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_normalize_expert_weights: int | None = None,
+) -> tuple:
+ """
+ Route tokens to experts and compute expert weights and indices (CPU version).
+
+ Args:
+ x: Input tensor [batch, seq, hidden] or [tokens, hidden]
+ router_weight: Router weight [num_experts, hidden]
+ router_bias: Router bias [num_experts] or None
+ moe_top_k: Number of experts per token
+ moe_num_experts: Total number of experts
+ moe_normalize_expert_weights: Normalization order or None
+
+ Returns:
+ Tuple of (logits, expert_weights, expert_indices)
+ """
+ x_flat = x.view(-1, x.shape[-1])
+ logits = F.linear(x_flat, router_weight, router_bias)
+
+ if moe_top_k == 1:
+ expert_weights, expert_indices = logits.max(dim=-1, keepdim=True)
+ else:
+ expert_weights, expert_indices = torch.topk(logits, moe_top_k, dim=-1)
+
+ expert_weights = expert_weights.softmax(dim=-1)
+
+ if moe_normalize_expert_weights is not None:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+
+ return logits, expert_weights, expert_indices
+
+
+def cpu_fused_moe(
+ hidden_states: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ w1_bias: torch.Tensor | None = None,
+ w2_bias: torch.Tensor | None = None,
+ activation: str = "silu",
+ alpha: float = 1.702,
+ limit: float = 7.0,
+ is_interleaved: bool = True,
+) -> torch.Tensor:
+ """
+ CPU Fused MoE using PyTorch operations.
+
+ This implementation processes all experts in parallel using batched operations
+ instead of sequential for loops, which is more efficient on CPU.
+
+ Args:
+ hidden_states: [num_tokens, hidden_size]
+ w1: [num_experts, hidden_size, 2*inter_size] - gate_up_proj weights
+ w2: [num_experts, inter_size, hidden_size] - down_proj weights
+ topk_weights: [num_tokens, topk] - routing weights
+ topk_ids: [num_tokens, topk] - expert indices
+ w1_bias: [num_experts, 2*inter_size] or None
+ w2_bias: [num_experts, hidden_size] or None
+ activation: "silu" or "swigluoai"
+ alpha: swigluoai alpha parameter
+ limit: swigluoai limit parameter
+ is_interleaved: whether gate_up is interleaved [g0,u0,g1,u1,...] (True for GptOss)
+
+ Returns:
+ output: [num_tokens, hidden_size]
+ """
+ num_tokens, hidden_size = hidden_states.shape
+ num_experts = w1.shape[0]
+ inter_size = w2.shape[1]
+ topk = topk_weights.shape[1]
+
+ # Initialize output
+ output = torch.zeros_like(hidden_states)
+
+ # Build expert mask: which tokens go to which expert
+ # expert_mask[expert_id] contains indices of (token_idx, topk_pos) pairs
+ for expert_idx in range(num_experts):
+ # Find tokens assigned to this expert
+ # mask shape: [num_tokens, topk], True where topk_ids == expert_idx
+ mask = (topk_ids == expert_idx)
+
+ if not mask.any():
+ continue
+
+ # Get token indices and topk positions
+ token_indices, topk_positions = torch.where(mask)
+
+ if len(token_indices) == 0:
+ continue
+
+ # Gather input tokens for this expert
+ # current_hidden: [num_selected_tokens, hidden_size]
+ current_hidden = hidden_states[token_indices]
+
+ # Get weights for this expert
+ # w1[expert_idx]: [hidden_size, 2*inter_size]
+ # w2[expert_idx]: [inter_size, hidden_size]
+ expert_w1 = w1[expert_idx] # [hidden_size, 2*inter_size]
+ expert_w2 = w2[expert_idx] # [inter_size, hidden_size]
+
+ # First projection: hidden @ w1 -> [num_selected, 2*inter_size]
+ gate_up = current_hidden @ expert_w1
+
+ # Add bias if present
+ if w1_bias is not None:
+ gate_up = gate_up + w1_bias[expert_idx]
+
+ # Split gate and up projections
+ if is_interleaved:
+ # GptOss uses interleaved layout: [g0, u0, g1, u1, ...]
+ gate = gate_up[..., ::2] # [num_selected, inter_size]
+ up = gate_up[..., 1::2] # [num_selected, inter_size]
+ else:
+ # Standard layout: [gate_all, up_all]
+ gate = gate_up[..., :inter_size]
+ up = gate_up[..., inter_size:]
+
+ # Apply activation
+ if activation == "swigluoai":
+ activated = swigluoai_activation(gate, up, alpha, limit)
+ else: # silu
+ activated = silu_and_mul_activation(gate, up)
+
+ # Second projection: activated @ w2 -> [num_selected, hidden_size]
+ expert_out = activated @ expert_w2
+
+ # Add bias if present
+ if w2_bias is not None:
+ expert_out = expert_out + w2_bias[expert_idx]
+
+ # Apply routing weights and accumulate
+ # weights shape: [num_selected]
+ weights = topk_weights[token_indices, topk_positions].unsqueeze(-1)
+ weighted_out = expert_out * weights
+
+ # Accumulate to output
+ output.index_add_(0, token_indices, weighted_out.to(output.dtype))
+
+ return output
+
+
+class MegaBlocksMoeMLP(torch.nn.Module):
+ """
+ CPU MoE MLP module that can be used as a drop-in replacement for
+ the transformers GptOssMLP when using @use_kernel_forward_from_hub.
+ """
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> tuple:
+ """
+ Forward pass through the MoE layer.
+
+ Args:
+ x: Input tensor of shape [batch_size, seq_len, hidden_size] or [tokens, hidden_size]
+
+ Returns:
+ Tuple of (output, expert_weights) where:
+ - output: Tensor of same shape as input
+ - expert_weights: Expert weights for each token [tokens, top_k]
+ """
+ # Get MoE parameters from the wrapped modules
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+
+ # Detect activation type
+ if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
+ activation = "swigluoai"
+ alpha = self.experts.alpha
+ limit = self.experts.limit
+ else:
+ activation = getattr(self.experts, "activation", "silu")
+ alpha = 1.702
+ limit = 7.0
+
+ # Get weight tensors
+ if hasattr(self.experts, "gate_up_proj"):
+ w1 = self.experts.gate_up_proj
+ is_interleaved = True # GptOss uses interleaved layout
+ elif hasattr(self.experts, "w1"):
+ w1 = self.experts.w1
+ w3 = getattr(self.experts, "w3", None)
+ if w3 is not None:
+ w1 = torch.cat([w1, w3], dim=-1)
+ is_interleaved = False
+ else:
+ raise AttributeError("experts module must have 'gate_up_proj' or 'w1' attribute")
+
+ if hasattr(self.experts, "down_proj"):
+ w2 = self.experts.down_proj
+ elif hasattr(self.experts, "w2"):
+ w2 = self.experts.w2
+ else:
+ raise AttributeError("experts module must have 'down_proj' or 'w2' attribute")
+
+ # Get optional bias tensors
+ w1_bias = getattr(self.experts, "gate_up_proj_bias", None)
+ w2_bias = getattr(self.experts, "down_proj_bias", None)
+
+ # Store original shape
+ in_shape = x.size()
+
+ # Route tokens to experts
+ logits, expert_weights, expert_indices = route_tokens_cpu(
+ x,
+ self.router.weight,
+ getattr(self.router, "bias", None),
+ moe_top_k,
+ moe_num_experts,
+ moe_normalize_expert_weights,
+ )
+
+ # Reshape input for fused MoE
+ x_flat = x.view(-1, x.shape[-1])
+
+ # Call CPU fused MoE
+ output = cpu_fused_moe(
+ hidden_states=x_flat,
+ w1=w1,
+ w2=w2,
+ topk_weights=expert_weights,
+ topk_ids=expert_indices,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ activation=activation,
+ alpha=alpha,
+ limit=limit,
+ is_interleaved=is_interleaved,
+ )
+
+ # Restore original shape
+ output = output.view(in_shape)
+
+ return output, expert_weights
+
+
+# Export classes and functions
+__all__ = [
+ "MegaBlocksMoeMLP",
+ "cpu_fused_moe",
+ "route_tokens_cpu",
+ "swigluoai_activation",
+ "silu_and_mul_activation",
+]
diff --git a/build/torch210-cxx11-cu128-x86_64-linux/cpu_moe_cpp.py b/build/torch210-cxx11-cu128-x86_64-linux/cpu_moe_cpp.py
new file mode 100644
index 0000000000000000000000000000000000000000..073ff66d24ce348fbb5ed19c9027fadd3f7a9c61
--- /dev/null
+++ b/build/torch210-cxx11-cu128-x86_64-linux/cpu_moe_cpp.py
@@ -0,0 +1,265 @@
+# SPDX-License-Identifier: Apache-2.0
+# MegaBlocks C++ Optimized CPU MoE
+
+"""
+C++ accelerated MoE with brgemm optimization for Intel AMX.
+Direct replacement for cpu_fused_moe.MegaBlocksMoeMLP with better performance.
+"""
+
+import torch
+from typing import Optional
+from .cpu_fused_moe import route_tokens_cpu
+from ._ops import ops
+
+
+def _to_local_tensor(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
+ """Convert DTensor to local torch.Tensor if needed for custom ops compatibility."""
+ if tensor is None:
+ return None
+ # Check if it's a DTensor by looking for the to_local() method
+ if hasattr(tensor, "to_local"):
+ return tensor.to_local()
+ return tensor
+
+
+def fused_moe_cpp(
+ hidden_states: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ inplace: bool = False,
+ use_int8_w8a8: bool = False,
+ use_fp8_w8a16: bool = False,
+ use_mxfp4: bool = False,
+ w1_scale: Optional[torch.Tensor] = None,
+ w2_scale: Optional[torch.Tensor] = None,
+ block_size: Optional[list] = None,
+ a1_scale: Optional[torch.Tensor] = None,
+ a2_scale: Optional[torch.Tensor] = None,
+ w1_bias: Optional[torch.Tensor] = None,
+ w2_bias: Optional[torch.Tensor] = None,
+ alpha: Optional[float] = None,
+ limit: Optional[float] = None,
+ is_vnni: bool = False,
+) -> torch.Tensor:
+ """
+ C++ Fused MoE with brgemm optimization (sglang compatible interface).
+
+ Uses at::native::cpublas::brgemm for efficient batch GEMM on Intel CPUs.
+ Supports both silu_and_mul (standard SwiGLU) and swigluoai (GptOss) activations.
+
+ Args:
+ hidden_states: Input tensor [M, K]
+ w1: Gate and up projections [E, 2N, K]
+ w2: Down projection [E, K, N]
+ topk_weights: Expert weights [M, topk]
+ topk_ids: Expert indices [M, topk]
+ inplace: Whether to use hidden_states as output
+ use_int8_w8a8: Use int8 quantization
+ use_fp8_w8a16: Use fp8 quantization
+ use_mxfp4: Use mxfp4 quantization
+ w1_scale, w2_scale: Quantization scales
+ block_size: Block size for fp8
+ a1_scale, a2_scale: Activation scales
+ w1_bias, w2_bias: Optional biases
+ alpha: swigluoai alpha parameter (set to enable swiglu)
+ limit: swigluoai limit parameter (set to enable swiglu)
+ is_vnni: Whether w1/w2 are already in VNNI packed format
+ """
+ # MXFP4/FP8 kernels only support bf16, convert if needed
+ orig_dtype = hidden_states.dtype
+ need_convert = ((use_mxfp4 or use_fp8_w8a16) and orig_dtype != torch.bfloat16) or orig_dtype == torch.float32
+ if need_convert:
+ hidden_states = hidden_states.to(torch.bfloat16)
+
+ # bias must match hidden_states dtype
+ if w1_bias is not None:
+ w1_bias = w1_bias.to(hidden_states.dtype)
+ if w2_bias is not None:
+ w2_bias = w2_bias.to(hidden_states.dtype)
+
+ # Convert DTensor to local tensor for custom ops compatibility (TP mode)
+ hidden_states = _to_local_tensor(hidden_states)
+ w1 = _to_local_tensor(w1)
+ w2 = _to_local_tensor(w2)
+ topk_weights = _to_local_tensor(topk_weights)
+ topk_ids = _to_local_tensor(topk_ids)
+ w1_scale = _to_local_tensor(w1_scale)
+ w2_scale = _to_local_tensor(w2_scale)
+ a1_scale = _to_local_tensor(a1_scale)
+ a2_scale = _to_local_tensor(a2_scale)
+ w1_bias = _to_local_tensor(w1_bias)
+ w2_bias = _to_local_tensor(w2_bias)
+
+ output = ops.fused_experts(
+ hidden_states, w1, w2, topk_weights, topk_ids,
+ inplace, use_int8_w8a8, use_fp8_w8a16, use_mxfp4,
+ w1_scale, w2_scale, block_size, a1_scale, a2_scale,
+ w1_bias, w2_bias, alpha, limit, is_vnni
+ )
+
+ # Convert back to original dtype if needed
+ if need_convert:
+ output = output.to(orig_dtype)
+ return output
+
+
+class CPUMegaBlocksMoeMLP(torch.nn.Module):
+ """
+ C++ optimized MoE MLP using brgemm.
+ Drop-in replacement for cpu_fused_moe.MegaBlocksMoeMLP with better performance.
+
+ Usage in transformers:
+ # Will be used via @use_kernel_forward_from_hub decorator
+ """
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> tuple:
+ """
+ Forward pass through the MoE layer using C++ kernel.
+
+ Args:
+ x: Input tensor [batch_size, seq_len, hidden_size]
+
+ Returns:
+ Tuple of (output, expert_weights)
+ """
+ # Optimization for GPT-OSS model
+ if getattr(self, "use_mxfp4", None) is None:
+ self.use_mxfp4 = False
+
+ w1_scale = None
+ w2_scale = None
+
+ if (
+ not getattr(self, "packed_scales", False)
+ and hasattr(self.experts, "gate_up_proj")
+ and getattr(self.experts, "gate_up_proj_precision_config", None) is not None
+ ):
+ # convert scales
+ data_1 = ops.convert_scale_packed(self.experts.gate_up_proj_precision_config.weight_scale.data.transpose(-1, -2).contiguous())
+ data_2 = ops.convert_scale_packed(self.experts.down_proj_precision_config.weight_scale.data.transpose(-1, -2).contiguous())
+ self.experts.gate_up_proj_precision_config.weight_scale.storage.data = data_1
+ self.experts.down_proj_precision_config.weight_scale.storage.data = data_2
+ self.packed_scales = True
+ self.use_mxfp4 = True
+
+ if not getattr(self, "packed_weight", False) and hasattr(
+ self.experts, "gate_up_proj"
+ ):
+ # convert weights
+ data_1 = self.experts.gate_up_proj.data.transpose(-1, -2).contiguous()
+ data_2 = self.experts.down_proj.data.transpose(-1, -2).contiguous()
+ if self.use_mxfp4:
+ self.experts.gate_up_proj.storage.data = ops.convert_weight_packed(data_1)
+ self.experts.down_proj.storage.data = ops.convert_weight_packed(data_2)
+ else:
+ # convert_weight_packed only supports bfloat16, float16, int8, fp8_e4m3 or uint8(mxfp4 or int4).
+ data_1 = data_1.to(torch.bfloat16) if data_1.dtype == torch.float32 else data_1
+ data_2 = data_2.to(torch.bfloat16) if data_2.dtype == torch.float32 else data_2
+ self.experts.gate_up_proj.data = ops.convert_weight_packed(data_1)
+ self.experts.down_proj.data = ops.convert_weight_packed(data_2)
+
+ # C++ kernel does not support float32.
+ dtype = torch.bfloat16 if x.dtype == torch.float32 else x.dtype
+ if getattr(self.experts, "gate_up_proj_bias", None) is not None:
+ self.experts.gate_up_proj_bias.data = self.experts.gate_up_proj_bias.data.to(dtype)
+ if getattr(self.experts, "down_proj_bias", None) is not None:
+ self.experts.down_proj_bias.data = self.experts.down_proj_bias.data.to(dtype)
+
+ self.packed_weight = True
+
+ # Get MoE parameters
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
+
+ # Detect activation type
+ if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
+ activation = "swigluoai"
+ alpha = self.experts.alpha
+ limit = self.experts.limit
+ else:
+ activation = getattr(self.experts, "activation", "silu")
+ alpha = 1.702
+ limit = 7.0
+
+ # Get weight tensors
+ if hasattr(self.experts, "gate_up_proj"):
+ w1 = self.experts.gate_up_proj
+ elif hasattr(self.experts, "w1"):
+ w1 = self.experts.w1
+ w3 = getattr(self.experts, "w3", None)
+ if w3 is not None:
+ w1 = torch.cat([w1, w3], dim=-1)
+ else:
+ raise AttributeError("experts module must have 'gate_up_proj' or 'w1' attribute")
+
+ if hasattr(self.experts, "down_proj"):
+ w2 = self.experts.down_proj
+ elif hasattr(self.experts, "w2"):
+ w2 = self.experts.w2
+ else:
+ raise AttributeError("experts module must have 'down_proj' or 'w2' attribute")
+
+ # Get optional bias tensors
+ w1_bias = getattr(self.experts, "gate_up_proj_bias", None)
+ w2_bias = getattr(self.experts, "down_proj_bias", None)
+ w1_bias = w1_bias if w1_bias is None else w1_bias.data
+ w2_bias = w2_bias if w2_bias is None else w2_bias.data
+
+ if self.use_mxfp4:
+ w1_scale = self.experts.gate_up_proj_precision_config.weight_scale.data
+ w2_scale = self.experts.down_proj_precision_config.weight_scale.data
+
+ # Store original shape
+ in_shape = x.size()
+
+ # Route tokens to experts (Python implementation is fast enough)
+ logits, expert_weights, expert_indices = route_tokens_cpu(
+ x,
+ self.router.weight,
+ getattr(self.router, "bias", None),
+ moe_top_k,
+ moe_num_experts,
+ moe_normalize_expert_weights,
+ )
+
+ # Flatten input
+ x_flat = x.view(-1, x.shape[-1])
+
+ # Determine alpha/limit for swiglu activation
+ use_alpha = alpha if activation == "swigluoai" else None
+ use_limit = limit if activation == "swigluoai" else None
+
+ # Call C++ optimized kernel
+ output = fused_moe_cpp(
+ hidden_states=x_flat,
+ w1=w1.data,
+ w2=w2.data,
+ topk_weights=expert_weights,
+ topk_ids=expert_indices.to(torch.int32),
+ inplace=False,
+ use_int8_w8a8=False,
+ use_fp8_w8a16=False,
+ use_mxfp4=self.use_mxfp4,
+ w1_scale=w1_scale,
+ w2_scale=w2_scale,
+ block_size=None,
+ a1_scale=None,
+ a2_scale=None,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ alpha=use_alpha,
+ limit=use_limit,
+ is_vnni=getattr(self, "packed_weight", False),
+ )
+
+ # Restore original shape
+ output = output.view(in_shape)
+
+ return output, expert_weights
+
+
+__all__ = ["fused_moe_cpp", "MegaBlocksMoeMLP"]
diff --git a/build/torch210-cxx11-cu128-x86_64-linux/grouped_gemm/__init__.py b/build/torch210-cxx11-cu128-x86_64-linux/grouped_gemm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b91c8308f0c24f4c4171b6e4f15b6f76dabf295a
--- /dev/null
+++ b/build/torch210-cxx11-cu128-x86_64-linux/grouped_gemm/__init__.py
@@ -0,0 +1,2 @@
+from . import ops
+from . import backend
diff --git a/build/torch210-cxx11-cu128-x86_64-linux/grouped_gemm/backend.py b/build/torch210-cxx11-cu128-x86_64-linux/grouped_gemm/backend.py
new file mode 100644
index 0000000000000000000000000000000000000000..76037d8039cbfc2f0577275c78e4bc0be762592a
--- /dev/null
+++ b/build/torch210-cxx11-cu128-x86_64-linux/grouped_gemm/backend.py
@@ -0,0 +1,33 @@
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# # TODO(tgale): Wrap this in a try-block with better
+# # error message and instructions for building the
+# # c++ operations.
+# import grouped_gemm_backend as backend
+
+# We import the backend operations from the megablocks package as
+# grouped_gemm is vendored in megablocks in this repository.
+# from ... import _ops as backend
+# from megablocks._ops import ops as backend # type: ignore
+from .._ops import ops as backend # type: ignore
+
+def _allocate_output(a, b, batch_sizes, trans_a, trans_b):
+ assert not (trans_a and trans_b)
+ assert batch_sizes.ndim == 1, "Expected 1d tensor for batch_sizes"
+ assert a.ndim == 2, "Expected 2d tensor for 'a'"
+ assert b.ndim == (2 if trans_a else 3)
+
+ shape = (
+ (batch_sizes.shape[0], a.shape[1], b.shape[1])
+ if trans_a else
+ (a.shape[0], (b.shape[1] if trans_b else b.shape[2]))
+ )
+ return torch.empty(*shape, device=a.device, dtype=a.dtype)
+
+def gmm(a, b, batch_sizes, trans_a=False, trans_b=False, c=None):
+ if c is None:
+ c = _allocate_output(a, b, batch_sizes, trans_a, trans_b)
+ backend.gmm(a, b, c, batch_sizes, trans_a, trans_b)
+ return c
diff --git a/build/torch210-cxx11-cu128-x86_64-linux/grouped_gemm/ops.py b/build/torch210-cxx11-cu128-x86_64-linux/grouped_gemm/ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b30dd14e23837ea3b12334f4e31337ed9ad2b69
--- /dev/null
+++ b/build/torch210-cxx11-cu128-x86_64-linux/grouped_gemm/ops.py
@@ -0,0 +1,33 @@
+from . import backend
+import torch
+
+
+class GroupedGemm(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, a, b, batch_sizes, trans_b):
+ ctx.save_for_backward(a, b, batch_sizes)
+ ctx.trans_b = trans_b
+ return backend.gmm(a, b, batch_sizes, trans_a=False, trans_b=trans_b)
+
+ @staticmethod
+ def backward(ctx, grad):
+ grad = grad.contiguous()
+ a, b, batch_sizes = ctx.saved_tensors
+ trans_b = ctx.trans_b
+
+ agrad = None
+ if ctx.needs_input_grad[0]:
+ agrad = backend.gmm(
+ grad, b, batch_sizes, trans_a=False, trans_b=not trans_b)
+
+ bgrad = None
+ if ctx.needs_input_grad[1]:
+ lhs, rhs = (grad, a) if trans_b else (a, grad)
+ bgrad = backend.gmm(
+ lhs, rhs, batch_sizes, trans_a=True, trans_b=False)
+ return agrad, bgrad, None, None
+
+
+def gmm(a, b, batch_sizes, trans_b=False):
+ return GroupedGemm.apply(a, b, batch_sizes, trans_b)
diff --git a/build/torch210-cxx11-cu128-x86_64-linux/grouped_gemm_util.py b/build/torch210-cxx11-cu128-x86_64-linux/grouped_gemm_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d6d49fc46b0a57ad46e4179df3cc1ac2a24f7ae
--- /dev/null
+++ b/build/torch210-cxx11-cu128-x86_64-linux/grouped_gemm_util.py
@@ -0,0 +1,31 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+import warnings
+
+_grouped_gemm_is_available: bool = False
+try:
+ # import grouped_gemm
+ pass
+ _grouped_gemm_is_available = True
+except ImportError as error:
+ warnings.warn('Grouped GEMM not available.')
+
+
+def grouped_gemm_is_available():
+ return _grouped_gemm_is_available
+
+
+def assert_grouped_gemm_is_available():
+ msg = (
+ 'Grouped GEMM not available. Please run '
+ '`pip install git+https://github.com/tgale96/grouped_gemm@main`.',
+ )
+ assert _grouped_gemm_is_available, msg
+
+
+# backend = grouped_gemm.backend if grouped_gemm_is_available() else None
+# ops = grouped_gemm.ops if grouped_gemm_is_available() else None
+
+
+from .grouped_gemm import backend as ops
+from .grouped_gemm import ops as backend
diff --git a/build/torch210-cxx11-cu128-x86_64-linux/layers.py b/build/torch210-cxx11-cu128-x86_64-linux/layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c91edcd9e2d1a4ef9eac90217ff481f08ab1886
--- /dev/null
+++ b/build/torch210-cxx11-cu128-x86_64-linux/layers.py
@@ -0,0 +1,1232 @@
+import torch
+import torch.distributed as dist
+
+from typing import Optional, Any, TYPE_CHECKING
+
+from . import _layers
+from . import ops
+
+# Conditional import for meta kernel registration
+if TYPE_CHECKING:
+
+ def register_fake(fn):
+ return lambda name: fn
+
+else:
+ try:
+ from torch.library import register_fake
+ except ImportError:
+ try:
+ from torch.library import impl_abstract as register_fake
+ except ImportError:
+ # Fallback for older PyTorch versions
+ def register_fake(op_name):
+ def decorator(fn):
+ return fn
+
+ return decorator
+
+
+# Meta kernel implementations for torch.compile compatibility
+def _install_meta_kernels():
+ """Install meta kernels for existing MegaBlocks operations"""
+
+ # Create wrapper functions that check for compilation and return meta tensors
+
+ # Patch ops.sort
+ if hasattr(ops, "sort"):
+ original_sort = ops.sort
+
+ def sort_with_meta(x, end_bit=None):
+ if torch.compiler.is_compiling():
+ # print("Using meta kernel for sort")
+ # Meta implementation - return tensors with correct shape/dtype/device
+ return torch.empty_like(x), torch.empty_like(x)
+ # print("Using original sort kernel")
+ return original_sort(x, end_bit)
+
+ ops.sort = sort_with_meta
+
+ # Patch ops.histogram
+ if hasattr(ops, "histogram"):
+ original_histogram = ops.histogram
+
+ def histogram_with_meta(x, max_val):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty((max_val,), dtype=torch.int32, device=x.device)
+ return original_histogram(x, max_val)
+
+ ops.histogram = histogram_with_meta
+
+ # Patch ops.inclusive_cumsum
+ if hasattr(ops, "inclusive_cumsum"):
+ original_inclusive_cumsum = ops.inclusive_cumsum
+
+ def inclusive_cumsum_with_meta(x, dim):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty_like(x)
+ return original_inclusive_cumsum(x, dim)
+
+ ops.inclusive_cumsum = inclusive_cumsum_with_meta
+
+ # Patch ops.binned_gather
+ if hasattr(ops, "binned_gather"):
+ original_binned_gather = ops.binned_gather
+
+ def binned_gather_with_meta(x, indices, bins, bin_size, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - output shape based on bin_size
+ if x.dim() >= 2:
+ hidden_size = x.size(-1)
+ return torch.empty(
+ (bin_size, x.size(1), hidden_size),
+ dtype=x.dtype,
+ device=x.device,
+ )
+ else:
+ return torch.empty((bin_size,), dtype=x.dtype, device=x.device)
+ return original_binned_gather(x, indices, bins, bin_size, top_k)
+
+ ops.binned_gather = binned_gather_with_meta
+
+ # Patch ops.binned_scatter
+ if hasattr(ops, "binned_scatter"):
+ original_binned_scatter = ops.binned_scatter
+
+ def binned_scatter_with_meta(x, indices, weights, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - typically reduces to 2D
+ if x.dim() >= 3:
+ return torch.empty(
+ (x.size(1), x.size(2)), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty_like(x)
+ return original_binned_scatter(x, indices, weights, bins, top_k)
+
+ ops.binned_scatter = binned_scatter_with_meta
+
+ # Patch ops.gather
+ if hasattr(ops, "gather"):
+ original_gather = ops.gather
+
+ def gather_with_meta(x, indices, bin_ids, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ if x.dim() >= 2:
+ hidden_size = x.size(-1)
+ return torch.empty(
+ (indices.numel(), hidden_size), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty(indices.shape, dtype=x.dtype, device=x.device)
+ return original_gather(x, indices, bin_ids, bins, top_k)
+
+ ops.gather = gather_with_meta
+
+ # Patch ops.scatter
+ if hasattr(ops, "scatter"):
+ original_scatter = ops.scatter
+
+ def scatter_with_meta(x, indices, bin_ids, weights, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - restore sequence shape
+ seq_len = (
+ indices.size(0) // top_k
+ if indices.numel() > 0 and top_k > 0
+ else x.size(0)
+ )
+ if x.dim() >= 2:
+ return torch.empty(
+ (seq_len, x.size(-1)), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty((seq_len,), dtype=x.dtype, device=x.device)
+ return original_scatter(x, indices, bin_ids, weights, bins, top_k)
+
+ ops.scatter = scatter_with_meta
+
+ # Patch ops.replicate
+ if hasattr(ops, "replicate"):
+ original_replicate = ops.replicate
+
+ def replicate_with_meta(x, bins, num_outputs):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty(
+ (x.shape[0], num_outputs), dtype=x.dtype, device=x.device
+ )
+ return original_replicate(x, bins, num_outputs)
+
+ ops.replicate = replicate_with_meta
+
+ # Patch ops.repeat (if it's a regular function)
+ if hasattr(ops, "repeat"):
+ original_repeat = ops.repeat
+
+ def repeat_with_meta(x, repeats):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ if isinstance(repeats, (tuple, list)):
+ new_shape = list(x.shape)
+ for i, rep in enumerate(repeats):
+ if i < len(new_shape):
+ new_shape[i] *= rep
+ return torch.empty(new_shape, dtype=x.dtype, device=x.device)
+ else:
+ new_shape = [x.size(0) * repeats] + list(x.shape[1:])
+ return torch.empty(new_shape, dtype=x.dtype, device=x.device)
+ return original_repeat(x, repeats)
+
+ ops.repeat = repeat_with_meta
+
+
+# Install meta kernels on import
+try:
+ _install_meta_kernels()
+except Exception as e:
+ # If meta kernel installation fails, continue without them
+ # torch.compile may not work but the library will still function
+ import warnings
+
+ warnings.warn(
+ f"Failed to install meta kernels for torch.compile support: {e}", UserWarning
+ )
+
+
+# Set the expert model parallel attributes on a tensor
+def set_expert_model_parallel_attributes(
+ tensor: torch.Tensor,
+ is_parallel: bool,
+):
+ assert not hasattr(tensor, "expert_model_parallel")
+ setattr(tensor, "expert_model_parallel", is_parallel)
+
+
+# Get the expert model parallel attributes from a tensor
+def expert_sharding_degree(
+ world_size: int,
+ moe_num_experts: int,
+) -> int:
+ esd = min(world_size, moe_num_experts)
+ if (moe_num_experts % esd) != 0:
+ raise ValueError(f"Cannot shard {moe_num_experts} experts {esd} ways.")
+ return esd
+
+
+# Calculate the hidden sharding degree based on world size and expert sharding degree
+def hidden_sharding_degree(
+ world_size: int,
+ moe_num_experts: int,
+ ffn_hidden_size: int,
+) -> int:
+ esd = expert_sharding_degree(world_size, moe_num_experts)
+ hsd = world_size // esd
+ if (ffn_hidden_size % hsd) != 0:
+ raise ValueError(f"Cannot shard {ffn_hidden_size} features {hsd} ways.")
+ if (esd * hsd) != world_size:
+ raise ValueError(
+ f"Invalid sharding. expert_sharding_degree ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size})."
+ )
+ return hsd
+
+
+# Calculate the number of experts per rank based on world size and expert sharding degree
+def experts_per_rank(
+ moe_num_experts: int,
+ world_size: int,
+) -> int:
+ return moe_num_experts // expert_sharding_degree(world_size, moe_num_experts)
+
+
+# Calculate the number of features per rank based on ffn hidden size and hidden sharding degree
+def features_per_rank(
+ ffn_hidden_size: int, world_size: int, moe_num_experts: int
+) -> int:
+ return ffn_hidden_size // hidden_sharding_degree(
+ world_size, moe_num_experts, ffn_hidden_size
+ )
+
+
+# Apply jitter to the input tensor
+def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor:
+ low = 1.0 - moe_jitter_eps
+ high = 1.0 + moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return x * (low + noise * (high - low))
+
+
+# Compute the top-k scores from the logits
+def compute_top_k(scores: torch.Tensor, moe_top_k: int):
+ if moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, moe_top_k, dim=-1)
+
+
+# Route tokens to experts and compute expert weights and indices
+def route_tokens(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: torch.Tensor,
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ if training and moe_jitter_eps is not None:
+ x = apply_jitter(x, moe_jitter_eps)
+
+ x_flat = x.view(-1, x.shape[-1])
+ logits = torch.nn.functional.linear(x_flat, router_weight, router_bias)
+ expert_weights, expert_indices = compute_top_k(logits, moe_top_k)
+ expert_weights = expert_weights.softmax(dim=-1)
+ if moe_normalize_expert_weights is not None:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+ if uniform_expert_assignment:
+ expert_indices = _layers.router._uniform_expert_assignment(
+ expert_indices,
+ moe_num_experts,
+ )
+
+ return logits, expert_weights, expert_indices
+
+
+# Scale the gradient of the weights
+def scale_grad(
+ w: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+) -> torch.Tensor:
+ if gradient_scale is None:
+ return w
+ return _layers.mlp.scale_gradient(w, gradient_scale)
+
+
+# Forward pass for the MLP layer
+def mlp_forward(
+ x: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ limit: float = 7.0,
+):
+ # Scale weights
+ w1 = scale_grad(w1, gradient_scale)
+ w2 = scale_grad(w2, gradient_scale)
+ w1_bias = scale_grad(w1_bias, gradient_scale)
+ w2_bias = scale_grad(w2_bias, gradient_scale)
+
+ # Resolve dtensors
+ w1 = _layers.mlp.resolve_dtensor(w1)
+ w2 = _layers.mlp.resolve_dtensor(w2)
+ w1_bias = _layers.mlp.resolve_dtensor(w1_bias)
+ w2_bias = _layers.mlp.resolve_dtensor(w2_bias)
+
+ # Forward pass
+ gate_up = torch.bmm(x, w1) + w1_bias[..., None, :]
+ gate, up = gate_up[..., ::2], gate_up[..., 1::2]
+ gate = gate.clamp(min=None, max=limit)
+ up = up.clamp(min=-limit, max=limit)
+ glu = gate * torch.sigmoid(gate * alpha)
+ next_states = torch.bmm(((up + 1) * glu), w2)
+ next_states += w2_bias[..., None, :]
+ return next_states
+
+# Shared expert MLP forward pass
+def shared_mlp_forward(
+ x: torch.Tensor,
+ up_proj_weight: torch.Tensor,
+ down_proj_weight: torch.Tensor,
+ up_proj_bias: Optional[torch.Tensor] = None,
+ down_proj_bias: Optional[torch.Tensor] = None,
+ activation_fn: Optional[Any] = None,
+ gradient_scale: Optional[float] = None,
+) -> torch.Tensor:
+ # Default activation function
+ if activation_fn is None:
+ activation_fn = torch.nn.functional.gelu
+
+ # Scale weights
+ up_proj_weight = scale_grad(up_proj_weight, gradient_scale)
+ down_proj_weight = scale_grad(down_proj_weight, gradient_scale)
+ if up_proj_bias is not None:
+ up_proj_bias = scale_grad(up_proj_bias, gradient_scale)
+ if down_proj_bias is not None:
+ down_proj_bias = scale_grad(down_proj_bias, gradient_scale)
+
+ # Resolve dtensors
+ up_proj_weight = _layers.mlp.resolve_dtensor(up_proj_weight)
+ down_proj_weight = _layers.mlp.resolve_dtensor(down_proj_weight)
+ if up_proj_bias is not None:
+ up_proj_bias = _layers.mlp.resolve_dtensor(up_proj_bias)
+ if down_proj_bias is not None:
+ down_proj_bias = _layers.mlp.resolve_dtensor(down_proj_bias)
+
+ # Up projection
+ x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias)
+
+ # Activation
+ x = activation_fn(x)
+
+ # Down projection
+ x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias)
+
+ return x
+
+
+# Combine outputs from shared expert and regular experts
+def combine_expert_shared_outputs(
+ shared_expert_out: torch.Tensor,
+ expert_out: torch.Tensor,
+ shared_expert_weighted_sum: bool = False,
+ moe_top_k: int = 1,
+) -> torch.Tensor:
+ if shared_expert_weighted_sum:
+ # Weighted sum based on number of experts used
+ total_experts = moe_top_k + 1
+ shared_weight = 1.0 / total_experts
+ expert_weight = moe_top_k / total_experts
+ return shared_expert_out * shared_weight + expert_out * expert_weight
+ else:
+ # Simple addition
+ return shared_expert_out + expert_out
+
+
+# Global variable to store load balancing loss
+_LOAD_BALANCING_LOSS = []
+
+
+def save_load_balancing_loss(loss):
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.append(loss)
+
+
+def get_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ return _LOAD_BALANCING_LOSS
+
+
+def clear_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.clear()
+
+
+def batched_load_balancing_loss(args):
+ if args.moe_loss_weight == 0:
+ return 0.0
+
+ tokens_per_expert, expert_scores = zip(*get_load_balancing_loss())
+ num_layers_per_pipeline_stage = args.num_layers // args.pipeline_model_parallel_size
+ if args.num_layers_per_virtual_pipeline_stage is not None:
+ num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage
+
+ if len(tokens_per_expert) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f"Expected {num_layers_per_pipeline_stage} token_per_experts "
+ f"but found {len(tokens_per_expert)}.\nnum_layers = "
+ f"{args.num_layers}\npipeline_model_parallel_size = "
+ f"{args.pipeline_model_parallel_size}\n"
+ "num_layers_per_virtual_pipeline_stage"
+ f" = {args.num_layers_per_virtual_pipeline_stage}",
+ )
+ if len(expert_scores) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f"Expected {num_layers_per_pipeline_stage} expert_scores "
+ f"but found {len(tokens_per_expert)}.\nnum_layers = "
+ f"{args.num_layers}\npipeline_model_parallel_size = "
+ f"{args.pipeline_model_parallel_size}\n"
+ "num_layers_per_virtual_pipeline_stage"
+ f" = {args.num_layers_per_virtual_pipeline_stage}",
+ )
+
+ # Verify the shape of the tokens_per_expert and expert_scores tensors.
+ assert all(
+ (x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert)
+ )
+
+ tokens = expert_scores[0].shape[0]
+ assert all(
+ (
+ (
+ x.ndim == 2
+ and x.shape[1] == args.moe_num_experts
+ and x.shape[0] == tokens
+ )
+ for x in expert_scores
+ )
+ )
+
+ # Concatenate the contributions of each layer and convert to
+ # the correct types and formats for the dot product.
+ expert_scores = torch.cat(expert_scores, dim=1)
+ if args.moe_lbl_in_fp32:
+ expert_scores = expert_scores.float()
+ if tokens != 0:
+ expert_scores = expert_scores.mean(dim=0)
+ else:
+ expert_scores = expert_scores.sum(dim=0)
+ tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype)
+
+ expected_values = num_layers_per_pipeline_stage * args.moe_num_experts
+ assert tokens_per_expert.numel() == expected_values
+ assert expert_scores.numel() == expected_values
+
+ # Calculate the total scale across all factors.
+ #
+ # loss_weight * num_experts / (num_layers * tokens * top_k)
+ scale_numerator = args.moe_num_experts * args.moe_loss_weight
+ scale_denominator = args.num_layers * tokens * args.moe_top_k
+ scale = scale_numerator / scale_denominator
+ return scale * torch.dot(tokens_per_expert, expert_scores)
+
+
+# Calculate the expert capacity based on tokens, top_k, number of experts,
+# expert parallel group, capacity factor, and whether expert model parallelism is used.
+def expert_capacity(
+ tokens: int,
+ top_k: int,
+ num_experts: int,
+ expert_parallel_group: int,
+ moe_capacity_factor: float,
+ moe_expert_model_parallelism: bool,
+) -> int:
+ world_size = (
+ dist.get_world_size(expert_parallel_group)
+ if moe_expert_model_parallelism
+ else 1
+ )
+
+ tokens_per_expert = top_k * tokens * world_size / num_experts
+ return int(moe_capacity_factor * tokens_per_expert)
+
+
+def load_balancing_loss(
+ tokens_per_expert: torch.Tensor,
+ expert_scores: torch.Tensor,
+ top_k: int,
+ num_experts: int,
+):
+ assert len(expert_scores.size()) == 2
+ tokens, num_experts = expert_scores.size()
+ assert num_experts == num_experts
+ assert len(tokens_per_expert.size()) == 1
+ (num_experts,) = tokens_per_expert.size()
+ assert num_experts == num_experts
+ scale = num_experts / (tokens * top_k)
+ return scale * torch.dot(
+ tokens_per_expert.to(expert_scores.dtype),
+ expert_scores.mean(dim=0),
+ )
+
+
+def indices_and_bins(
+ top_expert: torch.Tensor,
+ sort_end_bit: int,
+ num_experts: int,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ top_expert = top_expert.int()
+
+ # Ensure contiguous memory layout
+ top_expert = top_expert.contiguous()
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(top_expert.device):
+ output = ops.sort(top_expert, sort_end_bit)
+ bin_ids, indices = output
+ tokens_per_expert = ops.histogram(top_expert, num_experts)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+
+ bins = bins.view(1) if not len(bins.size()) else bins
+ return indices, bin_ids, bins, tokens_per_expert
+
+
+def expert_capacity_fn(
+ tokens: int,
+ top_k: int,
+ num_experts: int,
+ expert_parallel_group: torch.distributed.ProcessGroup,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+) -> int:
+ world_size = (
+ dist.get_world_size(expert_parallel_group)
+ if moe_expert_model_parallelism
+ else 1
+ )
+ tokens_per_expert = top_k * tokens * world_size / num_experts
+ return int(moe_capacity_factor * tokens_per_expert)
+
+
+def permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ top_k,
+ w1,
+ w2,
+ w1_bias,
+ w2_bias,
+ gradient_scale,
+ alpha,
+):
+ # Route tokens to experts
+ x = x.view(-1, x.shape[-1])
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(x.device):
+ x = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
+
+ # Expert computation
+ x = mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale, alpha)
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(x.device):
+ # Route tokens back
+ out = ops.binned_scatter(x, indices, expert_weights, bins, top_k)
+ return out
+
+
+def forward_once(
+ x: torch.Tensor,
+ expert_weights: torch.Tensor,
+ top_experts: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ top_k: int = 4,
+ num_experts: int = 128,
+ expert_parallel_group: int = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ mlp_impl: Optional[str] = None,
+):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
+ top_experts, sort_end_bit, num_experts
+ )
+
+ # Calculate expert capacity
+ sl, bs, _ = x.size()
+
+ expert_capacity = expert_capacity_fn(
+ sl * bs,
+ top_k,
+ num_experts,
+ expert_parallel_group,
+ moe_capacity_factor,
+ moe_expert_model_parallelism,
+ )
+
+ if expert_capacity == 0:
+ expert_capacity = torch.max(tokens_per_expert).item()
+
+ x = permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ top_k,
+ w1,
+ w2,
+ w1_bias,
+ w2_bias,
+ gradient_scale,
+ alpha,
+ )
+ return x, tokens_per_expert
+
+
+def parallel_forward_once(
+ x: torch.Tensor,
+ expert_weights: torch.Tensor,
+ top_experts: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ top_k: int = 4,
+ num_experts: int = 128,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = True,
+ hidden_size: int = 1152,
+ mlp_impl: Optional[str] = "grouped",
+):
+ # Flatten inputs
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+
+ # TODO: remove debugging var
+ # my_rank = dist.get_rank(expert_parallel_group) if expert_parallel_group else 0
+
+ with torch.no_grad():
+ # Step 1: Local permutation setup
+ indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
+ top_experts, sort_end_bit, num_experts
+ )
+
+ # Calculate sharding parameters
+ world_size = dist.get_world_size(expert_parallel_group)
+ hidden_sharding_deg = hidden_sharding_degree(
+ world_size, num_experts, hidden_size
+ )
+ experts_per_rank_val = experts_per_rank(num_experts, world_size)
+
+ # Replicate token counts for hidden sharding
+ repeated_tokens_per_expert = ops.repeat(
+ tokens_per_expert, (hidden_sharding_deg,)
+ )
+
+ # Exchange token counts across devices
+ parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert)
+
+ # Ensure CUB knows which device to use
+ tpe_handle = dist.all_to_all_single(
+ parallel_tokens_per_expert,
+ repeated_tokens_per_expert,
+ group=expert_parallel_group,
+ async_op=True,
+ )
+
+ # Step 2: Local permutation - group tokens by target device
+ x = x.view(-1, x.shape[-1]) # [sl * bs, hs]
+ x = ops.gather(x, indices, bin_ids, bins, top_k)
+
+ # Step 3: Compute communication counts and exchange tokens
+ with torch.no_grad():
+ tpe_handle.wait()
+
+ # Reshape for per-device calculations
+ repeated_tokens_per_expert = repeated_tokens_per_expert.view(
+ world_size, experts_per_rank_val
+ )
+ parallel_tokens_per_expert = parallel_tokens_per_expert.view(
+ world_size, experts_per_rank_val
+ )
+
+ # Calculate send/recv counts
+ send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1).tolist()
+ # recv_counts = parallel_tokens_per_expert.cpu().sum(dim=-1).tolist()
+ parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
+ recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1).tolist()
+ tokens_received = sum(recv_counts)
+
+ # Replicate for hidden sharding
+ x = ops.repeat(x, (hidden_sharding_deg, 1))
+
+ # Cross-device token exchange
+ parallel_x, parallel_x_handle = _layers.all_to_all.all_to_all(
+ x, recv_counts, send_counts, expert_parallel_group, async_op=True
+ )
+
+ with torch.no_grad():
+ # Step 4: Setup for local expert computation
+ replicate_bins = ops.inclusive_cumsum(parallel_tokens_per_expert.flatten(), 0)
+ replicate_bins = (
+ replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins
+ )
+
+ # Create expert indices for received tokens
+ parallel_top_expert = torch.remainder(
+ torch.arange(
+ num_experts * hidden_sharding_deg,
+ dtype=torch.int32,
+ device=indices.device,
+ ),
+ experts_per_rank_val,
+ )
+ parallel_top_expert = ops.replicate(
+ parallel_top_expert.unsqueeze(dim=0),
+ replicate_bins,
+ tokens_received,
+ ).flatten()
+
+ # Sort tokens by expert assignment
+ parallel_bin_ids, parallel_indices = ops.sort(
+ parallel_top_expert,
+ sort_end_bit,
+ )
+
+ # Calculate bins for local experts
+ parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
+ dim=0, dtype=torch.int
+ )
+ parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = (
+ parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins
+ )
+
+ # Calculate expert capacity
+ expert_capacity = expert_capacity_fn(
+ tokens_received,
+ top_k,
+ experts_per_rank_val,
+ expert_parallel_group,
+ moe_capacity_factor,
+ moe_expert_model_parallelism,
+ )
+ if expert_capacity == 0:
+ expert_capacity = torch.max(parallel_tokens_per_expert).item()
+
+ # Locally permute the tokens and perform the expert computation.
+ # Block to make sure that the cross-device permutation is complete.
+ if mlp_impl == "grouped":
+ # GroupedMLP requires counts on CPU. We can use the tensor already
+ # moved to CPU for the prior all_to_all, which avoids an extra
+ # device synchronization.
+ parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+
+ # Step 5: Expert computation
+ parallel_x_handle.wait()
+
+ parallel_x = permute_and_compute(
+ parallel_x,
+ parallel_tokens_per_expert,
+ parallel_indices,
+ parallel_bin_ids,
+ None, # expert_weights
+ parallel_bins,
+ expert_capacity,
+ top_k=1,
+ w1=w1,
+ w2=w2,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ )
+
+ # Step 6: Reverse communication - send results back
+ x, _ = _layers.all_to_all.all_to_all(
+ parallel_x, send_counts, recv_counts, expert_parallel_group
+ )
+
+ # Step 7: Reduce across hidden sharding dimension
+ shape = (hidden_sharding_deg, -1, hidden_size)
+ x = x.view(shape).sum(dim=0)
+
+ # Step 8: Final local unpermutation
+ x = ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
+
+ return x, tokens_per_expert.flatten()
+
+
+def moe_forward(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: Optional[torch.Tensor],
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+ w1: torch.Tensor = None,
+ w2: torch.Tensor = None,
+ w1_bias: torch.Tensor = None,
+ w2_bias: torch.Tensor = None,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ forward_fn: Any = None,
+ hidden_size: int = None,
+ mlp_impl: str = "grouped",
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+
+ # Route tokens to experts
+ logits, expert_weights, expert_indices = route_tokens(
+ x,
+ router_weight,
+ router_bias,
+ moe_top_k,
+ moe_num_experts,
+ moe_jitter_eps,
+ moe_normalize_expert_weights,
+ uniform_expert_assignment,
+ training,
+ )
+
+ # Create router scores for output
+ router_scores = (
+ torch.zeros_like(logits)
+ .scatter_(1, expert_indices, expert_weights)
+ .transpose(0, 1)
+ )
+
+ in_shape = x.size()
+
+ # Prepare forward function arguments
+ forward_args = {
+ "x": x,
+ "expert_weights": expert_weights,
+ "top_experts": expert_indices,
+ "w1": w1,
+ "w2": w2,
+ "w1_bias": w1_bias,
+ "w2_bias": w2_bias,
+ "gradient_scale": gradient_scale,
+ "alpha": alpha,
+ "sort_end_bit": sort_end_bit,
+ "top_k": moe_top_k,
+ "num_experts": moe_num_experts,
+ "expert_parallel_group": expert_parallel_group,
+ "moe_capacity_factor": moe_capacity_factor,
+ "moe_expert_model_parallelism": moe_expert_model_parallelism,
+ "mlp_impl": mlp_impl,
+ }
+
+ # Add hidden_size for parallel forward
+ if moe_expert_model_parallelism and hidden_size is not None:
+ forward_args["hidden_size"] = hidden_size
+ elif moe_expert_model_parallelism and hidden_size is None:
+ # Infer hidden_size from input shape
+ forward_args["hidden_size"] = x.shape[-1]
+
+ # Compute expert outputs
+ x, tokens_per_expert = forward_fn(**forward_args)
+
+ # Save load balancing loss if needed
+ moe_loss_weight = 0.0 # Can be made configurable
+ if training and moe_loss_weight > 0:
+ save_load_balancing_loss((tokens_per_expert, logits))
+
+ # Restore original shape
+ x = x.view(in_shape)
+
+ return x, expert_weights, router_scores
+
+
+def moe_forward_with_shared_expert(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: Optional[torch.Tensor],
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+ w1: torch.Tensor = None,
+ w2: torch.Tensor = None,
+ w1_bias: torch.Tensor = None,
+ w2_bias: torch.Tensor = None,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ forward_fn: Any = None,
+ hidden_size: int = None,
+ mlp_impl: str = "grouped",
+ # Shared expert parameters
+ shared_up_proj_weight: Optional[torch.Tensor] = None,
+ shared_down_proj_weight: Optional[torch.Tensor] = None,
+ shared_up_proj_bias: Optional[torch.Tensor] = None,
+ shared_down_proj_bias: Optional[torch.Tensor] = None,
+ shared_expert_weighted_sum: bool = False,
+ shared_activation_fn: Optional[Any] = None,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+
+ # First, compute regular MoE forward pass
+ expert_out, expert_weights, router_scores = moe_forward(
+ x=x,
+ router_weight=router_weight,
+ router_bias=router_bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=training,
+ w1=w1,
+ w2=w2,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=moe_expert_model_parallelism,
+ forward_fn=forward_fn,
+ hidden_size=hidden_size,
+ mlp_impl=mlp_impl,
+ )
+
+ # If shared expert weights provided, compute shared expert output
+ if shared_up_proj_weight is not None and shared_down_proj_weight is not None:
+ shared_expert_out = shared_mlp_forward(
+ x=x,
+ up_proj_weight=shared_up_proj_weight,
+ down_proj_weight=shared_down_proj_weight,
+ up_proj_bias=shared_up_proj_bias,
+ down_proj_bias=shared_down_proj_bias,
+ activation_fn=shared_activation_fn,
+ gradient_scale=gradient_scale,
+ )
+
+ # Combine expert outputs
+ combined_out = combine_expert_shared_outputs(
+ shared_expert_out=shared_expert_out,
+ expert_out=expert_out,
+ shared_expert_weighted_sum=shared_expert_weighted_sum,
+ moe_top_k=moe_top_k,
+ )
+
+ return combined_out, expert_weights, router_scores
+
+ # Return regular MoE output if no shared expert
+ return expert_out, expert_weights, router_scores
+
+
+def create_shared_expert_weights(
+ hidden_size: int,
+ shared_expert_hidden_size: int,
+ device: torch.device,
+ dtype: torch.dtype,
+ init_method: Any,
+ output_layer_init_method: Any = None,
+) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
+
+ if output_layer_init_method is None:
+ output_layer_init_method = init_method
+
+ # Create weight tensors
+ up_proj_weight = torch.empty(
+ shared_expert_hidden_size,
+ hidden_size,
+ device=device,
+ dtype=dtype,
+ )
+ down_proj_weight = torch.empty(
+ hidden_size,
+ shared_expert_hidden_size,
+ device=device,
+ dtype=dtype,
+ )
+
+ # Initialize weights
+ init_method(up_proj_weight)
+ output_layer_init_method(down_proj_weight)
+
+ # No bias by default
+ return up_proj_weight, down_proj_weight, None, None
+
+
+# HACK: Extract device_mesh from pre-hook closure - required for transformers integration
+# This exists because device_mesh is trapped in hook closures with no model attribute
+# Fragile - breaks if hook structure changes or Python internals change
+# TODO: Replace with a more robust solution when available
+def get_device_mesh(model):
+ # Extract device_mesh from child's unused pre_hook closure
+ try:
+ # Find the pre-hook that contains 'device_mesh' in its closure
+ hook = next(
+ h
+ for h in model.experts._forward_pre_hooks.values()
+ if "device_mesh" in h.__code__.co_freevars
+ )
+ # Extract the device_mesh from the closure
+ return hook.__closure__[
+ hook.__code__.co_freevars.index("device_mesh")
+ ].cell_contents
+ except Exception:
+ return None
+
+
+class MegaBlocksMoeMLP(torch.nn.Module):
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ gradient_scale = getattr(self.experts, "gradient_scale", None)
+ alpha = getattr(self.experts, "alpha", 1.0)
+ moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+ uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
+
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = get_device_mesh(self)
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
+
+ has_parallel = (
+ expert_parallel_group is not None
+ and dist.is_initialized()
+ and dist.get_world_size(expert_parallel_group) > 1
+ )
+ forward_fn = parallel_forward_once if has_parallel else forward_once
+
+ sort_end_bit = max(
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
+ )
+ mlp_impl = getattr(self, "mlp_impl", "grouped")
+ output, expert_weights_out, *_ = moe_forward(
+ x=x,
+ router_weight=self.router.weight,
+ router_bias=self.router.bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=self.training,
+ w1=self.experts.gate_up_proj,
+ w2=self.experts.down_proj,
+ w1_bias=self.experts.gate_up_proj_bias,
+ w2_bias=self.experts.down_proj_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=has_parallel,
+ forward_fn=forward_fn,
+ hidden_size=self.experts.hidden_size,
+ mlp_impl=mlp_impl,
+ )
+ return output, expert_weights_out
+
+
+# Export main classes
+__all__ = ["MegaBlocksMoeMLP", "MegaBlocksMoeMLPWithSharedExpert"]
+
+
+class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
+
+ def __init__(self):
+ super().__init__()
+ # Shared expert weights will be set by the user
+ self.shared_up_proj_weight = None
+ self.shared_down_proj_weight = None
+ self.shared_up_proj_bias = None
+ self.shared_down_proj_bias = None
+ self.shared_expert_weighted_sum = False
+ self.shared_activation_fn = None
+
+ def set_shared_expert_weights(
+ self,
+ up_proj_weight: torch.Tensor,
+ down_proj_weight: torch.Tensor,
+ up_proj_bias: Optional[torch.Tensor] = None,
+ down_proj_bias: Optional[torch.Tensor] = None,
+ weighted_sum: bool = False,
+ activation_fn: Optional[Any] = None,
+ ):
+ self.shared_up_proj_weight = up_proj_weight
+ self.shared_down_proj_weight = down_proj_weight
+ self.shared_up_proj_bias = up_proj_bias
+ self.shared_down_proj_bias = down_proj_bias
+ self.shared_expert_weighted_sum = weighted_sum
+ self.shared_activation_fn = activation_fn
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ gradient_scale = getattr(self.experts, "gradient_scale", None)
+ alpha = getattr(self.experts, "alpha", 1.0)
+ moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+ uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
+
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = get_device_mesh(self)
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
+
+ has_parallel = (
+ expert_parallel_group is not None
+ and dist.is_initialized()
+ and dist.get_world_size(expert_parallel_group) > 1
+ )
+ forward_fn = parallel_forward_once if has_parallel else forward_once
+
+ sort_end_bit = max(
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
+ )
+ mlp_impl = getattr(self, "mlp_impl", "grouped")
+
+ output, expert_weights_out, *_ = moe_forward_with_shared_expert(
+ x=x,
+ router_weight=self.router.weight,
+ router_bias=self.router.bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=self.training,
+ w1=self.experts.gate_up_proj,
+ w2=self.experts.down_proj,
+ w1_bias=self.experts.gate_up_proj_bias,
+ w2_bias=self.experts.down_proj_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=has_parallel,
+ forward_fn=forward_fn,
+ hidden_size=self.experts.hidden_size,
+ mlp_impl=mlp_impl,
+ # Shared expert parameters
+ shared_up_proj_weight=self.shared_up_proj_weight,
+ shared_down_proj_weight=self.shared_down_proj_weight,
+ shared_up_proj_bias=self.shared_up_proj_bias,
+ shared_down_proj_bias=self.shared_down_proj_bias,
+ shared_expert_weighted_sum=self.shared_expert_weighted_sum,
+ shared_activation_fn=self.shared_activation_fn,
+ )
+ return output, expert_weights_out
+
+
+# Patch for XPU or CPU support
+if hasattr(torch, "xpu") and torch.xpu.is_available():
+ from .xpu_fused_moe import MegaBlocksMoeMLP
+
+from .cpu_moe_cpp import CPUMegaBlocksMoeMLP
diff --git a/build/torch210-cxx11-cu128-x86_64-linux/megablocks/__init__.py b/build/torch210-cxx11-cu128-x86_64-linux/megablocks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23
--- /dev/null
+++ b/build/torch210-cxx11-cu128-x86_64-linux/megablocks/__init__.py
@@ -0,0 +1,26 @@
+import ctypes
+import importlib.util
+import sys
+from pathlib import Path
+from types import ModuleType
+
+
+def _import_from_path(file_path: Path) -> ModuleType:
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
+ # it would also be used for other imports. So, we make a module name that
+ # depends on the path for it to be unique using the hex-encoded hash of
+ # the path.
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
+ module_name = path_hash
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ if spec is None:
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
+ module = importlib.util.module_from_spec(spec)
+ if module is None:
+ raise ImportError(f"Cannot load module {module_name} from spec")
+ sys.modules[module_name] = module
+ spec.loader.exec_module(module) # type: ignore
+ return module
+
+
+globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
diff --git a/build/torch210-cxx11-cu128-x86_64-linux/metadata.json b/build/torch210-cxx11-cu128-x86_64-linux/metadata.json
new file mode 100644
index 0000000000000000000000000000000000000000..e3e4edf582b7ffb515d0ed32e9fc9c89f125c441
--- /dev/null
+++ b/build/torch210-cxx11-cu128-x86_64-linux/metadata.json
@@ -0,0 +1,21 @@
+{
+ "version": 1,
+ "license": "Apache-2.0",
+ "python-depends": [],
+ "backend": {
+ "type": "cuda",
+ "archs": [
+ "10.0",
+ "10.1",
+ "12.0",
+ "7.0",
+ "7.2",
+ "7.5",
+ "8.0",
+ "8.6",
+ "8.7",
+ "8.9",
+ "9.0"
+ ]
+ }
+}
diff --git a/build/torch210-cxx11-cu128-x86_64-linux/ops/__init__.py b/build/torch210-cxx11-cu128-x86_64-linux/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b944080df810d0b0cfc571f3009b0098a651f9b7
--- /dev/null
+++ b/build/torch210-cxx11-cu128-x86_64-linux/ops/__init__.py
@@ -0,0 +1,35 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from .binned_gather import binned_gather
+from .binned_scatter import binned_scatter
+from .cumsum import exclusive_cumsum, inclusive_cumsum
+from .gather import gather
+from .histogram import histogram
+from .padded_gather import padded_gather
+from .padded_scatter import padded_scatter
+from .repeat import repeat
+from .replicate import replicate
+from .round_up import round_up
+from .scatter import scatter
+from .sort import sort
+from .sum import sum
+from .topology import topology
+
+__all__ = [
+ 'binned_gather',
+ 'binned_scatter',
+ 'exclusive_cumsum',
+ 'inclusive_cumsum',
+ 'gather',
+ 'histogram',
+ 'padded_gather',
+ 'padded_scatter',
+ 'repeat',
+ 'replicate',
+ 'round_up',
+ 'scatter',
+ 'sort',
+ 'sum',
+ 'topology',
+]
diff --git a/build/torch210-cxx11-cu128-x86_64-linux/ops/all_to_all_benchmark.py b/build/torch210-cxx11-cu128-x86_64-linux/ops/all_to_all_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c939818edca3345f6344bbc7cef07ffe3cd0181
--- /dev/null
+++ b/build/torch210-cxx11-cu128-x86_64-linux/ops/all_to_all_benchmark.py
@@ -0,0 +1,63 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import torch.distributed as dist
+
+# from megablocks import benchmark_util
+# from megablocks.layers.all_to_all import all_to_all
+
+from .. import benchmark_util
+from .._layers.all_to_all import all_to_all
+
+_ALL_TO_ALL_BENCHMARK = (
+ (8, 1024),
+ (16, 1024),
+ (32, 1024),
+ (64, 1024),
+ (128, 1024),
+ (256, 1024),
+ (512, 1024),
+ (1024, 1024),
+ (2 * 1024, 1024),
+ (4 * 1024, 1024),
+ (8 * 1024, 1024),
+ (16 * 1024, 1024),
+ (32 * 1024, 1024),
+ (64 * 1024, 1024),
+ (128 * 1024, 1024),
+ (256 * 1024, 1024),
+ (512 * 1024, 1024),
+ (1024 * 1024, 1024),
+)
+
+
+def benchmark_all_to_all(group, sl, hs):
+ world_size = dist.get_world_size(group)
+ assert (sl % world_size) == 0
+ send_recv_sizes = [sl // world_size] * world_size
+
+ x = torch.randn((sl, hs)).cuda().half()
+
+ details = {
+ 'world_size': world_size,
+ 'message_size (B)': send_recv_sizes[0] * hs * 2, # 2B elements.
+ }
+
+ def benchmark():
+ return all_to_all(x, send_recv_sizes, send_recv_sizes, group)
+
+ time, std = benchmark_util.benchmark_function(benchmark)
+
+ if dist.get_rank(group) == 0:
+ benchmark_util.log_benchmark('All-To-All', details, time, std)
+
+
+if __name__ == '__main__':
+ assert dist.is_available()
+ group = dist.init_process_group(backend='nccl')
+ local_rank = dist.get_rank(group)
+ torch.cuda.set_device(local_rank)
+
+ for args in _ALL_TO_ALL_BENCHMARK:
+ benchmark_all_to_all(group, *args)
diff --git a/build/torch210-cxx11-cu128-x86_64-linux/ops/binned_gather.py b/build/torch210-cxx11-cu128-x86_64-linux/ops/binned_gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..189a7fa3518d660f29ea32e7a04827164af98d60
--- /dev/null
+++ b/build/torch210-cxx11-cu128-x86_64-linux/ops/binned_gather.py
@@ -0,0 +1,37 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for binned_gather kernel.
+class BinnedGatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bins: torch.Tensor,
+ bin_size: int,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bins)
+ ctx.top_k = top_k
+ return kernels.binned_gather(x, indices, None, bins, bin_size, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ indices, bins = ctx.saved_tensors
+ out = kernels.binned_scatter(grad, indices, None, bins, ctx.top_k)
+ return out, None, None, None, None
+
+
+binned_gather = BinnedGatherOp.apply
diff --git a/build/torch210-cxx11-cu128-x86_64-linux/ops/binned_scatter.py b/build/torch210-cxx11-cu128-x86_64-linux/ops/binned_scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb937c0c106662ce8108c1cb926f8f063b163d3d
--- /dev/null
+++ b/build/torch210-cxx11-cu128-x86_64-linux/ops/binned_scatter.py
@@ -0,0 +1,59 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for binned_scatter kernel.
+class BinnedScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ):
+ assert len(x.size()) == 3
+ ctx.bin_size = x.size(1)
+ ctx.top_k = top_k
+
+ # TODO(tgale): Don't save 'x' for backwards if we don't need to
+ # calculate the gradient w.r.t. 'weights'.
+ ctx.save_for_backward(x, indices, weights, bins)
+ return kernels.binned_scatter(x, indices, weights, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ x, indices, weights, bins = ctx.saved_tensors
+ out = kernels.binned_gather(
+ grad,
+ indices,
+ weights,
+ bins,
+ ctx.bin_size,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[2]:
+ wgrad = kernels.binned_scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bins,
+ ctx.top_k,
+ )
+ return out, None, wgrad, None, None
+
+
+binned_scatter = BinnedScatterOp.apply
diff --git a/build/torch210-cxx11-cu128-x86_64-linux/ops/cumsum.py b/build/torch210-cxx11-cu128-x86_64-linux/ops/cumsum.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2b7572391e20045d335cf7337246e8a9b9f57ef
--- /dev/null
+++ b/build/torch210-cxx11-cu128-x86_64-linux/ops/cumsum.py
@@ -0,0 +1,52 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ # import megablocks_ops as ops # type: ignore
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrappers for cumsum kernels.
+# NOTE: Does not support gradients.
+class ExclusiveCumsumOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, dim: int):
+ if len(x.size()) == 1:
+ x = x.view([1, -1])
+ out = torch.empty_like(x)
+ ops.exclusive_cumsum(x, 1, out)
+ return out.squeeze()
+ out = torch.empty_like(x)
+ ops.exclusive_cumsum(x, dim, out)
+ return out
+
+
+exclusive_cumsum = ExclusiveCumsumOp.apply
+
+
+class InclusiveCumsumOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, dim: int) -> torch.Tensor:
+ if len(x.size()) == 1:
+ x = x.view([1, -1])
+ out = torch.empty_like(x)
+ ops.inclusive_cumsum(x, 1, out)
+ return out.squeeze()
+ out = torch.empty_like(x)
+ ops.inclusive_cumsum(x, dim, out)
+ return out
+
+
+inclusive_cumsum = InclusiveCumsumOp.apply
diff --git a/build/torch210-cxx11-cu128-x86_64-linux/ops/gather.py b/build/torch210-cxx11-cu128-x86_64-linux/ops/gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1f87c1e7bed8d3589dd790805234976e0b05898
--- /dev/null
+++ b/build/torch210-cxx11-cu128-x86_64-linux/ops/gather.py
@@ -0,0 +1,38 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for gather kernel.
+class GatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bin_ids, bins)
+ ctx.top_k = top_k
+ return kernels.gather(x, indices, bin_ids, None, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+
+ indices, bin_ids, bins = ctx.saved_tensors
+ out = kernels.scatter(grad, indices, bin_ids, None, bins, ctx.top_k)
+ return out, None, None, None, None, None
+
+
+gather = GatherOp.apply
diff --git a/build/torch210-cxx11-cu128-x86_64-linux/ops/histogram.py b/build/torch210-cxx11-cu128-x86_64-linux/ops/histogram.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b3f058ec373cbba7555704fb5e4212c3cc75d9d
--- /dev/null
+++ b/build/torch210-cxx11-cu128-x86_64-linux/ops/histogram.py
@@ -0,0 +1,27 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for histogram kernel.
+# NOTE: Does not support gradients.
+class HistogramOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, max_val: float):
+ return ops.histogram(x, max_val)
+
+
+histogram = HistogramOp.apply
diff --git a/build/torch210-cxx11-cu128-x86_64-linux/ops/histogram_benchmark.py b/build/torch210-cxx11-cu128-x86_64-linux/ops/histogram_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0b9c6047567b87a295979498142230d1b0c9411
--- /dev/null
+++ b/build/torch210-cxx11-cu128-x86_64-linux/ops/histogram_benchmark.py
@@ -0,0 +1,78 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import numpy as np
+import torch
+# from absl.testing import parameterized
+
+from .. import ops
+
+_HISTOGRAM_TESTS = (
+ (16384, torch.int32, 2),
+ (16384, torch.int32, 4),
+ (16384, torch.int32, 8),
+ (16384, torch.int32, 16),
+ (16384, torch.int32, 32),
+ (16384, torch.int32, 64),
+ (16384, torch.int32, 128),
+ (16384, torch.int32, 256),
+)
+
+
+def benchmark_function(fn, iterations=10):
+ # Run once to get rid of startup overhead.
+ fn()
+ times = []
+ for _ in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+ start.record()
+ fn()
+ end.record()
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ times = np.array(times)
+ return times.mean(), times.std(), times.max(), times.min()
+
+
+def log_benchmark(arguments, mean_t, std_t):
+ print('=' * 60)
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t))
+ print('=' * 60)
+
+
+# class HistogramBenchmark(parameterized.TestCase):
+#
+# @parameterized.parameters(*_HISTOGRAM_TESTS)
+# def testHistogram(self, n, dtype, max_val):
+# x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
+#
+# mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.histogram(x, max_val),)
+# arguments = {
+# 'n': n,
+# 'dtype': dtype,
+# 'max_val': max_val,
+# }
+# log_benchmark(arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_HISTOGRAM_TESTS)
+# def testTorchHistogram(self, n, dtype, max_val):
+# x = torch.randint(0, 128, (n,)).cuda().to(dtype)
+#
+# mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.histc(x, max_val, 0, max_val - 1),)
+# arguments = {
+# 'n': n,
+# 'dtype': dtype,
+# 'max_val': max_val,
+# }
+# log_benchmark(arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch210-cxx11-cu128-x86_64-linux/ops/matmul_benchmark.py b/build/torch210-cxx11-cu128-x86_64-linux/ops/matmul_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c4b9b8866ffed2eb769b77f2320c82e5491ae0e
--- /dev/null
+++ b/build/torch210-cxx11-cu128-x86_64-linux/ops/matmul_benchmark.py
@@ -0,0 +1,415 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+
+# import stk
+
+# try:
+# import stk
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/ops/matmul_benchmark.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+# from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+
+# Calling tensor.t() calls tensor.transpose(0, 1) which calls
+# torch.as_strided(...). Circumvent this chain to avoid an overhead
+# this adds.
+def transpose_view(x):
+ return torch.as_strided(
+ x,
+ (x.shape[1], x.shape[0]),
+ (x.stride()[1], x.stride()[0]),
+ )
+
+
+_MATMUL_TESTS = (
+ (64 * 1024, 512, 2048, 64),
+ (32 * 1024, 768, 3072, 64),
+ (8 * 1024, 1024, 4096, 64),
+ (4 * 2048, 4096, 4 * 4096, 4),
+)
+
+
+def log_benchmark(name, arguments, time, std, flops):
+ benchmark_util.log_benchmark(name, arguments, time, std)
+ print('flops = {:.2f}B'.format(flops / 1e9))
+ print('throughput = {:.2f}T'.format(flops / 1e9 / time))
+ print('=' * 60)
+
+
+# class MatmulBenchmark(parameterized.TestCase):
+#
+# def build_sparse_matrix(self, x, padded_bins, fhs, ne):
+# blocking = 128
+# padded_tokens, _ = x.size()
+# assert padded_tokens % blocking == 0
+# assert fhs % blocking == 0
+#
+# # Offsets for the sparse matrix. All rows have the
+# # same number of nonzero blocks dictated by the
+# # dimensionality of a single expert.
+# block_rows = padded_tokens // blocking
+# blocks_per_row = fhs // blocking
+# offsets = torch.arange(
+# 0,
+# block_rows * blocks_per_row + 1,
+# blocks_per_row,
+# dtype=torch.int32,
+# device=x.device,
+# )
+#
+# # Indices for the sparse matrix. The indices for
+# # the intermediate matrix are dynamic depending
+# # on the mapping of tokens to experts.
+# column_indices = ops.topology(
+# padded_bins,
+# blocking,
+# block_rows,
+# blocks_per_row,
+# )
+# data = torch.empty(
+# column_indices.numel(),
+# blocking,
+# blocking,
+# dtype=torch.float16,
+# device=x.device,
+# )
+# shape = (padded_tokens, fhs * ne)
+# row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
+# return stk.Matrix(shape, data, row_indices, column_indices, offsets)
+#
+# def build_input_matrix(self, sl, hs, ne):
+# x = torch.randn((sl, hs)).cuda().half()
+#
+# # Assign tokens to experts uniformly.
+# top_expert = torch.arange(0, sl).cuda().int() % ne
+#
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(top_expert, ne)
+# padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+# padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+# out = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, 1)
+# return out, padded_bins
+#
+# def build_weight_matrix(self, ne, hs, fhs):
+# return torch.randn((hs, ne * fhs)).cuda().half()
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_Fwd_SDD_NT(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+# w = transpose_view(w)
+#
+# def benchmark():
+# return stk.ops.sdd(x, w, topo)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0::Fwd::SDD::NT',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_GradX_DSD_NN(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+#
+# def benchmark():
+# return stk.ops.dsd(topo, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0::GradX::DSD::NN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_GradW_DSD_TN(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+# topo = topo.t()
+#
+# def benchmark():
+# return stk.ops.dsd(topo, x)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0::GradW::DSD::TN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_Fwd_DSD_NN(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+#
+# def benchmark():
+# return stk.ops.dsd(x, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::Fwd::DSD::NN',
+# arguments,
+# mean_t,
+# std_t,
+# x.nnz * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_GradX_SDD_NT(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+# out = stk.ops.dsd(x, w)
+# w = transpose_view(w)
+#
+# def benchmark():
+# return stk.ops.sdd(out, w, x)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::GradX::SDD::NT',
+# arguments,
+# mean_t,
+# std_t,
+# x.nnz * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_GradW_DSD_TN(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+# out = stk.ops.dsd(x, w)
+# x = x.t()
+#
+# def benchmark():
+# return stk.ops.dsd(x, out)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::GradW::DSD::TN',
+# arguments,
+# mean_t,
+# std_t,
+# x.nnz * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_Fwd_DDD_NT(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, hs)).cuda().half()
+# w = torch.randn((ne, hs, fhs)).cuda().half()
+#
+# w = w.transpose(1, 2).contiguous()
+# w = w.transpose(1, 2)
+#
+# def benchmark():
+# return torch.bmm(x, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0::Fwd:DDD::NT',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_GradX_DDD_NN(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, hs)).cuda().half()
+# w = torch.randn((ne, hs, fhs)).cuda().half()
+# out = torch.bmm(x, w)
+# w = w.transpose(1, 2).contiguous()
+#
+# def benchmark():
+# return torch.bmm(out, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0:GradX:DDD::NN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_GradW_DDD_TN(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, hs)).cuda().half()
+# w = torch.randn((ne, hs, fhs)).cuda().half()
+# out = torch.bmm(x, w)
+# out = out.transpose(1, 2)
+#
+# def benchmark():
+# return torch.bmm(out, x)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0:GradW:DDD::TN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_Fwd_DDD_NN(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+# w = torch.randn((ne, fhs, hs)).cuda().half()
+#
+# def benchmark():
+# return torch.bmm(x, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::Fwd::DDD::NN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_GradX_DDD_NT(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+# w = torch.randn((ne, fhs, hs)).cuda().half()
+# out = torch.bmm(x, w)
+# w = torch.transpose(w, 1, 2)
+#
+# def benchmark():
+# return torch.bmm(out, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::GradX::DDD::NT',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_GradW_DDD_TN(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+# w = torch.randn((ne, fhs, hs)).cuda().half()
+# out = torch.bmm(x, w)
+# x = torch.transpose(x, 1, 2)
+#
+# def benchmark():
+# return torch.bmm(x, out)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::GradW::DDD::TN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * hs * 2,
+# )
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch210-cxx11-cu128-x86_64-linux/ops/padded_gather.py b/build/torch210-cxx11-cu128-x86_64-linux/ops/padded_gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1cf4047c9494394d2a3884ba8830179013db7ff
--- /dev/null
+++ b/build/torch210-cxx11-cu128-x86_64-linux/ops/padded_gather.py
@@ -0,0 +1,55 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for padded_gather kernel.
+class PaddedGatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bin_ids, bins, padded_bins)
+ ctx.top_k = top_k
+ return kernels.padded_gather(
+ x,
+ indices,
+ bin_ids,
+ None,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+
+ indices, bin_ids, bins, padded_bins = ctx.saved_tensors
+ out = kernels.padded_scatter(
+ grad,
+ indices,
+ bin_ids,
+ None,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+ return out, None, None, None, None, None
+
+
+padded_gather = PaddedGatherOp.apply
diff --git a/build/torch210-cxx11-cu128-x86_64-linux/ops/padded_scatter.py b/build/torch210-cxx11-cu128-x86_64-linux/ops/padded_scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..61e021b81497e472cda5d72bdac557a0ca92d262
--- /dev/null
+++ b/build/torch210-cxx11-cu128-x86_64-linux/ops/padded_scatter.py
@@ -0,0 +1,98 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for padded_scatter kernel.
+class PaddedScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+ ):
+ maybe_x = [x] if ctx.needs_input_grad[3] else []
+ ctx.save_for_backward(
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ *maybe_x,
+ )
+ ctx.top_k = top_k
+ ctx.x_shape = x.shape
+ return kernels.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ saved_tensors = ctx.saved_tensors
+
+ indices, bin_ids, weights, bins, padded_bins = saved_tensors[:5]
+ dgrad = None
+ if ctx.needs_input_grad[0]:
+ dgrad = kernels.padded_gather(
+ grad,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[3]: # need wgrad
+ x = saved_tensors[-1]
+ wgrad = kernels.padded_scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+ return dgrad, None, None, wgrad, None, None, None, None
+
+
+def padded_scatter(
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+):
+ return PaddedScatterOp.apply(
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
diff --git a/build/torch210-cxx11-cu128-x86_64-linux/ops/padded_scatter_benchmark.py b/build/torch210-cxx11-cu128-x86_64-linux/ops/padded_scatter_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..3cbe4735891446b46f93170c64c23fe63632bf93
--- /dev/null
+++ b/build/torch210-cxx11-cu128-x86_64-linux/ops/padded_scatter_benchmark.py
@@ -0,0 +1,66 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import torch
+# from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+_PADDED_SCATTER_BENCHMARK = (
+ # dMoE-Medium, 8-way EMP.
+ (1024 * 16, 1024, 8, 4),
+ # dMoE-Medium, post-all-to-all.
+ (1024 * 16 * 4, 1024, 8, 1),
+)
+
+
+# class PaddedScatterTest(parameterized.TestCase):
+#
+# @parameterized.parameters(*_PADDED_SCATTER_BENCHMARK)
+# def testPaddedScatter(self, sl, hs, ne, top_k):
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+#
+# # Randomly assign tokens to experts.
+# top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(top_expert, ne)
+# padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+# padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+#
+# # Sample weights for the scatter reduce.
+# weights = torch.rand((sl * top_k,)).cuda().half()
+#
+# # Gather the data to prepare for backwards.
+# x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
+#
+# def benchmark():
+# return ops.padded_scatter(
+# x,
+# indices,
+# bin_ids,
+# weights,
+# bins,
+# padded_bins,
+# top_k,
+# )
+#
+# time, std = benchmark_util.benchmark_function(benchmark)
+# benchmark_util.log_benchmark(
+# 'Padded Scatter',
+# {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# 'top_k': top_k,
+# },
+# time,
+# std,
+# )
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch210-cxx11-cu128-x86_64-linux/ops/permute_benchmark.py b/build/torch210-cxx11-cu128-x86_64-linux/ops/permute_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..697abddbb3a2082ec4ddd6d94f89f7faabb34b40
--- /dev/null
+++ b/build/torch210-cxx11-cu128-x86_64-linux/ops/permute_benchmark.py
@@ -0,0 +1,149 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import torch
+# from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+_PERMUTE_TESTS = (
+ (16384, 768, 2),
+ (16384, 768, 4),
+ (16384, 768, 8),
+ (16384, 768, 16),
+ (16384, 768, 32),
+ (16384, 768, 64),
+ (16384, 768, 128),
+ (16384 * 8, 768, 2),
+ (16384 * 8, 768, 4),
+ (16384 * 8, 768, 8),
+ (16384 * 8, 768, 16),
+ (16384 * 8, 768, 32),
+ (16384 * 8, 768, 64),
+ (16384 * 8, 768, 128),
+)
+
+
+# class PermuteBenchmark(parameterized.TestCase):
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testBinnedGather(self, sl, hs, ne):
+# # NOTE: Capacity factor == 1.
+# ec = sl // ne
+#
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+# top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(indices, ne)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+#
+# def benchmark():
+# return ops.binned_gather(x, indices, bins, ec)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('BinnedGather', arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testBinnedScatter(self, sl, hs, ne):
+# # NOTE: Capacity factor == 1.
+# ec = sl // ne
+#
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+# top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(indices, ne)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+# x = ops.binned_gather(x, indices, bins, ec)
+#
+# def benchmark():
+# return ops.binned_scatter(x, indices, bins)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('BinnedScatter', arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testPaddedGather(self, sl, hs, ne):
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+#
+# # Randomly assign tokens to experts.
+# top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(top_expert, ne)
+# padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+# padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+#
+# def benchmark():
+# return ops.padded_gather(x, indices, bin_ids, bins, padded_bins)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('PaddedGather', arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testPaddedScatter(self, sl, hs, ne):
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+#
+# # Randomly assign tokens to experts.
+# top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(top_expert, ne)
+# padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+# padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+# x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins)
+#
+# def benchmark():
+# return ops.padded_scatter(x, indices, bin_ids, bins, padded_bins)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('PaddedScatter', arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testCopy(self, sl, hs, ne):
+# # NOTE: Capacity factor == 1.
+# # ec = sl // ne
+#
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+# y = x.clone()
+#
+# def benchmark():
+# return y.copy_(x)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('Copy', arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch210-cxx11-cu128-x86_64-linux/ops/repeat.py b/build/torch210-cxx11-cu128-x86_64-linux/ops/repeat.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e9e09de5f857d51cd758ab30b2f3a846d6f9275
--- /dev/null
+++ b/build/torch210-cxx11-cu128-x86_64-linux/ops/repeat.py
@@ -0,0 +1,10 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+
+def repeat(x: torch.Tensor, tiling: torch.Size):
+ if all((t == 1 for t in tiling)):
+ return x
+ return x.repeat(*tiling)
diff --git a/build/torch210-cxx11-cu128-x86_64-linux/ops/replicate.py b/build/torch210-cxx11-cu128-x86_64-linux/ops/replicate.py
new file mode 100644
index 0000000000000000000000000000000000000000..26daf0eede330603a4b8ea7167faf1411d07ca93
--- /dev/null
+++ b/build/torch210-cxx11-cu128-x86_64-linux/ops/replicate.py
@@ -0,0 +1,36 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for replicate kernel.
+class ReplicateOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, bins: torch.Tensor, num_outputs: int):
+ ctx.save_for_backward(bins)
+ out = torch.empty((x.shape[0], num_outputs), dtype=x.dtype, device=x.device)
+ ops.replicate_forward(x, bins, out)
+ return out
+
+ @staticmethod
+ def backward(ctx: Any, grad: torch.Tensor):
+ bins, = ctx.saved_tensors
+ out = torch.empty((grad.shape[0], bins.shape[0]), dtype=grad.dtype, device=grad.device)
+ ops.replicate_backward(grad, bins, out)
+ return out, None, None
+
+
+replicate = ReplicateOp.apply
diff --git a/build/torch210-cxx11-cu128-x86_64-linux/ops/round_up.py b/build/torch210-cxx11-cu128-x86_64-linux/ops/round_up.py
new file mode 100644
index 0000000000000000000000000000000000000000..6cf6bc873c9f448c5fa9126ebcfd66e8688002af
--- /dev/null
+++ b/build/torch210-cxx11-cu128-x86_64-linux/ops/round_up.py
@@ -0,0 +1,14 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+
+def round_up(x: torch.Tensor, value: int):
+ assert isinstance(value, int)
+ assert x.dtype == torch.int32
+
+ # TODO(tgale): If this becomes and issue
+ # do this in a custom kernel. We only expect
+ # to use this on arrays of less than 1k elements.
+ return torch.div(x + (value - 1), value, rounding_mode='trunc') * value
diff --git a/build/torch210-cxx11-cu128-x86_64-linux/ops/scatter.py b/build/torch210-cxx11-cu128-x86_64-linux/ops/scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4605d9b46f387761b070352365f223dbfe69d47
--- /dev/null
+++ b/build/torch210-cxx11-cu128-x86_64-linux/ops/scatter.py
@@ -0,0 +1,72 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Optional
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for scatter kernel.
+class ScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ) -> torch.Tensor:
+ maybe_x = [x] if ctx.needs_input_grad[3] else []
+ ctx.save_for_backward(indices, bin_ids, weights, bins, *maybe_x)
+ ctx.top_k = top_k
+ ctx.x_shape = x.shape
+ return kernels.scatter(x, indices, bin_ids, weights, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ saved_tensors = ctx.saved_tensors
+
+ indices, bin_ids, weights, bins = saved_tensors[:4]
+ dgrad = None
+ if ctx.needs_input_grad[0]:
+ dgrad = kernels.gather(
+ grad,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[3]: # need wgrad
+ x = saved_tensors[-1]
+ wgrad = kernels.scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bin_ids,
+ bins,
+ ctx.top_k,
+ )
+ return dgrad, None, None, wgrad, None, None, None
+
+
+def scatter(
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+) -> Optional[torch.Tensor]:
+ return ScatterOp.apply(x, indices, bin_ids, weights, bins, top_k)
diff --git a/build/torch210-cxx11-cu128-x86_64-linux/ops/sort.py b/build/torch210-cxx11-cu128-x86_64-linux/ops/sort.py
new file mode 100644
index 0000000000000000000000000000000000000000..bda3bf64283e39533c2eae3627e76bb2d0262c9f
--- /dev/null
+++ b/build/torch210-cxx11-cu128-x86_64-linux/ops/sort.py
@@ -0,0 +1,38 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Optional, Tuple
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+_BITS_FOR_DTYPE = {
+ torch.int16: 16,
+ torch.int32: 32,
+ torch.int64: 64,
+}
+
+
+# Autograd wrapper for sort kernel.
+# NOTE: Does not support gradients.
+class SortOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, end_bit: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
+ if end_bit is None:
+ end_bit = _BITS_FOR_DTYPE[x.dtype]
+ x_out = torch.empty_like(x)
+ iota_out = torch.empty_like(x)
+ ops.sort(x, end_bit, x_out, iota_out)
+ return (x_out, iota_out)
+
+
+sort = SortOp.apply
diff --git a/build/torch210-cxx11-cu128-x86_64-linux/ops/sort_benchmark.py b/build/torch210-cxx11-cu128-x86_64-linux/ops/sort_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..11043c0824c36372585f1d9f48480c2a6ef32eb6
--- /dev/null
+++ b/build/torch210-cxx11-cu128-x86_64-linux/ops/sort_benchmark.py
@@ -0,0 +1,85 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import numpy as np
+import torch
+# from absl.testing import parameterized
+
+from .. import ops
+
+_SORT_TESTS = (
+ (16384, torch.int32, None),
+ (16384, torch.int32, 2),
+ (16384, torch.int32, 128),
+)
+
+_BASELINE_SORT_TESTS = ((16384,),)
+
+
+def numpy_dtype(dtype):
+ types = {
+ torch.int16: np.int16,
+ torch.int32: np.int32,
+ torch.int64: np.int64,
+ }
+ return types[dtype]
+
+
+def benchmark_function(fn, iterations=10):
+ # Run once to get rid of startup overhead.
+ fn()
+ times = []
+ for _ in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+ start.record()
+ fn()
+ end.record()
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ times = np.array(times)
+ return times.mean(), times.std(), times.max(), times.min()
+
+
+def log_benchmark(arguments, mean_t, std_t):
+ print('=' * 60)
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t))
+ print('=' * 60)
+
+
+# class SortBenchmark(parameterized.TestCase):
+#
+# @parameterized.parameters(*_SORT_TESTS)
+# def testSort(self, n, dtype, max_val):
+# if max_val is None:
+# max_val = np.iinfo(numpy_dtype(dtype)).max
+# end_bit = int(np.ceil(np.log2(max_val)))
+# x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
+#
+# mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.sort(x, end_bit),)
+# arguments = {
+# 'n': n,
+# 'dtype': dtype,
+# 'max_val': max_val,
+# }
+# log_benchmark(arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_BASELINE_SORT_TESTS)
+# def testTorchSort(self, n):
+# x = torch.randint(0, 128, (n,)).cuda().to(torch.int32)
+#
+# mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.sort(x))
+# arguments = {
+# 'n': n,
+# }
+# log_benchmark(arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch210-cxx11-cu128-x86_64-linux/ops/stk_autocast.py b/build/torch210-cxx11-cu128-x86_64-linux/ops/stk_autocast.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a3626e5e0eec51339c95a448bca84be14a2ca93
--- /dev/null
+++ b/build/torch210-cxx11-cu128-x86_64-linux/ops/stk_autocast.py
@@ -0,0 +1,39 @@
+# vendored from
+# https://github.com/stanford-futuredata/stk/blob/736313768ef697ce13a0594a41b2512a0fbc9884/stk/backend/autocast.py
+import functools
+import torch
+
+
+def _is_eligible(x):
+ return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64)
+
+
+def _cast(x, dtype):
+ if isinstance(x, torch.Tensor) and _is_eligible(x):
+ return x.to(dtype)
+ elif isinstance(x, map):
+ return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()}
+ elif isinstance(x, list) or isinstance(x, tuple):
+ return type(x)(map(lambda y: _cast(y, dtype), x))
+ return x
+
+
+def custom_fwd(fwd):
+ """Wrap a custom autograd function that always uses autocast dtype."""
+
+ @functools.wraps(fwd)
+ def decorate_fwd(*args, **kwargs):
+ if torch.is_autocast_enabled():
+ with torch.autocast(device_type="cuda", enabled=False):
+ dtype = torch.get_autocast_gpu_dtype()
+ return fwd(*_cast(args, dtype), **_cast(kwargs, dtype))
+ return fwd(*args, **kwargs)
+ return decorate_fwd
+
+
+def custom_bwd(bwd):
+ @functools.wraps(bwd)
+ def decorate_bwd(*args, **kwargs):
+ with torch.autocast(device_type="cuda", enabled=False):
+ return bwd(*args, **kwargs)
+ return decorate_bwd
\ No newline at end of file
diff --git a/build/torch210-cxx11-cu128-x86_64-linux/ops/sum.py b/build/torch210-cxx11-cu128-x86_64-linux/ops/sum.py
new file mode 100644
index 0000000000000000000000000000000000000000..e00c1aa68e1193f5b72f75a2edc37de8d505facc
--- /dev/null
+++ b/build/torch210-cxx11-cu128-x86_64-linux/ops/sum.py
@@ -0,0 +1,9 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+import torch
+
+
+def sum(x: torch.Tensor, dim: int = 0):
+ if x.shape[dim] == 1:
+ return x.squeeze(dim=dim)
+ return x.sum(dim=dim)
diff --git a/build/torch210-cxx11-cu128-x86_64-linux/ops/topology.py b/build/torch210-cxx11-cu128-x86_64-linux/ops/topology.py
new file mode 100644
index 0000000000000000000000000000000000000000..76a50d3164db20534b099dcb4d8487a7aef25d15
--- /dev/null
+++ b/build/torch210-cxx11-cu128-x86_64-linux/ops/topology.py
@@ -0,0 +1,45 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for topology kernel.
+# NOTE: Does not support gradients.
+class TopologyOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(
+ ctx: Any,
+ padded_bins: torch.Tensor,
+ block_size: int,
+ output_block_rows: int,
+ output_block_columns: int,
+ ):
+ out = torch.empty(
+ output_block_rows * output_block_columns,
+ dtype=torch.int16,
+ device=padded_bins.device,
+ )
+ ops.indices(
+ padded_bins,
+ block_size,
+ output_block_rows,
+ output_block_columns,
+ out,
+ )
+ return out
+
+
+topology = TopologyOp.apply
diff --git a/build/torch210-cxx11-cu128-x86_64-linux/stk/__init__.py b/build/torch210-cxx11-cu128-x86_64-linux/stk/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..73c40c267b1c3f4949e9c957a5d2c9f682dfc1a6
--- /dev/null
+++ b/build/torch210-cxx11-cu128-x86_64-linux/stk/__init__.py
@@ -0,0 +1,7 @@
+# import stk.random
+# import stk.ops
+# from stk.matrix import Matrix
+
+from . import random
+from . import ops
+from .matrix import Matrix
diff --git a/build/torch210-cxx11-cu128-x86_64-linux/stk/backend/__init__.py b/build/torch210-cxx11-cu128-x86_64-linux/stk/backend/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/build/torch210-cxx11-cu128-x86_64-linux/stk/backend/autocast.py b/build/torch210-cxx11-cu128-x86_64-linux/stk/backend/autocast.py
new file mode 100644
index 0000000000000000000000000000000000000000..97f6e919a60f3fd579ed0215031008d14111dc96
--- /dev/null
+++ b/build/torch210-cxx11-cu128-x86_64-linux/stk/backend/autocast.py
@@ -0,0 +1,37 @@
+import functools
+import torch
+
+
+def _is_eligible(x):
+ return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64)
+
+
+def _cast(x, dtype):
+ if isinstance(x, torch.Tensor) and _is_eligible(x):
+ return x.to(dtype)
+ elif isinstance(x, map):
+ return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()}
+ elif isinstance(x, list) or isinstance(x, tuple):
+ return type(x)(map(lambda y: _cast(y, dtype), x))
+ return x
+
+
+def custom_fwd(fwd):
+ """Wrap a custom autograd function that always uses autocast dtype."""
+
+ @functools.wraps(fwd)
+ def decorate_fwd(*args, **kwargs):
+ if torch.is_autocast_enabled():
+ with torch.autocast(device_type="cuda", enabled=False):
+ dtype = torch.get_autocast_gpu_dtype()
+ return fwd(*_cast(args, dtype), **_cast(kwargs, dtype))
+ return fwd(*args, **kwargs)
+ return decorate_fwd
+
+
+def custom_bwd(bwd):
+ @functools.wraps(bwd)
+ def decorate_bwd(*args, **kwargs):
+ with torch.autocast(device_type="cuda", enabled=False):
+ return bwd(*args, **kwargs)
+ return decorate_bwd
diff --git a/build/torch210-cxx11-cu128-x86_64-linux/stk/backend/sputnik.py b/build/torch210-cxx11-cu128-x86_64-linux/stk/backend/sputnik.py
new file mode 100644
index 0000000000000000000000000000000000000000..220c947bc1e932e8c77cc30f4069e9930f1aa962
--- /dev/null
+++ b/build/torch210-cxx11-cu128-x86_64-linux/stk/backend/sputnik.py
@@ -0,0 +1,316 @@
+import torch
+
+from ..backend import triton_kernels as backend
+from ..backend.autocast import custom_bwd, custom_fwd
+
+
+def _standardize_shape(x, transpose):
+ if transpose:
+ return torch.Size((x[1], x[0]))
+ return x
+
+
+def _sparse_transpose(x):
+ return (torch.Size((x[0][1], x[0][0])), ) + x[1:]
+
+
+def _transpose_helper(x, transpose):
+ if isinstance(x, torch.Tensor):
+ return x.t() if transpose else x
+ if transpose:
+ x = _sparse_transpose(x)
+ return x + (transpose,)
+
+
+def _wrap(x):
+ if isinstance(x, torch.Tensor):
+ return (x,)
+ return x
+
+
+def _is_transposed(x):
+ return (not x.is_contiguous() and
+ x.stride()[0] == 1 and
+ x.stride()[1] == x.size()[0])
+
+
+def _call_helper(op, out, a, b, trans_a, trans_b):
+ args = (_wrap(_transpose_helper(a, trans_a)) +
+ _wrap(_transpose_helper(b, trans_b)))
+ if isinstance(out, tuple):
+ args = args + out
+ return op(*args)
+
+
+def _preprocess_inputs(lhs, rhs, dy):
+ if isinstance(lhs, torch.Tensor) and _is_transposed(lhs):
+ lhs = lhs.t()
+ if isinstance(rhs, torch.Tensor) and _is_transposed(rhs):
+ rhs = rhs.t()
+ if (isinstance(dy, torch.Tensor) and
+ not dy.is_contiguous() and
+ not _is_transposed(dy)):
+ dy = dy.contiguous()
+ if isinstance(dy, tuple) and not dy[1].is_contiguous():
+ dy = (dy[0], dy[1].contiguous()) + dy[2:]
+ return lhs, rhs, dy
+
+
+def _postprocess_outputs(x, transpose, grad):
+ if isinstance(x, torch.Tensor) and transpose:
+ return grad.t()
+ return grad
+
+
+def _lhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs):
+ lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy)
+
+ a, b = (rhs, dy) if trans_lhs else (dy, rhs)
+ trans_a = trans_lhs and trans_rhs
+ trans_b = trans_lhs or not trans_rhs
+ out = _call_helper(op, lhs, a, b, trans_a, trans_b)
+ return _postprocess_outputs(lhs, trans_lhs, out)
+
+
+def _rhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs):
+ lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy)
+
+ a, b = (dy, lhs) if trans_rhs else (lhs, dy)
+ trans_a = not trans_lhs or trans_rhs
+ trans_b = trans_lhs and trans_rhs
+ out = _call_helper(op, rhs, a, b, trans_a, trans_b)
+ return _postprocess_outputs(rhs, trans_rhs, out)
+
+
+class DSD(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs):
+ ctx.save_for_backward(data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ rhs)
+ ctx.shape = _standardize_shape(shape, transpose_a)
+ ctx.transpose_a = transpose_a
+
+ out = torch.empty(
+ (shape[0], rhs.size()[1]),
+ dtype=rhs.dtype,
+ device=rhs.device)
+
+ backend.dsd(shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs,
+ out)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs = (ctx.shape,) + saved_tensors[:-1]
+ rhs = saved_tensors[-1]
+ trans_a = ctx.transpose_a
+ trans_b = _is_transposed(rhs)
+
+ ddata = None
+ if ctx.needs_input_grad[1]:
+ ddata = _lhs_gradient(sdd,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ drhs = None
+ if ctx.needs_input_grad[-1]:
+ op = dds if trans_b else dsd
+ drhs = _rhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return None, ddata, None, None, None, None, None, None, None, drhs
+
+
+dsd = DSD.apply
+
+
+class DDS(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b):
+ ctx.save_for_backward(lhs,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t)
+ ctx.shape = _standardize_shape(shape, transpose_b)
+ ctx.transpose_b = transpose_b
+ out = torch.empty((lhs.size()[0], shape[1]),
+ dtype=lhs.dtype,
+ device=lhs.device)
+ backend.dds(lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b,
+ out)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs = saved_tensors[0]
+ rhs = (ctx.shape,) + saved_tensors[1:]
+ trans_a = _is_transposed(lhs)
+ trans_b = ctx.transpose_b
+
+ dlhs = None
+ if ctx.needs_input_grad[0]:
+ op = dsd if trans_a else dds
+ dlhs = _lhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ ddata = None
+ if ctx.needs_input_grad[2]:
+ ddata = _rhs_gradient(sdd,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return dlhs, None, ddata, None, None, None, None, None, None, None
+
+
+dds = DDS.apply
+
+
+class SDD(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ lhs,
+ rhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t):
+ ctx.save_for_backward(
+ lhs,
+ rhs,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t)
+ ctx.shape = shape
+ out = torch.empty(
+ data.shape,
+ dtype=lhs.dtype,
+ device=lhs.device)
+ backend.sdd(lhs,
+ rhs,
+ shape,
+ out,
+ offsets,
+ row_indices,
+ column_indices)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs, rhs = saved_tensors[:2]
+ dy = (ctx.shape, dy) + saved_tensors[2:]
+ trans_a = _is_transposed(lhs)
+ trans_b = _is_transposed(rhs)
+
+ dlhs = None
+ if ctx.needs_input_grad[0]:
+ op = dds if trans_a else dsd
+ dlhs = _lhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ drhs = None
+ if ctx.needs_input_grad[1]:
+ op = dsd if trans_b else dds
+ drhs = _rhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return dlhs, drhs, None, None, None, None, None, None, None, None
+
+
+sdd = SDD.apply
+
+class RowIndices(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, shape, data, offsets, column_indices):
+ out = torch.empty(
+ column_indices.shape,
+ dtype=column_indices.dtype,
+ device=column_indices.device)
+ backend.row_indices(shape, data, offsets, column_indices, out)
+ return out
+
+
+row_indices = RowIndices.apply
diff --git a/build/torch210-cxx11-cu128-x86_64-linux/stk/backend/triton_kernels.py b/build/torch210-cxx11-cu128-x86_64-linux/stk/backend/triton_kernels.py
new file mode 100644
index 0000000000000000000000000000000000000000..c535309f3321249f475367164a558f94a4f8eb86
--- /dev/null
+++ b/build/torch210-cxx11-cu128-x86_64-linux/stk/backend/triton_kernels.py
@@ -0,0 +1,393 @@
+import torch
+import triton
+import triton.language as tl
+from dataclasses import dataclass
+
+@dataclass
+class TritonConfig:
+ BLOCK_M: int = 128
+ BLOCK_N: int = 128
+ BLOCK_K: int = 32
+ BLOCK_SIZE: int = 128
+ NUM_STAGES: int = 4
+ NUM_WARPS: int = 4
+
+def _validate_matmul_dims(M: int, K: int, N: int):
+ error_string = "incompatible dimensions: tensor has dim with length: {}, which must be divisible by {}"
+ assert M % TritonConfig.BLOCK_M == 0, error_string.format(M, TritonConfig.BLOCK_M)
+ assert K % TritonConfig.BLOCK_K == 0, error_string.format(K, TritonConfig.BLOCK_K)
+ assert N % TritonConfig.BLOCK_N == 0, error_string.format(N, TritonConfig.BLOCK_N)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _sdd_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+ # matrix multiplication
+ pid = tl.program_id(0)
+ pid_m = tl.load(row_indices + pid)
+ pid_n = tl.load(column_indices + pid)
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
+ rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
+ rk = tl.arange(0, BLOCK_K)
+ # pointers
+ A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
+ B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ for k in range(0, tl.cdiv(K, BLOCK_K)):
+ a = tl.load(A)
+ b = tl.load(B)
+ acc += tl.dot(a, b)
+ A += BLOCK_K * stride_ak
+ B += BLOCK_K * stride_bk
+ #Store to sparse matrix
+ acc = acc.to(C.dtype.element_ty)
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+ cm = tl.arange(0, BLOCK_M)
+ cn = tl.arange(0, BLOCK_N)
+ C = C + pid * BLOCK_ELEMENTS + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _dsd_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices, offsets,
+ block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+
+ # matrix multiplication
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+
+ num_pid_m = tl.num_programs(0)
+ num_pid_n = tl.num_programs(1)
+ pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M)
+
+ start_inx = tl.load(offsets + pid_m)
+ end_inx = tl.load(offsets + pid_m + 1)
+
+ # pointers to sparse matrix
+ rm = tl.arange(0, BLOCK_M)
+ rak = tl.arange(0, BLOCK_K)
+
+ A += (rm[:, None] * stride_am + rak[None, :] * stride_ak)
+
+ # pointers to dense matrix
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ rbk = tl.arange(0, BLOCK_K)
+ B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn)
+
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K)
+
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+ ak_sub_incr = BLOCK_K * stride_ak
+ bk_sub_incr = BLOCK_K * stride_bk
+ bk_block_incr = BLOCK_SIZE * stride_bk
+
+ for k in range(nsub_blocks * (end_inx - start_inx)):
+ sub_block_inx = k % nsub_blocks
+ block_inx = k // nsub_blocks
+
+ if trans_A:
+ ptr_A = A + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr
+ else:
+ ptr_A = A + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr
+
+ ptr_B = B + tl.load(column_indices + start_inx + block_inx) * bk_block_incr + sub_block_inx * bk_sub_incr
+
+ a = tl.load(ptr_A)
+ b = tl.load(ptr_B)
+ acc += tl.dot(a, b)
+
+ acc = acc.to(C.dtype.element_ty)
+
+ cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+
+ C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _dds_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices, offsets,
+ block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+
+ # matrix multiplication
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+
+ num_pid_m = tl.num_programs(0)
+ num_pid_n = tl.num_programs(1)
+ pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M)
+
+ start_inx = tl.load(offsets + pid_n)
+ end_inx = tl.load(offsets + pid_n + 1)
+
+ # pointers to dense matrix
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ rak = tl.arange(0, BLOCK_K)
+
+ A += (rm[:, None] * stride_am + rak[None, :] * stride_ak)
+
+ # pointers to sparse matrix
+ rn = tl.arange(0, BLOCK_N)
+ rbk = tl.arange(0, BLOCK_K)
+ B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn)
+
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K)
+
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+
+ ak_sub_incr = BLOCK_K * stride_ak
+ ak_block_incr = BLOCK_SIZE * stride_ak
+ bk_sub_incr = BLOCK_K * stride_bk
+
+ for k in range(nsub_blocks * (end_inx - start_inx)):
+ sub_block_inx = k % nsub_blocks
+ block_inx = k // nsub_blocks
+
+ if trans_B:
+ ptr_B = B + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr
+ else:
+ ptr_B = B + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr
+
+ ptr_A = A + tl.load(column_indices + start_inx + block_inx) * ak_block_incr + sub_block_inx * ak_sub_incr
+ a = tl.load(ptr_A)
+ b = tl.load(ptr_B)
+ acc += tl.dot(a, b)
+
+ acc = acc.to(C.dtype.element_ty)
+ cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+def dsd(shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs,
+ out
+ ):
+
+ device = rhs.device
+ trans_A = transpose_a
+ trans_B = False
+
+ if rhs.stride(0) > 1 and rhs.stride(1) > 1:
+ trans_B = True
+
+ # checks constraints
+ assert shape[1] == rhs.shape[0], "incompatible dimensions"
+ M, K = shape
+ _, N = rhs.shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if rhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ stride_am, stride_ak = data.stride(1), data.stride(2)
+ stride_bk, stride_bn = rhs.stride(0), rhs.stride(1)
+ a_column_indices = column_indices
+ a_offsets = offsets
+
+ # launch kernel
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N']))
+
+ if trans_A:
+ stride_am, stride_ak = data.stride(2), data.stride(1)
+ a_column_indices, a_offsets = column_indices_t, offsets_t
+
+ if trans_B:
+ stride_bk, stride_bn = rhs.stride(1), rhs.stride(0)
+
+ _dsd_kernel[grid](
+ data.data, rhs, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(0), out.stride(1),
+ row_indices, a_column_indices, a_offsets,
+ block_offsets_t, trans_A, trans_B,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+ # return out
+
+def dds(lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b,
+ out
+ ):
+
+ device = lhs.device
+ trans_B = transpose_b
+ trans_A = False
+
+ if lhs.stride(0) > 1 and lhs.stride(1) > 1:
+ trans_A = True
+
+ # checks constraints
+ assert lhs.shape[1] == shape[0], "incompatible dimensions"
+ M, K = lhs.shape
+ _, N = shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if lhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ stride_am, stride_ak = lhs.stride(0), lhs.stride(1)
+ stride_bk, stride_bn = data.stride(1), data.stride(2)
+ b_column_indices = column_indices_t
+ b_offsets = offsets_t
+
+ # launch kernel
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N']))
+
+ if trans_A:
+ stride_am, stride_ak = lhs.stride(1), lhs.stride(0)
+ if trans_B:
+ stride_bk, stride_bn = data.stride(2), data.stride(1)
+ b_column_indices, b_offsets = column_indices, offsets
+
+ _dds_kernel[grid](
+ lhs, data, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(0), out.stride(1),
+ row_indices, b_column_indices, b_offsets,
+ block_offsets_t, trans_A, trans_B,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+
+def sdd(lhs,
+ rhs,
+ shape,
+ out,
+ offsets,
+ row_indices,
+ column_indices
+ ):
+
+ device = out.device
+ trans_A = False
+ trans_B = False
+
+ if lhs.stride(0) > 1 and lhs.stride(1) > 1:
+ trans_A = True
+ if rhs.stride(0) > 1 and rhs.stride(1) > 1:
+ trans_B = True
+
+ # checks constraints
+ assert lhs.shape[1] == rhs.shape[0], "incompatible dimensions"
+ M, K = lhs.shape
+ _, N = rhs.shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if out.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ # launch kernel
+ nnz_blocks = len(row_indices)
+ grid = lambda META: (nnz_blocks,)
+
+ stride_am, stride_ak = lhs.stride(0), lhs.stride(1)
+ stride_bk, stride_bn = rhs.stride(0), rhs.stride(1)
+
+ if trans_A:
+ stride_am, stride_ak = lhs.stride(1), lhs.stride(0)
+ if trans_B:
+ stride_bk, stride_bn = rhs.stride(1), rhs.stride(0)
+
+ _sdd_kernel[grid](
+ lhs, rhs, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(1), out.stride(2),
+ row_indices, column_indices,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+
+@triton.jit
+def _row_indices_kernel(offsets, out):
+ pid = tl.program_id(0)
+ row_offset = tl.load(offsets + pid)
+ nnz_blocks = tl.load(offsets + pid + 1) - row_offset
+ for nnz_block in range(nnz_blocks):
+ tl.store(out + row_offset + nnz_block, pid)
+
+def row_indices(
+ shape, data, offsets, column_indices, out
+):
+ block_rows = len(offsets) - 1
+ _row_indices_kernel[(block_rows, )](offsets, out)
diff --git a/build/torch210-cxx11-cu128-x86_64-linux/stk/matrix.py b/build/torch210-cxx11-cu128-x86_64-linux/stk/matrix.py
new file mode 100644
index 0000000000000000000000000000000000000000..80f42263d6aada287adbfa52a61fe950162a9e28
--- /dev/null
+++ b/build/torch210-cxx11-cu128-x86_64-linux/stk/matrix.py
@@ -0,0 +1,329 @@
+import numpy as np
+import torch
+
+# 1. Add heavyweight (data) validation helper.
+# 2. Add construction helpers
+# 3. Make indentation consistent
+# 4. Replace asserts with descriptive errors.
+
+##
+### Validation helpers.
+##
+
+
+def _validate_matrix(shape, data, row_indices, column_indices, offsets):
+ # Data should be [nnz, block_size, block_size]
+ if data.dim() == 1:
+ data = torch.reshape(data, [data.numel(), 1, 1])
+
+ # Blocks should be square.
+ if data.shape[-2] != data.shape[-1]:
+ raise ValueError(
+ "Expected square blocking in data. "
+ f"Got block shape {[data.shape[-2], data.shape[-1]]}")
+
+ # Flatten batch dimensions on data - original shape preserved
+ # in shape argument.
+ block_size = data.shape[-1]
+ data = data.view([-1, block_size, block_size])
+
+ if data.dim() != 3:
+ raise ValueError(
+ "Expected 3D shape for data (nnz, block, block). "
+ f"Got shape {data.dim()}D shape.")
+
+ block_size = data.shape[1]
+ if shape[-2] % block_size != 0 or shape[-1] % block_size != 0:
+ raise ValueError(
+ "Matrix shape must be dividible by blocking. "
+ f"Got shape {shape} with "
+ f"{[block_size, block_size]} blocking.")
+
+ if np.prod(shape) < data.numel():
+ raise ValueError(
+ "Invalid matrix. Number of nonzeros exceeds matrix capacity "
+ f"({data.numel()} v. {np.prod(shape)})")
+
+ if row_indices.dim() != 1:
+ raise ValueError(
+ f"Expected 1D row_indices. Got {row_indices.dim()}D row_indices.")
+
+ if column_indices.dim() != 1:
+ raise ValueError(
+ f"Expected 1D column_indices. Got {column_indices.dim()}D column_indices.")
+
+ if offsets.dim() != 1:
+ raise ValueError(
+ f"Expected 1D offsets. Got {offsets.dim()}D offsets.")
+
+ if row_indices.numel() != data.shape[0]:
+ raise ValueError(
+ "Expected 1 index per nonzero block. "
+ f"Got {row_indices.numel()} row_indices for {data.shape[0]} blocks")
+
+ if column_indices.numel() != data.shape[0]:
+ raise ValueError(
+ "Expected 1 index per nonzero block. "
+ f"Got {column_indices.numel()} column_indices for {data.shape[0]} blocks")
+
+ block_rows = np.prod(shape[:-1]) / block_size
+ if offsets.numel() != block_rows + 1:
+ raise ValueError(
+ "Expected one offset per block row plus one. "
+ f"Got {offsets.numel()} offsets with {block_rows} block rows.")
+
+ is_cuda = (data.is_cuda and
+ row_indices.is_cuda and
+ column_indices.is_cuda and
+ offsets.is_cuda)
+ is_cpu = (not data.is_cuda and
+ not row_indices.is_cuda and
+ not column_indices.is_cuda and
+ not offsets.is_cuda)
+ if not (is_cuda or is_cpu):
+ raise ValueError(
+ "Expected data & meta-data on common device. "
+ f"Got data on {data.device}, row_indices on {row_indices.device} "
+ f"column_indices on {column_indices.device} and "
+ f"offsets on {offsets.device}.")
+
+ if data.dtype != torch.float16:
+ raise ValueError(
+ f"Expected float16 data. Got {data.dtype} data.")
+ if row_indices.dtype != torch.int16:
+ raise ValueError(
+ f"Expected int16 row_indices. Got {row_indices.dtype} row_indices.")
+ if column_indices.dtype != torch.int16:
+ raise ValueError(
+ f"Expected int16 column_indices. Got {column_indices.dtype} column_indices.")
+ if offsets.dtype != torch.int32:
+ raise ValueError(
+ f"Expected int32 offsets. Got {offsets.dtype} offsets.")
+ return data
+
+
+def _transpose(size, data, row_indices, column_indices, offsets):
+ block_columns = size[1] // data.shape[1]
+
+ # Sort row indices by column indices to get the transposed matrix's
+ # column indices.
+ gather_indices = column_indices.argsort()
+ column_indices_t = row_indices.gather(0, gather_indices)
+ block_offsets_t = gather_indices.int()
+
+ # NOTE: Histogram is not implemented for any integer type on CPU. Do
+ # the histogram in 32-bit float, which can exactly represent 16-bit
+ # integers.
+ column_indices_float = column_indices.float()
+
+ zero = torch.zeros((1,), dtype=torch.int32, device=data.device)
+ nnz_per_column = column_indices_float.histc(block_columns, 0, block_columns)
+ nnz_per_column = nnz_per_column.int()
+ offsets_t = torch.cat([zero, nnz_per_column.cumsum(0, dtype=torch.int32)])
+ return column_indices_t, offsets_t, block_offsets_t
+
+
+class Matrix(torch.nn.Module):
+ """A matrix stored in sparse format.
+
+ Underlying format is block compressed sparse row (BCSR).
+
+ TODO(tgale): Make this mirror torch.Tensor API as much as possible.
+ """
+
+ def __init__(self,
+ size,
+ data,
+ row_indices,
+ column_indices,
+ offsets,
+ column_indices_t=None,
+ offsets_t=None,
+ block_offsets_t=None):
+ super().__init__()
+ self._size = size
+ self._data = data
+ self._row_indices = row_indices
+ self._column_indices = column_indices
+ self._offsets = offsets
+
+ # Produce the transpose meta-data if it is not passed in.
+ if ((column_indices_t is None) or (offsets_t is None) or
+ (block_offsets_t is None)):
+ column_indices_t, offsets_t, block_offsets_t = _transpose(
+ size, data, row_indices, column_indices, offsets)
+ self._column_indices_t = column_indices_t
+ self._offsets_t = offsets_t
+ self._block_offsets_t = block_offsets_t
+
+ self._transposed = False
+
+ # Validate that our metadata will not overflow.
+ max_dim = np.iinfo(np.int16).max * self.blocking
+ if column_indices.dtype == torch.int16:
+ if size[0] > max_dim or size[1] > max_dim:
+ raise ValueError(
+ "Sparse matrix with shape {size} exceeds representable "
+ "size with 16-bit indices.")
+
+ def validate(self):
+ _validate_matrix(self._size,
+ self._data,
+ self._row_indices,
+ self._column_indices,
+ self._offsets)
+
+ # TODO(tgale): Add heavyweight data validation.
+
+ def to(self, device):
+ # TODO(tgale): Handle type conversions here. We
+ # need to set the appropriate meta-data type for
+ # the given floating-point type.
+ self._data = self._data.to(device)
+ self._row_indices = self._row_indices.to(device)
+ self._column_indices = self._column_indices.to(device)
+ self._offsets = self._offsets.to(device)
+ self._column_indices_t = self._column_indices_t.to(device)
+ self._offsets_t = self._offsets_t.to(device)
+ self._block_offsets_t = self._block_offsets_t.to(device)
+ return self
+
+ def cuda(self):
+ return self.to(torch.cuda.current_device())
+
+ def clone(self):
+ return Matrix(
+ self.size(),
+ self.data.clone(),
+ self.row_indices.clone(),
+ self.column_indices.clone(),
+ self.offsets.clone(),
+ self.column_indices_t.clone(),
+ self.offsets_t.clone(),
+ self.block_offsets_t.clone())
+
+ def t(self):
+ if self.dim() != 2:
+ raise ValueError(
+ "t() expects a tensor with <= 2 dimensions, "
+ f"but self is {self.dim()}D.")
+ out = Matrix(self.size(),
+ self.data,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+ out._transposed = not self._transposed
+ out._size = torch.Size((self._size[1], self._size[0]))
+ return out
+
+ def contiguous(self):
+ raise ValueError("Not yet implemented.")
+
+ def is_contiguous(self):
+ return not self._transposed
+
+ @property
+ def is_cuda(self):
+ return self._data.is_cuda
+
+ @property
+ def device(self):
+ return self._data.device
+
+ def size(self):
+ return self._size
+
+ @property
+ def shape(self):
+ return self.size()
+
+ def dim(self):
+ return len(self._size)
+
+ @property
+ def data(self):
+ return self._data
+
+ @property
+ def row_indices(self):
+ return self._row_indices
+
+ @property
+ def column_indices(self):
+ return self._column_indices
+
+ @property
+ def offsets(self):
+ return self._offsets
+
+ @property
+ def offsets_t(self):
+ return self._offsets_t
+
+ @property
+ def column_indices_t(self):
+ return self._column_indices_t
+
+ @property
+ def block_offsets_t(self):
+ return self._block_offsets_t
+
+ @property
+ def dtype(self):
+ return self.data.dtype
+
+ @property
+ def nnz(self):
+ return self.data.numel()
+
+ @property
+ def blocking(self):
+ return self.data.shape[1]
+
+ @property
+ def requires_grad(self):
+ return self.data.requires_grad
+
+ def requires_grad_(self, x):
+ self.data.requires_grad_(x)
+ return self
+
+ def view(self, *shape):
+ assert self.is_contiguous()
+ if shape[-1] != self.size()[-1]:
+ raise ValueError(
+ "Can't change view on compressed dimension. "
+ f"{self.size()[-1]} v. {shape[-1]}.")
+ if np.prod(shape) != np.prod(self.size()):
+ raise ValueError(
+ "Mismatch in numel of Matrix and new shape. "
+ f"{np.prod(self.size())} v. {np.prod(shape)}")
+ return Matrix(shape,
+ self.data,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+
+ @property
+ def grad(self):
+ # TODO(tgale): Make sure this mirrors torch.Tensor
+ # behavior in the case where we ask for the gradient
+ # of a non-contiguous tensor.
+ size = self.size()
+ if not self.is_contiguous():
+ size = torch.Size((size[1], size[0]))
+ out = Matrix(size,
+ self.data.grad,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+ return out if self.is_contiguous() else out.t()
diff --git a/build/torch210-cxx11-cu128-x86_64-linux/stk/ops/__init__.py b/build/torch210-cxx11-cu128-x86_64-linux/stk/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc873b236f4cd4036964c016a4036e3ce5ebf1ac
--- /dev/null
+++ b/build/torch210-cxx11-cu128-x86_64-linux/stk/ops/__init__.py
@@ -0,0 +1,3 @@
+from .linear_ops import dds, dsd, sdd
+from .matrix_ops import ones_like, row_indices, sum, to_dense, to_sparse
+from .eltwise_ops import mul
diff --git a/build/torch210-cxx11-cu128-x86_64-linux/stk/ops/eltwise_ops.py b/build/torch210-cxx11-cu128-x86_64-linux/stk/ops/eltwise_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba7d7332320250fd01fa60e60528f19de3e8ed03
--- /dev/null
+++ b/build/torch210-cxx11-cu128-x86_64-linux/stk/ops/eltwise_ops.py
@@ -0,0 +1,28 @@
+from ..matrix import Matrix
+
+def mul(a, b):
+ """Performs element-wise multiplication of matrices a and b.
+
+ It is the user's responsibility to make sure that a and b
+ follow the same matrix topology. This function assumes it is safe
+ to use the topoplogy of a.
+
+ Args:
+ a: stk.Matrix.
+ b: stk.Matrix with a's matrix topology.
+
+ Returns:
+ stk.Matrix where the entries correspond to torch.mul(a, b).
+ """
+ assert isinstance(a, Matrix)
+ assert isinstance(b, Matrix)
+ assert a.size() == b.size()
+
+ return Matrix(a.size(),
+ a.data * b.data,
+ a.row_indices,
+ a.column_indices,
+ a.offsets,
+ a.column_indices_t,
+ a.offsets_t,
+ a.block_offsets_t)
diff --git a/build/torch210-cxx11-cu128-x86_64-linux/stk/ops/eltwise_ops_test.py b/build/torch210-cxx11-cu128-x86_64-linux/stk/ops/eltwise_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..2939372a5c68ac92b47b11015db4f75f4fd60ffa
--- /dev/null
+++ b/build/torch210-cxx11-cu128-x86_64-linux/stk/ops/eltwise_ops_test.py
@@ -0,0 +1,86 @@
+import unittest
+import itertools
+import torch
+# from absl.testing import parameterized
+
+import stk
+from stk.ops.linear_ops_test import allclose, _dense_and_sparse
+
+_MATRIX_SIZES = (
+ (128, 128, 0.0),
+ (256, 256, 0.5),
+ (2048, 1024, 0.8),
+ (512, 128, 0.0),
+ (128, 512, 0.0),
+ (1024, 512, 0.0),
+ (1024, 512, 0.5),
+ (1024, 512, 0.75),
+ (512, 1024, 0.0),
+ (512, 1024, 0.5),
+ (512, 1024, 0.75),
+ (1024, 1024, 0.0),
+ (1024, 1024, 0.5),
+ (1024, 1024, 0.75),
+)
+
+_DTYPE = (
+ torch.float16, torch.bfloat16
+)
+
+def _generate_testcases():
+ testcases = itertools.product(_MATRIX_SIZES, _DTYPE)
+ testcases = [(*size, 128, dtype) for
+ (size, dtype) in testcases]
+ return testcases
+
+_ELTWISE_OP_TESTS = _generate_testcases()
+
+def _dense_and_sparse_like(x, std=0.1):
+ dense_data = torch.randn_like(x.data, device=x.device) * std
+ sparse = stk.Matrix(x.size(),
+ dense_data,
+ x.row_indices,
+ x.column_indices,
+ x.offsets)
+ dense = stk.ops.to_dense(sparse)
+
+ return (dense.requires_grad_(True),
+ sparse.requires_grad_(True))
+
+# @parameterized.parameters(_ELTWISE_OP_TESTS)
+# class EltwiseOpsTest(parameterized.TestCase):
+#
+# def testEltwiseMul(self, m, n, sparsity, blocking, dtype):
+#
+# a_dense, a = _dense_and_sparse(m, n, sparsity, blocking, dtype)
+# b_dense, b = _dense_and_sparse_like(a)
+#
+# out = stk.ops.mul(a, b)
+# expected_out = torch.mul(a_dense, b_dense)
+#
+# # Compute the gradients w.r.t. the inputs.
+# expected_out.sum().backward()
+# stk.ops.sum(out).backward()
+#
+# # Validate the results.
+# out = stk.ops.to_dense(out)
+# self.assertEqual(out.dim(), 2)
+# self.assertEqual(expected_out.size(), out.size())
+# self.assertTrue(allclose(out, expected_out))
+#
+# # LHS gradient.
+# grad = stk.ops.to_dense(a.grad)
+# expected_grad = a_dense.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size(), grad.size())
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# # RHS gradient.
+# grad = stk.ops.to_dense(b.grad)
+# expected_grad = b_dense.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size(), grad.size())
+# self.assertTrue(allclose(grad, expected_grad))
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch210-cxx11-cu128-x86_64-linux/stk/ops/linear_ops.py b/build/torch210-cxx11-cu128-x86_64-linux/stk/ops/linear_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d277c8c07f9e30addc31900a12175c8a1f4d7ad
--- /dev/null
+++ b/build/torch210-cxx11-cu128-x86_64-linux/stk/ops/linear_ops.py
@@ -0,0 +1,59 @@
+import torch
+
+from ..backend import sputnik
+from ..matrix import Matrix
+
+
+def dsd(a, b):
+ assert isinstance(a, Matrix)
+ assert isinstance(b, torch.Tensor)
+ return sputnik.dsd(
+ a.size(),
+ a.data, a.offsets,
+ a.row_indices,
+ a.column_indices,
+ a.offsets_t,
+ a.column_indices_t,
+ a.block_offsets_t,
+ not a.is_contiguous(),
+ b)
+
+
+def dds(a, b):
+ assert isinstance(a, torch.Tensor)
+ assert isinstance(b, Matrix)
+ return sputnik.dds(
+ a,
+ b.size(),
+ b.data, b.offsets,
+ b.row_indices,
+ b.column_indices,
+ b.offsets_t,
+ b.column_indices_t,
+ b.block_offsets_t,
+ not b.is_contiguous())
+
+
+def sdd(a, b, topo):
+ assert isinstance(a, torch.Tensor)
+ assert isinstance(b, torch.Tensor)
+ assert isinstance(topo, Matrix)
+ assert topo.is_contiguous()
+ out = sputnik.sdd(
+ a, b,
+ topo.size(),
+ topo.data,
+ topo.offsets,
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets_t,
+ topo.column_indices_t,
+ topo.block_offsets_t)
+ return Matrix(topo.size(),
+ out,
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets,
+ topo.column_indices_t,
+ topo.offsets_t,
+ topo.block_offsets_t)
diff --git a/build/torch210-cxx11-cu128-x86_64-linux/stk/ops/linear_ops_test.py b/build/torch210-cxx11-cu128-x86_64-linux/stk/ops/linear_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1c24d350df9c1b2346c7da885502cd696c88867
--- /dev/null
+++ b/build/torch210-cxx11-cu128-x86_64-linux/stk/ops/linear_ops_test.py
@@ -0,0 +1,216 @@
+import unittest
+import itertools
+import numpy as np
+import torch
+# from absl.testing import parameterized
+
+import stk
+
+
+def allclose(x, y, pct=0.25):
+ mask = torch.isclose(x, y, rtol=5e-2)
+ pct_diff = (mask.numel() - mask.sum()) / mask.numel() * 100
+ if pct_diff > pct:
+ print("{:.2f}% of values not close.".format(pct_diff))
+ return False
+ return True
+
+
+# An assortment of problems designed to make sure
+# the bindings are operating correctly.
+_MATRIX_SIZES = (
+ (128, 128, 128, 0.0),
+ (256, 256, 256, 0.5),
+ (2048, 1024, 512, 0.8),
+ (512, 128, 128, 0.0),
+ (128, 128, 512, 0.0),
+ (1024, 512, 512, 0.0),
+ (1024, 512, 512, 0.5),
+ (1024, 512, 512, 0.75),
+ (512, 512, 1024, 0.0),
+ (512, 512, 1024, 0.5),
+ (512, 512, 1024, 0.75),
+ (1024, 1024, 1024, 0.0),
+ (1024, 1024, 1024, 0.5),
+ (1024, 1024, 1024, 0.75),
+)
+
+_TRANSPOSE = (
+ (False, False),
+ (False, True),
+ (True, False),
+ (True, True),
+)
+
+_DTYPE = (
+ torch.float16, torch.bfloat16
+)
+
+def _generate_testcases():
+ testcases = itertools.product(_MATRIX_SIZES, _TRANSPOSE, _DTYPE)
+ testcases = [(*size, *trans, 128, dtype) for
+ (size, trans, dtype) in testcases]
+ return testcases
+
+_LINEAR_OP_TESTS = _generate_testcases()
+
+def _dense_and_sparse(rows, cols, sparsity, blocking, dtype, std=0.1):
+ mask = stk.random.dense_mask(rows, cols, sparsity, blocking)
+ dense = (torch.randn(rows, cols) * std * mask).type(dtype)
+ sparse = stk.ops.to_sparse(dense, blocking)
+ cuda_device = torch.device("cuda")
+ return (dense.to(cuda_device).requires_grad_(True),
+ sparse.to(cuda_device).requires_grad_(True))
+
+
+def _dense(rows, cols, dtype, std=0.1):
+ cuda_device = torch.device("cuda")
+ out = (torch.randn(rows, cols) * std).type(dtype)
+ return out.to(cuda_device).requires_grad_(True)
+
+
+def _dense_2x(rows, cols, dtype):
+ a = _dense(rows, cols, dtype)
+ return a, a.detach().requires_grad_(True)
+
+
+def _with_transpose(op, a, b, trans_a, trans_b):
+ a = a.t() if trans_a else a
+ b = b.t() if trans_b else b
+ return op(a, b)
+
+
+def _mmm(a, b, topo):
+ mask = stk.ops.to_dense(stk.ops.ones_like(topo))
+ return torch.mm(a, b) * mask
+
+
+def _sparse_out_with_transpose(op, a, b, topo, trans_a, trans_b):
+ a = a.t() if trans_a else a
+ b = b.t() if trans_b else b
+ return op(a, b, topo)
+
+
+def _mask(x, mask):
+ mask = stk.ops.to_dense(stk.ops.ones_like(mask))
+ return x * mask
+
+
+# @parameterized.parameters(*_LINEAR_OP_TESTS)
+# class LinearOpsTest(parameterized.TestCase):
+#
+# def testLinearOps_Dsd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+# # Construct the operands.
+# a_shape = (k, m) if trans_a else (m, k)
+# a_dense, a = _dense_and_sparse(*a_shape, sparsity, blocking, dtype)
+# b_shape = (n, k) if trans_b else (k, n)
+# b, bcp = _dense_2x(*b_shape, dtype)
+#
+# # Execute the matmul.
+# out = _with_transpose(stk.ops.dsd, a, b, trans_a, trans_b)
+# expected_out = _with_transpose(torch.mm, a_dense, bcp, trans_a, trans_b)
+#
+# # Compute the gradients w.r.t. the inputs.
+# expected_out.sum().backward()
+# out.sum().backward()
+#
+# # Validate the results.
+# self.assertEqual(out.dim(), 2)
+# self.assertEqual(expected_out.size()[0], out.size()[0])
+# self.assertEqual(expected_out.size()[1], out.size()[1])
+# self.assertTrue(allclose(out, expected_out))
+#
+# # LHS gradient.
+# grad = stk.ops.to_dense(a.grad)
+# expected_grad = _mask(a_dense.grad, a.grad)
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# # RHS gradient.
+# grad = b.grad
+# expected_grad = bcp.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# def testLinearOps_Dds(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+# # Construct the operands.
+# a_shape = (k, m) if trans_a else (m, k)
+# a, acp = _dense_2x(*a_shape, dtype)
+# b_shape = (n, k) if trans_b else (k, n)
+# b_dense, b = _dense_and_sparse(*b_shape, sparsity, blocking, dtype)
+#
+# # Execute the matmul.
+# out = _with_transpose(stk.ops.dds, a, b, trans_a, trans_b)
+# expected_out = _with_transpose(torch.mm, acp, b_dense, trans_a, trans_b)
+#
+# # Compute the gradients w.r.t. the inputs.
+# expected_out.sum().backward()
+# out.sum().backward()
+#
+# # Validate the results.
+# self.assertEqual(out.dim(), 2)
+# self.assertEqual(expected_out.size()[0], out.size()[0])
+# self.assertEqual(expected_out.size()[1], out.size()[1])
+# self.assertTrue(allclose(out, expected_out))
+#
+# # LHS gradient.
+# grad = a.grad
+# expected_grad = acp.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# # RHS gradient.
+# grad = stk.ops.to_dense(b.grad)
+# expected_grad = _mask(b_dense.grad, b.grad)
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# def testLinearOps_Sdd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+# # Construct the operands.
+# a_shape = (k, m) if trans_a else (m, k)
+# a, acp = _dense_2x(*a_shape, dtype)
+# b_shape = (n, k) if trans_b else (k, n)
+# b, bcp = _dense_2x(*b_shape, dtype)
+# _, topo = _dense_and_sparse(m, n, sparsity, blocking, dtype)
+#
+# # Execute the matmul.
+# out = _sparse_out_with_transpose(stk.ops.sdd, a, b, topo, trans_a, trans_b)
+# expected_out = _sparse_out_with_transpose(_mmm, acp, bcp, topo, trans_a, trans_b)
+#
+# # Compute the gradients w.r.t. the inputs.
+# expected_out.sum().backward()
+# stk.ops.sum(out).backward()
+#
+# # Validate the results.
+# out = stk.ops.to_dense(out)
+# self.assertEqual(out.dim(), 2)
+# self.assertEqual(expected_out.size()[0], out.size()[0])
+# self.assertEqual(expected_out.size()[1], out.size()[1])
+# self.assertTrue(allclose(out, expected_out))
+#
+# # LHS gradient.
+# grad = a.grad
+# expected_grad = acp.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# # RHS gradient.
+# grad = b.grad
+# expected_grad = bcp.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch210-cxx11-cu128-x86_64-linux/stk/ops/matrix_ops.py b/build/torch210-cxx11-cu128-x86_64-linux/stk/ops/matrix_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..447c72dc73439d84f58c917676cc04e64f13e97d
--- /dev/null
+++ b/build/torch210-cxx11-cu128-x86_64-linux/stk/ops/matrix_ops.py
@@ -0,0 +1,98 @@
+from ..backend import sputnik
+from ..matrix import Matrix
+import torch
+import numpy as np
+
+
+@torch.no_grad()
+def row_indices(shape, data, offsets, column_indices):
+ return sputnik.row_indices(shape, data, offsets, column_indices)
+
+
+# TODO(tgale): Replace this helper with a custom kernel. This operation
+# is much simpler to do than how it's currently implemented.
+@torch.no_grad()
+def _expand_for_blocking(idxs, blocking):
+ # Duplicate for block column dimension.
+ idxs = torch.reshape(idxs, [idxs.size()[0], 1, 2]).repeat(1, blocking, 1)
+
+ # Update the column indices.
+ idxs[:, :, 1] *= blocking
+ idxs[:, :, 1] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking])
+
+ # Duplicate for block row dimension.
+ idxs = torch.reshape(idxs, [idxs.size()[0], 1, blocking, 2])
+ idxs = idxs.repeat(1, blocking, 1, 1)
+
+ # Update the row indices.
+ idxs[:, :, :, 0] *= blocking
+ idxs[:, :, :, 0] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking, 1])
+ idxs = torch.reshape(idxs, [-1, 2])
+ return idxs
+
+
+# TODO(tgale): Add input type checking.
+@torch.no_grad()
+def to_dense(x):
+ assert isinstance(x, Matrix)
+
+ shape = (np.prod(x.shape[:-1]), x.shape[-1])
+ row_idxs = x.row_indices.type(torch.int32)
+ col_idxs = x.column_indices.type(torch.int32)
+ indices = _expand_for_blocking(torch.stack([row_idxs, col_idxs], dim=1), x.blocking)
+ indices = (indices[:, 0] * shape[1] + indices[:, 1]).type(torch.int64)
+
+ out = torch.zeros(shape[0] * shape[1], dtype=x.dtype, device=x.device)
+ out.scatter_(0, indices, x.data.flatten())
+ return out.reshape(x.size())
+
+
+@torch.no_grad()
+def _mask(x, blocking=1):
+ assert x.dim() == 2
+ assert x.size()[0] % blocking == 0
+ assert x.size()[1] % blocking == 0
+ block_rows = x.size()[0] // blocking
+ block_cols = x.size()[1] // blocking
+ x = torch.reshape(x, [block_rows, blocking, block_cols, blocking])
+ x = torch.sum(torch.abs(x), dim=(1, 3))
+ return x != 0
+
+
+# TODO(tgale): Add input type checking.
+@torch.no_grad()
+def to_sparse(x, blocking=1):
+ m = _mask(x, blocking)
+
+ # TODO(tgale): Set to appropriate type for input matrix.
+ row_nnzs = torch.sum(m, dim=1).type(torch.int32)
+ zeros = torch.zeros((1,), dtype=row_nnzs.dtype, device=row_nnzs.device)
+ offsets = torch.cat([zeros, torch.cumsum(row_nnzs, dim=0)])
+ offsets = offsets.type(torch.int32)
+
+ indices = torch.nonzero(m).type(torch.int16)
+ row_indices = indices[:, 0]
+ column_indices = indices[:, 1]
+
+ # Nonzero indices in the dense matrix.
+ nonzero_indices = torch.nonzero(m)
+ nonzero_indices = _expand_for_blocking(nonzero_indices, blocking)
+ nonzero_indices = nonzero_indices[:, 0] * x.size()[1] + nonzero_indices[:, 1]
+
+ # Gather the data and construct the sparse matrix.
+ data = torch.gather(x.flatten(), dim=0, index=nonzero_indices)
+ data = torch.reshape(data, [-1, blocking, blocking])
+ return Matrix(x.size(), data, row_indices, column_indices, offsets)
+
+
+@torch.no_grad()
+def ones_like(x):
+ return Matrix(x.size(),
+ torch.ones_like(x.data),
+ x.row_indices,
+ x.column_indices, x.offsets)
+
+
+def sum(x):
+ assert isinstance(x, Matrix)
+ return x.data.sum()
diff --git a/build/torch210-cxx11-cu128-x86_64-linux/stk/ops/matrix_ops_test.py b/build/torch210-cxx11-cu128-x86_64-linux/stk/ops/matrix_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..d172d921f6f08b0e4fb709207a458b0e1e071bd0
--- /dev/null
+++ b/build/torch210-cxx11-cu128-x86_64-linux/stk/ops/matrix_ops_test.py
@@ -0,0 +1,62 @@
+import unittest
+
+# from absl.testing import parameterized
+import stk
+import torch
+
+
+# @parameterized.parameters(
+# (8, 16, 0.0, 1),
+# (8, 16, 0.5, 1),
+# (8, 16, .95, 1),
+# (16, 8, 0.0, 1),
+# (16, 8, 0.5, 1),
+# (16, 8, .95, 1),
+# (8, 16, 0.0, 8),
+# (8, 16, 0.5, 8),
+# (8, 16, 1.0, 8),
+# (16, 8, 0.0, 8),
+# (16, 8, 0.5, 8),
+# (16, 8, 1.0, 8),
+# (128, 256, 0.5, 16),
+# (256, 128, 0.75, 32),
+# (512, 512, .875, 128))
+# class MatrixOpsTest(parameterized.TestCase):
+#
+# def testMatrixOps_FormatConversion(self, rows, cols, sparsity, blocking):
+# mask = stk.random.dense_mask(rows, cols, sparsity, blocking)
+# x = (torch.randn(rows, cols) * mask).type(torch.float16)
+#
+# # Convert the matrix to sparse format.
+# sparse_x = stk.ops.to_sparse(x, blocking)
+#
+# # Validate the matrix.
+# sparse_x.validate()
+#
+# # Validate the shape.
+# self.assertEqual(sparse_x.dim(), 2)
+# self.assertEqual(sparse_x.size()[0], rows)
+# self.assertEqual(sparse_x.size()[1], cols)
+#
+# # Validate the sparsity.
+# numblocks = rows // blocking * cols // blocking
+# nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+# self.assertEqual(sparse_x.nnz, nnz)
+#
+# # Convert back to dense format.
+# dense_x = stk.ops.to_dense(sparse_x)
+#
+# # Validate the shape.
+# self.assertEqual(dense_x.dim(), 2)
+# self.assertEqual(dense_x.size()[0], rows)
+# self.assertEqual(dense_x.size()[1], cols)
+#
+# # Validate the sparsity
+# self.assertEqual(torch.count_nonzero(dense_x).item(), nnz)
+#
+# # Validate the output.
+# self.assertTrue(torch.all(torch.eq(x, dense_x)))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch210-cxx11-cu128-x86_64-linux/stk/random/__init__.py b/build/torch210-cxx11-cu128-x86_64-linux/stk/random/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2576d1ca27283f77569a9a620c7c99fa68aaf30e
--- /dev/null
+++ b/build/torch210-cxx11-cu128-x86_64-linux/stk/random/__init__.py
@@ -0,0 +1,2 @@
+# from stk.random.random_ops import dense_mask, mask, randn
+from .random_ops import dense_mask, mask, randn
diff --git a/build/torch210-cxx11-cu128-x86_64-linux/stk/random/random_ops.py b/build/torch210-cxx11-cu128-x86_64-linux/stk/random/random_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1b36771e0eb8e7abf46bcb3b136b5fb1d29df93
--- /dev/null
+++ b/build/torch210-cxx11-cu128-x86_64-linux/stk/random/random_ops.py
@@ -0,0 +1,36 @@
+import numpy as np
+import torch
+from ..ops import matrix_ops
+
+
+@torch.no_grad()
+def dense_mask(rows, cols, sparsity, blocking=1):
+ assert sparsity >= 0.0 and sparsity <= 1.0
+ assert rows % blocking == 0 and cols % blocking == 0
+
+ block_rows, block_cols = (rows // blocking, cols // blocking)
+ nnz = round(block_rows * block_cols * (1 - sparsity))
+
+ out = np.ones(block_rows * block_cols)
+ mask = np.random.choice(out.size, out.size - nnz, replace=False)
+ out[mask] = 0.0
+
+ out = np.tile(
+ np.reshape(out, [block_rows, 1, block_cols, 1]),
+ (1, blocking, 1, blocking))
+ out = np.reshape(out, [rows, cols])
+ return torch.from_numpy(out.astype(np.float32))
+
+
+@torch.no_grad()
+def mask(m, n, sparsity, blocking=1):
+ out = dense_mask(m, n, sparsity, blocking).type(torch.float16)
+ return matrix_ops.to_sparse(out, blocking=blocking)
+
+
+@torch.no_grad()
+def randn(shape, sparsity, blocking=1):
+ shape_2d = (np.prod(shape[:-1]), shape[-1])
+ out = mask(*shape_2d, sparsity, blocking)
+ out.data.copy_(torch.randn(*out.data.shape))
+ return out.view(*shape)
diff --git a/build/torch210-cxx11-cu128-x86_64-linux/stk/random/random_ops_test.py b/build/torch210-cxx11-cu128-x86_64-linux/stk/random/random_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..71d716b78b5ec009cbf9ac2dfdf09162a0102e62
--- /dev/null
+++ b/build/torch210-cxx11-cu128-x86_64-linux/stk/random/random_ops_test.py
@@ -0,0 +1,73 @@
+import unittest
+
+# from absl.testing import parameterized
+from . import random
+import torch
+
+
+# @parameterized.parameters(
+# (8, 16, 0.0, 1),
+# (8, 16, 0.5, 1),
+# (8, 16, .95, 1),
+# (16, 8, 0.0, 1),
+# (16, 8, 0.5, 1),
+# (16, 8, .95, 1),
+# (8, 16, 0.0, 8),
+# (8, 16, 0.5, 8),
+# (8, 16, 1.0, 8),
+# (16, 8, 0.0, 8),
+# (16, 8, 0.5, 8),
+# (16, 8, 1.0, 8),
+# (128, 256, 0.5, 16),
+# (256, 128, 0.75, 32),
+# (512, 512, .875, 128))
+# class RandomOpsTest(parameterized.TestCase):
+#
+# def testRandomOps_DenseMask(self, rows, cols, sparsity, blocking):
+# mask = random.dense_mask(
+# rows, cols, sparsity, blocking)
+#
+# # Validate the shape.
+# self.assertEqual(mask.dim(), 2)
+# self.assertEqual(mask.size()[0], rows)
+# self.assertEqual(mask.size()[1], cols)
+#
+# # Validate the sparsity
+# numblocks = rows // blocking * cols // blocking
+# nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+# self.assertEqual(
+# torch.count_nonzero(mask).item(),
+# nnz)
+#
+# # Check values are zero or one.
+# self.assertTrue(
+# torch.all(torch.logical_or(
+# torch.eq(mask, 0),
+# torch.eq(mask, 1))))
+#
+# def testRandomOps_SparseMask(self, rows, cols, sparsity, blocking):
+# mask = random.mask(
+# rows, cols, sparsity, blocking)
+#
+# # Validate the matrix.
+# mask.validate()
+#
+# # Validate the shape.
+# self.assertEqual(mask.dim(), 2)
+# self.assertEqual(mask.size()[0], rows)
+# self.assertEqual(mask.size()[1], cols)
+#
+# # Validate the sparsity.
+# numblocks = rows // blocking * cols // blocking
+# nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+# self.assertEqual(mask.nnz, nnz)
+#
+# # Check values are zero or one.
+# self.assertTrue(
+# torch.all(torch.logical_or(
+# torch.eq(mask.data, 0),
+# torch.eq(mask.data, 1))))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch210-cxx11-cu128-x86_64-linux/xpu_fused_moe.py b/build/torch210-cxx11-cu128-x86_64-linux/xpu_fused_moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..0440b484e6f4073a384dcf3bc562174d2ef71d25
--- /dev/null
+++ b/build/torch210-cxx11-cu128-x86_64-linux/xpu_fused_moe.py
@@ -0,0 +1,672 @@
+# SPDX-License-Identifier: Apache-2.0
+# MegaBlocks XPU Fused MoE Implementation
+import os
+import torch
+
+from ._ops import ops, add_op_namespace_prefix
+
+from torch.library import register_fake
+
+
+def resolve_dtensor(weight: torch.Tensor):
+ """Convert DTensor to local tensor for use with custom ops."""
+ from torch.distributed._tensor import DTensor
+ if isinstance(weight, DTensor):
+ return weight.to_local()
+ return weight
+
+
+# Register fake/meta kernels for torch.compile compatibility
+def _register_xpu_fake_kernels():
+ """Register fake kernels for XPU MoE operations to support torch.compile."""
+
+ def _register_if_available(op_name, fn):
+ if hasattr(ops, op_name):
+ register_fake(add_op_namespace_prefix(op_name))(fn)
+
+ _register_if_available(
+ "cutlass_grouped_gemm_interface",
+ lambda ptr_A, ptr_B, ptr_scales, ptr_bias, ptr_D, expert_first_token_offset, N, K, num_experts, is_B_int4, is_B_mxfp4: ptr_D,
+ )
+
+ _register_if_available(
+ "fused_moe_prologue",
+ lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, ep_rank, ep_size, num_experts_on_rank: None,
+ )
+
+ _register_if_available(
+ "moe_gather",
+ lambda output, moe_output, topk_weights, permuted_row_to_unpermuted_row, unpermuted_row_to_permuted_row, expert_first_token_offset, num_experts: None,
+ )
+
+ _register_if_available(
+ "silu_and_mul",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "mul_and_silu",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_and_mul",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_tanh_and_mul",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_fast",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_new",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_quick",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "swigluoai_and_mul",
+ lambda out, input, alpha=1.702, limit=7.0: None,
+ )
+
+
+# Register fake kernels on module load
+_register_xpu_fake_kernels()
+
+
+# default
+def cutlass_grouped_gemm(input_A, input_B, bias, output, expert_token_count, n,
+ k, num_experts):
+ # expert_token_count_ = torch.tensor(expert_token_count,
+ # dtype=torch.int64,
+ # device=input_A.device)
+ # if bias is not None:
+ # bias = bias.repeat_interleave(expert_token_count_, dim=0).float()
+
+ def exclusive_prefix_sum(arr):
+ prefix = [0]
+ for i, x in enumerate(arr):
+ prefix.append(prefix[-1] + x)
+ return prefix
+
+ expert_offset = torch.tensor(exclusive_prefix_sum(expert_token_count),
+ dtype=torch.int64,
+ device="xpu")
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=None,
+ ptr_bias=bias,
+ ptr_D=output,
+ expert_first_token_offset=expert_offset,
+ N=n,
+ K=k,
+ num_experts=num_experts,
+ is_B_int4=False,
+ is_B_mxfp4=False)
+
+
+def cutlass_grouped_gemm_xe2(input_A, input_B, scales, bias, output,
+ num_rows_per_expert, n, k, num_experts, is_B_int4,
+ is_B_mxfp4):
+ expert_first_token_offset = torch.cat([
+ torch.tensor([0],
+ dtype=num_rows_per_expert.dtype,
+ device=num_rows_per_expert.device),
+ torch.cumsum(num_rows_per_expert, dim=0)
+ ]).to(torch.int64)
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=scales,
+ ptr_bias=bias,
+ ptr_D=output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=n,
+ K=k,
+ num_experts=num_experts,
+ is_B_int4=is_B_int4,
+ is_B_mxfp4=is_B_mxfp4)
+
+
+def ceilDiv(a, b):
+ return (a + b - 1) // b
+
+
+def compute_num_tokens_per_block(num_tokens, num_experts_per_node):
+ for num_tokens_per_block in [32, 64, 128, 256, 512, 1024]:
+ num_blocks_per_seq = ceilDiv(num_tokens, num_tokens_per_block)
+ if num_blocks_per_seq * num_experts_per_node <= num_tokens_per_block:
+ return num_tokens_per_block
+ return 1024
+
+
+def _bytes_to_typed_tensor(byte_tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
+ """Reinterpret a uint8 buffer as a typed tensor by copying bytes.
+
+ This avoids `Tensor.view(dtype)` which can fail under torch.compile
+ constant folding when shape divisibility is not proven.
+ """
+ if byte_tensor.dtype != torch.uint8:
+ raise ValueError("byte_tensor must be uint8")
+ itemsize = torch.empty((), dtype=dtype).element_size()
+ numel = byte_tensor.numel() // itemsize
+ out = torch.empty((numel,), dtype=dtype, device=byte_tensor.device)
+ out.view(torch.uint8).copy_(byte_tensor.contiguous())
+ return out
+
+
+def implement_zp(qweight):
+ # change u4 to s4 to avoid zero point in gemm kernel
+ # only support default zero point now
+ assert qweight.dtype == torch.uint8, "Input tensor must be uint8"
+
+ high_u4 = (qweight >> 4) & 0x0F
+ low_u4 = qweight & 0x0F
+
+ high_s8 = high_u4.to(torch.int8)
+ low_s8 = low_u4.to(torch.int8)
+
+ high_s8 = high_s8 - 8
+ low_s8 = low_s8 - 8
+
+ def pack_compact(a, b):
+
+ def process_number(x):
+ sign = (x < 0).to(torch.uint8)
+ abs_low3 = (x.view(torch.uint8) & 0x7).to(torch.uint8)
+ return (sign << 3) | abs_low3
+
+ packed_a = process_number(a)
+ packed_b = process_number(b)
+
+ return (packed_a << 4) | packed_b
+
+ result = pack_compact(high_s8, low_s8)
+
+ return result
+
+
+def xpu_fused_moe(hidden_states,
+ w13,
+ w13_scales,
+ w13_bias,
+ w2,
+ w2_scales,
+ w2_bias,
+ topk_weights,
+ topk_ids,
+ n_experts_per_token,
+ activation,
+ num_experts,
+ ep_rank=0,
+ ep_size=1,
+ is_fp8=False,
+ is_int4=False,
+ is_mxfp4=False):
+ '''
+ hidden_states: [num_rows, hidden_size]
+ w13: [num_experts, 2*inter_size, hidden_size]
+ w13_scales:
+ None for bf16/fp16
+ or [num_experts] for fp8
+ or [num_experts, 2*inter_size, hidden_size // group_size] for 4bits
+ w13_bias: [num_experts, 2*inter_size] or None
+ w2: [num_experts, hidden_size, inter_size]
+ w2_scales:
+ None for bf16/fp16
+ or [num_experts] for fp8
+ or [num_experts, hidden_size, inter_size // group_size] for 4bits
+ w2_bias: [num_experts, hidden_size] or None
+ topk_weights: [num_rows, topk]
+ topk_ids: [num_rows, topk]
+ n_experts_per_token: int
+ activation: str
+ num_experts: int
+ is_int4: bool
+ is_mxfp4: bool
+ '''
+
+ # Resolve DTensors to local tensors before passing to custom ops
+ hidden_states = resolve_dtensor(hidden_states)
+ w13 = resolve_dtensor(w13)
+ w2 = resolve_dtensor(w2)
+ if w13_scales is not None:
+ w13_scales = resolve_dtensor(w13_scales)
+ if w13_bias is not None:
+ w13_bias = resolve_dtensor(w13_bias)
+ if w2_scales is not None:
+ w2_scales = resolve_dtensor(w2_scales)
+ if w2_bias is not None:
+ w2_bias = resolve_dtensor(w2_bias)
+ topk_weights = resolve_dtensor(topk_weights)
+ topk_ids = resolve_dtensor(topk_ids)
+
+ output = torch.empty_like(hidden_states)
+ num_rows, hidden_size = list(hidden_states.shape)
+
+ dim_last = w13.shape[-1]
+ dim_second_last = w13.shape[-2]
+
+ # w13 is combined gate+up weights, so one dimension is 2*inter_size
+ # Determine which dimension is hidden_size and which is 2*inter_size
+ if dim_second_last == hidden_size:
+ # w13 is [E, hidden_size, 2*inter_size] - standard layout
+ inter_size = dim_last // 2
+ needs_transpose = False
+ else:
+ # w13 is [E, 2*inter_size, hidden_size] - needs transpose
+ inter_size = dim_second_last // 2
+ needs_transpose = True
+
+ assert w13.is_contiguous() and w2.is_contiguous()
+
+ # 4bits support [E, N, K]
+ # other types [E, K, N]
+ if not is_int4 and not is_mxfp4:
+ if not hasattr(w13, 'xpu_fused_moe'):
+ if needs_transpose:
+ w13.data = w13.transpose(-1, -2).contiguous()
+ w2.data = w2.transpose(-1, -2).contiguous()
+ w13.xpu_fused_moe = True
+ w13.inter_size = inter_size
+ else:
+ inter_size = w13.inter_size
+
+ if is_int4 and not hasattr(w13, 'xpu_fused_moe'):
+ w13_tmp = torch.empty_like(w13)
+ w2_tmp = torch.empty_like(w2)
+ for i in range(num_experts):
+ w13_tmp[i] = implement_zp(w13[i])
+ w2_tmp[i] = implement_zp(w2[i])
+ w13_tmp = w13_tmp.contiguous()
+ w2_tmp = w2_tmp.contiguous()
+ w13.data = w13_tmp
+ w2.data = w2_tmp
+ w13.xpu_fused_moe = True
+
+ # TODO: will all integrated in Cpp func. Temporary expose before gemm fusion
+ num_experts_per_node = num_experts
+ experts_per_token = n_experts_per_token
+ num_moe_inputs = n_experts_per_token * num_rows
+ permuted_elems = num_moe_inputs * hidden_size
+ # interbuf_elems = num_moe_inputs * inter_size
+ permuted_row_to_unpermuted_row_size = num_moe_inputs * 4
+ permuted_token_selected_experts_size = num_moe_inputs * 4
+ src_to_dest_map_size = experts_per_token * num_rows * 4
+ expert_first_token_offset_size = (num_experts_per_node + 1) * 8
+ num_tokens_per_block = compute_num_tokens_per_block(
+ num_rows, num_experts_per_node)
+ num_blocks_per_seq = ceilDiv(num_rows, num_tokens_per_block)
+ blocked_expert_counts_size = num_experts_per_node * num_blocks_per_seq * 4
+ blocked_expert_counts_cumsum_size = blocked_expert_counts_size
+ blocked_row_to_unpermuted_row_size = num_experts_per_node * num_rows * 4
+ permuted_data_size = permuted_elems * hidden_states.element_size()
+ permuted_token_final_scales_size = num_moe_inputs * 4
+
+ ws_map = {}
+ map_offset = 0
+
+ def config_ws(name, size):
+ nonlocal map_offset
+ if size % 256 != 0:
+ size += 256 - size % 256
+ ws_map[name] = (size, map_offset)
+ map_offset += size
+
+ config_ws("permuted_row_to_unpermuted_row",
+ permuted_row_to_unpermuted_row_size)
+ config_ws("permuted_token_selected_experts",
+ permuted_token_selected_experts_size)
+ config_ws("unpermuted_row_to_permuted_row", src_to_dest_map_size)
+ config_ws("blocked_expert_counts", blocked_expert_counts_size)
+ config_ws("blocked_expert_counts_cumsum",
+ blocked_expert_counts_cumsum_size)
+ config_ws("blocked_row_to_unpermuted_row",
+ blocked_row_to_unpermuted_row_size)
+ config_ws("expert_first_token_offset", expert_first_token_offset_size)
+ config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
+ config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
+
+ workspace = torch.zeros(map_offset,
+ dtype=torch.uint8,
+ device=hidden_states.device)
+ if topk_ids.dtype == torch.int32:
+ topk_ids = topk_ids.to(torch.int64)
+ ops.fused_moe_prologue(
+ input=hidden_states,
+ token_selected_experts=topk_ids,
+ token_final_scales=topk_weights,
+ workspace=workspace,
+ hidden_size=hidden_size,
+ inter_size=inter_size,
+ ep_rank=ep_rank,
+ ep_size=ep_size,
+ num_experts_on_rank=num_experts_per_node)
+
+ expert_first_token_offset_bytes = workspace[
+ ws_map["expert_first_token_offset"][1]:
+ ws_map["expert_first_token_offset"][1] +
+ expert_first_token_offset_size]
+ unpermuted_row_to_permuted_row_bytes = workspace[
+ ws_map["unpermuted_row_to_permuted_row"][1]:
+ ws_map["unpermuted_row_to_permuted_row"][1] +
+ src_to_dest_map_size]
+ permuted_row_to_unpermuted_row_bytes = workspace[
+ ws_map["permuted_row_to_unpermuted_row"][1]:
+ ws_map["permuted_row_to_unpermuted_row"][1] +
+ permuted_row_to_unpermuted_row_size]
+
+ if torch.compiler.is_compiling():
+ expert_first_token_offset = _bytes_to_typed_tensor(
+ expert_first_token_offset_bytes, torch.int64
+ )
+ unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
+ unpermuted_row_to_permuted_row_bytes, torch.int32
+ )
+ permuted_row_to_unpermuted_row = _bytes_to_typed_tensor(
+ permuted_row_to_unpermuted_row_bytes, torch.int32
+ )
+ else:
+ expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
+ unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
+ permuted_row_to_unpermuted_row = permuted_row_to_unpermuted_row_bytes.view(torch.int32)
+ gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
+ ws_map["overlapped_gemm1_gemm2_inputs"][1] +
+ permuted_data_size].view(hidden_states.dtype).view(
+ num_moe_inputs, hidden_size)
+ # permuted_token_final_scales = workspace[
+ # ws_map["permuted_token_final_scales"][1]:
+ # ws_map["permuted_token_final_scales"][1] +
+ # permuted_token_final_scales_size].view(torch.float)
+ gemm1_output = torch.empty((num_moe_inputs, 2 * inter_size),
+ dtype=hidden_states.dtype,
+ device=hidden_states.device)
+
+ ########### gemm1 ##################
+ input_B = w13
+
+ if not is_fp8 and not is_int4 and not is_mxfp4:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=gemm1_input,
+ ptr_B=input_B,
+ ptr_scales=None,
+ ptr_bias=w13_bias,
+ ptr_D=gemm1_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=2 * inter_size,
+ K=hidden_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+ else:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=gemm1_input,
+ ptr_B=input_B,
+ ptr_scales=w13_scales,
+ ptr_bias=w13_bias,
+ ptr_D=gemm1_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=2 * inter_size,
+ K=hidden_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+
+ # act
+ act_output = torch.empty((num_moe_inputs, inter_size),
+ dtype=gemm1_output.dtype,
+ device=gemm1_output.device)
+ if activation == "silu":
+ ops.silu_and_mul(act_output, gemm1_output)
+ elif activation == "gelu":
+ ops.gelu_and_mul(act_output, gemm1_output)
+ elif activation == "swigluoai":
+ ops.swigluoai_and_mul(act_output, gemm1_output, 1.702, 7.0)
+ else:
+ raise ValueError(f"Unsupported FusedMoe activation: {activation}.")
+
+ ########### gemm2 ##################
+ input_A = act_output.contiguous()
+ input_B = w2
+ gemm2_output = torch.empty((num_moe_inputs, hidden_size),
+ dtype=hidden_states.dtype,
+ device=hidden_states.device)
+ if not is_fp8 and not is_int4 and not is_mxfp4:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=None,
+ ptr_bias=w2_bias,
+ ptr_D=gemm2_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=hidden_size,
+ K=inter_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+ else:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=w2_scales,
+ ptr_bias=w2_bias,
+ ptr_D=gemm2_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=hidden_size,
+ K=inter_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+
+ ops.moe_gather(output, gemm2_output, topk_weights,
+ permuted_row_to_unpermuted_row,
+ unpermuted_row_to_permuted_row,
+ expert_first_token_offset,
+ num_experts_per_node)
+ return output
+
+
+def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor:
+ """Apply jitter to the input tensor for regularization."""
+ low = 1.0 - moe_jitter_eps
+ high = 1.0 + moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return x * (low + noise * (high - low))
+
+
+def compute_top_k(scores: torch.Tensor, moe_top_k: int):
+ """Compute the top-k scores from the logits."""
+ if moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, moe_top_k, dim=-1)
+
+
+def route_tokens_xpu(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: torch.Tensor,
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ training: bool = False,
+) -> tuple:
+ """Route tokens to experts and compute expert weights and indices (XPU version)."""
+ if training and moe_jitter_eps is not None:
+ x = apply_jitter(x, moe_jitter_eps)
+
+ x_flat = x.view(-1, x.shape[-1])
+ logits = torch.nn.functional.linear(x_flat, router_weight, router_bias)
+ expert_weights, expert_indices = compute_top_k(logits, moe_top_k)
+ expert_weights = expert_weights.softmax(dim=-1)
+ if moe_normalize_expert_weights is not None:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+
+ return logits, expert_weights, expert_indices
+
+
+def _get_device_mesh(model):
+ """Extract device_mesh from child's unused pre_hook closure for EP support."""
+ try:
+ hook = next(
+ h
+ for h in model.experts._forward_pre_hooks.values()
+ if "device_mesh" in h.__code__.co_freevars
+ )
+ return hook.__closure__[
+ hook.__code__.co_freevars.index("device_mesh")
+ ].cell_contents
+ except Exception:
+ return None
+
+
+class MegaBlocksMoeMLP(torch.nn.Module):
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> tuple:
+ """
+ Forward pass through the MoE layer.
+
+ Args:
+ x: Input tensor of shape [batch_size, seq_len, hidden_size] or [tokens, hidden_size]
+
+ Returns:
+ Tuple of (output, expert_weights) where:
+ - output: Tensor of same shape as input
+ - expert_weights: Expert weights for each token [tokens, top_k]
+ """
+
+ # Get MoE parameters from the wrapped modules
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+
+ # Get EP (Expert Parallelism) parameters
+ ep_size = 1
+ ep_rank = 0
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = _get_device_mesh(self)
+ if device_mesh is not None:
+ expert_parallel_group = device_mesh.get_group()
+ if expert_parallel_group is not None:
+ import torch.distributed as dist
+ if dist.is_initialized():
+ ep_size = dist.get_world_size(expert_parallel_group)
+ ep_rank = dist.get_rank(expert_parallel_group)
+
+ # Number of experts on this rank
+ num_experts_on_rank = moe_num_experts // ep_size
+
+ # Detect activation type - check for GptOss-style swigluoai activation
+ # GptOssExperts has alpha and limit attributes for swigluoai
+ if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
+ activation = "swigluoai"
+ else:
+ activation = getattr(self.experts, "activation", "silu")
+
+ # Get weight tensors - support different naming conventions
+ if hasattr(self.experts, "gate_up_proj"):
+ w13 = self.experts.gate_up_proj
+ # NOTE: swigluoai_and_mul kernel expects interleaved layout [g0,u0,g1,u1,...]
+ # which matches GptOss's gate_up_proj layout, so no conversion needed.
+
+ elif hasattr(self.experts, "w1"):
+ # Combine w1 and w3 if stored separately
+ w1 = self.experts.w1
+ w3 = getattr(self.experts, "w3", None)
+ if w3 is not None:
+ w13 = torch.cat([w1, w3], dim=-2)
+ else:
+ w13 = w1
+ else:
+ raise AttributeError("experts module must have 'gate_up_proj' or 'w1' attribute")
+
+ if hasattr(self.experts, "down_proj"):
+ w2 = self.experts.down_proj
+ elif hasattr(self.experts, "w2"):
+ w2 = self.experts.w2
+ else:
+ raise AttributeError("experts module must have 'down_proj' or 'w2' attribute")
+
+ # Get optional bias tensors
+ w13_bias = getattr(self.experts, "gate_up_proj_bias", None)
+ w2_bias = getattr(self.experts, "down_proj_bias", None)
+
+ # Get quantization info
+ is_fp8 = getattr(self.experts, "is_fp8", False)
+ is_int4 = getattr(self.experts, "is_int4", False)
+ is_mxfp4 = getattr(self.experts, "is_mxfp4", False)
+
+ w13_scales = getattr(self.experts, "gate_up_proj_scales", None)
+ w2_scales = getattr(self.experts, "down_proj_scales", None)
+
+ # Store original shape
+ in_shape = x.size()
+
+ # Route tokens to experts
+ logits, expert_weights, expert_indices = route_tokens_xpu(
+ x,
+ self.router.weight,
+ getattr(self.router, "bias", None),
+ moe_top_k,
+ moe_num_experts,
+ moe_jitter_eps,
+ moe_normalize_expert_weights,
+ self.training,
+ )
+
+ # Reshape input for fused MoE
+ x_flat = x.view(-1, x.shape[-1])
+
+ # Call XPU fused MoE kernel
+ output = xpu_fused_moe(
+ hidden_states=x_flat,
+ w13=w13,
+ w13_scales=w13_scales,
+ w13_bias=w13_bias,
+ w2=w2,
+ w2_scales=w2_scales,
+ w2_bias=w2_bias,
+ topk_weights=expert_weights.float(),
+ topk_ids=expert_indices,
+ n_experts_per_token=moe_top_k,
+ activation=activation,
+ num_experts=num_experts_on_rank,
+ ep_rank=ep_rank,
+ ep_size=ep_size,
+ is_fp8=is_fp8,
+ is_int4=is_int4,
+ is_mxfp4=is_mxfp4,
+ )
+
+ # All-reduce across EP group to combine partial expert outputs
+ if ep_size > 1 and expert_parallel_group is not None:
+ import torch.distributed as dist
+ dist.all_reduce(output, op=dist.ReduceOp.SUM, group=expert_parallel_group)
+
+ # Restore original shape
+ output = output.view(in_shape)
+
+ return output, expert_weights
+
+
+# Export classes and functions
+__all__ = [
+ "MegaBlocksMoeMLP",
+ "xpu_fused_moe",
+ "cutlass_grouped_gemm",
+ "cutlass_grouped_gemm_xe2",
+]
\ No newline at end of file
diff --git a/build/torch210-cxx11-cu130-aarch64-linux/__init__.py b/build/torch210-cxx11-cu130-aarch64-linux/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..38075732c6d8fa0e1e6ef493145e1aca3851ae6b
--- /dev/null
+++ b/build/torch210-cxx11-cu130-aarch64-linux/__init__.py
@@ -0,0 +1,202 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+from ._ops import ops
+
+from .grouped_gemm import backend as gg_backend
+from .grouped_gemm import ops as gg_ops
+
+
+from ._layers.arguments import Arguments
+from ._layers.dmoe import ParallelDroplessMLP, dMoE
+from ._layers.glu import SparseGLU
+from ._layers.mlp import MLP, SparseMLP
+from ._layers.moe import MoE, ParallelMLP, get_load_balancing_loss
+
+from . import layers
+
+# This section contains the direct kernel exports (not inlcuded in the original code)
+def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
+ """
+ Compute exclusive cumulative sum along the specified dimension.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ result = ops.exclusive_cumsum(x, dim)
+ out.copy_(result)
+ return out
+
+
+def inclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
+ """
+ Compute inclusive cumulative sum along the specified dimension.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ result = ops.inclusive_cumsum(x, dim)
+ out.copy_(result)
+ return out
+
+
+def histogram(x: torch.Tensor, num_bins: int) -> torch.Tensor:
+ """
+ Compute histogram of input tensor values.
+
+ Args:
+ x: Input tensor
+ num_bins: Number of histogram bins
+
+ Returns:
+ Histogram tensor with counts for each bin
+ """
+ return ops.histogram(x, num_bins)
+
+
+def indices(
+ padded_bins: torch.Tensor,
+ block_size: int,
+ output_block_rows: int,
+ output_block_columns: int,
+) -> torch.Tensor:
+ """
+ Construct indices from padded bins for sparse operations.
+
+ Args:
+ padded_bins: Tensor containing bin boundaries
+ block_size: Size of each block
+ output_block_rows: Number of rows in output blocks
+ output_block_columns: Number of columns in output blocks
+
+ Returns:
+ Tensor containing constructed indices
+ """
+ return ops.indices(padded_bins, block_size, output_block_rows, output_block_columns)
+
+
+def replicate_forward(
+ x: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Forward pass of replicate operation - replicate values according to bin sizes.
+
+ Args:
+ x: Input tensor with values to replicate
+ bins: Tensor containing bin sizes
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ return ops.replicate_forward(x, bins, out)
+
+
+def replicate_backward(
+ grad: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Backward pass of replicate operation - reduce gradients back to bins.
+
+ Args:
+ grad: Gradient tensor to reduce
+ bins: Tensor containing bin sizes
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ return ops.replicate_backward(grad, bins, out)
+
+
+def sort(
+ x: torch.Tensor, end_bit: int, x_out: torch.Tensor, iota_out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Radix sort with index tracking.
+
+ Args:
+ x: Input tensor to sort
+ end_bit: Number of bits to consider in sorting
+ x_out: Output tensor for sorted values
+ iota_out: Output tensor for sorted indices
+
+ Returns:
+ The sorted values tensor
+ """
+ return ops.sort(x, end_bit, x_out, iota_out)
+
+
+# Convenience functions for common use cases
+def cumsum(x: torch.Tensor, dim: int = -1, exclusive: bool = False) -> torch.Tensor:
+ """
+ Compute cumulative sum with automatic output allocation.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum (default: last dimension)
+ exclusive: Whether to compute exclusive (True) or inclusive (False) cumsum
+
+ Returns:
+ New tensor containing the cumulative sum
+ """
+ out = torch.empty_like(x)
+ if exclusive:
+ return exclusive_cumsum(x, dim, out)
+ else:
+ return inclusive_cumsum(x, dim, out)
+
+
+def argsort(x: torch.Tensor, end_bit: int = 32) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Sort tensor and return both sorted values and indices.
+
+ Args:
+ x: Input tensor to sort
+ end_bit: Number of bits to consider in sorting
+
+ Returns:
+ Tuple of (sorted_values, sorted_indices)
+ """
+ x_out = torch.empty_like(x)
+ iota_out = torch.empty_like(x)
+ sort(x, end_bit, x_out, iota_out)
+ return x_out, iota_out
+
+
+# Export public API
+__all__ = [
+ "MyReplacementLayer",
+ # Direct kernel exports
+ "exclusive_cumsum",
+ "inclusive_cumsum",
+ "histogram",
+ "indices",
+ "replicate_forward",
+ "replicate_backward",
+ "sort",
+ "cumsum",
+ "argsort",
+ # Original exports
+ "Arguments",
+ "ParallelDroplessMLP",
+ "dMoE",
+ "SparseGLU",
+ "MLP",
+ "SparseMLP",
+ "MoE",
+ "ParallelMLP",
+ "get_load_balancing_loss",
+]
diff --git a/build/torch210-cxx11-cu130-aarch64-linux/_layers/__init__.py b/build/torch210-cxx11-cu130-aarch64-linux/_layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a720e7a2cc4e44636f6e433a2750e945dc38e8b2
--- /dev/null
+++ b/build/torch210-cxx11-cu130-aarch64-linux/_layers/__init__.py
@@ -0,0 +1,10 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# from megablocks.layers.dmoe import dMoE
+from .moe import MoE
+
+__all__ = [
+ 'MoE',
+ # 'dMoE',
+]
diff --git a/build/torch210-cxx11-cu130-aarch64-linux/_layers/activation_fn.py b/build/torch210-cxx11-cu130-aarch64-linux/_layers/activation_fn.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e1d956704840aa4daf7d1d71d24e051567feab9
--- /dev/null
+++ b/build/torch210-cxx11-cu130-aarch64-linux/_layers/activation_fn.py
@@ -0,0 +1,33 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Callable, Union
+
+import torch
+from ..stk import Matrix
+
+
+def act_fn(
+ x: Matrix,
+ function: Callable,
+ return_grad_fn: bool = False,
+ **kwargs,
+) -> Union[tuple[Matrix, Any] | Matrix]:
+ assert isinstance(x, Matrix)
+ with torch.set_grad_enabled(torch.is_grad_enabled() or return_grad_fn):
+ if return_grad_fn:
+ x.data.requires_grad = True
+ out = function(x.data, **kwargs)
+ y = Matrix(
+ x.size(),
+ out,
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
+ if return_grad_fn:
+ return y, out.backward
+ return y
diff --git a/build/torch210-cxx11-cu130-aarch64-linux/_layers/all_to_all.py b/build/torch210-cxx11-cu130-aarch64-linux/_layers/all_to_all.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ac7067bcaa34db1d82b340c43550fe3577aa7a3
--- /dev/null
+++ b/build/torch210-cxx11-cu130-aarch64-linux/_layers/all_to_all.py
@@ -0,0 +1,54 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import torch.distributed as dist
+
+
+class AllToAllOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op):
+ out = torch.empty((sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype)
+
+ ctx.input_shape = x.shape
+ ctx.output_split_sizes = output_split_sizes
+ ctx.input_split_sizes = input_split_sizes
+ ctx.group = group
+ handle = dist.all_to_all_single(
+ out,
+ x,
+ output_split_sizes=output_split_sizes,
+ input_split_sizes=input_split_sizes,
+ group=group,
+ async_op=async_op,
+ )
+ return out, handle
+
+ @staticmethod
+ def backward(ctx, grad, _):
+ if ctx.needs_input_grad[0]:
+ out = torch.empty(
+ ctx.input_shape,
+ device=grad.device,
+ dtype=grad.dtype,
+ )
+ dist.all_to_all_single(
+ out,
+ grad,
+ output_split_sizes=ctx.input_split_sizes,
+ input_split_sizes=ctx.output_split_sizes,
+ group=ctx.group,
+ )
+ return out, None, None, None, None
+ return None, None, None, None, None
+
+
+def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False):
+ return AllToAllOp.apply(
+ x,
+ output_split_sizes,
+ input_split_sizes,
+ group,
+ async_op,
+ )
diff --git a/build/torch210-cxx11-cu130-aarch64-linux/_layers/arguments.py b/build/torch210-cxx11-cu130-aarch64-linux/_layers/arguments.py
new file mode 100644
index 0000000000000000000000000000000000000000..4db9b1bd38bc2e2f421625c124f86b85f45c5ae0
--- /dev/null
+++ b/build/torch210-cxx11-cu130-aarch64-linux/_layers/arguments.py
@@ -0,0 +1,101 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import dataclasses
+from functools import partial
+from typing import Any, Callable, Optional, Union
+
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+
+# import megablocks.grouped_gemm_util as grouped_gemm
+from .. import grouped_gemm_util as grouped_gemm
+
+# Type annotation for in-place Tensor initialization function.
+InitFn = Union[Callable[[torch.Tensor], None], partial[torch.Tensor]]
+
+_ALLOWED_BITWIDTHS = (-1, 4, 8)
+
+DEFAULT_ACTIVATION_FN = partial(F.gelu, approximate='tanh')
+
+
+@dataclasses.dataclass
+class Arguments:
+ # Model arguments.
+ hidden_size: int = 1024
+ ffn_hidden_size: int = 4096
+ num_layers: int = 1
+ bias: bool = True
+ return_bias: bool = True
+ activation_fn: Optional[Callable] = DEFAULT_ACTIVATION_FN
+
+ # MoE arguments.
+ moe_num_experts: int = 1
+ moe_top_k: int = 1
+ moe_capacity_factor: int = 1
+ moe_normalize_expert_weights: Optional[Union[int, float]] = None
+ moe_loss_weight: float = 0.1
+ moe_jitter_eps: Optional[float] = None
+ moe_lbl_in_fp32: bool = False
+
+ # Parallelism arguments.
+ moe_expert_model_parallelism: bool = False
+ expert_parallel_group: Optional[dist.ProcessGroup] = None
+ pipeline_model_parallel_size: int = 1
+ num_layers_per_virtual_pipeline_stage: Optional[int] = None
+
+ # Compute arguments.
+ memory_optimized_mlp: bool = False
+ mlp_type: str = 'mlp'
+ mlp_impl: str = 'sparse'
+
+ # Initialization arguments.
+ fp16: bool = True
+ bf16: bool = False
+ device: Union[int, torch.device] = dataclasses.field(default_factory=torch.cuda.current_device)
+ init_method: InitFn = partial(torch.nn.init.normal_, mean=0.0, std=0.02)
+ output_layer_init_method: InitFn = init_method
+
+ # Benchmarking arguments.
+ uniform_expert_assignment: bool = False
+
+ # shared expert arguments
+ shared_expert: bool = False # enable using shared expert
+ fc_cls: Any = torch.nn.Linear # class of the fully connected layer in shared expert (purpose: to allow using custom FC layer eg te.Linear (for FP8))
+ fc_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict,) # kwargs for custom fc layers
+ remat_act_fn: bool = True # enable act fn to be rematerialized instead of stored
+ shared_expert_hidden_size: Optional[
+ int] = None # hidden size of the shared expert IF we want to set it to something different from hidden_size
+ shared_expert_weighted_sum: bool = False # enable using weighted sum for shared expert output (wieghted by number of experts used)
+
+ # Router Z-loss arguments
+ moe_zloss_weight: float = 0 # 1e-3 is a reasonable value
+ moe_zloss_in_fp32: bool = False
+
+ def __post_init__(self):
+ # Sparse MLP is not supported with triton >=3.2.0
+ # TODO: Remove this once sparse is supported with triton >=3.2.0
+ if self.__getattribute__('mlp_impl') == 'sparse':
+ try:
+ import triton
+ if triton.__version__ >= '3.2.0':
+ raise ValueError(
+ 'Sparse MLP is not supported with triton >=3.2.0. Please use mlp_impl="grouped" instead.',
+ )
+ except ImportError:
+ raise ImportError('Triton is required for sparse MLP implementation')
+
+ if self.__getattribute__('mlp_impl') == 'grouped':
+ grouped_gemm.assert_grouped_gemm_is_available()
+
+ if self.shared_expert_hidden_size is None:
+ self.shared_expert_hidden_size = self.ffn_hidden_size
+
+
+def from_megatron(megatron_args: Any):
+ args = Arguments()
+ for field in dataclasses.fields(args):
+ if hasattr(megatron_args, field.name):
+ setattr(args, field.name, getattr(megatron_args, field.name))
+ return args
diff --git a/build/torch210-cxx11-cu130-aarch64-linux/_layers/common.py b/build/torch210-cxx11-cu130-aarch64-linux/_layers/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d07109702963ba48a3b94ab860807954dfd79c1
--- /dev/null
+++ b/build/torch210-cxx11-cu130-aarch64-linux/_layers/common.py
@@ -0,0 +1,26 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+from .arguments import Arguments
+
+
+def dtype(args: Arguments):
+ if args.fp16:
+ return torch.float16
+ elif args.bf16:
+ return torch.bfloat16
+ return None
+
+
+def cast_if_autocast_enabled(tensor):
+ if torch.is_autocast_enabled():
+ if tensor.device.type == 'cuda':
+ dtype = torch.get_autocast_gpu_dtype()
+ elif tensor.device.type == 'cpu':
+ dtype = torch.get_autocast_cpu_dtype()
+ else:
+ raise NotImplementedError()
+ return tensor.to(dtype=dtype)
+ return tensor
diff --git a/build/torch210-cxx11-cu130-aarch64-linux/_layers/dmlp_registry.py b/build/torch210-cxx11-cu130-aarch64-linux/_layers/dmlp_registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..de2ed047042e438c7190ebb139b6f7f30009734c
--- /dev/null
+++ b/build/torch210-cxx11-cu130-aarch64-linux/_layers/dmlp_registry.py
@@ -0,0 +1,42 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Union
+
+from . import glu, mlp
+from .arguments import Arguments
+
+MlpType = Union[mlp.SparseMLP, glu.SparseGLU]
+
+_REGISTRY = {
+ 'mlp': {
+ 'grouped': mlp.GroupedMLP,
+ 'sparse': mlp.SparseMLP,
+ },
+ 'glu': {
+ 'grouped': glu.GroupedGLU,
+ 'sparse': glu.SparseGLU,
+ },
+}
+
+
+def get(args: Arguments) -> MlpType:
+ """Returns an MLP for use in a dMoE instance.
+
+ Uses the provided arguments to instantiate the appropriate
+ MLP instance. This only contains MLPs for use in dMoEs
+ (ie. only for the dropless versions of MoEs).
+
+ Args:
+ args: propagated Arguments dataclass.
+
+ Returns:
+ An instantiated MLP constructed using the input args.
+ """
+ if args.mlp_type not in _REGISTRY:
+ raise ValueError(f'Unsupported mlp type: {args.mlp_type}')
+
+ if args.mlp_impl not in _REGISTRY[args.mlp_type]:
+ raise ValueError(f'{args.mlp_type} does not support {args.mlp_impl} backend.',)
+
+ return _REGISTRY[args.mlp_type][args.mlp_impl](args)
diff --git a/build/torch210-cxx11-cu130-aarch64-linux/_layers/dmoe.py b/build/torch210-cxx11-cu130-aarch64-linux/_layers/dmoe.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d0375a4df2f27134c4127e60be04f3b45693050
--- /dev/null
+++ b/build/torch210-cxx11-cu130-aarch64-linux/_layers/dmoe.py
@@ -0,0 +1,337 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import numpy as np
+import torch
+
+# try:
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/dmoe.py is needed.',
+# )
+
+# import megablocks.ops as ops
+# # from megablocks.ops import ops
+# from megablocks.layers import common, dmlp_registry, moe, mpu
+# from megablocks.layers.arguments import Arguments
+
+from .. import stk
+from .. import ops
+from . import common, dmlp_registry, moe, mpu
+from .arguments import Arguments
+
+def promote_scalar(x):
+ return x.view(1) if not len(x.size()) else x
+
+
+class ParallelDroplessMLP(moe.ParallelMLP):
+
+ def __init__(self, args: Arguments):
+ super(ParallelDroplessMLP, self).__init__(args)
+ self.hidden_size = args.hidden_size
+ self.ffn_hidden_size = mpu.features_per_rank(args)
+ self.blocking = 128
+ self.mlp = dmlp_registry.get(args)
+
+ # Calculate the number of bits needed to represent the column indices
+ # in the intermediate sparse matrix.
+ max_column_index = ((self.ffn_hidden_size * self.num_experts) // self.blocking)
+ self.transpose_sort_end_bit = max(
+ int(np.ceil(np.log2(max_column_index))),
+ 1,
+ )
+
+ def sparse_transpose(self, size, row_indices, column_indices, offsets):
+ block_columns = size[1] // self.blocking
+
+ # Sort row indices by column indices to get the transposed matrix's
+ # column indices.
+ #
+ # NOTE: Our sort operation uses the same width indices as the input values.
+ # To avoid overflow when we have large activation matrices we cast to
+ # 32-bit before sorting.
+ _, gather_indices = ops.sort(
+ column_indices.int(),
+ self.transpose_sort_end_bit,
+ )
+
+ # There are a constant number of blocks in every row of the sparse matrix.
+ # A blocks offset is:
+ #
+ # row_index * blocks_per_row + column_index % blocks_per_row
+ #
+ # Once we have the block offsets ordered for transposition we can divide
+ # by blocks_per_row to get the transposed column indices.
+ column_indices_t = row_indices.gather(0, gather_indices.long())
+ block_offsets_t = gather_indices.int()
+
+ zero = torch.zeros((1,), dtype=torch.int32, device=row_indices.device)
+ nnz_per_column = ops.histogram(column_indices, block_columns)
+ nnz_per_column = ops.inclusive_cumsum(nnz_per_column, 0)
+ if nnz_per_column.dim() == 0:
+ # This addresses an edge case when ffn_hidden_size is equal to self.blocking.
+ nnz_per_column = nnz_per_column.unsqueeze(0)
+ offsets_t = torch.cat([zero, nnz_per_column])
+ return column_indices_t, offsets_t, block_offsets_t
+
+ def topology(self, x, padded_bins):
+ padded_tokens, _ = x.size()
+ assert padded_tokens % self.blocking == 0
+ if self.ffn_hidden_size % self.blocking != 0:
+ raise ValueError(
+ f'The ffn_hidden_size {self.ffn_hidden_size} must be divisible by ' +
+ f'the block size {self.blocking}. Please update your configuration.',
+ )
+
+ # Offsets for the sparse matrix. All rows have the
+ # same number of nonzero blocks dictated by the
+ # dimensionality of a single expert.
+ block_rows = padded_tokens // self.blocking
+ blocks_per_row = self.ffn_hidden_size // self.blocking
+ offsets = torch.arange(
+ 0,
+ block_rows * blocks_per_row + 1,
+ blocks_per_row,
+ dtype=torch.int32,
+ device=x.device,
+ )
+
+ # Indices for the sparse matrix. The indices for
+ # the intermediate matrix are dynamic depending
+ # on the mapping of tokens to experts.
+ column_indices = ops.topology(
+ padded_bins,
+ self.blocking,
+ block_rows,
+ blocks_per_row,
+ )
+
+ # TODO(tgale): This is unused. Remove the need for this in stk.
+ # For now, use meta init to save the device memory.
+ data = torch.empty(
+ column_indices.numel(),
+ self.blocking,
+ self.blocking,
+ dtype=common.dtype(self.args),
+ device='meta',
+ )
+ shape = (
+ padded_tokens,
+ self.ffn_hidden_size * mpu.experts_per_rank(self.args),
+ )
+ row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
+ column_indices_t, offsets_t, block_offsets_t = self.sparse_transpose(
+ shape,
+ row_indices,
+ column_indices,
+ offsets,
+ )
+ return stk.Matrix(
+ shape,
+ data,
+ row_indices,
+ column_indices,
+ offsets,
+ column_indices_t,
+ offsets_t,
+ block_offsets_t,
+ )
+
+ def indices_and_padded_bins(self, top_experts):
+ # Sort the expert ids to produce the scatter/gather
+ # indices for the permutation.
+ top_experts = top_experts.int()
+ bin_ids, indices = ops.sort(top_experts, self.sort_end_bit)
+
+ # Histogram the expert ids to identify the number of
+ # tokens routed to each expert.
+ tokens_per_expert = ops.histogram(top_experts, self.num_experts)
+
+ # Round the token counts up to the block size used in
+ # the matrix muliplications. Caculate the starting
+ # position of each bin.
+ padded_tokens_per_expert = ops.round_up(
+ tokens_per_expert,
+ self.blocking,
+ )
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ padded_bins = promote_scalar(padded_bins)
+
+ # Calculate the bin bounds for the sorted tokens.
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ bins = promote_scalar(bins)
+ return indices, bin_ids, bins, padded_bins, tokens_per_expert
+
+ def sparse_forward_once(self, x, expert_weights, top_experts):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, padded_bins, tokens_per_expert = (self.indices_and_padded_bins(top_experts))
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.padded_gather(
+ x,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ self.top_k,
+ )
+
+ # Create the sparse matrix topology.
+ with torch.no_grad():
+ topo = self.topology(x, padded_bins)
+
+ # Perform the expert computation.
+ x = self.mlp(x, topo)
+
+ # Un-route the data for the MoE output.
+ x = ops.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ padded_bins,
+ self.top_k,
+ )
+ return x, tokens_per_expert
+
+ # For use in the base-class parallel_forward_once.
+ def sparse_permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy, # unused
+ top_k,
+ ):
+
+ # Round the token counts up to the block size used in the matrix
+ # multiplication. Calculate the starting position of each bin.
+ padded_tokens_per_expert = ops.round_up(
+ tokens_per_expert,
+ self.blocking,
+ )
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ padded_bins = promote_scalar(padded_bins)
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
+
+ # Create the sparse matrix topology.
+ with torch.no_grad():
+ topo = self.topology(x, padded_bins)
+
+ # Perform the expert computation.
+ x = self.mlp(x, topo)
+
+ # Un-route the data for the MoE output.
+ return ops.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ def grouped_forward_once(self, x, expert_weights, top_experts):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ out = self.grouped_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ -1, # unused
+ self.args.moe_top_k,
+ )
+ return out, tokens_per_expert
+
+ def grouped_permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy, # unused
+ top_k,
+ ):
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.gather(x, indices, bin_ids, bins, top_k)
+
+ # Perform the expert computation.
+ x = self.mlp(x, tokens_per_expert)
+
+ # Un-route the data for the MoE output.
+ return ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
+
+ def forward_once(self, x, expert_weights, top_experts):
+ if self.args.mlp_impl == 'sparse':
+ return self.sparse_forward_once(x, expert_weights, top_experts)
+ else:
+ return self.grouped_forward_once(x, expert_weights, top_experts)
+
+ def permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ ):
+ if self.args.mlp_impl == 'sparse':
+ return self.sparse_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ )
+ else:
+ return self.grouped_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ )
+
+
+class dMoE(moe.MoE):
+
+ def _init_experts_mlp(self, args: Arguments):
+ return ParallelDroplessMLP(args)
diff --git a/build/torch210-cxx11-cu130-aarch64-linux/_layers/gelu.py b/build/torch210-cxx11-cu130-aarch64-linux/_layers/gelu.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4c9e6532798615b5c12c96694241a4c18ee8f7b
--- /dev/null
+++ b/build/torch210-cxx11-cu130-aarch64-linux/_layers/gelu.py
@@ -0,0 +1,52 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# try:
+# import stk
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/gelu.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+import torch.nn.functional as F
+
+
+@torch.jit.script
+def _gelu_backward_inplace(g, x):
+ tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
+ ff = (0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out))
+ return g.mul_(ff)
+
+
+def gelu_backward_(grad: stk.Matrix, x: stk.Matrix):
+ # NOTE: The two sparse matrices must have the same topology.
+ if isinstance(grad, stk.Matrix) and isinstance(x, stk.Matrix):
+ return stk.Matrix(
+ x.size(),
+ _gelu_backward_inplace(grad.data, x.data),
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
+ return _gelu_backward_inplace(grad, x)
+
+
+def gelu(x: stk.Matrix):
+ assert isinstance(x, stk.Matrix)
+ return stk.Matrix(
+ x.size(),
+ F.gelu(x.data, approximate='tanh'),
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
diff --git a/build/torch210-cxx11-cu130-aarch64-linux/_layers/glu.py b/build/torch210-cxx11-cu130-aarch64-linux/_layers/glu.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f297a41ff6a1a2a285f5b461951672364b898da
--- /dev/null
+++ b/build/torch210-cxx11-cu130-aarch64-linux/_layers/glu.py
@@ -0,0 +1,244 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# import stk.ops
+# try:
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/glu.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+
+# from megablocks import grouped_gemm_util as gg
+# from megablocks.layers import common, mpu
+# from megablocks.layers.activation_fn import act_fn
+# from megablocks.layers.arguments import Arguments
+# from megablocks.layers.mlp import (
+# SharedMLP,
+# SparseMLP,
+# create_dmoe_expert_weights,
+# resolve_dtensor,
+# )
+
+from .. import grouped_gemm_util as gg
+from . import common, mpu
+from .activation_fn import act_fn
+from .arguments import Arguments
+from .mlp import (
+ SharedMLP,
+ SparseMLP,
+ create_dmoe_expert_weights,
+ resolve_dtensor,
+)
+
+
+class SparseGLU(SparseMLP):
+
+ def __init__(self, args: Arguments):
+ super().__init__(args)
+ self.v1 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ with torch.no_grad():
+ self.v1.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ ),
+ )
+
+ mpu.set_expert_model_parallel_attributes(
+ self.v1,
+ self._should_set_parallelism_attribute,
+ )
+
+ def forward(self, x, topo):
+ if self.args.memory_optimized_mlp:
+ raise NotImplementedError(
+ 'Memory optimized implementation not yet supported with GLU with sparse kernels.',
+ )
+
+ w1, v1, w2 = self.scale_grad(self.w1), self.scale_grad(self.v1,), self.scale_grad(self.w2)
+ w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2)
+
+ # Compute the GLU.
+ x1 = stk.ops.sdd(x, w1.t(), topo)
+ x2 = stk.ops.sdd(x, v1.t(), topo)
+
+ activation_fn_out = act_fn(x1, self.args.activation_fn)
+ x1 = stk.ops.mul(activation_fn_out, x2)
+
+ return stk.ops.dsd(x1, w2)
+
+
+class MemoryOptimizedGroupedGLU(torch.autograd.Function):
+ """GroupedMLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, v1, w2, batch_sizes, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ v1 = v1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], v1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not v1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1', 'v1' and 'w2'.")
+
+ # Layer 0: x @ w1.t().
+ assert gg.backend is not None
+ sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True)
+ v1_out = gg.backend.gmm(x, v1, batch_sizes, trans_b=True)
+
+ # GeLU.
+ activation_fn_out = activation_fn(sdd_out) * v1_out
+
+ # Layer 1: x @ w2.
+ dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, v1, w2, batch_sizes, x, sdd_out, v1_out)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx, ddsd_out):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # Unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, v1, w2 = saved_tensors[:3]
+ batch_sizes = saved_tensors[3]
+ x = saved_tensors[4]
+ sdd_out, v1_out = saved_tensors[5:7]
+
+ # Rematerialize activation_fn output.
+ activation_fn = ctx.activation_fn
+ with torch.set_grad_enabled(True):
+ sdd_out.requires_grad = True
+ v1_out.requires_grad = True
+ activation_fn_out = activation_fn(sdd_out) * v1_out
+ activation_grad_fn = activation_fn_out.backward
+
+ # Compute dw2 with recomputed activation_fn output.
+ assert gg.backend is not None
+ dw2 = gg.backend.gmm(
+ activation_fn_out,
+ ddsd_out,
+ batch_sizes,
+ trans_a=True,
+ )
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ gg.backend.gmm(
+ ddsd_out,
+ w2,
+ batch_sizes,
+ trans_b=True,
+ c=dactivation_fn_out,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out)
+ dsdd_out = sdd_out.grad
+ dv1_out = v1_out.grad
+
+ # Compute dw1.
+ dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True)
+
+ # Compute dv1.
+ dv1 = gg.backend.gmm(dv1_out, x, batch_sizes, trans_a=True)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ dx = ddsd_out
+ gg.backend.gmm(dsdd_out, w1, batch_sizes, c=dx)
+ dx += gg.backend.gmm(dv1_out, v1, batch_sizes)
+ return dx, dw1, dv1, dw2, None, None
+
+
+memory_optimized_grouped_glu = MemoryOptimizedGroupedGLU.apply
+
+
+class GroupedGLU(SparseGLU):
+
+ def forward(self, x, tokens_per_expert):
+ batch_sizes = tokens_per_expert.cpu().to(torch.long)
+ w1, v1, w2 = (
+ self.scale_grad(self.w1),
+ self.scale_grad(self.v1),
+ self.scale_grad(self.w2),
+ )
+ w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2)
+
+ # Re-shape the weights for the grouped GEMMs.
+ ne = mpu.experts_per_rank(self.args)
+ w1 = w1.view(ne, -1, self.args.hidden_size)
+ v1 = v1.view(ne, -1, self.args.hidden_size)
+ w2 = w2.view(ne, -1, self.args.hidden_size)
+
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_grouped_glu(
+ x,
+ w1,
+ v1,
+ w2,
+ batch_sizes,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ assert gg.ops is not None
+ x1 = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
+ x2 = gg.ops.gmm(x, v1, batch_sizes, trans_b=True)
+ x1 = self.args.activation_fn(x1) * x2
+ return gg.ops.gmm(x1, w2, batch_sizes)
+
+
+class SharedGLU(SharedMLP):
+ """GPU for shared expert.
+
+ Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTGLU class
+ """
+
+ def __init__(self, args: Arguments):
+ super().__init__(args)
+ self.gate_proj = args.fc_cls(
+ args.hidden_size,
+ self.args.shared_expert_hidden_size,
+ **self.fc_kwargs,
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))
diff --git a/build/torch210-cxx11-cu130-aarch64-linux/_layers/memory_test.py b/build/torch210-cxx11-cu130-aarch64-linux/_layers/memory_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..74d1166931b712635131985b25a89f4ca23e576d
--- /dev/null
+++ b/build/torch210-cxx11-cu130-aarch64-linux/_layers/memory_test.py
@@ -0,0 +1,103 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import gc
+
+import torch
+import torch.distributed as dist
+
+# from megablocks.layers import arguments, dmoe
+from . import arguments, dmoe
+
+_TESTS = ((8, 2048, 4096, 4096, 32, 4),)
+
+
+def get_tensors():
+ ptrs = set()
+ out = []
+ for obj in gc.get_objects():
+ if torch.is_tensor(obj):
+ if not obj.is_contiguous() or obj.data_ptr() in ptrs:
+ continue
+ out.append(obj)
+ ptrs.add(obj.data_ptr())
+ return out
+
+
+def test_memory(
+ group,
+ batch_size,
+ sequence_length,
+ hidden_size,
+ ffn_hidden_size,
+ num_experts,
+ top_k,
+):
+ args = arguments.Arguments(
+ hidden_size=hidden_size,
+ ffn_hidden_size=ffn_hidden_size,
+ moe_num_experts=num_experts,
+ moe_top_k=top_k,
+ moe_expert_model_parallelism=True,
+ expert_parallel_group=group,
+ fp16=False,
+ bf16=True,
+ device=torch.cuda.current_device(),
+ )
+ layer = dmoe.dMoE(args).cuda()
+
+ x = torch.randn((batch_size, sequence_length, hidden_size),
+ device=torch.cuda.current_device(),
+ dtype=torch.bfloat16).requires_grad_(True)
+ torch.cuda.empty_cache()
+
+ # Run forward + backward.
+ # with torch.autograd.detect_anomaly():
+ out, _ = layer(x)
+ out.mean().backward()
+
+ # Report peak memory.
+ mem = torch.cuda.max_memory_allocated()
+ print('Max Memory Allocated = {:0.0f}MiB'.format(mem / 1e6))
+ print('Max Memory Reserved = {:0.0f}MiB'.format(torch.cuda.max_memory_reserved() / 1e6,),)
+
+ # Calculate weight and gradient memory usage.
+ weight_memory = 2 * (
+ layer.router.layer.weight.numel() + layer.experts.mlp.w1.numel() + layer.experts.mlp.w2.numel()
+ )
+
+ def grad_numel(x):
+ if x.grad is not None:
+ return x.grad.numel()
+ return 0
+
+ grad_memory = 2 * (
+ grad_numel(layer.router.layer.weight) + grad_numel(layer.experts.mlp.w1) + grad_numel(layer.experts.mlp.w2)
+ )
+ weight_memory += grad_memory
+
+ print('Weight Memory Allocated = {:0.0f}MiB'.format(weight_memory / 1e6))
+ print('Activation Memory Allocated = {:0.0f}MiB'.format((mem - weight_memory) / 1e6,),)
+
+ # Manually calculate GPU memory usage from the garbage
+ # collector.
+ gc.collect()
+ total = 0
+ tensors = get_tensors()
+ tensors = sorted(tensors, key=lambda x: -x.numel())
+ for i, t in enumerate(tensors):
+ total += t.numel()
+ print(f'{i}: {t.shape}, {t.numel() * 2}')
+ del tensors
+
+ print('Total Bytes Found = {:0.0f}MiB'.format(total * 2 / 1e6))
+
+
+if __name__ == '__main__':
+ assert dist.is_available()
+ group = dist.init_process_group(backend='nccl')
+ local_rank = dist.get_rank(group)
+ torch.cuda.set_device(local_rank)
+
+ for args in _TESTS:
+ test_memory(group, *args)
diff --git a/build/torch210-cxx11-cu130-aarch64-linux/_layers/mlp.py b/build/torch210-cxx11-cu130-aarch64-linux/_layers/mlp.py
new file mode 100644
index 0000000000000000000000000000000000000000..c99afb9904c24a8b6a83e79059cd1251dbbfd99e
--- /dev/null
+++ b/build/torch210-cxx11-cu130-aarch64-linux/_layers/mlp.py
@@ -0,0 +1,587 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# try:
+# import stk
+# import stk.backend.triton_kernels
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/mlp.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+from packaging import version
+
+# from megablocks import grouped_gemm_util as gg
+# from megablocks.layers import common, gelu, mpu
+# from megablocks.layers.activation_fn import act_fn
+# from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
+
+from .. import grouped_gemm_util as gg
+from . import common, gelu, mpu
+from .activation_fn import act_fn
+from .arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
+
+class ScaleGradient(torch.autograd.Function):
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx: Any, x: torch.Tensor, scale: float):
+ ctx.scale = scale
+ return x
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx: torch.Tensor, grad: torch.Tensor):
+ return grad * ctx.scale, None
+
+
+scale_gradient = ScaleGradient.apply
+
+
+def resolve_dtensor(weight: torch.Tensor):
+ if version.parse(torch.__version__) >= version.parse('2.0.0'):
+ from torch.distributed._tensor import DTensor
+ if isinstance(weight, DTensor):
+ return weight.to_local()
+ return weight
+
+
+def create_moe_expert_weights(
+ args: Arguments,
+ num_experts: int,
+ ffn_hidden_size: int,
+ hidden_size: int,
+ init_method: InitFn,
+):
+ # Create the entire weight matrix such that the sampled weights will
+ # not vary between data parallelism and expert model parallelism for
+ # the same random seed.
+ master_weights = torch.empty(
+ num_experts,
+ ffn_hidden_size,
+ hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ )
+ init_method(master_weights)
+
+ if not args.moe_expert_model_parallelism:
+ return master_weights
+
+ # Calculate the amount of sharding in each dimension.
+ expert_sharding_degree = mpu.expert_sharding_degree(args)
+ hidden_sharding_degree = mpu.hidden_sharding_degree(args)
+
+ # Calculate the experts per rank.
+ #
+ # NOTE: We assign ranks to be expert parallel before going
+ # tensor parallel.
+ rank = mpu.get_expert_parallel_rank(args)
+ expert_rank = rank % expert_sharding_degree
+ num_experts_per_rank = num_experts // expert_sharding_degree
+ start_expert = expert_rank * num_experts_per_rank
+ end_expert = (expert_rank + 1) * num_experts_per_rank
+
+ # Calculate the rows per rank.
+ row_rank = rank // expert_sharding_degree
+ num_rows_per_rank = ffn_hidden_size // hidden_sharding_degree
+ start_row = row_rank * num_rows_per_rank
+ end_row = (row_rank + 1) * num_rows_per_rank
+
+ # Slice the weight matrix to get the chunk for this rank.
+ with torch.no_grad():
+ weights = master_weights[start_expert:end_expert, start_row:end_row]
+ return weights
+
+
+class MLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ # expert_parallel_world_size = mpu.get_expert_parallel_world_size(args)
+ experts_per_rank = mpu.experts_per_rank(args)
+
+ self.w1 = torch.nn.Parameter(
+ torch.empty(
+ experts_per_rank,
+ args.hidden_size,
+ mpu.features_per_rank(args),
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ self.w2 = torch.nn.Parameter(
+ torch.empty(
+ experts_per_rank,
+ mpu.features_per_rank(args),
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w1,
+ args.moe_expert_model_parallelism,
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w2,
+ args.moe_expert_model_parallelism,
+ )
+
+ # Initialize the parameters for the MLP.
+ #
+ # NOTE: It is important that we create the weight tensors prior
+ # to creating the master weights and slicing our the piece for
+ # this rank. If the master weights are created first the PyTorch
+ # caching allocator appears to use the same memory block for these
+ # and the slice which causes large increases in our peak memory
+ # usage.
+ with torch.no_grad():
+ w1 = create_moe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ )
+ self.w1.copy_(w1.transpose(1, 2).contiguous())
+ self.w2.copy_(
+ create_moe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.output_layer_init_method,
+ ),
+ )
+
+ self.gradient_scale = None
+ if self.args.moe_expert_model_parallelism:
+ self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,)
+
+ def scale_grad(self, w):
+ if self.gradient_scale is None:
+ return w
+ return scale_gradient(w, self.gradient_scale)
+
+ def forward(self, x):
+ w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2)
+ w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2)
+ x = torch.bmm(x, w1)
+ x = self.args.activation_fn(x)
+ return torch.bmm(x, w2)
+
+
+def create_dmoe_expert_weights(
+ args: Arguments,
+ num_experts: int,
+ rows: int,
+ columns: int,
+ init_method: InitFn,
+):
+ weights = create_moe_expert_weights(
+ args,
+ num_experts,
+ rows,
+ columns,
+ init_method,
+ )
+ return weights.view([-1, columns])
+
+
+class MemoryOptimizedMLP(torch.autograd.Function):
+ """Sparse MLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, w2, topo, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.")
+
+ topo_tensors = (
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets,
+ topo.column_indices_t,
+ topo.offsets_t,
+ topo.block_offsets_t,
+ )
+
+ # Layer 0: x @ w1.t().
+ sdd_out = stk.ops.sdd(x, w1.t(), topo)
+
+ # GeLU.
+ activation_fn_out = act_fn(sdd_out, activation_fn)
+
+ # Layer 1: x @ w2.
+ dsd_out = stk.ops.dsd(activation_fn_out, w2)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.shape = topo.shape
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.data.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, w2, *topo_tensors, x, sdd_out.data)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx, ddsd_out):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, w2 = saved_tensors[:2]
+ topo_tensors = saved_tensors[2:8]
+ x = saved_tensors[8]
+ sdd_out_data = saved_tensors[9]
+
+ # rematerialize activation function output
+ activation_fn = ctx.activation_fn
+ sdd_out = stk.Matrix(ctx.shape, sdd_out_data, *topo_tensors)
+ activation_fn_out, activation_grad_fn = act_fn(
+ sdd_out,
+ activation_fn,
+ return_grad_fn=True,
+ )
+
+ # Compute dw2 with recomputed activation_fn output.
+ dw2 = stk.ops.dsd(activation_fn_out.t(), ddsd_out)
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ stk.backend.triton_kernels.sdd(
+ ddsd_out,
+ w2.t(),
+ dactivation_fn_out.shape,
+ dactivation_fn_out.data,
+ dactivation_fn_out.offsets,
+ dactivation_fn_out.row_indices,
+ dactivation_fn_out.column_indices,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ if activation_fn is DEFAULT_ACTIVATION_FN:
+ dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out)
+ else:
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out.data)
+ dsdd_out = stk.Matrix(ctx.shape, sdd_out.data.grad, *topo_tensors)
+
+ # Compute dw1.
+ dw1 = stk.ops.dsd(dsdd_out.t(), x)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ stk.backend.triton_kernels.dsd(
+ dsdd_out.shape,
+ dsdd_out.data,
+ dsdd_out.offsets,
+ dsdd_out.row_indices,
+ dsdd_out.column_indices,
+ dsdd_out.offsets_t,
+ dsdd_out.column_indices_t,
+ dsdd_out.block_offsets_t,
+ False,
+ w1,
+ ddsd_out,
+ )
+ dx = ddsd_out
+ return dx, dw1, dw2, None, None
+
+
+memory_optimized_mlp = MemoryOptimizedMLP.apply
+
+
+class SparseMLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ self._num_rows_per_rank = mpu.experts_per_rank(args) * mpu.features_per_rank(args)
+
+ self.w1 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ self.w2 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+
+ # Initialize the parameters for the MLP.
+ #
+ # NOTE: It is important that we create the weight tensors prior
+ # to creating the master weights and slicing our the piece for
+ # this rank. If the master weights are created first the PyTorch
+ # caching allocator appears to use the same memory block for these
+ # and the slice which causes large increases in our peak memory
+ # usage.
+ with torch.no_grad():
+ self.w1.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ ),
+ )
+ self.w2.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.output_layer_init_method,
+ ),
+ )
+
+ self._should_set_parallelism_attribute = args.moe_expert_model_parallelism
+ mpu.set_expert_model_parallel_attributes(
+ self.w1,
+ self._should_set_parallelism_attribute,
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w2,
+ self._should_set_parallelism_attribute,
+ )
+
+ self.gradient_scale = None
+ if self.args.moe_expert_model_parallelism:
+ self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,)
+
+ def scale_grad(self, w):
+ if self.gradient_scale is None:
+ return w
+ return scale_gradient(w, self.gradient_scale)
+
+ def forward(self, x, topo):
+ w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2)
+ w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2)
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_mlp(
+ x,
+ w1,
+ w2,
+ topo,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ x = stk.ops.sdd(x, w1.t(), topo)
+ activation_fn_out = act_fn(x, self.args.activation_fn)
+ return stk.ops.dsd(activation_fn_out, w2)
+
+
+class MemoryOptimizedGroupedMLP(torch.autograd.Function):
+ """GroupedMLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, w2, batch_sizes, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.")
+
+ # Layer 0: x @ w1.t().
+ assert gg.backend is not None
+ sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True)
+
+ # activation_fn
+ activation_fn_out = activation_fn(sdd_out)
+
+ # Layer 1: x @ w2.
+ dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, w2, batch_sizes, x, sdd_out)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx: Any, ddsd_out: torch.Tensor):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # Unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, w2 = saved_tensors[:2]
+ batch_sizes = saved_tensors[2]
+ x = saved_tensors[3]
+ sdd_out = saved_tensors[4]
+
+ # Rematerialize activation_fn output.
+ activation_fn = ctx.activation_fn
+ with torch.set_grad_enabled(True):
+ sdd_out.requires_grad = True
+ activation_fn_out = activation_fn(sdd_out)
+ activation_grad_fn = activation_fn_out.backward
+
+ # Compute dw2 with recomputed activation_fn output.
+ assert gg.backend is not None
+ dw2 = gg.backend.gmm(
+ activation_fn_out,
+ ddsd_out,
+ batch_sizes,
+ trans_a=True,
+ )
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ gg.backend.gmm(
+ ddsd_out,
+ w2,
+ batch_sizes,
+ trans_b=True,
+ c=dactivation_fn_out,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ if activation_fn is DEFAULT_ACTIVATION_FN:
+ dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out)
+ else:
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out)
+ dsdd_out = sdd_out.grad
+
+ # Compute dw1.
+ dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ gg.backend.gmm(dsdd_out, w1, batch_sizes, c=ddsd_out)
+ dx = ddsd_out
+ return dx, dw1, dw2, None, None
+
+
+memory_optimized_grouped_mlp = MemoryOptimizedGroupedMLP.apply
+
+
+class GroupedMLP(SparseMLP):
+
+ def forward(self, x, tokens_per_expert):
+ batch_sizes = tokens_per_expert.cpu().to(torch.long)
+ w1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.w2))
+
+ # Re-shape the weights for the grouped GEMMs.
+ ne = mpu.experts_per_rank(self.args)
+ w1 = resolve_dtensor(w1).view(ne, -1, self.args.hidden_size)
+ w2 = resolve_dtensor(w2).view(ne, -1, self.args.hidden_size)
+
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_grouped_mlp(
+ x,
+ w1,
+ w2,
+ batch_sizes,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ assert gg.ops is not None
+ x = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
+ x = self.args.activation_fn(x)
+ return gg.ops.gmm(x, w2, batch_sizes)
+
+
+class SharedMLP(torch.nn.Module):
+ """MLP for shared expert.
+
+ Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTMLP class
+ """
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ self.fc_kwargs: dict[str, Any] = {
+ 'bias': args.bias,
+ 'device': args.device,
+ }
+ self.fc_kwargs.update(args.fc_kwargs)
+
+ self.up_proj = args.fc_cls(
+ args.hidden_size,
+ args.shared_expert_hidden_size,
+ **self.fc_kwargs,
+ )
+ self.act = args.activation_fn
+ self.down_proj = args.fc_cls(
+ args.shared_expert_hidden_size,
+ args.hidden_size,
+ **self.fc_kwargs,
+ )
+ self.down_proj._is_residual = True # a flag for llm-foundry init
+
+ def add_experts_sharedexpert(
+ self,
+ shared_expert_out: torch.Tensor,
+ expert_out: torch.Tensor,
+ ) -> torch.Tensor:
+ # Helper function to add expert output to shared expert output
+ # with optional weighted sum.
+ if self.args.shared_expert_weighted_sum:
+ # enable using weighted sum for shared expert output
+ # wieghted by number of experts used
+ t_experts = self.args.moe_top_k + 1
+ sh_mlp_out = shared_expert_out / t_experts
+ return sh_mlp_out.add(
+ expert_out,
+ alpha=(self.args.moe_top_k / t_experts),
+ )
+
+ return shared_expert_out + expert_out
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.down_proj(self.act(self.up_proj(x)))
diff --git a/build/torch210-cxx11-cu130-aarch64-linux/_layers/moe.py b/build/torch210-cxx11-cu130-aarch64-linux/_layers/moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0a4aeaacc9c86fc70944e730c53f7a55644e05e
--- /dev/null
+++ b/build/torch210-cxx11-cu130-aarch64-linux/_layers/moe.py
@@ -0,0 +1,507 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Optional, Tuple
+
+import numpy as np
+import torch
+import torch.distributed as dist
+
+# import megablocks.ops as ops
+# from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry
+# from megablocks.layers.all_to_all import all_to_all
+# from megablocks.layers.arguments import Arguments
+
+from ..ops import (
+ sort,
+ histogram,
+ inclusive_cumsum,
+ exclusive_cumsum,
+ binned_gather,
+ binned_scatter,
+ gather,
+ scatter,
+ repeat,
+ replicate,
+)
+
+from . import common, mlp, mpu, router, sharedexpert_registry
+from .arguments import Arguments
+from .all_to_all import all_to_all
+
+_LOAD_BALANCING_LOSS = []
+
+
+def save_load_balancing_loss(loss):
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.append(loss)
+
+
+def get_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ return _LOAD_BALANCING_LOSS
+
+
+def clear_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.clear()
+
+
+def batched_load_balancing_loss(args: Arguments):
+ if args.moe_loss_weight == 0:
+ return 0.0
+
+ # tokens_per_expert[i].shape = (num_experts)
+ # expert_scores[i].shape = (tokens, num_experts)
+ tokens_per_expert, expert_scores = zip(*get_load_balancing_loss())
+ num_layers_per_pipeline_stage = (args.num_layers // args.pipeline_model_parallel_size)
+ if args.num_layers_per_virtual_pipeline_stage is not None:
+ num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage
+
+ if len(tokens_per_expert) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f'Expected {num_layers_per_pipeline_stage} token_per_experts '
+ f'but found {len(tokens_per_expert)}.\nnum_layers = '
+ f'{args.num_layers}\npipeline_model_parallel_size = '
+ f'{args.pipeline_model_parallel_size}\n'
+ 'num_layers_per_virtual_pipeline_stage'
+ f' = {args.num_layers_per_virtual_pipeline_stage}',
+ )
+ if len(expert_scores) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f'Expected {num_layers_per_pipeline_stage} expert_scores '
+ f'but found {len(tokens_per_expert)}.\nnum_layers = '
+ f'{args.num_layers}\npipeline_model_parallel_size = '
+ f'{args.pipeline_model_parallel_size}\n'
+ 'num_layers_per_virtual_pipeline_stage'
+ f' = {args.num_layers_per_virtual_pipeline_stage}',
+ )
+
+ # Verify the shape of the tokens_per_expert and expert_scores tensors.
+ assert all((x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert))
+
+ tokens = expert_scores[0].shape[0]
+ assert all(((x.ndim == 2 and x.shape[1] == args.moe_num_experts and x.shape[0] == tokens) for x in expert_scores))
+
+ # Concatenate the contributions of each layer and convert to
+ # the correct types and formats for the dot product.
+ expert_scores = torch.cat(expert_scores, dim=1)
+ if args.moe_lbl_in_fp32:
+ expert_scores = expert_scores.float()
+ if tokens != 0:
+ expert_scores = expert_scores.mean(dim=0)
+ else:
+ expert_scores = expert_scores.sum(dim=0)
+ tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype)
+
+ expected_values = num_layers_per_pipeline_stage * args.moe_num_experts
+ assert tokens_per_expert.numel() == expected_values
+ assert expert_scores.numel() == expected_values
+
+ # Calculate the total scale across all factors.
+ #
+ # loss_weight * num_experts / (num_layers * tokens * top_k)
+ scale_numerator = (args.moe_num_experts * args.moe_loss_weight)
+ scale_denominator = (args.num_layers * tokens * args.moe_top_k)
+ scale = scale_numerator / scale_denominator
+ return scale * torch.dot(tokens_per_expert, expert_scores)
+
+
+# NOTE: This class defines MoE expert computation, including expert model parallel
+# communication. When using FSDP on top of MegaBlocks this is the module that should
+# be wrapped s.t. the weight all-gathers can be scheduled *before* the expert model
+# parallel all2all.
+class ParallelMLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super(ParallelMLP, self).__init__()
+ self.args = args
+
+ # Calculate the number of experts in total and the number of experts
+ # owned by this rank.
+ # world_size = mpu.get_expert_parallel_world_size(args)
+ self.num_experts = args.moe_num_experts
+ self.top_k = self.args.moe_top_k
+
+ # Calculate the number of bits needed to represent the expert indices
+ # so that we can pass it to radix sort.
+ self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1)
+
+ # Expert MLP.
+ self.mlp = mlp.MLP(args)
+
+ self.bias: Optional[torch.Tensor]
+ if self.args.bias:
+ # Note that the output bias is not parallelized with expert
+ # model parallelism.
+ self.bias = torch.nn.Parameter(
+ torch.empty(
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ torch.nn.init.zeros_(self.bias)
+ else:
+ self.register_parameter('bias', None)
+
+ # Select the forward function for the operating mode.
+ self.forward_fn = (self.parallel_forward_once if args.moe_expert_model_parallelism else self.forward_once)
+
+ def expert_capacity(self, tokens: int) -> int:
+ world_size = mpu.get_expert_parallel_world_size(self.args)
+ tokens_per_expert = (self.top_k * tokens * world_size / self.num_experts)
+ return int(self.args.moe_capacity_factor * tokens_per_expert)
+
+ def load_balancing_loss(self, tokens_per_expert: torch.Tensor, expert_scores: torch.Tensor):
+ """Calculate the load balancing loss contribution."""
+ assert len(expert_scores.size()) == 2
+ tokens, num_experts = expert_scores.size()
+ assert num_experts == self.num_experts
+ assert len(tokens_per_expert.size()) == 1
+ num_experts, = tokens_per_expert.size()
+ assert num_experts == self.num_experts
+ scale = self.num_experts / (tokens * self.top_k)
+ return scale * torch.dot(
+ tokens_per_expert.to(expert_scores.dtype),
+ expert_scores.mean(dim=0),
+ )
+
+ def indices_and_bins(self,
+ top_expert: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ # Sort the expert ids to produce the scatter/gather
+ # indices for the permutation.
+ #
+ # TODO(tgale): Is it worth doing this conversion to 32-bit
+ # prior? Could we place the `torch.max` operation to return
+ # 32-bit expert indices?
+ top_expert = top_expert.int()
+ # output = ops.sort(top_expert, self.sort_end_bit)
+ output = sort(top_expert, self.sort_end_bit)
+ assert output is not None
+ bin_ids, indices = output
+
+ # Histogram the expert ids to identify the number of
+ # tokens routed to each expert.
+ #
+ # TODO(tgale): Does the sorted data produce a more favorable
+ # data distribution for histogram? Or is the op parallelism
+ # worth more?
+ # tokens_per_expert = ops.histogram(top_expert, self.num_experts)
+ tokens_per_expert = histogram(top_expert, self.num_experts)
+
+ # Calculate the bin bounds for the sorted tokens.
+ # bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ bins = inclusive_cumsum(tokens_per_expert, 0)
+ assert bins is not None
+ bins = bins.view(1) if not len(bins.size()) else bins
+
+ assert isinstance(indices, torch.Tensor)
+ assert isinstance(bin_ids, torch.Tensor)
+ assert isinstance(bins, torch.Tensor)
+ assert isinstance(tokens_per_expert, torch.Tensor)
+
+ return indices, bin_ids, bins, tokens_per_expert
+
+ def permute_and_compute(
+ self,
+ x: torch.Tensor,
+ tokens_per_expert: int, # unused
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor, # unused
+ expert_weights: torch.Tensor,
+ bins: torch.Tensor,
+ expert_capacity: int,
+ top_k: int,
+ ):
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ # output = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
+ output = binned_gather(x, indices, bins, expert_capacity, top_k)
+ assert output is not None
+ x = output
+
+ # Perform the expert computation. Note that we don't
+ # use biases for these linear operations.
+ x = self.mlp(x)
+
+ # Un-route the data for the MoE output.
+ # return ops.binned_scatter(x, indices, expert_weights, bins, top_k)
+ return binned_scatter(x, indices, expert_weights, bins, top_k)
+
+
+ def forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ # If expert_capacity is set to zero, set the number of tokens
+ # per expert to the maximum we need to avoid dropping tokens.
+ sl, bs, _ = x.size()
+ expert_capacity = self.expert_capacity(sl * bs)
+ if expert_capacity == 0:
+ expert_capacity = torch.max(tokens_per_expert).item()
+
+ x = self.permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ self.top_k,
+ )
+ return x, tokens_per_expert
+
+ def parallel_forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ # NOTE: This function implements the same computation as forward_once
+ # but with expert model parallelism.
+ #
+ # 1. Permute the tokens locally so that they are grouped by their
+ # expert assignments. This allows us to transfer all of the tokens
+ # for a remote device in one communication primitive.
+ #
+ # 2. Permute the tokens across the expert parallel devices. After
+ # this is completed each device has all of the tokens assigned to
+ # its set of experts in its local HBM.
+ #
+ # 3. Permute the tokens locally so that they are grouped by their
+ # expert assignement. After the distributed permutation the tokens
+ # are grouped by which device they came from. We re-order them
+ # locally to allow for efficient computation.
+ #
+ # After this series of permutations we compute the linear layers
+ # and then repeat these three steps in reverse to produce the final
+ # output.
+ #
+ # Compute the mapping of local tokens to experts.
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ # If we're sharding the experts along the hidden dimension
+ # multiple devices own parts of the same sets of experts.
+ # Replicate the token counts so every device gets the counts.
+ # repeated_tokens_per_expert = ops.repeat(
+ repeated_tokens_per_expert = repeat(
+ tokens_per_expert,
+ (mpu.hidden_sharding_degree(self.args),),
+ )
+
+ # Pass token count information to the device on which the
+ # target expert resides.
+ parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert,)
+ tpe_handle = dist.all_to_all_single(
+ parallel_tokens_per_expert,
+ repeated_tokens_per_expert,
+ group=self.args.expert_parallel_group,
+ async_op=True,
+ )
+
+ # Permute locally and without any padding so that tokens for each
+ # parallel device are stored contiguously.
+ #
+ # This view updates the shape of the tensor from [sl, bs, hs] to
+ # [sl * bs, hs] prior to the permutation.
+ x = x.view(-1, x.shape[-1])
+ # output = ops.gather(x, indices, bin_ids, bins, self.top_k)
+ output = gather(x, indices, bin_ids, bins, self.top_k)
+ assert output is not None
+ x = output
+
+ # Compute the number of tokens that will be received from each
+ # device and permute the input data across the devices.
+ with torch.no_grad():
+ tpe_handle.wait()
+ experts_per_rank = mpu.experts_per_rank(self.args)
+
+ # Reshape to [world_size, num_experts_per_rank].
+ world_size = mpu.get_expert_parallel_world_size(self.args)
+ repeated_tokens_per_expert = (repeated_tokens_per_expert.view(world_size, experts_per_rank))
+ parallel_tokens_per_expert = (parallel_tokens_per_expert.view(world_size, experts_per_rank))
+
+ # TODO(tgale): It might be faster to do this on the GPU and
+ # then communicate the results back to the host.
+ send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1)
+ parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
+ recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1)
+
+ # Convert the send/recv counts to lists.
+ send_counts = send_counts.tolist()
+ recv_counts = recv_counts.tolist()
+ tokens_received = sum(recv_counts)
+
+ # If we're sharding the experts along the hidden dimension
+ # multiple devices own parts of the same sets of experts.
+ # Replicate the token counts so devices that share experts
+ # get all of the tokens assigned to them.
+ #
+ # TODO(tgale): Fuse this into the prior, local permutation.
+ # x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
+ x = repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
+
+ # Start the cross-device permutation asynchronously so we can
+ # overlap communication with computation.
+ parallel_x, parallel_x_handle = all_to_all(
+ x,
+ recv_counts,
+ send_counts,
+ self.args.expert_parallel_group,
+ async_op=True,
+ )
+
+ with torch.no_grad():
+ # After we do the cross-device permutation we have the tokens on the
+ # correct device but not yet grouped by expert because we received
+ # tokens from each device as contiguous chunks. To group the tokens
+ # for expert computation we'll do one more local permutation. The
+ # rest of this torch.no_grad() scope sets up the indices and bins
+ # for this permutation.
+ # replicate_bins = ops.inclusive_cumsum(
+ replicate_bins = inclusive_cumsum(
+ parallel_tokens_per_expert.flatten(),
+ 0,
+ )
+ replicate_bins = (replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins)
+
+ # Construct the expert indices for the permuted tokens.
+ parallel_top_expert = torch.remainder(
+ torch.arange(
+ self.num_experts * mpu.hidden_sharding_degree(self.args),
+ dtype=torch.int32,
+ device=indices.device,
+ ),
+ mpu.experts_per_rank(self.args),
+ )
+ # parallel_top_expert = ops.replicate(
+ parallel_top_expert = replicate(
+ parallel_top_expert.unsqueeze(dim=0),
+ replicate_bins,
+ tokens_received,
+ ).flatten()
+
+ # TODO(tgale): The sort_end_bit here can be reduced.
+ # parallel_bin_ids, parallel_indices = ops.sort(
+ parallel_bin_ids, parallel_indices = sort(
+ parallel_top_expert,
+ self.sort_end_bit,
+ )
+
+ # Calculate the bins boundaries from the token counts.
+ parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+ # parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = (parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins)
+
+ # If expert_capacity is set to zero, set the number of tokens
+ # per expert to the maximum we need to avoid dropping tokens.
+ tokens, _ = x.size()
+ expert_capacity = self.expert_capacity(tokens)
+ if expert_capacity == 0:
+ expert_capacity = torch.max(parallel_tokens_per_expert).item()
+
+ # Locally permute the tokens and perform the expert computation.
+ # Block to make sure that the cross-device permutation is complete.
+ if self.args.mlp_impl == 'grouped':
+ # GroupedMLP requires counts on CPU. We can use the tensor already
+ # moved to CPU for the prior all_to_all, which avoids an extra
+ # device synchronization.
+ parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+ parallel_x_handle.wait()
+ parallel_x = self.permute_and_compute(
+ parallel_x,
+ parallel_tokens_per_expert,
+ parallel_indices,
+ parallel_bin_ids,
+ None, # expert_weights
+ parallel_bins,
+ expert_capacity,
+ top_k=1,
+ )
+
+ # Un-permute the tokens across the devices.
+ x, _ = all_to_all(
+ parallel_x,
+ send_counts,
+ recv_counts,
+ self.args.expert_parallel_group,
+ )
+
+ # Reduce along the hidden sharding to get the final outputs.
+ #
+ # TODO(tgale): Fuse this into the following local permutation.
+ shape = (
+ mpu.hidden_sharding_degree(self.args),
+ -1,
+ self.args.hidden_size,
+ )
+ # x = ops.sum(x.view(shape), dim=0)
+ x = x.view(shape).sum(dim=0)
+
+ # Un-permute locally to setup for the next series of operations.
+ # x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
+ x = scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
+ return x, tokens_per_expert.flatten()
+
+ def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ in_shape = x.size()
+
+ # Compute the experts.
+ x, tokens_per_expert = self.forward_fn(x, expert_weights, top_experts)
+ if self.training and self.args.moe_loss_weight > 0:
+ save_load_balancing_loss((tokens_per_expert, scores))
+ x = x.view(in_shape)
+ if self.bias is not None:
+ if self.args.return_bias:
+ return x, self.bias
+ return x + self.bias
+ return x
+
+
+class MoE(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super(MoE, self).__init__()
+
+ # Token router.
+ self.router = router.LearnedRouter(args)
+
+ # Expert computation helper.
+ self.experts = self._init_experts_mlp(args)
+
+ self.shared_expert = None
+ if args.shared_expert:
+ # SharedExpert computation helper.
+ self.shared_expert = sharedexpert_registry.get(args)
+
+ def _init_experts_mlp(self, args: Arguments):
+ return ParallelMLP(args)
+
+ def forward(self, x: torch.Tensor):
+ # NOTE: If we're going to cast the activations to lower precision
+ # do it before we permute the tokens to save bandwidth.
+ x = common.cast_if_autocast_enabled(x)
+
+ # Compute the expert scores and assignments.
+ scores, expert_weights, top_experts = self.router(x)
+
+ # Compute the experts.
+ out = self.experts(x, scores, expert_weights, top_experts)
+ if self.shared_expert is not None:
+ shared_expert_out = self.shared_expert(x)
+ out = self.shared_expert.add_experts_sharedexpert(
+ shared_expert_out,
+ out,
+ )
+ return out
diff --git a/build/torch210-cxx11-cu130-aarch64-linux/_layers/mpu.py b/build/torch210-cxx11-cu130-aarch64-linux/_layers/mpu.py
new file mode 100644
index 0000000000000000000000000000000000000000..434e143ab42bf3f83406d69e9dd1f72777716e22
--- /dev/null
+++ b/build/torch210-cxx11-cu130-aarch64-linux/_layers/mpu.py
@@ -0,0 +1,94 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Optional
+
+import torch
+import torch.distributed as dist
+
+# from megablocks.layers.arguments import Arguments
+from .arguments import Arguments
+
+
+class MoeParam(torch.Tensor):
+
+ def __init__(self):
+ super().__init__(self)
+ self.expert_model_parallel: bool
+
+
+def is_moe_param(tensor: torch.Tensor) -> bool:
+ return hasattr(tensor, 'expert_model_parallel')
+
+
+def get_expert_parallel_world_size(args: Arguments) -> int:
+ return (dist.get_world_size(args.expert_parallel_group) if args.moe_expert_model_parallelism else 1)
+
+
+def get_expert_parallel_rank(args: Arguments) -> int:
+ return (dist.get_rank(args.expert_parallel_group) if args.moe_expert_model_parallelism else 0)
+
+
+def set_expert_model_parallel_attributes(
+ tensor: torch.Tensor,
+ is_parallel: bool,
+):
+ assert not hasattr(tensor, 'expert_model_parallel')
+ setattr(tensor, 'expert_model_parallel', is_parallel)
+
+
+def param_is_expert_model_parallel(param: MoeParam) -> bool:
+ return (hasattr(param, 'expert_model_parallel') and param.expert_model_parallel)
+
+
+def copy_expert_model_parallel_attributes(
+ destination_tensor: torch.Tensor,
+ source_tensor: torch.Tensor,
+):
+ if hasattr(source_tensor, 'expert_model_parallel'):
+ setattr(
+ destination_tensor,
+ 'expert_model_parallel',
+ getattr(source_tensor, 'expert_model_parallel'),
+ )
+
+
+def synchronized_print(group: Optional[dist.ProcessGroup], *x: torch.Tensor):
+ world_size = dist.get_world_size(group)
+ rank = dist.get_rank(group)
+ for i in range(world_size):
+ dist.barrier(group)
+ if i == rank:
+ print(f'rank = {rank}', *x)
+
+
+# Helpers for expert/tensor sharding.
+def expert_sharding_degree(args: Arguments) -> int:
+ world_size = get_expert_parallel_world_size(args)
+ esd = min(world_size, args.moe_num_experts)
+
+ if (args.moe_num_experts % esd) != 0:
+ raise ValueError(f'Cannot shard {args.moe_num_experts} experts {esd} ways.',)
+ return esd
+
+
+def hidden_sharding_degree(args: Arguments) -> int:
+ world_size = get_expert_parallel_world_size(args)
+ esd = expert_sharding_degree(args)
+ hsd = world_size // esd
+
+ if (args.ffn_hidden_size % hsd) != 0:
+ raise ValueError(f'Cannot shard {args.ffn_hidden_size} features {hsd} ways.',)
+ if (esd * hsd) != world_size:
+ raise ValueError(
+ f"Invalid sharding. 'expert_sharding_degree' ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size}).",
+ )
+ return hsd
+
+
+def experts_per_rank(args: Arguments) -> int:
+ return args.moe_num_experts // expert_sharding_degree(args)
+
+
+def features_per_rank(args: Arguments) -> int:
+ return args.ffn_hidden_size // hidden_sharding_degree(args)
diff --git a/build/torch210-cxx11-cu130-aarch64-linux/_layers/router.py b/build/torch210-cxx11-cu130-aarch64-linux/_layers/router.py
new file mode 100644
index 0000000000000000000000000000000000000000..37cb2782348d62583376f1a183c7ede83601216d
--- /dev/null
+++ b/build/torch210-cxx11-cu130-aarch64-linux/_layers/router.py
@@ -0,0 +1,116 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+
+# from megablocks.layers import common
+# from megablocks.layers.arguments import Arguments
+from . import common
+from .arguments import Arguments
+
+_ROUTER_LOGITS = []
+
+
+def _save_router_logits(logits: torch.Tensor, args: Arguments):
+ if args.moe_zloss_weight == 0:
+ return
+ global _ROUTER_LOGITS
+ _ROUTER_LOGITS.append(logits)
+
+
+def clear_router_zloss():
+ global _ROUTER_LOGITS
+ _ROUTER_LOGITS.clear()
+
+
+def batched_router_zloss(args: Arguments):
+ global _ROUTER_LOGITS
+
+ if args.moe_zloss_weight == 0:
+ import warnings
+ warnings.warn('Call to batched_router_zloss, but moe_zloss_weight=0')
+ return 0
+
+ logits_per_router = _ROUTER_LOGITS
+
+ if args.moe_zloss_in_fp32:
+ logits_per_router = [logits.float() for logits in logits_per_router]
+
+ unscaled_zloss_per_router = torch.stack([
+ torch.logsumexp(logits, dim=1).square().mean() for logits in logits_per_router
+ ])
+
+ return args.moe_zloss_weight * unscaled_zloss_per_router
+
+
+# NOTE: To enable end-to-end benchmarking without convergence we
+# support a flag to force the router to assign tokens uniformly
+# across the experts. We do this with a custom autograd operation
+# so that PyTorch still executes the full set of router operation.
+class _UniformExpertAssignment(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, num_experts: int):
+ out = torch.arange(x.numel(), dtype=x.dtype, device=x.device)
+ out = torch.remainder(out, num_experts)
+ return out.view(x.shape)
+
+
+_uniform_expert_assignment = _UniformExpertAssignment.apply
+
+
+class LearnedRouter(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+
+ # Learned router parameters.
+ #
+ # NOTE: This weight matrix is not parallelized with expert model
+ # parallelism. Each device needs the entire router weight matrix
+ # so that it can route its batch of data correctly.
+ self.layer = torch.nn.Linear(
+ args.hidden_size,
+ args.moe_num_experts,
+ bias=False,
+ dtype=common.dtype(args),
+ device=args.device,
+ )
+ args.init_method(self.layer.weight)
+
+ def jitter(self, x: torch.Tensor):
+ low: float = 1.0 - self.args.moe_jitter_eps
+ high: float = 1.0 + self.args.moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return low + noise * (high - low)
+
+ def _top_k(self, scores: torch.Tensor):
+ if self.args.moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, self.args.moe_top_k, dim=-1)
+
+ def forward(self, x: torch.Tensor):
+ if self.training and self.args.moe_jitter_eps is not None:
+ x = x * self.jitter(x)
+
+ logits = self.layer(x.view(-1, x.shape[-1]))
+ _save_router_logits(logits, self.args)
+ scores = logits.softmax(dim=-1)
+ expert_weights, expert_indices = self._top_k(scores)
+ if self.args.moe_normalize_expert_weights:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=self.args.moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+
+ expert_indices = (
+ _uniform_expert_assignment(
+ expert_indices,
+ self.args.moe_num_experts,
+ ) if self.args.uniform_expert_assignment else expert_indices
+ )
+ return scores, expert_weights, expert_indices
diff --git a/build/torch210-cxx11-cu130-aarch64-linux/_layers/sharedexpert_registry.py b/build/torch210-cxx11-cu130-aarch64-linux/_layers/sharedexpert_registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..5840862f88f370ace5fd49bd0612fc98d186cc49
--- /dev/null
+++ b/build/torch210-cxx11-cu130-aarch64-linux/_layers/sharedexpert_registry.py
@@ -0,0 +1,32 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Union
+
+# from megablocks.layers import glu, mlp
+# from megablocks.layers.arguments import Arguments
+from . import glu, mlp
+from .arguments import Arguments
+
+_REGISTRY = {
+ 'mlp': mlp.SharedMLP,
+ 'glu': glu.SharedGLU,
+}
+
+
+def get(args: Arguments) -> Union[mlp.SharedMLP, glu.SharedGLU]:
+ """Returns an SharedMLP for use in a dMoE instance.
+
+ Uses the provided arguments to instantiate the appropriate
+ SharedMLP instance.
+
+ Args:
+ args: propagated Arguments dataclass.
+
+ Returns:
+ An instantiated SharedMLP constructed using the input args.
+ """
+ if args.mlp_type not in _REGISTRY:
+ raise ValueError(f'Unsupported mlp type: {args.mlp_type}')
+
+ return _REGISTRY[args.mlp_type](args)
diff --git a/build/torch210-cxx11-cu130-aarch64-linux/_megablocks_cuda_7a6bcf4.abi3.so b/build/torch210-cxx11-cu130-aarch64-linux/_megablocks_cuda_7a6bcf4.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..3aace39adc87c5eb54f3e3b57f2d05bf12eb3eb6
--- /dev/null
+++ b/build/torch210-cxx11-cu130-aarch64-linux/_megablocks_cuda_7a6bcf4.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1a97da44105e24c5ce37c013d124fa87ddb71aa465a5278304f87f426bafd575
+size 12073200
diff --git a/build/torch210-cxx11-cu130-aarch64-linux/_ops.py b/build/torch210-cxx11-cu130-aarch64-linux/_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..abde001f8cf5f78a02794d6e9a81fd8195e65d77
--- /dev/null
+++ b/build/torch210-cxx11-cu130-aarch64-linux/_ops.py
@@ -0,0 +1,9 @@
+import torch
+from . import _megablocks_cuda_7a6bcf4
+ops = torch.ops._megablocks_cuda_7a6bcf4
+
+def add_op_namespace_prefix(op_name: str):
+ """
+ Prefix op by namespace.
+ """
+ return f"_megablocks_cuda_7a6bcf4::{op_name}"
diff --git a/build/torch210-cxx11-cu130-aarch64-linux/_version.py b/build/torch210-cxx11-cu130-aarch64-linux/_version.py
new file mode 100644
index 0000000000000000000000000000000000000000..c55783177af19bc03654c730c4892df8f8532279
--- /dev/null
+++ b/build/torch210-cxx11-cu130-aarch64-linux/_version.py
@@ -0,0 +1,6 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+"""The MegaBlocks Version."""
+
+__version__ = '0.11.0.dev0'
diff --git a/build/torch210-cxx11-cu130-aarch64-linux/backend/__init__.py b/build/torch210-cxx11-cu130-aarch64-linux/backend/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d4e43e9b7e6a9a1c3bde2df34914643ca5d8332
--- /dev/null
+++ b/build/torch210-cxx11-cu130-aarch64-linux/backend/__init__.py
@@ -0,0 +1,2 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
diff --git a/build/torch210-cxx11-cu130-aarch64-linux/backend/kernels.py b/build/torch210-cxx11-cu130-aarch64-linux/backend/kernels.py
new file mode 100644
index 0000000000000000000000000000000000000000..d20139324883338ddb312e4b05a72056d74491ac
--- /dev/null
+++ b/build/torch210-cxx11-cu130-aarch64-linux/backend/kernels.py
@@ -0,0 +1,557 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import triton
+import triton.language as tl
+
+# Stub triton autotune when testing in a env that does not have CUDA
+# this approach preserves the original code but enables testing without a GPU
+if torch.cuda.is_available() is False:
+ import warnings
+
+ warnings.warn("CUDA is not available. Triton autotuning is disabled.")
+
+ def _no_autotune(*args, **kwargs):
+ def deco(fn):
+ return fn
+ return deco
+
+ triton.autotune = _no_autotune
+
+
+def assert_is_tensor(x, ndim):
+ if x.ndim != ndim:
+ raise ValueError(f'Expected {ndim}-tensor but got {x.ndim}-tensor')
+
+
+def assert_is_matrix(x):
+ assert_is_tensor(x, 2)
+
+
+def assert_is_vector(x):
+ if x.ndim != 1:
+ raise ValueError(f'Expected 1-tensor but got {x.ndim}-tensor')
+
+
+def assert_equal(a, b):
+ if a != b:
+ raise ValueError(f'Expected dimensions to be equal but got {a} and {b}.',)
+
+
+# a: (tokens, hidden_size), real.
+# indices: (tokens * top_k), integer.
+# bin_ids: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+# padded_bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _padded_copy(
+ a,
+ b,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+ A_TO_B: tl.constexpr,
+ SCALE: tl.constexpr,
+):
+ # Our index into array 'a'.
+ index_a = tl.load(indices + tl.program_id(0))
+
+ # One threadblock per row in 'a'. Array 'b' has greater or equal
+ # number of rows since they could be padded.
+ bin_idx = tl.load(bin_ids + tl.program_id(0))
+
+ # Now we know what bin we're assigned to, but we need to know how
+ # many threadblocks were assigned to earlier bins so we can offset
+ # in our bin properly.
+ offset_in_bin = tl.program_id(0)
+ if bin_idx > 0:
+ offset_in_bin -= tl.load(bins + bin_idx - 1)
+
+ # Load the starting index of our bin in array 'b'.
+ index_b = offset_in_bin
+ if bin_idx > 0:
+ index_b += tl.load(padded_bins + bin_idx - 1)
+
+ # Offset the input and output pointers.
+ #
+ # If we're going from A to B, divide the input index to copy
+ # the same input repeatedly. If we're going from B to A we
+ # need to reduce the result. Using atomics is slow, so we
+ # do the reduce step in a second kernel.
+ offset = index_a // TOP_K if A_TO_B else index_a
+ a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS)
+ b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ # Load the scale, if requested.
+ scale = tl.load(weights + index_a) if SCALE else 1
+
+ # Swap the pointers depending on the direction.
+ iptr = a if A_TO_B else b
+ optr = b if A_TO_B else a
+
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ x = tl.load(iptr + offsets, mask=mask)
+ x = x.to(tl.float32) * scale.to(tl.float32)
+
+ tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask)
+
+ offsets += BLOCK_X
+
+
+def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+ assert_equal(bin_ids.shape[0], x.shape[0] * top_k)
+ assert_equal(bins.size(), padded_bins.size())
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ # NOTE: Because of the padding, the output size is dynamic.
+ # We load the final padded bin bound to get the output rows.
+ output_rows = padded_bins[-1].cpu().item()
+ out = torch.zeros((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ x,
+ out,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def gather(x, indices, bin_ids, weights, bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+ assert_equal(bin_ids.shape[0], x.shape[0] * top_k)
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ # NOTE: There is no padding so the output rows equals the
+ # input rows multiplied by top_k.
+ output_rows = x.shape[0] * top_k
+ out = torch.empty((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ x,
+ out,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], bin_ids.shape[0])
+ assert_equal(bins.size(), padded_bins.size())
+
+ if weights is not None:
+ assert_equal(indices.shape[0], weights.shape[0])
+
+ tokens = indices.shape[0] // top_k
+ out = torch.empty((tokens, top_k, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ out,
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=False,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+
+ # Reduce along the top-k dimension, if needed.
+ return out.sum(dim=1) if top_k > 1 else out.view(tokens, x.shape[1])
+
+
+def scatter(x, indices, bin_ids, weights, bins, top_k):
+ return padded_scatter(x, indices, bin_ids, weights, bins, bins, top_k)
+
+
+# x: (tokens, top_k, hidden_size), real
+# grad: (tokens, hidden_size), real.
+# wgrad: (tokens, top_k), real.
+# indices: (tokens * top_k), integer.
+# bin_ids: (tokens * top_k), integer.
+# bins: (num_experts), integer.
+# padded_bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _padded_copy_wgrad(
+ x,
+ grad,
+ wgrad,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+):
+ # Our index into 'tokens * top_k'.
+ index_out = tl.load(indices + tl.program_id(0))
+
+ # One threadblock per row in 'a'. Array 'b' has greater or equal
+ # number of rows since they could be padded.
+ bin_idx = tl.load(bin_ids + tl.program_id(0))
+
+ # Now we know what bin we're assigned to, but we need to know how
+ # many threadblocks were assigned to earlier bins so we can offset
+ # in our bin properly.
+ offset_in_bin = tl.program_id(0)
+ if bin_idx > 0:
+ offset_in_bin -= tl.load(bins + bin_idx - 1)
+
+ # Load the starting index of our bin in array 'x'.
+ index_x = offset_in_bin
+ if bin_idx > 0:
+ index_x += tl.load(padded_bins + bin_idx - 1)
+
+ # Offset the input and output pointers.
+ wgrad += index_out
+ grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS)
+ x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ data = tl.load(x + offsets, mask=mask).to(tl.float32)
+ scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
+ acc += data * scale
+ offsets += BLOCK_X
+
+ # Reduce to get the final result and store.
+ out = tl.sum(acc).to(wgrad.dtype.element_ty)
+ tl.store(wgrad, out)
+
+
+def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_matrix(grad)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], bin_ids.shape[0])
+ assert_equal(bins.size(), padded_bins.size())
+
+ tokens = indices.shape[0] // top_k
+ out = torch.empty((tokens * top_k), dtype=x.dtype, device=x.device)
+ _padded_copy_wgrad[(indices.shape[0],)](
+ x,
+ grad,
+ out,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ TOP_K=top_k,
+ )
+ return out
+
+
+def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k):
+ return padded_scatter_wgrad(x, grad, indices, bin_ids, bins, bins, top_k)
+
+
+# a: (tokens, hidden_size), real.
+# b: (num_experts, expert_capacity, num_columns), real.
+# indices: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _binned_copy(
+ a,
+ b,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+ A_TO_B: tl.constexpr,
+ SCALE: tl.constexpr,
+):
+ # Load our indices into the output.
+ expert_idx = tl.program_id(0)
+ entry_idx = tl.program_id(1)
+
+ # Calculate our offset into the output.
+ index_b = expert_idx * expert_capacity + entry_idx
+
+ # Load the index bounds for our bin and calculate
+ # the number of tokens assigned to our expert.
+ start = 0
+ if expert_idx > 0:
+ start = tl.load(bins + expert_idx - 1)
+ end = tl.load(bins + expert_idx)
+ num_tokens = end - start
+
+ # Calculate our offset into the input. If we don't
+ # have an input exit early.
+ if entry_idx >= num_tokens:
+ return
+ index_a = tl.load(indices + start + entry_idx)
+
+ # Offset the input and output pointers.
+ #
+ # If we're going from A to B, divide the input index to copy
+ # the same input repeatedly. If we're going from B to A we
+ # need to reduce the result. Using atomics is slow, so we
+ # do the reduce step in a second kernel.
+ offset = index_a // TOP_K if A_TO_B else index_a
+ a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS)
+ b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ # Load the scale, if requested.
+ scale = tl.load(weights + index_a) if SCALE else 1
+
+ # Swap the pointers depending on the direction.
+ #
+ # NOTE: We need to zero the output in both directions.
+ iptr = a if A_TO_B else b
+ optr = b if A_TO_B else a
+
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ x = tl.load(iptr + offsets, mask=mask)
+ x = x.to(tl.float32) * scale.to(tl.float32)
+
+ tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask)
+
+ offsets += BLOCK_X
+
+
+def binned_gather(x, indices, weights, bins, expert_capacity, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ num_experts = bins.shape[0]
+ out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
+
+ _binned_copy[(num_experts, expert_capacity)](
+ x,
+ out,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def binned_scatter(x, indices, weights, bins, top_k):
+ # Validate the input shapes.
+ assert_is_tensor(x, 3)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(bins.shape[0], x.shape[0])
+
+ if weights is not None:
+ assert_equal(indices.shape[0], weights.shape[0])
+
+ num_experts, expert_capacity, hidden_size = x.shape
+ tokens = indices.shape[0] // top_k
+ out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
+ _binned_copy[(num_experts, expert_capacity)](
+ out,
+ x,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS=hidden_size,
+ A_TO_B=False,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+
+ # Reduce along the top-k dimension, if needed.
+ return out.sum(dim=1) if top_k > 1 else out.view(tokens, hidden_size)
+
+
+# a: (tokens, hidden_size), real.
+# b: (num_experts, expert_capacity, num_columns), real.
+# indices: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _binned_copy_wgrad(
+ x,
+ grad,
+ wgrad,
+ num_experts,
+ expert_capacity,
+ indices,
+ bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+):
+ # Load our indices into the output.
+ expert_idx = tl.program_id(0)
+ entry_idx = tl.program_id(1)
+
+ # Calculate our offset into the output.
+ index_x = expert_idx * expert_capacity + entry_idx
+
+ # Load the index bounds for our bin and calculate
+ # the number of tokens assigned to our expert.
+ start = 0
+ if expert_idx > 0:
+ start = tl.load(bins + expert_idx - 1)
+ end = tl.load(bins + expert_idx)
+ num_tokens = end - start
+
+ # Calculate our offset into the input. If we don't
+ # have an input exit early.
+ if entry_idx >= num_tokens:
+ return
+ index_out = tl.load(indices + start + entry_idx)
+
+ # Offset the input and output pointers.
+ wgrad += index_out
+ grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS)
+ x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ data = tl.load(x + offsets, mask=mask).to(tl.float32)
+ scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
+ acc += data * scale
+ offsets += BLOCK_X
+
+ # Reduce to get the final result and store.
+ out = tl.sum(acc).to(wgrad.dtype.element_ty)
+ tl.store(wgrad, out)
+
+
+def binned_scatter_wgrad(x, grad, indices, bins, top_k):
+ # Validate the input shapes.
+ assert_is_tensor(x, 3)
+ assert_is_matrix(grad)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(bins.shape[0], x.shape[0])
+
+ num_experts, expert_capacity, hidden_size = x.shape
+ tokens = indices.shape[0] // top_k
+ out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
+ _binned_copy_wgrad[(num_experts, expert_capacity)](
+ x,
+ grad,
+ out,
+ num_experts,
+ expert_capacity,
+ indices,
+ bins,
+ NUM_COLUMNS=hidden_size,
+ TOP_K=top_k,
+ )
+ return out
diff --git a/build/torch210-cxx11-cu130-aarch64-linux/benchmark_util.py b/build/torch210-cxx11-cu130-aarch64-linux/benchmark_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..02612d95e3ead1175a596e2878fa34b5bf85ad6f
--- /dev/null
+++ b/build/torch210-cxx11-cu130-aarch64-linux/benchmark_util.py
@@ -0,0 +1,35 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import numpy as np
+import torch
+
+
+def log_benchmark(name, arguments, time, std):
+ print('=' * 60)
+ print(f'{name} Benchmark')
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean time = {:.3f}ms, std time = {:.3f}ms'.format(time, std))
+ print('=' * 60)
+
+
+def benchmark_function(fn, iterations=100, warmup=10):
+ # Warmup iterations.
+ for _ in range(warmup):
+ fn()
+
+ times = []
+ for i in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+
+ start.record()
+ fn()
+ end.record()
+
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ return np.mean(times), np.std(times)
diff --git a/build/torch210-cxx11-cu130-aarch64-linux/cpu_fused_moe.py b/build/torch210-cxx11-cu130-aarch64-linux/cpu_fused_moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a29e81d6e9c57f64b0a78ac1c0828e45fd9d855
--- /dev/null
+++ b/build/torch210-cxx11-cu130-aarch64-linux/cpu_fused_moe.py
@@ -0,0 +1,311 @@
+# SPDX-License-Identifier: Apache-2.0
+# MegaBlocks CPU Fused MoE Implementation
+#
+# This is a pure Python/PyTorch implementation for CPU.
+# For better performance, consider using the C++ kernel implementation.
+#
+import torch
+import torch.nn.functional as F
+
+
+def swigluoai_activation(gate: torch.Tensor, up: torch.Tensor,
+ alpha: float = 1.702, limit: float = 7.0) -> torch.Tensor:
+ """
+ SwigluOAI activation function used in GptOss models.
+
+ Formula:
+ gate = clamp(gate, max=limit)
+ up = clamp(up, -limit, limit)
+ glu = gate * sigmoid(gate * alpha)
+ output = (up + 1) * glu
+
+ Args:
+ gate: Gate tensor from gate projection
+ up: Up tensor from up projection
+ alpha: Scaling factor for sigmoid (default: 1.702)
+ limit: Clamp limit (default: 7.0)
+
+ Returns:
+ Activated tensor
+ """
+ gate = gate.clamp(max=limit)
+ up = up.clamp(min=-limit, max=limit)
+ glu = gate * torch.sigmoid(gate * alpha)
+ return (up + 1) * glu
+
+
+def silu_and_mul_activation(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor:
+ """
+ SiLU (Swish) activation with element-wise multiplication.
+
+ Formula:
+ output = silu(gate) * up
+
+ Args:
+ gate: Gate tensor
+ up: Up tensor
+
+ Returns:
+ Activated tensor
+ """
+ return F.silu(gate) * up
+
+
+def route_tokens_cpu(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: torch.Tensor | None,
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_normalize_expert_weights: int | None = None,
+) -> tuple:
+ """
+ Route tokens to experts and compute expert weights and indices (CPU version).
+
+ Args:
+ x: Input tensor [batch, seq, hidden] or [tokens, hidden]
+ router_weight: Router weight [num_experts, hidden]
+ router_bias: Router bias [num_experts] or None
+ moe_top_k: Number of experts per token
+ moe_num_experts: Total number of experts
+ moe_normalize_expert_weights: Normalization order or None
+
+ Returns:
+ Tuple of (logits, expert_weights, expert_indices)
+ """
+ x_flat = x.view(-1, x.shape[-1])
+ logits = F.linear(x_flat, router_weight, router_bias)
+
+ if moe_top_k == 1:
+ expert_weights, expert_indices = logits.max(dim=-1, keepdim=True)
+ else:
+ expert_weights, expert_indices = torch.topk(logits, moe_top_k, dim=-1)
+
+ expert_weights = expert_weights.softmax(dim=-1)
+
+ if moe_normalize_expert_weights is not None:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+
+ return logits, expert_weights, expert_indices
+
+
+def cpu_fused_moe(
+ hidden_states: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ w1_bias: torch.Tensor | None = None,
+ w2_bias: torch.Tensor | None = None,
+ activation: str = "silu",
+ alpha: float = 1.702,
+ limit: float = 7.0,
+ is_interleaved: bool = True,
+) -> torch.Tensor:
+ """
+ CPU Fused MoE using PyTorch operations.
+
+ This implementation processes all experts in parallel using batched operations
+ instead of sequential for loops, which is more efficient on CPU.
+
+ Args:
+ hidden_states: [num_tokens, hidden_size]
+ w1: [num_experts, hidden_size, 2*inter_size] - gate_up_proj weights
+ w2: [num_experts, inter_size, hidden_size] - down_proj weights
+ topk_weights: [num_tokens, topk] - routing weights
+ topk_ids: [num_tokens, topk] - expert indices
+ w1_bias: [num_experts, 2*inter_size] or None
+ w2_bias: [num_experts, hidden_size] or None
+ activation: "silu" or "swigluoai"
+ alpha: swigluoai alpha parameter
+ limit: swigluoai limit parameter
+ is_interleaved: whether gate_up is interleaved [g0,u0,g1,u1,...] (True for GptOss)
+
+ Returns:
+ output: [num_tokens, hidden_size]
+ """
+ num_tokens, hidden_size = hidden_states.shape
+ num_experts = w1.shape[0]
+ inter_size = w2.shape[1]
+ topk = topk_weights.shape[1]
+
+ # Initialize output
+ output = torch.zeros_like(hidden_states)
+
+ # Build expert mask: which tokens go to which expert
+ # expert_mask[expert_id] contains indices of (token_idx, topk_pos) pairs
+ for expert_idx in range(num_experts):
+ # Find tokens assigned to this expert
+ # mask shape: [num_tokens, topk], True where topk_ids == expert_idx
+ mask = (topk_ids == expert_idx)
+
+ if not mask.any():
+ continue
+
+ # Get token indices and topk positions
+ token_indices, topk_positions = torch.where(mask)
+
+ if len(token_indices) == 0:
+ continue
+
+ # Gather input tokens for this expert
+ # current_hidden: [num_selected_tokens, hidden_size]
+ current_hidden = hidden_states[token_indices]
+
+ # Get weights for this expert
+ # w1[expert_idx]: [hidden_size, 2*inter_size]
+ # w2[expert_idx]: [inter_size, hidden_size]
+ expert_w1 = w1[expert_idx] # [hidden_size, 2*inter_size]
+ expert_w2 = w2[expert_idx] # [inter_size, hidden_size]
+
+ # First projection: hidden @ w1 -> [num_selected, 2*inter_size]
+ gate_up = current_hidden @ expert_w1
+
+ # Add bias if present
+ if w1_bias is not None:
+ gate_up = gate_up + w1_bias[expert_idx]
+
+ # Split gate and up projections
+ if is_interleaved:
+ # GptOss uses interleaved layout: [g0, u0, g1, u1, ...]
+ gate = gate_up[..., ::2] # [num_selected, inter_size]
+ up = gate_up[..., 1::2] # [num_selected, inter_size]
+ else:
+ # Standard layout: [gate_all, up_all]
+ gate = gate_up[..., :inter_size]
+ up = gate_up[..., inter_size:]
+
+ # Apply activation
+ if activation == "swigluoai":
+ activated = swigluoai_activation(gate, up, alpha, limit)
+ else: # silu
+ activated = silu_and_mul_activation(gate, up)
+
+ # Second projection: activated @ w2 -> [num_selected, hidden_size]
+ expert_out = activated @ expert_w2
+
+ # Add bias if present
+ if w2_bias is not None:
+ expert_out = expert_out + w2_bias[expert_idx]
+
+ # Apply routing weights and accumulate
+ # weights shape: [num_selected]
+ weights = topk_weights[token_indices, topk_positions].unsqueeze(-1)
+ weighted_out = expert_out * weights
+
+ # Accumulate to output
+ output.index_add_(0, token_indices, weighted_out.to(output.dtype))
+
+ return output
+
+
+class MegaBlocksMoeMLP(torch.nn.Module):
+ """
+ CPU MoE MLP module that can be used as a drop-in replacement for
+ the transformers GptOssMLP when using @use_kernel_forward_from_hub.
+ """
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> tuple:
+ """
+ Forward pass through the MoE layer.
+
+ Args:
+ x: Input tensor of shape [batch_size, seq_len, hidden_size] or [tokens, hidden_size]
+
+ Returns:
+ Tuple of (output, expert_weights) where:
+ - output: Tensor of same shape as input
+ - expert_weights: Expert weights for each token [tokens, top_k]
+ """
+ # Get MoE parameters from the wrapped modules
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+
+ # Detect activation type
+ if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
+ activation = "swigluoai"
+ alpha = self.experts.alpha
+ limit = self.experts.limit
+ else:
+ activation = getattr(self.experts, "activation", "silu")
+ alpha = 1.702
+ limit = 7.0
+
+ # Get weight tensors
+ if hasattr(self.experts, "gate_up_proj"):
+ w1 = self.experts.gate_up_proj
+ is_interleaved = True # GptOss uses interleaved layout
+ elif hasattr(self.experts, "w1"):
+ w1 = self.experts.w1
+ w3 = getattr(self.experts, "w3", None)
+ if w3 is not None:
+ w1 = torch.cat([w1, w3], dim=-1)
+ is_interleaved = False
+ else:
+ raise AttributeError("experts module must have 'gate_up_proj' or 'w1' attribute")
+
+ if hasattr(self.experts, "down_proj"):
+ w2 = self.experts.down_proj
+ elif hasattr(self.experts, "w2"):
+ w2 = self.experts.w2
+ else:
+ raise AttributeError("experts module must have 'down_proj' or 'w2' attribute")
+
+ # Get optional bias tensors
+ w1_bias = getattr(self.experts, "gate_up_proj_bias", None)
+ w2_bias = getattr(self.experts, "down_proj_bias", None)
+
+ # Store original shape
+ in_shape = x.size()
+
+ # Route tokens to experts
+ logits, expert_weights, expert_indices = route_tokens_cpu(
+ x,
+ self.router.weight,
+ getattr(self.router, "bias", None),
+ moe_top_k,
+ moe_num_experts,
+ moe_normalize_expert_weights,
+ )
+
+ # Reshape input for fused MoE
+ x_flat = x.view(-1, x.shape[-1])
+
+ # Call CPU fused MoE
+ output = cpu_fused_moe(
+ hidden_states=x_flat,
+ w1=w1,
+ w2=w2,
+ topk_weights=expert_weights,
+ topk_ids=expert_indices,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ activation=activation,
+ alpha=alpha,
+ limit=limit,
+ is_interleaved=is_interleaved,
+ )
+
+ # Restore original shape
+ output = output.view(in_shape)
+
+ return output, expert_weights
+
+
+# Export classes and functions
+__all__ = [
+ "MegaBlocksMoeMLP",
+ "cpu_fused_moe",
+ "route_tokens_cpu",
+ "swigluoai_activation",
+ "silu_and_mul_activation",
+]
diff --git a/build/torch210-cxx11-cu130-aarch64-linux/cpu_moe_cpp.py b/build/torch210-cxx11-cu130-aarch64-linux/cpu_moe_cpp.py
new file mode 100644
index 0000000000000000000000000000000000000000..073ff66d24ce348fbb5ed19c9027fadd3f7a9c61
--- /dev/null
+++ b/build/torch210-cxx11-cu130-aarch64-linux/cpu_moe_cpp.py
@@ -0,0 +1,265 @@
+# SPDX-License-Identifier: Apache-2.0
+# MegaBlocks C++ Optimized CPU MoE
+
+"""
+C++ accelerated MoE with brgemm optimization for Intel AMX.
+Direct replacement for cpu_fused_moe.MegaBlocksMoeMLP with better performance.
+"""
+
+import torch
+from typing import Optional
+from .cpu_fused_moe import route_tokens_cpu
+from ._ops import ops
+
+
+def _to_local_tensor(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
+ """Convert DTensor to local torch.Tensor if needed for custom ops compatibility."""
+ if tensor is None:
+ return None
+ # Check if it's a DTensor by looking for the to_local() method
+ if hasattr(tensor, "to_local"):
+ return tensor.to_local()
+ return tensor
+
+
+def fused_moe_cpp(
+ hidden_states: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ inplace: bool = False,
+ use_int8_w8a8: bool = False,
+ use_fp8_w8a16: bool = False,
+ use_mxfp4: bool = False,
+ w1_scale: Optional[torch.Tensor] = None,
+ w2_scale: Optional[torch.Tensor] = None,
+ block_size: Optional[list] = None,
+ a1_scale: Optional[torch.Tensor] = None,
+ a2_scale: Optional[torch.Tensor] = None,
+ w1_bias: Optional[torch.Tensor] = None,
+ w2_bias: Optional[torch.Tensor] = None,
+ alpha: Optional[float] = None,
+ limit: Optional[float] = None,
+ is_vnni: bool = False,
+) -> torch.Tensor:
+ """
+ C++ Fused MoE with brgemm optimization (sglang compatible interface).
+
+ Uses at::native::cpublas::brgemm for efficient batch GEMM on Intel CPUs.
+ Supports both silu_and_mul (standard SwiGLU) and swigluoai (GptOss) activations.
+
+ Args:
+ hidden_states: Input tensor [M, K]
+ w1: Gate and up projections [E, 2N, K]
+ w2: Down projection [E, K, N]
+ topk_weights: Expert weights [M, topk]
+ topk_ids: Expert indices [M, topk]
+ inplace: Whether to use hidden_states as output
+ use_int8_w8a8: Use int8 quantization
+ use_fp8_w8a16: Use fp8 quantization
+ use_mxfp4: Use mxfp4 quantization
+ w1_scale, w2_scale: Quantization scales
+ block_size: Block size for fp8
+ a1_scale, a2_scale: Activation scales
+ w1_bias, w2_bias: Optional biases
+ alpha: swigluoai alpha parameter (set to enable swiglu)
+ limit: swigluoai limit parameter (set to enable swiglu)
+ is_vnni: Whether w1/w2 are already in VNNI packed format
+ """
+ # MXFP4/FP8 kernels only support bf16, convert if needed
+ orig_dtype = hidden_states.dtype
+ need_convert = ((use_mxfp4 or use_fp8_w8a16) and orig_dtype != torch.bfloat16) or orig_dtype == torch.float32
+ if need_convert:
+ hidden_states = hidden_states.to(torch.bfloat16)
+
+ # bias must match hidden_states dtype
+ if w1_bias is not None:
+ w1_bias = w1_bias.to(hidden_states.dtype)
+ if w2_bias is not None:
+ w2_bias = w2_bias.to(hidden_states.dtype)
+
+ # Convert DTensor to local tensor for custom ops compatibility (TP mode)
+ hidden_states = _to_local_tensor(hidden_states)
+ w1 = _to_local_tensor(w1)
+ w2 = _to_local_tensor(w2)
+ topk_weights = _to_local_tensor(topk_weights)
+ topk_ids = _to_local_tensor(topk_ids)
+ w1_scale = _to_local_tensor(w1_scale)
+ w2_scale = _to_local_tensor(w2_scale)
+ a1_scale = _to_local_tensor(a1_scale)
+ a2_scale = _to_local_tensor(a2_scale)
+ w1_bias = _to_local_tensor(w1_bias)
+ w2_bias = _to_local_tensor(w2_bias)
+
+ output = ops.fused_experts(
+ hidden_states, w1, w2, topk_weights, topk_ids,
+ inplace, use_int8_w8a8, use_fp8_w8a16, use_mxfp4,
+ w1_scale, w2_scale, block_size, a1_scale, a2_scale,
+ w1_bias, w2_bias, alpha, limit, is_vnni
+ )
+
+ # Convert back to original dtype if needed
+ if need_convert:
+ output = output.to(orig_dtype)
+ return output
+
+
+class CPUMegaBlocksMoeMLP(torch.nn.Module):
+ """
+ C++ optimized MoE MLP using brgemm.
+ Drop-in replacement for cpu_fused_moe.MegaBlocksMoeMLP with better performance.
+
+ Usage in transformers:
+ # Will be used via @use_kernel_forward_from_hub decorator
+ """
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> tuple:
+ """
+ Forward pass through the MoE layer using C++ kernel.
+
+ Args:
+ x: Input tensor [batch_size, seq_len, hidden_size]
+
+ Returns:
+ Tuple of (output, expert_weights)
+ """
+ # Optimization for GPT-OSS model
+ if getattr(self, "use_mxfp4", None) is None:
+ self.use_mxfp4 = False
+
+ w1_scale = None
+ w2_scale = None
+
+ if (
+ not getattr(self, "packed_scales", False)
+ and hasattr(self.experts, "gate_up_proj")
+ and getattr(self.experts, "gate_up_proj_precision_config", None) is not None
+ ):
+ # convert scales
+ data_1 = ops.convert_scale_packed(self.experts.gate_up_proj_precision_config.weight_scale.data.transpose(-1, -2).contiguous())
+ data_2 = ops.convert_scale_packed(self.experts.down_proj_precision_config.weight_scale.data.transpose(-1, -2).contiguous())
+ self.experts.gate_up_proj_precision_config.weight_scale.storage.data = data_1
+ self.experts.down_proj_precision_config.weight_scale.storage.data = data_2
+ self.packed_scales = True
+ self.use_mxfp4 = True
+
+ if not getattr(self, "packed_weight", False) and hasattr(
+ self.experts, "gate_up_proj"
+ ):
+ # convert weights
+ data_1 = self.experts.gate_up_proj.data.transpose(-1, -2).contiguous()
+ data_2 = self.experts.down_proj.data.transpose(-1, -2).contiguous()
+ if self.use_mxfp4:
+ self.experts.gate_up_proj.storage.data = ops.convert_weight_packed(data_1)
+ self.experts.down_proj.storage.data = ops.convert_weight_packed(data_2)
+ else:
+ # convert_weight_packed only supports bfloat16, float16, int8, fp8_e4m3 or uint8(mxfp4 or int4).
+ data_1 = data_1.to(torch.bfloat16) if data_1.dtype == torch.float32 else data_1
+ data_2 = data_2.to(torch.bfloat16) if data_2.dtype == torch.float32 else data_2
+ self.experts.gate_up_proj.data = ops.convert_weight_packed(data_1)
+ self.experts.down_proj.data = ops.convert_weight_packed(data_2)
+
+ # C++ kernel does not support float32.
+ dtype = torch.bfloat16 if x.dtype == torch.float32 else x.dtype
+ if getattr(self.experts, "gate_up_proj_bias", None) is not None:
+ self.experts.gate_up_proj_bias.data = self.experts.gate_up_proj_bias.data.to(dtype)
+ if getattr(self.experts, "down_proj_bias", None) is not None:
+ self.experts.down_proj_bias.data = self.experts.down_proj_bias.data.to(dtype)
+
+ self.packed_weight = True
+
+ # Get MoE parameters
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
+
+ # Detect activation type
+ if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
+ activation = "swigluoai"
+ alpha = self.experts.alpha
+ limit = self.experts.limit
+ else:
+ activation = getattr(self.experts, "activation", "silu")
+ alpha = 1.702
+ limit = 7.0
+
+ # Get weight tensors
+ if hasattr(self.experts, "gate_up_proj"):
+ w1 = self.experts.gate_up_proj
+ elif hasattr(self.experts, "w1"):
+ w1 = self.experts.w1
+ w3 = getattr(self.experts, "w3", None)
+ if w3 is not None:
+ w1 = torch.cat([w1, w3], dim=-1)
+ else:
+ raise AttributeError("experts module must have 'gate_up_proj' or 'w1' attribute")
+
+ if hasattr(self.experts, "down_proj"):
+ w2 = self.experts.down_proj
+ elif hasattr(self.experts, "w2"):
+ w2 = self.experts.w2
+ else:
+ raise AttributeError("experts module must have 'down_proj' or 'w2' attribute")
+
+ # Get optional bias tensors
+ w1_bias = getattr(self.experts, "gate_up_proj_bias", None)
+ w2_bias = getattr(self.experts, "down_proj_bias", None)
+ w1_bias = w1_bias if w1_bias is None else w1_bias.data
+ w2_bias = w2_bias if w2_bias is None else w2_bias.data
+
+ if self.use_mxfp4:
+ w1_scale = self.experts.gate_up_proj_precision_config.weight_scale.data
+ w2_scale = self.experts.down_proj_precision_config.weight_scale.data
+
+ # Store original shape
+ in_shape = x.size()
+
+ # Route tokens to experts (Python implementation is fast enough)
+ logits, expert_weights, expert_indices = route_tokens_cpu(
+ x,
+ self.router.weight,
+ getattr(self.router, "bias", None),
+ moe_top_k,
+ moe_num_experts,
+ moe_normalize_expert_weights,
+ )
+
+ # Flatten input
+ x_flat = x.view(-1, x.shape[-1])
+
+ # Determine alpha/limit for swiglu activation
+ use_alpha = alpha if activation == "swigluoai" else None
+ use_limit = limit if activation == "swigluoai" else None
+
+ # Call C++ optimized kernel
+ output = fused_moe_cpp(
+ hidden_states=x_flat,
+ w1=w1.data,
+ w2=w2.data,
+ topk_weights=expert_weights,
+ topk_ids=expert_indices.to(torch.int32),
+ inplace=False,
+ use_int8_w8a8=False,
+ use_fp8_w8a16=False,
+ use_mxfp4=self.use_mxfp4,
+ w1_scale=w1_scale,
+ w2_scale=w2_scale,
+ block_size=None,
+ a1_scale=None,
+ a2_scale=None,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ alpha=use_alpha,
+ limit=use_limit,
+ is_vnni=getattr(self, "packed_weight", False),
+ )
+
+ # Restore original shape
+ output = output.view(in_shape)
+
+ return output, expert_weights
+
+
+__all__ = ["fused_moe_cpp", "MegaBlocksMoeMLP"]
diff --git a/build/torch210-cxx11-cu130-aarch64-linux/grouped_gemm/__init__.py b/build/torch210-cxx11-cu130-aarch64-linux/grouped_gemm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b91c8308f0c24f4c4171b6e4f15b6f76dabf295a
--- /dev/null
+++ b/build/torch210-cxx11-cu130-aarch64-linux/grouped_gemm/__init__.py
@@ -0,0 +1,2 @@
+from . import ops
+from . import backend
diff --git a/build/torch210-cxx11-cu130-aarch64-linux/grouped_gemm/backend.py b/build/torch210-cxx11-cu130-aarch64-linux/grouped_gemm/backend.py
new file mode 100644
index 0000000000000000000000000000000000000000..76037d8039cbfc2f0577275c78e4bc0be762592a
--- /dev/null
+++ b/build/torch210-cxx11-cu130-aarch64-linux/grouped_gemm/backend.py
@@ -0,0 +1,33 @@
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# # TODO(tgale): Wrap this in a try-block with better
+# # error message and instructions for building the
+# # c++ operations.
+# import grouped_gemm_backend as backend
+
+# We import the backend operations from the megablocks package as
+# grouped_gemm is vendored in megablocks in this repository.
+# from ... import _ops as backend
+# from megablocks._ops import ops as backend # type: ignore
+from .._ops import ops as backend # type: ignore
+
+def _allocate_output(a, b, batch_sizes, trans_a, trans_b):
+ assert not (trans_a and trans_b)
+ assert batch_sizes.ndim == 1, "Expected 1d tensor for batch_sizes"
+ assert a.ndim == 2, "Expected 2d tensor for 'a'"
+ assert b.ndim == (2 if trans_a else 3)
+
+ shape = (
+ (batch_sizes.shape[0], a.shape[1], b.shape[1])
+ if trans_a else
+ (a.shape[0], (b.shape[1] if trans_b else b.shape[2]))
+ )
+ return torch.empty(*shape, device=a.device, dtype=a.dtype)
+
+def gmm(a, b, batch_sizes, trans_a=False, trans_b=False, c=None):
+ if c is None:
+ c = _allocate_output(a, b, batch_sizes, trans_a, trans_b)
+ backend.gmm(a, b, c, batch_sizes, trans_a, trans_b)
+ return c
diff --git a/build/torch210-cxx11-cu130-aarch64-linux/grouped_gemm/ops.py b/build/torch210-cxx11-cu130-aarch64-linux/grouped_gemm/ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b30dd14e23837ea3b12334f4e31337ed9ad2b69
--- /dev/null
+++ b/build/torch210-cxx11-cu130-aarch64-linux/grouped_gemm/ops.py
@@ -0,0 +1,33 @@
+from . import backend
+import torch
+
+
+class GroupedGemm(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, a, b, batch_sizes, trans_b):
+ ctx.save_for_backward(a, b, batch_sizes)
+ ctx.trans_b = trans_b
+ return backend.gmm(a, b, batch_sizes, trans_a=False, trans_b=trans_b)
+
+ @staticmethod
+ def backward(ctx, grad):
+ grad = grad.contiguous()
+ a, b, batch_sizes = ctx.saved_tensors
+ trans_b = ctx.trans_b
+
+ agrad = None
+ if ctx.needs_input_grad[0]:
+ agrad = backend.gmm(
+ grad, b, batch_sizes, trans_a=False, trans_b=not trans_b)
+
+ bgrad = None
+ if ctx.needs_input_grad[1]:
+ lhs, rhs = (grad, a) if trans_b else (a, grad)
+ bgrad = backend.gmm(
+ lhs, rhs, batch_sizes, trans_a=True, trans_b=False)
+ return agrad, bgrad, None, None
+
+
+def gmm(a, b, batch_sizes, trans_b=False):
+ return GroupedGemm.apply(a, b, batch_sizes, trans_b)
diff --git a/build/torch210-cxx11-cu130-aarch64-linux/grouped_gemm_util.py b/build/torch210-cxx11-cu130-aarch64-linux/grouped_gemm_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d6d49fc46b0a57ad46e4179df3cc1ac2a24f7ae
--- /dev/null
+++ b/build/torch210-cxx11-cu130-aarch64-linux/grouped_gemm_util.py
@@ -0,0 +1,31 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+import warnings
+
+_grouped_gemm_is_available: bool = False
+try:
+ # import grouped_gemm
+ pass
+ _grouped_gemm_is_available = True
+except ImportError as error:
+ warnings.warn('Grouped GEMM not available.')
+
+
+def grouped_gemm_is_available():
+ return _grouped_gemm_is_available
+
+
+def assert_grouped_gemm_is_available():
+ msg = (
+ 'Grouped GEMM not available. Please run '
+ '`pip install git+https://github.com/tgale96/grouped_gemm@main`.',
+ )
+ assert _grouped_gemm_is_available, msg
+
+
+# backend = grouped_gemm.backend if grouped_gemm_is_available() else None
+# ops = grouped_gemm.ops if grouped_gemm_is_available() else None
+
+
+from .grouped_gemm import backend as ops
+from .grouped_gemm import ops as backend
diff --git a/build/torch210-cxx11-cu130-aarch64-linux/layers.py b/build/torch210-cxx11-cu130-aarch64-linux/layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c91edcd9e2d1a4ef9eac90217ff481f08ab1886
--- /dev/null
+++ b/build/torch210-cxx11-cu130-aarch64-linux/layers.py
@@ -0,0 +1,1232 @@
+import torch
+import torch.distributed as dist
+
+from typing import Optional, Any, TYPE_CHECKING
+
+from . import _layers
+from . import ops
+
+# Conditional import for meta kernel registration
+if TYPE_CHECKING:
+
+ def register_fake(fn):
+ return lambda name: fn
+
+else:
+ try:
+ from torch.library import register_fake
+ except ImportError:
+ try:
+ from torch.library import impl_abstract as register_fake
+ except ImportError:
+ # Fallback for older PyTorch versions
+ def register_fake(op_name):
+ def decorator(fn):
+ return fn
+
+ return decorator
+
+
+# Meta kernel implementations for torch.compile compatibility
+def _install_meta_kernels():
+ """Install meta kernels for existing MegaBlocks operations"""
+
+ # Create wrapper functions that check for compilation and return meta tensors
+
+ # Patch ops.sort
+ if hasattr(ops, "sort"):
+ original_sort = ops.sort
+
+ def sort_with_meta(x, end_bit=None):
+ if torch.compiler.is_compiling():
+ # print("Using meta kernel for sort")
+ # Meta implementation - return tensors with correct shape/dtype/device
+ return torch.empty_like(x), torch.empty_like(x)
+ # print("Using original sort kernel")
+ return original_sort(x, end_bit)
+
+ ops.sort = sort_with_meta
+
+ # Patch ops.histogram
+ if hasattr(ops, "histogram"):
+ original_histogram = ops.histogram
+
+ def histogram_with_meta(x, max_val):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty((max_val,), dtype=torch.int32, device=x.device)
+ return original_histogram(x, max_val)
+
+ ops.histogram = histogram_with_meta
+
+ # Patch ops.inclusive_cumsum
+ if hasattr(ops, "inclusive_cumsum"):
+ original_inclusive_cumsum = ops.inclusive_cumsum
+
+ def inclusive_cumsum_with_meta(x, dim):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty_like(x)
+ return original_inclusive_cumsum(x, dim)
+
+ ops.inclusive_cumsum = inclusive_cumsum_with_meta
+
+ # Patch ops.binned_gather
+ if hasattr(ops, "binned_gather"):
+ original_binned_gather = ops.binned_gather
+
+ def binned_gather_with_meta(x, indices, bins, bin_size, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - output shape based on bin_size
+ if x.dim() >= 2:
+ hidden_size = x.size(-1)
+ return torch.empty(
+ (bin_size, x.size(1), hidden_size),
+ dtype=x.dtype,
+ device=x.device,
+ )
+ else:
+ return torch.empty((bin_size,), dtype=x.dtype, device=x.device)
+ return original_binned_gather(x, indices, bins, bin_size, top_k)
+
+ ops.binned_gather = binned_gather_with_meta
+
+ # Patch ops.binned_scatter
+ if hasattr(ops, "binned_scatter"):
+ original_binned_scatter = ops.binned_scatter
+
+ def binned_scatter_with_meta(x, indices, weights, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - typically reduces to 2D
+ if x.dim() >= 3:
+ return torch.empty(
+ (x.size(1), x.size(2)), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty_like(x)
+ return original_binned_scatter(x, indices, weights, bins, top_k)
+
+ ops.binned_scatter = binned_scatter_with_meta
+
+ # Patch ops.gather
+ if hasattr(ops, "gather"):
+ original_gather = ops.gather
+
+ def gather_with_meta(x, indices, bin_ids, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ if x.dim() >= 2:
+ hidden_size = x.size(-1)
+ return torch.empty(
+ (indices.numel(), hidden_size), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty(indices.shape, dtype=x.dtype, device=x.device)
+ return original_gather(x, indices, bin_ids, bins, top_k)
+
+ ops.gather = gather_with_meta
+
+ # Patch ops.scatter
+ if hasattr(ops, "scatter"):
+ original_scatter = ops.scatter
+
+ def scatter_with_meta(x, indices, bin_ids, weights, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - restore sequence shape
+ seq_len = (
+ indices.size(0) // top_k
+ if indices.numel() > 0 and top_k > 0
+ else x.size(0)
+ )
+ if x.dim() >= 2:
+ return torch.empty(
+ (seq_len, x.size(-1)), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty((seq_len,), dtype=x.dtype, device=x.device)
+ return original_scatter(x, indices, bin_ids, weights, bins, top_k)
+
+ ops.scatter = scatter_with_meta
+
+ # Patch ops.replicate
+ if hasattr(ops, "replicate"):
+ original_replicate = ops.replicate
+
+ def replicate_with_meta(x, bins, num_outputs):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty(
+ (x.shape[0], num_outputs), dtype=x.dtype, device=x.device
+ )
+ return original_replicate(x, bins, num_outputs)
+
+ ops.replicate = replicate_with_meta
+
+ # Patch ops.repeat (if it's a regular function)
+ if hasattr(ops, "repeat"):
+ original_repeat = ops.repeat
+
+ def repeat_with_meta(x, repeats):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ if isinstance(repeats, (tuple, list)):
+ new_shape = list(x.shape)
+ for i, rep in enumerate(repeats):
+ if i < len(new_shape):
+ new_shape[i] *= rep
+ return torch.empty(new_shape, dtype=x.dtype, device=x.device)
+ else:
+ new_shape = [x.size(0) * repeats] + list(x.shape[1:])
+ return torch.empty(new_shape, dtype=x.dtype, device=x.device)
+ return original_repeat(x, repeats)
+
+ ops.repeat = repeat_with_meta
+
+
+# Install meta kernels on import
+try:
+ _install_meta_kernels()
+except Exception as e:
+ # If meta kernel installation fails, continue without them
+ # torch.compile may not work but the library will still function
+ import warnings
+
+ warnings.warn(
+ f"Failed to install meta kernels for torch.compile support: {e}", UserWarning
+ )
+
+
+# Set the expert model parallel attributes on a tensor
+def set_expert_model_parallel_attributes(
+ tensor: torch.Tensor,
+ is_parallel: bool,
+):
+ assert not hasattr(tensor, "expert_model_parallel")
+ setattr(tensor, "expert_model_parallel", is_parallel)
+
+
+# Get the expert model parallel attributes from a tensor
+def expert_sharding_degree(
+ world_size: int,
+ moe_num_experts: int,
+) -> int:
+ esd = min(world_size, moe_num_experts)
+ if (moe_num_experts % esd) != 0:
+ raise ValueError(f"Cannot shard {moe_num_experts} experts {esd} ways.")
+ return esd
+
+
+# Calculate the hidden sharding degree based on world size and expert sharding degree
+def hidden_sharding_degree(
+ world_size: int,
+ moe_num_experts: int,
+ ffn_hidden_size: int,
+) -> int:
+ esd = expert_sharding_degree(world_size, moe_num_experts)
+ hsd = world_size // esd
+ if (ffn_hidden_size % hsd) != 0:
+ raise ValueError(f"Cannot shard {ffn_hidden_size} features {hsd} ways.")
+ if (esd * hsd) != world_size:
+ raise ValueError(
+ f"Invalid sharding. expert_sharding_degree ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size})."
+ )
+ return hsd
+
+
+# Calculate the number of experts per rank based on world size and expert sharding degree
+def experts_per_rank(
+ moe_num_experts: int,
+ world_size: int,
+) -> int:
+ return moe_num_experts // expert_sharding_degree(world_size, moe_num_experts)
+
+
+# Calculate the number of features per rank based on ffn hidden size and hidden sharding degree
+def features_per_rank(
+ ffn_hidden_size: int, world_size: int, moe_num_experts: int
+) -> int:
+ return ffn_hidden_size // hidden_sharding_degree(
+ world_size, moe_num_experts, ffn_hidden_size
+ )
+
+
+# Apply jitter to the input tensor
+def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor:
+ low = 1.0 - moe_jitter_eps
+ high = 1.0 + moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return x * (low + noise * (high - low))
+
+
+# Compute the top-k scores from the logits
+def compute_top_k(scores: torch.Tensor, moe_top_k: int):
+ if moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, moe_top_k, dim=-1)
+
+
+# Route tokens to experts and compute expert weights and indices
+def route_tokens(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: torch.Tensor,
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ if training and moe_jitter_eps is not None:
+ x = apply_jitter(x, moe_jitter_eps)
+
+ x_flat = x.view(-1, x.shape[-1])
+ logits = torch.nn.functional.linear(x_flat, router_weight, router_bias)
+ expert_weights, expert_indices = compute_top_k(logits, moe_top_k)
+ expert_weights = expert_weights.softmax(dim=-1)
+ if moe_normalize_expert_weights is not None:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+ if uniform_expert_assignment:
+ expert_indices = _layers.router._uniform_expert_assignment(
+ expert_indices,
+ moe_num_experts,
+ )
+
+ return logits, expert_weights, expert_indices
+
+
+# Scale the gradient of the weights
+def scale_grad(
+ w: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+) -> torch.Tensor:
+ if gradient_scale is None:
+ return w
+ return _layers.mlp.scale_gradient(w, gradient_scale)
+
+
+# Forward pass for the MLP layer
+def mlp_forward(
+ x: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ limit: float = 7.0,
+):
+ # Scale weights
+ w1 = scale_grad(w1, gradient_scale)
+ w2 = scale_grad(w2, gradient_scale)
+ w1_bias = scale_grad(w1_bias, gradient_scale)
+ w2_bias = scale_grad(w2_bias, gradient_scale)
+
+ # Resolve dtensors
+ w1 = _layers.mlp.resolve_dtensor(w1)
+ w2 = _layers.mlp.resolve_dtensor(w2)
+ w1_bias = _layers.mlp.resolve_dtensor(w1_bias)
+ w2_bias = _layers.mlp.resolve_dtensor(w2_bias)
+
+ # Forward pass
+ gate_up = torch.bmm(x, w1) + w1_bias[..., None, :]
+ gate, up = gate_up[..., ::2], gate_up[..., 1::2]
+ gate = gate.clamp(min=None, max=limit)
+ up = up.clamp(min=-limit, max=limit)
+ glu = gate * torch.sigmoid(gate * alpha)
+ next_states = torch.bmm(((up + 1) * glu), w2)
+ next_states += w2_bias[..., None, :]
+ return next_states
+
+# Shared expert MLP forward pass
+def shared_mlp_forward(
+ x: torch.Tensor,
+ up_proj_weight: torch.Tensor,
+ down_proj_weight: torch.Tensor,
+ up_proj_bias: Optional[torch.Tensor] = None,
+ down_proj_bias: Optional[torch.Tensor] = None,
+ activation_fn: Optional[Any] = None,
+ gradient_scale: Optional[float] = None,
+) -> torch.Tensor:
+ # Default activation function
+ if activation_fn is None:
+ activation_fn = torch.nn.functional.gelu
+
+ # Scale weights
+ up_proj_weight = scale_grad(up_proj_weight, gradient_scale)
+ down_proj_weight = scale_grad(down_proj_weight, gradient_scale)
+ if up_proj_bias is not None:
+ up_proj_bias = scale_grad(up_proj_bias, gradient_scale)
+ if down_proj_bias is not None:
+ down_proj_bias = scale_grad(down_proj_bias, gradient_scale)
+
+ # Resolve dtensors
+ up_proj_weight = _layers.mlp.resolve_dtensor(up_proj_weight)
+ down_proj_weight = _layers.mlp.resolve_dtensor(down_proj_weight)
+ if up_proj_bias is not None:
+ up_proj_bias = _layers.mlp.resolve_dtensor(up_proj_bias)
+ if down_proj_bias is not None:
+ down_proj_bias = _layers.mlp.resolve_dtensor(down_proj_bias)
+
+ # Up projection
+ x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias)
+
+ # Activation
+ x = activation_fn(x)
+
+ # Down projection
+ x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias)
+
+ return x
+
+
+# Combine outputs from shared expert and regular experts
+def combine_expert_shared_outputs(
+ shared_expert_out: torch.Tensor,
+ expert_out: torch.Tensor,
+ shared_expert_weighted_sum: bool = False,
+ moe_top_k: int = 1,
+) -> torch.Tensor:
+ if shared_expert_weighted_sum:
+ # Weighted sum based on number of experts used
+ total_experts = moe_top_k + 1
+ shared_weight = 1.0 / total_experts
+ expert_weight = moe_top_k / total_experts
+ return shared_expert_out * shared_weight + expert_out * expert_weight
+ else:
+ # Simple addition
+ return shared_expert_out + expert_out
+
+
+# Global variable to store load balancing loss
+_LOAD_BALANCING_LOSS = []
+
+
+def save_load_balancing_loss(loss):
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.append(loss)
+
+
+def get_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ return _LOAD_BALANCING_LOSS
+
+
+def clear_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.clear()
+
+
+def batched_load_balancing_loss(args):
+ if args.moe_loss_weight == 0:
+ return 0.0
+
+ tokens_per_expert, expert_scores = zip(*get_load_balancing_loss())
+ num_layers_per_pipeline_stage = args.num_layers // args.pipeline_model_parallel_size
+ if args.num_layers_per_virtual_pipeline_stage is not None:
+ num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage
+
+ if len(tokens_per_expert) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f"Expected {num_layers_per_pipeline_stage} token_per_experts "
+ f"but found {len(tokens_per_expert)}.\nnum_layers = "
+ f"{args.num_layers}\npipeline_model_parallel_size = "
+ f"{args.pipeline_model_parallel_size}\n"
+ "num_layers_per_virtual_pipeline_stage"
+ f" = {args.num_layers_per_virtual_pipeline_stage}",
+ )
+ if len(expert_scores) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f"Expected {num_layers_per_pipeline_stage} expert_scores "
+ f"but found {len(tokens_per_expert)}.\nnum_layers = "
+ f"{args.num_layers}\npipeline_model_parallel_size = "
+ f"{args.pipeline_model_parallel_size}\n"
+ "num_layers_per_virtual_pipeline_stage"
+ f" = {args.num_layers_per_virtual_pipeline_stage}",
+ )
+
+ # Verify the shape of the tokens_per_expert and expert_scores tensors.
+ assert all(
+ (x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert)
+ )
+
+ tokens = expert_scores[0].shape[0]
+ assert all(
+ (
+ (
+ x.ndim == 2
+ and x.shape[1] == args.moe_num_experts
+ and x.shape[0] == tokens
+ )
+ for x in expert_scores
+ )
+ )
+
+ # Concatenate the contributions of each layer and convert to
+ # the correct types and formats for the dot product.
+ expert_scores = torch.cat(expert_scores, dim=1)
+ if args.moe_lbl_in_fp32:
+ expert_scores = expert_scores.float()
+ if tokens != 0:
+ expert_scores = expert_scores.mean(dim=0)
+ else:
+ expert_scores = expert_scores.sum(dim=0)
+ tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype)
+
+ expected_values = num_layers_per_pipeline_stage * args.moe_num_experts
+ assert tokens_per_expert.numel() == expected_values
+ assert expert_scores.numel() == expected_values
+
+ # Calculate the total scale across all factors.
+ #
+ # loss_weight * num_experts / (num_layers * tokens * top_k)
+ scale_numerator = args.moe_num_experts * args.moe_loss_weight
+ scale_denominator = args.num_layers * tokens * args.moe_top_k
+ scale = scale_numerator / scale_denominator
+ return scale * torch.dot(tokens_per_expert, expert_scores)
+
+
+# Calculate the expert capacity based on tokens, top_k, number of experts,
+# expert parallel group, capacity factor, and whether expert model parallelism is used.
+def expert_capacity(
+ tokens: int,
+ top_k: int,
+ num_experts: int,
+ expert_parallel_group: int,
+ moe_capacity_factor: float,
+ moe_expert_model_parallelism: bool,
+) -> int:
+ world_size = (
+ dist.get_world_size(expert_parallel_group)
+ if moe_expert_model_parallelism
+ else 1
+ )
+
+ tokens_per_expert = top_k * tokens * world_size / num_experts
+ return int(moe_capacity_factor * tokens_per_expert)
+
+
+def load_balancing_loss(
+ tokens_per_expert: torch.Tensor,
+ expert_scores: torch.Tensor,
+ top_k: int,
+ num_experts: int,
+):
+ assert len(expert_scores.size()) == 2
+ tokens, num_experts = expert_scores.size()
+ assert num_experts == num_experts
+ assert len(tokens_per_expert.size()) == 1
+ (num_experts,) = tokens_per_expert.size()
+ assert num_experts == num_experts
+ scale = num_experts / (tokens * top_k)
+ return scale * torch.dot(
+ tokens_per_expert.to(expert_scores.dtype),
+ expert_scores.mean(dim=0),
+ )
+
+
+def indices_and_bins(
+ top_expert: torch.Tensor,
+ sort_end_bit: int,
+ num_experts: int,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ top_expert = top_expert.int()
+
+ # Ensure contiguous memory layout
+ top_expert = top_expert.contiguous()
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(top_expert.device):
+ output = ops.sort(top_expert, sort_end_bit)
+ bin_ids, indices = output
+ tokens_per_expert = ops.histogram(top_expert, num_experts)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+
+ bins = bins.view(1) if not len(bins.size()) else bins
+ return indices, bin_ids, bins, tokens_per_expert
+
+
+def expert_capacity_fn(
+ tokens: int,
+ top_k: int,
+ num_experts: int,
+ expert_parallel_group: torch.distributed.ProcessGroup,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+) -> int:
+ world_size = (
+ dist.get_world_size(expert_parallel_group)
+ if moe_expert_model_parallelism
+ else 1
+ )
+ tokens_per_expert = top_k * tokens * world_size / num_experts
+ return int(moe_capacity_factor * tokens_per_expert)
+
+
+def permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ top_k,
+ w1,
+ w2,
+ w1_bias,
+ w2_bias,
+ gradient_scale,
+ alpha,
+):
+ # Route tokens to experts
+ x = x.view(-1, x.shape[-1])
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(x.device):
+ x = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
+
+ # Expert computation
+ x = mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale, alpha)
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(x.device):
+ # Route tokens back
+ out = ops.binned_scatter(x, indices, expert_weights, bins, top_k)
+ return out
+
+
+def forward_once(
+ x: torch.Tensor,
+ expert_weights: torch.Tensor,
+ top_experts: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ top_k: int = 4,
+ num_experts: int = 128,
+ expert_parallel_group: int = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ mlp_impl: Optional[str] = None,
+):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
+ top_experts, sort_end_bit, num_experts
+ )
+
+ # Calculate expert capacity
+ sl, bs, _ = x.size()
+
+ expert_capacity = expert_capacity_fn(
+ sl * bs,
+ top_k,
+ num_experts,
+ expert_parallel_group,
+ moe_capacity_factor,
+ moe_expert_model_parallelism,
+ )
+
+ if expert_capacity == 0:
+ expert_capacity = torch.max(tokens_per_expert).item()
+
+ x = permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ top_k,
+ w1,
+ w2,
+ w1_bias,
+ w2_bias,
+ gradient_scale,
+ alpha,
+ )
+ return x, tokens_per_expert
+
+
+def parallel_forward_once(
+ x: torch.Tensor,
+ expert_weights: torch.Tensor,
+ top_experts: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ top_k: int = 4,
+ num_experts: int = 128,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = True,
+ hidden_size: int = 1152,
+ mlp_impl: Optional[str] = "grouped",
+):
+ # Flatten inputs
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+
+ # TODO: remove debugging var
+ # my_rank = dist.get_rank(expert_parallel_group) if expert_parallel_group else 0
+
+ with torch.no_grad():
+ # Step 1: Local permutation setup
+ indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
+ top_experts, sort_end_bit, num_experts
+ )
+
+ # Calculate sharding parameters
+ world_size = dist.get_world_size(expert_parallel_group)
+ hidden_sharding_deg = hidden_sharding_degree(
+ world_size, num_experts, hidden_size
+ )
+ experts_per_rank_val = experts_per_rank(num_experts, world_size)
+
+ # Replicate token counts for hidden sharding
+ repeated_tokens_per_expert = ops.repeat(
+ tokens_per_expert, (hidden_sharding_deg,)
+ )
+
+ # Exchange token counts across devices
+ parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert)
+
+ # Ensure CUB knows which device to use
+ tpe_handle = dist.all_to_all_single(
+ parallel_tokens_per_expert,
+ repeated_tokens_per_expert,
+ group=expert_parallel_group,
+ async_op=True,
+ )
+
+ # Step 2: Local permutation - group tokens by target device
+ x = x.view(-1, x.shape[-1]) # [sl * bs, hs]
+ x = ops.gather(x, indices, bin_ids, bins, top_k)
+
+ # Step 3: Compute communication counts and exchange tokens
+ with torch.no_grad():
+ tpe_handle.wait()
+
+ # Reshape for per-device calculations
+ repeated_tokens_per_expert = repeated_tokens_per_expert.view(
+ world_size, experts_per_rank_val
+ )
+ parallel_tokens_per_expert = parallel_tokens_per_expert.view(
+ world_size, experts_per_rank_val
+ )
+
+ # Calculate send/recv counts
+ send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1).tolist()
+ # recv_counts = parallel_tokens_per_expert.cpu().sum(dim=-1).tolist()
+ parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
+ recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1).tolist()
+ tokens_received = sum(recv_counts)
+
+ # Replicate for hidden sharding
+ x = ops.repeat(x, (hidden_sharding_deg, 1))
+
+ # Cross-device token exchange
+ parallel_x, parallel_x_handle = _layers.all_to_all.all_to_all(
+ x, recv_counts, send_counts, expert_parallel_group, async_op=True
+ )
+
+ with torch.no_grad():
+ # Step 4: Setup for local expert computation
+ replicate_bins = ops.inclusive_cumsum(parallel_tokens_per_expert.flatten(), 0)
+ replicate_bins = (
+ replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins
+ )
+
+ # Create expert indices for received tokens
+ parallel_top_expert = torch.remainder(
+ torch.arange(
+ num_experts * hidden_sharding_deg,
+ dtype=torch.int32,
+ device=indices.device,
+ ),
+ experts_per_rank_val,
+ )
+ parallel_top_expert = ops.replicate(
+ parallel_top_expert.unsqueeze(dim=0),
+ replicate_bins,
+ tokens_received,
+ ).flatten()
+
+ # Sort tokens by expert assignment
+ parallel_bin_ids, parallel_indices = ops.sort(
+ parallel_top_expert,
+ sort_end_bit,
+ )
+
+ # Calculate bins for local experts
+ parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
+ dim=0, dtype=torch.int
+ )
+ parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = (
+ parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins
+ )
+
+ # Calculate expert capacity
+ expert_capacity = expert_capacity_fn(
+ tokens_received,
+ top_k,
+ experts_per_rank_val,
+ expert_parallel_group,
+ moe_capacity_factor,
+ moe_expert_model_parallelism,
+ )
+ if expert_capacity == 0:
+ expert_capacity = torch.max(parallel_tokens_per_expert).item()
+
+ # Locally permute the tokens and perform the expert computation.
+ # Block to make sure that the cross-device permutation is complete.
+ if mlp_impl == "grouped":
+ # GroupedMLP requires counts on CPU. We can use the tensor already
+ # moved to CPU for the prior all_to_all, which avoids an extra
+ # device synchronization.
+ parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+
+ # Step 5: Expert computation
+ parallel_x_handle.wait()
+
+ parallel_x = permute_and_compute(
+ parallel_x,
+ parallel_tokens_per_expert,
+ parallel_indices,
+ parallel_bin_ids,
+ None, # expert_weights
+ parallel_bins,
+ expert_capacity,
+ top_k=1,
+ w1=w1,
+ w2=w2,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ )
+
+ # Step 6: Reverse communication - send results back
+ x, _ = _layers.all_to_all.all_to_all(
+ parallel_x, send_counts, recv_counts, expert_parallel_group
+ )
+
+ # Step 7: Reduce across hidden sharding dimension
+ shape = (hidden_sharding_deg, -1, hidden_size)
+ x = x.view(shape).sum(dim=0)
+
+ # Step 8: Final local unpermutation
+ x = ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
+
+ return x, tokens_per_expert.flatten()
+
+
+def moe_forward(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: Optional[torch.Tensor],
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+ w1: torch.Tensor = None,
+ w2: torch.Tensor = None,
+ w1_bias: torch.Tensor = None,
+ w2_bias: torch.Tensor = None,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ forward_fn: Any = None,
+ hidden_size: int = None,
+ mlp_impl: str = "grouped",
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+
+ # Route tokens to experts
+ logits, expert_weights, expert_indices = route_tokens(
+ x,
+ router_weight,
+ router_bias,
+ moe_top_k,
+ moe_num_experts,
+ moe_jitter_eps,
+ moe_normalize_expert_weights,
+ uniform_expert_assignment,
+ training,
+ )
+
+ # Create router scores for output
+ router_scores = (
+ torch.zeros_like(logits)
+ .scatter_(1, expert_indices, expert_weights)
+ .transpose(0, 1)
+ )
+
+ in_shape = x.size()
+
+ # Prepare forward function arguments
+ forward_args = {
+ "x": x,
+ "expert_weights": expert_weights,
+ "top_experts": expert_indices,
+ "w1": w1,
+ "w2": w2,
+ "w1_bias": w1_bias,
+ "w2_bias": w2_bias,
+ "gradient_scale": gradient_scale,
+ "alpha": alpha,
+ "sort_end_bit": sort_end_bit,
+ "top_k": moe_top_k,
+ "num_experts": moe_num_experts,
+ "expert_parallel_group": expert_parallel_group,
+ "moe_capacity_factor": moe_capacity_factor,
+ "moe_expert_model_parallelism": moe_expert_model_parallelism,
+ "mlp_impl": mlp_impl,
+ }
+
+ # Add hidden_size for parallel forward
+ if moe_expert_model_parallelism and hidden_size is not None:
+ forward_args["hidden_size"] = hidden_size
+ elif moe_expert_model_parallelism and hidden_size is None:
+ # Infer hidden_size from input shape
+ forward_args["hidden_size"] = x.shape[-1]
+
+ # Compute expert outputs
+ x, tokens_per_expert = forward_fn(**forward_args)
+
+ # Save load balancing loss if needed
+ moe_loss_weight = 0.0 # Can be made configurable
+ if training and moe_loss_weight > 0:
+ save_load_balancing_loss((tokens_per_expert, logits))
+
+ # Restore original shape
+ x = x.view(in_shape)
+
+ return x, expert_weights, router_scores
+
+
+def moe_forward_with_shared_expert(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: Optional[torch.Tensor],
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+ w1: torch.Tensor = None,
+ w2: torch.Tensor = None,
+ w1_bias: torch.Tensor = None,
+ w2_bias: torch.Tensor = None,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ forward_fn: Any = None,
+ hidden_size: int = None,
+ mlp_impl: str = "grouped",
+ # Shared expert parameters
+ shared_up_proj_weight: Optional[torch.Tensor] = None,
+ shared_down_proj_weight: Optional[torch.Tensor] = None,
+ shared_up_proj_bias: Optional[torch.Tensor] = None,
+ shared_down_proj_bias: Optional[torch.Tensor] = None,
+ shared_expert_weighted_sum: bool = False,
+ shared_activation_fn: Optional[Any] = None,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+
+ # First, compute regular MoE forward pass
+ expert_out, expert_weights, router_scores = moe_forward(
+ x=x,
+ router_weight=router_weight,
+ router_bias=router_bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=training,
+ w1=w1,
+ w2=w2,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=moe_expert_model_parallelism,
+ forward_fn=forward_fn,
+ hidden_size=hidden_size,
+ mlp_impl=mlp_impl,
+ )
+
+ # If shared expert weights provided, compute shared expert output
+ if shared_up_proj_weight is not None and shared_down_proj_weight is not None:
+ shared_expert_out = shared_mlp_forward(
+ x=x,
+ up_proj_weight=shared_up_proj_weight,
+ down_proj_weight=shared_down_proj_weight,
+ up_proj_bias=shared_up_proj_bias,
+ down_proj_bias=shared_down_proj_bias,
+ activation_fn=shared_activation_fn,
+ gradient_scale=gradient_scale,
+ )
+
+ # Combine expert outputs
+ combined_out = combine_expert_shared_outputs(
+ shared_expert_out=shared_expert_out,
+ expert_out=expert_out,
+ shared_expert_weighted_sum=shared_expert_weighted_sum,
+ moe_top_k=moe_top_k,
+ )
+
+ return combined_out, expert_weights, router_scores
+
+ # Return regular MoE output if no shared expert
+ return expert_out, expert_weights, router_scores
+
+
+def create_shared_expert_weights(
+ hidden_size: int,
+ shared_expert_hidden_size: int,
+ device: torch.device,
+ dtype: torch.dtype,
+ init_method: Any,
+ output_layer_init_method: Any = None,
+) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
+
+ if output_layer_init_method is None:
+ output_layer_init_method = init_method
+
+ # Create weight tensors
+ up_proj_weight = torch.empty(
+ shared_expert_hidden_size,
+ hidden_size,
+ device=device,
+ dtype=dtype,
+ )
+ down_proj_weight = torch.empty(
+ hidden_size,
+ shared_expert_hidden_size,
+ device=device,
+ dtype=dtype,
+ )
+
+ # Initialize weights
+ init_method(up_proj_weight)
+ output_layer_init_method(down_proj_weight)
+
+ # No bias by default
+ return up_proj_weight, down_proj_weight, None, None
+
+
+# HACK: Extract device_mesh from pre-hook closure - required for transformers integration
+# This exists because device_mesh is trapped in hook closures with no model attribute
+# Fragile - breaks if hook structure changes or Python internals change
+# TODO: Replace with a more robust solution when available
+def get_device_mesh(model):
+ # Extract device_mesh from child's unused pre_hook closure
+ try:
+ # Find the pre-hook that contains 'device_mesh' in its closure
+ hook = next(
+ h
+ for h in model.experts._forward_pre_hooks.values()
+ if "device_mesh" in h.__code__.co_freevars
+ )
+ # Extract the device_mesh from the closure
+ return hook.__closure__[
+ hook.__code__.co_freevars.index("device_mesh")
+ ].cell_contents
+ except Exception:
+ return None
+
+
+class MegaBlocksMoeMLP(torch.nn.Module):
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ gradient_scale = getattr(self.experts, "gradient_scale", None)
+ alpha = getattr(self.experts, "alpha", 1.0)
+ moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+ uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
+
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = get_device_mesh(self)
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
+
+ has_parallel = (
+ expert_parallel_group is not None
+ and dist.is_initialized()
+ and dist.get_world_size(expert_parallel_group) > 1
+ )
+ forward_fn = parallel_forward_once if has_parallel else forward_once
+
+ sort_end_bit = max(
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
+ )
+ mlp_impl = getattr(self, "mlp_impl", "grouped")
+ output, expert_weights_out, *_ = moe_forward(
+ x=x,
+ router_weight=self.router.weight,
+ router_bias=self.router.bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=self.training,
+ w1=self.experts.gate_up_proj,
+ w2=self.experts.down_proj,
+ w1_bias=self.experts.gate_up_proj_bias,
+ w2_bias=self.experts.down_proj_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=has_parallel,
+ forward_fn=forward_fn,
+ hidden_size=self.experts.hidden_size,
+ mlp_impl=mlp_impl,
+ )
+ return output, expert_weights_out
+
+
+# Export main classes
+__all__ = ["MegaBlocksMoeMLP", "MegaBlocksMoeMLPWithSharedExpert"]
+
+
+class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
+
+ def __init__(self):
+ super().__init__()
+ # Shared expert weights will be set by the user
+ self.shared_up_proj_weight = None
+ self.shared_down_proj_weight = None
+ self.shared_up_proj_bias = None
+ self.shared_down_proj_bias = None
+ self.shared_expert_weighted_sum = False
+ self.shared_activation_fn = None
+
+ def set_shared_expert_weights(
+ self,
+ up_proj_weight: torch.Tensor,
+ down_proj_weight: torch.Tensor,
+ up_proj_bias: Optional[torch.Tensor] = None,
+ down_proj_bias: Optional[torch.Tensor] = None,
+ weighted_sum: bool = False,
+ activation_fn: Optional[Any] = None,
+ ):
+ self.shared_up_proj_weight = up_proj_weight
+ self.shared_down_proj_weight = down_proj_weight
+ self.shared_up_proj_bias = up_proj_bias
+ self.shared_down_proj_bias = down_proj_bias
+ self.shared_expert_weighted_sum = weighted_sum
+ self.shared_activation_fn = activation_fn
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ gradient_scale = getattr(self.experts, "gradient_scale", None)
+ alpha = getattr(self.experts, "alpha", 1.0)
+ moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+ uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
+
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = get_device_mesh(self)
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
+
+ has_parallel = (
+ expert_parallel_group is not None
+ and dist.is_initialized()
+ and dist.get_world_size(expert_parallel_group) > 1
+ )
+ forward_fn = parallel_forward_once if has_parallel else forward_once
+
+ sort_end_bit = max(
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
+ )
+ mlp_impl = getattr(self, "mlp_impl", "grouped")
+
+ output, expert_weights_out, *_ = moe_forward_with_shared_expert(
+ x=x,
+ router_weight=self.router.weight,
+ router_bias=self.router.bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=self.training,
+ w1=self.experts.gate_up_proj,
+ w2=self.experts.down_proj,
+ w1_bias=self.experts.gate_up_proj_bias,
+ w2_bias=self.experts.down_proj_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=has_parallel,
+ forward_fn=forward_fn,
+ hidden_size=self.experts.hidden_size,
+ mlp_impl=mlp_impl,
+ # Shared expert parameters
+ shared_up_proj_weight=self.shared_up_proj_weight,
+ shared_down_proj_weight=self.shared_down_proj_weight,
+ shared_up_proj_bias=self.shared_up_proj_bias,
+ shared_down_proj_bias=self.shared_down_proj_bias,
+ shared_expert_weighted_sum=self.shared_expert_weighted_sum,
+ shared_activation_fn=self.shared_activation_fn,
+ )
+ return output, expert_weights_out
+
+
+# Patch for XPU or CPU support
+if hasattr(torch, "xpu") and torch.xpu.is_available():
+ from .xpu_fused_moe import MegaBlocksMoeMLP
+
+from .cpu_moe_cpp import CPUMegaBlocksMoeMLP
diff --git a/build/torch210-cxx11-cu130-aarch64-linux/megablocks/__init__.py b/build/torch210-cxx11-cu130-aarch64-linux/megablocks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23
--- /dev/null
+++ b/build/torch210-cxx11-cu130-aarch64-linux/megablocks/__init__.py
@@ -0,0 +1,26 @@
+import ctypes
+import importlib.util
+import sys
+from pathlib import Path
+from types import ModuleType
+
+
+def _import_from_path(file_path: Path) -> ModuleType:
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
+ # it would also be used for other imports. So, we make a module name that
+ # depends on the path for it to be unique using the hex-encoded hash of
+ # the path.
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
+ module_name = path_hash
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ if spec is None:
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
+ module = importlib.util.module_from_spec(spec)
+ if module is None:
+ raise ImportError(f"Cannot load module {module_name} from spec")
+ sys.modules[module_name] = module
+ spec.loader.exec_module(module) # type: ignore
+ return module
+
+
+globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
diff --git a/build/torch210-cxx11-cu130-aarch64-linux/metadata.json b/build/torch210-cxx11-cu130-aarch64-linux/metadata.json
new file mode 100644
index 0000000000000000000000000000000000000000..a9813b81c6c98110d265c184f2016d728202289b
--- /dev/null
+++ b/build/torch210-cxx11-cu130-aarch64-linux/metadata.json
@@ -0,0 +1,18 @@
+{
+ "version": 1,
+ "license": "Apache-2.0",
+ "python-depends": [],
+ "backend": {
+ "type": "cuda",
+ "archs": [
+ "10.0",
+ "12.0",
+ "7.5",
+ "8.0",
+ "8.6",
+ "8.7",
+ "8.9",
+ "9.0"
+ ]
+ }
+}
diff --git a/build/torch210-cxx11-cu130-aarch64-linux/ops/__init__.py b/build/torch210-cxx11-cu130-aarch64-linux/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b944080df810d0b0cfc571f3009b0098a651f9b7
--- /dev/null
+++ b/build/torch210-cxx11-cu130-aarch64-linux/ops/__init__.py
@@ -0,0 +1,35 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from .binned_gather import binned_gather
+from .binned_scatter import binned_scatter
+from .cumsum import exclusive_cumsum, inclusive_cumsum
+from .gather import gather
+from .histogram import histogram
+from .padded_gather import padded_gather
+from .padded_scatter import padded_scatter
+from .repeat import repeat
+from .replicate import replicate
+from .round_up import round_up
+from .scatter import scatter
+from .sort import sort
+from .sum import sum
+from .topology import topology
+
+__all__ = [
+ 'binned_gather',
+ 'binned_scatter',
+ 'exclusive_cumsum',
+ 'inclusive_cumsum',
+ 'gather',
+ 'histogram',
+ 'padded_gather',
+ 'padded_scatter',
+ 'repeat',
+ 'replicate',
+ 'round_up',
+ 'scatter',
+ 'sort',
+ 'sum',
+ 'topology',
+]
diff --git a/build/torch210-cxx11-cu130-aarch64-linux/ops/all_to_all_benchmark.py b/build/torch210-cxx11-cu130-aarch64-linux/ops/all_to_all_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c939818edca3345f6344bbc7cef07ffe3cd0181
--- /dev/null
+++ b/build/torch210-cxx11-cu130-aarch64-linux/ops/all_to_all_benchmark.py
@@ -0,0 +1,63 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import torch.distributed as dist
+
+# from megablocks import benchmark_util
+# from megablocks.layers.all_to_all import all_to_all
+
+from .. import benchmark_util
+from .._layers.all_to_all import all_to_all
+
+_ALL_TO_ALL_BENCHMARK = (
+ (8, 1024),
+ (16, 1024),
+ (32, 1024),
+ (64, 1024),
+ (128, 1024),
+ (256, 1024),
+ (512, 1024),
+ (1024, 1024),
+ (2 * 1024, 1024),
+ (4 * 1024, 1024),
+ (8 * 1024, 1024),
+ (16 * 1024, 1024),
+ (32 * 1024, 1024),
+ (64 * 1024, 1024),
+ (128 * 1024, 1024),
+ (256 * 1024, 1024),
+ (512 * 1024, 1024),
+ (1024 * 1024, 1024),
+)
+
+
+def benchmark_all_to_all(group, sl, hs):
+ world_size = dist.get_world_size(group)
+ assert (sl % world_size) == 0
+ send_recv_sizes = [sl // world_size] * world_size
+
+ x = torch.randn((sl, hs)).cuda().half()
+
+ details = {
+ 'world_size': world_size,
+ 'message_size (B)': send_recv_sizes[0] * hs * 2, # 2B elements.
+ }
+
+ def benchmark():
+ return all_to_all(x, send_recv_sizes, send_recv_sizes, group)
+
+ time, std = benchmark_util.benchmark_function(benchmark)
+
+ if dist.get_rank(group) == 0:
+ benchmark_util.log_benchmark('All-To-All', details, time, std)
+
+
+if __name__ == '__main__':
+ assert dist.is_available()
+ group = dist.init_process_group(backend='nccl')
+ local_rank = dist.get_rank(group)
+ torch.cuda.set_device(local_rank)
+
+ for args in _ALL_TO_ALL_BENCHMARK:
+ benchmark_all_to_all(group, *args)
diff --git a/build/torch210-cxx11-cu130-aarch64-linux/ops/binned_gather.py b/build/torch210-cxx11-cu130-aarch64-linux/ops/binned_gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..189a7fa3518d660f29ea32e7a04827164af98d60
--- /dev/null
+++ b/build/torch210-cxx11-cu130-aarch64-linux/ops/binned_gather.py
@@ -0,0 +1,37 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for binned_gather kernel.
+class BinnedGatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bins: torch.Tensor,
+ bin_size: int,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bins)
+ ctx.top_k = top_k
+ return kernels.binned_gather(x, indices, None, bins, bin_size, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ indices, bins = ctx.saved_tensors
+ out = kernels.binned_scatter(grad, indices, None, bins, ctx.top_k)
+ return out, None, None, None, None
+
+
+binned_gather = BinnedGatherOp.apply
diff --git a/build/torch210-cxx11-cu130-aarch64-linux/ops/binned_scatter.py b/build/torch210-cxx11-cu130-aarch64-linux/ops/binned_scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb937c0c106662ce8108c1cb926f8f063b163d3d
--- /dev/null
+++ b/build/torch210-cxx11-cu130-aarch64-linux/ops/binned_scatter.py
@@ -0,0 +1,59 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for binned_scatter kernel.
+class BinnedScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ):
+ assert len(x.size()) == 3
+ ctx.bin_size = x.size(1)
+ ctx.top_k = top_k
+
+ # TODO(tgale): Don't save 'x' for backwards if we don't need to
+ # calculate the gradient w.r.t. 'weights'.
+ ctx.save_for_backward(x, indices, weights, bins)
+ return kernels.binned_scatter(x, indices, weights, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ x, indices, weights, bins = ctx.saved_tensors
+ out = kernels.binned_gather(
+ grad,
+ indices,
+ weights,
+ bins,
+ ctx.bin_size,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[2]:
+ wgrad = kernels.binned_scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bins,
+ ctx.top_k,
+ )
+ return out, None, wgrad, None, None
+
+
+binned_scatter = BinnedScatterOp.apply
diff --git a/build/torch210-cxx11-cu130-aarch64-linux/ops/cumsum.py b/build/torch210-cxx11-cu130-aarch64-linux/ops/cumsum.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2b7572391e20045d335cf7337246e8a9b9f57ef
--- /dev/null
+++ b/build/torch210-cxx11-cu130-aarch64-linux/ops/cumsum.py
@@ -0,0 +1,52 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ # import megablocks_ops as ops # type: ignore
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrappers for cumsum kernels.
+# NOTE: Does not support gradients.
+class ExclusiveCumsumOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, dim: int):
+ if len(x.size()) == 1:
+ x = x.view([1, -1])
+ out = torch.empty_like(x)
+ ops.exclusive_cumsum(x, 1, out)
+ return out.squeeze()
+ out = torch.empty_like(x)
+ ops.exclusive_cumsum(x, dim, out)
+ return out
+
+
+exclusive_cumsum = ExclusiveCumsumOp.apply
+
+
+class InclusiveCumsumOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, dim: int) -> torch.Tensor:
+ if len(x.size()) == 1:
+ x = x.view([1, -1])
+ out = torch.empty_like(x)
+ ops.inclusive_cumsum(x, 1, out)
+ return out.squeeze()
+ out = torch.empty_like(x)
+ ops.inclusive_cumsum(x, dim, out)
+ return out
+
+
+inclusive_cumsum = InclusiveCumsumOp.apply
diff --git a/build/torch210-cxx11-cu130-aarch64-linux/ops/gather.py b/build/torch210-cxx11-cu130-aarch64-linux/ops/gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1f87c1e7bed8d3589dd790805234976e0b05898
--- /dev/null
+++ b/build/torch210-cxx11-cu130-aarch64-linux/ops/gather.py
@@ -0,0 +1,38 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for gather kernel.
+class GatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bin_ids, bins)
+ ctx.top_k = top_k
+ return kernels.gather(x, indices, bin_ids, None, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+
+ indices, bin_ids, bins = ctx.saved_tensors
+ out = kernels.scatter(grad, indices, bin_ids, None, bins, ctx.top_k)
+ return out, None, None, None, None, None
+
+
+gather = GatherOp.apply
diff --git a/build/torch210-cxx11-cu130-aarch64-linux/ops/histogram.py b/build/torch210-cxx11-cu130-aarch64-linux/ops/histogram.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b3f058ec373cbba7555704fb5e4212c3cc75d9d
--- /dev/null
+++ b/build/torch210-cxx11-cu130-aarch64-linux/ops/histogram.py
@@ -0,0 +1,27 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for histogram kernel.
+# NOTE: Does not support gradients.
+class HistogramOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, max_val: float):
+ return ops.histogram(x, max_val)
+
+
+histogram = HistogramOp.apply
diff --git a/build/torch210-cxx11-cu130-aarch64-linux/ops/histogram_benchmark.py b/build/torch210-cxx11-cu130-aarch64-linux/ops/histogram_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0b9c6047567b87a295979498142230d1b0c9411
--- /dev/null
+++ b/build/torch210-cxx11-cu130-aarch64-linux/ops/histogram_benchmark.py
@@ -0,0 +1,78 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import numpy as np
+import torch
+# from absl.testing import parameterized
+
+from .. import ops
+
+_HISTOGRAM_TESTS = (
+ (16384, torch.int32, 2),
+ (16384, torch.int32, 4),
+ (16384, torch.int32, 8),
+ (16384, torch.int32, 16),
+ (16384, torch.int32, 32),
+ (16384, torch.int32, 64),
+ (16384, torch.int32, 128),
+ (16384, torch.int32, 256),
+)
+
+
+def benchmark_function(fn, iterations=10):
+ # Run once to get rid of startup overhead.
+ fn()
+ times = []
+ for _ in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+ start.record()
+ fn()
+ end.record()
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ times = np.array(times)
+ return times.mean(), times.std(), times.max(), times.min()
+
+
+def log_benchmark(arguments, mean_t, std_t):
+ print('=' * 60)
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t))
+ print('=' * 60)
+
+
+# class HistogramBenchmark(parameterized.TestCase):
+#
+# @parameterized.parameters(*_HISTOGRAM_TESTS)
+# def testHistogram(self, n, dtype, max_val):
+# x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
+#
+# mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.histogram(x, max_val),)
+# arguments = {
+# 'n': n,
+# 'dtype': dtype,
+# 'max_val': max_val,
+# }
+# log_benchmark(arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_HISTOGRAM_TESTS)
+# def testTorchHistogram(self, n, dtype, max_val):
+# x = torch.randint(0, 128, (n,)).cuda().to(dtype)
+#
+# mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.histc(x, max_val, 0, max_val - 1),)
+# arguments = {
+# 'n': n,
+# 'dtype': dtype,
+# 'max_val': max_val,
+# }
+# log_benchmark(arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch210-cxx11-cu130-aarch64-linux/ops/matmul_benchmark.py b/build/torch210-cxx11-cu130-aarch64-linux/ops/matmul_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c4b9b8866ffed2eb769b77f2320c82e5491ae0e
--- /dev/null
+++ b/build/torch210-cxx11-cu130-aarch64-linux/ops/matmul_benchmark.py
@@ -0,0 +1,415 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+
+# import stk
+
+# try:
+# import stk
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/ops/matmul_benchmark.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+# from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+
+# Calling tensor.t() calls tensor.transpose(0, 1) which calls
+# torch.as_strided(...). Circumvent this chain to avoid an overhead
+# this adds.
+def transpose_view(x):
+ return torch.as_strided(
+ x,
+ (x.shape[1], x.shape[0]),
+ (x.stride()[1], x.stride()[0]),
+ )
+
+
+_MATMUL_TESTS = (
+ (64 * 1024, 512, 2048, 64),
+ (32 * 1024, 768, 3072, 64),
+ (8 * 1024, 1024, 4096, 64),
+ (4 * 2048, 4096, 4 * 4096, 4),
+)
+
+
+def log_benchmark(name, arguments, time, std, flops):
+ benchmark_util.log_benchmark(name, arguments, time, std)
+ print('flops = {:.2f}B'.format(flops / 1e9))
+ print('throughput = {:.2f}T'.format(flops / 1e9 / time))
+ print('=' * 60)
+
+
+# class MatmulBenchmark(parameterized.TestCase):
+#
+# def build_sparse_matrix(self, x, padded_bins, fhs, ne):
+# blocking = 128
+# padded_tokens, _ = x.size()
+# assert padded_tokens % blocking == 0
+# assert fhs % blocking == 0
+#
+# # Offsets for the sparse matrix. All rows have the
+# # same number of nonzero blocks dictated by the
+# # dimensionality of a single expert.
+# block_rows = padded_tokens // blocking
+# blocks_per_row = fhs // blocking
+# offsets = torch.arange(
+# 0,
+# block_rows * blocks_per_row + 1,
+# blocks_per_row,
+# dtype=torch.int32,
+# device=x.device,
+# )
+#
+# # Indices for the sparse matrix. The indices for
+# # the intermediate matrix are dynamic depending
+# # on the mapping of tokens to experts.
+# column_indices = ops.topology(
+# padded_bins,
+# blocking,
+# block_rows,
+# blocks_per_row,
+# )
+# data = torch.empty(
+# column_indices.numel(),
+# blocking,
+# blocking,
+# dtype=torch.float16,
+# device=x.device,
+# )
+# shape = (padded_tokens, fhs * ne)
+# row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
+# return stk.Matrix(shape, data, row_indices, column_indices, offsets)
+#
+# def build_input_matrix(self, sl, hs, ne):
+# x = torch.randn((sl, hs)).cuda().half()
+#
+# # Assign tokens to experts uniformly.
+# top_expert = torch.arange(0, sl).cuda().int() % ne
+#
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(top_expert, ne)
+# padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+# padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+# out = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, 1)
+# return out, padded_bins
+#
+# def build_weight_matrix(self, ne, hs, fhs):
+# return torch.randn((hs, ne * fhs)).cuda().half()
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_Fwd_SDD_NT(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+# w = transpose_view(w)
+#
+# def benchmark():
+# return stk.ops.sdd(x, w, topo)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0::Fwd::SDD::NT',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_GradX_DSD_NN(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+#
+# def benchmark():
+# return stk.ops.dsd(topo, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0::GradX::DSD::NN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_GradW_DSD_TN(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+# topo = topo.t()
+#
+# def benchmark():
+# return stk.ops.dsd(topo, x)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0::GradW::DSD::TN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_Fwd_DSD_NN(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+#
+# def benchmark():
+# return stk.ops.dsd(x, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::Fwd::DSD::NN',
+# arguments,
+# mean_t,
+# std_t,
+# x.nnz * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_GradX_SDD_NT(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+# out = stk.ops.dsd(x, w)
+# w = transpose_view(w)
+#
+# def benchmark():
+# return stk.ops.sdd(out, w, x)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::GradX::SDD::NT',
+# arguments,
+# mean_t,
+# std_t,
+# x.nnz * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_GradW_DSD_TN(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+# out = stk.ops.dsd(x, w)
+# x = x.t()
+#
+# def benchmark():
+# return stk.ops.dsd(x, out)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::GradW::DSD::TN',
+# arguments,
+# mean_t,
+# std_t,
+# x.nnz * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_Fwd_DDD_NT(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, hs)).cuda().half()
+# w = torch.randn((ne, hs, fhs)).cuda().half()
+#
+# w = w.transpose(1, 2).contiguous()
+# w = w.transpose(1, 2)
+#
+# def benchmark():
+# return torch.bmm(x, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0::Fwd:DDD::NT',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_GradX_DDD_NN(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, hs)).cuda().half()
+# w = torch.randn((ne, hs, fhs)).cuda().half()
+# out = torch.bmm(x, w)
+# w = w.transpose(1, 2).contiguous()
+#
+# def benchmark():
+# return torch.bmm(out, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0:GradX:DDD::NN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_GradW_DDD_TN(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, hs)).cuda().half()
+# w = torch.randn((ne, hs, fhs)).cuda().half()
+# out = torch.bmm(x, w)
+# out = out.transpose(1, 2)
+#
+# def benchmark():
+# return torch.bmm(out, x)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0:GradW:DDD::TN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_Fwd_DDD_NN(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+# w = torch.randn((ne, fhs, hs)).cuda().half()
+#
+# def benchmark():
+# return torch.bmm(x, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::Fwd::DDD::NN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_GradX_DDD_NT(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+# w = torch.randn((ne, fhs, hs)).cuda().half()
+# out = torch.bmm(x, w)
+# w = torch.transpose(w, 1, 2)
+#
+# def benchmark():
+# return torch.bmm(out, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::GradX::DDD::NT',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_GradW_DDD_TN(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+# w = torch.randn((ne, fhs, hs)).cuda().half()
+# out = torch.bmm(x, w)
+# x = torch.transpose(x, 1, 2)
+#
+# def benchmark():
+# return torch.bmm(x, out)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::GradW::DDD::TN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * hs * 2,
+# )
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch210-cxx11-cu130-aarch64-linux/ops/padded_gather.py b/build/torch210-cxx11-cu130-aarch64-linux/ops/padded_gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1cf4047c9494394d2a3884ba8830179013db7ff
--- /dev/null
+++ b/build/torch210-cxx11-cu130-aarch64-linux/ops/padded_gather.py
@@ -0,0 +1,55 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for padded_gather kernel.
+class PaddedGatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bin_ids, bins, padded_bins)
+ ctx.top_k = top_k
+ return kernels.padded_gather(
+ x,
+ indices,
+ bin_ids,
+ None,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+
+ indices, bin_ids, bins, padded_bins = ctx.saved_tensors
+ out = kernels.padded_scatter(
+ grad,
+ indices,
+ bin_ids,
+ None,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+ return out, None, None, None, None, None
+
+
+padded_gather = PaddedGatherOp.apply
diff --git a/build/torch210-cxx11-cu130-aarch64-linux/ops/padded_scatter.py b/build/torch210-cxx11-cu130-aarch64-linux/ops/padded_scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..61e021b81497e472cda5d72bdac557a0ca92d262
--- /dev/null
+++ b/build/torch210-cxx11-cu130-aarch64-linux/ops/padded_scatter.py
@@ -0,0 +1,98 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for padded_scatter kernel.
+class PaddedScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+ ):
+ maybe_x = [x] if ctx.needs_input_grad[3] else []
+ ctx.save_for_backward(
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ *maybe_x,
+ )
+ ctx.top_k = top_k
+ ctx.x_shape = x.shape
+ return kernels.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ saved_tensors = ctx.saved_tensors
+
+ indices, bin_ids, weights, bins, padded_bins = saved_tensors[:5]
+ dgrad = None
+ if ctx.needs_input_grad[0]:
+ dgrad = kernels.padded_gather(
+ grad,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[3]: # need wgrad
+ x = saved_tensors[-1]
+ wgrad = kernels.padded_scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+ return dgrad, None, None, wgrad, None, None, None, None
+
+
+def padded_scatter(
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+):
+ return PaddedScatterOp.apply(
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
diff --git a/build/torch210-cxx11-cu130-aarch64-linux/ops/padded_scatter_benchmark.py b/build/torch210-cxx11-cu130-aarch64-linux/ops/padded_scatter_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..3cbe4735891446b46f93170c64c23fe63632bf93
--- /dev/null
+++ b/build/torch210-cxx11-cu130-aarch64-linux/ops/padded_scatter_benchmark.py
@@ -0,0 +1,66 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import torch
+# from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+_PADDED_SCATTER_BENCHMARK = (
+ # dMoE-Medium, 8-way EMP.
+ (1024 * 16, 1024, 8, 4),
+ # dMoE-Medium, post-all-to-all.
+ (1024 * 16 * 4, 1024, 8, 1),
+)
+
+
+# class PaddedScatterTest(parameterized.TestCase):
+#
+# @parameterized.parameters(*_PADDED_SCATTER_BENCHMARK)
+# def testPaddedScatter(self, sl, hs, ne, top_k):
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+#
+# # Randomly assign tokens to experts.
+# top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(top_expert, ne)
+# padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+# padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+#
+# # Sample weights for the scatter reduce.
+# weights = torch.rand((sl * top_k,)).cuda().half()
+#
+# # Gather the data to prepare for backwards.
+# x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
+#
+# def benchmark():
+# return ops.padded_scatter(
+# x,
+# indices,
+# bin_ids,
+# weights,
+# bins,
+# padded_bins,
+# top_k,
+# )
+#
+# time, std = benchmark_util.benchmark_function(benchmark)
+# benchmark_util.log_benchmark(
+# 'Padded Scatter',
+# {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# 'top_k': top_k,
+# },
+# time,
+# std,
+# )
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch210-cxx11-cu130-aarch64-linux/ops/permute_benchmark.py b/build/torch210-cxx11-cu130-aarch64-linux/ops/permute_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..697abddbb3a2082ec4ddd6d94f89f7faabb34b40
--- /dev/null
+++ b/build/torch210-cxx11-cu130-aarch64-linux/ops/permute_benchmark.py
@@ -0,0 +1,149 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import torch
+# from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+_PERMUTE_TESTS = (
+ (16384, 768, 2),
+ (16384, 768, 4),
+ (16384, 768, 8),
+ (16384, 768, 16),
+ (16384, 768, 32),
+ (16384, 768, 64),
+ (16384, 768, 128),
+ (16384 * 8, 768, 2),
+ (16384 * 8, 768, 4),
+ (16384 * 8, 768, 8),
+ (16384 * 8, 768, 16),
+ (16384 * 8, 768, 32),
+ (16384 * 8, 768, 64),
+ (16384 * 8, 768, 128),
+)
+
+
+# class PermuteBenchmark(parameterized.TestCase):
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testBinnedGather(self, sl, hs, ne):
+# # NOTE: Capacity factor == 1.
+# ec = sl // ne
+#
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+# top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(indices, ne)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+#
+# def benchmark():
+# return ops.binned_gather(x, indices, bins, ec)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('BinnedGather', arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testBinnedScatter(self, sl, hs, ne):
+# # NOTE: Capacity factor == 1.
+# ec = sl // ne
+#
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+# top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(indices, ne)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+# x = ops.binned_gather(x, indices, bins, ec)
+#
+# def benchmark():
+# return ops.binned_scatter(x, indices, bins)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('BinnedScatter', arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testPaddedGather(self, sl, hs, ne):
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+#
+# # Randomly assign tokens to experts.
+# top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(top_expert, ne)
+# padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+# padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+#
+# def benchmark():
+# return ops.padded_gather(x, indices, bin_ids, bins, padded_bins)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('PaddedGather', arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testPaddedScatter(self, sl, hs, ne):
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+#
+# # Randomly assign tokens to experts.
+# top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(top_expert, ne)
+# padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+# padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+# x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins)
+#
+# def benchmark():
+# return ops.padded_scatter(x, indices, bin_ids, bins, padded_bins)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('PaddedScatter', arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testCopy(self, sl, hs, ne):
+# # NOTE: Capacity factor == 1.
+# # ec = sl // ne
+#
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+# y = x.clone()
+#
+# def benchmark():
+# return y.copy_(x)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('Copy', arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch210-cxx11-cu130-aarch64-linux/ops/repeat.py b/build/torch210-cxx11-cu130-aarch64-linux/ops/repeat.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e9e09de5f857d51cd758ab30b2f3a846d6f9275
--- /dev/null
+++ b/build/torch210-cxx11-cu130-aarch64-linux/ops/repeat.py
@@ -0,0 +1,10 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+
+def repeat(x: torch.Tensor, tiling: torch.Size):
+ if all((t == 1 for t in tiling)):
+ return x
+ return x.repeat(*tiling)
diff --git a/build/torch210-cxx11-cu130-aarch64-linux/ops/replicate.py b/build/torch210-cxx11-cu130-aarch64-linux/ops/replicate.py
new file mode 100644
index 0000000000000000000000000000000000000000..26daf0eede330603a4b8ea7167faf1411d07ca93
--- /dev/null
+++ b/build/torch210-cxx11-cu130-aarch64-linux/ops/replicate.py
@@ -0,0 +1,36 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for replicate kernel.
+class ReplicateOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, bins: torch.Tensor, num_outputs: int):
+ ctx.save_for_backward(bins)
+ out = torch.empty((x.shape[0], num_outputs), dtype=x.dtype, device=x.device)
+ ops.replicate_forward(x, bins, out)
+ return out
+
+ @staticmethod
+ def backward(ctx: Any, grad: torch.Tensor):
+ bins, = ctx.saved_tensors
+ out = torch.empty((grad.shape[0], bins.shape[0]), dtype=grad.dtype, device=grad.device)
+ ops.replicate_backward(grad, bins, out)
+ return out, None, None
+
+
+replicate = ReplicateOp.apply
diff --git a/build/torch210-cxx11-cu130-aarch64-linux/ops/round_up.py b/build/torch210-cxx11-cu130-aarch64-linux/ops/round_up.py
new file mode 100644
index 0000000000000000000000000000000000000000..6cf6bc873c9f448c5fa9126ebcfd66e8688002af
--- /dev/null
+++ b/build/torch210-cxx11-cu130-aarch64-linux/ops/round_up.py
@@ -0,0 +1,14 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+
+def round_up(x: torch.Tensor, value: int):
+ assert isinstance(value, int)
+ assert x.dtype == torch.int32
+
+ # TODO(tgale): If this becomes and issue
+ # do this in a custom kernel. We only expect
+ # to use this on arrays of less than 1k elements.
+ return torch.div(x + (value - 1), value, rounding_mode='trunc') * value
diff --git a/build/torch210-cxx11-cu130-aarch64-linux/ops/scatter.py b/build/torch210-cxx11-cu130-aarch64-linux/ops/scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4605d9b46f387761b070352365f223dbfe69d47
--- /dev/null
+++ b/build/torch210-cxx11-cu130-aarch64-linux/ops/scatter.py
@@ -0,0 +1,72 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Optional
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for scatter kernel.
+class ScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ) -> torch.Tensor:
+ maybe_x = [x] if ctx.needs_input_grad[3] else []
+ ctx.save_for_backward(indices, bin_ids, weights, bins, *maybe_x)
+ ctx.top_k = top_k
+ ctx.x_shape = x.shape
+ return kernels.scatter(x, indices, bin_ids, weights, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ saved_tensors = ctx.saved_tensors
+
+ indices, bin_ids, weights, bins = saved_tensors[:4]
+ dgrad = None
+ if ctx.needs_input_grad[0]:
+ dgrad = kernels.gather(
+ grad,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[3]: # need wgrad
+ x = saved_tensors[-1]
+ wgrad = kernels.scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bin_ids,
+ bins,
+ ctx.top_k,
+ )
+ return dgrad, None, None, wgrad, None, None, None
+
+
+def scatter(
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+) -> Optional[torch.Tensor]:
+ return ScatterOp.apply(x, indices, bin_ids, weights, bins, top_k)
diff --git a/build/torch210-cxx11-cu130-aarch64-linux/ops/sort.py b/build/torch210-cxx11-cu130-aarch64-linux/ops/sort.py
new file mode 100644
index 0000000000000000000000000000000000000000..bda3bf64283e39533c2eae3627e76bb2d0262c9f
--- /dev/null
+++ b/build/torch210-cxx11-cu130-aarch64-linux/ops/sort.py
@@ -0,0 +1,38 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Optional, Tuple
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+_BITS_FOR_DTYPE = {
+ torch.int16: 16,
+ torch.int32: 32,
+ torch.int64: 64,
+}
+
+
+# Autograd wrapper for sort kernel.
+# NOTE: Does not support gradients.
+class SortOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, end_bit: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
+ if end_bit is None:
+ end_bit = _BITS_FOR_DTYPE[x.dtype]
+ x_out = torch.empty_like(x)
+ iota_out = torch.empty_like(x)
+ ops.sort(x, end_bit, x_out, iota_out)
+ return (x_out, iota_out)
+
+
+sort = SortOp.apply
diff --git a/build/torch210-cxx11-cu130-aarch64-linux/ops/sort_benchmark.py b/build/torch210-cxx11-cu130-aarch64-linux/ops/sort_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..11043c0824c36372585f1d9f48480c2a6ef32eb6
--- /dev/null
+++ b/build/torch210-cxx11-cu130-aarch64-linux/ops/sort_benchmark.py
@@ -0,0 +1,85 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import numpy as np
+import torch
+# from absl.testing import parameterized
+
+from .. import ops
+
+_SORT_TESTS = (
+ (16384, torch.int32, None),
+ (16384, torch.int32, 2),
+ (16384, torch.int32, 128),
+)
+
+_BASELINE_SORT_TESTS = ((16384,),)
+
+
+def numpy_dtype(dtype):
+ types = {
+ torch.int16: np.int16,
+ torch.int32: np.int32,
+ torch.int64: np.int64,
+ }
+ return types[dtype]
+
+
+def benchmark_function(fn, iterations=10):
+ # Run once to get rid of startup overhead.
+ fn()
+ times = []
+ for _ in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+ start.record()
+ fn()
+ end.record()
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ times = np.array(times)
+ return times.mean(), times.std(), times.max(), times.min()
+
+
+def log_benchmark(arguments, mean_t, std_t):
+ print('=' * 60)
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t))
+ print('=' * 60)
+
+
+# class SortBenchmark(parameterized.TestCase):
+#
+# @parameterized.parameters(*_SORT_TESTS)
+# def testSort(self, n, dtype, max_val):
+# if max_val is None:
+# max_val = np.iinfo(numpy_dtype(dtype)).max
+# end_bit = int(np.ceil(np.log2(max_val)))
+# x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
+#
+# mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.sort(x, end_bit),)
+# arguments = {
+# 'n': n,
+# 'dtype': dtype,
+# 'max_val': max_val,
+# }
+# log_benchmark(arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_BASELINE_SORT_TESTS)
+# def testTorchSort(self, n):
+# x = torch.randint(0, 128, (n,)).cuda().to(torch.int32)
+#
+# mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.sort(x))
+# arguments = {
+# 'n': n,
+# }
+# log_benchmark(arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch210-cxx11-cu130-aarch64-linux/ops/stk_autocast.py b/build/torch210-cxx11-cu130-aarch64-linux/ops/stk_autocast.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a3626e5e0eec51339c95a448bca84be14a2ca93
--- /dev/null
+++ b/build/torch210-cxx11-cu130-aarch64-linux/ops/stk_autocast.py
@@ -0,0 +1,39 @@
+# vendored from
+# https://github.com/stanford-futuredata/stk/blob/736313768ef697ce13a0594a41b2512a0fbc9884/stk/backend/autocast.py
+import functools
+import torch
+
+
+def _is_eligible(x):
+ return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64)
+
+
+def _cast(x, dtype):
+ if isinstance(x, torch.Tensor) and _is_eligible(x):
+ return x.to(dtype)
+ elif isinstance(x, map):
+ return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()}
+ elif isinstance(x, list) or isinstance(x, tuple):
+ return type(x)(map(lambda y: _cast(y, dtype), x))
+ return x
+
+
+def custom_fwd(fwd):
+ """Wrap a custom autograd function that always uses autocast dtype."""
+
+ @functools.wraps(fwd)
+ def decorate_fwd(*args, **kwargs):
+ if torch.is_autocast_enabled():
+ with torch.autocast(device_type="cuda", enabled=False):
+ dtype = torch.get_autocast_gpu_dtype()
+ return fwd(*_cast(args, dtype), **_cast(kwargs, dtype))
+ return fwd(*args, **kwargs)
+ return decorate_fwd
+
+
+def custom_bwd(bwd):
+ @functools.wraps(bwd)
+ def decorate_bwd(*args, **kwargs):
+ with torch.autocast(device_type="cuda", enabled=False):
+ return bwd(*args, **kwargs)
+ return decorate_bwd
\ No newline at end of file
diff --git a/build/torch210-cxx11-cu130-aarch64-linux/ops/sum.py b/build/torch210-cxx11-cu130-aarch64-linux/ops/sum.py
new file mode 100644
index 0000000000000000000000000000000000000000..e00c1aa68e1193f5b72f75a2edc37de8d505facc
--- /dev/null
+++ b/build/torch210-cxx11-cu130-aarch64-linux/ops/sum.py
@@ -0,0 +1,9 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+import torch
+
+
+def sum(x: torch.Tensor, dim: int = 0):
+ if x.shape[dim] == 1:
+ return x.squeeze(dim=dim)
+ return x.sum(dim=dim)
diff --git a/build/torch210-cxx11-cu130-aarch64-linux/ops/topology.py b/build/torch210-cxx11-cu130-aarch64-linux/ops/topology.py
new file mode 100644
index 0000000000000000000000000000000000000000..76a50d3164db20534b099dcb4d8487a7aef25d15
--- /dev/null
+++ b/build/torch210-cxx11-cu130-aarch64-linux/ops/topology.py
@@ -0,0 +1,45 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for topology kernel.
+# NOTE: Does not support gradients.
+class TopologyOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(
+ ctx: Any,
+ padded_bins: torch.Tensor,
+ block_size: int,
+ output_block_rows: int,
+ output_block_columns: int,
+ ):
+ out = torch.empty(
+ output_block_rows * output_block_columns,
+ dtype=torch.int16,
+ device=padded_bins.device,
+ )
+ ops.indices(
+ padded_bins,
+ block_size,
+ output_block_rows,
+ output_block_columns,
+ out,
+ )
+ return out
+
+
+topology = TopologyOp.apply
diff --git a/build/torch210-cxx11-cu130-aarch64-linux/stk/__init__.py b/build/torch210-cxx11-cu130-aarch64-linux/stk/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..73c40c267b1c3f4949e9c957a5d2c9f682dfc1a6
--- /dev/null
+++ b/build/torch210-cxx11-cu130-aarch64-linux/stk/__init__.py
@@ -0,0 +1,7 @@
+# import stk.random
+# import stk.ops
+# from stk.matrix import Matrix
+
+from . import random
+from . import ops
+from .matrix import Matrix
diff --git a/build/torch210-cxx11-cu130-aarch64-linux/stk/backend/__init__.py b/build/torch210-cxx11-cu130-aarch64-linux/stk/backend/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/build/torch210-cxx11-cu130-aarch64-linux/stk/backend/autocast.py b/build/torch210-cxx11-cu130-aarch64-linux/stk/backend/autocast.py
new file mode 100644
index 0000000000000000000000000000000000000000..97f6e919a60f3fd579ed0215031008d14111dc96
--- /dev/null
+++ b/build/torch210-cxx11-cu130-aarch64-linux/stk/backend/autocast.py
@@ -0,0 +1,37 @@
+import functools
+import torch
+
+
+def _is_eligible(x):
+ return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64)
+
+
+def _cast(x, dtype):
+ if isinstance(x, torch.Tensor) and _is_eligible(x):
+ return x.to(dtype)
+ elif isinstance(x, map):
+ return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()}
+ elif isinstance(x, list) or isinstance(x, tuple):
+ return type(x)(map(lambda y: _cast(y, dtype), x))
+ return x
+
+
+def custom_fwd(fwd):
+ """Wrap a custom autograd function that always uses autocast dtype."""
+
+ @functools.wraps(fwd)
+ def decorate_fwd(*args, **kwargs):
+ if torch.is_autocast_enabled():
+ with torch.autocast(device_type="cuda", enabled=False):
+ dtype = torch.get_autocast_gpu_dtype()
+ return fwd(*_cast(args, dtype), **_cast(kwargs, dtype))
+ return fwd(*args, **kwargs)
+ return decorate_fwd
+
+
+def custom_bwd(bwd):
+ @functools.wraps(bwd)
+ def decorate_bwd(*args, **kwargs):
+ with torch.autocast(device_type="cuda", enabled=False):
+ return bwd(*args, **kwargs)
+ return decorate_bwd
diff --git a/build/torch210-cxx11-cu130-aarch64-linux/stk/backend/sputnik.py b/build/torch210-cxx11-cu130-aarch64-linux/stk/backend/sputnik.py
new file mode 100644
index 0000000000000000000000000000000000000000..220c947bc1e932e8c77cc30f4069e9930f1aa962
--- /dev/null
+++ b/build/torch210-cxx11-cu130-aarch64-linux/stk/backend/sputnik.py
@@ -0,0 +1,316 @@
+import torch
+
+from ..backend import triton_kernels as backend
+from ..backend.autocast import custom_bwd, custom_fwd
+
+
+def _standardize_shape(x, transpose):
+ if transpose:
+ return torch.Size((x[1], x[0]))
+ return x
+
+
+def _sparse_transpose(x):
+ return (torch.Size((x[0][1], x[0][0])), ) + x[1:]
+
+
+def _transpose_helper(x, transpose):
+ if isinstance(x, torch.Tensor):
+ return x.t() if transpose else x
+ if transpose:
+ x = _sparse_transpose(x)
+ return x + (transpose,)
+
+
+def _wrap(x):
+ if isinstance(x, torch.Tensor):
+ return (x,)
+ return x
+
+
+def _is_transposed(x):
+ return (not x.is_contiguous() and
+ x.stride()[0] == 1 and
+ x.stride()[1] == x.size()[0])
+
+
+def _call_helper(op, out, a, b, trans_a, trans_b):
+ args = (_wrap(_transpose_helper(a, trans_a)) +
+ _wrap(_transpose_helper(b, trans_b)))
+ if isinstance(out, tuple):
+ args = args + out
+ return op(*args)
+
+
+def _preprocess_inputs(lhs, rhs, dy):
+ if isinstance(lhs, torch.Tensor) and _is_transposed(lhs):
+ lhs = lhs.t()
+ if isinstance(rhs, torch.Tensor) and _is_transposed(rhs):
+ rhs = rhs.t()
+ if (isinstance(dy, torch.Tensor) and
+ not dy.is_contiguous() and
+ not _is_transposed(dy)):
+ dy = dy.contiguous()
+ if isinstance(dy, tuple) and not dy[1].is_contiguous():
+ dy = (dy[0], dy[1].contiguous()) + dy[2:]
+ return lhs, rhs, dy
+
+
+def _postprocess_outputs(x, transpose, grad):
+ if isinstance(x, torch.Tensor) and transpose:
+ return grad.t()
+ return grad
+
+
+def _lhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs):
+ lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy)
+
+ a, b = (rhs, dy) if trans_lhs else (dy, rhs)
+ trans_a = trans_lhs and trans_rhs
+ trans_b = trans_lhs or not trans_rhs
+ out = _call_helper(op, lhs, a, b, trans_a, trans_b)
+ return _postprocess_outputs(lhs, trans_lhs, out)
+
+
+def _rhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs):
+ lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy)
+
+ a, b = (dy, lhs) if trans_rhs else (lhs, dy)
+ trans_a = not trans_lhs or trans_rhs
+ trans_b = trans_lhs and trans_rhs
+ out = _call_helper(op, rhs, a, b, trans_a, trans_b)
+ return _postprocess_outputs(rhs, trans_rhs, out)
+
+
+class DSD(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs):
+ ctx.save_for_backward(data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ rhs)
+ ctx.shape = _standardize_shape(shape, transpose_a)
+ ctx.transpose_a = transpose_a
+
+ out = torch.empty(
+ (shape[0], rhs.size()[1]),
+ dtype=rhs.dtype,
+ device=rhs.device)
+
+ backend.dsd(shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs,
+ out)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs = (ctx.shape,) + saved_tensors[:-1]
+ rhs = saved_tensors[-1]
+ trans_a = ctx.transpose_a
+ trans_b = _is_transposed(rhs)
+
+ ddata = None
+ if ctx.needs_input_grad[1]:
+ ddata = _lhs_gradient(sdd,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ drhs = None
+ if ctx.needs_input_grad[-1]:
+ op = dds if trans_b else dsd
+ drhs = _rhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return None, ddata, None, None, None, None, None, None, None, drhs
+
+
+dsd = DSD.apply
+
+
+class DDS(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b):
+ ctx.save_for_backward(lhs,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t)
+ ctx.shape = _standardize_shape(shape, transpose_b)
+ ctx.transpose_b = transpose_b
+ out = torch.empty((lhs.size()[0], shape[1]),
+ dtype=lhs.dtype,
+ device=lhs.device)
+ backend.dds(lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b,
+ out)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs = saved_tensors[0]
+ rhs = (ctx.shape,) + saved_tensors[1:]
+ trans_a = _is_transposed(lhs)
+ trans_b = ctx.transpose_b
+
+ dlhs = None
+ if ctx.needs_input_grad[0]:
+ op = dsd if trans_a else dds
+ dlhs = _lhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ ddata = None
+ if ctx.needs_input_grad[2]:
+ ddata = _rhs_gradient(sdd,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return dlhs, None, ddata, None, None, None, None, None, None, None
+
+
+dds = DDS.apply
+
+
+class SDD(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ lhs,
+ rhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t):
+ ctx.save_for_backward(
+ lhs,
+ rhs,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t)
+ ctx.shape = shape
+ out = torch.empty(
+ data.shape,
+ dtype=lhs.dtype,
+ device=lhs.device)
+ backend.sdd(lhs,
+ rhs,
+ shape,
+ out,
+ offsets,
+ row_indices,
+ column_indices)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs, rhs = saved_tensors[:2]
+ dy = (ctx.shape, dy) + saved_tensors[2:]
+ trans_a = _is_transposed(lhs)
+ trans_b = _is_transposed(rhs)
+
+ dlhs = None
+ if ctx.needs_input_grad[0]:
+ op = dds if trans_a else dsd
+ dlhs = _lhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ drhs = None
+ if ctx.needs_input_grad[1]:
+ op = dsd if trans_b else dds
+ drhs = _rhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return dlhs, drhs, None, None, None, None, None, None, None, None
+
+
+sdd = SDD.apply
+
+class RowIndices(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, shape, data, offsets, column_indices):
+ out = torch.empty(
+ column_indices.shape,
+ dtype=column_indices.dtype,
+ device=column_indices.device)
+ backend.row_indices(shape, data, offsets, column_indices, out)
+ return out
+
+
+row_indices = RowIndices.apply
diff --git a/build/torch210-cxx11-cu130-aarch64-linux/stk/backend/triton_kernels.py b/build/torch210-cxx11-cu130-aarch64-linux/stk/backend/triton_kernels.py
new file mode 100644
index 0000000000000000000000000000000000000000..c535309f3321249f475367164a558f94a4f8eb86
--- /dev/null
+++ b/build/torch210-cxx11-cu130-aarch64-linux/stk/backend/triton_kernels.py
@@ -0,0 +1,393 @@
+import torch
+import triton
+import triton.language as tl
+from dataclasses import dataclass
+
+@dataclass
+class TritonConfig:
+ BLOCK_M: int = 128
+ BLOCK_N: int = 128
+ BLOCK_K: int = 32
+ BLOCK_SIZE: int = 128
+ NUM_STAGES: int = 4
+ NUM_WARPS: int = 4
+
+def _validate_matmul_dims(M: int, K: int, N: int):
+ error_string = "incompatible dimensions: tensor has dim with length: {}, which must be divisible by {}"
+ assert M % TritonConfig.BLOCK_M == 0, error_string.format(M, TritonConfig.BLOCK_M)
+ assert K % TritonConfig.BLOCK_K == 0, error_string.format(K, TritonConfig.BLOCK_K)
+ assert N % TritonConfig.BLOCK_N == 0, error_string.format(N, TritonConfig.BLOCK_N)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _sdd_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+ # matrix multiplication
+ pid = tl.program_id(0)
+ pid_m = tl.load(row_indices + pid)
+ pid_n = tl.load(column_indices + pid)
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
+ rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
+ rk = tl.arange(0, BLOCK_K)
+ # pointers
+ A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
+ B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ for k in range(0, tl.cdiv(K, BLOCK_K)):
+ a = tl.load(A)
+ b = tl.load(B)
+ acc += tl.dot(a, b)
+ A += BLOCK_K * stride_ak
+ B += BLOCK_K * stride_bk
+ #Store to sparse matrix
+ acc = acc.to(C.dtype.element_ty)
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+ cm = tl.arange(0, BLOCK_M)
+ cn = tl.arange(0, BLOCK_N)
+ C = C + pid * BLOCK_ELEMENTS + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _dsd_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices, offsets,
+ block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+
+ # matrix multiplication
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+
+ num_pid_m = tl.num_programs(0)
+ num_pid_n = tl.num_programs(1)
+ pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M)
+
+ start_inx = tl.load(offsets + pid_m)
+ end_inx = tl.load(offsets + pid_m + 1)
+
+ # pointers to sparse matrix
+ rm = tl.arange(0, BLOCK_M)
+ rak = tl.arange(0, BLOCK_K)
+
+ A += (rm[:, None] * stride_am + rak[None, :] * stride_ak)
+
+ # pointers to dense matrix
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ rbk = tl.arange(0, BLOCK_K)
+ B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn)
+
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K)
+
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+ ak_sub_incr = BLOCK_K * stride_ak
+ bk_sub_incr = BLOCK_K * stride_bk
+ bk_block_incr = BLOCK_SIZE * stride_bk
+
+ for k in range(nsub_blocks * (end_inx - start_inx)):
+ sub_block_inx = k % nsub_blocks
+ block_inx = k // nsub_blocks
+
+ if trans_A:
+ ptr_A = A + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr
+ else:
+ ptr_A = A + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr
+
+ ptr_B = B + tl.load(column_indices + start_inx + block_inx) * bk_block_incr + sub_block_inx * bk_sub_incr
+
+ a = tl.load(ptr_A)
+ b = tl.load(ptr_B)
+ acc += tl.dot(a, b)
+
+ acc = acc.to(C.dtype.element_ty)
+
+ cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+
+ C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _dds_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices, offsets,
+ block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+
+ # matrix multiplication
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+
+ num_pid_m = tl.num_programs(0)
+ num_pid_n = tl.num_programs(1)
+ pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M)
+
+ start_inx = tl.load(offsets + pid_n)
+ end_inx = tl.load(offsets + pid_n + 1)
+
+ # pointers to dense matrix
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ rak = tl.arange(0, BLOCK_K)
+
+ A += (rm[:, None] * stride_am + rak[None, :] * stride_ak)
+
+ # pointers to sparse matrix
+ rn = tl.arange(0, BLOCK_N)
+ rbk = tl.arange(0, BLOCK_K)
+ B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn)
+
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K)
+
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+
+ ak_sub_incr = BLOCK_K * stride_ak
+ ak_block_incr = BLOCK_SIZE * stride_ak
+ bk_sub_incr = BLOCK_K * stride_bk
+
+ for k in range(nsub_blocks * (end_inx - start_inx)):
+ sub_block_inx = k % nsub_blocks
+ block_inx = k // nsub_blocks
+
+ if trans_B:
+ ptr_B = B + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr
+ else:
+ ptr_B = B + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr
+
+ ptr_A = A + tl.load(column_indices + start_inx + block_inx) * ak_block_incr + sub_block_inx * ak_sub_incr
+ a = tl.load(ptr_A)
+ b = tl.load(ptr_B)
+ acc += tl.dot(a, b)
+
+ acc = acc.to(C.dtype.element_ty)
+ cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+def dsd(shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs,
+ out
+ ):
+
+ device = rhs.device
+ trans_A = transpose_a
+ trans_B = False
+
+ if rhs.stride(0) > 1 and rhs.stride(1) > 1:
+ trans_B = True
+
+ # checks constraints
+ assert shape[1] == rhs.shape[0], "incompatible dimensions"
+ M, K = shape
+ _, N = rhs.shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if rhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ stride_am, stride_ak = data.stride(1), data.stride(2)
+ stride_bk, stride_bn = rhs.stride(0), rhs.stride(1)
+ a_column_indices = column_indices
+ a_offsets = offsets
+
+ # launch kernel
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N']))
+
+ if trans_A:
+ stride_am, stride_ak = data.stride(2), data.stride(1)
+ a_column_indices, a_offsets = column_indices_t, offsets_t
+
+ if trans_B:
+ stride_bk, stride_bn = rhs.stride(1), rhs.stride(0)
+
+ _dsd_kernel[grid](
+ data.data, rhs, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(0), out.stride(1),
+ row_indices, a_column_indices, a_offsets,
+ block_offsets_t, trans_A, trans_B,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+ # return out
+
+def dds(lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b,
+ out
+ ):
+
+ device = lhs.device
+ trans_B = transpose_b
+ trans_A = False
+
+ if lhs.stride(0) > 1 and lhs.stride(1) > 1:
+ trans_A = True
+
+ # checks constraints
+ assert lhs.shape[1] == shape[0], "incompatible dimensions"
+ M, K = lhs.shape
+ _, N = shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if lhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ stride_am, stride_ak = lhs.stride(0), lhs.stride(1)
+ stride_bk, stride_bn = data.stride(1), data.stride(2)
+ b_column_indices = column_indices_t
+ b_offsets = offsets_t
+
+ # launch kernel
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N']))
+
+ if trans_A:
+ stride_am, stride_ak = lhs.stride(1), lhs.stride(0)
+ if trans_B:
+ stride_bk, stride_bn = data.stride(2), data.stride(1)
+ b_column_indices, b_offsets = column_indices, offsets
+
+ _dds_kernel[grid](
+ lhs, data, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(0), out.stride(1),
+ row_indices, b_column_indices, b_offsets,
+ block_offsets_t, trans_A, trans_B,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+
+def sdd(lhs,
+ rhs,
+ shape,
+ out,
+ offsets,
+ row_indices,
+ column_indices
+ ):
+
+ device = out.device
+ trans_A = False
+ trans_B = False
+
+ if lhs.stride(0) > 1 and lhs.stride(1) > 1:
+ trans_A = True
+ if rhs.stride(0) > 1 and rhs.stride(1) > 1:
+ trans_B = True
+
+ # checks constraints
+ assert lhs.shape[1] == rhs.shape[0], "incompatible dimensions"
+ M, K = lhs.shape
+ _, N = rhs.shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if out.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ # launch kernel
+ nnz_blocks = len(row_indices)
+ grid = lambda META: (nnz_blocks,)
+
+ stride_am, stride_ak = lhs.stride(0), lhs.stride(1)
+ stride_bk, stride_bn = rhs.stride(0), rhs.stride(1)
+
+ if trans_A:
+ stride_am, stride_ak = lhs.stride(1), lhs.stride(0)
+ if trans_B:
+ stride_bk, stride_bn = rhs.stride(1), rhs.stride(0)
+
+ _sdd_kernel[grid](
+ lhs, rhs, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(1), out.stride(2),
+ row_indices, column_indices,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+
+@triton.jit
+def _row_indices_kernel(offsets, out):
+ pid = tl.program_id(0)
+ row_offset = tl.load(offsets + pid)
+ nnz_blocks = tl.load(offsets + pid + 1) - row_offset
+ for nnz_block in range(nnz_blocks):
+ tl.store(out + row_offset + nnz_block, pid)
+
+def row_indices(
+ shape, data, offsets, column_indices, out
+):
+ block_rows = len(offsets) - 1
+ _row_indices_kernel[(block_rows, )](offsets, out)
diff --git a/build/torch210-cxx11-cu130-aarch64-linux/stk/matrix.py b/build/torch210-cxx11-cu130-aarch64-linux/stk/matrix.py
new file mode 100644
index 0000000000000000000000000000000000000000..80f42263d6aada287adbfa52a61fe950162a9e28
--- /dev/null
+++ b/build/torch210-cxx11-cu130-aarch64-linux/stk/matrix.py
@@ -0,0 +1,329 @@
+import numpy as np
+import torch
+
+# 1. Add heavyweight (data) validation helper.
+# 2. Add construction helpers
+# 3. Make indentation consistent
+# 4. Replace asserts with descriptive errors.
+
+##
+### Validation helpers.
+##
+
+
+def _validate_matrix(shape, data, row_indices, column_indices, offsets):
+ # Data should be [nnz, block_size, block_size]
+ if data.dim() == 1:
+ data = torch.reshape(data, [data.numel(), 1, 1])
+
+ # Blocks should be square.
+ if data.shape[-2] != data.shape[-1]:
+ raise ValueError(
+ "Expected square blocking in data. "
+ f"Got block shape {[data.shape[-2], data.shape[-1]]}")
+
+ # Flatten batch dimensions on data - original shape preserved
+ # in shape argument.
+ block_size = data.shape[-1]
+ data = data.view([-1, block_size, block_size])
+
+ if data.dim() != 3:
+ raise ValueError(
+ "Expected 3D shape for data (nnz, block, block). "
+ f"Got shape {data.dim()}D shape.")
+
+ block_size = data.shape[1]
+ if shape[-2] % block_size != 0 or shape[-1] % block_size != 0:
+ raise ValueError(
+ "Matrix shape must be dividible by blocking. "
+ f"Got shape {shape} with "
+ f"{[block_size, block_size]} blocking.")
+
+ if np.prod(shape) < data.numel():
+ raise ValueError(
+ "Invalid matrix. Number of nonzeros exceeds matrix capacity "
+ f"({data.numel()} v. {np.prod(shape)})")
+
+ if row_indices.dim() != 1:
+ raise ValueError(
+ f"Expected 1D row_indices. Got {row_indices.dim()}D row_indices.")
+
+ if column_indices.dim() != 1:
+ raise ValueError(
+ f"Expected 1D column_indices. Got {column_indices.dim()}D column_indices.")
+
+ if offsets.dim() != 1:
+ raise ValueError(
+ f"Expected 1D offsets. Got {offsets.dim()}D offsets.")
+
+ if row_indices.numel() != data.shape[0]:
+ raise ValueError(
+ "Expected 1 index per nonzero block. "
+ f"Got {row_indices.numel()} row_indices for {data.shape[0]} blocks")
+
+ if column_indices.numel() != data.shape[0]:
+ raise ValueError(
+ "Expected 1 index per nonzero block. "
+ f"Got {column_indices.numel()} column_indices for {data.shape[0]} blocks")
+
+ block_rows = np.prod(shape[:-1]) / block_size
+ if offsets.numel() != block_rows + 1:
+ raise ValueError(
+ "Expected one offset per block row plus one. "
+ f"Got {offsets.numel()} offsets with {block_rows} block rows.")
+
+ is_cuda = (data.is_cuda and
+ row_indices.is_cuda and
+ column_indices.is_cuda and
+ offsets.is_cuda)
+ is_cpu = (not data.is_cuda and
+ not row_indices.is_cuda and
+ not column_indices.is_cuda and
+ not offsets.is_cuda)
+ if not (is_cuda or is_cpu):
+ raise ValueError(
+ "Expected data & meta-data on common device. "
+ f"Got data on {data.device}, row_indices on {row_indices.device} "
+ f"column_indices on {column_indices.device} and "
+ f"offsets on {offsets.device}.")
+
+ if data.dtype != torch.float16:
+ raise ValueError(
+ f"Expected float16 data. Got {data.dtype} data.")
+ if row_indices.dtype != torch.int16:
+ raise ValueError(
+ f"Expected int16 row_indices. Got {row_indices.dtype} row_indices.")
+ if column_indices.dtype != torch.int16:
+ raise ValueError(
+ f"Expected int16 column_indices. Got {column_indices.dtype} column_indices.")
+ if offsets.dtype != torch.int32:
+ raise ValueError(
+ f"Expected int32 offsets. Got {offsets.dtype} offsets.")
+ return data
+
+
+def _transpose(size, data, row_indices, column_indices, offsets):
+ block_columns = size[1] // data.shape[1]
+
+ # Sort row indices by column indices to get the transposed matrix's
+ # column indices.
+ gather_indices = column_indices.argsort()
+ column_indices_t = row_indices.gather(0, gather_indices)
+ block_offsets_t = gather_indices.int()
+
+ # NOTE: Histogram is not implemented for any integer type on CPU. Do
+ # the histogram in 32-bit float, which can exactly represent 16-bit
+ # integers.
+ column_indices_float = column_indices.float()
+
+ zero = torch.zeros((1,), dtype=torch.int32, device=data.device)
+ nnz_per_column = column_indices_float.histc(block_columns, 0, block_columns)
+ nnz_per_column = nnz_per_column.int()
+ offsets_t = torch.cat([zero, nnz_per_column.cumsum(0, dtype=torch.int32)])
+ return column_indices_t, offsets_t, block_offsets_t
+
+
+class Matrix(torch.nn.Module):
+ """A matrix stored in sparse format.
+
+ Underlying format is block compressed sparse row (BCSR).
+
+ TODO(tgale): Make this mirror torch.Tensor API as much as possible.
+ """
+
+ def __init__(self,
+ size,
+ data,
+ row_indices,
+ column_indices,
+ offsets,
+ column_indices_t=None,
+ offsets_t=None,
+ block_offsets_t=None):
+ super().__init__()
+ self._size = size
+ self._data = data
+ self._row_indices = row_indices
+ self._column_indices = column_indices
+ self._offsets = offsets
+
+ # Produce the transpose meta-data if it is not passed in.
+ if ((column_indices_t is None) or (offsets_t is None) or
+ (block_offsets_t is None)):
+ column_indices_t, offsets_t, block_offsets_t = _transpose(
+ size, data, row_indices, column_indices, offsets)
+ self._column_indices_t = column_indices_t
+ self._offsets_t = offsets_t
+ self._block_offsets_t = block_offsets_t
+
+ self._transposed = False
+
+ # Validate that our metadata will not overflow.
+ max_dim = np.iinfo(np.int16).max * self.blocking
+ if column_indices.dtype == torch.int16:
+ if size[0] > max_dim or size[1] > max_dim:
+ raise ValueError(
+ "Sparse matrix with shape {size} exceeds representable "
+ "size with 16-bit indices.")
+
+ def validate(self):
+ _validate_matrix(self._size,
+ self._data,
+ self._row_indices,
+ self._column_indices,
+ self._offsets)
+
+ # TODO(tgale): Add heavyweight data validation.
+
+ def to(self, device):
+ # TODO(tgale): Handle type conversions here. We
+ # need to set the appropriate meta-data type for
+ # the given floating-point type.
+ self._data = self._data.to(device)
+ self._row_indices = self._row_indices.to(device)
+ self._column_indices = self._column_indices.to(device)
+ self._offsets = self._offsets.to(device)
+ self._column_indices_t = self._column_indices_t.to(device)
+ self._offsets_t = self._offsets_t.to(device)
+ self._block_offsets_t = self._block_offsets_t.to(device)
+ return self
+
+ def cuda(self):
+ return self.to(torch.cuda.current_device())
+
+ def clone(self):
+ return Matrix(
+ self.size(),
+ self.data.clone(),
+ self.row_indices.clone(),
+ self.column_indices.clone(),
+ self.offsets.clone(),
+ self.column_indices_t.clone(),
+ self.offsets_t.clone(),
+ self.block_offsets_t.clone())
+
+ def t(self):
+ if self.dim() != 2:
+ raise ValueError(
+ "t() expects a tensor with <= 2 dimensions, "
+ f"but self is {self.dim()}D.")
+ out = Matrix(self.size(),
+ self.data,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+ out._transposed = not self._transposed
+ out._size = torch.Size((self._size[1], self._size[0]))
+ return out
+
+ def contiguous(self):
+ raise ValueError("Not yet implemented.")
+
+ def is_contiguous(self):
+ return not self._transposed
+
+ @property
+ def is_cuda(self):
+ return self._data.is_cuda
+
+ @property
+ def device(self):
+ return self._data.device
+
+ def size(self):
+ return self._size
+
+ @property
+ def shape(self):
+ return self.size()
+
+ def dim(self):
+ return len(self._size)
+
+ @property
+ def data(self):
+ return self._data
+
+ @property
+ def row_indices(self):
+ return self._row_indices
+
+ @property
+ def column_indices(self):
+ return self._column_indices
+
+ @property
+ def offsets(self):
+ return self._offsets
+
+ @property
+ def offsets_t(self):
+ return self._offsets_t
+
+ @property
+ def column_indices_t(self):
+ return self._column_indices_t
+
+ @property
+ def block_offsets_t(self):
+ return self._block_offsets_t
+
+ @property
+ def dtype(self):
+ return self.data.dtype
+
+ @property
+ def nnz(self):
+ return self.data.numel()
+
+ @property
+ def blocking(self):
+ return self.data.shape[1]
+
+ @property
+ def requires_grad(self):
+ return self.data.requires_grad
+
+ def requires_grad_(self, x):
+ self.data.requires_grad_(x)
+ return self
+
+ def view(self, *shape):
+ assert self.is_contiguous()
+ if shape[-1] != self.size()[-1]:
+ raise ValueError(
+ "Can't change view on compressed dimension. "
+ f"{self.size()[-1]} v. {shape[-1]}.")
+ if np.prod(shape) != np.prod(self.size()):
+ raise ValueError(
+ "Mismatch in numel of Matrix and new shape. "
+ f"{np.prod(self.size())} v. {np.prod(shape)}")
+ return Matrix(shape,
+ self.data,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+
+ @property
+ def grad(self):
+ # TODO(tgale): Make sure this mirrors torch.Tensor
+ # behavior in the case where we ask for the gradient
+ # of a non-contiguous tensor.
+ size = self.size()
+ if not self.is_contiguous():
+ size = torch.Size((size[1], size[0]))
+ out = Matrix(size,
+ self.data.grad,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+ return out if self.is_contiguous() else out.t()
diff --git a/build/torch210-cxx11-cu130-aarch64-linux/stk/ops/__init__.py b/build/torch210-cxx11-cu130-aarch64-linux/stk/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc873b236f4cd4036964c016a4036e3ce5ebf1ac
--- /dev/null
+++ b/build/torch210-cxx11-cu130-aarch64-linux/stk/ops/__init__.py
@@ -0,0 +1,3 @@
+from .linear_ops import dds, dsd, sdd
+from .matrix_ops import ones_like, row_indices, sum, to_dense, to_sparse
+from .eltwise_ops import mul
diff --git a/build/torch210-cxx11-cu130-aarch64-linux/stk/ops/eltwise_ops.py b/build/torch210-cxx11-cu130-aarch64-linux/stk/ops/eltwise_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba7d7332320250fd01fa60e60528f19de3e8ed03
--- /dev/null
+++ b/build/torch210-cxx11-cu130-aarch64-linux/stk/ops/eltwise_ops.py
@@ -0,0 +1,28 @@
+from ..matrix import Matrix
+
+def mul(a, b):
+ """Performs element-wise multiplication of matrices a and b.
+
+ It is the user's responsibility to make sure that a and b
+ follow the same matrix topology. This function assumes it is safe
+ to use the topoplogy of a.
+
+ Args:
+ a: stk.Matrix.
+ b: stk.Matrix with a's matrix topology.
+
+ Returns:
+ stk.Matrix where the entries correspond to torch.mul(a, b).
+ """
+ assert isinstance(a, Matrix)
+ assert isinstance(b, Matrix)
+ assert a.size() == b.size()
+
+ return Matrix(a.size(),
+ a.data * b.data,
+ a.row_indices,
+ a.column_indices,
+ a.offsets,
+ a.column_indices_t,
+ a.offsets_t,
+ a.block_offsets_t)
diff --git a/build/torch210-cxx11-cu130-aarch64-linux/stk/ops/eltwise_ops_test.py b/build/torch210-cxx11-cu130-aarch64-linux/stk/ops/eltwise_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..2939372a5c68ac92b47b11015db4f75f4fd60ffa
--- /dev/null
+++ b/build/torch210-cxx11-cu130-aarch64-linux/stk/ops/eltwise_ops_test.py
@@ -0,0 +1,86 @@
+import unittest
+import itertools
+import torch
+# from absl.testing import parameterized
+
+import stk
+from stk.ops.linear_ops_test import allclose, _dense_and_sparse
+
+_MATRIX_SIZES = (
+ (128, 128, 0.0),
+ (256, 256, 0.5),
+ (2048, 1024, 0.8),
+ (512, 128, 0.0),
+ (128, 512, 0.0),
+ (1024, 512, 0.0),
+ (1024, 512, 0.5),
+ (1024, 512, 0.75),
+ (512, 1024, 0.0),
+ (512, 1024, 0.5),
+ (512, 1024, 0.75),
+ (1024, 1024, 0.0),
+ (1024, 1024, 0.5),
+ (1024, 1024, 0.75),
+)
+
+_DTYPE = (
+ torch.float16, torch.bfloat16
+)
+
+def _generate_testcases():
+ testcases = itertools.product(_MATRIX_SIZES, _DTYPE)
+ testcases = [(*size, 128, dtype) for
+ (size, dtype) in testcases]
+ return testcases
+
+_ELTWISE_OP_TESTS = _generate_testcases()
+
+def _dense_and_sparse_like(x, std=0.1):
+ dense_data = torch.randn_like(x.data, device=x.device) * std
+ sparse = stk.Matrix(x.size(),
+ dense_data,
+ x.row_indices,
+ x.column_indices,
+ x.offsets)
+ dense = stk.ops.to_dense(sparse)
+
+ return (dense.requires_grad_(True),
+ sparse.requires_grad_(True))
+
+# @parameterized.parameters(_ELTWISE_OP_TESTS)
+# class EltwiseOpsTest(parameterized.TestCase):
+#
+# def testEltwiseMul(self, m, n, sparsity, blocking, dtype):
+#
+# a_dense, a = _dense_and_sparse(m, n, sparsity, blocking, dtype)
+# b_dense, b = _dense_and_sparse_like(a)
+#
+# out = stk.ops.mul(a, b)
+# expected_out = torch.mul(a_dense, b_dense)
+#
+# # Compute the gradients w.r.t. the inputs.
+# expected_out.sum().backward()
+# stk.ops.sum(out).backward()
+#
+# # Validate the results.
+# out = stk.ops.to_dense(out)
+# self.assertEqual(out.dim(), 2)
+# self.assertEqual(expected_out.size(), out.size())
+# self.assertTrue(allclose(out, expected_out))
+#
+# # LHS gradient.
+# grad = stk.ops.to_dense(a.grad)
+# expected_grad = a_dense.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size(), grad.size())
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# # RHS gradient.
+# grad = stk.ops.to_dense(b.grad)
+# expected_grad = b_dense.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size(), grad.size())
+# self.assertTrue(allclose(grad, expected_grad))
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch210-cxx11-cu130-aarch64-linux/stk/ops/linear_ops.py b/build/torch210-cxx11-cu130-aarch64-linux/stk/ops/linear_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d277c8c07f9e30addc31900a12175c8a1f4d7ad
--- /dev/null
+++ b/build/torch210-cxx11-cu130-aarch64-linux/stk/ops/linear_ops.py
@@ -0,0 +1,59 @@
+import torch
+
+from ..backend import sputnik
+from ..matrix import Matrix
+
+
+def dsd(a, b):
+ assert isinstance(a, Matrix)
+ assert isinstance(b, torch.Tensor)
+ return sputnik.dsd(
+ a.size(),
+ a.data, a.offsets,
+ a.row_indices,
+ a.column_indices,
+ a.offsets_t,
+ a.column_indices_t,
+ a.block_offsets_t,
+ not a.is_contiguous(),
+ b)
+
+
+def dds(a, b):
+ assert isinstance(a, torch.Tensor)
+ assert isinstance(b, Matrix)
+ return sputnik.dds(
+ a,
+ b.size(),
+ b.data, b.offsets,
+ b.row_indices,
+ b.column_indices,
+ b.offsets_t,
+ b.column_indices_t,
+ b.block_offsets_t,
+ not b.is_contiguous())
+
+
+def sdd(a, b, topo):
+ assert isinstance(a, torch.Tensor)
+ assert isinstance(b, torch.Tensor)
+ assert isinstance(topo, Matrix)
+ assert topo.is_contiguous()
+ out = sputnik.sdd(
+ a, b,
+ topo.size(),
+ topo.data,
+ topo.offsets,
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets_t,
+ topo.column_indices_t,
+ topo.block_offsets_t)
+ return Matrix(topo.size(),
+ out,
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets,
+ topo.column_indices_t,
+ topo.offsets_t,
+ topo.block_offsets_t)
diff --git a/build/torch210-cxx11-cu130-aarch64-linux/stk/ops/linear_ops_test.py b/build/torch210-cxx11-cu130-aarch64-linux/stk/ops/linear_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1c24d350df9c1b2346c7da885502cd696c88867
--- /dev/null
+++ b/build/torch210-cxx11-cu130-aarch64-linux/stk/ops/linear_ops_test.py
@@ -0,0 +1,216 @@
+import unittest
+import itertools
+import numpy as np
+import torch
+# from absl.testing import parameterized
+
+import stk
+
+
+def allclose(x, y, pct=0.25):
+ mask = torch.isclose(x, y, rtol=5e-2)
+ pct_diff = (mask.numel() - mask.sum()) / mask.numel() * 100
+ if pct_diff > pct:
+ print("{:.2f}% of values not close.".format(pct_diff))
+ return False
+ return True
+
+
+# An assortment of problems designed to make sure
+# the bindings are operating correctly.
+_MATRIX_SIZES = (
+ (128, 128, 128, 0.0),
+ (256, 256, 256, 0.5),
+ (2048, 1024, 512, 0.8),
+ (512, 128, 128, 0.0),
+ (128, 128, 512, 0.0),
+ (1024, 512, 512, 0.0),
+ (1024, 512, 512, 0.5),
+ (1024, 512, 512, 0.75),
+ (512, 512, 1024, 0.0),
+ (512, 512, 1024, 0.5),
+ (512, 512, 1024, 0.75),
+ (1024, 1024, 1024, 0.0),
+ (1024, 1024, 1024, 0.5),
+ (1024, 1024, 1024, 0.75),
+)
+
+_TRANSPOSE = (
+ (False, False),
+ (False, True),
+ (True, False),
+ (True, True),
+)
+
+_DTYPE = (
+ torch.float16, torch.bfloat16
+)
+
+def _generate_testcases():
+ testcases = itertools.product(_MATRIX_SIZES, _TRANSPOSE, _DTYPE)
+ testcases = [(*size, *trans, 128, dtype) for
+ (size, trans, dtype) in testcases]
+ return testcases
+
+_LINEAR_OP_TESTS = _generate_testcases()
+
+def _dense_and_sparse(rows, cols, sparsity, blocking, dtype, std=0.1):
+ mask = stk.random.dense_mask(rows, cols, sparsity, blocking)
+ dense = (torch.randn(rows, cols) * std * mask).type(dtype)
+ sparse = stk.ops.to_sparse(dense, blocking)
+ cuda_device = torch.device("cuda")
+ return (dense.to(cuda_device).requires_grad_(True),
+ sparse.to(cuda_device).requires_grad_(True))
+
+
+def _dense(rows, cols, dtype, std=0.1):
+ cuda_device = torch.device("cuda")
+ out = (torch.randn(rows, cols) * std).type(dtype)
+ return out.to(cuda_device).requires_grad_(True)
+
+
+def _dense_2x(rows, cols, dtype):
+ a = _dense(rows, cols, dtype)
+ return a, a.detach().requires_grad_(True)
+
+
+def _with_transpose(op, a, b, trans_a, trans_b):
+ a = a.t() if trans_a else a
+ b = b.t() if trans_b else b
+ return op(a, b)
+
+
+def _mmm(a, b, topo):
+ mask = stk.ops.to_dense(stk.ops.ones_like(topo))
+ return torch.mm(a, b) * mask
+
+
+def _sparse_out_with_transpose(op, a, b, topo, trans_a, trans_b):
+ a = a.t() if trans_a else a
+ b = b.t() if trans_b else b
+ return op(a, b, topo)
+
+
+def _mask(x, mask):
+ mask = stk.ops.to_dense(stk.ops.ones_like(mask))
+ return x * mask
+
+
+# @parameterized.parameters(*_LINEAR_OP_TESTS)
+# class LinearOpsTest(parameterized.TestCase):
+#
+# def testLinearOps_Dsd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+# # Construct the operands.
+# a_shape = (k, m) if trans_a else (m, k)
+# a_dense, a = _dense_and_sparse(*a_shape, sparsity, blocking, dtype)
+# b_shape = (n, k) if trans_b else (k, n)
+# b, bcp = _dense_2x(*b_shape, dtype)
+#
+# # Execute the matmul.
+# out = _with_transpose(stk.ops.dsd, a, b, trans_a, trans_b)
+# expected_out = _with_transpose(torch.mm, a_dense, bcp, trans_a, trans_b)
+#
+# # Compute the gradients w.r.t. the inputs.
+# expected_out.sum().backward()
+# out.sum().backward()
+#
+# # Validate the results.
+# self.assertEqual(out.dim(), 2)
+# self.assertEqual(expected_out.size()[0], out.size()[0])
+# self.assertEqual(expected_out.size()[1], out.size()[1])
+# self.assertTrue(allclose(out, expected_out))
+#
+# # LHS gradient.
+# grad = stk.ops.to_dense(a.grad)
+# expected_grad = _mask(a_dense.grad, a.grad)
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# # RHS gradient.
+# grad = b.grad
+# expected_grad = bcp.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# def testLinearOps_Dds(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+# # Construct the operands.
+# a_shape = (k, m) if trans_a else (m, k)
+# a, acp = _dense_2x(*a_shape, dtype)
+# b_shape = (n, k) if trans_b else (k, n)
+# b_dense, b = _dense_and_sparse(*b_shape, sparsity, blocking, dtype)
+#
+# # Execute the matmul.
+# out = _with_transpose(stk.ops.dds, a, b, trans_a, trans_b)
+# expected_out = _with_transpose(torch.mm, acp, b_dense, trans_a, trans_b)
+#
+# # Compute the gradients w.r.t. the inputs.
+# expected_out.sum().backward()
+# out.sum().backward()
+#
+# # Validate the results.
+# self.assertEqual(out.dim(), 2)
+# self.assertEqual(expected_out.size()[0], out.size()[0])
+# self.assertEqual(expected_out.size()[1], out.size()[1])
+# self.assertTrue(allclose(out, expected_out))
+#
+# # LHS gradient.
+# grad = a.grad
+# expected_grad = acp.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# # RHS gradient.
+# grad = stk.ops.to_dense(b.grad)
+# expected_grad = _mask(b_dense.grad, b.grad)
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# def testLinearOps_Sdd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+# # Construct the operands.
+# a_shape = (k, m) if trans_a else (m, k)
+# a, acp = _dense_2x(*a_shape, dtype)
+# b_shape = (n, k) if trans_b else (k, n)
+# b, bcp = _dense_2x(*b_shape, dtype)
+# _, topo = _dense_and_sparse(m, n, sparsity, blocking, dtype)
+#
+# # Execute the matmul.
+# out = _sparse_out_with_transpose(stk.ops.sdd, a, b, topo, trans_a, trans_b)
+# expected_out = _sparse_out_with_transpose(_mmm, acp, bcp, topo, trans_a, trans_b)
+#
+# # Compute the gradients w.r.t. the inputs.
+# expected_out.sum().backward()
+# stk.ops.sum(out).backward()
+#
+# # Validate the results.
+# out = stk.ops.to_dense(out)
+# self.assertEqual(out.dim(), 2)
+# self.assertEqual(expected_out.size()[0], out.size()[0])
+# self.assertEqual(expected_out.size()[1], out.size()[1])
+# self.assertTrue(allclose(out, expected_out))
+#
+# # LHS gradient.
+# grad = a.grad
+# expected_grad = acp.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# # RHS gradient.
+# grad = b.grad
+# expected_grad = bcp.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch210-cxx11-cu130-aarch64-linux/stk/ops/matrix_ops.py b/build/torch210-cxx11-cu130-aarch64-linux/stk/ops/matrix_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..447c72dc73439d84f58c917676cc04e64f13e97d
--- /dev/null
+++ b/build/torch210-cxx11-cu130-aarch64-linux/stk/ops/matrix_ops.py
@@ -0,0 +1,98 @@
+from ..backend import sputnik
+from ..matrix import Matrix
+import torch
+import numpy as np
+
+
+@torch.no_grad()
+def row_indices(shape, data, offsets, column_indices):
+ return sputnik.row_indices(shape, data, offsets, column_indices)
+
+
+# TODO(tgale): Replace this helper with a custom kernel. This operation
+# is much simpler to do than how it's currently implemented.
+@torch.no_grad()
+def _expand_for_blocking(idxs, blocking):
+ # Duplicate for block column dimension.
+ idxs = torch.reshape(idxs, [idxs.size()[0], 1, 2]).repeat(1, blocking, 1)
+
+ # Update the column indices.
+ idxs[:, :, 1] *= blocking
+ idxs[:, :, 1] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking])
+
+ # Duplicate for block row dimension.
+ idxs = torch.reshape(idxs, [idxs.size()[0], 1, blocking, 2])
+ idxs = idxs.repeat(1, blocking, 1, 1)
+
+ # Update the row indices.
+ idxs[:, :, :, 0] *= blocking
+ idxs[:, :, :, 0] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking, 1])
+ idxs = torch.reshape(idxs, [-1, 2])
+ return idxs
+
+
+# TODO(tgale): Add input type checking.
+@torch.no_grad()
+def to_dense(x):
+ assert isinstance(x, Matrix)
+
+ shape = (np.prod(x.shape[:-1]), x.shape[-1])
+ row_idxs = x.row_indices.type(torch.int32)
+ col_idxs = x.column_indices.type(torch.int32)
+ indices = _expand_for_blocking(torch.stack([row_idxs, col_idxs], dim=1), x.blocking)
+ indices = (indices[:, 0] * shape[1] + indices[:, 1]).type(torch.int64)
+
+ out = torch.zeros(shape[0] * shape[1], dtype=x.dtype, device=x.device)
+ out.scatter_(0, indices, x.data.flatten())
+ return out.reshape(x.size())
+
+
+@torch.no_grad()
+def _mask(x, blocking=1):
+ assert x.dim() == 2
+ assert x.size()[0] % blocking == 0
+ assert x.size()[1] % blocking == 0
+ block_rows = x.size()[0] // blocking
+ block_cols = x.size()[1] // blocking
+ x = torch.reshape(x, [block_rows, blocking, block_cols, blocking])
+ x = torch.sum(torch.abs(x), dim=(1, 3))
+ return x != 0
+
+
+# TODO(tgale): Add input type checking.
+@torch.no_grad()
+def to_sparse(x, blocking=1):
+ m = _mask(x, blocking)
+
+ # TODO(tgale): Set to appropriate type for input matrix.
+ row_nnzs = torch.sum(m, dim=1).type(torch.int32)
+ zeros = torch.zeros((1,), dtype=row_nnzs.dtype, device=row_nnzs.device)
+ offsets = torch.cat([zeros, torch.cumsum(row_nnzs, dim=0)])
+ offsets = offsets.type(torch.int32)
+
+ indices = torch.nonzero(m).type(torch.int16)
+ row_indices = indices[:, 0]
+ column_indices = indices[:, 1]
+
+ # Nonzero indices in the dense matrix.
+ nonzero_indices = torch.nonzero(m)
+ nonzero_indices = _expand_for_blocking(nonzero_indices, blocking)
+ nonzero_indices = nonzero_indices[:, 0] * x.size()[1] + nonzero_indices[:, 1]
+
+ # Gather the data and construct the sparse matrix.
+ data = torch.gather(x.flatten(), dim=0, index=nonzero_indices)
+ data = torch.reshape(data, [-1, blocking, blocking])
+ return Matrix(x.size(), data, row_indices, column_indices, offsets)
+
+
+@torch.no_grad()
+def ones_like(x):
+ return Matrix(x.size(),
+ torch.ones_like(x.data),
+ x.row_indices,
+ x.column_indices, x.offsets)
+
+
+def sum(x):
+ assert isinstance(x, Matrix)
+ return x.data.sum()
diff --git a/build/torch210-cxx11-cu130-aarch64-linux/stk/ops/matrix_ops_test.py b/build/torch210-cxx11-cu130-aarch64-linux/stk/ops/matrix_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..d172d921f6f08b0e4fb709207a458b0e1e071bd0
--- /dev/null
+++ b/build/torch210-cxx11-cu130-aarch64-linux/stk/ops/matrix_ops_test.py
@@ -0,0 +1,62 @@
+import unittest
+
+# from absl.testing import parameterized
+import stk
+import torch
+
+
+# @parameterized.parameters(
+# (8, 16, 0.0, 1),
+# (8, 16, 0.5, 1),
+# (8, 16, .95, 1),
+# (16, 8, 0.0, 1),
+# (16, 8, 0.5, 1),
+# (16, 8, .95, 1),
+# (8, 16, 0.0, 8),
+# (8, 16, 0.5, 8),
+# (8, 16, 1.0, 8),
+# (16, 8, 0.0, 8),
+# (16, 8, 0.5, 8),
+# (16, 8, 1.0, 8),
+# (128, 256, 0.5, 16),
+# (256, 128, 0.75, 32),
+# (512, 512, .875, 128))
+# class MatrixOpsTest(parameterized.TestCase):
+#
+# def testMatrixOps_FormatConversion(self, rows, cols, sparsity, blocking):
+# mask = stk.random.dense_mask(rows, cols, sparsity, blocking)
+# x = (torch.randn(rows, cols) * mask).type(torch.float16)
+#
+# # Convert the matrix to sparse format.
+# sparse_x = stk.ops.to_sparse(x, blocking)
+#
+# # Validate the matrix.
+# sparse_x.validate()
+#
+# # Validate the shape.
+# self.assertEqual(sparse_x.dim(), 2)
+# self.assertEqual(sparse_x.size()[0], rows)
+# self.assertEqual(sparse_x.size()[1], cols)
+#
+# # Validate the sparsity.
+# numblocks = rows // blocking * cols // blocking
+# nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+# self.assertEqual(sparse_x.nnz, nnz)
+#
+# # Convert back to dense format.
+# dense_x = stk.ops.to_dense(sparse_x)
+#
+# # Validate the shape.
+# self.assertEqual(dense_x.dim(), 2)
+# self.assertEqual(dense_x.size()[0], rows)
+# self.assertEqual(dense_x.size()[1], cols)
+#
+# # Validate the sparsity
+# self.assertEqual(torch.count_nonzero(dense_x).item(), nnz)
+#
+# # Validate the output.
+# self.assertTrue(torch.all(torch.eq(x, dense_x)))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch210-cxx11-cu130-aarch64-linux/stk/random/__init__.py b/build/torch210-cxx11-cu130-aarch64-linux/stk/random/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2576d1ca27283f77569a9a620c7c99fa68aaf30e
--- /dev/null
+++ b/build/torch210-cxx11-cu130-aarch64-linux/stk/random/__init__.py
@@ -0,0 +1,2 @@
+# from stk.random.random_ops import dense_mask, mask, randn
+from .random_ops import dense_mask, mask, randn
diff --git a/build/torch210-cxx11-cu130-aarch64-linux/stk/random/random_ops.py b/build/torch210-cxx11-cu130-aarch64-linux/stk/random/random_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1b36771e0eb8e7abf46bcb3b136b5fb1d29df93
--- /dev/null
+++ b/build/torch210-cxx11-cu130-aarch64-linux/stk/random/random_ops.py
@@ -0,0 +1,36 @@
+import numpy as np
+import torch
+from ..ops import matrix_ops
+
+
+@torch.no_grad()
+def dense_mask(rows, cols, sparsity, blocking=1):
+ assert sparsity >= 0.0 and sparsity <= 1.0
+ assert rows % blocking == 0 and cols % blocking == 0
+
+ block_rows, block_cols = (rows // blocking, cols // blocking)
+ nnz = round(block_rows * block_cols * (1 - sparsity))
+
+ out = np.ones(block_rows * block_cols)
+ mask = np.random.choice(out.size, out.size - nnz, replace=False)
+ out[mask] = 0.0
+
+ out = np.tile(
+ np.reshape(out, [block_rows, 1, block_cols, 1]),
+ (1, blocking, 1, blocking))
+ out = np.reshape(out, [rows, cols])
+ return torch.from_numpy(out.astype(np.float32))
+
+
+@torch.no_grad()
+def mask(m, n, sparsity, blocking=1):
+ out = dense_mask(m, n, sparsity, blocking).type(torch.float16)
+ return matrix_ops.to_sparse(out, blocking=blocking)
+
+
+@torch.no_grad()
+def randn(shape, sparsity, blocking=1):
+ shape_2d = (np.prod(shape[:-1]), shape[-1])
+ out = mask(*shape_2d, sparsity, blocking)
+ out.data.copy_(torch.randn(*out.data.shape))
+ return out.view(*shape)
diff --git a/build/torch210-cxx11-cu130-aarch64-linux/stk/random/random_ops_test.py b/build/torch210-cxx11-cu130-aarch64-linux/stk/random/random_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..71d716b78b5ec009cbf9ac2dfdf09162a0102e62
--- /dev/null
+++ b/build/torch210-cxx11-cu130-aarch64-linux/stk/random/random_ops_test.py
@@ -0,0 +1,73 @@
+import unittest
+
+# from absl.testing import parameterized
+from . import random
+import torch
+
+
+# @parameterized.parameters(
+# (8, 16, 0.0, 1),
+# (8, 16, 0.5, 1),
+# (8, 16, .95, 1),
+# (16, 8, 0.0, 1),
+# (16, 8, 0.5, 1),
+# (16, 8, .95, 1),
+# (8, 16, 0.0, 8),
+# (8, 16, 0.5, 8),
+# (8, 16, 1.0, 8),
+# (16, 8, 0.0, 8),
+# (16, 8, 0.5, 8),
+# (16, 8, 1.0, 8),
+# (128, 256, 0.5, 16),
+# (256, 128, 0.75, 32),
+# (512, 512, .875, 128))
+# class RandomOpsTest(parameterized.TestCase):
+#
+# def testRandomOps_DenseMask(self, rows, cols, sparsity, blocking):
+# mask = random.dense_mask(
+# rows, cols, sparsity, blocking)
+#
+# # Validate the shape.
+# self.assertEqual(mask.dim(), 2)
+# self.assertEqual(mask.size()[0], rows)
+# self.assertEqual(mask.size()[1], cols)
+#
+# # Validate the sparsity
+# numblocks = rows // blocking * cols // blocking
+# nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+# self.assertEqual(
+# torch.count_nonzero(mask).item(),
+# nnz)
+#
+# # Check values are zero or one.
+# self.assertTrue(
+# torch.all(torch.logical_or(
+# torch.eq(mask, 0),
+# torch.eq(mask, 1))))
+#
+# def testRandomOps_SparseMask(self, rows, cols, sparsity, blocking):
+# mask = random.mask(
+# rows, cols, sparsity, blocking)
+#
+# # Validate the matrix.
+# mask.validate()
+#
+# # Validate the shape.
+# self.assertEqual(mask.dim(), 2)
+# self.assertEqual(mask.size()[0], rows)
+# self.assertEqual(mask.size()[1], cols)
+#
+# # Validate the sparsity.
+# numblocks = rows // blocking * cols // blocking
+# nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+# self.assertEqual(mask.nnz, nnz)
+#
+# # Check values are zero or one.
+# self.assertTrue(
+# torch.all(torch.logical_or(
+# torch.eq(mask.data, 0),
+# torch.eq(mask.data, 1))))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch210-cxx11-cu130-aarch64-linux/xpu_fused_moe.py b/build/torch210-cxx11-cu130-aarch64-linux/xpu_fused_moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..0440b484e6f4073a384dcf3bc562174d2ef71d25
--- /dev/null
+++ b/build/torch210-cxx11-cu130-aarch64-linux/xpu_fused_moe.py
@@ -0,0 +1,672 @@
+# SPDX-License-Identifier: Apache-2.0
+# MegaBlocks XPU Fused MoE Implementation
+import os
+import torch
+
+from ._ops import ops, add_op_namespace_prefix
+
+from torch.library import register_fake
+
+
+def resolve_dtensor(weight: torch.Tensor):
+ """Convert DTensor to local tensor for use with custom ops."""
+ from torch.distributed._tensor import DTensor
+ if isinstance(weight, DTensor):
+ return weight.to_local()
+ return weight
+
+
+# Register fake/meta kernels for torch.compile compatibility
+def _register_xpu_fake_kernels():
+ """Register fake kernels for XPU MoE operations to support torch.compile."""
+
+ def _register_if_available(op_name, fn):
+ if hasattr(ops, op_name):
+ register_fake(add_op_namespace_prefix(op_name))(fn)
+
+ _register_if_available(
+ "cutlass_grouped_gemm_interface",
+ lambda ptr_A, ptr_B, ptr_scales, ptr_bias, ptr_D, expert_first_token_offset, N, K, num_experts, is_B_int4, is_B_mxfp4: ptr_D,
+ )
+
+ _register_if_available(
+ "fused_moe_prologue",
+ lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, ep_rank, ep_size, num_experts_on_rank: None,
+ )
+
+ _register_if_available(
+ "moe_gather",
+ lambda output, moe_output, topk_weights, permuted_row_to_unpermuted_row, unpermuted_row_to_permuted_row, expert_first_token_offset, num_experts: None,
+ )
+
+ _register_if_available(
+ "silu_and_mul",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "mul_and_silu",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_and_mul",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_tanh_and_mul",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_fast",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_new",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_quick",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "swigluoai_and_mul",
+ lambda out, input, alpha=1.702, limit=7.0: None,
+ )
+
+
+# Register fake kernels on module load
+_register_xpu_fake_kernels()
+
+
+# default
+def cutlass_grouped_gemm(input_A, input_B, bias, output, expert_token_count, n,
+ k, num_experts):
+ # expert_token_count_ = torch.tensor(expert_token_count,
+ # dtype=torch.int64,
+ # device=input_A.device)
+ # if bias is not None:
+ # bias = bias.repeat_interleave(expert_token_count_, dim=0).float()
+
+ def exclusive_prefix_sum(arr):
+ prefix = [0]
+ for i, x in enumerate(arr):
+ prefix.append(prefix[-1] + x)
+ return prefix
+
+ expert_offset = torch.tensor(exclusive_prefix_sum(expert_token_count),
+ dtype=torch.int64,
+ device="xpu")
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=None,
+ ptr_bias=bias,
+ ptr_D=output,
+ expert_first_token_offset=expert_offset,
+ N=n,
+ K=k,
+ num_experts=num_experts,
+ is_B_int4=False,
+ is_B_mxfp4=False)
+
+
+def cutlass_grouped_gemm_xe2(input_A, input_B, scales, bias, output,
+ num_rows_per_expert, n, k, num_experts, is_B_int4,
+ is_B_mxfp4):
+ expert_first_token_offset = torch.cat([
+ torch.tensor([0],
+ dtype=num_rows_per_expert.dtype,
+ device=num_rows_per_expert.device),
+ torch.cumsum(num_rows_per_expert, dim=0)
+ ]).to(torch.int64)
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=scales,
+ ptr_bias=bias,
+ ptr_D=output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=n,
+ K=k,
+ num_experts=num_experts,
+ is_B_int4=is_B_int4,
+ is_B_mxfp4=is_B_mxfp4)
+
+
+def ceilDiv(a, b):
+ return (a + b - 1) // b
+
+
+def compute_num_tokens_per_block(num_tokens, num_experts_per_node):
+ for num_tokens_per_block in [32, 64, 128, 256, 512, 1024]:
+ num_blocks_per_seq = ceilDiv(num_tokens, num_tokens_per_block)
+ if num_blocks_per_seq * num_experts_per_node <= num_tokens_per_block:
+ return num_tokens_per_block
+ return 1024
+
+
+def _bytes_to_typed_tensor(byte_tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
+ """Reinterpret a uint8 buffer as a typed tensor by copying bytes.
+
+ This avoids `Tensor.view(dtype)` which can fail under torch.compile
+ constant folding when shape divisibility is not proven.
+ """
+ if byte_tensor.dtype != torch.uint8:
+ raise ValueError("byte_tensor must be uint8")
+ itemsize = torch.empty((), dtype=dtype).element_size()
+ numel = byte_tensor.numel() // itemsize
+ out = torch.empty((numel,), dtype=dtype, device=byte_tensor.device)
+ out.view(torch.uint8).copy_(byte_tensor.contiguous())
+ return out
+
+
+def implement_zp(qweight):
+ # change u4 to s4 to avoid zero point in gemm kernel
+ # only support default zero point now
+ assert qweight.dtype == torch.uint8, "Input tensor must be uint8"
+
+ high_u4 = (qweight >> 4) & 0x0F
+ low_u4 = qweight & 0x0F
+
+ high_s8 = high_u4.to(torch.int8)
+ low_s8 = low_u4.to(torch.int8)
+
+ high_s8 = high_s8 - 8
+ low_s8 = low_s8 - 8
+
+ def pack_compact(a, b):
+
+ def process_number(x):
+ sign = (x < 0).to(torch.uint8)
+ abs_low3 = (x.view(torch.uint8) & 0x7).to(torch.uint8)
+ return (sign << 3) | abs_low3
+
+ packed_a = process_number(a)
+ packed_b = process_number(b)
+
+ return (packed_a << 4) | packed_b
+
+ result = pack_compact(high_s8, low_s8)
+
+ return result
+
+
+def xpu_fused_moe(hidden_states,
+ w13,
+ w13_scales,
+ w13_bias,
+ w2,
+ w2_scales,
+ w2_bias,
+ topk_weights,
+ topk_ids,
+ n_experts_per_token,
+ activation,
+ num_experts,
+ ep_rank=0,
+ ep_size=1,
+ is_fp8=False,
+ is_int4=False,
+ is_mxfp4=False):
+ '''
+ hidden_states: [num_rows, hidden_size]
+ w13: [num_experts, 2*inter_size, hidden_size]
+ w13_scales:
+ None for bf16/fp16
+ or [num_experts] for fp8
+ or [num_experts, 2*inter_size, hidden_size // group_size] for 4bits
+ w13_bias: [num_experts, 2*inter_size] or None
+ w2: [num_experts, hidden_size, inter_size]
+ w2_scales:
+ None for bf16/fp16
+ or [num_experts] for fp8
+ or [num_experts, hidden_size, inter_size // group_size] for 4bits
+ w2_bias: [num_experts, hidden_size] or None
+ topk_weights: [num_rows, topk]
+ topk_ids: [num_rows, topk]
+ n_experts_per_token: int
+ activation: str
+ num_experts: int
+ is_int4: bool
+ is_mxfp4: bool
+ '''
+
+ # Resolve DTensors to local tensors before passing to custom ops
+ hidden_states = resolve_dtensor(hidden_states)
+ w13 = resolve_dtensor(w13)
+ w2 = resolve_dtensor(w2)
+ if w13_scales is not None:
+ w13_scales = resolve_dtensor(w13_scales)
+ if w13_bias is not None:
+ w13_bias = resolve_dtensor(w13_bias)
+ if w2_scales is not None:
+ w2_scales = resolve_dtensor(w2_scales)
+ if w2_bias is not None:
+ w2_bias = resolve_dtensor(w2_bias)
+ topk_weights = resolve_dtensor(topk_weights)
+ topk_ids = resolve_dtensor(topk_ids)
+
+ output = torch.empty_like(hidden_states)
+ num_rows, hidden_size = list(hidden_states.shape)
+
+ dim_last = w13.shape[-1]
+ dim_second_last = w13.shape[-2]
+
+ # w13 is combined gate+up weights, so one dimension is 2*inter_size
+ # Determine which dimension is hidden_size and which is 2*inter_size
+ if dim_second_last == hidden_size:
+ # w13 is [E, hidden_size, 2*inter_size] - standard layout
+ inter_size = dim_last // 2
+ needs_transpose = False
+ else:
+ # w13 is [E, 2*inter_size, hidden_size] - needs transpose
+ inter_size = dim_second_last // 2
+ needs_transpose = True
+
+ assert w13.is_contiguous() and w2.is_contiguous()
+
+ # 4bits support [E, N, K]
+ # other types [E, K, N]
+ if not is_int4 and not is_mxfp4:
+ if not hasattr(w13, 'xpu_fused_moe'):
+ if needs_transpose:
+ w13.data = w13.transpose(-1, -2).contiguous()
+ w2.data = w2.transpose(-1, -2).contiguous()
+ w13.xpu_fused_moe = True
+ w13.inter_size = inter_size
+ else:
+ inter_size = w13.inter_size
+
+ if is_int4 and not hasattr(w13, 'xpu_fused_moe'):
+ w13_tmp = torch.empty_like(w13)
+ w2_tmp = torch.empty_like(w2)
+ for i in range(num_experts):
+ w13_tmp[i] = implement_zp(w13[i])
+ w2_tmp[i] = implement_zp(w2[i])
+ w13_tmp = w13_tmp.contiguous()
+ w2_tmp = w2_tmp.contiguous()
+ w13.data = w13_tmp
+ w2.data = w2_tmp
+ w13.xpu_fused_moe = True
+
+ # TODO: will all integrated in Cpp func. Temporary expose before gemm fusion
+ num_experts_per_node = num_experts
+ experts_per_token = n_experts_per_token
+ num_moe_inputs = n_experts_per_token * num_rows
+ permuted_elems = num_moe_inputs * hidden_size
+ # interbuf_elems = num_moe_inputs * inter_size
+ permuted_row_to_unpermuted_row_size = num_moe_inputs * 4
+ permuted_token_selected_experts_size = num_moe_inputs * 4
+ src_to_dest_map_size = experts_per_token * num_rows * 4
+ expert_first_token_offset_size = (num_experts_per_node + 1) * 8
+ num_tokens_per_block = compute_num_tokens_per_block(
+ num_rows, num_experts_per_node)
+ num_blocks_per_seq = ceilDiv(num_rows, num_tokens_per_block)
+ blocked_expert_counts_size = num_experts_per_node * num_blocks_per_seq * 4
+ blocked_expert_counts_cumsum_size = blocked_expert_counts_size
+ blocked_row_to_unpermuted_row_size = num_experts_per_node * num_rows * 4
+ permuted_data_size = permuted_elems * hidden_states.element_size()
+ permuted_token_final_scales_size = num_moe_inputs * 4
+
+ ws_map = {}
+ map_offset = 0
+
+ def config_ws(name, size):
+ nonlocal map_offset
+ if size % 256 != 0:
+ size += 256 - size % 256
+ ws_map[name] = (size, map_offset)
+ map_offset += size
+
+ config_ws("permuted_row_to_unpermuted_row",
+ permuted_row_to_unpermuted_row_size)
+ config_ws("permuted_token_selected_experts",
+ permuted_token_selected_experts_size)
+ config_ws("unpermuted_row_to_permuted_row", src_to_dest_map_size)
+ config_ws("blocked_expert_counts", blocked_expert_counts_size)
+ config_ws("blocked_expert_counts_cumsum",
+ blocked_expert_counts_cumsum_size)
+ config_ws("blocked_row_to_unpermuted_row",
+ blocked_row_to_unpermuted_row_size)
+ config_ws("expert_first_token_offset", expert_first_token_offset_size)
+ config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
+ config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
+
+ workspace = torch.zeros(map_offset,
+ dtype=torch.uint8,
+ device=hidden_states.device)
+ if topk_ids.dtype == torch.int32:
+ topk_ids = topk_ids.to(torch.int64)
+ ops.fused_moe_prologue(
+ input=hidden_states,
+ token_selected_experts=topk_ids,
+ token_final_scales=topk_weights,
+ workspace=workspace,
+ hidden_size=hidden_size,
+ inter_size=inter_size,
+ ep_rank=ep_rank,
+ ep_size=ep_size,
+ num_experts_on_rank=num_experts_per_node)
+
+ expert_first_token_offset_bytes = workspace[
+ ws_map["expert_first_token_offset"][1]:
+ ws_map["expert_first_token_offset"][1] +
+ expert_first_token_offset_size]
+ unpermuted_row_to_permuted_row_bytes = workspace[
+ ws_map["unpermuted_row_to_permuted_row"][1]:
+ ws_map["unpermuted_row_to_permuted_row"][1] +
+ src_to_dest_map_size]
+ permuted_row_to_unpermuted_row_bytes = workspace[
+ ws_map["permuted_row_to_unpermuted_row"][1]:
+ ws_map["permuted_row_to_unpermuted_row"][1] +
+ permuted_row_to_unpermuted_row_size]
+
+ if torch.compiler.is_compiling():
+ expert_first_token_offset = _bytes_to_typed_tensor(
+ expert_first_token_offset_bytes, torch.int64
+ )
+ unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
+ unpermuted_row_to_permuted_row_bytes, torch.int32
+ )
+ permuted_row_to_unpermuted_row = _bytes_to_typed_tensor(
+ permuted_row_to_unpermuted_row_bytes, torch.int32
+ )
+ else:
+ expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
+ unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
+ permuted_row_to_unpermuted_row = permuted_row_to_unpermuted_row_bytes.view(torch.int32)
+ gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
+ ws_map["overlapped_gemm1_gemm2_inputs"][1] +
+ permuted_data_size].view(hidden_states.dtype).view(
+ num_moe_inputs, hidden_size)
+ # permuted_token_final_scales = workspace[
+ # ws_map["permuted_token_final_scales"][1]:
+ # ws_map["permuted_token_final_scales"][1] +
+ # permuted_token_final_scales_size].view(torch.float)
+ gemm1_output = torch.empty((num_moe_inputs, 2 * inter_size),
+ dtype=hidden_states.dtype,
+ device=hidden_states.device)
+
+ ########### gemm1 ##################
+ input_B = w13
+
+ if not is_fp8 and not is_int4 and not is_mxfp4:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=gemm1_input,
+ ptr_B=input_B,
+ ptr_scales=None,
+ ptr_bias=w13_bias,
+ ptr_D=gemm1_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=2 * inter_size,
+ K=hidden_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+ else:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=gemm1_input,
+ ptr_B=input_B,
+ ptr_scales=w13_scales,
+ ptr_bias=w13_bias,
+ ptr_D=gemm1_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=2 * inter_size,
+ K=hidden_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+
+ # act
+ act_output = torch.empty((num_moe_inputs, inter_size),
+ dtype=gemm1_output.dtype,
+ device=gemm1_output.device)
+ if activation == "silu":
+ ops.silu_and_mul(act_output, gemm1_output)
+ elif activation == "gelu":
+ ops.gelu_and_mul(act_output, gemm1_output)
+ elif activation == "swigluoai":
+ ops.swigluoai_and_mul(act_output, gemm1_output, 1.702, 7.0)
+ else:
+ raise ValueError(f"Unsupported FusedMoe activation: {activation}.")
+
+ ########### gemm2 ##################
+ input_A = act_output.contiguous()
+ input_B = w2
+ gemm2_output = torch.empty((num_moe_inputs, hidden_size),
+ dtype=hidden_states.dtype,
+ device=hidden_states.device)
+ if not is_fp8 and not is_int4 and not is_mxfp4:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=None,
+ ptr_bias=w2_bias,
+ ptr_D=gemm2_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=hidden_size,
+ K=inter_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+ else:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=w2_scales,
+ ptr_bias=w2_bias,
+ ptr_D=gemm2_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=hidden_size,
+ K=inter_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+
+ ops.moe_gather(output, gemm2_output, topk_weights,
+ permuted_row_to_unpermuted_row,
+ unpermuted_row_to_permuted_row,
+ expert_first_token_offset,
+ num_experts_per_node)
+ return output
+
+
+def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor:
+ """Apply jitter to the input tensor for regularization."""
+ low = 1.0 - moe_jitter_eps
+ high = 1.0 + moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return x * (low + noise * (high - low))
+
+
+def compute_top_k(scores: torch.Tensor, moe_top_k: int):
+ """Compute the top-k scores from the logits."""
+ if moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, moe_top_k, dim=-1)
+
+
+def route_tokens_xpu(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: torch.Tensor,
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ training: bool = False,
+) -> tuple:
+ """Route tokens to experts and compute expert weights and indices (XPU version)."""
+ if training and moe_jitter_eps is not None:
+ x = apply_jitter(x, moe_jitter_eps)
+
+ x_flat = x.view(-1, x.shape[-1])
+ logits = torch.nn.functional.linear(x_flat, router_weight, router_bias)
+ expert_weights, expert_indices = compute_top_k(logits, moe_top_k)
+ expert_weights = expert_weights.softmax(dim=-1)
+ if moe_normalize_expert_weights is not None:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+
+ return logits, expert_weights, expert_indices
+
+
+def _get_device_mesh(model):
+ """Extract device_mesh from child's unused pre_hook closure for EP support."""
+ try:
+ hook = next(
+ h
+ for h in model.experts._forward_pre_hooks.values()
+ if "device_mesh" in h.__code__.co_freevars
+ )
+ return hook.__closure__[
+ hook.__code__.co_freevars.index("device_mesh")
+ ].cell_contents
+ except Exception:
+ return None
+
+
+class MegaBlocksMoeMLP(torch.nn.Module):
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> tuple:
+ """
+ Forward pass through the MoE layer.
+
+ Args:
+ x: Input tensor of shape [batch_size, seq_len, hidden_size] or [tokens, hidden_size]
+
+ Returns:
+ Tuple of (output, expert_weights) where:
+ - output: Tensor of same shape as input
+ - expert_weights: Expert weights for each token [tokens, top_k]
+ """
+
+ # Get MoE parameters from the wrapped modules
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+
+ # Get EP (Expert Parallelism) parameters
+ ep_size = 1
+ ep_rank = 0
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = _get_device_mesh(self)
+ if device_mesh is not None:
+ expert_parallel_group = device_mesh.get_group()
+ if expert_parallel_group is not None:
+ import torch.distributed as dist
+ if dist.is_initialized():
+ ep_size = dist.get_world_size(expert_parallel_group)
+ ep_rank = dist.get_rank(expert_parallel_group)
+
+ # Number of experts on this rank
+ num_experts_on_rank = moe_num_experts // ep_size
+
+ # Detect activation type - check for GptOss-style swigluoai activation
+ # GptOssExperts has alpha and limit attributes for swigluoai
+ if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
+ activation = "swigluoai"
+ else:
+ activation = getattr(self.experts, "activation", "silu")
+
+ # Get weight tensors - support different naming conventions
+ if hasattr(self.experts, "gate_up_proj"):
+ w13 = self.experts.gate_up_proj
+ # NOTE: swigluoai_and_mul kernel expects interleaved layout [g0,u0,g1,u1,...]
+ # which matches GptOss's gate_up_proj layout, so no conversion needed.
+
+ elif hasattr(self.experts, "w1"):
+ # Combine w1 and w3 if stored separately
+ w1 = self.experts.w1
+ w3 = getattr(self.experts, "w3", None)
+ if w3 is not None:
+ w13 = torch.cat([w1, w3], dim=-2)
+ else:
+ w13 = w1
+ else:
+ raise AttributeError("experts module must have 'gate_up_proj' or 'w1' attribute")
+
+ if hasattr(self.experts, "down_proj"):
+ w2 = self.experts.down_proj
+ elif hasattr(self.experts, "w2"):
+ w2 = self.experts.w2
+ else:
+ raise AttributeError("experts module must have 'down_proj' or 'w2' attribute")
+
+ # Get optional bias tensors
+ w13_bias = getattr(self.experts, "gate_up_proj_bias", None)
+ w2_bias = getattr(self.experts, "down_proj_bias", None)
+
+ # Get quantization info
+ is_fp8 = getattr(self.experts, "is_fp8", False)
+ is_int4 = getattr(self.experts, "is_int4", False)
+ is_mxfp4 = getattr(self.experts, "is_mxfp4", False)
+
+ w13_scales = getattr(self.experts, "gate_up_proj_scales", None)
+ w2_scales = getattr(self.experts, "down_proj_scales", None)
+
+ # Store original shape
+ in_shape = x.size()
+
+ # Route tokens to experts
+ logits, expert_weights, expert_indices = route_tokens_xpu(
+ x,
+ self.router.weight,
+ getattr(self.router, "bias", None),
+ moe_top_k,
+ moe_num_experts,
+ moe_jitter_eps,
+ moe_normalize_expert_weights,
+ self.training,
+ )
+
+ # Reshape input for fused MoE
+ x_flat = x.view(-1, x.shape[-1])
+
+ # Call XPU fused MoE kernel
+ output = xpu_fused_moe(
+ hidden_states=x_flat,
+ w13=w13,
+ w13_scales=w13_scales,
+ w13_bias=w13_bias,
+ w2=w2,
+ w2_scales=w2_scales,
+ w2_bias=w2_bias,
+ topk_weights=expert_weights.float(),
+ topk_ids=expert_indices,
+ n_experts_per_token=moe_top_k,
+ activation=activation,
+ num_experts=num_experts_on_rank,
+ ep_rank=ep_rank,
+ ep_size=ep_size,
+ is_fp8=is_fp8,
+ is_int4=is_int4,
+ is_mxfp4=is_mxfp4,
+ )
+
+ # All-reduce across EP group to combine partial expert outputs
+ if ep_size > 1 and expert_parallel_group is not None:
+ import torch.distributed as dist
+ dist.all_reduce(output, op=dist.ReduceOp.SUM, group=expert_parallel_group)
+
+ # Restore original shape
+ output = output.view(in_shape)
+
+ return output, expert_weights
+
+
+# Export classes and functions
+__all__ = [
+ "MegaBlocksMoeMLP",
+ "xpu_fused_moe",
+ "cutlass_grouped_gemm",
+ "cutlass_grouped_gemm_xe2",
+]
\ No newline at end of file
diff --git a/build/torch210-cxx11-cu130-x86_64-linux/__init__.py b/build/torch210-cxx11-cu130-x86_64-linux/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..38075732c6d8fa0e1e6ef493145e1aca3851ae6b
--- /dev/null
+++ b/build/torch210-cxx11-cu130-x86_64-linux/__init__.py
@@ -0,0 +1,202 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+from ._ops import ops
+
+from .grouped_gemm import backend as gg_backend
+from .grouped_gemm import ops as gg_ops
+
+
+from ._layers.arguments import Arguments
+from ._layers.dmoe import ParallelDroplessMLP, dMoE
+from ._layers.glu import SparseGLU
+from ._layers.mlp import MLP, SparseMLP
+from ._layers.moe import MoE, ParallelMLP, get_load_balancing_loss
+
+from . import layers
+
+# This section contains the direct kernel exports (not inlcuded in the original code)
+def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
+ """
+ Compute exclusive cumulative sum along the specified dimension.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ result = ops.exclusive_cumsum(x, dim)
+ out.copy_(result)
+ return out
+
+
+def inclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
+ """
+ Compute inclusive cumulative sum along the specified dimension.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ result = ops.inclusive_cumsum(x, dim)
+ out.copy_(result)
+ return out
+
+
+def histogram(x: torch.Tensor, num_bins: int) -> torch.Tensor:
+ """
+ Compute histogram of input tensor values.
+
+ Args:
+ x: Input tensor
+ num_bins: Number of histogram bins
+
+ Returns:
+ Histogram tensor with counts for each bin
+ """
+ return ops.histogram(x, num_bins)
+
+
+def indices(
+ padded_bins: torch.Tensor,
+ block_size: int,
+ output_block_rows: int,
+ output_block_columns: int,
+) -> torch.Tensor:
+ """
+ Construct indices from padded bins for sparse operations.
+
+ Args:
+ padded_bins: Tensor containing bin boundaries
+ block_size: Size of each block
+ output_block_rows: Number of rows in output blocks
+ output_block_columns: Number of columns in output blocks
+
+ Returns:
+ Tensor containing constructed indices
+ """
+ return ops.indices(padded_bins, block_size, output_block_rows, output_block_columns)
+
+
+def replicate_forward(
+ x: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Forward pass of replicate operation - replicate values according to bin sizes.
+
+ Args:
+ x: Input tensor with values to replicate
+ bins: Tensor containing bin sizes
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ return ops.replicate_forward(x, bins, out)
+
+
+def replicate_backward(
+ grad: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Backward pass of replicate operation - reduce gradients back to bins.
+
+ Args:
+ grad: Gradient tensor to reduce
+ bins: Tensor containing bin sizes
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ return ops.replicate_backward(grad, bins, out)
+
+
+def sort(
+ x: torch.Tensor, end_bit: int, x_out: torch.Tensor, iota_out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Radix sort with index tracking.
+
+ Args:
+ x: Input tensor to sort
+ end_bit: Number of bits to consider in sorting
+ x_out: Output tensor for sorted values
+ iota_out: Output tensor for sorted indices
+
+ Returns:
+ The sorted values tensor
+ """
+ return ops.sort(x, end_bit, x_out, iota_out)
+
+
+# Convenience functions for common use cases
+def cumsum(x: torch.Tensor, dim: int = -1, exclusive: bool = False) -> torch.Tensor:
+ """
+ Compute cumulative sum with automatic output allocation.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum (default: last dimension)
+ exclusive: Whether to compute exclusive (True) or inclusive (False) cumsum
+
+ Returns:
+ New tensor containing the cumulative sum
+ """
+ out = torch.empty_like(x)
+ if exclusive:
+ return exclusive_cumsum(x, dim, out)
+ else:
+ return inclusive_cumsum(x, dim, out)
+
+
+def argsort(x: torch.Tensor, end_bit: int = 32) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Sort tensor and return both sorted values and indices.
+
+ Args:
+ x: Input tensor to sort
+ end_bit: Number of bits to consider in sorting
+
+ Returns:
+ Tuple of (sorted_values, sorted_indices)
+ """
+ x_out = torch.empty_like(x)
+ iota_out = torch.empty_like(x)
+ sort(x, end_bit, x_out, iota_out)
+ return x_out, iota_out
+
+
+# Export public API
+__all__ = [
+ "MyReplacementLayer",
+ # Direct kernel exports
+ "exclusive_cumsum",
+ "inclusive_cumsum",
+ "histogram",
+ "indices",
+ "replicate_forward",
+ "replicate_backward",
+ "sort",
+ "cumsum",
+ "argsort",
+ # Original exports
+ "Arguments",
+ "ParallelDroplessMLP",
+ "dMoE",
+ "SparseGLU",
+ "MLP",
+ "SparseMLP",
+ "MoE",
+ "ParallelMLP",
+ "get_load_balancing_loss",
+]
diff --git a/build/torch210-cxx11-cu130-x86_64-linux/_layers/__init__.py b/build/torch210-cxx11-cu130-x86_64-linux/_layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a720e7a2cc4e44636f6e433a2750e945dc38e8b2
--- /dev/null
+++ b/build/torch210-cxx11-cu130-x86_64-linux/_layers/__init__.py
@@ -0,0 +1,10 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# from megablocks.layers.dmoe import dMoE
+from .moe import MoE
+
+__all__ = [
+ 'MoE',
+ # 'dMoE',
+]
diff --git a/build/torch210-cxx11-cu130-x86_64-linux/_layers/activation_fn.py b/build/torch210-cxx11-cu130-x86_64-linux/_layers/activation_fn.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e1d956704840aa4daf7d1d71d24e051567feab9
--- /dev/null
+++ b/build/torch210-cxx11-cu130-x86_64-linux/_layers/activation_fn.py
@@ -0,0 +1,33 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Callable, Union
+
+import torch
+from ..stk import Matrix
+
+
+def act_fn(
+ x: Matrix,
+ function: Callable,
+ return_grad_fn: bool = False,
+ **kwargs,
+) -> Union[tuple[Matrix, Any] | Matrix]:
+ assert isinstance(x, Matrix)
+ with torch.set_grad_enabled(torch.is_grad_enabled() or return_grad_fn):
+ if return_grad_fn:
+ x.data.requires_grad = True
+ out = function(x.data, **kwargs)
+ y = Matrix(
+ x.size(),
+ out,
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
+ if return_grad_fn:
+ return y, out.backward
+ return y
diff --git a/build/torch210-cxx11-cu130-x86_64-linux/_layers/all_to_all.py b/build/torch210-cxx11-cu130-x86_64-linux/_layers/all_to_all.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ac7067bcaa34db1d82b340c43550fe3577aa7a3
--- /dev/null
+++ b/build/torch210-cxx11-cu130-x86_64-linux/_layers/all_to_all.py
@@ -0,0 +1,54 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import torch.distributed as dist
+
+
+class AllToAllOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op):
+ out = torch.empty((sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype)
+
+ ctx.input_shape = x.shape
+ ctx.output_split_sizes = output_split_sizes
+ ctx.input_split_sizes = input_split_sizes
+ ctx.group = group
+ handle = dist.all_to_all_single(
+ out,
+ x,
+ output_split_sizes=output_split_sizes,
+ input_split_sizes=input_split_sizes,
+ group=group,
+ async_op=async_op,
+ )
+ return out, handle
+
+ @staticmethod
+ def backward(ctx, grad, _):
+ if ctx.needs_input_grad[0]:
+ out = torch.empty(
+ ctx.input_shape,
+ device=grad.device,
+ dtype=grad.dtype,
+ )
+ dist.all_to_all_single(
+ out,
+ grad,
+ output_split_sizes=ctx.input_split_sizes,
+ input_split_sizes=ctx.output_split_sizes,
+ group=ctx.group,
+ )
+ return out, None, None, None, None
+ return None, None, None, None, None
+
+
+def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False):
+ return AllToAllOp.apply(
+ x,
+ output_split_sizes,
+ input_split_sizes,
+ group,
+ async_op,
+ )
diff --git a/build/torch210-cxx11-cu130-x86_64-linux/_layers/arguments.py b/build/torch210-cxx11-cu130-x86_64-linux/_layers/arguments.py
new file mode 100644
index 0000000000000000000000000000000000000000..4db9b1bd38bc2e2f421625c124f86b85f45c5ae0
--- /dev/null
+++ b/build/torch210-cxx11-cu130-x86_64-linux/_layers/arguments.py
@@ -0,0 +1,101 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import dataclasses
+from functools import partial
+from typing import Any, Callable, Optional, Union
+
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+
+# import megablocks.grouped_gemm_util as grouped_gemm
+from .. import grouped_gemm_util as grouped_gemm
+
+# Type annotation for in-place Tensor initialization function.
+InitFn = Union[Callable[[torch.Tensor], None], partial[torch.Tensor]]
+
+_ALLOWED_BITWIDTHS = (-1, 4, 8)
+
+DEFAULT_ACTIVATION_FN = partial(F.gelu, approximate='tanh')
+
+
+@dataclasses.dataclass
+class Arguments:
+ # Model arguments.
+ hidden_size: int = 1024
+ ffn_hidden_size: int = 4096
+ num_layers: int = 1
+ bias: bool = True
+ return_bias: bool = True
+ activation_fn: Optional[Callable] = DEFAULT_ACTIVATION_FN
+
+ # MoE arguments.
+ moe_num_experts: int = 1
+ moe_top_k: int = 1
+ moe_capacity_factor: int = 1
+ moe_normalize_expert_weights: Optional[Union[int, float]] = None
+ moe_loss_weight: float = 0.1
+ moe_jitter_eps: Optional[float] = None
+ moe_lbl_in_fp32: bool = False
+
+ # Parallelism arguments.
+ moe_expert_model_parallelism: bool = False
+ expert_parallel_group: Optional[dist.ProcessGroup] = None
+ pipeline_model_parallel_size: int = 1
+ num_layers_per_virtual_pipeline_stage: Optional[int] = None
+
+ # Compute arguments.
+ memory_optimized_mlp: bool = False
+ mlp_type: str = 'mlp'
+ mlp_impl: str = 'sparse'
+
+ # Initialization arguments.
+ fp16: bool = True
+ bf16: bool = False
+ device: Union[int, torch.device] = dataclasses.field(default_factory=torch.cuda.current_device)
+ init_method: InitFn = partial(torch.nn.init.normal_, mean=0.0, std=0.02)
+ output_layer_init_method: InitFn = init_method
+
+ # Benchmarking arguments.
+ uniform_expert_assignment: bool = False
+
+ # shared expert arguments
+ shared_expert: bool = False # enable using shared expert
+ fc_cls: Any = torch.nn.Linear # class of the fully connected layer in shared expert (purpose: to allow using custom FC layer eg te.Linear (for FP8))
+ fc_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict,) # kwargs for custom fc layers
+ remat_act_fn: bool = True # enable act fn to be rematerialized instead of stored
+ shared_expert_hidden_size: Optional[
+ int] = None # hidden size of the shared expert IF we want to set it to something different from hidden_size
+ shared_expert_weighted_sum: bool = False # enable using weighted sum for shared expert output (wieghted by number of experts used)
+
+ # Router Z-loss arguments
+ moe_zloss_weight: float = 0 # 1e-3 is a reasonable value
+ moe_zloss_in_fp32: bool = False
+
+ def __post_init__(self):
+ # Sparse MLP is not supported with triton >=3.2.0
+ # TODO: Remove this once sparse is supported with triton >=3.2.0
+ if self.__getattribute__('mlp_impl') == 'sparse':
+ try:
+ import triton
+ if triton.__version__ >= '3.2.0':
+ raise ValueError(
+ 'Sparse MLP is not supported with triton >=3.2.0. Please use mlp_impl="grouped" instead.',
+ )
+ except ImportError:
+ raise ImportError('Triton is required for sparse MLP implementation')
+
+ if self.__getattribute__('mlp_impl') == 'grouped':
+ grouped_gemm.assert_grouped_gemm_is_available()
+
+ if self.shared_expert_hidden_size is None:
+ self.shared_expert_hidden_size = self.ffn_hidden_size
+
+
+def from_megatron(megatron_args: Any):
+ args = Arguments()
+ for field in dataclasses.fields(args):
+ if hasattr(megatron_args, field.name):
+ setattr(args, field.name, getattr(megatron_args, field.name))
+ return args
diff --git a/build/torch210-cxx11-cu130-x86_64-linux/_layers/common.py b/build/torch210-cxx11-cu130-x86_64-linux/_layers/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d07109702963ba48a3b94ab860807954dfd79c1
--- /dev/null
+++ b/build/torch210-cxx11-cu130-x86_64-linux/_layers/common.py
@@ -0,0 +1,26 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+from .arguments import Arguments
+
+
+def dtype(args: Arguments):
+ if args.fp16:
+ return torch.float16
+ elif args.bf16:
+ return torch.bfloat16
+ return None
+
+
+def cast_if_autocast_enabled(tensor):
+ if torch.is_autocast_enabled():
+ if tensor.device.type == 'cuda':
+ dtype = torch.get_autocast_gpu_dtype()
+ elif tensor.device.type == 'cpu':
+ dtype = torch.get_autocast_cpu_dtype()
+ else:
+ raise NotImplementedError()
+ return tensor.to(dtype=dtype)
+ return tensor
diff --git a/build/torch210-cxx11-cu130-x86_64-linux/_layers/dmlp_registry.py b/build/torch210-cxx11-cu130-x86_64-linux/_layers/dmlp_registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..de2ed047042e438c7190ebb139b6f7f30009734c
--- /dev/null
+++ b/build/torch210-cxx11-cu130-x86_64-linux/_layers/dmlp_registry.py
@@ -0,0 +1,42 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Union
+
+from . import glu, mlp
+from .arguments import Arguments
+
+MlpType = Union[mlp.SparseMLP, glu.SparseGLU]
+
+_REGISTRY = {
+ 'mlp': {
+ 'grouped': mlp.GroupedMLP,
+ 'sparse': mlp.SparseMLP,
+ },
+ 'glu': {
+ 'grouped': glu.GroupedGLU,
+ 'sparse': glu.SparseGLU,
+ },
+}
+
+
+def get(args: Arguments) -> MlpType:
+ """Returns an MLP for use in a dMoE instance.
+
+ Uses the provided arguments to instantiate the appropriate
+ MLP instance. This only contains MLPs for use in dMoEs
+ (ie. only for the dropless versions of MoEs).
+
+ Args:
+ args: propagated Arguments dataclass.
+
+ Returns:
+ An instantiated MLP constructed using the input args.
+ """
+ if args.mlp_type not in _REGISTRY:
+ raise ValueError(f'Unsupported mlp type: {args.mlp_type}')
+
+ if args.mlp_impl not in _REGISTRY[args.mlp_type]:
+ raise ValueError(f'{args.mlp_type} does not support {args.mlp_impl} backend.',)
+
+ return _REGISTRY[args.mlp_type][args.mlp_impl](args)
diff --git a/build/torch210-cxx11-cu130-x86_64-linux/_layers/dmoe.py b/build/torch210-cxx11-cu130-x86_64-linux/_layers/dmoe.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d0375a4df2f27134c4127e60be04f3b45693050
--- /dev/null
+++ b/build/torch210-cxx11-cu130-x86_64-linux/_layers/dmoe.py
@@ -0,0 +1,337 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import numpy as np
+import torch
+
+# try:
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/dmoe.py is needed.',
+# )
+
+# import megablocks.ops as ops
+# # from megablocks.ops import ops
+# from megablocks.layers import common, dmlp_registry, moe, mpu
+# from megablocks.layers.arguments import Arguments
+
+from .. import stk
+from .. import ops
+from . import common, dmlp_registry, moe, mpu
+from .arguments import Arguments
+
+def promote_scalar(x):
+ return x.view(1) if not len(x.size()) else x
+
+
+class ParallelDroplessMLP(moe.ParallelMLP):
+
+ def __init__(self, args: Arguments):
+ super(ParallelDroplessMLP, self).__init__(args)
+ self.hidden_size = args.hidden_size
+ self.ffn_hidden_size = mpu.features_per_rank(args)
+ self.blocking = 128
+ self.mlp = dmlp_registry.get(args)
+
+ # Calculate the number of bits needed to represent the column indices
+ # in the intermediate sparse matrix.
+ max_column_index = ((self.ffn_hidden_size * self.num_experts) // self.blocking)
+ self.transpose_sort_end_bit = max(
+ int(np.ceil(np.log2(max_column_index))),
+ 1,
+ )
+
+ def sparse_transpose(self, size, row_indices, column_indices, offsets):
+ block_columns = size[1] // self.blocking
+
+ # Sort row indices by column indices to get the transposed matrix's
+ # column indices.
+ #
+ # NOTE: Our sort operation uses the same width indices as the input values.
+ # To avoid overflow when we have large activation matrices we cast to
+ # 32-bit before sorting.
+ _, gather_indices = ops.sort(
+ column_indices.int(),
+ self.transpose_sort_end_bit,
+ )
+
+ # There are a constant number of blocks in every row of the sparse matrix.
+ # A blocks offset is:
+ #
+ # row_index * blocks_per_row + column_index % blocks_per_row
+ #
+ # Once we have the block offsets ordered for transposition we can divide
+ # by blocks_per_row to get the transposed column indices.
+ column_indices_t = row_indices.gather(0, gather_indices.long())
+ block_offsets_t = gather_indices.int()
+
+ zero = torch.zeros((1,), dtype=torch.int32, device=row_indices.device)
+ nnz_per_column = ops.histogram(column_indices, block_columns)
+ nnz_per_column = ops.inclusive_cumsum(nnz_per_column, 0)
+ if nnz_per_column.dim() == 0:
+ # This addresses an edge case when ffn_hidden_size is equal to self.blocking.
+ nnz_per_column = nnz_per_column.unsqueeze(0)
+ offsets_t = torch.cat([zero, nnz_per_column])
+ return column_indices_t, offsets_t, block_offsets_t
+
+ def topology(self, x, padded_bins):
+ padded_tokens, _ = x.size()
+ assert padded_tokens % self.blocking == 0
+ if self.ffn_hidden_size % self.blocking != 0:
+ raise ValueError(
+ f'The ffn_hidden_size {self.ffn_hidden_size} must be divisible by ' +
+ f'the block size {self.blocking}. Please update your configuration.',
+ )
+
+ # Offsets for the sparse matrix. All rows have the
+ # same number of nonzero blocks dictated by the
+ # dimensionality of a single expert.
+ block_rows = padded_tokens // self.blocking
+ blocks_per_row = self.ffn_hidden_size // self.blocking
+ offsets = torch.arange(
+ 0,
+ block_rows * blocks_per_row + 1,
+ blocks_per_row,
+ dtype=torch.int32,
+ device=x.device,
+ )
+
+ # Indices for the sparse matrix. The indices for
+ # the intermediate matrix are dynamic depending
+ # on the mapping of tokens to experts.
+ column_indices = ops.topology(
+ padded_bins,
+ self.blocking,
+ block_rows,
+ blocks_per_row,
+ )
+
+ # TODO(tgale): This is unused. Remove the need for this in stk.
+ # For now, use meta init to save the device memory.
+ data = torch.empty(
+ column_indices.numel(),
+ self.blocking,
+ self.blocking,
+ dtype=common.dtype(self.args),
+ device='meta',
+ )
+ shape = (
+ padded_tokens,
+ self.ffn_hidden_size * mpu.experts_per_rank(self.args),
+ )
+ row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
+ column_indices_t, offsets_t, block_offsets_t = self.sparse_transpose(
+ shape,
+ row_indices,
+ column_indices,
+ offsets,
+ )
+ return stk.Matrix(
+ shape,
+ data,
+ row_indices,
+ column_indices,
+ offsets,
+ column_indices_t,
+ offsets_t,
+ block_offsets_t,
+ )
+
+ def indices_and_padded_bins(self, top_experts):
+ # Sort the expert ids to produce the scatter/gather
+ # indices for the permutation.
+ top_experts = top_experts.int()
+ bin_ids, indices = ops.sort(top_experts, self.sort_end_bit)
+
+ # Histogram the expert ids to identify the number of
+ # tokens routed to each expert.
+ tokens_per_expert = ops.histogram(top_experts, self.num_experts)
+
+ # Round the token counts up to the block size used in
+ # the matrix muliplications. Caculate the starting
+ # position of each bin.
+ padded_tokens_per_expert = ops.round_up(
+ tokens_per_expert,
+ self.blocking,
+ )
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ padded_bins = promote_scalar(padded_bins)
+
+ # Calculate the bin bounds for the sorted tokens.
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ bins = promote_scalar(bins)
+ return indices, bin_ids, bins, padded_bins, tokens_per_expert
+
+ def sparse_forward_once(self, x, expert_weights, top_experts):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, padded_bins, tokens_per_expert = (self.indices_and_padded_bins(top_experts))
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.padded_gather(
+ x,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ self.top_k,
+ )
+
+ # Create the sparse matrix topology.
+ with torch.no_grad():
+ topo = self.topology(x, padded_bins)
+
+ # Perform the expert computation.
+ x = self.mlp(x, topo)
+
+ # Un-route the data for the MoE output.
+ x = ops.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ padded_bins,
+ self.top_k,
+ )
+ return x, tokens_per_expert
+
+ # For use in the base-class parallel_forward_once.
+ def sparse_permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy, # unused
+ top_k,
+ ):
+
+ # Round the token counts up to the block size used in the matrix
+ # multiplication. Calculate the starting position of each bin.
+ padded_tokens_per_expert = ops.round_up(
+ tokens_per_expert,
+ self.blocking,
+ )
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ padded_bins = promote_scalar(padded_bins)
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
+
+ # Create the sparse matrix topology.
+ with torch.no_grad():
+ topo = self.topology(x, padded_bins)
+
+ # Perform the expert computation.
+ x = self.mlp(x, topo)
+
+ # Un-route the data for the MoE output.
+ return ops.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ def grouped_forward_once(self, x, expert_weights, top_experts):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ out = self.grouped_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ -1, # unused
+ self.args.moe_top_k,
+ )
+ return out, tokens_per_expert
+
+ def grouped_permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy, # unused
+ top_k,
+ ):
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.gather(x, indices, bin_ids, bins, top_k)
+
+ # Perform the expert computation.
+ x = self.mlp(x, tokens_per_expert)
+
+ # Un-route the data for the MoE output.
+ return ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
+
+ def forward_once(self, x, expert_weights, top_experts):
+ if self.args.mlp_impl == 'sparse':
+ return self.sparse_forward_once(x, expert_weights, top_experts)
+ else:
+ return self.grouped_forward_once(x, expert_weights, top_experts)
+
+ def permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ ):
+ if self.args.mlp_impl == 'sparse':
+ return self.sparse_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ )
+ else:
+ return self.grouped_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ )
+
+
+class dMoE(moe.MoE):
+
+ def _init_experts_mlp(self, args: Arguments):
+ return ParallelDroplessMLP(args)
diff --git a/build/torch210-cxx11-cu130-x86_64-linux/_layers/gelu.py b/build/torch210-cxx11-cu130-x86_64-linux/_layers/gelu.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4c9e6532798615b5c12c96694241a4c18ee8f7b
--- /dev/null
+++ b/build/torch210-cxx11-cu130-x86_64-linux/_layers/gelu.py
@@ -0,0 +1,52 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# try:
+# import stk
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/gelu.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+import torch.nn.functional as F
+
+
+@torch.jit.script
+def _gelu_backward_inplace(g, x):
+ tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
+ ff = (0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out))
+ return g.mul_(ff)
+
+
+def gelu_backward_(grad: stk.Matrix, x: stk.Matrix):
+ # NOTE: The two sparse matrices must have the same topology.
+ if isinstance(grad, stk.Matrix) and isinstance(x, stk.Matrix):
+ return stk.Matrix(
+ x.size(),
+ _gelu_backward_inplace(grad.data, x.data),
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
+ return _gelu_backward_inplace(grad, x)
+
+
+def gelu(x: stk.Matrix):
+ assert isinstance(x, stk.Matrix)
+ return stk.Matrix(
+ x.size(),
+ F.gelu(x.data, approximate='tanh'),
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
diff --git a/build/torch210-cxx11-cu130-x86_64-linux/_layers/glu.py b/build/torch210-cxx11-cu130-x86_64-linux/_layers/glu.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f297a41ff6a1a2a285f5b461951672364b898da
--- /dev/null
+++ b/build/torch210-cxx11-cu130-x86_64-linux/_layers/glu.py
@@ -0,0 +1,244 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# import stk.ops
+# try:
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/glu.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+
+# from megablocks import grouped_gemm_util as gg
+# from megablocks.layers import common, mpu
+# from megablocks.layers.activation_fn import act_fn
+# from megablocks.layers.arguments import Arguments
+# from megablocks.layers.mlp import (
+# SharedMLP,
+# SparseMLP,
+# create_dmoe_expert_weights,
+# resolve_dtensor,
+# )
+
+from .. import grouped_gemm_util as gg
+from . import common, mpu
+from .activation_fn import act_fn
+from .arguments import Arguments
+from .mlp import (
+ SharedMLP,
+ SparseMLP,
+ create_dmoe_expert_weights,
+ resolve_dtensor,
+)
+
+
+class SparseGLU(SparseMLP):
+
+ def __init__(self, args: Arguments):
+ super().__init__(args)
+ self.v1 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ with torch.no_grad():
+ self.v1.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ ),
+ )
+
+ mpu.set_expert_model_parallel_attributes(
+ self.v1,
+ self._should_set_parallelism_attribute,
+ )
+
+ def forward(self, x, topo):
+ if self.args.memory_optimized_mlp:
+ raise NotImplementedError(
+ 'Memory optimized implementation not yet supported with GLU with sparse kernels.',
+ )
+
+ w1, v1, w2 = self.scale_grad(self.w1), self.scale_grad(self.v1,), self.scale_grad(self.w2)
+ w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2)
+
+ # Compute the GLU.
+ x1 = stk.ops.sdd(x, w1.t(), topo)
+ x2 = stk.ops.sdd(x, v1.t(), topo)
+
+ activation_fn_out = act_fn(x1, self.args.activation_fn)
+ x1 = stk.ops.mul(activation_fn_out, x2)
+
+ return stk.ops.dsd(x1, w2)
+
+
+class MemoryOptimizedGroupedGLU(torch.autograd.Function):
+ """GroupedMLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, v1, w2, batch_sizes, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ v1 = v1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], v1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not v1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1', 'v1' and 'w2'.")
+
+ # Layer 0: x @ w1.t().
+ assert gg.backend is not None
+ sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True)
+ v1_out = gg.backend.gmm(x, v1, batch_sizes, trans_b=True)
+
+ # GeLU.
+ activation_fn_out = activation_fn(sdd_out) * v1_out
+
+ # Layer 1: x @ w2.
+ dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, v1, w2, batch_sizes, x, sdd_out, v1_out)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx, ddsd_out):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # Unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, v1, w2 = saved_tensors[:3]
+ batch_sizes = saved_tensors[3]
+ x = saved_tensors[4]
+ sdd_out, v1_out = saved_tensors[5:7]
+
+ # Rematerialize activation_fn output.
+ activation_fn = ctx.activation_fn
+ with torch.set_grad_enabled(True):
+ sdd_out.requires_grad = True
+ v1_out.requires_grad = True
+ activation_fn_out = activation_fn(sdd_out) * v1_out
+ activation_grad_fn = activation_fn_out.backward
+
+ # Compute dw2 with recomputed activation_fn output.
+ assert gg.backend is not None
+ dw2 = gg.backend.gmm(
+ activation_fn_out,
+ ddsd_out,
+ batch_sizes,
+ trans_a=True,
+ )
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ gg.backend.gmm(
+ ddsd_out,
+ w2,
+ batch_sizes,
+ trans_b=True,
+ c=dactivation_fn_out,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out)
+ dsdd_out = sdd_out.grad
+ dv1_out = v1_out.grad
+
+ # Compute dw1.
+ dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True)
+
+ # Compute dv1.
+ dv1 = gg.backend.gmm(dv1_out, x, batch_sizes, trans_a=True)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ dx = ddsd_out
+ gg.backend.gmm(dsdd_out, w1, batch_sizes, c=dx)
+ dx += gg.backend.gmm(dv1_out, v1, batch_sizes)
+ return dx, dw1, dv1, dw2, None, None
+
+
+memory_optimized_grouped_glu = MemoryOptimizedGroupedGLU.apply
+
+
+class GroupedGLU(SparseGLU):
+
+ def forward(self, x, tokens_per_expert):
+ batch_sizes = tokens_per_expert.cpu().to(torch.long)
+ w1, v1, w2 = (
+ self.scale_grad(self.w1),
+ self.scale_grad(self.v1),
+ self.scale_grad(self.w2),
+ )
+ w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2)
+
+ # Re-shape the weights for the grouped GEMMs.
+ ne = mpu.experts_per_rank(self.args)
+ w1 = w1.view(ne, -1, self.args.hidden_size)
+ v1 = v1.view(ne, -1, self.args.hidden_size)
+ w2 = w2.view(ne, -1, self.args.hidden_size)
+
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_grouped_glu(
+ x,
+ w1,
+ v1,
+ w2,
+ batch_sizes,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ assert gg.ops is not None
+ x1 = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
+ x2 = gg.ops.gmm(x, v1, batch_sizes, trans_b=True)
+ x1 = self.args.activation_fn(x1) * x2
+ return gg.ops.gmm(x1, w2, batch_sizes)
+
+
+class SharedGLU(SharedMLP):
+ """GPU for shared expert.
+
+ Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTGLU class
+ """
+
+ def __init__(self, args: Arguments):
+ super().__init__(args)
+ self.gate_proj = args.fc_cls(
+ args.hidden_size,
+ self.args.shared_expert_hidden_size,
+ **self.fc_kwargs,
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))
diff --git a/build/torch210-cxx11-cu130-x86_64-linux/_layers/memory_test.py b/build/torch210-cxx11-cu130-x86_64-linux/_layers/memory_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..74d1166931b712635131985b25a89f4ca23e576d
--- /dev/null
+++ b/build/torch210-cxx11-cu130-x86_64-linux/_layers/memory_test.py
@@ -0,0 +1,103 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import gc
+
+import torch
+import torch.distributed as dist
+
+# from megablocks.layers import arguments, dmoe
+from . import arguments, dmoe
+
+_TESTS = ((8, 2048, 4096, 4096, 32, 4),)
+
+
+def get_tensors():
+ ptrs = set()
+ out = []
+ for obj in gc.get_objects():
+ if torch.is_tensor(obj):
+ if not obj.is_contiguous() or obj.data_ptr() in ptrs:
+ continue
+ out.append(obj)
+ ptrs.add(obj.data_ptr())
+ return out
+
+
+def test_memory(
+ group,
+ batch_size,
+ sequence_length,
+ hidden_size,
+ ffn_hidden_size,
+ num_experts,
+ top_k,
+):
+ args = arguments.Arguments(
+ hidden_size=hidden_size,
+ ffn_hidden_size=ffn_hidden_size,
+ moe_num_experts=num_experts,
+ moe_top_k=top_k,
+ moe_expert_model_parallelism=True,
+ expert_parallel_group=group,
+ fp16=False,
+ bf16=True,
+ device=torch.cuda.current_device(),
+ )
+ layer = dmoe.dMoE(args).cuda()
+
+ x = torch.randn((batch_size, sequence_length, hidden_size),
+ device=torch.cuda.current_device(),
+ dtype=torch.bfloat16).requires_grad_(True)
+ torch.cuda.empty_cache()
+
+ # Run forward + backward.
+ # with torch.autograd.detect_anomaly():
+ out, _ = layer(x)
+ out.mean().backward()
+
+ # Report peak memory.
+ mem = torch.cuda.max_memory_allocated()
+ print('Max Memory Allocated = {:0.0f}MiB'.format(mem / 1e6))
+ print('Max Memory Reserved = {:0.0f}MiB'.format(torch.cuda.max_memory_reserved() / 1e6,),)
+
+ # Calculate weight and gradient memory usage.
+ weight_memory = 2 * (
+ layer.router.layer.weight.numel() + layer.experts.mlp.w1.numel() + layer.experts.mlp.w2.numel()
+ )
+
+ def grad_numel(x):
+ if x.grad is not None:
+ return x.grad.numel()
+ return 0
+
+ grad_memory = 2 * (
+ grad_numel(layer.router.layer.weight) + grad_numel(layer.experts.mlp.w1) + grad_numel(layer.experts.mlp.w2)
+ )
+ weight_memory += grad_memory
+
+ print('Weight Memory Allocated = {:0.0f}MiB'.format(weight_memory / 1e6))
+ print('Activation Memory Allocated = {:0.0f}MiB'.format((mem - weight_memory) / 1e6,),)
+
+ # Manually calculate GPU memory usage from the garbage
+ # collector.
+ gc.collect()
+ total = 0
+ tensors = get_tensors()
+ tensors = sorted(tensors, key=lambda x: -x.numel())
+ for i, t in enumerate(tensors):
+ total += t.numel()
+ print(f'{i}: {t.shape}, {t.numel() * 2}')
+ del tensors
+
+ print('Total Bytes Found = {:0.0f}MiB'.format(total * 2 / 1e6))
+
+
+if __name__ == '__main__':
+ assert dist.is_available()
+ group = dist.init_process_group(backend='nccl')
+ local_rank = dist.get_rank(group)
+ torch.cuda.set_device(local_rank)
+
+ for args in _TESTS:
+ test_memory(group, *args)
diff --git a/build/torch210-cxx11-cu130-x86_64-linux/_layers/mlp.py b/build/torch210-cxx11-cu130-x86_64-linux/_layers/mlp.py
new file mode 100644
index 0000000000000000000000000000000000000000..c99afb9904c24a8b6a83e79059cd1251dbbfd99e
--- /dev/null
+++ b/build/torch210-cxx11-cu130-x86_64-linux/_layers/mlp.py
@@ -0,0 +1,587 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# try:
+# import stk
+# import stk.backend.triton_kernels
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/mlp.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+from packaging import version
+
+# from megablocks import grouped_gemm_util as gg
+# from megablocks.layers import common, gelu, mpu
+# from megablocks.layers.activation_fn import act_fn
+# from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
+
+from .. import grouped_gemm_util as gg
+from . import common, gelu, mpu
+from .activation_fn import act_fn
+from .arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
+
+class ScaleGradient(torch.autograd.Function):
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx: Any, x: torch.Tensor, scale: float):
+ ctx.scale = scale
+ return x
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx: torch.Tensor, grad: torch.Tensor):
+ return grad * ctx.scale, None
+
+
+scale_gradient = ScaleGradient.apply
+
+
+def resolve_dtensor(weight: torch.Tensor):
+ if version.parse(torch.__version__) >= version.parse('2.0.0'):
+ from torch.distributed._tensor import DTensor
+ if isinstance(weight, DTensor):
+ return weight.to_local()
+ return weight
+
+
+def create_moe_expert_weights(
+ args: Arguments,
+ num_experts: int,
+ ffn_hidden_size: int,
+ hidden_size: int,
+ init_method: InitFn,
+):
+ # Create the entire weight matrix such that the sampled weights will
+ # not vary between data parallelism and expert model parallelism for
+ # the same random seed.
+ master_weights = torch.empty(
+ num_experts,
+ ffn_hidden_size,
+ hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ )
+ init_method(master_weights)
+
+ if not args.moe_expert_model_parallelism:
+ return master_weights
+
+ # Calculate the amount of sharding in each dimension.
+ expert_sharding_degree = mpu.expert_sharding_degree(args)
+ hidden_sharding_degree = mpu.hidden_sharding_degree(args)
+
+ # Calculate the experts per rank.
+ #
+ # NOTE: We assign ranks to be expert parallel before going
+ # tensor parallel.
+ rank = mpu.get_expert_parallel_rank(args)
+ expert_rank = rank % expert_sharding_degree
+ num_experts_per_rank = num_experts // expert_sharding_degree
+ start_expert = expert_rank * num_experts_per_rank
+ end_expert = (expert_rank + 1) * num_experts_per_rank
+
+ # Calculate the rows per rank.
+ row_rank = rank // expert_sharding_degree
+ num_rows_per_rank = ffn_hidden_size // hidden_sharding_degree
+ start_row = row_rank * num_rows_per_rank
+ end_row = (row_rank + 1) * num_rows_per_rank
+
+ # Slice the weight matrix to get the chunk for this rank.
+ with torch.no_grad():
+ weights = master_weights[start_expert:end_expert, start_row:end_row]
+ return weights
+
+
+class MLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ # expert_parallel_world_size = mpu.get_expert_parallel_world_size(args)
+ experts_per_rank = mpu.experts_per_rank(args)
+
+ self.w1 = torch.nn.Parameter(
+ torch.empty(
+ experts_per_rank,
+ args.hidden_size,
+ mpu.features_per_rank(args),
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ self.w2 = torch.nn.Parameter(
+ torch.empty(
+ experts_per_rank,
+ mpu.features_per_rank(args),
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w1,
+ args.moe_expert_model_parallelism,
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w2,
+ args.moe_expert_model_parallelism,
+ )
+
+ # Initialize the parameters for the MLP.
+ #
+ # NOTE: It is important that we create the weight tensors prior
+ # to creating the master weights and slicing our the piece for
+ # this rank. If the master weights are created first the PyTorch
+ # caching allocator appears to use the same memory block for these
+ # and the slice which causes large increases in our peak memory
+ # usage.
+ with torch.no_grad():
+ w1 = create_moe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ )
+ self.w1.copy_(w1.transpose(1, 2).contiguous())
+ self.w2.copy_(
+ create_moe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.output_layer_init_method,
+ ),
+ )
+
+ self.gradient_scale = None
+ if self.args.moe_expert_model_parallelism:
+ self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,)
+
+ def scale_grad(self, w):
+ if self.gradient_scale is None:
+ return w
+ return scale_gradient(w, self.gradient_scale)
+
+ def forward(self, x):
+ w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2)
+ w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2)
+ x = torch.bmm(x, w1)
+ x = self.args.activation_fn(x)
+ return torch.bmm(x, w2)
+
+
+def create_dmoe_expert_weights(
+ args: Arguments,
+ num_experts: int,
+ rows: int,
+ columns: int,
+ init_method: InitFn,
+):
+ weights = create_moe_expert_weights(
+ args,
+ num_experts,
+ rows,
+ columns,
+ init_method,
+ )
+ return weights.view([-1, columns])
+
+
+class MemoryOptimizedMLP(torch.autograd.Function):
+ """Sparse MLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, w2, topo, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.")
+
+ topo_tensors = (
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets,
+ topo.column_indices_t,
+ topo.offsets_t,
+ topo.block_offsets_t,
+ )
+
+ # Layer 0: x @ w1.t().
+ sdd_out = stk.ops.sdd(x, w1.t(), topo)
+
+ # GeLU.
+ activation_fn_out = act_fn(sdd_out, activation_fn)
+
+ # Layer 1: x @ w2.
+ dsd_out = stk.ops.dsd(activation_fn_out, w2)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.shape = topo.shape
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.data.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, w2, *topo_tensors, x, sdd_out.data)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx, ddsd_out):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, w2 = saved_tensors[:2]
+ topo_tensors = saved_tensors[2:8]
+ x = saved_tensors[8]
+ sdd_out_data = saved_tensors[9]
+
+ # rematerialize activation function output
+ activation_fn = ctx.activation_fn
+ sdd_out = stk.Matrix(ctx.shape, sdd_out_data, *topo_tensors)
+ activation_fn_out, activation_grad_fn = act_fn(
+ sdd_out,
+ activation_fn,
+ return_grad_fn=True,
+ )
+
+ # Compute dw2 with recomputed activation_fn output.
+ dw2 = stk.ops.dsd(activation_fn_out.t(), ddsd_out)
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ stk.backend.triton_kernels.sdd(
+ ddsd_out,
+ w2.t(),
+ dactivation_fn_out.shape,
+ dactivation_fn_out.data,
+ dactivation_fn_out.offsets,
+ dactivation_fn_out.row_indices,
+ dactivation_fn_out.column_indices,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ if activation_fn is DEFAULT_ACTIVATION_FN:
+ dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out)
+ else:
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out.data)
+ dsdd_out = stk.Matrix(ctx.shape, sdd_out.data.grad, *topo_tensors)
+
+ # Compute dw1.
+ dw1 = stk.ops.dsd(dsdd_out.t(), x)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ stk.backend.triton_kernels.dsd(
+ dsdd_out.shape,
+ dsdd_out.data,
+ dsdd_out.offsets,
+ dsdd_out.row_indices,
+ dsdd_out.column_indices,
+ dsdd_out.offsets_t,
+ dsdd_out.column_indices_t,
+ dsdd_out.block_offsets_t,
+ False,
+ w1,
+ ddsd_out,
+ )
+ dx = ddsd_out
+ return dx, dw1, dw2, None, None
+
+
+memory_optimized_mlp = MemoryOptimizedMLP.apply
+
+
+class SparseMLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ self._num_rows_per_rank = mpu.experts_per_rank(args) * mpu.features_per_rank(args)
+
+ self.w1 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ self.w2 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+
+ # Initialize the parameters for the MLP.
+ #
+ # NOTE: It is important that we create the weight tensors prior
+ # to creating the master weights and slicing our the piece for
+ # this rank. If the master weights are created first the PyTorch
+ # caching allocator appears to use the same memory block for these
+ # and the slice which causes large increases in our peak memory
+ # usage.
+ with torch.no_grad():
+ self.w1.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ ),
+ )
+ self.w2.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.output_layer_init_method,
+ ),
+ )
+
+ self._should_set_parallelism_attribute = args.moe_expert_model_parallelism
+ mpu.set_expert_model_parallel_attributes(
+ self.w1,
+ self._should_set_parallelism_attribute,
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w2,
+ self._should_set_parallelism_attribute,
+ )
+
+ self.gradient_scale = None
+ if self.args.moe_expert_model_parallelism:
+ self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,)
+
+ def scale_grad(self, w):
+ if self.gradient_scale is None:
+ return w
+ return scale_gradient(w, self.gradient_scale)
+
+ def forward(self, x, topo):
+ w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2)
+ w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2)
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_mlp(
+ x,
+ w1,
+ w2,
+ topo,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ x = stk.ops.sdd(x, w1.t(), topo)
+ activation_fn_out = act_fn(x, self.args.activation_fn)
+ return stk.ops.dsd(activation_fn_out, w2)
+
+
+class MemoryOptimizedGroupedMLP(torch.autograd.Function):
+ """GroupedMLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, w2, batch_sizes, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.")
+
+ # Layer 0: x @ w1.t().
+ assert gg.backend is not None
+ sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True)
+
+ # activation_fn
+ activation_fn_out = activation_fn(sdd_out)
+
+ # Layer 1: x @ w2.
+ dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, w2, batch_sizes, x, sdd_out)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx: Any, ddsd_out: torch.Tensor):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # Unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, w2 = saved_tensors[:2]
+ batch_sizes = saved_tensors[2]
+ x = saved_tensors[3]
+ sdd_out = saved_tensors[4]
+
+ # Rematerialize activation_fn output.
+ activation_fn = ctx.activation_fn
+ with torch.set_grad_enabled(True):
+ sdd_out.requires_grad = True
+ activation_fn_out = activation_fn(sdd_out)
+ activation_grad_fn = activation_fn_out.backward
+
+ # Compute dw2 with recomputed activation_fn output.
+ assert gg.backend is not None
+ dw2 = gg.backend.gmm(
+ activation_fn_out,
+ ddsd_out,
+ batch_sizes,
+ trans_a=True,
+ )
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ gg.backend.gmm(
+ ddsd_out,
+ w2,
+ batch_sizes,
+ trans_b=True,
+ c=dactivation_fn_out,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ if activation_fn is DEFAULT_ACTIVATION_FN:
+ dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out)
+ else:
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out)
+ dsdd_out = sdd_out.grad
+
+ # Compute dw1.
+ dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ gg.backend.gmm(dsdd_out, w1, batch_sizes, c=ddsd_out)
+ dx = ddsd_out
+ return dx, dw1, dw2, None, None
+
+
+memory_optimized_grouped_mlp = MemoryOptimizedGroupedMLP.apply
+
+
+class GroupedMLP(SparseMLP):
+
+ def forward(self, x, tokens_per_expert):
+ batch_sizes = tokens_per_expert.cpu().to(torch.long)
+ w1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.w2))
+
+ # Re-shape the weights for the grouped GEMMs.
+ ne = mpu.experts_per_rank(self.args)
+ w1 = resolve_dtensor(w1).view(ne, -1, self.args.hidden_size)
+ w2 = resolve_dtensor(w2).view(ne, -1, self.args.hidden_size)
+
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_grouped_mlp(
+ x,
+ w1,
+ w2,
+ batch_sizes,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ assert gg.ops is not None
+ x = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
+ x = self.args.activation_fn(x)
+ return gg.ops.gmm(x, w2, batch_sizes)
+
+
+class SharedMLP(torch.nn.Module):
+ """MLP for shared expert.
+
+ Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTMLP class
+ """
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ self.fc_kwargs: dict[str, Any] = {
+ 'bias': args.bias,
+ 'device': args.device,
+ }
+ self.fc_kwargs.update(args.fc_kwargs)
+
+ self.up_proj = args.fc_cls(
+ args.hidden_size,
+ args.shared_expert_hidden_size,
+ **self.fc_kwargs,
+ )
+ self.act = args.activation_fn
+ self.down_proj = args.fc_cls(
+ args.shared_expert_hidden_size,
+ args.hidden_size,
+ **self.fc_kwargs,
+ )
+ self.down_proj._is_residual = True # a flag for llm-foundry init
+
+ def add_experts_sharedexpert(
+ self,
+ shared_expert_out: torch.Tensor,
+ expert_out: torch.Tensor,
+ ) -> torch.Tensor:
+ # Helper function to add expert output to shared expert output
+ # with optional weighted sum.
+ if self.args.shared_expert_weighted_sum:
+ # enable using weighted sum for shared expert output
+ # wieghted by number of experts used
+ t_experts = self.args.moe_top_k + 1
+ sh_mlp_out = shared_expert_out / t_experts
+ return sh_mlp_out.add(
+ expert_out,
+ alpha=(self.args.moe_top_k / t_experts),
+ )
+
+ return shared_expert_out + expert_out
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.down_proj(self.act(self.up_proj(x)))
diff --git a/build/torch210-cxx11-cu130-x86_64-linux/_layers/moe.py b/build/torch210-cxx11-cu130-x86_64-linux/_layers/moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0a4aeaacc9c86fc70944e730c53f7a55644e05e
--- /dev/null
+++ b/build/torch210-cxx11-cu130-x86_64-linux/_layers/moe.py
@@ -0,0 +1,507 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Optional, Tuple
+
+import numpy as np
+import torch
+import torch.distributed as dist
+
+# import megablocks.ops as ops
+# from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry
+# from megablocks.layers.all_to_all import all_to_all
+# from megablocks.layers.arguments import Arguments
+
+from ..ops import (
+ sort,
+ histogram,
+ inclusive_cumsum,
+ exclusive_cumsum,
+ binned_gather,
+ binned_scatter,
+ gather,
+ scatter,
+ repeat,
+ replicate,
+)
+
+from . import common, mlp, mpu, router, sharedexpert_registry
+from .arguments import Arguments
+from .all_to_all import all_to_all
+
+_LOAD_BALANCING_LOSS = []
+
+
+def save_load_balancing_loss(loss):
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.append(loss)
+
+
+def get_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ return _LOAD_BALANCING_LOSS
+
+
+def clear_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.clear()
+
+
+def batched_load_balancing_loss(args: Arguments):
+ if args.moe_loss_weight == 0:
+ return 0.0
+
+ # tokens_per_expert[i].shape = (num_experts)
+ # expert_scores[i].shape = (tokens, num_experts)
+ tokens_per_expert, expert_scores = zip(*get_load_balancing_loss())
+ num_layers_per_pipeline_stage = (args.num_layers // args.pipeline_model_parallel_size)
+ if args.num_layers_per_virtual_pipeline_stage is not None:
+ num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage
+
+ if len(tokens_per_expert) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f'Expected {num_layers_per_pipeline_stage} token_per_experts '
+ f'but found {len(tokens_per_expert)}.\nnum_layers = '
+ f'{args.num_layers}\npipeline_model_parallel_size = '
+ f'{args.pipeline_model_parallel_size}\n'
+ 'num_layers_per_virtual_pipeline_stage'
+ f' = {args.num_layers_per_virtual_pipeline_stage}',
+ )
+ if len(expert_scores) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f'Expected {num_layers_per_pipeline_stage} expert_scores '
+ f'but found {len(tokens_per_expert)}.\nnum_layers = '
+ f'{args.num_layers}\npipeline_model_parallel_size = '
+ f'{args.pipeline_model_parallel_size}\n'
+ 'num_layers_per_virtual_pipeline_stage'
+ f' = {args.num_layers_per_virtual_pipeline_stage}',
+ )
+
+ # Verify the shape of the tokens_per_expert and expert_scores tensors.
+ assert all((x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert))
+
+ tokens = expert_scores[0].shape[0]
+ assert all(((x.ndim == 2 and x.shape[1] == args.moe_num_experts and x.shape[0] == tokens) for x in expert_scores))
+
+ # Concatenate the contributions of each layer and convert to
+ # the correct types and formats for the dot product.
+ expert_scores = torch.cat(expert_scores, dim=1)
+ if args.moe_lbl_in_fp32:
+ expert_scores = expert_scores.float()
+ if tokens != 0:
+ expert_scores = expert_scores.mean(dim=0)
+ else:
+ expert_scores = expert_scores.sum(dim=0)
+ tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype)
+
+ expected_values = num_layers_per_pipeline_stage * args.moe_num_experts
+ assert tokens_per_expert.numel() == expected_values
+ assert expert_scores.numel() == expected_values
+
+ # Calculate the total scale across all factors.
+ #
+ # loss_weight * num_experts / (num_layers * tokens * top_k)
+ scale_numerator = (args.moe_num_experts * args.moe_loss_weight)
+ scale_denominator = (args.num_layers * tokens * args.moe_top_k)
+ scale = scale_numerator / scale_denominator
+ return scale * torch.dot(tokens_per_expert, expert_scores)
+
+
+# NOTE: This class defines MoE expert computation, including expert model parallel
+# communication. When using FSDP on top of MegaBlocks this is the module that should
+# be wrapped s.t. the weight all-gathers can be scheduled *before* the expert model
+# parallel all2all.
+class ParallelMLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super(ParallelMLP, self).__init__()
+ self.args = args
+
+ # Calculate the number of experts in total and the number of experts
+ # owned by this rank.
+ # world_size = mpu.get_expert_parallel_world_size(args)
+ self.num_experts = args.moe_num_experts
+ self.top_k = self.args.moe_top_k
+
+ # Calculate the number of bits needed to represent the expert indices
+ # so that we can pass it to radix sort.
+ self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1)
+
+ # Expert MLP.
+ self.mlp = mlp.MLP(args)
+
+ self.bias: Optional[torch.Tensor]
+ if self.args.bias:
+ # Note that the output bias is not parallelized with expert
+ # model parallelism.
+ self.bias = torch.nn.Parameter(
+ torch.empty(
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ torch.nn.init.zeros_(self.bias)
+ else:
+ self.register_parameter('bias', None)
+
+ # Select the forward function for the operating mode.
+ self.forward_fn = (self.parallel_forward_once if args.moe_expert_model_parallelism else self.forward_once)
+
+ def expert_capacity(self, tokens: int) -> int:
+ world_size = mpu.get_expert_parallel_world_size(self.args)
+ tokens_per_expert = (self.top_k * tokens * world_size / self.num_experts)
+ return int(self.args.moe_capacity_factor * tokens_per_expert)
+
+ def load_balancing_loss(self, tokens_per_expert: torch.Tensor, expert_scores: torch.Tensor):
+ """Calculate the load balancing loss contribution."""
+ assert len(expert_scores.size()) == 2
+ tokens, num_experts = expert_scores.size()
+ assert num_experts == self.num_experts
+ assert len(tokens_per_expert.size()) == 1
+ num_experts, = tokens_per_expert.size()
+ assert num_experts == self.num_experts
+ scale = self.num_experts / (tokens * self.top_k)
+ return scale * torch.dot(
+ tokens_per_expert.to(expert_scores.dtype),
+ expert_scores.mean(dim=0),
+ )
+
+ def indices_and_bins(self,
+ top_expert: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ # Sort the expert ids to produce the scatter/gather
+ # indices for the permutation.
+ #
+ # TODO(tgale): Is it worth doing this conversion to 32-bit
+ # prior? Could we place the `torch.max` operation to return
+ # 32-bit expert indices?
+ top_expert = top_expert.int()
+ # output = ops.sort(top_expert, self.sort_end_bit)
+ output = sort(top_expert, self.sort_end_bit)
+ assert output is not None
+ bin_ids, indices = output
+
+ # Histogram the expert ids to identify the number of
+ # tokens routed to each expert.
+ #
+ # TODO(tgale): Does the sorted data produce a more favorable
+ # data distribution for histogram? Or is the op parallelism
+ # worth more?
+ # tokens_per_expert = ops.histogram(top_expert, self.num_experts)
+ tokens_per_expert = histogram(top_expert, self.num_experts)
+
+ # Calculate the bin bounds for the sorted tokens.
+ # bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ bins = inclusive_cumsum(tokens_per_expert, 0)
+ assert bins is not None
+ bins = bins.view(1) if not len(bins.size()) else bins
+
+ assert isinstance(indices, torch.Tensor)
+ assert isinstance(bin_ids, torch.Tensor)
+ assert isinstance(bins, torch.Tensor)
+ assert isinstance(tokens_per_expert, torch.Tensor)
+
+ return indices, bin_ids, bins, tokens_per_expert
+
+ def permute_and_compute(
+ self,
+ x: torch.Tensor,
+ tokens_per_expert: int, # unused
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor, # unused
+ expert_weights: torch.Tensor,
+ bins: torch.Tensor,
+ expert_capacity: int,
+ top_k: int,
+ ):
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ # output = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
+ output = binned_gather(x, indices, bins, expert_capacity, top_k)
+ assert output is not None
+ x = output
+
+ # Perform the expert computation. Note that we don't
+ # use biases for these linear operations.
+ x = self.mlp(x)
+
+ # Un-route the data for the MoE output.
+ # return ops.binned_scatter(x, indices, expert_weights, bins, top_k)
+ return binned_scatter(x, indices, expert_weights, bins, top_k)
+
+
+ def forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ # If expert_capacity is set to zero, set the number of tokens
+ # per expert to the maximum we need to avoid dropping tokens.
+ sl, bs, _ = x.size()
+ expert_capacity = self.expert_capacity(sl * bs)
+ if expert_capacity == 0:
+ expert_capacity = torch.max(tokens_per_expert).item()
+
+ x = self.permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ self.top_k,
+ )
+ return x, tokens_per_expert
+
+ def parallel_forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ # NOTE: This function implements the same computation as forward_once
+ # but with expert model parallelism.
+ #
+ # 1. Permute the tokens locally so that they are grouped by their
+ # expert assignments. This allows us to transfer all of the tokens
+ # for a remote device in one communication primitive.
+ #
+ # 2. Permute the tokens across the expert parallel devices. After
+ # this is completed each device has all of the tokens assigned to
+ # its set of experts in its local HBM.
+ #
+ # 3. Permute the tokens locally so that they are grouped by their
+ # expert assignement. After the distributed permutation the tokens
+ # are grouped by which device they came from. We re-order them
+ # locally to allow for efficient computation.
+ #
+ # After this series of permutations we compute the linear layers
+ # and then repeat these three steps in reverse to produce the final
+ # output.
+ #
+ # Compute the mapping of local tokens to experts.
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ # If we're sharding the experts along the hidden dimension
+ # multiple devices own parts of the same sets of experts.
+ # Replicate the token counts so every device gets the counts.
+ # repeated_tokens_per_expert = ops.repeat(
+ repeated_tokens_per_expert = repeat(
+ tokens_per_expert,
+ (mpu.hidden_sharding_degree(self.args),),
+ )
+
+ # Pass token count information to the device on which the
+ # target expert resides.
+ parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert,)
+ tpe_handle = dist.all_to_all_single(
+ parallel_tokens_per_expert,
+ repeated_tokens_per_expert,
+ group=self.args.expert_parallel_group,
+ async_op=True,
+ )
+
+ # Permute locally and without any padding so that tokens for each
+ # parallel device are stored contiguously.
+ #
+ # This view updates the shape of the tensor from [sl, bs, hs] to
+ # [sl * bs, hs] prior to the permutation.
+ x = x.view(-1, x.shape[-1])
+ # output = ops.gather(x, indices, bin_ids, bins, self.top_k)
+ output = gather(x, indices, bin_ids, bins, self.top_k)
+ assert output is not None
+ x = output
+
+ # Compute the number of tokens that will be received from each
+ # device and permute the input data across the devices.
+ with torch.no_grad():
+ tpe_handle.wait()
+ experts_per_rank = mpu.experts_per_rank(self.args)
+
+ # Reshape to [world_size, num_experts_per_rank].
+ world_size = mpu.get_expert_parallel_world_size(self.args)
+ repeated_tokens_per_expert = (repeated_tokens_per_expert.view(world_size, experts_per_rank))
+ parallel_tokens_per_expert = (parallel_tokens_per_expert.view(world_size, experts_per_rank))
+
+ # TODO(tgale): It might be faster to do this on the GPU and
+ # then communicate the results back to the host.
+ send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1)
+ parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
+ recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1)
+
+ # Convert the send/recv counts to lists.
+ send_counts = send_counts.tolist()
+ recv_counts = recv_counts.tolist()
+ tokens_received = sum(recv_counts)
+
+ # If we're sharding the experts along the hidden dimension
+ # multiple devices own parts of the same sets of experts.
+ # Replicate the token counts so devices that share experts
+ # get all of the tokens assigned to them.
+ #
+ # TODO(tgale): Fuse this into the prior, local permutation.
+ # x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
+ x = repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
+
+ # Start the cross-device permutation asynchronously so we can
+ # overlap communication with computation.
+ parallel_x, parallel_x_handle = all_to_all(
+ x,
+ recv_counts,
+ send_counts,
+ self.args.expert_parallel_group,
+ async_op=True,
+ )
+
+ with torch.no_grad():
+ # After we do the cross-device permutation we have the tokens on the
+ # correct device but not yet grouped by expert because we received
+ # tokens from each device as contiguous chunks. To group the tokens
+ # for expert computation we'll do one more local permutation. The
+ # rest of this torch.no_grad() scope sets up the indices and bins
+ # for this permutation.
+ # replicate_bins = ops.inclusive_cumsum(
+ replicate_bins = inclusive_cumsum(
+ parallel_tokens_per_expert.flatten(),
+ 0,
+ )
+ replicate_bins = (replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins)
+
+ # Construct the expert indices for the permuted tokens.
+ parallel_top_expert = torch.remainder(
+ torch.arange(
+ self.num_experts * mpu.hidden_sharding_degree(self.args),
+ dtype=torch.int32,
+ device=indices.device,
+ ),
+ mpu.experts_per_rank(self.args),
+ )
+ # parallel_top_expert = ops.replicate(
+ parallel_top_expert = replicate(
+ parallel_top_expert.unsqueeze(dim=0),
+ replicate_bins,
+ tokens_received,
+ ).flatten()
+
+ # TODO(tgale): The sort_end_bit here can be reduced.
+ # parallel_bin_ids, parallel_indices = ops.sort(
+ parallel_bin_ids, parallel_indices = sort(
+ parallel_top_expert,
+ self.sort_end_bit,
+ )
+
+ # Calculate the bins boundaries from the token counts.
+ parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+ # parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = (parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins)
+
+ # If expert_capacity is set to zero, set the number of tokens
+ # per expert to the maximum we need to avoid dropping tokens.
+ tokens, _ = x.size()
+ expert_capacity = self.expert_capacity(tokens)
+ if expert_capacity == 0:
+ expert_capacity = torch.max(parallel_tokens_per_expert).item()
+
+ # Locally permute the tokens and perform the expert computation.
+ # Block to make sure that the cross-device permutation is complete.
+ if self.args.mlp_impl == 'grouped':
+ # GroupedMLP requires counts on CPU. We can use the tensor already
+ # moved to CPU for the prior all_to_all, which avoids an extra
+ # device synchronization.
+ parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+ parallel_x_handle.wait()
+ parallel_x = self.permute_and_compute(
+ parallel_x,
+ parallel_tokens_per_expert,
+ parallel_indices,
+ parallel_bin_ids,
+ None, # expert_weights
+ parallel_bins,
+ expert_capacity,
+ top_k=1,
+ )
+
+ # Un-permute the tokens across the devices.
+ x, _ = all_to_all(
+ parallel_x,
+ send_counts,
+ recv_counts,
+ self.args.expert_parallel_group,
+ )
+
+ # Reduce along the hidden sharding to get the final outputs.
+ #
+ # TODO(tgale): Fuse this into the following local permutation.
+ shape = (
+ mpu.hidden_sharding_degree(self.args),
+ -1,
+ self.args.hidden_size,
+ )
+ # x = ops.sum(x.view(shape), dim=0)
+ x = x.view(shape).sum(dim=0)
+
+ # Un-permute locally to setup for the next series of operations.
+ # x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
+ x = scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
+ return x, tokens_per_expert.flatten()
+
+ def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ in_shape = x.size()
+
+ # Compute the experts.
+ x, tokens_per_expert = self.forward_fn(x, expert_weights, top_experts)
+ if self.training and self.args.moe_loss_weight > 0:
+ save_load_balancing_loss((tokens_per_expert, scores))
+ x = x.view(in_shape)
+ if self.bias is not None:
+ if self.args.return_bias:
+ return x, self.bias
+ return x + self.bias
+ return x
+
+
+class MoE(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super(MoE, self).__init__()
+
+ # Token router.
+ self.router = router.LearnedRouter(args)
+
+ # Expert computation helper.
+ self.experts = self._init_experts_mlp(args)
+
+ self.shared_expert = None
+ if args.shared_expert:
+ # SharedExpert computation helper.
+ self.shared_expert = sharedexpert_registry.get(args)
+
+ def _init_experts_mlp(self, args: Arguments):
+ return ParallelMLP(args)
+
+ def forward(self, x: torch.Tensor):
+ # NOTE: If we're going to cast the activations to lower precision
+ # do it before we permute the tokens to save bandwidth.
+ x = common.cast_if_autocast_enabled(x)
+
+ # Compute the expert scores and assignments.
+ scores, expert_weights, top_experts = self.router(x)
+
+ # Compute the experts.
+ out = self.experts(x, scores, expert_weights, top_experts)
+ if self.shared_expert is not None:
+ shared_expert_out = self.shared_expert(x)
+ out = self.shared_expert.add_experts_sharedexpert(
+ shared_expert_out,
+ out,
+ )
+ return out
diff --git a/build/torch210-cxx11-cu130-x86_64-linux/_layers/mpu.py b/build/torch210-cxx11-cu130-x86_64-linux/_layers/mpu.py
new file mode 100644
index 0000000000000000000000000000000000000000..434e143ab42bf3f83406d69e9dd1f72777716e22
--- /dev/null
+++ b/build/torch210-cxx11-cu130-x86_64-linux/_layers/mpu.py
@@ -0,0 +1,94 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Optional
+
+import torch
+import torch.distributed as dist
+
+# from megablocks.layers.arguments import Arguments
+from .arguments import Arguments
+
+
+class MoeParam(torch.Tensor):
+
+ def __init__(self):
+ super().__init__(self)
+ self.expert_model_parallel: bool
+
+
+def is_moe_param(tensor: torch.Tensor) -> bool:
+ return hasattr(tensor, 'expert_model_parallel')
+
+
+def get_expert_parallel_world_size(args: Arguments) -> int:
+ return (dist.get_world_size(args.expert_parallel_group) if args.moe_expert_model_parallelism else 1)
+
+
+def get_expert_parallel_rank(args: Arguments) -> int:
+ return (dist.get_rank(args.expert_parallel_group) if args.moe_expert_model_parallelism else 0)
+
+
+def set_expert_model_parallel_attributes(
+ tensor: torch.Tensor,
+ is_parallel: bool,
+):
+ assert not hasattr(tensor, 'expert_model_parallel')
+ setattr(tensor, 'expert_model_parallel', is_parallel)
+
+
+def param_is_expert_model_parallel(param: MoeParam) -> bool:
+ return (hasattr(param, 'expert_model_parallel') and param.expert_model_parallel)
+
+
+def copy_expert_model_parallel_attributes(
+ destination_tensor: torch.Tensor,
+ source_tensor: torch.Tensor,
+):
+ if hasattr(source_tensor, 'expert_model_parallel'):
+ setattr(
+ destination_tensor,
+ 'expert_model_parallel',
+ getattr(source_tensor, 'expert_model_parallel'),
+ )
+
+
+def synchronized_print(group: Optional[dist.ProcessGroup], *x: torch.Tensor):
+ world_size = dist.get_world_size(group)
+ rank = dist.get_rank(group)
+ for i in range(world_size):
+ dist.barrier(group)
+ if i == rank:
+ print(f'rank = {rank}', *x)
+
+
+# Helpers for expert/tensor sharding.
+def expert_sharding_degree(args: Arguments) -> int:
+ world_size = get_expert_parallel_world_size(args)
+ esd = min(world_size, args.moe_num_experts)
+
+ if (args.moe_num_experts % esd) != 0:
+ raise ValueError(f'Cannot shard {args.moe_num_experts} experts {esd} ways.',)
+ return esd
+
+
+def hidden_sharding_degree(args: Arguments) -> int:
+ world_size = get_expert_parallel_world_size(args)
+ esd = expert_sharding_degree(args)
+ hsd = world_size // esd
+
+ if (args.ffn_hidden_size % hsd) != 0:
+ raise ValueError(f'Cannot shard {args.ffn_hidden_size} features {hsd} ways.',)
+ if (esd * hsd) != world_size:
+ raise ValueError(
+ f"Invalid sharding. 'expert_sharding_degree' ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size}).",
+ )
+ return hsd
+
+
+def experts_per_rank(args: Arguments) -> int:
+ return args.moe_num_experts // expert_sharding_degree(args)
+
+
+def features_per_rank(args: Arguments) -> int:
+ return args.ffn_hidden_size // hidden_sharding_degree(args)
diff --git a/build/torch210-cxx11-cu130-x86_64-linux/_layers/router.py b/build/torch210-cxx11-cu130-x86_64-linux/_layers/router.py
new file mode 100644
index 0000000000000000000000000000000000000000..37cb2782348d62583376f1a183c7ede83601216d
--- /dev/null
+++ b/build/torch210-cxx11-cu130-x86_64-linux/_layers/router.py
@@ -0,0 +1,116 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+
+# from megablocks.layers import common
+# from megablocks.layers.arguments import Arguments
+from . import common
+from .arguments import Arguments
+
+_ROUTER_LOGITS = []
+
+
+def _save_router_logits(logits: torch.Tensor, args: Arguments):
+ if args.moe_zloss_weight == 0:
+ return
+ global _ROUTER_LOGITS
+ _ROUTER_LOGITS.append(logits)
+
+
+def clear_router_zloss():
+ global _ROUTER_LOGITS
+ _ROUTER_LOGITS.clear()
+
+
+def batched_router_zloss(args: Arguments):
+ global _ROUTER_LOGITS
+
+ if args.moe_zloss_weight == 0:
+ import warnings
+ warnings.warn('Call to batched_router_zloss, but moe_zloss_weight=0')
+ return 0
+
+ logits_per_router = _ROUTER_LOGITS
+
+ if args.moe_zloss_in_fp32:
+ logits_per_router = [logits.float() for logits in logits_per_router]
+
+ unscaled_zloss_per_router = torch.stack([
+ torch.logsumexp(logits, dim=1).square().mean() for logits in logits_per_router
+ ])
+
+ return args.moe_zloss_weight * unscaled_zloss_per_router
+
+
+# NOTE: To enable end-to-end benchmarking without convergence we
+# support a flag to force the router to assign tokens uniformly
+# across the experts. We do this with a custom autograd operation
+# so that PyTorch still executes the full set of router operation.
+class _UniformExpertAssignment(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, num_experts: int):
+ out = torch.arange(x.numel(), dtype=x.dtype, device=x.device)
+ out = torch.remainder(out, num_experts)
+ return out.view(x.shape)
+
+
+_uniform_expert_assignment = _UniformExpertAssignment.apply
+
+
+class LearnedRouter(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+
+ # Learned router parameters.
+ #
+ # NOTE: This weight matrix is not parallelized with expert model
+ # parallelism. Each device needs the entire router weight matrix
+ # so that it can route its batch of data correctly.
+ self.layer = torch.nn.Linear(
+ args.hidden_size,
+ args.moe_num_experts,
+ bias=False,
+ dtype=common.dtype(args),
+ device=args.device,
+ )
+ args.init_method(self.layer.weight)
+
+ def jitter(self, x: torch.Tensor):
+ low: float = 1.0 - self.args.moe_jitter_eps
+ high: float = 1.0 + self.args.moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return low + noise * (high - low)
+
+ def _top_k(self, scores: torch.Tensor):
+ if self.args.moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, self.args.moe_top_k, dim=-1)
+
+ def forward(self, x: torch.Tensor):
+ if self.training and self.args.moe_jitter_eps is not None:
+ x = x * self.jitter(x)
+
+ logits = self.layer(x.view(-1, x.shape[-1]))
+ _save_router_logits(logits, self.args)
+ scores = logits.softmax(dim=-1)
+ expert_weights, expert_indices = self._top_k(scores)
+ if self.args.moe_normalize_expert_weights:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=self.args.moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+
+ expert_indices = (
+ _uniform_expert_assignment(
+ expert_indices,
+ self.args.moe_num_experts,
+ ) if self.args.uniform_expert_assignment else expert_indices
+ )
+ return scores, expert_weights, expert_indices
diff --git a/build/torch210-cxx11-cu130-x86_64-linux/_layers/sharedexpert_registry.py b/build/torch210-cxx11-cu130-x86_64-linux/_layers/sharedexpert_registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..5840862f88f370ace5fd49bd0612fc98d186cc49
--- /dev/null
+++ b/build/torch210-cxx11-cu130-x86_64-linux/_layers/sharedexpert_registry.py
@@ -0,0 +1,32 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Union
+
+# from megablocks.layers import glu, mlp
+# from megablocks.layers.arguments import Arguments
+from . import glu, mlp
+from .arguments import Arguments
+
+_REGISTRY = {
+ 'mlp': mlp.SharedMLP,
+ 'glu': glu.SharedGLU,
+}
+
+
+def get(args: Arguments) -> Union[mlp.SharedMLP, glu.SharedGLU]:
+ """Returns an SharedMLP for use in a dMoE instance.
+
+ Uses the provided arguments to instantiate the appropriate
+ SharedMLP instance.
+
+ Args:
+ args: propagated Arguments dataclass.
+
+ Returns:
+ An instantiated SharedMLP constructed using the input args.
+ """
+ if args.mlp_type not in _REGISTRY:
+ raise ValueError(f'Unsupported mlp type: {args.mlp_type}')
+
+ return _REGISTRY[args.mlp_type](args)
diff --git a/build/torch210-cxx11-cu130-x86_64-linux/_megablocks_cuda_7a6bcf4.abi3.so b/build/torch210-cxx11-cu130-x86_64-linux/_megablocks_cuda_7a6bcf4.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..ae40d277669f147909c87b24ab352118b0c55653
--- /dev/null
+++ b/build/torch210-cxx11-cu130-x86_64-linux/_megablocks_cuda_7a6bcf4.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c3f4b9db1caad794b2dfa9befd5d7225e1a0a78bd891f82bb1d1d84c46143ddf
+size 12041592
diff --git a/build/torch210-cxx11-cu130-x86_64-linux/_ops.py b/build/torch210-cxx11-cu130-x86_64-linux/_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..abde001f8cf5f78a02794d6e9a81fd8195e65d77
--- /dev/null
+++ b/build/torch210-cxx11-cu130-x86_64-linux/_ops.py
@@ -0,0 +1,9 @@
+import torch
+from . import _megablocks_cuda_7a6bcf4
+ops = torch.ops._megablocks_cuda_7a6bcf4
+
+def add_op_namespace_prefix(op_name: str):
+ """
+ Prefix op by namespace.
+ """
+ return f"_megablocks_cuda_7a6bcf4::{op_name}"
diff --git a/build/torch210-cxx11-cu130-x86_64-linux/_version.py b/build/torch210-cxx11-cu130-x86_64-linux/_version.py
new file mode 100644
index 0000000000000000000000000000000000000000..c55783177af19bc03654c730c4892df8f8532279
--- /dev/null
+++ b/build/torch210-cxx11-cu130-x86_64-linux/_version.py
@@ -0,0 +1,6 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+"""The MegaBlocks Version."""
+
+__version__ = '0.11.0.dev0'
diff --git a/build/torch210-cxx11-cu130-x86_64-linux/backend/__init__.py b/build/torch210-cxx11-cu130-x86_64-linux/backend/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d4e43e9b7e6a9a1c3bde2df34914643ca5d8332
--- /dev/null
+++ b/build/torch210-cxx11-cu130-x86_64-linux/backend/__init__.py
@@ -0,0 +1,2 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
diff --git a/build/torch210-cxx11-cu130-x86_64-linux/backend/kernels.py b/build/torch210-cxx11-cu130-x86_64-linux/backend/kernels.py
new file mode 100644
index 0000000000000000000000000000000000000000..d20139324883338ddb312e4b05a72056d74491ac
--- /dev/null
+++ b/build/torch210-cxx11-cu130-x86_64-linux/backend/kernels.py
@@ -0,0 +1,557 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import triton
+import triton.language as tl
+
+# Stub triton autotune when testing in a env that does not have CUDA
+# this approach preserves the original code but enables testing without a GPU
+if torch.cuda.is_available() is False:
+ import warnings
+
+ warnings.warn("CUDA is not available. Triton autotuning is disabled.")
+
+ def _no_autotune(*args, **kwargs):
+ def deco(fn):
+ return fn
+ return deco
+
+ triton.autotune = _no_autotune
+
+
+def assert_is_tensor(x, ndim):
+ if x.ndim != ndim:
+ raise ValueError(f'Expected {ndim}-tensor but got {x.ndim}-tensor')
+
+
+def assert_is_matrix(x):
+ assert_is_tensor(x, 2)
+
+
+def assert_is_vector(x):
+ if x.ndim != 1:
+ raise ValueError(f'Expected 1-tensor but got {x.ndim}-tensor')
+
+
+def assert_equal(a, b):
+ if a != b:
+ raise ValueError(f'Expected dimensions to be equal but got {a} and {b}.',)
+
+
+# a: (tokens, hidden_size), real.
+# indices: (tokens * top_k), integer.
+# bin_ids: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+# padded_bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _padded_copy(
+ a,
+ b,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+ A_TO_B: tl.constexpr,
+ SCALE: tl.constexpr,
+):
+ # Our index into array 'a'.
+ index_a = tl.load(indices + tl.program_id(0))
+
+ # One threadblock per row in 'a'. Array 'b' has greater or equal
+ # number of rows since they could be padded.
+ bin_idx = tl.load(bin_ids + tl.program_id(0))
+
+ # Now we know what bin we're assigned to, but we need to know how
+ # many threadblocks were assigned to earlier bins so we can offset
+ # in our bin properly.
+ offset_in_bin = tl.program_id(0)
+ if bin_idx > 0:
+ offset_in_bin -= tl.load(bins + bin_idx - 1)
+
+ # Load the starting index of our bin in array 'b'.
+ index_b = offset_in_bin
+ if bin_idx > 0:
+ index_b += tl.load(padded_bins + bin_idx - 1)
+
+ # Offset the input and output pointers.
+ #
+ # If we're going from A to B, divide the input index to copy
+ # the same input repeatedly. If we're going from B to A we
+ # need to reduce the result. Using atomics is slow, so we
+ # do the reduce step in a second kernel.
+ offset = index_a // TOP_K if A_TO_B else index_a
+ a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS)
+ b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ # Load the scale, if requested.
+ scale = tl.load(weights + index_a) if SCALE else 1
+
+ # Swap the pointers depending on the direction.
+ iptr = a if A_TO_B else b
+ optr = b if A_TO_B else a
+
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ x = tl.load(iptr + offsets, mask=mask)
+ x = x.to(tl.float32) * scale.to(tl.float32)
+
+ tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask)
+
+ offsets += BLOCK_X
+
+
+def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+ assert_equal(bin_ids.shape[0], x.shape[0] * top_k)
+ assert_equal(bins.size(), padded_bins.size())
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ # NOTE: Because of the padding, the output size is dynamic.
+ # We load the final padded bin bound to get the output rows.
+ output_rows = padded_bins[-1].cpu().item()
+ out = torch.zeros((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ x,
+ out,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def gather(x, indices, bin_ids, weights, bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+ assert_equal(bin_ids.shape[0], x.shape[0] * top_k)
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ # NOTE: There is no padding so the output rows equals the
+ # input rows multiplied by top_k.
+ output_rows = x.shape[0] * top_k
+ out = torch.empty((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ x,
+ out,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], bin_ids.shape[0])
+ assert_equal(bins.size(), padded_bins.size())
+
+ if weights is not None:
+ assert_equal(indices.shape[0], weights.shape[0])
+
+ tokens = indices.shape[0] // top_k
+ out = torch.empty((tokens, top_k, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ out,
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=False,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+
+ # Reduce along the top-k dimension, if needed.
+ return out.sum(dim=1) if top_k > 1 else out.view(tokens, x.shape[1])
+
+
+def scatter(x, indices, bin_ids, weights, bins, top_k):
+ return padded_scatter(x, indices, bin_ids, weights, bins, bins, top_k)
+
+
+# x: (tokens, top_k, hidden_size), real
+# grad: (tokens, hidden_size), real.
+# wgrad: (tokens, top_k), real.
+# indices: (tokens * top_k), integer.
+# bin_ids: (tokens * top_k), integer.
+# bins: (num_experts), integer.
+# padded_bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _padded_copy_wgrad(
+ x,
+ grad,
+ wgrad,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+):
+ # Our index into 'tokens * top_k'.
+ index_out = tl.load(indices + tl.program_id(0))
+
+ # One threadblock per row in 'a'. Array 'b' has greater or equal
+ # number of rows since they could be padded.
+ bin_idx = tl.load(bin_ids + tl.program_id(0))
+
+ # Now we know what bin we're assigned to, but we need to know how
+ # many threadblocks were assigned to earlier bins so we can offset
+ # in our bin properly.
+ offset_in_bin = tl.program_id(0)
+ if bin_idx > 0:
+ offset_in_bin -= tl.load(bins + bin_idx - 1)
+
+ # Load the starting index of our bin in array 'x'.
+ index_x = offset_in_bin
+ if bin_idx > 0:
+ index_x += tl.load(padded_bins + bin_idx - 1)
+
+ # Offset the input and output pointers.
+ wgrad += index_out
+ grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS)
+ x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ data = tl.load(x + offsets, mask=mask).to(tl.float32)
+ scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
+ acc += data * scale
+ offsets += BLOCK_X
+
+ # Reduce to get the final result and store.
+ out = tl.sum(acc).to(wgrad.dtype.element_ty)
+ tl.store(wgrad, out)
+
+
+def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_matrix(grad)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], bin_ids.shape[0])
+ assert_equal(bins.size(), padded_bins.size())
+
+ tokens = indices.shape[0] // top_k
+ out = torch.empty((tokens * top_k), dtype=x.dtype, device=x.device)
+ _padded_copy_wgrad[(indices.shape[0],)](
+ x,
+ grad,
+ out,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ TOP_K=top_k,
+ )
+ return out
+
+
+def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k):
+ return padded_scatter_wgrad(x, grad, indices, bin_ids, bins, bins, top_k)
+
+
+# a: (tokens, hidden_size), real.
+# b: (num_experts, expert_capacity, num_columns), real.
+# indices: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _binned_copy(
+ a,
+ b,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+ A_TO_B: tl.constexpr,
+ SCALE: tl.constexpr,
+):
+ # Load our indices into the output.
+ expert_idx = tl.program_id(0)
+ entry_idx = tl.program_id(1)
+
+ # Calculate our offset into the output.
+ index_b = expert_idx * expert_capacity + entry_idx
+
+ # Load the index bounds for our bin and calculate
+ # the number of tokens assigned to our expert.
+ start = 0
+ if expert_idx > 0:
+ start = tl.load(bins + expert_idx - 1)
+ end = tl.load(bins + expert_idx)
+ num_tokens = end - start
+
+ # Calculate our offset into the input. If we don't
+ # have an input exit early.
+ if entry_idx >= num_tokens:
+ return
+ index_a = tl.load(indices + start + entry_idx)
+
+ # Offset the input and output pointers.
+ #
+ # If we're going from A to B, divide the input index to copy
+ # the same input repeatedly. If we're going from B to A we
+ # need to reduce the result. Using atomics is slow, so we
+ # do the reduce step in a second kernel.
+ offset = index_a // TOP_K if A_TO_B else index_a
+ a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS)
+ b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ # Load the scale, if requested.
+ scale = tl.load(weights + index_a) if SCALE else 1
+
+ # Swap the pointers depending on the direction.
+ #
+ # NOTE: We need to zero the output in both directions.
+ iptr = a if A_TO_B else b
+ optr = b if A_TO_B else a
+
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ x = tl.load(iptr + offsets, mask=mask)
+ x = x.to(tl.float32) * scale.to(tl.float32)
+
+ tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask)
+
+ offsets += BLOCK_X
+
+
+def binned_gather(x, indices, weights, bins, expert_capacity, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ num_experts = bins.shape[0]
+ out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
+
+ _binned_copy[(num_experts, expert_capacity)](
+ x,
+ out,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def binned_scatter(x, indices, weights, bins, top_k):
+ # Validate the input shapes.
+ assert_is_tensor(x, 3)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(bins.shape[0], x.shape[0])
+
+ if weights is not None:
+ assert_equal(indices.shape[0], weights.shape[0])
+
+ num_experts, expert_capacity, hidden_size = x.shape
+ tokens = indices.shape[0] // top_k
+ out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
+ _binned_copy[(num_experts, expert_capacity)](
+ out,
+ x,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS=hidden_size,
+ A_TO_B=False,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+
+ # Reduce along the top-k dimension, if needed.
+ return out.sum(dim=1) if top_k > 1 else out.view(tokens, hidden_size)
+
+
+# a: (tokens, hidden_size), real.
+# b: (num_experts, expert_capacity, num_columns), real.
+# indices: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _binned_copy_wgrad(
+ x,
+ grad,
+ wgrad,
+ num_experts,
+ expert_capacity,
+ indices,
+ bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+):
+ # Load our indices into the output.
+ expert_idx = tl.program_id(0)
+ entry_idx = tl.program_id(1)
+
+ # Calculate our offset into the output.
+ index_x = expert_idx * expert_capacity + entry_idx
+
+ # Load the index bounds for our bin and calculate
+ # the number of tokens assigned to our expert.
+ start = 0
+ if expert_idx > 0:
+ start = tl.load(bins + expert_idx - 1)
+ end = tl.load(bins + expert_idx)
+ num_tokens = end - start
+
+ # Calculate our offset into the input. If we don't
+ # have an input exit early.
+ if entry_idx >= num_tokens:
+ return
+ index_out = tl.load(indices + start + entry_idx)
+
+ # Offset the input and output pointers.
+ wgrad += index_out
+ grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS)
+ x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ data = tl.load(x + offsets, mask=mask).to(tl.float32)
+ scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
+ acc += data * scale
+ offsets += BLOCK_X
+
+ # Reduce to get the final result and store.
+ out = tl.sum(acc).to(wgrad.dtype.element_ty)
+ tl.store(wgrad, out)
+
+
+def binned_scatter_wgrad(x, grad, indices, bins, top_k):
+ # Validate the input shapes.
+ assert_is_tensor(x, 3)
+ assert_is_matrix(grad)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(bins.shape[0], x.shape[0])
+
+ num_experts, expert_capacity, hidden_size = x.shape
+ tokens = indices.shape[0] // top_k
+ out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
+ _binned_copy_wgrad[(num_experts, expert_capacity)](
+ x,
+ grad,
+ out,
+ num_experts,
+ expert_capacity,
+ indices,
+ bins,
+ NUM_COLUMNS=hidden_size,
+ TOP_K=top_k,
+ )
+ return out
diff --git a/build/torch210-cxx11-cu130-x86_64-linux/benchmark_util.py b/build/torch210-cxx11-cu130-x86_64-linux/benchmark_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..02612d95e3ead1175a596e2878fa34b5bf85ad6f
--- /dev/null
+++ b/build/torch210-cxx11-cu130-x86_64-linux/benchmark_util.py
@@ -0,0 +1,35 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import numpy as np
+import torch
+
+
+def log_benchmark(name, arguments, time, std):
+ print('=' * 60)
+ print(f'{name} Benchmark')
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean time = {:.3f}ms, std time = {:.3f}ms'.format(time, std))
+ print('=' * 60)
+
+
+def benchmark_function(fn, iterations=100, warmup=10):
+ # Warmup iterations.
+ for _ in range(warmup):
+ fn()
+
+ times = []
+ for i in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+
+ start.record()
+ fn()
+ end.record()
+
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ return np.mean(times), np.std(times)
diff --git a/build/torch210-cxx11-cu130-x86_64-linux/cpu_fused_moe.py b/build/torch210-cxx11-cu130-x86_64-linux/cpu_fused_moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a29e81d6e9c57f64b0a78ac1c0828e45fd9d855
--- /dev/null
+++ b/build/torch210-cxx11-cu130-x86_64-linux/cpu_fused_moe.py
@@ -0,0 +1,311 @@
+# SPDX-License-Identifier: Apache-2.0
+# MegaBlocks CPU Fused MoE Implementation
+#
+# This is a pure Python/PyTorch implementation for CPU.
+# For better performance, consider using the C++ kernel implementation.
+#
+import torch
+import torch.nn.functional as F
+
+
+def swigluoai_activation(gate: torch.Tensor, up: torch.Tensor,
+ alpha: float = 1.702, limit: float = 7.0) -> torch.Tensor:
+ """
+ SwigluOAI activation function used in GptOss models.
+
+ Formula:
+ gate = clamp(gate, max=limit)
+ up = clamp(up, -limit, limit)
+ glu = gate * sigmoid(gate * alpha)
+ output = (up + 1) * glu
+
+ Args:
+ gate: Gate tensor from gate projection
+ up: Up tensor from up projection
+ alpha: Scaling factor for sigmoid (default: 1.702)
+ limit: Clamp limit (default: 7.0)
+
+ Returns:
+ Activated tensor
+ """
+ gate = gate.clamp(max=limit)
+ up = up.clamp(min=-limit, max=limit)
+ glu = gate * torch.sigmoid(gate * alpha)
+ return (up + 1) * glu
+
+
+def silu_and_mul_activation(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor:
+ """
+ SiLU (Swish) activation with element-wise multiplication.
+
+ Formula:
+ output = silu(gate) * up
+
+ Args:
+ gate: Gate tensor
+ up: Up tensor
+
+ Returns:
+ Activated tensor
+ """
+ return F.silu(gate) * up
+
+
+def route_tokens_cpu(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: torch.Tensor | None,
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_normalize_expert_weights: int | None = None,
+) -> tuple:
+ """
+ Route tokens to experts and compute expert weights and indices (CPU version).
+
+ Args:
+ x: Input tensor [batch, seq, hidden] or [tokens, hidden]
+ router_weight: Router weight [num_experts, hidden]
+ router_bias: Router bias [num_experts] or None
+ moe_top_k: Number of experts per token
+ moe_num_experts: Total number of experts
+ moe_normalize_expert_weights: Normalization order or None
+
+ Returns:
+ Tuple of (logits, expert_weights, expert_indices)
+ """
+ x_flat = x.view(-1, x.shape[-1])
+ logits = F.linear(x_flat, router_weight, router_bias)
+
+ if moe_top_k == 1:
+ expert_weights, expert_indices = logits.max(dim=-1, keepdim=True)
+ else:
+ expert_weights, expert_indices = torch.topk(logits, moe_top_k, dim=-1)
+
+ expert_weights = expert_weights.softmax(dim=-1)
+
+ if moe_normalize_expert_weights is not None:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+
+ return logits, expert_weights, expert_indices
+
+
+def cpu_fused_moe(
+ hidden_states: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ w1_bias: torch.Tensor | None = None,
+ w2_bias: torch.Tensor | None = None,
+ activation: str = "silu",
+ alpha: float = 1.702,
+ limit: float = 7.0,
+ is_interleaved: bool = True,
+) -> torch.Tensor:
+ """
+ CPU Fused MoE using PyTorch operations.
+
+ This implementation processes all experts in parallel using batched operations
+ instead of sequential for loops, which is more efficient on CPU.
+
+ Args:
+ hidden_states: [num_tokens, hidden_size]
+ w1: [num_experts, hidden_size, 2*inter_size] - gate_up_proj weights
+ w2: [num_experts, inter_size, hidden_size] - down_proj weights
+ topk_weights: [num_tokens, topk] - routing weights
+ topk_ids: [num_tokens, topk] - expert indices
+ w1_bias: [num_experts, 2*inter_size] or None
+ w2_bias: [num_experts, hidden_size] or None
+ activation: "silu" or "swigluoai"
+ alpha: swigluoai alpha parameter
+ limit: swigluoai limit parameter
+ is_interleaved: whether gate_up is interleaved [g0,u0,g1,u1,...] (True for GptOss)
+
+ Returns:
+ output: [num_tokens, hidden_size]
+ """
+ num_tokens, hidden_size = hidden_states.shape
+ num_experts = w1.shape[0]
+ inter_size = w2.shape[1]
+ topk = topk_weights.shape[1]
+
+ # Initialize output
+ output = torch.zeros_like(hidden_states)
+
+ # Build expert mask: which tokens go to which expert
+ # expert_mask[expert_id] contains indices of (token_idx, topk_pos) pairs
+ for expert_idx in range(num_experts):
+ # Find tokens assigned to this expert
+ # mask shape: [num_tokens, topk], True where topk_ids == expert_idx
+ mask = (topk_ids == expert_idx)
+
+ if not mask.any():
+ continue
+
+ # Get token indices and topk positions
+ token_indices, topk_positions = torch.where(mask)
+
+ if len(token_indices) == 0:
+ continue
+
+ # Gather input tokens for this expert
+ # current_hidden: [num_selected_tokens, hidden_size]
+ current_hidden = hidden_states[token_indices]
+
+ # Get weights for this expert
+ # w1[expert_idx]: [hidden_size, 2*inter_size]
+ # w2[expert_idx]: [inter_size, hidden_size]
+ expert_w1 = w1[expert_idx] # [hidden_size, 2*inter_size]
+ expert_w2 = w2[expert_idx] # [inter_size, hidden_size]
+
+ # First projection: hidden @ w1 -> [num_selected, 2*inter_size]
+ gate_up = current_hidden @ expert_w1
+
+ # Add bias if present
+ if w1_bias is not None:
+ gate_up = gate_up + w1_bias[expert_idx]
+
+ # Split gate and up projections
+ if is_interleaved:
+ # GptOss uses interleaved layout: [g0, u0, g1, u1, ...]
+ gate = gate_up[..., ::2] # [num_selected, inter_size]
+ up = gate_up[..., 1::2] # [num_selected, inter_size]
+ else:
+ # Standard layout: [gate_all, up_all]
+ gate = gate_up[..., :inter_size]
+ up = gate_up[..., inter_size:]
+
+ # Apply activation
+ if activation == "swigluoai":
+ activated = swigluoai_activation(gate, up, alpha, limit)
+ else: # silu
+ activated = silu_and_mul_activation(gate, up)
+
+ # Second projection: activated @ w2 -> [num_selected, hidden_size]
+ expert_out = activated @ expert_w2
+
+ # Add bias if present
+ if w2_bias is not None:
+ expert_out = expert_out + w2_bias[expert_idx]
+
+ # Apply routing weights and accumulate
+ # weights shape: [num_selected]
+ weights = topk_weights[token_indices, topk_positions].unsqueeze(-1)
+ weighted_out = expert_out * weights
+
+ # Accumulate to output
+ output.index_add_(0, token_indices, weighted_out.to(output.dtype))
+
+ return output
+
+
+class MegaBlocksMoeMLP(torch.nn.Module):
+ """
+ CPU MoE MLP module that can be used as a drop-in replacement for
+ the transformers GptOssMLP when using @use_kernel_forward_from_hub.
+ """
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> tuple:
+ """
+ Forward pass through the MoE layer.
+
+ Args:
+ x: Input tensor of shape [batch_size, seq_len, hidden_size] or [tokens, hidden_size]
+
+ Returns:
+ Tuple of (output, expert_weights) where:
+ - output: Tensor of same shape as input
+ - expert_weights: Expert weights for each token [tokens, top_k]
+ """
+ # Get MoE parameters from the wrapped modules
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+
+ # Detect activation type
+ if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
+ activation = "swigluoai"
+ alpha = self.experts.alpha
+ limit = self.experts.limit
+ else:
+ activation = getattr(self.experts, "activation", "silu")
+ alpha = 1.702
+ limit = 7.0
+
+ # Get weight tensors
+ if hasattr(self.experts, "gate_up_proj"):
+ w1 = self.experts.gate_up_proj
+ is_interleaved = True # GptOss uses interleaved layout
+ elif hasattr(self.experts, "w1"):
+ w1 = self.experts.w1
+ w3 = getattr(self.experts, "w3", None)
+ if w3 is not None:
+ w1 = torch.cat([w1, w3], dim=-1)
+ is_interleaved = False
+ else:
+ raise AttributeError("experts module must have 'gate_up_proj' or 'w1' attribute")
+
+ if hasattr(self.experts, "down_proj"):
+ w2 = self.experts.down_proj
+ elif hasattr(self.experts, "w2"):
+ w2 = self.experts.w2
+ else:
+ raise AttributeError("experts module must have 'down_proj' or 'w2' attribute")
+
+ # Get optional bias tensors
+ w1_bias = getattr(self.experts, "gate_up_proj_bias", None)
+ w2_bias = getattr(self.experts, "down_proj_bias", None)
+
+ # Store original shape
+ in_shape = x.size()
+
+ # Route tokens to experts
+ logits, expert_weights, expert_indices = route_tokens_cpu(
+ x,
+ self.router.weight,
+ getattr(self.router, "bias", None),
+ moe_top_k,
+ moe_num_experts,
+ moe_normalize_expert_weights,
+ )
+
+ # Reshape input for fused MoE
+ x_flat = x.view(-1, x.shape[-1])
+
+ # Call CPU fused MoE
+ output = cpu_fused_moe(
+ hidden_states=x_flat,
+ w1=w1,
+ w2=w2,
+ topk_weights=expert_weights,
+ topk_ids=expert_indices,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ activation=activation,
+ alpha=alpha,
+ limit=limit,
+ is_interleaved=is_interleaved,
+ )
+
+ # Restore original shape
+ output = output.view(in_shape)
+
+ return output, expert_weights
+
+
+# Export classes and functions
+__all__ = [
+ "MegaBlocksMoeMLP",
+ "cpu_fused_moe",
+ "route_tokens_cpu",
+ "swigluoai_activation",
+ "silu_and_mul_activation",
+]
diff --git a/build/torch210-cxx11-cu130-x86_64-linux/cpu_moe_cpp.py b/build/torch210-cxx11-cu130-x86_64-linux/cpu_moe_cpp.py
new file mode 100644
index 0000000000000000000000000000000000000000..073ff66d24ce348fbb5ed19c9027fadd3f7a9c61
--- /dev/null
+++ b/build/torch210-cxx11-cu130-x86_64-linux/cpu_moe_cpp.py
@@ -0,0 +1,265 @@
+# SPDX-License-Identifier: Apache-2.0
+# MegaBlocks C++ Optimized CPU MoE
+
+"""
+C++ accelerated MoE with brgemm optimization for Intel AMX.
+Direct replacement for cpu_fused_moe.MegaBlocksMoeMLP with better performance.
+"""
+
+import torch
+from typing import Optional
+from .cpu_fused_moe import route_tokens_cpu
+from ._ops import ops
+
+
+def _to_local_tensor(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
+ """Convert DTensor to local torch.Tensor if needed for custom ops compatibility."""
+ if tensor is None:
+ return None
+ # Check if it's a DTensor by looking for the to_local() method
+ if hasattr(tensor, "to_local"):
+ return tensor.to_local()
+ return tensor
+
+
+def fused_moe_cpp(
+ hidden_states: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ inplace: bool = False,
+ use_int8_w8a8: bool = False,
+ use_fp8_w8a16: bool = False,
+ use_mxfp4: bool = False,
+ w1_scale: Optional[torch.Tensor] = None,
+ w2_scale: Optional[torch.Tensor] = None,
+ block_size: Optional[list] = None,
+ a1_scale: Optional[torch.Tensor] = None,
+ a2_scale: Optional[torch.Tensor] = None,
+ w1_bias: Optional[torch.Tensor] = None,
+ w2_bias: Optional[torch.Tensor] = None,
+ alpha: Optional[float] = None,
+ limit: Optional[float] = None,
+ is_vnni: bool = False,
+) -> torch.Tensor:
+ """
+ C++ Fused MoE with brgemm optimization (sglang compatible interface).
+
+ Uses at::native::cpublas::brgemm for efficient batch GEMM on Intel CPUs.
+ Supports both silu_and_mul (standard SwiGLU) and swigluoai (GptOss) activations.
+
+ Args:
+ hidden_states: Input tensor [M, K]
+ w1: Gate and up projections [E, 2N, K]
+ w2: Down projection [E, K, N]
+ topk_weights: Expert weights [M, topk]
+ topk_ids: Expert indices [M, topk]
+ inplace: Whether to use hidden_states as output
+ use_int8_w8a8: Use int8 quantization
+ use_fp8_w8a16: Use fp8 quantization
+ use_mxfp4: Use mxfp4 quantization
+ w1_scale, w2_scale: Quantization scales
+ block_size: Block size for fp8
+ a1_scale, a2_scale: Activation scales
+ w1_bias, w2_bias: Optional biases
+ alpha: swigluoai alpha parameter (set to enable swiglu)
+ limit: swigluoai limit parameter (set to enable swiglu)
+ is_vnni: Whether w1/w2 are already in VNNI packed format
+ """
+ # MXFP4/FP8 kernels only support bf16, convert if needed
+ orig_dtype = hidden_states.dtype
+ need_convert = ((use_mxfp4 or use_fp8_w8a16) and orig_dtype != torch.bfloat16) or orig_dtype == torch.float32
+ if need_convert:
+ hidden_states = hidden_states.to(torch.bfloat16)
+
+ # bias must match hidden_states dtype
+ if w1_bias is not None:
+ w1_bias = w1_bias.to(hidden_states.dtype)
+ if w2_bias is not None:
+ w2_bias = w2_bias.to(hidden_states.dtype)
+
+ # Convert DTensor to local tensor for custom ops compatibility (TP mode)
+ hidden_states = _to_local_tensor(hidden_states)
+ w1 = _to_local_tensor(w1)
+ w2 = _to_local_tensor(w2)
+ topk_weights = _to_local_tensor(topk_weights)
+ topk_ids = _to_local_tensor(topk_ids)
+ w1_scale = _to_local_tensor(w1_scale)
+ w2_scale = _to_local_tensor(w2_scale)
+ a1_scale = _to_local_tensor(a1_scale)
+ a2_scale = _to_local_tensor(a2_scale)
+ w1_bias = _to_local_tensor(w1_bias)
+ w2_bias = _to_local_tensor(w2_bias)
+
+ output = ops.fused_experts(
+ hidden_states, w1, w2, topk_weights, topk_ids,
+ inplace, use_int8_w8a8, use_fp8_w8a16, use_mxfp4,
+ w1_scale, w2_scale, block_size, a1_scale, a2_scale,
+ w1_bias, w2_bias, alpha, limit, is_vnni
+ )
+
+ # Convert back to original dtype if needed
+ if need_convert:
+ output = output.to(orig_dtype)
+ return output
+
+
+class CPUMegaBlocksMoeMLP(torch.nn.Module):
+ """
+ C++ optimized MoE MLP using brgemm.
+ Drop-in replacement for cpu_fused_moe.MegaBlocksMoeMLP with better performance.
+
+ Usage in transformers:
+ # Will be used via @use_kernel_forward_from_hub decorator
+ """
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> tuple:
+ """
+ Forward pass through the MoE layer using C++ kernel.
+
+ Args:
+ x: Input tensor [batch_size, seq_len, hidden_size]
+
+ Returns:
+ Tuple of (output, expert_weights)
+ """
+ # Optimization for GPT-OSS model
+ if getattr(self, "use_mxfp4", None) is None:
+ self.use_mxfp4 = False
+
+ w1_scale = None
+ w2_scale = None
+
+ if (
+ not getattr(self, "packed_scales", False)
+ and hasattr(self.experts, "gate_up_proj")
+ and getattr(self.experts, "gate_up_proj_precision_config", None) is not None
+ ):
+ # convert scales
+ data_1 = ops.convert_scale_packed(self.experts.gate_up_proj_precision_config.weight_scale.data.transpose(-1, -2).contiguous())
+ data_2 = ops.convert_scale_packed(self.experts.down_proj_precision_config.weight_scale.data.transpose(-1, -2).contiguous())
+ self.experts.gate_up_proj_precision_config.weight_scale.storage.data = data_1
+ self.experts.down_proj_precision_config.weight_scale.storage.data = data_2
+ self.packed_scales = True
+ self.use_mxfp4 = True
+
+ if not getattr(self, "packed_weight", False) and hasattr(
+ self.experts, "gate_up_proj"
+ ):
+ # convert weights
+ data_1 = self.experts.gate_up_proj.data.transpose(-1, -2).contiguous()
+ data_2 = self.experts.down_proj.data.transpose(-1, -2).contiguous()
+ if self.use_mxfp4:
+ self.experts.gate_up_proj.storage.data = ops.convert_weight_packed(data_1)
+ self.experts.down_proj.storage.data = ops.convert_weight_packed(data_2)
+ else:
+ # convert_weight_packed only supports bfloat16, float16, int8, fp8_e4m3 or uint8(mxfp4 or int4).
+ data_1 = data_1.to(torch.bfloat16) if data_1.dtype == torch.float32 else data_1
+ data_2 = data_2.to(torch.bfloat16) if data_2.dtype == torch.float32 else data_2
+ self.experts.gate_up_proj.data = ops.convert_weight_packed(data_1)
+ self.experts.down_proj.data = ops.convert_weight_packed(data_2)
+
+ # C++ kernel does not support float32.
+ dtype = torch.bfloat16 if x.dtype == torch.float32 else x.dtype
+ if getattr(self.experts, "gate_up_proj_bias", None) is not None:
+ self.experts.gate_up_proj_bias.data = self.experts.gate_up_proj_bias.data.to(dtype)
+ if getattr(self.experts, "down_proj_bias", None) is not None:
+ self.experts.down_proj_bias.data = self.experts.down_proj_bias.data.to(dtype)
+
+ self.packed_weight = True
+
+ # Get MoE parameters
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
+
+ # Detect activation type
+ if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
+ activation = "swigluoai"
+ alpha = self.experts.alpha
+ limit = self.experts.limit
+ else:
+ activation = getattr(self.experts, "activation", "silu")
+ alpha = 1.702
+ limit = 7.0
+
+ # Get weight tensors
+ if hasattr(self.experts, "gate_up_proj"):
+ w1 = self.experts.gate_up_proj
+ elif hasattr(self.experts, "w1"):
+ w1 = self.experts.w1
+ w3 = getattr(self.experts, "w3", None)
+ if w3 is not None:
+ w1 = torch.cat([w1, w3], dim=-1)
+ else:
+ raise AttributeError("experts module must have 'gate_up_proj' or 'w1' attribute")
+
+ if hasattr(self.experts, "down_proj"):
+ w2 = self.experts.down_proj
+ elif hasattr(self.experts, "w2"):
+ w2 = self.experts.w2
+ else:
+ raise AttributeError("experts module must have 'down_proj' or 'w2' attribute")
+
+ # Get optional bias tensors
+ w1_bias = getattr(self.experts, "gate_up_proj_bias", None)
+ w2_bias = getattr(self.experts, "down_proj_bias", None)
+ w1_bias = w1_bias if w1_bias is None else w1_bias.data
+ w2_bias = w2_bias if w2_bias is None else w2_bias.data
+
+ if self.use_mxfp4:
+ w1_scale = self.experts.gate_up_proj_precision_config.weight_scale.data
+ w2_scale = self.experts.down_proj_precision_config.weight_scale.data
+
+ # Store original shape
+ in_shape = x.size()
+
+ # Route tokens to experts (Python implementation is fast enough)
+ logits, expert_weights, expert_indices = route_tokens_cpu(
+ x,
+ self.router.weight,
+ getattr(self.router, "bias", None),
+ moe_top_k,
+ moe_num_experts,
+ moe_normalize_expert_weights,
+ )
+
+ # Flatten input
+ x_flat = x.view(-1, x.shape[-1])
+
+ # Determine alpha/limit for swiglu activation
+ use_alpha = alpha if activation == "swigluoai" else None
+ use_limit = limit if activation == "swigluoai" else None
+
+ # Call C++ optimized kernel
+ output = fused_moe_cpp(
+ hidden_states=x_flat,
+ w1=w1.data,
+ w2=w2.data,
+ topk_weights=expert_weights,
+ topk_ids=expert_indices.to(torch.int32),
+ inplace=False,
+ use_int8_w8a8=False,
+ use_fp8_w8a16=False,
+ use_mxfp4=self.use_mxfp4,
+ w1_scale=w1_scale,
+ w2_scale=w2_scale,
+ block_size=None,
+ a1_scale=None,
+ a2_scale=None,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ alpha=use_alpha,
+ limit=use_limit,
+ is_vnni=getattr(self, "packed_weight", False),
+ )
+
+ # Restore original shape
+ output = output.view(in_shape)
+
+ return output, expert_weights
+
+
+__all__ = ["fused_moe_cpp", "MegaBlocksMoeMLP"]
diff --git a/build/torch210-cxx11-cu130-x86_64-linux/grouped_gemm/__init__.py b/build/torch210-cxx11-cu130-x86_64-linux/grouped_gemm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b91c8308f0c24f4c4171b6e4f15b6f76dabf295a
--- /dev/null
+++ b/build/torch210-cxx11-cu130-x86_64-linux/grouped_gemm/__init__.py
@@ -0,0 +1,2 @@
+from . import ops
+from . import backend
diff --git a/build/torch210-cxx11-cu130-x86_64-linux/grouped_gemm/backend.py b/build/torch210-cxx11-cu130-x86_64-linux/grouped_gemm/backend.py
new file mode 100644
index 0000000000000000000000000000000000000000..76037d8039cbfc2f0577275c78e4bc0be762592a
--- /dev/null
+++ b/build/torch210-cxx11-cu130-x86_64-linux/grouped_gemm/backend.py
@@ -0,0 +1,33 @@
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# # TODO(tgale): Wrap this in a try-block with better
+# # error message and instructions for building the
+# # c++ operations.
+# import grouped_gemm_backend as backend
+
+# We import the backend operations from the megablocks package as
+# grouped_gemm is vendored in megablocks in this repository.
+# from ... import _ops as backend
+# from megablocks._ops import ops as backend # type: ignore
+from .._ops import ops as backend # type: ignore
+
+def _allocate_output(a, b, batch_sizes, trans_a, trans_b):
+ assert not (trans_a and trans_b)
+ assert batch_sizes.ndim == 1, "Expected 1d tensor for batch_sizes"
+ assert a.ndim == 2, "Expected 2d tensor for 'a'"
+ assert b.ndim == (2 if trans_a else 3)
+
+ shape = (
+ (batch_sizes.shape[0], a.shape[1], b.shape[1])
+ if trans_a else
+ (a.shape[0], (b.shape[1] if trans_b else b.shape[2]))
+ )
+ return torch.empty(*shape, device=a.device, dtype=a.dtype)
+
+def gmm(a, b, batch_sizes, trans_a=False, trans_b=False, c=None):
+ if c is None:
+ c = _allocate_output(a, b, batch_sizes, trans_a, trans_b)
+ backend.gmm(a, b, c, batch_sizes, trans_a, trans_b)
+ return c
diff --git a/build/torch210-cxx11-cu130-x86_64-linux/grouped_gemm/ops.py b/build/torch210-cxx11-cu130-x86_64-linux/grouped_gemm/ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b30dd14e23837ea3b12334f4e31337ed9ad2b69
--- /dev/null
+++ b/build/torch210-cxx11-cu130-x86_64-linux/grouped_gemm/ops.py
@@ -0,0 +1,33 @@
+from . import backend
+import torch
+
+
+class GroupedGemm(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, a, b, batch_sizes, trans_b):
+ ctx.save_for_backward(a, b, batch_sizes)
+ ctx.trans_b = trans_b
+ return backend.gmm(a, b, batch_sizes, trans_a=False, trans_b=trans_b)
+
+ @staticmethod
+ def backward(ctx, grad):
+ grad = grad.contiguous()
+ a, b, batch_sizes = ctx.saved_tensors
+ trans_b = ctx.trans_b
+
+ agrad = None
+ if ctx.needs_input_grad[0]:
+ agrad = backend.gmm(
+ grad, b, batch_sizes, trans_a=False, trans_b=not trans_b)
+
+ bgrad = None
+ if ctx.needs_input_grad[1]:
+ lhs, rhs = (grad, a) if trans_b else (a, grad)
+ bgrad = backend.gmm(
+ lhs, rhs, batch_sizes, trans_a=True, trans_b=False)
+ return agrad, bgrad, None, None
+
+
+def gmm(a, b, batch_sizes, trans_b=False):
+ return GroupedGemm.apply(a, b, batch_sizes, trans_b)
diff --git a/build/torch210-cxx11-cu130-x86_64-linux/grouped_gemm_util.py b/build/torch210-cxx11-cu130-x86_64-linux/grouped_gemm_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d6d49fc46b0a57ad46e4179df3cc1ac2a24f7ae
--- /dev/null
+++ b/build/torch210-cxx11-cu130-x86_64-linux/grouped_gemm_util.py
@@ -0,0 +1,31 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+import warnings
+
+_grouped_gemm_is_available: bool = False
+try:
+ # import grouped_gemm
+ pass
+ _grouped_gemm_is_available = True
+except ImportError as error:
+ warnings.warn('Grouped GEMM not available.')
+
+
+def grouped_gemm_is_available():
+ return _grouped_gemm_is_available
+
+
+def assert_grouped_gemm_is_available():
+ msg = (
+ 'Grouped GEMM not available. Please run '
+ '`pip install git+https://github.com/tgale96/grouped_gemm@main`.',
+ )
+ assert _grouped_gemm_is_available, msg
+
+
+# backend = grouped_gemm.backend if grouped_gemm_is_available() else None
+# ops = grouped_gemm.ops if grouped_gemm_is_available() else None
+
+
+from .grouped_gemm import backend as ops
+from .grouped_gemm import ops as backend
diff --git a/build/torch210-cxx11-cu130-x86_64-linux/layers.py b/build/torch210-cxx11-cu130-x86_64-linux/layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c91edcd9e2d1a4ef9eac90217ff481f08ab1886
--- /dev/null
+++ b/build/torch210-cxx11-cu130-x86_64-linux/layers.py
@@ -0,0 +1,1232 @@
+import torch
+import torch.distributed as dist
+
+from typing import Optional, Any, TYPE_CHECKING
+
+from . import _layers
+from . import ops
+
+# Conditional import for meta kernel registration
+if TYPE_CHECKING:
+
+ def register_fake(fn):
+ return lambda name: fn
+
+else:
+ try:
+ from torch.library import register_fake
+ except ImportError:
+ try:
+ from torch.library import impl_abstract as register_fake
+ except ImportError:
+ # Fallback for older PyTorch versions
+ def register_fake(op_name):
+ def decorator(fn):
+ return fn
+
+ return decorator
+
+
+# Meta kernel implementations for torch.compile compatibility
+def _install_meta_kernels():
+ """Install meta kernels for existing MegaBlocks operations"""
+
+ # Create wrapper functions that check for compilation and return meta tensors
+
+ # Patch ops.sort
+ if hasattr(ops, "sort"):
+ original_sort = ops.sort
+
+ def sort_with_meta(x, end_bit=None):
+ if torch.compiler.is_compiling():
+ # print("Using meta kernel for sort")
+ # Meta implementation - return tensors with correct shape/dtype/device
+ return torch.empty_like(x), torch.empty_like(x)
+ # print("Using original sort kernel")
+ return original_sort(x, end_bit)
+
+ ops.sort = sort_with_meta
+
+ # Patch ops.histogram
+ if hasattr(ops, "histogram"):
+ original_histogram = ops.histogram
+
+ def histogram_with_meta(x, max_val):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty((max_val,), dtype=torch.int32, device=x.device)
+ return original_histogram(x, max_val)
+
+ ops.histogram = histogram_with_meta
+
+ # Patch ops.inclusive_cumsum
+ if hasattr(ops, "inclusive_cumsum"):
+ original_inclusive_cumsum = ops.inclusive_cumsum
+
+ def inclusive_cumsum_with_meta(x, dim):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty_like(x)
+ return original_inclusive_cumsum(x, dim)
+
+ ops.inclusive_cumsum = inclusive_cumsum_with_meta
+
+ # Patch ops.binned_gather
+ if hasattr(ops, "binned_gather"):
+ original_binned_gather = ops.binned_gather
+
+ def binned_gather_with_meta(x, indices, bins, bin_size, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - output shape based on bin_size
+ if x.dim() >= 2:
+ hidden_size = x.size(-1)
+ return torch.empty(
+ (bin_size, x.size(1), hidden_size),
+ dtype=x.dtype,
+ device=x.device,
+ )
+ else:
+ return torch.empty((bin_size,), dtype=x.dtype, device=x.device)
+ return original_binned_gather(x, indices, bins, bin_size, top_k)
+
+ ops.binned_gather = binned_gather_with_meta
+
+ # Patch ops.binned_scatter
+ if hasattr(ops, "binned_scatter"):
+ original_binned_scatter = ops.binned_scatter
+
+ def binned_scatter_with_meta(x, indices, weights, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - typically reduces to 2D
+ if x.dim() >= 3:
+ return torch.empty(
+ (x.size(1), x.size(2)), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty_like(x)
+ return original_binned_scatter(x, indices, weights, bins, top_k)
+
+ ops.binned_scatter = binned_scatter_with_meta
+
+ # Patch ops.gather
+ if hasattr(ops, "gather"):
+ original_gather = ops.gather
+
+ def gather_with_meta(x, indices, bin_ids, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ if x.dim() >= 2:
+ hidden_size = x.size(-1)
+ return torch.empty(
+ (indices.numel(), hidden_size), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty(indices.shape, dtype=x.dtype, device=x.device)
+ return original_gather(x, indices, bin_ids, bins, top_k)
+
+ ops.gather = gather_with_meta
+
+ # Patch ops.scatter
+ if hasattr(ops, "scatter"):
+ original_scatter = ops.scatter
+
+ def scatter_with_meta(x, indices, bin_ids, weights, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - restore sequence shape
+ seq_len = (
+ indices.size(0) // top_k
+ if indices.numel() > 0 and top_k > 0
+ else x.size(0)
+ )
+ if x.dim() >= 2:
+ return torch.empty(
+ (seq_len, x.size(-1)), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty((seq_len,), dtype=x.dtype, device=x.device)
+ return original_scatter(x, indices, bin_ids, weights, bins, top_k)
+
+ ops.scatter = scatter_with_meta
+
+ # Patch ops.replicate
+ if hasattr(ops, "replicate"):
+ original_replicate = ops.replicate
+
+ def replicate_with_meta(x, bins, num_outputs):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty(
+ (x.shape[0], num_outputs), dtype=x.dtype, device=x.device
+ )
+ return original_replicate(x, bins, num_outputs)
+
+ ops.replicate = replicate_with_meta
+
+ # Patch ops.repeat (if it's a regular function)
+ if hasattr(ops, "repeat"):
+ original_repeat = ops.repeat
+
+ def repeat_with_meta(x, repeats):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ if isinstance(repeats, (tuple, list)):
+ new_shape = list(x.shape)
+ for i, rep in enumerate(repeats):
+ if i < len(new_shape):
+ new_shape[i] *= rep
+ return torch.empty(new_shape, dtype=x.dtype, device=x.device)
+ else:
+ new_shape = [x.size(0) * repeats] + list(x.shape[1:])
+ return torch.empty(new_shape, dtype=x.dtype, device=x.device)
+ return original_repeat(x, repeats)
+
+ ops.repeat = repeat_with_meta
+
+
+# Install meta kernels on import
+try:
+ _install_meta_kernels()
+except Exception as e:
+ # If meta kernel installation fails, continue without them
+ # torch.compile may not work but the library will still function
+ import warnings
+
+ warnings.warn(
+ f"Failed to install meta kernels for torch.compile support: {e}", UserWarning
+ )
+
+
+# Set the expert model parallel attributes on a tensor
+def set_expert_model_parallel_attributes(
+ tensor: torch.Tensor,
+ is_parallel: bool,
+):
+ assert not hasattr(tensor, "expert_model_parallel")
+ setattr(tensor, "expert_model_parallel", is_parallel)
+
+
+# Get the expert model parallel attributes from a tensor
+def expert_sharding_degree(
+ world_size: int,
+ moe_num_experts: int,
+) -> int:
+ esd = min(world_size, moe_num_experts)
+ if (moe_num_experts % esd) != 0:
+ raise ValueError(f"Cannot shard {moe_num_experts} experts {esd} ways.")
+ return esd
+
+
+# Calculate the hidden sharding degree based on world size and expert sharding degree
+def hidden_sharding_degree(
+ world_size: int,
+ moe_num_experts: int,
+ ffn_hidden_size: int,
+) -> int:
+ esd = expert_sharding_degree(world_size, moe_num_experts)
+ hsd = world_size // esd
+ if (ffn_hidden_size % hsd) != 0:
+ raise ValueError(f"Cannot shard {ffn_hidden_size} features {hsd} ways.")
+ if (esd * hsd) != world_size:
+ raise ValueError(
+ f"Invalid sharding. expert_sharding_degree ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size})."
+ )
+ return hsd
+
+
+# Calculate the number of experts per rank based on world size and expert sharding degree
+def experts_per_rank(
+ moe_num_experts: int,
+ world_size: int,
+) -> int:
+ return moe_num_experts // expert_sharding_degree(world_size, moe_num_experts)
+
+
+# Calculate the number of features per rank based on ffn hidden size and hidden sharding degree
+def features_per_rank(
+ ffn_hidden_size: int, world_size: int, moe_num_experts: int
+) -> int:
+ return ffn_hidden_size // hidden_sharding_degree(
+ world_size, moe_num_experts, ffn_hidden_size
+ )
+
+
+# Apply jitter to the input tensor
+def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor:
+ low = 1.0 - moe_jitter_eps
+ high = 1.0 + moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return x * (low + noise * (high - low))
+
+
+# Compute the top-k scores from the logits
+def compute_top_k(scores: torch.Tensor, moe_top_k: int):
+ if moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, moe_top_k, dim=-1)
+
+
+# Route tokens to experts and compute expert weights and indices
+def route_tokens(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: torch.Tensor,
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ if training and moe_jitter_eps is not None:
+ x = apply_jitter(x, moe_jitter_eps)
+
+ x_flat = x.view(-1, x.shape[-1])
+ logits = torch.nn.functional.linear(x_flat, router_weight, router_bias)
+ expert_weights, expert_indices = compute_top_k(logits, moe_top_k)
+ expert_weights = expert_weights.softmax(dim=-1)
+ if moe_normalize_expert_weights is not None:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+ if uniform_expert_assignment:
+ expert_indices = _layers.router._uniform_expert_assignment(
+ expert_indices,
+ moe_num_experts,
+ )
+
+ return logits, expert_weights, expert_indices
+
+
+# Scale the gradient of the weights
+def scale_grad(
+ w: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+) -> torch.Tensor:
+ if gradient_scale is None:
+ return w
+ return _layers.mlp.scale_gradient(w, gradient_scale)
+
+
+# Forward pass for the MLP layer
+def mlp_forward(
+ x: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ limit: float = 7.0,
+):
+ # Scale weights
+ w1 = scale_grad(w1, gradient_scale)
+ w2 = scale_grad(w2, gradient_scale)
+ w1_bias = scale_grad(w1_bias, gradient_scale)
+ w2_bias = scale_grad(w2_bias, gradient_scale)
+
+ # Resolve dtensors
+ w1 = _layers.mlp.resolve_dtensor(w1)
+ w2 = _layers.mlp.resolve_dtensor(w2)
+ w1_bias = _layers.mlp.resolve_dtensor(w1_bias)
+ w2_bias = _layers.mlp.resolve_dtensor(w2_bias)
+
+ # Forward pass
+ gate_up = torch.bmm(x, w1) + w1_bias[..., None, :]
+ gate, up = gate_up[..., ::2], gate_up[..., 1::2]
+ gate = gate.clamp(min=None, max=limit)
+ up = up.clamp(min=-limit, max=limit)
+ glu = gate * torch.sigmoid(gate * alpha)
+ next_states = torch.bmm(((up + 1) * glu), w2)
+ next_states += w2_bias[..., None, :]
+ return next_states
+
+# Shared expert MLP forward pass
+def shared_mlp_forward(
+ x: torch.Tensor,
+ up_proj_weight: torch.Tensor,
+ down_proj_weight: torch.Tensor,
+ up_proj_bias: Optional[torch.Tensor] = None,
+ down_proj_bias: Optional[torch.Tensor] = None,
+ activation_fn: Optional[Any] = None,
+ gradient_scale: Optional[float] = None,
+) -> torch.Tensor:
+ # Default activation function
+ if activation_fn is None:
+ activation_fn = torch.nn.functional.gelu
+
+ # Scale weights
+ up_proj_weight = scale_grad(up_proj_weight, gradient_scale)
+ down_proj_weight = scale_grad(down_proj_weight, gradient_scale)
+ if up_proj_bias is not None:
+ up_proj_bias = scale_grad(up_proj_bias, gradient_scale)
+ if down_proj_bias is not None:
+ down_proj_bias = scale_grad(down_proj_bias, gradient_scale)
+
+ # Resolve dtensors
+ up_proj_weight = _layers.mlp.resolve_dtensor(up_proj_weight)
+ down_proj_weight = _layers.mlp.resolve_dtensor(down_proj_weight)
+ if up_proj_bias is not None:
+ up_proj_bias = _layers.mlp.resolve_dtensor(up_proj_bias)
+ if down_proj_bias is not None:
+ down_proj_bias = _layers.mlp.resolve_dtensor(down_proj_bias)
+
+ # Up projection
+ x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias)
+
+ # Activation
+ x = activation_fn(x)
+
+ # Down projection
+ x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias)
+
+ return x
+
+
+# Combine outputs from shared expert and regular experts
+def combine_expert_shared_outputs(
+ shared_expert_out: torch.Tensor,
+ expert_out: torch.Tensor,
+ shared_expert_weighted_sum: bool = False,
+ moe_top_k: int = 1,
+) -> torch.Tensor:
+ if shared_expert_weighted_sum:
+ # Weighted sum based on number of experts used
+ total_experts = moe_top_k + 1
+ shared_weight = 1.0 / total_experts
+ expert_weight = moe_top_k / total_experts
+ return shared_expert_out * shared_weight + expert_out * expert_weight
+ else:
+ # Simple addition
+ return shared_expert_out + expert_out
+
+
+# Global variable to store load balancing loss
+_LOAD_BALANCING_LOSS = []
+
+
+def save_load_balancing_loss(loss):
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.append(loss)
+
+
+def get_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ return _LOAD_BALANCING_LOSS
+
+
+def clear_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.clear()
+
+
+def batched_load_balancing_loss(args):
+ if args.moe_loss_weight == 0:
+ return 0.0
+
+ tokens_per_expert, expert_scores = zip(*get_load_balancing_loss())
+ num_layers_per_pipeline_stage = args.num_layers // args.pipeline_model_parallel_size
+ if args.num_layers_per_virtual_pipeline_stage is not None:
+ num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage
+
+ if len(tokens_per_expert) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f"Expected {num_layers_per_pipeline_stage} token_per_experts "
+ f"but found {len(tokens_per_expert)}.\nnum_layers = "
+ f"{args.num_layers}\npipeline_model_parallel_size = "
+ f"{args.pipeline_model_parallel_size}\n"
+ "num_layers_per_virtual_pipeline_stage"
+ f" = {args.num_layers_per_virtual_pipeline_stage}",
+ )
+ if len(expert_scores) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f"Expected {num_layers_per_pipeline_stage} expert_scores "
+ f"but found {len(tokens_per_expert)}.\nnum_layers = "
+ f"{args.num_layers}\npipeline_model_parallel_size = "
+ f"{args.pipeline_model_parallel_size}\n"
+ "num_layers_per_virtual_pipeline_stage"
+ f" = {args.num_layers_per_virtual_pipeline_stage}",
+ )
+
+ # Verify the shape of the tokens_per_expert and expert_scores tensors.
+ assert all(
+ (x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert)
+ )
+
+ tokens = expert_scores[0].shape[0]
+ assert all(
+ (
+ (
+ x.ndim == 2
+ and x.shape[1] == args.moe_num_experts
+ and x.shape[0] == tokens
+ )
+ for x in expert_scores
+ )
+ )
+
+ # Concatenate the contributions of each layer and convert to
+ # the correct types and formats for the dot product.
+ expert_scores = torch.cat(expert_scores, dim=1)
+ if args.moe_lbl_in_fp32:
+ expert_scores = expert_scores.float()
+ if tokens != 0:
+ expert_scores = expert_scores.mean(dim=0)
+ else:
+ expert_scores = expert_scores.sum(dim=0)
+ tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype)
+
+ expected_values = num_layers_per_pipeline_stage * args.moe_num_experts
+ assert tokens_per_expert.numel() == expected_values
+ assert expert_scores.numel() == expected_values
+
+ # Calculate the total scale across all factors.
+ #
+ # loss_weight * num_experts / (num_layers * tokens * top_k)
+ scale_numerator = args.moe_num_experts * args.moe_loss_weight
+ scale_denominator = args.num_layers * tokens * args.moe_top_k
+ scale = scale_numerator / scale_denominator
+ return scale * torch.dot(tokens_per_expert, expert_scores)
+
+
+# Calculate the expert capacity based on tokens, top_k, number of experts,
+# expert parallel group, capacity factor, and whether expert model parallelism is used.
+def expert_capacity(
+ tokens: int,
+ top_k: int,
+ num_experts: int,
+ expert_parallel_group: int,
+ moe_capacity_factor: float,
+ moe_expert_model_parallelism: bool,
+) -> int:
+ world_size = (
+ dist.get_world_size(expert_parallel_group)
+ if moe_expert_model_parallelism
+ else 1
+ )
+
+ tokens_per_expert = top_k * tokens * world_size / num_experts
+ return int(moe_capacity_factor * tokens_per_expert)
+
+
+def load_balancing_loss(
+ tokens_per_expert: torch.Tensor,
+ expert_scores: torch.Tensor,
+ top_k: int,
+ num_experts: int,
+):
+ assert len(expert_scores.size()) == 2
+ tokens, num_experts = expert_scores.size()
+ assert num_experts == num_experts
+ assert len(tokens_per_expert.size()) == 1
+ (num_experts,) = tokens_per_expert.size()
+ assert num_experts == num_experts
+ scale = num_experts / (tokens * top_k)
+ return scale * torch.dot(
+ tokens_per_expert.to(expert_scores.dtype),
+ expert_scores.mean(dim=0),
+ )
+
+
+def indices_and_bins(
+ top_expert: torch.Tensor,
+ sort_end_bit: int,
+ num_experts: int,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ top_expert = top_expert.int()
+
+ # Ensure contiguous memory layout
+ top_expert = top_expert.contiguous()
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(top_expert.device):
+ output = ops.sort(top_expert, sort_end_bit)
+ bin_ids, indices = output
+ tokens_per_expert = ops.histogram(top_expert, num_experts)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+
+ bins = bins.view(1) if not len(bins.size()) else bins
+ return indices, bin_ids, bins, tokens_per_expert
+
+
+def expert_capacity_fn(
+ tokens: int,
+ top_k: int,
+ num_experts: int,
+ expert_parallel_group: torch.distributed.ProcessGroup,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+) -> int:
+ world_size = (
+ dist.get_world_size(expert_parallel_group)
+ if moe_expert_model_parallelism
+ else 1
+ )
+ tokens_per_expert = top_k * tokens * world_size / num_experts
+ return int(moe_capacity_factor * tokens_per_expert)
+
+
+def permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ top_k,
+ w1,
+ w2,
+ w1_bias,
+ w2_bias,
+ gradient_scale,
+ alpha,
+):
+ # Route tokens to experts
+ x = x.view(-1, x.shape[-1])
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(x.device):
+ x = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
+
+ # Expert computation
+ x = mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale, alpha)
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(x.device):
+ # Route tokens back
+ out = ops.binned_scatter(x, indices, expert_weights, bins, top_k)
+ return out
+
+
+def forward_once(
+ x: torch.Tensor,
+ expert_weights: torch.Tensor,
+ top_experts: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ top_k: int = 4,
+ num_experts: int = 128,
+ expert_parallel_group: int = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ mlp_impl: Optional[str] = None,
+):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
+ top_experts, sort_end_bit, num_experts
+ )
+
+ # Calculate expert capacity
+ sl, bs, _ = x.size()
+
+ expert_capacity = expert_capacity_fn(
+ sl * bs,
+ top_k,
+ num_experts,
+ expert_parallel_group,
+ moe_capacity_factor,
+ moe_expert_model_parallelism,
+ )
+
+ if expert_capacity == 0:
+ expert_capacity = torch.max(tokens_per_expert).item()
+
+ x = permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ top_k,
+ w1,
+ w2,
+ w1_bias,
+ w2_bias,
+ gradient_scale,
+ alpha,
+ )
+ return x, tokens_per_expert
+
+
+def parallel_forward_once(
+ x: torch.Tensor,
+ expert_weights: torch.Tensor,
+ top_experts: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ top_k: int = 4,
+ num_experts: int = 128,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = True,
+ hidden_size: int = 1152,
+ mlp_impl: Optional[str] = "grouped",
+):
+ # Flatten inputs
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+
+ # TODO: remove debugging var
+ # my_rank = dist.get_rank(expert_parallel_group) if expert_parallel_group else 0
+
+ with torch.no_grad():
+ # Step 1: Local permutation setup
+ indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
+ top_experts, sort_end_bit, num_experts
+ )
+
+ # Calculate sharding parameters
+ world_size = dist.get_world_size(expert_parallel_group)
+ hidden_sharding_deg = hidden_sharding_degree(
+ world_size, num_experts, hidden_size
+ )
+ experts_per_rank_val = experts_per_rank(num_experts, world_size)
+
+ # Replicate token counts for hidden sharding
+ repeated_tokens_per_expert = ops.repeat(
+ tokens_per_expert, (hidden_sharding_deg,)
+ )
+
+ # Exchange token counts across devices
+ parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert)
+
+ # Ensure CUB knows which device to use
+ tpe_handle = dist.all_to_all_single(
+ parallel_tokens_per_expert,
+ repeated_tokens_per_expert,
+ group=expert_parallel_group,
+ async_op=True,
+ )
+
+ # Step 2: Local permutation - group tokens by target device
+ x = x.view(-1, x.shape[-1]) # [sl * bs, hs]
+ x = ops.gather(x, indices, bin_ids, bins, top_k)
+
+ # Step 3: Compute communication counts and exchange tokens
+ with torch.no_grad():
+ tpe_handle.wait()
+
+ # Reshape for per-device calculations
+ repeated_tokens_per_expert = repeated_tokens_per_expert.view(
+ world_size, experts_per_rank_val
+ )
+ parallel_tokens_per_expert = parallel_tokens_per_expert.view(
+ world_size, experts_per_rank_val
+ )
+
+ # Calculate send/recv counts
+ send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1).tolist()
+ # recv_counts = parallel_tokens_per_expert.cpu().sum(dim=-1).tolist()
+ parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
+ recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1).tolist()
+ tokens_received = sum(recv_counts)
+
+ # Replicate for hidden sharding
+ x = ops.repeat(x, (hidden_sharding_deg, 1))
+
+ # Cross-device token exchange
+ parallel_x, parallel_x_handle = _layers.all_to_all.all_to_all(
+ x, recv_counts, send_counts, expert_parallel_group, async_op=True
+ )
+
+ with torch.no_grad():
+ # Step 4: Setup for local expert computation
+ replicate_bins = ops.inclusive_cumsum(parallel_tokens_per_expert.flatten(), 0)
+ replicate_bins = (
+ replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins
+ )
+
+ # Create expert indices for received tokens
+ parallel_top_expert = torch.remainder(
+ torch.arange(
+ num_experts * hidden_sharding_deg,
+ dtype=torch.int32,
+ device=indices.device,
+ ),
+ experts_per_rank_val,
+ )
+ parallel_top_expert = ops.replicate(
+ parallel_top_expert.unsqueeze(dim=0),
+ replicate_bins,
+ tokens_received,
+ ).flatten()
+
+ # Sort tokens by expert assignment
+ parallel_bin_ids, parallel_indices = ops.sort(
+ parallel_top_expert,
+ sort_end_bit,
+ )
+
+ # Calculate bins for local experts
+ parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
+ dim=0, dtype=torch.int
+ )
+ parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = (
+ parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins
+ )
+
+ # Calculate expert capacity
+ expert_capacity = expert_capacity_fn(
+ tokens_received,
+ top_k,
+ experts_per_rank_val,
+ expert_parallel_group,
+ moe_capacity_factor,
+ moe_expert_model_parallelism,
+ )
+ if expert_capacity == 0:
+ expert_capacity = torch.max(parallel_tokens_per_expert).item()
+
+ # Locally permute the tokens and perform the expert computation.
+ # Block to make sure that the cross-device permutation is complete.
+ if mlp_impl == "grouped":
+ # GroupedMLP requires counts on CPU. We can use the tensor already
+ # moved to CPU for the prior all_to_all, which avoids an extra
+ # device synchronization.
+ parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+
+ # Step 5: Expert computation
+ parallel_x_handle.wait()
+
+ parallel_x = permute_and_compute(
+ parallel_x,
+ parallel_tokens_per_expert,
+ parallel_indices,
+ parallel_bin_ids,
+ None, # expert_weights
+ parallel_bins,
+ expert_capacity,
+ top_k=1,
+ w1=w1,
+ w2=w2,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ )
+
+ # Step 6: Reverse communication - send results back
+ x, _ = _layers.all_to_all.all_to_all(
+ parallel_x, send_counts, recv_counts, expert_parallel_group
+ )
+
+ # Step 7: Reduce across hidden sharding dimension
+ shape = (hidden_sharding_deg, -1, hidden_size)
+ x = x.view(shape).sum(dim=0)
+
+ # Step 8: Final local unpermutation
+ x = ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
+
+ return x, tokens_per_expert.flatten()
+
+
+def moe_forward(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: Optional[torch.Tensor],
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+ w1: torch.Tensor = None,
+ w2: torch.Tensor = None,
+ w1_bias: torch.Tensor = None,
+ w2_bias: torch.Tensor = None,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ forward_fn: Any = None,
+ hidden_size: int = None,
+ mlp_impl: str = "grouped",
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+
+ # Route tokens to experts
+ logits, expert_weights, expert_indices = route_tokens(
+ x,
+ router_weight,
+ router_bias,
+ moe_top_k,
+ moe_num_experts,
+ moe_jitter_eps,
+ moe_normalize_expert_weights,
+ uniform_expert_assignment,
+ training,
+ )
+
+ # Create router scores for output
+ router_scores = (
+ torch.zeros_like(logits)
+ .scatter_(1, expert_indices, expert_weights)
+ .transpose(0, 1)
+ )
+
+ in_shape = x.size()
+
+ # Prepare forward function arguments
+ forward_args = {
+ "x": x,
+ "expert_weights": expert_weights,
+ "top_experts": expert_indices,
+ "w1": w1,
+ "w2": w2,
+ "w1_bias": w1_bias,
+ "w2_bias": w2_bias,
+ "gradient_scale": gradient_scale,
+ "alpha": alpha,
+ "sort_end_bit": sort_end_bit,
+ "top_k": moe_top_k,
+ "num_experts": moe_num_experts,
+ "expert_parallel_group": expert_parallel_group,
+ "moe_capacity_factor": moe_capacity_factor,
+ "moe_expert_model_parallelism": moe_expert_model_parallelism,
+ "mlp_impl": mlp_impl,
+ }
+
+ # Add hidden_size for parallel forward
+ if moe_expert_model_parallelism and hidden_size is not None:
+ forward_args["hidden_size"] = hidden_size
+ elif moe_expert_model_parallelism and hidden_size is None:
+ # Infer hidden_size from input shape
+ forward_args["hidden_size"] = x.shape[-1]
+
+ # Compute expert outputs
+ x, tokens_per_expert = forward_fn(**forward_args)
+
+ # Save load balancing loss if needed
+ moe_loss_weight = 0.0 # Can be made configurable
+ if training and moe_loss_weight > 0:
+ save_load_balancing_loss((tokens_per_expert, logits))
+
+ # Restore original shape
+ x = x.view(in_shape)
+
+ return x, expert_weights, router_scores
+
+
+def moe_forward_with_shared_expert(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: Optional[torch.Tensor],
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+ w1: torch.Tensor = None,
+ w2: torch.Tensor = None,
+ w1_bias: torch.Tensor = None,
+ w2_bias: torch.Tensor = None,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ forward_fn: Any = None,
+ hidden_size: int = None,
+ mlp_impl: str = "grouped",
+ # Shared expert parameters
+ shared_up_proj_weight: Optional[torch.Tensor] = None,
+ shared_down_proj_weight: Optional[torch.Tensor] = None,
+ shared_up_proj_bias: Optional[torch.Tensor] = None,
+ shared_down_proj_bias: Optional[torch.Tensor] = None,
+ shared_expert_weighted_sum: bool = False,
+ shared_activation_fn: Optional[Any] = None,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+
+ # First, compute regular MoE forward pass
+ expert_out, expert_weights, router_scores = moe_forward(
+ x=x,
+ router_weight=router_weight,
+ router_bias=router_bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=training,
+ w1=w1,
+ w2=w2,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=moe_expert_model_parallelism,
+ forward_fn=forward_fn,
+ hidden_size=hidden_size,
+ mlp_impl=mlp_impl,
+ )
+
+ # If shared expert weights provided, compute shared expert output
+ if shared_up_proj_weight is not None and shared_down_proj_weight is not None:
+ shared_expert_out = shared_mlp_forward(
+ x=x,
+ up_proj_weight=shared_up_proj_weight,
+ down_proj_weight=shared_down_proj_weight,
+ up_proj_bias=shared_up_proj_bias,
+ down_proj_bias=shared_down_proj_bias,
+ activation_fn=shared_activation_fn,
+ gradient_scale=gradient_scale,
+ )
+
+ # Combine expert outputs
+ combined_out = combine_expert_shared_outputs(
+ shared_expert_out=shared_expert_out,
+ expert_out=expert_out,
+ shared_expert_weighted_sum=shared_expert_weighted_sum,
+ moe_top_k=moe_top_k,
+ )
+
+ return combined_out, expert_weights, router_scores
+
+ # Return regular MoE output if no shared expert
+ return expert_out, expert_weights, router_scores
+
+
+def create_shared_expert_weights(
+ hidden_size: int,
+ shared_expert_hidden_size: int,
+ device: torch.device,
+ dtype: torch.dtype,
+ init_method: Any,
+ output_layer_init_method: Any = None,
+) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
+
+ if output_layer_init_method is None:
+ output_layer_init_method = init_method
+
+ # Create weight tensors
+ up_proj_weight = torch.empty(
+ shared_expert_hidden_size,
+ hidden_size,
+ device=device,
+ dtype=dtype,
+ )
+ down_proj_weight = torch.empty(
+ hidden_size,
+ shared_expert_hidden_size,
+ device=device,
+ dtype=dtype,
+ )
+
+ # Initialize weights
+ init_method(up_proj_weight)
+ output_layer_init_method(down_proj_weight)
+
+ # No bias by default
+ return up_proj_weight, down_proj_weight, None, None
+
+
+# HACK: Extract device_mesh from pre-hook closure - required for transformers integration
+# This exists because device_mesh is trapped in hook closures with no model attribute
+# Fragile - breaks if hook structure changes or Python internals change
+# TODO: Replace with a more robust solution when available
+def get_device_mesh(model):
+ # Extract device_mesh from child's unused pre_hook closure
+ try:
+ # Find the pre-hook that contains 'device_mesh' in its closure
+ hook = next(
+ h
+ for h in model.experts._forward_pre_hooks.values()
+ if "device_mesh" in h.__code__.co_freevars
+ )
+ # Extract the device_mesh from the closure
+ return hook.__closure__[
+ hook.__code__.co_freevars.index("device_mesh")
+ ].cell_contents
+ except Exception:
+ return None
+
+
+class MegaBlocksMoeMLP(torch.nn.Module):
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ gradient_scale = getattr(self.experts, "gradient_scale", None)
+ alpha = getattr(self.experts, "alpha", 1.0)
+ moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+ uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
+
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = get_device_mesh(self)
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
+
+ has_parallel = (
+ expert_parallel_group is not None
+ and dist.is_initialized()
+ and dist.get_world_size(expert_parallel_group) > 1
+ )
+ forward_fn = parallel_forward_once if has_parallel else forward_once
+
+ sort_end_bit = max(
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
+ )
+ mlp_impl = getattr(self, "mlp_impl", "grouped")
+ output, expert_weights_out, *_ = moe_forward(
+ x=x,
+ router_weight=self.router.weight,
+ router_bias=self.router.bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=self.training,
+ w1=self.experts.gate_up_proj,
+ w2=self.experts.down_proj,
+ w1_bias=self.experts.gate_up_proj_bias,
+ w2_bias=self.experts.down_proj_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=has_parallel,
+ forward_fn=forward_fn,
+ hidden_size=self.experts.hidden_size,
+ mlp_impl=mlp_impl,
+ )
+ return output, expert_weights_out
+
+
+# Export main classes
+__all__ = ["MegaBlocksMoeMLP", "MegaBlocksMoeMLPWithSharedExpert"]
+
+
+class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
+
+ def __init__(self):
+ super().__init__()
+ # Shared expert weights will be set by the user
+ self.shared_up_proj_weight = None
+ self.shared_down_proj_weight = None
+ self.shared_up_proj_bias = None
+ self.shared_down_proj_bias = None
+ self.shared_expert_weighted_sum = False
+ self.shared_activation_fn = None
+
+ def set_shared_expert_weights(
+ self,
+ up_proj_weight: torch.Tensor,
+ down_proj_weight: torch.Tensor,
+ up_proj_bias: Optional[torch.Tensor] = None,
+ down_proj_bias: Optional[torch.Tensor] = None,
+ weighted_sum: bool = False,
+ activation_fn: Optional[Any] = None,
+ ):
+ self.shared_up_proj_weight = up_proj_weight
+ self.shared_down_proj_weight = down_proj_weight
+ self.shared_up_proj_bias = up_proj_bias
+ self.shared_down_proj_bias = down_proj_bias
+ self.shared_expert_weighted_sum = weighted_sum
+ self.shared_activation_fn = activation_fn
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ gradient_scale = getattr(self.experts, "gradient_scale", None)
+ alpha = getattr(self.experts, "alpha", 1.0)
+ moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+ uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
+
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = get_device_mesh(self)
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
+
+ has_parallel = (
+ expert_parallel_group is not None
+ and dist.is_initialized()
+ and dist.get_world_size(expert_parallel_group) > 1
+ )
+ forward_fn = parallel_forward_once if has_parallel else forward_once
+
+ sort_end_bit = max(
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
+ )
+ mlp_impl = getattr(self, "mlp_impl", "grouped")
+
+ output, expert_weights_out, *_ = moe_forward_with_shared_expert(
+ x=x,
+ router_weight=self.router.weight,
+ router_bias=self.router.bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=self.training,
+ w1=self.experts.gate_up_proj,
+ w2=self.experts.down_proj,
+ w1_bias=self.experts.gate_up_proj_bias,
+ w2_bias=self.experts.down_proj_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=has_parallel,
+ forward_fn=forward_fn,
+ hidden_size=self.experts.hidden_size,
+ mlp_impl=mlp_impl,
+ # Shared expert parameters
+ shared_up_proj_weight=self.shared_up_proj_weight,
+ shared_down_proj_weight=self.shared_down_proj_weight,
+ shared_up_proj_bias=self.shared_up_proj_bias,
+ shared_down_proj_bias=self.shared_down_proj_bias,
+ shared_expert_weighted_sum=self.shared_expert_weighted_sum,
+ shared_activation_fn=self.shared_activation_fn,
+ )
+ return output, expert_weights_out
+
+
+# Patch for XPU or CPU support
+if hasattr(torch, "xpu") and torch.xpu.is_available():
+ from .xpu_fused_moe import MegaBlocksMoeMLP
+
+from .cpu_moe_cpp import CPUMegaBlocksMoeMLP
diff --git a/build/torch210-cxx11-cu130-x86_64-linux/megablocks/__init__.py b/build/torch210-cxx11-cu130-x86_64-linux/megablocks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23
--- /dev/null
+++ b/build/torch210-cxx11-cu130-x86_64-linux/megablocks/__init__.py
@@ -0,0 +1,26 @@
+import ctypes
+import importlib.util
+import sys
+from pathlib import Path
+from types import ModuleType
+
+
+def _import_from_path(file_path: Path) -> ModuleType:
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
+ # it would also be used for other imports. So, we make a module name that
+ # depends on the path for it to be unique using the hex-encoded hash of
+ # the path.
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
+ module_name = path_hash
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ if spec is None:
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
+ module = importlib.util.module_from_spec(spec)
+ if module is None:
+ raise ImportError(f"Cannot load module {module_name} from spec")
+ sys.modules[module_name] = module
+ spec.loader.exec_module(module) # type: ignore
+ return module
+
+
+globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
diff --git a/build/torch210-cxx11-cu130-x86_64-linux/metadata.json b/build/torch210-cxx11-cu130-x86_64-linux/metadata.json
new file mode 100644
index 0000000000000000000000000000000000000000..a9813b81c6c98110d265c184f2016d728202289b
--- /dev/null
+++ b/build/torch210-cxx11-cu130-x86_64-linux/metadata.json
@@ -0,0 +1,18 @@
+{
+ "version": 1,
+ "license": "Apache-2.0",
+ "python-depends": [],
+ "backend": {
+ "type": "cuda",
+ "archs": [
+ "10.0",
+ "12.0",
+ "7.5",
+ "8.0",
+ "8.6",
+ "8.7",
+ "8.9",
+ "9.0"
+ ]
+ }
+}
diff --git a/build/torch210-cxx11-cu130-x86_64-linux/ops/__init__.py b/build/torch210-cxx11-cu130-x86_64-linux/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b944080df810d0b0cfc571f3009b0098a651f9b7
--- /dev/null
+++ b/build/torch210-cxx11-cu130-x86_64-linux/ops/__init__.py
@@ -0,0 +1,35 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from .binned_gather import binned_gather
+from .binned_scatter import binned_scatter
+from .cumsum import exclusive_cumsum, inclusive_cumsum
+from .gather import gather
+from .histogram import histogram
+from .padded_gather import padded_gather
+from .padded_scatter import padded_scatter
+from .repeat import repeat
+from .replicate import replicate
+from .round_up import round_up
+from .scatter import scatter
+from .sort import sort
+from .sum import sum
+from .topology import topology
+
+__all__ = [
+ 'binned_gather',
+ 'binned_scatter',
+ 'exclusive_cumsum',
+ 'inclusive_cumsum',
+ 'gather',
+ 'histogram',
+ 'padded_gather',
+ 'padded_scatter',
+ 'repeat',
+ 'replicate',
+ 'round_up',
+ 'scatter',
+ 'sort',
+ 'sum',
+ 'topology',
+]
diff --git a/build/torch210-cxx11-cu130-x86_64-linux/ops/all_to_all_benchmark.py b/build/torch210-cxx11-cu130-x86_64-linux/ops/all_to_all_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c939818edca3345f6344bbc7cef07ffe3cd0181
--- /dev/null
+++ b/build/torch210-cxx11-cu130-x86_64-linux/ops/all_to_all_benchmark.py
@@ -0,0 +1,63 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import torch.distributed as dist
+
+# from megablocks import benchmark_util
+# from megablocks.layers.all_to_all import all_to_all
+
+from .. import benchmark_util
+from .._layers.all_to_all import all_to_all
+
+_ALL_TO_ALL_BENCHMARK = (
+ (8, 1024),
+ (16, 1024),
+ (32, 1024),
+ (64, 1024),
+ (128, 1024),
+ (256, 1024),
+ (512, 1024),
+ (1024, 1024),
+ (2 * 1024, 1024),
+ (4 * 1024, 1024),
+ (8 * 1024, 1024),
+ (16 * 1024, 1024),
+ (32 * 1024, 1024),
+ (64 * 1024, 1024),
+ (128 * 1024, 1024),
+ (256 * 1024, 1024),
+ (512 * 1024, 1024),
+ (1024 * 1024, 1024),
+)
+
+
+def benchmark_all_to_all(group, sl, hs):
+ world_size = dist.get_world_size(group)
+ assert (sl % world_size) == 0
+ send_recv_sizes = [sl // world_size] * world_size
+
+ x = torch.randn((sl, hs)).cuda().half()
+
+ details = {
+ 'world_size': world_size,
+ 'message_size (B)': send_recv_sizes[0] * hs * 2, # 2B elements.
+ }
+
+ def benchmark():
+ return all_to_all(x, send_recv_sizes, send_recv_sizes, group)
+
+ time, std = benchmark_util.benchmark_function(benchmark)
+
+ if dist.get_rank(group) == 0:
+ benchmark_util.log_benchmark('All-To-All', details, time, std)
+
+
+if __name__ == '__main__':
+ assert dist.is_available()
+ group = dist.init_process_group(backend='nccl')
+ local_rank = dist.get_rank(group)
+ torch.cuda.set_device(local_rank)
+
+ for args in _ALL_TO_ALL_BENCHMARK:
+ benchmark_all_to_all(group, *args)
diff --git a/build/torch210-cxx11-cu130-x86_64-linux/ops/binned_gather.py b/build/torch210-cxx11-cu130-x86_64-linux/ops/binned_gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..189a7fa3518d660f29ea32e7a04827164af98d60
--- /dev/null
+++ b/build/torch210-cxx11-cu130-x86_64-linux/ops/binned_gather.py
@@ -0,0 +1,37 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for binned_gather kernel.
+class BinnedGatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bins: torch.Tensor,
+ bin_size: int,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bins)
+ ctx.top_k = top_k
+ return kernels.binned_gather(x, indices, None, bins, bin_size, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ indices, bins = ctx.saved_tensors
+ out = kernels.binned_scatter(grad, indices, None, bins, ctx.top_k)
+ return out, None, None, None, None
+
+
+binned_gather = BinnedGatherOp.apply
diff --git a/build/torch210-cxx11-cu130-x86_64-linux/ops/binned_scatter.py b/build/torch210-cxx11-cu130-x86_64-linux/ops/binned_scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb937c0c106662ce8108c1cb926f8f063b163d3d
--- /dev/null
+++ b/build/torch210-cxx11-cu130-x86_64-linux/ops/binned_scatter.py
@@ -0,0 +1,59 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for binned_scatter kernel.
+class BinnedScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ):
+ assert len(x.size()) == 3
+ ctx.bin_size = x.size(1)
+ ctx.top_k = top_k
+
+ # TODO(tgale): Don't save 'x' for backwards if we don't need to
+ # calculate the gradient w.r.t. 'weights'.
+ ctx.save_for_backward(x, indices, weights, bins)
+ return kernels.binned_scatter(x, indices, weights, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ x, indices, weights, bins = ctx.saved_tensors
+ out = kernels.binned_gather(
+ grad,
+ indices,
+ weights,
+ bins,
+ ctx.bin_size,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[2]:
+ wgrad = kernels.binned_scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bins,
+ ctx.top_k,
+ )
+ return out, None, wgrad, None, None
+
+
+binned_scatter = BinnedScatterOp.apply
diff --git a/build/torch210-cxx11-cu130-x86_64-linux/ops/cumsum.py b/build/torch210-cxx11-cu130-x86_64-linux/ops/cumsum.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2b7572391e20045d335cf7337246e8a9b9f57ef
--- /dev/null
+++ b/build/torch210-cxx11-cu130-x86_64-linux/ops/cumsum.py
@@ -0,0 +1,52 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ # import megablocks_ops as ops # type: ignore
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrappers for cumsum kernels.
+# NOTE: Does not support gradients.
+class ExclusiveCumsumOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, dim: int):
+ if len(x.size()) == 1:
+ x = x.view([1, -1])
+ out = torch.empty_like(x)
+ ops.exclusive_cumsum(x, 1, out)
+ return out.squeeze()
+ out = torch.empty_like(x)
+ ops.exclusive_cumsum(x, dim, out)
+ return out
+
+
+exclusive_cumsum = ExclusiveCumsumOp.apply
+
+
+class InclusiveCumsumOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, dim: int) -> torch.Tensor:
+ if len(x.size()) == 1:
+ x = x.view([1, -1])
+ out = torch.empty_like(x)
+ ops.inclusive_cumsum(x, 1, out)
+ return out.squeeze()
+ out = torch.empty_like(x)
+ ops.inclusive_cumsum(x, dim, out)
+ return out
+
+
+inclusive_cumsum = InclusiveCumsumOp.apply
diff --git a/build/torch210-cxx11-cu130-x86_64-linux/ops/gather.py b/build/torch210-cxx11-cu130-x86_64-linux/ops/gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1f87c1e7bed8d3589dd790805234976e0b05898
--- /dev/null
+++ b/build/torch210-cxx11-cu130-x86_64-linux/ops/gather.py
@@ -0,0 +1,38 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for gather kernel.
+class GatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bin_ids, bins)
+ ctx.top_k = top_k
+ return kernels.gather(x, indices, bin_ids, None, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+
+ indices, bin_ids, bins = ctx.saved_tensors
+ out = kernels.scatter(grad, indices, bin_ids, None, bins, ctx.top_k)
+ return out, None, None, None, None, None
+
+
+gather = GatherOp.apply
diff --git a/build/torch210-cxx11-cu130-x86_64-linux/ops/histogram.py b/build/torch210-cxx11-cu130-x86_64-linux/ops/histogram.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b3f058ec373cbba7555704fb5e4212c3cc75d9d
--- /dev/null
+++ b/build/torch210-cxx11-cu130-x86_64-linux/ops/histogram.py
@@ -0,0 +1,27 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for histogram kernel.
+# NOTE: Does not support gradients.
+class HistogramOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, max_val: float):
+ return ops.histogram(x, max_val)
+
+
+histogram = HistogramOp.apply
diff --git a/build/torch210-cxx11-cu130-x86_64-linux/ops/histogram_benchmark.py b/build/torch210-cxx11-cu130-x86_64-linux/ops/histogram_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0b9c6047567b87a295979498142230d1b0c9411
--- /dev/null
+++ b/build/torch210-cxx11-cu130-x86_64-linux/ops/histogram_benchmark.py
@@ -0,0 +1,78 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import numpy as np
+import torch
+# from absl.testing import parameterized
+
+from .. import ops
+
+_HISTOGRAM_TESTS = (
+ (16384, torch.int32, 2),
+ (16384, torch.int32, 4),
+ (16384, torch.int32, 8),
+ (16384, torch.int32, 16),
+ (16384, torch.int32, 32),
+ (16384, torch.int32, 64),
+ (16384, torch.int32, 128),
+ (16384, torch.int32, 256),
+)
+
+
+def benchmark_function(fn, iterations=10):
+ # Run once to get rid of startup overhead.
+ fn()
+ times = []
+ for _ in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+ start.record()
+ fn()
+ end.record()
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ times = np.array(times)
+ return times.mean(), times.std(), times.max(), times.min()
+
+
+def log_benchmark(arguments, mean_t, std_t):
+ print('=' * 60)
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t))
+ print('=' * 60)
+
+
+# class HistogramBenchmark(parameterized.TestCase):
+#
+# @parameterized.parameters(*_HISTOGRAM_TESTS)
+# def testHistogram(self, n, dtype, max_val):
+# x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
+#
+# mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.histogram(x, max_val),)
+# arguments = {
+# 'n': n,
+# 'dtype': dtype,
+# 'max_val': max_val,
+# }
+# log_benchmark(arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_HISTOGRAM_TESTS)
+# def testTorchHistogram(self, n, dtype, max_val):
+# x = torch.randint(0, 128, (n,)).cuda().to(dtype)
+#
+# mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.histc(x, max_val, 0, max_val - 1),)
+# arguments = {
+# 'n': n,
+# 'dtype': dtype,
+# 'max_val': max_val,
+# }
+# log_benchmark(arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch210-cxx11-cu130-x86_64-linux/ops/matmul_benchmark.py b/build/torch210-cxx11-cu130-x86_64-linux/ops/matmul_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c4b9b8866ffed2eb769b77f2320c82e5491ae0e
--- /dev/null
+++ b/build/torch210-cxx11-cu130-x86_64-linux/ops/matmul_benchmark.py
@@ -0,0 +1,415 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+
+# import stk
+
+# try:
+# import stk
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/ops/matmul_benchmark.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+# from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+
+# Calling tensor.t() calls tensor.transpose(0, 1) which calls
+# torch.as_strided(...). Circumvent this chain to avoid an overhead
+# this adds.
+def transpose_view(x):
+ return torch.as_strided(
+ x,
+ (x.shape[1], x.shape[0]),
+ (x.stride()[1], x.stride()[0]),
+ )
+
+
+_MATMUL_TESTS = (
+ (64 * 1024, 512, 2048, 64),
+ (32 * 1024, 768, 3072, 64),
+ (8 * 1024, 1024, 4096, 64),
+ (4 * 2048, 4096, 4 * 4096, 4),
+)
+
+
+def log_benchmark(name, arguments, time, std, flops):
+ benchmark_util.log_benchmark(name, arguments, time, std)
+ print('flops = {:.2f}B'.format(flops / 1e9))
+ print('throughput = {:.2f}T'.format(flops / 1e9 / time))
+ print('=' * 60)
+
+
+# class MatmulBenchmark(parameterized.TestCase):
+#
+# def build_sparse_matrix(self, x, padded_bins, fhs, ne):
+# blocking = 128
+# padded_tokens, _ = x.size()
+# assert padded_tokens % blocking == 0
+# assert fhs % blocking == 0
+#
+# # Offsets for the sparse matrix. All rows have the
+# # same number of nonzero blocks dictated by the
+# # dimensionality of a single expert.
+# block_rows = padded_tokens // blocking
+# blocks_per_row = fhs // blocking
+# offsets = torch.arange(
+# 0,
+# block_rows * blocks_per_row + 1,
+# blocks_per_row,
+# dtype=torch.int32,
+# device=x.device,
+# )
+#
+# # Indices for the sparse matrix. The indices for
+# # the intermediate matrix are dynamic depending
+# # on the mapping of tokens to experts.
+# column_indices = ops.topology(
+# padded_bins,
+# blocking,
+# block_rows,
+# blocks_per_row,
+# )
+# data = torch.empty(
+# column_indices.numel(),
+# blocking,
+# blocking,
+# dtype=torch.float16,
+# device=x.device,
+# )
+# shape = (padded_tokens, fhs * ne)
+# row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
+# return stk.Matrix(shape, data, row_indices, column_indices, offsets)
+#
+# def build_input_matrix(self, sl, hs, ne):
+# x = torch.randn((sl, hs)).cuda().half()
+#
+# # Assign tokens to experts uniformly.
+# top_expert = torch.arange(0, sl).cuda().int() % ne
+#
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(top_expert, ne)
+# padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+# padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+# out = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, 1)
+# return out, padded_bins
+#
+# def build_weight_matrix(self, ne, hs, fhs):
+# return torch.randn((hs, ne * fhs)).cuda().half()
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_Fwd_SDD_NT(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+# w = transpose_view(w)
+#
+# def benchmark():
+# return stk.ops.sdd(x, w, topo)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0::Fwd::SDD::NT',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_GradX_DSD_NN(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+#
+# def benchmark():
+# return stk.ops.dsd(topo, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0::GradX::DSD::NN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_GradW_DSD_TN(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+# topo = topo.t()
+#
+# def benchmark():
+# return stk.ops.dsd(topo, x)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0::GradW::DSD::TN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_Fwd_DSD_NN(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+#
+# def benchmark():
+# return stk.ops.dsd(x, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::Fwd::DSD::NN',
+# arguments,
+# mean_t,
+# std_t,
+# x.nnz * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_GradX_SDD_NT(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+# out = stk.ops.dsd(x, w)
+# w = transpose_view(w)
+#
+# def benchmark():
+# return stk.ops.sdd(out, w, x)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::GradX::SDD::NT',
+# arguments,
+# mean_t,
+# std_t,
+# x.nnz * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_GradW_DSD_TN(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+# out = stk.ops.dsd(x, w)
+# x = x.t()
+#
+# def benchmark():
+# return stk.ops.dsd(x, out)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::GradW::DSD::TN',
+# arguments,
+# mean_t,
+# std_t,
+# x.nnz * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_Fwd_DDD_NT(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, hs)).cuda().half()
+# w = torch.randn((ne, hs, fhs)).cuda().half()
+#
+# w = w.transpose(1, 2).contiguous()
+# w = w.transpose(1, 2)
+#
+# def benchmark():
+# return torch.bmm(x, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0::Fwd:DDD::NT',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_GradX_DDD_NN(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, hs)).cuda().half()
+# w = torch.randn((ne, hs, fhs)).cuda().half()
+# out = torch.bmm(x, w)
+# w = w.transpose(1, 2).contiguous()
+#
+# def benchmark():
+# return torch.bmm(out, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0:GradX:DDD::NN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_GradW_DDD_TN(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, hs)).cuda().half()
+# w = torch.randn((ne, hs, fhs)).cuda().half()
+# out = torch.bmm(x, w)
+# out = out.transpose(1, 2)
+#
+# def benchmark():
+# return torch.bmm(out, x)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0:GradW:DDD::TN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_Fwd_DDD_NN(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+# w = torch.randn((ne, fhs, hs)).cuda().half()
+#
+# def benchmark():
+# return torch.bmm(x, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::Fwd::DDD::NN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_GradX_DDD_NT(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+# w = torch.randn((ne, fhs, hs)).cuda().half()
+# out = torch.bmm(x, w)
+# w = torch.transpose(w, 1, 2)
+#
+# def benchmark():
+# return torch.bmm(out, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::GradX::DDD::NT',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_GradW_DDD_TN(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+# w = torch.randn((ne, fhs, hs)).cuda().half()
+# out = torch.bmm(x, w)
+# x = torch.transpose(x, 1, 2)
+#
+# def benchmark():
+# return torch.bmm(x, out)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::GradW::DDD::TN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * hs * 2,
+# )
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch210-cxx11-cu130-x86_64-linux/ops/padded_gather.py b/build/torch210-cxx11-cu130-x86_64-linux/ops/padded_gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1cf4047c9494394d2a3884ba8830179013db7ff
--- /dev/null
+++ b/build/torch210-cxx11-cu130-x86_64-linux/ops/padded_gather.py
@@ -0,0 +1,55 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for padded_gather kernel.
+class PaddedGatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bin_ids, bins, padded_bins)
+ ctx.top_k = top_k
+ return kernels.padded_gather(
+ x,
+ indices,
+ bin_ids,
+ None,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+
+ indices, bin_ids, bins, padded_bins = ctx.saved_tensors
+ out = kernels.padded_scatter(
+ grad,
+ indices,
+ bin_ids,
+ None,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+ return out, None, None, None, None, None
+
+
+padded_gather = PaddedGatherOp.apply
diff --git a/build/torch210-cxx11-cu130-x86_64-linux/ops/padded_scatter.py b/build/torch210-cxx11-cu130-x86_64-linux/ops/padded_scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..61e021b81497e472cda5d72bdac557a0ca92d262
--- /dev/null
+++ b/build/torch210-cxx11-cu130-x86_64-linux/ops/padded_scatter.py
@@ -0,0 +1,98 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for padded_scatter kernel.
+class PaddedScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+ ):
+ maybe_x = [x] if ctx.needs_input_grad[3] else []
+ ctx.save_for_backward(
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ *maybe_x,
+ )
+ ctx.top_k = top_k
+ ctx.x_shape = x.shape
+ return kernels.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ saved_tensors = ctx.saved_tensors
+
+ indices, bin_ids, weights, bins, padded_bins = saved_tensors[:5]
+ dgrad = None
+ if ctx.needs_input_grad[0]:
+ dgrad = kernels.padded_gather(
+ grad,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[3]: # need wgrad
+ x = saved_tensors[-1]
+ wgrad = kernels.padded_scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+ return dgrad, None, None, wgrad, None, None, None, None
+
+
+def padded_scatter(
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+):
+ return PaddedScatterOp.apply(
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
diff --git a/build/torch210-cxx11-cu130-x86_64-linux/ops/padded_scatter_benchmark.py b/build/torch210-cxx11-cu130-x86_64-linux/ops/padded_scatter_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..3cbe4735891446b46f93170c64c23fe63632bf93
--- /dev/null
+++ b/build/torch210-cxx11-cu130-x86_64-linux/ops/padded_scatter_benchmark.py
@@ -0,0 +1,66 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import torch
+# from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+_PADDED_SCATTER_BENCHMARK = (
+ # dMoE-Medium, 8-way EMP.
+ (1024 * 16, 1024, 8, 4),
+ # dMoE-Medium, post-all-to-all.
+ (1024 * 16 * 4, 1024, 8, 1),
+)
+
+
+# class PaddedScatterTest(parameterized.TestCase):
+#
+# @parameterized.parameters(*_PADDED_SCATTER_BENCHMARK)
+# def testPaddedScatter(self, sl, hs, ne, top_k):
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+#
+# # Randomly assign tokens to experts.
+# top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(top_expert, ne)
+# padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+# padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+#
+# # Sample weights for the scatter reduce.
+# weights = torch.rand((sl * top_k,)).cuda().half()
+#
+# # Gather the data to prepare for backwards.
+# x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
+#
+# def benchmark():
+# return ops.padded_scatter(
+# x,
+# indices,
+# bin_ids,
+# weights,
+# bins,
+# padded_bins,
+# top_k,
+# )
+#
+# time, std = benchmark_util.benchmark_function(benchmark)
+# benchmark_util.log_benchmark(
+# 'Padded Scatter',
+# {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# 'top_k': top_k,
+# },
+# time,
+# std,
+# )
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch210-cxx11-cu130-x86_64-linux/ops/permute_benchmark.py b/build/torch210-cxx11-cu130-x86_64-linux/ops/permute_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..697abddbb3a2082ec4ddd6d94f89f7faabb34b40
--- /dev/null
+++ b/build/torch210-cxx11-cu130-x86_64-linux/ops/permute_benchmark.py
@@ -0,0 +1,149 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import torch
+# from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+_PERMUTE_TESTS = (
+ (16384, 768, 2),
+ (16384, 768, 4),
+ (16384, 768, 8),
+ (16384, 768, 16),
+ (16384, 768, 32),
+ (16384, 768, 64),
+ (16384, 768, 128),
+ (16384 * 8, 768, 2),
+ (16384 * 8, 768, 4),
+ (16384 * 8, 768, 8),
+ (16384 * 8, 768, 16),
+ (16384 * 8, 768, 32),
+ (16384 * 8, 768, 64),
+ (16384 * 8, 768, 128),
+)
+
+
+# class PermuteBenchmark(parameterized.TestCase):
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testBinnedGather(self, sl, hs, ne):
+# # NOTE: Capacity factor == 1.
+# ec = sl // ne
+#
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+# top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(indices, ne)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+#
+# def benchmark():
+# return ops.binned_gather(x, indices, bins, ec)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('BinnedGather', arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testBinnedScatter(self, sl, hs, ne):
+# # NOTE: Capacity factor == 1.
+# ec = sl // ne
+#
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+# top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(indices, ne)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+# x = ops.binned_gather(x, indices, bins, ec)
+#
+# def benchmark():
+# return ops.binned_scatter(x, indices, bins)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('BinnedScatter', arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testPaddedGather(self, sl, hs, ne):
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+#
+# # Randomly assign tokens to experts.
+# top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(top_expert, ne)
+# padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+# padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+#
+# def benchmark():
+# return ops.padded_gather(x, indices, bin_ids, bins, padded_bins)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('PaddedGather', arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testPaddedScatter(self, sl, hs, ne):
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+#
+# # Randomly assign tokens to experts.
+# top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(top_expert, ne)
+# padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+# padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+# x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins)
+#
+# def benchmark():
+# return ops.padded_scatter(x, indices, bin_ids, bins, padded_bins)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('PaddedScatter', arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testCopy(self, sl, hs, ne):
+# # NOTE: Capacity factor == 1.
+# # ec = sl // ne
+#
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+# y = x.clone()
+#
+# def benchmark():
+# return y.copy_(x)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('Copy', arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch210-cxx11-cu130-x86_64-linux/ops/repeat.py b/build/torch210-cxx11-cu130-x86_64-linux/ops/repeat.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e9e09de5f857d51cd758ab30b2f3a846d6f9275
--- /dev/null
+++ b/build/torch210-cxx11-cu130-x86_64-linux/ops/repeat.py
@@ -0,0 +1,10 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+
+def repeat(x: torch.Tensor, tiling: torch.Size):
+ if all((t == 1 for t in tiling)):
+ return x
+ return x.repeat(*tiling)
diff --git a/build/torch210-cxx11-cu130-x86_64-linux/ops/replicate.py b/build/torch210-cxx11-cu130-x86_64-linux/ops/replicate.py
new file mode 100644
index 0000000000000000000000000000000000000000..26daf0eede330603a4b8ea7167faf1411d07ca93
--- /dev/null
+++ b/build/torch210-cxx11-cu130-x86_64-linux/ops/replicate.py
@@ -0,0 +1,36 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for replicate kernel.
+class ReplicateOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, bins: torch.Tensor, num_outputs: int):
+ ctx.save_for_backward(bins)
+ out = torch.empty((x.shape[0], num_outputs), dtype=x.dtype, device=x.device)
+ ops.replicate_forward(x, bins, out)
+ return out
+
+ @staticmethod
+ def backward(ctx: Any, grad: torch.Tensor):
+ bins, = ctx.saved_tensors
+ out = torch.empty((grad.shape[0], bins.shape[0]), dtype=grad.dtype, device=grad.device)
+ ops.replicate_backward(grad, bins, out)
+ return out, None, None
+
+
+replicate = ReplicateOp.apply
diff --git a/build/torch210-cxx11-cu130-x86_64-linux/ops/round_up.py b/build/torch210-cxx11-cu130-x86_64-linux/ops/round_up.py
new file mode 100644
index 0000000000000000000000000000000000000000..6cf6bc873c9f448c5fa9126ebcfd66e8688002af
--- /dev/null
+++ b/build/torch210-cxx11-cu130-x86_64-linux/ops/round_up.py
@@ -0,0 +1,14 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+
+def round_up(x: torch.Tensor, value: int):
+ assert isinstance(value, int)
+ assert x.dtype == torch.int32
+
+ # TODO(tgale): If this becomes and issue
+ # do this in a custom kernel. We only expect
+ # to use this on arrays of less than 1k elements.
+ return torch.div(x + (value - 1), value, rounding_mode='trunc') * value
diff --git a/build/torch210-cxx11-cu130-x86_64-linux/ops/scatter.py b/build/torch210-cxx11-cu130-x86_64-linux/ops/scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4605d9b46f387761b070352365f223dbfe69d47
--- /dev/null
+++ b/build/torch210-cxx11-cu130-x86_64-linux/ops/scatter.py
@@ -0,0 +1,72 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Optional
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for scatter kernel.
+class ScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ) -> torch.Tensor:
+ maybe_x = [x] if ctx.needs_input_grad[3] else []
+ ctx.save_for_backward(indices, bin_ids, weights, bins, *maybe_x)
+ ctx.top_k = top_k
+ ctx.x_shape = x.shape
+ return kernels.scatter(x, indices, bin_ids, weights, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ saved_tensors = ctx.saved_tensors
+
+ indices, bin_ids, weights, bins = saved_tensors[:4]
+ dgrad = None
+ if ctx.needs_input_grad[0]:
+ dgrad = kernels.gather(
+ grad,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[3]: # need wgrad
+ x = saved_tensors[-1]
+ wgrad = kernels.scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bin_ids,
+ bins,
+ ctx.top_k,
+ )
+ return dgrad, None, None, wgrad, None, None, None
+
+
+def scatter(
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+) -> Optional[torch.Tensor]:
+ return ScatterOp.apply(x, indices, bin_ids, weights, bins, top_k)
diff --git a/build/torch210-cxx11-cu130-x86_64-linux/ops/sort.py b/build/torch210-cxx11-cu130-x86_64-linux/ops/sort.py
new file mode 100644
index 0000000000000000000000000000000000000000..bda3bf64283e39533c2eae3627e76bb2d0262c9f
--- /dev/null
+++ b/build/torch210-cxx11-cu130-x86_64-linux/ops/sort.py
@@ -0,0 +1,38 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Optional, Tuple
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+_BITS_FOR_DTYPE = {
+ torch.int16: 16,
+ torch.int32: 32,
+ torch.int64: 64,
+}
+
+
+# Autograd wrapper for sort kernel.
+# NOTE: Does not support gradients.
+class SortOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, end_bit: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
+ if end_bit is None:
+ end_bit = _BITS_FOR_DTYPE[x.dtype]
+ x_out = torch.empty_like(x)
+ iota_out = torch.empty_like(x)
+ ops.sort(x, end_bit, x_out, iota_out)
+ return (x_out, iota_out)
+
+
+sort = SortOp.apply
diff --git a/build/torch210-cxx11-cu130-x86_64-linux/ops/sort_benchmark.py b/build/torch210-cxx11-cu130-x86_64-linux/ops/sort_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..11043c0824c36372585f1d9f48480c2a6ef32eb6
--- /dev/null
+++ b/build/torch210-cxx11-cu130-x86_64-linux/ops/sort_benchmark.py
@@ -0,0 +1,85 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import numpy as np
+import torch
+# from absl.testing import parameterized
+
+from .. import ops
+
+_SORT_TESTS = (
+ (16384, torch.int32, None),
+ (16384, torch.int32, 2),
+ (16384, torch.int32, 128),
+)
+
+_BASELINE_SORT_TESTS = ((16384,),)
+
+
+def numpy_dtype(dtype):
+ types = {
+ torch.int16: np.int16,
+ torch.int32: np.int32,
+ torch.int64: np.int64,
+ }
+ return types[dtype]
+
+
+def benchmark_function(fn, iterations=10):
+ # Run once to get rid of startup overhead.
+ fn()
+ times = []
+ for _ in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+ start.record()
+ fn()
+ end.record()
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ times = np.array(times)
+ return times.mean(), times.std(), times.max(), times.min()
+
+
+def log_benchmark(arguments, mean_t, std_t):
+ print('=' * 60)
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t))
+ print('=' * 60)
+
+
+# class SortBenchmark(parameterized.TestCase):
+#
+# @parameterized.parameters(*_SORT_TESTS)
+# def testSort(self, n, dtype, max_val):
+# if max_val is None:
+# max_val = np.iinfo(numpy_dtype(dtype)).max
+# end_bit = int(np.ceil(np.log2(max_val)))
+# x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
+#
+# mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.sort(x, end_bit),)
+# arguments = {
+# 'n': n,
+# 'dtype': dtype,
+# 'max_val': max_val,
+# }
+# log_benchmark(arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_BASELINE_SORT_TESTS)
+# def testTorchSort(self, n):
+# x = torch.randint(0, 128, (n,)).cuda().to(torch.int32)
+#
+# mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.sort(x))
+# arguments = {
+# 'n': n,
+# }
+# log_benchmark(arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch210-cxx11-cu130-x86_64-linux/ops/stk_autocast.py b/build/torch210-cxx11-cu130-x86_64-linux/ops/stk_autocast.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a3626e5e0eec51339c95a448bca84be14a2ca93
--- /dev/null
+++ b/build/torch210-cxx11-cu130-x86_64-linux/ops/stk_autocast.py
@@ -0,0 +1,39 @@
+# vendored from
+# https://github.com/stanford-futuredata/stk/blob/736313768ef697ce13a0594a41b2512a0fbc9884/stk/backend/autocast.py
+import functools
+import torch
+
+
+def _is_eligible(x):
+ return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64)
+
+
+def _cast(x, dtype):
+ if isinstance(x, torch.Tensor) and _is_eligible(x):
+ return x.to(dtype)
+ elif isinstance(x, map):
+ return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()}
+ elif isinstance(x, list) or isinstance(x, tuple):
+ return type(x)(map(lambda y: _cast(y, dtype), x))
+ return x
+
+
+def custom_fwd(fwd):
+ """Wrap a custom autograd function that always uses autocast dtype."""
+
+ @functools.wraps(fwd)
+ def decorate_fwd(*args, **kwargs):
+ if torch.is_autocast_enabled():
+ with torch.autocast(device_type="cuda", enabled=False):
+ dtype = torch.get_autocast_gpu_dtype()
+ return fwd(*_cast(args, dtype), **_cast(kwargs, dtype))
+ return fwd(*args, **kwargs)
+ return decorate_fwd
+
+
+def custom_bwd(bwd):
+ @functools.wraps(bwd)
+ def decorate_bwd(*args, **kwargs):
+ with torch.autocast(device_type="cuda", enabled=False):
+ return bwd(*args, **kwargs)
+ return decorate_bwd
\ No newline at end of file
diff --git a/build/torch210-cxx11-cu130-x86_64-linux/ops/sum.py b/build/torch210-cxx11-cu130-x86_64-linux/ops/sum.py
new file mode 100644
index 0000000000000000000000000000000000000000..e00c1aa68e1193f5b72f75a2edc37de8d505facc
--- /dev/null
+++ b/build/torch210-cxx11-cu130-x86_64-linux/ops/sum.py
@@ -0,0 +1,9 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+import torch
+
+
+def sum(x: torch.Tensor, dim: int = 0):
+ if x.shape[dim] == 1:
+ return x.squeeze(dim=dim)
+ return x.sum(dim=dim)
diff --git a/build/torch210-cxx11-cu130-x86_64-linux/ops/topology.py b/build/torch210-cxx11-cu130-x86_64-linux/ops/topology.py
new file mode 100644
index 0000000000000000000000000000000000000000..76a50d3164db20534b099dcb4d8487a7aef25d15
--- /dev/null
+++ b/build/torch210-cxx11-cu130-x86_64-linux/ops/topology.py
@@ -0,0 +1,45 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for topology kernel.
+# NOTE: Does not support gradients.
+class TopologyOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(
+ ctx: Any,
+ padded_bins: torch.Tensor,
+ block_size: int,
+ output_block_rows: int,
+ output_block_columns: int,
+ ):
+ out = torch.empty(
+ output_block_rows * output_block_columns,
+ dtype=torch.int16,
+ device=padded_bins.device,
+ )
+ ops.indices(
+ padded_bins,
+ block_size,
+ output_block_rows,
+ output_block_columns,
+ out,
+ )
+ return out
+
+
+topology = TopologyOp.apply
diff --git a/build/torch210-cxx11-cu130-x86_64-linux/stk/__init__.py b/build/torch210-cxx11-cu130-x86_64-linux/stk/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..73c40c267b1c3f4949e9c957a5d2c9f682dfc1a6
--- /dev/null
+++ b/build/torch210-cxx11-cu130-x86_64-linux/stk/__init__.py
@@ -0,0 +1,7 @@
+# import stk.random
+# import stk.ops
+# from stk.matrix import Matrix
+
+from . import random
+from . import ops
+from .matrix import Matrix
diff --git a/build/torch210-cxx11-cu130-x86_64-linux/stk/backend/__init__.py b/build/torch210-cxx11-cu130-x86_64-linux/stk/backend/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/build/torch210-cxx11-cu130-x86_64-linux/stk/backend/autocast.py b/build/torch210-cxx11-cu130-x86_64-linux/stk/backend/autocast.py
new file mode 100644
index 0000000000000000000000000000000000000000..97f6e919a60f3fd579ed0215031008d14111dc96
--- /dev/null
+++ b/build/torch210-cxx11-cu130-x86_64-linux/stk/backend/autocast.py
@@ -0,0 +1,37 @@
+import functools
+import torch
+
+
+def _is_eligible(x):
+ return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64)
+
+
+def _cast(x, dtype):
+ if isinstance(x, torch.Tensor) and _is_eligible(x):
+ return x.to(dtype)
+ elif isinstance(x, map):
+ return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()}
+ elif isinstance(x, list) or isinstance(x, tuple):
+ return type(x)(map(lambda y: _cast(y, dtype), x))
+ return x
+
+
+def custom_fwd(fwd):
+ """Wrap a custom autograd function that always uses autocast dtype."""
+
+ @functools.wraps(fwd)
+ def decorate_fwd(*args, **kwargs):
+ if torch.is_autocast_enabled():
+ with torch.autocast(device_type="cuda", enabled=False):
+ dtype = torch.get_autocast_gpu_dtype()
+ return fwd(*_cast(args, dtype), **_cast(kwargs, dtype))
+ return fwd(*args, **kwargs)
+ return decorate_fwd
+
+
+def custom_bwd(bwd):
+ @functools.wraps(bwd)
+ def decorate_bwd(*args, **kwargs):
+ with torch.autocast(device_type="cuda", enabled=False):
+ return bwd(*args, **kwargs)
+ return decorate_bwd
diff --git a/build/torch210-cxx11-cu130-x86_64-linux/stk/backend/sputnik.py b/build/torch210-cxx11-cu130-x86_64-linux/stk/backend/sputnik.py
new file mode 100644
index 0000000000000000000000000000000000000000..220c947bc1e932e8c77cc30f4069e9930f1aa962
--- /dev/null
+++ b/build/torch210-cxx11-cu130-x86_64-linux/stk/backend/sputnik.py
@@ -0,0 +1,316 @@
+import torch
+
+from ..backend import triton_kernels as backend
+from ..backend.autocast import custom_bwd, custom_fwd
+
+
+def _standardize_shape(x, transpose):
+ if transpose:
+ return torch.Size((x[1], x[0]))
+ return x
+
+
+def _sparse_transpose(x):
+ return (torch.Size((x[0][1], x[0][0])), ) + x[1:]
+
+
+def _transpose_helper(x, transpose):
+ if isinstance(x, torch.Tensor):
+ return x.t() if transpose else x
+ if transpose:
+ x = _sparse_transpose(x)
+ return x + (transpose,)
+
+
+def _wrap(x):
+ if isinstance(x, torch.Tensor):
+ return (x,)
+ return x
+
+
+def _is_transposed(x):
+ return (not x.is_contiguous() and
+ x.stride()[0] == 1 and
+ x.stride()[1] == x.size()[0])
+
+
+def _call_helper(op, out, a, b, trans_a, trans_b):
+ args = (_wrap(_transpose_helper(a, trans_a)) +
+ _wrap(_transpose_helper(b, trans_b)))
+ if isinstance(out, tuple):
+ args = args + out
+ return op(*args)
+
+
+def _preprocess_inputs(lhs, rhs, dy):
+ if isinstance(lhs, torch.Tensor) and _is_transposed(lhs):
+ lhs = lhs.t()
+ if isinstance(rhs, torch.Tensor) and _is_transposed(rhs):
+ rhs = rhs.t()
+ if (isinstance(dy, torch.Tensor) and
+ not dy.is_contiguous() and
+ not _is_transposed(dy)):
+ dy = dy.contiguous()
+ if isinstance(dy, tuple) and not dy[1].is_contiguous():
+ dy = (dy[0], dy[1].contiguous()) + dy[2:]
+ return lhs, rhs, dy
+
+
+def _postprocess_outputs(x, transpose, grad):
+ if isinstance(x, torch.Tensor) and transpose:
+ return grad.t()
+ return grad
+
+
+def _lhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs):
+ lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy)
+
+ a, b = (rhs, dy) if trans_lhs else (dy, rhs)
+ trans_a = trans_lhs and trans_rhs
+ trans_b = trans_lhs or not trans_rhs
+ out = _call_helper(op, lhs, a, b, trans_a, trans_b)
+ return _postprocess_outputs(lhs, trans_lhs, out)
+
+
+def _rhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs):
+ lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy)
+
+ a, b = (dy, lhs) if trans_rhs else (lhs, dy)
+ trans_a = not trans_lhs or trans_rhs
+ trans_b = trans_lhs and trans_rhs
+ out = _call_helper(op, rhs, a, b, trans_a, trans_b)
+ return _postprocess_outputs(rhs, trans_rhs, out)
+
+
+class DSD(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs):
+ ctx.save_for_backward(data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ rhs)
+ ctx.shape = _standardize_shape(shape, transpose_a)
+ ctx.transpose_a = transpose_a
+
+ out = torch.empty(
+ (shape[0], rhs.size()[1]),
+ dtype=rhs.dtype,
+ device=rhs.device)
+
+ backend.dsd(shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs,
+ out)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs = (ctx.shape,) + saved_tensors[:-1]
+ rhs = saved_tensors[-1]
+ trans_a = ctx.transpose_a
+ trans_b = _is_transposed(rhs)
+
+ ddata = None
+ if ctx.needs_input_grad[1]:
+ ddata = _lhs_gradient(sdd,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ drhs = None
+ if ctx.needs_input_grad[-1]:
+ op = dds if trans_b else dsd
+ drhs = _rhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return None, ddata, None, None, None, None, None, None, None, drhs
+
+
+dsd = DSD.apply
+
+
+class DDS(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b):
+ ctx.save_for_backward(lhs,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t)
+ ctx.shape = _standardize_shape(shape, transpose_b)
+ ctx.transpose_b = transpose_b
+ out = torch.empty((lhs.size()[0], shape[1]),
+ dtype=lhs.dtype,
+ device=lhs.device)
+ backend.dds(lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b,
+ out)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs = saved_tensors[0]
+ rhs = (ctx.shape,) + saved_tensors[1:]
+ trans_a = _is_transposed(lhs)
+ trans_b = ctx.transpose_b
+
+ dlhs = None
+ if ctx.needs_input_grad[0]:
+ op = dsd if trans_a else dds
+ dlhs = _lhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ ddata = None
+ if ctx.needs_input_grad[2]:
+ ddata = _rhs_gradient(sdd,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return dlhs, None, ddata, None, None, None, None, None, None, None
+
+
+dds = DDS.apply
+
+
+class SDD(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ lhs,
+ rhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t):
+ ctx.save_for_backward(
+ lhs,
+ rhs,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t)
+ ctx.shape = shape
+ out = torch.empty(
+ data.shape,
+ dtype=lhs.dtype,
+ device=lhs.device)
+ backend.sdd(lhs,
+ rhs,
+ shape,
+ out,
+ offsets,
+ row_indices,
+ column_indices)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs, rhs = saved_tensors[:2]
+ dy = (ctx.shape, dy) + saved_tensors[2:]
+ trans_a = _is_transposed(lhs)
+ trans_b = _is_transposed(rhs)
+
+ dlhs = None
+ if ctx.needs_input_grad[0]:
+ op = dds if trans_a else dsd
+ dlhs = _lhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ drhs = None
+ if ctx.needs_input_grad[1]:
+ op = dsd if trans_b else dds
+ drhs = _rhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return dlhs, drhs, None, None, None, None, None, None, None, None
+
+
+sdd = SDD.apply
+
+class RowIndices(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, shape, data, offsets, column_indices):
+ out = torch.empty(
+ column_indices.shape,
+ dtype=column_indices.dtype,
+ device=column_indices.device)
+ backend.row_indices(shape, data, offsets, column_indices, out)
+ return out
+
+
+row_indices = RowIndices.apply
diff --git a/build/torch210-cxx11-cu130-x86_64-linux/stk/backend/triton_kernels.py b/build/torch210-cxx11-cu130-x86_64-linux/stk/backend/triton_kernels.py
new file mode 100644
index 0000000000000000000000000000000000000000..c535309f3321249f475367164a558f94a4f8eb86
--- /dev/null
+++ b/build/torch210-cxx11-cu130-x86_64-linux/stk/backend/triton_kernels.py
@@ -0,0 +1,393 @@
+import torch
+import triton
+import triton.language as tl
+from dataclasses import dataclass
+
+@dataclass
+class TritonConfig:
+ BLOCK_M: int = 128
+ BLOCK_N: int = 128
+ BLOCK_K: int = 32
+ BLOCK_SIZE: int = 128
+ NUM_STAGES: int = 4
+ NUM_WARPS: int = 4
+
+def _validate_matmul_dims(M: int, K: int, N: int):
+ error_string = "incompatible dimensions: tensor has dim with length: {}, which must be divisible by {}"
+ assert M % TritonConfig.BLOCK_M == 0, error_string.format(M, TritonConfig.BLOCK_M)
+ assert K % TritonConfig.BLOCK_K == 0, error_string.format(K, TritonConfig.BLOCK_K)
+ assert N % TritonConfig.BLOCK_N == 0, error_string.format(N, TritonConfig.BLOCK_N)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _sdd_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+ # matrix multiplication
+ pid = tl.program_id(0)
+ pid_m = tl.load(row_indices + pid)
+ pid_n = tl.load(column_indices + pid)
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
+ rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
+ rk = tl.arange(0, BLOCK_K)
+ # pointers
+ A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
+ B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ for k in range(0, tl.cdiv(K, BLOCK_K)):
+ a = tl.load(A)
+ b = tl.load(B)
+ acc += tl.dot(a, b)
+ A += BLOCK_K * stride_ak
+ B += BLOCK_K * stride_bk
+ #Store to sparse matrix
+ acc = acc.to(C.dtype.element_ty)
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+ cm = tl.arange(0, BLOCK_M)
+ cn = tl.arange(0, BLOCK_N)
+ C = C + pid * BLOCK_ELEMENTS + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _dsd_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices, offsets,
+ block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+
+ # matrix multiplication
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+
+ num_pid_m = tl.num_programs(0)
+ num_pid_n = tl.num_programs(1)
+ pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M)
+
+ start_inx = tl.load(offsets + pid_m)
+ end_inx = tl.load(offsets + pid_m + 1)
+
+ # pointers to sparse matrix
+ rm = tl.arange(0, BLOCK_M)
+ rak = tl.arange(0, BLOCK_K)
+
+ A += (rm[:, None] * stride_am + rak[None, :] * stride_ak)
+
+ # pointers to dense matrix
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ rbk = tl.arange(0, BLOCK_K)
+ B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn)
+
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K)
+
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+ ak_sub_incr = BLOCK_K * stride_ak
+ bk_sub_incr = BLOCK_K * stride_bk
+ bk_block_incr = BLOCK_SIZE * stride_bk
+
+ for k in range(nsub_blocks * (end_inx - start_inx)):
+ sub_block_inx = k % nsub_blocks
+ block_inx = k // nsub_blocks
+
+ if trans_A:
+ ptr_A = A + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr
+ else:
+ ptr_A = A + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr
+
+ ptr_B = B + tl.load(column_indices + start_inx + block_inx) * bk_block_incr + sub_block_inx * bk_sub_incr
+
+ a = tl.load(ptr_A)
+ b = tl.load(ptr_B)
+ acc += tl.dot(a, b)
+
+ acc = acc.to(C.dtype.element_ty)
+
+ cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+
+ C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _dds_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices, offsets,
+ block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+
+ # matrix multiplication
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+
+ num_pid_m = tl.num_programs(0)
+ num_pid_n = tl.num_programs(1)
+ pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M)
+
+ start_inx = tl.load(offsets + pid_n)
+ end_inx = tl.load(offsets + pid_n + 1)
+
+ # pointers to dense matrix
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ rak = tl.arange(0, BLOCK_K)
+
+ A += (rm[:, None] * stride_am + rak[None, :] * stride_ak)
+
+ # pointers to sparse matrix
+ rn = tl.arange(0, BLOCK_N)
+ rbk = tl.arange(0, BLOCK_K)
+ B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn)
+
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K)
+
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+
+ ak_sub_incr = BLOCK_K * stride_ak
+ ak_block_incr = BLOCK_SIZE * stride_ak
+ bk_sub_incr = BLOCK_K * stride_bk
+
+ for k in range(nsub_blocks * (end_inx - start_inx)):
+ sub_block_inx = k % nsub_blocks
+ block_inx = k // nsub_blocks
+
+ if trans_B:
+ ptr_B = B + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr
+ else:
+ ptr_B = B + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr
+
+ ptr_A = A + tl.load(column_indices + start_inx + block_inx) * ak_block_incr + sub_block_inx * ak_sub_incr
+ a = tl.load(ptr_A)
+ b = tl.load(ptr_B)
+ acc += tl.dot(a, b)
+
+ acc = acc.to(C.dtype.element_ty)
+ cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+def dsd(shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs,
+ out
+ ):
+
+ device = rhs.device
+ trans_A = transpose_a
+ trans_B = False
+
+ if rhs.stride(0) > 1 and rhs.stride(1) > 1:
+ trans_B = True
+
+ # checks constraints
+ assert shape[1] == rhs.shape[0], "incompatible dimensions"
+ M, K = shape
+ _, N = rhs.shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if rhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ stride_am, stride_ak = data.stride(1), data.stride(2)
+ stride_bk, stride_bn = rhs.stride(0), rhs.stride(1)
+ a_column_indices = column_indices
+ a_offsets = offsets
+
+ # launch kernel
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N']))
+
+ if trans_A:
+ stride_am, stride_ak = data.stride(2), data.stride(1)
+ a_column_indices, a_offsets = column_indices_t, offsets_t
+
+ if trans_B:
+ stride_bk, stride_bn = rhs.stride(1), rhs.stride(0)
+
+ _dsd_kernel[grid](
+ data.data, rhs, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(0), out.stride(1),
+ row_indices, a_column_indices, a_offsets,
+ block_offsets_t, trans_A, trans_B,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+ # return out
+
+def dds(lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b,
+ out
+ ):
+
+ device = lhs.device
+ trans_B = transpose_b
+ trans_A = False
+
+ if lhs.stride(0) > 1 and lhs.stride(1) > 1:
+ trans_A = True
+
+ # checks constraints
+ assert lhs.shape[1] == shape[0], "incompatible dimensions"
+ M, K = lhs.shape
+ _, N = shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if lhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ stride_am, stride_ak = lhs.stride(0), lhs.stride(1)
+ stride_bk, stride_bn = data.stride(1), data.stride(2)
+ b_column_indices = column_indices_t
+ b_offsets = offsets_t
+
+ # launch kernel
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N']))
+
+ if trans_A:
+ stride_am, stride_ak = lhs.stride(1), lhs.stride(0)
+ if trans_B:
+ stride_bk, stride_bn = data.stride(2), data.stride(1)
+ b_column_indices, b_offsets = column_indices, offsets
+
+ _dds_kernel[grid](
+ lhs, data, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(0), out.stride(1),
+ row_indices, b_column_indices, b_offsets,
+ block_offsets_t, trans_A, trans_B,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+
+def sdd(lhs,
+ rhs,
+ shape,
+ out,
+ offsets,
+ row_indices,
+ column_indices
+ ):
+
+ device = out.device
+ trans_A = False
+ trans_B = False
+
+ if lhs.stride(0) > 1 and lhs.stride(1) > 1:
+ trans_A = True
+ if rhs.stride(0) > 1 and rhs.stride(1) > 1:
+ trans_B = True
+
+ # checks constraints
+ assert lhs.shape[1] == rhs.shape[0], "incompatible dimensions"
+ M, K = lhs.shape
+ _, N = rhs.shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if out.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ # launch kernel
+ nnz_blocks = len(row_indices)
+ grid = lambda META: (nnz_blocks,)
+
+ stride_am, stride_ak = lhs.stride(0), lhs.stride(1)
+ stride_bk, stride_bn = rhs.stride(0), rhs.stride(1)
+
+ if trans_A:
+ stride_am, stride_ak = lhs.stride(1), lhs.stride(0)
+ if trans_B:
+ stride_bk, stride_bn = rhs.stride(1), rhs.stride(0)
+
+ _sdd_kernel[grid](
+ lhs, rhs, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(1), out.stride(2),
+ row_indices, column_indices,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+
+@triton.jit
+def _row_indices_kernel(offsets, out):
+ pid = tl.program_id(0)
+ row_offset = tl.load(offsets + pid)
+ nnz_blocks = tl.load(offsets + pid + 1) - row_offset
+ for nnz_block in range(nnz_blocks):
+ tl.store(out + row_offset + nnz_block, pid)
+
+def row_indices(
+ shape, data, offsets, column_indices, out
+):
+ block_rows = len(offsets) - 1
+ _row_indices_kernel[(block_rows, )](offsets, out)
diff --git a/build/torch210-cxx11-cu130-x86_64-linux/stk/matrix.py b/build/torch210-cxx11-cu130-x86_64-linux/stk/matrix.py
new file mode 100644
index 0000000000000000000000000000000000000000..80f42263d6aada287adbfa52a61fe950162a9e28
--- /dev/null
+++ b/build/torch210-cxx11-cu130-x86_64-linux/stk/matrix.py
@@ -0,0 +1,329 @@
+import numpy as np
+import torch
+
+# 1. Add heavyweight (data) validation helper.
+# 2. Add construction helpers
+# 3. Make indentation consistent
+# 4. Replace asserts with descriptive errors.
+
+##
+### Validation helpers.
+##
+
+
+def _validate_matrix(shape, data, row_indices, column_indices, offsets):
+ # Data should be [nnz, block_size, block_size]
+ if data.dim() == 1:
+ data = torch.reshape(data, [data.numel(), 1, 1])
+
+ # Blocks should be square.
+ if data.shape[-2] != data.shape[-1]:
+ raise ValueError(
+ "Expected square blocking in data. "
+ f"Got block shape {[data.shape[-2], data.shape[-1]]}")
+
+ # Flatten batch dimensions on data - original shape preserved
+ # in shape argument.
+ block_size = data.shape[-1]
+ data = data.view([-1, block_size, block_size])
+
+ if data.dim() != 3:
+ raise ValueError(
+ "Expected 3D shape for data (nnz, block, block). "
+ f"Got shape {data.dim()}D shape.")
+
+ block_size = data.shape[1]
+ if shape[-2] % block_size != 0 or shape[-1] % block_size != 0:
+ raise ValueError(
+ "Matrix shape must be dividible by blocking. "
+ f"Got shape {shape} with "
+ f"{[block_size, block_size]} blocking.")
+
+ if np.prod(shape) < data.numel():
+ raise ValueError(
+ "Invalid matrix. Number of nonzeros exceeds matrix capacity "
+ f"({data.numel()} v. {np.prod(shape)})")
+
+ if row_indices.dim() != 1:
+ raise ValueError(
+ f"Expected 1D row_indices. Got {row_indices.dim()}D row_indices.")
+
+ if column_indices.dim() != 1:
+ raise ValueError(
+ f"Expected 1D column_indices. Got {column_indices.dim()}D column_indices.")
+
+ if offsets.dim() != 1:
+ raise ValueError(
+ f"Expected 1D offsets. Got {offsets.dim()}D offsets.")
+
+ if row_indices.numel() != data.shape[0]:
+ raise ValueError(
+ "Expected 1 index per nonzero block. "
+ f"Got {row_indices.numel()} row_indices for {data.shape[0]} blocks")
+
+ if column_indices.numel() != data.shape[0]:
+ raise ValueError(
+ "Expected 1 index per nonzero block. "
+ f"Got {column_indices.numel()} column_indices for {data.shape[0]} blocks")
+
+ block_rows = np.prod(shape[:-1]) / block_size
+ if offsets.numel() != block_rows + 1:
+ raise ValueError(
+ "Expected one offset per block row plus one. "
+ f"Got {offsets.numel()} offsets with {block_rows} block rows.")
+
+ is_cuda = (data.is_cuda and
+ row_indices.is_cuda and
+ column_indices.is_cuda and
+ offsets.is_cuda)
+ is_cpu = (not data.is_cuda and
+ not row_indices.is_cuda and
+ not column_indices.is_cuda and
+ not offsets.is_cuda)
+ if not (is_cuda or is_cpu):
+ raise ValueError(
+ "Expected data & meta-data on common device. "
+ f"Got data on {data.device}, row_indices on {row_indices.device} "
+ f"column_indices on {column_indices.device} and "
+ f"offsets on {offsets.device}.")
+
+ if data.dtype != torch.float16:
+ raise ValueError(
+ f"Expected float16 data. Got {data.dtype} data.")
+ if row_indices.dtype != torch.int16:
+ raise ValueError(
+ f"Expected int16 row_indices. Got {row_indices.dtype} row_indices.")
+ if column_indices.dtype != torch.int16:
+ raise ValueError(
+ f"Expected int16 column_indices. Got {column_indices.dtype} column_indices.")
+ if offsets.dtype != torch.int32:
+ raise ValueError(
+ f"Expected int32 offsets. Got {offsets.dtype} offsets.")
+ return data
+
+
+def _transpose(size, data, row_indices, column_indices, offsets):
+ block_columns = size[1] // data.shape[1]
+
+ # Sort row indices by column indices to get the transposed matrix's
+ # column indices.
+ gather_indices = column_indices.argsort()
+ column_indices_t = row_indices.gather(0, gather_indices)
+ block_offsets_t = gather_indices.int()
+
+ # NOTE: Histogram is not implemented for any integer type on CPU. Do
+ # the histogram in 32-bit float, which can exactly represent 16-bit
+ # integers.
+ column_indices_float = column_indices.float()
+
+ zero = torch.zeros((1,), dtype=torch.int32, device=data.device)
+ nnz_per_column = column_indices_float.histc(block_columns, 0, block_columns)
+ nnz_per_column = nnz_per_column.int()
+ offsets_t = torch.cat([zero, nnz_per_column.cumsum(0, dtype=torch.int32)])
+ return column_indices_t, offsets_t, block_offsets_t
+
+
+class Matrix(torch.nn.Module):
+ """A matrix stored in sparse format.
+
+ Underlying format is block compressed sparse row (BCSR).
+
+ TODO(tgale): Make this mirror torch.Tensor API as much as possible.
+ """
+
+ def __init__(self,
+ size,
+ data,
+ row_indices,
+ column_indices,
+ offsets,
+ column_indices_t=None,
+ offsets_t=None,
+ block_offsets_t=None):
+ super().__init__()
+ self._size = size
+ self._data = data
+ self._row_indices = row_indices
+ self._column_indices = column_indices
+ self._offsets = offsets
+
+ # Produce the transpose meta-data if it is not passed in.
+ if ((column_indices_t is None) or (offsets_t is None) or
+ (block_offsets_t is None)):
+ column_indices_t, offsets_t, block_offsets_t = _transpose(
+ size, data, row_indices, column_indices, offsets)
+ self._column_indices_t = column_indices_t
+ self._offsets_t = offsets_t
+ self._block_offsets_t = block_offsets_t
+
+ self._transposed = False
+
+ # Validate that our metadata will not overflow.
+ max_dim = np.iinfo(np.int16).max * self.blocking
+ if column_indices.dtype == torch.int16:
+ if size[0] > max_dim or size[1] > max_dim:
+ raise ValueError(
+ "Sparse matrix with shape {size} exceeds representable "
+ "size with 16-bit indices.")
+
+ def validate(self):
+ _validate_matrix(self._size,
+ self._data,
+ self._row_indices,
+ self._column_indices,
+ self._offsets)
+
+ # TODO(tgale): Add heavyweight data validation.
+
+ def to(self, device):
+ # TODO(tgale): Handle type conversions here. We
+ # need to set the appropriate meta-data type for
+ # the given floating-point type.
+ self._data = self._data.to(device)
+ self._row_indices = self._row_indices.to(device)
+ self._column_indices = self._column_indices.to(device)
+ self._offsets = self._offsets.to(device)
+ self._column_indices_t = self._column_indices_t.to(device)
+ self._offsets_t = self._offsets_t.to(device)
+ self._block_offsets_t = self._block_offsets_t.to(device)
+ return self
+
+ def cuda(self):
+ return self.to(torch.cuda.current_device())
+
+ def clone(self):
+ return Matrix(
+ self.size(),
+ self.data.clone(),
+ self.row_indices.clone(),
+ self.column_indices.clone(),
+ self.offsets.clone(),
+ self.column_indices_t.clone(),
+ self.offsets_t.clone(),
+ self.block_offsets_t.clone())
+
+ def t(self):
+ if self.dim() != 2:
+ raise ValueError(
+ "t() expects a tensor with <= 2 dimensions, "
+ f"but self is {self.dim()}D.")
+ out = Matrix(self.size(),
+ self.data,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+ out._transposed = not self._transposed
+ out._size = torch.Size((self._size[1], self._size[0]))
+ return out
+
+ def contiguous(self):
+ raise ValueError("Not yet implemented.")
+
+ def is_contiguous(self):
+ return not self._transposed
+
+ @property
+ def is_cuda(self):
+ return self._data.is_cuda
+
+ @property
+ def device(self):
+ return self._data.device
+
+ def size(self):
+ return self._size
+
+ @property
+ def shape(self):
+ return self.size()
+
+ def dim(self):
+ return len(self._size)
+
+ @property
+ def data(self):
+ return self._data
+
+ @property
+ def row_indices(self):
+ return self._row_indices
+
+ @property
+ def column_indices(self):
+ return self._column_indices
+
+ @property
+ def offsets(self):
+ return self._offsets
+
+ @property
+ def offsets_t(self):
+ return self._offsets_t
+
+ @property
+ def column_indices_t(self):
+ return self._column_indices_t
+
+ @property
+ def block_offsets_t(self):
+ return self._block_offsets_t
+
+ @property
+ def dtype(self):
+ return self.data.dtype
+
+ @property
+ def nnz(self):
+ return self.data.numel()
+
+ @property
+ def blocking(self):
+ return self.data.shape[1]
+
+ @property
+ def requires_grad(self):
+ return self.data.requires_grad
+
+ def requires_grad_(self, x):
+ self.data.requires_grad_(x)
+ return self
+
+ def view(self, *shape):
+ assert self.is_contiguous()
+ if shape[-1] != self.size()[-1]:
+ raise ValueError(
+ "Can't change view on compressed dimension. "
+ f"{self.size()[-1]} v. {shape[-1]}.")
+ if np.prod(shape) != np.prod(self.size()):
+ raise ValueError(
+ "Mismatch in numel of Matrix and new shape. "
+ f"{np.prod(self.size())} v. {np.prod(shape)}")
+ return Matrix(shape,
+ self.data,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+
+ @property
+ def grad(self):
+ # TODO(tgale): Make sure this mirrors torch.Tensor
+ # behavior in the case where we ask for the gradient
+ # of a non-contiguous tensor.
+ size = self.size()
+ if not self.is_contiguous():
+ size = torch.Size((size[1], size[0]))
+ out = Matrix(size,
+ self.data.grad,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+ return out if self.is_contiguous() else out.t()
diff --git a/build/torch210-cxx11-cu130-x86_64-linux/stk/ops/__init__.py b/build/torch210-cxx11-cu130-x86_64-linux/stk/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc873b236f4cd4036964c016a4036e3ce5ebf1ac
--- /dev/null
+++ b/build/torch210-cxx11-cu130-x86_64-linux/stk/ops/__init__.py
@@ -0,0 +1,3 @@
+from .linear_ops import dds, dsd, sdd
+from .matrix_ops import ones_like, row_indices, sum, to_dense, to_sparse
+from .eltwise_ops import mul
diff --git a/build/torch210-cxx11-cu130-x86_64-linux/stk/ops/eltwise_ops.py b/build/torch210-cxx11-cu130-x86_64-linux/stk/ops/eltwise_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba7d7332320250fd01fa60e60528f19de3e8ed03
--- /dev/null
+++ b/build/torch210-cxx11-cu130-x86_64-linux/stk/ops/eltwise_ops.py
@@ -0,0 +1,28 @@
+from ..matrix import Matrix
+
+def mul(a, b):
+ """Performs element-wise multiplication of matrices a and b.
+
+ It is the user's responsibility to make sure that a and b
+ follow the same matrix topology. This function assumes it is safe
+ to use the topoplogy of a.
+
+ Args:
+ a: stk.Matrix.
+ b: stk.Matrix with a's matrix topology.
+
+ Returns:
+ stk.Matrix where the entries correspond to torch.mul(a, b).
+ """
+ assert isinstance(a, Matrix)
+ assert isinstance(b, Matrix)
+ assert a.size() == b.size()
+
+ return Matrix(a.size(),
+ a.data * b.data,
+ a.row_indices,
+ a.column_indices,
+ a.offsets,
+ a.column_indices_t,
+ a.offsets_t,
+ a.block_offsets_t)
diff --git a/build/torch210-cxx11-cu130-x86_64-linux/stk/ops/eltwise_ops_test.py b/build/torch210-cxx11-cu130-x86_64-linux/stk/ops/eltwise_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..2939372a5c68ac92b47b11015db4f75f4fd60ffa
--- /dev/null
+++ b/build/torch210-cxx11-cu130-x86_64-linux/stk/ops/eltwise_ops_test.py
@@ -0,0 +1,86 @@
+import unittest
+import itertools
+import torch
+# from absl.testing import parameterized
+
+import stk
+from stk.ops.linear_ops_test import allclose, _dense_and_sparse
+
+_MATRIX_SIZES = (
+ (128, 128, 0.0),
+ (256, 256, 0.5),
+ (2048, 1024, 0.8),
+ (512, 128, 0.0),
+ (128, 512, 0.0),
+ (1024, 512, 0.0),
+ (1024, 512, 0.5),
+ (1024, 512, 0.75),
+ (512, 1024, 0.0),
+ (512, 1024, 0.5),
+ (512, 1024, 0.75),
+ (1024, 1024, 0.0),
+ (1024, 1024, 0.5),
+ (1024, 1024, 0.75),
+)
+
+_DTYPE = (
+ torch.float16, torch.bfloat16
+)
+
+def _generate_testcases():
+ testcases = itertools.product(_MATRIX_SIZES, _DTYPE)
+ testcases = [(*size, 128, dtype) for
+ (size, dtype) in testcases]
+ return testcases
+
+_ELTWISE_OP_TESTS = _generate_testcases()
+
+def _dense_and_sparse_like(x, std=0.1):
+ dense_data = torch.randn_like(x.data, device=x.device) * std
+ sparse = stk.Matrix(x.size(),
+ dense_data,
+ x.row_indices,
+ x.column_indices,
+ x.offsets)
+ dense = stk.ops.to_dense(sparse)
+
+ return (dense.requires_grad_(True),
+ sparse.requires_grad_(True))
+
+# @parameterized.parameters(_ELTWISE_OP_TESTS)
+# class EltwiseOpsTest(parameterized.TestCase):
+#
+# def testEltwiseMul(self, m, n, sparsity, blocking, dtype):
+#
+# a_dense, a = _dense_and_sparse(m, n, sparsity, blocking, dtype)
+# b_dense, b = _dense_and_sparse_like(a)
+#
+# out = stk.ops.mul(a, b)
+# expected_out = torch.mul(a_dense, b_dense)
+#
+# # Compute the gradients w.r.t. the inputs.
+# expected_out.sum().backward()
+# stk.ops.sum(out).backward()
+#
+# # Validate the results.
+# out = stk.ops.to_dense(out)
+# self.assertEqual(out.dim(), 2)
+# self.assertEqual(expected_out.size(), out.size())
+# self.assertTrue(allclose(out, expected_out))
+#
+# # LHS gradient.
+# grad = stk.ops.to_dense(a.grad)
+# expected_grad = a_dense.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size(), grad.size())
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# # RHS gradient.
+# grad = stk.ops.to_dense(b.grad)
+# expected_grad = b_dense.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size(), grad.size())
+# self.assertTrue(allclose(grad, expected_grad))
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch210-cxx11-cu130-x86_64-linux/stk/ops/linear_ops.py b/build/torch210-cxx11-cu130-x86_64-linux/stk/ops/linear_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d277c8c07f9e30addc31900a12175c8a1f4d7ad
--- /dev/null
+++ b/build/torch210-cxx11-cu130-x86_64-linux/stk/ops/linear_ops.py
@@ -0,0 +1,59 @@
+import torch
+
+from ..backend import sputnik
+from ..matrix import Matrix
+
+
+def dsd(a, b):
+ assert isinstance(a, Matrix)
+ assert isinstance(b, torch.Tensor)
+ return sputnik.dsd(
+ a.size(),
+ a.data, a.offsets,
+ a.row_indices,
+ a.column_indices,
+ a.offsets_t,
+ a.column_indices_t,
+ a.block_offsets_t,
+ not a.is_contiguous(),
+ b)
+
+
+def dds(a, b):
+ assert isinstance(a, torch.Tensor)
+ assert isinstance(b, Matrix)
+ return sputnik.dds(
+ a,
+ b.size(),
+ b.data, b.offsets,
+ b.row_indices,
+ b.column_indices,
+ b.offsets_t,
+ b.column_indices_t,
+ b.block_offsets_t,
+ not b.is_contiguous())
+
+
+def sdd(a, b, topo):
+ assert isinstance(a, torch.Tensor)
+ assert isinstance(b, torch.Tensor)
+ assert isinstance(topo, Matrix)
+ assert topo.is_contiguous()
+ out = sputnik.sdd(
+ a, b,
+ topo.size(),
+ topo.data,
+ topo.offsets,
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets_t,
+ topo.column_indices_t,
+ topo.block_offsets_t)
+ return Matrix(topo.size(),
+ out,
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets,
+ topo.column_indices_t,
+ topo.offsets_t,
+ topo.block_offsets_t)
diff --git a/build/torch210-cxx11-cu130-x86_64-linux/stk/ops/linear_ops_test.py b/build/torch210-cxx11-cu130-x86_64-linux/stk/ops/linear_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1c24d350df9c1b2346c7da885502cd696c88867
--- /dev/null
+++ b/build/torch210-cxx11-cu130-x86_64-linux/stk/ops/linear_ops_test.py
@@ -0,0 +1,216 @@
+import unittest
+import itertools
+import numpy as np
+import torch
+# from absl.testing import parameterized
+
+import stk
+
+
+def allclose(x, y, pct=0.25):
+ mask = torch.isclose(x, y, rtol=5e-2)
+ pct_diff = (mask.numel() - mask.sum()) / mask.numel() * 100
+ if pct_diff > pct:
+ print("{:.2f}% of values not close.".format(pct_diff))
+ return False
+ return True
+
+
+# An assortment of problems designed to make sure
+# the bindings are operating correctly.
+_MATRIX_SIZES = (
+ (128, 128, 128, 0.0),
+ (256, 256, 256, 0.5),
+ (2048, 1024, 512, 0.8),
+ (512, 128, 128, 0.0),
+ (128, 128, 512, 0.0),
+ (1024, 512, 512, 0.0),
+ (1024, 512, 512, 0.5),
+ (1024, 512, 512, 0.75),
+ (512, 512, 1024, 0.0),
+ (512, 512, 1024, 0.5),
+ (512, 512, 1024, 0.75),
+ (1024, 1024, 1024, 0.0),
+ (1024, 1024, 1024, 0.5),
+ (1024, 1024, 1024, 0.75),
+)
+
+_TRANSPOSE = (
+ (False, False),
+ (False, True),
+ (True, False),
+ (True, True),
+)
+
+_DTYPE = (
+ torch.float16, torch.bfloat16
+)
+
+def _generate_testcases():
+ testcases = itertools.product(_MATRIX_SIZES, _TRANSPOSE, _DTYPE)
+ testcases = [(*size, *trans, 128, dtype) for
+ (size, trans, dtype) in testcases]
+ return testcases
+
+_LINEAR_OP_TESTS = _generate_testcases()
+
+def _dense_and_sparse(rows, cols, sparsity, blocking, dtype, std=0.1):
+ mask = stk.random.dense_mask(rows, cols, sparsity, blocking)
+ dense = (torch.randn(rows, cols) * std * mask).type(dtype)
+ sparse = stk.ops.to_sparse(dense, blocking)
+ cuda_device = torch.device("cuda")
+ return (dense.to(cuda_device).requires_grad_(True),
+ sparse.to(cuda_device).requires_grad_(True))
+
+
+def _dense(rows, cols, dtype, std=0.1):
+ cuda_device = torch.device("cuda")
+ out = (torch.randn(rows, cols) * std).type(dtype)
+ return out.to(cuda_device).requires_grad_(True)
+
+
+def _dense_2x(rows, cols, dtype):
+ a = _dense(rows, cols, dtype)
+ return a, a.detach().requires_grad_(True)
+
+
+def _with_transpose(op, a, b, trans_a, trans_b):
+ a = a.t() if trans_a else a
+ b = b.t() if trans_b else b
+ return op(a, b)
+
+
+def _mmm(a, b, topo):
+ mask = stk.ops.to_dense(stk.ops.ones_like(topo))
+ return torch.mm(a, b) * mask
+
+
+def _sparse_out_with_transpose(op, a, b, topo, trans_a, trans_b):
+ a = a.t() if trans_a else a
+ b = b.t() if trans_b else b
+ return op(a, b, topo)
+
+
+def _mask(x, mask):
+ mask = stk.ops.to_dense(stk.ops.ones_like(mask))
+ return x * mask
+
+
+# @parameterized.parameters(*_LINEAR_OP_TESTS)
+# class LinearOpsTest(parameterized.TestCase):
+#
+# def testLinearOps_Dsd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+# # Construct the operands.
+# a_shape = (k, m) if trans_a else (m, k)
+# a_dense, a = _dense_and_sparse(*a_shape, sparsity, blocking, dtype)
+# b_shape = (n, k) if trans_b else (k, n)
+# b, bcp = _dense_2x(*b_shape, dtype)
+#
+# # Execute the matmul.
+# out = _with_transpose(stk.ops.dsd, a, b, trans_a, trans_b)
+# expected_out = _with_transpose(torch.mm, a_dense, bcp, trans_a, trans_b)
+#
+# # Compute the gradients w.r.t. the inputs.
+# expected_out.sum().backward()
+# out.sum().backward()
+#
+# # Validate the results.
+# self.assertEqual(out.dim(), 2)
+# self.assertEqual(expected_out.size()[0], out.size()[0])
+# self.assertEqual(expected_out.size()[1], out.size()[1])
+# self.assertTrue(allclose(out, expected_out))
+#
+# # LHS gradient.
+# grad = stk.ops.to_dense(a.grad)
+# expected_grad = _mask(a_dense.grad, a.grad)
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# # RHS gradient.
+# grad = b.grad
+# expected_grad = bcp.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# def testLinearOps_Dds(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+# # Construct the operands.
+# a_shape = (k, m) if trans_a else (m, k)
+# a, acp = _dense_2x(*a_shape, dtype)
+# b_shape = (n, k) if trans_b else (k, n)
+# b_dense, b = _dense_and_sparse(*b_shape, sparsity, blocking, dtype)
+#
+# # Execute the matmul.
+# out = _with_transpose(stk.ops.dds, a, b, trans_a, trans_b)
+# expected_out = _with_transpose(torch.mm, acp, b_dense, trans_a, trans_b)
+#
+# # Compute the gradients w.r.t. the inputs.
+# expected_out.sum().backward()
+# out.sum().backward()
+#
+# # Validate the results.
+# self.assertEqual(out.dim(), 2)
+# self.assertEqual(expected_out.size()[0], out.size()[0])
+# self.assertEqual(expected_out.size()[1], out.size()[1])
+# self.assertTrue(allclose(out, expected_out))
+#
+# # LHS gradient.
+# grad = a.grad
+# expected_grad = acp.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# # RHS gradient.
+# grad = stk.ops.to_dense(b.grad)
+# expected_grad = _mask(b_dense.grad, b.grad)
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# def testLinearOps_Sdd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+# # Construct the operands.
+# a_shape = (k, m) if trans_a else (m, k)
+# a, acp = _dense_2x(*a_shape, dtype)
+# b_shape = (n, k) if trans_b else (k, n)
+# b, bcp = _dense_2x(*b_shape, dtype)
+# _, topo = _dense_and_sparse(m, n, sparsity, blocking, dtype)
+#
+# # Execute the matmul.
+# out = _sparse_out_with_transpose(stk.ops.sdd, a, b, topo, trans_a, trans_b)
+# expected_out = _sparse_out_with_transpose(_mmm, acp, bcp, topo, trans_a, trans_b)
+#
+# # Compute the gradients w.r.t. the inputs.
+# expected_out.sum().backward()
+# stk.ops.sum(out).backward()
+#
+# # Validate the results.
+# out = stk.ops.to_dense(out)
+# self.assertEqual(out.dim(), 2)
+# self.assertEqual(expected_out.size()[0], out.size()[0])
+# self.assertEqual(expected_out.size()[1], out.size()[1])
+# self.assertTrue(allclose(out, expected_out))
+#
+# # LHS gradient.
+# grad = a.grad
+# expected_grad = acp.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# # RHS gradient.
+# grad = b.grad
+# expected_grad = bcp.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch210-cxx11-cu130-x86_64-linux/stk/ops/matrix_ops.py b/build/torch210-cxx11-cu130-x86_64-linux/stk/ops/matrix_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..447c72dc73439d84f58c917676cc04e64f13e97d
--- /dev/null
+++ b/build/torch210-cxx11-cu130-x86_64-linux/stk/ops/matrix_ops.py
@@ -0,0 +1,98 @@
+from ..backend import sputnik
+from ..matrix import Matrix
+import torch
+import numpy as np
+
+
+@torch.no_grad()
+def row_indices(shape, data, offsets, column_indices):
+ return sputnik.row_indices(shape, data, offsets, column_indices)
+
+
+# TODO(tgale): Replace this helper with a custom kernel. This operation
+# is much simpler to do than how it's currently implemented.
+@torch.no_grad()
+def _expand_for_blocking(idxs, blocking):
+ # Duplicate for block column dimension.
+ idxs = torch.reshape(idxs, [idxs.size()[0], 1, 2]).repeat(1, blocking, 1)
+
+ # Update the column indices.
+ idxs[:, :, 1] *= blocking
+ idxs[:, :, 1] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking])
+
+ # Duplicate for block row dimension.
+ idxs = torch.reshape(idxs, [idxs.size()[0], 1, blocking, 2])
+ idxs = idxs.repeat(1, blocking, 1, 1)
+
+ # Update the row indices.
+ idxs[:, :, :, 0] *= blocking
+ idxs[:, :, :, 0] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking, 1])
+ idxs = torch.reshape(idxs, [-1, 2])
+ return idxs
+
+
+# TODO(tgale): Add input type checking.
+@torch.no_grad()
+def to_dense(x):
+ assert isinstance(x, Matrix)
+
+ shape = (np.prod(x.shape[:-1]), x.shape[-1])
+ row_idxs = x.row_indices.type(torch.int32)
+ col_idxs = x.column_indices.type(torch.int32)
+ indices = _expand_for_blocking(torch.stack([row_idxs, col_idxs], dim=1), x.blocking)
+ indices = (indices[:, 0] * shape[1] + indices[:, 1]).type(torch.int64)
+
+ out = torch.zeros(shape[0] * shape[1], dtype=x.dtype, device=x.device)
+ out.scatter_(0, indices, x.data.flatten())
+ return out.reshape(x.size())
+
+
+@torch.no_grad()
+def _mask(x, blocking=1):
+ assert x.dim() == 2
+ assert x.size()[0] % blocking == 0
+ assert x.size()[1] % blocking == 0
+ block_rows = x.size()[0] // blocking
+ block_cols = x.size()[1] // blocking
+ x = torch.reshape(x, [block_rows, blocking, block_cols, blocking])
+ x = torch.sum(torch.abs(x), dim=(1, 3))
+ return x != 0
+
+
+# TODO(tgale): Add input type checking.
+@torch.no_grad()
+def to_sparse(x, blocking=1):
+ m = _mask(x, blocking)
+
+ # TODO(tgale): Set to appropriate type for input matrix.
+ row_nnzs = torch.sum(m, dim=1).type(torch.int32)
+ zeros = torch.zeros((1,), dtype=row_nnzs.dtype, device=row_nnzs.device)
+ offsets = torch.cat([zeros, torch.cumsum(row_nnzs, dim=0)])
+ offsets = offsets.type(torch.int32)
+
+ indices = torch.nonzero(m).type(torch.int16)
+ row_indices = indices[:, 0]
+ column_indices = indices[:, 1]
+
+ # Nonzero indices in the dense matrix.
+ nonzero_indices = torch.nonzero(m)
+ nonzero_indices = _expand_for_blocking(nonzero_indices, blocking)
+ nonzero_indices = nonzero_indices[:, 0] * x.size()[1] + nonzero_indices[:, 1]
+
+ # Gather the data and construct the sparse matrix.
+ data = torch.gather(x.flatten(), dim=0, index=nonzero_indices)
+ data = torch.reshape(data, [-1, blocking, blocking])
+ return Matrix(x.size(), data, row_indices, column_indices, offsets)
+
+
+@torch.no_grad()
+def ones_like(x):
+ return Matrix(x.size(),
+ torch.ones_like(x.data),
+ x.row_indices,
+ x.column_indices, x.offsets)
+
+
+def sum(x):
+ assert isinstance(x, Matrix)
+ return x.data.sum()
diff --git a/build/torch210-cxx11-cu130-x86_64-linux/stk/ops/matrix_ops_test.py b/build/torch210-cxx11-cu130-x86_64-linux/stk/ops/matrix_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..d172d921f6f08b0e4fb709207a458b0e1e071bd0
--- /dev/null
+++ b/build/torch210-cxx11-cu130-x86_64-linux/stk/ops/matrix_ops_test.py
@@ -0,0 +1,62 @@
+import unittest
+
+# from absl.testing import parameterized
+import stk
+import torch
+
+
+# @parameterized.parameters(
+# (8, 16, 0.0, 1),
+# (8, 16, 0.5, 1),
+# (8, 16, .95, 1),
+# (16, 8, 0.0, 1),
+# (16, 8, 0.5, 1),
+# (16, 8, .95, 1),
+# (8, 16, 0.0, 8),
+# (8, 16, 0.5, 8),
+# (8, 16, 1.0, 8),
+# (16, 8, 0.0, 8),
+# (16, 8, 0.5, 8),
+# (16, 8, 1.0, 8),
+# (128, 256, 0.5, 16),
+# (256, 128, 0.75, 32),
+# (512, 512, .875, 128))
+# class MatrixOpsTest(parameterized.TestCase):
+#
+# def testMatrixOps_FormatConversion(self, rows, cols, sparsity, blocking):
+# mask = stk.random.dense_mask(rows, cols, sparsity, blocking)
+# x = (torch.randn(rows, cols) * mask).type(torch.float16)
+#
+# # Convert the matrix to sparse format.
+# sparse_x = stk.ops.to_sparse(x, blocking)
+#
+# # Validate the matrix.
+# sparse_x.validate()
+#
+# # Validate the shape.
+# self.assertEqual(sparse_x.dim(), 2)
+# self.assertEqual(sparse_x.size()[0], rows)
+# self.assertEqual(sparse_x.size()[1], cols)
+#
+# # Validate the sparsity.
+# numblocks = rows // blocking * cols // blocking
+# nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+# self.assertEqual(sparse_x.nnz, nnz)
+#
+# # Convert back to dense format.
+# dense_x = stk.ops.to_dense(sparse_x)
+#
+# # Validate the shape.
+# self.assertEqual(dense_x.dim(), 2)
+# self.assertEqual(dense_x.size()[0], rows)
+# self.assertEqual(dense_x.size()[1], cols)
+#
+# # Validate the sparsity
+# self.assertEqual(torch.count_nonzero(dense_x).item(), nnz)
+#
+# # Validate the output.
+# self.assertTrue(torch.all(torch.eq(x, dense_x)))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch210-cxx11-cu130-x86_64-linux/stk/random/__init__.py b/build/torch210-cxx11-cu130-x86_64-linux/stk/random/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2576d1ca27283f77569a9a620c7c99fa68aaf30e
--- /dev/null
+++ b/build/torch210-cxx11-cu130-x86_64-linux/stk/random/__init__.py
@@ -0,0 +1,2 @@
+# from stk.random.random_ops import dense_mask, mask, randn
+from .random_ops import dense_mask, mask, randn
diff --git a/build/torch210-cxx11-cu130-x86_64-linux/stk/random/random_ops.py b/build/torch210-cxx11-cu130-x86_64-linux/stk/random/random_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1b36771e0eb8e7abf46bcb3b136b5fb1d29df93
--- /dev/null
+++ b/build/torch210-cxx11-cu130-x86_64-linux/stk/random/random_ops.py
@@ -0,0 +1,36 @@
+import numpy as np
+import torch
+from ..ops import matrix_ops
+
+
+@torch.no_grad()
+def dense_mask(rows, cols, sparsity, blocking=1):
+ assert sparsity >= 0.0 and sparsity <= 1.0
+ assert rows % blocking == 0 and cols % blocking == 0
+
+ block_rows, block_cols = (rows // blocking, cols // blocking)
+ nnz = round(block_rows * block_cols * (1 - sparsity))
+
+ out = np.ones(block_rows * block_cols)
+ mask = np.random.choice(out.size, out.size - nnz, replace=False)
+ out[mask] = 0.0
+
+ out = np.tile(
+ np.reshape(out, [block_rows, 1, block_cols, 1]),
+ (1, blocking, 1, blocking))
+ out = np.reshape(out, [rows, cols])
+ return torch.from_numpy(out.astype(np.float32))
+
+
+@torch.no_grad()
+def mask(m, n, sparsity, blocking=1):
+ out = dense_mask(m, n, sparsity, blocking).type(torch.float16)
+ return matrix_ops.to_sparse(out, blocking=blocking)
+
+
+@torch.no_grad()
+def randn(shape, sparsity, blocking=1):
+ shape_2d = (np.prod(shape[:-1]), shape[-1])
+ out = mask(*shape_2d, sparsity, blocking)
+ out.data.copy_(torch.randn(*out.data.shape))
+ return out.view(*shape)
diff --git a/build/torch210-cxx11-cu130-x86_64-linux/stk/random/random_ops_test.py b/build/torch210-cxx11-cu130-x86_64-linux/stk/random/random_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..71d716b78b5ec009cbf9ac2dfdf09162a0102e62
--- /dev/null
+++ b/build/torch210-cxx11-cu130-x86_64-linux/stk/random/random_ops_test.py
@@ -0,0 +1,73 @@
+import unittest
+
+# from absl.testing import parameterized
+from . import random
+import torch
+
+
+# @parameterized.parameters(
+# (8, 16, 0.0, 1),
+# (8, 16, 0.5, 1),
+# (8, 16, .95, 1),
+# (16, 8, 0.0, 1),
+# (16, 8, 0.5, 1),
+# (16, 8, .95, 1),
+# (8, 16, 0.0, 8),
+# (8, 16, 0.5, 8),
+# (8, 16, 1.0, 8),
+# (16, 8, 0.0, 8),
+# (16, 8, 0.5, 8),
+# (16, 8, 1.0, 8),
+# (128, 256, 0.5, 16),
+# (256, 128, 0.75, 32),
+# (512, 512, .875, 128))
+# class RandomOpsTest(parameterized.TestCase):
+#
+# def testRandomOps_DenseMask(self, rows, cols, sparsity, blocking):
+# mask = random.dense_mask(
+# rows, cols, sparsity, blocking)
+#
+# # Validate the shape.
+# self.assertEqual(mask.dim(), 2)
+# self.assertEqual(mask.size()[0], rows)
+# self.assertEqual(mask.size()[1], cols)
+#
+# # Validate the sparsity
+# numblocks = rows // blocking * cols // blocking
+# nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+# self.assertEqual(
+# torch.count_nonzero(mask).item(),
+# nnz)
+#
+# # Check values are zero or one.
+# self.assertTrue(
+# torch.all(torch.logical_or(
+# torch.eq(mask, 0),
+# torch.eq(mask, 1))))
+#
+# def testRandomOps_SparseMask(self, rows, cols, sparsity, blocking):
+# mask = random.mask(
+# rows, cols, sparsity, blocking)
+#
+# # Validate the matrix.
+# mask.validate()
+#
+# # Validate the shape.
+# self.assertEqual(mask.dim(), 2)
+# self.assertEqual(mask.size()[0], rows)
+# self.assertEqual(mask.size()[1], cols)
+#
+# # Validate the sparsity.
+# numblocks = rows // blocking * cols // blocking
+# nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+# self.assertEqual(mask.nnz, nnz)
+#
+# # Check values are zero or one.
+# self.assertTrue(
+# torch.all(torch.logical_or(
+# torch.eq(mask.data, 0),
+# torch.eq(mask.data, 1))))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch210-cxx11-cu130-x86_64-linux/xpu_fused_moe.py b/build/torch210-cxx11-cu130-x86_64-linux/xpu_fused_moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..0440b484e6f4073a384dcf3bc562174d2ef71d25
--- /dev/null
+++ b/build/torch210-cxx11-cu130-x86_64-linux/xpu_fused_moe.py
@@ -0,0 +1,672 @@
+# SPDX-License-Identifier: Apache-2.0
+# MegaBlocks XPU Fused MoE Implementation
+import os
+import torch
+
+from ._ops import ops, add_op_namespace_prefix
+
+from torch.library import register_fake
+
+
+def resolve_dtensor(weight: torch.Tensor):
+ """Convert DTensor to local tensor for use with custom ops."""
+ from torch.distributed._tensor import DTensor
+ if isinstance(weight, DTensor):
+ return weight.to_local()
+ return weight
+
+
+# Register fake/meta kernels for torch.compile compatibility
+def _register_xpu_fake_kernels():
+ """Register fake kernels for XPU MoE operations to support torch.compile."""
+
+ def _register_if_available(op_name, fn):
+ if hasattr(ops, op_name):
+ register_fake(add_op_namespace_prefix(op_name))(fn)
+
+ _register_if_available(
+ "cutlass_grouped_gemm_interface",
+ lambda ptr_A, ptr_B, ptr_scales, ptr_bias, ptr_D, expert_first_token_offset, N, K, num_experts, is_B_int4, is_B_mxfp4: ptr_D,
+ )
+
+ _register_if_available(
+ "fused_moe_prologue",
+ lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, ep_rank, ep_size, num_experts_on_rank: None,
+ )
+
+ _register_if_available(
+ "moe_gather",
+ lambda output, moe_output, topk_weights, permuted_row_to_unpermuted_row, unpermuted_row_to_permuted_row, expert_first_token_offset, num_experts: None,
+ )
+
+ _register_if_available(
+ "silu_and_mul",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "mul_and_silu",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_and_mul",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_tanh_and_mul",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_fast",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_new",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_quick",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "swigluoai_and_mul",
+ lambda out, input, alpha=1.702, limit=7.0: None,
+ )
+
+
+# Register fake kernels on module load
+_register_xpu_fake_kernels()
+
+
+# default
+def cutlass_grouped_gemm(input_A, input_B, bias, output, expert_token_count, n,
+ k, num_experts):
+ # expert_token_count_ = torch.tensor(expert_token_count,
+ # dtype=torch.int64,
+ # device=input_A.device)
+ # if bias is not None:
+ # bias = bias.repeat_interleave(expert_token_count_, dim=0).float()
+
+ def exclusive_prefix_sum(arr):
+ prefix = [0]
+ for i, x in enumerate(arr):
+ prefix.append(prefix[-1] + x)
+ return prefix
+
+ expert_offset = torch.tensor(exclusive_prefix_sum(expert_token_count),
+ dtype=torch.int64,
+ device="xpu")
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=None,
+ ptr_bias=bias,
+ ptr_D=output,
+ expert_first_token_offset=expert_offset,
+ N=n,
+ K=k,
+ num_experts=num_experts,
+ is_B_int4=False,
+ is_B_mxfp4=False)
+
+
+def cutlass_grouped_gemm_xe2(input_A, input_B, scales, bias, output,
+ num_rows_per_expert, n, k, num_experts, is_B_int4,
+ is_B_mxfp4):
+ expert_first_token_offset = torch.cat([
+ torch.tensor([0],
+ dtype=num_rows_per_expert.dtype,
+ device=num_rows_per_expert.device),
+ torch.cumsum(num_rows_per_expert, dim=0)
+ ]).to(torch.int64)
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=scales,
+ ptr_bias=bias,
+ ptr_D=output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=n,
+ K=k,
+ num_experts=num_experts,
+ is_B_int4=is_B_int4,
+ is_B_mxfp4=is_B_mxfp4)
+
+
+def ceilDiv(a, b):
+ return (a + b - 1) // b
+
+
+def compute_num_tokens_per_block(num_tokens, num_experts_per_node):
+ for num_tokens_per_block in [32, 64, 128, 256, 512, 1024]:
+ num_blocks_per_seq = ceilDiv(num_tokens, num_tokens_per_block)
+ if num_blocks_per_seq * num_experts_per_node <= num_tokens_per_block:
+ return num_tokens_per_block
+ return 1024
+
+
+def _bytes_to_typed_tensor(byte_tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
+ """Reinterpret a uint8 buffer as a typed tensor by copying bytes.
+
+ This avoids `Tensor.view(dtype)` which can fail under torch.compile
+ constant folding when shape divisibility is not proven.
+ """
+ if byte_tensor.dtype != torch.uint8:
+ raise ValueError("byte_tensor must be uint8")
+ itemsize = torch.empty((), dtype=dtype).element_size()
+ numel = byte_tensor.numel() // itemsize
+ out = torch.empty((numel,), dtype=dtype, device=byte_tensor.device)
+ out.view(torch.uint8).copy_(byte_tensor.contiguous())
+ return out
+
+
+def implement_zp(qweight):
+ # change u4 to s4 to avoid zero point in gemm kernel
+ # only support default zero point now
+ assert qweight.dtype == torch.uint8, "Input tensor must be uint8"
+
+ high_u4 = (qweight >> 4) & 0x0F
+ low_u4 = qweight & 0x0F
+
+ high_s8 = high_u4.to(torch.int8)
+ low_s8 = low_u4.to(torch.int8)
+
+ high_s8 = high_s8 - 8
+ low_s8 = low_s8 - 8
+
+ def pack_compact(a, b):
+
+ def process_number(x):
+ sign = (x < 0).to(torch.uint8)
+ abs_low3 = (x.view(torch.uint8) & 0x7).to(torch.uint8)
+ return (sign << 3) | abs_low3
+
+ packed_a = process_number(a)
+ packed_b = process_number(b)
+
+ return (packed_a << 4) | packed_b
+
+ result = pack_compact(high_s8, low_s8)
+
+ return result
+
+
+def xpu_fused_moe(hidden_states,
+ w13,
+ w13_scales,
+ w13_bias,
+ w2,
+ w2_scales,
+ w2_bias,
+ topk_weights,
+ topk_ids,
+ n_experts_per_token,
+ activation,
+ num_experts,
+ ep_rank=0,
+ ep_size=1,
+ is_fp8=False,
+ is_int4=False,
+ is_mxfp4=False):
+ '''
+ hidden_states: [num_rows, hidden_size]
+ w13: [num_experts, 2*inter_size, hidden_size]
+ w13_scales:
+ None for bf16/fp16
+ or [num_experts] for fp8
+ or [num_experts, 2*inter_size, hidden_size // group_size] for 4bits
+ w13_bias: [num_experts, 2*inter_size] or None
+ w2: [num_experts, hidden_size, inter_size]
+ w2_scales:
+ None for bf16/fp16
+ or [num_experts] for fp8
+ or [num_experts, hidden_size, inter_size // group_size] for 4bits
+ w2_bias: [num_experts, hidden_size] or None
+ topk_weights: [num_rows, topk]
+ topk_ids: [num_rows, topk]
+ n_experts_per_token: int
+ activation: str
+ num_experts: int
+ is_int4: bool
+ is_mxfp4: bool
+ '''
+
+ # Resolve DTensors to local tensors before passing to custom ops
+ hidden_states = resolve_dtensor(hidden_states)
+ w13 = resolve_dtensor(w13)
+ w2 = resolve_dtensor(w2)
+ if w13_scales is not None:
+ w13_scales = resolve_dtensor(w13_scales)
+ if w13_bias is not None:
+ w13_bias = resolve_dtensor(w13_bias)
+ if w2_scales is not None:
+ w2_scales = resolve_dtensor(w2_scales)
+ if w2_bias is not None:
+ w2_bias = resolve_dtensor(w2_bias)
+ topk_weights = resolve_dtensor(topk_weights)
+ topk_ids = resolve_dtensor(topk_ids)
+
+ output = torch.empty_like(hidden_states)
+ num_rows, hidden_size = list(hidden_states.shape)
+
+ dim_last = w13.shape[-1]
+ dim_second_last = w13.shape[-2]
+
+ # w13 is combined gate+up weights, so one dimension is 2*inter_size
+ # Determine which dimension is hidden_size and which is 2*inter_size
+ if dim_second_last == hidden_size:
+ # w13 is [E, hidden_size, 2*inter_size] - standard layout
+ inter_size = dim_last // 2
+ needs_transpose = False
+ else:
+ # w13 is [E, 2*inter_size, hidden_size] - needs transpose
+ inter_size = dim_second_last // 2
+ needs_transpose = True
+
+ assert w13.is_contiguous() and w2.is_contiguous()
+
+ # 4bits support [E, N, K]
+ # other types [E, K, N]
+ if not is_int4 and not is_mxfp4:
+ if not hasattr(w13, 'xpu_fused_moe'):
+ if needs_transpose:
+ w13.data = w13.transpose(-1, -2).contiguous()
+ w2.data = w2.transpose(-1, -2).contiguous()
+ w13.xpu_fused_moe = True
+ w13.inter_size = inter_size
+ else:
+ inter_size = w13.inter_size
+
+ if is_int4 and not hasattr(w13, 'xpu_fused_moe'):
+ w13_tmp = torch.empty_like(w13)
+ w2_tmp = torch.empty_like(w2)
+ for i in range(num_experts):
+ w13_tmp[i] = implement_zp(w13[i])
+ w2_tmp[i] = implement_zp(w2[i])
+ w13_tmp = w13_tmp.contiguous()
+ w2_tmp = w2_tmp.contiguous()
+ w13.data = w13_tmp
+ w2.data = w2_tmp
+ w13.xpu_fused_moe = True
+
+ # TODO: will all integrated in Cpp func. Temporary expose before gemm fusion
+ num_experts_per_node = num_experts
+ experts_per_token = n_experts_per_token
+ num_moe_inputs = n_experts_per_token * num_rows
+ permuted_elems = num_moe_inputs * hidden_size
+ # interbuf_elems = num_moe_inputs * inter_size
+ permuted_row_to_unpermuted_row_size = num_moe_inputs * 4
+ permuted_token_selected_experts_size = num_moe_inputs * 4
+ src_to_dest_map_size = experts_per_token * num_rows * 4
+ expert_first_token_offset_size = (num_experts_per_node + 1) * 8
+ num_tokens_per_block = compute_num_tokens_per_block(
+ num_rows, num_experts_per_node)
+ num_blocks_per_seq = ceilDiv(num_rows, num_tokens_per_block)
+ blocked_expert_counts_size = num_experts_per_node * num_blocks_per_seq * 4
+ blocked_expert_counts_cumsum_size = blocked_expert_counts_size
+ blocked_row_to_unpermuted_row_size = num_experts_per_node * num_rows * 4
+ permuted_data_size = permuted_elems * hidden_states.element_size()
+ permuted_token_final_scales_size = num_moe_inputs * 4
+
+ ws_map = {}
+ map_offset = 0
+
+ def config_ws(name, size):
+ nonlocal map_offset
+ if size % 256 != 0:
+ size += 256 - size % 256
+ ws_map[name] = (size, map_offset)
+ map_offset += size
+
+ config_ws("permuted_row_to_unpermuted_row",
+ permuted_row_to_unpermuted_row_size)
+ config_ws("permuted_token_selected_experts",
+ permuted_token_selected_experts_size)
+ config_ws("unpermuted_row_to_permuted_row", src_to_dest_map_size)
+ config_ws("blocked_expert_counts", blocked_expert_counts_size)
+ config_ws("blocked_expert_counts_cumsum",
+ blocked_expert_counts_cumsum_size)
+ config_ws("blocked_row_to_unpermuted_row",
+ blocked_row_to_unpermuted_row_size)
+ config_ws("expert_first_token_offset", expert_first_token_offset_size)
+ config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
+ config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
+
+ workspace = torch.zeros(map_offset,
+ dtype=torch.uint8,
+ device=hidden_states.device)
+ if topk_ids.dtype == torch.int32:
+ topk_ids = topk_ids.to(torch.int64)
+ ops.fused_moe_prologue(
+ input=hidden_states,
+ token_selected_experts=topk_ids,
+ token_final_scales=topk_weights,
+ workspace=workspace,
+ hidden_size=hidden_size,
+ inter_size=inter_size,
+ ep_rank=ep_rank,
+ ep_size=ep_size,
+ num_experts_on_rank=num_experts_per_node)
+
+ expert_first_token_offset_bytes = workspace[
+ ws_map["expert_first_token_offset"][1]:
+ ws_map["expert_first_token_offset"][1] +
+ expert_first_token_offset_size]
+ unpermuted_row_to_permuted_row_bytes = workspace[
+ ws_map["unpermuted_row_to_permuted_row"][1]:
+ ws_map["unpermuted_row_to_permuted_row"][1] +
+ src_to_dest_map_size]
+ permuted_row_to_unpermuted_row_bytes = workspace[
+ ws_map["permuted_row_to_unpermuted_row"][1]:
+ ws_map["permuted_row_to_unpermuted_row"][1] +
+ permuted_row_to_unpermuted_row_size]
+
+ if torch.compiler.is_compiling():
+ expert_first_token_offset = _bytes_to_typed_tensor(
+ expert_first_token_offset_bytes, torch.int64
+ )
+ unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
+ unpermuted_row_to_permuted_row_bytes, torch.int32
+ )
+ permuted_row_to_unpermuted_row = _bytes_to_typed_tensor(
+ permuted_row_to_unpermuted_row_bytes, torch.int32
+ )
+ else:
+ expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
+ unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
+ permuted_row_to_unpermuted_row = permuted_row_to_unpermuted_row_bytes.view(torch.int32)
+ gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
+ ws_map["overlapped_gemm1_gemm2_inputs"][1] +
+ permuted_data_size].view(hidden_states.dtype).view(
+ num_moe_inputs, hidden_size)
+ # permuted_token_final_scales = workspace[
+ # ws_map["permuted_token_final_scales"][1]:
+ # ws_map["permuted_token_final_scales"][1] +
+ # permuted_token_final_scales_size].view(torch.float)
+ gemm1_output = torch.empty((num_moe_inputs, 2 * inter_size),
+ dtype=hidden_states.dtype,
+ device=hidden_states.device)
+
+ ########### gemm1 ##################
+ input_B = w13
+
+ if not is_fp8 and not is_int4 and not is_mxfp4:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=gemm1_input,
+ ptr_B=input_B,
+ ptr_scales=None,
+ ptr_bias=w13_bias,
+ ptr_D=gemm1_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=2 * inter_size,
+ K=hidden_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+ else:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=gemm1_input,
+ ptr_B=input_B,
+ ptr_scales=w13_scales,
+ ptr_bias=w13_bias,
+ ptr_D=gemm1_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=2 * inter_size,
+ K=hidden_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+
+ # act
+ act_output = torch.empty((num_moe_inputs, inter_size),
+ dtype=gemm1_output.dtype,
+ device=gemm1_output.device)
+ if activation == "silu":
+ ops.silu_and_mul(act_output, gemm1_output)
+ elif activation == "gelu":
+ ops.gelu_and_mul(act_output, gemm1_output)
+ elif activation == "swigluoai":
+ ops.swigluoai_and_mul(act_output, gemm1_output, 1.702, 7.0)
+ else:
+ raise ValueError(f"Unsupported FusedMoe activation: {activation}.")
+
+ ########### gemm2 ##################
+ input_A = act_output.contiguous()
+ input_B = w2
+ gemm2_output = torch.empty((num_moe_inputs, hidden_size),
+ dtype=hidden_states.dtype,
+ device=hidden_states.device)
+ if not is_fp8 and not is_int4 and not is_mxfp4:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=None,
+ ptr_bias=w2_bias,
+ ptr_D=gemm2_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=hidden_size,
+ K=inter_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+ else:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=w2_scales,
+ ptr_bias=w2_bias,
+ ptr_D=gemm2_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=hidden_size,
+ K=inter_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+
+ ops.moe_gather(output, gemm2_output, topk_weights,
+ permuted_row_to_unpermuted_row,
+ unpermuted_row_to_permuted_row,
+ expert_first_token_offset,
+ num_experts_per_node)
+ return output
+
+
+def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor:
+ """Apply jitter to the input tensor for regularization."""
+ low = 1.0 - moe_jitter_eps
+ high = 1.0 + moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return x * (low + noise * (high - low))
+
+
+def compute_top_k(scores: torch.Tensor, moe_top_k: int):
+ """Compute the top-k scores from the logits."""
+ if moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, moe_top_k, dim=-1)
+
+
+def route_tokens_xpu(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: torch.Tensor,
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ training: bool = False,
+) -> tuple:
+ """Route tokens to experts and compute expert weights and indices (XPU version)."""
+ if training and moe_jitter_eps is not None:
+ x = apply_jitter(x, moe_jitter_eps)
+
+ x_flat = x.view(-1, x.shape[-1])
+ logits = torch.nn.functional.linear(x_flat, router_weight, router_bias)
+ expert_weights, expert_indices = compute_top_k(logits, moe_top_k)
+ expert_weights = expert_weights.softmax(dim=-1)
+ if moe_normalize_expert_weights is not None:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+
+ return logits, expert_weights, expert_indices
+
+
+def _get_device_mesh(model):
+ """Extract device_mesh from child's unused pre_hook closure for EP support."""
+ try:
+ hook = next(
+ h
+ for h in model.experts._forward_pre_hooks.values()
+ if "device_mesh" in h.__code__.co_freevars
+ )
+ return hook.__closure__[
+ hook.__code__.co_freevars.index("device_mesh")
+ ].cell_contents
+ except Exception:
+ return None
+
+
+class MegaBlocksMoeMLP(torch.nn.Module):
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> tuple:
+ """
+ Forward pass through the MoE layer.
+
+ Args:
+ x: Input tensor of shape [batch_size, seq_len, hidden_size] or [tokens, hidden_size]
+
+ Returns:
+ Tuple of (output, expert_weights) where:
+ - output: Tensor of same shape as input
+ - expert_weights: Expert weights for each token [tokens, top_k]
+ """
+
+ # Get MoE parameters from the wrapped modules
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+
+ # Get EP (Expert Parallelism) parameters
+ ep_size = 1
+ ep_rank = 0
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = _get_device_mesh(self)
+ if device_mesh is not None:
+ expert_parallel_group = device_mesh.get_group()
+ if expert_parallel_group is not None:
+ import torch.distributed as dist
+ if dist.is_initialized():
+ ep_size = dist.get_world_size(expert_parallel_group)
+ ep_rank = dist.get_rank(expert_parallel_group)
+
+ # Number of experts on this rank
+ num_experts_on_rank = moe_num_experts // ep_size
+
+ # Detect activation type - check for GptOss-style swigluoai activation
+ # GptOssExperts has alpha and limit attributes for swigluoai
+ if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
+ activation = "swigluoai"
+ else:
+ activation = getattr(self.experts, "activation", "silu")
+
+ # Get weight tensors - support different naming conventions
+ if hasattr(self.experts, "gate_up_proj"):
+ w13 = self.experts.gate_up_proj
+ # NOTE: swigluoai_and_mul kernel expects interleaved layout [g0,u0,g1,u1,...]
+ # which matches GptOss's gate_up_proj layout, so no conversion needed.
+
+ elif hasattr(self.experts, "w1"):
+ # Combine w1 and w3 if stored separately
+ w1 = self.experts.w1
+ w3 = getattr(self.experts, "w3", None)
+ if w3 is not None:
+ w13 = torch.cat([w1, w3], dim=-2)
+ else:
+ w13 = w1
+ else:
+ raise AttributeError("experts module must have 'gate_up_proj' or 'w1' attribute")
+
+ if hasattr(self.experts, "down_proj"):
+ w2 = self.experts.down_proj
+ elif hasattr(self.experts, "w2"):
+ w2 = self.experts.w2
+ else:
+ raise AttributeError("experts module must have 'down_proj' or 'w2' attribute")
+
+ # Get optional bias tensors
+ w13_bias = getattr(self.experts, "gate_up_proj_bias", None)
+ w2_bias = getattr(self.experts, "down_proj_bias", None)
+
+ # Get quantization info
+ is_fp8 = getattr(self.experts, "is_fp8", False)
+ is_int4 = getattr(self.experts, "is_int4", False)
+ is_mxfp4 = getattr(self.experts, "is_mxfp4", False)
+
+ w13_scales = getattr(self.experts, "gate_up_proj_scales", None)
+ w2_scales = getattr(self.experts, "down_proj_scales", None)
+
+ # Store original shape
+ in_shape = x.size()
+
+ # Route tokens to experts
+ logits, expert_weights, expert_indices = route_tokens_xpu(
+ x,
+ self.router.weight,
+ getattr(self.router, "bias", None),
+ moe_top_k,
+ moe_num_experts,
+ moe_jitter_eps,
+ moe_normalize_expert_weights,
+ self.training,
+ )
+
+ # Reshape input for fused MoE
+ x_flat = x.view(-1, x.shape[-1])
+
+ # Call XPU fused MoE kernel
+ output = xpu_fused_moe(
+ hidden_states=x_flat,
+ w13=w13,
+ w13_scales=w13_scales,
+ w13_bias=w13_bias,
+ w2=w2,
+ w2_scales=w2_scales,
+ w2_bias=w2_bias,
+ topk_weights=expert_weights.float(),
+ topk_ids=expert_indices,
+ n_experts_per_token=moe_top_k,
+ activation=activation,
+ num_experts=num_experts_on_rank,
+ ep_rank=ep_rank,
+ ep_size=ep_size,
+ is_fp8=is_fp8,
+ is_int4=is_int4,
+ is_mxfp4=is_mxfp4,
+ )
+
+ # All-reduce across EP group to combine partial expert outputs
+ if ep_size > 1 and expert_parallel_group is not None:
+ import torch.distributed as dist
+ dist.all_reduce(output, op=dist.ReduceOp.SUM, group=expert_parallel_group)
+
+ # Restore original shape
+ output = output.view(in_shape)
+
+ return output, expert_weights
+
+
+# Export classes and functions
+__all__ = [
+ "MegaBlocksMoeMLP",
+ "xpu_fused_moe",
+ "cutlass_grouped_gemm",
+ "cutlass_grouped_gemm_xe2",
+]
\ No newline at end of file
diff --git a/build/torch210-cxx11-xpu20253-x86_64-linux/__init__.py b/build/torch210-cxx11-xpu20253-x86_64-linux/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..38075732c6d8fa0e1e6ef493145e1aca3851ae6b
--- /dev/null
+++ b/build/torch210-cxx11-xpu20253-x86_64-linux/__init__.py
@@ -0,0 +1,202 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+from ._ops import ops
+
+from .grouped_gemm import backend as gg_backend
+from .grouped_gemm import ops as gg_ops
+
+
+from ._layers.arguments import Arguments
+from ._layers.dmoe import ParallelDroplessMLP, dMoE
+from ._layers.glu import SparseGLU
+from ._layers.mlp import MLP, SparseMLP
+from ._layers.moe import MoE, ParallelMLP, get_load_balancing_loss
+
+from . import layers
+
+# This section contains the direct kernel exports (not inlcuded in the original code)
+def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
+ """
+ Compute exclusive cumulative sum along the specified dimension.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ result = ops.exclusive_cumsum(x, dim)
+ out.copy_(result)
+ return out
+
+
+def inclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
+ """
+ Compute inclusive cumulative sum along the specified dimension.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ result = ops.inclusive_cumsum(x, dim)
+ out.copy_(result)
+ return out
+
+
+def histogram(x: torch.Tensor, num_bins: int) -> torch.Tensor:
+ """
+ Compute histogram of input tensor values.
+
+ Args:
+ x: Input tensor
+ num_bins: Number of histogram bins
+
+ Returns:
+ Histogram tensor with counts for each bin
+ """
+ return ops.histogram(x, num_bins)
+
+
+def indices(
+ padded_bins: torch.Tensor,
+ block_size: int,
+ output_block_rows: int,
+ output_block_columns: int,
+) -> torch.Tensor:
+ """
+ Construct indices from padded bins for sparse operations.
+
+ Args:
+ padded_bins: Tensor containing bin boundaries
+ block_size: Size of each block
+ output_block_rows: Number of rows in output blocks
+ output_block_columns: Number of columns in output blocks
+
+ Returns:
+ Tensor containing constructed indices
+ """
+ return ops.indices(padded_bins, block_size, output_block_rows, output_block_columns)
+
+
+def replicate_forward(
+ x: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Forward pass of replicate operation - replicate values according to bin sizes.
+
+ Args:
+ x: Input tensor with values to replicate
+ bins: Tensor containing bin sizes
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ return ops.replicate_forward(x, bins, out)
+
+
+def replicate_backward(
+ grad: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Backward pass of replicate operation - reduce gradients back to bins.
+
+ Args:
+ grad: Gradient tensor to reduce
+ bins: Tensor containing bin sizes
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ return ops.replicate_backward(grad, bins, out)
+
+
+def sort(
+ x: torch.Tensor, end_bit: int, x_out: torch.Tensor, iota_out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Radix sort with index tracking.
+
+ Args:
+ x: Input tensor to sort
+ end_bit: Number of bits to consider in sorting
+ x_out: Output tensor for sorted values
+ iota_out: Output tensor for sorted indices
+
+ Returns:
+ The sorted values tensor
+ """
+ return ops.sort(x, end_bit, x_out, iota_out)
+
+
+# Convenience functions for common use cases
+def cumsum(x: torch.Tensor, dim: int = -1, exclusive: bool = False) -> torch.Tensor:
+ """
+ Compute cumulative sum with automatic output allocation.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum (default: last dimension)
+ exclusive: Whether to compute exclusive (True) or inclusive (False) cumsum
+
+ Returns:
+ New tensor containing the cumulative sum
+ """
+ out = torch.empty_like(x)
+ if exclusive:
+ return exclusive_cumsum(x, dim, out)
+ else:
+ return inclusive_cumsum(x, dim, out)
+
+
+def argsort(x: torch.Tensor, end_bit: int = 32) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Sort tensor and return both sorted values and indices.
+
+ Args:
+ x: Input tensor to sort
+ end_bit: Number of bits to consider in sorting
+
+ Returns:
+ Tuple of (sorted_values, sorted_indices)
+ """
+ x_out = torch.empty_like(x)
+ iota_out = torch.empty_like(x)
+ sort(x, end_bit, x_out, iota_out)
+ return x_out, iota_out
+
+
+# Export public API
+__all__ = [
+ "MyReplacementLayer",
+ # Direct kernel exports
+ "exclusive_cumsum",
+ "inclusive_cumsum",
+ "histogram",
+ "indices",
+ "replicate_forward",
+ "replicate_backward",
+ "sort",
+ "cumsum",
+ "argsort",
+ # Original exports
+ "Arguments",
+ "ParallelDroplessMLP",
+ "dMoE",
+ "SparseGLU",
+ "MLP",
+ "SparseMLP",
+ "MoE",
+ "ParallelMLP",
+ "get_load_balancing_loss",
+]
diff --git a/build/torch210-cxx11-xpu20253-x86_64-linux/_layers/__init__.py b/build/torch210-cxx11-xpu20253-x86_64-linux/_layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a720e7a2cc4e44636f6e433a2750e945dc38e8b2
--- /dev/null
+++ b/build/torch210-cxx11-xpu20253-x86_64-linux/_layers/__init__.py
@@ -0,0 +1,10 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# from megablocks.layers.dmoe import dMoE
+from .moe import MoE
+
+__all__ = [
+ 'MoE',
+ # 'dMoE',
+]
diff --git a/build/torch210-cxx11-xpu20253-x86_64-linux/_layers/activation_fn.py b/build/torch210-cxx11-xpu20253-x86_64-linux/_layers/activation_fn.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e1d956704840aa4daf7d1d71d24e051567feab9
--- /dev/null
+++ b/build/torch210-cxx11-xpu20253-x86_64-linux/_layers/activation_fn.py
@@ -0,0 +1,33 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Callable, Union
+
+import torch
+from ..stk import Matrix
+
+
+def act_fn(
+ x: Matrix,
+ function: Callable,
+ return_grad_fn: bool = False,
+ **kwargs,
+) -> Union[tuple[Matrix, Any] | Matrix]:
+ assert isinstance(x, Matrix)
+ with torch.set_grad_enabled(torch.is_grad_enabled() or return_grad_fn):
+ if return_grad_fn:
+ x.data.requires_grad = True
+ out = function(x.data, **kwargs)
+ y = Matrix(
+ x.size(),
+ out,
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
+ if return_grad_fn:
+ return y, out.backward
+ return y
diff --git a/build/torch210-cxx11-xpu20253-x86_64-linux/_layers/all_to_all.py b/build/torch210-cxx11-xpu20253-x86_64-linux/_layers/all_to_all.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ac7067bcaa34db1d82b340c43550fe3577aa7a3
--- /dev/null
+++ b/build/torch210-cxx11-xpu20253-x86_64-linux/_layers/all_to_all.py
@@ -0,0 +1,54 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import torch.distributed as dist
+
+
+class AllToAllOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op):
+ out = torch.empty((sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype)
+
+ ctx.input_shape = x.shape
+ ctx.output_split_sizes = output_split_sizes
+ ctx.input_split_sizes = input_split_sizes
+ ctx.group = group
+ handle = dist.all_to_all_single(
+ out,
+ x,
+ output_split_sizes=output_split_sizes,
+ input_split_sizes=input_split_sizes,
+ group=group,
+ async_op=async_op,
+ )
+ return out, handle
+
+ @staticmethod
+ def backward(ctx, grad, _):
+ if ctx.needs_input_grad[0]:
+ out = torch.empty(
+ ctx.input_shape,
+ device=grad.device,
+ dtype=grad.dtype,
+ )
+ dist.all_to_all_single(
+ out,
+ grad,
+ output_split_sizes=ctx.input_split_sizes,
+ input_split_sizes=ctx.output_split_sizes,
+ group=ctx.group,
+ )
+ return out, None, None, None, None
+ return None, None, None, None, None
+
+
+def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False):
+ return AllToAllOp.apply(
+ x,
+ output_split_sizes,
+ input_split_sizes,
+ group,
+ async_op,
+ )
diff --git a/build/torch210-cxx11-xpu20253-x86_64-linux/_layers/arguments.py b/build/torch210-cxx11-xpu20253-x86_64-linux/_layers/arguments.py
new file mode 100644
index 0000000000000000000000000000000000000000..4db9b1bd38bc2e2f421625c124f86b85f45c5ae0
--- /dev/null
+++ b/build/torch210-cxx11-xpu20253-x86_64-linux/_layers/arguments.py
@@ -0,0 +1,101 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import dataclasses
+from functools import partial
+from typing import Any, Callable, Optional, Union
+
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+
+# import megablocks.grouped_gemm_util as grouped_gemm
+from .. import grouped_gemm_util as grouped_gemm
+
+# Type annotation for in-place Tensor initialization function.
+InitFn = Union[Callable[[torch.Tensor], None], partial[torch.Tensor]]
+
+_ALLOWED_BITWIDTHS = (-1, 4, 8)
+
+DEFAULT_ACTIVATION_FN = partial(F.gelu, approximate='tanh')
+
+
+@dataclasses.dataclass
+class Arguments:
+ # Model arguments.
+ hidden_size: int = 1024
+ ffn_hidden_size: int = 4096
+ num_layers: int = 1
+ bias: bool = True
+ return_bias: bool = True
+ activation_fn: Optional[Callable] = DEFAULT_ACTIVATION_FN
+
+ # MoE arguments.
+ moe_num_experts: int = 1
+ moe_top_k: int = 1
+ moe_capacity_factor: int = 1
+ moe_normalize_expert_weights: Optional[Union[int, float]] = None
+ moe_loss_weight: float = 0.1
+ moe_jitter_eps: Optional[float] = None
+ moe_lbl_in_fp32: bool = False
+
+ # Parallelism arguments.
+ moe_expert_model_parallelism: bool = False
+ expert_parallel_group: Optional[dist.ProcessGroup] = None
+ pipeline_model_parallel_size: int = 1
+ num_layers_per_virtual_pipeline_stage: Optional[int] = None
+
+ # Compute arguments.
+ memory_optimized_mlp: bool = False
+ mlp_type: str = 'mlp'
+ mlp_impl: str = 'sparse'
+
+ # Initialization arguments.
+ fp16: bool = True
+ bf16: bool = False
+ device: Union[int, torch.device] = dataclasses.field(default_factory=torch.cuda.current_device)
+ init_method: InitFn = partial(torch.nn.init.normal_, mean=0.0, std=0.02)
+ output_layer_init_method: InitFn = init_method
+
+ # Benchmarking arguments.
+ uniform_expert_assignment: bool = False
+
+ # shared expert arguments
+ shared_expert: bool = False # enable using shared expert
+ fc_cls: Any = torch.nn.Linear # class of the fully connected layer in shared expert (purpose: to allow using custom FC layer eg te.Linear (for FP8))
+ fc_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict,) # kwargs for custom fc layers
+ remat_act_fn: bool = True # enable act fn to be rematerialized instead of stored
+ shared_expert_hidden_size: Optional[
+ int] = None # hidden size of the shared expert IF we want to set it to something different from hidden_size
+ shared_expert_weighted_sum: bool = False # enable using weighted sum for shared expert output (wieghted by number of experts used)
+
+ # Router Z-loss arguments
+ moe_zloss_weight: float = 0 # 1e-3 is a reasonable value
+ moe_zloss_in_fp32: bool = False
+
+ def __post_init__(self):
+ # Sparse MLP is not supported with triton >=3.2.0
+ # TODO: Remove this once sparse is supported with triton >=3.2.0
+ if self.__getattribute__('mlp_impl') == 'sparse':
+ try:
+ import triton
+ if triton.__version__ >= '3.2.0':
+ raise ValueError(
+ 'Sparse MLP is not supported with triton >=3.2.0. Please use mlp_impl="grouped" instead.',
+ )
+ except ImportError:
+ raise ImportError('Triton is required for sparse MLP implementation')
+
+ if self.__getattribute__('mlp_impl') == 'grouped':
+ grouped_gemm.assert_grouped_gemm_is_available()
+
+ if self.shared_expert_hidden_size is None:
+ self.shared_expert_hidden_size = self.ffn_hidden_size
+
+
+def from_megatron(megatron_args: Any):
+ args = Arguments()
+ for field in dataclasses.fields(args):
+ if hasattr(megatron_args, field.name):
+ setattr(args, field.name, getattr(megatron_args, field.name))
+ return args
diff --git a/build/torch210-cxx11-xpu20253-x86_64-linux/_layers/common.py b/build/torch210-cxx11-xpu20253-x86_64-linux/_layers/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d07109702963ba48a3b94ab860807954dfd79c1
--- /dev/null
+++ b/build/torch210-cxx11-xpu20253-x86_64-linux/_layers/common.py
@@ -0,0 +1,26 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+from .arguments import Arguments
+
+
+def dtype(args: Arguments):
+ if args.fp16:
+ return torch.float16
+ elif args.bf16:
+ return torch.bfloat16
+ return None
+
+
+def cast_if_autocast_enabled(tensor):
+ if torch.is_autocast_enabled():
+ if tensor.device.type == 'cuda':
+ dtype = torch.get_autocast_gpu_dtype()
+ elif tensor.device.type == 'cpu':
+ dtype = torch.get_autocast_cpu_dtype()
+ else:
+ raise NotImplementedError()
+ return tensor.to(dtype=dtype)
+ return tensor
diff --git a/build/torch210-cxx11-xpu20253-x86_64-linux/_layers/dmlp_registry.py b/build/torch210-cxx11-xpu20253-x86_64-linux/_layers/dmlp_registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..de2ed047042e438c7190ebb139b6f7f30009734c
--- /dev/null
+++ b/build/torch210-cxx11-xpu20253-x86_64-linux/_layers/dmlp_registry.py
@@ -0,0 +1,42 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Union
+
+from . import glu, mlp
+from .arguments import Arguments
+
+MlpType = Union[mlp.SparseMLP, glu.SparseGLU]
+
+_REGISTRY = {
+ 'mlp': {
+ 'grouped': mlp.GroupedMLP,
+ 'sparse': mlp.SparseMLP,
+ },
+ 'glu': {
+ 'grouped': glu.GroupedGLU,
+ 'sparse': glu.SparseGLU,
+ },
+}
+
+
+def get(args: Arguments) -> MlpType:
+ """Returns an MLP for use in a dMoE instance.
+
+ Uses the provided arguments to instantiate the appropriate
+ MLP instance. This only contains MLPs for use in dMoEs
+ (ie. only for the dropless versions of MoEs).
+
+ Args:
+ args: propagated Arguments dataclass.
+
+ Returns:
+ An instantiated MLP constructed using the input args.
+ """
+ if args.mlp_type not in _REGISTRY:
+ raise ValueError(f'Unsupported mlp type: {args.mlp_type}')
+
+ if args.mlp_impl not in _REGISTRY[args.mlp_type]:
+ raise ValueError(f'{args.mlp_type} does not support {args.mlp_impl} backend.',)
+
+ return _REGISTRY[args.mlp_type][args.mlp_impl](args)
diff --git a/build/torch210-cxx11-xpu20253-x86_64-linux/_layers/dmoe.py b/build/torch210-cxx11-xpu20253-x86_64-linux/_layers/dmoe.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d0375a4df2f27134c4127e60be04f3b45693050
--- /dev/null
+++ b/build/torch210-cxx11-xpu20253-x86_64-linux/_layers/dmoe.py
@@ -0,0 +1,337 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import numpy as np
+import torch
+
+# try:
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/dmoe.py is needed.',
+# )
+
+# import megablocks.ops as ops
+# # from megablocks.ops import ops
+# from megablocks.layers import common, dmlp_registry, moe, mpu
+# from megablocks.layers.arguments import Arguments
+
+from .. import stk
+from .. import ops
+from . import common, dmlp_registry, moe, mpu
+from .arguments import Arguments
+
+def promote_scalar(x):
+ return x.view(1) if not len(x.size()) else x
+
+
+class ParallelDroplessMLP(moe.ParallelMLP):
+
+ def __init__(self, args: Arguments):
+ super(ParallelDroplessMLP, self).__init__(args)
+ self.hidden_size = args.hidden_size
+ self.ffn_hidden_size = mpu.features_per_rank(args)
+ self.blocking = 128
+ self.mlp = dmlp_registry.get(args)
+
+ # Calculate the number of bits needed to represent the column indices
+ # in the intermediate sparse matrix.
+ max_column_index = ((self.ffn_hidden_size * self.num_experts) // self.blocking)
+ self.transpose_sort_end_bit = max(
+ int(np.ceil(np.log2(max_column_index))),
+ 1,
+ )
+
+ def sparse_transpose(self, size, row_indices, column_indices, offsets):
+ block_columns = size[1] // self.blocking
+
+ # Sort row indices by column indices to get the transposed matrix's
+ # column indices.
+ #
+ # NOTE: Our sort operation uses the same width indices as the input values.
+ # To avoid overflow when we have large activation matrices we cast to
+ # 32-bit before sorting.
+ _, gather_indices = ops.sort(
+ column_indices.int(),
+ self.transpose_sort_end_bit,
+ )
+
+ # There are a constant number of blocks in every row of the sparse matrix.
+ # A blocks offset is:
+ #
+ # row_index * blocks_per_row + column_index % blocks_per_row
+ #
+ # Once we have the block offsets ordered for transposition we can divide
+ # by blocks_per_row to get the transposed column indices.
+ column_indices_t = row_indices.gather(0, gather_indices.long())
+ block_offsets_t = gather_indices.int()
+
+ zero = torch.zeros((1,), dtype=torch.int32, device=row_indices.device)
+ nnz_per_column = ops.histogram(column_indices, block_columns)
+ nnz_per_column = ops.inclusive_cumsum(nnz_per_column, 0)
+ if nnz_per_column.dim() == 0:
+ # This addresses an edge case when ffn_hidden_size is equal to self.blocking.
+ nnz_per_column = nnz_per_column.unsqueeze(0)
+ offsets_t = torch.cat([zero, nnz_per_column])
+ return column_indices_t, offsets_t, block_offsets_t
+
+ def topology(self, x, padded_bins):
+ padded_tokens, _ = x.size()
+ assert padded_tokens % self.blocking == 0
+ if self.ffn_hidden_size % self.blocking != 0:
+ raise ValueError(
+ f'The ffn_hidden_size {self.ffn_hidden_size} must be divisible by ' +
+ f'the block size {self.blocking}. Please update your configuration.',
+ )
+
+ # Offsets for the sparse matrix. All rows have the
+ # same number of nonzero blocks dictated by the
+ # dimensionality of a single expert.
+ block_rows = padded_tokens // self.blocking
+ blocks_per_row = self.ffn_hidden_size // self.blocking
+ offsets = torch.arange(
+ 0,
+ block_rows * blocks_per_row + 1,
+ blocks_per_row,
+ dtype=torch.int32,
+ device=x.device,
+ )
+
+ # Indices for the sparse matrix. The indices for
+ # the intermediate matrix are dynamic depending
+ # on the mapping of tokens to experts.
+ column_indices = ops.topology(
+ padded_bins,
+ self.blocking,
+ block_rows,
+ blocks_per_row,
+ )
+
+ # TODO(tgale): This is unused. Remove the need for this in stk.
+ # For now, use meta init to save the device memory.
+ data = torch.empty(
+ column_indices.numel(),
+ self.blocking,
+ self.blocking,
+ dtype=common.dtype(self.args),
+ device='meta',
+ )
+ shape = (
+ padded_tokens,
+ self.ffn_hidden_size * mpu.experts_per_rank(self.args),
+ )
+ row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
+ column_indices_t, offsets_t, block_offsets_t = self.sparse_transpose(
+ shape,
+ row_indices,
+ column_indices,
+ offsets,
+ )
+ return stk.Matrix(
+ shape,
+ data,
+ row_indices,
+ column_indices,
+ offsets,
+ column_indices_t,
+ offsets_t,
+ block_offsets_t,
+ )
+
+ def indices_and_padded_bins(self, top_experts):
+ # Sort the expert ids to produce the scatter/gather
+ # indices for the permutation.
+ top_experts = top_experts.int()
+ bin_ids, indices = ops.sort(top_experts, self.sort_end_bit)
+
+ # Histogram the expert ids to identify the number of
+ # tokens routed to each expert.
+ tokens_per_expert = ops.histogram(top_experts, self.num_experts)
+
+ # Round the token counts up to the block size used in
+ # the matrix muliplications. Caculate the starting
+ # position of each bin.
+ padded_tokens_per_expert = ops.round_up(
+ tokens_per_expert,
+ self.blocking,
+ )
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ padded_bins = promote_scalar(padded_bins)
+
+ # Calculate the bin bounds for the sorted tokens.
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ bins = promote_scalar(bins)
+ return indices, bin_ids, bins, padded_bins, tokens_per_expert
+
+ def sparse_forward_once(self, x, expert_weights, top_experts):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, padded_bins, tokens_per_expert = (self.indices_and_padded_bins(top_experts))
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.padded_gather(
+ x,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ self.top_k,
+ )
+
+ # Create the sparse matrix topology.
+ with torch.no_grad():
+ topo = self.topology(x, padded_bins)
+
+ # Perform the expert computation.
+ x = self.mlp(x, topo)
+
+ # Un-route the data for the MoE output.
+ x = ops.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ padded_bins,
+ self.top_k,
+ )
+ return x, tokens_per_expert
+
+ # For use in the base-class parallel_forward_once.
+ def sparse_permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy, # unused
+ top_k,
+ ):
+
+ # Round the token counts up to the block size used in the matrix
+ # multiplication. Calculate the starting position of each bin.
+ padded_tokens_per_expert = ops.round_up(
+ tokens_per_expert,
+ self.blocking,
+ )
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ padded_bins = promote_scalar(padded_bins)
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
+
+ # Create the sparse matrix topology.
+ with torch.no_grad():
+ topo = self.topology(x, padded_bins)
+
+ # Perform the expert computation.
+ x = self.mlp(x, topo)
+
+ # Un-route the data for the MoE output.
+ return ops.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ def grouped_forward_once(self, x, expert_weights, top_experts):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ out = self.grouped_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ -1, # unused
+ self.args.moe_top_k,
+ )
+ return out, tokens_per_expert
+
+ def grouped_permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy, # unused
+ top_k,
+ ):
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.gather(x, indices, bin_ids, bins, top_k)
+
+ # Perform the expert computation.
+ x = self.mlp(x, tokens_per_expert)
+
+ # Un-route the data for the MoE output.
+ return ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
+
+ def forward_once(self, x, expert_weights, top_experts):
+ if self.args.mlp_impl == 'sparse':
+ return self.sparse_forward_once(x, expert_weights, top_experts)
+ else:
+ return self.grouped_forward_once(x, expert_weights, top_experts)
+
+ def permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ ):
+ if self.args.mlp_impl == 'sparse':
+ return self.sparse_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ )
+ else:
+ return self.grouped_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ )
+
+
+class dMoE(moe.MoE):
+
+ def _init_experts_mlp(self, args: Arguments):
+ return ParallelDroplessMLP(args)
diff --git a/build/torch210-cxx11-xpu20253-x86_64-linux/_layers/gelu.py b/build/torch210-cxx11-xpu20253-x86_64-linux/_layers/gelu.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4c9e6532798615b5c12c96694241a4c18ee8f7b
--- /dev/null
+++ b/build/torch210-cxx11-xpu20253-x86_64-linux/_layers/gelu.py
@@ -0,0 +1,52 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# try:
+# import stk
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/gelu.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+import torch.nn.functional as F
+
+
+@torch.jit.script
+def _gelu_backward_inplace(g, x):
+ tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
+ ff = (0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out))
+ return g.mul_(ff)
+
+
+def gelu_backward_(grad: stk.Matrix, x: stk.Matrix):
+ # NOTE: The two sparse matrices must have the same topology.
+ if isinstance(grad, stk.Matrix) and isinstance(x, stk.Matrix):
+ return stk.Matrix(
+ x.size(),
+ _gelu_backward_inplace(grad.data, x.data),
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
+ return _gelu_backward_inplace(grad, x)
+
+
+def gelu(x: stk.Matrix):
+ assert isinstance(x, stk.Matrix)
+ return stk.Matrix(
+ x.size(),
+ F.gelu(x.data, approximate='tanh'),
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
diff --git a/build/torch210-cxx11-xpu20253-x86_64-linux/_layers/glu.py b/build/torch210-cxx11-xpu20253-x86_64-linux/_layers/glu.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f297a41ff6a1a2a285f5b461951672364b898da
--- /dev/null
+++ b/build/torch210-cxx11-xpu20253-x86_64-linux/_layers/glu.py
@@ -0,0 +1,244 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# import stk.ops
+# try:
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/glu.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+
+# from megablocks import grouped_gemm_util as gg
+# from megablocks.layers import common, mpu
+# from megablocks.layers.activation_fn import act_fn
+# from megablocks.layers.arguments import Arguments
+# from megablocks.layers.mlp import (
+# SharedMLP,
+# SparseMLP,
+# create_dmoe_expert_weights,
+# resolve_dtensor,
+# )
+
+from .. import grouped_gemm_util as gg
+from . import common, mpu
+from .activation_fn import act_fn
+from .arguments import Arguments
+from .mlp import (
+ SharedMLP,
+ SparseMLP,
+ create_dmoe_expert_weights,
+ resolve_dtensor,
+)
+
+
+class SparseGLU(SparseMLP):
+
+ def __init__(self, args: Arguments):
+ super().__init__(args)
+ self.v1 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ with torch.no_grad():
+ self.v1.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ ),
+ )
+
+ mpu.set_expert_model_parallel_attributes(
+ self.v1,
+ self._should_set_parallelism_attribute,
+ )
+
+ def forward(self, x, topo):
+ if self.args.memory_optimized_mlp:
+ raise NotImplementedError(
+ 'Memory optimized implementation not yet supported with GLU with sparse kernels.',
+ )
+
+ w1, v1, w2 = self.scale_grad(self.w1), self.scale_grad(self.v1,), self.scale_grad(self.w2)
+ w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2)
+
+ # Compute the GLU.
+ x1 = stk.ops.sdd(x, w1.t(), topo)
+ x2 = stk.ops.sdd(x, v1.t(), topo)
+
+ activation_fn_out = act_fn(x1, self.args.activation_fn)
+ x1 = stk.ops.mul(activation_fn_out, x2)
+
+ return stk.ops.dsd(x1, w2)
+
+
+class MemoryOptimizedGroupedGLU(torch.autograd.Function):
+ """GroupedMLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, v1, w2, batch_sizes, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ v1 = v1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], v1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not v1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1', 'v1' and 'w2'.")
+
+ # Layer 0: x @ w1.t().
+ assert gg.backend is not None
+ sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True)
+ v1_out = gg.backend.gmm(x, v1, batch_sizes, trans_b=True)
+
+ # GeLU.
+ activation_fn_out = activation_fn(sdd_out) * v1_out
+
+ # Layer 1: x @ w2.
+ dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, v1, w2, batch_sizes, x, sdd_out, v1_out)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx, ddsd_out):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # Unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, v1, w2 = saved_tensors[:3]
+ batch_sizes = saved_tensors[3]
+ x = saved_tensors[4]
+ sdd_out, v1_out = saved_tensors[5:7]
+
+ # Rematerialize activation_fn output.
+ activation_fn = ctx.activation_fn
+ with torch.set_grad_enabled(True):
+ sdd_out.requires_grad = True
+ v1_out.requires_grad = True
+ activation_fn_out = activation_fn(sdd_out) * v1_out
+ activation_grad_fn = activation_fn_out.backward
+
+ # Compute dw2 with recomputed activation_fn output.
+ assert gg.backend is not None
+ dw2 = gg.backend.gmm(
+ activation_fn_out,
+ ddsd_out,
+ batch_sizes,
+ trans_a=True,
+ )
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ gg.backend.gmm(
+ ddsd_out,
+ w2,
+ batch_sizes,
+ trans_b=True,
+ c=dactivation_fn_out,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out)
+ dsdd_out = sdd_out.grad
+ dv1_out = v1_out.grad
+
+ # Compute dw1.
+ dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True)
+
+ # Compute dv1.
+ dv1 = gg.backend.gmm(dv1_out, x, batch_sizes, trans_a=True)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ dx = ddsd_out
+ gg.backend.gmm(dsdd_out, w1, batch_sizes, c=dx)
+ dx += gg.backend.gmm(dv1_out, v1, batch_sizes)
+ return dx, dw1, dv1, dw2, None, None
+
+
+memory_optimized_grouped_glu = MemoryOptimizedGroupedGLU.apply
+
+
+class GroupedGLU(SparseGLU):
+
+ def forward(self, x, tokens_per_expert):
+ batch_sizes = tokens_per_expert.cpu().to(torch.long)
+ w1, v1, w2 = (
+ self.scale_grad(self.w1),
+ self.scale_grad(self.v1),
+ self.scale_grad(self.w2),
+ )
+ w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2)
+
+ # Re-shape the weights for the grouped GEMMs.
+ ne = mpu.experts_per_rank(self.args)
+ w1 = w1.view(ne, -1, self.args.hidden_size)
+ v1 = v1.view(ne, -1, self.args.hidden_size)
+ w2 = w2.view(ne, -1, self.args.hidden_size)
+
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_grouped_glu(
+ x,
+ w1,
+ v1,
+ w2,
+ batch_sizes,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ assert gg.ops is not None
+ x1 = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
+ x2 = gg.ops.gmm(x, v1, batch_sizes, trans_b=True)
+ x1 = self.args.activation_fn(x1) * x2
+ return gg.ops.gmm(x1, w2, batch_sizes)
+
+
+class SharedGLU(SharedMLP):
+ """GPU for shared expert.
+
+ Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTGLU class
+ """
+
+ def __init__(self, args: Arguments):
+ super().__init__(args)
+ self.gate_proj = args.fc_cls(
+ args.hidden_size,
+ self.args.shared_expert_hidden_size,
+ **self.fc_kwargs,
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))
diff --git a/build/torch210-cxx11-xpu20253-x86_64-linux/_layers/memory_test.py b/build/torch210-cxx11-xpu20253-x86_64-linux/_layers/memory_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..74d1166931b712635131985b25a89f4ca23e576d
--- /dev/null
+++ b/build/torch210-cxx11-xpu20253-x86_64-linux/_layers/memory_test.py
@@ -0,0 +1,103 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import gc
+
+import torch
+import torch.distributed as dist
+
+# from megablocks.layers import arguments, dmoe
+from . import arguments, dmoe
+
+_TESTS = ((8, 2048, 4096, 4096, 32, 4),)
+
+
+def get_tensors():
+ ptrs = set()
+ out = []
+ for obj in gc.get_objects():
+ if torch.is_tensor(obj):
+ if not obj.is_contiguous() or obj.data_ptr() in ptrs:
+ continue
+ out.append(obj)
+ ptrs.add(obj.data_ptr())
+ return out
+
+
+def test_memory(
+ group,
+ batch_size,
+ sequence_length,
+ hidden_size,
+ ffn_hidden_size,
+ num_experts,
+ top_k,
+):
+ args = arguments.Arguments(
+ hidden_size=hidden_size,
+ ffn_hidden_size=ffn_hidden_size,
+ moe_num_experts=num_experts,
+ moe_top_k=top_k,
+ moe_expert_model_parallelism=True,
+ expert_parallel_group=group,
+ fp16=False,
+ bf16=True,
+ device=torch.cuda.current_device(),
+ )
+ layer = dmoe.dMoE(args).cuda()
+
+ x = torch.randn((batch_size, sequence_length, hidden_size),
+ device=torch.cuda.current_device(),
+ dtype=torch.bfloat16).requires_grad_(True)
+ torch.cuda.empty_cache()
+
+ # Run forward + backward.
+ # with torch.autograd.detect_anomaly():
+ out, _ = layer(x)
+ out.mean().backward()
+
+ # Report peak memory.
+ mem = torch.cuda.max_memory_allocated()
+ print('Max Memory Allocated = {:0.0f}MiB'.format(mem / 1e6))
+ print('Max Memory Reserved = {:0.0f}MiB'.format(torch.cuda.max_memory_reserved() / 1e6,),)
+
+ # Calculate weight and gradient memory usage.
+ weight_memory = 2 * (
+ layer.router.layer.weight.numel() + layer.experts.mlp.w1.numel() + layer.experts.mlp.w2.numel()
+ )
+
+ def grad_numel(x):
+ if x.grad is not None:
+ return x.grad.numel()
+ return 0
+
+ grad_memory = 2 * (
+ grad_numel(layer.router.layer.weight) + grad_numel(layer.experts.mlp.w1) + grad_numel(layer.experts.mlp.w2)
+ )
+ weight_memory += grad_memory
+
+ print('Weight Memory Allocated = {:0.0f}MiB'.format(weight_memory / 1e6))
+ print('Activation Memory Allocated = {:0.0f}MiB'.format((mem - weight_memory) / 1e6,),)
+
+ # Manually calculate GPU memory usage from the garbage
+ # collector.
+ gc.collect()
+ total = 0
+ tensors = get_tensors()
+ tensors = sorted(tensors, key=lambda x: -x.numel())
+ for i, t in enumerate(tensors):
+ total += t.numel()
+ print(f'{i}: {t.shape}, {t.numel() * 2}')
+ del tensors
+
+ print('Total Bytes Found = {:0.0f}MiB'.format(total * 2 / 1e6))
+
+
+if __name__ == '__main__':
+ assert dist.is_available()
+ group = dist.init_process_group(backend='nccl')
+ local_rank = dist.get_rank(group)
+ torch.cuda.set_device(local_rank)
+
+ for args in _TESTS:
+ test_memory(group, *args)
diff --git a/build/torch210-cxx11-xpu20253-x86_64-linux/_layers/mlp.py b/build/torch210-cxx11-xpu20253-x86_64-linux/_layers/mlp.py
new file mode 100644
index 0000000000000000000000000000000000000000..c99afb9904c24a8b6a83e79059cd1251dbbfd99e
--- /dev/null
+++ b/build/torch210-cxx11-xpu20253-x86_64-linux/_layers/mlp.py
@@ -0,0 +1,587 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# try:
+# import stk
+# import stk.backend.triton_kernels
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/mlp.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+from packaging import version
+
+# from megablocks import grouped_gemm_util as gg
+# from megablocks.layers import common, gelu, mpu
+# from megablocks.layers.activation_fn import act_fn
+# from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
+
+from .. import grouped_gemm_util as gg
+from . import common, gelu, mpu
+from .activation_fn import act_fn
+from .arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
+
+class ScaleGradient(torch.autograd.Function):
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx: Any, x: torch.Tensor, scale: float):
+ ctx.scale = scale
+ return x
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx: torch.Tensor, grad: torch.Tensor):
+ return grad * ctx.scale, None
+
+
+scale_gradient = ScaleGradient.apply
+
+
+def resolve_dtensor(weight: torch.Tensor):
+ if version.parse(torch.__version__) >= version.parse('2.0.0'):
+ from torch.distributed._tensor import DTensor
+ if isinstance(weight, DTensor):
+ return weight.to_local()
+ return weight
+
+
+def create_moe_expert_weights(
+ args: Arguments,
+ num_experts: int,
+ ffn_hidden_size: int,
+ hidden_size: int,
+ init_method: InitFn,
+):
+ # Create the entire weight matrix such that the sampled weights will
+ # not vary between data parallelism and expert model parallelism for
+ # the same random seed.
+ master_weights = torch.empty(
+ num_experts,
+ ffn_hidden_size,
+ hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ )
+ init_method(master_weights)
+
+ if not args.moe_expert_model_parallelism:
+ return master_weights
+
+ # Calculate the amount of sharding in each dimension.
+ expert_sharding_degree = mpu.expert_sharding_degree(args)
+ hidden_sharding_degree = mpu.hidden_sharding_degree(args)
+
+ # Calculate the experts per rank.
+ #
+ # NOTE: We assign ranks to be expert parallel before going
+ # tensor parallel.
+ rank = mpu.get_expert_parallel_rank(args)
+ expert_rank = rank % expert_sharding_degree
+ num_experts_per_rank = num_experts // expert_sharding_degree
+ start_expert = expert_rank * num_experts_per_rank
+ end_expert = (expert_rank + 1) * num_experts_per_rank
+
+ # Calculate the rows per rank.
+ row_rank = rank // expert_sharding_degree
+ num_rows_per_rank = ffn_hidden_size // hidden_sharding_degree
+ start_row = row_rank * num_rows_per_rank
+ end_row = (row_rank + 1) * num_rows_per_rank
+
+ # Slice the weight matrix to get the chunk for this rank.
+ with torch.no_grad():
+ weights = master_weights[start_expert:end_expert, start_row:end_row]
+ return weights
+
+
+class MLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ # expert_parallel_world_size = mpu.get_expert_parallel_world_size(args)
+ experts_per_rank = mpu.experts_per_rank(args)
+
+ self.w1 = torch.nn.Parameter(
+ torch.empty(
+ experts_per_rank,
+ args.hidden_size,
+ mpu.features_per_rank(args),
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ self.w2 = torch.nn.Parameter(
+ torch.empty(
+ experts_per_rank,
+ mpu.features_per_rank(args),
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w1,
+ args.moe_expert_model_parallelism,
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w2,
+ args.moe_expert_model_parallelism,
+ )
+
+ # Initialize the parameters for the MLP.
+ #
+ # NOTE: It is important that we create the weight tensors prior
+ # to creating the master weights and slicing our the piece for
+ # this rank. If the master weights are created first the PyTorch
+ # caching allocator appears to use the same memory block for these
+ # and the slice which causes large increases in our peak memory
+ # usage.
+ with torch.no_grad():
+ w1 = create_moe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ )
+ self.w1.copy_(w1.transpose(1, 2).contiguous())
+ self.w2.copy_(
+ create_moe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.output_layer_init_method,
+ ),
+ )
+
+ self.gradient_scale = None
+ if self.args.moe_expert_model_parallelism:
+ self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,)
+
+ def scale_grad(self, w):
+ if self.gradient_scale is None:
+ return w
+ return scale_gradient(w, self.gradient_scale)
+
+ def forward(self, x):
+ w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2)
+ w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2)
+ x = torch.bmm(x, w1)
+ x = self.args.activation_fn(x)
+ return torch.bmm(x, w2)
+
+
+def create_dmoe_expert_weights(
+ args: Arguments,
+ num_experts: int,
+ rows: int,
+ columns: int,
+ init_method: InitFn,
+):
+ weights = create_moe_expert_weights(
+ args,
+ num_experts,
+ rows,
+ columns,
+ init_method,
+ )
+ return weights.view([-1, columns])
+
+
+class MemoryOptimizedMLP(torch.autograd.Function):
+ """Sparse MLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, w2, topo, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.")
+
+ topo_tensors = (
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets,
+ topo.column_indices_t,
+ topo.offsets_t,
+ topo.block_offsets_t,
+ )
+
+ # Layer 0: x @ w1.t().
+ sdd_out = stk.ops.sdd(x, w1.t(), topo)
+
+ # GeLU.
+ activation_fn_out = act_fn(sdd_out, activation_fn)
+
+ # Layer 1: x @ w2.
+ dsd_out = stk.ops.dsd(activation_fn_out, w2)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.shape = topo.shape
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.data.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, w2, *topo_tensors, x, sdd_out.data)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx, ddsd_out):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, w2 = saved_tensors[:2]
+ topo_tensors = saved_tensors[2:8]
+ x = saved_tensors[8]
+ sdd_out_data = saved_tensors[9]
+
+ # rematerialize activation function output
+ activation_fn = ctx.activation_fn
+ sdd_out = stk.Matrix(ctx.shape, sdd_out_data, *topo_tensors)
+ activation_fn_out, activation_grad_fn = act_fn(
+ sdd_out,
+ activation_fn,
+ return_grad_fn=True,
+ )
+
+ # Compute dw2 with recomputed activation_fn output.
+ dw2 = stk.ops.dsd(activation_fn_out.t(), ddsd_out)
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ stk.backend.triton_kernels.sdd(
+ ddsd_out,
+ w2.t(),
+ dactivation_fn_out.shape,
+ dactivation_fn_out.data,
+ dactivation_fn_out.offsets,
+ dactivation_fn_out.row_indices,
+ dactivation_fn_out.column_indices,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ if activation_fn is DEFAULT_ACTIVATION_FN:
+ dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out)
+ else:
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out.data)
+ dsdd_out = stk.Matrix(ctx.shape, sdd_out.data.grad, *topo_tensors)
+
+ # Compute dw1.
+ dw1 = stk.ops.dsd(dsdd_out.t(), x)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ stk.backend.triton_kernels.dsd(
+ dsdd_out.shape,
+ dsdd_out.data,
+ dsdd_out.offsets,
+ dsdd_out.row_indices,
+ dsdd_out.column_indices,
+ dsdd_out.offsets_t,
+ dsdd_out.column_indices_t,
+ dsdd_out.block_offsets_t,
+ False,
+ w1,
+ ddsd_out,
+ )
+ dx = ddsd_out
+ return dx, dw1, dw2, None, None
+
+
+memory_optimized_mlp = MemoryOptimizedMLP.apply
+
+
+class SparseMLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ self._num_rows_per_rank = mpu.experts_per_rank(args) * mpu.features_per_rank(args)
+
+ self.w1 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ self.w2 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+
+ # Initialize the parameters for the MLP.
+ #
+ # NOTE: It is important that we create the weight tensors prior
+ # to creating the master weights and slicing our the piece for
+ # this rank. If the master weights are created first the PyTorch
+ # caching allocator appears to use the same memory block for these
+ # and the slice which causes large increases in our peak memory
+ # usage.
+ with torch.no_grad():
+ self.w1.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ ),
+ )
+ self.w2.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.output_layer_init_method,
+ ),
+ )
+
+ self._should_set_parallelism_attribute = args.moe_expert_model_parallelism
+ mpu.set_expert_model_parallel_attributes(
+ self.w1,
+ self._should_set_parallelism_attribute,
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w2,
+ self._should_set_parallelism_attribute,
+ )
+
+ self.gradient_scale = None
+ if self.args.moe_expert_model_parallelism:
+ self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,)
+
+ def scale_grad(self, w):
+ if self.gradient_scale is None:
+ return w
+ return scale_gradient(w, self.gradient_scale)
+
+ def forward(self, x, topo):
+ w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2)
+ w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2)
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_mlp(
+ x,
+ w1,
+ w2,
+ topo,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ x = stk.ops.sdd(x, w1.t(), topo)
+ activation_fn_out = act_fn(x, self.args.activation_fn)
+ return stk.ops.dsd(activation_fn_out, w2)
+
+
+class MemoryOptimizedGroupedMLP(torch.autograd.Function):
+ """GroupedMLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, w2, batch_sizes, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.")
+
+ # Layer 0: x @ w1.t().
+ assert gg.backend is not None
+ sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True)
+
+ # activation_fn
+ activation_fn_out = activation_fn(sdd_out)
+
+ # Layer 1: x @ w2.
+ dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, w2, batch_sizes, x, sdd_out)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx: Any, ddsd_out: torch.Tensor):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # Unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, w2 = saved_tensors[:2]
+ batch_sizes = saved_tensors[2]
+ x = saved_tensors[3]
+ sdd_out = saved_tensors[4]
+
+ # Rematerialize activation_fn output.
+ activation_fn = ctx.activation_fn
+ with torch.set_grad_enabled(True):
+ sdd_out.requires_grad = True
+ activation_fn_out = activation_fn(sdd_out)
+ activation_grad_fn = activation_fn_out.backward
+
+ # Compute dw2 with recomputed activation_fn output.
+ assert gg.backend is not None
+ dw2 = gg.backend.gmm(
+ activation_fn_out,
+ ddsd_out,
+ batch_sizes,
+ trans_a=True,
+ )
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ gg.backend.gmm(
+ ddsd_out,
+ w2,
+ batch_sizes,
+ trans_b=True,
+ c=dactivation_fn_out,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ if activation_fn is DEFAULT_ACTIVATION_FN:
+ dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out)
+ else:
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out)
+ dsdd_out = sdd_out.grad
+
+ # Compute dw1.
+ dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ gg.backend.gmm(dsdd_out, w1, batch_sizes, c=ddsd_out)
+ dx = ddsd_out
+ return dx, dw1, dw2, None, None
+
+
+memory_optimized_grouped_mlp = MemoryOptimizedGroupedMLP.apply
+
+
+class GroupedMLP(SparseMLP):
+
+ def forward(self, x, tokens_per_expert):
+ batch_sizes = tokens_per_expert.cpu().to(torch.long)
+ w1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.w2))
+
+ # Re-shape the weights for the grouped GEMMs.
+ ne = mpu.experts_per_rank(self.args)
+ w1 = resolve_dtensor(w1).view(ne, -1, self.args.hidden_size)
+ w2 = resolve_dtensor(w2).view(ne, -1, self.args.hidden_size)
+
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_grouped_mlp(
+ x,
+ w1,
+ w2,
+ batch_sizes,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ assert gg.ops is not None
+ x = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
+ x = self.args.activation_fn(x)
+ return gg.ops.gmm(x, w2, batch_sizes)
+
+
+class SharedMLP(torch.nn.Module):
+ """MLP for shared expert.
+
+ Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTMLP class
+ """
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ self.fc_kwargs: dict[str, Any] = {
+ 'bias': args.bias,
+ 'device': args.device,
+ }
+ self.fc_kwargs.update(args.fc_kwargs)
+
+ self.up_proj = args.fc_cls(
+ args.hidden_size,
+ args.shared_expert_hidden_size,
+ **self.fc_kwargs,
+ )
+ self.act = args.activation_fn
+ self.down_proj = args.fc_cls(
+ args.shared_expert_hidden_size,
+ args.hidden_size,
+ **self.fc_kwargs,
+ )
+ self.down_proj._is_residual = True # a flag for llm-foundry init
+
+ def add_experts_sharedexpert(
+ self,
+ shared_expert_out: torch.Tensor,
+ expert_out: torch.Tensor,
+ ) -> torch.Tensor:
+ # Helper function to add expert output to shared expert output
+ # with optional weighted sum.
+ if self.args.shared_expert_weighted_sum:
+ # enable using weighted sum for shared expert output
+ # wieghted by number of experts used
+ t_experts = self.args.moe_top_k + 1
+ sh_mlp_out = shared_expert_out / t_experts
+ return sh_mlp_out.add(
+ expert_out,
+ alpha=(self.args.moe_top_k / t_experts),
+ )
+
+ return shared_expert_out + expert_out
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.down_proj(self.act(self.up_proj(x)))
diff --git a/build/torch210-cxx11-xpu20253-x86_64-linux/_layers/moe.py b/build/torch210-cxx11-xpu20253-x86_64-linux/_layers/moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0a4aeaacc9c86fc70944e730c53f7a55644e05e
--- /dev/null
+++ b/build/torch210-cxx11-xpu20253-x86_64-linux/_layers/moe.py
@@ -0,0 +1,507 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Optional, Tuple
+
+import numpy as np
+import torch
+import torch.distributed as dist
+
+# import megablocks.ops as ops
+# from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry
+# from megablocks.layers.all_to_all import all_to_all
+# from megablocks.layers.arguments import Arguments
+
+from ..ops import (
+ sort,
+ histogram,
+ inclusive_cumsum,
+ exclusive_cumsum,
+ binned_gather,
+ binned_scatter,
+ gather,
+ scatter,
+ repeat,
+ replicate,
+)
+
+from . import common, mlp, mpu, router, sharedexpert_registry
+from .arguments import Arguments
+from .all_to_all import all_to_all
+
+_LOAD_BALANCING_LOSS = []
+
+
+def save_load_balancing_loss(loss):
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.append(loss)
+
+
+def get_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ return _LOAD_BALANCING_LOSS
+
+
+def clear_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.clear()
+
+
+def batched_load_balancing_loss(args: Arguments):
+ if args.moe_loss_weight == 0:
+ return 0.0
+
+ # tokens_per_expert[i].shape = (num_experts)
+ # expert_scores[i].shape = (tokens, num_experts)
+ tokens_per_expert, expert_scores = zip(*get_load_balancing_loss())
+ num_layers_per_pipeline_stage = (args.num_layers // args.pipeline_model_parallel_size)
+ if args.num_layers_per_virtual_pipeline_stage is not None:
+ num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage
+
+ if len(tokens_per_expert) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f'Expected {num_layers_per_pipeline_stage} token_per_experts '
+ f'but found {len(tokens_per_expert)}.\nnum_layers = '
+ f'{args.num_layers}\npipeline_model_parallel_size = '
+ f'{args.pipeline_model_parallel_size}\n'
+ 'num_layers_per_virtual_pipeline_stage'
+ f' = {args.num_layers_per_virtual_pipeline_stage}',
+ )
+ if len(expert_scores) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f'Expected {num_layers_per_pipeline_stage} expert_scores '
+ f'but found {len(tokens_per_expert)}.\nnum_layers = '
+ f'{args.num_layers}\npipeline_model_parallel_size = '
+ f'{args.pipeline_model_parallel_size}\n'
+ 'num_layers_per_virtual_pipeline_stage'
+ f' = {args.num_layers_per_virtual_pipeline_stage}',
+ )
+
+ # Verify the shape of the tokens_per_expert and expert_scores tensors.
+ assert all((x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert))
+
+ tokens = expert_scores[0].shape[0]
+ assert all(((x.ndim == 2 and x.shape[1] == args.moe_num_experts and x.shape[0] == tokens) for x in expert_scores))
+
+ # Concatenate the contributions of each layer and convert to
+ # the correct types and formats for the dot product.
+ expert_scores = torch.cat(expert_scores, dim=1)
+ if args.moe_lbl_in_fp32:
+ expert_scores = expert_scores.float()
+ if tokens != 0:
+ expert_scores = expert_scores.mean(dim=0)
+ else:
+ expert_scores = expert_scores.sum(dim=0)
+ tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype)
+
+ expected_values = num_layers_per_pipeline_stage * args.moe_num_experts
+ assert tokens_per_expert.numel() == expected_values
+ assert expert_scores.numel() == expected_values
+
+ # Calculate the total scale across all factors.
+ #
+ # loss_weight * num_experts / (num_layers * tokens * top_k)
+ scale_numerator = (args.moe_num_experts * args.moe_loss_weight)
+ scale_denominator = (args.num_layers * tokens * args.moe_top_k)
+ scale = scale_numerator / scale_denominator
+ return scale * torch.dot(tokens_per_expert, expert_scores)
+
+
+# NOTE: This class defines MoE expert computation, including expert model parallel
+# communication. When using FSDP on top of MegaBlocks this is the module that should
+# be wrapped s.t. the weight all-gathers can be scheduled *before* the expert model
+# parallel all2all.
+class ParallelMLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super(ParallelMLP, self).__init__()
+ self.args = args
+
+ # Calculate the number of experts in total and the number of experts
+ # owned by this rank.
+ # world_size = mpu.get_expert_parallel_world_size(args)
+ self.num_experts = args.moe_num_experts
+ self.top_k = self.args.moe_top_k
+
+ # Calculate the number of bits needed to represent the expert indices
+ # so that we can pass it to radix sort.
+ self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1)
+
+ # Expert MLP.
+ self.mlp = mlp.MLP(args)
+
+ self.bias: Optional[torch.Tensor]
+ if self.args.bias:
+ # Note that the output bias is not parallelized with expert
+ # model parallelism.
+ self.bias = torch.nn.Parameter(
+ torch.empty(
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ torch.nn.init.zeros_(self.bias)
+ else:
+ self.register_parameter('bias', None)
+
+ # Select the forward function for the operating mode.
+ self.forward_fn = (self.parallel_forward_once if args.moe_expert_model_parallelism else self.forward_once)
+
+ def expert_capacity(self, tokens: int) -> int:
+ world_size = mpu.get_expert_parallel_world_size(self.args)
+ tokens_per_expert = (self.top_k * tokens * world_size / self.num_experts)
+ return int(self.args.moe_capacity_factor * tokens_per_expert)
+
+ def load_balancing_loss(self, tokens_per_expert: torch.Tensor, expert_scores: torch.Tensor):
+ """Calculate the load balancing loss contribution."""
+ assert len(expert_scores.size()) == 2
+ tokens, num_experts = expert_scores.size()
+ assert num_experts == self.num_experts
+ assert len(tokens_per_expert.size()) == 1
+ num_experts, = tokens_per_expert.size()
+ assert num_experts == self.num_experts
+ scale = self.num_experts / (tokens * self.top_k)
+ return scale * torch.dot(
+ tokens_per_expert.to(expert_scores.dtype),
+ expert_scores.mean(dim=0),
+ )
+
+ def indices_and_bins(self,
+ top_expert: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ # Sort the expert ids to produce the scatter/gather
+ # indices for the permutation.
+ #
+ # TODO(tgale): Is it worth doing this conversion to 32-bit
+ # prior? Could we place the `torch.max` operation to return
+ # 32-bit expert indices?
+ top_expert = top_expert.int()
+ # output = ops.sort(top_expert, self.sort_end_bit)
+ output = sort(top_expert, self.sort_end_bit)
+ assert output is not None
+ bin_ids, indices = output
+
+ # Histogram the expert ids to identify the number of
+ # tokens routed to each expert.
+ #
+ # TODO(tgale): Does the sorted data produce a more favorable
+ # data distribution for histogram? Or is the op parallelism
+ # worth more?
+ # tokens_per_expert = ops.histogram(top_expert, self.num_experts)
+ tokens_per_expert = histogram(top_expert, self.num_experts)
+
+ # Calculate the bin bounds for the sorted tokens.
+ # bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ bins = inclusive_cumsum(tokens_per_expert, 0)
+ assert bins is not None
+ bins = bins.view(1) if not len(bins.size()) else bins
+
+ assert isinstance(indices, torch.Tensor)
+ assert isinstance(bin_ids, torch.Tensor)
+ assert isinstance(bins, torch.Tensor)
+ assert isinstance(tokens_per_expert, torch.Tensor)
+
+ return indices, bin_ids, bins, tokens_per_expert
+
+ def permute_and_compute(
+ self,
+ x: torch.Tensor,
+ tokens_per_expert: int, # unused
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor, # unused
+ expert_weights: torch.Tensor,
+ bins: torch.Tensor,
+ expert_capacity: int,
+ top_k: int,
+ ):
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ # output = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
+ output = binned_gather(x, indices, bins, expert_capacity, top_k)
+ assert output is not None
+ x = output
+
+ # Perform the expert computation. Note that we don't
+ # use biases for these linear operations.
+ x = self.mlp(x)
+
+ # Un-route the data for the MoE output.
+ # return ops.binned_scatter(x, indices, expert_weights, bins, top_k)
+ return binned_scatter(x, indices, expert_weights, bins, top_k)
+
+
+ def forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ # If expert_capacity is set to zero, set the number of tokens
+ # per expert to the maximum we need to avoid dropping tokens.
+ sl, bs, _ = x.size()
+ expert_capacity = self.expert_capacity(sl * bs)
+ if expert_capacity == 0:
+ expert_capacity = torch.max(tokens_per_expert).item()
+
+ x = self.permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ self.top_k,
+ )
+ return x, tokens_per_expert
+
+ def parallel_forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ # NOTE: This function implements the same computation as forward_once
+ # but with expert model parallelism.
+ #
+ # 1. Permute the tokens locally so that they are grouped by their
+ # expert assignments. This allows us to transfer all of the tokens
+ # for a remote device in one communication primitive.
+ #
+ # 2. Permute the tokens across the expert parallel devices. After
+ # this is completed each device has all of the tokens assigned to
+ # its set of experts in its local HBM.
+ #
+ # 3. Permute the tokens locally so that they are grouped by their
+ # expert assignement. After the distributed permutation the tokens
+ # are grouped by which device they came from. We re-order them
+ # locally to allow for efficient computation.
+ #
+ # After this series of permutations we compute the linear layers
+ # and then repeat these three steps in reverse to produce the final
+ # output.
+ #
+ # Compute the mapping of local tokens to experts.
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ # If we're sharding the experts along the hidden dimension
+ # multiple devices own parts of the same sets of experts.
+ # Replicate the token counts so every device gets the counts.
+ # repeated_tokens_per_expert = ops.repeat(
+ repeated_tokens_per_expert = repeat(
+ tokens_per_expert,
+ (mpu.hidden_sharding_degree(self.args),),
+ )
+
+ # Pass token count information to the device on which the
+ # target expert resides.
+ parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert,)
+ tpe_handle = dist.all_to_all_single(
+ parallel_tokens_per_expert,
+ repeated_tokens_per_expert,
+ group=self.args.expert_parallel_group,
+ async_op=True,
+ )
+
+ # Permute locally and without any padding so that tokens for each
+ # parallel device are stored contiguously.
+ #
+ # This view updates the shape of the tensor from [sl, bs, hs] to
+ # [sl * bs, hs] prior to the permutation.
+ x = x.view(-1, x.shape[-1])
+ # output = ops.gather(x, indices, bin_ids, bins, self.top_k)
+ output = gather(x, indices, bin_ids, bins, self.top_k)
+ assert output is not None
+ x = output
+
+ # Compute the number of tokens that will be received from each
+ # device and permute the input data across the devices.
+ with torch.no_grad():
+ tpe_handle.wait()
+ experts_per_rank = mpu.experts_per_rank(self.args)
+
+ # Reshape to [world_size, num_experts_per_rank].
+ world_size = mpu.get_expert_parallel_world_size(self.args)
+ repeated_tokens_per_expert = (repeated_tokens_per_expert.view(world_size, experts_per_rank))
+ parallel_tokens_per_expert = (parallel_tokens_per_expert.view(world_size, experts_per_rank))
+
+ # TODO(tgale): It might be faster to do this on the GPU and
+ # then communicate the results back to the host.
+ send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1)
+ parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
+ recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1)
+
+ # Convert the send/recv counts to lists.
+ send_counts = send_counts.tolist()
+ recv_counts = recv_counts.tolist()
+ tokens_received = sum(recv_counts)
+
+ # If we're sharding the experts along the hidden dimension
+ # multiple devices own parts of the same sets of experts.
+ # Replicate the token counts so devices that share experts
+ # get all of the tokens assigned to them.
+ #
+ # TODO(tgale): Fuse this into the prior, local permutation.
+ # x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
+ x = repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
+
+ # Start the cross-device permutation asynchronously so we can
+ # overlap communication with computation.
+ parallel_x, parallel_x_handle = all_to_all(
+ x,
+ recv_counts,
+ send_counts,
+ self.args.expert_parallel_group,
+ async_op=True,
+ )
+
+ with torch.no_grad():
+ # After we do the cross-device permutation we have the tokens on the
+ # correct device but not yet grouped by expert because we received
+ # tokens from each device as contiguous chunks. To group the tokens
+ # for expert computation we'll do one more local permutation. The
+ # rest of this torch.no_grad() scope sets up the indices and bins
+ # for this permutation.
+ # replicate_bins = ops.inclusive_cumsum(
+ replicate_bins = inclusive_cumsum(
+ parallel_tokens_per_expert.flatten(),
+ 0,
+ )
+ replicate_bins = (replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins)
+
+ # Construct the expert indices for the permuted tokens.
+ parallel_top_expert = torch.remainder(
+ torch.arange(
+ self.num_experts * mpu.hidden_sharding_degree(self.args),
+ dtype=torch.int32,
+ device=indices.device,
+ ),
+ mpu.experts_per_rank(self.args),
+ )
+ # parallel_top_expert = ops.replicate(
+ parallel_top_expert = replicate(
+ parallel_top_expert.unsqueeze(dim=0),
+ replicate_bins,
+ tokens_received,
+ ).flatten()
+
+ # TODO(tgale): The sort_end_bit here can be reduced.
+ # parallel_bin_ids, parallel_indices = ops.sort(
+ parallel_bin_ids, parallel_indices = sort(
+ parallel_top_expert,
+ self.sort_end_bit,
+ )
+
+ # Calculate the bins boundaries from the token counts.
+ parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+ # parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = (parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins)
+
+ # If expert_capacity is set to zero, set the number of tokens
+ # per expert to the maximum we need to avoid dropping tokens.
+ tokens, _ = x.size()
+ expert_capacity = self.expert_capacity(tokens)
+ if expert_capacity == 0:
+ expert_capacity = torch.max(parallel_tokens_per_expert).item()
+
+ # Locally permute the tokens and perform the expert computation.
+ # Block to make sure that the cross-device permutation is complete.
+ if self.args.mlp_impl == 'grouped':
+ # GroupedMLP requires counts on CPU. We can use the tensor already
+ # moved to CPU for the prior all_to_all, which avoids an extra
+ # device synchronization.
+ parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+ parallel_x_handle.wait()
+ parallel_x = self.permute_and_compute(
+ parallel_x,
+ parallel_tokens_per_expert,
+ parallel_indices,
+ parallel_bin_ids,
+ None, # expert_weights
+ parallel_bins,
+ expert_capacity,
+ top_k=1,
+ )
+
+ # Un-permute the tokens across the devices.
+ x, _ = all_to_all(
+ parallel_x,
+ send_counts,
+ recv_counts,
+ self.args.expert_parallel_group,
+ )
+
+ # Reduce along the hidden sharding to get the final outputs.
+ #
+ # TODO(tgale): Fuse this into the following local permutation.
+ shape = (
+ mpu.hidden_sharding_degree(self.args),
+ -1,
+ self.args.hidden_size,
+ )
+ # x = ops.sum(x.view(shape), dim=0)
+ x = x.view(shape).sum(dim=0)
+
+ # Un-permute locally to setup for the next series of operations.
+ # x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
+ x = scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
+ return x, tokens_per_expert.flatten()
+
+ def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ in_shape = x.size()
+
+ # Compute the experts.
+ x, tokens_per_expert = self.forward_fn(x, expert_weights, top_experts)
+ if self.training and self.args.moe_loss_weight > 0:
+ save_load_balancing_loss((tokens_per_expert, scores))
+ x = x.view(in_shape)
+ if self.bias is not None:
+ if self.args.return_bias:
+ return x, self.bias
+ return x + self.bias
+ return x
+
+
+class MoE(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super(MoE, self).__init__()
+
+ # Token router.
+ self.router = router.LearnedRouter(args)
+
+ # Expert computation helper.
+ self.experts = self._init_experts_mlp(args)
+
+ self.shared_expert = None
+ if args.shared_expert:
+ # SharedExpert computation helper.
+ self.shared_expert = sharedexpert_registry.get(args)
+
+ def _init_experts_mlp(self, args: Arguments):
+ return ParallelMLP(args)
+
+ def forward(self, x: torch.Tensor):
+ # NOTE: If we're going to cast the activations to lower precision
+ # do it before we permute the tokens to save bandwidth.
+ x = common.cast_if_autocast_enabled(x)
+
+ # Compute the expert scores and assignments.
+ scores, expert_weights, top_experts = self.router(x)
+
+ # Compute the experts.
+ out = self.experts(x, scores, expert_weights, top_experts)
+ if self.shared_expert is not None:
+ shared_expert_out = self.shared_expert(x)
+ out = self.shared_expert.add_experts_sharedexpert(
+ shared_expert_out,
+ out,
+ )
+ return out
diff --git a/build/torch210-cxx11-xpu20253-x86_64-linux/_layers/mpu.py b/build/torch210-cxx11-xpu20253-x86_64-linux/_layers/mpu.py
new file mode 100644
index 0000000000000000000000000000000000000000..434e143ab42bf3f83406d69e9dd1f72777716e22
--- /dev/null
+++ b/build/torch210-cxx11-xpu20253-x86_64-linux/_layers/mpu.py
@@ -0,0 +1,94 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Optional
+
+import torch
+import torch.distributed as dist
+
+# from megablocks.layers.arguments import Arguments
+from .arguments import Arguments
+
+
+class MoeParam(torch.Tensor):
+
+ def __init__(self):
+ super().__init__(self)
+ self.expert_model_parallel: bool
+
+
+def is_moe_param(tensor: torch.Tensor) -> bool:
+ return hasattr(tensor, 'expert_model_parallel')
+
+
+def get_expert_parallel_world_size(args: Arguments) -> int:
+ return (dist.get_world_size(args.expert_parallel_group) if args.moe_expert_model_parallelism else 1)
+
+
+def get_expert_parallel_rank(args: Arguments) -> int:
+ return (dist.get_rank(args.expert_parallel_group) if args.moe_expert_model_parallelism else 0)
+
+
+def set_expert_model_parallel_attributes(
+ tensor: torch.Tensor,
+ is_parallel: bool,
+):
+ assert not hasattr(tensor, 'expert_model_parallel')
+ setattr(tensor, 'expert_model_parallel', is_parallel)
+
+
+def param_is_expert_model_parallel(param: MoeParam) -> bool:
+ return (hasattr(param, 'expert_model_parallel') and param.expert_model_parallel)
+
+
+def copy_expert_model_parallel_attributes(
+ destination_tensor: torch.Tensor,
+ source_tensor: torch.Tensor,
+):
+ if hasattr(source_tensor, 'expert_model_parallel'):
+ setattr(
+ destination_tensor,
+ 'expert_model_parallel',
+ getattr(source_tensor, 'expert_model_parallel'),
+ )
+
+
+def synchronized_print(group: Optional[dist.ProcessGroup], *x: torch.Tensor):
+ world_size = dist.get_world_size(group)
+ rank = dist.get_rank(group)
+ for i in range(world_size):
+ dist.barrier(group)
+ if i == rank:
+ print(f'rank = {rank}', *x)
+
+
+# Helpers for expert/tensor sharding.
+def expert_sharding_degree(args: Arguments) -> int:
+ world_size = get_expert_parallel_world_size(args)
+ esd = min(world_size, args.moe_num_experts)
+
+ if (args.moe_num_experts % esd) != 0:
+ raise ValueError(f'Cannot shard {args.moe_num_experts} experts {esd} ways.',)
+ return esd
+
+
+def hidden_sharding_degree(args: Arguments) -> int:
+ world_size = get_expert_parallel_world_size(args)
+ esd = expert_sharding_degree(args)
+ hsd = world_size // esd
+
+ if (args.ffn_hidden_size % hsd) != 0:
+ raise ValueError(f'Cannot shard {args.ffn_hidden_size} features {hsd} ways.',)
+ if (esd * hsd) != world_size:
+ raise ValueError(
+ f"Invalid sharding. 'expert_sharding_degree' ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size}).",
+ )
+ return hsd
+
+
+def experts_per_rank(args: Arguments) -> int:
+ return args.moe_num_experts // expert_sharding_degree(args)
+
+
+def features_per_rank(args: Arguments) -> int:
+ return args.ffn_hidden_size // hidden_sharding_degree(args)
diff --git a/build/torch210-cxx11-xpu20253-x86_64-linux/_layers/router.py b/build/torch210-cxx11-xpu20253-x86_64-linux/_layers/router.py
new file mode 100644
index 0000000000000000000000000000000000000000..37cb2782348d62583376f1a183c7ede83601216d
--- /dev/null
+++ b/build/torch210-cxx11-xpu20253-x86_64-linux/_layers/router.py
@@ -0,0 +1,116 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+
+# from megablocks.layers import common
+# from megablocks.layers.arguments import Arguments
+from . import common
+from .arguments import Arguments
+
+_ROUTER_LOGITS = []
+
+
+def _save_router_logits(logits: torch.Tensor, args: Arguments):
+ if args.moe_zloss_weight == 0:
+ return
+ global _ROUTER_LOGITS
+ _ROUTER_LOGITS.append(logits)
+
+
+def clear_router_zloss():
+ global _ROUTER_LOGITS
+ _ROUTER_LOGITS.clear()
+
+
+def batched_router_zloss(args: Arguments):
+ global _ROUTER_LOGITS
+
+ if args.moe_zloss_weight == 0:
+ import warnings
+ warnings.warn('Call to batched_router_zloss, but moe_zloss_weight=0')
+ return 0
+
+ logits_per_router = _ROUTER_LOGITS
+
+ if args.moe_zloss_in_fp32:
+ logits_per_router = [logits.float() for logits in logits_per_router]
+
+ unscaled_zloss_per_router = torch.stack([
+ torch.logsumexp(logits, dim=1).square().mean() for logits in logits_per_router
+ ])
+
+ return args.moe_zloss_weight * unscaled_zloss_per_router
+
+
+# NOTE: To enable end-to-end benchmarking without convergence we
+# support a flag to force the router to assign tokens uniformly
+# across the experts. We do this with a custom autograd operation
+# so that PyTorch still executes the full set of router operation.
+class _UniformExpertAssignment(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, num_experts: int):
+ out = torch.arange(x.numel(), dtype=x.dtype, device=x.device)
+ out = torch.remainder(out, num_experts)
+ return out.view(x.shape)
+
+
+_uniform_expert_assignment = _UniformExpertAssignment.apply
+
+
+class LearnedRouter(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+
+ # Learned router parameters.
+ #
+ # NOTE: This weight matrix is not parallelized with expert model
+ # parallelism. Each device needs the entire router weight matrix
+ # so that it can route its batch of data correctly.
+ self.layer = torch.nn.Linear(
+ args.hidden_size,
+ args.moe_num_experts,
+ bias=False,
+ dtype=common.dtype(args),
+ device=args.device,
+ )
+ args.init_method(self.layer.weight)
+
+ def jitter(self, x: torch.Tensor):
+ low: float = 1.0 - self.args.moe_jitter_eps
+ high: float = 1.0 + self.args.moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return low + noise * (high - low)
+
+ def _top_k(self, scores: torch.Tensor):
+ if self.args.moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, self.args.moe_top_k, dim=-1)
+
+ def forward(self, x: torch.Tensor):
+ if self.training and self.args.moe_jitter_eps is not None:
+ x = x * self.jitter(x)
+
+ logits = self.layer(x.view(-1, x.shape[-1]))
+ _save_router_logits(logits, self.args)
+ scores = logits.softmax(dim=-1)
+ expert_weights, expert_indices = self._top_k(scores)
+ if self.args.moe_normalize_expert_weights:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=self.args.moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+
+ expert_indices = (
+ _uniform_expert_assignment(
+ expert_indices,
+ self.args.moe_num_experts,
+ ) if self.args.uniform_expert_assignment else expert_indices
+ )
+ return scores, expert_weights, expert_indices
diff --git a/build/torch210-cxx11-xpu20253-x86_64-linux/_layers/sharedexpert_registry.py b/build/torch210-cxx11-xpu20253-x86_64-linux/_layers/sharedexpert_registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..5840862f88f370ace5fd49bd0612fc98d186cc49
--- /dev/null
+++ b/build/torch210-cxx11-xpu20253-x86_64-linux/_layers/sharedexpert_registry.py
@@ -0,0 +1,32 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Union
+
+# from megablocks.layers import glu, mlp
+# from megablocks.layers.arguments import Arguments
+from . import glu, mlp
+from .arguments import Arguments
+
+_REGISTRY = {
+ 'mlp': mlp.SharedMLP,
+ 'glu': glu.SharedGLU,
+}
+
+
+def get(args: Arguments) -> Union[mlp.SharedMLP, glu.SharedGLU]:
+ """Returns an SharedMLP for use in a dMoE instance.
+
+ Uses the provided arguments to instantiate the appropriate
+ SharedMLP instance.
+
+ Args:
+ args: propagated Arguments dataclass.
+
+ Returns:
+ An instantiated SharedMLP constructed using the input args.
+ """
+ if args.mlp_type not in _REGISTRY:
+ raise ValueError(f'Unsupported mlp type: {args.mlp_type}')
+
+ return _REGISTRY[args.mlp_type](args)
diff --git a/build/torch210-cxx11-xpu20253-x86_64-linux/_megablocks_xpu_7a6bcf4.abi3.so b/build/torch210-cxx11-xpu20253-x86_64-linux/_megablocks_xpu_7a6bcf4.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..9ea39ea7a409dce7a5431c24e20fe754bbd42787
--- /dev/null
+++ b/build/torch210-cxx11-xpu20253-x86_64-linux/_megablocks_xpu_7a6bcf4.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8b22948998af13e0b921419366b6f68a1dd0e649e7ccb8c55c123c1aa9f3ec5b
+size 5381760
diff --git a/build/torch210-cxx11-xpu20253-x86_64-linux/_ops.py b/build/torch210-cxx11-xpu20253-x86_64-linux/_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c8dd6eeccd632df5e23111e5dd5221d3e1fcb47
--- /dev/null
+++ b/build/torch210-cxx11-xpu20253-x86_64-linux/_ops.py
@@ -0,0 +1,9 @@
+import torch
+from . import _megablocks_xpu_7a6bcf4
+ops = torch.ops._megablocks_xpu_7a6bcf4
+
+def add_op_namespace_prefix(op_name: str):
+ """
+ Prefix op by namespace.
+ """
+ return f"_megablocks_xpu_7a6bcf4::{op_name}"
diff --git a/build/torch210-cxx11-xpu20253-x86_64-linux/_version.py b/build/torch210-cxx11-xpu20253-x86_64-linux/_version.py
new file mode 100644
index 0000000000000000000000000000000000000000..c55783177af19bc03654c730c4892df8f8532279
--- /dev/null
+++ b/build/torch210-cxx11-xpu20253-x86_64-linux/_version.py
@@ -0,0 +1,6 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+"""The MegaBlocks Version."""
+
+__version__ = '0.11.0.dev0'
diff --git a/build/torch210-cxx11-xpu20253-x86_64-linux/backend/__init__.py b/build/torch210-cxx11-xpu20253-x86_64-linux/backend/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d4e43e9b7e6a9a1c3bde2df34914643ca5d8332
--- /dev/null
+++ b/build/torch210-cxx11-xpu20253-x86_64-linux/backend/__init__.py
@@ -0,0 +1,2 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
diff --git a/build/torch210-cxx11-xpu20253-x86_64-linux/backend/kernels.py b/build/torch210-cxx11-xpu20253-x86_64-linux/backend/kernels.py
new file mode 100644
index 0000000000000000000000000000000000000000..d20139324883338ddb312e4b05a72056d74491ac
--- /dev/null
+++ b/build/torch210-cxx11-xpu20253-x86_64-linux/backend/kernels.py
@@ -0,0 +1,557 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import triton
+import triton.language as tl
+
+# Stub triton autotune when testing in a env that does not have CUDA
+# this approach preserves the original code but enables testing without a GPU
+if torch.cuda.is_available() is False:
+ import warnings
+
+ warnings.warn("CUDA is not available. Triton autotuning is disabled.")
+
+ def _no_autotune(*args, **kwargs):
+ def deco(fn):
+ return fn
+ return deco
+
+ triton.autotune = _no_autotune
+
+
+def assert_is_tensor(x, ndim):
+ if x.ndim != ndim:
+ raise ValueError(f'Expected {ndim}-tensor but got {x.ndim}-tensor')
+
+
+def assert_is_matrix(x):
+ assert_is_tensor(x, 2)
+
+
+def assert_is_vector(x):
+ if x.ndim != 1:
+ raise ValueError(f'Expected 1-tensor but got {x.ndim}-tensor')
+
+
+def assert_equal(a, b):
+ if a != b:
+ raise ValueError(f'Expected dimensions to be equal but got {a} and {b}.',)
+
+
+# a: (tokens, hidden_size), real.
+# indices: (tokens * top_k), integer.
+# bin_ids: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+# padded_bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _padded_copy(
+ a,
+ b,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+ A_TO_B: tl.constexpr,
+ SCALE: tl.constexpr,
+):
+ # Our index into array 'a'.
+ index_a = tl.load(indices + tl.program_id(0))
+
+ # One threadblock per row in 'a'. Array 'b' has greater or equal
+ # number of rows since they could be padded.
+ bin_idx = tl.load(bin_ids + tl.program_id(0))
+
+ # Now we know what bin we're assigned to, but we need to know how
+ # many threadblocks were assigned to earlier bins so we can offset
+ # in our bin properly.
+ offset_in_bin = tl.program_id(0)
+ if bin_idx > 0:
+ offset_in_bin -= tl.load(bins + bin_idx - 1)
+
+ # Load the starting index of our bin in array 'b'.
+ index_b = offset_in_bin
+ if bin_idx > 0:
+ index_b += tl.load(padded_bins + bin_idx - 1)
+
+ # Offset the input and output pointers.
+ #
+ # If we're going from A to B, divide the input index to copy
+ # the same input repeatedly. If we're going from B to A we
+ # need to reduce the result. Using atomics is slow, so we
+ # do the reduce step in a second kernel.
+ offset = index_a // TOP_K if A_TO_B else index_a
+ a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS)
+ b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ # Load the scale, if requested.
+ scale = tl.load(weights + index_a) if SCALE else 1
+
+ # Swap the pointers depending on the direction.
+ iptr = a if A_TO_B else b
+ optr = b if A_TO_B else a
+
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ x = tl.load(iptr + offsets, mask=mask)
+ x = x.to(tl.float32) * scale.to(tl.float32)
+
+ tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask)
+
+ offsets += BLOCK_X
+
+
+def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+ assert_equal(bin_ids.shape[0], x.shape[0] * top_k)
+ assert_equal(bins.size(), padded_bins.size())
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ # NOTE: Because of the padding, the output size is dynamic.
+ # We load the final padded bin bound to get the output rows.
+ output_rows = padded_bins[-1].cpu().item()
+ out = torch.zeros((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ x,
+ out,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def gather(x, indices, bin_ids, weights, bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+ assert_equal(bin_ids.shape[0], x.shape[0] * top_k)
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ # NOTE: There is no padding so the output rows equals the
+ # input rows multiplied by top_k.
+ output_rows = x.shape[0] * top_k
+ out = torch.empty((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ x,
+ out,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], bin_ids.shape[0])
+ assert_equal(bins.size(), padded_bins.size())
+
+ if weights is not None:
+ assert_equal(indices.shape[0], weights.shape[0])
+
+ tokens = indices.shape[0] // top_k
+ out = torch.empty((tokens, top_k, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ out,
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=False,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+
+ # Reduce along the top-k dimension, if needed.
+ return out.sum(dim=1) if top_k > 1 else out.view(tokens, x.shape[1])
+
+
+def scatter(x, indices, bin_ids, weights, bins, top_k):
+ return padded_scatter(x, indices, bin_ids, weights, bins, bins, top_k)
+
+
+# x: (tokens, top_k, hidden_size), real
+# grad: (tokens, hidden_size), real.
+# wgrad: (tokens, top_k), real.
+# indices: (tokens * top_k), integer.
+# bin_ids: (tokens * top_k), integer.
+# bins: (num_experts), integer.
+# padded_bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _padded_copy_wgrad(
+ x,
+ grad,
+ wgrad,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+):
+ # Our index into 'tokens * top_k'.
+ index_out = tl.load(indices + tl.program_id(0))
+
+ # One threadblock per row in 'a'. Array 'b' has greater or equal
+ # number of rows since they could be padded.
+ bin_idx = tl.load(bin_ids + tl.program_id(0))
+
+ # Now we know what bin we're assigned to, but we need to know how
+ # many threadblocks were assigned to earlier bins so we can offset
+ # in our bin properly.
+ offset_in_bin = tl.program_id(0)
+ if bin_idx > 0:
+ offset_in_bin -= tl.load(bins + bin_idx - 1)
+
+ # Load the starting index of our bin in array 'x'.
+ index_x = offset_in_bin
+ if bin_idx > 0:
+ index_x += tl.load(padded_bins + bin_idx - 1)
+
+ # Offset the input and output pointers.
+ wgrad += index_out
+ grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS)
+ x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ data = tl.load(x + offsets, mask=mask).to(tl.float32)
+ scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
+ acc += data * scale
+ offsets += BLOCK_X
+
+ # Reduce to get the final result and store.
+ out = tl.sum(acc).to(wgrad.dtype.element_ty)
+ tl.store(wgrad, out)
+
+
+def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_matrix(grad)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], bin_ids.shape[0])
+ assert_equal(bins.size(), padded_bins.size())
+
+ tokens = indices.shape[0] // top_k
+ out = torch.empty((tokens * top_k), dtype=x.dtype, device=x.device)
+ _padded_copy_wgrad[(indices.shape[0],)](
+ x,
+ grad,
+ out,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ TOP_K=top_k,
+ )
+ return out
+
+
+def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k):
+ return padded_scatter_wgrad(x, grad, indices, bin_ids, bins, bins, top_k)
+
+
+# a: (tokens, hidden_size), real.
+# b: (num_experts, expert_capacity, num_columns), real.
+# indices: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _binned_copy(
+ a,
+ b,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+ A_TO_B: tl.constexpr,
+ SCALE: tl.constexpr,
+):
+ # Load our indices into the output.
+ expert_idx = tl.program_id(0)
+ entry_idx = tl.program_id(1)
+
+ # Calculate our offset into the output.
+ index_b = expert_idx * expert_capacity + entry_idx
+
+ # Load the index bounds for our bin and calculate
+ # the number of tokens assigned to our expert.
+ start = 0
+ if expert_idx > 0:
+ start = tl.load(bins + expert_idx - 1)
+ end = tl.load(bins + expert_idx)
+ num_tokens = end - start
+
+ # Calculate our offset into the input. If we don't
+ # have an input exit early.
+ if entry_idx >= num_tokens:
+ return
+ index_a = tl.load(indices + start + entry_idx)
+
+ # Offset the input and output pointers.
+ #
+ # If we're going from A to B, divide the input index to copy
+ # the same input repeatedly. If we're going from B to A we
+ # need to reduce the result. Using atomics is slow, so we
+ # do the reduce step in a second kernel.
+ offset = index_a // TOP_K if A_TO_B else index_a
+ a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS)
+ b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ # Load the scale, if requested.
+ scale = tl.load(weights + index_a) if SCALE else 1
+
+ # Swap the pointers depending on the direction.
+ #
+ # NOTE: We need to zero the output in both directions.
+ iptr = a if A_TO_B else b
+ optr = b if A_TO_B else a
+
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ x = tl.load(iptr + offsets, mask=mask)
+ x = x.to(tl.float32) * scale.to(tl.float32)
+
+ tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask)
+
+ offsets += BLOCK_X
+
+
+def binned_gather(x, indices, weights, bins, expert_capacity, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ num_experts = bins.shape[0]
+ out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
+
+ _binned_copy[(num_experts, expert_capacity)](
+ x,
+ out,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def binned_scatter(x, indices, weights, bins, top_k):
+ # Validate the input shapes.
+ assert_is_tensor(x, 3)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(bins.shape[0], x.shape[0])
+
+ if weights is not None:
+ assert_equal(indices.shape[0], weights.shape[0])
+
+ num_experts, expert_capacity, hidden_size = x.shape
+ tokens = indices.shape[0] // top_k
+ out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
+ _binned_copy[(num_experts, expert_capacity)](
+ out,
+ x,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS=hidden_size,
+ A_TO_B=False,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+
+ # Reduce along the top-k dimension, if needed.
+ return out.sum(dim=1) if top_k > 1 else out.view(tokens, hidden_size)
+
+
+# a: (tokens, hidden_size), real.
+# b: (num_experts, expert_capacity, num_columns), real.
+# indices: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _binned_copy_wgrad(
+ x,
+ grad,
+ wgrad,
+ num_experts,
+ expert_capacity,
+ indices,
+ bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+):
+ # Load our indices into the output.
+ expert_idx = tl.program_id(0)
+ entry_idx = tl.program_id(1)
+
+ # Calculate our offset into the output.
+ index_x = expert_idx * expert_capacity + entry_idx
+
+ # Load the index bounds for our bin and calculate
+ # the number of tokens assigned to our expert.
+ start = 0
+ if expert_idx > 0:
+ start = tl.load(bins + expert_idx - 1)
+ end = tl.load(bins + expert_idx)
+ num_tokens = end - start
+
+ # Calculate our offset into the input. If we don't
+ # have an input exit early.
+ if entry_idx >= num_tokens:
+ return
+ index_out = tl.load(indices + start + entry_idx)
+
+ # Offset the input and output pointers.
+ wgrad += index_out
+ grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS)
+ x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ data = tl.load(x + offsets, mask=mask).to(tl.float32)
+ scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
+ acc += data * scale
+ offsets += BLOCK_X
+
+ # Reduce to get the final result and store.
+ out = tl.sum(acc).to(wgrad.dtype.element_ty)
+ tl.store(wgrad, out)
+
+
+def binned_scatter_wgrad(x, grad, indices, bins, top_k):
+ # Validate the input shapes.
+ assert_is_tensor(x, 3)
+ assert_is_matrix(grad)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(bins.shape[0], x.shape[0])
+
+ num_experts, expert_capacity, hidden_size = x.shape
+ tokens = indices.shape[0] // top_k
+ out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
+ _binned_copy_wgrad[(num_experts, expert_capacity)](
+ x,
+ grad,
+ out,
+ num_experts,
+ expert_capacity,
+ indices,
+ bins,
+ NUM_COLUMNS=hidden_size,
+ TOP_K=top_k,
+ )
+ return out
diff --git a/build/torch210-cxx11-xpu20253-x86_64-linux/benchmark_util.py b/build/torch210-cxx11-xpu20253-x86_64-linux/benchmark_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..02612d95e3ead1175a596e2878fa34b5bf85ad6f
--- /dev/null
+++ b/build/torch210-cxx11-xpu20253-x86_64-linux/benchmark_util.py
@@ -0,0 +1,35 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import numpy as np
+import torch
+
+
+def log_benchmark(name, arguments, time, std):
+ print('=' * 60)
+ print(f'{name} Benchmark')
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean time = {:.3f}ms, std time = {:.3f}ms'.format(time, std))
+ print('=' * 60)
+
+
+def benchmark_function(fn, iterations=100, warmup=10):
+ # Warmup iterations.
+ for _ in range(warmup):
+ fn()
+
+ times = []
+ for i in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+
+ start.record()
+ fn()
+ end.record()
+
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ return np.mean(times), np.std(times)
diff --git a/build/torch210-cxx11-xpu20253-x86_64-linux/cpu_fused_moe.py b/build/torch210-cxx11-xpu20253-x86_64-linux/cpu_fused_moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a29e81d6e9c57f64b0a78ac1c0828e45fd9d855
--- /dev/null
+++ b/build/torch210-cxx11-xpu20253-x86_64-linux/cpu_fused_moe.py
@@ -0,0 +1,311 @@
+# SPDX-License-Identifier: Apache-2.0
+# MegaBlocks CPU Fused MoE Implementation
+#
+# This is a pure Python/PyTorch implementation for CPU.
+# For better performance, consider using the C++ kernel implementation.
+#
+import torch
+import torch.nn.functional as F
+
+
+def swigluoai_activation(gate: torch.Tensor, up: torch.Tensor,
+ alpha: float = 1.702, limit: float = 7.0) -> torch.Tensor:
+ """
+ SwigluOAI activation function used in GptOss models.
+
+ Formula:
+ gate = clamp(gate, max=limit)
+ up = clamp(up, -limit, limit)
+ glu = gate * sigmoid(gate * alpha)
+ output = (up + 1) * glu
+
+ Args:
+ gate: Gate tensor from gate projection
+ up: Up tensor from up projection
+ alpha: Scaling factor for sigmoid (default: 1.702)
+ limit: Clamp limit (default: 7.0)
+
+ Returns:
+ Activated tensor
+ """
+ gate = gate.clamp(max=limit)
+ up = up.clamp(min=-limit, max=limit)
+ glu = gate * torch.sigmoid(gate * alpha)
+ return (up + 1) * glu
+
+
+def silu_and_mul_activation(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor:
+ """
+ SiLU (Swish) activation with element-wise multiplication.
+
+ Formula:
+ output = silu(gate) * up
+
+ Args:
+ gate: Gate tensor
+ up: Up tensor
+
+ Returns:
+ Activated tensor
+ """
+ return F.silu(gate) * up
+
+
+def route_tokens_cpu(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: torch.Tensor | None,
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_normalize_expert_weights: int | None = None,
+) -> tuple:
+ """
+ Route tokens to experts and compute expert weights and indices (CPU version).
+
+ Args:
+ x: Input tensor [batch, seq, hidden] or [tokens, hidden]
+ router_weight: Router weight [num_experts, hidden]
+ router_bias: Router bias [num_experts] or None
+ moe_top_k: Number of experts per token
+ moe_num_experts: Total number of experts
+ moe_normalize_expert_weights: Normalization order or None
+
+ Returns:
+ Tuple of (logits, expert_weights, expert_indices)
+ """
+ x_flat = x.view(-1, x.shape[-1])
+ logits = F.linear(x_flat, router_weight, router_bias)
+
+ if moe_top_k == 1:
+ expert_weights, expert_indices = logits.max(dim=-1, keepdim=True)
+ else:
+ expert_weights, expert_indices = torch.topk(logits, moe_top_k, dim=-1)
+
+ expert_weights = expert_weights.softmax(dim=-1)
+
+ if moe_normalize_expert_weights is not None:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+
+ return logits, expert_weights, expert_indices
+
+
+def cpu_fused_moe(
+ hidden_states: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ w1_bias: torch.Tensor | None = None,
+ w2_bias: torch.Tensor | None = None,
+ activation: str = "silu",
+ alpha: float = 1.702,
+ limit: float = 7.0,
+ is_interleaved: bool = True,
+) -> torch.Tensor:
+ """
+ CPU Fused MoE using PyTorch operations.
+
+ This implementation processes all experts in parallel using batched operations
+ instead of sequential for loops, which is more efficient on CPU.
+
+ Args:
+ hidden_states: [num_tokens, hidden_size]
+ w1: [num_experts, hidden_size, 2*inter_size] - gate_up_proj weights
+ w2: [num_experts, inter_size, hidden_size] - down_proj weights
+ topk_weights: [num_tokens, topk] - routing weights
+ topk_ids: [num_tokens, topk] - expert indices
+ w1_bias: [num_experts, 2*inter_size] or None
+ w2_bias: [num_experts, hidden_size] or None
+ activation: "silu" or "swigluoai"
+ alpha: swigluoai alpha parameter
+ limit: swigluoai limit parameter
+ is_interleaved: whether gate_up is interleaved [g0,u0,g1,u1,...] (True for GptOss)
+
+ Returns:
+ output: [num_tokens, hidden_size]
+ """
+ num_tokens, hidden_size = hidden_states.shape
+ num_experts = w1.shape[0]
+ inter_size = w2.shape[1]
+ topk = topk_weights.shape[1]
+
+ # Initialize output
+ output = torch.zeros_like(hidden_states)
+
+ # Build expert mask: which tokens go to which expert
+ # expert_mask[expert_id] contains indices of (token_idx, topk_pos) pairs
+ for expert_idx in range(num_experts):
+ # Find tokens assigned to this expert
+ # mask shape: [num_tokens, topk], True where topk_ids == expert_idx
+ mask = (topk_ids == expert_idx)
+
+ if not mask.any():
+ continue
+
+ # Get token indices and topk positions
+ token_indices, topk_positions = torch.where(mask)
+
+ if len(token_indices) == 0:
+ continue
+
+ # Gather input tokens for this expert
+ # current_hidden: [num_selected_tokens, hidden_size]
+ current_hidden = hidden_states[token_indices]
+
+ # Get weights for this expert
+ # w1[expert_idx]: [hidden_size, 2*inter_size]
+ # w2[expert_idx]: [inter_size, hidden_size]
+ expert_w1 = w1[expert_idx] # [hidden_size, 2*inter_size]
+ expert_w2 = w2[expert_idx] # [inter_size, hidden_size]
+
+ # First projection: hidden @ w1 -> [num_selected, 2*inter_size]
+ gate_up = current_hidden @ expert_w1
+
+ # Add bias if present
+ if w1_bias is not None:
+ gate_up = gate_up + w1_bias[expert_idx]
+
+ # Split gate and up projections
+ if is_interleaved:
+ # GptOss uses interleaved layout: [g0, u0, g1, u1, ...]
+ gate = gate_up[..., ::2] # [num_selected, inter_size]
+ up = gate_up[..., 1::2] # [num_selected, inter_size]
+ else:
+ # Standard layout: [gate_all, up_all]
+ gate = gate_up[..., :inter_size]
+ up = gate_up[..., inter_size:]
+
+ # Apply activation
+ if activation == "swigluoai":
+ activated = swigluoai_activation(gate, up, alpha, limit)
+ else: # silu
+ activated = silu_and_mul_activation(gate, up)
+
+ # Second projection: activated @ w2 -> [num_selected, hidden_size]
+ expert_out = activated @ expert_w2
+
+ # Add bias if present
+ if w2_bias is not None:
+ expert_out = expert_out + w2_bias[expert_idx]
+
+ # Apply routing weights and accumulate
+ # weights shape: [num_selected]
+ weights = topk_weights[token_indices, topk_positions].unsqueeze(-1)
+ weighted_out = expert_out * weights
+
+ # Accumulate to output
+ output.index_add_(0, token_indices, weighted_out.to(output.dtype))
+
+ return output
+
+
+class MegaBlocksMoeMLP(torch.nn.Module):
+ """
+ CPU MoE MLP module that can be used as a drop-in replacement for
+ the transformers GptOssMLP when using @use_kernel_forward_from_hub.
+ """
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> tuple:
+ """
+ Forward pass through the MoE layer.
+
+ Args:
+ x: Input tensor of shape [batch_size, seq_len, hidden_size] or [tokens, hidden_size]
+
+ Returns:
+ Tuple of (output, expert_weights) where:
+ - output: Tensor of same shape as input
+ - expert_weights: Expert weights for each token [tokens, top_k]
+ """
+ # Get MoE parameters from the wrapped modules
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+
+ # Detect activation type
+ if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
+ activation = "swigluoai"
+ alpha = self.experts.alpha
+ limit = self.experts.limit
+ else:
+ activation = getattr(self.experts, "activation", "silu")
+ alpha = 1.702
+ limit = 7.0
+
+ # Get weight tensors
+ if hasattr(self.experts, "gate_up_proj"):
+ w1 = self.experts.gate_up_proj
+ is_interleaved = True # GptOss uses interleaved layout
+ elif hasattr(self.experts, "w1"):
+ w1 = self.experts.w1
+ w3 = getattr(self.experts, "w3", None)
+ if w3 is not None:
+ w1 = torch.cat([w1, w3], dim=-1)
+ is_interleaved = False
+ else:
+ raise AttributeError("experts module must have 'gate_up_proj' or 'w1' attribute")
+
+ if hasattr(self.experts, "down_proj"):
+ w2 = self.experts.down_proj
+ elif hasattr(self.experts, "w2"):
+ w2 = self.experts.w2
+ else:
+ raise AttributeError("experts module must have 'down_proj' or 'w2' attribute")
+
+ # Get optional bias tensors
+ w1_bias = getattr(self.experts, "gate_up_proj_bias", None)
+ w2_bias = getattr(self.experts, "down_proj_bias", None)
+
+ # Store original shape
+ in_shape = x.size()
+
+ # Route tokens to experts
+ logits, expert_weights, expert_indices = route_tokens_cpu(
+ x,
+ self.router.weight,
+ getattr(self.router, "bias", None),
+ moe_top_k,
+ moe_num_experts,
+ moe_normalize_expert_weights,
+ )
+
+ # Reshape input for fused MoE
+ x_flat = x.view(-1, x.shape[-1])
+
+ # Call CPU fused MoE
+ output = cpu_fused_moe(
+ hidden_states=x_flat,
+ w1=w1,
+ w2=w2,
+ topk_weights=expert_weights,
+ topk_ids=expert_indices,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ activation=activation,
+ alpha=alpha,
+ limit=limit,
+ is_interleaved=is_interleaved,
+ )
+
+ # Restore original shape
+ output = output.view(in_shape)
+
+ return output, expert_weights
+
+
+# Export classes and functions
+__all__ = [
+ "MegaBlocksMoeMLP",
+ "cpu_fused_moe",
+ "route_tokens_cpu",
+ "swigluoai_activation",
+ "silu_and_mul_activation",
+]
diff --git a/build/torch210-cxx11-xpu20253-x86_64-linux/cpu_moe_cpp.py b/build/torch210-cxx11-xpu20253-x86_64-linux/cpu_moe_cpp.py
new file mode 100644
index 0000000000000000000000000000000000000000..073ff66d24ce348fbb5ed19c9027fadd3f7a9c61
--- /dev/null
+++ b/build/torch210-cxx11-xpu20253-x86_64-linux/cpu_moe_cpp.py
@@ -0,0 +1,265 @@
+# SPDX-License-Identifier: Apache-2.0
+# MegaBlocks C++ Optimized CPU MoE
+
+"""
+C++ accelerated MoE with brgemm optimization for Intel AMX.
+Direct replacement for cpu_fused_moe.MegaBlocksMoeMLP with better performance.
+"""
+
+import torch
+from typing import Optional
+from .cpu_fused_moe import route_tokens_cpu
+from ._ops import ops
+
+
+def _to_local_tensor(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
+ """Convert DTensor to local torch.Tensor if needed for custom ops compatibility."""
+ if tensor is None:
+ return None
+ # Check if it's a DTensor by looking for the to_local() method
+ if hasattr(tensor, "to_local"):
+ return tensor.to_local()
+ return tensor
+
+
+def fused_moe_cpp(
+ hidden_states: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ inplace: bool = False,
+ use_int8_w8a8: bool = False,
+ use_fp8_w8a16: bool = False,
+ use_mxfp4: bool = False,
+ w1_scale: Optional[torch.Tensor] = None,
+ w2_scale: Optional[torch.Tensor] = None,
+ block_size: Optional[list] = None,
+ a1_scale: Optional[torch.Tensor] = None,
+ a2_scale: Optional[torch.Tensor] = None,
+ w1_bias: Optional[torch.Tensor] = None,
+ w2_bias: Optional[torch.Tensor] = None,
+ alpha: Optional[float] = None,
+ limit: Optional[float] = None,
+ is_vnni: bool = False,
+) -> torch.Tensor:
+ """
+ C++ Fused MoE with brgemm optimization (sglang compatible interface).
+
+ Uses at::native::cpublas::brgemm for efficient batch GEMM on Intel CPUs.
+ Supports both silu_and_mul (standard SwiGLU) and swigluoai (GptOss) activations.
+
+ Args:
+ hidden_states: Input tensor [M, K]
+ w1: Gate and up projections [E, 2N, K]
+ w2: Down projection [E, K, N]
+ topk_weights: Expert weights [M, topk]
+ topk_ids: Expert indices [M, topk]
+ inplace: Whether to use hidden_states as output
+ use_int8_w8a8: Use int8 quantization
+ use_fp8_w8a16: Use fp8 quantization
+ use_mxfp4: Use mxfp4 quantization
+ w1_scale, w2_scale: Quantization scales
+ block_size: Block size for fp8
+ a1_scale, a2_scale: Activation scales
+ w1_bias, w2_bias: Optional biases
+ alpha: swigluoai alpha parameter (set to enable swiglu)
+ limit: swigluoai limit parameter (set to enable swiglu)
+ is_vnni: Whether w1/w2 are already in VNNI packed format
+ """
+ # MXFP4/FP8 kernels only support bf16, convert if needed
+ orig_dtype = hidden_states.dtype
+ need_convert = ((use_mxfp4 or use_fp8_w8a16) and orig_dtype != torch.bfloat16) or orig_dtype == torch.float32
+ if need_convert:
+ hidden_states = hidden_states.to(torch.bfloat16)
+
+ # bias must match hidden_states dtype
+ if w1_bias is not None:
+ w1_bias = w1_bias.to(hidden_states.dtype)
+ if w2_bias is not None:
+ w2_bias = w2_bias.to(hidden_states.dtype)
+
+ # Convert DTensor to local tensor for custom ops compatibility (TP mode)
+ hidden_states = _to_local_tensor(hidden_states)
+ w1 = _to_local_tensor(w1)
+ w2 = _to_local_tensor(w2)
+ topk_weights = _to_local_tensor(topk_weights)
+ topk_ids = _to_local_tensor(topk_ids)
+ w1_scale = _to_local_tensor(w1_scale)
+ w2_scale = _to_local_tensor(w2_scale)
+ a1_scale = _to_local_tensor(a1_scale)
+ a2_scale = _to_local_tensor(a2_scale)
+ w1_bias = _to_local_tensor(w1_bias)
+ w2_bias = _to_local_tensor(w2_bias)
+
+ output = ops.fused_experts(
+ hidden_states, w1, w2, topk_weights, topk_ids,
+ inplace, use_int8_w8a8, use_fp8_w8a16, use_mxfp4,
+ w1_scale, w2_scale, block_size, a1_scale, a2_scale,
+ w1_bias, w2_bias, alpha, limit, is_vnni
+ )
+
+ # Convert back to original dtype if needed
+ if need_convert:
+ output = output.to(orig_dtype)
+ return output
+
+
+class CPUMegaBlocksMoeMLP(torch.nn.Module):
+ """
+ C++ optimized MoE MLP using brgemm.
+ Drop-in replacement for cpu_fused_moe.MegaBlocksMoeMLP with better performance.
+
+ Usage in transformers:
+ # Will be used via @use_kernel_forward_from_hub decorator
+ """
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> tuple:
+ """
+ Forward pass through the MoE layer using C++ kernel.
+
+ Args:
+ x: Input tensor [batch_size, seq_len, hidden_size]
+
+ Returns:
+ Tuple of (output, expert_weights)
+ """
+ # Optimization for GPT-OSS model
+ if getattr(self, "use_mxfp4", None) is None:
+ self.use_mxfp4 = False
+
+ w1_scale = None
+ w2_scale = None
+
+ if (
+ not getattr(self, "packed_scales", False)
+ and hasattr(self.experts, "gate_up_proj")
+ and getattr(self.experts, "gate_up_proj_precision_config", None) is not None
+ ):
+ # convert scales
+ data_1 = ops.convert_scale_packed(self.experts.gate_up_proj_precision_config.weight_scale.data.transpose(-1, -2).contiguous())
+ data_2 = ops.convert_scale_packed(self.experts.down_proj_precision_config.weight_scale.data.transpose(-1, -2).contiguous())
+ self.experts.gate_up_proj_precision_config.weight_scale.storage.data = data_1
+ self.experts.down_proj_precision_config.weight_scale.storage.data = data_2
+ self.packed_scales = True
+ self.use_mxfp4 = True
+
+ if not getattr(self, "packed_weight", False) and hasattr(
+ self.experts, "gate_up_proj"
+ ):
+ # convert weights
+ data_1 = self.experts.gate_up_proj.data.transpose(-1, -2).contiguous()
+ data_2 = self.experts.down_proj.data.transpose(-1, -2).contiguous()
+ if self.use_mxfp4:
+ self.experts.gate_up_proj.storage.data = ops.convert_weight_packed(data_1)
+ self.experts.down_proj.storage.data = ops.convert_weight_packed(data_2)
+ else:
+ # convert_weight_packed only supports bfloat16, float16, int8, fp8_e4m3 or uint8(mxfp4 or int4).
+ data_1 = data_1.to(torch.bfloat16) if data_1.dtype == torch.float32 else data_1
+ data_2 = data_2.to(torch.bfloat16) if data_2.dtype == torch.float32 else data_2
+ self.experts.gate_up_proj.data = ops.convert_weight_packed(data_1)
+ self.experts.down_proj.data = ops.convert_weight_packed(data_2)
+
+ # C++ kernel does not support float32.
+ dtype = torch.bfloat16 if x.dtype == torch.float32 else x.dtype
+ if getattr(self.experts, "gate_up_proj_bias", None) is not None:
+ self.experts.gate_up_proj_bias.data = self.experts.gate_up_proj_bias.data.to(dtype)
+ if getattr(self.experts, "down_proj_bias", None) is not None:
+ self.experts.down_proj_bias.data = self.experts.down_proj_bias.data.to(dtype)
+
+ self.packed_weight = True
+
+ # Get MoE parameters
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
+
+ # Detect activation type
+ if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
+ activation = "swigluoai"
+ alpha = self.experts.alpha
+ limit = self.experts.limit
+ else:
+ activation = getattr(self.experts, "activation", "silu")
+ alpha = 1.702
+ limit = 7.0
+
+ # Get weight tensors
+ if hasattr(self.experts, "gate_up_proj"):
+ w1 = self.experts.gate_up_proj
+ elif hasattr(self.experts, "w1"):
+ w1 = self.experts.w1
+ w3 = getattr(self.experts, "w3", None)
+ if w3 is not None:
+ w1 = torch.cat([w1, w3], dim=-1)
+ else:
+ raise AttributeError("experts module must have 'gate_up_proj' or 'w1' attribute")
+
+ if hasattr(self.experts, "down_proj"):
+ w2 = self.experts.down_proj
+ elif hasattr(self.experts, "w2"):
+ w2 = self.experts.w2
+ else:
+ raise AttributeError("experts module must have 'down_proj' or 'w2' attribute")
+
+ # Get optional bias tensors
+ w1_bias = getattr(self.experts, "gate_up_proj_bias", None)
+ w2_bias = getattr(self.experts, "down_proj_bias", None)
+ w1_bias = w1_bias if w1_bias is None else w1_bias.data
+ w2_bias = w2_bias if w2_bias is None else w2_bias.data
+
+ if self.use_mxfp4:
+ w1_scale = self.experts.gate_up_proj_precision_config.weight_scale.data
+ w2_scale = self.experts.down_proj_precision_config.weight_scale.data
+
+ # Store original shape
+ in_shape = x.size()
+
+ # Route tokens to experts (Python implementation is fast enough)
+ logits, expert_weights, expert_indices = route_tokens_cpu(
+ x,
+ self.router.weight,
+ getattr(self.router, "bias", None),
+ moe_top_k,
+ moe_num_experts,
+ moe_normalize_expert_weights,
+ )
+
+ # Flatten input
+ x_flat = x.view(-1, x.shape[-1])
+
+ # Determine alpha/limit for swiglu activation
+ use_alpha = alpha if activation == "swigluoai" else None
+ use_limit = limit if activation == "swigluoai" else None
+
+ # Call C++ optimized kernel
+ output = fused_moe_cpp(
+ hidden_states=x_flat,
+ w1=w1.data,
+ w2=w2.data,
+ topk_weights=expert_weights,
+ topk_ids=expert_indices.to(torch.int32),
+ inplace=False,
+ use_int8_w8a8=False,
+ use_fp8_w8a16=False,
+ use_mxfp4=self.use_mxfp4,
+ w1_scale=w1_scale,
+ w2_scale=w2_scale,
+ block_size=None,
+ a1_scale=None,
+ a2_scale=None,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ alpha=use_alpha,
+ limit=use_limit,
+ is_vnni=getattr(self, "packed_weight", False),
+ )
+
+ # Restore original shape
+ output = output.view(in_shape)
+
+ return output, expert_weights
+
+
+__all__ = ["fused_moe_cpp", "MegaBlocksMoeMLP"]
diff --git a/build/torch210-cxx11-xpu20253-x86_64-linux/grouped_gemm/__init__.py b/build/torch210-cxx11-xpu20253-x86_64-linux/grouped_gemm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b91c8308f0c24f4c4171b6e4f15b6f76dabf295a
--- /dev/null
+++ b/build/torch210-cxx11-xpu20253-x86_64-linux/grouped_gemm/__init__.py
@@ -0,0 +1,2 @@
+from . import ops
+from . import backend
diff --git a/build/torch210-cxx11-xpu20253-x86_64-linux/grouped_gemm/backend.py b/build/torch210-cxx11-xpu20253-x86_64-linux/grouped_gemm/backend.py
new file mode 100644
index 0000000000000000000000000000000000000000..76037d8039cbfc2f0577275c78e4bc0be762592a
--- /dev/null
+++ b/build/torch210-cxx11-xpu20253-x86_64-linux/grouped_gemm/backend.py
@@ -0,0 +1,33 @@
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# # TODO(tgale): Wrap this in a try-block with better
+# # error message and instructions for building the
+# # c++ operations.
+# import grouped_gemm_backend as backend
+
+# We import the backend operations from the megablocks package as
+# grouped_gemm is vendored in megablocks in this repository.
+# from ... import _ops as backend
+# from megablocks._ops import ops as backend # type: ignore
+from .._ops import ops as backend # type: ignore
+
+def _allocate_output(a, b, batch_sizes, trans_a, trans_b):
+ assert not (trans_a and trans_b)
+ assert batch_sizes.ndim == 1, "Expected 1d tensor for batch_sizes"
+ assert a.ndim == 2, "Expected 2d tensor for 'a'"
+ assert b.ndim == (2 if trans_a else 3)
+
+ shape = (
+ (batch_sizes.shape[0], a.shape[1], b.shape[1])
+ if trans_a else
+ (a.shape[0], (b.shape[1] if trans_b else b.shape[2]))
+ )
+ return torch.empty(*shape, device=a.device, dtype=a.dtype)
+
+def gmm(a, b, batch_sizes, trans_a=False, trans_b=False, c=None):
+ if c is None:
+ c = _allocate_output(a, b, batch_sizes, trans_a, trans_b)
+ backend.gmm(a, b, c, batch_sizes, trans_a, trans_b)
+ return c
diff --git a/build/torch210-cxx11-xpu20253-x86_64-linux/grouped_gemm/ops.py b/build/torch210-cxx11-xpu20253-x86_64-linux/grouped_gemm/ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b30dd14e23837ea3b12334f4e31337ed9ad2b69
--- /dev/null
+++ b/build/torch210-cxx11-xpu20253-x86_64-linux/grouped_gemm/ops.py
@@ -0,0 +1,33 @@
+from . import backend
+import torch
+
+
+class GroupedGemm(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, a, b, batch_sizes, trans_b):
+ ctx.save_for_backward(a, b, batch_sizes)
+ ctx.trans_b = trans_b
+ return backend.gmm(a, b, batch_sizes, trans_a=False, trans_b=trans_b)
+
+ @staticmethod
+ def backward(ctx, grad):
+ grad = grad.contiguous()
+ a, b, batch_sizes = ctx.saved_tensors
+ trans_b = ctx.trans_b
+
+ agrad = None
+ if ctx.needs_input_grad[0]:
+ agrad = backend.gmm(
+ grad, b, batch_sizes, trans_a=False, trans_b=not trans_b)
+
+ bgrad = None
+ if ctx.needs_input_grad[1]:
+ lhs, rhs = (grad, a) if trans_b else (a, grad)
+ bgrad = backend.gmm(
+ lhs, rhs, batch_sizes, trans_a=True, trans_b=False)
+ return agrad, bgrad, None, None
+
+
+def gmm(a, b, batch_sizes, trans_b=False):
+ return GroupedGemm.apply(a, b, batch_sizes, trans_b)
diff --git a/build/torch210-cxx11-xpu20253-x86_64-linux/grouped_gemm_util.py b/build/torch210-cxx11-xpu20253-x86_64-linux/grouped_gemm_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d6d49fc46b0a57ad46e4179df3cc1ac2a24f7ae
--- /dev/null
+++ b/build/torch210-cxx11-xpu20253-x86_64-linux/grouped_gemm_util.py
@@ -0,0 +1,31 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+import warnings
+
+_grouped_gemm_is_available: bool = False
+try:
+ # import grouped_gemm
+ pass
+ _grouped_gemm_is_available = True
+except ImportError as error:
+ warnings.warn('Grouped GEMM not available.')
+
+
+def grouped_gemm_is_available():
+ return _grouped_gemm_is_available
+
+
+def assert_grouped_gemm_is_available():
+ msg = (
+ 'Grouped GEMM not available. Please run '
+ '`pip install git+https://github.com/tgale96/grouped_gemm@main`.',
+ )
+ assert _grouped_gemm_is_available, msg
+
+
+# backend = grouped_gemm.backend if grouped_gemm_is_available() else None
+# ops = grouped_gemm.ops if grouped_gemm_is_available() else None
+
+
+from .grouped_gemm import backend as ops
+from .grouped_gemm import ops as backend
diff --git a/build/torch210-cxx11-xpu20253-x86_64-linux/layers.py b/build/torch210-cxx11-xpu20253-x86_64-linux/layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c91edcd9e2d1a4ef9eac90217ff481f08ab1886
--- /dev/null
+++ b/build/torch210-cxx11-xpu20253-x86_64-linux/layers.py
@@ -0,0 +1,1232 @@
+import torch
+import torch.distributed as dist
+
+from typing import Optional, Any, TYPE_CHECKING
+
+from . import _layers
+from . import ops
+
+# Conditional import for meta kernel registration
+if TYPE_CHECKING:
+
+ def register_fake(fn):
+ return lambda name: fn
+
+else:
+ try:
+ from torch.library import register_fake
+ except ImportError:
+ try:
+ from torch.library import impl_abstract as register_fake
+ except ImportError:
+ # Fallback for older PyTorch versions
+ def register_fake(op_name):
+ def decorator(fn):
+ return fn
+
+ return decorator
+
+
+# Meta kernel implementations for torch.compile compatibility
+def _install_meta_kernels():
+ """Install meta kernels for existing MegaBlocks operations"""
+
+ # Create wrapper functions that check for compilation and return meta tensors
+
+ # Patch ops.sort
+ if hasattr(ops, "sort"):
+ original_sort = ops.sort
+
+ def sort_with_meta(x, end_bit=None):
+ if torch.compiler.is_compiling():
+ # print("Using meta kernel for sort")
+ # Meta implementation - return tensors with correct shape/dtype/device
+ return torch.empty_like(x), torch.empty_like(x)
+ # print("Using original sort kernel")
+ return original_sort(x, end_bit)
+
+ ops.sort = sort_with_meta
+
+ # Patch ops.histogram
+ if hasattr(ops, "histogram"):
+ original_histogram = ops.histogram
+
+ def histogram_with_meta(x, max_val):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty((max_val,), dtype=torch.int32, device=x.device)
+ return original_histogram(x, max_val)
+
+ ops.histogram = histogram_with_meta
+
+ # Patch ops.inclusive_cumsum
+ if hasattr(ops, "inclusive_cumsum"):
+ original_inclusive_cumsum = ops.inclusive_cumsum
+
+ def inclusive_cumsum_with_meta(x, dim):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty_like(x)
+ return original_inclusive_cumsum(x, dim)
+
+ ops.inclusive_cumsum = inclusive_cumsum_with_meta
+
+ # Patch ops.binned_gather
+ if hasattr(ops, "binned_gather"):
+ original_binned_gather = ops.binned_gather
+
+ def binned_gather_with_meta(x, indices, bins, bin_size, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - output shape based on bin_size
+ if x.dim() >= 2:
+ hidden_size = x.size(-1)
+ return torch.empty(
+ (bin_size, x.size(1), hidden_size),
+ dtype=x.dtype,
+ device=x.device,
+ )
+ else:
+ return torch.empty((bin_size,), dtype=x.dtype, device=x.device)
+ return original_binned_gather(x, indices, bins, bin_size, top_k)
+
+ ops.binned_gather = binned_gather_with_meta
+
+ # Patch ops.binned_scatter
+ if hasattr(ops, "binned_scatter"):
+ original_binned_scatter = ops.binned_scatter
+
+ def binned_scatter_with_meta(x, indices, weights, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - typically reduces to 2D
+ if x.dim() >= 3:
+ return torch.empty(
+ (x.size(1), x.size(2)), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty_like(x)
+ return original_binned_scatter(x, indices, weights, bins, top_k)
+
+ ops.binned_scatter = binned_scatter_with_meta
+
+ # Patch ops.gather
+ if hasattr(ops, "gather"):
+ original_gather = ops.gather
+
+ def gather_with_meta(x, indices, bin_ids, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ if x.dim() >= 2:
+ hidden_size = x.size(-1)
+ return torch.empty(
+ (indices.numel(), hidden_size), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty(indices.shape, dtype=x.dtype, device=x.device)
+ return original_gather(x, indices, bin_ids, bins, top_k)
+
+ ops.gather = gather_with_meta
+
+ # Patch ops.scatter
+ if hasattr(ops, "scatter"):
+ original_scatter = ops.scatter
+
+ def scatter_with_meta(x, indices, bin_ids, weights, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - restore sequence shape
+ seq_len = (
+ indices.size(0) // top_k
+ if indices.numel() > 0 and top_k > 0
+ else x.size(0)
+ )
+ if x.dim() >= 2:
+ return torch.empty(
+ (seq_len, x.size(-1)), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty((seq_len,), dtype=x.dtype, device=x.device)
+ return original_scatter(x, indices, bin_ids, weights, bins, top_k)
+
+ ops.scatter = scatter_with_meta
+
+ # Patch ops.replicate
+ if hasattr(ops, "replicate"):
+ original_replicate = ops.replicate
+
+ def replicate_with_meta(x, bins, num_outputs):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty(
+ (x.shape[0], num_outputs), dtype=x.dtype, device=x.device
+ )
+ return original_replicate(x, bins, num_outputs)
+
+ ops.replicate = replicate_with_meta
+
+ # Patch ops.repeat (if it's a regular function)
+ if hasattr(ops, "repeat"):
+ original_repeat = ops.repeat
+
+ def repeat_with_meta(x, repeats):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ if isinstance(repeats, (tuple, list)):
+ new_shape = list(x.shape)
+ for i, rep in enumerate(repeats):
+ if i < len(new_shape):
+ new_shape[i] *= rep
+ return torch.empty(new_shape, dtype=x.dtype, device=x.device)
+ else:
+ new_shape = [x.size(0) * repeats] + list(x.shape[1:])
+ return torch.empty(new_shape, dtype=x.dtype, device=x.device)
+ return original_repeat(x, repeats)
+
+ ops.repeat = repeat_with_meta
+
+
+# Install meta kernels on import
+try:
+ _install_meta_kernels()
+except Exception as e:
+ # If meta kernel installation fails, continue without them
+ # torch.compile may not work but the library will still function
+ import warnings
+
+ warnings.warn(
+ f"Failed to install meta kernels for torch.compile support: {e}", UserWarning
+ )
+
+
+# Set the expert model parallel attributes on a tensor
+def set_expert_model_parallel_attributes(
+ tensor: torch.Tensor,
+ is_parallel: bool,
+):
+ assert not hasattr(tensor, "expert_model_parallel")
+ setattr(tensor, "expert_model_parallel", is_parallel)
+
+
+# Get the expert model parallel attributes from a tensor
+def expert_sharding_degree(
+ world_size: int,
+ moe_num_experts: int,
+) -> int:
+ esd = min(world_size, moe_num_experts)
+ if (moe_num_experts % esd) != 0:
+ raise ValueError(f"Cannot shard {moe_num_experts} experts {esd} ways.")
+ return esd
+
+
+# Calculate the hidden sharding degree based on world size and expert sharding degree
+def hidden_sharding_degree(
+ world_size: int,
+ moe_num_experts: int,
+ ffn_hidden_size: int,
+) -> int:
+ esd = expert_sharding_degree(world_size, moe_num_experts)
+ hsd = world_size // esd
+ if (ffn_hidden_size % hsd) != 0:
+ raise ValueError(f"Cannot shard {ffn_hidden_size} features {hsd} ways.")
+ if (esd * hsd) != world_size:
+ raise ValueError(
+ f"Invalid sharding. expert_sharding_degree ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size})."
+ )
+ return hsd
+
+
+# Calculate the number of experts per rank based on world size and expert sharding degree
+def experts_per_rank(
+ moe_num_experts: int,
+ world_size: int,
+) -> int:
+ return moe_num_experts // expert_sharding_degree(world_size, moe_num_experts)
+
+
+# Calculate the number of features per rank based on ffn hidden size and hidden sharding degree
+def features_per_rank(
+ ffn_hidden_size: int, world_size: int, moe_num_experts: int
+) -> int:
+ return ffn_hidden_size // hidden_sharding_degree(
+ world_size, moe_num_experts, ffn_hidden_size
+ )
+
+
+# Apply jitter to the input tensor
+def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor:
+ low = 1.0 - moe_jitter_eps
+ high = 1.0 + moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return x * (low + noise * (high - low))
+
+
+# Compute the top-k scores from the logits
+def compute_top_k(scores: torch.Tensor, moe_top_k: int):
+ if moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, moe_top_k, dim=-1)
+
+
+# Route tokens to experts and compute expert weights and indices
+def route_tokens(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: torch.Tensor,
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ if training and moe_jitter_eps is not None:
+ x = apply_jitter(x, moe_jitter_eps)
+
+ x_flat = x.view(-1, x.shape[-1])
+ logits = torch.nn.functional.linear(x_flat, router_weight, router_bias)
+ expert_weights, expert_indices = compute_top_k(logits, moe_top_k)
+ expert_weights = expert_weights.softmax(dim=-1)
+ if moe_normalize_expert_weights is not None:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+ if uniform_expert_assignment:
+ expert_indices = _layers.router._uniform_expert_assignment(
+ expert_indices,
+ moe_num_experts,
+ )
+
+ return logits, expert_weights, expert_indices
+
+
+# Scale the gradient of the weights
+def scale_grad(
+ w: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+) -> torch.Tensor:
+ if gradient_scale is None:
+ return w
+ return _layers.mlp.scale_gradient(w, gradient_scale)
+
+
+# Forward pass for the MLP layer
+def mlp_forward(
+ x: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ limit: float = 7.0,
+):
+ # Scale weights
+ w1 = scale_grad(w1, gradient_scale)
+ w2 = scale_grad(w2, gradient_scale)
+ w1_bias = scale_grad(w1_bias, gradient_scale)
+ w2_bias = scale_grad(w2_bias, gradient_scale)
+
+ # Resolve dtensors
+ w1 = _layers.mlp.resolve_dtensor(w1)
+ w2 = _layers.mlp.resolve_dtensor(w2)
+ w1_bias = _layers.mlp.resolve_dtensor(w1_bias)
+ w2_bias = _layers.mlp.resolve_dtensor(w2_bias)
+
+ # Forward pass
+ gate_up = torch.bmm(x, w1) + w1_bias[..., None, :]
+ gate, up = gate_up[..., ::2], gate_up[..., 1::2]
+ gate = gate.clamp(min=None, max=limit)
+ up = up.clamp(min=-limit, max=limit)
+ glu = gate * torch.sigmoid(gate * alpha)
+ next_states = torch.bmm(((up + 1) * glu), w2)
+ next_states += w2_bias[..., None, :]
+ return next_states
+
+# Shared expert MLP forward pass
+def shared_mlp_forward(
+ x: torch.Tensor,
+ up_proj_weight: torch.Tensor,
+ down_proj_weight: torch.Tensor,
+ up_proj_bias: Optional[torch.Tensor] = None,
+ down_proj_bias: Optional[torch.Tensor] = None,
+ activation_fn: Optional[Any] = None,
+ gradient_scale: Optional[float] = None,
+) -> torch.Tensor:
+ # Default activation function
+ if activation_fn is None:
+ activation_fn = torch.nn.functional.gelu
+
+ # Scale weights
+ up_proj_weight = scale_grad(up_proj_weight, gradient_scale)
+ down_proj_weight = scale_grad(down_proj_weight, gradient_scale)
+ if up_proj_bias is not None:
+ up_proj_bias = scale_grad(up_proj_bias, gradient_scale)
+ if down_proj_bias is not None:
+ down_proj_bias = scale_grad(down_proj_bias, gradient_scale)
+
+ # Resolve dtensors
+ up_proj_weight = _layers.mlp.resolve_dtensor(up_proj_weight)
+ down_proj_weight = _layers.mlp.resolve_dtensor(down_proj_weight)
+ if up_proj_bias is not None:
+ up_proj_bias = _layers.mlp.resolve_dtensor(up_proj_bias)
+ if down_proj_bias is not None:
+ down_proj_bias = _layers.mlp.resolve_dtensor(down_proj_bias)
+
+ # Up projection
+ x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias)
+
+ # Activation
+ x = activation_fn(x)
+
+ # Down projection
+ x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias)
+
+ return x
+
+
+# Combine outputs from shared expert and regular experts
+def combine_expert_shared_outputs(
+ shared_expert_out: torch.Tensor,
+ expert_out: torch.Tensor,
+ shared_expert_weighted_sum: bool = False,
+ moe_top_k: int = 1,
+) -> torch.Tensor:
+ if shared_expert_weighted_sum:
+ # Weighted sum based on number of experts used
+ total_experts = moe_top_k + 1
+ shared_weight = 1.0 / total_experts
+ expert_weight = moe_top_k / total_experts
+ return shared_expert_out * shared_weight + expert_out * expert_weight
+ else:
+ # Simple addition
+ return shared_expert_out + expert_out
+
+
+# Global variable to store load balancing loss
+_LOAD_BALANCING_LOSS = []
+
+
+def save_load_balancing_loss(loss):
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.append(loss)
+
+
+def get_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ return _LOAD_BALANCING_LOSS
+
+
+def clear_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.clear()
+
+
+def batched_load_balancing_loss(args):
+ if args.moe_loss_weight == 0:
+ return 0.0
+
+ tokens_per_expert, expert_scores = zip(*get_load_balancing_loss())
+ num_layers_per_pipeline_stage = args.num_layers // args.pipeline_model_parallel_size
+ if args.num_layers_per_virtual_pipeline_stage is not None:
+ num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage
+
+ if len(tokens_per_expert) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f"Expected {num_layers_per_pipeline_stage} token_per_experts "
+ f"but found {len(tokens_per_expert)}.\nnum_layers = "
+ f"{args.num_layers}\npipeline_model_parallel_size = "
+ f"{args.pipeline_model_parallel_size}\n"
+ "num_layers_per_virtual_pipeline_stage"
+ f" = {args.num_layers_per_virtual_pipeline_stage}",
+ )
+ if len(expert_scores) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f"Expected {num_layers_per_pipeline_stage} expert_scores "
+ f"but found {len(tokens_per_expert)}.\nnum_layers = "
+ f"{args.num_layers}\npipeline_model_parallel_size = "
+ f"{args.pipeline_model_parallel_size}\n"
+ "num_layers_per_virtual_pipeline_stage"
+ f" = {args.num_layers_per_virtual_pipeline_stage}",
+ )
+
+ # Verify the shape of the tokens_per_expert and expert_scores tensors.
+ assert all(
+ (x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert)
+ )
+
+ tokens = expert_scores[0].shape[0]
+ assert all(
+ (
+ (
+ x.ndim == 2
+ and x.shape[1] == args.moe_num_experts
+ and x.shape[0] == tokens
+ )
+ for x in expert_scores
+ )
+ )
+
+ # Concatenate the contributions of each layer and convert to
+ # the correct types and formats for the dot product.
+ expert_scores = torch.cat(expert_scores, dim=1)
+ if args.moe_lbl_in_fp32:
+ expert_scores = expert_scores.float()
+ if tokens != 0:
+ expert_scores = expert_scores.mean(dim=0)
+ else:
+ expert_scores = expert_scores.sum(dim=0)
+ tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype)
+
+ expected_values = num_layers_per_pipeline_stage * args.moe_num_experts
+ assert tokens_per_expert.numel() == expected_values
+ assert expert_scores.numel() == expected_values
+
+ # Calculate the total scale across all factors.
+ #
+ # loss_weight * num_experts / (num_layers * tokens * top_k)
+ scale_numerator = args.moe_num_experts * args.moe_loss_weight
+ scale_denominator = args.num_layers * tokens * args.moe_top_k
+ scale = scale_numerator / scale_denominator
+ return scale * torch.dot(tokens_per_expert, expert_scores)
+
+
+# Calculate the expert capacity based on tokens, top_k, number of experts,
+# expert parallel group, capacity factor, and whether expert model parallelism is used.
+def expert_capacity(
+ tokens: int,
+ top_k: int,
+ num_experts: int,
+ expert_parallel_group: int,
+ moe_capacity_factor: float,
+ moe_expert_model_parallelism: bool,
+) -> int:
+ world_size = (
+ dist.get_world_size(expert_parallel_group)
+ if moe_expert_model_parallelism
+ else 1
+ )
+
+ tokens_per_expert = top_k * tokens * world_size / num_experts
+ return int(moe_capacity_factor * tokens_per_expert)
+
+
+def load_balancing_loss(
+ tokens_per_expert: torch.Tensor,
+ expert_scores: torch.Tensor,
+ top_k: int,
+ num_experts: int,
+):
+ assert len(expert_scores.size()) == 2
+ tokens, num_experts = expert_scores.size()
+ assert num_experts == num_experts
+ assert len(tokens_per_expert.size()) == 1
+ (num_experts,) = tokens_per_expert.size()
+ assert num_experts == num_experts
+ scale = num_experts / (tokens * top_k)
+ return scale * torch.dot(
+ tokens_per_expert.to(expert_scores.dtype),
+ expert_scores.mean(dim=0),
+ )
+
+
+def indices_and_bins(
+ top_expert: torch.Tensor,
+ sort_end_bit: int,
+ num_experts: int,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ top_expert = top_expert.int()
+
+ # Ensure contiguous memory layout
+ top_expert = top_expert.contiguous()
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(top_expert.device):
+ output = ops.sort(top_expert, sort_end_bit)
+ bin_ids, indices = output
+ tokens_per_expert = ops.histogram(top_expert, num_experts)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+
+ bins = bins.view(1) if not len(bins.size()) else bins
+ return indices, bin_ids, bins, tokens_per_expert
+
+
+def expert_capacity_fn(
+ tokens: int,
+ top_k: int,
+ num_experts: int,
+ expert_parallel_group: torch.distributed.ProcessGroup,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+) -> int:
+ world_size = (
+ dist.get_world_size(expert_parallel_group)
+ if moe_expert_model_parallelism
+ else 1
+ )
+ tokens_per_expert = top_k * tokens * world_size / num_experts
+ return int(moe_capacity_factor * tokens_per_expert)
+
+
+def permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ top_k,
+ w1,
+ w2,
+ w1_bias,
+ w2_bias,
+ gradient_scale,
+ alpha,
+):
+ # Route tokens to experts
+ x = x.view(-1, x.shape[-1])
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(x.device):
+ x = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
+
+ # Expert computation
+ x = mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale, alpha)
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(x.device):
+ # Route tokens back
+ out = ops.binned_scatter(x, indices, expert_weights, bins, top_k)
+ return out
+
+
+def forward_once(
+ x: torch.Tensor,
+ expert_weights: torch.Tensor,
+ top_experts: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ top_k: int = 4,
+ num_experts: int = 128,
+ expert_parallel_group: int = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ mlp_impl: Optional[str] = None,
+):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
+ top_experts, sort_end_bit, num_experts
+ )
+
+ # Calculate expert capacity
+ sl, bs, _ = x.size()
+
+ expert_capacity = expert_capacity_fn(
+ sl * bs,
+ top_k,
+ num_experts,
+ expert_parallel_group,
+ moe_capacity_factor,
+ moe_expert_model_parallelism,
+ )
+
+ if expert_capacity == 0:
+ expert_capacity = torch.max(tokens_per_expert).item()
+
+ x = permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ top_k,
+ w1,
+ w2,
+ w1_bias,
+ w2_bias,
+ gradient_scale,
+ alpha,
+ )
+ return x, tokens_per_expert
+
+
+def parallel_forward_once(
+ x: torch.Tensor,
+ expert_weights: torch.Tensor,
+ top_experts: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ top_k: int = 4,
+ num_experts: int = 128,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = True,
+ hidden_size: int = 1152,
+ mlp_impl: Optional[str] = "grouped",
+):
+ # Flatten inputs
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+
+ # TODO: remove debugging var
+ # my_rank = dist.get_rank(expert_parallel_group) if expert_parallel_group else 0
+
+ with torch.no_grad():
+ # Step 1: Local permutation setup
+ indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
+ top_experts, sort_end_bit, num_experts
+ )
+
+ # Calculate sharding parameters
+ world_size = dist.get_world_size(expert_parallel_group)
+ hidden_sharding_deg = hidden_sharding_degree(
+ world_size, num_experts, hidden_size
+ )
+ experts_per_rank_val = experts_per_rank(num_experts, world_size)
+
+ # Replicate token counts for hidden sharding
+ repeated_tokens_per_expert = ops.repeat(
+ tokens_per_expert, (hidden_sharding_deg,)
+ )
+
+ # Exchange token counts across devices
+ parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert)
+
+ # Ensure CUB knows which device to use
+ tpe_handle = dist.all_to_all_single(
+ parallel_tokens_per_expert,
+ repeated_tokens_per_expert,
+ group=expert_parallel_group,
+ async_op=True,
+ )
+
+ # Step 2: Local permutation - group tokens by target device
+ x = x.view(-1, x.shape[-1]) # [sl * bs, hs]
+ x = ops.gather(x, indices, bin_ids, bins, top_k)
+
+ # Step 3: Compute communication counts and exchange tokens
+ with torch.no_grad():
+ tpe_handle.wait()
+
+ # Reshape for per-device calculations
+ repeated_tokens_per_expert = repeated_tokens_per_expert.view(
+ world_size, experts_per_rank_val
+ )
+ parallel_tokens_per_expert = parallel_tokens_per_expert.view(
+ world_size, experts_per_rank_val
+ )
+
+ # Calculate send/recv counts
+ send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1).tolist()
+ # recv_counts = parallel_tokens_per_expert.cpu().sum(dim=-1).tolist()
+ parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
+ recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1).tolist()
+ tokens_received = sum(recv_counts)
+
+ # Replicate for hidden sharding
+ x = ops.repeat(x, (hidden_sharding_deg, 1))
+
+ # Cross-device token exchange
+ parallel_x, parallel_x_handle = _layers.all_to_all.all_to_all(
+ x, recv_counts, send_counts, expert_parallel_group, async_op=True
+ )
+
+ with torch.no_grad():
+ # Step 4: Setup for local expert computation
+ replicate_bins = ops.inclusive_cumsum(parallel_tokens_per_expert.flatten(), 0)
+ replicate_bins = (
+ replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins
+ )
+
+ # Create expert indices for received tokens
+ parallel_top_expert = torch.remainder(
+ torch.arange(
+ num_experts * hidden_sharding_deg,
+ dtype=torch.int32,
+ device=indices.device,
+ ),
+ experts_per_rank_val,
+ )
+ parallel_top_expert = ops.replicate(
+ parallel_top_expert.unsqueeze(dim=0),
+ replicate_bins,
+ tokens_received,
+ ).flatten()
+
+ # Sort tokens by expert assignment
+ parallel_bin_ids, parallel_indices = ops.sort(
+ parallel_top_expert,
+ sort_end_bit,
+ )
+
+ # Calculate bins for local experts
+ parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
+ dim=0, dtype=torch.int
+ )
+ parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = (
+ parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins
+ )
+
+ # Calculate expert capacity
+ expert_capacity = expert_capacity_fn(
+ tokens_received,
+ top_k,
+ experts_per_rank_val,
+ expert_parallel_group,
+ moe_capacity_factor,
+ moe_expert_model_parallelism,
+ )
+ if expert_capacity == 0:
+ expert_capacity = torch.max(parallel_tokens_per_expert).item()
+
+ # Locally permute the tokens and perform the expert computation.
+ # Block to make sure that the cross-device permutation is complete.
+ if mlp_impl == "grouped":
+ # GroupedMLP requires counts on CPU. We can use the tensor already
+ # moved to CPU for the prior all_to_all, which avoids an extra
+ # device synchronization.
+ parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+
+ # Step 5: Expert computation
+ parallel_x_handle.wait()
+
+ parallel_x = permute_and_compute(
+ parallel_x,
+ parallel_tokens_per_expert,
+ parallel_indices,
+ parallel_bin_ids,
+ None, # expert_weights
+ parallel_bins,
+ expert_capacity,
+ top_k=1,
+ w1=w1,
+ w2=w2,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ )
+
+ # Step 6: Reverse communication - send results back
+ x, _ = _layers.all_to_all.all_to_all(
+ parallel_x, send_counts, recv_counts, expert_parallel_group
+ )
+
+ # Step 7: Reduce across hidden sharding dimension
+ shape = (hidden_sharding_deg, -1, hidden_size)
+ x = x.view(shape).sum(dim=0)
+
+ # Step 8: Final local unpermutation
+ x = ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
+
+ return x, tokens_per_expert.flatten()
+
+
+def moe_forward(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: Optional[torch.Tensor],
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+ w1: torch.Tensor = None,
+ w2: torch.Tensor = None,
+ w1_bias: torch.Tensor = None,
+ w2_bias: torch.Tensor = None,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ forward_fn: Any = None,
+ hidden_size: int = None,
+ mlp_impl: str = "grouped",
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+
+ # Route tokens to experts
+ logits, expert_weights, expert_indices = route_tokens(
+ x,
+ router_weight,
+ router_bias,
+ moe_top_k,
+ moe_num_experts,
+ moe_jitter_eps,
+ moe_normalize_expert_weights,
+ uniform_expert_assignment,
+ training,
+ )
+
+ # Create router scores for output
+ router_scores = (
+ torch.zeros_like(logits)
+ .scatter_(1, expert_indices, expert_weights)
+ .transpose(0, 1)
+ )
+
+ in_shape = x.size()
+
+ # Prepare forward function arguments
+ forward_args = {
+ "x": x,
+ "expert_weights": expert_weights,
+ "top_experts": expert_indices,
+ "w1": w1,
+ "w2": w2,
+ "w1_bias": w1_bias,
+ "w2_bias": w2_bias,
+ "gradient_scale": gradient_scale,
+ "alpha": alpha,
+ "sort_end_bit": sort_end_bit,
+ "top_k": moe_top_k,
+ "num_experts": moe_num_experts,
+ "expert_parallel_group": expert_parallel_group,
+ "moe_capacity_factor": moe_capacity_factor,
+ "moe_expert_model_parallelism": moe_expert_model_parallelism,
+ "mlp_impl": mlp_impl,
+ }
+
+ # Add hidden_size for parallel forward
+ if moe_expert_model_parallelism and hidden_size is not None:
+ forward_args["hidden_size"] = hidden_size
+ elif moe_expert_model_parallelism and hidden_size is None:
+ # Infer hidden_size from input shape
+ forward_args["hidden_size"] = x.shape[-1]
+
+ # Compute expert outputs
+ x, tokens_per_expert = forward_fn(**forward_args)
+
+ # Save load balancing loss if needed
+ moe_loss_weight = 0.0 # Can be made configurable
+ if training and moe_loss_weight > 0:
+ save_load_balancing_loss((tokens_per_expert, logits))
+
+ # Restore original shape
+ x = x.view(in_shape)
+
+ return x, expert_weights, router_scores
+
+
+def moe_forward_with_shared_expert(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: Optional[torch.Tensor],
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+ w1: torch.Tensor = None,
+ w2: torch.Tensor = None,
+ w1_bias: torch.Tensor = None,
+ w2_bias: torch.Tensor = None,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ forward_fn: Any = None,
+ hidden_size: int = None,
+ mlp_impl: str = "grouped",
+ # Shared expert parameters
+ shared_up_proj_weight: Optional[torch.Tensor] = None,
+ shared_down_proj_weight: Optional[torch.Tensor] = None,
+ shared_up_proj_bias: Optional[torch.Tensor] = None,
+ shared_down_proj_bias: Optional[torch.Tensor] = None,
+ shared_expert_weighted_sum: bool = False,
+ shared_activation_fn: Optional[Any] = None,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+
+ # First, compute regular MoE forward pass
+ expert_out, expert_weights, router_scores = moe_forward(
+ x=x,
+ router_weight=router_weight,
+ router_bias=router_bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=training,
+ w1=w1,
+ w2=w2,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=moe_expert_model_parallelism,
+ forward_fn=forward_fn,
+ hidden_size=hidden_size,
+ mlp_impl=mlp_impl,
+ )
+
+ # If shared expert weights provided, compute shared expert output
+ if shared_up_proj_weight is not None and shared_down_proj_weight is not None:
+ shared_expert_out = shared_mlp_forward(
+ x=x,
+ up_proj_weight=shared_up_proj_weight,
+ down_proj_weight=shared_down_proj_weight,
+ up_proj_bias=shared_up_proj_bias,
+ down_proj_bias=shared_down_proj_bias,
+ activation_fn=shared_activation_fn,
+ gradient_scale=gradient_scale,
+ )
+
+ # Combine expert outputs
+ combined_out = combine_expert_shared_outputs(
+ shared_expert_out=shared_expert_out,
+ expert_out=expert_out,
+ shared_expert_weighted_sum=shared_expert_weighted_sum,
+ moe_top_k=moe_top_k,
+ )
+
+ return combined_out, expert_weights, router_scores
+
+ # Return regular MoE output if no shared expert
+ return expert_out, expert_weights, router_scores
+
+
+def create_shared_expert_weights(
+ hidden_size: int,
+ shared_expert_hidden_size: int,
+ device: torch.device,
+ dtype: torch.dtype,
+ init_method: Any,
+ output_layer_init_method: Any = None,
+) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
+
+ if output_layer_init_method is None:
+ output_layer_init_method = init_method
+
+ # Create weight tensors
+ up_proj_weight = torch.empty(
+ shared_expert_hidden_size,
+ hidden_size,
+ device=device,
+ dtype=dtype,
+ )
+ down_proj_weight = torch.empty(
+ hidden_size,
+ shared_expert_hidden_size,
+ device=device,
+ dtype=dtype,
+ )
+
+ # Initialize weights
+ init_method(up_proj_weight)
+ output_layer_init_method(down_proj_weight)
+
+ # No bias by default
+ return up_proj_weight, down_proj_weight, None, None
+
+
+# HACK: Extract device_mesh from pre-hook closure - required for transformers integration
+# This exists because device_mesh is trapped in hook closures with no model attribute
+# Fragile - breaks if hook structure changes or Python internals change
+# TODO: Replace with a more robust solution when available
+def get_device_mesh(model):
+ # Extract device_mesh from child's unused pre_hook closure
+ try:
+ # Find the pre-hook that contains 'device_mesh' in its closure
+ hook = next(
+ h
+ for h in model.experts._forward_pre_hooks.values()
+ if "device_mesh" in h.__code__.co_freevars
+ )
+ # Extract the device_mesh from the closure
+ return hook.__closure__[
+ hook.__code__.co_freevars.index("device_mesh")
+ ].cell_contents
+ except Exception:
+ return None
+
+
+class MegaBlocksMoeMLP(torch.nn.Module):
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ gradient_scale = getattr(self.experts, "gradient_scale", None)
+ alpha = getattr(self.experts, "alpha", 1.0)
+ moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+ uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
+
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = get_device_mesh(self)
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
+
+ has_parallel = (
+ expert_parallel_group is not None
+ and dist.is_initialized()
+ and dist.get_world_size(expert_parallel_group) > 1
+ )
+ forward_fn = parallel_forward_once if has_parallel else forward_once
+
+ sort_end_bit = max(
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
+ )
+ mlp_impl = getattr(self, "mlp_impl", "grouped")
+ output, expert_weights_out, *_ = moe_forward(
+ x=x,
+ router_weight=self.router.weight,
+ router_bias=self.router.bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=self.training,
+ w1=self.experts.gate_up_proj,
+ w2=self.experts.down_proj,
+ w1_bias=self.experts.gate_up_proj_bias,
+ w2_bias=self.experts.down_proj_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=has_parallel,
+ forward_fn=forward_fn,
+ hidden_size=self.experts.hidden_size,
+ mlp_impl=mlp_impl,
+ )
+ return output, expert_weights_out
+
+
+# Export main classes
+__all__ = ["MegaBlocksMoeMLP", "MegaBlocksMoeMLPWithSharedExpert"]
+
+
+class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
+
+ def __init__(self):
+ super().__init__()
+ # Shared expert weights will be set by the user
+ self.shared_up_proj_weight = None
+ self.shared_down_proj_weight = None
+ self.shared_up_proj_bias = None
+ self.shared_down_proj_bias = None
+ self.shared_expert_weighted_sum = False
+ self.shared_activation_fn = None
+
+ def set_shared_expert_weights(
+ self,
+ up_proj_weight: torch.Tensor,
+ down_proj_weight: torch.Tensor,
+ up_proj_bias: Optional[torch.Tensor] = None,
+ down_proj_bias: Optional[torch.Tensor] = None,
+ weighted_sum: bool = False,
+ activation_fn: Optional[Any] = None,
+ ):
+ self.shared_up_proj_weight = up_proj_weight
+ self.shared_down_proj_weight = down_proj_weight
+ self.shared_up_proj_bias = up_proj_bias
+ self.shared_down_proj_bias = down_proj_bias
+ self.shared_expert_weighted_sum = weighted_sum
+ self.shared_activation_fn = activation_fn
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ gradient_scale = getattr(self.experts, "gradient_scale", None)
+ alpha = getattr(self.experts, "alpha", 1.0)
+ moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+ uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
+
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = get_device_mesh(self)
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
+
+ has_parallel = (
+ expert_parallel_group is not None
+ and dist.is_initialized()
+ and dist.get_world_size(expert_parallel_group) > 1
+ )
+ forward_fn = parallel_forward_once if has_parallel else forward_once
+
+ sort_end_bit = max(
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
+ )
+ mlp_impl = getattr(self, "mlp_impl", "grouped")
+
+ output, expert_weights_out, *_ = moe_forward_with_shared_expert(
+ x=x,
+ router_weight=self.router.weight,
+ router_bias=self.router.bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=self.training,
+ w1=self.experts.gate_up_proj,
+ w2=self.experts.down_proj,
+ w1_bias=self.experts.gate_up_proj_bias,
+ w2_bias=self.experts.down_proj_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=has_parallel,
+ forward_fn=forward_fn,
+ hidden_size=self.experts.hidden_size,
+ mlp_impl=mlp_impl,
+ # Shared expert parameters
+ shared_up_proj_weight=self.shared_up_proj_weight,
+ shared_down_proj_weight=self.shared_down_proj_weight,
+ shared_up_proj_bias=self.shared_up_proj_bias,
+ shared_down_proj_bias=self.shared_down_proj_bias,
+ shared_expert_weighted_sum=self.shared_expert_weighted_sum,
+ shared_activation_fn=self.shared_activation_fn,
+ )
+ return output, expert_weights_out
+
+
+# Patch for XPU or CPU support
+if hasattr(torch, "xpu") and torch.xpu.is_available():
+ from .xpu_fused_moe import MegaBlocksMoeMLP
+
+from .cpu_moe_cpp import CPUMegaBlocksMoeMLP
diff --git a/build/torch210-cxx11-xpu20253-x86_64-linux/megablocks/__init__.py b/build/torch210-cxx11-xpu20253-x86_64-linux/megablocks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23
--- /dev/null
+++ b/build/torch210-cxx11-xpu20253-x86_64-linux/megablocks/__init__.py
@@ -0,0 +1,26 @@
+import ctypes
+import importlib.util
+import sys
+from pathlib import Path
+from types import ModuleType
+
+
+def _import_from_path(file_path: Path) -> ModuleType:
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
+ # it would also be used for other imports. So, we make a module name that
+ # depends on the path for it to be unique using the hex-encoded hash of
+ # the path.
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
+ module_name = path_hash
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ if spec is None:
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
+ module = importlib.util.module_from_spec(spec)
+ if module is None:
+ raise ImportError(f"Cannot load module {module_name} from spec")
+ sys.modules[module_name] = module
+ spec.loader.exec_module(module) # type: ignore
+ return module
+
+
+globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
diff --git a/build/torch210-cxx11-xpu20253-x86_64-linux/metadata.json b/build/torch210-cxx11-xpu20253-x86_64-linux/metadata.json
new file mode 100644
index 0000000000000000000000000000000000000000..b911d0a2549a35a1c65ab7e77d32e5aac23cd6ac
--- /dev/null
+++ b/build/torch210-cxx11-xpu20253-x86_64-linux/metadata.json
@@ -0,0 +1,8 @@
+{
+ "version": 1,
+ "license": "Apache-2.0",
+ "python-depends": [],
+ "backend": {
+ "type": "xpu"
+ }
+}
\ No newline at end of file
diff --git a/build/torch210-cxx11-xpu20253-x86_64-linux/ops/__init__.py b/build/torch210-cxx11-xpu20253-x86_64-linux/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b944080df810d0b0cfc571f3009b0098a651f9b7
--- /dev/null
+++ b/build/torch210-cxx11-xpu20253-x86_64-linux/ops/__init__.py
@@ -0,0 +1,35 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from .binned_gather import binned_gather
+from .binned_scatter import binned_scatter
+from .cumsum import exclusive_cumsum, inclusive_cumsum
+from .gather import gather
+from .histogram import histogram
+from .padded_gather import padded_gather
+from .padded_scatter import padded_scatter
+from .repeat import repeat
+from .replicate import replicate
+from .round_up import round_up
+from .scatter import scatter
+from .sort import sort
+from .sum import sum
+from .topology import topology
+
+__all__ = [
+ 'binned_gather',
+ 'binned_scatter',
+ 'exclusive_cumsum',
+ 'inclusive_cumsum',
+ 'gather',
+ 'histogram',
+ 'padded_gather',
+ 'padded_scatter',
+ 'repeat',
+ 'replicate',
+ 'round_up',
+ 'scatter',
+ 'sort',
+ 'sum',
+ 'topology',
+]
diff --git a/build/torch210-cxx11-xpu20253-x86_64-linux/ops/all_to_all_benchmark.py b/build/torch210-cxx11-xpu20253-x86_64-linux/ops/all_to_all_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c939818edca3345f6344bbc7cef07ffe3cd0181
--- /dev/null
+++ b/build/torch210-cxx11-xpu20253-x86_64-linux/ops/all_to_all_benchmark.py
@@ -0,0 +1,63 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import torch.distributed as dist
+
+# from megablocks import benchmark_util
+# from megablocks.layers.all_to_all import all_to_all
+
+from .. import benchmark_util
+from .._layers.all_to_all import all_to_all
+
+_ALL_TO_ALL_BENCHMARK = (
+ (8, 1024),
+ (16, 1024),
+ (32, 1024),
+ (64, 1024),
+ (128, 1024),
+ (256, 1024),
+ (512, 1024),
+ (1024, 1024),
+ (2 * 1024, 1024),
+ (4 * 1024, 1024),
+ (8 * 1024, 1024),
+ (16 * 1024, 1024),
+ (32 * 1024, 1024),
+ (64 * 1024, 1024),
+ (128 * 1024, 1024),
+ (256 * 1024, 1024),
+ (512 * 1024, 1024),
+ (1024 * 1024, 1024),
+)
+
+
+def benchmark_all_to_all(group, sl, hs):
+ world_size = dist.get_world_size(group)
+ assert (sl % world_size) == 0
+ send_recv_sizes = [sl // world_size] * world_size
+
+ x = torch.randn((sl, hs)).cuda().half()
+
+ details = {
+ 'world_size': world_size,
+ 'message_size (B)': send_recv_sizes[0] * hs * 2, # 2B elements.
+ }
+
+ def benchmark():
+ return all_to_all(x, send_recv_sizes, send_recv_sizes, group)
+
+ time, std = benchmark_util.benchmark_function(benchmark)
+
+ if dist.get_rank(group) == 0:
+ benchmark_util.log_benchmark('All-To-All', details, time, std)
+
+
+if __name__ == '__main__':
+ assert dist.is_available()
+ group = dist.init_process_group(backend='nccl')
+ local_rank = dist.get_rank(group)
+ torch.cuda.set_device(local_rank)
+
+ for args in _ALL_TO_ALL_BENCHMARK:
+ benchmark_all_to_all(group, *args)
diff --git a/build/torch210-cxx11-xpu20253-x86_64-linux/ops/binned_gather.py b/build/torch210-cxx11-xpu20253-x86_64-linux/ops/binned_gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..189a7fa3518d660f29ea32e7a04827164af98d60
--- /dev/null
+++ b/build/torch210-cxx11-xpu20253-x86_64-linux/ops/binned_gather.py
@@ -0,0 +1,37 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for binned_gather kernel.
+class BinnedGatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bins: torch.Tensor,
+ bin_size: int,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bins)
+ ctx.top_k = top_k
+ return kernels.binned_gather(x, indices, None, bins, bin_size, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ indices, bins = ctx.saved_tensors
+ out = kernels.binned_scatter(grad, indices, None, bins, ctx.top_k)
+ return out, None, None, None, None
+
+
+binned_gather = BinnedGatherOp.apply
diff --git a/build/torch210-cxx11-xpu20253-x86_64-linux/ops/binned_scatter.py b/build/torch210-cxx11-xpu20253-x86_64-linux/ops/binned_scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb937c0c106662ce8108c1cb926f8f063b163d3d
--- /dev/null
+++ b/build/torch210-cxx11-xpu20253-x86_64-linux/ops/binned_scatter.py
@@ -0,0 +1,59 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for binned_scatter kernel.
+class BinnedScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ):
+ assert len(x.size()) == 3
+ ctx.bin_size = x.size(1)
+ ctx.top_k = top_k
+
+ # TODO(tgale): Don't save 'x' for backwards if we don't need to
+ # calculate the gradient w.r.t. 'weights'.
+ ctx.save_for_backward(x, indices, weights, bins)
+ return kernels.binned_scatter(x, indices, weights, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ x, indices, weights, bins = ctx.saved_tensors
+ out = kernels.binned_gather(
+ grad,
+ indices,
+ weights,
+ bins,
+ ctx.bin_size,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[2]:
+ wgrad = kernels.binned_scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bins,
+ ctx.top_k,
+ )
+ return out, None, wgrad, None, None
+
+
+binned_scatter = BinnedScatterOp.apply
diff --git a/build/torch210-cxx11-xpu20253-x86_64-linux/ops/cumsum.py b/build/torch210-cxx11-xpu20253-x86_64-linux/ops/cumsum.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2b7572391e20045d335cf7337246e8a9b9f57ef
--- /dev/null
+++ b/build/torch210-cxx11-xpu20253-x86_64-linux/ops/cumsum.py
@@ -0,0 +1,52 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ # import megablocks_ops as ops # type: ignore
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrappers for cumsum kernels.
+# NOTE: Does not support gradients.
+class ExclusiveCumsumOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, dim: int):
+ if len(x.size()) == 1:
+ x = x.view([1, -1])
+ out = torch.empty_like(x)
+ ops.exclusive_cumsum(x, 1, out)
+ return out.squeeze()
+ out = torch.empty_like(x)
+ ops.exclusive_cumsum(x, dim, out)
+ return out
+
+
+exclusive_cumsum = ExclusiveCumsumOp.apply
+
+
+class InclusiveCumsumOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, dim: int) -> torch.Tensor:
+ if len(x.size()) == 1:
+ x = x.view([1, -1])
+ out = torch.empty_like(x)
+ ops.inclusive_cumsum(x, 1, out)
+ return out.squeeze()
+ out = torch.empty_like(x)
+ ops.inclusive_cumsum(x, dim, out)
+ return out
+
+
+inclusive_cumsum = InclusiveCumsumOp.apply
diff --git a/build/torch210-cxx11-xpu20253-x86_64-linux/ops/gather.py b/build/torch210-cxx11-xpu20253-x86_64-linux/ops/gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1f87c1e7bed8d3589dd790805234976e0b05898
--- /dev/null
+++ b/build/torch210-cxx11-xpu20253-x86_64-linux/ops/gather.py
@@ -0,0 +1,38 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for gather kernel.
+class GatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bin_ids, bins)
+ ctx.top_k = top_k
+ return kernels.gather(x, indices, bin_ids, None, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+
+ indices, bin_ids, bins = ctx.saved_tensors
+ out = kernels.scatter(grad, indices, bin_ids, None, bins, ctx.top_k)
+ return out, None, None, None, None, None
+
+
+gather = GatherOp.apply
diff --git a/build/torch210-cxx11-xpu20253-x86_64-linux/ops/histogram.py b/build/torch210-cxx11-xpu20253-x86_64-linux/ops/histogram.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b3f058ec373cbba7555704fb5e4212c3cc75d9d
--- /dev/null
+++ b/build/torch210-cxx11-xpu20253-x86_64-linux/ops/histogram.py
@@ -0,0 +1,27 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for histogram kernel.
+# NOTE: Does not support gradients.
+class HistogramOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, max_val: float):
+ return ops.histogram(x, max_val)
+
+
+histogram = HistogramOp.apply
diff --git a/build/torch210-cxx11-xpu20253-x86_64-linux/ops/histogram_benchmark.py b/build/torch210-cxx11-xpu20253-x86_64-linux/ops/histogram_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0b9c6047567b87a295979498142230d1b0c9411
--- /dev/null
+++ b/build/torch210-cxx11-xpu20253-x86_64-linux/ops/histogram_benchmark.py
@@ -0,0 +1,78 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import numpy as np
+import torch
+# from absl.testing import parameterized
+
+from .. import ops
+
+_HISTOGRAM_TESTS = (
+ (16384, torch.int32, 2),
+ (16384, torch.int32, 4),
+ (16384, torch.int32, 8),
+ (16384, torch.int32, 16),
+ (16384, torch.int32, 32),
+ (16384, torch.int32, 64),
+ (16384, torch.int32, 128),
+ (16384, torch.int32, 256),
+)
+
+
+def benchmark_function(fn, iterations=10):
+ # Run once to get rid of startup overhead.
+ fn()
+ times = []
+ for _ in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+ start.record()
+ fn()
+ end.record()
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ times = np.array(times)
+ return times.mean(), times.std(), times.max(), times.min()
+
+
+def log_benchmark(arguments, mean_t, std_t):
+ print('=' * 60)
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t))
+ print('=' * 60)
+
+
+# class HistogramBenchmark(parameterized.TestCase):
+#
+# @parameterized.parameters(*_HISTOGRAM_TESTS)
+# def testHistogram(self, n, dtype, max_val):
+# x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
+#
+# mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.histogram(x, max_val),)
+# arguments = {
+# 'n': n,
+# 'dtype': dtype,
+# 'max_val': max_val,
+# }
+# log_benchmark(arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_HISTOGRAM_TESTS)
+# def testTorchHistogram(self, n, dtype, max_val):
+# x = torch.randint(0, 128, (n,)).cuda().to(dtype)
+#
+# mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.histc(x, max_val, 0, max_val - 1),)
+# arguments = {
+# 'n': n,
+# 'dtype': dtype,
+# 'max_val': max_val,
+# }
+# log_benchmark(arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch210-cxx11-xpu20253-x86_64-linux/ops/matmul_benchmark.py b/build/torch210-cxx11-xpu20253-x86_64-linux/ops/matmul_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c4b9b8866ffed2eb769b77f2320c82e5491ae0e
--- /dev/null
+++ b/build/torch210-cxx11-xpu20253-x86_64-linux/ops/matmul_benchmark.py
@@ -0,0 +1,415 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+
+# import stk
+
+# try:
+# import stk
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/ops/matmul_benchmark.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+# from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+
+# Calling tensor.t() calls tensor.transpose(0, 1) which calls
+# torch.as_strided(...). Circumvent this chain to avoid an overhead
+# this adds.
+def transpose_view(x):
+ return torch.as_strided(
+ x,
+ (x.shape[1], x.shape[0]),
+ (x.stride()[1], x.stride()[0]),
+ )
+
+
+_MATMUL_TESTS = (
+ (64 * 1024, 512, 2048, 64),
+ (32 * 1024, 768, 3072, 64),
+ (8 * 1024, 1024, 4096, 64),
+ (4 * 2048, 4096, 4 * 4096, 4),
+)
+
+
+def log_benchmark(name, arguments, time, std, flops):
+ benchmark_util.log_benchmark(name, arguments, time, std)
+ print('flops = {:.2f}B'.format(flops / 1e9))
+ print('throughput = {:.2f}T'.format(flops / 1e9 / time))
+ print('=' * 60)
+
+
+# class MatmulBenchmark(parameterized.TestCase):
+#
+# def build_sparse_matrix(self, x, padded_bins, fhs, ne):
+# blocking = 128
+# padded_tokens, _ = x.size()
+# assert padded_tokens % blocking == 0
+# assert fhs % blocking == 0
+#
+# # Offsets for the sparse matrix. All rows have the
+# # same number of nonzero blocks dictated by the
+# # dimensionality of a single expert.
+# block_rows = padded_tokens // blocking
+# blocks_per_row = fhs // blocking
+# offsets = torch.arange(
+# 0,
+# block_rows * blocks_per_row + 1,
+# blocks_per_row,
+# dtype=torch.int32,
+# device=x.device,
+# )
+#
+# # Indices for the sparse matrix. The indices for
+# # the intermediate matrix are dynamic depending
+# # on the mapping of tokens to experts.
+# column_indices = ops.topology(
+# padded_bins,
+# blocking,
+# block_rows,
+# blocks_per_row,
+# )
+# data = torch.empty(
+# column_indices.numel(),
+# blocking,
+# blocking,
+# dtype=torch.float16,
+# device=x.device,
+# )
+# shape = (padded_tokens, fhs * ne)
+# row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
+# return stk.Matrix(shape, data, row_indices, column_indices, offsets)
+#
+# def build_input_matrix(self, sl, hs, ne):
+# x = torch.randn((sl, hs)).cuda().half()
+#
+# # Assign tokens to experts uniformly.
+# top_expert = torch.arange(0, sl).cuda().int() % ne
+#
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(top_expert, ne)
+# padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+# padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+# out = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, 1)
+# return out, padded_bins
+#
+# def build_weight_matrix(self, ne, hs, fhs):
+# return torch.randn((hs, ne * fhs)).cuda().half()
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_Fwd_SDD_NT(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+# w = transpose_view(w)
+#
+# def benchmark():
+# return stk.ops.sdd(x, w, topo)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0::Fwd::SDD::NT',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_GradX_DSD_NN(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+#
+# def benchmark():
+# return stk.ops.dsd(topo, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0::GradX::DSD::NN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_GradW_DSD_TN(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+# topo = topo.t()
+#
+# def benchmark():
+# return stk.ops.dsd(topo, x)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0::GradW::DSD::TN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_Fwd_DSD_NN(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+#
+# def benchmark():
+# return stk.ops.dsd(x, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::Fwd::DSD::NN',
+# arguments,
+# mean_t,
+# std_t,
+# x.nnz * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_GradX_SDD_NT(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+# out = stk.ops.dsd(x, w)
+# w = transpose_view(w)
+#
+# def benchmark():
+# return stk.ops.sdd(out, w, x)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::GradX::SDD::NT',
+# arguments,
+# mean_t,
+# std_t,
+# x.nnz * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_GradW_DSD_TN(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+# out = stk.ops.dsd(x, w)
+# x = x.t()
+#
+# def benchmark():
+# return stk.ops.dsd(x, out)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::GradW::DSD::TN',
+# arguments,
+# mean_t,
+# std_t,
+# x.nnz * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_Fwd_DDD_NT(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, hs)).cuda().half()
+# w = torch.randn((ne, hs, fhs)).cuda().half()
+#
+# w = w.transpose(1, 2).contiguous()
+# w = w.transpose(1, 2)
+#
+# def benchmark():
+# return torch.bmm(x, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0::Fwd:DDD::NT',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_GradX_DDD_NN(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, hs)).cuda().half()
+# w = torch.randn((ne, hs, fhs)).cuda().half()
+# out = torch.bmm(x, w)
+# w = w.transpose(1, 2).contiguous()
+#
+# def benchmark():
+# return torch.bmm(out, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0:GradX:DDD::NN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_GradW_DDD_TN(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, hs)).cuda().half()
+# w = torch.randn((ne, hs, fhs)).cuda().half()
+# out = torch.bmm(x, w)
+# out = out.transpose(1, 2)
+#
+# def benchmark():
+# return torch.bmm(out, x)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0:GradW:DDD::TN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_Fwd_DDD_NN(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+# w = torch.randn((ne, fhs, hs)).cuda().half()
+#
+# def benchmark():
+# return torch.bmm(x, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::Fwd::DDD::NN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_GradX_DDD_NT(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+# w = torch.randn((ne, fhs, hs)).cuda().half()
+# out = torch.bmm(x, w)
+# w = torch.transpose(w, 1, 2)
+#
+# def benchmark():
+# return torch.bmm(out, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::GradX::DDD::NT',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_GradW_DDD_TN(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+# w = torch.randn((ne, fhs, hs)).cuda().half()
+# out = torch.bmm(x, w)
+# x = torch.transpose(x, 1, 2)
+#
+# def benchmark():
+# return torch.bmm(x, out)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::GradW::DDD::TN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * hs * 2,
+# )
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch210-cxx11-xpu20253-x86_64-linux/ops/padded_gather.py b/build/torch210-cxx11-xpu20253-x86_64-linux/ops/padded_gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1cf4047c9494394d2a3884ba8830179013db7ff
--- /dev/null
+++ b/build/torch210-cxx11-xpu20253-x86_64-linux/ops/padded_gather.py
@@ -0,0 +1,55 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for padded_gather kernel.
+class PaddedGatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bin_ids, bins, padded_bins)
+ ctx.top_k = top_k
+ return kernels.padded_gather(
+ x,
+ indices,
+ bin_ids,
+ None,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+
+ indices, bin_ids, bins, padded_bins = ctx.saved_tensors
+ out = kernels.padded_scatter(
+ grad,
+ indices,
+ bin_ids,
+ None,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+ return out, None, None, None, None, None
+
+
+padded_gather = PaddedGatherOp.apply
diff --git a/build/torch210-cxx11-xpu20253-x86_64-linux/ops/padded_scatter.py b/build/torch210-cxx11-xpu20253-x86_64-linux/ops/padded_scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..61e021b81497e472cda5d72bdac557a0ca92d262
--- /dev/null
+++ b/build/torch210-cxx11-xpu20253-x86_64-linux/ops/padded_scatter.py
@@ -0,0 +1,98 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for padded_scatter kernel.
+class PaddedScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+ ):
+ maybe_x = [x] if ctx.needs_input_grad[3] else []
+ ctx.save_for_backward(
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ *maybe_x,
+ )
+ ctx.top_k = top_k
+ ctx.x_shape = x.shape
+ return kernels.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ saved_tensors = ctx.saved_tensors
+
+ indices, bin_ids, weights, bins, padded_bins = saved_tensors[:5]
+ dgrad = None
+ if ctx.needs_input_grad[0]:
+ dgrad = kernels.padded_gather(
+ grad,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[3]: # need wgrad
+ x = saved_tensors[-1]
+ wgrad = kernels.padded_scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+ return dgrad, None, None, wgrad, None, None, None, None
+
+
+def padded_scatter(
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+):
+ return PaddedScatterOp.apply(
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
diff --git a/build/torch210-cxx11-xpu20253-x86_64-linux/ops/padded_scatter_benchmark.py b/build/torch210-cxx11-xpu20253-x86_64-linux/ops/padded_scatter_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..3cbe4735891446b46f93170c64c23fe63632bf93
--- /dev/null
+++ b/build/torch210-cxx11-xpu20253-x86_64-linux/ops/padded_scatter_benchmark.py
@@ -0,0 +1,66 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import torch
+# from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+_PADDED_SCATTER_BENCHMARK = (
+ # dMoE-Medium, 8-way EMP.
+ (1024 * 16, 1024, 8, 4),
+ # dMoE-Medium, post-all-to-all.
+ (1024 * 16 * 4, 1024, 8, 1),
+)
+
+
+# class PaddedScatterTest(parameterized.TestCase):
+#
+# @parameterized.parameters(*_PADDED_SCATTER_BENCHMARK)
+# def testPaddedScatter(self, sl, hs, ne, top_k):
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+#
+# # Randomly assign tokens to experts.
+# top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(top_expert, ne)
+# padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+# padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+#
+# # Sample weights for the scatter reduce.
+# weights = torch.rand((sl * top_k,)).cuda().half()
+#
+# # Gather the data to prepare for backwards.
+# x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
+#
+# def benchmark():
+# return ops.padded_scatter(
+# x,
+# indices,
+# bin_ids,
+# weights,
+# bins,
+# padded_bins,
+# top_k,
+# )
+#
+# time, std = benchmark_util.benchmark_function(benchmark)
+# benchmark_util.log_benchmark(
+# 'Padded Scatter',
+# {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# 'top_k': top_k,
+# },
+# time,
+# std,
+# )
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch210-cxx11-xpu20253-x86_64-linux/ops/permute_benchmark.py b/build/torch210-cxx11-xpu20253-x86_64-linux/ops/permute_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..697abddbb3a2082ec4ddd6d94f89f7faabb34b40
--- /dev/null
+++ b/build/torch210-cxx11-xpu20253-x86_64-linux/ops/permute_benchmark.py
@@ -0,0 +1,149 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import torch
+# from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+_PERMUTE_TESTS = (
+ (16384, 768, 2),
+ (16384, 768, 4),
+ (16384, 768, 8),
+ (16384, 768, 16),
+ (16384, 768, 32),
+ (16384, 768, 64),
+ (16384, 768, 128),
+ (16384 * 8, 768, 2),
+ (16384 * 8, 768, 4),
+ (16384 * 8, 768, 8),
+ (16384 * 8, 768, 16),
+ (16384 * 8, 768, 32),
+ (16384 * 8, 768, 64),
+ (16384 * 8, 768, 128),
+)
+
+
+# class PermuteBenchmark(parameterized.TestCase):
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testBinnedGather(self, sl, hs, ne):
+# # NOTE: Capacity factor == 1.
+# ec = sl // ne
+#
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+# top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(indices, ne)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+#
+# def benchmark():
+# return ops.binned_gather(x, indices, bins, ec)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('BinnedGather', arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testBinnedScatter(self, sl, hs, ne):
+# # NOTE: Capacity factor == 1.
+# ec = sl // ne
+#
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+# top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(indices, ne)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+# x = ops.binned_gather(x, indices, bins, ec)
+#
+# def benchmark():
+# return ops.binned_scatter(x, indices, bins)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('BinnedScatter', arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testPaddedGather(self, sl, hs, ne):
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+#
+# # Randomly assign tokens to experts.
+# top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(top_expert, ne)
+# padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+# padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+#
+# def benchmark():
+# return ops.padded_gather(x, indices, bin_ids, bins, padded_bins)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('PaddedGather', arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testPaddedScatter(self, sl, hs, ne):
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+#
+# # Randomly assign tokens to experts.
+# top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(top_expert, ne)
+# padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+# padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+# x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins)
+#
+# def benchmark():
+# return ops.padded_scatter(x, indices, bin_ids, bins, padded_bins)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('PaddedScatter', arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testCopy(self, sl, hs, ne):
+# # NOTE: Capacity factor == 1.
+# # ec = sl // ne
+#
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+# y = x.clone()
+#
+# def benchmark():
+# return y.copy_(x)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('Copy', arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch210-cxx11-xpu20253-x86_64-linux/ops/repeat.py b/build/torch210-cxx11-xpu20253-x86_64-linux/ops/repeat.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e9e09de5f857d51cd758ab30b2f3a846d6f9275
--- /dev/null
+++ b/build/torch210-cxx11-xpu20253-x86_64-linux/ops/repeat.py
@@ -0,0 +1,10 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+
+def repeat(x: torch.Tensor, tiling: torch.Size):
+ if all((t == 1 for t in tiling)):
+ return x
+ return x.repeat(*tiling)
diff --git a/build/torch210-cxx11-xpu20253-x86_64-linux/ops/replicate.py b/build/torch210-cxx11-xpu20253-x86_64-linux/ops/replicate.py
new file mode 100644
index 0000000000000000000000000000000000000000..26daf0eede330603a4b8ea7167faf1411d07ca93
--- /dev/null
+++ b/build/torch210-cxx11-xpu20253-x86_64-linux/ops/replicate.py
@@ -0,0 +1,36 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for replicate kernel.
+class ReplicateOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, bins: torch.Tensor, num_outputs: int):
+ ctx.save_for_backward(bins)
+ out = torch.empty((x.shape[0], num_outputs), dtype=x.dtype, device=x.device)
+ ops.replicate_forward(x, bins, out)
+ return out
+
+ @staticmethod
+ def backward(ctx: Any, grad: torch.Tensor):
+ bins, = ctx.saved_tensors
+ out = torch.empty((grad.shape[0], bins.shape[0]), dtype=grad.dtype, device=grad.device)
+ ops.replicate_backward(grad, bins, out)
+ return out, None, None
+
+
+replicate = ReplicateOp.apply
diff --git a/build/torch210-cxx11-xpu20253-x86_64-linux/ops/round_up.py b/build/torch210-cxx11-xpu20253-x86_64-linux/ops/round_up.py
new file mode 100644
index 0000000000000000000000000000000000000000..6cf6bc873c9f448c5fa9126ebcfd66e8688002af
--- /dev/null
+++ b/build/torch210-cxx11-xpu20253-x86_64-linux/ops/round_up.py
@@ -0,0 +1,14 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+
+def round_up(x: torch.Tensor, value: int):
+ assert isinstance(value, int)
+ assert x.dtype == torch.int32
+
+ # TODO(tgale): If this becomes and issue
+ # do this in a custom kernel. We only expect
+ # to use this on arrays of less than 1k elements.
+ return torch.div(x + (value - 1), value, rounding_mode='trunc') * value
diff --git a/build/torch210-cxx11-xpu20253-x86_64-linux/ops/scatter.py b/build/torch210-cxx11-xpu20253-x86_64-linux/ops/scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4605d9b46f387761b070352365f223dbfe69d47
--- /dev/null
+++ b/build/torch210-cxx11-xpu20253-x86_64-linux/ops/scatter.py
@@ -0,0 +1,72 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Optional
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for scatter kernel.
+class ScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ) -> torch.Tensor:
+ maybe_x = [x] if ctx.needs_input_grad[3] else []
+ ctx.save_for_backward(indices, bin_ids, weights, bins, *maybe_x)
+ ctx.top_k = top_k
+ ctx.x_shape = x.shape
+ return kernels.scatter(x, indices, bin_ids, weights, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ saved_tensors = ctx.saved_tensors
+
+ indices, bin_ids, weights, bins = saved_tensors[:4]
+ dgrad = None
+ if ctx.needs_input_grad[0]:
+ dgrad = kernels.gather(
+ grad,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[3]: # need wgrad
+ x = saved_tensors[-1]
+ wgrad = kernels.scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bin_ids,
+ bins,
+ ctx.top_k,
+ )
+ return dgrad, None, None, wgrad, None, None, None
+
+
+def scatter(
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+) -> Optional[torch.Tensor]:
+ return ScatterOp.apply(x, indices, bin_ids, weights, bins, top_k)
diff --git a/build/torch210-cxx11-xpu20253-x86_64-linux/ops/sort.py b/build/torch210-cxx11-xpu20253-x86_64-linux/ops/sort.py
new file mode 100644
index 0000000000000000000000000000000000000000..bda3bf64283e39533c2eae3627e76bb2d0262c9f
--- /dev/null
+++ b/build/torch210-cxx11-xpu20253-x86_64-linux/ops/sort.py
@@ -0,0 +1,38 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Optional, Tuple
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+_BITS_FOR_DTYPE = {
+ torch.int16: 16,
+ torch.int32: 32,
+ torch.int64: 64,
+}
+
+
+# Autograd wrapper for sort kernel.
+# NOTE: Does not support gradients.
+class SortOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, end_bit: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
+ if end_bit is None:
+ end_bit = _BITS_FOR_DTYPE[x.dtype]
+ x_out = torch.empty_like(x)
+ iota_out = torch.empty_like(x)
+ ops.sort(x, end_bit, x_out, iota_out)
+ return (x_out, iota_out)
+
+
+sort = SortOp.apply
diff --git a/build/torch210-cxx11-xpu20253-x86_64-linux/ops/sort_benchmark.py b/build/torch210-cxx11-xpu20253-x86_64-linux/ops/sort_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..11043c0824c36372585f1d9f48480c2a6ef32eb6
--- /dev/null
+++ b/build/torch210-cxx11-xpu20253-x86_64-linux/ops/sort_benchmark.py
@@ -0,0 +1,85 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import numpy as np
+import torch
+# from absl.testing import parameterized
+
+from .. import ops
+
+_SORT_TESTS = (
+ (16384, torch.int32, None),
+ (16384, torch.int32, 2),
+ (16384, torch.int32, 128),
+)
+
+_BASELINE_SORT_TESTS = ((16384,),)
+
+
+def numpy_dtype(dtype):
+ types = {
+ torch.int16: np.int16,
+ torch.int32: np.int32,
+ torch.int64: np.int64,
+ }
+ return types[dtype]
+
+
+def benchmark_function(fn, iterations=10):
+ # Run once to get rid of startup overhead.
+ fn()
+ times = []
+ for _ in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+ start.record()
+ fn()
+ end.record()
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ times = np.array(times)
+ return times.mean(), times.std(), times.max(), times.min()
+
+
+def log_benchmark(arguments, mean_t, std_t):
+ print('=' * 60)
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t))
+ print('=' * 60)
+
+
+# class SortBenchmark(parameterized.TestCase):
+#
+# @parameterized.parameters(*_SORT_TESTS)
+# def testSort(self, n, dtype, max_val):
+# if max_val is None:
+# max_val = np.iinfo(numpy_dtype(dtype)).max
+# end_bit = int(np.ceil(np.log2(max_val)))
+# x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
+#
+# mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.sort(x, end_bit),)
+# arguments = {
+# 'n': n,
+# 'dtype': dtype,
+# 'max_val': max_val,
+# }
+# log_benchmark(arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_BASELINE_SORT_TESTS)
+# def testTorchSort(self, n):
+# x = torch.randint(0, 128, (n,)).cuda().to(torch.int32)
+#
+# mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.sort(x))
+# arguments = {
+# 'n': n,
+# }
+# log_benchmark(arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch210-cxx11-xpu20253-x86_64-linux/ops/stk_autocast.py b/build/torch210-cxx11-xpu20253-x86_64-linux/ops/stk_autocast.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a3626e5e0eec51339c95a448bca84be14a2ca93
--- /dev/null
+++ b/build/torch210-cxx11-xpu20253-x86_64-linux/ops/stk_autocast.py
@@ -0,0 +1,39 @@
+# vendored from
+# https://github.com/stanford-futuredata/stk/blob/736313768ef697ce13a0594a41b2512a0fbc9884/stk/backend/autocast.py
+import functools
+import torch
+
+
+def _is_eligible(x):
+ return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64)
+
+
+def _cast(x, dtype):
+ if isinstance(x, torch.Tensor) and _is_eligible(x):
+ return x.to(dtype)
+ elif isinstance(x, map):
+ return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()}
+ elif isinstance(x, list) or isinstance(x, tuple):
+ return type(x)(map(lambda y: _cast(y, dtype), x))
+ return x
+
+
+def custom_fwd(fwd):
+ """Wrap a custom autograd function that always uses autocast dtype."""
+
+ @functools.wraps(fwd)
+ def decorate_fwd(*args, **kwargs):
+ if torch.is_autocast_enabled():
+ with torch.autocast(device_type="cuda", enabled=False):
+ dtype = torch.get_autocast_gpu_dtype()
+ return fwd(*_cast(args, dtype), **_cast(kwargs, dtype))
+ return fwd(*args, **kwargs)
+ return decorate_fwd
+
+
+def custom_bwd(bwd):
+ @functools.wraps(bwd)
+ def decorate_bwd(*args, **kwargs):
+ with torch.autocast(device_type="cuda", enabled=False):
+ return bwd(*args, **kwargs)
+ return decorate_bwd
\ No newline at end of file
diff --git a/build/torch210-cxx11-xpu20253-x86_64-linux/ops/sum.py b/build/torch210-cxx11-xpu20253-x86_64-linux/ops/sum.py
new file mode 100644
index 0000000000000000000000000000000000000000..e00c1aa68e1193f5b72f75a2edc37de8d505facc
--- /dev/null
+++ b/build/torch210-cxx11-xpu20253-x86_64-linux/ops/sum.py
@@ -0,0 +1,9 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+import torch
+
+
+def sum(x: torch.Tensor, dim: int = 0):
+ if x.shape[dim] == 1:
+ return x.squeeze(dim=dim)
+ return x.sum(dim=dim)
diff --git a/build/torch210-cxx11-xpu20253-x86_64-linux/ops/topology.py b/build/torch210-cxx11-xpu20253-x86_64-linux/ops/topology.py
new file mode 100644
index 0000000000000000000000000000000000000000..76a50d3164db20534b099dcb4d8487a7aef25d15
--- /dev/null
+++ b/build/torch210-cxx11-xpu20253-x86_64-linux/ops/topology.py
@@ -0,0 +1,45 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for topology kernel.
+# NOTE: Does not support gradients.
+class TopologyOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(
+ ctx: Any,
+ padded_bins: torch.Tensor,
+ block_size: int,
+ output_block_rows: int,
+ output_block_columns: int,
+ ):
+ out = torch.empty(
+ output_block_rows * output_block_columns,
+ dtype=torch.int16,
+ device=padded_bins.device,
+ )
+ ops.indices(
+ padded_bins,
+ block_size,
+ output_block_rows,
+ output_block_columns,
+ out,
+ )
+ return out
+
+
+topology = TopologyOp.apply
diff --git a/build/torch210-cxx11-xpu20253-x86_64-linux/stk/__init__.py b/build/torch210-cxx11-xpu20253-x86_64-linux/stk/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..73c40c267b1c3f4949e9c957a5d2c9f682dfc1a6
--- /dev/null
+++ b/build/torch210-cxx11-xpu20253-x86_64-linux/stk/__init__.py
@@ -0,0 +1,7 @@
+# import stk.random
+# import stk.ops
+# from stk.matrix import Matrix
+
+from . import random
+from . import ops
+from .matrix import Matrix
diff --git a/build/torch210-cxx11-xpu20253-x86_64-linux/stk/backend/__init__.py b/build/torch210-cxx11-xpu20253-x86_64-linux/stk/backend/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/build/torch210-cxx11-xpu20253-x86_64-linux/stk/backend/autocast.py b/build/torch210-cxx11-xpu20253-x86_64-linux/stk/backend/autocast.py
new file mode 100644
index 0000000000000000000000000000000000000000..97f6e919a60f3fd579ed0215031008d14111dc96
--- /dev/null
+++ b/build/torch210-cxx11-xpu20253-x86_64-linux/stk/backend/autocast.py
@@ -0,0 +1,37 @@
+import functools
+import torch
+
+
+def _is_eligible(x):
+ return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64)
+
+
+def _cast(x, dtype):
+ if isinstance(x, torch.Tensor) and _is_eligible(x):
+ return x.to(dtype)
+ elif isinstance(x, map):
+ return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()}
+ elif isinstance(x, list) or isinstance(x, tuple):
+ return type(x)(map(lambda y: _cast(y, dtype), x))
+ return x
+
+
+def custom_fwd(fwd):
+ """Wrap a custom autograd function that always uses autocast dtype."""
+
+ @functools.wraps(fwd)
+ def decorate_fwd(*args, **kwargs):
+ if torch.is_autocast_enabled():
+ with torch.autocast(device_type="cuda", enabled=False):
+ dtype = torch.get_autocast_gpu_dtype()
+ return fwd(*_cast(args, dtype), **_cast(kwargs, dtype))
+ return fwd(*args, **kwargs)
+ return decorate_fwd
+
+
+def custom_bwd(bwd):
+ @functools.wraps(bwd)
+ def decorate_bwd(*args, **kwargs):
+ with torch.autocast(device_type="cuda", enabled=False):
+ return bwd(*args, **kwargs)
+ return decorate_bwd
diff --git a/build/torch210-cxx11-xpu20253-x86_64-linux/stk/backend/sputnik.py b/build/torch210-cxx11-xpu20253-x86_64-linux/stk/backend/sputnik.py
new file mode 100644
index 0000000000000000000000000000000000000000..220c947bc1e932e8c77cc30f4069e9930f1aa962
--- /dev/null
+++ b/build/torch210-cxx11-xpu20253-x86_64-linux/stk/backend/sputnik.py
@@ -0,0 +1,316 @@
+import torch
+
+from ..backend import triton_kernels as backend
+from ..backend.autocast import custom_bwd, custom_fwd
+
+
+def _standardize_shape(x, transpose):
+ if transpose:
+ return torch.Size((x[1], x[0]))
+ return x
+
+
+def _sparse_transpose(x):
+ return (torch.Size((x[0][1], x[0][0])), ) + x[1:]
+
+
+def _transpose_helper(x, transpose):
+ if isinstance(x, torch.Tensor):
+ return x.t() if transpose else x
+ if transpose:
+ x = _sparse_transpose(x)
+ return x + (transpose,)
+
+
+def _wrap(x):
+ if isinstance(x, torch.Tensor):
+ return (x,)
+ return x
+
+
+def _is_transposed(x):
+ return (not x.is_contiguous() and
+ x.stride()[0] == 1 and
+ x.stride()[1] == x.size()[0])
+
+
+def _call_helper(op, out, a, b, trans_a, trans_b):
+ args = (_wrap(_transpose_helper(a, trans_a)) +
+ _wrap(_transpose_helper(b, trans_b)))
+ if isinstance(out, tuple):
+ args = args + out
+ return op(*args)
+
+
+def _preprocess_inputs(lhs, rhs, dy):
+ if isinstance(lhs, torch.Tensor) and _is_transposed(lhs):
+ lhs = lhs.t()
+ if isinstance(rhs, torch.Tensor) and _is_transposed(rhs):
+ rhs = rhs.t()
+ if (isinstance(dy, torch.Tensor) and
+ not dy.is_contiguous() and
+ not _is_transposed(dy)):
+ dy = dy.contiguous()
+ if isinstance(dy, tuple) and not dy[1].is_contiguous():
+ dy = (dy[0], dy[1].contiguous()) + dy[2:]
+ return lhs, rhs, dy
+
+
+def _postprocess_outputs(x, transpose, grad):
+ if isinstance(x, torch.Tensor) and transpose:
+ return grad.t()
+ return grad
+
+
+def _lhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs):
+ lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy)
+
+ a, b = (rhs, dy) if trans_lhs else (dy, rhs)
+ trans_a = trans_lhs and trans_rhs
+ trans_b = trans_lhs or not trans_rhs
+ out = _call_helper(op, lhs, a, b, trans_a, trans_b)
+ return _postprocess_outputs(lhs, trans_lhs, out)
+
+
+def _rhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs):
+ lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy)
+
+ a, b = (dy, lhs) if trans_rhs else (lhs, dy)
+ trans_a = not trans_lhs or trans_rhs
+ trans_b = trans_lhs and trans_rhs
+ out = _call_helper(op, rhs, a, b, trans_a, trans_b)
+ return _postprocess_outputs(rhs, trans_rhs, out)
+
+
+class DSD(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs):
+ ctx.save_for_backward(data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ rhs)
+ ctx.shape = _standardize_shape(shape, transpose_a)
+ ctx.transpose_a = transpose_a
+
+ out = torch.empty(
+ (shape[0], rhs.size()[1]),
+ dtype=rhs.dtype,
+ device=rhs.device)
+
+ backend.dsd(shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs,
+ out)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs = (ctx.shape,) + saved_tensors[:-1]
+ rhs = saved_tensors[-1]
+ trans_a = ctx.transpose_a
+ trans_b = _is_transposed(rhs)
+
+ ddata = None
+ if ctx.needs_input_grad[1]:
+ ddata = _lhs_gradient(sdd,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ drhs = None
+ if ctx.needs_input_grad[-1]:
+ op = dds if trans_b else dsd
+ drhs = _rhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return None, ddata, None, None, None, None, None, None, None, drhs
+
+
+dsd = DSD.apply
+
+
+class DDS(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b):
+ ctx.save_for_backward(lhs,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t)
+ ctx.shape = _standardize_shape(shape, transpose_b)
+ ctx.transpose_b = transpose_b
+ out = torch.empty((lhs.size()[0], shape[1]),
+ dtype=lhs.dtype,
+ device=lhs.device)
+ backend.dds(lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b,
+ out)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs = saved_tensors[0]
+ rhs = (ctx.shape,) + saved_tensors[1:]
+ trans_a = _is_transposed(lhs)
+ trans_b = ctx.transpose_b
+
+ dlhs = None
+ if ctx.needs_input_grad[0]:
+ op = dsd if trans_a else dds
+ dlhs = _lhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ ddata = None
+ if ctx.needs_input_grad[2]:
+ ddata = _rhs_gradient(sdd,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return dlhs, None, ddata, None, None, None, None, None, None, None
+
+
+dds = DDS.apply
+
+
+class SDD(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ lhs,
+ rhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t):
+ ctx.save_for_backward(
+ lhs,
+ rhs,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t)
+ ctx.shape = shape
+ out = torch.empty(
+ data.shape,
+ dtype=lhs.dtype,
+ device=lhs.device)
+ backend.sdd(lhs,
+ rhs,
+ shape,
+ out,
+ offsets,
+ row_indices,
+ column_indices)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs, rhs = saved_tensors[:2]
+ dy = (ctx.shape, dy) + saved_tensors[2:]
+ trans_a = _is_transposed(lhs)
+ trans_b = _is_transposed(rhs)
+
+ dlhs = None
+ if ctx.needs_input_grad[0]:
+ op = dds if trans_a else dsd
+ dlhs = _lhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ drhs = None
+ if ctx.needs_input_grad[1]:
+ op = dsd if trans_b else dds
+ drhs = _rhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return dlhs, drhs, None, None, None, None, None, None, None, None
+
+
+sdd = SDD.apply
+
+class RowIndices(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, shape, data, offsets, column_indices):
+ out = torch.empty(
+ column_indices.shape,
+ dtype=column_indices.dtype,
+ device=column_indices.device)
+ backend.row_indices(shape, data, offsets, column_indices, out)
+ return out
+
+
+row_indices = RowIndices.apply
diff --git a/build/torch210-cxx11-xpu20253-x86_64-linux/stk/backend/triton_kernels.py b/build/torch210-cxx11-xpu20253-x86_64-linux/stk/backend/triton_kernels.py
new file mode 100644
index 0000000000000000000000000000000000000000..c535309f3321249f475367164a558f94a4f8eb86
--- /dev/null
+++ b/build/torch210-cxx11-xpu20253-x86_64-linux/stk/backend/triton_kernels.py
@@ -0,0 +1,393 @@
+import torch
+import triton
+import triton.language as tl
+from dataclasses import dataclass
+
+@dataclass
+class TritonConfig:
+ BLOCK_M: int = 128
+ BLOCK_N: int = 128
+ BLOCK_K: int = 32
+ BLOCK_SIZE: int = 128
+ NUM_STAGES: int = 4
+ NUM_WARPS: int = 4
+
+def _validate_matmul_dims(M: int, K: int, N: int):
+ error_string = "incompatible dimensions: tensor has dim with length: {}, which must be divisible by {}"
+ assert M % TritonConfig.BLOCK_M == 0, error_string.format(M, TritonConfig.BLOCK_M)
+ assert K % TritonConfig.BLOCK_K == 0, error_string.format(K, TritonConfig.BLOCK_K)
+ assert N % TritonConfig.BLOCK_N == 0, error_string.format(N, TritonConfig.BLOCK_N)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _sdd_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+ # matrix multiplication
+ pid = tl.program_id(0)
+ pid_m = tl.load(row_indices + pid)
+ pid_n = tl.load(column_indices + pid)
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
+ rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
+ rk = tl.arange(0, BLOCK_K)
+ # pointers
+ A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
+ B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ for k in range(0, tl.cdiv(K, BLOCK_K)):
+ a = tl.load(A)
+ b = tl.load(B)
+ acc += tl.dot(a, b)
+ A += BLOCK_K * stride_ak
+ B += BLOCK_K * stride_bk
+ #Store to sparse matrix
+ acc = acc.to(C.dtype.element_ty)
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+ cm = tl.arange(0, BLOCK_M)
+ cn = tl.arange(0, BLOCK_N)
+ C = C + pid * BLOCK_ELEMENTS + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _dsd_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices, offsets,
+ block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+
+ # matrix multiplication
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+
+ num_pid_m = tl.num_programs(0)
+ num_pid_n = tl.num_programs(1)
+ pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M)
+
+ start_inx = tl.load(offsets + pid_m)
+ end_inx = tl.load(offsets + pid_m + 1)
+
+ # pointers to sparse matrix
+ rm = tl.arange(0, BLOCK_M)
+ rak = tl.arange(0, BLOCK_K)
+
+ A += (rm[:, None] * stride_am + rak[None, :] * stride_ak)
+
+ # pointers to dense matrix
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ rbk = tl.arange(0, BLOCK_K)
+ B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn)
+
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K)
+
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+ ak_sub_incr = BLOCK_K * stride_ak
+ bk_sub_incr = BLOCK_K * stride_bk
+ bk_block_incr = BLOCK_SIZE * stride_bk
+
+ for k in range(nsub_blocks * (end_inx - start_inx)):
+ sub_block_inx = k % nsub_blocks
+ block_inx = k // nsub_blocks
+
+ if trans_A:
+ ptr_A = A + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr
+ else:
+ ptr_A = A + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr
+
+ ptr_B = B + tl.load(column_indices + start_inx + block_inx) * bk_block_incr + sub_block_inx * bk_sub_incr
+
+ a = tl.load(ptr_A)
+ b = tl.load(ptr_B)
+ acc += tl.dot(a, b)
+
+ acc = acc.to(C.dtype.element_ty)
+
+ cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+
+ C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _dds_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices, offsets,
+ block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+
+ # matrix multiplication
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+
+ num_pid_m = tl.num_programs(0)
+ num_pid_n = tl.num_programs(1)
+ pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M)
+
+ start_inx = tl.load(offsets + pid_n)
+ end_inx = tl.load(offsets + pid_n + 1)
+
+ # pointers to dense matrix
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ rak = tl.arange(0, BLOCK_K)
+
+ A += (rm[:, None] * stride_am + rak[None, :] * stride_ak)
+
+ # pointers to sparse matrix
+ rn = tl.arange(0, BLOCK_N)
+ rbk = tl.arange(0, BLOCK_K)
+ B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn)
+
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K)
+
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+
+ ak_sub_incr = BLOCK_K * stride_ak
+ ak_block_incr = BLOCK_SIZE * stride_ak
+ bk_sub_incr = BLOCK_K * stride_bk
+
+ for k in range(nsub_blocks * (end_inx - start_inx)):
+ sub_block_inx = k % nsub_blocks
+ block_inx = k // nsub_blocks
+
+ if trans_B:
+ ptr_B = B + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr
+ else:
+ ptr_B = B + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr
+
+ ptr_A = A + tl.load(column_indices + start_inx + block_inx) * ak_block_incr + sub_block_inx * ak_sub_incr
+ a = tl.load(ptr_A)
+ b = tl.load(ptr_B)
+ acc += tl.dot(a, b)
+
+ acc = acc.to(C.dtype.element_ty)
+ cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+def dsd(shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs,
+ out
+ ):
+
+ device = rhs.device
+ trans_A = transpose_a
+ trans_B = False
+
+ if rhs.stride(0) > 1 and rhs.stride(1) > 1:
+ trans_B = True
+
+ # checks constraints
+ assert shape[1] == rhs.shape[0], "incompatible dimensions"
+ M, K = shape
+ _, N = rhs.shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if rhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ stride_am, stride_ak = data.stride(1), data.stride(2)
+ stride_bk, stride_bn = rhs.stride(0), rhs.stride(1)
+ a_column_indices = column_indices
+ a_offsets = offsets
+
+ # launch kernel
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N']))
+
+ if trans_A:
+ stride_am, stride_ak = data.stride(2), data.stride(1)
+ a_column_indices, a_offsets = column_indices_t, offsets_t
+
+ if trans_B:
+ stride_bk, stride_bn = rhs.stride(1), rhs.stride(0)
+
+ _dsd_kernel[grid](
+ data.data, rhs, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(0), out.stride(1),
+ row_indices, a_column_indices, a_offsets,
+ block_offsets_t, trans_A, trans_B,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+ # return out
+
+def dds(lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b,
+ out
+ ):
+
+ device = lhs.device
+ trans_B = transpose_b
+ trans_A = False
+
+ if lhs.stride(0) > 1 and lhs.stride(1) > 1:
+ trans_A = True
+
+ # checks constraints
+ assert lhs.shape[1] == shape[0], "incompatible dimensions"
+ M, K = lhs.shape
+ _, N = shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if lhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ stride_am, stride_ak = lhs.stride(0), lhs.stride(1)
+ stride_bk, stride_bn = data.stride(1), data.stride(2)
+ b_column_indices = column_indices_t
+ b_offsets = offsets_t
+
+ # launch kernel
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N']))
+
+ if trans_A:
+ stride_am, stride_ak = lhs.stride(1), lhs.stride(0)
+ if trans_B:
+ stride_bk, stride_bn = data.stride(2), data.stride(1)
+ b_column_indices, b_offsets = column_indices, offsets
+
+ _dds_kernel[grid](
+ lhs, data, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(0), out.stride(1),
+ row_indices, b_column_indices, b_offsets,
+ block_offsets_t, trans_A, trans_B,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+
+def sdd(lhs,
+ rhs,
+ shape,
+ out,
+ offsets,
+ row_indices,
+ column_indices
+ ):
+
+ device = out.device
+ trans_A = False
+ trans_B = False
+
+ if lhs.stride(0) > 1 and lhs.stride(1) > 1:
+ trans_A = True
+ if rhs.stride(0) > 1 and rhs.stride(1) > 1:
+ trans_B = True
+
+ # checks constraints
+ assert lhs.shape[1] == rhs.shape[0], "incompatible dimensions"
+ M, K = lhs.shape
+ _, N = rhs.shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if out.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ # launch kernel
+ nnz_blocks = len(row_indices)
+ grid = lambda META: (nnz_blocks,)
+
+ stride_am, stride_ak = lhs.stride(0), lhs.stride(1)
+ stride_bk, stride_bn = rhs.stride(0), rhs.stride(1)
+
+ if trans_A:
+ stride_am, stride_ak = lhs.stride(1), lhs.stride(0)
+ if trans_B:
+ stride_bk, stride_bn = rhs.stride(1), rhs.stride(0)
+
+ _sdd_kernel[grid](
+ lhs, rhs, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(1), out.stride(2),
+ row_indices, column_indices,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+
+@triton.jit
+def _row_indices_kernel(offsets, out):
+ pid = tl.program_id(0)
+ row_offset = tl.load(offsets + pid)
+ nnz_blocks = tl.load(offsets + pid + 1) - row_offset
+ for nnz_block in range(nnz_blocks):
+ tl.store(out + row_offset + nnz_block, pid)
+
+def row_indices(
+ shape, data, offsets, column_indices, out
+):
+ block_rows = len(offsets) - 1
+ _row_indices_kernel[(block_rows, )](offsets, out)
diff --git a/build/torch210-cxx11-xpu20253-x86_64-linux/stk/matrix.py b/build/torch210-cxx11-xpu20253-x86_64-linux/stk/matrix.py
new file mode 100644
index 0000000000000000000000000000000000000000..80f42263d6aada287adbfa52a61fe950162a9e28
--- /dev/null
+++ b/build/torch210-cxx11-xpu20253-x86_64-linux/stk/matrix.py
@@ -0,0 +1,329 @@
+import numpy as np
+import torch
+
+# 1. Add heavyweight (data) validation helper.
+# 2. Add construction helpers
+# 3. Make indentation consistent
+# 4. Replace asserts with descriptive errors.
+
+##
+### Validation helpers.
+##
+
+
+def _validate_matrix(shape, data, row_indices, column_indices, offsets):
+ # Data should be [nnz, block_size, block_size]
+ if data.dim() == 1:
+ data = torch.reshape(data, [data.numel(), 1, 1])
+
+ # Blocks should be square.
+ if data.shape[-2] != data.shape[-1]:
+ raise ValueError(
+ "Expected square blocking in data. "
+ f"Got block shape {[data.shape[-2], data.shape[-1]]}")
+
+ # Flatten batch dimensions on data - original shape preserved
+ # in shape argument.
+ block_size = data.shape[-1]
+ data = data.view([-1, block_size, block_size])
+
+ if data.dim() != 3:
+ raise ValueError(
+ "Expected 3D shape for data (nnz, block, block). "
+ f"Got shape {data.dim()}D shape.")
+
+ block_size = data.shape[1]
+ if shape[-2] % block_size != 0 or shape[-1] % block_size != 0:
+ raise ValueError(
+ "Matrix shape must be dividible by blocking. "
+ f"Got shape {shape} with "
+ f"{[block_size, block_size]} blocking.")
+
+ if np.prod(shape) < data.numel():
+ raise ValueError(
+ "Invalid matrix. Number of nonzeros exceeds matrix capacity "
+ f"({data.numel()} v. {np.prod(shape)})")
+
+ if row_indices.dim() != 1:
+ raise ValueError(
+ f"Expected 1D row_indices. Got {row_indices.dim()}D row_indices.")
+
+ if column_indices.dim() != 1:
+ raise ValueError(
+ f"Expected 1D column_indices. Got {column_indices.dim()}D column_indices.")
+
+ if offsets.dim() != 1:
+ raise ValueError(
+ f"Expected 1D offsets. Got {offsets.dim()}D offsets.")
+
+ if row_indices.numel() != data.shape[0]:
+ raise ValueError(
+ "Expected 1 index per nonzero block. "
+ f"Got {row_indices.numel()} row_indices for {data.shape[0]} blocks")
+
+ if column_indices.numel() != data.shape[0]:
+ raise ValueError(
+ "Expected 1 index per nonzero block. "
+ f"Got {column_indices.numel()} column_indices for {data.shape[0]} blocks")
+
+ block_rows = np.prod(shape[:-1]) / block_size
+ if offsets.numel() != block_rows + 1:
+ raise ValueError(
+ "Expected one offset per block row plus one. "
+ f"Got {offsets.numel()} offsets with {block_rows} block rows.")
+
+ is_cuda = (data.is_cuda and
+ row_indices.is_cuda and
+ column_indices.is_cuda and
+ offsets.is_cuda)
+ is_cpu = (not data.is_cuda and
+ not row_indices.is_cuda and
+ not column_indices.is_cuda and
+ not offsets.is_cuda)
+ if not (is_cuda or is_cpu):
+ raise ValueError(
+ "Expected data & meta-data on common device. "
+ f"Got data on {data.device}, row_indices on {row_indices.device} "
+ f"column_indices on {column_indices.device} and "
+ f"offsets on {offsets.device}.")
+
+ if data.dtype != torch.float16:
+ raise ValueError(
+ f"Expected float16 data. Got {data.dtype} data.")
+ if row_indices.dtype != torch.int16:
+ raise ValueError(
+ f"Expected int16 row_indices. Got {row_indices.dtype} row_indices.")
+ if column_indices.dtype != torch.int16:
+ raise ValueError(
+ f"Expected int16 column_indices. Got {column_indices.dtype} column_indices.")
+ if offsets.dtype != torch.int32:
+ raise ValueError(
+ f"Expected int32 offsets. Got {offsets.dtype} offsets.")
+ return data
+
+
+def _transpose(size, data, row_indices, column_indices, offsets):
+ block_columns = size[1] // data.shape[1]
+
+ # Sort row indices by column indices to get the transposed matrix's
+ # column indices.
+ gather_indices = column_indices.argsort()
+ column_indices_t = row_indices.gather(0, gather_indices)
+ block_offsets_t = gather_indices.int()
+
+ # NOTE: Histogram is not implemented for any integer type on CPU. Do
+ # the histogram in 32-bit float, which can exactly represent 16-bit
+ # integers.
+ column_indices_float = column_indices.float()
+
+ zero = torch.zeros((1,), dtype=torch.int32, device=data.device)
+ nnz_per_column = column_indices_float.histc(block_columns, 0, block_columns)
+ nnz_per_column = nnz_per_column.int()
+ offsets_t = torch.cat([zero, nnz_per_column.cumsum(0, dtype=torch.int32)])
+ return column_indices_t, offsets_t, block_offsets_t
+
+
+class Matrix(torch.nn.Module):
+ """A matrix stored in sparse format.
+
+ Underlying format is block compressed sparse row (BCSR).
+
+ TODO(tgale): Make this mirror torch.Tensor API as much as possible.
+ """
+
+ def __init__(self,
+ size,
+ data,
+ row_indices,
+ column_indices,
+ offsets,
+ column_indices_t=None,
+ offsets_t=None,
+ block_offsets_t=None):
+ super().__init__()
+ self._size = size
+ self._data = data
+ self._row_indices = row_indices
+ self._column_indices = column_indices
+ self._offsets = offsets
+
+ # Produce the transpose meta-data if it is not passed in.
+ if ((column_indices_t is None) or (offsets_t is None) or
+ (block_offsets_t is None)):
+ column_indices_t, offsets_t, block_offsets_t = _transpose(
+ size, data, row_indices, column_indices, offsets)
+ self._column_indices_t = column_indices_t
+ self._offsets_t = offsets_t
+ self._block_offsets_t = block_offsets_t
+
+ self._transposed = False
+
+ # Validate that our metadata will not overflow.
+ max_dim = np.iinfo(np.int16).max * self.blocking
+ if column_indices.dtype == torch.int16:
+ if size[0] > max_dim or size[1] > max_dim:
+ raise ValueError(
+ "Sparse matrix with shape {size} exceeds representable "
+ "size with 16-bit indices.")
+
+ def validate(self):
+ _validate_matrix(self._size,
+ self._data,
+ self._row_indices,
+ self._column_indices,
+ self._offsets)
+
+ # TODO(tgale): Add heavyweight data validation.
+
+ def to(self, device):
+ # TODO(tgale): Handle type conversions here. We
+ # need to set the appropriate meta-data type for
+ # the given floating-point type.
+ self._data = self._data.to(device)
+ self._row_indices = self._row_indices.to(device)
+ self._column_indices = self._column_indices.to(device)
+ self._offsets = self._offsets.to(device)
+ self._column_indices_t = self._column_indices_t.to(device)
+ self._offsets_t = self._offsets_t.to(device)
+ self._block_offsets_t = self._block_offsets_t.to(device)
+ return self
+
+ def cuda(self):
+ return self.to(torch.cuda.current_device())
+
+ def clone(self):
+ return Matrix(
+ self.size(),
+ self.data.clone(),
+ self.row_indices.clone(),
+ self.column_indices.clone(),
+ self.offsets.clone(),
+ self.column_indices_t.clone(),
+ self.offsets_t.clone(),
+ self.block_offsets_t.clone())
+
+ def t(self):
+ if self.dim() != 2:
+ raise ValueError(
+ "t() expects a tensor with <= 2 dimensions, "
+ f"but self is {self.dim()}D.")
+ out = Matrix(self.size(),
+ self.data,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+ out._transposed = not self._transposed
+ out._size = torch.Size((self._size[1], self._size[0]))
+ return out
+
+ def contiguous(self):
+ raise ValueError("Not yet implemented.")
+
+ def is_contiguous(self):
+ return not self._transposed
+
+ @property
+ def is_cuda(self):
+ return self._data.is_cuda
+
+ @property
+ def device(self):
+ return self._data.device
+
+ def size(self):
+ return self._size
+
+ @property
+ def shape(self):
+ return self.size()
+
+ def dim(self):
+ return len(self._size)
+
+ @property
+ def data(self):
+ return self._data
+
+ @property
+ def row_indices(self):
+ return self._row_indices
+
+ @property
+ def column_indices(self):
+ return self._column_indices
+
+ @property
+ def offsets(self):
+ return self._offsets
+
+ @property
+ def offsets_t(self):
+ return self._offsets_t
+
+ @property
+ def column_indices_t(self):
+ return self._column_indices_t
+
+ @property
+ def block_offsets_t(self):
+ return self._block_offsets_t
+
+ @property
+ def dtype(self):
+ return self.data.dtype
+
+ @property
+ def nnz(self):
+ return self.data.numel()
+
+ @property
+ def blocking(self):
+ return self.data.shape[1]
+
+ @property
+ def requires_grad(self):
+ return self.data.requires_grad
+
+ def requires_grad_(self, x):
+ self.data.requires_grad_(x)
+ return self
+
+ def view(self, *shape):
+ assert self.is_contiguous()
+ if shape[-1] != self.size()[-1]:
+ raise ValueError(
+ "Can't change view on compressed dimension. "
+ f"{self.size()[-1]} v. {shape[-1]}.")
+ if np.prod(shape) != np.prod(self.size()):
+ raise ValueError(
+ "Mismatch in numel of Matrix and new shape. "
+ f"{np.prod(self.size())} v. {np.prod(shape)}")
+ return Matrix(shape,
+ self.data,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+
+ @property
+ def grad(self):
+ # TODO(tgale): Make sure this mirrors torch.Tensor
+ # behavior in the case where we ask for the gradient
+ # of a non-contiguous tensor.
+ size = self.size()
+ if not self.is_contiguous():
+ size = torch.Size((size[1], size[0]))
+ out = Matrix(size,
+ self.data.grad,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+ return out if self.is_contiguous() else out.t()
diff --git a/build/torch210-cxx11-xpu20253-x86_64-linux/stk/ops/__init__.py b/build/torch210-cxx11-xpu20253-x86_64-linux/stk/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc873b236f4cd4036964c016a4036e3ce5ebf1ac
--- /dev/null
+++ b/build/torch210-cxx11-xpu20253-x86_64-linux/stk/ops/__init__.py
@@ -0,0 +1,3 @@
+from .linear_ops import dds, dsd, sdd
+from .matrix_ops import ones_like, row_indices, sum, to_dense, to_sparse
+from .eltwise_ops import mul
diff --git a/build/torch210-cxx11-xpu20253-x86_64-linux/stk/ops/eltwise_ops.py b/build/torch210-cxx11-xpu20253-x86_64-linux/stk/ops/eltwise_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba7d7332320250fd01fa60e60528f19de3e8ed03
--- /dev/null
+++ b/build/torch210-cxx11-xpu20253-x86_64-linux/stk/ops/eltwise_ops.py
@@ -0,0 +1,28 @@
+from ..matrix import Matrix
+
+def mul(a, b):
+ """Performs element-wise multiplication of matrices a and b.
+
+ It is the user's responsibility to make sure that a and b
+ follow the same matrix topology. This function assumes it is safe
+ to use the topoplogy of a.
+
+ Args:
+ a: stk.Matrix.
+ b: stk.Matrix with a's matrix topology.
+
+ Returns:
+ stk.Matrix where the entries correspond to torch.mul(a, b).
+ """
+ assert isinstance(a, Matrix)
+ assert isinstance(b, Matrix)
+ assert a.size() == b.size()
+
+ return Matrix(a.size(),
+ a.data * b.data,
+ a.row_indices,
+ a.column_indices,
+ a.offsets,
+ a.column_indices_t,
+ a.offsets_t,
+ a.block_offsets_t)
diff --git a/build/torch210-cxx11-xpu20253-x86_64-linux/stk/ops/eltwise_ops_test.py b/build/torch210-cxx11-xpu20253-x86_64-linux/stk/ops/eltwise_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..2939372a5c68ac92b47b11015db4f75f4fd60ffa
--- /dev/null
+++ b/build/torch210-cxx11-xpu20253-x86_64-linux/stk/ops/eltwise_ops_test.py
@@ -0,0 +1,86 @@
+import unittest
+import itertools
+import torch
+# from absl.testing import parameterized
+
+import stk
+from stk.ops.linear_ops_test import allclose, _dense_and_sparse
+
+_MATRIX_SIZES = (
+ (128, 128, 0.0),
+ (256, 256, 0.5),
+ (2048, 1024, 0.8),
+ (512, 128, 0.0),
+ (128, 512, 0.0),
+ (1024, 512, 0.0),
+ (1024, 512, 0.5),
+ (1024, 512, 0.75),
+ (512, 1024, 0.0),
+ (512, 1024, 0.5),
+ (512, 1024, 0.75),
+ (1024, 1024, 0.0),
+ (1024, 1024, 0.5),
+ (1024, 1024, 0.75),
+)
+
+_DTYPE = (
+ torch.float16, torch.bfloat16
+)
+
+def _generate_testcases():
+ testcases = itertools.product(_MATRIX_SIZES, _DTYPE)
+ testcases = [(*size, 128, dtype) for
+ (size, dtype) in testcases]
+ return testcases
+
+_ELTWISE_OP_TESTS = _generate_testcases()
+
+def _dense_and_sparse_like(x, std=0.1):
+ dense_data = torch.randn_like(x.data, device=x.device) * std
+ sparse = stk.Matrix(x.size(),
+ dense_data,
+ x.row_indices,
+ x.column_indices,
+ x.offsets)
+ dense = stk.ops.to_dense(sparse)
+
+ return (dense.requires_grad_(True),
+ sparse.requires_grad_(True))
+
+# @parameterized.parameters(_ELTWISE_OP_TESTS)
+# class EltwiseOpsTest(parameterized.TestCase):
+#
+# def testEltwiseMul(self, m, n, sparsity, blocking, dtype):
+#
+# a_dense, a = _dense_and_sparse(m, n, sparsity, blocking, dtype)
+# b_dense, b = _dense_and_sparse_like(a)
+#
+# out = stk.ops.mul(a, b)
+# expected_out = torch.mul(a_dense, b_dense)
+#
+# # Compute the gradients w.r.t. the inputs.
+# expected_out.sum().backward()
+# stk.ops.sum(out).backward()
+#
+# # Validate the results.
+# out = stk.ops.to_dense(out)
+# self.assertEqual(out.dim(), 2)
+# self.assertEqual(expected_out.size(), out.size())
+# self.assertTrue(allclose(out, expected_out))
+#
+# # LHS gradient.
+# grad = stk.ops.to_dense(a.grad)
+# expected_grad = a_dense.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size(), grad.size())
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# # RHS gradient.
+# grad = stk.ops.to_dense(b.grad)
+# expected_grad = b_dense.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size(), grad.size())
+# self.assertTrue(allclose(grad, expected_grad))
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch210-cxx11-xpu20253-x86_64-linux/stk/ops/linear_ops.py b/build/torch210-cxx11-xpu20253-x86_64-linux/stk/ops/linear_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d277c8c07f9e30addc31900a12175c8a1f4d7ad
--- /dev/null
+++ b/build/torch210-cxx11-xpu20253-x86_64-linux/stk/ops/linear_ops.py
@@ -0,0 +1,59 @@
+import torch
+
+from ..backend import sputnik
+from ..matrix import Matrix
+
+
+def dsd(a, b):
+ assert isinstance(a, Matrix)
+ assert isinstance(b, torch.Tensor)
+ return sputnik.dsd(
+ a.size(),
+ a.data, a.offsets,
+ a.row_indices,
+ a.column_indices,
+ a.offsets_t,
+ a.column_indices_t,
+ a.block_offsets_t,
+ not a.is_contiguous(),
+ b)
+
+
+def dds(a, b):
+ assert isinstance(a, torch.Tensor)
+ assert isinstance(b, Matrix)
+ return sputnik.dds(
+ a,
+ b.size(),
+ b.data, b.offsets,
+ b.row_indices,
+ b.column_indices,
+ b.offsets_t,
+ b.column_indices_t,
+ b.block_offsets_t,
+ not b.is_contiguous())
+
+
+def sdd(a, b, topo):
+ assert isinstance(a, torch.Tensor)
+ assert isinstance(b, torch.Tensor)
+ assert isinstance(topo, Matrix)
+ assert topo.is_contiguous()
+ out = sputnik.sdd(
+ a, b,
+ topo.size(),
+ topo.data,
+ topo.offsets,
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets_t,
+ topo.column_indices_t,
+ topo.block_offsets_t)
+ return Matrix(topo.size(),
+ out,
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets,
+ topo.column_indices_t,
+ topo.offsets_t,
+ topo.block_offsets_t)
diff --git a/build/torch210-cxx11-xpu20253-x86_64-linux/stk/ops/linear_ops_test.py b/build/torch210-cxx11-xpu20253-x86_64-linux/stk/ops/linear_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1c24d350df9c1b2346c7da885502cd696c88867
--- /dev/null
+++ b/build/torch210-cxx11-xpu20253-x86_64-linux/stk/ops/linear_ops_test.py
@@ -0,0 +1,216 @@
+import unittest
+import itertools
+import numpy as np
+import torch
+# from absl.testing import parameterized
+
+import stk
+
+
+def allclose(x, y, pct=0.25):
+ mask = torch.isclose(x, y, rtol=5e-2)
+ pct_diff = (mask.numel() - mask.sum()) / mask.numel() * 100
+ if pct_diff > pct:
+ print("{:.2f}% of values not close.".format(pct_diff))
+ return False
+ return True
+
+
+# An assortment of problems designed to make sure
+# the bindings are operating correctly.
+_MATRIX_SIZES = (
+ (128, 128, 128, 0.0),
+ (256, 256, 256, 0.5),
+ (2048, 1024, 512, 0.8),
+ (512, 128, 128, 0.0),
+ (128, 128, 512, 0.0),
+ (1024, 512, 512, 0.0),
+ (1024, 512, 512, 0.5),
+ (1024, 512, 512, 0.75),
+ (512, 512, 1024, 0.0),
+ (512, 512, 1024, 0.5),
+ (512, 512, 1024, 0.75),
+ (1024, 1024, 1024, 0.0),
+ (1024, 1024, 1024, 0.5),
+ (1024, 1024, 1024, 0.75),
+)
+
+_TRANSPOSE = (
+ (False, False),
+ (False, True),
+ (True, False),
+ (True, True),
+)
+
+_DTYPE = (
+ torch.float16, torch.bfloat16
+)
+
+def _generate_testcases():
+ testcases = itertools.product(_MATRIX_SIZES, _TRANSPOSE, _DTYPE)
+ testcases = [(*size, *trans, 128, dtype) for
+ (size, trans, dtype) in testcases]
+ return testcases
+
+_LINEAR_OP_TESTS = _generate_testcases()
+
+def _dense_and_sparse(rows, cols, sparsity, blocking, dtype, std=0.1):
+ mask = stk.random.dense_mask(rows, cols, sparsity, blocking)
+ dense = (torch.randn(rows, cols) * std * mask).type(dtype)
+ sparse = stk.ops.to_sparse(dense, blocking)
+ cuda_device = torch.device("cuda")
+ return (dense.to(cuda_device).requires_grad_(True),
+ sparse.to(cuda_device).requires_grad_(True))
+
+
+def _dense(rows, cols, dtype, std=0.1):
+ cuda_device = torch.device("cuda")
+ out = (torch.randn(rows, cols) * std).type(dtype)
+ return out.to(cuda_device).requires_grad_(True)
+
+
+def _dense_2x(rows, cols, dtype):
+ a = _dense(rows, cols, dtype)
+ return a, a.detach().requires_grad_(True)
+
+
+def _with_transpose(op, a, b, trans_a, trans_b):
+ a = a.t() if trans_a else a
+ b = b.t() if trans_b else b
+ return op(a, b)
+
+
+def _mmm(a, b, topo):
+ mask = stk.ops.to_dense(stk.ops.ones_like(topo))
+ return torch.mm(a, b) * mask
+
+
+def _sparse_out_with_transpose(op, a, b, topo, trans_a, trans_b):
+ a = a.t() if trans_a else a
+ b = b.t() if trans_b else b
+ return op(a, b, topo)
+
+
+def _mask(x, mask):
+ mask = stk.ops.to_dense(stk.ops.ones_like(mask))
+ return x * mask
+
+
+# @parameterized.parameters(*_LINEAR_OP_TESTS)
+# class LinearOpsTest(parameterized.TestCase):
+#
+# def testLinearOps_Dsd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+# # Construct the operands.
+# a_shape = (k, m) if trans_a else (m, k)
+# a_dense, a = _dense_and_sparse(*a_shape, sparsity, blocking, dtype)
+# b_shape = (n, k) if trans_b else (k, n)
+# b, bcp = _dense_2x(*b_shape, dtype)
+#
+# # Execute the matmul.
+# out = _with_transpose(stk.ops.dsd, a, b, trans_a, trans_b)
+# expected_out = _with_transpose(torch.mm, a_dense, bcp, trans_a, trans_b)
+#
+# # Compute the gradients w.r.t. the inputs.
+# expected_out.sum().backward()
+# out.sum().backward()
+#
+# # Validate the results.
+# self.assertEqual(out.dim(), 2)
+# self.assertEqual(expected_out.size()[0], out.size()[0])
+# self.assertEqual(expected_out.size()[1], out.size()[1])
+# self.assertTrue(allclose(out, expected_out))
+#
+# # LHS gradient.
+# grad = stk.ops.to_dense(a.grad)
+# expected_grad = _mask(a_dense.grad, a.grad)
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# # RHS gradient.
+# grad = b.grad
+# expected_grad = bcp.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# def testLinearOps_Dds(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+# # Construct the operands.
+# a_shape = (k, m) if trans_a else (m, k)
+# a, acp = _dense_2x(*a_shape, dtype)
+# b_shape = (n, k) if trans_b else (k, n)
+# b_dense, b = _dense_and_sparse(*b_shape, sparsity, blocking, dtype)
+#
+# # Execute the matmul.
+# out = _with_transpose(stk.ops.dds, a, b, trans_a, trans_b)
+# expected_out = _with_transpose(torch.mm, acp, b_dense, trans_a, trans_b)
+#
+# # Compute the gradients w.r.t. the inputs.
+# expected_out.sum().backward()
+# out.sum().backward()
+#
+# # Validate the results.
+# self.assertEqual(out.dim(), 2)
+# self.assertEqual(expected_out.size()[0], out.size()[0])
+# self.assertEqual(expected_out.size()[1], out.size()[1])
+# self.assertTrue(allclose(out, expected_out))
+#
+# # LHS gradient.
+# grad = a.grad
+# expected_grad = acp.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# # RHS gradient.
+# grad = stk.ops.to_dense(b.grad)
+# expected_grad = _mask(b_dense.grad, b.grad)
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# def testLinearOps_Sdd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+# # Construct the operands.
+# a_shape = (k, m) if trans_a else (m, k)
+# a, acp = _dense_2x(*a_shape, dtype)
+# b_shape = (n, k) if trans_b else (k, n)
+# b, bcp = _dense_2x(*b_shape, dtype)
+# _, topo = _dense_and_sparse(m, n, sparsity, blocking, dtype)
+#
+# # Execute the matmul.
+# out = _sparse_out_with_transpose(stk.ops.sdd, a, b, topo, trans_a, trans_b)
+# expected_out = _sparse_out_with_transpose(_mmm, acp, bcp, topo, trans_a, trans_b)
+#
+# # Compute the gradients w.r.t. the inputs.
+# expected_out.sum().backward()
+# stk.ops.sum(out).backward()
+#
+# # Validate the results.
+# out = stk.ops.to_dense(out)
+# self.assertEqual(out.dim(), 2)
+# self.assertEqual(expected_out.size()[0], out.size()[0])
+# self.assertEqual(expected_out.size()[1], out.size()[1])
+# self.assertTrue(allclose(out, expected_out))
+#
+# # LHS gradient.
+# grad = a.grad
+# expected_grad = acp.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# # RHS gradient.
+# grad = b.grad
+# expected_grad = bcp.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch210-cxx11-xpu20253-x86_64-linux/stk/ops/matrix_ops.py b/build/torch210-cxx11-xpu20253-x86_64-linux/stk/ops/matrix_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..447c72dc73439d84f58c917676cc04e64f13e97d
--- /dev/null
+++ b/build/torch210-cxx11-xpu20253-x86_64-linux/stk/ops/matrix_ops.py
@@ -0,0 +1,98 @@
+from ..backend import sputnik
+from ..matrix import Matrix
+import torch
+import numpy as np
+
+
+@torch.no_grad()
+def row_indices(shape, data, offsets, column_indices):
+ return sputnik.row_indices(shape, data, offsets, column_indices)
+
+
+# TODO(tgale): Replace this helper with a custom kernel. This operation
+# is much simpler to do than how it's currently implemented.
+@torch.no_grad()
+def _expand_for_blocking(idxs, blocking):
+ # Duplicate for block column dimension.
+ idxs = torch.reshape(idxs, [idxs.size()[0], 1, 2]).repeat(1, blocking, 1)
+
+ # Update the column indices.
+ idxs[:, :, 1] *= blocking
+ idxs[:, :, 1] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking])
+
+ # Duplicate for block row dimension.
+ idxs = torch.reshape(idxs, [idxs.size()[0], 1, blocking, 2])
+ idxs = idxs.repeat(1, blocking, 1, 1)
+
+ # Update the row indices.
+ idxs[:, :, :, 0] *= blocking
+ idxs[:, :, :, 0] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking, 1])
+ idxs = torch.reshape(idxs, [-1, 2])
+ return idxs
+
+
+# TODO(tgale): Add input type checking.
+@torch.no_grad()
+def to_dense(x):
+ assert isinstance(x, Matrix)
+
+ shape = (np.prod(x.shape[:-1]), x.shape[-1])
+ row_idxs = x.row_indices.type(torch.int32)
+ col_idxs = x.column_indices.type(torch.int32)
+ indices = _expand_for_blocking(torch.stack([row_idxs, col_idxs], dim=1), x.blocking)
+ indices = (indices[:, 0] * shape[1] + indices[:, 1]).type(torch.int64)
+
+ out = torch.zeros(shape[0] * shape[1], dtype=x.dtype, device=x.device)
+ out.scatter_(0, indices, x.data.flatten())
+ return out.reshape(x.size())
+
+
+@torch.no_grad()
+def _mask(x, blocking=1):
+ assert x.dim() == 2
+ assert x.size()[0] % blocking == 0
+ assert x.size()[1] % blocking == 0
+ block_rows = x.size()[0] // blocking
+ block_cols = x.size()[1] // blocking
+ x = torch.reshape(x, [block_rows, blocking, block_cols, blocking])
+ x = torch.sum(torch.abs(x), dim=(1, 3))
+ return x != 0
+
+
+# TODO(tgale): Add input type checking.
+@torch.no_grad()
+def to_sparse(x, blocking=1):
+ m = _mask(x, blocking)
+
+ # TODO(tgale): Set to appropriate type for input matrix.
+ row_nnzs = torch.sum(m, dim=1).type(torch.int32)
+ zeros = torch.zeros((1,), dtype=row_nnzs.dtype, device=row_nnzs.device)
+ offsets = torch.cat([zeros, torch.cumsum(row_nnzs, dim=0)])
+ offsets = offsets.type(torch.int32)
+
+ indices = torch.nonzero(m).type(torch.int16)
+ row_indices = indices[:, 0]
+ column_indices = indices[:, 1]
+
+ # Nonzero indices in the dense matrix.
+ nonzero_indices = torch.nonzero(m)
+ nonzero_indices = _expand_for_blocking(nonzero_indices, blocking)
+ nonzero_indices = nonzero_indices[:, 0] * x.size()[1] + nonzero_indices[:, 1]
+
+ # Gather the data and construct the sparse matrix.
+ data = torch.gather(x.flatten(), dim=0, index=nonzero_indices)
+ data = torch.reshape(data, [-1, blocking, blocking])
+ return Matrix(x.size(), data, row_indices, column_indices, offsets)
+
+
+@torch.no_grad()
+def ones_like(x):
+ return Matrix(x.size(),
+ torch.ones_like(x.data),
+ x.row_indices,
+ x.column_indices, x.offsets)
+
+
+def sum(x):
+ assert isinstance(x, Matrix)
+ return x.data.sum()
diff --git a/build/torch210-cxx11-xpu20253-x86_64-linux/stk/ops/matrix_ops_test.py b/build/torch210-cxx11-xpu20253-x86_64-linux/stk/ops/matrix_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..d172d921f6f08b0e4fb709207a458b0e1e071bd0
--- /dev/null
+++ b/build/torch210-cxx11-xpu20253-x86_64-linux/stk/ops/matrix_ops_test.py
@@ -0,0 +1,62 @@
+import unittest
+
+# from absl.testing import parameterized
+import stk
+import torch
+
+
+# @parameterized.parameters(
+# (8, 16, 0.0, 1),
+# (8, 16, 0.5, 1),
+# (8, 16, .95, 1),
+# (16, 8, 0.0, 1),
+# (16, 8, 0.5, 1),
+# (16, 8, .95, 1),
+# (8, 16, 0.0, 8),
+# (8, 16, 0.5, 8),
+# (8, 16, 1.0, 8),
+# (16, 8, 0.0, 8),
+# (16, 8, 0.5, 8),
+# (16, 8, 1.0, 8),
+# (128, 256, 0.5, 16),
+# (256, 128, 0.75, 32),
+# (512, 512, .875, 128))
+# class MatrixOpsTest(parameterized.TestCase):
+#
+# def testMatrixOps_FormatConversion(self, rows, cols, sparsity, blocking):
+# mask = stk.random.dense_mask(rows, cols, sparsity, blocking)
+# x = (torch.randn(rows, cols) * mask).type(torch.float16)
+#
+# # Convert the matrix to sparse format.
+# sparse_x = stk.ops.to_sparse(x, blocking)
+#
+# # Validate the matrix.
+# sparse_x.validate()
+#
+# # Validate the shape.
+# self.assertEqual(sparse_x.dim(), 2)
+# self.assertEqual(sparse_x.size()[0], rows)
+# self.assertEqual(sparse_x.size()[1], cols)
+#
+# # Validate the sparsity.
+# numblocks = rows // blocking * cols // blocking
+# nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+# self.assertEqual(sparse_x.nnz, nnz)
+#
+# # Convert back to dense format.
+# dense_x = stk.ops.to_dense(sparse_x)
+#
+# # Validate the shape.
+# self.assertEqual(dense_x.dim(), 2)
+# self.assertEqual(dense_x.size()[0], rows)
+# self.assertEqual(dense_x.size()[1], cols)
+#
+# # Validate the sparsity
+# self.assertEqual(torch.count_nonzero(dense_x).item(), nnz)
+#
+# # Validate the output.
+# self.assertTrue(torch.all(torch.eq(x, dense_x)))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch210-cxx11-xpu20253-x86_64-linux/stk/random/__init__.py b/build/torch210-cxx11-xpu20253-x86_64-linux/stk/random/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2576d1ca27283f77569a9a620c7c99fa68aaf30e
--- /dev/null
+++ b/build/torch210-cxx11-xpu20253-x86_64-linux/stk/random/__init__.py
@@ -0,0 +1,2 @@
+# from stk.random.random_ops import dense_mask, mask, randn
+from .random_ops import dense_mask, mask, randn
diff --git a/build/torch210-cxx11-xpu20253-x86_64-linux/stk/random/random_ops.py b/build/torch210-cxx11-xpu20253-x86_64-linux/stk/random/random_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1b36771e0eb8e7abf46bcb3b136b5fb1d29df93
--- /dev/null
+++ b/build/torch210-cxx11-xpu20253-x86_64-linux/stk/random/random_ops.py
@@ -0,0 +1,36 @@
+import numpy as np
+import torch
+from ..ops import matrix_ops
+
+
+@torch.no_grad()
+def dense_mask(rows, cols, sparsity, blocking=1):
+ assert sparsity >= 0.0 and sparsity <= 1.0
+ assert rows % blocking == 0 and cols % blocking == 0
+
+ block_rows, block_cols = (rows // blocking, cols // blocking)
+ nnz = round(block_rows * block_cols * (1 - sparsity))
+
+ out = np.ones(block_rows * block_cols)
+ mask = np.random.choice(out.size, out.size - nnz, replace=False)
+ out[mask] = 0.0
+
+ out = np.tile(
+ np.reshape(out, [block_rows, 1, block_cols, 1]),
+ (1, blocking, 1, blocking))
+ out = np.reshape(out, [rows, cols])
+ return torch.from_numpy(out.astype(np.float32))
+
+
+@torch.no_grad()
+def mask(m, n, sparsity, blocking=1):
+ out = dense_mask(m, n, sparsity, blocking).type(torch.float16)
+ return matrix_ops.to_sparse(out, blocking=blocking)
+
+
+@torch.no_grad()
+def randn(shape, sparsity, blocking=1):
+ shape_2d = (np.prod(shape[:-1]), shape[-1])
+ out = mask(*shape_2d, sparsity, blocking)
+ out.data.copy_(torch.randn(*out.data.shape))
+ return out.view(*shape)
diff --git a/build/torch210-cxx11-xpu20253-x86_64-linux/stk/random/random_ops_test.py b/build/torch210-cxx11-xpu20253-x86_64-linux/stk/random/random_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..71d716b78b5ec009cbf9ac2dfdf09162a0102e62
--- /dev/null
+++ b/build/torch210-cxx11-xpu20253-x86_64-linux/stk/random/random_ops_test.py
@@ -0,0 +1,73 @@
+import unittest
+
+# from absl.testing import parameterized
+from . import random
+import torch
+
+
+# @parameterized.parameters(
+# (8, 16, 0.0, 1),
+# (8, 16, 0.5, 1),
+# (8, 16, .95, 1),
+# (16, 8, 0.0, 1),
+# (16, 8, 0.5, 1),
+# (16, 8, .95, 1),
+# (8, 16, 0.0, 8),
+# (8, 16, 0.5, 8),
+# (8, 16, 1.0, 8),
+# (16, 8, 0.0, 8),
+# (16, 8, 0.5, 8),
+# (16, 8, 1.0, 8),
+# (128, 256, 0.5, 16),
+# (256, 128, 0.75, 32),
+# (512, 512, .875, 128))
+# class RandomOpsTest(parameterized.TestCase):
+#
+# def testRandomOps_DenseMask(self, rows, cols, sparsity, blocking):
+# mask = random.dense_mask(
+# rows, cols, sparsity, blocking)
+#
+# # Validate the shape.
+# self.assertEqual(mask.dim(), 2)
+# self.assertEqual(mask.size()[0], rows)
+# self.assertEqual(mask.size()[1], cols)
+#
+# # Validate the sparsity
+# numblocks = rows // blocking * cols // blocking
+# nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+# self.assertEqual(
+# torch.count_nonzero(mask).item(),
+# nnz)
+#
+# # Check values are zero or one.
+# self.assertTrue(
+# torch.all(torch.logical_or(
+# torch.eq(mask, 0),
+# torch.eq(mask, 1))))
+#
+# def testRandomOps_SparseMask(self, rows, cols, sparsity, blocking):
+# mask = random.mask(
+# rows, cols, sparsity, blocking)
+#
+# # Validate the matrix.
+# mask.validate()
+#
+# # Validate the shape.
+# self.assertEqual(mask.dim(), 2)
+# self.assertEqual(mask.size()[0], rows)
+# self.assertEqual(mask.size()[1], cols)
+#
+# # Validate the sparsity.
+# numblocks = rows // blocking * cols // blocking
+# nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+# self.assertEqual(mask.nnz, nnz)
+#
+# # Check values are zero or one.
+# self.assertTrue(
+# torch.all(torch.logical_or(
+# torch.eq(mask.data, 0),
+# torch.eq(mask.data, 1))))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch210-cxx11-xpu20253-x86_64-linux/xpu_fused_moe.py b/build/torch210-cxx11-xpu20253-x86_64-linux/xpu_fused_moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..0440b484e6f4073a384dcf3bc562174d2ef71d25
--- /dev/null
+++ b/build/torch210-cxx11-xpu20253-x86_64-linux/xpu_fused_moe.py
@@ -0,0 +1,672 @@
+# SPDX-License-Identifier: Apache-2.0
+# MegaBlocks XPU Fused MoE Implementation
+import os
+import torch
+
+from ._ops import ops, add_op_namespace_prefix
+
+from torch.library import register_fake
+
+
+def resolve_dtensor(weight: torch.Tensor):
+ """Convert DTensor to local tensor for use with custom ops."""
+ from torch.distributed._tensor import DTensor
+ if isinstance(weight, DTensor):
+ return weight.to_local()
+ return weight
+
+
+# Register fake/meta kernels for torch.compile compatibility
+def _register_xpu_fake_kernels():
+ """Register fake kernels for XPU MoE operations to support torch.compile."""
+
+ def _register_if_available(op_name, fn):
+ if hasattr(ops, op_name):
+ register_fake(add_op_namespace_prefix(op_name))(fn)
+
+ _register_if_available(
+ "cutlass_grouped_gemm_interface",
+ lambda ptr_A, ptr_B, ptr_scales, ptr_bias, ptr_D, expert_first_token_offset, N, K, num_experts, is_B_int4, is_B_mxfp4: ptr_D,
+ )
+
+ _register_if_available(
+ "fused_moe_prologue",
+ lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, ep_rank, ep_size, num_experts_on_rank: None,
+ )
+
+ _register_if_available(
+ "moe_gather",
+ lambda output, moe_output, topk_weights, permuted_row_to_unpermuted_row, unpermuted_row_to_permuted_row, expert_first_token_offset, num_experts: None,
+ )
+
+ _register_if_available(
+ "silu_and_mul",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "mul_and_silu",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_and_mul",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_tanh_and_mul",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_fast",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_new",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_quick",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "swigluoai_and_mul",
+ lambda out, input, alpha=1.702, limit=7.0: None,
+ )
+
+
+# Register fake kernels on module load
+_register_xpu_fake_kernels()
+
+
+# default
+def cutlass_grouped_gemm(input_A, input_B, bias, output, expert_token_count, n,
+ k, num_experts):
+ # expert_token_count_ = torch.tensor(expert_token_count,
+ # dtype=torch.int64,
+ # device=input_A.device)
+ # if bias is not None:
+ # bias = bias.repeat_interleave(expert_token_count_, dim=0).float()
+
+ def exclusive_prefix_sum(arr):
+ prefix = [0]
+ for i, x in enumerate(arr):
+ prefix.append(prefix[-1] + x)
+ return prefix
+
+ expert_offset = torch.tensor(exclusive_prefix_sum(expert_token_count),
+ dtype=torch.int64,
+ device="xpu")
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=None,
+ ptr_bias=bias,
+ ptr_D=output,
+ expert_first_token_offset=expert_offset,
+ N=n,
+ K=k,
+ num_experts=num_experts,
+ is_B_int4=False,
+ is_B_mxfp4=False)
+
+
+def cutlass_grouped_gemm_xe2(input_A, input_B, scales, bias, output,
+ num_rows_per_expert, n, k, num_experts, is_B_int4,
+ is_B_mxfp4):
+ expert_first_token_offset = torch.cat([
+ torch.tensor([0],
+ dtype=num_rows_per_expert.dtype,
+ device=num_rows_per_expert.device),
+ torch.cumsum(num_rows_per_expert, dim=0)
+ ]).to(torch.int64)
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=scales,
+ ptr_bias=bias,
+ ptr_D=output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=n,
+ K=k,
+ num_experts=num_experts,
+ is_B_int4=is_B_int4,
+ is_B_mxfp4=is_B_mxfp4)
+
+
+def ceilDiv(a, b):
+ return (a + b - 1) // b
+
+
+def compute_num_tokens_per_block(num_tokens, num_experts_per_node):
+ for num_tokens_per_block in [32, 64, 128, 256, 512, 1024]:
+ num_blocks_per_seq = ceilDiv(num_tokens, num_tokens_per_block)
+ if num_blocks_per_seq * num_experts_per_node <= num_tokens_per_block:
+ return num_tokens_per_block
+ return 1024
+
+
+def _bytes_to_typed_tensor(byte_tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
+ """Reinterpret a uint8 buffer as a typed tensor by copying bytes.
+
+ This avoids `Tensor.view(dtype)` which can fail under torch.compile
+ constant folding when shape divisibility is not proven.
+ """
+ if byte_tensor.dtype != torch.uint8:
+ raise ValueError("byte_tensor must be uint8")
+ itemsize = torch.empty((), dtype=dtype).element_size()
+ numel = byte_tensor.numel() // itemsize
+ out = torch.empty((numel,), dtype=dtype, device=byte_tensor.device)
+ out.view(torch.uint8).copy_(byte_tensor.contiguous())
+ return out
+
+
+def implement_zp(qweight):
+ # change u4 to s4 to avoid zero point in gemm kernel
+ # only support default zero point now
+ assert qweight.dtype == torch.uint8, "Input tensor must be uint8"
+
+ high_u4 = (qweight >> 4) & 0x0F
+ low_u4 = qweight & 0x0F
+
+ high_s8 = high_u4.to(torch.int8)
+ low_s8 = low_u4.to(torch.int8)
+
+ high_s8 = high_s8 - 8
+ low_s8 = low_s8 - 8
+
+ def pack_compact(a, b):
+
+ def process_number(x):
+ sign = (x < 0).to(torch.uint8)
+ abs_low3 = (x.view(torch.uint8) & 0x7).to(torch.uint8)
+ return (sign << 3) | abs_low3
+
+ packed_a = process_number(a)
+ packed_b = process_number(b)
+
+ return (packed_a << 4) | packed_b
+
+ result = pack_compact(high_s8, low_s8)
+
+ return result
+
+
+def xpu_fused_moe(hidden_states,
+ w13,
+ w13_scales,
+ w13_bias,
+ w2,
+ w2_scales,
+ w2_bias,
+ topk_weights,
+ topk_ids,
+ n_experts_per_token,
+ activation,
+ num_experts,
+ ep_rank=0,
+ ep_size=1,
+ is_fp8=False,
+ is_int4=False,
+ is_mxfp4=False):
+ '''
+ hidden_states: [num_rows, hidden_size]
+ w13: [num_experts, 2*inter_size, hidden_size]
+ w13_scales:
+ None for bf16/fp16
+ or [num_experts] for fp8
+ or [num_experts, 2*inter_size, hidden_size // group_size] for 4bits
+ w13_bias: [num_experts, 2*inter_size] or None
+ w2: [num_experts, hidden_size, inter_size]
+ w2_scales:
+ None for bf16/fp16
+ or [num_experts] for fp8
+ or [num_experts, hidden_size, inter_size // group_size] for 4bits
+ w2_bias: [num_experts, hidden_size] or None
+ topk_weights: [num_rows, topk]
+ topk_ids: [num_rows, topk]
+ n_experts_per_token: int
+ activation: str
+ num_experts: int
+ is_int4: bool
+ is_mxfp4: bool
+ '''
+
+ # Resolve DTensors to local tensors before passing to custom ops
+ hidden_states = resolve_dtensor(hidden_states)
+ w13 = resolve_dtensor(w13)
+ w2 = resolve_dtensor(w2)
+ if w13_scales is not None:
+ w13_scales = resolve_dtensor(w13_scales)
+ if w13_bias is not None:
+ w13_bias = resolve_dtensor(w13_bias)
+ if w2_scales is not None:
+ w2_scales = resolve_dtensor(w2_scales)
+ if w2_bias is not None:
+ w2_bias = resolve_dtensor(w2_bias)
+ topk_weights = resolve_dtensor(topk_weights)
+ topk_ids = resolve_dtensor(topk_ids)
+
+ output = torch.empty_like(hidden_states)
+ num_rows, hidden_size = list(hidden_states.shape)
+
+ dim_last = w13.shape[-1]
+ dim_second_last = w13.shape[-2]
+
+ # w13 is combined gate+up weights, so one dimension is 2*inter_size
+ # Determine which dimension is hidden_size and which is 2*inter_size
+ if dim_second_last == hidden_size:
+ # w13 is [E, hidden_size, 2*inter_size] - standard layout
+ inter_size = dim_last // 2
+ needs_transpose = False
+ else:
+ # w13 is [E, 2*inter_size, hidden_size] - needs transpose
+ inter_size = dim_second_last // 2
+ needs_transpose = True
+
+ assert w13.is_contiguous() and w2.is_contiguous()
+
+ # 4bits support [E, N, K]
+ # other types [E, K, N]
+ if not is_int4 and not is_mxfp4:
+ if not hasattr(w13, 'xpu_fused_moe'):
+ if needs_transpose:
+ w13.data = w13.transpose(-1, -2).contiguous()
+ w2.data = w2.transpose(-1, -2).contiguous()
+ w13.xpu_fused_moe = True
+ w13.inter_size = inter_size
+ else:
+ inter_size = w13.inter_size
+
+ if is_int4 and not hasattr(w13, 'xpu_fused_moe'):
+ w13_tmp = torch.empty_like(w13)
+ w2_tmp = torch.empty_like(w2)
+ for i in range(num_experts):
+ w13_tmp[i] = implement_zp(w13[i])
+ w2_tmp[i] = implement_zp(w2[i])
+ w13_tmp = w13_tmp.contiguous()
+ w2_tmp = w2_tmp.contiguous()
+ w13.data = w13_tmp
+ w2.data = w2_tmp
+ w13.xpu_fused_moe = True
+
+ # TODO: will all integrated in Cpp func. Temporary expose before gemm fusion
+ num_experts_per_node = num_experts
+ experts_per_token = n_experts_per_token
+ num_moe_inputs = n_experts_per_token * num_rows
+ permuted_elems = num_moe_inputs * hidden_size
+ # interbuf_elems = num_moe_inputs * inter_size
+ permuted_row_to_unpermuted_row_size = num_moe_inputs * 4
+ permuted_token_selected_experts_size = num_moe_inputs * 4
+ src_to_dest_map_size = experts_per_token * num_rows * 4
+ expert_first_token_offset_size = (num_experts_per_node + 1) * 8
+ num_tokens_per_block = compute_num_tokens_per_block(
+ num_rows, num_experts_per_node)
+ num_blocks_per_seq = ceilDiv(num_rows, num_tokens_per_block)
+ blocked_expert_counts_size = num_experts_per_node * num_blocks_per_seq * 4
+ blocked_expert_counts_cumsum_size = blocked_expert_counts_size
+ blocked_row_to_unpermuted_row_size = num_experts_per_node * num_rows * 4
+ permuted_data_size = permuted_elems * hidden_states.element_size()
+ permuted_token_final_scales_size = num_moe_inputs * 4
+
+ ws_map = {}
+ map_offset = 0
+
+ def config_ws(name, size):
+ nonlocal map_offset
+ if size % 256 != 0:
+ size += 256 - size % 256
+ ws_map[name] = (size, map_offset)
+ map_offset += size
+
+ config_ws("permuted_row_to_unpermuted_row",
+ permuted_row_to_unpermuted_row_size)
+ config_ws("permuted_token_selected_experts",
+ permuted_token_selected_experts_size)
+ config_ws("unpermuted_row_to_permuted_row", src_to_dest_map_size)
+ config_ws("blocked_expert_counts", blocked_expert_counts_size)
+ config_ws("blocked_expert_counts_cumsum",
+ blocked_expert_counts_cumsum_size)
+ config_ws("blocked_row_to_unpermuted_row",
+ blocked_row_to_unpermuted_row_size)
+ config_ws("expert_first_token_offset", expert_first_token_offset_size)
+ config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
+ config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
+
+ workspace = torch.zeros(map_offset,
+ dtype=torch.uint8,
+ device=hidden_states.device)
+ if topk_ids.dtype == torch.int32:
+ topk_ids = topk_ids.to(torch.int64)
+ ops.fused_moe_prologue(
+ input=hidden_states,
+ token_selected_experts=topk_ids,
+ token_final_scales=topk_weights,
+ workspace=workspace,
+ hidden_size=hidden_size,
+ inter_size=inter_size,
+ ep_rank=ep_rank,
+ ep_size=ep_size,
+ num_experts_on_rank=num_experts_per_node)
+
+ expert_first_token_offset_bytes = workspace[
+ ws_map["expert_first_token_offset"][1]:
+ ws_map["expert_first_token_offset"][1] +
+ expert_first_token_offset_size]
+ unpermuted_row_to_permuted_row_bytes = workspace[
+ ws_map["unpermuted_row_to_permuted_row"][1]:
+ ws_map["unpermuted_row_to_permuted_row"][1] +
+ src_to_dest_map_size]
+ permuted_row_to_unpermuted_row_bytes = workspace[
+ ws_map["permuted_row_to_unpermuted_row"][1]:
+ ws_map["permuted_row_to_unpermuted_row"][1] +
+ permuted_row_to_unpermuted_row_size]
+
+ if torch.compiler.is_compiling():
+ expert_first_token_offset = _bytes_to_typed_tensor(
+ expert_first_token_offset_bytes, torch.int64
+ )
+ unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
+ unpermuted_row_to_permuted_row_bytes, torch.int32
+ )
+ permuted_row_to_unpermuted_row = _bytes_to_typed_tensor(
+ permuted_row_to_unpermuted_row_bytes, torch.int32
+ )
+ else:
+ expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
+ unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
+ permuted_row_to_unpermuted_row = permuted_row_to_unpermuted_row_bytes.view(torch.int32)
+ gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
+ ws_map["overlapped_gemm1_gemm2_inputs"][1] +
+ permuted_data_size].view(hidden_states.dtype).view(
+ num_moe_inputs, hidden_size)
+ # permuted_token_final_scales = workspace[
+ # ws_map["permuted_token_final_scales"][1]:
+ # ws_map["permuted_token_final_scales"][1] +
+ # permuted_token_final_scales_size].view(torch.float)
+ gemm1_output = torch.empty((num_moe_inputs, 2 * inter_size),
+ dtype=hidden_states.dtype,
+ device=hidden_states.device)
+
+ ########### gemm1 ##################
+ input_B = w13
+
+ if not is_fp8 and not is_int4 and not is_mxfp4:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=gemm1_input,
+ ptr_B=input_B,
+ ptr_scales=None,
+ ptr_bias=w13_bias,
+ ptr_D=gemm1_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=2 * inter_size,
+ K=hidden_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+ else:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=gemm1_input,
+ ptr_B=input_B,
+ ptr_scales=w13_scales,
+ ptr_bias=w13_bias,
+ ptr_D=gemm1_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=2 * inter_size,
+ K=hidden_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+
+ # act
+ act_output = torch.empty((num_moe_inputs, inter_size),
+ dtype=gemm1_output.dtype,
+ device=gemm1_output.device)
+ if activation == "silu":
+ ops.silu_and_mul(act_output, gemm1_output)
+ elif activation == "gelu":
+ ops.gelu_and_mul(act_output, gemm1_output)
+ elif activation == "swigluoai":
+ ops.swigluoai_and_mul(act_output, gemm1_output, 1.702, 7.0)
+ else:
+ raise ValueError(f"Unsupported FusedMoe activation: {activation}.")
+
+ ########### gemm2 ##################
+ input_A = act_output.contiguous()
+ input_B = w2
+ gemm2_output = torch.empty((num_moe_inputs, hidden_size),
+ dtype=hidden_states.dtype,
+ device=hidden_states.device)
+ if not is_fp8 and not is_int4 and not is_mxfp4:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=None,
+ ptr_bias=w2_bias,
+ ptr_D=gemm2_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=hidden_size,
+ K=inter_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+ else:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=w2_scales,
+ ptr_bias=w2_bias,
+ ptr_D=gemm2_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=hidden_size,
+ K=inter_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+
+ ops.moe_gather(output, gemm2_output, topk_weights,
+ permuted_row_to_unpermuted_row,
+ unpermuted_row_to_permuted_row,
+ expert_first_token_offset,
+ num_experts_per_node)
+ return output
+
+
+def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor:
+ """Apply jitter to the input tensor for regularization."""
+ low = 1.0 - moe_jitter_eps
+ high = 1.0 + moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return x * (low + noise * (high - low))
+
+
+def compute_top_k(scores: torch.Tensor, moe_top_k: int):
+ """Compute the top-k scores from the logits."""
+ if moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, moe_top_k, dim=-1)
+
+
+def route_tokens_xpu(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: torch.Tensor,
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ training: bool = False,
+) -> tuple:
+ """Route tokens to experts and compute expert weights and indices (XPU version)."""
+ if training and moe_jitter_eps is not None:
+ x = apply_jitter(x, moe_jitter_eps)
+
+ x_flat = x.view(-1, x.shape[-1])
+ logits = torch.nn.functional.linear(x_flat, router_weight, router_bias)
+ expert_weights, expert_indices = compute_top_k(logits, moe_top_k)
+ expert_weights = expert_weights.softmax(dim=-1)
+ if moe_normalize_expert_weights is not None:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+
+ return logits, expert_weights, expert_indices
+
+
+def _get_device_mesh(model):
+ """Extract device_mesh from child's unused pre_hook closure for EP support."""
+ try:
+ hook = next(
+ h
+ for h in model.experts._forward_pre_hooks.values()
+ if "device_mesh" in h.__code__.co_freevars
+ )
+ return hook.__closure__[
+ hook.__code__.co_freevars.index("device_mesh")
+ ].cell_contents
+ except Exception:
+ return None
+
+
+class MegaBlocksMoeMLP(torch.nn.Module):
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> tuple:
+ """
+ Forward pass through the MoE layer.
+
+ Args:
+ x: Input tensor of shape [batch_size, seq_len, hidden_size] or [tokens, hidden_size]
+
+ Returns:
+ Tuple of (output, expert_weights) where:
+ - output: Tensor of same shape as input
+ - expert_weights: Expert weights for each token [tokens, top_k]
+ """
+
+ # Get MoE parameters from the wrapped modules
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+
+ # Get EP (Expert Parallelism) parameters
+ ep_size = 1
+ ep_rank = 0
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = _get_device_mesh(self)
+ if device_mesh is not None:
+ expert_parallel_group = device_mesh.get_group()
+ if expert_parallel_group is not None:
+ import torch.distributed as dist
+ if dist.is_initialized():
+ ep_size = dist.get_world_size(expert_parallel_group)
+ ep_rank = dist.get_rank(expert_parallel_group)
+
+ # Number of experts on this rank
+ num_experts_on_rank = moe_num_experts // ep_size
+
+ # Detect activation type - check for GptOss-style swigluoai activation
+ # GptOssExperts has alpha and limit attributes for swigluoai
+ if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
+ activation = "swigluoai"
+ else:
+ activation = getattr(self.experts, "activation", "silu")
+
+ # Get weight tensors - support different naming conventions
+ if hasattr(self.experts, "gate_up_proj"):
+ w13 = self.experts.gate_up_proj
+ # NOTE: swigluoai_and_mul kernel expects interleaved layout [g0,u0,g1,u1,...]
+ # which matches GptOss's gate_up_proj layout, so no conversion needed.
+
+ elif hasattr(self.experts, "w1"):
+ # Combine w1 and w3 if stored separately
+ w1 = self.experts.w1
+ w3 = getattr(self.experts, "w3", None)
+ if w3 is not None:
+ w13 = torch.cat([w1, w3], dim=-2)
+ else:
+ w13 = w1
+ else:
+ raise AttributeError("experts module must have 'gate_up_proj' or 'w1' attribute")
+
+ if hasattr(self.experts, "down_proj"):
+ w2 = self.experts.down_proj
+ elif hasattr(self.experts, "w2"):
+ w2 = self.experts.w2
+ else:
+ raise AttributeError("experts module must have 'down_proj' or 'w2' attribute")
+
+ # Get optional bias tensors
+ w13_bias = getattr(self.experts, "gate_up_proj_bias", None)
+ w2_bias = getattr(self.experts, "down_proj_bias", None)
+
+ # Get quantization info
+ is_fp8 = getattr(self.experts, "is_fp8", False)
+ is_int4 = getattr(self.experts, "is_int4", False)
+ is_mxfp4 = getattr(self.experts, "is_mxfp4", False)
+
+ w13_scales = getattr(self.experts, "gate_up_proj_scales", None)
+ w2_scales = getattr(self.experts, "down_proj_scales", None)
+
+ # Store original shape
+ in_shape = x.size()
+
+ # Route tokens to experts
+ logits, expert_weights, expert_indices = route_tokens_xpu(
+ x,
+ self.router.weight,
+ getattr(self.router, "bias", None),
+ moe_top_k,
+ moe_num_experts,
+ moe_jitter_eps,
+ moe_normalize_expert_weights,
+ self.training,
+ )
+
+ # Reshape input for fused MoE
+ x_flat = x.view(-1, x.shape[-1])
+
+ # Call XPU fused MoE kernel
+ output = xpu_fused_moe(
+ hidden_states=x_flat,
+ w13=w13,
+ w13_scales=w13_scales,
+ w13_bias=w13_bias,
+ w2=w2,
+ w2_scales=w2_scales,
+ w2_bias=w2_bias,
+ topk_weights=expert_weights.float(),
+ topk_ids=expert_indices,
+ n_experts_per_token=moe_top_k,
+ activation=activation,
+ num_experts=num_experts_on_rank,
+ ep_rank=ep_rank,
+ ep_size=ep_size,
+ is_fp8=is_fp8,
+ is_int4=is_int4,
+ is_mxfp4=is_mxfp4,
+ )
+
+ # All-reduce across EP group to combine partial expert outputs
+ if ep_size > 1 and expert_parallel_group is not None:
+ import torch.distributed as dist
+ dist.all_reduce(output, op=dist.ReduceOp.SUM, group=expert_parallel_group)
+
+ # Restore original shape
+ output = output.view(in_shape)
+
+ return output, expert_weights
+
+
+# Export classes and functions
+__all__ = [
+ "MegaBlocksMoeMLP",
+ "xpu_fused_moe",
+ "cutlass_grouped_gemm",
+ "cutlass_grouped_gemm_xe2",
+]
\ No newline at end of file
diff --git a/build/torch211-cxx11-cpu-x86_64-linux/__init__.py b/build/torch211-cxx11-cpu-x86_64-linux/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..38075732c6d8fa0e1e6ef493145e1aca3851ae6b
--- /dev/null
+++ b/build/torch211-cxx11-cpu-x86_64-linux/__init__.py
@@ -0,0 +1,202 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+from ._ops import ops
+
+from .grouped_gemm import backend as gg_backend
+from .grouped_gemm import ops as gg_ops
+
+
+from ._layers.arguments import Arguments
+from ._layers.dmoe import ParallelDroplessMLP, dMoE
+from ._layers.glu import SparseGLU
+from ._layers.mlp import MLP, SparseMLP
+from ._layers.moe import MoE, ParallelMLP, get_load_balancing_loss
+
+from . import layers
+
+# This section contains the direct kernel exports (not inlcuded in the original code)
+def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
+ """
+ Compute exclusive cumulative sum along the specified dimension.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ result = ops.exclusive_cumsum(x, dim)
+ out.copy_(result)
+ return out
+
+
+def inclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
+ """
+ Compute inclusive cumulative sum along the specified dimension.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ result = ops.inclusive_cumsum(x, dim)
+ out.copy_(result)
+ return out
+
+
+def histogram(x: torch.Tensor, num_bins: int) -> torch.Tensor:
+ """
+ Compute histogram of input tensor values.
+
+ Args:
+ x: Input tensor
+ num_bins: Number of histogram bins
+
+ Returns:
+ Histogram tensor with counts for each bin
+ """
+ return ops.histogram(x, num_bins)
+
+
+def indices(
+ padded_bins: torch.Tensor,
+ block_size: int,
+ output_block_rows: int,
+ output_block_columns: int,
+) -> torch.Tensor:
+ """
+ Construct indices from padded bins for sparse operations.
+
+ Args:
+ padded_bins: Tensor containing bin boundaries
+ block_size: Size of each block
+ output_block_rows: Number of rows in output blocks
+ output_block_columns: Number of columns in output blocks
+
+ Returns:
+ Tensor containing constructed indices
+ """
+ return ops.indices(padded_bins, block_size, output_block_rows, output_block_columns)
+
+
+def replicate_forward(
+ x: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Forward pass of replicate operation - replicate values according to bin sizes.
+
+ Args:
+ x: Input tensor with values to replicate
+ bins: Tensor containing bin sizes
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ return ops.replicate_forward(x, bins, out)
+
+
+def replicate_backward(
+ grad: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Backward pass of replicate operation - reduce gradients back to bins.
+
+ Args:
+ grad: Gradient tensor to reduce
+ bins: Tensor containing bin sizes
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ return ops.replicate_backward(grad, bins, out)
+
+
+def sort(
+ x: torch.Tensor, end_bit: int, x_out: torch.Tensor, iota_out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Radix sort with index tracking.
+
+ Args:
+ x: Input tensor to sort
+ end_bit: Number of bits to consider in sorting
+ x_out: Output tensor for sorted values
+ iota_out: Output tensor for sorted indices
+
+ Returns:
+ The sorted values tensor
+ """
+ return ops.sort(x, end_bit, x_out, iota_out)
+
+
+# Convenience functions for common use cases
+def cumsum(x: torch.Tensor, dim: int = -1, exclusive: bool = False) -> torch.Tensor:
+ """
+ Compute cumulative sum with automatic output allocation.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum (default: last dimension)
+ exclusive: Whether to compute exclusive (True) or inclusive (False) cumsum
+
+ Returns:
+ New tensor containing the cumulative sum
+ """
+ out = torch.empty_like(x)
+ if exclusive:
+ return exclusive_cumsum(x, dim, out)
+ else:
+ return inclusive_cumsum(x, dim, out)
+
+
+def argsort(x: torch.Tensor, end_bit: int = 32) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Sort tensor and return both sorted values and indices.
+
+ Args:
+ x: Input tensor to sort
+ end_bit: Number of bits to consider in sorting
+
+ Returns:
+ Tuple of (sorted_values, sorted_indices)
+ """
+ x_out = torch.empty_like(x)
+ iota_out = torch.empty_like(x)
+ sort(x, end_bit, x_out, iota_out)
+ return x_out, iota_out
+
+
+# Export public API
+__all__ = [
+ "MyReplacementLayer",
+ # Direct kernel exports
+ "exclusive_cumsum",
+ "inclusive_cumsum",
+ "histogram",
+ "indices",
+ "replicate_forward",
+ "replicate_backward",
+ "sort",
+ "cumsum",
+ "argsort",
+ # Original exports
+ "Arguments",
+ "ParallelDroplessMLP",
+ "dMoE",
+ "SparseGLU",
+ "MLP",
+ "SparseMLP",
+ "MoE",
+ "ParallelMLP",
+ "get_load_balancing_loss",
+]
diff --git a/build/torch211-cxx11-cpu-x86_64-linux/_layers/__init__.py b/build/torch211-cxx11-cpu-x86_64-linux/_layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a720e7a2cc4e44636f6e433a2750e945dc38e8b2
--- /dev/null
+++ b/build/torch211-cxx11-cpu-x86_64-linux/_layers/__init__.py
@@ -0,0 +1,10 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# from megablocks.layers.dmoe import dMoE
+from .moe import MoE
+
+__all__ = [
+ 'MoE',
+ # 'dMoE',
+]
diff --git a/build/torch211-cxx11-cpu-x86_64-linux/_layers/activation_fn.py b/build/torch211-cxx11-cpu-x86_64-linux/_layers/activation_fn.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e1d956704840aa4daf7d1d71d24e051567feab9
--- /dev/null
+++ b/build/torch211-cxx11-cpu-x86_64-linux/_layers/activation_fn.py
@@ -0,0 +1,33 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Callable, Union
+
+import torch
+from ..stk import Matrix
+
+
+def act_fn(
+ x: Matrix,
+ function: Callable,
+ return_grad_fn: bool = False,
+ **kwargs,
+) -> Union[tuple[Matrix, Any] | Matrix]:
+ assert isinstance(x, Matrix)
+ with torch.set_grad_enabled(torch.is_grad_enabled() or return_grad_fn):
+ if return_grad_fn:
+ x.data.requires_grad = True
+ out = function(x.data, **kwargs)
+ y = Matrix(
+ x.size(),
+ out,
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
+ if return_grad_fn:
+ return y, out.backward
+ return y
diff --git a/build/torch211-cxx11-cpu-x86_64-linux/_layers/all_to_all.py b/build/torch211-cxx11-cpu-x86_64-linux/_layers/all_to_all.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ac7067bcaa34db1d82b340c43550fe3577aa7a3
--- /dev/null
+++ b/build/torch211-cxx11-cpu-x86_64-linux/_layers/all_to_all.py
@@ -0,0 +1,54 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import torch.distributed as dist
+
+
+class AllToAllOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op):
+ out = torch.empty((sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype)
+
+ ctx.input_shape = x.shape
+ ctx.output_split_sizes = output_split_sizes
+ ctx.input_split_sizes = input_split_sizes
+ ctx.group = group
+ handle = dist.all_to_all_single(
+ out,
+ x,
+ output_split_sizes=output_split_sizes,
+ input_split_sizes=input_split_sizes,
+ group=group,
+ async_op=async_op,
+ )
+ return out, handle
+
+ @staticmethod
+ def backward(ctx, grad, _):
+ if ctx.needs_input_grad[0]:
+ out = torch.empty(
+ ctx.input_shape,
+ device=grad.device,
+ dtype=grad.dtype,
+ )
+ dist.all_to_all_single(
+ out,
+ grad,
+ output_split_sizes=ctx.input_split_sizes,
+ input_split_sizes=ctx.output_split_sizes,
+ group=ctx.group,
+ )
+ return out, None, None, None, None
+ return None, None, None, None, None
+
+
+def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False):
+ return AllToAllOp.apply(
+ x,
+ output_split_sizes,
+ input_split_sizes,
+ group,
+ async_op,
+ )
diff --git a/build/torch211-cxx11-cpu-x86_64-linux/_layers/arguments.py b/build/torch211-cxx11-cpu-x86_64-linux/_layers/arguments.py
new file mode 100644
index 0000000000000000000000000000000000000000..4db9b1bd38bc2e2f421625c124f86b85f45c5ae0
--- /dev/null
+++ b/build/torch211-cxx11-cpu-x86_64-linux/_layers/arguments.py
@@ -0,0 +1,101 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import dataclasses
+from functools import partial
+from typing import Any, Callable, Optional, Union
+
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+
+# import megablocks.grouped_gemm_util as grouped_gemm
+from .. import grouped_gemm_util as grouped_gemm
+
+# Type annotation for in-place Tensor initialization function.
+InitFn = Union[Callable[[torch.Tensor], None], partial[torch.Tensor]]
+
+_ALLOWED_BITWIDTHS = (-1, 4, 8)
+
+DEFAULT_ACTIVATION_FN = partial(F.gelu, approximate='tanh')
+
+
+@dataclasses.dataclass
+class Arguments:
+ # Model arguments.
+ hidden_size: int = 1024
+ ffn_hidden_size: int = 4096
+ num_layers: int = 1
+ bias: bool = True
+ return_bias: bool = True
+ activation_fn: Optional[Callable] = DEFAULT_ACTIVATION_FN
+
+ # MoE arguments.
+ moe_num_experts: int = 1
+ moe_top_k: int = 1
+ moe_capacity_factor: int = 1
+ moe_normalize_expert_weights: Optional[Union[int, float]] = None
+ moe_loss_weight: float = 0.1
+ moe_jitter_eps: Optional[float] = None
+ moe_lbl_in_fp32: bool = False
+
+ # Parallelism arguments.
+ moe_expert_model_parallelism: bool = False
+ expert_parallel_group: Optional[dist.ProcessGroup] = None
+ pipeline_model_parallel_size: int = 1
+ num_layers_per_virtual_pipeline_stage: Optional[int] = None
+
+ # Compute arguments.
+ memory_optimized_mlp: bool = False
+ mlp_type: str = 'mlp'
+ mlp_impl: str = 'sparse'
+
+ # Initialization arguments.
+ fp16: bool = True
+ bf16: bool = False
+ device: Union[int, torch.device] = dataclasses.field(default_factory=torch.cuda.current_device)
+ init_method: InitFn = partial(torch.nn.init.normal_, mean=0.0, std=0.02)
+ output_layer_init_method: InitFn = init_method
+
+ # Benchmarking arguments.
+ uniform_expert_assignment: bool = False
+
+ # shared expert arguments
+ shared_expert: bool = False # enable using shared expert
+ fc_cls: Any = torch.nn.Linear # class of the fully connected layer in shared expert (purpose: to allow using custom FC layer eg te.Linear (for FP8))
+ fc_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict,) # kwargs for custom fc layers
+ remat_act_fn: bool = True # enable act fn to be rematerialized instead of stored
+ shared_expert_hidden_size: Optional[
+ int] = None # hidden size of the shared expert IF we want to set it to something different from hidden_size
+ shared_expert_weighted_sum: bool = False # enable using weighted sum for shared expert output (wieghted by number of experts used)
+
+ # Router Z-loss arguments
+ moe_zloss_weight: float = 0 # 1e-3 is a reasonable value
+ moe_zloss_in_fp32: bool = False
+
+ def __post_init__(self):
+ # Sparse MLP is not supported with triton >=3.2.0
+ # TODO: Remove this once sparse is supported with triton >=3.2.0
+ if self.__getattribute__('mlp_impl') == 'sparse':
+ try:
+ import triton
+ if triton.__version__ >= '3.2.0':
+ raise ValueError(
+ 'Sparse MLP is not supported with triton >=3.2.0. Please use mlp_impl="grouped" instead.',
+ )
+ except ImportError:
+ raise ImportError('Triton is required for sparse MLP implementation')
+
+ if self.__getattribute__('mlp_impl') == 'grouped':
+ grouped_gemm.assert_grouped_gemm_is_available()
+
+ if self.shared_expert_hidden_size is None:
+ self.shared_expert_hidden_size = self.ffn_hidden_size
+
+
+def from_megatron(megatron_args: Any):
+ args = Arguments()
+ for field in dataclasses.fields(args):
+ if hasattr(megatron_args, field.name):
+ setattr(args, field.name, getattr(megatron_args, field.name))
+ return args
diff --git a/build/torch211-cxx11-cpu-x86_64-linux/_layers/common.py b/build/torch211-cxx11-cpu-x86_64-linux/_layers/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d07109702963ba48a3b94ab860807954dfd79c1
--- /dev/null
+++ b/build/torch211-cxx11-cpu-x86_64-linux/_layers/common.py
@@ -0,0 +1,26 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+from .arguments import Arguments
+
+
+def dtype(args: Arguments):
+ if args.fp16:
+ return torch.float16
+ elif args.bf16:
+ return torch.bfloat16
+ return None
+
+
+def cast_if_autocast_enabled(tensor):
+ if torch.is_autocast_enabled():
+ if tensor.device.type == 'cuda':
+ dtype = torch.get_autocast_gpu_dtype()
+ elif tensor.device.type == 'cpu':
+ dtype = torch.get_autocast_cpu_dtype()
+ else:
+ raise NotImplementedError()
+ return tensor.to(dtype=dtype)
+ return tensor
diff --git a/build/torch211-cxx11-cpu-x86_64-linux/_layers/dmlp_registry.py b/build/torch211-cxx11-cpu-x86_64-linux/_layers/dmlp_registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..de2ed047042e438c7190ebb139b6f7f30009734c
--- /dev/null
+++ b/build/torch211-cxx11-cpu-x86_64-linux/_layers/dmlp_registry.py
@@ -0,0 +1,42 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Union
+
+from . import glu, mlp
+from .arguments import Arguments
+
+MlpType = Union[mlp.SparseMLP, glu.SparseGLU]
+
+_REGISTRY = {
+ 'mlp': {
+ 'grouped': mlp.GroupedMLP,
+ 'sparse': mlp.SparseMLP,
+ },
+ 'glu': {
+ 'grouped': glu.GroupedGLU,
+ 'sparse': glu.SparseGLU,
+ },
+}
+
+
+def get(args: Arguments) -> MlpType:
+ """Returns an MLP for use in a dMoE instance.
+
+ Uses the provided arguments to instantiate the appropriate
+ MLP instance. This only contains MLPs for use in dMoEs
+ (ie. only for the dropless versions of MoEs).
+
+ Args:
+ args: propagated Arguments dataclass.
+
+ Returns:
+ An instantiated MLP constructed using the input args.
+ """
+ if args.mlp_type not in _REGISTRY:
+ raise ValueError(f'Unsupported mlp type: {args.mlp_type}')
+
+ if args.mlp_impl not in _REGISTRY[args.mlp_type]:
+ raise ValueError(f'{args.mlp_type} does not support {args.mlp_impl} backend.',)
+
+ return _REGISTRY[args.mlp_type][args.mlp_impl](args)
diff --git a/build/torch211-cxx11-cpu-x86_64-linux/_layers/dmoe.py b/build/torch211-cxx11-cpu-x86_64-linux/_layers/dmoe.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d0375a4df2f27134c4127e60be04f3b45693050
--- /dev/null
+++ b/build/torch211-cxx11-cpu-x86_64-linux/_layers/dmoe.py
@@ -0,0 +1,337 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import numpy as np
+import torch
+
+# try:
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/dmoe.py is needed.',
+# )
+
+# import megablocks.ops as ops
+# # from megablocks.ops import ops
+# from megablocks.layers import common, dmlp_registry, moe, mpu
+# from megablocks.layers.arguments import Arguments
+
+from .. import stk
+from .. import ops
+from . import common, dmlp_registry, moe, mpu
+from .arguments import Arguments
+
+def promote_scalar(x):
+ return x.view(1) if not len(x.size()) else x
+
+
+class ParallelDroplessMLP(moe.ParallelMLP):
+
+ def __init__(self, args: Arguments):
+ super(ParallelDroplessMLP, self).__init__(args)
+ self.hidden_size = args.hidden_size
+ self.ffn_hidden_size = mpu.features_per_rank(args)
+ self.blocking = 128
+ self.mlp = dmlp_registry.get(args)
+
+ # Calculate the number of bits needed to represent the column indices
+ # in the intermediate sparse matrix.
+ max_column_index = ((self.ffn_hidden_size * self.num_experts) // self.blocking)
+ self.transpose_sort_end_bit = max(
+ int(np.ceil(np.log2(max_column_index))),
+ 1,
+ )
+
+ def sparse_transpose(self, size, row_indices, column_indices, offsets):
+ block_columns = size[1] // self.blocking
+
+ # Sort row indices by column indices to get the transposed matrix's
+ # column indices.
+ #
+ # NOTE: Our sort operation uses the same width indices as the input values.
+ # To avoid overflow when we have large activation matrices we cast to
+ # 32-bit before sorting.
+ _, gather_indices = ops.sort(
+ column_indices.int(),
+ self.transpose_sort_end_bit,
+ )
+
+ # There are a constant number of blocks in every row of the sparse matrix.
+ # A blocks offset is:
+ #
+ # row_index * blocks_per_row + column_index % blocks_per_row
+ #
+ # Once we have the block offsets ordered for transposition we can divide
+ # by blocks_per_row to get the transposed column indices.
+ column_indices_t = row_indices.gather(0, gather_indices.long())
+ block_offsets_t = gather_indices.int()
+
+ zero = torch.zeros((1,), dtype=torch.int32, device=row_indices.device)
+ nnz_per_column = ops.histogram(column_indices, block_columns)
+ nnz_per_column = ops.inclusive_cumsum(nnz_per_column, 0)
+ if nnz_per_column.dim() == 0:
+ # This addresses an edge case when ffn_hidden_size is equal to self.blocking.
+ nnz_per_column = nnz_per_column.unsqueeze(0)
+ offsets_t = torch.cat([zero, nnz_per_column])
+ return column_indices_t, offsets_t, block_offsets_t
+
+ def topology(self, x, padded_bins):
+ padded_tokens, _ = x.size()
+ assert padded_tokens % self.blocking == 0
+ if self.ffn_hidden_size % self.blocking != 0:
+ raise ValueError(
+ f'The ffn_hidden_size {self.ffn_hidden_size} must be divisible by ' +
+ f'the block size {self.blocking}. Please update your configuration.',
+ )
+
+ # Offsets for the sparse matrix. All rows have the
+ # same number of nonzero blocks dictated by the
+ # dimensionality of a single expert.
+ block_rows = padded_tokens // self.blocking
+ blocks_per_row = self.ffn_hidden_size // self.blocking
+ offsets = torch.arange(
+ 0,
+ block_rows * blocks_per_row + 1,
+ blocks_per_row,
+ dtype=torch.int32,
+ device=x.device,
+ )
+
+ # Indices for the sparse matrix. The indices for
+ # the intermediate matrix are dynamic depending
+ # on the mapping of tokens to experts.
+ column_indices = ops.topology(
+ padded_bins,
+ self.blocking,
+ block_rows,
+ blocks_per_row,
+ )
+
+ # TODO(tgale): This is unused. Remove the need for this in stk.
+ # For now, use meta init to save the device memory.
+ data = torch.empty(
+ column_indices.numel(),
+ self.blocking,
+ self.blocking,
+ dtype=common.dtype(self.args),
+ device='meta',
+ )
+ shape = (
+ padded_tokens,
+ self.ffn_hidden_size * mpu.experts_per_rank(self.args),
+ )
+ row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
+ column_indices_t, offsets_t, block_offsets_t = self.sparse_transpose(
+ shape,
+ row_indices,
+ column_indices,
+ offsets,
+ )
+ return stk.Matrix(
+ shape,
+ data,
+ row_indices,
+ column_indices,
+ offsets,
+ column_indices_t,
+ offsets_t,
+ block_offsets_t,
+ )
+
+ def indices_and_padded_bins(self, top_experts):
+ # Sort the expert ids to produce the scatter/gather
+ # indices for the permutation.
+ top_experts = top_experts.int()
+ bin_ids, indices = ops.sort(top_experts, self.sort_end_bit)
+
+ # Histogram the expert ids to identify the number of
+ # tokens routed to each expert.
+ tokens_per_expert = ops.histogram(top_experts, self.num_experts)
+
+ # Round the token counts up to the block size used in
+ # the matrix muliplications. Caculate the starting
+ # position of each bin.
+ padded_tokens_per_expert = ops.round_up(
+ tokens_per_expert,
+ self.blocking,
+ )
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ padded_bins = promote_scalar(padded_bins)
+
+ # Calculate the bin bounds for the sorted tokens.
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ bins = promote_scalar(bins)
+ return indices, bin_ids, bins, padded_bins, tokens_per_expert
+
+ def sparse_forward_once(self, x, expert_weights, top_experts):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, padded_bins, tokens_per_expert = (self.indices_and_padded_bins(top_experts))
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.padded_gather(
+ x,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ self.top_k,
+ )
+
+ # Create the sparse matrix topology.
+ with torch.no_grad():
+ topo = self.topology(x, padded_bins)
+
+ # Perform the expert computation.
+ x = self.mlp(x, topo)
+
+ # Un-route the data for the MoE output.
+ x = ops.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ padded_bins,
+ self.top_k,
+ )
+ return x, tokens_per_expert
+
+ # For use in the base-class parallel_forward_once.
+ def sparse_permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy, # unused
+ top_k,
+ ):
+
+ # Round the token counts up to the block size used in the matrix
+ # multiplication. Calculate the starting position of each bin.
+ padded_tokens_per_expert = ops.round_up(
+ tokens_per_expert,
+ self.blocking,
+ )
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ padded_bins = promote_scalar(padded_bins)
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
+
+ # Create the sparse matrix topology.
+ with torch.no_grad():
+ topo = self.topology(x, padded_bins)
+
+ # Perform the expert computation.
+ x = self.mlp(x, topo)
+
+ # Un-route the data for the MoE output.
+ return ops.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ def grouped_forward_once(self, x, expert_weights, top_experts):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ out = self.grouped_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ -1, # unused
+ self.args.moe_top_k,
+ )
+ return out, tokens_per_expert
+
+ def grouped_permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy, # unused
+ top_k,
+ ):
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.gather(x, indices, bin_ids, bins, top_k)
+
+ # Perform the expert computation.
+ x = self.mlp(x, tokens_per_expert)
+
+ # Un-route the data for the MoE output.
+ return ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
+
+ def forward_once(self, x, expert_weights, top_experts):
+ if self.args.mlp_impl == 'sparse':
+ return self.sparse_forward_once(x, expert_weights, top_experts)
+ else:
+ return self.grouped_forward_once(x, expert_weights, top_experts)
+
+ def permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ ):
+ if self.args.mlp_impl == 'sparse':
+ return self.sparse_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ )
+ else:
+ return self.grouped_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ )
+
+
+class dMoE(moe.MoE):
+
+ def _init_experts_mlp(self, args: Arguments):
+ return ParallelDroplessMLP(args)
diff --git a/build/torch211-cxx11-cpu-x86_64-linux/_layers/gelu.py b/build/torch211-cxx11-cpu-x86_64-linux/_layers/gelu.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4c9e6532798615b5c12c96694241a4c18ee8f7b
--- /dev/null
+++ b/build/torch211-cxx11-cpu-x86_64-linux/_layers/gelu.py
@@ -0,0 +1,52 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# try:
+# import stk
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/gelu.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+import torch.nn.functional as F
+
+
+@torch.jit.script
+def _gelu_backward_inplace(g, x):
+ tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
+ ff = (0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out))
+ return g.mul_(ff)
+
+
+def gelu_backward_(grad: stk.Matrix, x: stk.Matrix):
+ # NOTE: The two sparse matrices must have the same topology.
+ if isinstance(grad, stk.Matrix) and isinstance(x, stk.Matrix):
+ return stk.Matrix(
+ x.size(),
+ _gelu_backward_inplace(grad.data, x.data),
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
+ return _gelu_backward_inplace(grad, x)
+
+
+def gelu(x: stk.Matrix):
+ assert isinstance(x, stk.Matrix)
+ return stk.Matrix(
+ x.size(),
+ F.gelu(x.data, approximate='tanh'),
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
diff --git a/build/torch211-cxx11-cpu-x86_64-linux/_layers/glu.py b/build/torch211-cxx11-cpu-x86_64-linux/_layers/glu.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f297a41ff6a1a2a285f5b461951672364b898da
--- /dev/null
+++ b/build/torch211-cxx11-cpu-x86_64-linux/_layers/glu.py
@@ -0,0 +1,244 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# import stk.ops
+# try:
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/glu.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+
+# from megablocks import grouped_gemm_util as gg
+# from megablocks.layers import common, mpu
+# from megablocks.layers.activation_fn import act_fn
+# from megablocks.layers.arguments import Arguments
+# from megablocks.layers.mlp import (
+# SharedMLP,
+# SparseMLP,
+# create_dmoe_expert_weights,
+# resolve_dtensor,
+# )
+
+from .. import grouped_gemm_util as gg
+from . import common, mpu
+from .activation_fn import act_fn
+from .arguments import Arguments
+from .mlp import (
+ SharedMLP,
+ SparseMLP,
+ create_dmoe_expert_weights,
+ resolve_dtensor,
+)
+
+
+class SparseGLU(SparseMLP):
+
+ def __init__(self, args: Arguments):
+ super().__init__(args)
+ self.v1 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ with torch.no_grad():
+ self.v1.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ ),
+ )
+
+ mpu.set_expert_model_parallel_attributes(
+ self.v1,
+ self._should_set_parallelism_attribute,
+ )
+
+ def forward(self, x, topo):
+ if self.args.memory_optimized_mlp:
+ raise NotImplementedError(
+ 'Memory optimized implementation not yet supported with GLU with sparse kernels.',
+ )
+
+ w1, v1, w2 = self.scale_grad(self.w1), self.scale_grad(self.v1,), self.scale_grad(self.w2)
+ w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2)
+
+ # Compute the GLU.
+ x1 = stk.ops.sdd(x, w1.t(), topo)
+ x2 = stk.ops.sdd(x, v1.t(), topo)
+
+ activation_fn_out = act_fn(x1, self.args.activation_fn)
+ x1 = stk.ops.mul(activation_fn_out, x2)
+
+ return stk.ops.dsd(x1, w2)
+
+
+class MemoryOptimizedGroupedGLU(torch.autograd.Function):
+ """GroupedMLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, v1, w2, batch_sizes, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ v1 = v1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], v1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not v1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1', 'v1' and 'w2'.")
+
+ # Layer 0: x @ w1.t().
+ assert gg.backend is not None
+ sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True)
+ v1_out = gg.backend.gmm(x, v1, batch_sizes, trans_b=True)
+
+ # GeLU.
+ activation_fn_out = activation_fn(sdd_out) * v1_out
+
+ # Layer 1: x @ w2.
+ dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, v1, w2, batch_sizes, x, sdd_out, v1_out)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx, ddsd_out):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # Unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, v1, w2 = saved_tensors[:3]
+ batch_sizes = saved_tensors[3]
+ x = saved_tensors[4]
+ sdd_out, v1_out = saved_tensors[5:7]
+
+ # Rematerialize activation_fn output.
+ activation_fn = ctx.activation_fn
+ with torch.set_grad_enabled(True):
+ sdd_out.requires_grad = True
+ v1_out.requires_grad = True
+ activation_fn_out = activation_fn(sdd_out) * v1_out
+ activation_grad_fn = activation_fn_out.backward
+
+ # Compute dw2 with recomputed activation_fn output.
+ assert gg.backend is not None
+ dw2 = gg.backend.gmm(
+ activation_fn_out,
+ ddsd_out,
+ batch_sizes,
+ trans_a=True,
+ )
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ gg.backend.gmm(
+ ddsd_out,
+ w2,
+ batch_sizes,
+ trans_b=True,
+ c=dactivation_fn_out,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out)
+ dsdd_out = sdd_out.grad
+ dv1_out = v1_out.grad
+
+ # Compute dw1.
+ dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True)
+
+ # Compute dv1.
+ dv1 = gg.backend.gmm(dv1_out, x, batch_sizes, trans_a=True)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ dx = ddsd_out
+ gg.backend.gmm(dsdd_out, w1, batch_sizes, c=dx)
+ dx += gg.backend.gmm(dv1_out, v1, batch_sizes)
+ return dx, dw1, dv1, dw2, None, None
+
+
+memory_optimized_grouped_glu = MemoryOptimizedGroupedGLU.apply
+
+
+class GroupedGLU(SparseGLU):
+
+ def forward(self, x, tokens_per_expert):
+ batch_sizes = tokens_per_expert.cpu().to(torch.long)
+ w1, v1, w2 = (
+ self.scale_grad(self.w1),
+ self.scale_grad(self.v1),
+ self.scale_grad(self.w2),
+ )
+ w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2)
+
+ # Re-shape the weights for the grouped GEMMs.
+ ne = mpu.experts_per_rank(self.args)
+ w1 = w1.view(ne, -1, self.args.hidden_size)
+ v1 = v1.view(ne, -1, self.args.hidden_size)
+ w2 = w2.view(ne, -1, self.args.hidden_size)
+
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_grouped_glu(
+ x,
+ w1,
+ v1,
+ w2,
+ batch_sizes,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ assert gg.ops is not None
+ x1 = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
+ x2 = gg.ops.gmm(x, v1, batch_sizes, trans_b=True)
+ x1 = self.args.activation_fn(x1) * x2
+ return gg.ops.gmm(x1, w2, batch_sizes)
+
+
+class SharedGLU(SharedMLP):
+ """GPU for shared expert.
+
+ Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTGLU class
+ """
+
+ def __init__(self, args: Arguments):
+ super().__init__(args)
+ self.gate_proj = args.fc_cls(
+ args.hidden_size,
+ self.args.shared_expert_hidden_size,
+ **self.fc_kwargs,
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))
diff --git a/build/torch211-cxx11-cpu-x86_64-linux/_layers/memory_test.py b/build/torch211-cxx11-cpu-x86_64-linux/_layers/memory_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..74d1166931b712635131985b25a89f4ca23e576d
--- /dev/null
+++ b/build/torch211-cxx11-cpu-x86_64-linux/_layers/memory_test.py
@@ -0,0 +1,103 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import gc
+
+import torch
+import torch.distributed as dist
+
+# from megablocks.layers import arguments, dmoe
+from . import arguments, dmoe
+
+_TESTS = ((8, 2048, 4096, 4096, 32, 4),)
+
+
+def get_tensors():
+ ptrs = set()
+ out = []
+ for obj in gc.get_objects():
+ if torch.is_tensor(obj):
+ if not obj.is_contiguous() or obj.data_ptr() in ptrs:
+ continue
+ out.append(obj)
+ ptrs.add(obj.data_ptr())
+ return out
+
+
+def test_memory(
+ group,
+ batch_size,
+ sequence_length,
+ hidden_size,
+ ffn_hidden_size,
+ num_experts,
+ top_k,
+):
+ args = arguments.Arguments(
+ hidden_size=hidden_size,
+ ffn_hidden_size=ffn_hidden_size,
+ moe_num_experts=num_experts,
+ moe_top_k=top_k,
+ moe_expert_model_parallelism=True,
+ expert_parallel_group=group,
+ fp16=False,
+ bf16=True,
+ device=torch.cuda.current_device(),
+ )
+ layer = dmoe.dMoE(args).cuda()
+
+ x = torch.randn((batch_size, sequence_length, hidden_size),
+ device=torch.cuda.current_device(),
+ dtype=torch.bfloat16).requires_grad_(True)
+ torch.cuda.empty_cache()
+
+ # Run forward + backward.
+ # with torch.autograd.detect_anomaly():
+ out, _ = layer(x)
+ out.mean().backward()
+
+ # Report peak memory.
+ mem = torch.cuda.max_memory_allocated()
+ print('Max Memory Allocated = {:0.0f}MiB'.format(mem / 1e6))
+ print('Max Memory Reserved = {:0.0f}MiB'.format(torch.cuda.max_memory_reserved() / 1e6,),)
+
+ # Calculate weight and gradient memory usage.
+ weight_memory = 2 * (
+ layer.router.layer.weight.numel() + layer.experts.mlp.w1.numel() + layer.experts.mlp.w2.numel()
+ )
+
+ def grad_numel(x):
+ if x.grad is not None:
+ return x.grad.numel()
+ return 0
+
+ grad_memory = 2 * (
+ grad_numel(layer.router.layer.weight) + grad_numel(layer.experts.mlp.w1) + grad_numel(layer.experts.mlp.w2)
+ )
+ weight_memory += grad_memory
+
+ print('Weight Memory Allocated = {:0.0f}MiB'.format(weight_memory / 1e6))
+ print('Activation Memory Allocated = {:0.0f}MiB'.format((mem - weight_memory) / 1e6,),)
+
+ # Manually calculate GPU memory usage from the garbage
+ # collector.
+ gc.collect()
+ total = 0
+ tensors = get_tensors()
+ tensors = sorted(tensors, key=lambda x: -x.numel())
+ for i, t in enumerate(tensors):
+ total += t.numel()
+ print(f'{i}: {t.shape}, {t.numel() * 2}')
+ del tensors
+
+ print('Total Bytes Found = {:0.0f}MiB'.format(total * 2 / 1e6))
+
+
+if __name__ == '__main__':
+ assert dist.is_available()
+ group = dist.init_process_group(backend='nccl')
+ local_rank = dist.get_rank(group)
+ torch.cuda.set_device(local_rank)
+
+ for args in _TESTS:
+ test_memory(group, *args)
diff --git a/build/torch211-cxx11-cpu-x86_64-linux/_layers/mlp.py b/build/torch211-cxx11-cpu-x86_64-linux/_layers/mlp.py
new file mode 100644
index 0000000000000000000000000000000000000000..c99afb9904c24a8b6a83e79059cd1251dbbfd99e
--- /dev/null
+++ b/build/torch211-cxx11-cpu-x86_64-linux/_layers/mlp.py
@@ -0,0 +1,587 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# try:
+# import stk
+# import stk.backend.triton_kernels
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/mlp.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+from packaging import version
+
+# from megablocks import grouped_gemm_util as gg
+# from megablocks.layers import common, gelu, mpu
+# from megablocks.layers.activation_fn import act_fn
+# from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
+
+from .. import grouped_gemm_util as gg
+from . import common, gelu, mpu
+from .activation_fn import act_fn
+from .arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
+
+class ScaleGradient(torch.autograd.Function):
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx: Any, x: torch.Tensor, scale: float):
+ ctx.scale = scale
+ return x
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx: torch.Tensor, grad: torch.Tensor):
+ return grad * ctx.scale, None
+
+
+scale_gradient = ScaleGradient.apply
+
+
+def resolve_dtensor(weight: torch.Tensor):
+ if version.parse(torch.__version__) >= version.parse('2.0.0'):
+ from torch.distributed._tensor import DTensor
+ if isinstance(weight, DTensor):
+ return weight.to_local()
+ return weight
+
+
+def create_moe_expert_weights(
+ args: Arguments,
+ num_experts: int,
+ ffn_hidden_size: int,
+ hidden_size: int,
+ init_method: InitFn,
+):
+ # Create the entire weight matrix such that the sampled weights will
+ # not vary between data parallelism and expert model parallelism for
+ # the same random seed.
+ master_weights = torch.empty(
+ num_experts,
+ ffn_hidden_size,
+ hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ )
+ init_method(master_weights)
+
+ if not args.moe_expert_model_parallelism:
+ return master_weights
+
+ # Calculate the amount of sharding in each dimension.
+ expert_sharding_degree = mpu.expert_sharding_degree(args)
+ hidden_sharding_degree = mpu.hidden_sharding_degree(args)
+
+ # Calculate the experts per rank.
+ #
+ # NOTE: We assign ranks to be expert parallel before going
+ # tensor parallel.
+ rank = mpu.get_expert_parallel_rank(args)
+ expert_rank = rank % expert_sharding_degree
+ num_experts_per_rank = num_experts // expert_sharding_degree
+ start_expert = expert_rank * num_experts_per_rank
+ end_expert = (expert_rank + 1) * num_experts_per_rank
+
+ # Calculate the rows per rank.
+ row_rank = rank // expert_sharding_degree
+ num_rows_per_rank = ffn_hidden_size // hidden_sharding_degree
+ start_row = row_rank * num_rows_per_rank
+ end_row = (row_rank + 1) * num_rows_per_rank
+
+ # Slice the weight matrix to get the chunk for this rank.
+ with torch.no_grad():
+ weights = master_weights[start_expert:end_expert, start_row:end_row]
+ return weights
+
+
+class MLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ # expert_parallel_world_size = mpu.get_expert_parallel_world_size(args)
+ experts_per_rank = mpu.experts_per_rank(args)
+
+ self.w1 = torch.nn.Parameter(
+ torch.empty(
+ experts_per_rank,
+ args.hidden_size,
+ mpu.features_per_rank(args),
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ self.w2 = torch.nn.Parameter(
+ torch.empty(
+ experts_per_rank,
+ mpu.features_per_rank(args),
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w1,
+ args.moe_expert_model_parallelism,
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w2,
+ args.moe_expert_model_parallelism,
+ )
+
+ # Initialize the parameters for the MLP.
+ #
+ # NOTE: It is important that we create the weight tensors prior
+ # to creating the master weights and slicing our the piece for
+ # this rank. If the master weights are created first the PyTorch
+ # caching allocator appears to use the same memory block for these
+ # and the slice which causes large increases in our peak memory
+ # usage.
+ with torch.no_grad():
+ w1 = create_moe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ )
+ self.w1.copy_(w1.transpose(1, 2).contiguous())
+ self.w2.copy_(
+ create_moe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.output_layer_init_method,
+ ),
+ )
+
+ self.gradient_scale = None
+ if self.args.moe_expert_model_parallelism:
+ self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,)
+
+ def scale_grad(self, w):
+ if self.gradient_scale is None:
+ return w
+ return scale_gradient(w, self.gradient_scale)
+
+ def forward(self, x):
+ w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2)
+ w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2)
+ x = torch.bmm(x, w1)
+ x = self.args.activation_fn(x)
+ return torch.bmm(x, w2)
+
+
+def create_dmoe_expert_weights(
+ args: Arguments,
+ num_experts: int,
+ rows: int,
+ columns: int,
+ init_method: InitFn,
+):
+ weights = create_moe_expert_weights(
+ args,
+ num_experts,
+ rows,
+ columns,
+ init_method,
+ )
+ return weights.view([-1, columns])
+
+
+class MemoryOptimizedMLP(torch.autograd.Function):
+ """Sparse MLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, w2, topo, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.")
+
+ topo_tensors = (
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets,
+ topo.column_indices_t,
+ topo.offsets_t,
+ topo.block_offsets_t,
+ )
+
+ # Layer 0: x @ w1.t().
+ sdd_out = stk.ops.sdd(x, w1.t(), topo)
+
+ # GeLU.
+ activation_fn_out = act_fn(sdd_out, activation_fn)
+
+ # Layer 1: x @ w2.
+ dsd_out = stk.ops.dsd(activation_fn_out, w2)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.shape = topo.shape
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.data.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, w2, *topo_tensors, x, sdd_out.data)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx, ddsd_out):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, w2 = saved_tensors[:2]
+ topo_tensors = saved_tensors[2:8]
+ x = saved_tensors[8]
+ sdd_out_data = saved_tensors[9]
+
+ # rematerialize activation function output
+ activation_fn = ctx.activation_fn
+ sdd_out = stk.Matrix(ctx.shape, sdd_out_data, *topo_tensors)
+ activation_fn_out, activation_grad_fn = act_fn(
+ sdd_out,
+ activation_fn,
+ return_grad_fn=True,
+ )
+
+ # Compute dw2 with recomputed activation_fn output.
+ dw2 = stk.ops.dsd(activation_fn_out.t(), ddsd_out)
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ stk.backend.triton_kernels.sdd(
+ ddsd_out,
+ w2.t(),
+ dactivation_fn_out.shape,
+ dactivation_fn_out.data,
+ dactivation_fn_out.offsets,
+ dactivation_fn_out.row_indices,
+ dactivation_fn_out.column_indices,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ if activation_fn is DEFAULT_ACTIVATION_FN:
+ dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out)
+ else:
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out.data)
+ dsdd_out = stk.Matrix(ctx.shape, sdd_out.data.grad, *topo_tensors)
+
+ # Compute dw1.
+ dw1 = stk.ops.dsd(dsdd_out.t(), x)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ stk.backend.triton_kernels.dsd(
+ dsdd_out.shape,
+ dsdd_out.data,
+ dsdd_out.offsets,
+ dsdd_out.row_indices,
+ dsdd_out.column_indices,
+ dsdd_out.offsets_t,
+ dsdd_out.column_indices_t,
+ dsdd_out.block_offsets_t,
+ False,
+ w1,
+ ddsd_out,
+ )
+ dx = ddsd_out
+ return dx, dw1, dw2, None, None
+
+
+memory_optimized_mlp = MemoryOptimizedMLP.apply
+
+
+class SparseMLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ self._num_rows_per_rank = mpu.experts_per_rank(args) * mpu.features_per_rank(args)
+
+ self.w1 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ self.w2 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+
+ # Initialize the parameters for the MLP.
+ #
+ # NOTE: It is important that we create the weight tensors prior
+ # to creating the master weights and slicing our the piece for
+ # this rank. If the master weights are created first the PyTorch
+ # caching allocator appears to use the same memory block for these
+ # and the slice which causes large increases in our peak memory
+ # usage.
+ with torch.no_grad():
+ self.w1.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ ),
+ )
+ self.w2.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.output_layer_init_method,
+ ),
+ )
+
+ self._should_set_parallelism_attribute = args.moe_expert_model_parallelism
+ mpu.set_expert_model_parallel_attributes(
+ self.w1,
+ self._should_set_parallelism_attribute,
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w2,
+ self._should_set_parallelism_attribute,
+ )
+
+ self.gradient_scale = None
+ if self.args.moe_expert_model_parallelism:
+ self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,)
+
+ def scale_grad(self, w):
+ if self.gradient_scale is None:
+ return w
+ return scale_gradient(w, self.gradient_scale)
+
+ def forward(self, x, topo):
+ w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2)
+ w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2)
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_mlp(
+ x,
+ w1,
+ w2,
+ topo,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ x = stk.ops.sdd(x, w1.t(), topo)
+ activation_fn_out = act_fn(x, self.args.activation_fn)
+ return stk.ops.dsd(activation_fn_out, w2)
+
+
+class MemoryOptimizedGroupedMLP(torch.autograd.Function):
+ """GroupedMLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, w2, batch_sizes, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.")
+
+ # Layer 0: x @ w1.t().
+ assert gg.backend is not None
+ sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True)
+
+ # activation_fn
+ activation_fn_out = activation_fn(sdd_out)
+
+ # Layer 1: x @ w2.
+ dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, w2, batch_sizes, x, sdd_out)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx: Any, ddsd_out: torch.Tensor):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # Unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, w2 = saved_tensors[:2]
+ batch_sizes = saved_tensors[2]
+ x = saved_tensors[3]
+ sdd_out = saved_tensors[4]
+
+ # Rematerialize activation_fn output.
+ activation_fn = ctx.activation_fn
+ with torch.set_grad_enabled(True):
+ sdd_out.requires_grad = True
+ activation_fn_out = activation_fn(sdd_out)
+ activation_grad_fn = activation_fn_out.backward
+
+ # Compute dw2 with recomputed activation_fn output.
+ assert gg.backend is not None
+ dw2 = gg.backend.gmm(
+ activation_fn_out,
+ ddsd_out,
+ batch_sizes,
+ trans_a=True,
+ )
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ gg.backend.gmm(
+ ddsd_out,
+ w2,
+ batch_sizes,
+ trans_b=True,
+ c=dactivation_fn_out,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ if activation_fn is DEFAULT_ACTIVATION_FN:
+ dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out)
+ else:
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out)
+ dsdd_out = sdd_out.grad
+
+ # Compute dw1.
+ dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ gg.backend.gmm(dsdd_out, w1, batch_sizes, c=ddsd_out)
+ dx = ddsd_out
+ return dx, dw1, dw2, None, None
+
+
+memory_optimized_grouped_mlp = MemoryOptimizedGroupedMLP.apply
+
+
+class GroupedMLP(SparseMLP):
+
+ def forward(self, x, tokens_per_expert):
+ batch_sizes = tokens_per_expert.cpu().to(torch.long)
+ w1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.w2))
+
+ # Re-shape the weights for the grouped GEMMs.
+ ne = mpu.experts_per_rank(self.args)
+ w1 = resolve_dtensor(w1).view(ne, -1, self.args.hidden_size)
+ w2 = resolve_dtensor(w2).view(ne, -1, self.args.hidden_size)
+
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_grouped_mlp(
+ x,
+ w1,
+ w2,
+ batch_sizes,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ assert gg.ops is not None
+ x = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
+ x = self.args.activation_fn(x)
+ return gg.ops.gmm(x, w2, batch_sizes)
+
+
+class SharedMLP(torch.nn.Module):
+ """MLP for shared expert.
+
+ Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTMLP class
+ """
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ self.fc_kwargs: dict[str, Any] = {
+ 'bias': args.bias,
+ 'device': args.device,
+ }
+ self.fc_kwargs.update(args.fc_kwargs)
+
+ self.up_proj = args.fc_cls(
+ args.hidden_size,
+ args.shared_expert_hidden_size,
+ **self.fc_kwargs,
+ )
+ self.act = args.activation_fn
+ self.down_proj = args.fc_cls(
+ args.shared_expert_hidden_size,
+ args.hidden_size,
+ **self.fc_kwargs,
+ )
+ self.down_proj._is_residual = True # a flag for llm-foundry init
+
+ def add_experts_sharedexpert(
+ self,
+ shared_expert_out: torch.Tensor,
+ expert_out: torch.Tensor,
+ ) -> torch.Tensor:
+ # Helper function to add expert output to shared expert output
+ # with optional weighted sum.
+ if self.args.shared_expert_weighted_sum:
+ # enable using weighted sum for shared expert output
+ # wieghted by number of experts used
+ t_experts = self.args.moe_top_k + 1
+ sh_mlp_out = shared_expert_out / t_experts
+ return sh_mlp_out.add(
+ expert_out,
+ alpha=(self.args.moe_top_k / t_experts),
+ )
+
+ return shared_expert_out + expert_out
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.down_proj(self.act(self.up_proj(x)))
diff --git a/build/torch211-cxx11-cpu-x86_64-linux/_layers/moe.py b/build/torch211-cxx11-cpu-x86_64-linux/_layers/moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0a4aeaacc9c86fc70944e730c53f7a55644e05e
--- /dev/null
+++ b/build/torch211-cxx11-cpu-x86_64-linux/_layers/moe.py
@@ -0,0 +1,507 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Optional, Tuple
+
+import numpy as np
+import torch
+import torch.distributed as dist
+
+# import megablocks.ops as ops
+# from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry
+# from megablocks.layers.all_to_all import all_to_all
+# from megablocks.layers.arguments import Arguments
+
+from ..ops import (
+ sort,
+ histogram,
+ inclusive_cumsum,
+ exclusive_cumsum,
+ binned_gather,
+ binned_scatter,
+ gather,
+ scatter,
+ repeat,
+ replicate,
+)
+
+from . import common, mlp, mpu, router, sharedexpert_registry
+from .arguments import Arguments
+from .all_to_all import all_to_all
+
+_LOAD_BALANCING_LOSS = []
+
+
+def save_load_balancing_loss(loss):
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.append(loss)
+
+
+def get_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ return _LOAD_BALANCING_LOSS
+
+
+def clear_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.clear()
+
+
+def batched_load_balancing_loss(args: Arguments):
+ if args.moe_loss_weight == 0:
+ return 0.0
+
+ # tokens_per_expert[i].shape = (num_experts)
+ # expert_scores[i].shape = (tokens, num_experts)
+ tokens_per_expert, expert_scores = zip(*get_load_balancing_loss())
+ num_layers_per_pipeline_stage = (args.num_layers // args.pipeline_model_parallel_size)
+ if args.num_layers_per_virtual_pipeline_stage is not None:
+ num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage
+
+ if len(tokens_per_expert) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f'Expected {num_layers_per_pipeline_stage} token_per_experts '
+ f'but found {len(tokens_per_expert)}.\nnum_layers = '
+ f'{args.num_layers}\npipeline_model_parallel_size = '
+ f'{args.pipeline_model_parallel_size}\n'
+ 'num_layers_per_virtual_pipeline_stage'
+ f' = {args.num_layers_per_virtual_pipeline_stage}',
+ )
+ if len(expert_scores) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f'Expected {num_layers_per_pipeline_stage} expert_scores '
+ f'but found {len(tokens_per_expert)}.\nnum_layers = '
+ f'{args.num_layers}\npipeline_model_parallel_size = '
+ f'{args.pipeline_model_parallel_size}\n'
+ 'num_layers_per_virtual_pipeline_stage'
+ f' = {args.num_layers_per_virtual_pipeline_stage}',
+ )
+
+ # Verify the shape of the tokens_per_expert and expert_scores tensors.
+ assert all((x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert))
+
+ tokens = expert_scores[0].shape[0]
+ assert all(((x.ndim == 2 and x.shape[1] == args.moe_num_experts and x.shape[0] == tokens) for x in expert_scores))
+
+ # Concatenate the contributions of each layer and convert to
+ # the correct types and formats for the dot product.
+ expert_scores = torch.cat(expert_scores, dim=1)
+ if args.moe_lbl_in_fp32:
+ expert_scores = expert_scores.float()
+ if tokens != 0:
+ expert_scores = expert_scores.mean(dim=0)
+ else:
+ expert_scores = expert_scores.sum(dim=0)
+ tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype)
+
+ expected_values = num_layers_per_pipeline_stage * args.moe_num_experts
+ assert tokens_per_expert.numel() == expected_values
+ assert expert_scores.numel() == expected_values
+
+ # Calculate the total scale across all factors.
+ #
+ # loss_weight * num_experts / (num_layers * tokens * top_k)
+ scale_numerator = (args.moe_num_experts * args.moe_loss_weight)
+ scale_denominator = (args.num_layers * tokens * args.moe_top_k)
+ scale = scale_numerator / scale_denominator
+ return scale * torch.dot(tokens_per_expert, expert_scores)
+
+
+# NOTE: This class defines MoE expert computation, including expert model parallel
+# communication. When using FSDP on top of MegaBlocks this is the module that should
+# be wrapped s.t. the weight all-gathers can be scheduled *before* the expert model
+# parallel all2all.
+class ParallelMLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super(ParallelMLP, self).__init__()
+ self.args = args
+
+ # Calculate the number of experts in total and the number of experts
+ # owned by this rank.
+ # world_size = mpu.get_expert_parallel_world_size(args)
+ self.num_experts = args.moe_num_experts
+ self.top_k = self.args.moe_top_k
+
+ # Calculate the number of bits needed to represent the expert indices
+ # so that we can pass it to radix sort.
+ self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1)
+
+ # Expert MLP.
+ self.mlp = mlp.MLP(args)
+
+ self.bias: Optional[torch.Tensor]
+ if self.args.bias:
+ # Note that the output bias is not parallelized with expert
+ # model parallelism.
+ self.bias = torch.nn.Parameter(
+ torch.empty(
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ torch.nn.init.zeros_(self.bias)
+ else:
+ self.register_parameter('bias', None)
+
+ # Select the forward function for the operating mode.
+ self.forward_fn = (self.parallel_forward_once if args.moe_expert_model_parallelism else self.forward_once)
+
+ def expert_capacity(self, tokens: int) -> int:
+ world_size = mpu.get_expert_parallel_world_size(self.args)
+ tokens_per_expert = (self.top_k * tokens * world_size / self.num_experts)
+ return int(self.args.moe_capacity_factor * tokens_per_expert)
+
+ def load_balancing_loss(self, tokens_per_expert: torch.Tensor, expert_scores: torch.Tensor):
+ """Calculate the load balancing loss contribution."""
+ assert len(expert_scores.size()) == 2
+ tokens, num_experts = expert_scores.size()
+ assert num_experts == self.num_experts
+ assert len(tokens_per_expert.size()) == 1
+ num_experts, = tokens_per_expert.size()
+ assert num_experts == self.num_experts
+ scale = self.num_experts / (tokens * self.top_k)
+ return scale * torch.dot(
+ tokens_per_expert.to(expert_scores.dtype),
+ expert_scores.mean(dim=0),
+ )
+
+ def indices_and_bins(self,
+ top_expert: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ # Sort the expert ids to produce the scatter/gather
+ # indices for the permutation.
+ #
+ # TODO(tgale): Is it worth doing this conversion to 32-bit
+ # prior? Could we place the `torch.max` operation to return
+ # 32-bit expert indices?
+ top_expert = top_expert.int()
+ # output = ops.sort(top_expert, self.sort_end_bit)
+ output = sort(top_expert, self.sort_end_bit)
+ assert output is not None
+ bin_ids, indices = output
+
+ # Histogram the expert ids to identify the number of
+ # tokens routed to each expert.
+ #
+ # TODO(tgale): Does the sorted data produce a more favorable
+ # data distribution for histogram? Or is the op parallelism
+ # worth more?
+ # tokens_per_expert = ops.histogram(top_expert, self.num_experts)
+ tokens_per_expert = histogram(top_expert, self.num_experts)
+
+ # Calculate the bin bounds for the sorted tokens.
+ # bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ bins = inclusive_cumsum(tokens_per_expert, 0)
+ assert bins is not None
+ bins = bins.view(1) if not len(bins.size()) else bins
+
+ assert isinstance(indices, torch.Tensor)
+ assert isinstance(bin_ids, torch.Tensor)
+ assert isinstance(bins, torch.Tensor)
+ assert isinstance(tokens_per_expert, torch.Tensor)
+
+ return indices, bin_ids, bins, tokens_per_expert
+
+ def permute_and_compute(
+ self,
+ x: torch.Tensor,
+ tokens_per_expert: int, # unused
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor, # unused
+ expert_weights: torch.Tensor,
+ bins: torch.Tensor,
+ expert_capacity: int,
+ top_k: int,
+ ):
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ # output = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
+ output = binned_gather(x, indices, bins, expert_capacity, top_k)
+ assert output is not None
+ x = output
+
+ # Perform the expert computation. Note that we don't
+ # use biases for these linear operations.
+ x = self.mlp(x)
+
+ # Un-route the data for the MoE output.
+ # return ops.binned_scatter(x, indices, expert_weights, bins, top_k)
+ return binned_scatter(x, indices, expert_weights, bins, top_k)
+
+
+ def forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ # If expert_capacity is set to zero, set the number of tokens
+ # per expert to the maximum we need to avoid dropping tokens.
+ sl, bs, _ = x.size()
+ expert_capacity = self.expert_capacity(sl * bs)
+ if expert_capacity == 0:
+ expert_capacity = torch.max(tokens_per_expert).item()
+
+ x = self.permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ self.top_k,
+ )
+ return x, tokens_per_expert
+
+ def parallel_forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ # NOTE: This function implements the same computation as forward_once
+ # but with expert model parallelism.
+ #
+ # 1. Permute the tokens locally so that they are grouped by their
+ # expert assignments. This allows us to transfer all of the tokens
+ # for a remote device in one communication primitive.
+ #
+ # 2. Permute the tokens across the expert parallel devices. After
+ # this is completed each device has all of the tokens assigned to
+ # its set of experts in its local HBM.
+ #
+ # 3. Permute the tokens locally so that they are grouped by their
+ # expert assignement. After the distributed permutation the tokens
+ # are grouped by which device they came from. We re-order them
+ # locally to allow for efficient computation.
+ #
+ # After this series of permutations we compute the linear layers
+ # and then repeat these three steps in reverse to produce the final
+ # output.
+ #
+ # Compute the mapping of local tokens to experts.
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ # If we're sharding the experts along the hidden dimension
+ # multiple devices own parts of the same sets of experts.
+ # Replicate the token counts so every device gets the counts.
+ # repeated_tokens_per_expert = ops.repeat(
+ repeated_tokens_per_expert = repeat(
+ tokens_per_expert,
+ (mpu.hidden_sharding_degree(self.args),),
+ )
+
+ # Pass token count information to the device on which the
+ # target expert resides.
+ parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert,)
+ tpe_handle = dist.all_to_all_single(
+ parallel_tokens_per_expert,
+ repeated_tokens_per_expert,
+ group=self.args.expert_parallel_group,
+ async_op=True,
+ )
+
+ # Permute locally and without any padding so that tokens for each
+ # parallel device are stored contiguously.
+ #
+ # This view updates the shape of the tensor from [sl, bs, hs] to
+ # [sl * bs, hs] prior to the permutation.
+ x = x.view(-1, x.shape[-1])
+ # output = ops.gather(x, indices, bin_ids, bins, self.top_k)
+ output = gather(x, indices, bin_ids, bins, self.top_k)
+ assert output is not None
+ x = output
+
+ # Compute the number of tokens that will be received from each
+ # device and permute the input data across the devices.
+ with torch.no_grad():
+ tpe_handle.wait()
+ experts_per_rank = mpu.experts_per_rank(self.args)
+
+ # Reshape to [world_size, num_experts_per_rank].
+ world_size = mpu.get_expert_parallel_world_size(self.args)
+ repeated_tokens_per_expert = (repeated_tokens_per_expert.view(world_size, experts_per_rank))
+ parallel_tokens_per_expert = (parallel_tokens_per_expert.view(world_size, experts_per_rank))
+
+ # TODO(tgale): It might be faster to do this on the GPU and
+ # then communicate the results back to the host.
+ send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1)
+ parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
+ recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1)
+
+ # Convert the send/recv counts to lists.
+ send_counts = send_counts.tolist()
+ recv_counts = recv_counts.tolist()
+ tokens_received = sum(recv_counts)
+
+ # If we're sharding the experts along the hidden dimension
+ # multiple devices own parts of the same sets of experts.
+ # Replicate the token counts so devices that share experts
+ # get all of the tokens assigned to them.
+ #
+ # TODO(tgale): Fuse this into the prior, local permutation.
+ # x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
+ x = repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
+
+ # Start the cross-device permutation asynchronously so we can
+ # overlap communication with computation.
+ parallel_x, parallel_x_handle = all_to_all(
+ x,
+ recv_counts,
+ send_counts,
+ self.args.expert_parallel_group,
+ async_op=True,
+ )
+
+ with torch.no_grad():
+ # After we do the cross-device permutation we have the tokens on the
+ # correct device but not yet grouped by expert because we received
+ # tokens from each device as contiguous chunks. To group the tokens
+ # for expert computation we'll do one more local permutation. The
+ # rest of this torch.no_grad() scope sets up the indices and bins
+ # for this permutation.
+ # replicate_bins = ops.inclusive_cumsum(
+ replicate_bins = inclusive_cumsum(
+ parallel_tokens_per_expert.flatten(),
+ 0,
+ )
+ replicate_bins = (replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins)
+
+ # Construct the expert indices for the permuted tokens.
+ parallel_top_expert = torch.remainder(
+ torch.arange(
+ self.num_experts * mpu.hidden_sharding_degree(self.args),
+ dtype=torch.int32,
+ device=indices.device,
+ ),
+ mpu.experts_per_rank(self.args),
+ )
+ # parallel_top_expert = ops.replicate(
+ parallel_top_expert = replicate(
+ parallel_top_expert.unsqueeze(dim=0),
+ replicate_bins,
+ tokens_received,
+ ).flatten()
+
+ # TODO(tgale): The sort_end_bit here can be reduced.
+ # parallel_bin_ids, parallel_indices = ops.sort(
+ parallel_bin_ids, parallel_indices = sort(
+ parallel_top_expert,
+ self.sort_end_bit,
+ )
+
+ # Calculate the bins boundaries from the token counts.
+ parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+ # parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = (parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins)
+
+ # If expert_capacity is set to zero, set the number of tokens
+ # per expert to the maximum we need to avoid dropping tokens.
+ tokens, _ = x.size()
+ expert_capacity = self.expert_capacity(tokens)
+ if expert_capacity == 0:
+ expert_capacity = torch.max(parallel_tokens_per_expert).item()
+
+ # Locally permute the tokens and perform the expert computation.
+ # Block to make sure that the cross-device permutation is complete.
+ if self.args.mlp_impl == 'grouped':
+ # GroupedMLP requires counts on CPU. We can use the tensor already
+ # moved to CPU for the prior all_to_all, which avoids an extra
+ # device synchronization.
+ parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+ parallel_x_handle.wait()
+ parallel_x = self.permute_and_compute(
+ parallel_x,
+ parallel_tokens_per_expert,
+ parallel_indices,
+ parallel_bin_ids,
+ None, # expert_weights
+ parallel_bins,
+ expert_capacity,
+ top_k=1,
+ )
+
+ # Un-permute the tokens across the devices.
+ x, _ = all_to_all(
+ parallel_x,
+ send_counts,
+ recv_counts,
+ self.args.expert_parallel_group,
+ )
+
+ # Reduce along the hidden sharding to get the final outputs.
+ #
+ # TODO(tgale): Fuse this into the following local permutation.
+ shape = (
+ mpu.hidden_sharding_degree(self.args),
+ -1,
+ self.args.hidden_size,
+ )
+ # x = ops.sum(x.view(shape), dim=0)
+ x = x.view(shape).sum(dim=0)
+
+ # Un-permute locally to setup for the next series of operations.
+ # x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
+ x = scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
+ return x, tokens_per_expert.flatten()
+
+ def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ in_shape = x.size()
+
+ # Compute the experts.
+ x, tokens_per_expert = self.forward_fn(x, expert_weights, top_experts)
+ if self.training and self.args.moe_loss_weight > 0:
+ save_load_balancing_loss((tokens_per_expert, scores))
+ x = x.view(in_shape)
+ if self.bias is not None:
+ if self.args.return_bias:
+ return x, self.bias
+ return x + self.bias
+ return x
+
+
+class MoE(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super(MoE, self).__init__()
+
+ # Token router.
+ self.router = router.LearnedRouter(args)
+
+ # Expert computation helper.
+ self.experts = self._init_experts_mlp(args)
+
+ self.shared_expert = None
+ if args.shared_expert:
+ # SharedExpert computation helper.
+ self.shared_expert = sharedexpert_registry.get(args)
+
+ def _init_experts_mlp(self, args: Arguments):
+ return ParallelMLP(args)
+
+ def forward(self, x: torch.Tensor):
+ # NOTE: If we're going to cast the activations to lower precision
+ # do it before we permute the tokens to save bandwidth.
+ x = common.cast_if_autocast_enabled(x)
+
+ # Compute the expert scores and assignments.
+ scores, expert_weights, top_experts = self.router(x)
+
+ # Compute the experts.
+ out = self.experts(x, scores, expert_weights, top_experts)
+ if self.shared_expert is not None:
+ shared_expert_out = self.shared_expert(x)
+ out = self.shared_expert.add_experts_sharedexpert(
+ shared_expert_out,
+ out,
+ )
+ return out
diff --git a/build/torch211-cxx11-cpu-x86_64-linux/_layers/mpu.py b/build/torch211-cxx11-cpu-x86_64-linux/_layers/mpu.py
new file mode 100644
index 0000000000000000000000000000000000000000..434e143ab42bf3f83406d69e9dd1f72777716e22
--- /dev/null
+++ b/build/torch211-cxx11-cpu-x86_64-linux/_layers/mpu.py
@@ -0,0 +1,94 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Optional
+
+import torch
+import torch.distributed as dist
+
+# from megablocks.layers.arguments import Arguments
+from .arguments import Arguments
+
+
+class MoeParam(torch.Tensor):
+
+ def __init__(self):
+ super().__init__(self)
+ self.expert_model_parallel: bool
+
+
+def is_moe_param(tensor: torch.Tensor) -> bool:
+ return hasattr(tensor, 'expert_model_parallel')
+
+
+def get_expert_parallel_world_size(args: Arguments) -> int:
+ return (dist.get_world_size(args.expert_parallel_group) if args.moe_expert_model_parallelism else 1)
+
+
+def get_expert_parallel_rank(args: Arguments) -> int:
+ return (dist.get_rank(args.expert_parallel_group) if args.moe_expert_model_parallelism else 0)
+
+
+def set_expert_model_parallel_attributes(
+ tensor: torch.Tensor,
+ is_parallel: bool,
+):
+ assert not hasattr(tensor, 'expert_model_parallel')
+ setattr(tensor, 'expert_model_parallel', is_parallel)
+
+
+def param_is_expert_model_parallel(param: MoeParam) -> bool:
+ return (hasattr(param, 'expert_model_parallel') and param.expert_model_parallel)
+
+
+def copy_expert_model_parallel_attributes(
+ destination_tensor: torch.Tensor,
+ source_tensor: torch.Tensor,
+):
+ if hasattr(source_tensor, 'expert_model_parallel'):
+ setattr(
+ destination_tensor,
+ 'expert_model_parallel',
+ getattr(source_tensor, 'expert_model_parallel'),
+ )
+
+
+def synchronized_print(group: Optional[dist.ProcessGroup], *x: torch.Tensor):
+ world_size = dist.get_world_size(group)
+ rank = dist.get_rank(group)
+ for i in range(world_size):
+ dist.barrier(group)
+ if i == rank:
+ print(f'rank = {rank}', *x)
+
+
+# Helpers for expert/tensor sharding.
+def expert_sharding_degree(args: Arguments) -> int:
+ world_size = get_expert_parallel_world_size(args)
+ esd = min(world_size, args.moe_num_experts)
+
+ if (args.moe_num_experts % esd) != 0:
+ raise ValueError(f'Cannot shard {args.moe_num_experts} experts {esd} ways.',)
+ return esd
+
+
+def hidden_sharding_degree(args: Arguments) -> int:
+ world_size = get_expert_parallel_world_size(args)
+ esd = expert_sharding_degree(args)
+ hsd = world_size // esd
+
+ if (args.ffn_hidden_size % hsd) != 0:
+ raise ValueError(f'Cannot shard {args.ffn_hidden_size} features {hsd} ways.',)
+ if (esd * hsd) != world_size:
+ raise ValueError(
+ f"Invalid sharding. 'expert_sharding_degree' ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size}).",
+ )
+ return hsd
+
+
+def experts_per_rank(args: Arguments) -> int:
+ return args.moe_num_experts // expert_sharding_degree(args)
+
+
+def features_per_rank(args: Arguments) -> int:
+ return args.ffn_hidden_size // hidden_sharding_degree(args)
diff --git a/build/torch211-cxx11-cpu-x86_64-linux/_layers/router.py b/build/torch211-cxx11-cpu-x86_64-linux/_layers/router.py
new file mode 100644
index 0000000000000000000000000000000000000000..37cb2782348d62583376f1a183c7ede83601216d
--- /dev/null
+++ b/build/torch211-cxx11-cpu-x86_64-linux/_layers/router.py
@@ -0,0 +1,116 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+
+# from megablocks.layers import common
+# from megablocks.layers.arguments import Arguments
+from . import common
+from .arguments import Arguments
+
+_ROUTER_LOGITS = []
+
+
+def _save_router_logits(logits: torch.Tensor, args: Arguments):
+ if args.moe_zloss_weight == 0:
+ return
+ global _ROUTER_LOGITS
+ _ROUTER_LOGITS.append(logits)
+
+
+def clear_router_zloss():
+ global _ROUTER_LOGITS
+ _ROUTER_LOGITS.clear()
+
+
+def batched_router_zloss(args: Arguments):
+ global _ROUTER_LOGITS
+
+ if args.moe_zloss_weight == 0:
+ import warnings
+ warnings.warn('Call to batched_router_zloss, but moe_zloss_weight=0')
+ return 0
+
+ logits_per_router = _ROUTER_LOGITS
+
+ if args.moe_zloss_in_fp32:
+ logits_per_router = [logits.float() for logits in logits_per_router]
+
+ unscaled_zloss_per_router = torch.stack([
+ torch.logsumexp(logits, dim=1).square().mean() for logits in logits_per_router
+ ])
+
+ return args.moe_zloss_weight * unscaled_zloss_per_router
+
+
+# NOTE: To enable end-to-end benchmarking without convergence we
+# support a flag to force the router to assign tokens uniformly
+# across the experts. We do this with a custom autograd operation
+# so that PyTorch still executes the full set of router operation.
+class _UniformExpertAssignment(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, num_experts: int):
+ out = torch.arange(x.numel(), dtype=x.dtype, device=x.device)
+ out = torch.remainder(out, num_experts)
+ return out.view(x.shape)
+
+
+_uniform_expert_assignment = _UniformExpertAssignment.apply
+
+
+class LearnedRouter(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+
+ # Learned router parameters.
+ #
+ # NOTE: This weight matrix is not parallelized with expert model
+ # parallelism. Each device needs the entire router weight matrix
+ # so that it can route its batch of data correctly.
+ self.layer = torch.nn.Linear(
+ args.hidden_size,
+ args.moe_num_experts,
+ bias=False,
+ dtype=common.dtype(args),
+ device=args.device,
+ )
+ args.init_method(self.layer.weight)
+
+ def jitter(self, x: torch.Tensor):
+ low: float = 1.0 - self.args.moe_jitter_eps
+ high: float = 1.0 + self.args.moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return low + noise * (high - low)
+
+ def _top_k(self, scores: torch.Tensor):
+ if self.args.moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, self.args.moe_top_k, dim=-1)
+
+ def forward(self, x: torch.Tensor):
+ if self.training and self.args.moe_jitter_eps is not None:
+ x = x * self.jitter(x)
+
+ logits = self.layer(x.view(-1, x.shape[-1]))
+ _save_router_logits(logits, self.args)
+ scores = logits.softmax(dim=-1)
+ expert_weights, expert_indices = self._top_k(scores)
+ if self.args.moe_normalize_expert_weights:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=self.args.moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+
+ expert_indices = (
+ _uniform_expert_assignment(
+ expert_indices,
+ self.args.moe_num_experts,
+ ) if self.args.uniform_expert_assignment else expert_indices
+ )
+ return scores, expert_weights, expert_indices
diff --git a/build/torch211-cxx11-cpu-x86_64-linux/_layers/sharedexpert_registry.py b/build/torch211-cxx11-cpu-x86_64-linux/_layers/sharedexpert_registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..5840862f88f370ace5fd49bd0612fc98d186cc49
--- /dev/null
+++ b/build/torch211-cxx11-cpu-x86_64-linux/_layers/sharedexpert_registry.py
@@ -0,0 +1,32 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Union
+
+# from megablocks.layers import glu, mlp
+# from megablocks.layers.arguments import Arguments
+from . import glu, mlp
+from .arguments import Arguments
+
+_REGISTRY = {
+ 'mlp': mlp.SharedMLP,
+ 'glu': glu.SharedGLU,
+}
+
+
+def get(args: Arguments) -> Union[mlp.SharedMLP, glu.SharedGLU]:
+ """Returns an SharedMLP for use in a dMoE instance.
+
+ Uses the provided arguments to instantiate the appropriate
+ SharedMLP instance.
+
+ Args:
+ args: propagated Arguments dataclass.
+
+ Returns:
+ An instantiated SharedMLP constructed using the input args.
+ """
+ if args.mlp_type not in _REGISTRY:
+ raise ValueError(f'Unsupported mlp type: {args.mlp_type}')
+
+ return _REGISTRY[args.mlp_type](args)
diff --git a/build/torch211-cxx11-cpu-x86_64-linux/_megablocks_cpu_7a6bcf4.abi3.so b/build/torch211-cxx11-cpu-x86_64-linux/_megablocks_cpu_7a6bcf4.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..4fe96907856dba6b076db26ec4f8522939171a26
--- /dev/null
+++ b/build/torch211-cxx11-cpu-x86_64-linux/_megablocks_cpu_7a6bcf4.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6fe559f32ea12ed42966ea79e77aa2ea0c7bfb5e123e84ac526fc5d94cf6b9a3
+size 2219080
diff --git a/build/torch211-cxx11-cpu-x86_64-linux/_ops.py b/build/torch211-cxx11-cpu-x86_64-linux/_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..102573e975ffc7897e0e9c4edca028ed1dc67419
--- /dev/null
+++ b/build/torch211-cxx11-cpu-x86_64-linux/_ops.py
@@ -0,0 +1,9 @@
+import torch
+from . import _megablocks_cpu_7a6bcf4
+ops = torch.ops._megablocks_cpu_7a6bcf4
+
+def add_op_namespace_prefix(op_name: str):
+ """
+ Prefix op by namespace.
+ """
+ return f"_megablocks_cpu_7a6bcf4::{op_name}"
diff --git a/build/torch211-cxx11-cpu-x86_64-linux/_version.py b/build/torch211-cxx11-cpu-x86_64-linux/_version.py
new file mode 100644
index 0000000000000000000000000000000000000000..c55783177af19bc03654c730c4892df8f8532279
--- /dev/null
+++ b/build/torch211-cxx11-cpu-x86_64-linux/_version.py
@@ -0,0 +1,6 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+"""The MegaBlocks Version."""
+
+__version__ = '0.11.0.dev0'
diff --git a/build/torch211-cxx11-cpu-x86_64-linux/backend/__init__.py b/build/torch211-cxx11-cpu-x86_64-linux/backend/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d4e43e9b7e6a9a1c3bde2df34914643ca5d8332
--- /dev/null
+++ b/build/torch211-cxx11-cpu-x86_64-linux/backend/__init__.py
@@ -0,0 +1,2 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
diff --git a/build/torch211-cxx11-cpu-x86_64-linux/backend/kernels.py b/build/torch211-cxx11-cpu-x86_64-linux/backend/kernels.py
new file mode 100644
index 0000000000000000000000000000000000000000..d20139324883338ddb312e4b05a72056d74491ac
--- /dev/null
+++ b/build/torch211-cxx11-cpu-x86_64-linux/backend/kernels.py
@@ -0,0 +1,557 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import triton
+import triton.language as tl
+
+# Stub triton autotune when testing in a env that does not have CUDA
+# this approach preserves the original code but enables testing without a GPU
+if torch.cuda.is_available() is False:
+ import warnings
+
+ warnings.warn("CUDA is not available. Triton autotuning is disabled.")
+
+ def _no_autotune(*args, **kwargs):
+ def deco(fn):
+ return fn
+ return deco
+
+ triton.autotune = _no_autotune
+
+
+def assert_is_tensor(x, ndim):
+ if x.ndim != ndim:
+ raise ValueError(f'Expected {ndim}-tensor but got {x.ndim}-tensor')
+
+
+def assert_is_matrix(x):
+ assert_is_tensor(x, 2)
+
+
+def assert_is_vector(x):
+ if x.ndim != 1:
+ raise ValueError(f'Expected 1-tensor but got {x.ndim}-tensor')
+
+
+def assert_equal(a, b):
+ if a != b:
+ raise ValueError(f'Expected dimensions to be equal but got {a} and {b}.',)
+
+
+# a: (tokens, hidden_size), real.
+# indices: (tokens * top_k), integer.
+# bin_ids: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+# padded_bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _padded_copy(
+ a,
+ b,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+ A_TO_B: tl.constexpr,
+ SCALE: tl.constexpr,
+):
+ # Our index into array 'a'.
+ index_a = tl.load(indices + tl.program_id(0))
+
+ # One threadblock per row in 'a'. Array 'b' has greater or equal
+ # number of rows since they could be padded.
+ bin_idx = tl.load(bin_ids + tl.program_id(0))
+
+ # Now we know what bin we're assigned to, but we need to know how
+ # many threadblocks were assigned to earlier bins so we can offset
+ # in our bin properly.
+ offset_in_bin = tl.program_id(0)
+ if bin_idx > 0:
+ offset_in_bin -= tl.load(bins + bin_idx - 1)
+
+ # Load the starting index of our bin in array 'b'.
+ index_b = offset_in_bin
+ if bin_idx > 0:
+ index_b += tl.load(padded_bins + bin_idx - 1)
+
+ # Offset the input and output pointers.
+ #
+ # If we're going from A to B, divide the input index to copy
+ # the same input repeatedly. If we're going from B to A we
+ # need to reduce the result. Using atomics is slow, so we
+ # do the reduce step in a second kernel.
+ offset = index_a // TOP_K if A_TO_B else index_a
+ a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS)
+ b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ # Load the scale, if requested.
+ scale = tl.load(weights + index_a) if SCALE else 1
+
+ # Swap the pointers depending on the direction.
+ iptr = a if A_TO_B else b
+ optr = b if A_TO_B else a
+
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ x = tl.load(iptr + offsets, mask=mask)
+ x = x.to(tl.float32) * scale.to(tl.float32)
+
+ tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask)
+
+ offsets += BLOCK_X
+
+
+def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+ assert_equal(bin_ids.shape[0], x.shape[0] * top_k)
+ assert_equal(bins.size(), padded_bins.size())
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ # NOTE: Because of the padding, the output size is dynamic.
+ # We load the final padded bin bound to get the output rows.
+ output_rows = padded_bins[-1].cpu().item()
+ out = torch.zeros((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ x,
+ out,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def gather(x, indices, bin_ids, weights, bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+ assert_equal(bin_ids.shape[0], x.shape[0] * top_k)
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ # NOTE: There is no padding so the output rows equals the
+ # input rows multiplied by top_k.
+ output_rows = x.shape[0] * top_k
+ out = torch.empty((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ x,
+ out,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], bin_ids.shape[0])
+ assert_equal(bins.size(), padded_bins.size())
+
+ if weights is not None:
+ assert_equal(indices.shape[0], weights.shape[0])
+
+ tokens = indices.shape[0] // top_k
+ out = torch.empty((tokens, top_k, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ out,
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=False,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+
+ # Reduce along the top-k dimension, if needed.
+ return out.sum(dim=1) if top_k > 1 else out.view(tokens, x.shape[1])
+
+
+def scatter(x, indices, bin_ids, weights, bins, top_k):
+ return padded_scatter(x, indices, bin_ids, weights, bins, bins, top_k)
+
+
+# x: (tokens, top_k, hidden_size), real
+# grad: (tokens, hidden_size), real.
+# wgrad: (tokens, top_k), real.
+# indices: (tokens * top_k), integer.
+# bin_ids: (tokens * top_k), integer.
+# bins: (num_experts), integer.
+# padded_bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _padded_copy_wgrad(
+ x,
+ grad,
+ wgrad,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+):
+ # Our index into 'tokens * top_k'.
+ index_out = tl.load(indices + tl.program_id(0))
+
+ # One threadblock per row in 'a'. Array 'b' has greater or equal
+ # number of rows since they could be padded.
+ bin_idx = tl.load(bin_ids + tl.program_id(0))
+
+ # Now we know what bin we're assigned to, but we need to know how
+ # many threadblocks were assigned to earlier bins so we can offset
+ # in our bin properly.
+ offset_in_bin = tl.program_id(0)
+ if bin_idx > 0:
+ offset_in_bin -= tl.load(bins + bin_idx - 1)
+
+ # Load the starting index of our bin in array 'x'.
+ index_x = offset_in_bin
+ if bin_idx > 0:
+ index_x += tl.load(padded_bins + bin_idx - 1)
+
+ # Offset the input and output pointers.
+ wgrad += index_out
+ grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS)
+ x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ data = tl.load(x + offsets, mask=mask).to(tl.float32)
+ scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
+ acc += data * scale
+ offsets += BLOCK_X
+
+ # Reduce to get the final result and store.
+ out = tl.sum(acc).to(wgrad.dtype.element_ty)
+ tl.store(wgrad, out)
+
+
+def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_matrix(grad)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], bin_ids.shape[0])
+ assert_equal(bins.size(), padded_bins.size())
+
+ tokens = indices.shape[0] // top_k
+ out = torch.empty((tokens * top_k), dtype=x.dtype, device=x.device)
+ _padded_copy_wgrad[(indices.shape[0],)](
+ x,
+ grad,
+ out,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ TOP_K=top_k,
+ )
+ return out
+
+
+def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k):
+ return padded_scatter_wgrad(x, grad, indices, bin_ids, bins, bins, top_k)
+
+
+# a: (tokens, hidden_size), real.
+# b: (num_experts, expert_capacity, num_columns), real.
+# indices: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _binned_copy(
+ a,
+ b,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+ A_TO_B: tl.constexpr,
+ SCALE: tl.constexpr,
+):
+ # Load our indices into the output.
+ expert_idx = tl.program_id(0)
+ entry_idx = tl.program_id(1)
+
+ # Calculate our offset into the output.
+ index_b = expert_idx * expert_capacity + entry_idx
+
+ # Load the index bounds for our bin and calculate
+ # the number of tokens assigned to our expert.
+ start = 0
+ if expert_idx > 0:
+ start = tl.load(bins + expert_idx - 1)
+ end = tl.load(bins + expert_idx)
+ num_tokens = end - start
+
+ # Calculate our offset into the input. If we don't
+ # have an input exit early.
+ if entry_idx >= num_tokens:
+ return
+ index_a = tl.load(indices + start + entry_idx)
+
+ # Offset the input and output pointers.
+ #
+ # If we're going from A to B, divide the input index to copy
+ # the same input repeatedly. If we're going from B to A we
+ # need to reduce the result. Using atomics is slow, so we
+ # do the reduce step in a second kernel.
+ offset = index_a // TOP_K if A_TO_B else index_a
+ a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS)
+ b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ # Load the scale, if requested.
+ scale = tl.load(weights + index_a) if SCALE else 1
+
+ # Swap the pointers depending on the direction.
+ #
+ # NOTE: We need to zero the output in both directions.
+ iptr = a if A_TO_B else b
+ optr = b if A_TO_B else a
+
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ x = tl.load(iptr + offsets, mask=mask)
+ x = x.to(tl.float32) * scale.to(tl.float32)
+
+ tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask)
+
+ offsets += BLOCK_X
+
+
+def binned_gather(x, indices, weights, bins, expert_capacity, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ num_experts = bins.shape[0]
+ out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
+
+ _binned_copy[(num_experts, expert_capacity)](
+ x,
+ out,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def binned_scatter(x, indices, weights, bins, top_k):
+ # Validate the input shapes.
+ assert_is_tensor(x, 3)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(bins.shape[0], x.shape[0])
+
+ if weights is not None:
+ assert_equal(indices.shape[0], weights.shape[0])
+
+ num_experts, expert_capacity, hidden_size = x.shape
+ tokens = indices.shape[0] // top_k
+ out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
+ _binned_copy[(num_experts, expert_capacity)](
+ out,
+ x,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS=hidden_size,
+ A_TO_B=False,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+
+ # Reduce along the top-k dimension, if needed.
+ return out.sum(dim=1) if top_k > 1 else out.view(tokens, hidden_size)
+
+
+# a: (tokens, hidden_size), real.
+# b: (num_experts, expert_capacity, num_columns), real.
+# indices: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _binned_copy_wgrad(
+ x,
+ grad,
+ wgrad,
+ num_experts,
+ expert_capacity,
+ indices,
+ bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+):
+ # Load our indices into the output.
+ expert_idx = tl.program_id(0)
+ entry_idx = tl.program_id(1)
+
+ # Calculate our offset into the output.
+ index_x = expert_idx * expert_capacity + entry_idx
+
+ # Load the index bounds for our bin and calculate
+ # the number of tokens assigned to our expert.
+ start = 0
+ if expert_idx > 0:
+ start = tl.load(bins + expert_idx - 1)
+ end = tl.load(bins + expert_idx)
+ num_tokens = end - start
+
+ # Calculate our offset into the input. If we don't
+ # have an input exit early.
+ if entry_idx >= num_tokens:
+ return
+ index_out = tl.load(indices + start + entry_idx)
+
+ # Offset the input and output pointers.
+ wgrad += index_out
+ grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS)
+ x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ data = tl.load(x + offsets, mask=mask).to(tl.float32)
+ scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
+ acc += data * scale
+ offsets += BLOCK_X
+
+ # Reduce to get the final result and store.
+ out = tl.sum(acc).to(wgrad.dtype.element_ty)
+ tl.store(wgrad, out)
+
+
+def binned_scatter_wgrad(x, grad, indices, bins, top_k):
+ # Validate the input shapes.
+ assert_is_tensor(x, 3)
+ assert_is_matrix(grad)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(bins.shape[0], x.shape[0])
+
+ num_experts, expert_capacity, hidden_size = x.shape
+ tokens = indices.shape[0] // top_k
+ out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
+ _binned_copy_wgrad[(num_experts, expert_capacity)](
+ x,
+ grad,
+ out,
+ num_experts,
+ expert_capacity,
+ indices,
+ bins,
+ NUM_COLUMNS=hidden_size,
+ TOP_K=top_k,
+ )
+ return out
diff --git a/build/torch211-cxx11-cpu-x86_64-linux/benchmark_util.py b/build/torch211-cxx11-cpu-x86_64-linux/benchmark_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..02612d95e3ead1175a596e2878fa34b5bf85ad6f
--- /dev/null
+++ b/build/torch211-cxx11-cpu-x86_64-linux/benchmark_util.py
@@ -0,0 +1,35 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import numpy as np
+import torch
+
+
+def log_benchmark(name, arguments, time, std):
+ print('=' * 60)
+ print(f'{name} Benchmark')
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean time = {:.3f}ms, std time = {:.3f}ms'.format(time, std))
+ print('=' * 60)
+
+
+def benchmark_function(fn, iterations=100, warmup=10):
+ # Warmup iterations.
+ for _ in range(warmup):
+ fn()
+
+ times = []
+ for i in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+
+ start.record()
+ fn()
+ end.record()
+
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ return np.mean(times), np.std(times)
diff --git a/build/torch211-cxx11-cpu-x86_64-linux/cpu_fused_moe.py b/build/torch211-cxx11-cpu-x86_64-linux/cpu_fused_moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a29e81d6e9c57f64b0a78ac1c0828e45fd9d855
--- /dev/null
+++ b/build/torch211-cxx11-cpu-x86_64-linux/cpu_fused_moe.py
@@ -0,0 +1,311 @@
+# SPDX-License-Identifier: Apache-2.0
+# MegaBlocks CPU Fused MoE Implementation
+#
+# This is a pure Python/PyTorch implementation for CPU.
+# For better performance, consider using the C++ kernel implementation.
+#
+import torch
+import torch.nn.functional as F
+
+
+def swigluoai_activation(gate: torch.Tensor, up: torch.Tensor,
+ alpha: float = 1.702, limit: float = 7.0) -> torch.Tensor:
+ """
+ SwigluOAI activation function used in GptOss models.
+
+ Formula:
+ gate = clamp(gate, max=limit)
+ up = clamp(up, -limit, limit)
+ glu = gate * sigmoid(gate * alpha)
+ output = (up + 1) * glu
+
+ Args:
+ gate: Gate tensor from gate projection
+ up: Up tensor from up projection
+ alpha: Scaling factor for sigmoid (default: 1.702)
+ limit: Clamp limit (default: 7.0)
+
+ Returns:
+ Activated tensor
+ """
+ gate = gate.clamp(max=limit)
+ up = up.clamp(min=-limit, max=limit)
+ glu = gate * torch.sigmoid(gate * alpha)
+ return (up + 1) * glu
+
+
+def silu_and_mul_activation(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor:
+ """
+ SiLU (Swish) activation with element-wise multiplication.
+
+ Formula:
+ output = silu(gate) * up
+
+ Args:
+ gate: Gate tensor
+ up: Up tensor
+
+ Returns:
+ Activated tensor
+ """
+ return F.silu(gate) * up
+
+
+def route_tokens_cpu(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: torch.Tensor | None,
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_normalize_expert_weights: int | None = None,
+) -> tuple:
+ """
+ Route tokens to experts and compute expert weights and indices (CPU version).
+
+ Args:
+ x: Input tensor [batch, seq, hidden] or [tokens, hidden]
+ router_weight: Router weight [num_experts, hidden]
+ router_bias: Router bias [num_experts] or None
+ moe_top_k: Number of experts per token
+ moe_num_experts: Total number of experts
+ moe_normalize_expert_weights: Normalization order or None
+
+ Returns:
+ Tuple of (logits, expert_weights, expert_indices)
+ """
+ x_flat = x.view(-1, x.shape[-1])
+ logits = F.linear(x_flat, router_weight, router_bias)
+
+ if moe_top_k == 1:
+ expert_weights, expert_indices = logits.max(dim=-1, keepdim=True)
+ else:
+ expert_weights, expert_indices = torch.topk(logits, moe_top_k, dim=-1)
+
+ expert_weights = expert_weights.softmax(dim=-1)
+
+ if moe_normalize_expert_weights is not None:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+
+ return logits, expert_weights, expert_indices
+
+
+def cpu_fused_moe(
+ hidden_states: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ w1_bias: torch.Tensor | None = None,
+ w2_bias: torch.Tensor | None = None,
+ activation: str = "silu",
+ alpha: float = 1.702,
+ limit: float = 7.0,
+ is_interleaved: bool = True,
+) -> torch.Tensor:
+ """
+ CPU Fused MoE using PyTorch operations.
+
+ This implementation processes all experts in parallel using batched operations
+ instead of sequential for loops, which is more efficient on CPU.
+
+ Args:
+ hidden_states: [num_tokens, hidden_size]
+ w1: [num_experts, hidden_size, 2*inter_size] - gate_up_proj weights
+ w2: [num_experts, inter_size, hidden_size] - down_proj weights
+ topk_weights: [num_tokens, topk] - routing weights
+ topk_ids: [num_tokens, topk] - expert indices
+ w1_bias: [num_experts, 2*inter_size] or None
+ w2_bias: [num_experts, hidden_size] or None
+ activation: "silu" or "swigluoai"
+ alpha: swigluoai alpha parameter
+ limit: swigluoai limit parameter
+ is_interleaved: whether gate_up is interleaved [g0,u0,g1,u1,...] (True for GptOss)
+
+ Returns:
+ output: [num_tokens, hidden_size]
+ """
+ num_tokens, hidden_size = hidden_states.shape
+ num_experts = w1.shape[0]
+ inter_size = w2.shape[1]
+ topk = topk_weights.shape[1]
+
+ # Initialize output
+ output = torch.zeros_like(hidden_states)
+
+ # Build expert mask: which tokens go to which expert
+ # expert_mask[expert_id] contains indices of (token_idx, topk_pos) pairs
+ for expert_idx in range(num_experts):
+ # Find tokens assigned to this expert
+ # mask shape: [num_tokens, topk], True where topk_ids == expert_idx
+ mask = (topk_ids == expert_idx)
+
+ if not mask.any():
+ continue
+
+ # Get token indices and topk positions
+ token_indices, topk_positions = torch.where(mask)
+
+ if len(token_indices) == 0:
+ continue
+
+ # Gather input tokens for this expert
+ # current_hidden: [num_selected_tokens, hidden_size]
+ current_hidden = hidden_states[token_indices]
+
+ # Get weights for this expert
+ # w1[expert_idx]: [hidden_size, 2*inter_size]
+ # w2[expert_idx]: [inter_size, hidden_size]
+ expert_w1 = w1[expert_idx] # [hidden_size, 2*inter_size]
+ expert_w2 = w2[expert_idx] # [inter_size, hidden_size]
+
+ # First projection: hidden @ w1 -> [num_selected, 2*inter_size]
+ gate_up = current_hidden @ expert_w1
+
+ # Add bias if present
+ if w1_bias is not None:
+ gate_up = gate_up + w1_bias[expert_idx]
+
+ # Split gate and up projections
+ if is_interleaved:
+ # GptOss uses interleaved layout: [g0, u0, g1, u1, ...]
+ gate = gate_up[..., ::2] # [num_selected, inter_size]
+ up = gate_up[..., 1::2] # [num_selected, inter_size]
+ else:
+ # Standard layout: [gate_all, up_all]
+ gate = gate_up[..., :inter_size]
+ up = gate_up[..., inter_size:]
+
+ # Apply activation
+ if activation == "swigluoai":
+ activated = swigluoai_activation(gate, up, alpha, limit)
+ else: # silu
+ activated = silu_and_mul_activation(gate, up)
+
+ # Second projection: activated @ w2 -> [num_selected, hidden_size]
+ expert_out = activated @ expert_w2
+
+ # Add bias if present
+ if w2_bias is not None:
+ expert_out = expert_out + w2_bias[expert_idx]
+
+ # Apply routing weights and accumulate
+ # weights shape: [num_selected]
+ weights = topk_weights[token_indices, topk_positions].unsqueeze(-1)
+ weighted_out = expert_out * weights
+
+ # Accumulate to output
+ output.index_add_(0, token_indices, weighted_out.to(output.dtype))
+
+ return output
+
+
+class MegaBlocksMoeMLP(torch.nn.Module):
+ """
+ CPU MoE MLP module that can be used as a drop-in replacement for
+ the transformers GptOssMLP when using @use_kernel_forward_from_hub.
+ """
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> tuple:
+ """
+ Forward pass through the MoE layer.
+
+ Args:
+ x: Input tensor of shape [batch_size, seq_len, hidden_size] or [tokens, hidden_size]
+
+ Returns:
+ Tuple of (output, expert_weights) where:
+ - output: Tensor of same shape as input
+ - expert_weights: Expert weights for each token [tokens, top_k]
+ """
+ # Get MoE parameters from the wrapped modules
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+
+ # Detect activation type
+ if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
+ activation = "swigluoai"
+ alpha = self.experts.alpha
+ limit = self.experts.limit
+ else:
+ activation = getattr(self.experts, "activation", "silu")
+ alpha = 1.702
+ limit = 7.0
+
+ # Get weight tensors
+ if hasattr(self.experts, "gate_up_proj"):
+ w1 = self.experts.gate_up_proj
+ is_interleaved = True # GptOss uses interleaved layout
+ elif hasattr(self.experts, "w1"):
+ w1 = self.experts.w1
+ w3 = getattr(self.experts, "w3", None)
+ if w3 is not None:
+ w1 = torch.cat([w1, w3], dim=-1)
+ is_interleaved = False
+ else:
+ raise AttributeError("experts module must have 'gate_up_proj' or 'w1' attribute")
+
+ if hasattr(self.experts, "down_proj"):
+ w2 = self.experts.down_proj
+ elif hasattr(self.experts, "w2"):
+ w2 = self.experts.w2
+ else:
+ raise AttributeError("experts module must have 'down_proj' or 'w2' attribute")
+
+ # Get optional bias tensors
+ w1_bias = getattr(self.experts, "gate_up_proj_bias", None)
+ w2_bias = getattr(self.experts, "down_proj_bias", None)
+
+ # Store original shape
+ in_shape = x.size()
+
+ # Route tokens to experts
+ logits, expert_weights, expert_indices = route_tokens_cpu(
+ x,
+ self.router.weight,
+ getattr(self.router, "bias", None),
+ moe_top_k,
+ moe_num_experts,
+ moe_normalize_expert_weights,
+ )
+
+ # Reshape input for fused MoE
+ x_flat = x.view(-1, x.shape[-1])
+
+ # Call CPU fused MoE
+ output = cpu_fused_moe(
+ hidden_states=x_flat,
+ w1=w1,
+ w2=w2,
+ topk_weights=expert_weights,
+ topk_ids=expert_indices,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ activation=activation,
+ alpha=alpha,
+ limit=limit,
+ is_interleaved=is_interleaved,
+ )
+
+ # Restore original shape
+ output = output.view(in_shape)
+
+ return output, expert_weights
+
+
+# Export classes and functions
+__all__ = [
+ "MegaBlocksMoeMLP",
+ "cpu_fused_moe",
+ "route_tokens_cpu",
+ "swigluoai_activation",
+ "silu_and_mul_activation",
+]
diff --git a/build/torch211-cxx11-cpu-x86_64-linux/cpu_moe_cpp.py b/build/torch211-cxx11-cpu-x86_64-linux/cpu_moe_cpp.py
new file mode 100644
index 0000000000000000000000000000000000000000..073ff66d24ce348fbb5ed19c9027fadd3f7a9c61
--- /dev/null
+++ b/build/torch211-cxx11-cpu-x86_64-linux/cpu_moe_cpp.py
@@ -0,0 +1,265 @@
+# SPDX-License-Identifier: Apache-2.0
+# MegaBlocks C++ Optimized CPU MoE
+
+"""
+C++ accelerated MoE with brgemm optimization for Intel AMX.
+Direct replacement for cpu_fused_moe.MegaBlocksMoeMLP with better performance.
+"""
+
+import torch
+from typing import Optional
+from .cpu_fused_moe import route_tokens_cpu
+from ._ops import ops
+
+
+def _to_local_tensor(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
+ """Convert DTensor to local torch.Tensor if needed for custom ops compatibility."""
+ if tensor is None:
+ return None
+ # Check if it's a DTensor by looking for the to_local() method
+ if hasattr(tensor, "to_local"):
+ return tensor.to_local()
+ return tensor
+
+
+def fused_moe_cpp(
+ hidden_states: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ inplace: bool = False,
+ use_int8_w8a8: bool = False,
+ use_fp8_w8a16: bool = False,
+ use_mxfp4: bool = False,
+ w1_scale: Optional[torch.Tensor] = None,
+ w2_scale: Optional[torch.Tensor] = None,
+ block_size: Optional[list] = None,
+ a1_scale: Optional[torch.Tensor] = None,
+ a2_scale: Optional[torch.Tensor] = None,
+ w1_bias: Optional[torch.Tensor] = None,
+ w2_bias: Optional[torch.Tensor] = None,
+ alpha: Optional[float] = None,
+ limit: Optional[float] = None,
+ is_vnni: bool = False,
+) -> torch.Tensor:
+ """
+ C++ Fused MoE with brgemm optimization (sglang compatible interface).
+
+ Uses at::native::cpublas::brgemm for efficient batch GEMM on Intel CPUs.
+ Supports both silu_and_mul (standard SwiGLU) and swigluoai (GptOss) activations.
+
+ Args:
+ hidden_states: Input tensor [M, K]
+ w1: Gate and up projections [E, 2N, K]
+ w2: Down projection [E, K, N]
+ topk_weights: Expert weights [M, topk]
+ topk_ids: Expert indices [M, topk]
+ inplace: Whether to use hidden_states as output
+ use_int8_w8a8: Use int8 quantization
+ use_fp8_w8a16: Use fp8 quantization
+ use_mxfp4: Use mxfp4 quantization
+ w1_scale, w2_scale: Quantization scales
+ block_size: Block size for fp8
+ a1_scale, a2_scale: Activation scales
+ w1_bias, w2_bias: Optional biases
+ alpha: swigluoai alpha parameter (set to enable swiglu)
+ limit: swigluoai limit parameter (set to enable swiglu)
+ is_vnni: Whether w1/w2 are already in VNNI packed format
+ """
+ # MXFP4/FP8 kernels only support bf16, convert if needed
+ orig_dtype = hidden_states.dtype
+ need_convert = ((use_mxfp4 or use_fp8_w8a16) and orig_dtype != torch.bfloat16) or orig_dtype == torch.float32
+ if need_convert:
+ hidden_states = hidden_states.to(torch.bfloat16)
+
+ # bias must match hidden_states dtype
+ if w1_bias is not None:
+ w1_bias = w1_bias.to(hidden_states.dtype)
+ if w2_bias is not None:
+ w2_bias = w2_bias.to(hidden_states.dtype)
+
+ # Convert DTensor to local tensor for custom ops compatibility (TP mode)
+ hidden_states = _to_local_tensor(hidden_states)
+ w1 = _to_local_tensor(w1)
+ w2 = _to_local_tensor(w2)
+ topk_weights = _to_local_tensor(topk_weights)
+ topk_ids = _to_local_tensor(topk_ids)
+ w1_scale = _to_local_tensor(w1_scale)
+ w2_scale = _to_local_tensor(w2_scale)
+ a1_scale = _to_local_tensor(a1_scale)
+ a2_scale = _to_local_tensor(a2_scale)
+ w1_bias = _to_local_tensor(w1_bias)
+ w2_bias = _to_local_tensor(w2_bias)
+
+ output = ops.fused_experts(
+ hidden_states, w1, w2, topk_weights, topk_ids,
+ inplace, use_int8_w8a8, use_fp8_w8a16, use_mxfp4,
+ w1_scale, w2_scale, block_size, a1_scale, a2_scale,
+ w1_bias, w2_bias, alpha, limit, is_vnni
+ )
+
+ # Convert back to original dtype if needed
+ if need_convert:
+ output = output.to(orig_dtype)
+ return output
+
+
+class CPUMegaBlocksMoeMLP(torch.nn.Module):
+ """
+ C++ optimized MoE MLP using brgemm.
+ Drop-in replacement for cpu_fused_moe.MegaBlocksMoeMLP with better performance.
+
+ Usage in transformers:
+ # Will be used via @use_kernel_forward_from_hub decorator
+ """
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> tuple:
+ """
+ Forward pass through the MoE layer using C++ kernel.
+
+ Args:
+ x: Input tensor [batch_size, seq_len, hidden_size]
+
+ Returns:
+ Tuple of (output, expert_weights)
+ """
+ # Optimization for GPT-OSS model
+ if getattr(self, "use_mxfp4", None) is None:
+ self.use_mxfp4 = False
+
+ w1_scale = None
+ w2_scale = None
+
+ if (
+ not getattr(self, "packed_scales", False)
+ and hasattr(self.experts, "gate_up_proj")
+ and getattr(self.experts, "gate_up_proj_precision_config", None) is not None
+ ):
+ # convert scales
+ data_1 = ops.convert_scale_packed(self.experts.gate_up_proj_precision_config.weight_scale.data.transpose(-1, -2).contiguous())
+ data_2 = ops.convert_scale_packed(self.experts.down_proj_precision_config.weight_scale.data.transpose(-1, -2).contiguous())
+ self.experts.gate_up_proj_precision_config.weight_scale.storage.data = data_1
+ self.experts.down_proj_precision_config.weight_scale.storage.data = data_2
+ self.packed_scales = True
+ self.use_mxfp4 = True
+
+ if not getattr(self, "packed_weight", False) and hasattr(
+ self.experts, "gate_up_proj"
+ ):
+ # convert weights
+ data_1 = self.experts.gate_up_proj.data.transpose(-1, -2).contiguous()
+ data_2 = self.experts.down_proj.data.transpose(-1, -2).contiguous()
+ if self.use_mxfp4:
+ self.experts.gate_up_proj.storage.data = ops.convert_weight_packed(data_1)
+ self.experts.down_proj.storage.data = ops.convert_weight_packed(data_2)
+ else:
+ # convert_weight_packed only supports bfloat16, float16, int8, fp8_e4m3 or uint8(mxfp4 or int4).
+ data_1 = data_1.to(torch.bfloat16) if data_1.dtype == torch.float32 else data_1
+ data_2 = data_2.to(torch.bfloat16) if data_2.dtype == torch.float32 else data_2
+ self.experts.gate_up_proj.data = ops.convert_weight_packed(data_1)
+ self.experts.down_proj.data = ops.convert_weight_packed(data_2)
+
+ # C++ kernel does not support float32.
+ dtype = torch.bfloat16 if x.dtype == torch.float32 else x.dtype
+ if getattr(self.experts, "gate_up_proj_bias", None) is not None:
+ self.experts.gate_up_proj_bias.data = self.experts.gate_up_proj_bias.data.to(dtype)
+ if getattr(self.experts, "down_proj_bias", None) is not None:
+ self.experts.down_proj_bias.data = self.experts.down_proj_bias.data.to(dtype)
+
+ self.packed_weight = True
+
+ # Get MoE parameters
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
+
+ # Detect activation type
+ if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
+ activation = "swigluoai"
+ alpha = self.experts.alpha
+ limit = self.experts.limit
+ else:
+ activation = getattr(self.experts, "activation", "silu")
+ alpha = 1.702
+ limit = 7.0
+
+ # Get weight tensors
+ if hasattr(self.experts, "gate_up_proj"):
+ w1 = self.experts.gate_up_proj
+ elif hasattr(self.experts, "w1"):
+ w1 = self.experts.w1
+ w3 = getattr(self.experts, "w3", None)
+ if w3 is not None:
+ w1 = torch.cat([w1, w3], dim=-1)
+ else:
+ raise AttributeError("experts module must have 'gate_up_proj' or 'w1' attribute")
+
+ if hasattr(self.experts, "down_proj"):
+ w2 = self.experts.down_proj
+ elif hasattr(self.experts, "w2"):
+ w2 = self.experts.w2
+ else:
+ raise AttributeError("experts module must have 'down_proj' or 'w2' attribute")
+
+ # Get optional bias tensors
+ w1_bias = getattr(self.experts, "gate_up_proj_bias", None)
+ w2_bias = getattr(self.experts, "down_proj_bias", None)
+ w1_bias = w1_bias if w1_bias is None else w1_bias.data
+ w2_bias = w2_bias if w2_bias is None else w2_bias.data
+
+ if self.use_mxfp4:
+ w1_scale = self.experts.gate_up_proj_precision_config.weight_scale.data
+ w2_scale = self.experts.down_proj_precision_config.weight_scale.data
+
+ # Store original shape
+ in_shape = x.size()
+
+ # Route tokens to experts (Python implementation is fast enough)
+ logits, expert_weights, expert_indices = route_tokens_cpu(
+ x,
+ self.router.weight,
+ getattr(self.router, "bias", None),
+ moe_top_k,
+ moe_num_experts,
+ moe_normalize_expert_weights,
+ )
+
+ # Flatten input
+ x_flat = x.view(-1, x.shape[-1])
+
+ # Determine alpha/limit for swiglu activation
+ use_alpha = alpha if activation == "swigluoai" else None
+ use_limit = limit if activation == "swigluoai" else None
+
+ # Call C++ optimized kernel
+ output = fused_moe_cpp(
+ hidden_states=x_flat,
+ w1=w1.data,
+ w2=w2.data,
+ topk_weights=expert_weights,
+ topk_ids=expert_indices.to(torch.int32),
+ inplace=False,
+ use_int8_w8a8=False,
+ use_fp8_w8a16=False,
+ use_mxfp4=self.use_mxfp4,
+ w1_scale=w1_scale,
+ w2_scale=w2_scale,
+ block_size=None,
+ a1_scale=None,
+ a2_scale=None,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ alpha=use_alpha,
+ limit=use_limit,
+ is_vnni=getattr(self, "packed_weight", False),
+ )
+
+ # Restore original shape
+ output = output.view(in_shape)
+
+ return output, expert_weights
+
+
+__all__ = ["fused_moe_cpp", "MegaBlocksMoeMLP"]
diff --git a/build/torch211-cxx11-cpu-x86_64-linux/grouped_gemm/__init__.py b/build/torch211-cxx11-cpu-x86_64-linux/grouped_gemm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b91c8308f0c24f4c4171b6e4f15b6f76dabf295a
--- /dev/null
+++ b/build/torch211-cxx11-cpu-x86_64-linux/grouped_gemm/__init__.py
@@ -0,0 +1,2 @@
+from . import ops
+from . import backend
diff --git a/build/torch211-cxx11-cpu-x86_64-linux/grouped_gemm/backend.py b/build/torch211-cxx11-cpu-x86_64-linux/grouped_gemm/backend.py
new file mode 100644
index 0000000000000000000000000000000000000000..76037d8039cbfc2f0577275c78e4bc0be762592a
--- /dev/null
+++ b/build/torch211-cxx11-cpu-x86_64-linux/grouped_gemm/backend.py
@@ -0,0 +1,33 @@
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# # TODO(tgale): Wrap this in a try-block with better
+# # error message and instructions for building the
+# # c++ operations.
+# import grouped_gemm_backend as backend
+
+# We import the backend operations from the megablocks package as
+# grouped_gemm is vendored in megablocks in this repository.
+# from ... import _ops as backend
+# from megablocks._ops import ops as backend # type: ignore
+from .._ops import ops as backend # type: ignore
+
+def _allocate_output(a, b, batch_sizes, trans_a, trans_b):
+ assert not (trans_a and trans_b)
+ assert batch_sizes.ndim == 1, "Expected 1d tensor for batch_sizes"
+ assert a.ndim == 2, "Expected 2d tensor for 'a'"
+ assert b.ndim == (2 if trans_a else 3)
+
+ shape = (
+ (batch_sizes.shape[0], a.shape[1], b.shape[1])
+ if trans_a else
+ (a.shape[0], (b.shape[1] if trans_b else b.shape[2]))
+ )
+ return torch.empty(*shape, device=a.device, dtype=a.dtype)
+
+def gmm(a, b, batch_sizes, trans_a=False, trans_b=False, c=None):
+ if c is None:
+ c = _allocate_output(a, b, batch_sizes, trans_a, trans_b)
+ backend.gmm(a, b, c, batch_sizes, trans_a, trans_b)
+ return c
diff --git a/build/torch211-cxx11-cpu-x86_64-linux/grouped_gemm/ops.py b/build/torch211-cxx11-cpu-x86_64-linux/grouped_gemm/ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b30dd14e23837ea3b12334f4e31337ed9ad2b69
--- /dev/null
+++ b/build/torch211-cxx11-cpu-x86_64-linux/grouped_gemm/ops.py
@@ -0,0 +1,33 @@
+from . import backend
+import torch
+
+
+class GroupedGemm(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, a, b, batch_sizes, trans_b):
+ ctx.save_for_backward(a, b, batch_sizes)
+ ctx.trans_b = trans_b
+ return backend.gmm(a, b, batch_sizes, trans_a=False, trans_b=trans_b)
+
+ @staticmethod
+ def backward(ctx, grad):
+ grad = grad.contiguous()
+ a, b, batch_sizes = ctx.saved_tensors
+ trans_b = ctx.trans_b
+
+ agrad = None
+ if ctx.needs_input_grad[0]:
+ agrad = backend.gmm(
+ grad, b, batch_sizes, trans_a=False, trans_b=not trans_b)
+
+ bgrad = None
+ if ctx.needs_input_grad[1]:
+ lhs, rhs = (grad, a) if trans_b else (a, grad)
+ bgrad = backend.gmm(
+ lhs, rhs, batch_sizes, trans_a=True, trans_b=False)
+ return agrad, bgrad, None, None
+
+
+def gmm(a, b, batch_sizes, trans_b=False):
+ return GroupedGemm.apply(a, b, batch_sizes, trans_b)
diff --git a/build/torch211-cxx11-cpu-x86_64-linux/grouped_gemm_util.py b/build/torch211-cxx11-cpu-x86_64-linux/grouped_gemm_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d6d49fc46b0a57ad46e4179df3cc1ac2a24f7ae
--- /dev/null
+++ b/build/torch211-cxx11-cpu-x86_64-linux/grouped_gemm_util.py
@@ -0,0 +1,31 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+import warnings
+
+_grouped_gemm_is_available: bool = False
+try:
+ # import grouped_gemm
+ pass
+ _grouped_gemm_is_available = True
+except ImportError as error:
+ warnings.warn('Grouped GEMM not available.')
+
+
+def grouped_gemm_is_available():
+ return _grouped_gemm_is_available
+
+
+def assert_grouped_gemm_is_available():
+ msg = (
+ 'Grouped GEMM not available. Please run '
+ '`pip install git+https://github.com/tgale96/grouped_gemm@main`.',
+ )
+ assert _grouped_gemm_is_available, msg
+
+
+# backend = grouped_gemm.backend if grouped_gemm_is_available() else None
+# ops = grouped_gemm.ops if grouped_gemm_is_available() else None
+
+
+from .grouped_gemm import backend as ops
+from .grouped_gemm import ops as backend
diff --git a/build/torch211-cxx11-cpu-x86_64-linux/layers.py b/build/torch211-cxx11-cpu-x86_64-linux/layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c91edcd9e2d1a4ef9eac90217ff481f08ab1886
--- /dev/null
+++ b/build/torch211-cxx11-cpu-x86_64-linux/layers.py
@@ -0,0 +1,1232 @@
+import torch
+import torch.distributed as dist
+
+from typing import Optional, Any, TYPE_CHECKING
+
+from . import _layers
+from . import ops
+
+# Conditional import for meta kernel registration
+if TYPE_CHECKING:
+
+ def register_fake(fn):
+ return lambda name: fn
+
+else:
+ try:
+ from torch.library import register_fake
+ except ImportError:
+ try:
+ from torch.library import impl_abstract as register_fake
+ except ImportError:
+ # Fallback for older PyTorch versions
+ def register_fake(op_name):
+ def decorator(fn):
+ return fn
+
+ return decorator
+
+
+# Meta kernel implementations for torch.compile compatibility
+def _install_meta_kernels():
+ """Install meta kernels for existing MegaBlocks operations"""
+
+ # Create wrapper functions that check for compilation and return meta tensors
+
+ # Patch ops.sort
+ if hasattr(ops, "sort"):
+ original_sort = ops.sort
+
+ def sort_with_meta(x, end_bit=None):
+ if torch.compiler.is_compiling():
+ # print("Using meta kernel for sort")
+ # Meta implementation - return tensors with correct shape/dtype/device
+ return torch.empty_like(x), torch.empty_like(x)
+ # print("Using original sort kernel")
+ return original_sort(x, end_bit)
+
+ ops.sort = sort_with_meta
+
+ # Patch ops.histogram
+ if hasattr(ops, "histogram"):
+ original_histogram = ops.histogram
+
+ def histogram_with_meta(x, max_val):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty((max_val,), dtype=torch.int32, device=x.device)
+ return original_histogram(x, max_val)
+
+ ops.histogram = histogram_with_meta
+
+ # Patch ops.inclusive_cumsum
+ if hasattr(ops, "inclusive_cumsum"):
+ original_inclusive_cumsum = ops.inclusive_cumsum
+
+ def inclusive_cumsum_with_meta(x, dim):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty_like(x)
+ return original_inclusive_cumsum(x, dim)
+
+ ops.inclusive_cumsum = inclusive_cumsum_with_meta
+
+ # Patch ops.binned_gather
+ if hasattr(ops, "binned_gather"):
+ original_binned_gather = ops.binned_gather
+
+ def binned_gather_with_meta(x, indices, bins, bin_size, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - output shape based on bin_size
+ if x.dim() >= 2:
+ hidden_size = x.size(-1)
+ return torch.empty(
+ (bin_size, x.size(1), hidden_size),
+ dtype=x.dtype,
+ device=x.device,
+ )
+ else:
+ return torch.empty((bin_size,), dtype=x.dtype, device=x.device)
+ return original_binned_gather(x, indices, bins, bin_size, top_k)
+
+ ops.binned_gather = binned_gather_with_meta
+
+ # Patch ops.binned_scatter
+ if hasattr(ops, "binned_scatter"):
+ original_binned_scatter = ops.binned_scatter
+
+ def binned_scatter_with_meta(x, indices, weights, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - typically reduces to 2D
+ if x.dim() >= 3:
+ return torch.empty(
+ (x.size(1), x.size(2)), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty_like(x)
+ return original_binned_scatter(x, indices, weights, bins, top_k)
+
+ ops.binned_scatter = binned_scatter_with_meta
+
+ # Patch ops.gather
+ if hasattr(ops, "gather"):
+ original_gather = ops.gather
+
+ def gather_with_meta(x, indices, bin_ids, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ if x.dim() >= 2:
+ hidden_size = x.size(-1)
+ return torch.empty(
+ (indices.numel(), hidden_size), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty(indices.shape, dtype=x.dtype, device=x.device)
+ return original_gather(x, indices, bin_ids, bins, top_k)
+
+ ops.gather = gather_with_meta
+
+ # Patch ops.scatter
+ if hasattr(ops, "scatter"):
+ original_scatter = ops.scatter
+
+ def scatter_with_meta(x, indices, bin_ids, weights, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - restore sequence shape
+ seq_len = (
+ indices.size(0) // top_k
+ if indices.numel() > 0 and top_k > 0
+ else x.size(0)
+ )
+ if x.dim() >= 2:
+ return torch.empty(
+ (seq_len, x.size(-1)), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty((seq_len,), dtype=x.dtype, device=x.device)
+ return original_scatter(x, indices, bin_ids, weights, bins, top_k)
+
+ ops.scatter = scatter_with_meta
+
+ # Patch ops.replicate
+ if hasattr(ops, "replicate"):
+ original_replicate = ops.replicate
+
+ def replicate_with_meta(x, bins, num_outputs):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty(
+ (x.shape[0], num_outputs), dtype=x.dtype, device=x.device
+ )
+ return original_replicate(x, bins, num_outputs)
+
+ ops.replicate = replicate_with_meta
+
+ # Patch ops.repeat (if it's a regular function)
+ if hasattr(ops, "repeat"):
+ original_repeat = ops.repeat
+
+ def repeat_with_meta(x, repeats):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ if isinstance(repeats, (tuple, list)):
+ new_shape = list(x.shape)
+ for i, rep in enumerate(repeats):
+ if i < len(new_shape):
+ new_shape[i] *= rep
+ return torch.empty(new_shape, dtype=x.dtype, device=x.device)
+ else:
+ new_shape = [x.size(0) * repeats] + list(x.shape[1:])
+ return torch.empty(new_shape, dtype=x.dtype, device=x.device)
+ return original_repeat(x, repeats)
+
+ ops.repeat = repeat_with_meta
+
+
+# Install meta kernels on import
+try:
+ _install_meta_kernels()
+except Exception as e:
+ # If meta kernel installation fails, continue without them
+ # torch.compile may not work but the library will still function
+ import warnings
+
+ warnings.warn(
+ f"Failed to install meta kernels for torch.compile support: {e}", UserWarning
+ )
+
+
+# Set the expert model parallel attributes on a tensor
+def set_expert_model_parallel_attributes(
+ tensor: torch.Tensor,
+ is_parallel: bool,
+):
+ assert not hasattr(tensor, "expert_model_parallel")
+ setattr(tensor, "expert_model_parallel", is_parallel)
+
+
+# Get the expert model parallel attributes from a tensor
+def expert_sharding_degree(
+ world_size: int,
+ moe_num_experts: int,
+) -> int:
+ esd = min(world_size, moe_num_experts)
+ if (moe_num_experts % esd) != 0:
+ raise ValueError(f"Cannot shard {moe_num_experts} experts {esd} ways.")
+ return esd
+
+
+# Calculate the hidden sharding degree based on world size and expert sharding degree
+def hidden_sharding_degree(
+ world_size: int,
+ moe_num_experts: int,
+ ffn_hidden_size: int,
+) -> int:
+ esd = expert_sharding_degree(world_size, moe_num_experts)
+ hsd = world_size // esd
+ if (ffn_hidden_size % hsd) != 0:
+ raise ValueError(f"Cannot shard {ffn_hidden_size} features {hsd} ways.")
+ if (esd * hsd) != world_size:
+ raise ValueError(
+ f"Invalid sharding. expert_sharding_degree ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size})."
+ )
+ return hsd
+
+
+# Calculate the number of experts per rank based on world size and expert sharding degree
+def experts_per_rank(
+ moe_num_experts: int,
+ world_size: int,
+) -> int:
+ return moe_num_experts // expert_sharding_degree(world_size, moe_num_experts)
+
+
+# Calculate the number of features per rank based on ffn hidden size and hidden sharding degree
+def features_per_rank(
+ ffn_hidden_size: int, world_size: int, moe_num_experts: int
+) -> int:
+ return ffn_hidden_size // hidden_sharding_degree(
+ world_size, moe_num_experts, ffn_hidden_size
+ )
+
+
+# Apply jitter to the input tensor
+def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor:
+ low = 1.0 - moe_jitter_eps
+ high = 1.0 + moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return x * (low + noise * (high - low))
+
+
+# Compute the top-k scores from the logits
+def compute_top_k(scores: torch.Tensor, moe_top_k: int):
+ if moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, moe_top_k, dim=-1)
+
+
+# Route tokens to experts and compute expert weights and indices
+def route_tokens(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: torch.Tensor,
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ if training and moe_jitter_eps is not None:
+ x = apply_jitter(x, moe_jitter_eps)
+
+ x_flat = x.view(-1, x.shape[-1])
+ logits = torch.nn.functional.linear(x_flat, router_weight, router_bias)
+ expert_weights, expert_indices = compute_top_k(logits, moe_top_k)
+ expert_weights = expert_weights.softmax(dim=-1)
+ if moe_normalize_expert_weights is not None:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+ if uniform_expert_assignment:
+ expert_indices = _layers.router._uniform_expert_assignment(
+ expert_indices,
+ moe_num_experts,
+ )
+
+ return logits, expert_weights, expert_indices
+
+
+# Scale the gradient of the weights
+def scale_grad(
+ w: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+) -> torch.Tensor:
+ if gradient_scale is None:
+ return w
+ return _layers.mlp.scale_gradient(w, gradient_scale)
+
+
+# Forward pass for the MLP layer
+def mlp_forward(
+ x: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ limit: float = 7.0,
+):
+ # Scale weights
+ w1 = scale_grad(w1, gradient_scale)
+ w2 = scale_grad(w2, gradient_scale)
+ w1_bias = scale_grad(w1_bias, gradient_scale)
+ w2_bias = scale_grad(w2_bias, gradient_scale)
+
+ # Resolve dtensors
+ w1 = _layers.mlp.resolve_dtensor(w1)
+ w2 = _layers.mlp.resolve_dtensor(w2)
+ w1_bias = _layers.mlp.resolve_dtensor(w1_bias)
+ w2_bias = _layers.mlp.resolve_dtensor(w2_bias)
+
+ # Forward pass
+ gate_up = torch.bmm(x, w1) + w1_bias[..., None, :]
+ gate, up = gate_up[..., ::2], gate_up[..., 1::2]
+ gate = gate.clamp(min=None, max=limit)
+ up = up.clamp(min=-limit, max=limit)
+ glu = gate * torch.sigmoid(gate * alpha)
+ next_states = torch.bmm(((up + 1) * glu), w2)
+ next_states += w2_bias[..., None, :]
+ return next_states
+
+# Shared expert MLP forward pass
+def shared_mlp_forward(
+ x: torch.Tensor,
+ up_proj_weight: torch.Tensor,
+ down_proj_weight: torch.Tensor,
+ up_proj_bias: Optional[torch.Tensor] = None,
+ down_proj_bias: Optional[torch.Tensor] = None,
+ activation_fn: Optional[Any] = None,
+ gradient_scale: Optional[float] = None,
+) -> torch.Tensor:
+ # Default activation function
+ if activation_fn is None:
+ activation_fn = torch.nn.functional.gelu
+
+ # Scale weights
+ up_proj_weight = scale_grad(up_proj_weight, gradient_scale)
+ down_proj_weight = scale_grad(down_proj_weight, gradient_scale)
+ if up_proj_bias is not None:
+ up_proj_bias = scale_grad(up_proj_bias, gradient_scale)
+ if down_proj_bias is not None:
+ down_proj_bias = scale_grad(down_proj_bias, gradient_scale)
+
+ # Resolve dtensors
+ up_proj_weight = _layers.mlp.resolve_dtensor(up_proj_weight)
+ down_proj_weight = _layers.mlp.resolve_dtensor(down_proj_weight)
+ if up_proj_bias is not None:
+ up_proj_bias = _layers.mlp.resolve_dtensor(up_proj_bias)
+ if down_proj_bias is not None:
+ down_proj_bias = _layers.mlp.resolve_dtensor(down_proj_bias)
+
+ # Up projection
+ x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias)
+
+ # Activation
+ x = activation_fn(x)
+
+ # Down projection
+ x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias)
+
+ return x
+
+
+# Combine outputs from shared expert and regular experts
+def combine_expert_shared_outputs(
+ shared_expert_out: torch.Tensor,
+ expert_out: torch.Tensor,
+ shared_expert_weighted_sum: bool = False,
+ moe_top_k: int = 1,
+) -> torch.Tensor:
+ if shared_expert_weighted_sum:
+ # Weighted sum based on number of experts used
+ total_experts = moe_top_k + 1
+ shared_weight = 1.0 / total_experts
+ expert_weight = moe_top_k / total_experts
+ return shared_expert_out * shared_weight + expert_out * expert_weight
+ else:
+ # Simple addition
+ return shared_expert_out + expert_out
+
+
+# Global variable to store load balancing loss
+_LOAD_BALANCING_LOSS = []
+
+
+def save_load_balancing_loss(loss):
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.append(loss)
+
+
+def get_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ return _LOAD_BALANCING_LOSS
+
+
+def clear_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.clear()
+
+
+def batched_load_balancing_loss(args):
+ if args.moe_loss_weight == 0:
+ return 0.0
+
+ tokens_per_expert, expert_scores = zip(*get_load_balancing_loss())
+ num_layers_per_pipeline_stage = args.num_layers // args.pipeline_model_parallel_size
+ if args.num_layers_per_virtual_pipeline_stage is not None:
+ num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage
+
+ if len(tokens_per_expert) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f"Expected {num_layers_per_pipeline_stage} token_per_experts "
+ f"but found {len(tokens_per_expert)}.\nnum_layers = "
+ f"{args.num_layers}\npipeline_model_parallel_size = "
+ f"{args.pipeline_model_parallel_size}\n"
+ "num_layers_per_virtual_pipeline_stage"
+ f" = {args.num_layers_per_virtual_pipeline_stage}",
+ )
+ if len(expert_scores) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f"Expected {num_layers_per_pipeline_stage} expert_scores "
+ f"but found {len(tokens_per_expert)}.\nnum_layers = "
+ f"{args.num_layers}\npipeline_model_parallel_size = "
+ f"{args.pipeline_model_parallel_size}\n"
+ "num_layers_per_virtual_pipeline_stage"
+ f" = {args.num_layers_per_virtual_pipeline_stage}",
+ )
+
+ # Verify the shape of the tokens_per_expert and expert_scores tensors.
+ assert all(
+ (x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert)
+ )
+
+ tokens = expert_scores[0].shape[0]
+ assert all(
+ (
+ (
+ x.ndim == 2
+ and x.shape[1] == args.moe_num_experts
+ and x.shape[0] == tokens
+ )
+ for x in expert_scores
+ )
+ )
+
+ # Concatenate the contributions of each layer and convert to
+ # the correct types and formats for the dot product.
+ expert_scores = torch.cat(expert_scores, dim=1)
+ if args.moe_lbl_in_fp32:
+ expert_scores = expert_scores.float()
+ if tokens != 0:
+ expert_scores = expert_scores.mean(dim=0)
+ else:
+ expert_scores = expert_scores.sum(dim=0)
+ tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype)
+
+ expected_values = num_layers_per_pipeline_stage * args.moe_num_experts
+ assert tokens_per_expert.numel() == expected_values
+ assert expert_scores.numel() == expected_values
+
+ # Calculate the total scale across all factors.
+ #
+ # loss_weight * num_experts / (num_layers * tokens * top_k)
+ scale_numerator = args.moe_num_experts * args.moe_loss_weight
+ scale_denominator = args.num_layers * tokens * args.moe_top_k
+ scale = scale_numerator / scale_denominator
+ return scale * torch.dot(tokens_per_expert, expert_scores)
+
+
+# Calculate the expert capacity based on tokens, top_k, number of experts,
+# expert parallel group, capacity factor, and whether expert model parallelism is used.
+def expert_capacity(
+ tokens: int,
+ top_k: int,
+ num_experts: int,
+ expert_parallel_group: int,
+ moe_capacity_factor: float,
+ moe_expert_model_parallelism: bool,
+) -> int:
+ world_size = (
+ dist.get_world_size(expert_parallel_group)
+ if moe_expert_model_parallelism
+ else 1
+ )
+
+ tokens_per_expert = top_k * tokens * world_size / num_experts
+ return int(moe_capacity_factor * tokens_per_expert)
+
+
+def load_balancing_loss(
+ tokens_per_expert: torch.Tensor,
+ expert_scores: torch.Tensor,
+ top_k: int,
+ num_experts: int,
+):
+ assert len(expert_scores.size()) == 2
+ tokens, num_experts = expert_scores.size()
+ assert num_experts == num_experts
+ assert len(tokens_per_expert.size()) == 1
+ (num_experts,) = tokens_per_expert.size()
+ assert num_experts == num_experts
+ scale = num_experts / (tokens * top_k)
+ return scale * torch.dot(
+ tokens_per_expert.to(expert_scores.dtype),
+ expert_scores.mean(dim=0),
+ )
+
+
+def indices_and_bins(
+ top_expert: torch.Tensor,
+ sort_end_bit: int,
+ num_experts: int,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ top_expert = top_expert.int()
+
+ # Ensure contiguous memory layout
+ top_expert = top_expert.contiguous()
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(top_expert.device):
+ output = ops.sort(top_expert, sort_end_bit)
+ bin_ids, indices = output
+ tokens_per_expert = ops.histogram(top_expert, num_experts)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+
+ bins = bins.view(1) if not len(bins.size()) else bins
+ return indices, bin_ids, bins, tokens_per_expert
+
+
+def expert_capacity_fn(
+ tokens: int,
+ top_k: int,
+ num_experts: int,
+ expert_parallel_group: torch.distributed.ProcessGroup,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+) -> int:
+ world_size = (
+ dist.get_world_size(expert_parallel_group)
+ if moe_expert_model_parallelism
+ else 1
+ )
+ tokens_per_expert = top_k * tokens * world_size / num_experts
+ return int(moe_capacity_factor * tokens_per_expert)
+
+
+def permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ top_k,
+ w1,
+ w2,
+ w1_bias,
+ w2_bias,
+ gradient_scale,
+ alpha,
+):
+ # Route tokens to experts
+ x = x.view(-1, x.shape[-1])
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(x.device):
+ x = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
+
+ # Expert computation
+ x = mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale, alpha)
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(x.device):
+ # Route tokens back
+ out = ops.binned_scatter(x, indices, expert_weights, bins, top_k)
+ return out
+
+
+def forward_once(
+ x: torch.Tensor,
+ expert_weights: torch.Tensor,
+ top_experts: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ top_k: int = 4,
+ num_experts: int = 128,
+ expert_parallel_group: int = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ mlp_impl: Optional[str] = None,
+):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
+ top_experts, sort_end_bit, num_experts
+ )
+
+ # Calculate expert capacity
+ sl, bs, _ = x.size()
+
+ expert_capacity = expert_capacity_fn(
+ sl * bs,
+ top_k,
+ num_experts,
+ expert_parallel_group,
+ moe_capacity_factor,
+ moe_expert_model_parallelism,
+ )
+
+ if expert_capacity == 0:
+ expert_capacity = torch.max(tokens_per_expert).item()
+
+ x = permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ top_k,
+ w1,
+ w2,
+ w1_bias,
+ w2_bias,
+ gradient_scale,
+ alpha,
+ )
+ return x, tokens_per_expert
+
+
+def parallel_forward_once(
+ x: torch.Tensor,
+ expert_weights: torch.Tensor,
+ top_experts: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ top_k: int = 4,
+ num_experts: int = 128,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = True,
+ hidden_size: int = 1152,
+ mlp_impl: Optional[str] = "grouped",
+):
+ # Flatten inputs
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+
+ # TODO: remove debugging var
+ # my_rank = dist.get_rank(expert_parallel_group) if expert_parallel_group else 0
+
+ with torch.no_grad():
+ # Step 1: Local permutation setup
+ indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
+ top_experts, sort_end_bit, num_experts
+ )
+
+ # Calculate sharding parameters
+ world_size = dist.get_world_size(expert_parallel_group)
+ hidden_sharding_deg = hidden_sharding_degree(
+ world_size, num_experts, hidden_size
+ )
+ experts_per_rank_val = experts_per_rank(num_experts, world_size)
+
+ # Replicate token counts for hidden sharding
+ repeated_tokens_per_expert = ops.repeat(
+ tokens_per_expert, (hidden_sharding_deg,)
+ )
+
+ # Exchange token counts across devices
+ parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert)
+
+ # Ensure CUB knows which device to use
+ tpe_handle = dist.all_to_all_single(
+ parallel_tokens_per_expert,
+ repeated_tokens_per_expert,
+ group=expert_parallel_group,
+ async_op=True,
+ )
+
+ # Step 2: Local permutation - group tokens by target device
+ x = x.view(-1, x.shape[-1]) # [sl * bs, hs]
+ x = ops.gather(x, indices, bin_ids, bins, top_k)
+
+ # Step 3: Compute communication counts and exchange tokens
+ with torch.no_grad():
+ tpe_handle.wait()
+
+ # Reshape for per-device calculations
+ repeated_tokens_per_expert = repeated_tokens_per_expert.view(
+ world_size, experts_per_rank_val
+ )
+ parallel_tokens_per_expert = parallel_tokens_per_expert.view(
+ world_size, experts_per_rank_val
+ )
+
+ # Calculate send/recv counts
+ send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1).tolist()
+ # recv_counts = parallel_tokens_per_expert.cpu().sum(dim=-1).tolist()
+ parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
+ recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1).tolist()
+ tokens_received = sum(recv_counts)
+
+ # Replicate for hidden sharding
+ x = ops.repeat(x, (hidden_sharding_deg, 1))
+
+ # Cross-device token exchange
+ parallel_x, parallel_x_handle = _layers.all_to_all.all_to_all(
+ x, recv_counts, send_counts, expert_parallel_group, async_op=True
+ )
+
+ with torch.no_grad():
+ # Step 4: Setup for local expert computation
+ replicate_bins = ops.inclusive_cumsum(parallel_tokens_per_expert.flatten(), 0)
+ replicate_bins = (
+ replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins
+ )
+
+ # Create expert indices for received tokens
+ parallel_top_expert = torch.remainder(
+ torch.arange(
+ num_experts * hidden_sharding_deg,
+ dtype=torch.int32,
+ device=indices.device,
+ ),
+ experts_per_rank_val,
+ )
+ parallel_top_expert = ops.replicate(
+ parallel_top_expert.unsqueeze(dim=0),
+ replicate_bins,
+ tokens_received,
+ ).flatten()
+
+ # Sort tokens by expert assignment
+ parallel_bin_ids, parallel_indices = ops.sort(
+ parallel_top_expert,
+ sort_end_bit,
+ )
+
+ # Calculate bins for local experts
+ parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
+ dim=0, dtype=torch.int
+ )
+ parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = (
+ parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins
+ )
+
+ # Calculate expert capacity
+ expert_capacity = expert_capacity_fn(
+ tokens_received,
+ top_k,
+ experts_per_rank_val,
+ expert_parallel_group,
+ moe_capacity_factor,
+ moe_expert_model_parallelism,
+ )
+ if expert_capacity == 0:
+ expert_capacity = torch.max(parallel_tokens_per_expert).item()
+
+ # Locally permute the tokens and perform the expert computation.
+ # Block to make sure that the cross-device permutation is complete.
+ if mlp_impl == "grouped":
+ # GroupedMLP requires counts on CPU. We can use the tensor already
+ # moved to CPU for the prior all_to_all, which avoids an extra
+ # device synchronization.
+ parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+
+ # Step 5: Expert computation
+ parallel_x_handle.wait()
+
+ parallel_x = permute_and_compute(
+ parallel_x,
+ parallel_tokens_per_expert,
+ parallel_indices,
+ parallel_bin_ids,
+ None, # expert_weights
+ parallel_bins,
+ expert_capacity,
+ top_k=1,
+ w1=w1,
+ w2=w2,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ )
+
+ # Step 6: Reverse communication - send results back
+ x, _ = _layers.all_to_all.all_to_all(
+ parallel_x, send_counts, recv_counts, expert_parallel_group
+ )
+
+ # Step 7: Reduce across hidden sharding dimension
+ shape = (hidden_sharding_deg, -1, hidden_size)
+ x = x.view(shape).sum(dim=0)
+
+ # Step 8: Final local unpermutation
+ x = ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
+
+ return x, tokens_per_expert.flatten()
+
+
+def moe_forward(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: Optional[torch.Tensor],
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+ w1: torch.Tensor = None,
+ w2: torch.Tensor = None,
+ w1_bias: torch.Tensor = None,
+ w2_bias: torch.Tensor = None,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ forward_fn: Any = None,
+ hidden_size: int = None,
+ mlp_impl: str = "grouped",
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+
+ # Route tokens to experts
+ logits, expert_weights, expert_indices = route_tokens(
+ x,
+ router_weight,
+ router_bias,
+ moe_top_k,
+ moe_num_experts,
+ moe_jitter_eps,
+ moe_normalize_expert_weights,
+ uniform_expert_assignment,
+ training,
+ )
+
+ # Create router scores for output
+ router_scores = (
+ torch.zeros_like(logits)
+ .scatter_(1, expert_indices, expert_weights)
+ .transpose(0, 1)
+ )
+
+ in_shape = x.size()
+
+ # Prepare forward function arguments
+ forward_args = {
+ "x": x,
+ "expert_weights": expert_weights,
+ "top_experts": expert_indices,
+ "w1": w1,
+ "w2": w2,
+ "w1_bias": w1_bias,
+ "w2_bias": w2_bias,
+ "gradient_scale": gradient_scale,
+ "alpha": alpha,
+ "sort_end_bit": sort_end_bit,
+ "top_k": moe_top_k,
+ "num_experts": moe_num_experts,
+ "expert_parallel_group": expert_parallel_group,
+ "moe_capacity_factor": moe_capacity_factor,
+ "moe_expert_model_parallelism": moe_expert_model_parallelism,
+ "mlp_impl": mlp_impl,
+ }
+
+ # Add hidden_size for parallel forward
+ if moe_expert_model_parallelism and hidden_size is not None:
+ forward_args["hidden_size"] = hidden_size
+ elif moe_expert_model_parallelism and hidden_size is None:
+ # Infer hidden_size from input shape
+ forward_args["hidden_size"] = x.shape[-1]
+
+ # Compute expert outputs
+ x, tokens_per_expert = forward_fn(**forward_args)
+
+ # Save load balancing loss if needed
+ moe_loss_weight = 0.0 # Can be made configurable
+ if training and moe_loss_weight > 0:
+ save_load_balancing_loss((tokens_per_expert, logits))
+
+ # Restore original shape
+ x = x.view(in_shape)
+
+ return x, expert_weights, router_scores
+
+
+def moe_forward_with_shared_expert(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: Optional[torch.Tensor],
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+ w1: torch.Tensor = None,
+ w2: torch.Tensor = None,
+ w1_bias: torch.Tensor = None,
+ w2_bias: torch.Tensor = None,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ forward_fn: Any = None,
+ hidden_size: int = None,
+ mlp_impl: str = "grouped",
+ # Shared expert parameters
+ shared_up_proj_weight: Optional[torch.Tensor] = None,
+ shared_down_proj_weight: Optional[torch.Tensor] = None,
+ shared_up_proj_bias: Optional[torch.Tensor] = None,
+ shared_down_proj_bias: Optional[torch.Tensor] = None,
+ shared_expert_weighted_sum: bool = False,
+ shared_activation_fn: Optional[Any] = None,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+
+ # First, compute regular MoE forward pass
+ expert_out, expert_weights, router_scores = moe_forward(
+ x=x,
+ router_weight=router_weight,
+ router_bias=router_bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=training,
+ w1=w1,
+ w2=w2,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=moe_expert_model_parallelism,
+ forward_fn=forward_fn,
+ hidden_size=hidden_size,
+ mlp_impl=mlp_impl,
+ )
+
+ # If shared expert weights provided, compute shared expert output
+ if shared_up_proj_weight is not None and shared_down_proj_weight is not None:
+ shared_expert_out = shared_mlp_forward(
+ x=x,
+ up_proj_weight=shared_up_proj_weight,
+ down_proj_weight=shared_down_proj_weight,
+ up_proj_bias=shared_up_proj_bias,
+ down_proj_bias=shared_down_proj_bias,
+ activation_fn=shared_activation_fn,
+ gradient_scale=gradient_scale,
+ )
+
+ # Combine expert outputs
+ combined_out = combine_expert_shared_outputs(
+ shared_expert_out=shared_expert_out,
+ expert_out=expert_out,
+ shared_expert_weighted_sum=shared_expert_weighted_sum,
+ moe_top_k=moe_top_k,
+ )
+
+ return combined_out, expert_weights, router_scores
+
+ # Return regular MoE output if no shared expert
+ return expert_out, expert_weights, router_scores
+
+
+def create_shared_expert_weights(
+ hidden_size: int,
+ shared_expert_hidden_size: int,
+ device: torch.device,
+ dtype: torch.dtype,
+ init_method: Any,
+ output_layer_init_method: Any = None,
+) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
+
+ if output_layer_init_method is None:
+ output_layer_init_method = init_method
+
+ # Create weight tensors
+ up_proj_weight = torch.empty(
+ shared_expert_hidden_size,
+ hidden_size,
+ device=device,
+ dtype=dtype,
+ )
+ down_proj_weight = torch.empty(
+ hidden_size,
+ shared_expert_hidden_size,
+ device=device,
+ dtype=dtype,
+ )
+
+ # Initialize weights
+ init_method(up_proj_weight)
+ output_layer_init_method(down_proj_weight)
+
+ # No bias by default
+ return up_proj_weight, down_proj_weight, None, None
+
+
+# HACK: Extract device_mesh from pre-hook closure - required for transformers integration
+# This exists because device_mesh is trapped in hook closures with no model attribute
+# Fragile - breaks if hook structure changes or Python internals change
+# TODO: Replace with a more robust solution when available
+def get_device_mesh(model):
+ # Extract device_mesh from child's unused pre_hook closure
+ try:
+ # Find the pre-hook that contains 'device_mesh' in its closure
+ hook = next(
+ h
+ for h in model.experts._forward_pre_hooks.values()
+ if "device_mesh" in h.__code__.co_freevars
+ )
+ # Extract the device_mesh from the closure
+ return hook.__closure__[
+ hook.__code__.co_freevars.index("device_mesh")
+ ].cell_contents
+ except Exception:
+ return None
+
+
+class MegaBlocksMoeMLP(torch.nn.Module):
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ gradient_scale = getattr(self.experts, "gradient_scale", None)
+ alpha = getattr(self.experts, "alpha", 1.0)
+ moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+ uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
+
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = get_device_mesh(self)
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
+
+ has_parallel = (
+ expert_parallel_group is not None
+ and dist.is_initialized()
+ and dist.get_world_size(expert_parallel_group) > 1
+ )
+ forward_fn = parallel_forward_once if has_parallel else forward_once
+
+ sort_end_bit = max(
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
+ )
+ mlp_impl = getattr(self, "mlp_impl", "grouped")
+ output, expert_weights_out, *_ = moe_forward(
+ x=x,
+ router_weight=self.router.weight,
+ router_bias=self.router.bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=self.training,
+ w1=self.experts.gate_up_proj,
+ w2=self.experts.down_proj,
+ w1_bias=self.experts.gate_up_proj_bias,
+ w2_bias=self.experts.down_proj_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=has_parallel,
+ forward_fn=forward_fn,
+ hidden_size=self.experts.hidden_size,
+ mlp_impl=mlp_impl,
+ )
+ return output, expert_weights_out
+
+
+# Export main classes
+__all__ = ["MegaBlocksMoeMLP", "MegaBlocksMoeMLPWithSharedExpert"]
+
+
+class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
+
+ def __init__(self):
+ super().__init__()
+ # Shared expert weights will be set by the user
+ self.shared_up_proj_weight = None
+ self.shared_down_proj_weight = None
+ self.shared_up_proj_bias = None
+ self.shared_down_proj_bias = None
+ self.shared_expert_weighted_sum = False
+ self.shared_activation_fn = None
+
+ def set_shared_expert_weights(
+ self,
+ up_proj_weight: torch.Tensor,
+ down_proj_weight: torch.Tensor,
+ up_proj_bias: Optional[torch.Tensor] = None,
+ down_proj_bias: Optional[torch.Tensor] = None,
+ weighted_sum: bool = False,
+ activation_fn: Optional[Any] = None,
+ ):
+ self.shared_up_proj_weight = up_proj_weight
+ self.shared_down_proj_weight = down_proj_weight
+ self.shared_up_proj_bias = up_proj_bias
+ self.shared_down_proj_bias = down_proj_bias
+ self.shared_expert_weighted_sum = weighted_sum
+ self.shared_activation_fn = activation_fn
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ gradient_scale = getattr(self.experts, "gradient_scale", None)
+ alpha = getattr(self.experts, "alpha", 1.0)
+ moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+ uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
+
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = get_device_mesh(self)
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
+
+ has_parallel = (
+ expert_parallel_group is not None
+ and dist.is_initialized()
+ and dist.get_world_size(expert_parallel_group) > 1
+ )
+ forward_fn = parallel_forward_once if has_parallel else forward_once
+
+ sort_end_bit = max(
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
+ )
+ mlp_impl = getattr(self, "mlp_impl", "grouped")
+
+ output, expert_weights_out, *_ = moe_forward_with_shared_expert(
+ x=x,
+ router_weight=self.router.weight,
+ router_bias=self.router.bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=self.training,
+ w1=self.experts.gate_up_proj,
+ w2=self.experts.down_proj,
+ w1_bias=self.experts.gate_up_proj_bias,
+ w2_bias=self.experts.down_proj_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=has_parallel,
+ forward_fn=forward_fn,
+ hidden_size=self.experts.hidden_size,
+ mlp_impl=mlp_impl,
+ # Shared expert parameters
+ shared_up_proj_weight=self.shared_up_proj_weight,
+ shared_down_proj_weight=self.shared_down_proj_weight,
+ shared_up_proj_bias=self.shared_up_proj_bias,
+ shared_down_proj_bias=self.shared_down_proj_bias,
+ shared_expert_weighted_sum=self.shared_expert_weighted_sum,
+ shared_activation_fn=self.shared_activation_fn,
+ )
+ return output, expert_weights_out
+
+
+# Patch for XPU or CPU support
+if hasattr(torch, "xpu") and torch.xpu.is_available():
+ from .xpu_fused_moe import MegaBlocksMoeMLP
+
+from .cpu_moe_cpp import CPUMegaBlocksMoeMLP
diff --git a/build/torch211-cxx11-cpu-x86_64-linux/megablocks/__init__.py b/build/torch211-cxx11-cpu-x86_64-linux/megablocks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23
--- /dev/null
+++ b/build/torch211-cxx11-cpu-x86_64-linux/megablocks/__init__.py
@@ -0,0 +1,26 @@
+import ctypes
+import importlib.util
+import sys
+from pathlib import Path
+from types import ModuleType
+
+
+def _import_from_path(file_path: Path) -> ModuleType:
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
+ # it would also be used for other imports. So, we make a module name that
+ # depends on the path for it to be unique using the hex-encoded hash of
+ # the path.
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
+ module_name = path_hash
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ if spec is None:
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
+ module = importlib.util.module_from_spec(spec)
+ if module is None:
+ raise ImportError(f"Cannot load module {module_name} from spec")
+ sys.modules[module_name] = module
+ spec.loader.exec_module(module) # type: ignore
+ return module
+
+
+globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
diff --git a/build/torch211-cxx11-cpu-x86_64-linux/metadata.json b/build/torch211-cxx11-cpu-x86_64-linux/metadata.json
new file mode 100644
index 0000000000000000000000000000000000000000..eb22148b3f551be150f7824a5684c19bbc40ae0e
--- /dev/null
+++ b/build/torch211-cxx11-cpu-x86_64-linux/metadata.json
@@ -0,0 +1,8 @@
+{
+ "version": 1,
+ "license": "Apache-2.0",
+ "python-depends": [],
+ "backend": {
+ "type": "cpu"
+ }
+}
\ No newline at end of file
diff --git a/build/torch211-cxx11-cpu-x86_64-linux/ops/__init__.py b/build/torch211-cxx11-cpu-x86_64-linux/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b944080df810d0b0cfc571f3009b0098a651f9b7
--- /dev/null
+++ b/build/torch211-cxx11-cpu-x86_64-linux/ops/__init__.py
@@ -0,0 +1,35 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from .binned_gather import binned_gather
+from .binned_scatter import binned_scatter
+from .cumsum import exclusive_cumsum, inclusive_cumsum
+from .gather import gather
+from .histogram import histogram
+from .padded_gather import padded_gather
+from .padded_scatter import padded_scatter
+from .repeat import repeat
+from .replicate import replicate
+from .round_up import round_up
+from .scatter import scatter
+from .sort import sort
+from .sum import sum
+from .topology import topology
+
+__all__ = [
+ 'binned_gather',
+ 'binned_scatter',
+ 'exclusive_cumsum',
+ 'inclusive_cumsum',
+ 'gather',
+ 'histogram',
+ 'padded_gather',
+ 'padded_scatter',
+ 'repeat',
+ 'replicate',
+ 'round_up',
+ 'scatter',
+ 'sort',
+ 'sum',
+ 'topology',
+]
diff --git a/build/torch211-cxx11-cpu-x86_64-linux/ops/all_to_all_benchmark.py b/build/torch211-cxx11-cpu-x86_64-linux/ops/all_to_all_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c939818edca3345f6344bbc7cef07ffe3cd0181
--- /dev/null
+++ b/build/torch211-cxx11-cpu-x86_64-linux/ops/all_to_all_benchmark.py
@@ -0,0 +1,63 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import torch.distributed as dist
+
+# from megablocks import benchmark_util
+# from megablocks.layers.all_to_all import all_to_all
+
+from .. import benchmark_util
+from .._layers.all_to_all import all_to_all
+
+_ALL_TO_ALL_BENCHMARK = (
+ (8, 1024),
+ (16, 1024),
+ (32, 1024),
+ (64, 1024),
+ (128, 1024),
+ (256, 1024),
+ (512, 1024),
+ (1024, 1024),
+ (2 * 1024, 1024),
+ (4 * 1024, 1024),
+ (8 * 1024, 1024),
+ (16 * 1024, 1024),
+ (32 * 1024, 1024),
+ (64 * 1024, 1024),
+ (128 * 1024, 1024),
+ (256 * 1024, 1024),
+ (512 * 1024, 1024),
+ (1024 * 1024, 1024),
+)
+
+
+def benchmark_all_to_all(group, sl, hs):
+ world_size = dist.get_world_size(group)
+ assert (sl % world_size) == 0
+ send_recv_sizes = [sl // world_size] * world_size
+
+ x = torch.randn((sl, hs)).cuda().half()
+
+ details = {
+ 'world_size': world_size,
+ 'message_size (B)': send_recv_sizes[0] * hs * 2, # 2B elements.
+ }
+
+ def benchmark():
+ return all_to_all(x, send_recv_sizes, send_recv_sizes, group)
+
+ time, std = benchmark_util.benchmark_function(benchmark)
+
+ if dist.get_rank(group) == 0:
+ benchmark_util.log_benchmark('All-To-All', details, time, std)
+
+
+if __name__ == '__main__':
+ assert dist.is_available()
+ group = dist.init_process_group(backend='nccl')
+ local_rank = dist.get_rank(group)
+ torch.cuda.set_device(local_rank)
+
+ for args in _ALL_TO_ALL_BENCHMARK:
+ benchmark_all_to_all(group, *args)
diff --git a/build/torch211-cxx11-cpu-x86_64-linux/ops/binned_gather.py b/build/torch211-cxx11-cpu-x86_64-linux/ops/binned_gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..189a7fa3518d660f29ea32e7a04827164af98d60
--- /dev/null
+++ b/build/torch211-cxx11-cpu-x86_64-linux/ops/binned_gather.py
@@ -0,0 +1,37 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for binned_gather kernel.
+class BinnedGatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bins: torch.Tensor,
+ bin_size: int,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bins)
+ ctx.top_k = top_k
+ return kernels.binned_gather(x, indices, None, bins, bin_size, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ indices, bins = ctx.saved_tensors
+ out = kernels.binned_scatter(grad, indices, None, bins, ctx.top_k)
+ return out, None, None, None, None
+
+
+binned_gather = BinnedGatherOp.apply
diff --git a/build/torch211-cxx11-cpu-x86_64-linux/ops/binned_scatter.py b/build/torch211-cxx11-cpu-x86_64-linux/ops/binned_scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb937c0c106662ce8108c1cb926f8f063b163d3d
--- /dev/null
+++ b/build/torch211-cxx11-cpu-x86_64-linux/ops/binned_scatter.py
@@ -0,0 +1,59 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for binned_scatter kernel.
+class BinnedScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ):
+ assert len(x.size()) == 3
+ ctx.bin_size = x.size(1)
+ ctx.top_k = top_k
+
+ # TODO(tgale): Don't save 'x' for backwards if we don't need to
+ # calculate the gradient w.r.t. 'weights'.
+ ctx.save_for_backward(x, indices, weights, bins)
+ return kernels.binned_scatter(x, indices, weights, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ x, indices, weights, bins = ctx.saved_tensors
+ out = kernels.binned_gather(
+ grad,
+ indices,
+ weights,
+ bins,
+ ctx.bin_size,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[2]:
+ wgrad = kernels.binned_scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bins,
+ ctx.top_k,
+ )
+ return out, None, wgrad, None, None
+
+
+binned_scatter = BinnedScatterOp.apply
diff --git a/build/torch211-cxx11-cpu-x86_64-linux/ops/cumsum.py b/build/torch211-cxx11-cpu-x86_64-linux/ops/cumsum.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2b7572391e20045d335cf7337246e8a9b9f57ef
--- /dev/null
+++ b/build/torch211-cxx11-cpu-x86_64-linux/ops/cumsum.py
@@ -0,0 +1,52 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ # import megablocks_ops as ops # type: ignore
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrappers for cumsum kernels.
+# NOTE: Does not support gradients.
+class ExclusiveCumsumOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, dim: int):
+ if len(x.size()) == 1:
+ x = x.view([1, -1])
+ out = torch.empty_like(x)
+ ops.exclusive_cumsum(x, 1, out)
+ return out.squeeze()
+ out = torch.empty_like(x)
+ ops.exclusive_cumsum(x, dim, out)
+ return out
+
+
+exclusive_cumsum = ExclusiveCumsumOp.apply
+
+
+class InclusiveCumsumOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, dim: int) -> torch.Tensor:
+ if len(x.size()) == 1:
+ x = x.view([1, -1])
+ out = torch.empty_like(x)
+ ops.inclusive_cumsum(x, 1, out)
+ return out.squeeze()
+ out = torch.empty_like(x)
+ ops.inclusive_cumsum(x, dim, out)
+ return out
+
+
+inclusive_cumsum = InclusiveCumsumOp.apply
diff --git a/build/torch211-cxx11-cpu-x86_64-linux/ops/gather.py b/build/torch211-cxx11-cpu-x86_64-linux/ops/gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1f87c1e7bed8d3589dd790805234976e0b05898
--- /dev/null
+++ b/build/torch211-cxx11-cpu-x86_64-linux/ops/gather.py
@@ -0,0 +1,38 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for gather kernel.
+class GatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bin_ids, bins)
+ ctx.top_k = top_k
+ return kernels.gather(x, indices, bin_ids, None, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+
+ indices, bin_ids, bins = ctx.saved_tensors
+ out = kernels.scatter(grad, indices, bin_ids, None, bins, ctx.top_k)
+ return out, None, None, None, None, None
+
+
+gather = GatherOp.apply
diff --git a/build/torch211-cxx11-cpu-x86_64-linux/ops/histogram.py b/build/torch211-cxx11-cpu-x86_64-linux/ops/histogram.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b3f058ec373cbba7555704fb5e4212c3cc75d9d
--- /dev/null
+++ b/build/torch211-cxx11-cpu-x86_64-linux/ops/histogram.py
@@ -0,0 +1,27 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for histogram kernel.
+# NOTE: Does not support gradients.
+class HistogramOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, max_val: float):
+ return ops.histogram(x, max_val)
+
+
+histogram = HistogramOp.apply
diff --git a/build/torch211-cxx11-cpu-x86_64-linux/ops/histogram_benchmark.py b/build/torch211-cxx11-cpu-x86_64-linux/ops/histogram_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0b9c6047567b87a295979498142230d1b0c9411
--- /dev/null
+++ b/build/torch211-cxx11-cpu-x86_64-linux/ops/histogram_benchmark.py
@@ -0,0 +1,78 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import numpy as np
+import torch
+# from absl.testing import parameterized
+
+from .. import ops
+
+_HISTOGRAM_TESTS = (
+ (16384, torch.int32, 2),
+ (16384, torch.int32, 4),
+ (16384, torch.int32, 8),
+ (16384, torch.int32, 16),
+ (16384, torch.int32, 32),
+ (16384, torch.int32, 64),
+ (16384, torch.int32, 128),
+ (16384, torch.int32, 256),
+)
+
+
+def benchmark_function(fn, iterations=10):
+ # Run once to get rid of startup overhead.
+ fn()
+ times = []
+ for _ in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+ start.record()
+ fn()
+ end.record()
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ times = np.array(times)
+ return times.mean(), times.std(), times.max(), times.min()
+
+
+def log_benchmark(arguments, mean_t, std_t):
+ print('=' * 60)
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t))
+ print('=' * 60)
+
+
+# class HistogramBenchmark(parameterized.TestCase):
+#
+# @parameterized.parameters(*_HISTOGRAM_TESTS)
+# def testHistogram(self, n, dtype, max_val):
+# x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
+#
+# mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.histogram(x, max_val),)
+# arguments = {
+# 'n': n,
+# 'dtype': dtype,
+# 'max_val': max_val,
+# }
+# log_benchmark(arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_HISTOGRAM_TESTS)
+# def testTorchHistogram(self, n, dtype, max_val):
+# x = torch.randint(0, 128, (n,)).cuda().to(dtype)
+#
+# mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.histc(x, max_val, 0, max_val - 1),)
+# arguments = {
+# 'n': n,
+# 'dtype': dtype,
+# 'max_val': max_val,
+# }
+# log_benchmark(arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch211-cxx11-cpu-x86_64-linux/ops/matmul_benchmark.py b/build/torch211-cxx11-cpu-x86_64-linux/ops/matmul_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c4b9b8866ffed2eb769b77f2320c82e5491ae0e
--- /dev/null
+++ b/build/torch211-cxx11-cpu-x86_64-linux/ops/matmul_benchmark.py
@@ -0,0 +1,415 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+
+# import stk
+
+# try:
+# import stk
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/ops/matmul_benchmark.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+# from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+
+# Calling tensor.t() calls tensor.transpose(0, 1) which calls
+# torch.as_strided(...). Circumvent this chain to avoid an overhead
+# this adds.
+def transpose_view(x):
+ return torch.as_strided(
+ x,
+ (x.shape[1], x.shape[0]),
+ (x.stride()[1], x.stride()[0]),
+ )
+
+
+_MATMUL_TESTS = (
+ (64 * 1024, 512, 2048, 64),
+ (32 * 1024, 768, 3072, 64),
+ (8 * 1024, 1024, 4096, 64),
+ (4 * 2048, 4096, 4 * 4096, 4),
+)
+
+
+def log_benchmark(name, arguments, time, std, flops):
+ benchmark_util.log_benchmark(name, arguments, time, std)
+ print('flops = {:.2f}B'.format(flops / 1e9))
+ print('throughput = {:.2f}T'.format(flops / 1e9 / time))
+ print('=' * 60)
+
+
+# class MatmulBenchmark(parameterized.TestCase):
+#
+# def build_sparse_matrix(self, x, padded_bins, fhs, ne):
+# blocking = 128
+# padded_tokens, _ = x.size()
+# assert padded_tokens % blocking == 0
+# assert fhs % blocking == 0
+#
+# # Offsets for the sparse matrix. All rows have the
+# # same number of nonzero blocks dictated by the
+# # dimensionality of a single expert.
+# block_rows = padded_tokens // blocking
+# blocks_per_row = fhs // blocking
+# offsets = torch.arange(
+# 0,
+# block_rows * blocks_per_row + 1,
+# blocks_per_row,
+# dtype=torch.int32,
+# device=x.device,
+# )
+#
+# # Indices for the sparse matrix. The indices for
+# # the intermediate matrix are dynamic depending
+# # on the mapping of tokens to experts.
+# column_indices = ops.topology(
+# padded_bins,
+# blocking,
+# block_rows,
+# blocks_per_row,
+# )
+# data = torch.empty(
+# column_indices.numel(),
+# blocking,
+# blocking,
+# dtype=torch.float16,
+# device=x.device,
+# )
+# shape = (padded_tokens, fhs * ne)
+# row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
+# return stk.Matrix(shape, data, row_indices, column_indices, offsets)
+#
+# def build_input_matrix(self, sl, hs, ne):
+# x = torch.randn((sl, hs)).cuda().half()
+#
+# # Assign tokens to experts uniformly.
+# top_expert = torch.arange(0, sl).cuda().int() % ne
+#
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(top_expert, ne)
+# padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+# padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+# out = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, 1)
+# return out, padded_bins
+#
+# def build_weight_matrix(self, ne, hs, fhs):
+# return torch.randn((hs, ne * fhs)).cuda().half()
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_Fwd_SDD_NT(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+# w = transpose_view(w)
+#
+# def benchmark():
+# return stk.ops.sdd(x, w, topo)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0::Fwd::SDD::NT',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_GradX_DSD_NN(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+#
+# def benchmark():
+# return stk.ops.dsd(topo, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0::GradX::DSD::NN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_GradW_DSD_TN(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+# topo = topo.t()
+#
+# def benchmark():
+# return stk.ops.dsd(topo, x)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0::GradW::DSD::TN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_Fwd_DSD_NN(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+#
+# def benchmark():
+# return stk.ops.dsd(x, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::Fwd::DSD::NN',
+# arguments,
+# mean_t,
+# std_t,
+# x.nnz * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_GradX_SDD_NT(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+# out = stk.ops.dsd(x, w)
+# w = transpose_view(w)
+#
+# def benchmark():
+# return stk.ops.sdd(out, w, x)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::GradX::SDD::NT',
+# arguments,
+# mean_t,
+# std_t,
+# x.nnz * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_GradW_DSD_TN(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+# out = stk.ops.dsd(x, w)
+# x = x.t()
+#
+# def benchmark():
+# return stk.ops.dsd(x, out)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::GradW::DSD::TN',
+# arguments,
+# mean_t,
+# std_t,
+# x.nnz * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_Fwd_DDD_NT(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, hs)).cuda().half()
+# w = torch.randn((ne, hs, fhs)).cuda().half()
+#
+# w = w.transpose(1, 2).contiguous()
+# w = w.transpose(1, 2)
+#
+# def benchmark():
+# return torch.bmm(x, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0::Fwd:DDD::NT',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_GradX_DDD_NN(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, hs)).cuda().half()
+# w = torch.randn((ne, hs, fhs)).cuda().half()
+# out = torch.bmm(x, w)
+# w = w.transpose(1, 2).contiguous()
+#
+# def benchmark():
+# return torch.bmm(out, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0:GradX:DDD::NN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_GradW_DDD_TN(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, hs)).cuda().half()
+# w = torch.randn((ne, hs, fhs)).cuda().half()
+# out = torch.bmm(x, w)
+# out = out.transpose(1, 2)
+#
+# def benchmark():
+# return torch.bmm(out, x)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0:GradW:DDD::TN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_Fwd_DDD_NN(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+# w = torch.randn((ne, fhs, hs)).cuda().half()
+#
+# def benchmark():
+# return torch.bmm(x, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::Fwd::DDD::NN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_GradX_DDD_NT(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+# w = torch.randn((ne, fhs, hs)).cuda().half()
+# out = torch.bmm(x, w)
+# w = torch.transpose(w, 1, 2)
+#
+# def benchmark():
+# return torch.bmm(out, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::GradX::DDD::NT',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_GradW_DDD_TN(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+# w = torch.randn((ne, fhs, hs)).cuda().half()
+# out = torch.bmm(x, w)
+# x = torch.transpose(x, 1, 2)
+#
+# def benchmark():
+# return torch.bmm(x, out)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::GradW::DDD::TN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * hs * 2,
+# )
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch211-cxx11-cpu-x86_64-linux/ops/padded_gather.py b/build/torch211-cxx11-cpu-x86_64-linux/ops/padded_gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1cf4047c9494394d2a3884ba8830179013db7ff
--- /dev/null
+++ b/build/torch211-cxx11-cpu-x86_64-linux/ops/padded_gather.py
@@ -0,0 +1,55 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for padded_gather kernel.
+class PaddedGatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bin_ids, bins, padded_bins)
+ ctx.top_k = top_k
+ return kernels.padded_gather(
+ x,
+ indices,
+ bin_ids,
+ None,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+
+ indices, bin_ids, bins, padded_bins = ctx.saved_tensors
+ out = kernels.padded_scatter(
+ grad,
+ indices,
+ bin_ids,
+ None,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+ return out, None, None, None, None, None
+
+
+padded_gather = PaddedGatherOp.apply
diff --git a/build/torch211-cxx11-cpu-x86_64-linux/ops/padded_scatter.py b/build/torch211-cxx11-cpu-x86_64-linux/ops/padded_scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..61e021b81497e472cda5d72bdac557a0ca92d262
--- /dev/null
+++ b/build/torch211-cxx11-cpu-x86_64-linux/ops/padded_scatter.py
@@ -0,0 +1,98 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for padded_scatter kernel.
+class PaddedScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+ ):
+ maybe_x = [x] if ctx.needs_input_grad[3] else []
+ ctx.save_for_backward(
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ *maybe_x,
+ )
+ ctx.top_k = top_k
+ ctx.x_shape = x.shape
+ return kernels.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ saved_tensors = ctx.saved_tensors
+
+ indices, bin_ids, weights, bins, padded_bins = saved_tensors[:5]
+ dgrad = None
+ if ctx.needs_input_grad[0]:
+ dgrad = kernels.padded_gather(
+ grad,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[3]: # need wgrad
+ x = saved_tensors[-1]
+ wgrad = kernels.padded_scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+ return dgrad, None, None, wgrad, None, None, None, None
+
+
+def padded_scatter(
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+):
+ return PaddedScatterOp.apply(
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
diff --git a/build/torch211-cxx11-cpu-x86_64-linux/ops/padded_scatter_benchmark.py b/build/torch211-cxx11-cpu-x86_64-linux/ops/padded_scatter_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..3cbe4735891446b46f93170c64c23fe63632bf93
--- /dev/null
+++ b/build/torch211-cxx11-cpu-x86_64-linux/ops/padded_scatter_benchmark.py
@@ -0,0 +1,66 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import torch
+# from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+_PADDED_SCATTER_BENCHMARK = (
+ # dMoE-Medium, 8-way EMP.
+ (1024 * 16, 1024, 8, 4),
+ # dMoE-Medium, post-all-to-all.
+ (1024 * 16 * 4, 1024, 8, 1),
+)
+
+
+# class PaddedScatterTest(parameterized.TestCase):
+#
+# @parameterized.parameters(*_PADDED_SCATTER_BENCHMARK)
+# def testPaddedScatter(self, sl, hs, ne, top_k):
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+#
+# # Randomly assign tokens to experts.
+# top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(top_expert, ne)
+# padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+# padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+#
+# # Sample weights for the scatter reduce.
+# weights = torch.rand((sl * top_k,)).cuda().half()
+#
+# # Gather the data to prepare for backwards.
+# x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
+#
+# def benchmark():
+# return ops.padded_scatter(
+# x,
+# indices,
+# bin_ids,
+# weights,
+# bins,
+# padded_bins,
+# top_k,
+# )
+#
+# time, std = benchmark_util.benchmark_function(benchmark)
+# benchmark_util.log_benchmark(
+# 'Padded Scatter',
+# {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# 'top_k': top_k,
+# },
+# time,
+# std,
+# )
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch211-cxx11-cpu-x86_64-linux/ops/permute_benchmark.py b/build/torch211-cxx11-cpu-x86_64-linux/ops/permute_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..697abddbb3a2082ec4ddd6d94f89f7faabb34b40
--- /dev/null
+++ b/build/torch211-cxx11-cpu-x86_64-linux/ops/permute_benchmark.py
@@ -0,0 +1,149 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import torch
+# from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+_PERMUTE_TESTS = (
+ (16384, 768, 2),
+ (16384, 768, 4),
+ (16384, 768, 8),
+ (16384, 768, 16),
+ (16384, 768, 32),
+ (16384, 768, 64),
+ (16384, 768, 128),
+ (16384 * 8, 768, 2),
+ (16384 * 8, 768, 4),
+ (16384 * 8, 768, 8),
+ (16384 * 8, 768, 16),
+ (16384 * 8, 768, 32),
+ (16384 * 8, 768, 64),
+ (16384 * 8, 768, 128),
+)
+
+
+# class PermuteBenchmark(parameterized.TestCase):
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testBinnedGather(self, sl, hs, ne):
+# # NOTE: Capacity factor == 1.
+# ec = sl // ne
+#
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+# top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(indices, ne)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+#
+# def benchmark():
+# return ops.binned_gather(x, indices, bins, ec)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('BinnedGather', arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testBinnedScatter(self, sl, hs, ne):
+# # NOTE: Capacity factor == 1.
+# ec = sl // ne
+#
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+# top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(indices, ne)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+# x = ops.binned_gather(x, indices, bins, ec)
+#
+# def benchmark():
+# return ops.binned_scatter(x, indices, bins)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('BinnedScatter', arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testPaddedGather(self, sl, hs, ne):
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+#
+# # Randomly assign tokens to experts.
+# top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(top_expert, ne)
+# padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+# padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+#
+# def benchmark():
+# return ops.padded_gather(x, indices, bin_ids, bins, padded_bins)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('PaddedGather', arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testPaddedScatter(self, sl, hs, ne):
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+#
+# # Randomly assign tokens to experts.
+# top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(top_expert, ne)
+# padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+# padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+# x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins)
+#
+# def benchmark():
+# return ops.padded_scatter(x, indices, bin_ids, bins, padded_bins)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('PaddedScatter', arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testCopy(self, sl, hs, ne):
+# # NOTE: Capacity factor == 1.
+# # ec = sl // ne
+#
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+# y = x.clone()
+#
+# def benchmark():
+# return y.copy_(x)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('Copy', arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch211-cxx11-cpu-x86_64-linux/ops/repeat.py b/build/torch211-cxx11-cpu-x86_64-linux/ops/repeat.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e9e09de5f857d51cd758ab30b2f3a846d6f9275
--- /dev/null
+++ b/build/torch211-cxx11-cpu-x86_64-linux/ops/repeat.py
@@ -0,0 +1,10 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+
+def repeat(x: torch.Tensor, tiling: torch.Size):
+ if all((t == 1 for t in tiling)):
+ return x
+ return x.repeat(*tiling)
diff --git a/build/torch211-cxx11-cpu-x86_64-linux/ops/replicate.py b/build/torch211-cxx11-cpu-x86_64-linux/ops/replicate.py
new file mode 100644
index 0000000000000000000000000000000000000000..26daf0eede330603a4b8ea7167faf1411d07ca93
--- /dev/null
+++ b/build/torch211-cxx11-cpu-x86_64-linux/ops/replicate.py
@@ -0,0 +1,36 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for replicate kernel.
+class ReplicateOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, bins: torch.Tensor, num_outputs: int):
+ ctx.save_for_backward(bins)
+ out = torch.empty((x.shape[0], num_outputs), dtype=x.dtype, device=x.device)
+ ops.replicate_forward(x, bins, out)
+ return out
+
+ @staticmethod
+ def backward(ctx: Any, grad: torch.Tensor):
+ bins, = ctx.saved_tensors
+ out = torch.empty((grad.shape[0], bins.shape[0]), dtype=grad.dtype, device=grad.device)
+ ops.replicate_backward(grad, bins, out)
+ return out, None, None
+
+
+replicate = ReplicateOp.apply
diff --git a/build/torch211-cxx11-cpu-x86_64-linux/ops/round_up.py b/build/torch211-cxx11-cpu-x86_64-linux/ops/round_up.py
new file mode 100644
index 0000000000000000000000000000000000000000..6cf6bc873c9f448c5fa9126ebcfd66e8688002af
--- /dev/null
+++ b/build/torch211-cxx11-cpu-x86_64-linux/ops/round_up.py
@@ -0,0 +1,14 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+
+def round_up(x: torch.Tensor, value: int):
+ assert isinstance(value, int)
+ assert x.dtype == torch.int32
+
+ # TODO(tgale): If this becomes and issue
+ # do this in a custom kernel. We only expect
+ # to use this on arrays of less than 1k elements.
+ return torch.div(x + (value - 1), value, rounding_mode='trunc') * value
diff --git a/build/torch211-cxx11-cpu-x86_64-linux/ops/scatter.py b/build/torch211-cxx11-cpu-x86_64-linux/ops/scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4605d9b46f387761b070352365f223dbfe69d47
--- /dev/null
+++ b/build/torch211-cxx11-cpu-x86_64-linux/ops/scatter.py
@@ -0,0 +1,72 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Optional
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for scatter kernel.
+class ScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ) -> torch.Tensor:
+ maybe_x = [x] if ctx.needs_input_grad[3] else []
+ ctx.save_for_backward(indices, bin_ids, weights, bins, *maybe_x)
+ ctx.top_k = top_k
+ ctx.x_shape = x.shape
+ return kernels.scatter(x, indices, bin_ids, weights, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ saved_tensors = ctx.saved_tensors
+
+ indices, bin_ids, weights, bins = saved_tensors[:4]
+ dgrad = None
+ if ctx.needs_input_grad[0]:
+ dgrad = kernels.gather(
+ grad,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[3]: # need wgrad
+ x = saved_tensors[-1]
+ wgrad = kernels.scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bin_ids,
+ bins,
+ ctx.top_k,
+ )
+ return dgrad, None, None, wgrad, None, None, None
+
+
+def scatter(
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+) -> Optional[torch.Tensor]:
+ return ScatterOp.apply(x, indices, bin_ids, weights, bins, top_k)
diff --git a/build/torch211-cxx11-cpu-x86_64-linux/ops/sort.py b/build/torch211-cxx11-cpu-x86_64-linux/ops/sort.py
new file mode 100644
index 0000000000000000000000000000000000000000..bda3bf64283e39533c2eae3627e76bb2d0262c9f
--- /dev/null
+++ b/build/torch211-cxx11-cpu-x86_64-linux/ops/sort.py
@@ -0,0 +1,38 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Optional, Tuple
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+_BITS_FOR_DTYPE = {
+ torch.int16: 16,
+ torch.int32: 32,
+ torch.int64: 64,
+}
+
+
+# Autograd wrapper for sort kernel.
+# NOTE: Does not support gradients.
+class SortOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, end_bit: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
+ if end_bit is None:
+ end_bit = _BITS_FOR_DTYPE[x.dtype]
+ x_out = torch.empty_like(x)
+ iota_out = torch.empty_like(x)
+ ops.sort(x, end_bit, x_out, iota_out)
+ return (x_out, iota_out)
+
+
+sort = SortOp.apply
diff --git a/build/torch211-cxx11-cpu-x86_64-linux/ops/sort_benchmark.py b/build/torch211-cxx11-cpu-x86_64-linux/ops/sort_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..11043c0824c36372585f1d9f48480c2a6ef32eb6
--- /dev/null
+++ b/build/torch211-cxx11-cpu-x86_64-linux/ops/sort_benchmark.py
@@ -0,0 +1,85 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import numpy as np
+import torch
+# from absl.testing import parameterized
+
+from .. import ops
+
+_SORT_TESTS = (
+ (16384, torch.int32, None),
+ (16384, torch.int32, 2),
+ (16384, torch.int32, 128),
+)
+
+_BASELINE_SORT_TESTS = ((16384,),)
+
+
+def numpy_dtype(dtype):
+ types = {
+ torch.int16: np.int16,
+ torch.int32: np.int32,
+ torch.int64: np.int64,
+ }
+ return types[dtype]
+
+
+def benchmark_function(fn, iterations=10):
+ # Run once to get rid of startup overhead.
+ fn()
+ times = []
+ for _ in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+ start.record()
+ fn()
+ end.record()
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ times = np.array(times)
+ return times.mean(), times.std(), times.max(), times.min()
+
+
+def log_benchmark(arguments, mean_t, std_t):
+ print('=' * 60)
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t))
+ print('=' * 60)
+
+
+# class SortBenchmark(parameterized.TestCase):
+#
+# @parameterized.parameters(*_SORT_TESTS)
+# def testSort(self, n, dtype, max_val):
+# if max_val is None:
+# max_val = np.iinfo(numpy_dtype(dtype)).max
+# end_bit = int(np.ceil(np.log2(max_val)))
+# x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
+#
+# mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.sort(x, end_bit),)
+# arguments = {
+# 'n': n,
+# 'dtype': dtype,
+# 'max_val': max_val,
+# }
+# log_benchmark(arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_BASELINE_SORT_TESTS)
+# def testTorchSort(self, n):
+# x = torch.randint(0, 128, (n,)).cuda().to(torch.int32)
+#
+# mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.sort(x))
+# arguments = {
+# 'n': n,
+# }
+# log_benchmark(arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch211-cxx11-cpu-x86_64-linux/ops/stk_autocast.py b/build/torch211-cxx11-cpu-x86_64-linux/ops/stk_autocast.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a3626e5e0eec51339c95a448bca84be14a2ca93
--- /dev/null
+++ b/build/torch211-cxx11-cpu-x86_64-linux/ops/stk_autocast.py
@@ -0,0 +1,39 @@
+# vendored from
+# https://github.com/stanford-futuredata/stk/blob/736313768ef697ce13a0594a41b2512a0fbc9884/stk/backend/autocast.py
+import functools
+import torch
+
+
+def _is_eligible(x):
+ return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64)
+
+
+def _cast(x, dtype):
+ if isinstance(x, torch.Tensor) and _is_eligible(x):
+ return x.to(dtype)
+ elif isinstance(x, map):
+ return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()}
+ elif isinstance(x, list) or isinstance(x, tuple):
+ return type(x)(map(lambda y: _cast(y, dtype), x))
+ return x
+
+
+def custom_fwd(fwd):
+ """Wrap a custom autograd function that always uses autocast dtype."""
+
+ @functools.wraps(fwd)
+ def decorate_fwd(*args, **kwargs):
+ if torch.is_autocast_enabled():
+ with torch.autocast(device_type="cuda", enabled=False):
+ dtype = torch.get_autocast_gpu_dtype()
+ return fwd(*_cast(args, dtype), **_cast(kwargs, dtype))
+ return fwd(*args, **kwargs)
+ return decorate_fwd
+
+
+def custom_bwd(bwd):
+ @functools.wraps(bwd)
+ def decorate_bwd(*args, **kwargs):
+ with torch.autocast(device_type="cuda", enabled=False):
+ return bwd(*args, **kwargs)
+ return decorate_bwd
\ No newline at end of file
diff --git a/build/torch211-cxx11-cpu-x86_64-linux/ops/sum.py b/build/torch211-cxx11-cpu-x86_64-linux/ops/sum.py
new file mode 100644
index 0000000000000000000000000000000000000000..e00c1aa68e1193f5b72f75a2edc37de8d505facc
--- /dev/null
+++ b/build/torch211-cxx11-cpu-x86_64-linux/ops/sum.py
@@ -0,0 +1,9 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+import torch
+
+
+def sum(x: torch.Tensor, dim: int = 0):
+ if x.shape[dim] == 1:
+ return x.squeeze(dim=dim)
+ return x.sum(dim=dim)
diff --git a/build/torch211-cxx11-cpu-x86_64-linux/ops/topology.py b/build/torch211-cxx11-cpu-x86_64-linux/ops/topology.py
new file mode 100644
index 0000000000000000000000000000000000000000..76a50d3164db20534b099dcb4d8487a7aef25d15
--- /dev/null
+++ b/build/torch211-cxx11-cpu-x86_64-linux/ops/topology.py
@@ -0,0 +1,45 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for topology kernel.
+# NOTE: Does not support gradients.
+class TopologyOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(
+ ctx: Any,
+ padded_bins: torch.Tensor,
+ block_size: int,
+ output_block_rows: int,
+ output_block_columns: int,
+ ):
+ out = torch.empty(
+ output_block_rows * output_block_columns,
+ dtype=torch.int16,
+ device=padded_bins.device,
+ )
+ ops.indices(
+ padded_bins,
+ block_size,
+ output_block_rows,
+ output_block_columns,
+ out,
+ )
+ return out
+
+
+topology = TopologyOp.apply
diff --git a/build/torch211-cxx11-cpu-x86_64-linux/stk/__init__.py b/build/torch211-cxx11-cpu-x86_64-linux/stk/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..73c40c267b1c3f4949e9c957a5d2c9f682dfc1a6
--- /dev/null
+++ b/build/torch211-cxx11-cpu-x86_64-linux/stk/__init__.py
@@ -0,0 +1,7 @@
+# import stk.random
+# import stk.ops
+# from stk.matrix import Matrix
+
+from . import random
+from . import ops
+from .matrix import Matrix
diff --git a/build/torch211-cxx11-cpu-x86_64-linux/stk/backend/__init__.py b/build/torch211-cxx11-cpu-x86_64-linux/stk/backend/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/build/torch211-cxx11-cpu-x86_64-linux/stk/backend/autocast.py b/build/torch211-cxx11-cpu-x86_64-linux/stk/backend/autocast.py
new file mode 100644
index 0000000000000000000000000000000000000000..97f6e919a60f3fd579ed0215031008d14111dc96
--- /dev/null
+++ b/build/torch211-cxx11-cpu-x86_64-linux/stk/backend/autocast.py
@@ -0,0 +1,37 @@
+import functools
+import torch
+
+
+def _is_eligible(x):
+ return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64)
+
+
+def _cast(x, dtype):
+ if isinstance(x, torch.Tensor) and _is_eligible(x):
+ return x.to(dtype)
+ elif isinstance(x, map):
+ return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()}
+ elif isinstance(x, list) or isinstance(x, tuple):
+ return type(x)(map(lambda y: _cast(y, dtype), x))
+ return x
+
+
+def custom_fwd(fwd):
+ """Wrap a custom autograd function that always uses autocast dtype."""
+
+ @functools.wraps(fwd)
+ def decorate_fwd(*args, **kwargs):
+ if torch.is_autocast_enabled():
+ with torch.autocast(device_type="cuda", enabled=False):
+ dtype = torch.get_autocast_gpu_dtype()
+ return fwd(*_cast(args, dtype), **_cast(kwargs, dtype))
+ return fwd(*args, **kwargs)
+ return decorate_fwd
+
+
+def custom_bwd(bwd):
+ @functools.wraps(bwd)
+ def decorate_bwd(*args, **kwargs):
+ with torch.autocast(device_type="cuda", enabled=False):
+ return bwd(*args, **kwargs)
+ return decorate_bwd
diff --git a/build/torch211-cxx11-cpu-x86_64-linux/stk/backend/sputnik.py b/build/torch211-cxx11-cpu-x86_64-linux/stk/backend/sputnik.py
new file mode 100644
index 0000000000000000000000000000000000000000..220c947bc1e932e8c77cc30f4069e9930f1aa962
--- /dev/null
+++ b/build/torch211-cxx11-cpu-x86_64-linux/stk/backend/sputnik.py
@@ -0,0 +1,316 @@
+import torch
+
+from ..backend import triton_kernels as backend
+from ..backend.autocast import custom_bwd, custom_fwd
+
+
+def _standardize_shape(x, transpose):
+ if transpose:
+ return torch.Size((x[1], x[0]))
+ return x
+
+
+def _sparse_transpose(x):
+ return (torch.Size((x[0][1], x[0][0])), ) + x[1:]
+
+
+def _transpose_helper(x, transpose):
+ if isinstance(x, torch.Tensor):
+ return x.t() if transpose else x
+ if transpose:
+ x = _sparse_transpose(x)
+ return x + (transpose,)
+
+
+def _wrap(x):
+ if isinstance(x, torch.Tensor):
+ return (x,)
+ return x
+
+
+def _is_transposed(x):
+ return (not x.is_contiguous() and
+ x.stride()[0] == 1 and
+ x.stride()[1] == x.size()[0])
+
+
+def _call_helper(op, out, a, b, trans_a, trans_b):
+ args = (_wrap(_transpose_helper(a, trans_a)) +
+ _wrap(_transpose_helper(b, trans_b)))
+ if isinstance(out, tuple):
+ args = args + out
+ return op(*args)
+
+
+def _preprocess_inputs(lhs, rhs, dy):
+ if isinstance(lhs, torch.Tensor) and _is_transposed(lhs):
+ lhs = lhs.t()
+ if isinstance(rhs, torch.Tensor) and _is_transposed(rhs):
+ rhs = rhs.t()
+ if (isinstance(dy, torch.Tensor) and
+ not dy.is_contiguous() and
+ not _is_transposed(dy)):
+ dy = dy.contiguous()
+ if isinstance(dy, tuple) and not dy[1].is_contiguous():
+ dy = (dy[0], dy[1].contiguous()) + dy[2:]
+ return lhs, rhs, dy
+
+
+def _postprocess_outputs(x, transpose, grad):
+ if isinstance(x, torch.Tensor) and transpose:
+ return grad.t()
+ return grad
+
+
+def _lhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs):
+ lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy)
+
+ a, b = (rhs, dy) if trans_lhs else (dy, rhs)
+ trans_a = trans_lhs and trans_rhs
+ trans_b = trans_lhs or not trans_rhs
+ out = _call_helper(op, lhs, a, b, trans_a, trans_b)
+ return _postprocess_outputs(lhs, trans_lhs, out)
+
+
+def _rhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs):
+ lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy)
+
+ a, b = (dy, lhs) if trans_rhs else (lhs, dy)
+ trans_a = not trans_lhs or trans_rhs
+ trans_b = trans_lhs and trans_rhs
+ out = _call_helper(op, rhs, a, b, trans_a, trans_b)
+ return _postprocess_outputs(rhs, trans_rhs, out)
+
+
+class DSD(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs):
+ ctx.save_for_backward(data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ rhs)
+ ctx.shape = _standardize_shape(shape, transpose_a)
+ ctx.transpose_a = transpose_a
+
+ out = torch.empty(
+ (shape[0], rhs.size()[1]),
+ dtype=rhs.dtype,
+ device=rhs.device)
+
+ backend.dsd(shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs,
+ out)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs = (ctx.shape,) + saved_tensors[:-1]
+ rhs = saved_tensors[-1]
+ trans_a = ctx.transpose_a
+ trans_b = _is_transposed(rhs)
+
+ ddata = None
+ if ctx.needs_input_grad[1]:
+ ddata = _lhs_gradient(sdd,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ drhs = None
+ if ctx.needs_input_grad[-1]:
+ op = dds if trans_b else dsd
+ drhs = _rhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return None, ddata, None, None, None, None, None, None, None, drhs
+
+
+dsd = DSD.apply
+
+
+class DDS(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b):
+ ctx.save_for_backward(lhs,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t)
+ ctx.shape = _standardize_shape(shape, transpose_b)
+ ctx.transpose_b = transpose_b
+ out = torch.empty((lhs.size()[0], shape[1]),
+ dtype=lhs.dtype,
+ device=lhs.device)
+ backend.dds(lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b,
+ out)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs = saved_tensors[0]
+ rhs = (ctx.shape,) + saved_tensors[1:]
+ trans_a = _is_transposed(lhs)
+ trans_b = ctx.transpose_b
+
+ dlhs = None
+ if ctx.needs_input_grad[0]:
+ op = dsd if trans_a else dds
+ dlhs = _lhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ ddata = None
+ if ctx.needs_input_grad[2]:
+ ddata = _rhs_gradient(sdd,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return dlhs, None, ddata, None, None, None, None, None, None, None
+
+
+dds = DDS.apply
+
+
+class SDD(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ lhs,
+ rhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t):
+ ctx.save_for_backward(
+ lhs,
+ rhs,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t)
+ ctx.shape = shape
+ out = torch.empty(
+ data.shape,
+ dtype=lhs.dtype,
+ device=lhs.device)
+ backend.sdd(lhs,
+ rhs,
+ shape,
+ out,
+ offsets,
+ row_indices,
+ column_indices)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs, rhs = saved_tensors[:2]
+ dy = (ctx.shape, dy) + saved_tensors[2:]
+ trans_a = _is_transposed(lhs)
+ trans_b = _is_transposed(rhs)
+
+ dlhs = None
+ if ctx.needs_input_grad[0]:
+ op = dds if trans_a else dsd
+ dlhs = _lhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ drhs = None
+ if ctx.needs_input_grad[1]:
+ op = dsd if trans_b else dds
+ drhs = _rhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return dlhs, drhs, None, None, None, None, None, None, None, None
+
+
+sdd = SDD.apply
+
+class RowIndices(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, shape, data, offsets, column_indices):
+ out = torch.empty(
+ column_indices.shape,
+ dtype=column_indices.dtype,
+ device=column_indices.device)
+ backend.row_indices(shape, data, offsets, column_indices, out)
+ return out
+
+
+row_indices = RowIndices.apply
diff --git a/build/torch211-cxx11-cpu-x86_64-linux/stk/backend/triton_kernels.py b/build/torch211-cxx11-cpu-x86_64-linux/stk/backend/triton_kernels.py
new file mode 100644
index 0000000000000000000000000000000000000000..c535309f3321249f475367164a558f94a4f8eb86
--- /dev/null
+++ b/build/torch211-cxx11-cpu-x86_64-linux/stk/backend/triton_kernels.py
@@ -0,0 +1,393 @@
+import torch
+import triton
+import triton.language as tl
+from dataclasses import dataclass
+
+@dataclass
+class TritonConfig:
+ BLOCK_M: int = 128
+ BLOCK_N: int = 128
+ BLOCK_K: int = 32
+ BLOCK_SIZE: int = 128
+ NUM_STAGES: int = 4
+ NUM_WARPS: int = 4
+
+def _validate_matmul_dims(M: int, K: int, N: int):
+ error_string = "incompatible dimensions: tensor has dim with length: {}, which must be divisible by {}"
+ assert M % TritonConfig.BLOCK_M == 0, error_string.format(M, TritonConfig.BLOCK_M)
+ assert K % TritonConfig.BLOCK_K == 0, error_string.format(K, TritonConfig.BLOCK_K)
+ assert N % TritonConfig.BLOCK_N == 0, error_string.format(N, TritonConfig.BLOCK_N)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _sdd_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+ # matrix multiplication
+ pid = tl.program_id(0)
+ pid_m = tl.load(row_indices + pid)
+ pid_n = tl.load(column_indices + pid)
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
+ rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
+ rk = tl.arange(0, BLOCK_K)
+ # pointers
+ A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
+ B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ for k in range(0, tl.cdiv(K, BLOCK_K)):
+ a = tl.load(A)
+ b = tl.load(B)
+ acc += tl.dot(a, b)
+ A += BLOCK_K * stride_ak
+ B += BLOCK_K * stride_bk
+ #Store to sparse matrix
+ acc = acc.to(C.dtype.element_ty)
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+ cm = tl.arange(0, BLOCK_M)
+ cn = tl.arange(0, BLOCK_N)
+ C = C + pid * BLOCK_ELEMENTS + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _dsd_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices, offsets,
+ block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+
+ # matrix multiplication
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+
+ num_pid_m = tl.num_programs(0)
+ num_pid_n = tl.num_programs(1)
+ pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M)
+
+ start_inx = tl.load(offsets + pid_m)
+ end_inx = tl.load(offsets + pid_m + 1)
+
+ # pointers to sparse matrix
+ rm = tl.arange(0, BLOCK_M)
+ rak = tl.arange(0, BLOCK_K)
+
+ A += (rm[:, None] * stride_am + rak[None, :] * stride_ak)
+
+ # pointers to dense matrix
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ rbk = tl.arange(0, BLOCK_K)
+ B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn)
+
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K)
+
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+ ak_sub_incr = BLOCK_K * stride_ak
+ bk_sub_incr = BLOCK_K * stride_bk
+ bk_block_incr = BLOCK_SIZE * stride_bk
+
+ for k in range(nsub_blocks * (end_inx - start_inx)):
+ sub_block_inx = k % nsub_blocks
+ block_inx = k // nsub_blocks
+
+ if trans_A:
+ ptr_A = A + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr
+ else:
+ ptr_A = A + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr
+
+ ptr_B = B + tl.load(column_indices + start_inx + block_inx) * bk_block_incr + sub_block_inx * bk_sub_incr
+
+ a = tl.load(ptr_A)
+ b = tl.load(ptr_B)
+ acc += tl.dot(a, b)
+
+ acc = acc.to(C.dtype.element_ty)
+
+ cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+
+ C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _dds_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices, offsets,
+ block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+
+ # matrix multiplication
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+
+ num_pid_m = tl.num_programs(0)
+ num_pid_n = tl.num_programs(1)
+ pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M)
+
+ start_inx = tl.load(offsets + pid_n)
+ end_inx = tl.load(offsets + pid_n + 1)
+
+ # pointers to dense matrix
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ rak = tl.arange(0, BLOCK_K)
+
+ A += (rm[:, None] * stride_am + rak[None, :] * stride_ak)
+
+ # pointers to sparse matrix
+ rn = tl.arange(0, BLOCK_N)
+ rbk = tl.arange(0, BLOCK_K)
+ B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn)
+
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K)
+
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+
+ ak_sub_incr = BLOCK_K * stride_ak
+ ak_block_incr = BLOCK_SIZE * stride_ak
+ bk_sub_incr = BLOCK_K * stride_bk
+
+ for k in range(nsub_blocks * (end_inx - start_inx)):
+ sub_block_inx = k % nsub_blocks
+ block_inx = k // nsub_blocks
+
+ if trans_B:
+ ptr_B = B + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr
+ else:
+ ptr_B = B + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr
+
+ ptr_A = A + tl.load(column_indices + start_inx + block_inx) * ak_block_incr + sub_block_inx * ak_sub_incr
+ a = tl.load(ptr_A)
+ b = tl.load(ptr_B)
+ acc += tl.dot(a, b)
+
+ acc = acc.to(C.dtype.element_ty)
+ cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+def dsd(shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs,
+ out
+ ):
+
+ device = rhs.device
+ trans_A = transpose_a
+ trans_B = False
+
+ if rhs.stride(0) > 1 and rhs.stride(1) > 1:
+ trans_B = True
+
+ # checks constraints
+ assert shape[1] == rhs.shape[0], "incompatible dimensions"
+ M, K = shape
+ _, N = rhs.shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if rhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ stride_am, stride_ak = data.stride(1), data.stride(2)
+ stride_bk, stride_bn = rhs.stride(0), rhs.stride(1)
+ a_column_indices = column_indices
+ a_offsets = offsets
+
+ # launch kernel
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N']))
+
+ if trans_A:
+ stride_am, stride_ak = data.stride(2), data.stride(1)
+ a_column_indices, a_offsets = column_indices_t, offsets_t
+
+ if trans_B:
+ stride_bk, stride_bn = rhs.stride(1), rhs.stride(0)
+
+ _dsd_kernel[grid](
+ data.data, rhs, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(0), out.stride(1),
+ row_indices, a_column_indices, a_offsets,
+ block_offsets_t, trans_A, trans_B,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+ # return out
+
+def dds(lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b,
+ out
+ ):
+
+ device = lhs.device
+ trans_B = transpose_b
+ trans_A = False
+
+ if lhs.stride(0) > 1 and lhs.stride(1) > 1:
+ trans_A = True
+
+ # checks constraints
+ assert lhs.shape[1] == shape[0], "incompatible dimensions"
+ M, K = lhs.shape
+ _, N = shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if lhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ stride_am, stride_ak = lhs.stride(0), lhs.stride(1)
+ stride_bk, stride_bn = data.stride(1), data.stride(2)
+ b_column_indices = column_indices_t
+ b_offsets = offsets_t
+
+ # launch kernel
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N']))
+
+ if trans_A:
+ stride_am, stride_ak = lhs.stride(1), lhs.stride(0)
+ if trans_B:
+ stride_bk, stride_bn = data.stride(2), data.stride(1)
+ b_column_indices, b_offsets = column_indices, offsets
+
+ _dds_kernel[grid](
+ lhs, data, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(0), out.stride(1),
+ row_indices, b_column_indices, b_offsets,
+ block_offsets_t, trans_A, trans_B,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+
+def sdd(lhs,
+ rhs,
+ shape,
+ out,
+ offsets,
+ row_indices,
+ column_indices
+ ):
+
+ device = out.device
+ trans_A = False
+ trans_B = False
+
+ if lhs.stride(0) > 1 and lhs.stride(1) > 1:
+ trans_A = True
+ if rhs.stride(0) > 1 and rhs.stride(1) > 1:
+ trans_B = True
+
+ # checks constraints
+ assert lhs.shape[1] == rhs.shape[0], "incompatible dimensions"
+ M, K = lhs.shape
+ _, N = rhs.shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if out.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ # launch kernel
+ nnz_blocks = len(row_indices)
+ grid = lambda META: (nnz_blocks,)
+
+ stride_am, stride_ak = lhs.stride(0), lhs.stride(1)
+ stride_bk, stride_bn = rhs.stride(0), rhs.stride(1)
+
+ if trans_A:
+ stride_am, stride_ak = lhs.stride(1), lhs.stride(0)
+ if trans_B:
+ stride_bk, stride_bn = rhs.stride(1), rhs.stride(0)
+
+ _sdd_kernel[grid](
+ lhs, rhs, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(1), out.stride(2),
+ row_indices, column_indices,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+
+@triton.jit
+def _row_indices_kernel(offsets, out):
+ pid = tl.program_id(0)
+ row_offset = tl.load(offsets + pid)
+ nnz_blocks = tl.load(offsets + pid + 1) - row_offset
+ for nnz_block in range(nnz_blocks):
+ tl.store(out + row_offset + nnz_block, pid)
+
+def row_indices(
+ shape, data, offsets, column_indices, out
+):
+ block_rows = len(offsets) - 1
+ _row_indices_kernel[(block_rows, )](offsets, out)
diff --git a/build/torch211-cxx11-cpu-x86_64-linux/stk/matrix.py b/build/torch211-cxx11-cpu-x86_64-linux/stk/matrix.py
new file mode 100644
index 0000000000000000000000000000000000000000..80f42263d6aada287adbfa52a61fe950162a9e28
--- /dev/null
+++ b/build/torch211-cxx11-cpu-x86_64-linux/stk/matrix.py
@@ -0,0 +1,329 @@
+import numpy as np
+import torch
+
+# 1. Add heavyweight (data) validation helper.
+# 2. Add construction helpers
+# 3. Make indentation consistent
+# 4. Replace asserts with descriptive errors.
+
+##
+### Validation helpers.
+##
+
+
+def _validate_matrix(shape, data, row_indices, column_indices, offsets):
+ # Data should be [nnz, block_size, block_size]
+ if data.dim() == 1:
+ data = torch.reshape(data, [data.numel(), 1, 1])
+
+ # Blocks should be square.
+ if data.shape[-2] != data.shape[-1]:
+ raise ValueError(
+ "Expected square blocking in data. "
+ f"Got block shape {[data.shape[-2], data.shape[-1]]}")
+
+ # Flatten batch dimensions on data - original shape preserved
+ # in shape argument.
+ block_size = data.shape[-1]
+ data = data.view([-1, block_size, block_size])
+
+ if data.dim() != 3:
+ raise ValueError(
+ "Expected 3D shape for data (nnz, block, block). "
+ f"Got shape {data.dim()}D shape.")
+
+ block_size = data.shape[1]
+ if shape[-2] % block_size != 0 or shape[-1] % block_size != 0:
+ raise ValueError(
+ "Matrix shape must be dividible by blocking. "
+ f"Got shape {shape} with "
+ f"{[block_size, block_size]} blocking.")
+
+ if np.prod(shape) < data.numel():
+ raise ValueError(
+ "Invalid matrix. Number of nonzeros exceeds matrix capacity "
+ f"({data.numel()} v. {np.prod(shape)})")
+
+ if row_indices.dim() != 1:
+ raise ValueError(
+ f"Expected 1D row_indices. Got {row_indices.dim()}D row_indices.")
+
+ if column_indices.dim() != 1:
+ raise ValueError(
+ f"Expected 1D column_indices. Got {column_indices.dim()}D column_indices.")
+
+ if offsets.dim() != 1:
+ raise ValueError(
+ f"Expected 1D offsets. Got {offsets.dim()}D offsets.")
+
+ if row_indices.numel() != data.shape[0]:
+ raise ValueError(
+ "Expected 1 index per nonzero block. "
+ f"Got {row_indices.numel()} row_indices for {data.shape[0]} blocks")
+
+ if column_indices.numel() != data.shape[0]:
+ raise ValueError(
+ "Expected 1 index per nonzero block. "
+ f"Got {column_indices.numel()} column_indices for {data.shape[0]} blocks")
+
+ block_rows = np.prod(shape[:-1]) / block_size
+ if offsets.numel() != block_rows + 1:
+ raise ValueError(
+ "Expected one offset per block row plus one. "
+ f"Got {offsets.numel()} offsets with {block_rows} block rows.")
+
+ is_cuda = (data.is_cuda and
+ row_indices.is_cuda and
+ column_indices.is_cuda and
+ offsets.is_cuda)
+ is_cpu = (not data.is_cuda and
+ not row_indices.is_cuda and
+ not column_indices.is_cuda and
+ not offsets.is_cuda)
+ if not (is_cuda or is_cpu):
+ raise ValueError(
+ "Expected data & meta-data on common device. "
+ f"Got data on {data.device}, row_indices on {row_indices.device} "
+ f"column_indices on {column_indices.device} and "
+ f"offsets on {offsets.device}.")
+
+ if data.dtype != torch.float16:
+ raise ValueError(
+ f"Expected float16 data. Got {data.dtype} data.")
+ if row_indices.dtype != torch.int16:
+ raise ValueError(
+ f"Expected int16 row_indices. Got {row_indices.dtype} row_indices.")
+ if column_indices.dtype != torch.int16:
+ raise ValueError(
+ f"Expected int16 column_indices. Got {column_indices.dtype} column_indices.")
+ if offsets.dtype != torch.int32:
+ raise ValueError(
+ f"Expected int32 offsets. Got {offsets.dtype} offsets.")
+ return data
+
+
+def _transpose(size, data, row_indices, column_indices, offsets):
+ block_columns = size[1] // data.shape[1]
+
+ # Sort row indices by column indices to get the transposed matrix's
+ # column indices.
+ gather_indices = column_indices.argsort()
+ column_indices_t = row_indices.gather(0, gather_indices)
+ block_offsets_t = gather_indices.int()
+
+ # NOTE: Histogram is not implemented for any integer type on CPU. Do
+ # the histogram in 32-bit float, which can exactly represent 16-bit
+ # integers.
+ column_indices_float = column_indices.float()
+
+ zero = torch.zeros((1,), dtype=torch.int32, device=data.device)
+ nnz_per_column = column_indices_float.histc(block_columns, 0, block_columns)
+ nnz_per_column = nnz_per_column.int()
+ offsets_t = torch.cat([zero, nnz_per_column.cumsum(0, dtype=torch.int32)])
+ return column_indices_t, offsets_t, block_offsets_t
+
+
+class Matrix(torch.nn.Module):
+ """A matrix stored in sparse format.
+
+ Underlying format is block compressed sparse row (BCSR).
+
+ TODO(tgale): Make this mirror torch.Tensor API as much as possible.
+ """
+
+ def __init__(self,
+ size,
+ data,
+ row_indices,
+ column_indices,
+ offsets,
+ column_indices_t=None,
+ offsets_t=None,
+ block_offsets_t=None):
+ super().__init__()
+ self._size = size
+ self._data = data
+ self._row_indices = row_indices
+ self._column_indices = column_indices
+ self._offsets = offsets
+
+ # Produce the transpose meta-data if it is not passed in.
+ if ((column_indices_t is None) or (offsets_t is None) or
+ (block_offsets_t is None)):
+ column_indices_t, offsets_t, block_offsets_t = _transpose(
+ size, data, row_indices, column_indices, offsets)
+ self._column_indices_t = column_indices_t
+ self._offsets_t = offsets_t
+ self._block_offsets_t = block_offsets_t
+
+ self._transposed = False
+
+ # Validate that our metadata will not overflow.
+ max_dim = np.iinfo(np.int16).max * self.blocking
+ if column_indices.dtype == torch.int16:
+ if size[0] > max_dim or size[1] > max_dim:
+ raise ValueError(
+ "Sparse matrix with shape {size} exceeds representable "
+ "size with 16-bit indices.")
+
+ def validate(self):
+ _validate_matrix(self._size,
+ self._data,
+ self._row_indices,
+ self._column_indices,
+ self._offsets)
+
+ # TODO(tgale): Add heavyweight data validation.
+
+ def to(self, device):
+ # TODO(tgale): Handle type conversions here. We
+ # need to set the appropriate meta-data type for
+ # the given floating-point type.
+ self._data = self._data.to(device)
+ self._row_indices = self._row_indices.to(device)
+ self._column_indices = self._column_indices.to(device)
+ self._offsets = self._offsets.to(device)
+ self._column_indices_t = self._column_indices_t.to(device)
+ self._offsets_t = self._offsets_t.to(device)
+ self._block_offsets_t = self._block_offsets_t.to(device)
+ return self
+
+ def cuda(self):
+ return self.to(torch.cuda.current_device())
+
+ def clone(self):
+ return Matrix(
+ self.size(),
+ self.data.clone(),
+ self.row_indices.clone(),
+ self.column_indices.clone(),
+ self.offsets.clone(),
+ self.column_indices_t.clone(),
+ self.offsets_t.clone(),
+ self.block_offsets_t.clone())
+
+ def t(self):
+ if self.dim() != 2:
+ raise ValueError(
+ "t() expects a tensor with <= 2 dimensions, "
+ f"but self is {self.dim()}D.")
+ out = Matrix(self.size(),
+ self.data,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+ out._transposed = not self._transposed
+ out._size = torch.Size((self._size[1], self._size[0]))
+ return out
+
+ def contiguous(self):
+ raise ValueError("Not yet implemented.")
+
+ def is_contiguous(self):
+ return not self._transposed
+
+ @property
+ def is_cuda(self):
+ return self._data.is_cuda
+
+ @property
+ def device(self):
+ return self._data.device
+
+ def size(self):
+ return self._size
+
+ @property
+ def shape(self):
+ return self.size()
+
+ def dim(self):
+ return len(self._size)
+
+ @property
+ def data(self):
+ return self._data
+
+ @property
+ def row_indices(self):
+ return self._row_indices
+
+ @property
+ def column_indices(self):
+ return self._column_indices
+
+ @property
+ def offsets(self):
+ return self._offsets
+
+ @property
+ def offsets_t(self):
+ return self._offsets_t
+
+ @property
+ def column_indices_t(self):
+ return self._column_indices_t
+
+ @property
+ def block_offsets_t(self):
+ return self._block_offsets_t
+
+ @property
+ def dtype(self):
+ return self.data.dtype
+
+ @property
+ def nnz(self):
+ return self.data.numel()
+
+ @property
+ def blocking(self):
+ return self.data.shape[1]
+
+ @property
+ def requires_grad(self):
+ return self.data.requires_grad
+
+ def requires_grad_(self, x):
+ self.data.requires_grad_(x)
+ return self
+
+ def view(self, *shape):
+ assert self.is_contiguous()
+ if shape[-1] != self.size()[-1]:
+ raise ValueError(
+ "Can't change view on compressed dimension. "
+ f"{self.size()[-1]} v. {shape[-1]}.")
+ if np.prod(shape) != np.prod(self.size()):
+ raise ValueError(
+ "Mismatch in numel of Matrix and new shape. "
+ f"{np.prod(self.size())} v. {np.prod(shape)}")
+ return Matrix(shape,
+ self.data,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+
+ @property
+ def grad(self):
+ # TODO(tgale): Make sure this mirrors torch.Tensor
+ # behavior in the case where we ask for the gradient
+ # of a non-contiguous tensor.
+ size = self.size()
+ if not self.is_contiguous():
+ size = torch.Size((size[1], size[0]))
+ out = Matrix(size,
+ self.data.grad,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+ return out if self.is_contiguous() else out.t()
diff --git a/build/torch211-cxx11-cpu-x86_64-linux/stk/ops/__init__.py b/build/torch211-cxx11-cpu-x86_64-linux/stk/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc873b236f4cd4036964c016a4036e3ce5ebf1ac
--- /dev/null
+++ b/build/torch211-cxx11-cpu-x86_64-linux/stk/ops/__init__.py
@@ -0,0 +1,3 @@
+from .linear_ops import dds, dsd, sdd
+from .matrix_ops import ones_like, row_indices, sum, to_dense, to_sparse
+from .eltwise_ops import mul
diff --git a/build/torch211-cxx11-cpu-x86_64-linux/stk/ops/eltwise_ops.py b/build/torch211-cxx11-cpu-x86_64-linux/stk/ops/eltwise_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba7d7332320250fd01fa60e60528f19de3e8ed03
--- /dev/null
+++ b/build/torch211-cxx11-cpu-x86_64-linux/stk/ops/eltwise_ops.py
@@ -0,0 +1,28 @@
+from ..matrix import Matrix
+
+def mul(a, b):
+ """Performs element-wise multiplication of matrices a and b.
+
+ It is the user's responsibility to make sure that a and b
+ follow the same matrix topology. This function assumes it is safe
+ to use the topoplogy of a.
+
+ Args:
+ a: stk.Matrix.
+ b: stk.Matrix with a's matrix topology.
+
+ Returns:
+ stk.Matrix where the entries correspond to torch.mul(a, b).
+ """
+ assert isinstance(a, Matrix)
+ assert isinstance(b, Matrix)
+ assert a.size() == b.size()
+
+ return Matrix(a.size(),
+ a.data * b.data,
+ a.row_indices,
+ a.column_indices,
+ a.offsets,
+ a.column_indices_t,
+ a.offsets_t,
+ a.block_offsets_t)
diff --git a/build/torch211-cxx11-cpu-x86_64-linux/stk/ops/eltwise_ops_test.py b/build/torch211-cxx11-cpu-x86_64-linux/stk/ops/eltwise_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..2939372a5c68ac92b47b11015db4f75f4fd60ffa
--- /dev/null
+++ b/build/torch211-cxx11-cpu-x86_64-linux/stk/ops/eltwise_ops_test.py
@@ -0,0 +1,86 @@
+import unittest
+import itertools
+import torch
+# from absl.testing import parameterized
+
+import stk
+from stk.ops.linear_ops_test import allclose, _dense_and_sparse
+
+_MATRIX_SIZES = (
+ (128, 128, 0.0),
+ (256, 256, 0.5),
+ (2048, 1024, 0.8),
+ (512, 128, 0.0),
+ (128, 512, 0.0),
+ (1024, 512, 0.0),
+ (1024, 512, 0.5),
+ (1024, 512, 0.75),
+ (512, 1024, 0.0),
+ (512, 1024, 0.5),
+ (512, 1024, 0.75),
+ (1024, 1024, 0.0),
+ (1024, 1024, 0.5),
+ (1024, 1024, 0.75),
+)
+
+_DTYPE = (
+ torch.float16, torch.bfloat16
+)
+
+def _generate_testcases():
+ testcases = itertools.product(_MATRIX_SIZES, _DTYPE)
+ testcases = [(*size, 128, dtype) for
+ (size, dtype) in testcases]
+ return testcases
+
+_ELTWISE_OP_TESTS = _generate_testcases()
+
+def _dense_and_sparse_like(x, std=0.1):
+ dense_data = torch.randn_like(x.data, device=x.device) * std
+ sparse = stk.Matrix(x.size(),
+ dense_data,
+ x.row_indices,
+ x.column_indices,
+ x.offsets)
+ dense = stk.ops.to_dense(sparse)
+
+ return (dense.requires_grad_(True),
+ sparse.requires_grad_(True))
+
+# @parameterized.parameters(_ELTWISE_OP_TESTS)
+# class EltwiseOpsTest(parameterized.TestCase):
+#
+# def testEltwiseMul(self, m, n, sparsity, blocking, dtype):
+#
+# a_dense, a = _dense_and_sparse(m, n, sparsity, blocking, dtype)
+# b_dense, b = _dense_and_sparse_like(a)
+#
+# out = stk.ops.mul(a, b)
+# expected_out = torch.mul(a_dense, b_dense)
+#
+# # Compute the gradients w.r.t. the inputs.
+# expected_out.sum().backward()
+# stk.ops.sum(out).backward()
+#
+# # Validate the results.
+# out = stk.ops.to_dense(out)
+# self.assertEqual(out.dim(), 2)
+# self.assertEqual(expected_out.size(), out.size())
+# self.assertTrue(allclose(out, expected_out))
+#
+# # LHS gradient.
+# grad = stk.ops.to_dense(a.grad)
+# expected_grad = a_dense.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size(), grad.size())
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# # RHS gradient.
+# grad = stk.ops.to_dense(b.grad)
+# expected_grad = b_dense.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size(), grad.size())
+# self.assertTrue(allclose(grad, expected_grad))
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch211-cxx11-cpu-x86_64-linux/stk/ops/linear_ops.py b/build/torch211-cxx11-cpu-x86_64-linux/stk/ops/linear_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d277c8c07f9e30addc31900a12175c8a1f4d7ad
--- /dev/null
+++ b/build/torch211-cxx11-cpu-x86_64-linux/stk/ops/linear_ops.py
@@ -0,0 +1,59 @@
+import torch
+
+from ..backend import sputnik
+from ..matrix import Matrix
+
+
+def dsd(a, b):
+ assert isinstance(a, Matrix)
+ assert isinstance(b, torch.Tensor)
+ return sputnik.dsd(
+ a.size(),
+ a.data, a.offsets,
+ a.row_indices,
+ a.column_indices,
+ a.offsets_t,
+ a.column_indices_t,
+ a.block_offsets_t,
+ not a.is_contiguous(),
+ b)
+
+
+def dds(a, b):
+ assert isinstance(a, torch.Tensor)
+ assert isinstance(b, Matrix)
+ return sputnik.dds(
+ a,
+ b.size(),
+ b.data, b.offsets,
+ b.row_indices,
+ b.column_indices,
+ b.offsets_t,
+ b.column_indices_t,
+ b.block_offsets_t,
+ not b.is_contiguous())
+
+
+def sdd(a, b, topo):
+ assert isinstance(a, torch.Tensor)
+ assert isinstance(b, torch.Tensor)
+ assert isinstance(topo, Matrix)
+ assert topo.is_contiguous()
+ out = sputnik.sdd(
+ a, b,
+ topo.size(),
+ topo.data,
+ topo.offsets,
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets_t,
+ topo.column_indices_t,
+ topo.block_offsets_t)
+ return Matrix(topo.size(),
+ out,
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets,
+ topo.column_indices_t,
+ topo.offsets_t,
+ topo.block_offsets_t)
diff --git a/build/torch211-cxx11-cpu-x86_64-linux/stk/ops/linear_ops_test.py b/build/torch211-cxx11-cpu-x86_64-linux/stk/ops/linear_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1c24d350df9c1b2346c7da885502cd696c88867
--- /dev/null
+++ b/build/torch211-cxx11-cpu-x86_64-linux/stk/ops/linear_ops_test.py
@@ -0,0 +1,216 @@
+import unittest
+import itertools
+import numpy as np
+import torch
+# from absl.testing import parameterized
+
+import stk
+
+
+def allclose(x, y, pct=0.25):
+ mask = torch.isclose(x, y, rtol=5e-2)
+ pct_diff = (mask.numel() - mask.sum()) / mask.numel() * 100
+ if pct_diff > pct:
+ print("{:.2f}% of values not close.".format(pct_diff))
+ return False
+ return True
+
+
+# An assortment of problems designed to make sure
+# the bindings are operating correctly.
+_MATRIX_SIZES = (
+ (128, 128, 128, 0.0),
+ (256, 256, 256, 0.5),
+ (2048, 1024, 512, 0.8),
+ (512, 128, 128, 0.0),
+ (128, 128, 512, 0.0),
+ (1024, 512, 512, 0.0),
+ (1024, 512, 512, 0.5),
+ (1024, 512, 512, 0.75),
+ (512, 512, 1024, 0.0),
+ (512, 512, 1024, 0.5),
+ (512, 512, 1024, 0.75),
+ (1024, 1024, 1024, 0.0),
+ (1024, 1024, 1024, 0.5),
+ (1024, 1024, 1024, 0.75),
+)
+
+_TRANSPOSE = (
+ (False, False),
+ (False, True),
+ (True, False),
+ (True, True),
+)
+
+_DTYPE = (
+ torch.float16, torch.bfloat16
+)
+
+def _generate_testcases():
+ testcases = itertools.product(_MATRIX_SIZES, _TRANSPOSE, _DTYPE)
+ testcases = [(*size, *trans, 128, dtype) for
+ (size, trans, dtype) in testcases]
+ return testcases
+
+_LINEAR_OP_TESTS = _generate_testcases()
+
+def _dense_and_sparse(rows, cols, sparsity, blocking, dtype, std=0.1):
+ mask = stk.random.dense_mask(rows, cols, sparsity, blocking)
+ dense = (torch.randn(rows, cols) * std * mask).type(dtype)
+ sparse = stk.ops.to_sparse(dense, blocking)
+ cuda_device = torch.device("cuda")
+ return (dense.to(cuda_device).requires_grad_(True),
+ sparse.to(cuda_device).requires_grad_(True))
+
+
+def _dense(rows, cols, dtype, std=0.1):
+ cuda_device = torch.device("cuda")
+ out = (torch.randn(rows, cols) * std).type(dtype)
+ return out.to(cuda_device).requires_grad_(True)
+
+
+def _dense_2x(rows, cols, dtype):
+ a = _dense(rows, cols, dtype)
+ return a, a.detach().requires_grad_(True)
+
+
+def _with_transpose(op, a, b, trans_a, trans_b):
+ a = a.t() if trans_a else a
+ b = b.t() if trans_b else b
+ return op(a, b)
+
+
+def _mmm(a, b, topo):
+ mask = stk.ops.to_dense(stk.ops.ones_like(topo))
+ return torch.mm(a, b) * mask
+
+
+def _sparse_out_with_transpose(op, a, b, topo, trans_a, trans_b):
+ a = a.t() if trans_a else a
+ b = b.t() if trans_b else b
+ return op(a, b, topo)
+
+
+def _mask(x, mask):
+ mask = stk.ops.to_dense(stk.ops.ones_like(mask))
+ return x * mask
+
+
+# @parameterized.parameters(*_LINEAR_OP_TESTS)
+# class LinearOpsTest(parameterized.TestCase):
+#
+# def testLinearOps_Dsd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+# # Construct the operands.
+# a_shape = (k, m) if trans_a else (m, k)
+# a_dense, a = _dense_and_sparse(*a_shape, sparsity, blocking, dtype)
+# b_shape = (n, k) if trans_b else (k, n)
+# b, bcp = _dense_2x(*b_shape, dtype)
+#
+# # Execute the matmul.
+# out = _with_transpose(stk.ops.dsd, a, b, trans_a, trans_b)
+# expected_out = _with_transpose(torch.mm, a_dense, bcp, trans_a, trans_b)
+#
+# # Compute the gradients w.r.t. the inputs.
+# expected_out.sum().backward()
+# out.sum().backward()
+#
+# # Validate the results.
+# self.assertEqual(out.dim(), 2)
+# self.assertEqual(expected_out.size()[0], out.size()[0])
+# self.assertEqual(expected_out.size()[1], out.size()[1])
+# self.assertTrue(allclose(out, expected_out))
+#
+# # LHS gradient.
+# grad = stk.ops.to_dense(a.grad)
+# expected_grad = _mask(a_dense.grad, a.grad)
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# # RHS gradient.
+# grad = b.grad
+# expected_grad = bcp.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# def testLinearOps_Dds(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+# # Construct the operands.
+# a_shape = (k, m) if trans_a else (m, k)
+# a, acp = _dense_2x(*a_shape, dtype)
+# b_shape = (n, k) if trans_b else (k, n)
+# b_dense, b = _dense_and_sparse(*b_shape, sparsity, blocking, dtype)
+#
+# # Execute the matmul.
+# out = _with_transpose(stk.ops.dds, a, b, trans_a, trans_b)
+# expected_out = _with_transpose(torch.mm, acp, b_dense, trans_a, trans_b)
+#
+# # Compute the gradients w.r.t. the inputs.
+# expected_out.sum().backward()
+# out.sum().backward()
+#
+# # Validate the results.
+# self.assertEqual(out.dim(), 2)
+# self.assertEqual(expected_out.size()[0], out.size()[0])
+# self.assertEqual(expected_out.size()[1], out.size()[1])
+# self.assertTrue(allclose(out, expected_out))
+#
+# # LHS gradient.
+# grad = a.grad
+# expected_grad = acp.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# # RHS gradient.
+# grad = stk.ops.to_dense(b.grad)
+# expected_grad = _mask(b_dense.grad, b.grad)
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# def testLinearOps_Sdd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+# # Construct the operands.
+# a_shape = (k, m) if trans_a else (m, k)
+# a, acp = _dense_2x(*a_shape, dtype)
+# b_shape = (n, k) if trans_b else (k, n)
+# b, bcp = _dense_2x(*b_shape, dtype)
+# _, topo = _dense_and_sparse(m, n, sparsity, blocking, dtype)
+#
+# # Execute the matmul.
+# out = _sparse_out_with_transpose(stk.ops.sdd, a, b, topo, trans_a, trans_b)
+# expected_out = _sparse_out_with_transpose(_mmm, acp, bcp, topo, trans_a, trans_b)
+#
+# # Compute the gradients w.r.t. the inputs.
+# expected_out.sum().backward()
+# stk.ops.sum(out).backward()
+#
+# # Validate the results.
+# out = stk.ops.to_dense(out)
+# self.assertEqual(out.dim(), 2)
+# self.assertEqual(expected_out.size()[0], out.size()[0])
+# self.assertEqual(expected_out.size()[1], out.size()[1])
+# self.assertTrue(allclose(out, expected_out))
+#
+# # LHS gradient.
+# grad = a.grad
+# expected_grad = acp.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# # RHS gradient.
+# grad = b.grad
+# expected_grad = bcp.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch211-cxx11-cpu-x86_64-linux/stk/ops/matrix_ops.py b/build/torch211-cxx11-cpu-x86_64-linux/stk/ops/matrix_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..447c72dc73439d84f58c917676cc04e64f13e97d
--- /dev/null
+++ b/build/torch211-cxx11-cpu-x86_64-linux/stk/ops/matrix_ops.py
@@ -0,0 +1,98 @@
+from ..backend import sputnik
+from ..matrix import Matrix
+import torch
+import numpy as np
+
+
+@torch.no_grad()
+def row_indices(shape, data, offsets, column_indices):
+ return sputnik.row_indices(shape, data, offsets, column_indices)
+
+
+# TODO(tgale): Replace this helper with a custom kernel. This operation
+# is much simpler to do than how it's currently implemented.
+@torch.no_grad()
+def _expand_for_blocking(idxs, blocking):
+ # Duplicate for block column dimension.
+ idxs = torch.reshape(idxs, [idxs.size()[0], 1, 2]).repeat(1, blocking, 1)
+
+ # Update the column indices.
+ idxs[:, :, 1] *= blocking
+ idxs[:, :, 1] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking])
+
+ # Duplicate for block row dimension.
+ idxs = torch.reshape(idxs, [idxs.size()[0], 1, blocking, 2])
+ idxs = idxs.repeat(1, blocking, 1, 1)
+
+ # Update the row indices.
+ idxs[:, :, :, 0] *= blocking
+ idxs[:, :, :, 0] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking, 1])
+ idxs = torch.reshape(idxs, [-1, 2])
+ return idxs
+
+
+# TODO(tgale): Add input type checking.
+@torch.no_grad()
+def to_dense(x):
+ assert isinstance(x, Matrix)
+
+ shape = (np.prod(x.shape[:-1]), x.shape[-1])
+ row_idxs = x.row_indices.type(torch.int32)
+ col_idxs = x.column_indices.type(torch.int32)
+ indices = _expand_for_blocking(torch.stack([row_idxs, col_idxs], dim=1), x.blocking)
+ indices = (indices[:, 0] * shape[1] + indices[:, 1]).type(torch.int64)
+
+ out = torch.zeros(shape[0] * shape[1], dtype=x.dtype, device=x.device)
+ out.scatter_(0, indices, x.data.flatten())
+ return out.reshape(x.size())
+
+
+@torch.no_grad()
+def _mask(x, blocking=1):
+ assert x.dim() == 2
+ assert x.size()[0] % blocking == 0
+ assert x.size()[1] % blocking == 0
+ block_rows = x.size()[0] // blocking
+ block_cols = x.size()[1] // blocking
+ x = torch.reshape(x, [block_rows, blocking, block_cols, blocking])
+ x = torch.sum(torch.abs(x), dim=(1, 3))
+ return x != 0
+
+
+# TODO(tgale): Add input type checking.
+@torch.no_grad()
+def to_sparse(x, blocking=1):
+ m = _mask(x, blocking)
+
+ # TODO(tgale): Set to appropriate type for input matrix.
+ row_nnzs = torch.sum(m, dim=1).type(torch.int32)
+ zeros = torch.zeros((1,), dtype=row_nnzs.dtype, device=row_nnzs.device)
+ offsets = torch.cat([zeros, torch.cumsum(row_nnzs, dim=0)])
+ offsets = offsets.type(torch.int32)
+
+ indices = torch.nonzero(m).type(torch.int16)
+ row_indices = indices[:, 0]
+ column_indices = indices[:, 1]
+
+ # Nonzero indices in the dense matrix.
+ nonzero_indices = torch.nonzero(m)
+ nonzero_indices = _expand_for_blocking(nonzero_indices, blocking)
+ nonzero_indices = nonzero_indices[:, 0] * x.size()[1] + nonzero_indices[:, 1]
+
+ # Gather the data and construct the sparse matrix.
+ data = torch.gather(x.flatten(), dim=0, index=nonzero_indices)
+ data = torch.reshape(data, [-1, blocking, blocking])
+ return Matrix(x.size(), data, row_indices, column_indices, offsets)
+
+
+@torch.no_grad()
+def ones_like(x):
+ return Matrix(x.size(),
+ torch.ones_like(x.data),
+ x.row_indices,
+ x.column_indices, x.offsets)
+
+
+def sum(x):
+ assert isinstance(x, Matrix)
+ return x.data.sum()
diff --git a/build/torch211-cxx11-cpu-x86_64-linux/stk/ops/matrix_ops_test.py b/build/torch211-cxx11-cpu-x86_64-linux/stk/ops/matrix_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..d172d921f6f08b0e4fb709207a458b0e1e071bd0
--- /dev/null
+++ b/build/torch211-cxx11-cpu-x86_64-linux/stk/ops/matrix_ops_test.py
@@ -0,0 +1,62 @@
+import unittest
+
+# from absl.testing import parameterized
+import stk
+import torch
+
+
+# @parameterized.parameters(
+# (8, 16, 0.0, 1),
+# (8, 16, 0.5, 1),
+# (8, 16, .95, 1),
+# (16, 8, 0.0, 1),
+# (16, 8, 0.5, 1),
+# (16, 8, .95, 1),
+# (8, 16, 0.0, 8),
+# (8, 16, 0.5, 8),
+# (8, 16, 1.0, 8),
+# (16, 8, 0.0, 8),
+# (16, 8, 0.5, 8),
+# (16, 8, 1.0, 8),
+# (128, 256, 0.5, 16),
+# (256, 128, 0.75, 32),
+# (512, 512, .875, 128))
+# class MatrixOpsTest(parameterized.TestCase):
+#
+# def testMatrixOps_FormatConversion(self, rows, cols, sparsity, blocking):
+# mask = stk.random.dense_mask(rows, cols, sparsity, blocking)
+# x = (torch.randn(rows, cols) * mask).type(torch.float16)
+#
+# # Convert the matrix to sparse format.
+# sparse_x = stk.ops.to_sparse(x, blocking)
+#
+# # Validate the matrix.
+# sparse_x.validate()
+#
+# # Validate the shape.
+# self.assertEqual(sparse_x.dim(), 2)
+# self.assertEqual(sparse_x.size()[0], rows)
+# self.assertEqual(sparse_x.size()[1], cols)
+#
+# # Validate the sparsity.
+# numblocks = rows // blocking * cols // blocking
+# nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+# self.assertEqual(sparse_x.nnz, nnz)
+#
+# # Convert back to dense format.
+# dense_x = stk.ops.to_dense(sparse_x)
+#
+# # Validate the shape.
+# self.assertEqual(dense_x.dim(), 2)
+# self.assertEqual(dense_x.size()[0], rows)
+# self.assertEqual(dense_x.size()[1], cols)
+#
+# # Validate the sparsity
+# self.assertEqual(torch.count_nonzero(dense_x).item(), nnz)
+#
+# # Validate the output.
+# self.assertTrue(torch.all(torch.eq(x, dense_x)))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch211-cxx11-cpu-x86_64-linux/stk/random/__init__.py b/build/torch211-cxx11-cpu-x86_64-linux/stk/random/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2576d1ca27283f77569a9a620c7c99fa68aaf30e
--- /dev/null
+++ b/build/torch211-cxx11-cpu-x86_64-linux/stk/random/__init__.py
@@ -0,0 +1,2 @@
+# from stk.random.random_ops import dense_mask, mask, randn
+from .random_ops import dense_mask, mask, randn
diff --git a/build/torch211-cxx11-cpu-x86_64-linux/stk/random/random_ops.py b/build/torch211-cxx11-cpu-x86_64-linux/stk/random/random_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1b36771e0eb8e7abf46bcb3b136b5fb1d29df93
--- /dev/null
+++ b/build/torch211-cxx11-cpu-x86_64-linux/stk/random/random_ops.py
@@ -0,0 +1,36 @@
+import numpy as np
+import torch
+from ..ops import matrix_ops
+
+
+@torch.no_grad()
+def dense_mask(rows, cols, sparsity, blocking=1):
+ assert sparsity >= 0.0 and sparsity <= 1.0
+ assert rows % blocking == 0 and cols % blocking == 0
+
+ block_rows, block_cols = (rows // blocking, cols // blocking)
+ nnz = round(block_rows * block_cols * (1 - sparsity))
+
+ out = np.ones(block_rows * block_cols)
+ mask = np.random.choice(out.size, out.size - nnz, replace=False)
+ out[mask] = 0.0
+
+ out = np.tile(
+ np.reshape(out, [block_rows, 1, block_cols, 1]),
+ (1, blocking, 1, blocking))
+ out = np.reshape(out, [rows, cols])
+ return torch.from_numpy(out.astype(np.float32))
+
+
+@torch.no_grad()
+def mask(m, n, sparsity, blocking=1):
+ out = dense_mask(m, n, sparsity, blocking).type(torch.float16)
+ return matrix_ops.to_sparse(out, blocking=blocking)
+
+
+@torch.no_grad()
+def randn(shape, sparsity, blocking=1):
+ shape_2d = (np.prod(shape[:-1]), shape[-1])
+ out = mask(*shape_2d, sparsity, blocking)
+ out.data.copy_(torch.randn(*out.data.shape))
+ return out.view(*shape)
diff --git a/build/torch211-cxx11-cpu-x86_64-linux/stk/random/random_ops_test.py b/build/torch211-cxx11-cpu-x86_64-linux/stk/random/random_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..71d716b78b5ec009cbf9ac2dfdf09162a0102e62
--- /dev/null
+++ b/build/torch211-cxx11-cpu-x86_64-linux/stk/random/random_ops_test.py
@@ -0,0 +1,73 @@
+import unittest
+
+# from absl.testing import parameterized
+from . import random
+import torch
+
+
+# @parameterized.parameters(
+# (8, 16, 0.0, 1),
+# (8, 16, 0.5, 1),
+# (8, 16, .95, 1),
+# (16, 8, 0.0, 1),
+# (16, 8, 0.5, 1),
+# (16, 8, .95, 1),
+# (8, 16, 0.0, 8),
+# (8, 16, 0.5, 8),
+# (8, 16, 1.0, 8),
+# (16, 8, 0.0, 8),
+# (16, 8, 0.5, 8),
+# (16, 8, 1.0, 8),
+# (128, 256, 0.5, 16),
+# (256, 128, 0.75, 32),
+# (512, 512, .875, 128))
+# class RandomOpsTest(parameterized.TestCase):
+#
+# def testRandomOps_DenseMask(self, rows, cols, sparsity, blocking):
+# mask = random.dense_mask(
+# rows, cols, sparsity, blocking)
+#
+# # Validate the shape.
+# self.assertEqual(mask.dim(), 2)
+# self.assertEqual(mask.size()[0], rows)
+# self.assertEqual(mask.size()[1], cols)
+#
+# # Validate the sparsity
+# numblocks = rows // blocking * cols // blocking
+# nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+# self.assertEqual(
+# torch.count_nonzero(mask).item(),
+# nnz)
+#
+# # Check values are zero or one.
+# self.assertTrue(
+# torch.all(torch.logical_or(
+# torch.eq(mask, 0),
+# torch.eq(mask, 1))))
+#
+# def testRandomOps_SparseMask(self, rows, cols, sparsity, blocking):
+# mask = random.mask(
+# rows, cols, sparsity, blocking)
+#
+# # Validate the matrix.
+# mask.validate()
+#
+# # Validate the shape.
+# self.assertEqual(mask.dim(), 2)
+# self.assertEqual(mask.size()[0], rows)
+# self.assertEqual(mask.size()[1], cols)
+#
+# # Validate the sparsity.
+# numblocks = rows // blocking * cols // blocking
+# nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+# self.assertEqual(mask.nnz, nnz)
+#
+# # Check values are zero or one.
+# self.assertTrue(
+# torch.all(torch.logical_or(
+# torch.eq(mask.data, 0),
+# torch.eq(mask.data, 1))))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch211-cxx11-cpu-x86_64-linux/xpu_fused_moe.py b/build/torch211-cxx11-cpu-x86_64-linux/xpu_fused_moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..0440b484e6f4073a384dcf3bc562174d2ef71d25
--- /dev/null
+++ b/build/torch211-cxx11-cpu-x86_64-linux/xpu_fused_moe.py
@@ -0,0 +1,672 @@
+# SPDX-License-Identifier: Apache-2.0
+# MegaBlocks XPU Fused MoE Implementation
+import os
+import torch
+
+from ._ops import ops, add_op_namespace_prefix
+
+from torch.library import register_fake
+
+
+def resolve_dtensor(weight: torch.Tensor):
+ """Convert DTensor to local tensor for use with custom ops."""
+ from torch.distributed._tensor import DTensor
+ if isinstance(weight, DTensor):
+ return weight.to_local()
+ return weight
+
+
+# Register fake/meta kernels for torch.compile compatibility
+def _register_xpu_fake_kernels():
+ """Register fake kernels for XPU MoE operations to support torch.compile."""
+
+ def _register_if_available(op_name, fn):
+ if hasattr(ops, op_name):
+ register_fake(add_op_namespace_prefix(op_name))(fn)
+
+ _register_if_available(
+ "cutlass_grouped_gemm_interface",
+ lambda ptr_A, ptr_B, ptr_scales, ptr_bias, ptr_D, expert_first_token_offset, N, K, num_experts, is_B_int4, is_B_mxfp4: ptr_D,
+ )
+
+ _register_if_available(
+ "fused_moe_prologue",
+ lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, ep_rank, ep_size, num_experts_on_rank: None,
+ )
+
+ _register_if_available(
+ "moe_gather",
+ lambda output, moe_output, topk_weights, permuted_row_to_unpermuted_row, unpermuted_row_to_permuted_row, expert_first_token_offset, num_experts: None,
+ )
+
+ _register_if_available(
+ "silu_and_mul",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "mul_and_silu",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_and_mul",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_tanh_and_mul",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_fast",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_new",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_quick",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "swigluoai_and_mul",
+ lambda out, input, alpha=1.702, limit=7.0: None,
+ )
+
+
+# Register fake kernels on module load
+_register_xpu_fake_kernels()
+
+
+# default
+def cutlass_grouped_gemm(input_A, input_B, bias, output, expert_token_count, n,
+ k, num_experts):
+ # expert_token_count_ = torch.tensor(expert_token_count,
+ # dtype=torch.int64,
+ # device=input_A.device)
+ # if bias is not None:
+ # bias = bias.repeat_interleave(expert_token_count_, dim=0).float()
+
+ def exclusive_prefix_sum(arr):
+ prefix = [0]
+ for i, x in enumerate(arr):
+ prefix.append(prefix[-1] + x)
+ return prefix
+
+ expert_offset = torch.tensor(exclusive_prefix_sum(expert_token_count),
+ dtype=torch.int64,
+ device="xpu")
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=None,
+ ptr_bias=bias,
+ ptr_D=output,
+ expert_first_token_offset=expert_offset,
+ N=n,
+ K=k,
+ num_experts=num_experts,
+ is_B_int4=False,
+ is_B_mxfp4=False)
+
+
+def cutlass_grouped_gemm_xe2(input_A, input_B, scales, bias, output,
+ num_rows_per_expert, n, k, num_experts, is_B_int4,
+ is_B_mxfp4):
+ expert_first_token_offset = torch.cat([
+ torch.tensor([0],
+ dtype=num_rows_per_expert.dtype,
+ device=num_rows_per_expert.device),
+ torch.cumsum(num_rows_per_expert, dim=0)
+ ]).to(torch.int64)
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=scales,
+ ptr_bias=bias,
+ ptr_D=output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=n,
+ K=k,
+ num_experts=num_experts,
+ is_B_int4=is_B_int4,
+ is_B_mxfp4=is_B_mxfp4)
+
+
+def ceilDiv(a, b):
+ return (a + b - 1) // b
+
+
+def compute_num_tokens_per_block(num_tokens, num_experts_per_node):
+ for num_tokens_per_block in [32, 64, 128, 256, 512, 1024]:
+ num_blocks_per_seq = ceilDiv(num_tokens, num_tokens_per_block)
+ if num_blocks_per_seq * num_experts_per_node <= num_tokens_per_block:
+ return num_tokens_per_block
+ return 1024
+
+
+def _bytes_to_typed_tensor(byte_tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
+ """Reinterpret a uint8 buffer as a typed tensor by copying bytes.
+
+ This avoids `Tensor.view(dtype)` which can fail under torch.compile
+ constant folding when shape divisibility is not proven.
+ """
+ if byte_tensor.dtype != torch.uint8:
+ raise ValueError("byte_tensor must be uint8")
+ itemsize = torch.empty((), dtype=dtype).element_size()
+ numel = byte_tensor.numel() // itemsize
+ out = torch.empty((numel,), dtype=dtype, device=byte_tensor.device)
+ out.view(torch.uint8).copy_(byte_tensor.contiguous())
+ return out
+
+
+def implement_zp(qweight):
+ # change u4 to s4 to avoid zero point in gemm kernel
+ # only support default zero point now
+ assert qweight.dtype == torch.uint8, "Input tensor must be uint8"
+
+ high_u4 = (qweight >> 4) & 0x0F
+ low_u4 = qweight & 0x0F
+
+ high_s8 = high_u4.to(torch.int8)
+ low_s8 = low_u4.to(torch.int8)
+
+ high_s8 = high_s8 - 8
+ low_s8 = low_s8 - 8
+
+ def pack_compact(a, b):
+
+ def process_number(x):
+ sign = (x < 0).to(torch.uint8)
+ abs_low3 = (x.view(torch.uint8) & 0x7).to(torch.uint8)
+ return (sign << 3) | abs_low3
+
+ packed_a = process_number(a)
+ packed_b = process_number(b)
+
+ return (packed_a << 4) | packed_b
+
+ result = pack_compact(high_s8, low_s8)
+
+ return result
+
+
+def xpu_fused_moe(hidden_states,
+ w13,
+ w13_scales,
+ w13_bias,
+ w2,
+ w2_scales,
+ w2_bias,
+ topk_weights,
+ topk_ids,
+ n_experts_per_token,
+ activation,
+ num_experts,
+ ep_rank=0,
+ ep_size=1,
+ is_fp8=False,
+ is_int4=False,
+ is_mxfp4=False):
+ '''
+ hidden_states: [num_rows, hidden_size]
+ w13: [num_experts, 2*inter_size, hidden_size]
+ w13_scales:
+ None for bf16/fp16
+ or [num_experts] for fp8
+ or [num_experts, 2*inter_size, hidden_size // group_size] for 4bits
+ w13_bias: [num_experts, 2*inter_size] or None
+ w2: [num_experts, hidden_size, inter_size]
+ w2_scales:
+ None for bf16/fp16
+ or [num_experts] for fp8
+ or [num_experts, hidden_size, inter_size // group_size] for 4bits
+ w2_bias: [num_experts, hidden_size] or None
+ topk_weights: [num_rows, topk]
+ topk_ids: [num_rows, topk]
+ n_experts_per_token: int
+ activation: str
+ num_experts: int
+ is_int4: bool
+ is_mxfp4: bool
+ '''
+
+ # Resolve DTensors to local tensors before passing to custom ops
+ hidden_states = resolve_dtensor(hidden_states)
+ w13 = resolve_dtensor(w13)
+ w2 = resolve_dtensor(w2)
+ if w13_scales is not None:
+ w13_scales = resolve_dtensor(w13_scales)
+ if w13_bias is not None:
+ w13_bias = resolve_dtensor(w13_bias)
+ if w2_scales is not None:
+ w2_scales = resolve_dtensor(w2_scales)
+ if w2_bias is not None:
+ w2_bias = resolve_dtensor(w2_bias)
+ topk_weights = resolve_dtensor(topk_weights)
+ topk_ids = resolve_dtensor(topk_ids)
+
+ output = torch.empty_like(hidden_states)
+ num_rows, hidden_size = list(hidden_states.shape)
+
+ dim_last = w13.shape[-1]
+ dim_second_last = w13.shape[-2]
+
+ # w13 is combined gate+up weights, so one dimension is 2*inter_size
+ # Determine which dimension is hidden_size and which is 2*inter_size
+ if dim_second_last == hidden_size:
+ # w13 is [E, hidden_size, 2*inter_size] - standard layout
+ inter_size = dim_last // 2
+ needs_transpose = False
+ else:
+ # w13 is [E, 2*inter_size, hidden_size] - needs transpose
+ inter_size = dim_second_last // 2
+ needs_transpose = True
+
+ assert w13.is_contiguous() and w2.is_contiguous()
+
+ # 4bits support [E, N, K]
+ # other types [E, K, N]
+ if not is_int4 and not is_mxfp4:
+ if not hasattr(w13, 'xpu_fused_moe'):
+ if needs_transpose:
+ w13.data = w13.transpose(-1, -2).contiguous()
+ w2.data = w2.transpose(-1, -2).contiguous()
+ w13.xpu_fused_moe = True
+ w13.inter_size = inter_size
+ else:
+ inter_size = w13.inter_size
+
+ if is_int4 and not hasattr(w13, 'xpu_fused_moe'):
+ w13_tmp = torch.empty_like(w13)
+ w2_tmp = torch.empty_like(w2)
+ for i in range(num_experts):
+ w13_tmp[i] = implement_zp(w13[i])
+ w2_tmp[i] = implement_zp(w2[i])
+ w13_tmp = w13_tmp.contiguous()
+ w2_tmp = w2_tmp.contiguous()
+ w13.data = w13_tmp
+ w2.data = w2_tmp
+ w13.xpu_fused_moe = True
+
+ # TODO: will all integrated in Cpp func. Temporary expose before gemm fusion
+ num_experts_per_node = num_experts
+ experts_per_token = n_experts_per_token
+ num_moe_inputs = n_experts_per_token * num_rows
+ permuted_elems = num_moe_inputs * hidden_size
+ # interbuf_elems = num_moe_inputs * inter_size
+ permuted_row_to_unpermuted_row_size = num_moe_inputs * 4
+ permuted_token_selected_experts_size = num_moe_inputs * 4
+ src_to_dest_map_size = experts_per_token * num_rows * 4
+ expert_first_token_offset_size = (num_experts_per_node + 1) * 8
+ num_tokens_per_block = compute_num_tokens_per_block(
+ num_rows, num_experts_per_node)
+ num_blocks_per_seq = ceilDiv(num_rows, num_tokens_per_block)
+ blocked_expert_counts_size = num_experts_per_node * num_blocks_per_seq * 4
+ blocked_expert_counts_cumsum_size = blocked_expert_counts_size
+ blocked_row_to_unpermuted_row_size = num_experts_per_node * num_rows * 4
+ permuted_data_size = permuted_elems * hidden_states.element_size()
+ permuted_token_final_scales_size = num_moe_inputs * 4
+
+ ws_map = {}
+ map_offset = 0
+
+ def config_ws(name, size):
+ nonlocal map_offset
+ if size % 256 != 0:
+ size += 256 - size % 256
+ ws_map[name] = (size, map_offset)
+ map_offset += size
+
+ config_ws("permuted_row_to_unpermuted_row",
+ permuted_row_to_unpermuted_row_size)
+ config_ws("permuted_token_selected_experts",
+ permuted_token_selected_experts_size)
+ config_ws("unpermuted_row_to_permuted_row", src_to_dest_map_size)
+ config_ws("blocked_expert_counts", blocked_expert_counts_size)
+ config_ws("blocked_expert_counts_cumsum",
+ blocked_expert_counts_cumsum_size)
+ config_ws("blocked_row_to_unpermuted_row",
+ blocked_row_to_unpermuted_row_size)
+ config_ws("expert_first_token_offset", expert_first_token_offset_size)
+ config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
+ config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
+
+ workspace = torch.zeros(map_offset,
+ dtype=torch.uint8,
+ device=hidden_states.device)
+ if topk_ids.dtype == torch.int32:
+ topk_ids = topk_ids.to(torch.int64)
+ ops.fused_moe_prologue(
+ input=hidden_states,
+ token_selected_experts=topk_ids,
+ token_final_scales=topk_weights,
+ workspace=workspace,
+ hidden_size=hidden_size,
+ inter_size=inter_size,
+ ep_rank=ep_rank,
+ ep_size=ep_size,
+ num_experts_on_rank=num_experts_per_node)
+
+ expert_first_token_offset_bytes = workspace[
+ ws_map["expert_first_token_offset"][1]:
+ ws_map["expert_first_token_offset"][1] +
+ expert_first_token_offset_size]
+ unpermuted_row_to_permuted_row_bytes = workspace[
+ ws_map["unpermuted_row_to_permuted_row"][1]:
+ ws_map["unpermuted_row_to_permuted_row"][1] +
+ src_to_dest_map_size]
+ permuted_row_to_unpermuted_row_bytes = workspace[
+ ws_map["permuted_row_to_unpermuted_row"][1]:
+ ws_map["permuted_row_to_unpermuted_row"][1] +
+ permuted_row_to_unpermuted_row_size]
+
+ if torch.compiler.is_compiling():
+ expert_first_token_offset = _bytes_to_typed_tensor(
+ expert_first_token_offset_bytes, torch.int64
+ )
+ unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
+ unpermuted_row_to_permuted_row_bytes, torch.int32
+ )
+ permuted_row_to_unpermuted_row = _bytes_to_typed_tensor(
+ permuted_row_to_unpermuted_row_bytes, torch.int32
+ )
+ else:
+ expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
+ unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
+ permuted_row_to_unpermuted_row = permuted_row_to_unpermuted_row_bytes.view(torch.int32)
+ gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
+ ws_map["overlapped_gemm1_gemm2_inputs"][1] +
+ permuted_data_size].view(hidden_states.dtype).view(
+ num_moe_inputs, hidden_size)
+ # permuted_token_final_scales = workspace[
+ # ws_map["permuted_token_final_scales"][1]:
+ # ws_map["permuted_token_final_scales"][1] +
+ # permuted_token_final_scales_size].view(torch.float)
+ gemm1_output = torch.empty((num_moe_inputs, 2 * inter_size),
+ dtype=hidden_states.dtype,
+ device=hidden_states.device)
+
+ ########### gemm1 ##################
+ input_B = w13
+
+ if not is_fp8 and not is_int4 and not is_mxfp4:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=gemm1_input,
+ ptr_B=input_B,
+ ptr_scales=None,
+ ptr_bias=w13_bias,
+ ptr_D=gemm1_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=2 * inter_size,
+ K=hidden_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+ else:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=gemm1_input,
+ ptr_B=input_B,
+ ptr_scales=w13_scales,
+ ptr_bias=w13_bias,
+ ptr_D=gemm1_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=2 * inter_size,
+ K=hidden_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+
+ # act
+ act_output = torch.empty((num_moe_inputs, inter_size),
+ dtype=gemm1_output.dtype,
+ device=gemm1_output.device)
+ if activation == "silu":
+ ops.silu_and_mul(act_output, gemm1_output)
+ elif activation == "gelu":
+ ops.gelu_and_mul(act_output, gemm1_output)
+ elif activation == "swigluoai":
+ ops.swigluoai_and_mul(act_output, gemm1_output, 1.702, 7.0)
+ else:
+ raise ValueError(f"Unsupported FusedMoe activation: {activation}.")
+
+ ########### gemm2 ##################
+ input_A = act_output.contiguous()
+ input_B = w2
+ gemm2_output = torch.empty((num_moe_inputs, hidden_size),
+ dtype=hidden_states.dtype,
+ device=hidden_states.device)
+ if not is_fp8 and not is_int4 and not is_mxfp4:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=None,
+ ptr_bias=w2_bias,
+ ptr_D=gemm2_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=hidden_size,
+ K=inter_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+ else:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=w2_scales,
+ ptr_bias=w2_bias,
+ ptr_D=gemm2_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=hidden_size,
+ K=inter_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+
+ ops.moe_gather(output, gemm2_output, topk_weights,
+ permuted_row_to_unpermuted_row,
+ unpermuted_row_to_permuted_row,
+ expert_first_token_offset,
+ num_experts_per_node)
+ return output
+
+
+def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor:
+ """Apply jitter to the input tensor for regularization."""
+ low = 1.0 - moe_jitter_eps
+ high = 1.0 + moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return x * (low + noise * (high - low))
+
+
+def compute_top_k(scores: torch.Tensor, moe_top_k: int):
+ """Compute the top-k scores from the logits."""
+ if moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, moe_top_k, dim=-1)
+
+
+def route_tokens_xpu(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: torch.Tensor,
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ training: bool = False,
+) -> tuple:
+ """Route tokens to experts and compute expert weights and indices (XPU version)."""
+ if training and moe_jitter_eps is not None:
+ x = apply_jitter(x, moe_jitter_eps)
+
+ x_flat = x.view(-1, x.shape[-1])
+ logits = torch.nn.functional.linear(x_flat, router_weight, router_bias)
+ expert_weights, expert_indices = compute_top_k(logits, moe_top_k)
+ expert_weights = expert_weights.softmax(dim=-1)
+ if moe_normalize_expert_weights is not None:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+
+ return logits, expert_weights, expert_indices
+
+
+def _get_device_mesh(model):
+ """Extract device_mesh from child's unused pre_hook closure for EP support."""
+ try:
+ hook = next(
+ h
+ for h in model.experts._forward_pre_hooks.values()
+ if "device_mesh" in h.__code__.co_freevars
+ )
+ return hook.__closure__[
+ hook.__code__.co_freevars.index("device_mesh")
+ ].cell_contents
+ except Exception:
+ return None
+
+
+class MegaBlocksMoeMLP(torch.nn.Module):
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> tuple:
+ """
+ Forward pass through the MoE layer.
+
+ Args:
+ x: Input tensor of shape [batch_size, seq_len, hidden_size] or [tokens, hidden_size]
+
+ Returns:
+ Tuple of (output, expert_weights) where:
+ - output: Tensor of same shape as input
+ - expert_weights: Expert weights for each token [tokens, top_k]
+ """
+
+ # Get MoE parameters from the wrapped modules
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+
+ # Get EP (Expert Parallelism) parameters
+ ep_size = 1
+ ep_rank = 0
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = _get_device_mesh(self)
+ if device_mesh is not None:
+ expert_parallel_group = device_mesh.get_group()
+ if expert_parallel_group is not None:
+ import torch.distributed as dist
+ if dist.is_initialized():
+ ep_size = dist.get_world_size(expert_parallel_group)
+ ep_rank = dist.get_rank(expert_parallel_group)
+
+ # Number of experts on this rank
+ num_experts_on_rank = moe_num_experts // ep_size
+
+ # Detect activation type - check for GptOss-style swigluoai activation
+ # GptOssExperts has alpha and limit attributes for swigluoai
+ if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
+ activation = "swigluoai"
+ else:
+ activation = getattr(self.experts, "activation", "silu")
+
+ # Get weight tensors - support different naming conventions
+ if hasattr(self.experts, "gate_up_proj"):
+ w13 = self.experts.gate_up_proj
+ # NOTE: swigluoai_and_mul kernel expects interleaved layout [g0,u0,g1,u1,...]
+ # which matches GptOss's gate_up_proj layout, so no conversion needed.
+
+ elif hasattr(self.experts, "w1"):
+ # Combine w1 and w3 if stored separately
+ w1 = self.experts.w1
+ w3 = getattr(self.experts, "w3", None)
+ if w3 is not None:
+ w13 = torch.cat([w1, w3], dim=-2)
+ else:
+ w13 = w1
+ else:
+ raise AttributeError("experts module must have 'gate_up_proj' or 'w1' attribute")
+
+ if hasattr(self.experts, "down_proj"):
+ w2 = self.experts.down_proj
+ elif hasattr(self.experts, "w2"):
+ w2 = self.experts.w2
+ else:
+ raise AttributeError("experts module must have 'down_proj' or 'w2' attribute")
+
+ # Get optional bias tensors
+ w13_bias = getattr(self.experts, "gate_up_proj_bias", None)
+ w2_bias = getattr(self.experts, "down_proj_bias", None)
+
+ # Get quantization info
+ is_fp8 = getattr(self.experts, "is_fp8", False)
+ is_int4 = getattr(self.experts, "is_int4", False)
+ is_mxfp4 = getattr(self.experts, "is_mxfp4", False)
+
+ w13_scales = getattr(self.experts, "gate_up_proj_scales", None)
+ w2_scales = getattr(self.experts, "down_proj_scales", None)
+
+ # Store original shape
+ in_shape = x.size()
+
+ # Route tokens to experts
+ logits, expert_weights, expert_indices = route_tokens_xpu(
+ x,
+ self.router.weight,
+ getattr(self.router, "bias", None),
+ moe_top_k,
+ moe_num_experts,
+ moe_jitter_eps,
+ moe_normalize_expert_weights,
+ self.training,
+ )
+
+ # Reshape input for fused MoE
+ x_flat = x.view(-1, x.shape[-1])
+
+ # Call XPU fused MoE kernel
+ output = xpu_fused_moe(
+ hidden_states=x_flat,
+ w13=w13,
+ w13_scales=w13_scales,
+ w13_bias=w13_bias,
+ w2=w2,
+ w2_scales=w2_scales,
+ w2_bias=w2_bias,
+ topk_weights=expert_weights.float(),
+ topk_ids=expert_indices,
+ n_experts_per_token=moe_top_k,
+ activation=activation,
+ num_experts=num_experts_on_rank,
+ ep_rank=ep_rank,
+ ep_size=ep_size,
+ is_fp8=is_fp8,
+ is_int4=is_int4,
+ is_mxfp4=is_mxfp4,
+ )
+
+ # All-reduce across EP group to combine partial expert outputs
+ if ep_size > 1 and expert_parallel_group is not None:
+ import torch.distributed as dist
+ dist.all_reduce(output, op=dist.ReduceOp.SUM, group=expert_parallel_group)
+
+ # Restore original shape
+ output = output.view(in_shape)
+
+ return output, expert_weights
+
+
+# Export classes and functions
+__all__ = [
+ "MegaBlocksMoeMLP",
+ "xpu_fused_moe",
+ "cutlass_grouped_gemm",
+ "cutlass_grouped_gemm_xe2",
+]
\ No newline at end of file
diff --git a/build/torch211-cxx11-cu126-aarch64-linux/__init__.py b/build/torch211-cxx11-cu126-aarch64-linux/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..38075732c6d8fa0e1e6ef493145e1aca3851ae6b
--- /dev/null
+++ b/build/torch211-cxx11-cu126-aarch64-linux/__init__.py
@@ -0,0 +1,202 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+from ._ops import ops
+
+from .grouped_gemm import backend as gg_backend
+from .grouped_gemm import ops as gg_ops
+
+
+from ._layers.arguments import Arguments
+from ._layers.dmoe import ParallelDroplessMLP, dMoE
+from ._layers.glu import SparseGLU
+from ._layers.mlp import MLP, SparseMLP
+from ._layers.moe import MoE, ParallelMLP, get_load_balancing_loss
+
+from . import layers
+
+# This section contains the direct kernel exports (not inlcuded in the original code)
+def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
+ """
+ Compute exclusive cumulative sum along the specified dimension.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ result = ops.exclusive_cumsum(x, dim)
+ out.copy_(result)
+ return out
+
+
+def inclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
+ """
+ Compute inclusive cumulative sum along the specified dimension.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ result = ops.inclusive_cumsum(x, dim)
+ out.copy_(result)
+ return out
+
+
+def histogram(x: torch.Tensor, num_bins: int) -> torch.Tensor:
+ """
+ Compute histogram of input tensor values.
+
+ Args:
+ x: Input tensor
+ num_bins: Number of histogram bins
+
+ Returns:
+ Histogram tensor with counts for each bin
+ """
+ return ops.histogram(x, num_bins)
+
+
+def indices(
+ padded_bins: torch.Tensor,
+ block_size: int,
+ output_block_rows: int,
+ output_block_columns: int,
+) -> torch.Tensor:
+ """
+ Construct indices from padded bins for sparse operations.
+
+ Args:
+ padded_bins: Tensor containing bin boundaries
+ block_size: Size of each block
+ output_block_rows: Number of rows in output blocks
+ output_block_columns: Number of columns in output blocks
+
+ Returns:
+ Tensor containing constructed indices
+ """
+ return ops.indices(padded_bins, block_size, output_block_rows, output_block_columns)
+
+
+def replicate_forward(
+ x: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Forward pass of replicate operation - replicate values according to bin sizes.
+
+ Args:
+ x: Input tensor with values to replicate
+ bins: Tensor containing bin sizes
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ return ops.replicate_forward(x, bins, out)
+
+
+def replicate_backward(
+ grad: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Backward pass of replicate operation - reduce gradients back to bins.
+
+ Args:
+ grad: Gradient tensor to reduce
+ bins: Tensor containing bin sizes
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ return ops.replicate_backward(grad, bins, out)
+
+
+def sort(
+ x: torch.Tensor, end_bit: int, x_out: torch.Tensor, iota_out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Radix sort with index tracking.
+
+ Args:
+ x: Input tensor to sort
+ end_bit: Number of bits to consider in sorting
+ x_out: Output tensor for sorted values
+ iota_out: Output tensor for sorted indices
+
+ Returns:
+ The sorted values tensor
+ """
+ return ops.sort(x, end_bit, x_out, iota_out)
+
+
+# Convenience functions for common use cases
+def cumsum(x: torch.Tensor, dim: int = -1, exclusive: bool = False) -> torch.Tensor:
+ """
+ Compute cumulative sum with automatic output allocation.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum (default: last dimension)
+ exclusive: Whether to compute exclusive (True) or inclusive (False) cumsum
+
+ Returns:
+ New tensor containing the cumulative sum
+ """
+ out = torch.empty_like(x)
+ if exclusive:
+ return exclusive_cumsum(x, dim, out)
+ else:
+ return inclusive_cumsum(x, dim, out)
+
+
+def argsort(x: torch.Tensor, end_bit: int = 32) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Sort tensor and return both sorted values and indices.
+
+ Args:
+ x: Input tensor to sort
+ end_bit: Number of bits to consider in sorting
+
+ Returns:
+ Tuple of (sorted_values, sorted_indices)
+ """
+ x_out = torch.empty_like(x)
+ iota_out = torch.empty_like(x)
+ sort(x, end_bit, x_out, iota_out)
+ return x_out, iota_out
+
+
+# Export public API
+__all__ = [
+ "MyReplacementLayer",
+ # Direct kernel exports
+ "exclusive_cumsum",
+ "inclusive_cumsum",
+ "histogram",
+ "indices",
+ "replicate_forward",
+ "replicate_backward",
+ "sort",
+ "cumsum",
+ "argsort",
+ # Original exports
+ "Arguments",
+ "ParallelDroplessMLP",
+ "dMoE",
+ "SparseGLU",
+ "MLP",
+ "SparseMLP",
+ "MoE",
+ "ParallelMLP",
+ "get_load_balancing_loss",
+]
diff --git a/build/torch211-cxx11-cu126-aarch64-linux/_layers/__init__.py b/build/torch211-cxx11-cu126-aarch64-linux/_layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a720e7a2cc4e44636f6e433a2750e945dc38e8b2
--- /dev/null
+++ b/build/torch211-cxx11-cu126-aarch64-linux/_layers/__init__.py
@@ -0,0 +1,10 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# from megablocks.layers.dmoe import dMoE
+from .moe import MoE
+
+__all__ = [
+ 'MoE',
+ # 'dMoE',
+]
diff --git a/build/torch211-cxx11-cu126-aarch64-linux/_layers/activation_fn.py b/build/torch211-cxx11-cu126-aarch64-linux/_layers/activation_fn.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e1d956704840aa4daf7d1d71d24e051567feab9
--- /dev/null
+++ b/build/torch211-cxx11-cu126-aarch64-linux/_layers/activation_fn.py
@@ -0,0 +1,33 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Callable, Union
+
+import torch
+from ..stk import Matrix
+
+
+def act_fn(
+ x: Matrix,
+ function: Callable,
+ return_grad_fn: bool = False,
+ **kwargs,
+) -> Union[tuple[Matrix, Any] | Matrix]:
+ assert isinstance(x, Matrix)
+ with torch.set_grad_enabled(torch.is_grad_enabled() or return_grad_fn):
+ if return_grad_fn:
+ x.data.requires_grad = True
+ out = function(x.data, **kwargs)
+ y = Matrix(
+ x.size(),
+ out,
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
+ if return_grad_fn:
+ return y, out.backward
+ return y
diff --git a/build/torch211-cxx11-cu126-aarch64-linux/_layers/all_to_all.py b/build/torch211-cxx11-cu126-aarch64-linux/_layers/all_to_all.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ac7067bcaa34db1d82b340c43550fe3577aa7a3
--- /dev/null
+++ b/build/torch211-cxx11-cu126-aarch64-linux/_layers/all_to_all.py
@@ -0,0 +1,54 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import torch.distributed as dist
+
+
+class AllToAllOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op):
+ out = torch.empty((sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype)
+
+ ctx.input_shape = x.shape
+ ctx.output_split_sizes = output_split_sizes
+ ctx.input_split_sizes = input_split_sizes
+ ctx.group = group
+ handle = dist.all_to_all_single(
+ out,
+ x,
+ output_split_sizes=output_split_sizes,
+ input_split_sizes=input_split_sizes,
+ group=group,
+ async_op=async_op,
+ )
+ return out, handle
+
+ @staticmethod
+ def backward(ctx, grad, _):
+ if ctx.needs_input_grad[0]:
+ out = torch.empty(
+ ctx.input_shape,
+ device=grad.device,
+ dtype=grad.dtype,
+ )
+ dist.all_to_all_single(
+ out,
+ grad,
+ output_split_sizes=ctx.input_split_sizes,
+ input_split_sizes=ctx.output_split_sizes,
+ group=ctx.group,
+ )
+ return out, None, None, None, None
+ return None, None, None, None, None
+
+
+def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False):
+ return AllToAllOp.apply(
+ x,
+ output_split_sizes,
+ input_split_sizes,
+ group,
+ async_op,
+ )
diff --git a/build/torch211-cxx11-cu126-aarch64-linux/_layers/arguments.py b/build/torch211-cxx11-cu126-aarch64-linux/_layers/arguments.py
new file mode 100644
index 0000000000000000000000000000000000000000..4db9b1bd38bc2e2f421625c124f86b85f45c5ae0
--- /dev/null
+++ b/build/torch211-cxx11-cu126-aarch64-linux/_layers/arguments.py
@@ -0,0 +1,101 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import dataclasses
+from functools import partial
+from typing import Any, Callable, Optional, Union
+
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+
+# import megablocks.grouped_gemm_util as grouped_gemm
+from .. import grouped_gemm_util as grouped_gemm
+
+# Type annotation for in-place Tensor initialization function.
+InitFn = Union[Callable[[torch.Tensor], None], partial[torch.Tensor]]
+
+_ALLOWED_BITWIDTHS = (-1, 4, 8)
+
+DEFAULT_ACTIVATION_FN = partial(F.gelu, approximate='tanh')
+
+
+@dataclasses.dataclass
+class Arguments:
+ # Model arguments.
+ hidden_size: int = 1024
+ ffn_hidden_size: int = 4096
+ num_layers: int = 1
+ bias: bool = True
+ return_bias: bool = True
+ activation_fn: Optional[Callable] = DEFAULT_ACTIVATION_FN
+
+ # MoE arguments.
+ moe_num_experts: int = 1
+ moe_top_k: int = 1
+ moe_capacity_factor: int = 1
+ moe_normalize_expert_weights: Optional[Union[int, float]] = None
+ moe_loss_weight: float = 0.1
+ moe_jitter_eps: Optional[float] = None
+ moe_lbl_in_fp32: bool = False
+
+ # Parallelism arguments.
+ moe_expert_model_parallelism: bool = False
+ expert_parallel_group: Optional[dist.ProcessGroup] = None
+ pipeline_model_parallel_size: int = 1
+ num_layers_per_virtual_pipeline_stage: Optional[int] = None
+
+ # Compute arguments.
+ memory_optimized_mlp: bool = False
+ mlp_type: str = 'mlp'
+ mlp_impl: str = 'sparse'
+
+ # Initialization arguments.
+ fp16: bool = True
+ bf16: bool = False
+ device: Union[int, torch.device] = dataclasses.field(default_factory=torch.cuda.current_device)
+ init_method: InitFn = partial(torch.nn.init.normal_, mean=0.0, std=0.02)
+ output_layer_init_method: InitFn = init_method
+
+ # Benchmarking arguments.
+ uniform_expert_assignment: bool = False
+
+ # shared expert arguments
+ shared_expert: bool = False # enable using shared expert
+ fc_cls: Any = torch.nn.Linear # class of the fully connected layer in shared expert (purpose: to allow using custom FC layer eg te.Linear (for FP8))
+ fc_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict,) # kwargs for custom fc layers
+ remat_act_fn: bool = True # enable act fn to be rematerialized instead of stored
+ shared_expert_hidden_size: Optional[
+ int] = None # hidden size of the shared expert IF we want to set it to something different from hidden_size
+ shared_expert_weighted_sum: bool = False # enable using weighted sum for shared expert output (wieghted by number of experts used)
+
+ # Router Z-loss arguments
+ moe_zloss_weight: float = 0 # 1e-3 is a reasonable value
+ moe_zloss_in_fp32: bool = False
+
+ def __post_init__(self):
+ # Sparse MLP is not supported with triton >=3.2.0
+ # TODO: Remove this once sparse is supported with triton >=3.2.0
+ if self.__getattribute__('mlp_impl') == 'sparse':
+ try:
+ import triton
+ if triton.__version__ >= '3.2.0':
+ raise ValueError(
+ 'Sparse MLP is not supported with triton >=3.2.0. Please use mlp_impl="grouped" instead.',
+ )
+ except ImportError:
+ raise ImportError('Triton is required for sparse MLP implementation')
+
+ if self.__getattribute__('mlp_impl') == 'grouped':
+ grouped_gemm.assert_grouped_gemm_is_available()
+
+ if self.shared_expert_hidden_size is None:
+ self.shared_expert_hidden_size = self.ffn_hidden_size
+
+
+def from_megatron(megatron_args: Any):
+ args = Arguments()
+ for field in dataclasses.fields(args):
+ if hasattr(megatron_args, field.name):
+ setattr(args, field.name, getattr(megatron_args, field.name))
+ return args
diff --git a/build/torch211-cxx11-cu126-aarch64-linux/_layers/common.py b/build/torch211-cxx11-cu126-aarch64-linux/_layers/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d07109702963ba48a3b94ab860807954dfd79c1
--- /dev/null
+++ b/build/torch211-cxx11-cu126-aarch64-linux/_layers/common.py
@@ -0,0 +1,26 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+from .arguments import Arguments
+
+
+def dtype(args: Arguments):
+ if args.fp16:
+ return torch.float16
+ elif args.bf16:
+ return torch.bfloat16
+ return None
+
+
+def cast_if_autocast_enabled(tensor):
+ if torch.is_autocast_enabled():
+ if tensor.device.type == 'cuda':
+ dtype = torch.get_autocast_gpu_dtype()
+ elif tensor.device.type == 'cpu':
+ dtype = torch.get_autocast_cpu_dtype()
+ else:
+ raise NotImplementedError()
+ return tensor.to(dtype=dtype)
+ return tensor
diff --git a/build/torch211-cxx11-cu126-aarch64-linux/_layers/dmlp_registry.py b/build/torch211-cxx11-cu126-aarch64-linux/_layers/dmlp_registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..de2ed047042e438c7190ebb139b6f7f30009734c
--- /dev/null
+++ b/build/torch211-cxx11-cu126-aarch64-linux/_layers/dmlp_registry.py
@@ -0,0 +1,42 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Union
+
+from . import glu, mlp
+from .arguments import Arguments
+
+MlpType = Union[mlp.SparseMLP, glu.SparseGLU]
+
+_REGISTRY = {
+ 'mlp': {
+ 'grouped': mlp.GroupedMLP,
+ 'sparse': mlp.SparseMLP,
+ },
+ 'glu': {
+ 'grouped': glu.GroupedGLU,
+ 'sparse': glu.SparseGLU,
+ },
+}
+
+
+def get(args: Arguments) -> MlpType:
+ """Returns an MLP for use in a dMoE instance.
+
+ Uses the provided arguments to instantiate the appropriate
+ MLP instance. This only contains MLPs for use in dMoEs
+ (ie. only for the dropless versions of MoEs).
+
+ Args:
+ args: propagated Arguments dataclass.
+
+ Returns:
+ An instantiated MLP constructed using the input args.
+ """
+ if args.mlp_type not in _REGISTRY:
+ raise ValueError(f'Unsupported mlp type: {args.mlp_type}')
+
+ if args.mlp_impl not in _REGISTRY[args.mlp_type]:
+ raise ValueError(f'{args.mlp_type} does not support {args.mlp_impl} backend.',)
+
+ return _REGISTRY[args.mlp_type][args.mlp_impl](args)
diff --git a/build/torch211-cxx11-cu126-aarch64-linux/_layers/dmoe.py b/build/torch211-cxx11-cu126-aarch64-linux/_layers/dmoe.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d0375a4df2f27134c4127e60be04f3b45693050
--- /dev/null
+++ b/build/torch211-cxx11-cu126-aarch64-linux/_layers/dmoe.py
@@ -0,0 +1,337 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import numpy as np
+import torch
+
+# try:
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/dmoe.py is needed.',
+# )
+
+# import megablocks.ops as ops
+# # from megablocks.ops import ops
+# from megablocks.layers import common, dmlp_registry, moe, mpu
+# from megablocks.layers.arguments import Arguments
+
+from .. import stk
+from .. import ops
+from . import common, dmlp_registry, moe, mpu
+from .arguments import Arguments
+
+def promote_scalar(x):
+ return x.view(1) if not len(x.size()) else x
+
+
+class ParallelDroplessMLP(moe.ParallelMLP):
+
+ def __init__(self, args: Arguments):
+ super(ParallelDroplessMLP, self).__init__(args)
+ self.hidden_size = args.hidden_size
+ self.ffn_hidden_size = mpu.features_per_rank(args)
+ self.blocking = 128
+ self.mlp = dmlp_registry.get(args)
+
+ # Calculate the number of bits needed to represent the column indices
+ # in the intermediate sparse matrix.
+ max_column_index = ((self.ffn_hidden_size * self.num_experts) // self.blocking)
+ self.transpose_sort_end_bit = max(
+ int(np.ceil(np.log2(max_column_index))),
+ 1,
+ )
+
+ def sparse_transpose(self, size, row_indices, column_indices, offsets):
+ block_columns = size[1] // self.blocking
+
+ # Sort row indices by column indices to get the transposed matrix's
+ # column indices.
+ #
+ # NOTE: Our sort operation uses the same width indices as the input values.
+ # To avoid overflow when we have large activation matrices we cast to
+ # 32-bit before sorting.
+ _, gather_indices = ops.sort(
+ column_indices.int(),
+ self.transpose_sort_end_bit,
+ )
+
+ # There are a constant number of blocks in every row of the sparse matrix.
+ # A blocks offset is:
+ #
+ # row_index * blocks_per_row + column_index % blocks_per_row
+ #
+ # Once we have the block offsets ordered for transposition we can divide
+ # by blocks_per_row to get the transposed column indices.
+ column_indices_t = row_indices.gather(0, gather_indices.long())
+ block_offsets_t = gather_indices.int()
+
+ zero = torch.zeros((1,), dtype=torch.int32, device=row_indices.device)
+ nnz_per_column = ops.histogram(column_indices, block_columns)
+ nnz_per_column = ops.inclusive_cumsum(nnz_per_column, 0)
+ if nnz_per_column.dim() == 0:
+ # This addresses an edge case when ffn_hidden_size is equal to self.blocking.
+ nnz_per_column = nnz_per_column.unsqueeze(0)
+ offsets_t = torch.cat([zero, nnz_per_column])
+ return column_indices_t, offsets_t, block_offsets_t
+
+ def topology(self, x, padded_bins):
+ padded_tokens, _ = x.size()
+ assert padded_tokens % self.blocking == 0
+ if self.ffn_hidden_size % self.blocking != 0:
+ raise ValueError(
+ f'The ffn_hidden_size {self.ffn_hidden_size} must be divisible by ' +
+ f'the block size {self.blocking}. Please update your configuration.',
+ )
+
+ # Offsets for the sparse matrix. All rows have the
+ # same number of nonzero blocks dictated by the
+ # dimensionality of a single expert.
+ block_rows = padded_tokens // self.blocking
+ blocks_per_row = self.ffn_hidden_size // self.blocking
+ offsets = torch.arange(
+ 0,
+ block_rows * blocks_per_row + 1,
+ blocks_per_row,
+ dtype=torch.int32,
+ device=x.device,
+ )
+
+ # Indices for the sparse matrix. The indices for
+ # the intermediate matrix are dynamic depending
+ # on the mapping of tokens to experts.
+ column_indices = ops.topology(
+ padded_bins,
+ self.blocking,
+ block_rows,
+ blocks_per_row,
+ )
+
+ # TODO(tgale): This is unused. Remove the need for this in stk.
+ # For now, use meta init to save the device memory.
+ data = torch.empty(
+ column_indices.numel(),
+ self.blocking,
+ self.blocking,
+ dtype=common.dtype(self.args),
+ device='meta',
+ )
+ shape = (
+ padded_tokens,
+ self.ffn_hidden_size * mpu.experts_per_rank(self.args),
+ )
+ row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
+ column_indices_t, offsets_t, block_offsets_t = self.sparse_transpose(
+ shape,
+ row_indices,
+ column_indices,
+ offsets,
+ )
+ return stk.Matrix(
+ shape,
+ data,
+ row_indices,
+ column_indices,
+ offsets,
+ column_indices_t,
+ offsets_t,
+ block_offsets_t,
+ )
+
+ def indices_and_padded_bins(self, top_experts):
+ # Sort the expert ids to produce the scatter/gather
+ # indices for the permutation.
+ top_experts = top_experts.int()
+ bin_ids, indices = ops.sort(top_experts, self.sort_end_bit)
+
+ # Histogram the expert ids to identify the number of
+ # tokens routed to each expert.
+ tokens_per_expert = ops.histogram(top_experts, self.num_experts)
+
+ # Round the token counts up to the block size used in
+ # the matrix muliplications. Caculate the starting
+ # position of each bin.
+ padded_tokens_per_expert = ops.round_up(
+ tokens_per_expert,
+ self.blocking,
+ )
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ padded_bins = promote_scalar(padded_bins)
+
+ # Calculate the bin bounds for the sorted tokens.
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ bins = promote_scalar(bins)
+ return indices, bin_ids, bins, padded_bins, tokens_per_expert
+
+ def sparse_forward_once(self, x, expert_weights, top_experts):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, padded_bins, tokens_per_expert = (self.indices_and_padded_bins(top_experts))
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.padded_gather(
+ x,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ self.top_k,
+ )
+
+ # Create the sparse matrix topology.
+ with torch.no_grad():
+ topo = self.topology(x, padded_bins)
+
+ # Perform the expert computation.
+ x = self.mlp(x, topo)
+
+ # Un-route the data for the MoE output.
+ x = ops.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ padded_bins,
+ self.top_k,
+ )
+ return x, tokens_per_expert
+
+ # For use in the base-class parallel_forward_once.
+ def sparse_permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy, # unused
+ top_k,
+ ):
+
+ # Round the token counts up to the block size used in the matrix
+ # multiplication. Calculate the starting position of each bin.
+ padded_tokens_per_expert = ops.round_up(
+ tokens_per_expert,
+ self.blocking,
+ )
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ padded_bins = promote_scalar(padded_bins)
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
+
+ # Create the sparse matrix topology.
+ with torch.no_grad():
+ topo = self.topology(x, padded_bins)
+
+ # Perform the expert computation.
+ x = self.mlp(x, topo)
+
+ # Un-route the data for the MoE output.
+ return ops.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ def grouped_forward_once(self, x, expert_weights, top_experts):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ out = self.grouped_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ -1, # unused
+ self.args.moe_top_k,
+ )
+ return out, tokens_per_expert
+
+ def grouped_permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy, # unused
+ top_k,
+ ):
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.gather(x, indices, bin_ids, bins, top_k)
+
+ # Perform the expert computation.
+ x = self.mlp(x, tokens_per_expert)
+
+ # Un-route the data for the MoE output.
+ return ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
+
+ def forward_once(self, x, expert_weights, top_experts):
+ if self.args.mlp_impl == 'sparse':
+ return self.sparse_forward_once(x, expert_weights, top_experts)
+ else:
+ return self.grouped_forward_once(x, expert_weights, top_experts)
+
+ def permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ ):
+ if self.args.mlp_impl == 'sparse':
+ return self.sparse_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ )
+ else:
+ return self.grouped_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ )
+
+
+class dMoE(moe.MoE):
+
+ def _init_experts_mlp(self, args: Arguments):
+ return ParallelDroplessMLP(args)
diff --git a/build/torch211-cxx11-cu126-aarch64-linux/_layers/gelu.py b/build/torch211-cxx11-cu126-aarch64-linux/_layers/gelu.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4c9e6532798615b5c12c96694241a4c18ee8f7b
--- /dev/null
+++ b/build/torch211-cxx11-cu126-aarch64-linux/_layers/gelu.py
@@ -0,0 +1,52 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# try:
+# import stk
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/gelu.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+import torch.nn.functional as F
+
+
+@torch.jit.script
+def _gelu_backward_inplace(g, x):
+ tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
+ ff = (0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out))
+ return g.mul_(ff)
+
+
+def gelu_backward_(grad: stk.Matrix, x: stk.Matrix):
+ # NOTE: The two sparse matrices must have the same topology.
+ if isinstance(grad, stk.Matrix) and isinstance(x, stk.Matrix):
+ return stk.Matrix(
+ x.size(),
+ _gelu_backward_inplace(grad.data, x.data),
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
+ return _gelu_backward_inplace(grad, x)
+
+
+def gelu(x: stk.Matrix):
+ assert isinstance(x, stk.Matrix)
+ return stk.Matrix(
+ x.size(),
+ F.gelu(x.data, approximate='tanh'),
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
diff --git a/build/torch211-cxx11-cu126-aarch64-linux/_layers/glu.py b/build/torch211-cxx11-cu126-aarch64-linux/_layers/glu.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f297a41ff6a1a2a285f5b461951672364b898da
--- /dev/null
+++ b/build/torch211-cxx11-cu126-aarch64-linux/_layers/glu.py
@@ -0,0 +1,244 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# import stk.ops
+# try:
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/glu.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+
+# from megablocks import grouped_gemm_util as gg
+# from megablocks.layers import common, mpu
+# from megablocks.layers.activation_fn import act_fn
+# from megablocks.layers.arguments import Arguments
+# from megablocks.layers.mlp import (
+# SharedMLP,
+# SparseMLP,
+# create_dmoe_expert_weights,
+# resolve_dtensor,
+# )
+
+from .. import grouped_gemm_util as gg
+from . import common, mpu
+from .activation_fn import act_fn
+from .arguments import Arguments
+from .mlp import (
+ SharedMLP,
+ SparseMLP,
+ create_dmoe_expert_weights,
+ resolve_dtensor,
+)
+
+
+class SparseGLU(SparseMLP):
+
+ def __init__(self, args: Arguments):
+ super().__init__(args)
+ self.v1 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ with torch.no_grad():
+ self.v1.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ ),
+ )
+
+ mpu.set_expert_model_parallel_attributes(
+ self.v1,
+ self._should_set_parallelism_attribute,
+ )
+
+ def forward(self, x, topo):
+ if self.args.memory_optimized_mlp:
+ raise NotImplementedError(
+ 'Memory optimized implementation not yet supported with GLU with sparse kernels.',
+ )
+
+ w1, v1, w2 = self.scale_grad(self.w1), self.scale_grad(self.v1,), self.scale_grad(self.w2)
+ w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2)
+
+ # Compute the GLU.
+ x1 = stk.ops.sdd(x, w1.t(), topo)
+ x2 = stk.ops.sdd(x, v1.t(), topo)
+
+ activation_fn_out = act_fn(x1, self.args.activation_fn)
+ x1 = stk.ops.mul(activation_fn_out, x2)
+
+ return stk.ops.dsd(x1, w2)
+
+
+class MemoryOptimizedGroupedGLU(torch.autograd.Function):
+ """GroupedMLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, v1, w2, batch_sizes, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ v1 = v1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], v1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not v1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1', 'v1' and 'w2'.")
+
+ # Layer 0: x @ w1.t().
+ assert gg.backend is not None
+ sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True)
+ v1_out = gg.backend.gmm(x, v1, batch_sizes, trans_b=True)
+
+ # GeLU.
+ activation_fn_out = activation_fn(sdd_out) * v1_out
+
+ # Layer 1: x @ w2.
+ dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, v1, w2, batch_sizes, x, sdd_out, v1_out)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx, ddsd_out):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # Unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, v1, w2 = saved_tensors[:3]
+ batch_sizes = saved_tensors[3]
+ x = saved_tensors[4]
+ sdd_out, v1_out = saved_tensors[5:7]
+
+ # Rematerialize activation_fn output.
+ activation_fn = ctx.activation_fn
+ with torch.set_grad_enabled(True):
+ sdd_out.requires_grad = True
+ v1_out.requires_grad = True
+ activation_fn_out = activation_fn(sdd_out) * v1_out
+ activation_grad_fn = activation_fn_out.backward
+
+ # Compute dw2 with recomputed activation_fn output.
+ assert gg.backend is not None
+ dw2 = gg.backend.gmm(
+ activation_fn_out,
+ ddsd_out,
+ batch_sizes,
+ trans_a=True,
+ )
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ gg.backend.gmm(
+ ddsd_out,
+ w2,
+ batch_sizes,
+ trans_b=True,
+ c=dactivation_fn_out,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out)
+ dsdd_out = sdd_out.grad
+ dv1_out = v1_out.grad
+
+ # Compute dw1.
+ dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True)
+
+ # Compute dv1.
+ dv1 = gg.backend.gmm(dv1_out, x, batch_sizes, trans_a=True)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ dx = ddsd_out
+ gg.backend.gmm(dsdd_out, w1, batch_sizes, c=dx)
+ dx += gg.backend.gmm(dv1_out, v1, batch_sizes)
+ return dx, dw1, dv1, dw2, None, None
+
+
+memory_optimized_grouped_glu = MemoryOptimizedGroupedGLU.apply
+
+
+class GroupedGLU(SparseGLU):
+
+ def forward(self, x, tokens_per_expert):
+ batch_sizes = tokens_per_expert.cpu().to(torch.long)
+ w1, v1, w2 = (
+ self.scale_grad(self.w1),
+ self.scale_grad(self.v1),
+ self.scale_grad(self.w2),
+ )
+ w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2)
+
+ # Re-shape the weights for the grouped GEMMs.
+ ne = mpu.experts_per_rank(self.args)
+ w1 = w1.view(ne, -1, self.args.hidden_size)
+ v1 = v1.view(ne, -1, self.args.hidden_size)
+ w2 = w2.view(ne, -1, self.args.hidden_size)
+
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_grouped_glu(
+ x,
+ w1,
+ v1,
+ w2,
+ batch_sizes,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ assert gg.ops is not None
+ x1 = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
+ x2 = gg.ops.gmm(x, v1, batch_sizes, trans_b=True)
+ x1 = self.args.activation_fn(x1) * x2
+ return gg.ops.gmm(x1, w2, batch_sizes)
+
+
+class SharedGLU(SharedMLP):
+ """GPU for shared expert.
+
+ Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTGLU class
+ """
+
+ def __init__(self, args: Arguments):
+ super().__init__(args)
+ self.gate_proj = args.fc_cls(
+ args.hidden_size,
+ self.args.shared_expert_hidden_size,
+ **self.fc_kwargs,
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))
diff --git a/build/torch211-cxx11-cu126-aarch64-linux/_layers/memory_test.py b/build/torch211-cxx11-cu126-aarch64-linux/_layers/memory_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..74d1166931b712635131985b25a89f4ca23e576d
--- /dev/null
+++ b/build/torch211-cxx11-cu126-aarch64-linux/_layers/memory_test.py
@@ -0,0 +1,103 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import gc
+
+import torch
+import torch.distributed as dist
+
+# from megablocks.layers import arguments, dmoe
+from . import arguments, dmoe
+
+_TESTS = ((8, 2048, 4096, 4096, 32, 4),)
+
+
+def get_tensors():
+ ptrs = set()
+ out = []
+ for obj in gc.get_objects():
+ if torch.is_tensor(obj):
+ if not obj.is_contiguous() or obj.data_ptr() in ptrs:
+ continue
+ out.append(obj)
+ ptrs.add(obj.data_ptr())
+ return out
+
+
+def test_memory(
+ group,
+ batch_size,
+ sequence_length,
+ hidden_size,
+ ffn_hidden_size,
+ num_experts,
+ top_k,
+):
+ args = arguments.Arguments(
+ hidden_size=hidden_size,
+ ffn_hidden_size=ffn_hidden_size,
+ moe_num_experts=num_experts,
+ moe_top_k=top_k,
+ moe_expert_model_parallelism=True,
+ expert_parallel_group=group,
+ fp16=False,
+ bf16=True,
+ device=torch.cuda.current_device(),
+ )
+ layer = dmoe.dMoE(args).cuda()
+
+ x = torch.randn((batch_size, sequence_length, hidden_size),
+ device=torch.cuda.current_device(),
+ dtype=torch.bfloat16).requires_grad_(True)
+ torch.cuda.empty_cache()
+
+ # Run forward + backward.
+ # with torch.autograd.detect_anomaly():
+ out, _ = layer(x)
+ out.mean().backward()
+
+ # Report peak memory.
+ mem = torch.cuda.max_memory_allocated()
+ print('Max Memory Allocated = {:0.0f}MiB'.format(mem / 1e6))
+ print('Max Memory Reserved = {:0.0f}MiB'.format(torch.cuda.max_memory_reserved() / 1e6,),)
+
+ # Calculate weight and gradient memory usage.
+ weight_memory = 2 * (
+ layer.router.layer.weight.numel() + layer.experts.mlp.w1.numel() + layer.experts.mlp.w2.numel()
+ )
+
+ def grad_numel(x):
+ if x.grad is not None:
+ return x.grad.numel()
+ return 0
+
+ grad_memory = 2 * (
+ grad_numel(layer.router.layer.weight) + grad_numel(layer.experts.mlp.w1) + grad_numel(layer.experts.mlp.w2)
+ )
+ weight_memory += grad_memory
+
+ print('Weight Memory Allocated = {:0.0f}MiB'.format(weight_memory / 1e6))
+ print('Activation Memory Allocated = {:0.0f}MiB'.format((mem - weight_memory) / 1e6,),)
+
+ # Manually calculate GPU memory usage from the garbage
+ # collector.
+ gc.collect()
+ total = 0
+ tensors = get_tensors()
+ tensors = sorted(tensors, key=lambda x: -x.numel())
+ for i, t in enumerate(tensors):
+ total += t.numel()
+ print(f'{i}: {t.shape}, {t.numel() * 2}')
+ del tensors
+
+ print('Total Bytes Found = {:0.0f}MiB'.format(total * 2 / 1e6))
+
+
+if __name__ == '__main__':
+ assert dist.is_available()
+ group = dist.init_process_group(backend='nccl')
+ local_rank = dist.get_rank(group)
+ torch.cuda.set_device(local_rank)
+
+ for args in _TESTS:
+ test_memory(group, *args)
diff --git a/build/torch211-cxx11-cu126-aarch64-linux/_layers/mlp.py b/build/torch211-cxx11-cu126-aarch64-linux/_layers/mlp.py
new file mode 100644
index 0000000000000000000000000000000000000000..c99afb9904c24a8b6a83e79059cd1251dbbfd99e
--- /dev/null
+++ b/build/torch211-cxx11-cu126-aarch64-linux/_layers/mlp.py
@@ -0,0 +1,587 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# try:
+# import stk
+# import stk.backend.triton_kernels
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/mlp.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+from packaging import version
+
+# from megablocks import grouped_gemm_util as gg
+# from megablocks.layers import common, gelu, mpu
+# from megablocks.layers.activation_fn import act_fn
+# from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
+
+from .. import grouped_gemm_util as gg
+from . import common, gelu, mpu
+from .activation_fn import act_fn
+from .arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
+
+class ScaleGradient(torch.autograd.Function):
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx: Any, x: torch.Tensor, scale: float):
+ ctx.scale = scale
+ return x
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx: torch.Tensor, grad: torch.Tensor):
+ return grad * ctx.scale, None
+
+
+scale_gradient = ScaleGradient.apply
+
+
+def resolve_dtensor(weight: torch.Tensor):
+ if version.parse(torch.__version__) >= version.parse('2.0.0'):
+ from torch.distributed._tensor import DTensor
+ if isinstance(weight, DTensor):
+ return weight.to_local()
+ return weight
+
+
+def create_moe_expert_weights(
+ args: Arguments,
+ num_experts: int,
+ ffn_hidden_size: int,
+ hidden_size: int,
+ init_method: InitFn,
+):
+ # Create the entire weight matrix such that the sampled weights will
+ # not vary between data parallelism and expert model parallelism for
+ # the same random seed.
+ master_weights = torch.empty(
+ num_experts,
+ ffn_hidden_size,
+ hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ )
+ init_method(master_weights)
+
+ if not args.moe_expert_model_parallelism:
+ return master_weights
+
+ # Calculate the amount of sharding in each dimension.
+ expert_sharding_degree = mpu.expert_sharding_degree(args)
+ hidden_sharding_degree = mpu.hidden_sharding_degree(args)
+
+ # Calculate the experts per rank.
+ #
+ # NOTE: We assign ranks to be expert parallel before going
+ # tensor parallel.
+ rank = mpu.get_expert_parallel_rank(args)
+ expert_rank = rank % expert_sharding_degree
+ num_experts_per_rank = num_experts // expert_sharding_degree
+ start_expert = expert_rank * num_experts_per_rank
+ end_expert = (expert_rank + 1) * num_experts_per_rank
+
+ # Calculate the rows per rank.
+ row_rank = rank // expert_sharding_degree
+ num_rows_per_rank = ffn_hidden_size // hidden_sharding_degree
+ start_row = row_rank * num_rows_per_rank
+ end_row = (row_rank + 1) * num_rows_per_rank
+
+ # Slice the weight matrix to get the chunk for this rank.
+ with torch.no_grad():
+ weights = master_weights[start_expert:end_expert, start_row:end_row]
+ return weights
+
+
+class MLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ # expert_parallel_world_size = mpu.get_expert_parallel_world_size(args)
+ experts_per_rank = mpu.experts_per_rank(args)
+
+ self.w1 = torch.nn.Parameter(
+ torch.empty(
+ experts_per_rank,
+ args.hidden_size,
+ mpu.features_per_rank(args),
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ self.w2 = torch.nn.Parameter(
+ torch.empty(
+ experts_per_rank,
+ mpu.features_per_rank(args),
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w1,
+ args.moe_expert_model_parallelism,
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w2,
+ args.moe_expert_model_parallelism,
+ )
+
+ # Initialize the parameters for the MLP.
+ #
+ # NOTE: It is important that we create the weight tensors prior
+ # to creating the master weights and slicing our the piece for
+ # this rank. If the master weights are created first the PyTorch
+ # caching allocator appears to use the same memory block for these
+ # and the slice which causes large increases in our peak memory
+ # usage.
+ with torch.no_grad():
+ w1 = create_moe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ )
+ self.w1.copy_(w1.transpose(1, 2).contiguous())
+ self.w2.copy_(
+ create_moe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.output_layer_init_method,
+ ),
+ )
+
+ self.gradient_scale = None
+ if self.args.moe_expert_model_parallelism:
+ self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,)
+
+ def scale_grad(self, w):
+ if self.gradient_scale is None:
+ return w
+ return scale_gradient(w, self.gradient_scale)
+
+ def forward(self, x):
+ w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2)
+ w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2)
+ x = torch.bmm(x, w1)
+ x = self.args.activation_fn(x)
+ return torch.bmm(x, w2)
+
+
+def create_dmoe_expert_weights(
+ args: Arguments,
+ num_experts: int,
+ rows: int,
+ columns: int,
+ init_method: InitFn,
+):
+ weights = create_moe_expert_weights(
+ args,
+ num_experts,
+ rows,
+ columns,
+ init_method,
+ )
+ return weights.view([-1, columns])
+
+
+class MemoryOptimizedMLP(torch.autograd.Function):
+ """Sparse MLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, w2, topo, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.")
+
+ topo_tensors = (
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets,
+ topo.column_indices_t,
+ topo.offsets_t,
+ topo.block_offsets_t,
+ )
+
+ # Layer 0: x @ w1.t().
+ sdd_out = stk.ops.sdd(x, w1.t(), topo)
+
+ # GeLU.
+ activation_fn_out = act_fn(sdd_out, activation_fn)
+
+ # Layer 1: x @ w2.
+ dsd_out = stk.ops.dsd(activation_fn_out, w2)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.shape = topo.shape
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.data.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, w2, *topo_tensors, x, sdd_out.data)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx, ddsd_out):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, w2 = saved_tensors[:2]
+ topo_tensors = saved_tensors[2:8]
+ x = saved_tensors[8]
+ sdd_out_data = saved_tensors[9]
+
+ # rematerialize activation function output
+ activation_fn = ctx.activation_fn
+ sdd_out = stk.Matrix(ctx.shape, sdd_out_data, *topo_tensors)
+ activation_fn_out, activation_grad_fn = act_fn(
+ sdd_out,
+ activation_fn,
+ return_grad_fn=True,
+ )
+
+ # Compute dw2 with recomputed activation_fn output.
+ dw2 = stk.ops.dsd(activation_fn_out.t(), ddsd_out)
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ stk.backend.triton_kernels.sdd(
+ ddsd_out,
+ w2.t(),
+ dactivation_fn_out.shape,
+ dactivation_fn_out.data,
+ dactivation_fn_out.offsets,
+ dactivation_fn_out.row_indices,
+ dactivation_fn_out.column_indices,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ if activation_fn is DEFAULT_ACTIVATION_FN:
+ dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out)
+ else:
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out.data)
+ dsdd_out = stk.Matrix(ctx.shape, sdd_out.data.grad, *topo_tensors)
+
+ # Compute dw1.
+ dw1 = stk.ops.dsd(dsdd_out.t(), x)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ stk.backend.triton_kernels.dsd(
+ dsdd_out.shape,
+ dsdd_out.data,
+ dsdd_out.offsets,
+ dsdd_out.row_indices,
+ dsdd_out.column_indices,
+ dsdd_out.offsets_t,
+ dsdd_out.column_indices_t,
+ dsdd_out.block_offsets_t,
+ False,
+ w1,
+ ddsd_out,
+ )
+ dx = ddsd_out
+ return dx, dw1, dw2, None, None
+
+
+memory_optimized_mlp = MemoryOptimizedMLP.apply
+
+
+class SparseMLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ self._num_rows_per_rank = mpu.experts_per_rank(args) * mpu.features_per_rank(args)
+
+ self.w1 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ self.w2 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+
+ # Initialize the parameters for the MLP.
+ #
+ # NOTE: It is important that we create the weight tensors prior
+ # to creating the master weights and slicing our the piece for
+ # this rank. If the master weights are created first the PyTorch
+ # caching allocator appears to use the same memory block for these
+ # and the slice which causes large increases in our peak memory
+ # usage.
+ with torch.no_grad():
+ self.w1.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ ),
+ )
+ self.w2.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.output_layer_init_method,
+ ),
+ )
+
+ self._should_set_parallelism_attribute = args.moe_expert_model_parallelism
+ mpu.set_expert_model_parallel_attributes(
+ self.w1,
+ self._should_set_parallelism_attribute,
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w2,
+ self._should_set_parallelism_attribute,
+ )
+
+ self.gradient_scale = None
+ if self.args.moe_expert_model_parallelism:
+ self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,)
+
+ def scale_grad(self, w):
+ if self.gradient_scale is None:
+ return w
+ return scale_gradient(w, self.gradient_scale)
+
+ def forward(self, x, topo):
+ w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2)
+ w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2)
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_mlp(
+ x,
+ w1,
+ w2,
+ topo,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ x = stk.ops.sdd(x, w1.t(), topo)
+ activation_fn_out = act_fn(x, self.args.activation_fn)
+ return stk.ops.dsd(activation_fn_out, w2)
+
+
+class MemoryOptimizedGroupedMLP(torch.autograd.Function):
+ """GroupedMLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, w2, batch_sizes, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.")
+
+ # Layer 0: x @ w1.t().
+ assert gg.backend is not None
+ sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True)
+
+ # activation_fn
+ activation_fn_out = activation_fn(sdd_out)
+
+ # Layer 1: x @ w2.
+ dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, w2, batch_sizes, x, sdd_out)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx: Any, ddsd_out: torch.Tensor):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # Unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, w2 = saved_tensors[:2]
+ batch_sizes = saved_tensors[2]
+ x = saved_tensors[3]
+ sdd_out = saved_tensors[4]
+
+ # Rematerialize activation_fn output.
+ activation_fn = ctx.activation_fn
+ with torch.set_grad_enabled(True):
+ sdd_out.requires_grad = True
+ activation_fn_out = activation_fn(sdd_out)
+ activation_grad_fn = activation_fn_out.backward
+
+ # Compute dw2 with recomputed activation_fn output.
+ assert gg.backend is not None
+ dw2 = gg.backend.gmm(
+ activation_fn_out,
+ ddsd_out,
+ batch_sizes,
+ trans_a=True,
+ )
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ gg.backend.gmm(
+ ddsd_out,
+ w2,
+ batch_sizes,
+ trans_b=True,
+ c=dactivation_fn_out,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ if activation_fn is DEFAULT_ACTIVATION_FN:
+ dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out)
+ else:
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out)
+ dsdd_out = sdd_out.grad
+
+ # Compute dw1.
+ dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ gg.backend.gmm(dsdd_out, w1, batch_sizes, c=ddsd_out)
+ dx = ddsd_out
+ return dx, dw1, dw2, None, None
+
+
+memory_optimized_grouped_mlp = MemoryOptimizedGroupedMLP.apply
+
+
+class GroupedMLP(SparseMLP):
+
+ def forward(self, x, tokens_per_expert):
+ batch_sizes = tokens_per_expert.cpu().to(torch.long)
+ w1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.w2))
+
+ # Re-shape the weights for the grouped GEMMs.
+ ne = mpu.experts_per_rank(self.args)
+ w1 = resolve_dtensor(w1).view(ne, -1, self.args.hidden_size)
+ w2 = resolve_dtensor(w2).view(ne, -1, self.args.hidden_size)
+
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_grouped_mlp(
+ x,
+ w1,
+ w2,
+ batch_sizes,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ assert gg.ops is not None
+ x = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
+ x = self.args.activation_fn(x)
+ return gg.ops.gmm(x, w2, batch_sizes)
+
+
+class SharedMLP(torch.nn.Module):
+ """MLP for shared expert.
+
+ Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTMLP class
+ """
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ self.fc_kwargs: dict[str, Any] = {
+ 'bias': args.bias,
+ 'device': args.device,
+ }
+ self.fc_kwargs.update(args.fc_kwargs)
+
+ self.up_proj = args.fc_cls(
+ args.hidden_size,
+ args.shared_expert_hidden_size,
+ **self.fc_kwargs,
+ )
+ self.act = args.activation_fn
+ self.down_proj = args.fc_cls(
+ args.shared_expert_hidden_size,
+ args.hidden_size,
+ **self.fc_kwargs,
+ )
+ self.down_proj._is_residual = True # a flag for llm-foundry init
+
+ def add_experts_sharedexpert(
+ self,
+ shared_expert_out: torch.Tensor,
+ expert_out: torch.Tensor,
+ ) -> torch.Tensor:
+ # Helper function to add expert output to shared expert output
+ # with optional weighted sum.
+ if self.args.shared_expert_weighted_sum:
+ # enable using weighted sum for shared expert output
+ # wieghted by number of experts used
+ t_experts = self.args.moe_top_k + 1
+ sh_mlp_out = shared_expert_out / t_experts
+ return sh_mlp_out.add(
+ expert_out,
+ alpha=(self.args.moe_top_k / t_experts),
+ )
+
+ return shared_expert_out + expert_out
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.down_proj(self.act(self.up_proj(x)))
diff --git a/build/torch211-cxx11-cu126-aarch64-linux/_layers/moe.py b/build/torch211-cxx11-cu126-aarch64-linux/_layers/moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0a4aeaacc9c86fc70944e730c53f7a55644e05e
--- /dev/null
+++ b/build/torch211-cxx11-cu126-aarch64-linux/_layers/moe.py
@@ -0,0 +1,507 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Optional, Tuple
+
+import numpy as np
+import torch
+import torch.distributed as dist
+
+# import megablocks.ops as ops
+# from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry
+# from megablocks.layers.all_to_all import all_to_all
+# from megablocks.layers.arguments import Arguments
+
+from ..ops import (
+ sort,
+ histogram,
+ inclusive_cumsum,
+ exclusive_cumsum,
+ binned_gather,
+ binned_scatter,
+ gather,
+ scatter,
+ repeat,
+ replicate,
+)
+
+from . import common, mlp, mpu, router, sharedexpert_registry
+from .arguments import Arguments
+from .all_to_all import all_to_all
+
+_LOAD_BALANCING_LOSS = []
+
+
+def save_load_balancing_loss(loss):
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.append(loss)
+
+
+def get_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ return _LOAD_BALANCING_LOSS
+
+
+def clear_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.clear()
+
+
+def batched_load_balancing_loss(args: Arguments):
+ if args.moe_loss_weight == 0:
+ return 0.0
+
+ # tokens_per_expert[i].shape = (num_experts)
+ # expert_scores[i].shape = (tokens, num_experts)
+ tokens_per_expert, expert_scores = zip(*get_load_balancing_loss())
+ num_layers_per_pipeline_stage = (args.num_layers // args.pipeline_model_parallel_size)
+ if args.num_layers_per_virtual_pipeline_stage is not None:
+ num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage
+
+ if len(tokens_per_expert) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f'Expected {num_layers_per_pipeline_stage} token_per_experts '
+ f'but found {len(tokens_per_expert)}.\nnum_layers = '
+ f'{args.num_layers}\npipeline_model_parallel_size = '
+ f'{args.pipeline_model_parallel_size}\n'
+ 'num_layers_per_virtual_pipeline_stage'
+ f' = {args.num_layers_per_virtual_pipeline_stage}',
+ )
+ if len(expert_scores) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f'Expected {num_layers_per_pipeline_stage} expert_scores '
+ f'but found {len(tokens_per_expert)}.\nnum_layers = '
+ f'{args.num_layers}\npipeline_model_parallel_size = '
+ f'{args.pipeline_model_parallel_size}\n'
+ 'num_layers_per_virtual_pipeline_stage'
+ f' = {args.num_layers_per_virtual_pipeline_stage}',
+ )
+
+ # Verify the shape of the tokens_per_expert and expert_scores tensors.
+ assert all((x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert))
+
+ tokens = expert_scores[0].shape[0]
+ assert all(((x.ndim == 2 and x.shape[1] == args.moe_num_experts and x.shape[0] == tokens) for x in expert_scores))
+
+ # Concatenate the contributions of each layer and convert to
+ # the correct types and formats for the dot product.
+ expert_scores = torch.cat(expert_scores, dim=1)
+ if args.moe_lbl_in_fp32:
+ expert_scores = expert_scores.float()
+ if tokens != 0:
+ expert_scores = expert_scores.mean(dim=0)
+ else:
+ expert_scores = expert_scores.sum(dim=0)
+ tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype)
+
+ expected_values = num_layers_per_pipeline_stage * args.moe_num_experts
+ assert tokens_per_expert.numel() == expected_values
+ assert expert_scores.numel() == expected_values
+
+ # Calculate the total scale across all factors.
+ #
+ # loss_weight * num_experts / (num_layers * tokens * top_k)
+ scale_numerator = (args.moe_num_experts * args.moe_loss_weight)
+ scale_denominator = (args.num_layers * tokens * args.moe_top_k)
+ scale = scale_numerator / scale_denominator
+ return scale * torch.dot(tokens_per_expert, expert_scores)
+
+
+# NOTE: This class defines MoE expert computation, including expert model parallel
+# communication. When using FSDP on top of MegaBlocks this is the module that should
+# be wrapped s.t. the weight all-gathers can be scheduled *before* the expert model
+# parallel all2all.
+class ParallelMLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super(ParallelMLP, self).__init__()
+ self.args = args
+
+ # Calculate the number of experts in total and the number of experts
+ # owned by this rank.
+ # world_size = mpu.get_expert_parallel_world_size(args)
+ self.num_experts = args.moe_num_experts
+ self.top_k = self.args.moe_top_k
+
+ # Calculate the number of bits needed to represent the expert indices
+ # so that we can pass it to radix sort.
+ self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1)
+
+ # Expert MLP.
+ self.mlp = mlp.MLP(args)
+
+ self.bias: Optional[torch.Tensor]
+ if self.args.bias:
+ # Note that the output bias is not parallelized with expert
+ # model parallelism.
+ self.bias = torch.nn.Parameter(
+ torch.empty(
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ torch.nn.init.zeros_(self.bias)
+ else:
+ self.register_parameter('bias', None)
+
+ # Select the forward function for the operating mode.
+ self.forward_fn = (self.parallel_forward_once if args.moe_expert_model_parallelism else self.forward_once)
+
+ def expert_capacity(self, tokens: int) -> int:
+ world_size = mpu.get_expert_parallel_world_size(self.args)
+ tokens_per_expert = (self.top_k * tokens * world_size / self.num_experts)
+ return int(self.args.moe_capacity_factor * tokens_per_expert)
+
+ def load_balancing_loss(self, tokens_per_expert: torch.Tensor, expert_scores: torch.Tensor):
+ """Calculate the load balancing loss contribution."""
+ assert len(expert_scores.size()) == 2
+ tokens, num_experts = expert_scores.size()
+ assert num_experts == self.num_experts
+ assert len(tokens_per_expert.size()) == 1
+ num_experts, = tokens_per_expert.size()
+ assert num_experts == self.num_experts
+ scale = self.num_experts / (tokens * self.top_k)
+ return scale * torch.dot(
+ tokens_per_expert.to(expert_scores.dtype),
+ expert_scores.mean(dim=0),
+ )
+
+ def indices_and_bins(self,
+ top_expert: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ # Sort the expert ids to produce the scatter/gather
+ # indices for the permutation.
+ #
+ # TODO(tgale): Is it worth doing this conversion to 32-bit
+ # prior? Could we place the `torch.max` operation to return
+ # 32-bit expert indices?
+ top_expert = top_expert.int()
+ # output = ops.sort(top_expert, self.sort_end_bit)
+ output = sort(top_expert, self.sort_end_bit)
+ assert output is not None
+ bin_ids, indices = output
+
+ # Histogram the expert ids to identify the number of
+ # tokens routed to each expert.
+ #
+ # TODO(tgale): Does the sorted data produce a more favorable
+ # data distribution for histogram? Or is the op parallelism
+ # worth more?
+ # tokens_per_expert = ops.histogram(top_expert, self.num_experts)
+ tokens_per_expert = histogram(top_expert, self.num_experts)
+
+ # Calculate the bin bounds for the sorted tokens.
+ # bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ bins = inclusive_cumsum(tokens_per_expert, 0)
+ assert bins is not None
+ bins = bins.view(1) if not len(bins.size()) else bins
+
+ assert isinstance(indices, torch.Tensor)
+ assert isinstance(bin_ids, torch.Tensor)
+ assert isinstance(bins, torch.Tensor)
+ assert isinstance(tokens_per_expert, torch.Tensor)
+
+ return indices, bin_ids, bins, tokens_per_expert
+
+ def permute_and_compute(
+ self,
+ x: torch.Tensor,
+ tokens_per_expert: int, # unused
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor, # unused
+ expert_weights: torch.Tensor,
+ bins: torch.Tensor,
+ expert_capacity: int,
+ top_k: int,
+ ):
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ # output = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
+ output = binned_gather(x, indices, bins, expert_capacity, top_k)
+ assert output is not None
+ x = output
+
+ # Perform the expert computation. Note that we don't
+ # use biases for these linear operations.
+ x = self.mlp(x)
+
+ # Un-route the data for the MoE output.
+ # return ops.binned_scatter(x, indices, expert_weights, bins, top_k)
+ return binned_scatter(x, indices, expert_weights, bins, top_k)
+
+
+ def forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ # If expert_capacity is set to zero, set the number of tokens
+ # per expert to the maximum we need to avoid dropping tokens.
+ sl, bs, _ = x.size()
+ expert_capacity = self.expert_capacity(sl * bs)
+ if expert_capacity == 0:
+ expert_capacity = torch.max(tokens_per_expert).item()
+
+ x = self.permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ self.top_k,
+ )
+ return x, tokens_per_expert
+
+ def parallel_forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ # NOTE: This function implements the same computation as forward_once
+ # but with expert model parallelism.
+ #
+ # 1. Permute the tokens locally so that they are grouped by their
+ # expert assignments. This allows us to transfer all of the tokens
+ # for a remote device in one communication primitive.
+ #
+ # 2. Permute the tokens across the expert parallel devices. After
+ # this is completed each device has all of the tokens assigned to
+ # its set of experts in its local HBM.
+ #
+ # 3. Permute the tokens locally so that they are grouped by their
+ # expert assignement. After the distributed permutation the tokens
+ # are grouped by which device they came from. We re-order them
+ # locally to allow for efficient computation.
+ #
+ # After this series of permutations we compute the linear layers
+ # and then repeat these three steps in reverse to produce the final
+ # output.
+ #
+ # Compute the mapping of local tokens to experts.
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ # If we're sharding the experts along the hidden dimension
+ # multiple devices own parts of the same sets of experts.
+ # Replicate the token counts so every device gets the counts.
+ # repeated_tokens_per_expert = ops.repeat(
+ repeated_tokens_per_expert = repeat(
+ tokens_per_expert,
+ (mpu.hidden_sharding_degree(self.args),),
+ )
+
+ # Pass token count information to the device on which the
+ # target expert resides.
+ parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert,)
+ tpe_handle = dist.all_to_all_single(
+ parallel_tokens_per_expert,
+ repeated_tokens_per_expert,
+ group=self.args.expert_parallel_group,
+ async_op=True,
+ )
+
+ # Permute locally and without any padding so that tokens for each
+ # parallel device are stored contiguously.
+ #
+ # This view updates the shape of the tensor from [sl, bs, hs] to
+ # [sl * bs, hs] prior to the permutation.
+ x = x.view(-1, x.shape[-1])
+ # output = ops.gather(x, indices, bin_ids, bins, self.top_k)
+ output = gather(x, indices, bin_ids, bins, self.top_k)
+ assert output is not None
+ x = output
+
+ # Compute the number of tokens that will be received from each
+ # device and permute the input data across the devices.
+ with torch.no_grad():
+ tpe_handle.wait()
+ experts_per_rank = mpu.experts_per_rank(self.args)
+
+ # Reshape to [world_size, num_experts_per_rank].
+ world_size = mpu.get_expert_parallel_world_size(self.args)
+ repeated_tokens_per_expert = (repeated_tokens_per_expert.view(world_size, experts_per_rank))
+ parallel_tokens_per_expert = (parallel_tokens_per_expert.view(world_size, experts_per_rank))
+
+ # TODO(tgale): It might be faster to do this on the GPU and
+ # then communicate the results back to the host.
+ send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1)
+ parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
+ recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1)
+
+ # Convert the send/recv counts to lists.
+ send_counts = send_counts.tolist()
+ recv_counts = recv_counts.tolist()
+ tokens_received = sum(recv_counts)
+
+ # If we're sharding the experts along the hidden dimension
+ # multiple devices own parts of the same sets of experts.
+ # Replicate the token counts so devices that share experts
+ # get all of the tokens assigned to them.
+ #
+ # TODO(tgale): Fuse this into the prior, local permutation.
+ # x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
+ x = repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
+
+ # Start the cross-device permutation asynchronously so we can
+ # overlap communication with computation.
+ parallel_x, parallel_x_handle = all_to_all(
+ x,
+ recv_counts,
+ send_counts,
+ self.args.expert_parallel_group,
+ async_op=True,
+ )
+
+ with torch.no_grad():
+ # After we do the cross-device permutation we have the tokens on the
+ # correct device but not yet grouped by expert because we received
+ # tokens from each device as contiguous chunks. To group the tokens
+ # for expert computation we'll do one more local permutation. The
+ # rest of this torch.no_grad() scope sets up the indices and bins
+ # for this permutation.
+ # replicate_bins = ops.inclusive_cumsum(
+ replicate_bins = inclusive_cumsum(
+ parallel_tokens_per_expert.flatten(),
+ 0,
+ )
+ replicate_bins = (replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins)
+
+ # Construct the expert indices for the permuted tokens.
+ parallel_top_expert = torch.remainder(
+ torch.arange(
+ self.num_experts * mpu.hidden_sharding_degree(self.args),
+ dtype=torch.int32,
+ device=indices.device,
+ ),
+ mpu.experts_per_rank(self.args),
+ )
+ # parallel_top_expert = ops.replicate(
+ parallel_top_expert = replicate(
+ parallel_top_expert.unsqueeze(dim=0),
+ replicate_bins,
+ tokens_received,
+ ).flatten()
+
+ # TODO(tgale): The sort_end_bit here can be reduced.
+ # parallel_bin_ids, parallel_indices = ops.sort(
+ parallel_bin_ids, parallel_indices = sort(
+ parallel_top_expert,
+ self.sort_end_bit,
+ )
+
+ # Calculate the bins boundaries from the token counts.
+ parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+ # parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = (parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins)
+
+ # If expert_capacity is set to zero, set the number of tokens
+ # per expert to the maximum we need to avoid dropping tokens.
+ tokens, _ = x.size()
+ expert_capacity = self.expert_capacity(tokens)
+ if expert_capacity == 0:
+ expert_capacity = torch.max(parallel_tokens_per_expert).item()
+
+ # Locally permute the tokens and perform the expert computation.
+ # Block to make sure that the cross-device permutation is complete.
+ if self.args.mlp_impl == 'grouped':
+ # GroupedMLP requires counts on CPU. We can use the tensor already
+ # moved to CPU for the prior all_to_all, which avoids an extra
+ # device synchronization.
+ parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+ parallel_x_handle.wait()
+ parallel_x = self.permute_and_compute(
+ parallel_x,
+ parallel_tokens_per_expert,
+ parallel_indices,
+ parallel_bin_ids,
+ None, # expert_weights
+ parallel_bins,
+ expert_capacity,
+ top_k=1,
+ )
+
+ # Un-permute the tokens across the devices.
+ x, _ = all_to_all(
+ parallel_x,
+ send_counts,
+ recv_counts,
+ self.args.expert_parallel_group,
+ )
+
+ # Reduce along the hidden sharding to get the final outputs.
+ #
+ # TODO(tgale): Fuse this into the following local permutation.
+ shape = (
+ mpu.hidden_sharding_degree(self.args),
+ -1,
+ self.args.hidden_size,
+ )
+ # x = ops.sum(x.view(shape), dim=0)
+ x = x.view(shape).sum(dim=0)
+
+ # Un-permute locally to setup for the next series of operations.
+ # x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
+ x = scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
+ return x, tokens_per_expert.flatten()
+
+ def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ in_shape = x.size()
+
+ # Compute the experts.
+ x, tokens_per_expert = self.forward_fn(x, expert_weights, top_experts)
+ if self.training and self.args.moe_loss_weight > 0:
+ save_load_balancing_loss((tokens_per_expert, scores))
+ x = x.view(in_shape)
+ if self.bias is not None:
+ if self.args.return_bias:
+ return x, self.bias
+ return x + self.bias
+ return x
+
+
+class MoE(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super(MoE, self).__init__()
+
+ # Token router.
+ self.router = router.LearnedRouter(args)
+
+ # Expert computation helper.
+ self.experts = self._init_experts_mlp(args)
+
+ self.shared_expert = None
+ if args.shared_expert:
+ # SharedExpert computation helper.
+ self.shared_expert = sharedexpert_registry.get(args)
+
+ def _init_experts_mlp(self, args: Arguments):
+ return ParallelMLP(args)
+
+ def forward(self, x: torch.Tensor):
+ # NOTE: If we're going to cast the activations to lower precision
+ # do it before we permute the tokens to save bandwidth.
+ x = common.cast_if_autocast_enabled(x)
+
+ # Compute the expert scores and assignments.
+ scores, expert_weights, top_experts = self.router(x)
+
+ # Compute the experts.
+ out = self.experts(x, scores, expert_weights, top_experts)
+ if self.shared_expert is not None:
+ shared_expert_out = self.shared_expert(x)
+ out = self.shared_expert.add_experts_sharedexpert(
+ shared_expert_out,
+ out,
+ )
+ return out
diff --git a/build/torch211-cxx11-cu126-aarch64-linux/_layers/mpu.py b/build/torch211-cxx11-cu126-aarch64-linux/_layers/mpu.py
new file mode 100644
index 0000000000000000000000000000000000000000..434e143ab42bf3f83406d69e9dd1f72777716e22
--- /dev/null
+++ b/build/torch211-cxx11-cu126-aarch64-linux/_layers/mpu.py
@@ -0,0 +1,94 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Optional
+
+import torch
+import torch.distributed as dist
+
+# from megablocks.layers.arguments import Arguments
+from .arguments import Arguments
+
+
+class MoeParam(torch.Tensor):
+
+ def __init__(self):
+ super().__init__(self)
+ self.expert_model_parallel: bool
+
+
+def is_moe_param(tensor: torch.Tensor) -> bool:
+ return hasattr(tensor, 'expert_model_parallel')
+
+
+def get_expert_parallel_world_size(args: Arguments) -> int:
+ return (dist.get_world_size(args.expert_parallel_group) if args.moe_expert_model_parallelism else 1)
+
+
+def get_expert_parallel_rank(args: Arguments) -> int:
+ return (dist.get_rank(args.expert_parallel_group) if args.moe_expert_model_parallelism else 0)
+
+
+def set_expert_model_parallel_attributes(
+ tensor: torch.Tensor,
+ is_parallel: bool,
+):
+ assert not hasattr(tensor, 'expert_model_parallel')
+ setattr(tensor, 'expert_model_parallel', is_parallel)
+
+
+def param_is_expert_model_parallel(param: MoeParam) -> bool:
+ return (hasattr(param, 'expert_model_parallel') and param.expert_model_parallel)
+
+
+def copy_expert_model_parallel_attributes(
+ destination_tensor: torch.Tensor,
+ source_tensor: torch.Tensor,
+):
+ if hasattr(source_tensor, 'expert_model_parallel'):
+ setattr(
+ destination_tensor,
+ 'expert_model_parallel',
+ getattr(source_tensor, 'expert_model_parallel'),
+ )
+
+
+def synchronized_print(group: Optional[dist.ProcessGroup], *x: torch.Tensor):
+ world_size = dist.get_world_size(group)
+ rank = dist.get_rank(group)
+ for i in range(world_size):
+ dist.barrier(group)
+ if i == rank:
+ print(f'rank = {rank}', *x)
+
+
+# Helpers for expert/tensor sharding.
+def expert_sharding_degree(args: Arguments) -> int:
+ world_size = get_expert_parallel_world_size(args)
+ esd = min(world_size, args.moe_num_experts)
+
+ if (args.moe_num_experts % esd) != 0:
+ raise ValueError(f'Cannot shard {args.moe_num_experts} experts {esd} ways.',)
+ return esd
+
+
+def hidden_sharding_degree(args: Arguments) -> int:
+ world_size = get_expert_parallel_world_size(args)
+ esd = expert_sharding_degree(args)
+ hsd = world_size // esd
+
+ if (args.ffn_hidden_size % hsd) != 0:
+ raise ValueError(f'Cannot shard {args.ffn_hidden_size} features {hsd} ways.',)
+ if (esd * hsd) != world_size:
+ raise ValueError(
+ f"Invalid sharding. 'expert_sharding_degree' ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size}).",
+ )
+ return hsd
+
+
+def experts_per_rank(args: Arguments) -> int:
+ return args.moe_num_experts // expert_sharding_degree(args)
+
+
+def features_per_rank(args: Arguments) -> int:
+ return args.ffn_hidden_size // hidden_sharding_degree(args)
diff --git a/build/torch211-cxx11-cu126-aarch64-linux/_layers/router.py b/build/torch211-cxx11-cu126-aarch64-linux/_layers/router.py
new file mode 100644
index 0000000000000000000000000000000000000000..37cb2782348d62583376f1a183c7ede83601216d
--- /dev/null
+++ b/build/torch211-cxx11-cu126-aarch64-linux/_layers/router.py
@@ -0,0 +1,116 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+
+# from megablocks.layers import common
+# from megablocks.layers.arguments import Arguments
+from . import common
+from .arguments import Arguments
+
+_ROUTER_LOGITS = []
+
+
+def _save_router_logits(logits: torch.Tensor, args: Arguments):
+ if args.moe_zloss_weight == 0:
+ return
+ global _ROUTER_LOGITS
+ _ROUTER_LOGITS.append(logits)
+
+
+def clear_router_zloss():
+ global _ROUTER_LOGITS
+ _ROUTER_LOGITS.clear()
+
+
+def batched_router_zloss(args: Arguments):
+ global _ROUTER_LOGITS
+
+ if args.moe_zloss_weight == 0:
+ import warnings
+ warnings.warn('Call to batched_router_zloss, but moe_zloss_weight=0')
+ return 0
+
+ logits_per_router = _ROUTER_LOGITS
+
+ if args.moe_zloss_in_fp32:
+ logits_per_router = [logits.float() for logits in logits_per_router]
+
+ unscaled_zloss_per_router = torch.stack([
+ torch.logsumexp(logits, dim=1).square().mean() for logits in logits_per_router
+ ])
+
+ return args.moe_zloss_weight * unscaled_zloss_per_router
+
+
+# NOTE: To enable end-to-end benchmarking without convergence we
+# support a flag to force the router to assign tokens uniformly
+# across the experts. We do this with a custom autograd operation
+# so that PyTorch still executes the full set of router operation.
+class _UniformExpertAssignment(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, num_experts: int):
+ out = torch.arange(x.numel(), dtype=x.dtype, device=x.device)
+ out = torch.remainder(out, num_experts)
+ return out.view(x.shape)
+
+
+_uniform_expert_assignment = _UniformExpertAssignment.apply
+
+
+class LearnedRouter(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+
+ # Learned router parameters.
+ #
+ # NOTE: This weight matrix is not parallelized with expert model
+ # parallelism. Each device needs the entire router weight matrix
+ # so that it can route its batch of data correctly.
+ self.layer = torch.nn.Linear(
+ args.hidden_size,
+ args.moe_num_experts,
+ bias=False,
+ dtype=common.dtype(args),
+ device=args.device,
+ )
+ args.init_method(self.layer.weight)
+
+ def jitter(self, x: torch.Tensor):
+ low: float = 1.0 - self.args.moe_jitter_eps
+ high: float = 1.0 + self.args.moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return low + noise * (high - low)
+
+ def _top_k(self, scores: torch.Tensor):
+ if self.args.moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, self.args.moe_top_k, dim=-1)
+
+ def forward(self, x: torch.Tensor):
+ if self.training and self.args.moe_jitter_eps is not None:
+ x = x * self.jitter(x)
+
+ logits = self.layer(x.view(-1, x.shape[-1]))
+ _save_router_logits(logits, self.args)
+ scores = logits.softmax(dim=-1)
+ expert_weights, expert_indices = self._top_k(scores)
+ if self.args.moe_normalize_expert_weights:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=self.args.moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+
+ expert_indices = (
+ _uniform_expert_assignment(
+ expert_indices,
+ self.args.moe_num_experts,
+ ) if self.args.uniform_expert_assignment else expert_indices
+ )
+ return scores, expert_weights, expert_indices
diff --git a/build/torch211-cxx11-cu126-aarch64-linux/_layers/sharedexpert_registry.py b/build/torch211-cxx11-cu126-aarch64-linux/_layers/sharedexpert_registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..5840862f88f370ace5fd49bd0612fc98d186cc49
--- /dev/null
+++ b/build/torch211-cxx11-cu126-aarch64-linux/_layers/sharedexpert_registry.py
@@ -0,0 +1,32 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Union
+
+# from megablocks.layers import glu, mlp
+# from megablocks.layers.arguments import Arguments
+from . import glu, mlp
+from .arguments import Arguments
+
+_REGISTRY = {
+ 'mlp': mlp.SharedMLP,
+ 'glu': glu.SharedGLU,
+}
+
+
+def get(args: Arguments) -> Union[mlp.SharedMLP, glu.SharedGLU]:
+ """Returns an SharedMLP for use in a dMoE instance.
+
+ Uses the provided arguments to instantiate the appropriate
+ SharedMLP instance.
+
+ Args:
+ args: propagated Arguments dataclass.
+
+ Returns:
+ An instantiated SharedMLP constructed using the input args.
+ """
+ if args.mlp_type not in _REGISTRY:
+ raise ValueError(f'Unsupported mlp type: {args.mlp_type}')
+
+ return _REGISTRY[args.mlp_type](args)
diff --git a/build/torch211-cxx11-cu126-aarch64-linux/_megablocks_cuda_7a6bcf4.abi3.so b/build/torch211-cxx11-cu126-aarch64-linux/_megablocks_cuda_7a6bcf4.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..d122a7970e66f05517764a5bbe5efd723671111a
--- /dev/null
+++ b/build/torch211-cxx11-cu126-aarch64-linux/_megablocks_cuda_7a6bcf4.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bccbfca5181c0702db62b8285d63aa8c380902bf70555369e1a7b6c979009a01
+size 15124328
diff --git a/build/torch211-cxx11-cu126-aarch64-linux/_ops.py b/build/torch211-cxx11-cu126-aarch64-linux/_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..abde001f8cf5f78a02794d6e9a81fd8195e65d77
--- /dev/null
+++ b/build/torch211-cxx11-cu126-aarch64-linux/_ops.py
@@ -0,0 +1,9 @@
+import torch
+from . import _megablocks_cuda_7a6bcf4
+ops = torch.ops._megablocks_cuda_7a6bcf4
+
+def add_op_namespace_prefix(op_name: str):
+ """
+ Prefix op by namespace.
+ """
+ return f"_megablocks_cuda_7a6bcf4::{op_name}"
diff --git a/build/torch211-cxx11-cu126-aarch64-linux/_version.py b/build/torch211-cxx11-cu126-aarch64-linux/_version.py
new file mode 100644
index 0000000000000000000000000000000000000000..c55783177af19bc03654c730c4892df8f8532279
--- /dev/null
+++ b/build/torch211-cxx11-cu126-aarch64-linux/_version.py
@@ -0,0 +1,6 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+"""The MegaBlocks Version."""
+
+__version__ = '0.11.0.dev0'
diff --git a/build/torch211-cxx11-cu126-aarch64-linux/backend/__init__.py b/build/torch211-cxx11-cu126-aarch64-linux/backend/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d4e43e9b7e6a9a1c3bde2df34914643ca5d8332
--- /dev/null
+++ b/build/torch211-cxx11-cu126-aarch64-linux/backend/__init__.py
@@ -0,0 +1,2 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
diff --git a/build/torch211-cxx11-cu126-aarch64-linux/backend/kernels.py b/build/torch211-cxx11-cu126-aarch64-linux/backend/kernels.py
new file mode 100644
index 0000000000000000000000000000000000000000..d20139324883338ddb312e4b05a72056d74491ac
--- /dev/null
+++ b/build/torch211-cxx11-cu126-aarch64-linux/backend/kernels.py
@@ -0,0 +1,557 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import triton
+import triton.language as tl
+
+# Stub triton autotune when testing in a env that does not have CUDA
+# this approach preserves the original code but enables testing without a GPU
+if torch.cuda.is_available() is False:
+ import warnings
+
+ warnings.warn("CUDA is not available. Triton autotuning is disabled.")
+
+ def _no_autotune(*args, **kwargs):
+ def deco(fn):
+ return fn
+ return deco
+
+ triton.autotune = _no_autotune
+
+
+def assert_is_tensor(x, ndim):
+ if x.ndim != ndim:
+ raise ValueError(f'Expected {ndim}-tensor but got {x.ndim}-tensor')
+
+
+def assert_is_matrix(x):
+ assert_is_tensor(x, 2)
+
+
+def assert_is_vector(x):
+ if x.ndim != 1:
+ raise ValueError(f'Expected 1-tensor but got {x.ndim}-tensor')
+
+
+def assert_equal(a, b):
+ if a != b:
+ raise ValueError(f'Expected dimensions to be equal but got {a} and {b}.',)
+
+
+# a: (tokens, hidden_size), real.
+# indices: (tokens * top_k), integer.
+# bin_ids: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+# padded_bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _padded_copy(
+ a,
+ b,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+ A_TO_B: tl.constexpr,
+ SCALE: tl.constexpr,
+):
+ # Our index into array 'a'.
+ index_a = tl.load(indices + tl.program_id(0))
+
+ # One threadblock per row in 'a'. Array 'b' has greater or equal
+ # number of rows since they could be padded.
+ bin_idx = tl.load(bin_ids + tl.program_id(0))
+
+ # Now we know what bin we're assigned to, but we need to know how
+ # many threadblocks were assigned to earlier bins so we can offset
+ # in our bin properly.
+ offset_in_bin = tl.program_id(0)
+ if bin_idx > 0:
+ offset_in_bin -= tl.load(bins + bin_idx - 1)
+
+ # Load the starting index of our bin in array 'b'.
+ index_b = offset_in_bin
+ if bin_idx > 0:
+ index_b += tl.load(padded_bins + bin_idx - 1)
+
+ # Offset the input and output pointers.
+ #
+ # If we're going from A to B, divide the input index to copy
+ # the same input repeatedly. If we're going from B to A we
+ # need to reduce the result. Using atomics is slow, so we
+ # do the reduce step in a second kernel.
+ offset = index_a // TOP_K if A_TO_B else index_a
+ a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS)
+ b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ # Load the scale, if requested.
+ scale = tl.load(weights + index_a) if SCALE else 1
+
+ # Swap the pointers depending on the direction.
+ iptr = a if A_TO_B else b
+ optr = b if A_TO_B else a
+
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ x = tl.load(iptr + offsets, mask=mask)
+ x = x.to(tl.float32) * scale.to(tl.float32)
+
+ tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask)
+
+ offsets += BLOCK_X
+
+
+def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+ assert_equal(bin_ids.shape[0], x.shape[0] * top_k)
+ assert_equal(bins.size(), padded_bins.size())
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ # NOTE: Because of the padding, the output size is dynamic.
+ # We load the final padded bin bound to get the output rows.
+ output_rows = padded_bins[-1].cpu().item()
+ out = torch.zeros((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ x,
+ out,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def gather(x, indices, bin_ids, weights, bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+ assert_equal(bin_ids.shape[0], x.shape[0] * top_k)
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ # NOTE: There is no padding so the output rows equals the
+ # input rows multiplied by top_k.
+ output_rows = x.shape[0] * top_k
+ out = torch.empty((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ x,
+ out,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], bin_ids.shape[0])
+ assert_equal(bins.size(), padded_bins.size())
+
+ if weights is not None:
+ assert_equal(indices.shape[0], weights.shape[0])
+
+ tokens = indices.shape[0] // top_k
+ out = torch.empty((tokens, top_k, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ out,
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=False,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+
+ # Reduce along the top-k dimension, if needed.
+ return out.sum(dim=1) if top_k > 1 else out.view(tokens, x.shape[1])
+
+
+def scatter(x, indices, bin_ids, weights, bins, top_k):
+ return padded_scatter(x, indices, bin_ids, weights, bins, bins, top_k)
+
+
+# x: (tokens, top_k, hidden_size), real
+# grad: (tokens, hidden_size), real.
+# wgrad: (tokens, top_k), real.
+# indices: (tokens * top_k), integer.
+# bin_ids: (tokens * top_k), integer.
+# bins: (num_experts), integer.
+# padded_bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _padded_copy_wgrad(
+ x,
+ grad,
+ wgrad,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+):
+ # Our index into 'tokens * top_k'.
+ index_out = tl.load(indices + tl.program_id(0))
+
+ # One threadblock per row in 'a'. Array 'b' has greater or equal
+ # number of rows since they could be padded.
+ bin_idx = tl.load(bin_ids + tl.program_id(0))
+
+ # Now we know what bin we're assigned to, but we need to know how
+ # many threadblocks were assigned to earlier bins so we can offset
+ # in our bin properly.
+ offset_in_bin = tl.program_id(0)
+ if bin_idx > 0:
+ offset_in_bin -= tl.load(bins + bin_idx - 1)
+
+ # Load the starting index of our bin in array 'x'.
+ index_x = offset_in_bin
+ if bin_idx > 0:
+ index_x += tl.load(padded_bins + bin_idx - 1)
+
+ # Offset the input and output pointers.
+ wgrad += index_out
+ grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS)
+ x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ data = tl.load(x + offsets, mask=mask).to(tl.float32)
+ scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
+ acc += data * scale
+ offsets += BLOCK_X
+
+ # Reduce to get the final result and store.
+ out = tl.sum(acc).to(wgrad.dtype.element_ty)
+ tl.store(wgrad, out)
+
+
+def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_matrix(grad)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], bin_ids.shape[0])
+ assert_equal(bins.size(), padded_bins.size())
+
+ tokens = indices.shape[0] // top_k
+ out = torch.empty((tokens * top_k), dtype=x.dtype, device=x.device)
+ _padded_copy_wgrad[(indices.shape[0],)](
+ x,
+ grad,
+ out,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ TOP_K=top_k,
+ )
+ return out
+
+
+def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k):
+ return padded_scatter_wgrad(x, grad, indices, bin_ids, bins, bins, top_k)
+
+
+# a: (tokens, hidden_size), real.
+# b: (num_experts, expert_capacity, num_columns), real.
+# indices: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _binned_copy(
+ a,
+ b,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+ A_TO_B: tl.constexpr,
+ SCALE: tl.constexpr,
+):
+ # Load our indices into the output.
+ expert_idx = tl.program_id(0)
+ entry_idx = tl.program_id(1)
+
+ # Calculate our offset into the output.
+ index_b = expert_idx * expert_capacity + entry_idx
+
+ # Load the index bounds for our bin and calculate
+ # the number of tokens assigned to our expert.
+ start = 0
+ if expert_idx > 0:
+ start = tl.load(bins + expert_idx - 1)
+ end = tl.load(bins + expert_idx)
+ num_tokens = end - start
+
+ # Calculate our offset into the input. If we don't
+ # have an input exit early.
+ if entry_idx >= num_tokens:
+ return
+ index_a = tl.load(indices + start + entry_idx)
+
+ # Offset the input and output pointers.
+ #
+ # If we're going from A to B, divide the input index to copy
+ # the same input repeatedly. If we're going from B to A we
+ # need to reduce the result. Using atomics is slow, so we
+ # do the reduce step in a second kernel.
+ offset = index_a // TOP_K if A_TO_B else index_a
+ a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS)
+ b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ # Load the scale, if requested.
+ scale = tl.load(weights + index_a) if SCALE else 1
+
+ # Swap the pointers depending on the direction.
+ #
+ # NOTE: We need to zero the output in both directions.
+ iptr = a if A_TO_B else b
+ optr = b if A_TO_B else a
+
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ x = tl.load(iptr + offsets, mask=mask)
+ x = x.to(tl.float32) * scale.to(tl.float32)
+
+ tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask)
+
+ offsets += BLOCK_X
+
+
+def binned_gather(x, indices, weights, bins, expert_capacity, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ num_experts = bins.shape[0]
+ out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
+
+ _binned_copy[(num_experts, expert_capacity)](
+ x,
+ out,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def binned_scatter(x, indices, weights, bins, top_k):
+ # Validate the input shapes.
+ assert_is_tensor(x, 3)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(bins.shape[0], x.shape[0])
+
+ if weights is not None:
+ assert_equal(indices.shape[0], weights.shape[0])
+
+ num_experts, expert_capacity, hidden_size = x.shape
+ tokens = indices.shape[0] // top_k
+ out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
+ _binned_copy[(num_experts, expert_capacity)](
+ out,
+ x,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS=hidden_size,
+ A_TO_B=False,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+
+ # Reduce along the top-k dimension, if needed.
+ return out.sum(dim=1) if top_k > 1 else out.view(tokens, hidden_size)
+
+
+# a: (tokens, hidden_size), real.
+# b: (num_experts, expert_capacity, num_columns), real.
+# indices: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _binned_copy_wgrad(
+ x,
+ grad,
+ wgrad,
+ num_experts,
+ expert_capacity,
+ indices,
+ bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+):
+ # Load our indices into the output.
+ expert_idx = tl.program_id(0)
+ entry_idx = tl.program_id(1)
+
+ # Calculate our offset into the output.
+ index_x = expert_idx * expert_capacity + entry_idx
+
+ # Load the index bounds for our bin and calculate
+ # the number of tokens assigned to our expert.
+ start = 0
+ if expert_idx > 0:
+ start = tl.load(bins + expert_idx - 1)
+ end = tl.load(bins + expert_idx)
+ num_tokens = end - start
+
+ # Calculate our offset into the input. If we don't
+ # have an input exit early.
+ if entry_idx >= num_tokens:
+ return
+ index_out = tl.load(indices + start + entry_idx)
+
+ # Offset the input and output pointers.
+ wgrad += index_out
+ grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS)
+ x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ data = tl.load(x + offsets, mask=mask).to(tl.float32)
+ scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
+ acc += data * scale
+ offsets += BLOCK_X
+
+ # Reduce to get the final result and store.
+ out = tl.sum(acc).to(wgrad.dtype.element_ty)
+ tl.store(wgrad, out)
+
+
+def binned_scatter_wgrad(x, grad, indices, bins, top_k):
+ # Validate the input shapes.
+ assert_is_tensor(x, 3)
+ assert_is_matrix(grad)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(bins.shape[0], x.shape[0])
+
+ num_experts, expert_capacity, hidden_size = x.shape
+ tokens = indices.shape[0] // top_k
+ out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
+ _binned_copy_wgrad[(num_experts, expert_capacity)](
+ x,
+ grad,
+ out,
+ num_experts,
+ expert_capacity,
+ indices,
+ bins,
+ NUM_COLUMNS=hidden_size,
+ TOP_K=top_k,
+ )
+ return out
diff --git a/build/torch211-cxx11-cu126-aarch64-linux/benchmark_util.py b/build/torch211-cxx11-cu126-aarch64-linux/benchmark_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..02612d95e3ead1175a596e2878fa34b5bf85ad6f
--- /dev/null
+++ b/build/torch211-cxx11-cu126-aarch64-linux/benchmark_util.py
@@ -0,0 +1,35 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import numpy as np
+import torch
+
+
+def log_benchmark(name, arguments, time, std):
+ print('=' * 60)
+ print(f'{name} Benchmark')
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean time = {:.3f}ms, std time = {:.3f}ms'.format(time, std))
+ print('=' * 60)
+
+
+def benchmark_function(fn, iterations=100, warmup=10):
+ # Warmup iterations.
+ for _ in range(warmup):
+ fn()
+
+ times = []
+ for i in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+
+ start.record()
+ fn()
+ end.record()
+
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ return np.mean(times), np.std(times)
diff --git a/build/torch211-cxx11-cu126-aarch64-linux/cpu_fused_moe.py b/build/torch211-cxx11-cu126-aarch64-linux/cpu_fused_moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a29e81d6e9c57f64b0a78ac1c0828e45fd9d855
--- /dev/null
+++ b/build/torch211-cxx11-cu126-aarch64-linux/cpu_fused_moe.py
@@ -0,0 +1,311 @@
+# SPDX-License-Identifier: Apache-2.0
+# MegaBlocks CPU Fused MoE Implementation
+#
+# This is a pure Python/PyTorch implementation for CPU.
+# For better performance, consider using the C++ kernel implementation.
+#
+import torch
+import torch.nn.functional as F
+
+
+def swigluoai_activation(gate: torch.Tensor, up: torch.Tensor,
+ alpha: float = 1.702, limit: float = 7.0) -> torch.Tensor:
+ """
+ SwigluOAI activation function used in GptOss models.
+
+ Formula:
+ gate = clamp(gate, max=limit)
+ up = clamp(up, -limit, limit)
+ glu = gate * sigmoid(gate * alpha)
+ output = (up + 1) * glu
+
+ Args:
+ gate: Gate tensor from gate projection
+ up: Up tensor from up projection
+ alpha: Scaling factor for sigmoid (default: 1.702)
+ limit: Clamp limit (default: 7.0)
+
+ Returns:
+ Activated tensor
+ """
+ gate = gate.clamp(max=limit)
+ up = up.clamp(min=-limit, max=limit)
+ glu = gate * torch.sigmoid(gate * alpha)
+ return (up + 1) * glu
+
+
+def silu_and_mul_activation(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor:
+ """
+ SiLU (Swish) activation with element-wise multiplication.
+
+ Formula:
+ output = silu(gate) * up
+
+ Args:
+ gate: Gate tensor
+ up: Up tensor
+
+ Returns:
+ Activated tensor
+ """
+ return F.silu(gate) * up
+
+
+def route_tokens_cpu(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: torch.Tensor | None,
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_normalize_expert_weights: int | None = None,
+) -> tuple:
+ """
+ Route tokens to experts and compute expert weights and indices (CPU version).
+
+ Args:
+ x: Input tensor [batch, seq, hidden] or [tokens, hidden]
+ router_weight: Router weight [num_experts, hidden]
+ router_bias: Router bias [num_experts] or None
+ moe_top_k: Number of experts per token
+ moe_num_experts: Total number of experts
+ moe_normalize_expert_weights: Normalization order or None
+
+ Returns:
+ Tuple of (logits, expert_weights, expert_indices)
+ """
+ x_flat = x.view(-1, x.shape[-1])
+ logits = F.linear(x_flat, router_weight, router_bias)
+
+ if moe_top_k == 1:
+ expert_weights, expert_indices = logits.max(dim=-1, keepdim=True)
+ else:
+ expert_weights, expert_indices = torch.topk(logits, moe_top_k, dim=-1)
+
+ expert_weights = expert_weights.softmax(dim=-1)
+
+ if moe_normalize_expert_weights is not None:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+
+ return logits, expert_weights, expert_indices
+
+
+def cpu_fused_moe(
+ hidden_states: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ w1_bias: torch.Tensor | None = None,
+ w2_bias: torch.Tensor | None = None,
+ activation: str = "silu",
+ alpha: float = 1.702,
+ limit: float = 7.0,
+ is_interleaved: bool = True,
+) -> torch.Tensor:
+ """
+ CPU Fused MoE using PyTorch operations.
+
+ This implementation processes all experts in parallel using batched operations
+ instead of sequential for loops, which is more efficient on CPU.
+
+ Args:
+ hidden_states: [num_tokens, hidden_size]
+ w1: [num_experts, hidden_size, 2*inter_size] - gate_up_proj weights
+ w2: [num_experts, inter_size, hidden_size] - down_proj weights
+ topk_weights: [num_tokens, topk] - routing weights
+ topk_ids: [num_tokens, topk] - expert indices
+ w1_bias: [num_experts, 2*inter_size] or None
+ w2_bias: [num_experts, hidden_size] or None
+ activation: "silu" or "swigluoai"
+ alpha: swigluoai alpha parameter
+ limit: swigluoai limit parameter
+ is_interleaved: whether gate_up is interleaved [g0,u0,g1,u1,...] (True for GptOss)
+
+ Returns:
+ output: [num_tokens, hidden_size]
+ """
+ num_tokens, hidden_size = hidden_states.shape
+ num_experts = w1.shape[0]
+ inter_size = w2.shape[1]
+ topk = topk_weights.shape[1]
+
+ # Initialize output
+ output = torch.zeros_like(hidden_states)
+
+ # Build expert mask: which tokens go to which expert
+ # expert_mask[expert_id] contains indices of (token_idx, topk_pos) pairs
+ for expert_idx in range(num_experts):
+ # Find tokens assigned to this expert
+ # mask shape: [num_tokens, topk], True where topk_ids == expert_idx
+ mask = (topk_ids == expert_idx)
+
+ if not mask.any():
+ continue
+
+ # Get token indices and topk positions
+ token_indices, topk_positions = torch.where(mask)
+
+ if len(token_indices) == 0:
+ continue
+
+ # Gather input tokens for this expert
+ # current_hidden: [num_selected_tokens, hidden_size]
+ current_hidden = hidden_states[token_indices]
+
+ # Get weights for this expert
+ # w1[expert_idx]: [hidden_size, 2*inter_size]
+ # w2[expert_idx]: [inter_size, hidden_size]
+ expert_w1 = w1[expert_idx] # [hidden_size, 2*inter_size]
+ expert_w2 = w2[expert_idx] # [inter_size, hidden_size]
+
+ # First projection: hidden @ w1 -> [num_selected, 2*inter_size]
+ gate_up = current_hidden @ expert_w1
+
+ # Add bias if present
+ if w1_bias is not None:
+ gate_up = gate_up + w1_bias[expert_idx]
+
+ # Split gate and up projections
+ if is_interleaved:
+ # GptOss uses interleaved layout: [g0, u0, g1, u1, ...]
+ gate = gate_up[..., ::2] # [num_selected, inter_size]
+ up = gate_up[..., 1::2] # [num_selected, inter_size]
+ else:
+ # Standard layout: [gate_all, up_all]
+ gate = gate_up[..., :inter_size]
+ up = gate_up[..., inter_size:]
+
+ # Apply activation
+ if activation == "swigluoai":
+ activated = swigluoai_activation(gate, up, alpha, limit)
+ else: # silu
+ activated = silu_and_mul_activation(gate, up)
+
+ # Second projection: activated @ w2 -> [num_selected, hidden_size]
+ expert_out = activated @ expert_w2
+
+ # Add bias if present
+ if w2_bias is not None:
+ expert_out = expert_out + w2_bias[expert_idx]
+
+ # Apply routing weights and accumulate
+ # weights shape: [num_selected]
+ weights = topk_weights[token_indices, topk_positions].unsqueeze(-1)
+ weighted_out = expert_out * weights
+
+ # Accumulate to output
+ output.index_add_(0, token_indices, weighted_out.to(output.dtype))
+
+ return output
+
+
+class MegaBlocksMoeMLP(torch.nn.Module):
+ """
+ CPU MoE MLP module that can be used as a drop-in replacement for
+ the transformers GptOssMLP when using @use_kernel_forward_from_hub.
+ """
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> tuple:
+ """
+ Forward pass through the MoE layer.
+
+ Args:
+ x: Input tensor of shape [batch_size, seq_len, hidden_size] or [tokens, hidden_size]
+
+ Returns:
+ Tuple of (output, expert_weights) where:
+ - output: Tensor of same shape as input
+ - expert_weights: Expert weights for each token [tokens, top_k]
+ """
+ # Get MoE parameters from the wrapped modules
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+
+ # Detect activation type
+ if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
+ activation = "swigluoai"
+ alpha = self.experts.alpha
+ limit = self.experts.limit
+ else:
+ activation = getattr(self.experts, "activation", "silu")
+ alpha = 1.702
+ limit = 7.0
+
+ # Get weight tensors
+ if hasattr(self.experts, "gate_up_proj"):
+ w1 = self.experts.gate_up_proj
+ is_interleaved = True # GptOss uses interleaved layout
+ elif hasattr(self.experts, "w1"):
+ w1 = self.experts.w1
+ w3 = getattr(self.experts, "w3", None)
+ if w3 is not None:
+ w1 = torch.cat([w1, w3], dim=-1)
+ is_interleaved = False
+ else:
+ raise AttributeError("experts module must have 'gate_up_proj' or 'w1' attribute")
+
+ if hasattr(self.experts, "down_proj"):
+ w2 = self.experts.down_proj
+ elif hasattr(self.experts, "w2"):
+ w2 = self.experts.w2
+ else:
+ raise AttributeError("experts module must have 'down_proj' or 'w2' attribute")
+
+ # Get optional bias tensors
+ w1_bias = getattr(self.experts, "gate_up_proj_bias", None)
+ w2_bias = getattr(self.experts, "down_proj_bias", None)
+
+ # Store original shape
+ in_shape = x.size()
+
+ # Route tokens to experts
+ logits, expert_weights, expert_indices = route_tokens_cpu(
+ x,
+ self.router.weight,
+ getattr(self.router, "bias", None),
+ moe_top_k,
+ moe_num_experts,
+ moe_normalize_expert_weights,
+ )
+
+ # Reshape input for fused MoE
+ x_flat = x.view(-1, x.shape[-1])
+
+ # Call CPU fused MoE
+ output = cpu_fused_moe(
+ hidden_states=x_flat,
+ w1=w1,
+ w2=w2,
+ topk_weights=expert_weights,
+ topk_ids=expert_indices,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ activation=activation,
+ alpha=alpha,
+ limit=limit,
+ is_interleaved=is_interleaved,
+ )
+
+ # Restore original shape
+ output = output.view(in_shape)
+
+ return output, expert_weights
+
+
+# Export classes and functions
+__all__ = [
+ "MegaBlocksMoeMLP",
+ "cpu_fused_moe",
+ "route_tokens_cpu",
+ "swigluoai_activation",
+ "silu_and_mul_activation",
+]
diff --git a/build/torch211-cxx11-cu126-aarch64-linux/cpu_moe_cpp.py b/build/torch211-cxx11-cu126-aarch64-linux/cpu_moe_cpp.py
new file mode 100644
index 0000000000000000000000000000000000000000..073ff66d24ce348fbb5ed19c9027fadd3f7a9c61
--- /dev/null
+++ b/build/torch211-cxx11-cu126-aarch64-linux/cpu_moe_cpp.py
@@ -0,0 +1,265 @@
+# SPDX-License-Identifier: Apache-2.0
+# MegaBlocks C++ Optimized CPU MoE
+
+"""
+C++ accelerated MoE with brgemm optimization for Intel AMX.
+Direct replacement for cpu_fused_moe.MegaBlocksMoeMLP with better performance.
+"""
+
+import torch
+from typing import Optional
+from .cpu_fused_moe import route_tokens_cpu
+from ._ops import ops
+
+
+def _to_local_tensor(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
+ """Convert DTensor to local torch.Tensor if needed for custom ops compatibility."""
+ if tensor is None:
+ return None
+ # Check if it's a DTensor by looking for the to_local() method
+ if hasattr(tensor, "to_local"):
+ return tensor.to_local()
+ return tensor
+
+
+def fused_moe_cpp(
+ hidden_states: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ inplace: bool = False,
+ use_int8_w8a8: bool = False,
+ use_fp8_w8a16: bool = False,
+ use_mxfp4: bool = False,
+ w1_scale: Optional[torch.Tensor] = None,
+ w2_scale: Optional[torch.Tensor] = None,
+ block_size: Optional[list] = None,
+ a1_scale: Optional[torch.Tensor] = None,
+ a2_scale: Optional[torch.Tensor] = None,
+ w1_bias: Optional[torch.Tensor] = None,
+ w2_bias: Optional[torch.Tensor] = None,
+ alpha: Optional[float] = None,
+ limit: Optional[float] = None,
+ is_vnni: bool = False,
+) -> torch.Tensor:
+ """
+ C++ Fused MoE with brgemm optimization (sglang compatible interface).
+
+ Uses at::native::cpublas::brgemm for efficient batch GEMM on Intel CPUs.
+ Supports both silu_and_mul (standard SwiGLU) and swigluoai (GptOss) activations.
+
+ Args:
+ hidden_states: Input tensor [M, K]
+ w1: Gate and up projections [E, 2N, K]
+ w2: Down projection [E, K, N]
+ topk_weights: Expert weights [M, topk]
+ topk_ids: Expert indices [M, topk]
+ inplace: Whether to use hidden_states as output
+ use_int8_w8a8: Use int8 quantization
+ use_fp8_w8a16: Use fp8 quantization
+ use_mxfp4: Use mxfp4 quantization
+ w1_scale, w2_scale: Quantization scales
+ block_size: Block size for fp8
+ a1_scale, a2_scale: Activation scales
+ w1_bias, w2_bias: Optional biases
+ alpha: swigluoai alpha parameter (set to enable swiglu)
+ limit: swigluoai limit parameter (set to enable swiglu)
+ is_vnni: Whether w1/w2 are already in VNNI packed format
+ """
+ # MXFP4/FP8 kernels only support bf16, convert if needed
+ orig_dtype = hidden_states.dtype
+ need_convert = ((use_mxfp4 or use_fp8_w8a16) and orig_dtype != torch.bfloat16) or orig_dtype == torch.float32
+ if need_convert:
+ hidden_states = hidden_states.to(torch.bfloat16)
+
+ # bias must match hidden_states dtype
+ if w1_bias is not None:
+ w1_bias = w1_bias.to(hidden_states.dtype)
+ if w2_bias is not None:
+ w2_bias = w2_bias.to(hidden_states.dtype)
+
+ # Convert DTensor to local tensor for custom ops compatibility (TP mode)
+ hidden_states = _to_local_tensor(hidden_states)
+ w1 = _to_local_tensor(w1)
+ w2 = _to_local_tensor(w2)
+ topk_weights = _to_local_tensor(topk_weights)
+ topk_ids = _to_local_tensor(topk_ids)
+ w1_scale = _to_local_tensor(w1_scale)
+ w2_scale = _to_local_tensor(w2_scale)
+ a1_scale = _to_local_tensor(a1_scale)
+ a2_scale = _to_local_tensor(a2_scale)
+ w1_bias = _to_local_tensor(w1_bias)
+ w2_bias = _to_local_tensor(w2_bias)
+
+ output = ops.fused_experts(
+ hidden_states, w1, w2, topk_weights, topk_ids,
+ inplace, use_int8_w8a8, use_fp8_w8a16, use_mxfp4,
+ w1_scale, w2_scale, block_size, a1_scale, a2_scale,
+ w1_bias, w2_bias, alpha, limit, is_vnni
+ )
+
+ # Convert back to original dtype if needed
+ if need_convert:
+ output = output.to(orig_dtype)
+ return output
+
+
+class CPUMegaBlocksMoeMLP(torch.nn.Module):
+ """
+ C++ optimized MoE MLP using brgemm.
+ Drop-in replacement for cpu_fused_moe.MegaBlocksMoeMLP with better performance.
+
+ Usage in transformers:
+ # Will be used via @use_kernel_forward_from_hub decorator
+ """
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> tuple:
+ """
+ Forward pass through the MoE layer using C++ kernel.
+
+ Args:
+ x: Input tensor [batch_size, seq_len, hidden_size]
+
+ Returns:
+ Tuple of (output, expert_weights)
+ """
+ # Optimization for GPT-OSS model
+ if getattr(self, "use_mxfp4", None) is None:
+ self.use_mxfp4 = False
+
+ w1_scale = None
+ w2_scale = None
+
+ if (
+ not getattr(self, "packed_scales", False)
+ and hasattr(self.experts, "gate_up_proj")
+ and getattr(self.experts, "gate_up_proj_precision_config", None) is not None
+ ):
+ # convert scales
+ data_1 = ops.convert_scale_packed(self.experts.gate_up_proj_precision_config.weight_scale.data.transpose(-1, -2).contiguous())
+ data_2 = ops.convert_scale_packed(self.experts.down_proj_precision_config.weight_scale.data.transpose(-1, -2).contiguous())
+ self.experts.gate_up_proj_precision_config.weight_scale.storage.data = data_1
+ self.experts.down_proj_precision_config.weight_scale.storage.data = data_2
+ self.packed_scales = True
+ self.use_mxfp4 = True
+
+ if not getattr(self, "packed_weight", False) and hasattr(
+ self.experts, "gate_up_proj"
+ ):
+ # convert weights
+ data_1 = self.experts.gate_up_proj.data.transpose(-1, -2).contiguous()
+ data_2 = self.experts.down_proj.data.transpose(-1, -2).contiguous()
+ if self.use_mxfp4:
+ self.experts.gate_up_proj.storage.data = ops.convert_weight_packed(data_1)
+ self.experts.down_proj.storage.data = ops.convert_weight_packed(data_2)
+ else:
+ # convert_weight_packed only supports bfloat16, float16, int8, fp8_e4m3 or uint8(mxfp4 or int4).
+ data_1 = data_1.to(torch.bfloat16) if data_1.dtype == torch.float32 else data_1
+ data_2 = data_2.to(torch.bfloat16) if data_2.dtype == torch.float32 else data_2
+ self.experts.gate_up_proj.data = ops.convert_weight_packed(data_1)
+ self.experts.down_proj.data = ops.convert_weight_packed(data_2)
+
+ # C++ kernel does not support float32.
+ dtype = torch.bfloat16 if x.dtype == torch.float32 else x.dtype
+ if getattr(self.experts, "gate_up_proj_bias", None) is not None:
+ self.experts.gate_up_proj_bias.data = self.experts.gate_up_proj_bias.data.to(dtype)
+ if getattr(self.experts, "down_proj_bias", None) is not None:
+ self.experts.down_proj_bias.data = self.experts.down_proj_bias.data.to(dtype)
+
+ self.packed_weight = True
+
+ # Get MoE parameters
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
+
+ # Detect activation type
+ if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
+ activation = "swigluoai"
+ alpha = self.experts.alpha
+ limit = self.experts.limit
+ else:
+ activation = getattr(self.experts, "activation", "silu")
+ alpha = 1.702
+ limit = 7.0
+
+ # Get weight tensors
+ if hasattr(self.experts, "gate_up_proj"):
+ w1 = self.experts.gate_up_proj
+ elif hasattr(self.experts, "w1"):
+ w1 = self.experts.w1
+ w3 = getattr(self.experts, "w3", None)
+ if w3 is not None:
+ w1 = torch.cat([w1, w3], dim=-1)
+ else:
+ raise AttributeError("experts module must have 'gate_up_proj' or 'w1' attribute")
+
+ if hasattr(self.experts, "down_proj"):
+ w2 = self.experts.down_proj
+ elif hasattr(self.experts, "w2"):
+ w2 = self.experts.w2
+ else:
+ raise AttributeError("experts module must have 'down_proj' or 'w2' attribute")
+
+ # Get optional bias tensors
+ w1_bias = getattr(self.experts, "gate_up_proj_bias", None)
+ w2_bias = getattr(self.experts, "down_proj_bias", None)
+ w1_bias = w1_bias if w1_bias is None else w1_bias.data
+ w2_bias = w2_bias if w2_bias is None else w2_bias.data
+
+ if self.use_mxfp4:
+ w1_scale = self.experts.gate_up_proj_precision_config.weight_scale.data
+ w2_scale = self.experts.down_proj_precision_config.weight_scale.data
+
+ # Store original shape
+ in_shape = x.size()
+
+ # Route tokens to experts (Python implementation is fast enough)
+ logits, expert_weights, expert_indices = route_tokens_cpu(
+ x,
+ self.router.weight,
+ getattr(self.router, "bias", None),
+ moe_top_k,
+ moe_num_experts,
+ moe_normalize_expert_weights,
+ )
+
+ # Flatten input
+ x_flat = x.view(-1, x.shape[-1])
+
+ # Determine alpha/limit for swiglu activation
+ use_alpha = alpha if activation == "swigluoai" else None
+ use_limit = limit if activation == "swigluoai" else None
+
+ # Call C++ optimized kernel
+ output = fused_moe_cpp(
+ hidden_states=x_flat,
+ w1=w1.data,
+ w2=w2.data,
+ topk_weights=expert_weights,
+ topk_ids=expert_indices.to(torch.int32),
+ inplace=False,
+ use_int8_w8a8=False,
+ use_fp8_w8a16=False,
+ use_mxfp4=self.use_mxfp4,
+ w1_scale=w1_scale,
+ w2_scale=w2_scale,
+ block_size=None,
+ a1_scale=None,
+ a2_scale=None,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ alpha=use_alpha,
+ limit=use_limit,
+ is_vnni=getattr(self, "packed_weight", False),
+ )
+
+ # Restore original shape
+ output = output.view(in_shape)
+
+ return output, expert_weights
+
+
+__all__ = ["fused_moe_cpp", "MegaBlocksMoeMLP"]
diff --git a/build/torch211-cxx11-cu126-aarch64-linux/grouped_gemm/__init__.py b/build/torch211-cxx11-cu126-aarch64-linux/grouped_gemm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b91c8308f0c24f4c4171b6e4f15b6f76dabf295a
--- /dev/null
+++ b/build/torch211-cxx11-cu126-aarch64-linux/grouped_gemm/__init__.py
@@ -0,0 +1,2 @@
+from . import ops
+from . import backend
diff --git a/build/torch211-cxx11-cu126-aarch64-linux/grouped_gemm/backend.py b/build/torch211-cxx11-cu126-aarch64-linux/grouped_gemm/backend.py
new file mode 100644
index 0000000000000000000000000000000000000000..76037d8039cbfc2f0577275c78e4bc0be762592a
--- /dev/null
+++ b/build/torch211-cxx11-cu126-aarch64-linux/grouped_gemm/backend.py
@@ -0,0 +1,33 @@
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# # TODO(tgale): Wrap this in a try-block with better
+# # error message and instructions for building the
+# # c++ operations.
+# import grouped_gemm_backend as backend
+
+# We import the backend operations from the megablocks package as
+# grouped_gemm is vendored in megablocks in this repository.
+# from ... import _ops as backend
+# from megablocks._ops import ops as backend # type: ignore
+from .._ops import ops as backend # type: ignore
+
+def _allocate_output(a, b, batch_sizes, trans_a, trans_b):
+ assert not (trans_a and trans_b)
+ assert batch_sizes.ndim == 1, "Expected 1d tensor for batch_sizes"
+ assert a.ndim == 2, "Expected 2d tensor for 'a'"
+ assert b.ndim == (2 if trans_a else 3)
+
+ shape = (
+ (batch_sizes.shape[0], a.shape[1], b.shape[1])
+ if trans_a else
+ (a.shape[0], (b.shape[1] if trans_b else b.shape[2]))
+ )
+ return torch.empty(*shape, device=a.device, dtype=a.dtype)
+
+def gmm(a, b, batch_sizes, trans_a=False, trans_b=False, c=None):
+ if c is None:
+ c = _allocate_output(a, b, batch_sizes, trans_a, trans_b)
+ backend.gmm(a, b, c, batch_sizes, trans_a, trans_b)
+ return c
diff --git a/build/torch211-cxx11-cu126-aarch64-linux/grouped_gemm/ops.py b/build/torch211-cxx11-cu126-aarch64-linux/grouped_gemm/ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b30dd14e23837ea3b12334f4e31337ed9ad2b69
--- /dev/null
+++ b/build/torch211-cxx11-cu126-aarch64-linux/grouped_gemm/ops.py
@@ -0,0 +1,33 @@
+from . import backend
+import torch
+
+
+class GroupedGemm(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, a, b, batch_sizes, trans_b):
+ ctx.save_for_backward(a, b, batch_sizes)
+ ctx.trans_b = trans_b
+ return backend.gmm(a, b, batch_sizes, trans_a=False, trans_b=trans_b)
+
+ @staticmethod
+ def backward(ctx, grad):
+ grad = grad.contiguous()
+ a, b, batch_sizes = ctx.saved_tensors
+ trans_b = ctx.trans_b
+
+ agrad = None
+ if ctx.needs_input_grad[0]:
+ agrad = backend.gmm(
+ grad, b, batch_sizes, trans_a=False, trans_b=not trans_b)
+
+ bgrad = None
+ if ctx.needs_input_grad[1]:
+ lhs, rhs = (grad, a) if trans_b else (a, grad)
+ bgrad = backend.gmm(
+ lhs, rhs, batch_sizes, trans_a=True, trans_b=False)
+ return agrad, bgrad, None, None
+
+
+def gmm(a, b, batch_sizes, trans_b=False):
+ return GroupedGemm.apply(a, b, batch_sizes, trans_b)
diff --git a/build/torch211-cxx11-cu126-aarch64-linux/grouped_gemm_util.py b/build/torch211-cxx11-cu126-aarch64-linux/grouped_gemm_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d6d49fc46b0a57ad46e4179df3cc1ac2a24f7ae
--- /dev/null
+++ b/build/torch211-cxx11-cu126-aarch64-linux/grouped_gemm_util.py
@@ -0,0 +1,31 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+import warnings
+
+_grouped_gemm_is_available: bool = False
+try:
+ # import grouped_gemm
+ pass
+ _grouped_gemm_is_available = True
+except ImportError as error:
+ warnings.warn('Grouped GEMM not available.')
+
+
+def grouped_gemm_is_available():
+ return _grouped_gemm_is_available
+
+
+def assert_grouped_gemm_is_available():
+ msg = (
+ 'Grouped GEMM not available. Please run '
+ '`pip install git+https://github.com/tgale96/grouped_gemm@main`.',
+ )
+ assert _grouped_gemm_is_available, msg
+
+
+# backend = grouped_gemm.backend if grouped_gemm_is_available() else None
+# ops = grouped_gemm.ops if grouped_gemm_is_available() else None
+
+
+from .grouped_gemm import backend as ops
+from .grouped_gemm import ops as backend
diff --git a/build/torch211-cxx11-cu126-aarch64-linux/layers.py b/build/torch211-cxx11-cu126-aarch64-linux/layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c91edcd9e2d1a4ef9eac90217ff481f08ab1886
--- /dev/null
+++ b/build/torch211-cxx11-cu126-aarch64-linux/layers.py
@@ -0,0 +1,1232 @@
+import torch
+import torch.distributed as dist
+
+from typing import Optional, Any, TYPE_CHECKING
+
+from . import _layers
+from . import ops
+
+# Conditional import for meta kernel registration
+if TYPE_CHECKING:
+
+ def register_fake(fn):
+ return lambda name: fn
+
+else:
+ try:
+ from torch.library import register_fake
+ except ImportError:
+ try:
+ from torch.library import impl_abstract as register_fake
+ except ImportError:
+ # Fallback for older PyTorch versions
+ def register_fake(op_name):
+ def decorator(fn):
+ return fn
+
+ return decorator
+
+
+# Meta kernel implementations for torch.compile compatibility
+def _install_meta_kernels():
+ """Install meta kernels for existing MegaBlocks operations"""
+
+ # Create wrapper functions that check for compilation and return meta tensors
+
+ # Patch ops.sort
+ if hasattr(ops, "sort"):
+ original_sort = ops.sort
+
+ def sort_with_meta(x, end_bit=None):
+ if torch.compiler.is_compiling():
+ # print("Using meta kernel for sort")
+ # Meta implementation - return tensors with correct shape/dtype/device
+ return torch.empty_like(x), torch.empty_like(x)
+ # print("Using original sort kernel")
+ return original_sort(x, end_bit)
+
+ ops.sort = sort_with_meta
+
+ # Patch ops.histogram
+ if hasattr(ops, "histogram"):
+ original_histogram = ops.histogram
+
+ def histogram_with_meta(x, max_val):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty((max_val,), dtype=torch.int32, device=x.device)
+ return original_histogram(x, max_val)
+
+ ops.histogram = histogram_with_meta
+
+ # Patch ops.inclusive_cumsum
+ if hasattr(ops, "inclusive_cumsum"):
+ original_inclusive_cumsum = ops.inclusive_cumsum
+
+ def inclusive_cumsum_with_meta(x, dim):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty_like(x)
+ return original_inclusive_cumsum(x, dim)
+
+ ops.inclusive_cumsum = inclusive_cumsum_with_meta
+
+ # Patch ops.binned_gather
+ if hasattr(ops, "binned_gather"):
+ original_binned_gather = ops.binned_gather
+
+ def binned_gather_with_meta(x, indices, bins, bin_size, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - output shape based on bin_size
+ if x.dim() >= 2:
+ hidden_size = x.size(-1)
+ return torch.empty(
+ (bin_size, x.size(1), hidden_size),
+ dtype=x.dtype,
+ device=x.device,
+ )
+ else:
+ return torch.empty((bin_size,), dtype=x.dtype, device=x.device)
+ return original_binned_gather(x, indices, bins, bin_size, top_k)
+
+ ops.binned_gather = binned_gather_with_meta
+
+ # Patch ops.binned_scatter
+ if hasattr(ops, "binned_scatter"):
+ original_binned_scatter = ops.binned_scatter
+
+ def binned_scatter_with_meta(x, indices, weights, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - typically reduces to 2D
+ if x.dim() >= 3:
+ return torch.empty(
+ (x.size(1), x.size(2)), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty_like(x)
+ return original_binned_scatter(x, indices, weights, bins, top_k)
+
+ ops.binned_scatter = binned_scatter_with_meta
+
+ # Patch ops.gather
+ if hasattr(ops, "gather"):
+ original_gather = ops.gather
+
+ def gather_with_meta(x, indices, bin_ids, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ if x.dim() >= 2:
+ hidden_size = x.size(-1)
+ return torch.empty(
+ (indices.numel(), hidden_size), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty(indices.shape, dtype=x.dtype, device=x.device)
+ return original_gather(x, indices, bin_ids, bins, top_k)
+
+ ops.gather = gather_with_meta
+
+ # Patch ops.scatter
+ if hasattr(ops, "scatter"):
+ original_scatter = ops.scatter
+
+ def scatter_with_meta(x, indices, bin_ids, weights, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - restore sequence shape
+ seq_len = (
+ indices.size(0) // top_k
+ if indices.numel() > 0 and top_k > 0
+ else x.size(0)
+ )
+ if x.dim() >= 2:
+ return torch.empty(
+ (seq_len, x.size(-1)), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty((seq_len,), dtype=x.dtype, device=x.device)
+ return original_scatter(x, indices, bin_ids, weights, bins, top_k)
+
+ ops.scatter = scatter_with_meta
+
+ # Patch ops.replicate
+ if hasattr(ops, "replicate"):
+ original_replicate = ops.replicate
+
+ def replicate_with_meta(x, bins, num_outputs):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty(
+ (x.shape[0], num_outputs), dtype=x.dtype, device=x.device
+ )
+ return original_replicate(x, bins, num_outputs)
+
+ ops.replicate = replicate_with_meta
+
+ # Patch ops.repeat (if it's a regular function)
+ if hasattr(ops, "repeat"):
+ original_repeat = ops.repeat
+
+ def repeat_with_meta(x, repeats):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ if isinstance(repeats, (tuple, list)):
+ new_shape = list(x.shape)
+ for i, rep in enumerate(repeats):
+ if i < len(new_shape):
+ new_shape[i] *= rep
+ return torch.empty(new_shape, dtype=x.dtype, device=x.device)
+ else:
+ new_shape = [x.size(0) * repeats] + list(x.shape[1:])
+ return torch.empty(new_shape, dtype=x.dtype, device=x.device)
+ return original_repeat(x, repeats)
+
+ ops.repeat = repeat_with_meta
+
+
+# Install meta kernels on import
+try:
+ _install_meta_kernels()
+except Exception as e:
+ # If meta kernel installation fails, continue without them
+ # torch.compile may not work but the library will still function
+ import warnings
+
+ warnings.warn(
+ f"Failed to install meta kernels for torch.compile support: {e}", UserWarning
+ )
+
+
+# Set the expert model parallel attributes on a tensor
+def set_expert_model_parallel_attributes(
+ tensor: torch.Tensor,
+ is_parallel: bool,
+):
+ assert not hasattr(tensor, "expert_model_parallel")
+ setattr(tensor, "expert_model_parallel", is_parallel)
+
+
+# Get the expert model parallel attributes from a tensor
+def expert_sharding_degree(
+ world_size: int,
+ moe_num_experts: int,
+) -> int:
+ esd = min(world_size, moe_num_experts)
+ if (moe_num_experts % esd) != 0:
+ raise ValueError(f"Cannot shard {moe_num_experts} experts {esd} ways.")
+ return esd
+
+
+# Calculate the hidden sharding degree based on world size and expert sharding degree
+def hidden_sharding_degree(
+ world_size: int,
+ moe_num_experts: int,
+ ffn_hidden_size: int,
+) -> int:
+ esd = expert_sharding_degree(world_size, moe_num_experts)
+ hsd = world_size // esd
+ if (ffn_hidden_size % hsd) != 0:
+ raise ValueError(f"Cannot shard {ffn_hidden_size} features {hsd} ways.")
+ if (esd * hsd) != world_size:
+ raise ValueError(
+ f"Invalid sharding. expert_sharding_degree ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size})."
+ )
+ return hsd
+
+
+# Calculate the number of experts per rank based on world size and expert sharding degree
+def experts_per_rank(
+ moe_num_experts: int,
+ world_size: int,
+) -> int:
+ return moe_num_experts // expert_sharding_degree(world_size, moe_num_experts)
+
+
+# Calculate the number of features per rank based on ffn hidden size and hidden sharding degree
+def features_per_rank(
+ ffn_hidden_size: int, world_size: int, moe_num_experts: int
+) -> int:
+ return ffn_hidden_size // hidden_sharding_degree(
+ world_size, moe_num_experts, ffn_hidden_size
+ )
+
+
+# Apply jitter to the input tensor
+def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor:
+ low = 1.0 - moe_jitter_eps
+ high = 1.0 + moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return x * (low + noise * (high - low))
+
+
+# Compute the top-k scores from the logits
+def compute_top_k(scores: torch.Tensor, moe_top_k: int):
+ if moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, moe_top_k, dim=-1)
+
+
+# Route tokens to experts and compute expert weights and indices
+def route_tokens(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: torch.Tensor,
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ if training and moe_jitter_eps is not None:
+ x = apply_jitter(x, moe_jitter_eps)
+
+ x_flat = x.view(-1, x.shape[-1])
+ logits = torch.nn.functional.linear(x_flat, router_weight, router_bias)
+ expert_weights, expert_indices = compute_top_k(logits, moe_top_k)
+ expert_weights = expert_weights.softmax(dim=-1)
+ if moe_normalize_expert_weights is not None:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+ if uniform_expert_assignment:
+ expert_indices = _layers.router._uniform_expert_assignment(
+ expert_indices,
+ moe_num_experts,
+ )
+
+ return logits, expert_weights, expert_indices
+
+
+# Scale the gradient of the weights
+def scale_grad(
+ w: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+) -> torch.Tensor:
+ if gradient_scale is None:
+ return w
+ return _layers.mlp.scale_gradient(w, gradient_scale)
+
+
+# Forward pass for the MLP layer
+def mlp_forward(
+ x: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ limit: float = 7.0,
+):
+ # Scale weights
+ w1 = scale_grad(w1, gradient_scale)
+ w2 = scale_grad(w2, gradient_scale)
+ w1_bias = scale_grad(w1_bias, gradient_scale)
+ w2_bias = scale_grad(w2_bias, gradient_scale)
+
+ # Resolve dtensors
+ w1 = _layers.mlp.resolve_dtensor(w1)
+ w2 = _layers.mlp.resolve_dtensor(w2)
+ w1_bias = _layers.mlp.resolve_dtensor(w1_bias)
+ w2_bias = _layers.mlp.resolve_dtensor(w2_bias)
+
+ # Forward pass
+ gate_up = torch.bmm(x, w1) + w1_bias[..., None, :]
+ gate, up = gate_up[..., ::2], gate_up[..., 1::2]
+ gate = gate.clamp(min=None, max=limit)
+ up = up.clamp(min=-limit, max=limit)
+ glu = gate * torch.sigmoid(gate * alpha)
+ next_states = torch.bmm(((up + 1) * glu), w2)
+ next_states += w2_bias[..., None, :]
+ return next_states
+
+# Shared expert MLP forward pass
+def shared_mlp_forward(
+ x: torch.Tensor,
+ up_proj_weight: torch.Tensor,
+ down_proj_weight: torch.Tensor,
+ up_proj_bias: Optional[torch.Tensor] = None,
+ down_proj_bias: Optional[torch.Tensor] = None,
+ activation_fn: Optional[Any] = None,
+ gradient_scale: Optional[float] = None,
+) -> torch.Tensor:
+ # Default activation function
+ if activation_fn is None:
+ activation_fn = torch.nn.functional.gelu
+
+ # Scale weights
+ up_proj_weight = scale_grad(up_proj_weight, gradient_scale)
+ down_proj_weight = scale_grad(down_proj_weight, gradient_scale)
+ if up_proj_bias is not None:
+ up_proj_bias = scale_grad(up_proj_bias, gradient_scale)
+ if down_proj_bias is not None:
+ down_proj_bias = scale_grad(down_proj_bias, gradient_scale)
+
+ # Resolve dtensors
+ up_proj_weight = _layers.mlp.resolve_dtensor(up_proj_weight)
+ down_proj_weight = _layers.mlp.resolve_dtensor(down_proj_weight)
+ if up_proj_bias is not None:
+ up_proj_bias = _layers.mlp.resolve_dtensor(up_proj_bias)
+ if down_proj_bias is not None:
+ down_proj_bias = _layers.mlp.resolve_dtensor(down_proj_bias)
+
+ # Up projection
+ x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias)
+
+ # Activation
+ x = activation_fn(x)
+
+ # Down projection
+ x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias)
+
+ return x
+
+
+# Combine outputs from shared expert and regular experts
+def combine_expert_shared_outputs(
+ shared_expert_out: torch.Tensor,
+ expert_out: torch.Tensor,
+ shared_expert_weighted_sum: bool = False,
+ moe_top_k: int = 1,
+) -> torch.Tensor:
+ if shared_expert_weighted_sum:
+ # Weighted sum based on number of experts used
+ total_experts = moe_top_k + 1
+ shared_weight = 1.0 / total_experts
+ expert_weight = moe_top_k / total_experts
+ return shared_expert_out * shared_weight + expert_out * expert_weight
+ else:
+ # Simple addition
+ return shared_expert_out + expert_out
+
+
+# Global variable to store load balancing loss
+_LOAD_BALANCING_LOSS = []
+
+
+def save_load_balancing_loss(loss):
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.append(loss)
+
+
+def get_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ return _LOAD_BALANCING_LOSS
+
+
+def clear_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.clear()
+
+
+def batched_load_balancing_loss(args):
+ if args.moe_loss_weight == 0:
+ return 0.0
+
+ tokens_per_expert, expert_scores = zip(*get_load_balancing_loss())
+ num_layers_per_pipeline_stage = args.num_layers // args.pipeline_model_parallel_size
+ if args.num_layers_per_virtual_pipeline_stage is not None:
+ num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage
+
+ if len(tokens_per_expert) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f"Expected {num_layers_per_pipeline_stage} token_per_experts "
+ f"but found {len(tokens_per_expert)}.\nnum_layers = "
+ f"{args.num_layers}\npipeline_model_parallel_size = "
+ f"{args.pipeline_model_parallel_size}\n"
+ "num_layers_per_virtual_pipeline_stage"
+ f" = {args.num_layers_per_virtual_pipeline_stage}",
+ )
+ if len(expert_scores) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f"Expected {num_layers_per_pipeline_stage} expert_scores "
+ f"but found {len(tokens_per_expert)}.\nnum_layers = "
+ f"{args.num_layers}\npipeline_model_parallel_size = "
+ f"{args.pipeline_model_parallel_size}\n"
+ "num_layers_per_virtual_pipeline_stage"
+ f" = {args.num_layers_per_virtual_pipeline_stage}",
+ )
+
+ # Verify the shape of the tokens_per_expert and expert_scores tensors.
+ assert all(
+ (x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert)
+ )
+
+ tokens = expert_scores[0].shape[0]
+ assert all(
+ (
+ (
+ x.ndim == 2
+ and x.shape[1] == args.moe_num_experts
+ and x.shape[0] == tokens
+ )
+ for x in expert_scores
+ )
+ )
+
+ # Concatenate the contributions of each layer and convert to
+ # the correct types and formats for the dot product.
+ expert_scores = torch.cat(expert_scores, dim=1)
+ if args.moe_lbl_in_fp32:
+ expert_scores = expert_scores.float()
+ if tokens != 0:
+ expert_scores = expert_scores.mean(dim=0)
+ else:
+ expert_scores = expert_scores.sum(dim=0)
+ tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype)
+
+ expected_values = num_layers_per_pipeline_stage * args.moe_num_experts
+ assert tokens_per_expert.numel() == expected_values
+ assert expert_scores.numel() == expected_values
+
+ # Calculate the total scale across all factors.
+ #
+ # loss_weight * num_experts / (num_layers * tokens * top_k)
+ scale_numerator = args.moe_num_experts * args.moe_loss_weight
+ scale_denominator = args.num_layers * tokens * args.moe_top_k
+ scale = scale_numerator / scale_denominator
+ return scale * torch.dot(tokens_per_expert, expert_scores)
+
+
+# Calculate the expert capacity based on tokens, top_k, number of experts,
+# expert parallel group, capacity factor, and whether expert model parallelism is used.
+def expert_capacity(
+ tokens: int,
+ top_k: int,
+ num_experts: int,
+ expert_parallel_group: int,
+ moe_capacity_factor: float,
+ moe_expert_model_parallelism: bool,
+) -> int:
+ world_size = (
+ dist.get_world_size(expert_parallel_group)
+ if moe_expert_model_parallelism
+ else 1
+ )
+
+ tokens_per_expert = top_k * tokens * world_size / num_experts
+ return int(moe_capacity_factor * tokens_per_expert)
+
+
+def load_balancing_loss(
+ tokens_per_expert: torch.Tensor,
+ expert_scores: torch.Tensor,
+ top_k: int,
+ num_experts: int,
+):
+ assert len(expert_scores.size()) == 2
+ tokens, num_experts = expert_scores.size()
+ assert num_experts == num_experts
+ assert len(tokens_per_expert.size()) == 1
+ (num_experts,) = tokens_per_expert.size()
+ assert num_experts == num_experts
+ scale = num_experts / (tokens * top_k)
+ return scale * torch.dot(
+ tokens_per_expert.to(expert_scores.dtype),
+ expert_scores.mean(dim=0),
+ )
+
+
+def indices_and_bins(
+ top_expert: torch.Tensor,
+ sort_end_bit: int,
+ num_experts: int,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ top_expert = top_expert.int()
+
+ # Ensure contiguous memory layout
+ top_expert = top_expert.contiguous()
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(top_expert.device):
+ output = ops.sort(top_expert, sort_end_bit)
+ bin_ids, indices = output
+ tokens_per_expert = ops.histogram(top_expert, num_experts)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+
+ bins = bins.view(1) if not len(bins.size()) else bins
+ return indices, bin_ids, bins, tokens_per_expert
+
+
+def expert_capacity_fn(
+ tokens: int,
+ top_k: int,
+ num_experts: int,
+ expert_parallel_group: torch.distributed.ProcessGroup,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+) -> int:
+ world_size = (
+ dist.get_world_size(expert_parallel_group)
+ if moe_expert_model_parallelism
+ else 1
+ )
+ tokens_per_expert = top_k * tokens * world_size / num_experts
+ return int(moe_capacity_factor * tokens_per_expert)
+
+
+def permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ top_k,
+ w1,
+ w2,
+ w1_bias,
+ w2_bias,
+ gradient_scale,
+ alpha,
+):
+ # Route tokens to experts
+ x = x.view(-1, x.shape[-1])
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(x.device):
+ x = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
+
+ # Expert computation
+ x = mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale, alpha)
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(x.device):
+ # Route tokens back
+ out = ops.binned_scatter(x, indices, expert_weights, bins, top_k)
+ return out
+
+
+def forward_once(
+ x: torch.Tensor,
+ expert_weights: torch.Tensor,
+ top_experts: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ top_k: int = 4,
+ num_experts: int = 128,
+ expert_parallel_group: int = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ mlp_impl: Optional[str] = None,
+):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
+ top_experts, sort_end_bit, num_experts
+ )
+
+ # Calculate expert capacity
+ sl, bs, _ = x.size()
+
+ expert_capacity = expert_capacity_fn(
+ sl * bs,
+ top_k,
+ num_experts,
+ expert_parallel_group,
+ moe_capacity_factor,
+ moe_expert_model_parallelism,
+ )
+
+ if expert_capacity == 0:
+ expert_capacity = torch.max(tokens_per_expert).item()
+
+ x = permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ top_k,
+ w1,
+ w2,
+ w1_bias,
+ w2_bias,
+ gradient_scale,
+ alpha,
+ )
+ return x, tokens_per_expert
+
+
+def parallel_forward_once(
+ x: torch.Tensor,
+ expert_weights: torch.Tensor,
+ top_experts: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ top_k: int = 4,
+ num_experts: int = 128,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = True,
+ hidden_size: int = 1152,
+ mlp_impl: Optional[str] = "grouped",
+):
+ # Flatten inputs
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+
+ # TODO: remove debugging var
+ # my_rank = dist.get_rank(expert_parallel_group) if expert_parallel_group else 0
+
+ with torch.no_grad():
+ # Step 1: Local permutation setup
+ indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
+ top_experts, sort_end_bit, num_experts
+ )
+
+ # Calculate sharding parameters
+ world_size = dist.get_world_size(expert_parallel_group)
+ hidden_sharding_deg = hidden_sharding_degree(
+ world_size, num_experts, hidden_size
+ )
+ experts_per_rank_val = experts_per_rank(num_experts, world_size)
+
+ # Replicate token counts for hidden sharding
+ repeated_tokens_per_expert = ops.repeat(
+ tokens_per_expert, (hidden_sharding_deg,)
+ )
+
+ # Exchange token counts across devices
+ parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert)
+
+ # Ensure CUB knows which device to use
+ tpe_handle = dist.all_to_all_single(
+ parallel_tokens_per_expert,
+ repeated_tokens_per_expert,
+ group=expert_parallel_group,
+ async_op=True,
+ )
+
+ # Step 2: Local permutation - group tokens by target device
+ x = x.view(-1, x.shape[-1]) # [sl * bs, hs]
+ x = ops.gather(x, indices, bin_ids, bins, top_k)
+
+ # Step 3: Compute communication counts and exchange tokens
+ with torch.no_grad():
+ tpe_handle.wait()
+
+ # Reshape for per-device calculations
+ repeated_tokens_per_expert = repeated_tokens_per_expert.view(
+ world_size, experts_per_rank_val
+ )
+ parallel_tokens_per_expert = parallel_tokens_per_expert.view(
+ world_size, experts_per_rank_val
+ )
+
+ # Calculate send/recv counts
+ send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1).tolist()
+ # recv_counts = parallel_tokens_per_expert.cpu().sum(dim=-1).tolist()
+ parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
+ recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1).tolist()
+ tokens_received = sum(recv_counts)
+
+ # Replicate for hidden sharding
+ x = ops.repeat(x, (hidden_sharding_deg, 1))
+
+ # Cross-device token exchange
+ parallel_x, parallel_x_handle = _layers.all_to_all.all_to_all(
+ x, recv_counts, send_counts, expert_parallel_group, async_op=True
+ )
+
+ with torch.no_grad():
+ # Step 4: Setup for local expert computation
+ replicate_bins = ops.inclusive_cumsum(parallel_tokens_per_expert.flatten(), 0)
+ replicate_bins = (
+ replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins
+ )
+
+ # Create expert indices for received tokens
+ parallel_top_expert = torch.remainder(
+ torch.arange(
+ num_experts * hidden_sharding_deg,
+ dtype=torch.int32,
+ device=indices.device,
+ ),
+ experts_per_rank_val,
+ )
+ parallel_top_expert = ops.replicate(
+ parallel_top_expert.unsqueeze(dim=0),
+ replicate_bins,
+ tokens_received,
+ ).flatten()
+
+ # Sort tokens by expert assignment
+ parallel_bin_ids, parallel_indices = ops.sort(
+ parallel_top_expert,
+ sort_end_bit,
+ )
+
+ # Calculate bins for local experts
+ parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
+ dim=0, dtype=torch.int
+ )
+ parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = (
+ parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins
+ )
+
+ # Calculate expert capacity
+ expert_capacity = expert_capacity_fn(
+ tokens_received,
+ top_k,
+ experts_per_rank_val,
+ expert_parallel_group,
+ moe_capacity_factor,
+ moe_expert_model_parallelism,
+ )
+ if expert_capacity == 0:
+ expert_capacity = torch.max(parallel_tokens_per_expert).item()
+
+ # Locally permute the tokens and perform the expert computation.
+ # Block to make sure that the cross-device permutation is complete.
+ if mlp_impl == "grouped":
+ # GroupedMLP requires counts on CPU. We can use the tensor already
+ # moved to CPU for the prior all_to_all, which avoids an extra
+ # device synchronization.
+ parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+
+ # Step 5: Expert computation
+ parallel_x_handle.wait()
+
+ parallel_x = permute_and_compute(
+ parallel_x,
+ parallel_tokens_per_expert,
+ parallel_indices,
+ parallel_bin_ids,
+ None, # expert_weights
+ parallel_bins,
+ expert_capacity,
+ top_k=1,
+ w1=w1,
+ w2=w2,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ )
+
+ # Step 6: Reverse communication - send results back
+ x, _ = _layers.all_to_all.all_to_all(
+ parallel_x, send_counts, recv_counts, expert_parallel_group
+ )
+
+ # Step 7: Reduce across hidden sharding dimension
+ shape = (hidden_sharding_deg, -1, hidden_size)
+ x = x.view(shape).sum(dim=0)
+
+ # Step 8: Final local unpermutation
+ x = ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
+
+ return x, tokens_per_expert.flatten()
+
+
+def moe_forward(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: Optional[torch.Tensor],
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+ w1: torch.Tensor = None,
+ w2: torch.Tensor = None,
+ w1_bias: torch.Tensor = None,
+ w2_bias: torch.Tensor = None,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ forward_fn: Any = None,
+ hidden_size: int = None,
+ mlp_impl: str = "grouped",
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+
+ # Route tokens to experts
+ logits, expert_weights, expert_indices = route_tokens(
+ x,
+ router_weight,
+ router_bias,
+ moe_top_k,
+ moe_num_experts,
+ moe_jitter_eps,
+ moe_normalize_expert_weights,
+ uniform_expert_assignment,
+ training,
+ )
+
+ # Create router scores for output
+ router_scores = (
+ torch.zeros_like(logits)
+ .scatter_(1, expert_indices, expert_weights)
+ .transpose(0, 1)
+ )
+
+ in_shape = x.size()
+
+ # Prepare forward function arguments
+ forward_args = {
+ "x": x,
+ "expert_weights": expert_weights,
+ "top_experts": expert_indices,
+ "w1": w1,
+ "w2": w2,
+ "w1_bias": w1_bias,
+ "w2_bias": w2_bias,
+ "gradient_scale": gradient_scale,
+ "alpha": alpha,
+ "sort_end_bit": sort_end_bit,
+ "top_k": moe_top_k,
+ "num_experts": moe_num_experts,
+ "expert_parallel_group": expert_parallel_group,
+ "moe_capacity_factor": moe_capacity_factor,
+ "moe_expert_model_parallelism": moe_expert_model_parallelism,
+ "mlp_impl": mlp_impl,
+ }
+
+ # Add hidden_size for parallel forward
+ if moe_expert_model_parallelism and hidden_size is not None:
+ forward_args["hidden_size"] = hidden_size
+ elif moe_expert_model_parallelism and hidden_size is None:
+ # Infer hidden_size from input shape
+ forward_args["hidden_size"] = x.shape[-1]
+
+ # Compute expert outputs
+ x, tokens_per_expert = forward_fn(**forward_args)
+
+ # Save load balancing loss if needed
+ moe_loss_weight = 0.0 # Can be made configurable
+ if training and moe_loss_weight > 0:
+ save_load_balancing_loss((tokens_per_expert, logits))
+
+ # Restore original shape
+ x = x.view(in_shape)
+
+ return x, expert_weights, router_scores
+
+
+def moe_forward_with_shared_expert(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: Optional[torch.Tensor],
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+ w1: torch.Tensor = None,
+ w2: torch.Tensor = None,
+ w1_bias: torch.Tensor = None,
+ w2_bias: torch.Tensor = None,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ forward_fn: Any = None,
+ hidden_size: int = None,
+ mlp_impl: str = "grouped",
+ # Shared expert parameters
+ shared_up_proj_weight: Optional[torch.Tensor] = None,
+ shared_down_proj_weight: Optional[torch.Tensor] = None,
+ shared_up_proj_bias: Optional[torch.Tensor] = None,
+ shared_down_proj_bias: Optional[torch.Tensor] = None,
+ shared_expert_weighted_sum: bool = False,
+ shared_activation_fn: Optional[Any] = None,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+
+ # First, compute regular MoE forward pass
+ expert_out, expert_weights, router_scores = moe_forward(
+ x=x,
+ router_weight=router_weight,
+ router_bias=router_bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=training,
+ w1=w1,
+ w2=w2,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=moe_expert_model_parallelism,
+ forward_fn=forward_fn,
+ hidden_size=hidden_size,
+ mlp_impl=mlp_impl,
+ )
+
+ # If shared expert weights provided, compute shared expert output
+ if shared_up_proj_weight is not None and shared_down_proj_weight is not None:
+ shared_expert_out = shared_mlp_forward(
+ x=x,
+ up_proj_weight=shared_up_proj_weight,
+ down_proj_weight=shared_down_proj_weight,
+ up_proj_bias=shared_up_proj_bias,
+ down_proj_bias=shared_down_proj_bias,
+ activation_fn=shared_activation_fn,
+ gradient_scale=gradient_scale,
+ )
+
+ # Combine expert outputs
+ combined_out = combine_expert_shared_outputs(
+ shared_expert_out=shared_expert_out,
+ expert_out=expert_out,
+ shared_expert_weighted_sum=shared_expert_weighted_sum,
+ moe_top_k=moe_top_k,
+ )
+
+ return combined_out, expert_weights, router_scores
+
+ # Return regular MoE output if no shared expert
+ return expert_out, expert_weights, router_scores
+
+
+def create_shared_expert_weights(
+ hidden_size: int,
+ shared_expert_hidden_size: int,
+ device: torch.device,
+ dtype: torch.dtype,
+ init_method: Any,
+ output_layer_init_method: Any = None,
+) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
+
+ if output_layer_init_method is None:
+ output_layer_init_method = init_method
+
+ # Create weight tensors
+ up_proj_weight = torch.empty(
+ shared_expert_hidden_size,
+ hidden_size,
+ device=device,
+ dtype=dtype,
+ )
+ down_proj_weight = torch.empty(
+ hidden_size,
+ shared_expert_hidden_size,
+ device=device,
+ dtype=dtype,
+ )
+
+ # Initialize weights
+ init_method(up_proj_weight)
+ output_layer_init_method(down_proj_weight)
+
+ # No bias by default
+ return up_proj_weight, down_proj_weight, None, None
+
+
+# HACK: Extract device_mesh from pre-hook closure - required for transformers integration
+# This exists because device_mesh is trapped in hook closures with no model attribute
+# Fragile - breaks if hook structure changes or Python internals change
+# TODO: Replace with a more robust solution when available
+def get_device_mesh(model):
+ # Extract device_mesh from child's unused pre_hook closure
+ try:
+ # Find the pre-hook that contains 'device_mesh' in its closure
+ hook = next(
+ h
+ for h in model.experts._forward_pre_hooks.values()
+ if "device_mesh" in h.__code__.co_freevars
+ )
+ # Extract the device_mesh from the closure
+ return hook.__closure__[
+ hook.__code__.co_freevars.index("device_mesh")
+ ].cell_contents
+ except Exception:
+ return None
+
+
+class MegaBlocksMoeMLP(torch.nn.Module):
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ gradient_scale = getattr(self.experts, "gradient_scale", None)
+ alpha = getattr(self.experts, "alpha", 1.0)
+ moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+ uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
+
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = get_device_mesh(self)
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
+
+ has_parallel = (
+ expert_parallel_group is not None
+ and dist.is_initialized()
+ and dist.get_world_size(expert_parallel_group) > 1
+ )
+ forward_fn = parallel_forward_once if has_parallel else forward_once
+
+ sort_end_bit = max(
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
+ )
+ mlp_impl = getattr(self, "mlp_impl", "grouped")
+ output, expert_weights_out, *_ = moe_forward(
+ x=x,
+ router_weight=self.router.weight,
+ router_bias=self.router.bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=self.training,
+ w1=self.experts.gate_up_proj,
+ w2=self.experts.down_proj,
+ w1_bias=self.experts.gate_up_proj_bias,
+ w2_bias=self.experts.down_proj_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=has_parallel,
+ forward_fn=forward_fn,
+ hidden_size=self.experts.hidden_size,
+ mlp_impl=mlp_impl,
+ )
+ return output, expert_weights_out
+
+
+# Export main classes
+__all__ = ["MegaBlocksMoeMLP", "MegaBlocksMoeMLPWithSharedExpert"]
+
+
+class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
+
+ def __init__(self):
+ super().__init__()
+ # Shared expert weights will be set by the user
+ self.shared_up_proj_weight = None
+ self.shared_down_proj_weight = None
+ self.shared_up_proj_bias = None
+ self.shared_down_proj_bias = None
+ self.shared_expert_weighted_sum = False
+ self.shared_activation_fn = None
+
+ def set_shared_expert_weights(
+ self,
+ up_proj_weight: torch.Tensor,
+ down_proj_weight: torch.Tensor,
+ up_proj_bias: Optional[torch.Tensor] = None,
+ down_proj_bias: Optional[torch.Tensor] = None,
+ weighted_sum: bool = False,
+ activation_fn: Optional[Any] = None,
+ ):
+ self.shared_up_proj_weight = up_proj_weight
+ self.shared_down_proj_weight = down_proj_weight
+ self.shared_up_proj_bias = up_proj_bias
+ self.shared_down_proj_bias = down_proj_bias
+ self.shared_expert_weighted_sum = weighted_sum
+ self.shared_activation_fn = activation_fn
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ gradient_scale = getattr(self.experts, "gradient_scale", None)
+ alpha = getattr(self.experts, "alpha", 1.0)
+ moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+ uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
+
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = get_device_mesh(self)
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
+
+ has_parallel = (
+ expert_parallel_group is not None
+ and dist.is_initialized()
+ and dist.get_world_size(expert_parallel_group) > 1
+ )
+ forward_fn = parallel_forward_once if has_parallel else forward_once
+
+ sort_end_bit = max(
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
+ )
+ mlp_impl = getattr(self, "mlp_impl", "grouped")
+
+ output, expert_weights_out, *_ = moe_forward_with_shared_expert(
+ x=x,
+ router_weight=self.router.weight,
+ router_bias=self.router.bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=self.training,
+ w1=self.experts.gate_up_proj,
+ w2=self.experts.down_proj,
+ w1_bias=self.experts.gate_up_proj_bias,
+ w2_bias=self.experts.down_proj_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=has_parallel,
+ forward_fn=forward_fn,
+ hidden_size=self.experts.hidden_size,
+ mlp_impl=mlp_impl,
+ # Shared expert parameters
+ shared_up_proj_weight=self.shared_up_proj_weight,
+ shared_down_proj_weight=self.shared_down_proj_weight,
+ shared_up_proj_bias=self.shared_up_proj_bias,
+ shared_down_proj_bias=self.shared_down_proj_bias,
+ shared_expert_weighted_sum=self.shared_expert_weighted_sum,
+ shared_activation_fn=self.shared_activation_fn,
+ )
+ return output, expert_weights_out
+
+
+# Patch for XPU or CPU support
+if hasattr(torch, "xpu") and torch.xpu.is_available():
+ from .xpu_fused_moe import MegaBlocksMoeMLP
+
+from .cpu_moe_cpp import CPUMegaBlocksMoeMLP
diff --git a/build/torch211-cxx11-cu126-aarch64-linux/megablocks/__init__.py b/build/torch211-cxx11-cu126-aarch64-linux/megablocks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23
--- /dev/null
+++ b/build/torch211-cxx11-cu126-aarch64-linux/megablocks/__init__.py
@@ -0,0 +1,26 @@
+import ctypes
+import importlib.util
+import sys
+from pathlib import Path
+from types import ModuleType
+
+
+def _import_from_path(file_path: Path) -> ModuleType:
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
+ # it would also be used for other imports. So, we make a module name that
+ # depends on the path for it to be unique using the hex-encoded hash of
+ # the path.
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
+ module_name = path_hash
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ if spec is None:
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
+ module = importlib.util.module_from_spec(spec)
+ if module is None:
+ raise ImportError(f"Cannot load module {module_name} from spec")
+ sys.modules[module_name] = module
+ spec.loader.exec_module(module) # type: ignore
+ return module
+
+
+globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
diff --git a/build/torch211-cxx11-cu126-aarch64-linux/metadata.json b/build/torch211-cxx11-cu126-aarch64-linux/metadata.json
new file mode 100644
index 0000000000000000000000000000000000000000..155112c59509d3b4d07f4d090cbf57071e3f5217
--- /dev/null
+++ b/build/torch211-cxx11-cu126-aarch64-linux/metadata.json
@@ -0,0 +1,18 @@
+{
+ "version": 1,
+ "license": "Apache-2.0",
+ "python-depends": [],
+ "backend": {
+ "type": "cuda",
+ "archs": [
+ "7.0",
+ "7.2",
+ "7.5",
+ "8.0",
+ "8.6",
+ "8.7",
+ "8.9",
+ "9.0"
+ ]
+ }
+}
diff --git a/build/torch211-cxx11-cu126-aarch64-linux/ops/__init__.py b/build/torch211-cxx11-cu126-aarch64-linux/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b944080df810d0b0cfc571f3009b0098a651f9b7
--- /dev/null
+++ b/build/torch211-cxx11-cu126-aarch64-linux/ops/__init__.py
@@ -0,0 +1,35 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from .binned_gather import binned_gather
+from .binned_scatter import binned_scatter
+from .cumsum import exclusive_cumsum, inclusive_cumsum
+from .gather import gather
+from .histogram import histogram
+from .padded_gather import padded_gather
+from .padded_scatter import padded_scatter
+from .repeat import repeat
+from .replicate import replicate
+from .round_up import round_up
+from .scatter import scatter
+from .sort import sort
+from .sum import sum
+from .topology import topology
+
+__all__ = [
+ 'binned_gather',
+ 'binned_scatter',
+ 'exclusive_cumsum',
+ 'inclusive_cumsum',
+ 'gather',
+ 'histogram',
+ 'padded_gather',
+ 'padded_scatter',
+ 'repeat',
+ 'replicate',
+ 'round_up',
+ 'scatter',
+ 'sort',
+ 'sum',
+ 'topology',
+]
diff --git a/build/torch211-cxx11-cu126-aarch64-linux/ops/all_to_all_benchmark.py b/build/torch211-cxx11-cu126-aarch64-linux/ops/all_to_all_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c939818edca3345f6344bbc7cef07ffe3cd0181
--- /dev/null
+++ b/build/torch211-cxx11-cu126-aarch64-linux/ops/all_to_all_benchmark.py
@@ -0,0 +1,63 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import torch.distributed as dist
+
+# from megablocks import benchmark_util
+# from megablocks.layers.all_to_all import all_to_all
+
+from .. import benchmark_util
+from .._layers.all_to_all import all_to_all
+
+_ALL_TO_ALL_BENCHMARK = (
+ (8, 1024),
+ (16, 1024),
+ (32, 1024),
+ (64, 1024),
+ (128, 1024),
+ (256, 1024),
+ (512, 1024),
+ (1024, 1024),
+ (2 * 1024, 1024),
+ (4 * 1024, 1024),
+ (8 * 1024, 1024),
+ (16 * 1024, 1024),
+ (32 * 1024, 1024),
+ (64 * 1024, 1024),
+ (128 * 1024, 1024),
+ (256 * 1024, 1024),
+ (512 * 1024, 1024),
+ (1024 * 1024, 1024),
+)
+
+
+def benchmark_all_to_all(group, sl, hs):
+ world_size = dist.get_world_size(group)
+ assert (sl % world_size) == 0
+ send_recv_sizes = [sl // world_size] * world_size
+
+ x = torch.randn((sl, hs)).cuda().half()
+
+ details = {
+ 'world_size': world_size,
+ 'message_size (B)': send_recv_sizes[0] * hs * 2, # 2B elements.
+ }
+
+ def benchmark():
+ return all_to_all(x, send_recv_sizes, send_recv_sizes, group)
+
+ time, std = benchmark_util.benchmark_function(benchmark)
+
+ if dist.get_rank(group) == 0:
+ benchmark_util.log_benchmark('All-To-All', details, time, std)
+
+
+if __name__ == '__main__':
+ assert dist.is_available()
+ group = dist.init_process_group(backend='nccl')
+ local_rank = dist.get_rank(group)
+ torch.cuda.set_device(local_rank)
+
+ for args in _ALL_TO_ALL_BENCHMARK:
+ benchmark_all_to_all(group, *args)
diff --git a/build/torch211-cxx11-cu126-aarch64-linux/ops/binned_gather.py b/build/torch211-cxx11-cu126-aarch64-linux/ops/binned_gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..189a7fa3518d660f29ea32e7a04827164af98d60
--- /dev/null
+++ b/build/torch211-cxx11-cu126-aarch64-linux/ops/binned_gather.py
@@ -0,0 +1,37 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for binned_gather kernel.
+class BinnedGatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bins: torch.Tensor,
+ bin_size: int,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bins)
+ ctx.top_k = top_k
+ return kernels.binned_gather(x, indices, None, bins, bin_size, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ indices, bins = ctx.saved_tensors
+ out = kernels.binned_scatter(grad, indices, None, bins, ctx.top_k)
+ return out, None, None, None, None
+
+
+binned_gather = BinnedGatherOp.apply
diff --git a/build/torch211-cxx11-cu126-aarch64-linux/ops/binned_scatter.py b/build/torch211-cxx11-cu126-aarch64-linux/ops/binned_scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb937c0c106662ce8108c1cb926f8f063b163d3d
--- /dev/null
+++ b/build/torch211-cxx11-cu126-aarch64-linux/ops/binned_scatter.py
@@ -0,0 +1,59 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for binned_scatter kernel.
+class BinnedScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ):
+ assert len(x.size()) == 3
+ ctx.bin_size = x.size(1)
+ ctx.top_k = top_k
+
+ # TODO(tgale): Don't save 'x' for backwards if we don't need to
+ # calculate the gradient w.r.t. 'weights'.
+ ctx.save_for_backward(x, indices, weights, bins)
+ return kernels.binned_scatter(x, indices, weights, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ x, indices, weights, bins = ctx.saved_tensors
+ out = kernels.binned_gather(
+ grad,
+ indices,
+ weights,
+ bins,
+ ctx.bin_size,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[2]:
+ wgrad = kernels.binned_scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bins,
+ ctx.top_k,
+ )
+ return out, None, wgrad, None, None
+
+
+binned_scatter = BinnedScatterOp.apply
diff --git a/build/torch211-cxx11-cu126-aarch64-linux/ops/cumsum.py b/build/torch211-cxx11-cu126-aarch64-linux/ops/cumsum.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2b7572391e20045d335cf7337246e8a9b9f57ef
--- /dev/null
+++ b/build/torch211-cxx11-cu126-aarch64-linux/ops/cumsum.py
@@ -0,0 +1,52 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ # import megablocks_ops as ops # type: ignore
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrappers for cumsum kernels.
+# NOTE: Does not support gradients.
+class ExclusiveCumsumOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, dim: int):
+ if len(x.size()) == 1:
+ x = x.view([1, -1])
+ out = torch.empty_like(x)
+ ops.exclusive_cumsum(x, 1, out)
+ return out.squeeze()
+ out = torch.empty_like(x)
+ ops.exclusive_cumsum(x, dim, out)
+ return out
+
+
+exclusive_cumsum = ExclusiveCumsumOp.apply
+
+
+class InclusiveCumsumOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, dim: int) -> torch.Tensor:
+ if len(x.size()) == 1:
+ x = x.view([1, -1])
+ out = torch.empty_like(x)
+ ops.inclusive_cumsum(x, 1, out)
+ return out.squeeze()
+ out = torch.empty_like(x)
+ ops.inclusive_cumsum(x, dim, out)
+ return out
+
+
+inclusive_cumsum = InclusiveCumsumOp.apply
diff --git a/build/torch211-cxx11-cu126-aarch64-linux/ops/gather.py b/build/torch211-cxx11-cu126-aarch64-linux/ops/gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1f87c1e7bed8d3589dd790805234976e0b05898
--- /dev/null
+++ b/build/torch211-cxx11-cu126-aarch64-linux/ops/gather.py
@@ -0,0 +1,38 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for gather kernel.
+class GatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bin_ids, bins)
+ ctx.top_k = top_k
+ return kernels.gather(x, indices, bin_ids, None, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+
+ indices, bin_ids, bins = ctx.saved_tensors
+ out = kernels.scatter(grad, indices, bin_ids, None, bins, ctx.top_k)
+ return out, None, None, None, None, None
+
+
+gather = GatherOp.apply
diff --git a/build/torch211-cxx11-cu126-aarch64-linux/ops/histogram.py b/build/torch211-cxx11-cu126-aarch64-linux/ops/histogram.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b3f058ec373cbba7555704fb5e4212c3cc75d9d
--- /dev/null
+++ b/build/torch211-cxx11-cu126-aarch64-linux/ops/histogram.py
@@ -0,0 +1,27 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for histogram kernel.
+# NOTE: Does not support gradients.
+class HistogramOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, max_val: float):
+ return ops.histogram(x, max_val)
+
+
+histogram = HistogramOp.apply
diff --git a/build/torch211-cxx11-cu126-aarch64-linux/ops/histogram_benchmark.py b/build/torch211-cxx11-cu126-aarch64-linux/ops/histogram_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0b9c6047567b87a295979498142230d1b0c9411
--- /dev/null
+++ b/build/torch211-cxx11-cu126-aarch64-linux/ops/histogram_benchmark.py
@@ -0,0 +1,78 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import numpy as np
+import torch
+# from absl.testing import parameterized
+
+from .. import ops
+
+_HISTOGRAM_TESTS = (
+ (16384, torch.int32, 2),
+ (16384, torch.int32, 4),
+ (16384, torch.int32, 8),
+ (16384, torch.int32, 16),
+ (16384, torch.int32, 32),
+ (16384, torch.int32, 64),
+ (16384, torch.int32, 128),
+ (16384, torch.int32, 256),
+)
+
+
+def benchmark_function(fn, iterations=10):
+ # Run once to get rid of startup overhead.
+ fn()
+ times = []
+ for _ in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+ start.record()
+ fn()
+ end.record()
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ times = np.array(times)
+ return times.mean(), times.std(), times.max(), times.min()
+
+
+def log_benchmark(arguments, mean_t, std_t):
+ print('=' * 60)
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t))
+ print('=' * 60)
+
+
+# class HistogramBenchmark(parameterized.TestCase):
+#
+# @parameterized.parameters(*_HISTOGRAM_TESTS)
+# def testHistogram(self, n, dtype, max_val):
+# x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
+#
+# mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.histogram(x, max_val),)
+# arguments = {
+# 'n': n,
+# 'dtype': dtype,
+# 'max_val': max_val,
+# }
+# log_benchmark(arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_HISTOGRAM_TESTS)
+# def testTorchHistogram(self, n, dtype, max_val):
+# x = torch.randint(0, 128, (n,)).cuda().to(dtype)
+#
+# mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.histc(x, max_val, 0, max_val - 1),)
+# arguments = {
+# 'n': n,
+# 'dtype': dtype,
+# 'max_val': max_val,
+# }
+# log_benchmark(arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch211-cxx11-cu126-aarch64-linux/ops/matmul_benchmark.py b/build/torch211-cxx11-cu126-aarch64-linux/ops/matmul_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c4b9b8866ffed2eb769b77f2320c82e5491ae0e
--- /dev/null
+++ b/build/torch211-cxx11-cu126-aarch64-linux/ops/matmul_benchmark.py
@@ -0,0 +1,415 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+
+# import stk
+
+# try:
+# import stk
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/ops/matmul_benchmark.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+# from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+
+# Calling tensor.t() calls tensor.transpose(0, 1) which calls
+# torch.as_strided(...). Circumvent this chain to avoid an overhead
+# this adds.
+def transpose_view(x):
+ return torch.as_strided(
+ x,
+ (x.shape[1], x.shape[0]),
+ (x.stride()[1], x.stride()[0]),
+ )
+
+
+_MATMUL_TESTS = (
+ (64 * 1024, 512, 2048, 64),
+ (32 * 1024, 768, 3072, 64),
+ (8 * 1024, 1024, 4096, 64),
+ (4 * 2048, 4096, 4 * 4096, 4),
+)
+
+
+def log_benchmark(name, arguments, time, std, flops):
+ benchmark_util.log_benchmark(name, arguments, time, std)
+ print('flops = {:.2f}B'.format(flops / 1e9))
+ print('throughput = {:.2f}T'.format(flops / 1e9 / time))
+ print('=' * 60)
+
+
+# class MatmulBenchmark(parameterized.TestCase):
+#
+# def build_sparse_matrix(self, x, padded_bins, fhs, ne):
+# blocking = 128
+# padded_tokens, _ = x.size()
+# assert padded_tokens % blocking == 0
+# assert fhs % blocking == 0
+#
+# # Offsets for the sparse matrix. All rows have the
+# # same number of nonzero blocks dictated by the
+# # dimensionality of a single expert.
+# block_rows = padded_tokens // blocking
+# blocks_per_row = fhs // blocking
+# offsets = torch.arange(
+# 0,
+# block_rows * blocks_per_row + 1,
+# blocks_per_row,
+# dtype=torch.int32,
+# device=x.device,
+# )
+#
+# # Indices for the sparse matrix. The indices for
+# # the intermediate matrix are dynamic depending
+# # on the mapping of tokens to experts.
+# column_indices = ops.topology(
+# padded_bins,
+# blocking,
+# block_rows,
+# blocks_per_row,
+# )
+# data = torch.empty(
+# column_indices.numel(),
+# blocking,
+# blocking,
+# dtype=torch.float16,
+# device=x.device,
+# )
+# shape = (padded_tokens, fhs * ne)
+# row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
+# return stk.Matrix(shape, data, row_indices, column_indices, offsets)
+#
+# def build_input_matrix(self, sl, hs, ne):
+# x = torch.randn((sl, hs)).cuda().half()
+#
+# # Assign tokens to experts uniformly.
+# top_expert = torch.arange(0, sl).cuda().int() % ne
+#
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(top_expert, ne)
+# padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+# padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+# out = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, 1)
+# return out, padded_bins
+#
+# def build_weight_matrix(self, ne, hs, fhs):
+# return torch.randn((hs, ne * fhs)).cuda().half()
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_Fwd_SDD_NT(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+# w = transpose_view(w)
+#
+# def benchmark():
+# return stk.ops.sdd(x, w, topo)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0::Fwd::SDD::NT',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_GradX_DSD_NN(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+#
+# def benchmark():
+# return stk.ops.dsd(topo, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0::GradX::DSD::NN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_GradW_DSD_TN(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+# topo = topo.t()
+#
+# def benchmark():
+# return stk.ops.dsd(topo, x)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0::GradW::DSD::TN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_Fwd_DSD_NN(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+#
+# def benchmark():
+# return stk.ops.dsd(x, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::Fwd::DSD::NN',
+# arguments,
+# mean_t,
+# std_t,
+# x.nnz * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_GradX_SDD_NT(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+# out = stk.ops.dsd(x, w)
+# w = transpose_view(w)
+#
+# def benchmark():
+# return stk.ops.sdd(out, w, x)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::GradX::SDD::NT',
+# arguments,
+# mean_t,
+# std_t,
+# x.nnz * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_GradW_DSD_TN(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+# out = stk.ops.dsd(x, w)
+# x = x.t()
+#
+# def benchmark():
+# return stk.ops.dsd(x, out)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::GradW::DSD::TN',
+# arguments,
+# mean_t,
+# std_t,
+# x.nnz * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_Fwd_DDD_NT(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, hs)).cuda().half()
+# w = torch.randn((ne, hs, fhs)).cuda().half()
+#
+# w = w.transpose(1, 2).contiguous()
+# w = w.transpose(1, 2)
+#
+# def benchmark():
+# return torch.bmm(x, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0::Fwd:DDD::NT',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_GradX_DDD_NN(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, hs)).cuda().half()
+# w = torch.randn((ne, hs, fhs)).cuda().half()
+# out = torch.bmm(x, w)
+# w = w.transpose(1, 2).contiguous()
+#
+# def benchmark():
+# return torch.bmm(out, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0:GradX:DDD::NN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_GradW_DDD_TN(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, hs)).cuda().half()
+# w = torch.randn((ne, hs, fhs)).cuda().half()
+# out = torch.bmm(x, w)
+# out = out.transpose(1, 2)
+#
+# def benchmark():
+# return torch.bmm(out, x)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0:GradW:DDD::TN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_Fwd_DDD_NN(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+# w = torch.randn((ne, fhs, hs)).cuda().half()
+#
+# def benchmark():
+# return torch.bmm(x, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::Fwd::DDD::NN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_GradX_DDD_NT(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+# w = torch.randn((ne, fhs, hs)).cuda().half()
+# out = torch.bmm(x, w)
+# w = torch.transpose(w, 1, 2)
+#
+# def benchmark():
+# return torch.bmm(out, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::GradX::DDD::NT',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_GradW_DDD_TN(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+# w = torch.randn((ne, fhs, hs)).cuda().half()
+# out = torch.bmm(x, w)
+# x = torch.transpose(x, 1, 2)
+#
+# def benchmark():
+# return torch.bmm(x, out)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::GradW::DDD::TN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * hs * 2,
+# )
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch211-cxx11-cu126-aarch64-linux/ops/padded_gather.py b/build/torch211-cxx11-cu126-aarch64-linux/ops/padded_gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1cf4047c9494394d2a3884ba8830179013db7ff
--- /dev/null
+++ b/build/torch211-cxx11-cu126-aarch64-linux/ops/padded_gather.py
@@ -0,0 +1,55 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for padded_gather kernel.
+class PaddedGatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bin_ids, bins, padded_bins)
+ ctx.top_k = top_k
+ return kernels.padded_gather(
+ x,
+ indices,
+ bin_ids,
+ None,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+
+ indices, bin_ids, bins, padded_bins = ctx.saved_tensors
+ out = kernels.padded_scatter(
+ grad,
+ indices,
+ bin_ids,
+ None,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+ return out, None, None, None, None, None
+
+
+padded_gather = PaddedGatherOp.apply
diff --git a/build/torch211-cxx11-cu126-aarch64-linux/ops/padded_scatter.py b/build/torch211-cxx11-cu126-aarch64-linux/ops/padded_scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..61e021b81497e472cda5d72bdac557a0ca92d262
--- /dev/null
+++ b/build/torch211-cxx11-cu126-aarch64-linux/ops/padded_scatter.py
@@ -0,0 +1,98 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for padded_scatter kernel.
+class PaddedScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+ ):
+ maybe_x = [x] if ctx.needs_input_grad[3] else []
+ ctx.save_for_backward(
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ *maybe_x,
+ )
+ ctx.top_k = top_k
+ ctx.x_shape = x.shape
+ return kernels.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ saved_tensors = ctx.saved_tensors
+
+ indices, bin_ids, weights, bins, padded_bins = saved_tensors[:5]
+ dgrad = None
+ if ctx.needs_input_grad[0]:
+ dgrad = kernels.padded_gather(
+ grad,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[3]: # need wgrad
+ x = saved_tensors[-1]
+ wgrad = kernels.padded_scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+ return dgrad, None, None, wgrad, None, None, None, None
+
+
+def padded_scatter(
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+):
+ return PaddedScatterOp.apply(
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
diff --git a/build/torch211-cxx11-cu126-aarch64-linux/ops/padded_scatter_benchmark.py b/build/torch211-cxx11-cu126-aarch64-linux/ops/padded_scatter_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..3cbe4735891446b46f93170c64c23fe63632bf93
--- /dev/null
+++ b/build/torch211-cxx11-cu126-aarch64-linux/ops/padded_scatter_benchmark.py
@@ -0,0 +1,66 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import torch
+# from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+_PADDED_SCATTER_BENCHMARK = (
+ # dMoE-Medium, 8-way EMP.
+ (1024 * 16, 1024, 8, 4),
+ # dMoE-Medium, post-all-to-all.
+ (1024 * 16 * 4, 1024, 8, 1),
+)
+
+
+# class PaddedScatterTest(parameterized.TestCase):
+#
+# @parameterized.parameters(*_PADDED_SCATTER_BENCHMARK)
+# def testPaddedScatter(self, sl, hs, ne, top_k):
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+#
+# # Randomly assign tokens to experts.
+# top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(top_expert, ne)
+# padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+# padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+#
+# # Sample weights for the scatter reduce.
+# weights = torch.rand((sl * top_k,)).cuda().half()
+#
+# # Gather the data to prepare for backwards.
+# x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
+#
+# def benchmark():
+# return ops.padded_scatter(
+# x,
+# indices,
+# bin_ids,
+# weights,
+# bins,
+# padded_bins,
+# top_k,
+# )
+#
+# time, std = benchmark_util.benchmark_function(benchmark)
+# benchmark_util.log_benchmark(
+# 'Padded Scatter',
+# {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# 'top_k': top_k,
+# },
+# time,
+# std,
+# )
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch211-cxx11-cu126-aarch64-linux/ops/permute_benchmark.py b/build/torch211-cxx11-cu126-aarch64-linux/ops/permute_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..697abddbb3a2082ec4ddd6d94f89f7faabb34b40
--- /dev/null
+++ b/build/torch211-cxx11-cu126-aarch64-linux/ops/permute_benchmark.py
@@ -0,0 +1,149 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import torch
+# from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+_PERMUTE_TESTS = (
+ (16384, 768, 2),
+ (16384, 768, 4),
+ (16384, 768, 8),
+ (16384, 768, 16),
+ (16384, 768, 32),
+ (16384, 768, 64),
+ (16384, 768, 128),
+ (16384 * 8, 768, 2),
+ (16384 * 8, 768, 4),
+ (16384 * 8, 768, 8),
+ (16384 * 8, 768, 16),
+ (16384 * 8, 768, 32),
+ (16384 * 8, 768, 64),
+ (16384 * 8, 768, 128),
+)
+
+
+# class PermuteBenchmark(parameterized.TestCase):
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testBinnedGather(self, sl, hs, ne):
+# # NOTE: Capacity factor == 1.
+# ec = sl // ne
+#
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+# top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(indices, ne)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+#
+# def benchmark():
+# return ops.binned_gather(x, indices, bins, ec)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('BinnedGather', arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testBinnedScatter(self, sl, hs, ne):
+# # NOTE: Capacity factor == 1.
+# ec = sl // ne
+#
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+# top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(indices, ne)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+# x = ops.binned_gather(x, indices, bins, ec)
+#
+# def benchmark():
+# return ops.binned_scatter(x, indices, bins)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('BinnedScatter', arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testPaddedGather(self, sl, hs, ne):
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+#
+# # Randomly assign tokens to experts.
+# top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(top_expert, ne)
+# padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+# padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+#
+# def benchmark():
+# return ops.padded_gather(x, indices, bin_ids, bins, padded_bins)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('PaddedGather', arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testPaddedScatter(self, sl, hs, ne):
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+#
+# # Randomly assign tokens to experts.
+# top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(top_expert, ne)
+# padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+# padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+# x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins)
+#
+# def benchmark():
+# return ops.padded_scatter(x, indices, bin_ids, bins, padded_bins)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('PaddedScatter', arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testCopy(self, sl, hs, ne):
+# # NOTE: Capacity factor == 1.
+# # ec = sl // ne
+#
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+# y = x.clone()
+#
+# def benchmark():
+# return y.copy_(x)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('Copy', arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch211-cxx11-cu126-aarch64-linux/ops/repeat.py b/build/torch211-cxx11-cu126-aarch64-linux/ops/repeat.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e9e09de5f857d51cd758ab30b2f3a846d6f9275
--- /dev/null
+++ b/build/torch211-cxx11-cu126-aarch64-linux/ops/repeat.py
@@ -0,0 +1,10 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+
+def repeat(x: torch.Tensor, tiling: torch.Size):
+ if all((t == 1 for t in tiling)):
+ return x
+ return x.repeat(*tiling)
diff --git a/build/torch211-cxx11-cu126-aarch64-linux/ops/replicate.py b/build/torch211-cxx11-cu126-aarch64-linux/ops/replicate.py
new file mode 100644
index 0000000000000000000000000000000000000000..26daf0eede330603a4b8ea7167faf1411d07ca93
--- /dev/null
+++ b/build/torch211-cxx11-cu126-aarch64-linux/ops/replicate.py
@@ -0,0 +1,36 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for replicate kernel.
+class ReplicateOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, bins: torch.Tensor, num_outputs: int):
+ ctx.save_for_backward(bins)
+ out = torch.empty((x.shape[0], num_outputs), dtype=x.dtype, device=x.device)
+ ops.replicate_forward(x, bins, out)
+ return out
+
+ @staticmethod
+ def backward(ctx: Any, grad: torch.Tensor):
+ bins, = ctx.saved_tensors
+ out = torch.empty((grad.shape[0], bins.shape[0]), dtype=grad.dtype, device=grad.device)
+ ops.replicate_backward(grad, bins, out)
+ return out, None, None
+
+
+replicate = ReplicateOp.apply
diff --git a/build/torch211-cxx11-cu126-aarch64-linux/ops/round_up.py b/build/torch211-cxx11-cu126-aarch64-linux/ops/round_up.py
new file mode 100644
index 0000000000000000000000000000000000000000..6cf6bc873c9f448c5fa9126ebcfd66e8688002af
--- /dev/null
+++ b/build/torch211-cxx11-cu126-aarch64-linux/ops/round_up.py
@@ -0,0 +1,14 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+
+def round_up(x: torch.Tensor, value: int):
+ assert isinstance(value, int)
+ assert x.dtype == torch.int32
+
+ # TODO(tgale): If this becomes and issue
+ # do this in a custom kernel. We only expect
+ # to use this on arrays of less than 1k elements.
+ return torch.div(x + (value - 1), value, rounding_mode='trunc') * value
diff --git a/build/torch211-cxx11-cu126-aarch64-linux/ops/scatter.py b/build/torch211-cxx11-cu126-aarch64-linux/ops/scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4605d9b46f387761b070352365f223dbfe69d47
--- /dev/null
+++ b/build/torch211-cxx11-cu126-aarch64-linux/ops/scatter.py
@@ -0,0 +1,72 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Optional
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for scatter kernel.
+class ScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ) -> torch.Tensor:
+ maybe_x = [x] if ctx.needs_input_grad[3] else []
+ ctx.save_for_backward(indices, bin_ids, weights, bins, *maybe_x)
+ ctx.top_k = top_k
+ ctx.x_shape = x.shape
+ return kernels.scatter(x, indices, bin_ids, weights, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ saved_tensors = ctx.saved_tensors
+
+ indices, bin_ids, weights, bins = saved_tensors[:4]
+ dgrad = None
+ if ctx.needs_input_grad[0]:
+ dgrad = kernels.gather(
+ grad,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[3]: # need wgrad
+ x = saved_tensors[-1]
+ wgrad = kernels.scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bin_ids,
+ bins,
+ ctx.top_k,
+ )
+ return dgrad, None, None, wgrad, None, None, None
+
+
+def scatter(
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+) -> Optional[torch.Tensor]:
+ return ScatterOp.apply(x, indices, bin_ids, weights, bins, top_k)
diff --git a/build/torch211-cxx11-cu126-aarch64-linux/ops/sort.py b/build/torch211-cxx11-cu126-aarch64-linux/ops/sort.py
new file mode 100644
index 0000000000000000000000000000000000000000..bda3bf64283e39533c2eae3627e76bb2d0262c9f
--- /dev/null
+++ b/build/torch211-cxx11-cu126-aarch64-linux/ops/sort.py
@@ -0,0 +1,38 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Optional, Tuple
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+_BITS_FOR_DTYPE = {
+ torch.int16: 16,
+ torch.int32: 32,
+ torch.int64: 64,
+}
+
+
+# Autograd wrapper for sort kernel.
+# NOTE: Does not support gradients.
+class SortOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, end_bit: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
+ if end_bit is None:
+ end_bit = _BITS_FOR_DTYPE[x.dtype]
+ x_out = torch.empty_like(x)
+ iota_out = torch.empty_like(x)
+ ops.sort(x, end_bit, x_out, iota_out)
+ return (x_out, iota_out)
+
+
+sort = SortOp.apply
diff --git a/build/torch211-cxx11-cu126-aarch64-linux/ops/sort_benchmark.py b/build/torch211-cxx11-cu126-aarch64-linux/ops/sort_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..11043c0824c36372585f1d9f48480c2a6ef32eb6
--- /dev/null
+++ b/build/torch211-cxx11-cu126-aarch64-linux/ops/sort_benchmark.py
@@ -0,0 +1,85 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import numpy as np
+import torch
+# from absl.testing import parameterized
+
+from .. import ops
+
+_SORT_TESTS = (
+ (16384, torch.int32, None),
+ (16384, torch.int32, 2),
+ (16384, torch.int32, 128),
+)
+
+_BASELINE_SORT_TESTS = ((16384,),)
+
+
+def numpy_dtype(dtype):
+ types = {
+ torch.int16: np.int16,
+ torch.int32: np.int32,
+ torch.int64: np.int64,
+ }
+ return types[dtype]
+
+
+def benchmark_function(fn, iterations=10):
+ # Run once to get rid of startup overhead.
+ fn()
+ times = []
+ for _ in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+ start.record()
+ fn()
+ end.record()
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ times = np.array(times)
+ return times.mean(), times.std(), times.max(), times.min()
+
+
+def log_benchmark(arguments, mean_t, std_t):
+ print('=' * 60)
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t))
+ print('=' * 60)
+
+
+# class SortBenchmark(parameterized.TestCase):
+#
+# @parameterized.parameters(*_SORT_TESTS)
+# def testSort(self, n, dtype, max_val):
+# if max_val is None:
+# max_val = np.iinfo(numpy_dtype(dtype)).max
+# end_bit = int(np.ceil(np.log2(max_val)))
+# x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
+#
+# mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.sort(x, end_bit),)
+# arguments = {
+# 'n': n,
+# 'dtype': dtype,
+# 'max_val': max_val,
+# }
+# log_benchmark(arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_BASELINE_SORT_TESTS)
+# def testTorchSort(self, n):
+# x = torch.randint(0, 128, (n,)).cuda().to(torch.int32)
+#
+# mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.sort(x))
+# arguments = {
+# 'n': n,
+# }
+# log_benchmark(arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch211-cxx11-cu126-aarch64-linux/ops/stk_autocast.py b/build/torch211-cxx11-cu126-aarch64-linux/ops/stk_autocast.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a3626e5e0eec51339c95a448bca84be14a2ca93
--- /dev/null
+++ b/build/torch211-cxx11-cu126-aarch64-linux/ops/stk_autocast.py
@@ -0,0 +1,39 @@
+# vendored from
+# https://github.com/stanford-futuredata/stk/blob/736313768ef697ce13a0594a41b2512a0fbc9884/stk/backend/autocast.py
+import functools
+import torch
+
+
+def _is_eligible(x):
+ return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64)
+
+
+def _cast(x, dtype):
+ if isinstance(x, torch.Tensor) and _is_eligible(x):
+ return x.to(dtype)
+ elif isinstance(x, map):
+ return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()}
+ elif isinstance(x, list) or isinstance(x, tuple):
+ return type(x)(map(lambda y: _cast(y, dtype), x))
+ return x
+
+
+def custom_fwd(fwd):
+ """Wrap a custom autograd function that always uses autocast dtype."""
+
+ @functools.wraps(fwd)
+ def decorate_fwd(*args, **kwargs):
+ if torch.is_autocast_enabled():
+ with torch.autocast(device_type="cuda", enabled=False):
+ dtype = torch.get_autocast_gpu_dtype()
+ return fwd(*_cast(args, dtype), **_cast(kwargs, dtype))
+ return fwd(*args, **kwargs)
+ return decorate_fwd
+
+
+def custom_bwd(bwd):
+ @functools.wraps(bwd)
+ def decorate_bwd(*args, **kwargs):
+ with torch.autocast(device_type="cuda", enabled=False):
+ return bwd(*args, **kwargs)
+ return decorate_bwd
\ No newline at end of file
diff --git a/build/torch211-cxx11-cu126-aarch64-linux/ops/sum.py b/build/torch211-cxx11-cu126-aarch64-linux/ops/sum.py
new file mode 100644
index 0000000000000000000000000000000000000000..e00c1aa68e1193f5b72f75a2edc37de8d505facc
--- /dev/null
+++ b/build/torch211-cxx11-cu126-aarch64-linux/ops/sum.py
@@ -0,0 +1,9 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+import torch
+
+
+def sum(x: torch.Tensor, dim: int = 0):
+ if x.shape[dim] == 1:
+ return x.squeeze(dim=dim)
+ return x.sum(dim=dim)
diff --git a/build/torch211-cxx11-cu126-aarch64-linux/ops/topology.py b/build/torch211-cxx11-cu126-aarch64-linux/ops/topology.py
new file mode 100644
index 0000000000000000000000000000000000000000..76a50d3164db20534b099dcb4d8487a7aef25d15
--- /dev/null
+++ b/build/torch211-cxx11-cu126-aarch64-linux/ops/topology.py
@@ -0,0 +1,45 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for topology kernel.
+# NOTE: Does not support gradients.
+class TopologyOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(
+ ctx: Any,
+ padded_bins: torch.Tensor,
+ block_size: int,
+ output_block_rows: int,
+ output_block_columns: int,
+ ):
+ out = torch.empty(
+ output_block_rows * output_block_columns,
+ dtype=torch.int16,
+ device=padded_bins.device,
+ )
+ ops.indices(
+ padded_bins,
+ block_size,
+ output_block_rows,
+ output_block_columns,
+ out,
+ )
+ return out
+
+
+topology = TopologyOp.apply
diff --git a/build/torch211-cxx11-cu126-aarch64-linux/stk/__init__.py b/build/torch211-cxx11-cu126-aarch64-linux/stk/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..73c40c267b1c3f4949e9c957a5d2c9f682dfc1a6
--- /dev/null
+++ b/build/torch211-cxx11-cu126-aarch64-linux/stk/__init__.py
@@ -0,0 +1,7 @@
+# import stk.random
+# import stk.ops
+# from stk.matrix import Matrix
+
+from . import random
+from . import ops
+from .matrix import Matrix
diff --git a/build/torch211-cxx11-cu126-aarch64-linux/stk/backend/__init__.py b/build/torch211-cxx11-cu126-aarch64-linux/stk/backend/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/build/torch211-cxx11-cu126-aarch64-linux/stk/backend/autocast.py b/build/torch211-cxx11-cu126-aarch64-linux/stk/backend/autocast.py
new file mode 100644
index 0000000000000000000000000000000000000000..97f6e919a60f3fd579ed0215031008d14111dc96
--- /dev/null
+++ b/build/torch211-cxx11-cu126-aarch64-linux/stk/backend/autocast.py
@@ -0,0 +1,37 @@
+import functools
+import torch
+
+
+def _is_eligible(x):
+ return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64)
+
+
+def _cast(x, dtype):
+ if isinstance(x, torch.Tensor) and _is_eligible(x):
+ return x.to(dtype)
+ elif isinstance(x, map):
+ return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()}
+ elif isinstance(x, list) or isinstance(x, tuple):
+ return type(x)(map(lambda y: _cast(y, dtype), x))
+ return x
+
+
+def custom_fwd(fwd):
+ """Wrap a custom autograd function that always uses autocast dtype."""
+
+ @functools.wraps(fwd)
+ def decorate_fwd(*args, **kwargs):
+ if torch.is_autocast_enabled():
+ with torch.autocast(device_type="cuda", enabled=False):
+ dtype = torch.get_autocast_gpu_dtype()
+ return fwd(*_cast(args, dtype), **_cast(kwargs, dtype))
+ return fwd(*args, **kwargs)
+ return decorate_fwd
+
+
+def custom_bwd(bwd):
+ @functools.wraps(bwd)
+ def decorate_bwd(*args, **kwargs):
+ with torch.autocast(device_type="cuda", enabled=False):
+ return bwd(*args, **kwargs)
+ return decorate_bwd
diff --git a/build/torch211-cxx11-cu126-aarch64-linux/stk/backend/sputnik.py b/build/torch211-cxx11-cu126-aarch64-linux/stk/backend/sputnik.py
new file mode 100644
index 0000000000000000000000000000000000000000..220c947bc1e932e8c77cc30f4069e9930f1aa962
--- /dev/null
+++ b/build/torch211-cxx11-cu126-aarch64-linux/stk/backend/sputnik.py
@@ -0,0 +1,316 @@
+import torch
+
+from ..backend import triton_kernels as backend
+from ..backend.autocast import custom_bwd, custom_fwd
+
+
+def _standardize_shape(x, transpose):
+ if transpose:
+ return torch.Size((x[1], x[0]))
+ return x
+
+
+def _sparse_transpose(x):
+ return (torch.Size((x[0][1], x[0][0])), ) + x[1:]
+
+
+def _transpose_helper(x, transpose):
+ if isinstance(x, torch.Tensor):
+ return x.t() if transpose else x
+ if transpose:
+ x = _sparse_transpose(x)
+ return x + (transpose,)
+
+
+def _wrap(x):
+ if isinstance(x, torch.Tensor):
+ return (x,)
+ return x
+
+
+def _is_transposed(x):
+ return (not x.is_contiguous() and
+ x.stride()[0] == 1 and
+ x.stride()[1] == x.size()[0])
+
+
+def _call_helper(op, out, a, b, trans_a, trans_b):
+ args = (_wrap(_transpose_helper(a, trans_a)) +
+ _wrap(_transpose_helper(b, trans_b)))
+ if isinstance(out, tuple):
+ args = args + out
+ return op(*args)
+
+
+def _preprocess_inputs(lhs, rhs, dy):
+ if isinstance(lhs, torch.Tensor) and _is_transposed(lhs):
+ lhs = lhs.t()
+ if isinstance(rhs, torch.Tensor) and _is_transposed(rhs):
+ rhs = rhs.t()
+ if (isinstance(dy, torch.Tensor) and
+ not dy.is_contiguous() and
+ not _is_transposed(dy)):
+ dy = dy.contiguous()
+ if isinstance(dy, tuple) and not dy[1].is_contiguous():
+ dy = (dy[0], dy[1].contiguous()) + dy[2:]
+ return lhs, rhs, dy
+
+
+def _postprocess_outputs(x, transpose, grad):
+ if isinstance(x, torch.Tensor) and transpose:
+ return grad.t()
+ return grad
+
+
+def _lhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs):
+ lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy)
+
+ a, b = (rhs, dy) if trans_lhs else (dy, rhs)
+ trans_a = trans_lhs and trans_rhs
+ trans_b = trans_lhs or not trans_rhs
+ out = _call_helper(op, lhs, a, b, trans_a, trans_b)
+ return _postprocess_outputs(lhs, trans_lhs, out)
+
+
+def _rhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs):
+ lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy)
+
+ a, b = (dy, lhs) if trans_rhs else (lhs, dy)
+ trans_a = not trans_lhs or trans_rhs
+ trans_b = trans_lhs and trans_rhs
+ out = _call_helper(op, rhs, a, b, trans_a, trans_b)
+ return _postprocess_outputs(rhs, trans_rhs, out)
+
+
+class DSD(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs):
+ ctx.save_for_backward(data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ rhs)
+ ctx.shape = _standardize_shape(shape, transpose_a)
+ ctx.transpose_a = transpose_a
+
+ out = torch.empty(
+ (shape[0], rhs.size()[1]),
+ dtype=rhs.dtype,
+ device=rhs.device)
+
+ backend.dsd(shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs,
+ out)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs = (ctx.shape,) + saved_tensors[:-1]
+ rhs = saved_tensors[-1]
+ trans_a = ctx.transpose_a
+ trans_b = _is_transposed(rhs)
+
+ ddata = None
+ if ctx.needs_input_grad[1]:
+ ddata = _lhs_gradient(sdd,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ drhs = None
+ if ctx.needs_input_grad[-1]:
+ op = dds if trans_b else dsd
+ drhs = _rhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return None, ddata, None, None, None, None, None, None, None, drhs
+
+
+dsd = DSD.apply
+
+
+class DDS(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b):
+ ctx.save_for_backward(lhs,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t)
+ ctx.shape = _standardize_shape(shape, transpose_b)
+ ctx.transpose_b = transpose_b
+ out = torch.empty((lhs.size()[0], shape[1]),
+ dtype=lhs.dtype,
+ device=lhs.device)
+ backend.dds(lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b,
+ out)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs = saved_tensors[0]
+ rhs = (ctx.shape,) + saved_tensors[1:]
+ trans_a = _is_transposed(lhs)
+ trans_b = ctx.transpose_b
+
+ dlhs = None
+ if ctx.needs_input_grad[0]:
+ op = dsd if trans_a else dds
+ dlhs = _lhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ ddata = None
+ if ctx.needs_input_grad[2]:
+ ddata = _rhs_gradient(sdd,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return dlhs, None, ddata, None, None, None, None, None, None, None
+
+
+dds = DDS.apply
+
+
+class SDD(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ lhs,
+ rhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t):
+ ctx.save_for_backward(
+ lhs,
+ rhs,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t)
+ ctx.shape = shape
+ out = torch.empty(
+ data.shape,
+ dtype=lhs.dtype,
+ device=lhs.device)
+ backend.sdd(lhs,
+ rhs,
+ shape,
+ out,
+ offsets,
+ row_indices,
+ column_indices)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs, rhs = saved_tensors[:2]
+ dy = (ctx.shape, dy) + saved_tensors[2:]
+ trans_a = _is_transposed(lhs)
+ trans_b = _is_transposed(rhs)
+
+ dlhs = None
+ if ctx.needs_input_grad[0]:
+ op = dds if trans_a else dsd
+ dlhs = _lhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ drhs = None
+ if ctx.needs_input_grad[1]:
+ op = dsd if trans_b else dds
+ drhs = _rhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return dlhs, drhs, None, None, None, None, None, None, None, None
+
+
+sdd = SDD.apply
+
+class RowIndices(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, shape, data, offsets, column_indices):
+ out = torch.empty(
+ column_indices.shape,
+ dtype=column_indices.dtype,
+ device=column_indices.device)
+ backend.row_indices(shape, data, offsets, column_indices, out)
+ return out
+
+
+row_indices = RowIndices.apply
diff --git a/build/torch211-cxx11-cu126-aarch64-linux/stk/backend/triton_kernels.py b/build/torch211-cxx11-cu126-aarch64-linux/stk/backend/triton_kernels.py
new file mode 100644
index 0000000000000000000000000000000000000000..c535309f3321249f475367164a558f94a4f8eb86
--- /dev/null
+++ b/build/torch211-cxx11-cu126-aarch64-linux/stk/backend/triton_kernels.py
@@ -0,0 +1,393 @@
+import torch
+import triton
+import triton.language as tl
+from dataclasses import dataclass
+
+@dataclass
+class TritonConfig:
+ BLOCK_M: int = 128
+ BLOCK_N: int = 128
+ BLOCK_K: int = 32
+ BLOCK_SIZE: int = 128
+ NUM_STAGES: int = 4
+ NUM_WARPS: int = 4
+
+def _validate_matmul_dims(M: int, K: int, N: int):
+ error_string = "incompatible dimensions: tensor has dim with length: {}, which must be divisible by {}"
+ assert M % TritonConfig.BLOCK_M == 0, error_string.format(M, TritonConfig.BLOCK_M)
+ assert K % TritonConfig.BLOCK_K == 0, error_string.format(K, TritonConfig.BLOCK_K)
+ assert N % TritonConfig.BLOCK_N == 0, error_string.format(N, TritonConfig.BLOCK_N)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _sdd_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+ # matrix multiplication
+ pid = tl.program_id(0)
+ pid_m = tl.load(row_indices + pid)
+ pid_n = tl.load(column_indices + pid)
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
+ rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
+ rk = tl.arange(0, BLOCK_K)
+ # pointers
+ A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
+ B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ for k in range(0, tl.cdiv(K, BLOCK_K)):
+ a = tl.load(A)
+ b = tl.load(B)
+ acc += tl.dot(a, b)
+ A += BLOCK_K * stride_ak
+ B += BLOCK_K * stride_bk
+ #Store to sparse matrix
+ acc = acc.to(C.dtype.element_ty)
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+ cm = tl.arange(0, BLOCK_M)
+ cn = tl.arange(0, BLOCK_N)
+ C = C + pid * BLOCK_ELEMENTS + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _dsd_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices, offsets,
+ block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+
+ # matrix multiplication
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+
+ num_pid_m = tl.num_programs(0)
+ num_pid_n = tl.num_programs(1)
+ pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M)
+
+ start_inx = tl.load(offsets + pid_m)
+ end_inx = tl.load(offsets + pid_m + 1)
+
+ # pointers to sparse matrix
+ rm = tl.arange(0, BLOCK_M)
+ rak = tl.arange(0, BLOCK_K)
+
+ A += (rm[:, None] * stride_am + rak[None, :] * stride_ak)
+
+ # pointers to dense matrix
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ rbk = tl.arange(0, BLOCK_K)
+ B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn)
+
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K)
+
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+ ak_sub_incr = BLOCK_K * stride_ak
+ bk_sub_incr = BLOCK_K * stride_bk
+ bk_block_incr = BLOCK_SIZE * stride_bk
+
+ for k in range(nsub_blocks * (end_inx - start_inx)):
+ sub_block_inx = k % nsub_blocks
+ block_inx = k // nsub_blocks
+
+ if trans_A:
+ ptr_A = A + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr
+ else:
+ ptr_A = A + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr
+
+ ptr_B = B + tl.load(column_indices + start_inx + block_inx) * bk_block_incr + sub_block_inx * bk_sub_incr
+
+ a = tl.load(ptr_A)
+ b = tl.load(ptr_B)
+ acc += tl.dot(a, b)
+
+ acc = acc.to(C.dtype.element_ty)
+
+ cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+
+ C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _dds_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices, offsets,
+ block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+
+ # matrix multiplication
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+
+ num_pid_m = tl.num_programs(0)
+ num_pid_n = tl.num_programs(1)
+ pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M)
+
+ start_inx = tl.load(offsets + pid_n)
+ end_inx = tl.load(offsets + pid_n + 1)
+
+ # pointers to dense matrix
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ rak = tl.arange(0, BLOCK_K)
+
+ A += (rm[:, None] * stride_am + rak[None, :] * stride_ak)
+
+ # pointers to sparse matrix
+ rn = tl.arange(0, BLOCK_N)
+ rbk = tl.arange(0, BLOCK_K)
+ B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn)
+
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K)
+
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+
+ ak_sub_incr = BLOCK_K * stride_ak
+ ak_block_incr = BLOCK_SIZE * stride_ak
+ bk_sub_incr = BLOCK_K * stride_bk
+
+ for k in range(nsub_blocks * (end_inx - start_inx)):
+ sub_block_inx = k % nsub_blocks
+ block_inx = k // nsub_blocks
+
+ if trans_B:
+ ptr_B = B + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr
+ else:
+ ptr_B = B + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr
+
+ ptr_A = A + tl.load(column_indices + start_inx + block_inx) * ak_block_incr + sub_block_inx * ak_sub_incr
+ a = tl.load(ptr_A)
+ b = tl.load(ptr_B)
+ acc += tl.dot(a, b)
+
+ acc = acc.to(C.dtype.element_ty)
+ cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+def dsd(shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs,
+ out
+ ):
+
+ device = rhs.device
+ trans_A = transpose_a
+ trans_B = False
+
+ if rhs.stride(0) > 1 and rhs.stride(1) > 1:
+ trans_B = True
+
+ # checks constraints
+ assert shape[1] == rhs.shape[0], "incompatible dimensions"
+ M, K = shape
+ _, N = rhs.shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if rhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ stride_am, stride_ak = data.stride(1), data.stride(2)
+ stride_bk, stride_bn = rhs.stride(0), rhs.stride(1)
+ a_column_indices = column_indices
+ a_offsets = offsets
+
+ # launch kernel
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N']))
+
+ if trans_A:
+ stride_am, stride_ak = data.stride(2), data.stride(1)
+ a_column_indices, a_offsets = column_indices_t, offsets_t
+
+ if trans_B:
+ stride_bk, stride_bn = rhs.stride(1), rhs.stride(0)
+
+ _dsd_kernel[grid](
+ data.data, rhs, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(0), out.stride(1),
+ row_indices, a_column_indices, a_offsets,
+ block_offsets_t, trans_A, trans_B,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+ # return out
+
+def dds(lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b,
+ out
+ ):
+
+ device = lhs.device
+ trans_B = transpose_b
+ trans_A = False
+
+ if lhs.stride(0) > 1 and lhs.stride(1) > 1:
+ trans_A = True
+
+ # checks constraints
+ assert lhs.shape[1] == shape[0], "incompatible dimensions"
+ M, K = lhs.shape
+ _, N = shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if lhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ stride_am, stride_ak = lhs.stride(0), lhs.stride(1)
+ stride_bk, stride_bn = data.stride(1), data.stride(2)
+ b_column_indices = column_indices_t
+ b_offsets = offsets_t
+
+ # launch kernel
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N']))
+
+ if trans_A:
+ stride_am, stride_ak = lhs.stride(1), lhs.stride(0)
+ if trans_B:
+ stride_bk, stride_bn = data.stride(2), data.stride(1)
+ b_column_indices, b_offsets = column_indices, offsets
+
+ _dds_kernel[grid](
+ lhs, data, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(0), out.stride(1),
+ row_indices, b_column_indices, b_offsets,
+ block_offsets_t, trans_A, trans_B,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+
+def sdd(lhs,
+ rhs,
+ shape,
+ out,
+ offsets,
+ row_indices,
+ column_indices
+ ):
+
+ device = out.device
+ trans_A = False
+ trans_B = False
+
+ if lhs.stride(0) > 1 and lhs.stride(1) > 1:
+ trans_A = True
+ if rhs.stride(0) > 1 and rhs.stride(1) > 1:
+ trans_B = True
+
+ # checks constraints
+ assert lhs.shape[1] == rhs.shape[0], "incompatible dimensions"
+ M, K = lhs.shape
+ _, N = rhs.shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if out.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ # launch kernel
+ nnz_blocks = len(row_indices)
+ grid = lambda META: (nnz_blocks,)
+
+ stride_am, stride_ak = lhs.stride(0), lhs.stride(1)
+ stride_bk, stride_bn = rhs.stride(0), rhs.stride(1)
+
+ if trans_A:
+ stride_am, stride_ak = lhs.stride(1), lhs.stride(0)
+ if trans_B:
+ stride_bk, stride_bn = rhs.stride(1), rhs.stride(0)
+
+ _sdd_kernel[grid](
+ lhs, rhs, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(1), out.stride(2),
+ row_indices, column_indices,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+
+@triton.jit
+def _row_indices_kernel(offsets, out):
+ pid = tl.program_id(0)
+ row_offset = tl.load(offsets + pid)
+ nnz_blocks = tl.load(offsets + pid + 1) - row_offset
+ for nnz_block in range(nnz_blocks):
+ tl.store(out + row_offset + nnz_block, pid)
+
+def row_indices(
+ shape, data, offsets, column_indices, out
+):
+ block_rows = len(offsets) - 1
+ _row_indices_kernel[(block_rows, )](offsets, out)
diff --git a/build/torch211-cxx11-cu126-aarch64-linux/stk/matrix.py b/build/torch211-cxx11-cu126-aarch64-linux/stk/matrix.py
new file mode 100644
index 0000000000000000000000000000000000000000..80f42263d6aada287adbfa52a61fe950162a9e28
--- /dev/null
+++ b/build/torch211-cxx11-cu126-aarch64-linux/stk/matrix.py
@@ -0,0 +1,329 @@
+import numpy as np
+import torch
+
+# 1. Add heavyweight (data) validation helper.
+# 2. Add construction helpers
+# 3. Make indentation consistent
+# 4. Replace asserts with descriptive errors.
+
+##
+### Validation helpers.
+##
+
+
+def _validate_matrix(shape, data, row_indices, column_indices, offsets):
+ # Data should be [nnz, block_size, block_size]
+ if data.dim() == 1:
+ data = torch.reshape(data, [data.numel(), 1, 1])
+
+ # Blocks should be square.
+ if data.shape[-2] != data.shape[-1]:
+ raise ValueError(
+ "Expected square blocking in data. "
+ f"Got block shape {[data.shape[-2], data.shape[-1]]}")
+
+ # Flatten batch dimensions on data - original shape preserved
+ # in shape argument.
+ block_size = data.shape[-1]
+ data = data.view([-1, block_size, block_size])
+
+ if data.dim() != 3:
+ raise ValueError(
+ "Expected 3D shape for data (nnz, block, block). "
+ f"Got shape {data.dim()}D shape.")
+
+ block_size = data.shape[1]
+ if shape[-2] % block_size != 0 or shape[-1] % block_size != 0:
+ raise ValueError(
+ "Matrix shape must be dividible by blocking. "
+ f"Got shape {shape} with "
+ f"{[block_size, block_size]} blocking.")
+
+ if np.prod(shape) < data.numel():
+ raise ValueError(
+ "Invalid matrix. Number of nonzeros exceeds matrix capacity "
+ f"({data.numel()} v. {np.prod(shape)})")
+
+ if row_indices.dim() != 1:
+ raise ValueError(
+ f"Expected 1D row_indices. Got {row_indices.dim()}D row_indices.")
+
+ if column_indices.dim() != 1:
+ raise ValueError(
+ f"Expected 1D column_indices. Got {column_indices.dim()}D column_indices.")
+
+ if offsets.dim() != 1:
+ raise ValueError(
+ f"Expected 1D offsets. Got {offsets.dim()}D offsets.")
+
+ if row_indices.numel() != data.shape[0]:
+ raise ValueError(
+ "Expected 1 index per nonzero block. "
+ f"Got {row_indices.numel()} row_indices for {data.shape[0]} blocks")
+
+ if column_indices.numel() != data.shape[0]:
+ raise ValueError(
+ "Expected 1 index per nonzero block. "
+ f"Got {column_indices.numel()} column_indices for {data.shape[0]} blocks")
+
+ block_rows = np.prod(shape[:-1]) / block_size
+ if offsets.numel() != block_rows + 1:
+ raise ValueError(
+ "Expected one offset per block row plus one. "
+ f"Got {offsets.numel()} offsets with {block_rows} block rows.")
+
+ is_cuda = (data.is_cuda and
+ row_indices.is_cuda and
+ column_indices.is_cuda and
+ offsets.is_cuda)
+ is_cpu = (not data.is_cuda and
+ not row_indices.is_cuda and
+ not column_indices.is_cuda and
+ not offsets.is_cuda)
+ if not (is_cuda or is_cpu):
+ raise ValueError(
+ "Expected data & meta-data on common device. "
+ f"Got data on {data.device}, row_indices on {row_indices.device} "
+ f"column_indices on {column_indices.device} and "
+ f"offsets on {offsets.device}.")
+
+ if data.dtype != torch.float16:
+ raise ValueError(
+ f"Expected float16 data. Got {data.dtype} data.")
+ if row_indices.dtype != torch.int16:
+ raise ValueError(
+ f"Expected int16 row_indices. Got {row_indices.dtype} row_indices.")
+ if column_indices.dtype != torch.int16:
+ raise ValueError(
+ f"Expected int16 column_indices. Got {column_indices.dtype} column_indices.")
+ if offsets.dtype != torch.int32:
+ raise ValueError(
+ f"Expected int32 offsets. Got {offsets.dtype} offsets.")
+ return data
+
+
+def _transpose(size, data, row_indices, column_indices, offsets):
+ block_columns = size[1] // data.shape[1]
+
+ # Sort row indices by column indices to get the transposed matrix's
+ # column indices.
+ gather_indices = column_indices.argsort()
+ column_indices_t = row_indices.gather(0, gather_indices)
+ block_offsets_t = gather_indices.int()
+
+ # NOTE: Histogram is not implemented for any integer type on CPU. Do
+ # the histogram in 32-bit float, which can exactly represent 16-bit
+ # integers.
+ column_indices_float = column_indices.float()
+
+ zero = torch.zeros((1,), dtype=torch.int32, device=data.device)
+ nnz_per_column = column_indices_float.histc(block_columns, 0, block_columns)
+ nnz_per_column = nnz_per_column.int()
+ offsets_t = torch.cat([zero, nnz_per_column.cumsum(0, dtype=torch.int32)])
+ return column_indices_t, offsets_t, block_offsets_t
+
+
+class Matrix(torch.nn.Module):
+ """A matrix stored in sparse format.
+
+ Underlying format is block compressed sparse row (BCSR).
+
+ TODO(tgale): Make this mirror torch.Tensor API as much as possible.
+ """
+
+ def __init__(self,
+ size,
+ data,
+ row_indices,
+ column_indices,
+ offsets,
+ column_indices_t=None,
+ offsets_t=None,
+ block_offsets_t=None):
+ super().__init__()
+ self._size = size
+ self._data = data
+ self._row_indices = row_indices
+ self._column_indices = column_indices
+ self._offsets = offsets
+
+ # Produce the transpose meta-data if it is not passed in.
+ if ((column_indices_t is None) or (offsets_t is None) or
+ (block_offsets_t is None)):
+ column_indices_t, offsets_t, block_offsets_t = _transpose(
+ size, data, row_indices, column_indices, offsets)
+ self._column_indices_t = column_indices_t
+ self._offsets_t = offsets_t
+ self._block_offsets_t = block_offsets_t
+
+ self._transposed = False
+
+ # Validate that our metadata will not overflow.
+ max_dim = np.iinfo(np.int16).max * self.blocking
+ if column_indices.dtype == torch.int16:
+ if size[0] > max_dim or size[1] > max_dim:
+ raise ValueError(
+ "Sparse matrix with shape {size} exceeds representable "
+ "size with 16-bit indices.")
+
+ def validate(self):
+ _validate_matrix(self._size,
+ self._data,
+ self._row_indices,
+ self._column_indices,
+ self._offsets)
+
+ # TODO(tgale): Add heavyweight data validation.
+
+ def to(self, device):
+ # TODO(tgale): Handle type conversions here. We
+ # need to set the appropriate meta-data type for
+ # the given floating-point type.
+ self._data = self._data.to(device)
+ self._row_indices = self._row_indices.to(device)
+ self._column_indices = self._column_indices.to(device)
+ self._offsets = self._offsets.to(device)
+ self._column_indices_t = self._column_indices_t.to(device)
+ self._offsets_t = self._offsets_t.to(device)
+ self._block_offsets_t = self._block_offsets_t.to(device)
+ return self
+
+ def cuda(self):
+ return self.to(torch.cuda.current_device())
+
+ def clone(self):
+ return Matrix(
+ self.size(),
+ self.data.clone(),
+ self.row_indices.clone(),
+ self.column_indices.clone(),
+ self.offsets.clone(),
+ self.column_indices_t.clone(),
+ self.offsets_t.clone(),
+ self.block_offsets_t.clone())
+
+ def t(self):
+ if self.dim() != 2:
+ raise ValueError(
+ "t() expects a tensor with <= 2 dimensions, "
+ f"but self is {self.dim()}D.")
+ out = Matrix(self.size(),
+ self.data,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+ out._transposed = not self._transposed
+ out._size = torch.Size((self._size[1], self._size[0]))
+ return out
+
+ def contiguous(self):
+ raise ValueError("Not yet implemented.")
+
+ def is_contiguous(self):
+ return not self._transposed
+
+ @property
+ def is_cuda(self):
+ return self._data.is_cuda
+
+ @property
+ def device(self):
+ return self._data.device
+
+ def size(self):
+ return self._size
+
+ @property
+ def shape(self):
+ return self.size()
+
+ def dim(self):
+ return len(self._size)
+
+ @property
+ def data(self):
+ return self._data
+
+ @property
+ def row_indices(self):
+ return self._row_indices
+
+ @property
+ def column_indices(self):
+ return self._column_indices
+
+ @property
+ def offsets(self):
+ return self._offsets
+
+ @property
+ def offsets_t(self):
+ return self._offsets_t
+
+ @property
+ def column_indices_t(self):
+ return self._column_indices_t
+
+ @property
+ def block_offsets_t(self):
+ return self._block_offsets_t
+
+ @property
+ def dtype(self):
+ return self.data.dtype
+
+ @property
+ def nnz(self):
+ return self.data.numel()
+
+ @property
+ def blocking(self):
+ return self.data.shape[1]
+
+ @property
+ def requires_grad(self):
+ return self.data.requires_grad
+
+ def requires_grad_(self, x):
+ self.data.requires_grad_(x)
+ return self
+
+ def view(self, *shape):
+ assert self.is_contiguous()
+ if shape[-1] != self.size()[-1]:
+ raise ValueError(
+ "Can't change view on compressed dimension. "
+ f"{self.size()[-1]} v. {shape[-1]}.")
+ if np.prod(shape) != np.prod(self.size()):
+ raise ValueError(
+ "Mismatch in numel of Matrix and new shape. "
+ f"{np.prod(self.size())} v. {np.prod(shape)}")
+ return Matrix(shape,
+ self.data,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+
+ @property
+ def grad(self):
+ # TODO(tgale): Make sure this mirrors torch.Tensor
+ # behavior in the case where we ask for the gradient
+ # of a non-contiguous tensor.
+ size = self.size()
+ if not self.is_contiguous():
+ size = torch.Size((size[1], size[0]))
+ out = Matrix(size,
+ self.data.grad,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+ return out if self.is_contiguous() else out.t()
diff --git a/build/torch211-cxx11-cu126-aarch64-linux/stk/ops/__init__.py b/build/torch211-cxx11-cu126-aarch64-linux/stk/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc873b236f4cd4036964c016a4036e3ce5ebf1ac
--- /dev/null
+++ b/build/torch211-cxx11-cu126-aarch64-linux/stk/ops/__init__.py
@@ -0,0 +1,3 @@
+from .linear_ops import dds, dsd, sdd
+from .matrix_ops import ones_like, row_indices, sum, to_dense, to_sparse
+from .eltwise_ops import mul
diff --git a/build/torch211-cxx11-cu126-aarch64-linux/stk/ops/eltwise_ops.py b/build/torch211-cxx11-cu126-aarch64-linux/stk/ops/eltwise_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba7d7332320250fd01fa60e60528f19de3e8ed03
--- /dev/null
+++ b/build/torch211-cxx11-cu126-aarch64-linux/stk/ops/eltwise_ops.py
@@ -0,0 +1,28 @@
+from ..matrix import Matrix
+
+def mul(a, b):
+ """Performs element-wise multiplication of matrices a and b.
+
+ It is the user's responsibility to make sure that a and b
+ follow the same matrix topology. This function assumes it is safe
+ to use the topoplogy of a.
+
+ Args:
+ a: stk.Matrix.
+ b: stk.Matrix with a's matrix topology.
+
+ Returns:
+ stk.Matrix where the entries correspond to torch.mul(a, b).
+ """
+ assert isinstance(a, Matrix)
+ assert isinstance(b, Matrix)
+ assert a.size() == b.size()
+
+ return Matrix(a.size(),
+ a.data * b.data,
+ a.row_indices,
+ a.column_indices,
+ a.offsets,
+ a.column_indices_t,
+ a.offsets_t,
+ a.block_offsets_t)
diff --git a/build/torch211-cxx11-cu126-aarch64-linux/stk/ops/eltwise_ops_test.py b/build/torch211-cxx11-cu126-aarch64-linux/stk/ops/eltwise_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..2939372a5c68ac92b47b11015db4f75f4fd60ffa
--- /dev/null
+++ b/build/torch211-cxx11-cu126-aarch64-linux/stk/ops/eltwise_ops_test.py
@@ -0,0 +1,86 @@
+import unittest
+import itertools
+import torch
+# from absl.testing import parameterized
+
+import stk
+from stk.ops.linear_ops_test import allclose, _dense_and_sparse
+
+_MATRIX_SIZES = (
+ (128, 128, 0.0),
+ (256, 256, 0.5),
+ (2048, 1024, 0.8),
+ (512, 128, 0.0),
+ (128, 512, 0.0),
+ (1024, 512, 0.0),
+ (1024, 512, 0.5),
+ (1024, 512, 0.75),
+ (512, 1024, 0.0),
+ (512, 1024, 0.5),
+ (512, 1024, 0.75),
+ (1024, 1024, 0.0),
+ (1024, 1024, 0.5),
+ (1024, 1024, 0.75),
+)
+
+_DTYPE = (
+ torch.float16, torch.bfloat16
+)
+
+def _generate_testcases():
+ testcases = itertools.product(_MATRIX_SIZES, _DTYPE)
+ testcases = [(*size, 128, dtype) for
+ (size, dtype) in testcases]
+ return testcases
+
+_ELTWISE_OP_TESTS = _generate_testcases()
+
+def _dense_and_sparse_like(x, std=0.1):
+ dense_data = torch.randn_like(x.data, device=x.device) * std
+ sparse = stk.Matrix(x.size(),
+ dense_data,
+ x.row_indices,
+ x.column_indices,
+ x.offsets)
+ dense = stk.ops.to_dense(sparse)
+
+ return (dense.requires_grad_(True),
+ sparse.requires_grad_(True))
+
+# @parameterized.parameters(_ELTWISE_OP_TESTS)
+# class EltwiseOpsTest(parameterized.TestCase):
+#
+# def testEltwiseMul(self, m, n, sparsity, blocking, dtype):
+#
+# a_dense, a = _dense_and_sparse(m, n, sparsity, blocking, dtype)
+# b_dense, b = _dense_and_sparse_like(a)
+#
+# out = stk.ops.mul(a, b)
+# expected_out = torch.mul(a_dense, b_dense)
+#
+# # Compute the gradients w.r.t. the inputs.
+# expected_out.sum().backward()
+# stk.ops.sum(out).backward()
+#
+# # Validate the results.
+# out = stk.ops.to_dense(out)
+# self.assertEqual(out.dim(), 2)
+# self.assertEqual(expected_out.size(), out.size())
+# self.assertTrue(allclose(out, expected_out))
+#
+# # LHS gradient.
+# grad = stk.ops.to_dense(a.grad)
+# expected_grad = a_dense.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size(), grad.size())
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# # RHS gradient.
+# grad = stk.ops.to_dense(b.grad)
+# expected_grad = b_dense.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size(), grad.size())
+# self.assertTrue(allclose(grad, expected_grad))
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch211-cxx11-cu126-aarch64-linux/stk/ops/linear_ops.py b/build/torch211-cxx11-cu126-aarch64-linux/stk/ops/linear_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d277c8c07f9e30addc31900a12175c8a1f4d7ad
--- /dev/null
+++ b/build/torch211-cxx11-cu126-aarch64-linux/stk/ops/linear_ops.py
@@ -0,0 +1,59 @@
+import torch
+
+from ..backend import sputnik
+from ..matrix import Matrix
+
+
+def dsd(a, b):
+ assert isinstance(a, Matrix)
+ assert isinstance(b, torch.Tensor)
+ return sputnik.dsd(
+ a.size(),
+ a.data, a.offsets,
+ a.row_indices,
+ a.column_indices,
+ a.offsets_t,
+ a.column_indices_t,
+ a.block_offsets_t,
+ not a.is_contiguous(),
+ b)
+
+
+def dds(a, b):
+ assert isinstance(a, torch.Tensor)
+ assert isinstance(b, Matrix)
+ return sputnik.dds(
+ a,
+ b.size(),
+ b.data, b.offsets,
+ b.row_indices,
+ b.column_indices,
+ b.offsets_t,
+ b.column_indices_t,
+ b.block_offsets_t,
+ not b.is_contiguous())
+
+
+def sdd(a, b, topo):
+ assert isinstance(a, torch.Tensor)
+ assert isinstance(b, torch.Tensor)
+ assert isinstance(topo, Matrix)
+ assert topo.is_contiguous()
+ out = sputnik.sdd(
+ a, b,
+ topo.size(),
+ topo.data,
+ topo.offsets,
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets_t,
+ topo.column_indices_t,
+ topo.block_offsets_t)
+ return Matrix(topo.size(),
+ out,
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets,
+ topo.column_indices_t,
+ topo.offsets_t,
+ topo.block_offsets_t)
diff --git a/build/torch211-cxx11-cu126-aarch64-linux/stk/ops/linear_ops_test.py b/build/torch211-cxx11-cu126-aarch64-linux/stk/ops/linear_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1c24d350df9c1b2346c7da885502cd696c88867
--- /dev/null
+++ b/build/torch211-cxx11-cu126-aarch64-linux/stk/ops/linear_ops_test.py
@@ -0,0 +1,216 @@
+import unittest
+import itertools
+import numpy as np
+import torch
+# from absl.testing import parameterized
+
+import stk
+
+
+def allclose(x, y, pct=0.25):
+ mask = torch.isclose(x, y, rtol=5e-2)
+ pct_diff = (mask.numel() - mask.sum()) / mask.numel() * 100
+ if pct_diff > pct:
+ print("{:.2f}% of values not close.".format(pct_diff))
+ return False
+ return True
+
+
+# An assortment of problems designed to make sure
+# the bindings are operating correctly.
+_MATRIX_SIZES = (
+ (128, 128, 128, 0.0),
+ (256, 256, 256, 0.5),
+ (2048, 1024, 512, 0.8),
+ (512, 128, 128, 0.0),
+ (128, 128, 512, 0.0),
+ (1024, 512, 512, 0.0),
+ (1024, 512, 512, 0.5),
+ (1024, 512, 512, 0.75),
+ (512, 512, 1024, 0.0),
+ (512, 512, 1024, 0.5),
+ (512, 512, 1024, 0.75),
+ (1024, 1024, 1024, 0.0),
+ (1024, 1024, 1024, 0.5),
+ (1024, 1024, 1024, 0.75),
+)
+
+_TRANSPOSE = (
+ (False, False),
+ (False, True),
+ (True, False),
+ (True, True),
+)
+
+_DTYPE = (
+ torch.float16, torch.bfloat16
+)
+
+def _generate_testcases():
+ testcases = itertools.product(_MATRIX_SIZES, _TRANSPOSE, _DTYPE)
+ testcases = [(*size, *trans, 128, dtype) for
+ (size, trans, dtype) in testcases]
+ return testcases
+
+_LINEAR_OP_TESTS = _generate_testcases()
+
+def _dense_and_sparse(rows, cols, sparsity, blocking, dtype, std=0.1):
+ mask = stk.random.dense_mask(rows, cols, sparsity, blocking)
+ dense = (torch.randn(rows, cols) * std * mask).type(dtype)
+ sparse = stk.ops.to_sparse(dense, blocking)
+ cuda_device = torch.device("cuda")
+ return (dense.to(cuda_device).requires_grad_(True),
+ sparse.to(cuda_device).requires_grad_(True))
+
+
+def _dense(rows, cols, dtype, std=0.1):
+ cuda_device = torch.device("cuda")
+ out = (torch.randn(rows, cols) * std).type(dtype)
+ return out.to(cuda_device).requires_grad_(True)
+
+
+def _dense_2x(rows, cols, dtype):
+ a = _dense(rows, cols, dtype)
+ return a, a.detach().requires_grad_(True)
+
+
+def _with_transpose(op, a, b, trans_a, trans_b):
+ a = a.t() if trans_a else a
+ b = b.t() if trans_b else b
+ return op(a, b)
+
+
+def _mmm(a, b, topo):
+ mask = stk.ops.to_dense(stk.ops.ones_like(topo))
+ return torch.mm(a, b) * mask
+
+
+def _sparse_out_with_transpose(op, a, b, topo, trans_a, trans_b):
+ a = a.t() if trans_a else a
+ b = b.t() if trans_b else b
+ return op(a, b, topo)
+
+
+def _mask(x, mask):
+ mask = stk.ops.to_dense(stk.ops.ones_like(mask))
+ return x * mask
+
+
+# @parameterized.parameters(*_LINEAR_OP_TESTS)
+# class LinearOpsTest(parameterized.TestCase):
+#
+# def testLinearOps_Dsd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+# # Construct the operands.
+# a_shape = (k, m) if trans_a else (m, k)
+# a_dense, a = _dense_and_sparse(*a_shape, sparsity, blocking, dtype)
+# b_shape = (n, k) if trans_b else (k, n)
+# b, bcp = _dense_2x(*b_shape, dtype)
+#
+# # Execute the matmul.
+# out = _with_transpose(stk.ops.dsd, a, b, trans_a, trans_b)
+# expected_out = _with_transpose(torch.mm, a_dense, bcp, trans_a, trans_b)
+#
+# # Compute the gradients w.r.t. the inputs.
+# expected_out.sum().backward()
+# out.sum().backward()
+#
+# # Validate the results.
+# self.assertEqual(out.dim(), 2)
+# self.assertEqual(expected_out.size()[0], out.size()[0])
+# self.assertEqual(expected_out.size()[1], out.size()[1])
+# self.assertTrue(allclose(out, expected_out))
+#
+# # LHS gradient.
+# grad = stk.ops.to_dense(a.grad)
+# expected_grad = _mask(a_dense.grad, a.grad)
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# # RHS gradient.
+# grad = b.grad
+# expected_grad = bcp.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# def testLinearOps_Dds(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+# # Construct the operands.
+# a_shape = (k, m) if trans_a else (m, k)
+# a, acp = _dense_2x(*a_shape, dtype)
+# b_shape = (n, k) if trans_b else (k, n)
+# b_dense, b = _dense_and_sparse(*b_shape, sparsity, blocking, dtype)
+#
+# # Execute the matmul.
+# out = _with_transpose(stk.ops.dds, a, b, trans_a, trans_b)
+# expected_out = _with_transpose(torch.mm, acp, b_dense, trans_a, trans_b)
+#
+# # Compute the gradients w.r.t. the inputs.
+# expected_out.sum().backward()
+# out.sum().backward()
+#
+# # Validate the results.
+# self.assertEqual(out.dim(), 2)
+# self.assertEqual(expected_out.size()[0], out.size()[0])
+# self.assertEqual(expected_out.size()[1], out.size()[1])
+# self.assertTrue(allclose(out, expected_out))
+#
+# # LHS gradient.
+# grad = a.grad
+# expected_grad = acp.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# # RHS gradient.
+# grad = stk.ops.to_dense(b.grad)
+# expected_grad = _mask(b_dense.grad, b.grad)
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# def testLinearOps_Sdd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+# # Construct the operands.
+# a_shape = (k, m) if trans_a else (m, k)
+# a, acp = _dense_2x(*a_shape, dtype)
+# b_shape = (n, k) if trans_b else (k, n)
+# b, bcp = _dense_2x(*b_shape, dtype)
+# _, topo = _dense_and_sparse(m, n, sparsity, blocking, dtype)
+#
+# # Execute the matmul.
+# out = _sparse_out_with_transpose(stk.ops.sdd, a, b, topo, trans_a, trans_b)
+# expected_out = _sparse_out_with_transpose(_mmm, acp, bcp, topo, trans_a, trans_b)
+#
+# # Compute the gradients w.r.t. the inputs.
+# expected_out.sum().backward()
+# stk.ops.sum(out).backward()
+#
+# # Validate the results.
+# out = stk.ops.to_dense(out)
+# self.assertEqual(out.dim(), 2)
+# self.assertEqual(expected_out.size()[0], out.size()[0])
+# self.assertEqual(expected_out.size()[1], out.size()[1])
+# self.assertTrue(allclose(out, expected_out))
+#
+# # LHS gradient.
+# grad = a.grad
+# expected_grad = acp.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# # RHS gradient.
+# grad = b.grad
+# expected_grad = bcp.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch211-cxx11-cu126-aarch64-linux/stk/ops/matrix_ops.py b/build/torch211-cxx11-cu126-aarch64-linux/stk/ops/matrix_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..447c72dc73439d84f58c917676cc04e64f13e97d
--- /dev/null
+++ b/build/torch211-cxx11-cu126-aarch64-linux/stk/ops/matrix_ops.py
@@ -0,0 +1,98 @@
+from ..backend import sputnik
+from ..matrix import Matrix
+import torch
+import numpy as np
+
+
+@torch.no_grad()
+def row_indices(shape, data, offsets, column_indices):
+ return sputnik.row_indices(shape, data, offsets, column_indices)
+
+
+# TODO(tgale): Replace this helper with a custom kernel. This operation
+# is much simpler to do than how it's currently implemented.
+@torch.no_grad()
+def _expand_for_blocking(idxs, blocking):
+ # Duplicate for block column dimension.
+ idxs = torch.reshape(idxs, [idxs.size()[0], 1, 2]).repeat(1, blocking, 1)
+
+ # Update the column indices.
+ idxs[:, :, 1] *= blocking
+ idxs[:, :, 1] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking])
+
+ # Duplicate for block row dimension.
+ idxs = torch.reshape(idxs, [idxs.size()[0], 1, blocking, 2])
+ idxs = idxs.repeat(1, blocking, 1, 1)
+
+ # Update the row indices.
+ idxs[:, :, :, 0] *= blocking
+ idxs[:, :, :, 0] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking, 1])
+ idxs = torch.reshape(idxs, [-1, 2])
+ return idxs
+
+
+# TODO(tgale): Add input type checking.
+@torch.no_grad()
+def to_dense(x):
+ assert isinstance(x, Matrix)
+
+ shape = (np.prod(x.shape[:-1]), x.shape[-1])
+ row_idxs = x.row_indices.type(torch.int32)
+ col_idxs = x.column_indices.type(torch.int32)
+ indices = _expand_for_blocking(torch.stack([row_idxs, col_idxs], dim=1), x.blocking)
+ indices = (indices[:, 0] * shape[1] + indices[:, 1]).type(torch.int64)
+
+ out = torch.zeros(shape[0] * shape[1], dtype=x.dtype, device=x.device)
+ out.scatter_(0, indices, x.data.flatten())
+ return out.reshape(x.size())
+
+
+@torch.no_grad()
+def _mask(x, blocking=1):
+ assert x.dim() == 2
+ assert x.size()[0] % blocking == 0
+ assert x.size()[1] % blocking == 0
+ block_rows = x.size()[0] // blocking
+ block_cols = x.size()[1] // blocking
+ x = torch.reshape(x, [block_rows, blocking, block_cols, blocking])
+ x = torch.sum(torch.abs(x), dim=(1, 3))
+ return x != 0
+
+
+# TODO(tgale): Add input type checking.
+@torch.no_grad()
+def to_sparse(x, blocking=1):
+ m = _mask(x, blocking)
+
+ # TODO(tgale): Set to appropriate type for input matrix.
+ row_nnzs = torch.sum(m, dim=1).type(torch.int32)
+ zeros = torch.zeros((1,), dtype=row_nnzs.dtype, device=row_nnzs.device)
+ offsets = torch.cat([zeros, torch.cumsum(row_nnzs, dim=0)])
+ offsets = offsets.type(torch.int32)
+
+ indices = torch.nonzero(m).type(torch.int16)
+ row_indices = indices[:, 0]
+ column_indices = indices[:, 1]
+
+ # Nonzero indices in the dense matrix.
+ nonzero_indices = torch.nonzero(m)
+ nonzero_indices = _expand_for_blocking(nonzero_indices, blocking)
+ nonzero_indices = nonzero_indices[:, 0] * x.size()[1] + nonzero_indices[:, 1]
+
+ # Gather the data and construct the sparse matrix.
+ data = torch.gather(x.flatten(), dim=0, index=nonzero_indices)
+ data = torch.reshape(data, [-1, blocking, blocking])
+ return Matrix(x.size(), data, row_indices, column_indices, offsets)
+
+
+@torch.no_grad()
+def ones_like(x):
+ return Matrix(x.size(),
+ torch.ones_like(x.data),
+ x.row_indices,
+ x.column_indices, x.offsets)
+
+
+def sum(x):
+ assert isinstance(x, Matrix)
+ return x.data.sum()
diff --git a/build/torch211-cxx11-cu126-aarch64-linux/stk/ops/matrix_ops_test.py b/build/torch211-cxx11-cu126-aarch64-linux/stk/ops/matrix_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..d172d921f6f08b0e4fb709207a458b0e1e071bd0
--- /dev/null
+++ b/build/torch211-cxx11-cu126-aarch64-linux/stk/ops/matrix_ops_test.py
@@ -0,0 +1,62 @@
+import unittest
+
+# from absl.testing import parameterized
+import stk
+import torch
+
+
+# @parameterized.parameters(
+# (8, 16, 0.0, 1),
+# (8, 16, 0.5, 1),
+# (8, 16, .95, 1),
+# (16, 8, 0.0, 1),
+# (16, 8, 0.5, 1),
+# (16, 8, .95, 1),
+# (8, 16, 0.0, 8),
+# (8, 16, 0.5, 8),
+# (8, 16, 1.0, 8),
+# (16, 8, 0.0, 8),
+# (16, 8, 0.5, 8),
+# (16, 8, 1.0, 8),
+# (128, 256, 0.5, 16),
+# (256, 128, 0.75, 32),
+# (512, 512, .875, 128))
+# class MatrixOpsTest(parameterized.TestCase):
+#
+# def testMatrixOps_FormatConversion(self, rows, cols, sparsity, blocking):
+# mask = stk.random.dense_mask(rows, cols, sparsity, blocking)
+# x = (torch.randn(rows, cols) * mask).type(torch.float16)
+#
+# # Convert the matrix to sparse format.
+# sparse_x = stk.ops.to_sparse(x, blocking)
+#
+# # Validate the matrix.
+# sparse_x.validate()
+#
+# # Validate the shape.
+# self.assertEqual(sparse_x.dim(), 2)
+# self.assertEqual(sparse_x.size()[0], rows)
+# self.assertEqual(sparse_x.size()[1], cols)
+#
+# # Validate the sparsity.
+# numblocks = rows // blocking * cols // blocking
+# nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+# self.assertEqual(sparse_x.nnz, nnz)
+#
+# # Convert back to dense format.
+# dense_x = stk.ops.to_dense(sparse_x)
+#
+# # Validate the shape.
+# self.assertEqual(dense_x.dim(), 2)
+# self.assertEqual(dense_x.size()[0], rows)
+# self.assertEqual(dense_x.size()[1], cols)
+#
+# # Validate the sparsity
+# self.assertEqual(torch.count_nonzero(dense_x).item(), nnz)
+#
+# # Validate the output.
+# self.assertTrue(torch.all(torch.eq(x, dense_x)))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch211-cxx11-cu126-aarch64-linux/stk/random/__init__.py b/build/torch211-cxx11-cu126-aarch64-linux/stk/random/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2576d1ca27283f77569a9a620c7c99fa68aaf30e
--- /dev/null
+++ b/build/torch211-cxx11-cu126-aarch64-linux/stk/random/__init__.py
@@ -0,0 +1,2 @@
+# from stk.random.random_ops import dense_mask, mask, randn
+from .random_ops import dense_mask, mask, randn
diff --git a/build/torch211-cxx11-cu126-aarch64-linux/stk/random/random_ops.py b/build/torch211-cxx11-cu126-aarch64-linux/stk/random/random_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1b36771e0eb8e7abf46bcb3b136b5fb1d29df93
--- /dev/null
+++ b/build/torch211-cxx11-cu126-aarch64-linux/stk/random/random_ops.py
@@ -0,0 +1,36 @@
+import numpy as np
+import torch
+from ..ops import matrix_ops
+
+
+@torch.no_grad()
+def dense_mask(rows, cols, sparsity, blocking=1):
+ assert sparsity >= 0.0 and sparsity <= 1.0
+ assert rows % blocking == 0 and cols % blocking == 0
+
+ block_rows, block_cols = (rows // blocking, cols // blocking)
+ nnz = round(block_rows * block_cols * (1 - sparsity))
+
+ out = np.ones(block_rows * block_cols)
+ mask = np.random.choice(out.size, out.size - nnz, replace=False)
+ out[mask] = 0.0
+
+ out = np.tile(
+ np.reshape(out, [block_rows, 1, block_cols, 1]),
+ (1, blocking, 1, blocking))
+ out = np.reshape(out, [rows, cols])
+ return torch.from_numpy(out.astype(np.float32))
+
+
+@torch.no_grad()
+def mask(m, n, sparsity, blocking=1):
+ out = dense_mask(m, n, sparsity, blocking).type(torch.float16)
+ return matrix_ops.to_sparse(out, blocking=blocking)
+
+
+@torch.no_grad()
+def randn(shape, sparsity, blocking=1):
+ shape_2d = (np.prod(shape[:-1]), shape[-1])
+ out = mask(*shape_2d, sparsity, blocking)
+ out.data.copy_(torch.randn(*out.data.shape))
+ return out.view(*shape)
diff --git a/build/torch211-cxx11-cu126-aarch64-linux/stk/random/random_ops_test.py b/build/torch211-cxx11-cu126-aarch64-linux/stk/random/random_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..71d716b78b5ec009cbf9ac2dfdf09162a0102e62
--- /dev/null
+++ b/build/torch211-cxx11-cu126-aarch64-linux/stk/random/random_ops_test.py
@@ -0,0 +1,73 @@
+import unittest
+
+# from absl.testing import parameterized
+from . import random
+import torch
+
+
+# @parameterized.parameters(
+# (8, 16, 0.0, 1),
+# (8, 16, 0.5, 1),
+# (8, 16, .95, 1),
+# (16, 8, 0.0, 1),
+# (16, 8, 0.5, 1),
+# (16, 8, .95, 1),
+# (8, 16, 0.0, 8),
+# (8, 16, 0.5, 8),
+# (8, 16, 1.0, 8),
+# (16, 8, 0.0, 8),
+# (16, 8, 0.5, 8),
+# (16, 8, 1.0, 8),
+# (128, 256, 0.5, 16),
+# (256, 128, 0.75, 32),
+# (512, 512, .875, 128))
+# class RandomOpsTest(parameterized.TestCase):
+#
+# def testRandomOps_DenseMask(self, rows, cols, sparsity, blocking):
+# mask = random.dense_mask(
+# rows, cols, sparsity, blocking)
+#
+# # Validate the shape.
+# self.assertEqual(mask.dim(), 2)
+# self.assertEqual(mask.size()[0], rows)
+# self.assertEqual(mask.size()[1], cols)
+#
+# # Validate the sparsity
+# numblocks = rows // blocking * cols // blocking
+# nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+# self.assertEqual(
+# torch.count_nonzero(mask).item(),
+# nnz)
+#
+# # Check values are zero or one.
+# self.assertTrue(
+# torch.all(torch.logical_or(
+# torch.eq(mask, 0),
+# torch.eq(mask, 1))))
+#
+# def testRandomOps_SparseMask(self, rows, cols, sparsity, blocking):
+# mask = random.mask(
+# rows, cols, sparsity, blocking)
+#
+# # Validate the matrix.
+# mask.validate()
+#
+# # Validate the shape.
+# self.assertEqual(mask.dim(), 2)
+# self.assertEqual(mask.size()[0], rows)
+# self.assertEqual(mask.size()[1], cols)
+#
+# # Validate the sparsity.
+# numblocks = rows // blocking * cols // blocking
+# nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+# self.assertEqual(mask.nnz, nnz)
+#
+# # Check values are zero or one.
+# self.assertTrue(
+# torch.all(torch.logical_or(
+# torch.eq(mask.data, 0),
+# torch.eq(mask.data, 1))))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch211-cxx11-cu126-aarch64-linux/xpu_fused_moe.py b/build/torch211-cxx11-cu126-aarch64-linux/xpu_fused_moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..0440b484e6f4073a384dcf3bc562174d2ef71d25
--- /dev/null
+++ b/build/torch211-cxx11-cu126-aarch64-linux/xpu_fused_moe.py
@@ -0,0 +1,672 @@
+# SPDX-License-Identifier: Apache-2.0
+# MegaBlocks XPU Fused MoE Implementation
+import os
+import torch
+
+from ._ops import ops, add_op_namespace_prefix
+
+from torch.library import register_fake
+
+
+def resolve_dtensor(weight: torch.Tensor):
+ """Convert DTensor to local tensor for use with custom ops."""
+ from torch.distributed._tensor import DTensor
+ if isinstance(weight, DTensor):
+ return weight.to_local()
+ return weight
+
+
+# Register fake/meta kernels for torch.compile compatibility
+def _register_xpu_fake_kernels():
+ """Register fake kernels for XPU MoE operations to support torch.compile."""
+
+ def _register_if_available(op_name, fn):
+ if hasattr(ops, op_name):
+ register_fake(add_op_namespace_prefix(op_name))(fn)
+
+ _register_if_available(
+ "cutlass_grouped_gemm_interface",
+ lambda ptr_A, ptr_B, ptr_scales, ptr_bias, ptr_D, expert_first_token_offset, N, K, num_experts, is_B_int4, is_B_mxfp4: ptr_D,
+ )
+
+ _register_if_available(
+ "fused_moe_prologue",
+ lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, ep_rank, ep_size, num_experts_on_rank: None,
+ )
+
+ _register_if_available(
+ "moe_gather",
+ lambda output, moe_output, topk_weights, permuted_row_to_unpermuted_row, unpermuted_row_to_permuted_row, expert_first_token_offset, num_experts: None,
+ )
+
+ _register_if_available(
+ "silu_and_mul",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "mul_and_silu",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_and_mul",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_tanh_and_mul",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_fast",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_new",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_quick",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "swigluoai_and_mul",
+ lambda out, input, alpha=1.702, limit=7.0: None,
+ )
+
+
+# Register fake kernels on module load
+_register_xpu_fake_kernels()
+
+
+# default
+def cutlass_grouped_gemm(input_A, input_B, bias, output, expert_token_count, n,
+ k, num_experts):
+ # expert_token_count_ = torch.tensor(expert_token_count,
+ # dtype=torch.int64,
+ # device=input_A.device)
+ # if bias is not None:
+ # bias = bias.repeat_interleave(expert_token_count_, dim=0).float()
+
+ def exclusive_prefix_sum(arr):
+ prefix = [0]
+ for i, x in enumerate(arr):
+ prefix.append(prefix[-1] + x)
+ return prefix
+
+ expert_offset = torch.tensor(exclusive_prefix_sum(expert_token_count),
+ dtype=torch.int64,
+ device="xpu")
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=None,
+ ptr_bias=bias,
+ ptr_D=output,
+ expert_first_token_offset=expert_offset,
+ N=n,
+ K=k,
+ num_experts=num_experts,
+ is_B_int4=False,
+ is_B_mxfp4=False)
+
+
+def cutlass_grouped_gemm_xe2(input_A, input_B, scales, bias, output,
+ num_rows_per_expert, n, k, num_experts, is_B_int4,
+ is_B_mxfp4):
+ expert_first_token_offset = torch.cat([
+ torch.tensor([0],
+ dtype=num_rows_per_expert.dtype,
+ device=num_rows_per_expert.device),
+ torch.cumsum(num_rows_per_expert, dim=0)
+ ]).to(torch.int64)
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=scales,
+ ptr_bias=bias,
+ ptr_D=output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=n,
+ K=k,
+ num_experts=num_experts,
+ is_B_int4=is_B_int4,
+ is_B_mxfp4=is_B_mxfp4)
+
+
+def ceilDiv(a, b):
+ return (a + b - 1) // b
+
+
+def compute_num_tokens_per_block(num_tokens, num_experts_per_node):
+ for num_tokens_per_block in [32, 64, 128, 256, 512, 1024]:
+ num_blocks_per_seq = ceilDiv(num_tokens, num_tokens_per_block)
+ if num_blocks_per_seq * num_experts_per_node <= num_tokens_per_block:
+ return num_tokens_per_block
+ return 1024
+
+
+def _bytes_to_typed_tensor(byte_tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
+ """Reinterpret a uint8 buffer as a typed tensor by copying bytes.
+
+ This avoids `Tensor.view(dtype)` which can fail under torch.compile
+ constant folding when shape divisibility is not proven.
+ """
+ if byte_tensor.dtype != torch.uint8:
+ raise ValueError("byte_tensor must be uint8")
+ itemsize = torch.empty((), dtype=dtype).element_size()
+ numel = byte_tensor.numel() // itemsize
+ out = torch.empty((numel,), dtype=dtype, device=byte_tensor.device)
+ out.view(torch.uint8).copy_(byte_tensor.contiguous())
+ return out
+
+
+def implement_zp(qweight):
+ # change u4 to s4 to avoid zero point in gemm kernel
+ # only support default zero point now
+ assert qweight.dtype == torch.uint8, "Input tensor must be uint8"
+
+ high_u4 = (qweight >> 4) & 0x0F
+ low_u4 = qweight & 0x0F
+
+ high_s8 = high_u4.to(torch.int8)
+ low_s8 = low_u4.to(torch.int8)
+
+ high_s8 = high_s8 - 8
+ low_s8 = low_s8 - 8
+
+ def pack_compact(a, b):
+
+ def process_number(x):
+ sign = (x < 0).to(torch.uint8)
+ abs_low3 = (x.view(torch.uint8) & 0x7).to(torch.uint8)
+ return (sign << 3) | abs_low3
+
+ packed_a = process_number(a)
+ packed_b = process_number(b)
+
+ return (packed_a << 4) | packed_b
+
+ result = pack_compact(high_s8, low_s8)
+
+ return result
+
+
+def xpu_fused_moe(hidden_states,
+ w13,
+ w13_scales,
+ w13_bias,
+ w2,
+ w2_scales,
+ w2_bias,
+ topk_weights,
+ topk_ids,
+ n_experts_per_token,
+ activation,
+ num_experts,
+ ep_rank=0,
+ ep_size=1,
+ is_fp8=False,
+ is_int4=False,
+ is_mxfp4=False):
+ '''
+ hidden_states: [num_rows, hidden_size]
+ w13: [num_experts, 2*inter_size, hidden_size]
+ w13_scales:
+ None for bf16/fp16
+ or [num_experts] for fp8
+ or [num_experts, 2*inter_size, hidden_size // group_size] for 4bits
+ w13_bias: [num_experts, 2*inter_size] or None
+ w2: [num_experts, hidden_size, inter_size]
+ w2_scales:
+ None for bf16/fp16
+ or [num_experts] for fp8
+ or [num_experts, hidden_size, inter_size // group_size] for 4bits
+ w2_bias: [num_experts, hidden_size] or None
+ topk_weights: [num_rows, topk]
+ topk_ids: [num_rows, topk]
+ n_experts_per_token: int
+ activation: str
+ num_experts: int
+ is_int4: bool
+ is_mxfp4: bool
+ '''
+
+ # Resolve DTensors to local tensors before passing to custom ops
+ hidden_states = resolve_dtensor(hidden_states)
+ w13 = resolve_dtensor(w13)
+ w2 = resolve_dtensor(w2)
+ if w13_scales is not None:
+ w13_scales = resolve_dtensor(w13_scales)
+ if w13_bias is not None:
+ w13_bias = resolve_dtensor(w13_bias)
+ if w2_scales is not None:
+ w2_scales = resolve_dtensor(w2_scales)
+ if w2_bias is not None:
+ w2_bias = resolve_dtensor(w2_bias)
+ topk_weights = resolve_dtensor(topk_weights)
+ topk_ids = resolve_dtensor(topk_ids)
+
+ output = torch.empty_like(hidden_states)
+ num_rows, hidden_size = list(hidden_states.shape)
+
+ dim_last = w13.shape[-1]
+ dim_second_last = w13.shape[-2]
+
+ # w13 is combined gate+up weights, so one dimension is 2*inter_size
+ # Determine which dimension is hidden_size and which is 2*inter_size
+ if dim_second_last == hidden_size:
+ # w13 is [E, hidden_size, 2*inter_size] - standard layout
+ inter_size = dim_last // 2
+ needs_transpose = False
+ else:
+ # w13 is [E, 2*inter_size, hidden_size] - needs transpose
+ inter_size = dim_second_last // 2
+ needs_transpose = True
+
+ assert w13.is_contiguous() and w2.is_contiguous()
+
+ # 4bits support [E, N, K]
+ # other types [E, K, N]
+ if not is_int4 and not is_mxfp4:
+ if not hasattr(w13, 'xpu_fused_moe'):
+ if needs_transpose:
+ w13.data = w13.transpose(-1, -2).contiguous()
+ w2.data = w2.transpose(-1, -2).contiguous()
+ w13.xpu_fused_moe = True
+ w13.inter_size = inter_size
+ else:
+ inter_size = w13.inter_size
+
+ if is_int4 and not hasattr(w13, 'xpu_fused_moe'):
+ w13_tmp = torch.empty_like(w13)
+ w2_tmp = torch.empty_like(w2)
+ for i in range(num_experts):
+ w13_tmp[i] = implement_zp(w13[i])
+ w2_tmp[i] = implement_zp(w2[i])
+ w13_tmp = w13_tmp.contiguous()
+ w2_tmp = w2_tmp.contiguous()
+ w13.data = w13_tmp
+ w2.data = w2_tmp
+ w13.xpu_fused_moe = True
+
+ # TODO: will all integrated in Cpp func. Temporary expose before gemm fusion
+ num_experts_per_node = num_experts
+ experts_per_token = n_experts_per_token
+ num_moe_inputs = n_experts_per_token * num_rows
+ permuted_elems = num_moe_inputs * hidden_size
+ # interbuf_elems = num_moe_inputs * inter_size
+ permuted_row_to_unpermuted_row_size = num_moe_inputs * 4
+ permuted_token_selected_experts_size = num_moe_inputs * 4
+ src_to_dest_map_size = experts_per_token * num_rows * 4
+ expert_first_token_offset_size = (num_experts_per_node + 1) * 8
+ num_tokens_per_block = compute_num_tokens_per_block(
+ num_rows, num_experts_per_node)
+ num_blocks_per_seq = ceilDiv(num_rows, num_tokens_per_block)
+ blocked_expert_counts_size = num_experts_per_node * num_blocks_per_seq * 4
+ blocked_expert_counts_cumsum_size = blocked_expert_counts_size
+ blocked_row_to_unpermuted_row_size = num_experts_per_node * num_rows * 4
+ permuted_data_size = permuted_elems * hidden_states.element_size()
+ permuted_token_final_scales_size = num_moe_inputs * 4
+
+ ws_map = {}
+ map_offset = 0
+
+ def config_ws(name, size):
+ nonlocal map_offset
+ if size % 256 != 0:
+ size += 256 - size % 256
+ ws_map[name] = (size, map_offset)
+ map_offset += size
+
+ config_ws("permuted_row_to_unpermuted_row",
+ permuted_row_to_unpermuted_row_size)
+ config_ws("permuted_token_selected_experts",
+ permuted_token_selected_experts_size)
+ config_ws("unpermuted_row_to_permuted_row", src_to_dest_map_size)
+ config_ws("blocked_expert_counts", blocked_expert_counts_size)
+ config_ws("blocked_expert_counts_cumsum",
+ blocked_expert_counts_cumsum_size)
+ config_ws("blocked_row_to_unpermuted_row",
+ blocked_row_to_unpermuted_row_size)
+ config_ws("expert_first_token_offset", expert_first_token_offset_size)
+ config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
+ config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
+
+ workspace = torch.zeros(map_offset,
+ dtype=torch.uint8,
+ device=hidden_states.device)
+ if topk_ids.dtype == torch.int32:
+ topk_ids = topk_ids.to(torch.int64)
+ ops.fused_moe_prologue(
+ input=hidden_states,
+ token_selected_experts=topk_ids,
+ token_final_scales=topk_weights,
+ workspace=workspace,
+ hidden_size=hidden_size,
+ inter_size=inter_size,
+ ep_rank=ep_rank,
+ ep_size=ep_size,
+ num_experts_on_rank=num_experts_per_node)
+
+ expert_first_token_offset_bytes = workspace[
+ ws_map["expert_first_token_offset"][1]:
+ ws_map["expert_first_token_offset"][1] +
+ expert_first_token_offset_size]
+ unpermuted_row_to_permuted_row_bytes = workspace[
+ ws_map["unpermuted_row_to_permuted_row"][1]:
+ ws_map["unpermuted_row_to_permuted_row"][1] +
+ src_to_dest_map_size]
+ permuted_row_to_unpermuted_row_bytes = workspace[
+ ws_map["permuted_row_to_unpermuted_row"][1]:
+ ws_map["permuted_row_to_unpermuted_row"][1] +
+ permuted_row_to_unpermuted_row_size]
+
+ if torch.compiler.is_compiling():
+ expert_first_token_offset = _bytes_to_typed_tensor(
+ expert_first_token_offset_bytes, torch.int64
+ )
+ unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
+ unpermuted_row_to_permuted_row_bytes, torch.int32
+ )
+ permuted_row_to_unpermuted_row = _bytes_to_typed_tensor(
+ permuted_row_to_unpermuted_row_bytes, torch.int32
+ )
+ else:
+ expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
+ unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
+ permuted_row_to_unpermuted_row = permuted_row_to_unpermuted_row_bytes.view(torch.int32)
+ gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
+ ws_map["overlapped_gemm1_gemm2_inputs"][1] +
+ permuted_data_size].view(hidden_states.dtype).view(
+ num_moe_inputs, hidden_size)
+ # permuted_token_final_scales = workspace[
+ # ws_map["permuted_token_final_scales"][1]:
+ # ws_map["permuted_token_final_scales"][1] +
+ # permuted_token_final_scales_size].view(torch.float)
+ gemm1_output = torch.empty((num_moe_inputs, 2 * inter_size),
+ dtype=hidden_states.dtype,
+ device=hidden_states.device)
+
+ ########### gemm1 ##################
+ input_B = w13
+
+ if not is_fp8 and not is_int4 and not is_mxfp4:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=gemm1_input,
+ ptr_B=input_B,
+ ptr_scales=None,
+ ptr_bias=w13_bias,
+ ptr_D=gemm1_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=2 * inter_size,
+ K=hidden_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+ else:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=gemm1_input,
+ ptr_B=input_B,
+ ptr_scales=w13_scales,
+ ptr_bias=w13_bias,
+ ptr_D=gemm1_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=2 * inter_size,
+ K=hidden_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+
+ # act
+ act_output = torch.empty((num_moe_inputs, inter_size),
+ dtype=gemm1_output.dtype,
+ device=gemm1_output.device)
+ if activation == "silu":
+ ops.silu_and_mul(act_output, gemm1_output)
+ elif activation == "gelu":
+ ops.gelu_and_mul(act_output, gemm1_output)
+ elif activation == "swigluoai":
+ ops.swigluoai_and_mul(act_output, gemm1_output, 1.702, 7.0)
+ else:
+ raise ValueError(f"Unsupported FusedMoe activation: {activation}.")
+
+ ########### gemm2 ##################
+ input_A = act_output.contiguous()
+ input_B = w2
+ gemm2_output = torch.empty((num_moe_inputs, hidden_size),
+ dtype=hidden_states.dtype,
+ device=hidden_states.device)
+ if not is_fp8 and not is_int4 and not is_mxfp4:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=None,
+ ptr_bias=w2_bias,
+ ptr_D=gemm2_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=hidden_size,
+ K=inter_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+ else:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=w2_scales,
+ ptr_bias=w2_bias,
+ ptr_D=gemm2_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=hidden_size,
+ K=inter_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+
+ ops.moe_gather(output, gemm2_output, topk_weights,
+ permuted_row_to_unpermuted_row,
+ unpermuted_row_to_permuted_row,
+ expert_first_token_offset,
+ num_experts_per_node)
+ return output
+
+
+def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor:
+ """Apply jitter to the input tensor for regularization."""
+ low = 1.0 - moe_jitter_eps
+ high = 1.0 + moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return x * (low + noise * (high - low))
+
+
+def compute_top_k(scores: torch.Tensor, moe_top_k: int):
+ """Compute the top-k scores from the logits."""
+ if moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, moe_top_k, dim=-1)
+
+
+def route_tokens_xpu(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: torch.Tensor,
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ training: bool = False,
+) -> tuple:
+ """Route tokens to experts and compute expert weights and indices (XPU version)."""
+ if training and moe_jitter_eps is not None:
+ x = apply_jitter(x, moe_jitter_eps)
+
+ x_flat = x.view(-1, x.shape[-1])
+ logits = torch.nn.functional.linear(x_flat, router_weight, router_bias)
+ expert_weights, expert_indices = compute_top_k(logits, moe_top_k)
+ expert_weights = expert_weights.softmax(dim=-1)
+ if moe_normalize_expert_weights is not None:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+
+ return logits, expert_weights, expert_indices
+
+
+def _get_device_mesh(model):
+ """Extract device_mesh from child's unused pre_hook closure for EP support."""
+ try:
+ hook = next(
+ h
+ for h in model.experts._forward_pre_hooks.values()
+ if "device_mesh" in h.__code__.co_freevars
+ )
+ return hook.__closure__[
+ hook.__code__.co_freevars.index("device_mesh")
+ ].cell_contents
+ except Exception:
+ return None
+
+
+class MegaBlocksMoeMLP(torch.nn.Module):
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> tuple:
+ """
+ Forward pass through the MoE layer.
+
+ Args:
+ x: Input tensor of shape [batch_size, seq_len, hidden_size] or [tokens, hidden_size]
+
+ Returns:
+ Tuple of (output, expert_weights) where:
+ - output: Tensor of same shape as input
+ - expert_weights: Expert weights for each token [tokens, top_k]
+ """
+
+ # Get MoE parameters from the wrapped modules
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+
+ # Get EP (Expert Parallelism) parameters
+ ep_size = 1
+ ep_rank = 0
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = _get_device_mesh(self)
+ if device_mesh is not None:
+ expert_parallel_group = device_mesh.get_group()
+ if expert_parallel_group is not None:
+ import torch.distributed as dist
+ if dist.is_initialized():
+ ep_size = dist.get_world_size(expert_parallel_group)
+ ep_rank = dist.get_rank(expert_parallel_group)
+
+ # Number of experts on this rank
+ num_experts_on_rank = moe_num_experts // ep_size
+
+ # Detect activation type - check for GptOss-style swigluoai activation
+ # GptOssExperts has alpha and limit attributes for swigluoai
+ if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
+ activation = "swigluoai"
+ else:
+ activation = getattr(self.experts, "activation", "silu")
+
+ # Get weight tensors - support different naming conventions
+ if hasattr(self.experts, "gate_up_proj"):
+ w13 = self.experts.gate_up_proj
+ # NOTE: swigluoai_and_mul kernel expects interleaved layout [g0,u0,g1,u1,...]
+ # which matches GptOss's gate_up_proj layout, so no conversion needed.
+
+ elif hasattr(self.experts, "w1"):
+ # Combine w1 and w3 if stored separately
+ w1 = self.experts.w1
+ w3 = getattr(self.experts, "w3", None)
+ if w3 is not None:
+ w13 = torch.cat([w1, w3], dim=-2)
+ else:
+ w13 = w1
+ else:
+ raise AttributeError("experts module must have 'gate_up_proj' or 'w1' attribute")
+
+ if hasattr(self.experts, "down_proj"):
+ w2 = self.experts.down_proj
+ elif hasattr(self.experts, "w2"):
+ w2 = self.experts.w2
+ else:
+ raise AttributeError("experts module must have 'down_proj' or 'w2' attribute")
+
+ # Get optional bias tensors
+ w13_bias = getattr(self.experts, "gate_up_proj_bias", None)
+ w2_bias = getattr(self.experts, "down_proj_bias", None)
+
+ # Get quantization info
+ is_fp8 = getattr(self.experts, "is_fp8", False)
+ is_int4 = getattr(self.experts, "is_int4", False)
+ is_mxfp4 = getattr(self.experts, "is_mxfp4", False)
+
+ w13_scales = getattr(self.experts, "gate_up_proj_scales", None)
+ w2_scales = getattr(self.experts, "down_proj_scales", None)
+
+ # Store original shape
+ in_shape = x.size()
+
+ # Route tokens to experts
+ logits, expert_weights, expert_indices = route_tokens_xpu(
+ x,
+ self.router.weight,
+ getattr(self.router, "bias", None),
+ moe_top_k,
+ moe_num_experts,
+ moe_jitter_eps,
+ moe_normalize_expert_weights,
+ self.training,
+ )
+
+ # Reshape input for fused MoE
+ x_flat = x.view(-1, x.shape[-1])
+
+ # Call XPU fused MoE kernel
+ output = xpu_fused_moe(
+ hidden_states=x_flat,
+ w13=w13,
+ w13_scales=w13_scales,
+ w13_bias=w13_bias,
+ w2=w2,
+ w2_scales=w2_scales,
+ w2_bias=w2_bias,
+ topk_weights=expert_weights.float(),
+ topk_ids=expert_indices,
+ n_experts_per_token=moe_top_k,
+ activation=activation,
+ num_experts=num_experts_on_rank,
+ ep_rank=ep_rank,
+ ep_size=ep_size,
+ is_fp8=is_fp8,
+ is_int4=is_int4,
+ is_mxfp4=is_mxfp4,
+ )
+
+ # All-reduce across EP group to combine partial expert outputs
+ if ep_size > 1 and expert_parallel_group is not None:
+ import torch.distributed as dist
+ dist.all_reduce(output, op=dist.ReduceOp.SUM, group=expert_parallel_group)
+
+ # Restore original shape
+ output = output.view(in_shape)
+
+ return output, expert_weights
+
+
+# Export classes and functions
+__all__ = [
+ "MegaBlocksMoeMLP",
+ "xpu_fused_moe",
+ "cutlass_grouped_gemm",
+ "cutlass_grouped_gemm_xe2",
+]
\ No newline at end of file
diff --git a/build/torch211-cxx11-cu126-x86_64-linux/__init__.py b/build/torch211-cxx11-cu126-x86_64-linux/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..38075732c6d8fa0e1e6ef493145e1aca3851ae6b
--- /dev/null
+++ b/build/torch211-cxx11-cu126-x86_64-linux/__init__.py
@@ -0,0 +1,202 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+from ._ops import ops
+
+from .grouped_gemm import backend as gg_backend
+from .grouped_gemm import ops as gg_ops
+
+
+from ._layers.arguments import Arguments
+from ._layers.dmoe import ParallelDroplessMLP, dMoE
+from ._layers.glu import SparseGLU
+from ._layers.mlp import MLP, SparseMLP
+from ._layers.moe import MoE, ParallelMLP, get_load_balancing_loss
+
+from . import layers
+
+# This section contains the direct kernel exports (not inlcuded in the original code)
+def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
+ """
+ Compute exclusive cumulative sum along the specified dimension.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ result = ops.exclusive_cumsum(x, dim)
+ out.copy_(result)
+ return out
+
+
+def inclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
+ """
+ Compute inclusive cumulative sum along the specified dimension.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ result = ops.inclusive_cumsum(x, dim)
+ out.copy_(result)
+ return out
+
+
+def histogram(x: torch.Tensor, num_bins: int) -> torch.Tensor:
+ """
+ Compute histogram of input tensor values.
+
+ Args:
+ x: Input tensor
+ num_bins: Number of histogram bins
+
+ Returns:
+ Histogram tensor with counts for each bin
+ """
+ return ops.histogram(x, num_bins)
+
+
+def indices(
+ padded_bins: torch.Tensor,
+ block_size: int,
+ output_block_rows: int,
+ output_block_columns: int,
+) -> torch.Tensor:
+ """
+ Construct indices from padded bins for sparse operations.
+
+ Args:
+ padded_bins: Tensor containing bin boundaries
+ block_size: Size of each block
+ output_block_rows: Number of rows in output blocks
+ output_block_columns: Number of columns in output blocks
+
+ Returns:
+ Tensor containing constructed indices
+ """
+ return ops.indices(padded_bins, block_size, output_block_rows, output_block_columns)
+
+
+def replicate_forward(
+ x: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Forward pass of replicate operation - replicate values according to bin sizes.
+
+ Args:
+ x: Input tensor with values to replicate
+ bins: Tensor containing bin sizes
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ return ops.replicate_forward(x, bins, out)
+
+
+def replicate_backward(
+ grad: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Backward pass of replicate operation - reduce gradients back to bins.
+
+ Args:
+ grad: Gradient tensor to reduce
+ bins: Tensor containing bin sizes
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ return ops.replicate_backward(grad, bins, out)
+
+
+def sort(
+ x: torch.Tensor, end_bit: int, x_out: torch.Tensor, iota_out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Radix sort with index tracking.
+
+ Args:
+ x: Input tensor to sort
+ end_bit: Number of bits to consider in sorting
+ x_out: Output tensor for sorted values
+ iota_out: Output tensor for sorted indices
+
+ Returns:
+ The sorted values tensor
+ """
+ return ops.sort(x, end_bit, x_out, iota_out)
+
+
+# Convenience functions for common use cases
+def cumsum(x: torch.Tensor, dim: int = -1, exclusive: bool = False) -> torch.Tensor:
+ """
+ Compute cumulative sum with automatic output allocation.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum (default: last dimension)
+ exclusive: Whether to compute exclusive (True) or inclusive (False) cumsum
+
+ Returns:
+ New tensor containing the cumulative sum
+ """
+ out = torch.empty_like(x)
+ if exclusive:
+ return exclusive_cumsum(x, dim, out)
+ else:
+ return inclusive_cumsum(x, dim, out)
+
+
+def argsort(x: torch.Tensor, end_bit: int = 32) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Sort tensor and return both sorted values and indices.
+
+ Args:
+ x: Input tensor to sort
+ end_bit: Number of bits to consider in sorting
+
+ Returns:
+ Tuple of (sorted_values, sorted_indices)
+ """
+ x_out = torch.empty_like(x)
+ iota_out = torch.empty_like(x)
+ sort(x, end_bit, x_out, iota_out)
+ return x_out, iota_out
+
+
+# Export public API
+__all__ = [
+ "MyReplacementLayer",
+ # Direct kernel exports
+ "exclusive_cumsum",
+ "inclusive_cumsum",
+ "histogram",
+ "indices",
+ "replicate_forward",
+ "replicate_backward",
+ "sort",
+ "cumsum",
+ "argsort",
+ # Original exports
+ "Arguments",
+ "ParallelDroplessMLP",
+ "dMoE",
+ "SparseGLU",
+ "MLP",
+ "SparseMLP",
+ "MoE",
+ "ParallelMLP",
+ "get_load_balancing_loss",
+]
diff --git a/build/torch211-cxx11-cu126-x86_64-linux/_layers/__init__.py b/build/torch211-cxx11-cu126-x86_64-linux/_layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a720e7a2cc4e44636f6e433a2750e945dc38e8b2
--- /dev/null
+++ b/build/torch211-cxx11-cu126-x86_64-linux/_layers/__init__.py
@@ -0,0 +1,10 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# from megablocks.layers.dmoe import dMoE
+from .moe import MoE
+
+__all__ = [
+ 'MoE',
+ # 'dMoE',
+]
diff --git a/build/torch211-cxx11-cu126-x86_64-linux/_layers/activation_fn.py b/build/torch211-cxx11-cu126-x86_64-linux/_layers/activation_fn.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e1d956704840aa4daf7d1d71d24e051567feab9
--- /dev/null
+++ b/build/torch211-cxx11-cu126-x86_64-linux/_layers/activation_fn.py
@@ -0,0 +1,33 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Callable, Union
+
+import torch
+from ..stk import Matrix
+
+
+def act_fn(
+ x: Matrix,
+ function: Callable,
+ return_grad_fn: bool = False,
+ **kwargs,
+) -> Union[tuple[Matrix, Any] | Matrix]:
+ assert isinstance(x, Matrix)
+ with torch.set_grad_enabled(torch.is_grad_enabled() or return_grad_fn):
+ if return_grad_fn:
+ x.data.requires_grad = True
+ out = function(x.data, **kwargs)
+ y = Matrix(
+ x.size(),
+ out,
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
+ if return_grad_fn:
+ return y, out.backward
+ return y
diff --git a/build/torch211-cxx11-cu126-x86_64-linux/_layers/all_to_all.py b/build/torch211-cxx11-cu126-x86_64-linux/_layers/all_to_all.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ac7067bcaa34db1d82b340c43550fe3577aa7a3
--- /dev/null
+++ b/build/torch211-cxx11-cu126-x86_64-linux/_layers/all_to_all.py
@@ -0,0 +1,54 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import torch.distributed as dist
+
+
+class AllToAllOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op):
+ out = torch.empty((sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype)
+
+ ctx.input_shape = x.shape
+ ctx.output_split_sizes = output_split_sizes
+ ctx.input_split_sizes = input_split_sizes
+ ctx.group = group
+ handle = dist.all_to_all_single(
+ out,
+ x,
+ output_split_sizes=output_split_sizes,
+ input_split_sizes=input_split_sizes,
+ group=group,
+ async_op=async_op,
+ )
+ return out, handle
+
+ @staticmethod
+ def backward(ctx, grad, _):
+ if ctx.needs_input_grad[0]:
+ out = torch.empty(
+ ctx.input_shape,
+ device=grad.device,
+ dtype=grad.dtype,
+ )
+ dist.all_to_all_single(
+ out,
+ grad,
+ output_split_sizes=ctx.input_split_sizes,
+ input_split_sizes=ctx.output_split_sizes,
+ group=ctx.group,
+ )
+ return out, None, None, None, None
+ return None, None, None, None, None
+
+
+def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False):
+ return AllToAllOp.apply(
+ x,
+ output_split_sizes,
+ input_split_sizes,
+ group,
+ async_op,
+ )
diff --git a/build/torch211-cxx11-cu126-x86_64-linux/_layers/arguments.py b/build/torch211-cxx11-cu126-x86_64-linux/_layers/arguments.py
new file mode 100644
index 0000000000000000000000000000000000000000..4db9b1bd38bc2e2f421625c124f86b85f45c5ae0
--- /dev/null
+++ b/build/torch211-cxx11-cu126-x86_64-linux/_layers/arguments.py
@@ -0,0 +1,101 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import dataclasses
+from functools import partial
+from typing import Any, Callable, Optional, Union
+
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+
+# import megablocks.grouped_gemm_util as grouped_gemm
+from .. import grouped_gemm_util as grouped_gemm
+
+# Type annotation for in-place Tensor initialization function.
+InitFn = Union[Callable[[torch.Tensor], None], partial[torch.Tensor]]
+
+_ALLOWED_BITWIDTHS = (-1, 4, 8)
+
+DEFAULT_ACTIVATION_FN = partial(F.gelu, approximate='tanh')
+
+
+@dataclasses.dataclass
+class Arguments:
+ # Model arguments.
+ hidden_size: int = 1024
+ ffn_hidden_size: int = 4096
+ num_layers: int = 1
+ bias: bool = True
+ return_bias: bool = True
+ activation_fn: Optional[Callable] = DEFAULT_ACTIVATION_FN
+
+ # MoE arguments.
+ moe_num_experts: int = 1
+ moe_top_k: int = 1
+ moe_capacity_factor: int = 1
+ moe_normalize_expert_weights: Optional[Union[int, float]] = None
+ moe_loss_weight: float = 0.1
+ moe_jitter_eps: Optional[float] = None
+ moe_lbl_in_fp32: bool = False
+
+ # Parallelism arguments.
+ moe_expert_model_parallelism: bool = False
+ expert_parallel_group: Optional[dist.ProcessGroup] = None
+ pipeline_model_parallel_size: int = 1
+ num_layers_per_virtual_pipeline_stage: Optional[int] = None
+
+ # Compute arguments.
+ memory_optimized_mlp: bool = False
+ mlp_type: str = 'mlp'
+ mlp_impl: str = 'sparse'
+
+ # Initialization arguments.
+ fp16: bool = True
+ bf16: bool = False
+ device: Union[int, torch.device] = dataclasses.field(default_factory=torch.cuda.current_device)
+ init_method: InitFn = partial(torch.nn.init.normal_, mean=0.0, std=0.02)
+ output_layer_init_method: InitFn = init_method
+
+ # Benchmarking arguments.
+ uniform_expert_assignment: bool = False
+
+ # shared expert arguments
+ shared_expert: bool = False # enable using shared expert
+ fc_cls: Any = torch.nn.Linear # class of the fully connected layer in shared expert (purpose: to allow using custom FC layer eg te.Linear (for FP8))
+ fc_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict,) # kwargs for custom fc layers
+ remat_act_fn: bool = True # enable act fn to be rematerialized instead of stored
+ shared_expert_hidden_size: Optional[
+ int] = None # hidden size of the shared expert IF we want to set it to something different from hidden_size
+ shared_expert_weighted_sum: bool = False # enable using weighted sum for shared expert output (wieghted by number of experts used)
+
+ # Router Z-loss arguments
+ moe_zloss_weight: float = 0 # 1e-3 is a reasonable value
+ moe_zloss_in_fp32: bool = False
+
+ def __post_init__(self):
+ # Sparse MLP is not supported with triton >=3.2.0
+ # TODO: Remove this once sparse is supported with triton >=3.2.0
+ if self.__getattribute__('mlp_impl') == 'sparse':
+ try:
+ import triton
+ if triton.__version__ >= '3.2.0':
+ raise ValueError(
+ 'Sparse MLP is not supported with triton >=3.2.0. Please use mlp_impl="grouped" instead.',
+ )
+ except ImportError:
+ raise ImportError('Triton is required for sparse MLP implementation')
+
+ if self.__getattribute__('mlp_impl') == 'grouped':
+ grouped_gemm.assert_grouped_gemm_is_available()
+
+ if self.shared_expert_hidden_size is None:
+ self.shared_expert_hidden_size = self.ffn_hidden_size
+
+
+def from_megatron(megatron_args: Any):
+ args = Arguments()
+ for field in dataclasses.fields(args):
+ if hasattr(megatron_args, field.name):
+ setattr(args, field.name, getattr(megatron_args, field.name))
+ return args
diff --git a/build/torch211-cxx11-cu126-x86_64-linux/_layers/common.py b/build/torch211-cxx11-cu126-x86_64-linux/_layers/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d07109702963ba48a3b94ab860807954dfd79c1
--- /dev/null
+++ b/build/torch211-cxx11-cu126-x86_64-linux/_layers/common.py
@@ -0,0 +1,26 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+from .arguments import Arguments
+
+
+def dtype(args: Arguments):
+ if args.fp16:
+ return torch.float16
+ elif args.bf16:
+ return torch.bfloat16
+ return None
+
+
+def cast_if_autocast_enabled(tensor):
+ if torch.is_autocast_enabled():
+ if tensor.device.type == 'cuda':
+ dtype = torch.get_autocast_gpu_dtype()
+ elif tensor.device.type == 'cpu':
+ dtype = torch.get_autocast_cpu_dtype()
+ else:
+ raise NotImplementedError()
+ return tensor.to(dtype=dtype)
+ return tensor
diff --git a/build/torch211-cxx11-cu126-x86_64-linux/_layers/dmlp_registry.py b/build/torch211-cxx11-cu126-x86_64-linux/_layers/dmlp_registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..de2ed047042e438c7190ebb139b6f7f30009734c
--- /dev/null
+++ b/build/torch211-cxx11-cu126-x86_64-linux/_layers/dmlp_registry.py
@@ -0,0 +1,42 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Union
+
+from . import glu, mlp
+from .arguments import Arguments
+
+MlpType = Union[mlp.SparseMLP, glu.SparseGLU]
+
+_REGISTRY = {
+ 'mlp': {
+ 'grouped': mlp.GroupedMLP,
+ 'sparse': mlp.SparseMLP,
+ },
+ 'glu': {
+ 'grouped': glu.GroupedGLU,
+ 'sparse': glu.SparseGLU,
+ },
+}
+
+
+def get(args: Arguments) -> MlpType:
+ """Returns an MLP for use in a dMoE instance.
+
+ Uses the provided arguments to instantiate the appropriate
+ MLP instance. This only contains MLPs for use in dMoEs
+ (ie. only for the dropless versions of MoEs).
+
+ Args:
+ args: propagated Arguments dataclass.
+
+ Returns:
+ An instantiated MLP constructed using the input args.
+ """
+ if args.mlp_type not in _REGISTRY:
+ raise ValueError(f'Unsupported mlp type: {args.mlp_type}')
+
+ if args.mlp_impl not in _REGISTRY[args.mlp_type]:
+ raise ValueError(f'{args.mlp_type} does not support {args.mlp_impl} backend.',)
+
+ return _REGISTRY[args.mlp_type][args.mlp_impl](args)
diff --git a/build/torch211-cxx11-cu126-x86_64-linux/_layers/dmoe.py b/build/torch211-cxx11-cu126-x86_64-linux/_layers/dmoe.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d0375a4df2f27134c4127e60be04f3b45693050
--- /dev/null
+++ b/build/torch211-cxx11-cu126-x86_64-linux/_layers/dmoe.py
@@ -0,0 +1,337 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import numpy as np
+import torch
+
+# try:
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/dmoe.py is needed.',
+# )
+
+# import megablocks.ops as ops
+# # from megablocks.ops import ops
+# from megablocks.layers import common, dmlp_registry, moe, mpu
+# from megablocks.layers.arguments import Arguments
+
+from .. import stk
+from .. import ops
+from . import common, dmlp_registry, moe, mpu
+from .arguments import Arguments
+
+def promote_scalar(x):
+ return x.view(1) if not len(x.size()) else x
+
+
+class ParallelDroplessMLP(moe.ParallelMLP):
+
+ def __init__(self, args: Arguments):
+ super(ParallelDroplessMLP, self).__init__(args)
+ self.hidden_size = args.hidden_size
+ self.ffn_hidden_size = mpu.features_per_rank(args)
+ self.blocking = 128
+ self.mlp = dmlp_registry.get(args)
+
+ # Calculate the number of bits needed to represent the column indices
+ # in the intermediate sparse matrix.
+ max_column_index = ((self.ffn_hidden_size * self.num_experts) // self.blocking)
+ self.transpose_sort_end_bit = max(
+ int(np.ceil(np.log2(max_column_index))),
+ 1,
+ )
+
+ def sparse_transpose(self, size, row_indices, column_indices, offsets):
+ block_columns = size[1] // self.blocking
+
+ # Sort row indices by column indices to get the transposed matrix's
+ # column indices.
+ #
+ # NOTE: Our sort operation uses the same width indices as the input values.
+ # To avoid overflow when we have large activation matrices we cast to
+ # 32-bit before sorting.
+ _, gather_indices = ops.sort(
+ column_indices.int(),
+ self.transpose_sort_end_bit,
+ )
+
+ # There are a constant number of blocks in every row of the sparse matrix.
+ # A blocks offset is:
+ #
+ # row_index * blocks_per_row + column_index % blocks_per_row
+ #
+ # Once we have the block offsets ordered for transposition we can divide
+ # by blocks_per_row to get the transposed column indices.
+ column_indices_t = row_indices.gather(0, gather_indices.long())
+ block_offsets_t = gather_indices.int()
+
+ zero = torch.zeros((1,), dtype=torch.int32, device=row_indices.device)
+ nnz_per_column = ops.histogram(column_indices, block_columns)
+ nnz_per_column = ops.inclusive_cumsum(nnz_per_column, 0)
+ if nnz_per_column.dim() == 0:
+ # This addresses an edge case when ffn_hidden_size is equal to self.blocking.
+ nnz_per_column = nnz_per_column.unsqueeze(0)
+ offsets_t = torch.cat([zero, nnz_per_column])
+ return column_indices_t, offsets_t, block_offsets_t
+
+ def topology(self, x, padded_bins):
+ padded_tokens, _ = x.size()
+ assert padded_tokens % self.blocking == 0
+ if self.ffn_hidden_size % self.blocking != 0:
+ raise ValueError(
+ f'The ffn_hidden_size {self.ffn_hidden_size} must be divisible by ' +
+ f'the block size {self.blocking}. Please update your configuration.',
+ )
+
+ # Offsets for the sparse matrix. All rows have the
+ # same number of nonzero blocks dictated by the
+ # dimensionality of a single expert.
+ block_rows = padded_tokens // self.blocking
+ blocks_per_row = self.ffn_hidden_size // self.blocking
+ offsets = torch.arange(
+ 0,
+ block_rows * blocks_per_row + 1,
+ blocks_per_row,
+ dtype=torch.int32,
+ device=x.device,
+ )
+
+ # Indices for the sparse matrix. The indices for
+ # the intermediate matrix are dynamic depending
+ # on the mapping of tokens to experts.
+ column_indices = ops.topology(
+ padded_bins,
+ self.blocking,
+ block_rows,
+ blocks_per_row,
+ )
+
+ # TODO(tgale): This is unused. Remove the need for this in stk.
+ # For now, use meta init to save the device memory.
+ data = torch.empty(
+ column_indices.numel(),
+ self.blocking,
+ self.blocking,
+ dtype=common.dtype(self.args),
+ device='meta',
+ )
+ shape = (
+ padded_tokens,
+ self.ffn_hidden_size * mpu.experts_per_rank(self.args),
+ )
+ row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
+ column_indices_t, offsets_t, block_offsets_t = self.sparse_transpose(
+ shape,
+ row_indices,
+ column_indices,
+ offsets,
+ )
+ return stk.Matrix(
+ shape,
+ data,
+ row_indices,
+ column_indices,
+ offsets,
+ column_indices_t,
+ offsets_t,
+ block_offsets_t,
+ )
+
+ def indices_and_padded_bins(self, top_experts):
+ # Sort the expert ids to produce the scatter/gather
+ # indices for the permutation.
+ top_experts = top_experts.int()
+ bin_ids, indices = ops.sort(top_experts, self.sort_end_bit)
+
+ # Histogram the expert ids to identify the number of
+ # tokens routed to each expert.
+ tokens_per_expert = ops.histogram(top_experts, self.num_experts)
+
+ # Round the token counts up to the block size used in
+ # the matrix muliplications. Caculate the starting
+ # position of each bin.
+ padded_tokens_per_expert = ops.round_up(
+ tokens_per_expert,
+ self.blocking,
+ )
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ padded_bins = promote_scalar(padded_bins)
+
+ # Calculate the bin bounds for the sorted tokens.
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ bins = promote_scalar(bins)
+ return indices, bin_ids, bins, padded_bins, tokens_per_expert
+
+ def sparse_forward_once(self, x, expert_weights, top_experts):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, padded_bins, tokens_per_expert = (self.indices_and_padded_bins(top_experts))
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.padded_gather(
+ x,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ self.top_k,
+ )
+
+ # Create the sparse matrix topology.
+ with torch.no_grad():
+ topo = self.topology(x, padded_bins)
+
+ # Perform the expert computation.
+ x = self.mlp(x, topo)
+
+ # Un-route the data for the MoE output.
+ x = ops.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ padded_bins,
+ self.top_k,
+ )
+ return x, tokens_per_expert
+
+ # For use in the base-class parallel_forward_once.
+ def sparse_permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy, # unused
+ top_k,
+ ):
+
+ # Round the token counts up to the block size used in the matrix
+ # multiplication. Calculate the starting position of each bin.
+ padded_tokens_per_expert = ops.round_up(
+ tokens_per_expert,
+ self.blocking,
+ )
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ padded_bins = promote_scalar(padded_bins)
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
+
+ # Create the sparse matrix topology.
+ with torch.no_grad():
+ topo = self.topology(x, padded_bins)
+
+ # Perform the expert computation.
+ x = self.mlp(x, topo)
+
+ # Un-route the data for the MoE output.
+ return ops.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ def grouped_forward_once(self, x, expert_weights, top_experts):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ out = self.grouped_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ -1, # unused
+ self.args.moe_top_k,
+ )
+ return out, tokens_per_expert
+
+ def grouped_permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy, # unused
+ top_k,
+ ):
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.gather(x, indices, bin_ids, bins, top_k)
+
+ # Perform the expert computation.
+ x = self.mlp(x, tokens_per_expert)
+
+ # Un-route the data for the MoE output.
+ return ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
+
+ def forward_once(self, x, expert_weights, top_experts):
+ if self.args.mlp_impl == 'sparse':
+ return self.sparse_forward_once(x, expert_weights, top_experts)
+ else:
+ return self.grouped_forward_once(x, expert_weights, top_experts)
+
+ def permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ ):
+ if self.args.mlp_impl == 'sparse':
+ return self.sparse_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ )
+ else:
+ return self.grouped_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ )
+
+
+class dMoE(moe.MoE):
+
+ def _init_experts_mlp(self, args: Arguments):
+ return ParallelDroplessMLP(args)
diff --git a/build/torch211-cxx11-cu126-x86_64-linux/_layers/gelu.py b/build/torch211-cxx11-cu126-x86_64-linux/_layers/gelu.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4c9e6532798615b5c12c96694241a4c18ee8f7b
--- /dev/null
+++ b/build/torch211-cxx11-cu126-x86_64-linux/_layers/gelu.py
@@ -0,0 +1,52 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# try:
+# import stk
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/gelu.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+import torch.nn.functional as F
+
+
+@torch.jit.script
+def _gelu_backward_inplace(g, x):
+ tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
+ ff = (0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out))
+ return g.mul_(ff)
+
+
+def gelu_backward_(grad: stk.Matrix, x: stk.Matrix):
+ # NOTE: The two sparse matrices must have the same topology.
+ if isinstance(grad, stk.Matrix) and isinstance(x, stk.Matrix):
+ return stk.Matrix(
+ x.size(),
+ _gelu_backward_inplace(grad.data, x.data),
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
+ return _gelu_backward_inplace(grad, x)
+
+
+def gelu(x: stk.Matrix):
+ assert isinstance(x, stk.Matrix)
+ return stk.Matrix(
+ x.size(),
+ F.gelu(x.data, approximate='tanh'),
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
diff --git a/build/torch211-cxx11-cu126-x86_64-linux/_layers/glu.py b/build/torch211-cxx11-cu126-x86_64-linux/_layers/glu.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f297a41ff6a1a2a285f5b461951672364b898da
--- /dev/null
+++ b/build/torch211-cxx11-cu126-x86_64-linux/_layers/glu.py
@@ -0,0 +1,244 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# import stk.ops
+# try:
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/glu.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+
+# from megablocks import grouped_gemm_util as gg
+# from megablocks.layers import common, mpu
+# from megablocks.layers.activation_fn import act_fn
+# from megablocks.layers.arguments import Arguments
+# from megablocks.layers.mlp import (
+# SharedMLP,
+# SparseMLP,
+# create_dmoe_expert_weights,
+# resolve_dtensor,
+# )
+
+from .. import grouped_gemm_util as gg
+from . import common, mpu
+from .activation_fn import act_fn
+from .arguments import Arguments
+from .mlp import (
+ SharedMLP,
+ SparseMLP,
+ create_dmoe_expert_weights,
+ resolve_dtensor,
+)
+
+
+class SparseGLU(SparseMLP):
+
+ def __init__(self, args: Arguments):
+ super().__init__(args)
+ self.v1 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ with torch.no_grad():
+ self.v1.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ ),
+ )
+
+ mpu.set_expert_model_parallel_attributes(
+ self.v1,
+ self._should_set_parallelism_attribute,
+ )
+
+ def forward(self, x, topo):
+ if self.args.memory_optimized_mlp:
+ raise NotImplementedError(
+ 'Memory optimized implementation not yet supported with GLU with sparse kernels.',
+ )
+
+ w1, v1, w2 = self.scale_grad(self.w1), self.scale_grad(self.v1,), self.scale_grad(self.w2)
+ w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2)
+
+ # Compute the GLU.
+ x1 = stk.ops.sdd(x, w1.t(), topo)
+ x2 = stk.ops.sdd(x, v1.t(), topo)
+
+ activation_fn_out = act_fn(x1, self.args.activation_fn)
+ x1 = stk.ops.mul(activation_fn_out, x2)
+
+ return stk.ops.dsd(x1, w2)
+
+
+class MemoryOptimizedGroupedGLU(torch.autograd.Function):
+ """GroupedMLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, v1, w2, batch_sizes, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ v1 = v1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], v1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not v1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1', 'v1' and 'w2'.")
+
+ # Layer 0: x @ w1.t().
+ assert gg.backend is not None
+ sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True)
+ v1_out = gg.backend.gmm(x, v1, batch_sizes, trans_b=True)
+
+ # GeLU.
+ activation_fn_out = activation_fn(sdd_out) * v1_out
+
+ # Layer 1: x @ w2.
+ dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, v1, w2, batch_sizes, x, sdd_out, v1_out)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx, ddsd_out):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # Unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, v1, w2 = saved_tensors[:3]
+ batch_sizes = saved_tensors[3]
+ x = saved_tensors[4]
+ sdd_out, v1_out = saved_tensors[5:7]
+
+ # Rematerialize activation_fn output.
+ activation_fn = ctx.activation_fn
+ with torch.set_grad_enabled(True):
+ sdd_out.requires_grad = True
+ v1_out.requires_grad = True
+ activation_fn_out = activation_fn(sdd_out) * v1_out
+ activation_grad_fn = activation_fn_out.backward
+
+ # Compute dw2 with recomputed activation_fn output.
+ assert gg.backend is not None
+ dw2 = gg.backend.gmm(
+ activation_fn_out,
+ ddsd_out,
+ batch_sizes,
+ trans_a=True,
+ )
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ gg.backend.gmm(
+ ddsd_out,
+ w2,
+ batch_sizes,
+ trans_b=True,
+ c=dactivation_fn_out,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out)
+ dsdd_out = sdd_out.grad
+ dv1_out = v1_out.grad
+
+ # Compute dw1.
+ dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True)
+
+ # Compute dv1.
+ dv1 = gg.backend.gmm(dv1_out, x, batch_sizes, trans_a=True)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ dx = ddsd_out
+ gg.backend.gmm(dsdd_out, w1, batch_sizes, c=dx)
+ dx += gg.backend.gmm(dv1_out, v1, batch_sizes)
+ return dx, dw1, dv1, dw2, None, None
+
+
+memory_optimized_grouped_glu = MemoryOptimizedGroupedGLU.apply
+
+
+class GroupedGLU(SparseGLU):
+
+ def forward(self, x, tokens_per_expert):
+ batch_sizes = tokens_per_expert.cpu().to(torch.long)
+ w1, v1, w2 = (
+ self.scale_grad(self.w1),
+ self.scale_grad(self.v1),
+ self.scale_grad(self.w2),
+ )
+ w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2)
+
+ # Re-shape the weights for the grouped GEMMs.
+ ne = mpu.experts_per_rank(self.args)
+ w1 = w1.view(ne, -1, self.args.hidden_size)
+ v1 = v1.view(ne, -1, self.args.hidden_size)
+ w2 = w2.view(ne, -1, self.args.hidden_size)
+
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_grouped_glu(
+ x,
+ w1,
+ v1,
+ w2,
+ batch_sizes,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ assert gg.ops is not None
+ x1 = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
+ x2 = gg.ops.gmm(x, v1, batch_sizes, trans_b=True)
+ x1 = self.args.activation_fn(x1) * x2
+ return gg.ops.gmm(x1, w2, batch_sizes)
+
+
+class SharedGLU(SharedMLP):
+ """GPU for shared expert.
+
+ Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTGLU class
+ """
+
+ def __init__(self, args: Arguments):
+ super().__init__(args)
+ self.gate_proj = args.fc_cls(
+ args.hidden_size,
+ self.args.shared_expert_hidden_size,
+ **self.fc_kwargs,
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))
diff --git a/build/torch211-cxx11-cu126-x86_64-linux/_layers/memory_test.py b/build/torch211-cxx11-cu126-x86_64-linux/_layers/memory_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..74d1166931b712635131985b25a89f4ca23e576d
--- /dev/null
+++ b/build/torch211-cxx11-cu126-x86_64-linux/_layers/memory_test.py
@@ -0,0 +1,103 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import gc
+
+import torch
+import torch.distributed as dist
+
+# from megablocks.layers import arguments, dmoe
+from . import arguments, dmoe
+
+_TESTS = ((8, 2048, 4096, 4096, 32, 4),)
+
+
+def get_tensors():
+ ptrs = set()
+ out = []
+ for obj in gc.get_objects():
+ if torch.is_tensor(obj):
+ if not obj.is_contiguous() or obj.data_ptr() in ptrs:
+ continue
+ out.append(obj)
+ ptrs.add(obj.data_ptr())
+ return out
+
+
+def test_memory(
+ group,
+ batch_size,
+ sequence_length,
+ hidden_size,
+ ffn_hidden_size,
+ num_experts,
+ top_k,
+):
+ args = arguments.Arguments(
+ hidden_size=hidden_size,
+ ffn_hidden_size=ffn_hidden_size,
+ moe_num_experts=num_experts,
+ moe_top_k=top_k,
+ moe_expert_model_parallelism=True,
+ expert_parallel_group=group,
+ fp16=False,
+ bf16=True,
+ device=torch.cuda.current_device(),
+ )
+ layer = dmoe.dMoE(args).cuda()
+
+ x = torch.randn((batch_size, sequence_length, hidden_size),
+ device=torch.cuda.current_device(),
+ dtype=torch.bfloat16).requires_grad_(True)
+ torch.cuda.empty_cache()
+
+ # Run forward + backward.
+ # with torch.autograd.detect_anomaly():
+ out, _ = layer(x)
+ out.mean().backward()
+
+ # Report peak memory.
+ mem = torch.cuda.max_memory_allocated()
+ print('Max Memory Allocated = {:0.0f}MiB'.format(mem / 1e6))
+ print('Max Memory Reserved = {:0.0f}MiB'.format(torch.cuda.max_memory_reserved() / 1e6,),)
+
+ # Calculate weight and gradient memory usage.
+ weight_memory = 2 * (
+ layer.router.layer.weight.numel() + layer.experts.mlp.w1.numel() + layer.experts.mlp.w2.numel()
+ )
+
+ def grad_numel(x):
+ if x.grad is not None:
+ return x.grad.numel()
+ return 0
+
+ grad_memory = 2 * (
+ grad_numel(layer.router.layer.weight) + grad_numel(layer.experts.mlp.w1) + grad_numel(layer.experts.mlp.w2)
+ )
+ weight_memory += grad_memory
+
+ print('Weight Memory Allocated = {:0.0f}MiB'.format(weight_memory / 1e6))
+ print('Activation Memory Allocated = {:0.0f}MiB'.format((mem - weight_memory) / 1e6,),)
+
+ # Manually calculate GPU memory usage from the garbage
+ # collector.
+ gc.collect()
+ total = 0
+ tensors = get_tensors()
+ tensors = sorted(tensors, key=lambda x: -x.numel())
+ for i, t in enumerate(tensors):
+ total += t.numel()
+ print(f'{i}: {t.shape}, {t.numel() * 2}')
+ del tensors
+
+ print('Total Bytes Found = {:0.0f}MiB'.format(total * 2 / 1e6))
+
+
+if __name__ == '__main__':
+ assert dist.is_available()
+ group = dist.init_process_group(backend='nccl')
+ local_rank = dist.get_rank(group)
+ torch.cuda.set_device(local_rank)
+
+ for args in _TESTS:
+ test_memory(group, *args)
diff --git a/build/torch211-cxx11-cu126-x86_64-linux/_layers/mlp.py b/build/torch211-cxx11-cu126-x86_64-linux/_layers/mlp.py
new file mode 100644
index 0000000000000000000000000000000000000000..c99afb9904c24a8b6a83e79059cd1251dbbfd99e
--- /dev/null
+++ b/build/torch211-cxx11-cu126-x86_64-linux/_layers/mlp.py
@@ -0,0 +1,587 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# try:
+# import stk
+# import stk.backend.triton_kernels
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/mlp.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+from packaging import version
+
+# from megablocks import grouped_gemm_util as gg
+# from megablocks.layers import common, gelu, mpu
+# from megablocks.layers.activation_fn import act_fn
+# from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
+
+from .. import grouped_gemm_util as gg
+from . import common, gelu, mpu
+from .activation_fn import act_fn
+from .arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
+
+class ScaleGradient(torch.autograd.Function):
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx: Any, x: torch.Tensor, scale: float):
+ ctx.scale = scale
+ return x
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx: torch.Tensor, grad: torch.Tensor):
+ return grad * ctx.scale, None
+
+
+scale_gradient = ScaleGradient.apply
+
+
+def resolve_dtensor(weight: torch.Tensor):
+ if version.parse(torch.__version__) >= version.parse('2.0.0'):
+ from torch.distributed._tensor import DTensor
+ if isinstance(weight, DTensor):
+ return weight.to_local()
+ return weight
+
+
+def create_moe_expert_weights(
+ args: Arguments,
+ num_experts: int,
+ ffn_hidden_size: int,
+ hidden_size: int,
+ init_method: InitFn,
+):
+ # Create the entire weight matrix such that the sampled weights will
+ # not vary between data parallelism and expert model parallelism for
+ # the same random seed.
+ master_weights = torch.empty(
+ num_experts,
+ ffn_hidden_size,
+ hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ )
+ init_method(master_weights)
+
+ if not args.moe_expert_model_parallelism:
+ return master_weights
+
+ # Calculate the amount of sharding in each dimension.
+ expert_sharding_degree = mpu.expert_sharding_degree(args)
+ hidden_sharding_degree = mpu.hidden_sharding_degree(args)
+
+ # Calculate the experts per rank.
+ #
+ # NOTE: We assign ranks to be expert parallel before going
+ # tensor parallel.
+ rank = mpu.get_expert_parallel_rank(args)
+ expert_rank = rank % expert_sharding_degree
+ num_experts_per_rank = num_experts // expert_sharding_degree
+ start_expert = expert_rank * num_experts_per_rank
+ end_expert = (expert_rank + 1) * num_experts_per_rank
+
+ # Calculate the rows per rank.
+ row_rank = rank // expert_sharding_degree
+ num_rows_per_rank = ffn_hidden_size // hidden_sharding_degree
+ start_row = row_rank * num_rows_per_rank
+ end_row = (row_rank + 1) * num_rows_per_rank
+
+ # Slice the weight matrix to get the chunk for this rank.
+ with torch.no_grad():
+ weights = master_weights[start_expert:end_expert, start_row:end_row]
+ return weights
+
+
+class MLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ # expert_parallel_world_size = mpu.get_expert_parallel_world_size(args)
+ experts_per_rank = mpu.experts_per_rank(args)
+
+ self.w1 = torch.nn.Parameter(
+ torch.empty(
+ experts_per_rank,
+ args.hidden_size,
+ mpu.features_per_rank(args),
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ self.w2 = torch.nn.Parameter(
+ torch.empty(
+ experts_per_rank,
+ mpu.features_per_rank(args),
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w1,
+ args.moe_expert_model_parallelism,
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w2,
+ args.moe_expert_model_parallelism,
+ )
+
+ # Initialize the parameters for the MLP.
+ #
+ # NOTE: It is important that we create the weight tensors prior
+ # to creating the master weights and slicing our the piece for
+ # this rank. If the master weights are created first the PyTorch
+ # caching allocator appears to use the same memory block for these
+ # and the slice which causes large increases in our peak memory
+ # usage.
+ with torch.no_grad():
+ w1 = create_moe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ )
+ self.w1.copy_(w1.transpose(1, 2).contiguous())
+ self.w2.copy_(
+ create_moe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.output_layer_init_method,
+ ),
+ )
+
+ self.gradient_scale = None
+ if self.args.moe_expert_model_parallelism:
+ self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,)
+
+ def scale_grad(self, w):
+ if self.gradient_scale is None:
+ return w
+ return scale_gradient(w, self.gradient_scale)
+
+ def forward(self, x):
+ w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2)
+ w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2)
+ x = torch.bmm(x, w1)
+ x = self.args.activation_fn(x)
+ return torch.bmm(x, w2)
+
+
+def create_dmoe_expert_weights(
+ args: Arguments,
+ num_experts: int,
+ rows: int,
+ columns: int,
+ init_method: InitFn,
+):
+ weights = create_moe_expert_weights(
+ args,
+ num_experts,
+ rows,
+ columns,
+ init_method,
+ )
+ return weights.view([-1, columns])
+
+
+class MemoryOptimizedMLP(torch.autograd.Function):
+ """Sparse MLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, w2, topo, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.")
+
+ topo_tensors = (
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets,
+ topo.column_indices_t,
+ topo.offsets_t,
+ topo.block_offsets_t,
+ )
+
+ # Layer 0: x @ w1.t().
+ sdd_out = stk.ops.sdd(x, w1.t(), topo)
+
+ # GeLU.
+ activation_fn_out = act_fn(sdd_out, activation_fn)
+
+ # Layer 1: x @ w2.
+ dsd_out = stk.ops.dsd(activation_fn_out, w2)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.shape = topo.shape
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.data.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, w2, *topo_tensors, x, sdd_out.data)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx, ddsd_out):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, w2 = saved_tensors[:2]
+ topo_tensors = saved_tensors[2:8]
+ x = saved_tensors[8]
+ sdd_out_data = saved_tensors[9]
+
+ # rematerialize activation function output
+ activation_fn = ctx.activation_fn
+ sdd_out = stk.Matrix(ctx.shape, sdd_out_data, *topo_tensors)
+ activation_fn_out, activation_grad_fn = act_fn(
+ sdd_out,
+ activation_fn,
+ return_grad_fn=True,
+ )
+
+ # Compute dw2 with recomputed activation_fn output.
+ dw2 = stk.ops.dsd(activation_fn_out.t(), ddsd_out)
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ stk.backend.triton_kernels.sdd(
+ ddsd_out,
+ w2.t(),
+ dactivation_fn_out.shape,
+ dactivation_fn_out.data,
+ dactivation_fn_out.offsets,
+ dactivation_fn_out.row_indices,
+ dactivation_fn_out.column_indices,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ if activation_fn is DEFAULT_ACTIVATION_FN:
+ dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out)
+ else:
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out.data)
+ dsdd_out = stk.Matrix(ctx.shape, sdd_out.data.grad, *topo_tensors)
+
+ # Compute dw1.
+ dw1 = stk.ops.dsd(dsdd_out.t(), x)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ stk.backend.triton_kernels.dsd(
+ dsdd_out.shape,
+ dsdd_out.data,
+ dsdd_out.offsets,
+ dsdd_out.row_indices,
+ dsdd_out.column_indices,
+ dsdd_out.offsets_t,
+ dsdd_out.column_indices_t,
+ dsdd_out.block_offsets_t,
+ False,
+ w1,
+ ddsd_out,
+ )
+ dx = ddsd_out
+ return dx, dw1, dw2, None, None
+
+
+memory_optimized_mlp = MemoryOptimizedMLP.apply
+
+
+class SparseMLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ self._num_rows_per_rank = mpu.experts_per_rank(args) * mpu.features_per_rank(args)
+
+ self.w1 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ self.w2 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+
+ # Initialize the parameters for the MLP.
+ #
+ # NOTE: It is important that we create the weight tensors prior
+ # to creating the master weights and slicing our the piece for
+ # this rank. If the master weights are created first the PyTorch
+ # caching allocator appears to use the same memory block for these
+ # and the slice which causes large increases in our peak memory
+ # usage.
+ with torch.no_grad():
+ self.w1.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ ),
+ )
+ self.w2.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.output_layer_init_method,
+ ),
+ )
+
+ self._should_set_parallelism_attribute = args.moe_expert_model_parallelism
+ mpu.set_expert_model_parallel_attributes(
+ self.w1,
+ self._should_set_parallelism_attribute,
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w2,
+ self._should_set_parallelism_attribute,
+ )
+
+ self.gradient_scale = None
+ if self.args.moe_expert_model_parallelism:
+ self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,)
+
+ def scale_grad(self, w):
+ if self.gradient_scale is None:
+ return w
+ return scale_gradient(w, self.gradient_scale)
+
+ def forward(self, x, topo):
+ w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2)
+ w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2)
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_mlp(
+ x,
+ w1,
+ w2,
+ topo,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ x = stk.ops.sdd(x, w1.t(), topo)
+ activation_fn_out = act_fn(x, self.args.activation_fn)
+ return stk.ops.dsd(activation_fn_out, w2)
+
+
+class MemoryOptimizedGroupedMLP(torch.autograd.Function):
+ """GroupedMLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, w2, batch_sizes, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.")
+
+ # Layer 0: x @ w1.t().
+ assert gg.backend is not None
+ sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True)
+
+ # activation_fn
+ activation_fn_out = activation_fn(sdd_out)
+
+ # Layer 1: x @ w2.
+ dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, w2, batch_sizes, x, sdd_out)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx: Any, ddsd_out: torch.Tensor):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # Unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, w2 = saved_tensors[:2]
+ batch_sizes = saved_tensors[2]
+ x = saved_tensors[3]
+ sdd_out = saved_tensors[4]
+
+ # Rematerialize activation_fn output.
+ activation_fn = ctx.activation_fn
+ with torch.set_grad_enabled(True):
+ sdd_out.requires_grad = True
+ activation_fn_out = activation_fn(sdd_out)
+ activation_grad_fn = activation_fn_out.backward
+
+ # Compute dw2 with recomputed activation_fn output.
+ assert gg.backend is not None
+ dw2 = gg.backend.gmm(
+ activation_fn_out,
+ ddsd_out,
+ batch_sizes,
+ trans_a=True,
+ )
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ gg.backend.gmm(
+ ddsd_out,
+ w2,
+ batch_sizes,
+ trans_b=True,
+ c=dactivation_fn_out,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ if activation_fn is DEFAULT_ACTIVATION_FN:
+ dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out)
+ else:
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out)
+ dsdd_out = sdd_out.grad
+
+ # Compute dw1.
+ dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ gg.backend.gmm(dsdd_out, w1, batch_sizes, c=ddsd_out)
+ dx = ddsd_out
+ return dx, dw1, dw2, None, None
+
+
+memory_optimized_grouped_mlp = MemoryOptimizedGroupedMLP.apply
+
+
+class GroupedMLP(SparseMLP):
+
+ def forward(self, x, tokens_per_expert):
+ batch_sizes = tokens_per_expert.cpu().to(torch.long)
+ w1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.w2))
+
+ # Re-shape the weights for the grouped GEMMs.
+ ne = mpu.experts_per_rank(self.args)
+ w1 = resolve_dtensor(w1).view(ne, -1, self.args.hidden_size)
+ w2 = resolve_dtensor(w2).view(ne, -1, self.args.hidden_size)
+
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_grouped_mlp(
+ x,
+ w1,
+ w2,
+ batch_sizes,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ assert gg.ops is not None
+ x = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
+ x = self.args.activation_fn(x)
+ return gg.ops.gmm(x, w2, batch_sizes)
+
+
+class SharedMLP(torch.nn.Module):
+ """MLP for shared expert.
+
+ Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTMLP class
+ """
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ self.fc_kwargs: dict[str, Any] = {
+ 'bias': args.bias,
+ 'device': args.device,
+ }
+ self.fc_kwargs.update(args.fc_kwargs)
+
+ self.up_proj = args.fc_cls(
+ args.hidden_size,
+ args.shared_expert_hidden_size,
+ **self.fc_kwargs,
+ )
+ self.act = args.activation_fn
+ self.down_proj = args.fc_cls(
+ args.shared_expert_hidden_size,
+ args.hidden_size,
+ **self.fc_kwargs,
+ )
+ self.down_proj._is_residual = True # a flag for llm-foundry init
+
+ def add_experts_sharedexpert(
+ self,
+ shared_expert_out: torch.Tensor,
+ expert_out: torch.Tensor,
+ ) -> torch.Tensor:
+ # Helper function to add expert output to shared expert output
+ # with optional weighted sum.
+ if self.args.shared_expert_weighted_sum:
+ # enable using weighted sum for shared expert output
+ # wieghted by number of experts used
+ t_experts = self.args.moe_top_k + 1
+ sh_mlp_out = shared_expert_out / t_experts
+ return sh_mlp_out.add(
+ expert_out,
+ alpha=(self.args.moe_top_k / t_experts),
+ )
+
+ return shared_expert_out + expert_out
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.down_proj(self.act(self.up_proj(x)))
diff --git a/build/torch211-cxx11-cu126-x86_64-linux/_layers/moe.py b/build/torch211-cxx11-cu126-x86_64-linux/_layers/moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0a4aeaacc9c86fc70944e730c53f7a55644e05e
--- /dev/null
+++ b/build/torch211-cxx11-cu126-x86_64-linux/_layers/moe.py
@@ -0,0 +1,507 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Optional, Tuple
+
+import numpy as np
+import torch
+import torch.distributed as dist
+
+# import megablocks.ops as ops
+# from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry
+# from megablocks.layers.all_to_all import all_to_all
+# from megablocks.layers.arguments import Arguments
+
+from ..ops import (
+ sort,
+ histogram,
+ inclusive_cumsum,
+ exclusive_cumsum,
+ binned_gather,
+ binned_scatter,
+ gather,
+ scatter,
+ repeat,
+ replicate,
+)
+
+from . import common, mlp, mpu, router, sharedexpert_registry
+from .arguments import Arguments
+from .all_to_all import all_to_all
+
+_LOAD_BALANCING_LOSS = []
+
+
+def save_load_balancing_loss(loss):
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.append(loss)
+
+
+def get_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ return _LOAD_BALANCING_LOSS
+
+
+def clear_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.clear()
+
+
+def batched_load_balancing_loss(args: Arguments):
+ if args.moe_loss_weight == 0:
+ return 0.0
+
+ # tokens_per_expert[i].shape = (num_experts)
+ # expert_scores[i].shape = (tokens, num_experts)
+ tokens_per_expert, expert_scores = zip(*get_load_balancing_loss())
+ num_layers_per_pipeline_stage = (args.num_layers // args.pipeline_model_parallel_size)
+ if args.num_layers_per_virtual_pipeline_stage is not None:
+ num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage
+
+ if len(tokens_per_expert) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f'Expected {num_layers_per_pipeline_stage} token_per_experts '
+ f'but found {len(tokens_per_expert)}.\nnum_layers = '
+ f'{args.num_layers}\npipeline_model_parallel_size = '
+ f'{args.pipeline_model_parallel_size}\n'
+ 'num_layers_per_virtual_pipeline_stage'
+ f' = {args.num_layers_per_virtual_pipeline_stage}',
+ )
+ if len(expert_scores) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f'Expected {num_layers_per_pipeline_stage} expert_scores '
+ f'but found {len(tokens_per_expert)}.\nnum_layers = '
+ f'{args.num_layers}\npipeline_model_parallel_size = '
+ f'{args.pipeline_model_parallel_size}\n'
+ 'num_layers_per_virtual_pipeline_stage'
+ f' = {args.num_layers_per_virtual_pipeline_stage}',
+ )
+
+ # Verify the shape of the tokens_per_expert and expert_scores tensors.
+ assert all((x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert))
+
+ tokens = expert_scores[0].shape[0]
+ assert all(((x.ndim == 2 and x.shape[1] == args.moe_num_experts and x.shape[0] == tokens) for x in expert_scores))
+
+ # Concatenate the contributions of each layer and convert to
+ # the correct types and formats for the dot product.
+ expert_scores = torch.cat(expert_scores, dim=1)
+ if args.moe_lbl_in_fp32:
+ expert_scores = expert_scores.float()
+ if tokens != 0:
+ expert_scores = expert_scores.mean(dim=0)
+ else:
+ expert_scores = expert_scores.sum(dim=0)
+ tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype)
+
+ expected_values = num_layers_per_pipeline_stage * args.moe_num_experts
+ assert tokens_per_expert.numel() == expected_values
+ assert expert_scores.numel() == expected_values
+
+ # Calculate the total scale across all factors.
+ #
+ # loss_weight * num_experts / (num_layers * tokens * top_k)
+ scale_numerator = (args.moe_num_experts * args.moe_loss_weight)
+ scale_denominator = (args.num_layers * tokens * args.moe_top_k)
+ scale = scale_numerator / scale_denominator
+ return scale * torch.dot(tokens_per_expert, expert_scores)
+
+
+# NOTE: This class defines MoE expert computation, including expert model parallel
+# communication. When using FSDP on top of MegaBlocks this is the module that should
+# be wrapped s.t. the weight all-gathers can be scheduled *before* the expert model
+# parallel all2all.
+class ParallelMLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super(ParallelMLP, self).__init__()
+ self.args = args
+
+ # Calculate the number of experts in total and the number of experts
+ # owned by this rank.
+ # world_size = mpu.get_expert_parallel_world_size(args)
+ self.num_experts = args.moe_num_experts
+ self.top_k = self.args.moe_top_k
+
+ # Calculate the number of bits needed to represent the expert indices
+ # so that we can pass it to radix sort.
+ self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1)
+
+ # Expert MLP.
+ self.mlp = mlp.MLP(args)
+
+ self.bias: Optional[torch.Tensor]
+ if self.args.bias:
+ # Note that the output bias is not parallelized with expert
+ # model parallelism.
+ self.bias = torch.nn.Parameter(
+ torch.empty(
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ torch.nn.init.zeros_(self.bias)
+ else:
+ self.register_parameter('bias', None)
+
+ # Select the forward function for the operating mode.
+ self.forward_fn = (self.parallel_forward_once if args.moe_expert_model_parallelism else self.forward_once)
+
+ def expert_capacity(self, tokens: int) -> int:
+ world_size = mpu.get_expert_parallel_world_size(self.args)
+ tokens_per_expert = (self.top_k * tokens * world_size / self.num_experts)
+ return int(self.args.moe_capacity_factor * tokens_per_expert)
+
+ def load_balancing_loss(self, tokens_per_expert: torch.Tensor, expert_scores: torch.Tensor):
+ """Calculate the load balancing loss contribution."""
+ assert len(expert_scores.size()) == 2
+ tokens, num_experts = expert_scores.size()
+ assert num_experts == self.num_experts
+ assert len(tokens_per_expert.size()) == 1
+ num_experts, = tokens_per_expert.size()
+ assert num_experts == self.num_experts
+ scale = self.num_experts / (tokens * self.top_k)
+ return scale * torch.dot(
+ tokens_per_expert.to(expert_scores.dtype),
+ expert_scores.mean(dim=0),
+ )
+
+ def indices_and_bins(self,
+ top_expert: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ # Sort the expert ids to produce the scatter/gather
+ # indices for the permutation.
+ #
+ # TODO(tgale): Is it worth doing this conversion to 32-bit
+ # prior? Could we place the `torch.max` operation to return
+ # 32-bit expert indices?
+ top_expert = top_expert.int()
+ # output = ops.sort(top_expert, self.sort_end_bit)
+ output = sort(top_expert, self.sort_end_bit)
+ assert output is not None
+ bin_ids, indices = output
+
+ # Histogram the expert ids to identify the number of
+ # tokens routed to each expert.
+ #
+ # TODO(tgale): Does the sorted data produce a more favorable
+ # data distribution for histogram? Or is the op parallelism
+ # worth more?
+ # tokens_per_expert = ops.histogram(top_expert, self.num_experts)
+ tokens_per_expert = histogram(top_expert, self.num_experts)
+
+ # Calculate the bin bounds for the sorted tokens.
+ # bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ bins = inclusive_cumsum(tokens_per_expert, 0)
+ assert bins is not None
+ bins = bins.view(1) if not len(bins.size()) else bins
+
+ assert isinstance(indices, torch.Tensor)
+ assert isinstance(bin_ids, torch.Tensor)
+ assert isinstance(bins, torch.Tensor)
+ assert isinstance(tokens_per_expert, torch.Tensor)
+
+ return indices, bin_ids, bins, tokens_per_expert
+
+ def permute_and_compute(
+ self,
+ x: torch.Tensor,
+ tokens_per_expert: int, # unused
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor, # unused
+ expert_weights: torch.Tensor,
+ bins: torch.Tensor,
+ expert_capacity: int,
+ top_k: int,
+ ):
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ # output = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
+ output = binned_gather(x, indices, bins, expert_capacity, top_k)
+ assert output is not None
+ x = output
+
+ # Perform the expert computation. Note that we don't
+ # use biases for these linear operations.
+ x = self.mlp(x)
+
+ # Un-route the data for the MoE output.
+ # return ops.binned_scatter(x, indices, expert_weights, bins, top_k)
+ return binned_scatter(x, indices, expert_weights, bins, top_k)
+
+
+ def forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ # If expert_capacity is set to zero, set the number of tokens
+ # per expert to the maximum we need to avoid dropping tokens.
+ sl, bs, _ = x.size()
+ expert_capacity = self.expert_capacity(sl * bs)
+ if expert_capacity == 0:
+ expert_capacity = torch.max(tokens_per_expert).item()
+
+ x = self.permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ self.top_k,
+ )
+ return x, tokens_per_expert
+
+ def parallel_forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ # NOTE: This function implements the same computation as forward_once
+ # but with expert model parallelism.
+ #
+ # 1. Permute the tokens locally so that they are grouped by their
+ # expert assignments. This allows us to transfer all of the tokens
+ # for a remote device in one communication primitive.
+ #
+ # 2. Permute the tokens across the expert parallel devices. After
+ # this is completed each device has all of the tokens assigned to
+ # its set of experts in its local HBM.
+ #
+ # 3. Permute the tokens locally so that they are grouped by their
+ # expert assignement. After the distributed permutation the tokens
+ # are grouped by which device they came from. We re-order them
+ # locally to allow for efficient computation.
+ #
+ # After this series of permutations we compute the linear layers
+ # and then repeat these three steps in reverse to produce the final
+ # output.
+ #
+ # Compute the mapping of local tokens to experts.
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ # If we're sharding the experts along the hidden dimension
+ # multiple devices own parts of the same sets of experts.
+ # Replicate the token counts so every device gets the counts.
+ # repeated_tokens_per_expert = ops.repeat(
+ repeated_tokens_per_expert = repeat(
+ tokens_per_expert,
+ (mpu.hidden_sharding_degree(self.args),),
+ )
+
+ # Pass token count information to the device on which the
+ # target expert resides.
+ parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert,)
+ tpe_handle = dist.all_to_all_single(
+ parallel_tokens_per_expert,
+ repeated_tokens_per_expert,
+ group=self.args.expert_parallel_group,
+ async_op=True,
+ )
+
+ # Permute locally and without any padding so that tokens for each
+ # parallel device are stored contiguously.
+ #
+ # This view updates the shape of the tensor from [sl, bs, hs] to
+ # [sl * bs, hs] prior to the permutation.
+ x = x.view(-1, x.shape[-1])
+ # output = ops.gather(x, indices, bin_ids, bins, self.top_k)
+ output = gather(x, indices, bin_ids, bins, self.top_k)
+ assert output is not None
+ x = output
+
+ # Compute the number of tokens that will be received from each
+ # device and permute the input data across the devices.
+ with torch.no_grad():
+ tpe_handle.wait()
+ experts_per_rank = mpu.experts_per_rank(self.args)
+
+ # Reshape to [world_size, num_experts_per_rank].
+ world_size = mpu.get_expert_parallel_world_size(self.args)
+ repeated_tokens_per_expert = (repeated_tokens_per_expert.view(world_size, experts_per_rank))
+ parallel_tokens_per_expert = (parallel_tokens_per_expert.view(world_size, experts_per_rank))
+
+ # TODO(tgale): It might be faster to do this on the GPU and
+ # then communicate the results back to the host.
+ send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1)
+ parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
+ recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1)
+
+ # Convert the send/recv counts to lists.
+ send_counts = send_counts.tolist()
+ recv_counts = recv_counts.tolist()
+ tokens_received = sum(recv_counts)
+
+ # If we're sharding the experts along the hidden dimension
+ # multiple devices own parts of the same sets of experts.
+ # Replicate the token counts so devices that share experts
+ # get all of the tokens assigned to them.
+ #
+ # TODO(tgale): Fuse this into the prior, local permutation.
+ # x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
+ x = repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
+
+ # Start the cross-device permutation asynchronously so we can
+ # overlap communication with computation.
+ parallel_x, parallel_x_handle = all_to_all(
+ x,
+ recv_counts,
+ send_counts,
+ self.args.expert_parallel_group,
+ async_op=True,
+ )
+
+ with torch.no_grad():
+ # After we do the cross-device permutation we have the tokens on the
+ # correct device but not yet grouped by expert because we received
+ # tokens from each device as contiguous chunks. To group the tokens
+ # for expert computation we'll do one more local permutation. The
+ # rest of this torch.no_grad() scope sets up the indices and bins
+ # for this permutation.
+ # replicate_bins = ops.inclusive_cumsum(
+ replicate_bins = inclusive_cumsum(
+ parallel_tokens_per_expert.flatten(),
+ 0,
+ )
+ replicate_bins = (replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins)
+
+ # Construct the expert indices for the permuted tokens.
+ parallel_top_expert = torch.remainder(
+ torch.arange(
+ self.num_experts * mpu.hidden_sharding_degree(self.args),
+ dtype=torch.int32,
+ device=indices.device,
+ ),
+ mpu.experts_per_rank(self.args),
+ )
+ # parallel_top_expert = ops.replicate(
+ parallel_top_expert = replicate(
+ parallel_top_expert.unsqueeze(dim=0),
+ replicate_bins,
+ tokens_received,
+ ).flatten()
+
+ # TODO(tgale): The sort_end_bit here can be reduced.
+ # parallel_bin_ids, parallel_indices = ops.sort(
+ parallel_bin_ids, parallel_indices = sort(
+ parallel_top_expert,
+ self.sort_end_bit,
+ )
+
+ # Calculate the bins boundaries from the token counts.
+ parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+ # parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = (parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins)
+
+ # If expert_capacity is set to zero, set the number of tokens
+ # per expert to the maximum we need to avoid dropping tokens.
+ tokens, _ = x.size()
+ expert_capacity = self.expert_capacity(tokens)
+ if expert_capacity == 0:
+ expert_capacity = torch.max(parallel_tokens_per_expert).item()
+
+ # Locally permute the tokens and perform the expert computation.
+ # Block to make sure that the cross-device permutation is complete.
+ if self.args.mlp_impl == 'grouped':
+ # GroupedMLP requires counts on CPU. We can use the tensor already
+ # moved to CPU for the prior all_to_all, which avoids an extra
+ # device synchronization.
+ parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+ parallel_x_handle.wait()
+ parallel_x = self.permute_and_compute(
+ parallel_x,
+ parallel_tokens_per_expert,
+ parallel_indices,
+ parallel_bin_ids,
+ None, # expert_weights
+ parallel_bins,
+ expert_capacity,
+ top_k=1,
+ )
+
+ # Un-permute the tokens across the devices.
+ x, _ = all_to_all(
+ parallel_x,
+ send_counts,
+ recv_counts,
+ self.args.expert_parallel_group,
+ )
+
+ # Reduce along the hidden sharding to get the final outputs.
+ #
+ # TODO(tgale): Fuse this into the following local permutation.
+ shape = (
+ mpu.hidden_sharding_degree(self.args),
+ -1,
+ self.args.hidden_size,
+ )
+ # x = ops.sum(x.view(shape), dim=0)
+ x = x.view(shape).sum(dim=0)
+
+ # Un-permute locally to setup for the next series of operations.
+ # x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
+ x = scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
+ return x, tokens_per_expert.flatten()
+
+ def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ in_shape = x.size()
+
+ # Compute the experts.
+ x, tokens_per_expert = self.forward_fn(x, expert_weights, top_experts)
+ if self.training and self.args.moe_loss_weight > 0:
+ save_load_balancing_loss((tokens_per_expert, scores))
+ x = x.view(in_shape)
+ if self.bias is not None:
+ if self.args.return_bias:
+ return x, self.bias
+ return x + self.bias
+ return x
+
+
+class MoE(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super(MoE, self).__init__()
+
+ # Token router.
+ self.router = router.LearnedRouter(args)
+
+ # Expert computation helper.
+ self.experts = self._init_experts_mlp(args)
+
+ self.shared_expert = None
+ if args.shared_expert:
+ # SharedExpert computation helper.
+ self.shared_expert = sharedexpert_registry.get(args)
+
+ def _init_experts_mlp(self, args: Arguments):
+ return ParallelMLP(args)
+
+ def forward(self, x: torch.Tensor):
+ # NOTE: If we're going to cast the activations to lower precision
+ # do it before we permute the tokens to save bandwidth.
+ x = common.cast_if_autocast_enabled(x)
+
+ # Compute the expert scores and assignments.
+ scores, expert_weights, top_experts = self.router(x)
+
+ # Compute the experts.
+ out = self.experts(x, scores, expert_weights, top_experts)
+ if self.shared_expert is not None:
+ shared_expert_out = self.shared_expert(x)
+ out = self.shared_expert.add_experts_sharedexpert(
+ shared_expert_out,
+ out,
+ )
+ return out
diff --git a/build/torch211-cxx11-cu126-x86_64-linux/_layers/mpu.py b/build/torch211-cxx11-cu126-x86_64-linux/_layers/mpu.py
new file mode 100644
index 0000000000000000000000000000000000000000..434e143ab42bf3f83406d69e9dd1f72777716e22
--- /dev/null
+++ b/build/torch211-cxx11-cu126-x86_64-linux/_layers/mpu.py
@@ -0,0 +1,94 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Optional
+
+import torch
+import torch.distributed as dist
+
+# from megablocks.layers.arguments import Arguments
+from .arguments import Arguments
+
+
+class MoeParam(torch.Tensor):
+
+ def __init__(self):
+ super().__init__(self)
+ self.expert_model_parallel: bool
+
+
+def is_moe_param(tensor: torch.Tensor) -> bool:
+ return hasattr(tensor, 'expert_model_parallel')
+
+
+def get_expert_parallel_world_size(args: Arguments) -> int:
+ return (dist.get_world_size(args.expert_parallel_group) if args.moe_expert_model_parallelism else 1)
+
+
+def get_expert_parallel_rank(args: Arguments) -> int:
+ return (dist.get_rank(args.expert_parallel_group) if args.moe_expert_model_parallelism else 0)
+
+
+def set_expert_model_parallel_attributes(
+ tensor: torch.Tensor,
+ is_parallel: bool,
+):
+ assert not hasattr(tensor, 'expert_model_parallel')
+ setattr(tensor, 'expert_model_parallel', is_parallel)
+
+
+def param_is_expert_model_parallel(param: MoeParam) -> bool:
+ return (hasattr(param, 'expert_model_parallel') and param.expert_model_parallel)
+
+
+def copy_expert_model_parallel_attributes(
+ destination_tensor: torch.Tensor,
+ source_tensor: torch.Tensor,
+):
+ if hasattr(source_tensor, 'expert_model_parallel'):
+ setattr(
+ destination_tensor,
+ 'expert_model_parallel',
+ getattr(source_tensor, 'expert_model_parallel'),
+ )
+
+
+def synchronized_print(group: Optional[dist.ProcessGroup], *x: torch.Tensor):
+ world_size = dist.get_world_size(group)
+ rank = dist.get_rank(group)
+ for i in range(world_size):
+ dist.barrier(group)
+ if i == rank:
+ print(f'rank = {rank}', *x)
+
+
+# Helpers for expert/tensor sharding.
+def expert_sharding_degree(args: Arguments) -> int:
+ world_size = get_expert_parallel_world_size(args)
+ esd = min(world_size, args.moe_num_experts)
+
+ if (args.moe_num_experts % esd) != 0:
+ raise ValueError(f'Cannot shard {args.moe_num_experts} experts {esd} ways.',)
+ return esd
+
+
+def hidden_sharding_degree(args: Arguments) -> int:
+ world_size = get_expert_parallel_world_size(args)
+ esd = expert_sharding_degree(args)
+ hsd = world_size // esd
+
+ if (args.ffn_hidden_size % hsd) != 0:
+ raise ValueError(f'Cannot shard {args.ffn_hidden_size} features {hsd} ways.',)
+ if (esd * hsd) != world_size:
+ raise ValueError(
+ f"Invalid sharding. 'expert_sharding_degree' ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size}).",
+ )
+ return hsd
+
+
+def experts_per_rank(args: Arguments) -> int:
+ return args.moe_num_experts // expert_sharding_degree(args)
+
+
+def features_per_rank(args: Arguments) -> int:
+ return args.ffn_hidden_size // hidden_sharding_degree(args)
diff --git a/build/torch211-cxx11-cu126-x86_64-linux/_layers/router.py b/build/torch211-cxx11-cu126-x86_64-linux/_layers/router.py
new file mode 100644
index 0000000000000000000000000000000000000000..37cb2782348d62583376f1a183c7ede83601216d
--- /dev/null
+++ b/build/torch211-cxx11-cu126-x86_64-linux/_layers/router.py
@@ -0,0 +1,116 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+
+# from megablocks.layers import common
+# from megablocks.layers.arguments import Arguments
+from . import common
+from .arguments import Arguments
+
+_ROUTER_LOGITS = []
+
+
+def _save_router_logits(logits: torch.Tensor, args: Arguments):
+ if args.moe_zloss_weight == 0:
+ return
+ global _ROUTER_LOGITS
+ _ROUTER_LOGITS.append(logits)
+
+
+def clear_router_zloss():
+ global _ROUTER_LOGITS
+ _ROUTER_LOGITS.clear()
+
+
+def batched_router_zloss(args: Arguments):
+ global _ROUTER_LOGITS
+
+ if args.moe_zloss_weight == 0:
+ import warnings
+ warnings.warn('Call to batched_router_zloss, but moe_zloss_weight=0')
+ return 0
+
+ logits_per_router = _ROUTER_LOGITS
+
+ if args.moe_zloss_in_fp32:
+ logits_per_router = [logits.float() for logits in logits_per_router]
+
+ unscaled_zloss_per_router = torch.stack([
+ torch.logsumexp(logits, dim=1).square().mean() for logits in logits_per_router
+ ])
+
+ return args.moe_zloss_weight * unscaled_zloss_per_router
+
+
+# NOTE: To enable end-to-end benchmarking without convergence we
+# support a flag to force the router to assign tokens uniformly
+# across the experts. We do this with a custom autograd operation
+# so that PyTorch still executes the full set of router operation.
+class _UniformExpertAssignment(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, num_experts: int):
+ out = torch.arange(x.numel(), dtype=x.dtype, device=x.device)
+ out = torch.remainder(out, num_experts)
+ return out.view(x.shape)
+
+
+_uniform_expert_assignment = _UniformExpertAssignment.apply
+
+
+class LearnedRouter(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+
+ # Learned router parameters.
+ #
+ # NOTE: This weight matrix is not parallelized with expert model
+ # parallelism. Each device needs the entire router weight matrix
+ # so that it can route its batch of data correctly.
+ self.layer = torch.nn.Linear(
+ args.hidden_size,
+ args.moe_num_experts,
+ bias=False,
+ dtype=common.dtype(args),
+ device=args.device,
+ )
+ args.init_method(self.layer.weight)
+
+ def jitter(self, x: torch.Tensor):
+ low: float = 1.0 - self.args.moe_jitter_eps
+ high: float = 1.0 + self.args.moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return low + noise * (high - low)
+
+ def _top_k(self, scores: torch.Tensor):
+ if self.args.moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, self.args.moe_top_k, dim=-1)
+
+ def forward(self, x: torch.Tensor):
+ if self.training and self.args.moe_jitter_eps is not None:
+ x = x * self.jitter(x)
+
+ logits = self.layer(x.view(-1, x.shape[-1]))
+ _save_router_logits(logits, self.args)
+ scores = logits.softmax(dim=-1)
+ expert_weights, expert_indices = self._top_k(scores)
+ if self.args.moe_normalize_expert_weights:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=self.args.moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+
+ expert_indices = (
+ _uniform_expert_assignment(
+ expert_indices,
+ self.args.moe_num_experts,
+ ) if self.args.uniform_expert_assignment else expert_indices
+ )
+ return scores, expert_weights, expert_indices
diff --git a/build/torch211-cxx11-cu126-x86_64-linux/_layers/sharedexpert_registry.py b/build/torch211-cxx11-cu126-x86_64-linux/_layers/sharedexpert_registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..5840862f88f370ace5fd49bd0612fc98d186cc49
--- /dev/null
+++ b/build/torch211-cxx11-cu126-x86_64-linux/_layers/sharedexpert_registry.py
@@ -0,0 +1,32 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Union
+
+# from megablocks.layers import glu, mlp
+# from megablocks.layers.arguments import Arguments
+from . import glu, mlp
+from .arguments import Arguments
+
+_REGISTRY = {
+ 'mlp': mlp.SharedMLP,
+ 'glu': glu.SharedGLU,
+}
+
+
+def get(args: Arguments) -> Union[mlp.SharedMLP, glu.SharedGLU]:
+ """Returns an SharedMLP for use in a dMoE instance.
+
+ Uses the provided arguments to instantiate the appropriate
+ SharedMLP instance.
+
+ Args:
+ args: propagated Arguments dataclass.
+
+ Returns:
+ An instantiated SharedMLP constructed using the input args.
+ """
+ if args.mlp_type not in _REGISTRY:
+ raise ValueError(f'Unsupported mlp type: {args.mlp_type}')
+
+ return _REGISTRY[args.mlp_type](args)
diff --git a/build/torch211-cxx11-cu126-x86_64-linux/_megablocks_cuda_7a6bcf4.abi3.so b/build/torch211-cxx11-cu126-x86_64-linux/_megablocks_cuda_7a6bcf4.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..4850e7984e1967316cf9646f8d1f1869af56e094
--- /dev/null
+++ b/build/torch211-cxx11-cu126-x86_64-linux/_megablocks_cuda_7a6bcf4.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e3d80af1e9bea8f67e6377db54e0b11f61bda494c42b5f5612a9f93eebc5ef55
+size 15061056
diff --git a/build/torch211-cxx11-cu126-x86_64-linux/_ops.py b/build/torch211-cxx11-cu126-x86_64-linux/_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..abde001f8cf5f78a02794d6e9a81fd8195e65d77
--- /dev/null
+++ b/build/torch211-cxx11-cu126-x86_64-linux/_ops.py
@@ -0,0 +1,9 @@
+import torch
+from . import _megablocks_cuda_7a6bcf4
+ops = torch.ops._megablocks_cuda_7a6bcf4
+
+def add_op_namespace_prefix(op_name: str):
+ """
+ Prefix op by namespace.
+ """
+ return f"_megablocks_cuda_7a6bcf4::{op_name}"
diff --git a/build/torch211-cxx11-cu126-x86_64-linux/_version.py b/build/torch211-cxx11-cu126-x86_64-linux/_version.py
new file mode 100644
index 0000000000000000000000000000000000000000..c55783177af19bc03654c730c4892df8f8532279
--- /dev/null
+++ b/build/torch211-cxx11-cu126-x86_64-linux/_version.py
@@ -0,0 +1,6 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+"""The MegaBlocks Version."""
+
+__version__ = '0.11.0.dev0'
diff --git a/build/torch211-cxx11-cu126-x86_64-linux/backend/__init__.py b/build/torch211-cxx11-cu126-x86_64-linux/backend/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d4e43e9b7e6a9a1c3bde2df34914643ca5d8332
--- /dev/null
+++ b/build/torch211-cxx11-cu126-x86_64-linux/backend/__init__.py
@@ -0,0 +1,2 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
diff --git a/build/torch211-cxx11-cu126-x86_64-linux/backend/kernels.py b/build/torch211-cxx11-cu126-x86_64-linux/backend/kernels.py
new file mode 100644
index 0000000000000000000000000000000000000000..d20139324883338ddb312e4b05a72056d74491ac
--- /dev/null
+++ b/build/torch211-cxx11-cu126-x86_64-linux/backend/kernels.py
@@ -0,0 +1,557 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import triton
+import triton.language as tl
+
+# Stub triton autotune when testing in a env that does not have CUDA
+# this approach preserves the original code but enables testing without a GPU
+if torch.cuda.is_available() is False:
+ import warnings
+
+ warnings.warn("CUDA is not available. Triton autotuning is disabled.")
+
+ def _no_autotune(*args, **kwargs):
+ def deco(fn):
+ return fn
+ return deco
+
+ triton.autotune = _no_autotune
+
+
+def assert_is_tensor(x, ndim):
+ if x.ndim != ndim:
+ raise ValueError(f'Expected {ndim}-tensor but got {x.ndim}-tensor')
+
+
+def assert_is_matrix(x):
+ assert_is_tensor(x, 2)
+
+
+def assert_is_vector(x):
+ if x.ndim != 1:
+ raise ValueError(f'Expected 1-tensor but got {x.ndim}-tensor')
+
+
+def assert_equal(a, b):
+ if a != b:
+ raise ValueError(f'Expected dimensions to be equal but got {a} and {b}.',)
+
+
+# a: (tokens, hidden_size), real.
+# indices: (tokens * top_k), integer.
+# bin_ids: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+# padded_bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _padded_copy(
+ a,
+ b,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+ A_TO_B: tl.constexpr,
+ SCALE: tl.constexpr,
+):
+ # Our index into array 'a'.
+ index_a = tl.load(indices + tl.program_id(0))
+
+ # One threadblock per row in 'a'. Array 'b' has greater or equal
+ # number of rows since they could be padded.
+ bin_idx = tl.load(bin_ids + tl.program_id(0))
+
+ # Now we know what bin we're assigned to, but we need to know how
+ # many threadblocks were assigned to earlier bins so we can offset
+ # in our bin properly.
+ offset_in_bin = tl.program_id(0)
+ if bin_idx > 0:
+ offset_in_bin -= tl.load(bins + bin_idx - 1)
+
+ # Load the starting index of our bin in array 'b'.
+ index_b = offset_in_bin
+ if bin_idx > 0:
+ index_b += tl.load(padded_bins + bin_idx - 1)
+
+ # Offset the input and output pointers.
+ #
+ # If we're going from A to B, divide the input index to copy
+ # the same input repeatedly. If we're going from B to A we
+ # need to reduce the result. Using atomics is slow, so we
+ # do the reduce step in a second kernel.
+ offset = index_a // TOP_K if A_TO_B else index_a
+ a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS)
+ b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ # Load the scale, if requested.
+ scale = tl.load(weights + index_a) if SCALE else 1
+
+ # Swap the pointers depending on the direction.
+ iptr = a if A_TO_B else b
+ optr = b if A_TO_B else a
+
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ x = tl.load(iptr + offsets, mask=mask)
+ x = x.to(tl.float32) * scale.to(tl.float32)
+
+ tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask)
+
+ offsets += BLOCK_X
+
+
+def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+ assert_equal(bin_ids.shape[0], x.shape[0] * top_k)
+ assert_equal(bins.size(), padded_bins.size())
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ # NOTE: Because of the padding, the output size is dynamic.
+ # We load the final padded bin bound to get the output rows.
+ output_rows = padded_bins[-1].cpu().item()
+ out = torch.zeros((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ x,
+ out,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def gather(x, indices, bin_ids, weights, bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+ assert_equal(bin_ids.shape[0], x.shape[0] * top_k)
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ # NOTE: There is no padding so the output rows equals the
+ # input rows multiplied by top_k.
+ output_rows = x.shape[0] * top_k
+ out = torch.empty((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ x,
+ out,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], bin_ids.shape[0])
+ assert_equal(bins.size(), padded_bins.size())
+
+ if weights is not None:
+ assert_equal(indices.shape[0], weights.shape[0])
+
+ tokens = indices.shape[0] // top_k
+ out = torch.empty((tokens, top_k, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ out,
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=False,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+
+ # Reduce along the top-k dimension, if needed.
+ return out.sum(dim=1) if top_k > 1 else out.view(tokens, x.shape[1])
+
+
+def scatter(x, indices, bin_ids, weights, bins, top_k):
+ return padded_scatter(x, indices, bin_ids, weights, bins, bins, top_k)
+
+
+# x: (tokens, top_k, hidden_size), real
+# grad: (tokens, hidden_size), real.
+# wgrad: (tokens, top_k), real.
+# indices: (tokens * top_k), integer.
+# bin_ids: (tokens * top_k), integer.
+# bins: (num_experts), integer.
+# padded_bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _padded_copy_wgrad(
+ x,
+ grad,
+ wgrad,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+):
+ # Our index into 'tokens * top_k'.
+ index_out = tl.load(indices + tl.program_id(0))
+
+ # One threadblock per row in 'a'. Array 'b' has greater or equal
+ # number of rows since they could be padded.
+ bin_idx = tl.load(bin_ids + tl.program_id(0))
+
+ # Now we know what bin we're assigned to, but we need to know how
+ # many threadblocks were assigned to earlier bins so we can offset
+ # in our bin properly.
+ offset_in_bin = tl.program_id(0)
+ if bin_idx > 0:
+ offset_in_bin -= tl.load(bins + bin_idx - 1)
+
+ # Load the starting index of our bin in array 'x'.
+ index_x = offset_in_bin
+ if bin_idx > 0:
+ index_x += tl.load(padded_bins + bin_idx - 1)
+
+ # Offset the input and output pointers.
+ wgrad += index_out
+ grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS)
+ x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ data = tl.load(x + offsets, mask=mask).to(tl.float32)
+ scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
+ acc += data * scale
+ offsets += BLOCK_X
+
+ # Reduce to get the final result and store.
+ out = tl.sum(acc).to(wgrad.dtype.element_ty)
+ tl.store(wgrad, out)
+
+
+def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_matrix(grad)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], bin_ids.shape[0])
+ assert_equal(bins.size(), padded_bins.size())
+
+ tokens = indices.shape[0] // top_k
+ out = torch.empty((tokens * top_k), dtype=x.dtype, device=x.device)
+ _padded_copy_wgrad[(indices.shape[0],)](
+ x,
+ grad,
+ out,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ TOP_K=top_k,
+ )
+ return out
+
+
+def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k):
+ return padded_scatter_wgrad(x, grad, indices, bin_ids, bins, bins, top_k)
+
+
+# a: (tokens, hidden_size), real.
+# b: (num_experts, expert_capacity, num_columns), real.
+# indices: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _binned_copy(
+ a,
+ b,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+ A_TO_B: tl.constexpr,
+ SCALE: tl.constexpr,
+):
+ # Load our indices into the output.
+ expert_idx = tl.program_id(0)
+ entry_idx = tl.program_id(1)
+
+ # Calculate our offset into the output.
+ index_b = expert_idx * expert_capacity + entry_idx
+
+ # Load the index bounds for our bin and calculate
+ # the number of tokens assigned to our expert.
+ start = 0
+ if expert_idx > 0:
+ start = tl.load(bins + expert_idx - 1)
+ end = tl.load(bins + expert_idx)
+ num_tokens = end - start
+
+ # Calculate our offset into the input. If we don't
+ # have an input exit early.
+ if entry_idx >= num_tokens:
+ return
+ index_a = tl.load(indices + start + entry_idx)
+
+ # Offset the input and output pointers.
+ #
+ # If we're going from A to B, divide the input index to copy
+ # the same input repeatedly. If we're going from B to A we
+ # need to reduce the result. Using atomics is slow, so we
+ # do the reduce step in a second kernel.
+ offset = index_a // TOP_K if A_TO_B else index_a
+ a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS)
+ b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ # Load the scale, if requested.
+ scale = tl.load(weights + index_a) if SCALE else 1
+
+ # Swap the pointers depending on the direction.
+ #
+ # NOTE: We need to zero the output in both directions.
+ iptr = a if A_TO_B else b
+ optr = b if A_TO_B else a
+
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ x = tl.load(iptr + offsets, mask=mask)
+ x = x.to(tl.float32) * scale.to(tl.float32)
+
+ tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask)
+
+ offsets += BLOCK_X
+
+
+def binned_gather(x, indices, weights, bins, expert_capacity, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ num_experts = bins.shape[0]
+ out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
+
+ _binned_copy[(num_experts, expert_capacity)](
+ x,
+ out,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def binned_scatter(x, indices, weights, bins, top_k):
+ # Validate the input shapes.
+ assert_is_tensor(x, 3)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(bins.shape[0], x.shape[0])
+
+ if weights is not None:
+ assert_equal(indices.shape[0], weights.shape[0])
+
+ num_experts, expert_capacity, hidden_size = x.shape
+ tokens = indices.shape[0] // top_k
+ out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
+ _binned_copy[(num_experts, expert_capacity)](
+ out,
+ x,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS=hidden_size,
+ A_TO_B=False,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+
+ # Reduce along the top-k dimension, if needed.
+ return out.sum(dim=1) if top_k > 1 else out.view(tokens, hidden_size)
+
+
+# a: (tokens, hidden_size), real.
+# b: (num_experts, expert_capacity, num_columns), real.
+# indices: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _binned_copy_wgrad(
+ x,
+ grad,
+ wgrad,
+ num_experts,
+ expert_capacity,
+ indices,
+ bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+):
+ # Load our indices into the output.
+ expert_idx = tl.program_id(0)
+ entry_idx = tl.program_id(1)
+
+ # Calculate our offset into the output.
+ index_x = expert_idx * expert_capacity + entry_idx
+
+ # Load the index bounds for our bin and calculate
+ # the number of tokens assigned to our expert.
+ start = 0
+ if expert_idx > 0:
+ start = tl.load(bins + expert_idx - 1)
+ end = tl.load(bins + expert_idx)
+ num_tokens = end - start
+
+ # Calculate our offset into the input. If we don't
+ # have an input exit early.
+ if entry_idx >= num_tokens:
+ return
+ index_out = tl.load(indices + start + entry_idx)
+
+ # Offset the input and output pointers.
+ wgrad += index_out
+ grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS)
+ x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ data = tl.load(x + offsets, mask=mask).to(tl.float32)
+ scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
+ acc += data * scale
+ offsets += BLOCK_X
+
+ # Reduce to get the final result and store.
+ out = tl.sum(acc).to(wgrad.dtype.element_ty)
+ tl.store(wgrad, out)
+
+
+def binned_scatter_wgrad(x, grad, indices, bins, top_k):
+ # Validate the input shapes.
+ assert_is_tensor(x, 3)
+ assert_is_matrix(grad)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(bins.shape[0], x.shape[0])
+
+ num_experts, expert_capacity, hidden_size = x.shape
+ tokens = indices.shape[0] // top_k
+ out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
+ _binned_copy_wgrad[(num_experts, expert_capacity)](
+ x,
+ grad,
+ out,
+ num_experts,
+ expert_capacity,
+ indices,
+ bins,
+ NUM_COLUMNS=hidden_size,
+ TOP_K=top_k,
+ )
+ return out
diff --git a/build/torch211-cxx11-cu126-x86_64-linux/benchmark_util.py b/build/torch211-cxx11-cu126-x86_64-linux/benchmark_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..02612d95e3ead1175a596e2878fa34b5bf85ad6f
--- /dev/null
+++ b/build/torch211-cxx11-cu126-x86_64-linux/benchmark_util.py
@@ -0,0 +1,35 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import numpy as np
+import torch
+
+
+def log_benchmark(name, arguments, time, std):
+ print('=' * 60)
+ print(f'{name} Benchmark')
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean time = {:.3f}ms, std time = {:.3f}ms'.format(time, std))
+ print('=' * 60)
+
+
+def benchmark_function(fn, iterations=100, warmup=10):
+ # Warmup iterations.
+ for _ in range(warmup):
+ fn()
+
+ times = []
+ for i in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+
+ start.record()
+ fn()
+ end.record()
+
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ return np.mean(times), np.std(times)
diff --git a/build/torch211-cxx11-cu126-x86_64-linux/cpu_fused_moe.py b/build/torch211-cxx11-cu126-x86_64-linux/cpu_fused_moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a29e81d6e9c57f64b0a78ac1c0828e45fd9d855
--- /dev/null
+++ b/build/torch211-cxx11-cu126-x86_64-linux/cpu_fused_moe.py
@@ -0,0 +1,311 @@
+# SPDX-License-Identifier: Apache-2.0
+# MegaBlocks CPU Fused MoE Implementation
+#
+# This is a pure Python/PyTorch implementation for CPU.
+# For better performance, consider using the C++ kernel implementation.
+#
+import torch
+import torch.nn.functional as F
+
+
+def swigluoai_activation(gate: torch.Tensor, up: torch.Tensor,
+ alpha: float = 1.702, limit: float = 7.0) -> torch.Tensor:
+ """
+ SwigluOAI activation function used in GptOss models.
+
+ Formula:
+ gate = clamp(gate, max=limit)
+ up = clamp(up, -limit, limit)
+ glu = gate * sigmoid(gate * alpha)
+ output = (up + 1) * glu
+
+ Args:
+ gate: Gate tensor from gate projection
+ up: Up tensor from up projection
+ alpha: Scaling factor for sigmoid (default: 1.702)
+ limit: Clamp limit (default: 7.0)
+
+ Returns:
+ Activated tensor
+ """
+ gate = gate.clamp(max=limit)
+ up = up.clamp(min=-limit, max=limit)
+ glu = gate * torch.sigmoid(gate * alpha)
+ return (up + 1) * glu
+
+
+def silu_and_mul_activation(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor:
+ """
+ SiLU (Swish) activation with element-wise multiplication.
+
+ Formula:
+ output = silu(gate) * up
+
+ Args:
+ gate: Gate tensor
+ up: Up tensor
+
+ Returns:
+ Activated tensor
+ """
+ return F.silu(gate) * up
+
+
+def route_tokens_cpu(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: torch.Tensor | None,
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_normalize_expert_weights: int | None = None,
+) -> tuple:
+ """
+ Route tokens to experts and compute expert weights and indices (CPU version).
+
+ Args:
+ x: Input tensor [batch, seq, hidden] or [tokens, hidden]
+ router_weight: Router weight [num_experts, hidden]
+ router_bias: Router bias [num_experts] or None
+ moe_top_k: Number of experts per token
+ moe_num_experts: Total number of experts
+ moe_normalize_expert_weights: Normalization order or None
+
+ Returns:
+ Tuple of (logits, expert_weights, expert_indices)
+ """
+ x_flat = x.view(-1, x.shape[-1])
+ logits = F.linear(x_flat, router_weight, router_bias)
+
+ if moe_top_k == 1:
+ expert_weights, expert_indices = logits.max(dim=-1, keepdim=True)
+ else:
+ expert_weights, expert_indices = torch.topk(logits, moe_top_k, dim=-1)
+
+ expert_weights = expert_weights.softmax(dim=-1)
+
+ if moe_normalize_expert_weights is not None:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+
+ return logits, expert_weights, expert_indices
+
+
+def cpu_fused_moe(
+ hidden_states: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ w1_bias: torch.Tensor | None = None,
+ w2_bias: torch.Tensor | None = None,
+ activation: str = "silu",
+ alpha: float = 1.702,
+ limit: float = 7.0,
+ is_interleaved: bool = True,
+) -> torch.Tensor:
+ """
+ CPU Fused MoE using PyTorch operations.
+
+ This implementation processes all experts in parallel using batched operations
+ instead of sequential for loops, which is more efficient on CPU.
+
+ Args:
+ hidden_states: [num_tokens, hidden_size]
+ w1: [num_experts, hidden_size, 2*inter_size] - gate_up_proj weights
+ w2: [num_experts, inter_size, hidden_size] - down_proj weights
+ topk_weights: [num_tokens, topk] - routing weights
+ topk_ids: [num_tokens, topk] - expert indices
+ w1_bias: [num_experts, 2*inter_size] or None
+ w2_bias: [num_experts, hidden_size] or None
+ activation: "silu" or "swigluoai"
+ alpha: swigluoai alpha parameter
+ limit: swigluoai limit parameter
+ is_interleaved: whether gate_up is interleaved [g0,u0,g1,u1,...] (True for GptOss)
+
+ Returns:
+ output: [num_tokens, hidden_size]
+ """
+ num_tokens, hidden_size = hidden_states.shape
+ num_experts = w1.shape[0]
+ inter_size = w2.shape[1]
+ topk = topk_weights.shape[1]
+
+ # Initialize output
+ output = torch.zeros_like(hidden_states)
+
+ # Build expert mask: which tokens go to which expert
+ # expert_mask[expert_id] contains indices of (token_idx, topk_pos) pairs
+ for expert_idx in range(num_experts):
+ # Find tokens assigned to this expert
+ # mask shape: [num_tokens, topk], True where topk_ids == expert_idx
+ mask = (topk_ids == expert_idx)
+
+ if not mask.any():
+ continue
+
+ # Get token indices and topk positions
+ token_indices, topk_positions = torch.where(mask)
+
+ if len(token_indices) == 0:
+ continue
+
+ # Gather input tokens for this expert
+ # current_hidden: [num_selected_tokens, hidden_size]
+ current_hidden = hidden_states[token_indices]
+
+ # Get weights for this expert
+ # w1[expert_idx]: [hidden_size, 2*inter_size]
+ # w2[expert_idx]: [inter_size, hidden_size]
+ expert_w1 = w1[expert_idx] # [hidden_size, 2*inter_size]
+ expert_w2 = w2[expert_idx] # [inter_size, hidden_size]
+
+ # First projection: hidden @ w1 -> [num_selected, 2*inter_size]
+ gate_up = current_hidden @ expert_w1
+
+ # Add bias if present
+ if w1_bias is not None:
+ gate_up = gate_up + w1_bias[expert_idx]
+
+ # Split gate and up projections
+ if is_interleaved:
+ # GptOss uses interleaved layout: [g0, u0, g1, u1, ...]
+ gate = gate_up[..., ::2] # [num_selected, inter_size]
+ up = gate_up[..., 1::2] # [num_selected, inter_size]
+ else:
+ # Standard layout: [gate_all, up_all]
+ gate = gate_up[..., :inter_size]
+ up = gate_up[..., inter_size:]
+
+ # Apply activation
+ if activation == "swigluoai":
+ activated = swigluoai_activation(gate, up, alpha, limit)
+ else: # silu
+ activated = silu_and_mul_activation(gate, up)
+
+ # Second projection: activated @ w2 -> [num_selected, hidden_size]
+ expert_out = activated @ expert_w2
+
+ # Add bias if present
+ if w2_bias is not None:
+ expert_out = expert_out + w2_bias[expert_idx]
+
+ # Apply routing weights and accumulate
+ # weights shape: [num_selected]
+ weights = topk_weights[token_indices, topk_positions].unsqueeze(-1)
+ weighted_out = expert_out * weights
+
+ # Accumulate to output
+ output.index_add_(0, token_indices, weighted_out.to(output.dtype))
+
+ return output
+
+
+class MegaBlocksMoeMLP(torch.nn.Module):
+ """
+ CPU MoE MLP module that can be used as a drop-in replacement for
+ the transformers GptOssMLP when using @use_kernel_forward_from_hub.
+ """
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> tuple:
+ """
+ Forward pass through the MoE layer.
+
+ Args:
+ x: Input tensor of shape [batch_size, seq_len, hidden_size] or [tokens, hidden_size]
+
+ Returns:
+ Tuple of (output, expert_weights) where:
+ - output: Tensor of same shape as input
+ - expert_weights: Expert weights for each token [tokens, top_k]
+ """
+ # Get MoE parameters from the wrapped modules
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+
+ # Detect activation type
+ if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
+ activation = "swigluoai"
+ alpha = self.experts.alpha
+ limit = self.experts.limit
+ else:
+ activation = getattr(self.experts, "activation", "silu")
+ alpha = 1.702
+ limit = 7.0
+
+ # Get weight tensors
+ if hasattr(self.experts, "gate_up_proj"):
+ w1 = self.experts.gate_up_proj
+ is_interleaved = True # GptOss uses interleaved layout
+ elif hasattr(self.experts, "w1"):
+ w1 = self.experts.w1
+ w3 = getattr(self.experts, "w3", None)
+ if w3 is not None:
+ w1 = torch.cat([w1, w3], dim=-1)
+ is_interleaved = False
+ else:
+ raise AttributeError("experts module must have 'gate_up_proj' or 'w1' attribute")
+
+ if hasattr(self.experts, "down_proj"):
+ w2 = self.experts.down_proj
+ elif hasattr(self.experts, "w2"):
+ w2 = self.experts.w2
+ else:
+ raise AttributeError("experts module must have 'down_proj' or 'w2' attribute")
+
+ # Get optional bias tensors
+ w1_bias = getattr(self.experts, "gate_up_proj_bias", None)
+ w2_bias = getattr(self.experts, "down_proj_bias", None)
+
+ # Store original shape
+ in_shape = x.size()
+
+ # Route tokens to experts
+ logits, expert_weights, expert_indices = route_tokens_cpu(
+ x,
+ self.router.weight,
+ getattr(self.router, "bias", None),
+ moe_top_k,
+ moe_num_experts,
+ moe_normalize_expert_weights,
+ )
+
+ # Reshape input for fused MoE
+ x_flat = x.view(-1, x.shape[-1])
+
+ # Call CPU fused MoE
+ output = cpu_fused_moe(
+ hidden_states=x_flat,
+ w1=w1,
+ w2=w2,
+ topk_weights=expert_weights,
+ topk_ids=expert_indices,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ activation=activation,
+ alpha=alpha,
+ limit=limit,
+ is_interleaved=is_interleaved,
+ )
+
+ # Restore original shape
+ output = output.view(in_shape)
+
+ return output, expert_weights
+
+
+# Export classes and functions
+__all__ = [
+ "MegaBlocksMoeMLP",
+ "cpu_fused_moe",
+ "route_tokens_cpu",
+ "swigluoai_activation",
+ "silu_and_mul_activation",
+]
diff --git a/build/torch211-cxx11-cu126-x86_64-linux/cpu_moe_cpp.py b/build/torch211-cxx11-cu126-x86_64-linux/cpu_moe_cpp.py
new file mode 100644
index 0000000000000000000000000000000000000000..073ff66d24ce348fbb5ed19c9027fadd3f7a9c61
--- /dev/null
+++ b/build/torch211-cxx11-cu126-x86_64-linux/cpu_moe_cpp.py
@@ -0,0 +1,265 @@
+# SPDX-License-Identifier: Apache-2.0
+# MegaBlocks C++ Optimized CPU MoE
+
+"""
+C++ accelerated MoE with brgemm optimization for Intel AMX.
+Direct replacement for cpu_fused_moe.MegaBlocksMoeMLP with better performance.
+"""
+
+import torch
+from typing import Optional
+from .cpu_fused_moe import route_tokens_cpu
+from ._ops import ops
+
+
+def _to_local_tensor(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
+ """Convert DTensor to local torch.Tensor if needed for custom ops compatibility."""
+ if tensor is None:
+ return None
+ # Check if it's a DTensor by looking for the to_local() method
+ if hasattr(tensor, "to_local"):
+ return tensor.to_local()
+ return tensor
+
+
+def fused_moe_cpp(
+ hidden_states: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ inplace: bool = False,
+ use_int8_w8a8: bool = False,
+ use_fp8_w8a16: bool = False,
+ use_mxfp4: bool = False,
+ w1_scale: Optional[torch.Tensor] = None,
+ w2_scale: Optional[torch.Tensor] = None,
+ block_size: Optional[list] = None,
+ a1_scale: Optional[torch.Tensor] = None,
+ a2_scale: Optional[torch.Tensor] = None,
+ w1_bias: Optional[torch.Tensor] = None,
+ w2_bias: Optional[torch.Tensor] = None,
+ alpha: Optional[float] = None,
+ limit: Optional[float] = None,
+ is_vnni: bool = False,
+) -> torch.Tensor:
+ """
+ C++ Fused MoE with brgemm optimization (sglang compatible interface).
+
+ Uses at::native::cpublas::brgemm for efficient batch GEMM on Intel CPUs.
+ Supports both silu_and_mul (standard SwiGLU) and swigluoai (GptOss) activations.
+
+ Args:
+ hidden_states: Input tensor [M, K]
+ w1: Gate and up projections [E, 2N, K]
+ w2: Down projection [E, K, N]
+ topk_weights: Expert weights [M, topk]
+ topk_ids: Expert indices [M, topk]
+ inplace: Whether to use hidden_states as output
+ use_int8_w8a8: Use int8 quantization
+ use_fp8_w8a16: Use fp8 quantization
+ use_mxfp4: Use mxfp4 quantization
+ w1_scale, w2_scale: Quantization scales
+ block_size: Block size for fp8
+ a1_scale, a2_scale: Activation scales
+ w1_bias, w2_bias: Optional biases
+ alpha: swigluoai alpha parameter (set to enable swiglu)
+ limit: swigluoai limit parameter (set to enable swiglu)
+ is_vnni: Whether w1/w2 are already in VNNI packed format
+ """
+ # MXFP4/FP8 kernels only support bf16, convert if needed
+ orig_dtype = hidden_states.dtype
+ need_convert = ((use_mxfp4 or use_fp8_w8a16) and orig_dtype != torch.bfloat16) or orig_dtype == torch.float32
+ if need_convert:
+ hidden_states = hidden_states.to(torch.bfloat16)
+
+ # bias must match hidden_states dtype
+ if w1_bias is not None:
+ w1_bias = w1_bias.to(hidden_states.dtype)
+ if w2_bias is not None:
+ w2_bias = w2_bias.to(hidden_states.dtype)
+
+ # Convert DTensor to local tensor for custom ops compatibility (TP mode)
+ hidden_states = _to_local_tensor(hidden_states)
+ w1 = _to_local_tensor(w1)
+ w2 = _to_local_tensor(w2)
+ topk_weights = _to_local_tensor(topk_weights)
+ topk_ids = _to_local_tensor(topk_ids)
+ w1_scale = _to_local_tensor(w1_scale)
+ w2_scale = _to_local_tensor(w2_scale)
+ a1_scale = _to_local_tensor(a1_scale)
+ a2_scale = _to_local_tensor(a2_scale)
+ w1_bias = _to_local_tensor(w1_bias)
+ w2_bias = _to_local_tensor(w2_bias)
+
+ output = ops.fused_experts(
+ hidden_states, w1, w2, topk_weights, topk_ids,
+ inplace, use_int8_w8a8, use_fp8_w8a16, use_mxfp4,
+ w1_scale, w2_scale, block_size, a1_scale, a2_scale,
+ w1_bias, w2_bias, alpha, limit, is_vnni
+ )
+
+ # Convert back to original dtype if needed
+ if need_convert:
+ output = output.to(orig_dtype)
+ return output
+
+
+class CPUMegaBlocksMoeMLP(torch.nn.Module):
+ """
+ C++ optimized MoE MLP using brgemm.
+ Drop-in replacement for cpu_fused_moe.MegaBlocksMoeMLP with better performance.
+
+ Usage in transformers:
+ # Will be used via @use_kernel_forward_from_hub decorator
+ """
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> tuple:
+ """
+ Forward pass through the MoE layer using C++ kernel.
+
+ Args:
+ x: Input tensor [batch_size, seq_len, hidden_size]
+
+ Returns:
+ Tuple of (output, expert_weights)
+ """
+ # Optimization for GPT-OSS model
+ if getattr(self, "use_mxfp4", None) is None:
+ self.use_mxfp4 = False
+
+ w1_scale = None
+ w2_scale = None
+
+ if (
+ not getattr(self, "packed_scales", False)
+ and hasattr(self.experts, "gate_up_proj")
+ and getattr(self.experts, "gate_up_proj_precision_config", None) is not None
+ ):
+ # convert scales
+ data_1 = ops.convert_scale_packed(self.experts.gate_up_proj_precision_config.weight_scale.data.transpose(-1, -2).contiguous())
+ data_2 = ops.convert_scale_packed(self.experts.down_proj_precision_config.weight_scale.data.transpose(-1, -2).contiguous())
+ self.experts.gate_up_proj_precision_config.weight_scale.storage.data = data_1
+ self.experts.down_proj_precision_config.weight_scale.storage.data = data_2
+ self.packed_scales = True
+ self.use_mxfp4 = True
+
+ if not getattr(self, "packed_weight", False) and hasattr(
+ self.experts, "gate_up_proj"
+ ):
+ # convert weights
+ data_1 = self.experts.gate_up_proj.data.transpose(-1, -2).contiguous()
+ data_2 = self.experts.down_proj.data.transpose(-1, -2).contiguous()
+ if self.use_mxfp4:
+ self.experts.gate_up_proj.storage.data = ops.convert_weight_packed(data_1)
+ self.experts.down_proj.storage.data = ops.convert_weight_packed(data_2)
+ else:
+ # convert_weight_packed only supports bfloat16, float16, int8, fp8_e4m3 or uint8(mxfp4 or int4).
+ data_1 = data_1.to(torch.bfloat16) if data_1.dtype == torch.float32 else data_1
+ data_2 = data_2.to(torch.bfloat16) if data_2.dtype == torch.float32 else data_2
+ self.experts.gate_up_proj.data = ops.convert_weight_packed(data_1)
+ self.experts.down_proj.data = ops.convert_weight_packed(data_2)
+
+ # C++ kernel does not support float32.
+ dtype = torch.bfloat16 if x.dtype == torch.float32 else x.dtype
+ if getattr(self.experts, "gate_up_proj_bias", None) is not None:
+ self.experts.gate_up_proj_bias.data = self.experts.gate_up_proj_bias.data.to(dtype)
+ if getattr(self.experts, "down_proj_bias", None) is not None:
+ self.experts.down_proj_bias.data = self.experts.down_proj_bias.data.to(dtype)
+
+ self.packed_weight = True
+
+ # Get MoE parameters
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
+
+ # Detect activation type
+ if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
+ activation = "swigluoai"
+ alpha = self.experts.alpha
+ limit = self.experts.limit
+ else:
+ activation = getattr(self.experts, "activation", "silu")
+ alpha = 1.702
+ limit = 7.0
+
+ # Get weight tensors
+ if hasattr(self.experts, "gate_up_proj"):
+ w1 = self.experts.gate_up_proj
+ elif hasattr(self.experts, "w1"):
+ w1 = self.experts.w1
+ w3 = getattr(self.experts, "w3", None)
+ if w3 is not None:
+ w1 = torch.cat([w1, w3], dim=-1)
+ else:
+ raise AttributeError("experts module must have 'gate_up_proj' or 'w1' attribute")
+
+ if hasattr(self.experts, "down_proj"):
+ w2 = self.experts.down_proj
+ elif hasattr(self.experts, "w2"):
+ w2 = self.experts.w2
+ else:
+ raise AttributeError("experts module must have 'down_proj' or 'w2' attribute")
+
+ # Get optional bias tensors
+ w1_bias = getattr(self.experts, "gate_up_proj_bias", None)
+ w2_bias = getattr(self.experts, "down_proj_bias", None)
+ w1_bias = w1_bias if w1_bias is None else w1_bias.data
+ w2_bias = w2_bias if w2_bias is None else w2_bias.data
+
+ if self.use_mxfp4:
+ w1_scale = self.experts.gate_up_proj_precision_config.weight_scale.data
+ w2_scale = self.experts.down_proj_precision_config.weight_scale.data
+
+ # Store original shape
+ in_shape = x.size()
+
+ # Route tokens to experts (Python implementation is fast enough)
+ logits, expert_weights, expert_indices = route_tokens_cpu(
+ x,
+ self.router.weight,
+ getattr(self.router, "bias", None),
+ moe_top_k,
+ moe_num_experts,
+ moe_normalize_expert_weights,
+ )
+
+ # Flatten input
+ x_flat = x.view(-1, x.shape[-1])
+
+ # Determine alpha/limit for swiglu activation
+ use_alpha = alpha if activation == "swigluoai" else None
+ use_limit = limit if activation == "swigluoai" else None
+
+ # Call C++ optimized kernel
+ output = fused_moe_cpp(
+ hidden_states=x_flat,
+ w1=w1.data,
+ w2=w2.data,
+ topk_weights=expert_weights,
+ topk_ids=expert_indices.to(torch.int32),
+ inplace=False,
+ use_int8_w8a8=False,
+ use_fp8_w8a16=False,
+ use_mxfp4=self.use_mxfp4,
+ w1_scale=w1_scale,
+ w2_scale=w2_scale,
+ block_size=None,
+ a1_scale=None,
+ a2_scale=None,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ alpha=use_alpha,
+ limit=use_limit,
+ is_vnni=getattr(self, "packed_weight", False),
+ )
+
+ # Restore original shape
+ output = output.view(in_shape)
+
+ return output, expert_weights
+
+
+__all__ = ["fused_moe_cpp", "MegaBlocksMoeMLP"]
diff --git a/build/torch211-cxx11-cu126-x86_64-linux/grouped_gemm/__init__.py b/build/torch211-cxx11-cu126-x86_64-linux/grouped_gemm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b91c8308f0c24f4c4171b6e4f15b6f76dabf295a
--- /dev/null
+++ b/build/torch211-cxx11-cu126-x86_64-linux/grouped_gemm/__init__.py
@@ -0,0 +1,2 @@
+from . import ops
+from . import backend
diff --git a/build/torch211-cxx11-cu126-x86_64-linux/grouped_gemm/backend.py b/build/torch211-cxx11-cu126-x86_64-linux/grouped_gemm/backend.py
new file mode 100644
index 0000000000000000000000000000000000000000..76037d8039cbfc2f0577275c78e4bc0be762592a
--- /dev/null
+++ b/build/torch211-cxx11-cu126-x86_64-linux/grouped_gemm/backend.py
@@ -0,0 +1,33 @@
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# # TODO(tgale): Wrap this in a try-block with better
+# # error message and instructions for building the
+# # c++ operations.
+# import grouped_gemm_backend as backend
+
+# We import the backend operations from the megablocks package as
+# grouped_gemm is vendored in megablocks in this repository.
+# from ... import _ops as backend
+# from megablocks._ops import ops as backend # type: ignore
+from .._ops import ops as backend # type: ignore
+
+def _allocate_output(a, b, batch_sizes, trans_a, trans_b):
+ assert not (trans_a and trans_b)
+ assert batch_sizes.ndim == 1, "Expected 1d tensor for batch_sizes"
+ assert a.ndim == 2, "Expected 2d tensor for 'a'"
+ assert b.ndim == (2 if trans_a else 3)
+
+ shape = (
+ (batch_sizes.shape[0], a.shape[1], b.shape[1])
+ if trans_a else
+ (a.shape[0], (b.shape[1] if trans_b else b.shape[2]))
+ )
+ return torch.empty(*shape, device=a.device, dtype=a.dtype)
+
+def gmm(a, b, batch_sizes, trans_a=False, trans_b=False, c=None):
+ if c is None:
+ c = _allocate_output(a, b, batch_sizes, trans_a, trans_b)
+ backend.gmm(a, b, c, batch_sizes, trans_a, trans_b)
+ return c
diff --git a/build/torch211-cxx11-cu126-x86_64-linux/grouped_gemm/ops.py b/build/torch211-cxx11-cu126-x86_64-linux/grouped_gemm/ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b30dd14e23837ea3b12334f4e31337ed9ad2b69
--- /dev/null
+++ b/build/torch211-cxx11-cu126-x86_64-linux/grouped_gemm/ops.py
@@ -0,0 +1,33 @@
+from . import backend
+import torch
+
+
+class GroupedGemm(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, a, b, batch_sizes, trans_b):
+ ctx.save_for_backward(a, b, batch_sizes)
+ ctx.trans_b = trans_b
+ return backend.gmm(a, b, batch_sizes, trans_a=False, trans_b=trans_b)
+
+ @staticmethod
+ def backward(ctx, grad):
+ grad = grad.contiguous()
+ a, b, batch_sizes = ctx.saved_tensors
+ trans_b = ctx.trans_b
+
+ agrad = None
+ if ctx.needs_input_grad[0]:
+ agrad = backend.gmm(
+ grad, b, batch_sizes, trans_a=False, trans_b=not trans_b)
+
+ bgrad = None
+ if ctx.needs_input_grad[1]:
+ lhs, rhs = (grad, a) if trans_b else (a, grad)
+ bgrad = backend.gmm(
+ lhs, rhs, batch_sizes, trans_a=True, trans_b=False)
+ return agrad, bgrad, None, None
+
+
+def gmm(a, b, batch_sizes, trans_b=False):
+ return GroupedGemm.apply(a, b, batch_sizes, trans_b)
diff --git a/build/torch211-cxx11-cu126-x86_64-linux/grouped_gemm_util.py b/build/torch211-cxx11-cu126-x86_64-linux/grouped_gemm_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d6d49fc46b0a57ad46e4179df3cc1ac2a24f7ae
--- /dev/null
+++ b/build/torch211-cxx11-cu126-x86_64-linux/grouped_gemm_util.py
@@ -0,0 +1,31 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+import warnings
+
+_grouped_gemm_is_available: bool = False
+try:
+ # import grouped_gemm
+ pass
+ _grouped_gemm_is_available = True
+except ImportError as error:
+ warnings.warn('Grouped GEMM not available.')
+
+
+def grouped_gemm_is_available():
+ return _grouped_gemm_is_available
+
+
+def assert_grouped_gemm_is_available():
+ msg = (
+ 'Grouped GEMM not available. Please run '
+ '`pip install git+https://github.com/tgale96/grouped_gemm@main`.',
+ )
+ assert _grouped_gemm_is_available, msg
+
+
+# backend = grouped_gemm.backend if grouped_gemm_is_available() else None
+# ops = grouped_gemm.ops if grouped_gemm_is_available() else None
+
+
+from .grouped_gemm import backend as ops
+from .grouped_gemm import ops as backend
diff --git a/build/torch211-cxx11-cu126-x86_64-linux/layers.py b/build/torch211-cxx11-cu126-x86_64-linux/layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c91edcd9e2d1a4ef9eac90217ff481f08ab1886
--- /dev/null
+++ b/build/torch211-cxx11-cu126-x86_64-linux/layers.py
@@ -0,0 +1,1232 @@
+import torch
+import torch.distributed as dist
+
+from typing import Optional, Any, TYPE_CHECKING
+
+from . import _layers
+from . import ops
+
+# Conditional import for meta kernel registration
+if TYPE_CHECKING:
+
+ def register_fake(fn):
+ return lambda name: fn
+
+else:
+ try:
+ from torch.library import register_fake
+ except ImportError:
+ try:
+ from torch.library import impl_abstract as register_fake
+ except ImportError:
+ # Fallback for older PyTorch versions
+ def register_fake(op_name):
+ def decorator(fn):
+ return fn
+
+ return decorator
+
+
+# Meta kernel implementations for torch.compile compatibility
+def _install_meta_kernels():
+ """Install meta kernels for existing MegaBlocks operations"""
+
+ # Create wrapper functions that check for compilation and return meta tensors
+
+ # Patch ops.sort
+ if hasattr(ops, "sort"):
+ original_sort = ops.sort
+
+ def sort_with_meta(x, end_bit=None):
+ if torch.compiler.is_compiling():
+ # print("Using meta kernel for sort")
+ # Meta implementation - return tensors with correct shape/dtype/device
+ return torch.empty_like(x), torch.empty_like(x)
+ # print("Using original sort kernel")
+ return original_sort(x, end_bit)
+
+ ops.sort = sort_with_meta
+
+ # Patch ops.histogram
+ if hasattr(ops, "histogram"):
+ original_histogram = ops.histogram
+
+ def histogram_with_meta(x, max_val):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty((max_val,), dtype=torch.int32, device=x.device)
+ return original_histogram(x, max_val)
+
+ ops.histogram = histogram_with_meta
+
+ # Patch ops.inclusive_cumsum
+ if hasattr(ops, "inclusive_cumsum"):
+ original_inclusive_cumsum = ops.inclusive_cumsum
+
+ def inclusive_cumsum_with_meta(x, dim):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty_like(x)
+ return original_inclusive_cumsum(x, dim)
+
+ ops.inclusive_cumsum = inclusive_cumsum_with_meta
+
+ # Patch ops.binned_gather
+ if hasattr(ops, "binned_gather"):
+ original_binned_gather = ops.binned_gather
+
+ def binned_gather_with_meta(x, indices, bins, bin_size, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - output shape based on bin_size
+ if x.dim() >= 2:
+ hidden_size = x.size(-1)
+ return torch.empty(
+ (bin_size, x.size(1), hidden_size),
+ dtype=x.dtype,
+ device=x.device,
+ )
+ else:
+ return torch.empty((bin_size,), dtype=x.dtype, device=x.device)
+ return original_binned_gather(x, indices, bins, bin_size, top_k)
+
+ ops.binned_gather = binned_gather_with_meta
+
+ # Patch ops.binned_scatter
+ if hasattr(ops, "binned_scatter"):
+ original_binned_scatter = ops.binned_scatter
+
+ def binned_scatter_with_meta(x, indices, weights, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - typically reduces to 2D
+ if x.dim() >= 3:
+ return torch.empty(
+ (x.size(1), x.size(2)), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty_like(x)
+ return original_binned_scatter(x, indices, weights, bins, top_k)
+
+ ops.binned_scatter = binned_scatter_with_meta
+
+ # Patch ops.gather
+ if hasattr(ops, "gather"):
+ original_gather = ops.gather
+
+ def gather_with_meta(x, indices, bin_ids, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ if x.dim() >= 2:
+ hidden_size = x.size(-1)
+ return torch.empty(
+ (indices.numel(), hidden_size), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty(indices.shape, dtype=x.dtype, device=x.device)
+ return original_gather(x, indices, bin_ids, bins, top_k)
+
+ ops.gather = gather_with_meta
+
+ # Patch ops.scatter
+ if hasattr(ops, "scatter"):
+ original_scatter = ops.scatter
+
+ def scatter_with_meta(x, indices, bin_ids, weights, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - restore sequence shape
+ seq_len = (
+ indices.size(0) // top_k
+ if indices.numel() > 0 and top_k > 0
+ else x.size(0)
+ )
+ if x.dim() >= 2:
+ return torch.empty(
+ (seq_len, x.size(-1)), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty((seq_len,), dtype=x.dtype, device=x.device)
+ return original_scatter(x, indices, bin_ids, weights, bins, top_k)
+
+ ops.scatter = scatter_with_meta
+
+ # Patch ops.replicate
+ if hasattr(ops, "replicate"):
+ original_replicate = ops.replicate
+
+ def replicate_with_meta(x, bins, num_outputs):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty(
+ (x.shape[0], num_outputs), dtype=x.dtype, device=x.device
+ )
+ return original_replicate(x, bins, num_outputs)
+
+ ops.replicate = replicate_with_meta
+
+ # Patch ops.repeat (if it's a regular function)
+ if hasattr(ops, "repeat"):
+ original_repeat = ops.repeat
+
+ def repeat_with_meta(x, repeats):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ if isinstance(repeats, (tuple, list)):
+ new_shape = list(x.shape)
+ for i, rep in enumerate(repeats):
+ if i < len(new_shape):
+ new_shape[i] *= rep
+ return torch.empty(new_shape, dtype=x.dtype, device=x.device)
+ else:
+ new_shape = [x.size(0) * repeats] + list(x.shape[1:])
+ return torch.empty(new_shape, dtype=x.dtype, device=x.device)
+ return original_repeat(x, repeats)
+
+ ops.repeat = repeat_with_meta
+
+
+# Install meta kernels on import
+try:
+ _install_meta_kernels()
+except Exception as e:
+ # If meta kernel installation fails, continue without them
+ # torch.compile may not work but the library will still function
+ import warnings
+
+ warnings.warn(
+ f"Failed to install meta kernels for torch.compile support: {e}", UserWarning
+ )
+
+
+# Set the expert model parallel attributes on a tensor
+def set_expert_model_parallel_attributes(
+ tensor: torch.Tensor,
+ is_parallel: bool,
+):
+ assert not hasattr(tensor, "expert_model_parallel")
+ setattr(tensor, "expert_model_parallel", is_parallel)
+
+
+# Get the expert model parallel attributes from a tensor
+def expert_sharding_degree(
+ world_size: int,
+ moe_num_experts: int,
+) -> int:
+ esd = min(world_size, moe_num_experts)
+ if (moe_num_experts % esd) != 0:
+ raise ValueError(f"Cannot shard {moe_num_experts} experts {esd} ways.")
+ return esd
+
+
+# Calculate the hidden sharding degree based on world size and expert sharding degree
+def hidden_sharding_degree(
+ world_size: int,
+ moe_num_experts: int,
+ ffn_hidden_size: int,
+) -> int:
+ esd = expert_sharding_degree(world_size, moe_num_experts)
+ hsd = world_size // esd
+ if (ffn_hidden_size % hsd) != 0:
+ raise ValueError(f"Cannot shard {ffn_hidden_size} features {hsd} ways.")
+ if (esd * hsd) != world_size:
+ raise ValueError(
+ f"Invalid sharding. expert_sharding_degree ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size})."
+ )
+ return hsd
+
+
+# Calculate the number of experts per rank based on world size and expert sharding degree
+def experts_per_rank(
+ moe_num_experts: int,
+ world_size: int,
+) -> int:
+ return moe_num_experts // expert_sharding_degree(world_size, moe_num_experts)
+
+
+# Calculate the number of features per rank based on ffn hidden size and hidden sharding degree
+def features_per_rank(
+ ffn_hidden_size: int, world_size: int, moe_num_experts: int
+) -> int:
+ return ffn_hidden_size // hidden_sharding_degree(
+ world_size, moe_num_experts, ffn_hidden_size
+ )
+
+
+# Apply jitter to the input tensor
+def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor:
+ low = 1.0 - moe_jitter_eps
+ high = 1.0 + moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return x * (low + noise * (high - low))
+
+
+# Compute the top-k scores from the logits
+def compute_top_k(scores: torch.Tensor, moe_top_k: int):
+ if moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, moe_top_k, dim=-1)
+
+
+# Route tokens to experts and compute expert weights and indices
+def route_tokens(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: torch.Tensor,
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ if training and moe_jitter_eps is not None:
+ x = apply_jitter(x, moe_jitter_eps)
+
+ x_flat = x.view(-1, x.shape[-1])
+ logits = torch.nn.functional.linear(x_flat, router_weight, router_bias)
+ expert_weights, expert_indices = compute_top_k(logits, moe_top_k)
+ expert_weights = expert_weights.softmax(dim=-1)
+ if moe_normalize_expert_weights is not None:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+ if uniform_expert_assignment:
+ expert_indices = _layers.router._uniform_expert_assignment(
+ expert_indices,
+ moe_num_experts,
+ )
+
+ return logits, expert_weights, expert_indices
+
+
+# Scale the gradient of the weights
+def scale_grad(
+ w: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+) -> torch.Tensor:
+ if gradient_scale is None:
+ return w
+ return _layers.mlp.scale_gradient(w, gradient_scale)
+
+
+# Forward pass for the MLP layer
+def mlp_forward(
+ x: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ limit: float = 7.0,
+):
+ # Scale weights
+ w1 = scale_grad(w1, gradient_scale)
+ w2 = scale_grad(w2, gradient_scale)
+ w1_bias = scale_grad(w1_bias, gradient_scale)
+ w2_bias = scale_grad(w2_bias, gradient_scale)
+
+ # Resolve dtensors
+ w1 = _layers.mlp.resolve_dtensor(w1)
+ w2 = _layers.mlp.resolve_dtensor(w2)
+ w1_bias = _layers.mlp.resolve_dtensor(w1_bias)
+ w2_bias = _layers.mlp.resolve_dtensor(w2_bias)
+
+ # Forward pass
+ gate_up = torch.bmm(x, w1) + w1_bias[..., None, :]
+ gate, up = gate_up[..., ::2], gate_up[..., 1::2]
+ gate = gate.clamp(min=None, max=limit)
+ up = up.clamp(min=-limit, max=limit)
+ glu = gate * torch.sigmoid(gate * alpha)
+ next_states = torch.bmm(((up + 1) * glu), w2)
+ next_states += w2_bias[..., None, :]
+ return next_states
+
+# Shared expert MLP forward pass
+def shared_mlp_forward(
+ x: torch.Tensor,
+ up_proj_weight: torch.Tensor,
+ down_proj_weight: torch.Tensor,
+ up_proj_bias: Optional[torch.Tensor] = None,
+ down_proj_bias: Optional[torch.Tensor] = None,
+ activation_fn: Optional[Any] = None,
+ gradient_scale: Optional[float] = None,
+) -> torch.Tensor:
+ # Default activation function
+ if activation_fn is None:
+ activation_fn = torch.nn.functional.gelu
+
+ # Scale weights
+ up_proj_weight = scale_grad(up_proj_weight, gradient_scale)
+ down_proj_weight = scale_grad(down_proj_weight, gradient_scale)
+ if up_proj_bias is not None:
+ up_proj_bias = scale_grad(up_proj_bias, gradient_scale)
+ if down_proj_bias is not None:
+ down_proj_bias = scale_grad(down_proj_bias, gradient_scale)
+
+ # Resolve dtensors
+ up_proj_weight = _layers.mlp.resolve_dtensor(up_proj_weight)
+ down_proj_weight = _layers.mlp.resolve_dtensor(down_proj_weight)
+ if up_proj_bias is not None:
+ up_proj_bias = _layers.mlp.resolve_dtensor(up_proj_bias)
+ if down_proj_bias is not None:
+ down_proj_bias = _layers.mlp.resolve_dtensor(down_proj_bias)
+
+ # Up projection
+ x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias)
+
+ # Activation
+ x = activation_fn(x)
+
+ # Down projection
+ x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias)
+
+ return x
+
+
+# Combine outputs from shared expert and regular experts
+def combine_expert_shared_outputs(
+ shared_expert_out: torch.Tensor,
+ expert_out: torch.Tensor,
+ shared_expert_weighted_sum: bool = False,
+ moe_top_k: int = 1,
+) -> torch.Tensor:
+ if shared_expert_weighted_sum:
+ # Weighted sum based on number of experts used
+ total_experts = moe_top_k + 1
+ shared_weight = 1.0 / total_experts
+ expert_weight = moe_top_k / total_experts
+ return shared_expert_out * shared_weight + expert_out * expert_weight
+ else:
+ # Simple addition
+ return shared_expert_out + expert_out
+
+
+# Global variable to store load balancing loss
+_LOAD_BALANCING_LOSS = []
+
+
+def save_load_balancing_loss(loss):
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.append(loss)
+
+
+def get_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ return _LOAD_BALANCING_LOSS
+
+
+def clear_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.clear()
+
+
+def batched_load_balancing_loss(args):
+ if args.moe_loss_weight == 0:
+ return 0.0
+
+ tokens_per_expert, expert_scores = zip(*get_load_balancing_loss())
+ num_layers_per_pipeline_stage = args.num_layers // args.pipeline_model_parallel_size
+ if args.num_layers_per_virtual_pipeline_stage is not None:
+ num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage
+
+ if len(tokens_per_expert) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f"Expected {num_layers_per_pipeline_stage} token_per_experts "
+ f"but found {len(tokens_per_expert)}.\nnum_layers = "
+ f"{args.num_layers}\npipeline_model_parallel_size = "
+ f"{args.pipeline_model_parallel_size}\n"
+ "num_layers_per_virtual_pipeline_stage"
+ f" = {args.num_layers_per_virtual_pipeline_stage}",
+ )
+ if len(expert_scores) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f"Expected {num_layers_per_pipeline_stage} expert_scores "
+ f"but found {len(tokens_per_expert)}.\nnum_layers = "
+ f"{args.num_layers}\npipeline_model_parallel_size = "
+ f"{args.pipeline_model_parallel_size}\n"
+ "num_layers_per_virtual_pipeline_stage"
+ f" = {args.num_layers_per_virtual_pipeline_stage}",
+ )
+
+ # Verify the shape of the tokens_per_expert and expert_scores tensors.
+ assert all(
+ (x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert)
+ )
+
+ tokens = expert_scores[0].shape[0]
+ assert all(
+ (
+ (
+ x.ndim == 2
+ and x.shape[1] == args.moe_num_experts
+ and x.shape[0] == tokens
+ )
+ for x in expert_scores
+ )
+ )
+
+ # Concatenate the contributions of each layer and convert to
+ # the correct types and formats for the dot product.
+ expert_scores = torch.cat(expert_scores, dim=1)
+ if args.moe_lbl_in_fp32:
+ expert_scores = expert_scores.float()
+ if tokens != 0:
+ expert_scores = expert_scores.mean(dim=0)
+ else:
+ expert_scores = expert_scores.sum(dim=0)
+ tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype)
+
+ expected_values = num_layers_per_pipeline_stage * args.moe_num_experts
+ assert tokens_per_expert.numel() == expected_values
+ assert expert_scores.numel() == expected_values
+
+ # Calculate the total scale across all factors.
+ #
+ # loss_weight * num_experts / (num_layers * tokens * top_k)
+ scale_numerator = args.moe_num_experts * args.moe_loss_weight
+ scale_denominator = args.num_layers * tokens * args.moe_top_k
+ scale = scale_numerator / scale_denominator
+ return scale * torch.dot(tokens_per_expert, expert_scores)
+
+
+# Calculate the expert capacity based on tokens, top_k, number of experts,
+# expert parallel group, capacity factor, and whether expert model parallelism is used.
+def expert_capacity(
+ tokens: int,
+ top_k: int,
+ num_experts: int,
+ expert_parallel_group: int,
+ moe_capacity_factor: float,
+ moe_expert_model_parallelism: bool,
+) -> int:
+ world_size = (
+ dist.get_world_size(expert_parallel_group)
+ if moe_expert_model_parallelism
+ else 1
+ )
+
+ tokens_per_expert = top_k * tokens * world_size / num_experts
+ return int(moe_capacity_factor * tokens_per_expert)
+
+
+def load_balancing_loss(
+ tokens_per_expert: torch.Tensor,
+ expert_scores: torch.Tensor,
+ top_k: int,
+ num_experts: int,
+):
+ assert len(expert_scores.size()) == 2
+ tokens, num_experts = expert_scores.size()
+ assert num_experts == num_experts
+ assert len(tokens_per_expert.size()) == 1
+ (num_experts,) = tokens_per_expert.size()
+ assert num_experts == num_experts
+ scale = num_experts / (tokens * top_k)
+ return scale * torch.dot(
+ tokens_per_expert.to(expert_scores.dtype),
+ expert_scores.mean(dim=0),
+ )
+
+
+def indices_and_bins(
+ top_expert: torch.Tensor,
+ sort_end_bit: int,
+ num_experts: int,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ top_expert = top_expert.int()
+
+ # Ensure contiguous memory layout
+ top_expert = top_expert.contiguous()
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(top_expert.device):
+ output = ops.sort(top_expert, sort_end_bit)
+ bin_ids, indices = output
+ tokens_per_expert = ops.histogram(top_expert, num_experts)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+
+ bins = bins.view(1) if not len(bins.size()) else bins
+ return indices, bin_ids, bins, tokens_per_expert
+
+
+def expert_capacity_fn(
+ tokens: int,
+ top_k: int,
+ num_experts: int,
+ expert_parallel_group: torch.distributed.ProcessGroup,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+) -> int:
+ world_size = (
+ dist.get_world_size(expert_parallel_group)
+ if moe_expert_model_parallelism
+ else 1
+ )
+ tokens_per_expert = top_k * tokens * world_size / num_experts
+ return int(moe_capacity_factor * tokens_per_expert)
+
+
+def permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ top_k,
+ w1,
+ w2,
+ w1_bias,
+ w2_bias,
+ gradient_scale,
+ alpha,
+):
+ # Route tokens to experts
+ x = x.view(-1, x.shape[-1])
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(x.device):
+ x = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
+
+ # Expert computation
+ x = mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale, alpha)
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(x.device):
+ # Route tokens back
+ out = ops.binned_scatter(x, indices, expert_weights, bins, top_k)
+ return out
+
+
+def forward_once(
+ x: torch.Tensor,
+ expert_weights: torch.Tensor,
+ top_experts: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ top_k: int = 4,
+ num_experts: int = 128,
+ expert_parallel_group: int = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ mlp_impl: Optional[str] = None,
+):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
+ top_experts, sort_end_bit, num_experts
+ )
+
+ # Calculate expert capacity
+ sl, bs, _ = x.size()
+
+ expert_capacity = expert_capacity_fn(
+ sl * bs,
+ top_k,
+ num_experts,
+ expert_parallel_group,
+ moe_capacity_factor,
+ moe_expert_model_parallelism,
+ )
+
+ if expert_capacity == 0:
+ expert_capacity = torch.max(tokens_per_expert).item()
+
+ x = permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ top_k,
+ w1,
+ w2,
+ w1_bias,
+ w2_bias,
+ gradient_scale,
+ alpha,
+ )
+ return x, tokens_per_expert
+
+
+def parallel_forward_once(
+ x: torch.Tensor,
+ expert_weights: torch.Tensor,
+ top_experts: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ top_k: int = 4,
+ num_experts: int = 128,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = True,
+ hidden_size: int = 1152,
+ mlp_impl: Optional[str] = "grouped",
+):
+ # Flatten inputs
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+
+ # TODO: remove debugging var
+ # my_rank = dist.get_rank(expert_parallel_group) if expert_parallel_group else 0
+
+ with torch.no_grad():
+ # Step 1: Local permutation setup
+ indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
+ top_experts, sort_end_bit, num_experts
+ )
+
+ # Calculate sharding parameters
+ world_size = dist.get_world_size(expert_parallel_group)
+ hidden_sharding_deg = hidden_sharding_degree(
+ world_size, num_experts, hidden_size
+ )
+ experts_per_rank_val = experts_per_rank(num_experts, world_size)
+
+ # Replicate token counts for hidden sharding
+ repeated_tokens_per_expert = ops.repeat(
+ tokens_per_expert, (hidden_sharding_deg,)
+ )
+
+ # Exchange token counts across devices
+ parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert)
+
+ # Ensure CUB knows which device to use
+ tpe_handle = dist.all_to_all_single(
+ parallel_tokens_per_expert,
+ repeated_tokens_per_expert,
+ group=expert_parallel_group,
+ async_op=True,
+ )
+
+ # Step 2: Local permutation - group tokens by target device
+ x = x.view(-1, x.shape[-1]) # [sl * bs, hs]
+ x = ops.gather(x, indices, bin_ids, bins, top_k)
+
+ # Step 3: Compute communication counts and exchange tokens
+ with torch.no_grad():
+ tpe_handle.wait()
+
+ # Reshape for per-device calculations
+ repeated_tokens_per_expert = repeated_tokens_per_expert.view(
+ world_size, experts_per_rank_val
+ )
+ parallel_tokens_per_expert = parallel_tokens_per_expert.view(
+ world_size, experts_per_rank_val
+ )
+
+ # Calculate send/recv counts
+ send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1).tolist()
+ # recv_counts = parallel_tokens_per_expert.cpu().sum(dim=-1).tolist()
+ parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
+ recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1).tolist()
+ tokens_received = sum(recv_counts)
+
+ # Replicate for hidden sharding
+ x = ops.repeat(x, (hidden_sharding_deg, 1))
+
+ # Cross-device token exchange
+ parallel_x, parallel_x_handle = _layers.all_to_all.all_to_all(
+ x, recv_counts, send_counts, expert_parallel_group, async_op=True
+ )
+
+ with torch.no_grad():
+ # Step 4: Setup for local expert computation
+ replicate_bins = ops.inclusive_cumsum(parallel_tokens_per_expert.flatten(), 0)
+ replicate_bins = (
+ replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins
+ )
+
+ # Create expert indices for received tokens
+ parallel_top_expert = torch.remainder(
+ torch.arange(
+ num_experts * hidden_sharding_deg,
+ dtype=torch.int32,
+ device=indices.device,
+ ),
+ experts_per_rank_val,
+ )
+ parallel_top_expert = ops.replicate(
+ parallel_top_expert.unsqueeze(dim=0),
+ replicate_bins,
+ tokens_received,
+ ).flatten()
+
+ # Sort tokens by expert assignment
+ parallel_bin_ids, parallel_indices = ops.sort(
+ parallel_top_expert,
+ sort_end_bit,
+ )
+
+ # Calculate bins for local experts
+ parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
+ dim=0, dtype=torch.int
+ )
+ parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = (
+ parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins
+ )
+
+ # Calculate expert capacity
+ expert_capacity = expert_capacity_fn(
+ tokens_received,
+ top_k,
+ experts_per_rank_val,
+ expert_parallel_group,
+ moe_capacity_factor,
+ moe_expert_model_parallelism,
+ )
+ if expert_capacity == 0:
+ expert_capacity = torch.max(parallel_tokens_per_expert).item()
+
+ # Locally permute the tokens and perform the expert computation.
+ # Block to make sure that the cross-device permutation is complete.
+ if mlp_impl == "grouped":
+ # GroupedMLP requires counts on CPU. We can use the tensor already
+ # moved to CPU for the prior all_to_all, which avoids an extra
+ # device synchronization.
+ parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+
+ # Step 5: Expert computation
+ parallel_x_handle.wait()
+
+ parallel_x = permute_and_compute(
+ parallel_x,
+ parallel_tokens_per_expert,
+ parallel_indices,
+ parallel_bin_ids,
+ None, # expert_weights
+ parallel_bins,
+ expert_capacity,
+ top_k=1,
+ w1=w1,
+ w2=w2,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ )
+
+ # Step 6: Reverse communication - send results back
+ x, _ = _layers.all_to_all.all_to_all(
+ parallel_x, send_counts, recv_counts, expert_parallel_group
+ )
+
+ # Step 7: Reduce across hidden sharding dimension
+ shape = (hidden_sharding_deg, -1, hidden_size)
+ x = x.view(shape).sum(dim=0)
+
+ # Step 8: Final local unpermutation
+ x = ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
+
+ return x, tokens_per_expert.flatten()
+
+
+def moe_forward(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: Optional[torch.Tensor],
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+ w1: torch.Tensor = None,
+ w2: torch.Tensor = None,
+ w1_bias: torch.Tensor = None,
+ w2_bias: torch.Tensor = None,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ forward_fn: Any = None,
+ hidden_size: int = None,
+ mlp_impl: str = "grouped",
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+
+ # Route tokens to experts
+ logits, expert_weights, expert_indices = route_tokens(
+ x,
+ router_weight,
+ router_bias,
+ moe_top_k,
+ moe_num_experts,
+ moe_jitter_eps,
+ moe_normalize_expert_weights,
+ uniform_expert_assignment,
+ training,
+ )
+
+ # Create router scores for output
+ router_scores = (
+ torch.zeros_like(logits)
+ .scatter_(1, expert_indices, expert_weights)
+ .transpose(0, 1)
+ )
+
+ in_shape = x.size()
+
+ # Prepare forward function arguments
+ forward_args = {
+ "x": x,
+ "expert_weights": expert_weights,
+ "top_experts": expert_indices,
+ "w1": w1,
+ "w2": w2,
+ "w1_bias": w1_bias,
+ "w2_bias": w2_bias,
+ "gradient_scale": gradient_scale,
+ "alpha": alpha,
+ "sort_end_bit": sort_end_bit,
+ "top_k": moe_top_k,
+ "num_experts": moe_num_experts,
+ "expert_parallel_group": expert_parallel_group,
+ "moe_capacity_factor": moe_capacity_factor,
+ "moe_expert_model_parallelism": moe_expert_model_parallelism,
+ "mlp_impl": mlp_impl,
+ }
+
+ # Add hidden_size for parallel forward
+ if moe_expert_model_parallelism and hidden_size is not None:
+ forward_args["hidden_size"] = hidden_size
+ elif moe_expert_model_parallelism and hidden_size is None:
+ # Infer hidden_size from input shape
+ forward_args["hidden_size"] = x.shape[-1]
+
+ # Compute expert outputs
+ x, tokens_per_expert = forward_fn(**forward_args)
+
+ # Save load balancing loss if needed
+ moe_loss_weight = 0.0 # Can be made configurable
+ if training and moe_loss_weight > 0:
+ save_load_balancing_loss((tokens_per_expert, logits))
+
+ # Restore original shape
+ x = x.view(in_shape)
+
+ return x, expert_weights, router_scores
+
+
+def moe_forward_with_shared_expert(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: Optional[torch.Tensor],
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+ w1: torch.Tensor = None,
+ w2: torch.Tensor = None,
+ w1_bias: torch.Tensor = None,
+ w2_bias: torch.Tensor = None,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ forward_fn: Any = None,
+ hidden_size: int = None,
+ mlp_impl: str = "grouped",
+ # Shared expert parameters
+ shared_up_proj_weight: Optional[torch.Tensor] = None,
+ shared_down_proj_weight: Optional[torch.Tensor] = None,
+ shared_up_proj_bias: Optional[torch.Tensor] = None,
+ shared_down_proj_bias: Optional[torch.Tensor] = None,
+ shared_expert_weighted_sum: bool = False,
+ shared_activation_fn: Optional[Any] = None,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+
+ # First, compute regular MoE forward pass
+ expert_out, expert_weights, router_scores = moe_forward(
+ x=x,
+ router_weight=router_weight,
+ router_bias=router_bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=training,
+ w1=w1,
+ w2=w2,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=moe_expert_model_parallelism,
+ forward_fn=forward_fn,
+ hidden_size=hidden_size,
+ mlp_impl=mlp_impl,
+ )
+
+ # If shared expert weights provided, compute shared expert output
+ if shared_up_proj_weight is not None and shared_down_proj_weight is not None:
+ shared_expert_out = shared_mlp_forward(
+ x=x,
+ up_proj_weight=shared_up_proj_weight,
+ down_proj_weight=shared_down_proj_weight,
+ up_proj_bias=shared_up_proj_bias,
+ down_proj_bias=shared_down_proj_bias,
+ activation_fn=shared_activation_fn,
+ gradient_scale=gradient_scale,
+ )
+
+ # Combine expert outputs
+ combined_out = combine_expert_shared_outputs(
+ shared_expert_out=shared_expert_out,
+ expert_out=expert_out,
+ shared_expert_weighted_sum=shared_expert_weighted_sum,
+ moe_top_k=moe_top_k,
+ )
+
+ return combined_out, expert_weights, router_scores
+
+ # Return regular MoE output if no shared expert
+ return expert_out, expert_weights, router_scores
+
+
+def create_shared_expert_weights(
+ hidden_size: int,
+ shared_expert_hidden_size: int,
+ device: torch.device,
+ dtype: torch.dtype,
+ init_method: Any,
+ output_layer_init_method: Any = None,
+) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
+
+ if output_layer_init_method is None:
+ output_layer_init_method = init_method
+
+ # Create weight tensors
+ up_proj_weight = torch.empty(
+ shared_expert_hidden_size,
+ hidden_size,
+ device=device,
+ dtype=dtype,
+ )
+ down_proj_weight = torch.empty(
+ hidden_size,
+ shared_expert_hidden_size,
+ device=device,
+ dtype=dtype,
+ )
+
+ # Initialize weights
+ init_method(up_proj_weight)
+ output_layer_init_method(down_proj_weight)
+
+ # No bias by default
+ return up_proj_weight, down_proj_weight, None, None
+
+
+# HACK: Extract device_mesh from pre-hook closure - required for transformers integration
+# This exists because device_mesh is trapped in hook closures with no model attribute
+# Fragile - breaks if hook structure changes or Python internals change
+# TODO: Replace with a more robust solution when available
+def get_device_mesh(model):
+ # Extract device_mesh from child's unused pre_hook closure
+ try:
+ # Find the pre-hook that contains 'device_mesh' in its closure
+ hook = next(
+ h
+ for h in model.experts._forward_pre_hooks.values()
+ if "device_mesh" in h.__code__.co_freevars
+ )
+ # Extract the device_mesh from the closure
+ return hook.__closure__[
+ hook.__code__.co_freevars.index("device_mesh")
+ ].cell_contents
+ except Exception:
+ return None
+
+
+class MegaBlocksMoeMLP(torch.nn.Module):
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ gradient_scale = getattr(self.experts, "gradient_scale", None)
+ alpha = getattr(self.experts, "alpha", 1.0)
+ moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+ uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
+
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = get_device_mesh(self)
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
+
+ has_parallel = (
+ expert_parallel_group is not None
+ and dist.is_initialized()
+ and dist.get_world_size(expert_parallel_group) > 1
+ )
+ forward_fn = parallel_forward_once if has_parallel else forward_once
+
+ sort_end_bit = max(
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
+ )
+ mlp_impl = getattr(self, "mlp_impl", "grouped")
+ output, expert_weights_out, *_ = moe_forward(
+ x=x,
+ router_weight=self.router.weight,
+ router_bias=self.router.bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=self.training,
+ w1=self.experts.gate_up_proj,
+ w2=self.experts.down_proj,
+ w1_bias=self.experts.gate_up_proj_bias,
+ w2_bias=self.experts.down_proj_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=has_parallel,
+ forward_fn=forward_fn,
+ hidden_size=self.experts.hidden_size,
+ mlp_impl=mlp_impl,
+ )
+ return output, expert_weights_out
+
+
+# Export main classes
+__all__ = ["MegaBlocksMoeMLP", "MegaBlocksMoeMLPWithSharedExpert"]
+
+
+class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
+
+ def __init__(self):
+ super().__init__()
+ # Shared expert weights will be set by the user
+ self.shared_up_proj_weight = None
+ self.shared_down_proj_weight = None
+ self.shared_up_proj_bias = None
+ self.shared_down_proj_bias = None
+ self.shared_expert_weighted_sum = False
+ self.shared_activation_fn = None
+
+ def set_shared_expert_weights(
+ self,
+ up_proj_weight: torch.Tensor,
+ down_proj_weight: torch.Tensor,
+ up_proj_bias: Optional[torch.Tensor] = None,
+ down_proj_bias: Optional[torch.Tensor] = None,
+ weighted_sum: bool = False,
+ activation_fn: Optional[Any] = None,
+ ):
+ self.shared_up_proj_weight = up_proj_weight
+ self.shared_down_proj_weight = down_proj_weight
+ self.shared_up_proj_bias = up_proj_bias
+ self.shared_down_proj_bias = down_proj_bias
+ self.shared_expert_weighted_sum = weighted_sum
+ self.shared_activation_fn = activation_fn
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ gradient_scale = getattr(self.experts, "gradient_scale", None)
+ alpha = getattr(self.experts, "alpha", 1.0)
+ moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+ uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
+
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = get_device_mesh(self)
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
+
+ has_parallel = (
+ expert_parallel_group is not None
+ and dist.is_initialized()
+ and dist.get_world_size(expert_parallel_group) > 1
+ )
+ forward_fn = parallel_forward_once if has_parallel else forward_once
+
+ sort_end_bit = max(
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
+ )
+ mlp_impl = getattr(self, "mlp_impl", "grouped")
+
+ output, expert_weights_out, *_ = moe_forward_with_shared_expert(
+ x=x,
+ router_weight=self.router.weight,
+ router_bias=self.router.bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=self.training,
+ w1=self.experts.gate_up_proj,
+ w2=self.experts.down_proj,
+ w1_bias=self.experts.gate_up_proj_bias,
+ w2_bias=self.experts.down_proj_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=has_parallel,
+ forward_fn=forward_fn,
+ hidden_size=self.experts.hidden_size,
+ mlp_impl=mlp_impl,
+ # Shared expert parameters
+ shared_up_proj_weight=self.shared_up_proj_weight,
+ shared_down_proj_weight=self.shared_down_proj_weight,
+ shared_up_proj_bias=self.shared_up_proj_bias,
+ shared_down_proj_bias=self.shared_down_proj_bias,
+ shared_expert_weighted_sum=self.shared_expert_weighted_sum,
+ shared_activation_fn=self.shared_activation_fn,
+ )
+ return output, expert_weights_out
+
+
+# Patch for XPU or CPU support
+if hasattr(torch, "xpu") and torch.xpu.is_available():
+ from .xpu_fused_moe import MegaBlocksMoeMLP
+
+from .cpu_moe_cpp import CPUMegaBlocksMoeMLP
diff --git a/build/torch211-cxx11-cu126-x86_64-linux/megablocks/__init__.py b/build/torch211-cxx11-cu126-x86_64-linux/megablocks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23
--- /dev/null
+++ b/build/torch211-cxx11-cu126-x86_64-linux/megablocks/__init__.py
@@ -0,0 +1,26 @@
+import ctypes
+import importlib.util
+import sys
+from pathlib import Path
+from types import ModuleType
+
+
+def _import_from_path(file_path: Path) -> ModuleType:
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
+ # it would also be used for other imports. So, we make a module name that
+ # depends on the path for it to be unique using the hex-encoded hash of
+ # the path.
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
+ module_name = path_hash
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ if spec is None:
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
+ module = importlib.util.module_from_spec(spec)
+ if module is None:
+ raise ImportError(f"Cannot load module {module_name} from spec")
+ sys.modules[module_name] = module
+ spec.loader.exec_module(module) # type: ignore
+ return module
+
+
+globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
diff --git a/build/torch211-cxx11-cu126-x86_64-linux/metadata.json b/build/torch211-cxx11-cu126-x86_64-linux/metadata.json
new file mode 100644
index 0000000000000000000000000000000000000000..155112c59509d3b4d07f4d090cbf57071e3f5217
--- /dev/null
+++ b/build/torch211-cxx11-cu126-x86_64-linux/metadata.json
@@ -0,0 +1,18 @@
+{
+ "version": 1,
+ "license": "Apache-2.0",
+ "python-depends": [],
+ "backend": {
+ "type": "cuda",
+ "archs": [
+ "7.0",
+ "7.2",
+ "7.5",
+ "8.0",
+ "8.6",
+ "8.7",
+ "8.9",
+ "9.0"
+ ]
+ }
+}
diff --git a/build/torch211-cxx11-cu126-x86_64-linux/ops/__init__.py b/build/torch211-cxx11-cu126-x86_64-linux/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b944080df810d0b0cfc571f3009b0098a651f9b7
--- /dev/null
+++ b/build/torch211-cxx11-cu126-x86_64-linux/ops/__init__.py
@@ -0,0 +1,35 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from .binned_gather import binned_gather
+from .binned_scatter import binned_scatter
+from .cumsum import exclusive_cumsum, inclusive_cumsum
+from .gather import gather
+from .histogram import histogram
+from .padded_gather import padded_gather
+from .padded_scatter import padded_scatter
+from .repeat import repeat
+from .replicate import replicate
+from .round_up import round_up
+from .scatter import scatter
+from .sort import sort
+from .sum import sum
+from .topology import topology
+
+__all__ = [
+ 'binned_gather',
+ 'binned_scatter',
+ 'exclusive_cumsum',
+ 'inclusive_cumsum',
+ 'gather',
+ 'histogram',
+ 'padded_gather',
+ 'padded_scatter',
+ 'repeat',
+ 'replicate',
+ 'round_up',
+ 'scatter',
+ 'sort',
+ 'sum',
+ 'topology',
+]
diff --git a/build/torch211-cxx11-cu126-x86_64-linux/ops/all_to_all_benchmark.py b/build/torch211-cxx11-cu126-x86_64-linux/ops/all_to_all_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c939818edca3345f6344bbc7cef07ffe3cd0181
--- /dev/null
+++ b/build/torch211-cxx11-cu126-x86_64-linux/ops/all_to_all_benchmark.py
@@ -0,0 +1,63 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import torch.distributed as dist
+
+# from megablocks import benchmark_util
+# from megablocks.layers.all_to_all import all_to_all
+
+from .. import benchmark_util
+from .._layers.all_to_all import all_to_all
+
+_ALL_TO_ALL_BENCHMARK = (
+ (8, 1024),
+ (16, 1024),
+ (32, 1024),
+ (64, 1024),
+ (128, 1024),
+ (256, 1024),
+ (512, 1024),
+ (1024, 1024),
+ (2 * 1024, 1024),
+ (4 * 1024, 1024),
+ (8 * 1024, 1024),
+ (16 * 1024, 1024),
+ (32 * 1024, 1024),
+ (64 * 1024, 1024),
+ (128 * 1024, 1024),
+ (256 * 1024, 1024),
+ (512 * 1024, 1024),
+ (1024 * 1024, 1024),
+)
+
+
+def benchmark_all_to_all(group, sl, hs):
+ world_size = dist.get_world_size(group)
+ assert (sl % world_size) == 0
+ send_recv_sizes = [sl // world_size] * world_size
+
+ x = torch.randn((sl, hs)).cuda().half()
+
+ details = {
+ 'world_size': world_size,
+ 'message_size (B)': send_recv_sizes[0] * hs * 2, # 2B elements.
+ }
+
+ def benchmark():
+ return all_to_all(x, send_recv_sizes, send_recv_sizes, group)
+
+ time, std = benchmark_util.benchmark_function(benchmark)
+
+ if dist.get_rank(group) == 0:
+ benchmark_util.log_benchmark('All-To-All', details, time, std)
+
+
+if __name__ == '__main__':
+ assert dist.is_available()
+ group = dist.init_process_group(backend='nccl')
+ local_rank = dist.get_rank(group)
+ torch.cuda.set_device(local_rank)
+
+ for args in _ALL_TO_ALL_BENCHMARK:
+ benchmark_all_to_all(group, *args)
diff --git a/build/torch211-cxx11-cu126-x86_64-linux/ops/binned_gather.py b/build/torch211-cxx11-cu126-x86_64-linux/ops/binned_gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..189a7fa3518d660f29ea32e7a04827164af98d60
--- /dev/null
+++ b/build/torch211-cxx11-cu126-x86_64-linux/ops/binned_gather.py
@@ -0,0 +1,37 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for binned_gather kernel.
+class BinnedGatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bins: torch.Tensor,
+ bin_size: int,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bins)
+ ctx.top_k = top_k
+ return kernels.binned_gather(x, indices, None, bins, bin_size, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ indices, bins = ctx.saved_tensors
+ out = kernels.binned_scatter(grad, indices, None, bins, ctx.top_k)
+ return out, None, None, None, None
+
+
+binned_gather = BinnedGatherOp.apply
diff --git a/build/torch211-cxx11-cu126-x86_64-linux/ops/binned_scatter.py b/build/torch211-cxx11-cu126-x86_64-linux/ops/binned_scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb937c0c106662ce8108c1cb926f8f063b163d3d
--- /dev/null
+++ b/build/torch211-cxx11-cu126-x86_64-linux/ops/binned_scatter.py
@@ -0,0 +1,59 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for binned_scatter kernel.
+class BinnedScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ):
+ assert len(x.size()) == 3
+ ctx.bin_size = x.size(1)
+ ctx.top_k = top_k
+
+ # TODO(tgale): Don't save 'x' for backwards if we don't need to
+ # calculate the gradient w.r.t. 'weights'.
+ ctx.save_for_backward(x, indices, weights, bins)
+ return kernels.binned_scatter(x, indices, weights, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ x, indices, weights, bins = ctx.saved_tensors
+ out = kernels.binned_gather(
+ grad,
+ indices,
+ weights,
+ bins,
+ ctx.bin_size,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[2]:
+ wgrad = kernels.binned_scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bins,
+ ctx.top_k,
+ )
+ return out, None, wgrad, None, None
+
+
+binned_scatter = BinnedScatterOp.apply
diff --git a/build/torch211-cxx11-cu126-x86_64-linux/ops/cumsum.py b/build/torch211-cxx11-cu126-x86_64-linux/ops/cumsum.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2b7572391e20045d335cf7337246e8a9b9f57ef
--- /dev/null
+++ b/build/torch211-cxx11-cu126-x86_64-linux/ops/cumsum.py
@@ -0,0 +1,52 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ # import megablocks_ops as ops # type: ignore
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrappers for cumsum kernels.
+# NOTE: Does not support gradients.
+class ExclusiveCumsumOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, dim: int):
+ if len(x.size()) == 1:
+ x = x.view([1, -1])
+ out = torch.empty_like(x)
+ ops.exclusive_cumsum(x, 1, out)
+ return out.squeeze()
+ out = torch.empty_like(x)
+ ops.exclusive_cumsum(x, dim, out)
+ return out
+
+
+exclusive_cumsum = ExclusiveCumsumOp.apply
+
+
+class InclusiveCumsumOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, dim: int) -> torch.Tensor:
+ if len(x.size()) == 1:
+ x = x.view([1, -1])
+ out = torch.empty_like(x)
+ ops.inclusive_cumsum(x, 1, out)
+ return out.squeeze()
+ out = torch.empty_like(x)
+ ops.inclusive_cumsum(x, dim, out)
+ return out
+
+
+inclusive_cumsum = InclusiveCumsumOp.apply
diff --git a/build/torch211-cxx11-cu126-x86_64-linux/ops/gather.py b/build/torch211-cxx11-cu126-x86_64-linux/ops/gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1f87c1e7bed8d3589dd790805234976e0b05898
--- /dev/null
+++ b/build/torch211-cxx11-cu126-x86_64-linux/ops/gather.py
@@ -0,0 +1,38 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for gather kernel.
+class GatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bin_ids, bins)
+ ctx.top_k = top_k
+ return kernels.gather(x, indices, bin_ids, None, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+
+ indices, bin_ids, bins = ctx.saved_tensors
+ out = kernels.scatter(grad, indices, bin_ids, None, bins, ctx.top_k)
+ return out, None, None, None, None, None
+
+
+gather = GatherOp.apply
diff --git a/build/torch211-cxx11-cu126-x86_64-linux/ops/histogram.py b/build/torch211-cxx11-cu126-x86_64-linux/ops/histogram.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b3f058ec373cbba7555704fb5e4212c3cc75d9d
--- /dev/null
+++ b/build/torch211-cxx11-cu126-x86_64-linux/ops/histogram.py
@@ -0,0 +1,27 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for histogram kernel.
+# NOTE: Does not support gradients.
+class HistogramOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, max_val: float):
+ return ops.histogram(x, max_val)
+
+
+histogram = HistogramOp.apply
diff --git a/build/torch211-cxx11-cu126-x86_64-linux/ops/histogram_benchmark.py b/build/torch211-cxx11-cu126-x86_64-linux/ops/histogram_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0b9c6047567b87a295979498142230d1b0c9411
--- /dev/null
+++ b/build/torch211-cxx11-cu126-x86_64-linux/ops/histogram_benchmark.py
@@ -0,0 +1,78 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import numpy as np
+import torch
+# from absl.testing import parameterized
+
+from .. import ops
+
+_HISTOGRAM_TESTS = (
+ (16384, torch.int32, 2),
+ (16384, torch.int32, 4),
+ (16384, torch.int32, 8),
+ (16384, torch.int32, 16),
+ (16384, torch.int32, 32),
+ (16384, torch.int32, 64),
+ (16384, torch.int32, 128),
+ (16384, torch.int32, 256),
+)
+
+
+def benchmark_function(fn, iterations=10):
+ # Run once to get rid of startup overhead.
+ fn()
+ times = []
+ for _ in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+ start.record()
+ fn()
+ end.record()
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ times = np.array(times)
+ return times.mean(), times.std(), times.max(), times.min()
+
+
+def log_benchmark(arguments, mean_t, std_t):
+ print('=' * 60)
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t))
+ print('=' * 60)
+
+
+# class HistogramBenchmark(parameterized.TestCase):
+#
+# @parameterized.parameters(*_HISTOGRAM_TESTS)
+# def testHistogram(self, n, dtype, max_val):
+# x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
+#
+# mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.histogram(x, max_val),)
+# arguments = {
+# 'n': n,
+# 'dtype': dtype,
+# 'max_val': max_val,
+# }
+# log_benchmark(arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_HISTOGRAM_TESTS)
+# def testTorchHistogram(self, n, dtype, max_val):
+# x = torch.randint(0, 128, (n,)).cuda().to(dtype)
+#
+# mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.histc(x, max_val, 0, max_val - 1),)
+# arguments = {
+# 'n': n,
+# 'dtype': dtype,
+# 'max_val': max_val,
+# }
+# log_benchmark(arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch211-cxx11-cu126-x86_64-linux/ops/matmul_benchmark.py b/build/torch211-cxx11-cu126-x86_64-linux/ops/matmul_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c4b9b8866ffed2eb769b77f2320c82e5491ae0e
--- /dev/null
+++ b/build/torch211-cxx11-cu126-x86_64-linux/ops/matmul_benchmark.py
@@ -0,0 +1,415 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+
+# import stk
+
+# try:
+# import stk
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/ops/matmul_benchmark.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+# from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+
+# Calling tensor.t() calls tensor.transpose(0, 1) which calls
+# torch.as_strided(...). Circumvent this chain to avoid an overhead
+# this adds.
+def transpose_view(x):
+ return torch.as_strided(
+ x,
+ (x.shape[1], x.shape[0]),
+ (x.stride()[1], x.stride()[0]),
+ )
+
+
+_MATMUL_TESTS = (
+ (64 * 1024, 512, 2048, 64),
+ (32 * 1024, 768, 3072, 64),
+ (8 * 1024, 1024, 4096, 64),
+ (4 * 2048, 4096, 4 * 4096, 4),
+)
+
+
+def log_benchmark(name, arguments, time, std, flops):
+ benchmark_util.log_benchmark(name, arguments, time, std)
+ print('flops = {:.2f}B'.format(flops / 1e9))
+ print('throughput = {:.2f}T'.format(flops / 1e9 / time))
+ print('=' * 60)
+
+
+# class MatmulBenchmark(parameterized.TestCase):
+#
+# def build_sparse_matrix(self, x, padded_bins, fhs, ne):
+# blocking = 128
+# padded_tokens, _ = x.size()
+# assert padded_tokens % blocking == 0
+# assert fhs % blocking == 0
+#
+# # Offsets for the sparse matrix. All rows have the
+# # same number of nonzero blocks dictated by the
+# # dimensionality of a single expert.
+# block_rows = padded_tokens // blocking
+# blocks_per_row = fhs // blocking
+# offsets = torch.arange(
+# 0,
+# block_rows * blocks_per_row + 1,
+# blocks_per_row,
+# dtype=torch.int32,
+# device=x.device,
+# )
+#
+# # Indices for the sparse matrix. The indices for
+# # the intermediate matrix are dynamic depending
+# # on the mapping of tokens to experts.
+# column_indices = ops.topology(
+# padded_bins,
+# blocking,
+# block_rows,
+# blocks_per_row,
+# )
+# data = torch.empty(
+# column_indices.numel(),
+# blocking,
+# blocking,
+# dtype=torch.float16,
+# device=x.device,
+# )
+# shape = (padded_tokens, fhs * ne)
+# row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
+# return stk.Matrix(shape, data, row_indices, column_indices, offsets)
+#
+# def build_input_matrix(self, sl, hs, ne):
+# x = torch.randn((sl, hs)).cuda().half()
+#
+# # Assign tokens to experts uniformly.
+# top_expert = torch.arange(0, sl).cuda().int() % ne
+#
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(top_expert, ne)
+# padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+# padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+# out = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, 1)
+# return out, padded_bins
+#
+# def build_weight_matrix(self, ne, hs, fhs):
+# return torch.randn((hs, ne * fhs)).cuda().half()
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_Fwd_SDD_NT(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+# w = transpose_view(w)
+#
+# def benchmark():
+# return stk.ops.sdd(x, w, topo)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0::Fwd::SDD::NT',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_GradX_DSD_NN(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+#
+# def benchmark():
+# return stk.ops.dsd(topo, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0::GradX::DSD::NN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_GradW_DSD_TN(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+# topo = topo.t()
+#
+# def benchmark():
+# return stk.ops.dsd(topo, x)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0::GradW::DSD::TN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_Fwd_DSD_NN(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+#
+# def benchmark():
+# return stk.ops.dsd(x, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::Fwd::DSD::NN',
+# arguments,
+# mean_t,
+# std_t,
+# x.nnz * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_GradX_SDD_NT(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+# out = stk.ops.dsd(x, w)
+# w = transpose_view(w)
+#
+# def benchmark():
+# return stk.ops.sdd(out, w, x)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::GradX::SDD::NT',
+# arguments,
+# mean_t,
+# std_t,
+# x.nnz * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_GradW_DSD_TN(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+# out = stk.ops.dsd(x, w)
+# x = x.t()
+#
+# def benchmark():
+# return stk.ops.dsd(x, out)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::GradW::DSD::TN',
+# arguments,
+# mean_t,
+# std_t,
+# x.nnz * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_Fwd_DDD_NT(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, hs)).cuda().half()
+# w = torch.randn((ne, hs, fhs)).cuda().half()
+#
+# w = w.transpose(1, 2).contiguous()
+# w = w.transpose(1, 2)
+#
+# def benchmark():
+# return torch.bmm(x, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0::Fwd:DDD::NT',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_GradX_DDD_NN(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, hs)).cuda().half()
+# w = torch.randn((ne, hs, fhs)).cuda().half()
+# out = torch.bmm(x, w)
+# w = w.transpose(1, 2).contiguous()
+#
+# def benchmark():
+# return torch.bmm(out, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0:GradX:DDD::NN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_GradW_DDD_TN(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, hs)).cuda().half()
+# w = torch.randn((ne, hs, fhs)).cuda().half()
+# out = torch.bmm(x, w)
+# out = out.transpose(1, 2)
+#
+# def benchmark():
+# return torch.bmm(out, x)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0:GradW:DDD::TN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_Fwd_DDD_NN(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+# w = torch.randn((ne, fhs, hs)).cuda().half()
+#
+# def benchmark():
+# return torch.bmm(x, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::Fwd::DDD::NN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_GradX_DDD_NT(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+# w = torch.randn((ne, fhs, hs)).cuda().half()
+# out = torch.bmm(x, w)
+# w = torch.transpose(w, 1, 2)
+#
+# def benchmark():
+# return torch.bmm(out, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::GradX::DDD::NT',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_GradW_DDD_TN(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+# w = torch.randn((ne, fhs, hs)).cuda().half()
+# out = torch.bmm(x, w)
+# x = torch.transpose(x, 1, 2)
+#
+# def benchmark():
+# return torch.bmm(x, out)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::GradW::DDD::TN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * hs * 2,
+# )
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch211-cxx11-cu126-x86_64-linux/ops/padded_gather.py b/build/torch211-cxx11-cu126-x86_64-linux/ops/padded_gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1cf4047c9494394d2a3884ba8830179013db7ff
--- /dev/null
+++ b/build/torch211-cxx11-cu126-x86_64-linux/ops/padded_gather.py
@@ -0,0 +1,55 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for padded_gather kernel.
+class PaddedGatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bin_ids, bins, padded_bins)
+ ctx.top_k = top_k
+ return kernels.padded_gather(
+ x,
+ indices,
+ bin_ids,
+ None,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+
+ indices, bin_ids, bins, padded_bins = ctx.saved_tensors
+ out = kernels.padded_scatter(
+ grad,
+ indices,
+ bin_ids,
+ None,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+ return out, None, None, None, None, None
+
+
+padded_gather = PaddedGatherOp.apply
diff --git a/build/torch211-cxx11-cu126-x86_64-linux/ops/padded_scatter.py b/build/torch211-cxx11-cu126-x86_64-linux/ops/padded_scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..61e021b81497e472cda5d72bdac557a0ca92d262
--- /dev/null
+++ b/build/torch211-cxx11-cu126-x86_64-linux/ops/padded_scatter.py
@@ -0,0 +1,98 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for padded_scatter kernel.
+class PaddedScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+ ):
+ maybe_x = [x] if ctx.needs_input_grad[3] else []
+ ctx.save_for_backward(
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ *maybe_x,
+ )
+ ctx.top_k = top_k
+ ctx.x_shape = x.shape
+ return kernels.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ saved_tensors = ctx.saved_tensors
+
+ indices, bin_ids, weights, bins, padded_bins = saved_tensors[:5]
+ dgrad = None
+ if ctx.needs_input_grad[0]:
+ dgrad = kernels.padded_gather(
+ grad,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[3]: # need wgrad
+ x = saved_tensors[-1]
+ wgrad = kernels.padded_scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+ return dgrad, None, None, wgrad, None, None, None, None
+
+
+def padded_scatter(
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+):
+ return PaddedScatterOp.apply(
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
diff --git a/build/torch211-cxx11-cu126-x86_64-linux/ops/padded_scatter_benchmark.py b/build/torch211-cxx11-cu126-x86_64-linux/ops/padded_scatter_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..3cbe4735891446b46f93170c64c23fe63632bf93
--- /dev/null
+++ b/build/torch211-cxx11-cu126-x86_64-linux/ops/padded_scatter_benchmark.py
@@ -0,0 +1,66 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import torch
+# from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+_PADDED_SCATTER_BENCHMARK = (
+ # dMoE-Medium, 8-way EMP.
+ (1024 * 16, 1024, 8, 4),
+ # dMoE-Medium, post-all-to-all.
+ (1024 * 16 * 4, 1024, 8, 1),
+)
+
+
+# class PaddedScatterTest(parameterized.TestCase):
+#
+# @parameterized.parameters(*_PADDED_SCATTER_BENCHMARK)
+# def testPaddedScatter(self, sl, hs, ne, top_k):
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+#
+# # Randomly assign tokens to experts.
+# top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(top_expert, ne)
+# padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+# padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+#
+# # Sample weights for the scatter reduce.
+# weights = torch.rand((sl * top_k,)).cuda().half()
+#
+# # Gather the data to prepare for backwards.
+# x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
+#
+# def benchmark():
+# return ops.padded_scatter(
+# x,
+# indices,
+# bin_ids,
+# weights,
+# bins,
+# padded_bins,
+# top_k,
+# )
+#
+# time, std = benchmark_util.benchmark_function(benchmark)
+# benchmark_util.log_benchmark(
+# 'Padded Scatter',
+# {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# 'top_k': top_k,
+# },
+# time,
+# std,
+# )
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch211-cxx11-cu126-x86_64-linux/ops/permute_benchmark.py b/build/torch211-cxx11-cu126-x86_64-linux/ops/permute_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..697abddbb3a2082ec4ddd6d94f89f7faabb34b40
--- /dev/null
+++ b/build/torch211-cxx11-cu126-x86_64-linux/ops/permute_benchmark.py
@@ -0,0 +1,149 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import torch
+# from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+_PERMUTE_TESTS = (
+ (16384, 768, 2),
+ (16384, 768, 4),
+ (16384, 768, 8),
+ (16384, 768, 16),
+ (16384, 768, 32),
+ (16384, 768, 64),
+ (16384, 768, 128),
+ (16384 * 8, 768, 2),
+ (16384 * 8, 768, 4),
+ (16384 * 8, 768, 8),
+ (16384 * 8, 768, 16),
+ (16384 * 8, 768, 32),
+ (16384 * 8, 768, 64),
+ (16384 * 8, 768, 128),
+)
+
+
+# class PermuteBenchmark(parameterized.TestCase):
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testBinnedGather(self, sl, hs, ne):
+# # NOTE: Capacity factor == 1.
+# ec = sl // ne
+#
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+# top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(indices, ne)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+#
+# def benchmark():
+# return ops.binned_gather(x, indices, bins, ec)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('BinnedGather', arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testBinnedScatter(self, sl, hs, ne):
+# # NOTE: Capacity factor == 1.
+# ec = sl // ne
+#
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+# top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(indices, ne)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+# x = ops.binned_gather(x, indices, bins, ec)
+#
+# def benchmark():
+# return ops.binned_scatter(x, indices, bins)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('BinnedScatter', arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testPaddedGather(self, sl, hs, ne):
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+#
+# # Randomly assign tokens to experts.
+# top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(top_expert, ne)
+# padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+# padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+#
+# def benchmark():
+# return ops.padded_gather(x, indices, bin_ids, bins, padded_bins)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('PaddedGather', arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testPaddedScatter(self, sl, hs, ne):
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+#
+# # Randomly assign tokens to experts.
+# top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(top_expert, ne)
+# padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+# padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+# x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins)
+#
+# def benchmark():
+# return ops.padded_scatter(x, indices, bin_ids, bins, padded_bins)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('PaddedScatter', arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testCopy(self, sl, hs, ne):
+# # NOTE: Capacity factor == 1.
+# # ec = sl // ne
+#
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+# y = x.clone()
+#
+# def benchmark():
+# return y.copy_(x)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('Copy', arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch211-cxx11-cu126-x86_64-linux/ops/repeat.py b/build/torch211-cxx11-cu126-x86_64-linux/ops/repeat.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e9e09de5f857d51cd758ab30b2f3a846d6f9275
--- /dev/null
+++ b/build/torch211-cxx11-cu126-x86_64-linux/ops/repeat.py
@@ -0,0 +1,10 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+
+def repeat(x: torch.Tensor, tiling: torch.Size):
+ if all((t == 1 for t in tiling)):
+ return x
+ return x.repeat(*tiling)
diff --git a/build/torch211-cxx11-cu126-x86_64-linux/ops/replicate.py b/build/torch211-cxx11-cu126-x86_64-linux/ops/replicate.py
new file mode 100644
index 0000000000000000000000000000000000000000..26daf0eede330603a4b8ea7167faf1411d07ca93
--- /dev/null
+++ b/build/torch211-cxx11-cu126-x86_64-linux/ops/replicate.py
@@ -0,0 +1,36 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for replicate kernel.
+class ReplicateOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, bins: torch.Tensor, num_outputs: int):
+ ctx.save_for_backward(bins)
+ out = torch.empty((x.shape[0], num_outputs), dtype=x.dtype, device=x.device)
+ ops.replicate_forward(x, bins, out)
+ return out
+
+ @staticmethod
+ def backward(ctx: Any, grad: torch.Tensor):
+ bins, = ctx.saved_tensors
+ out = torch.empty((grad.shape[0], bins.shape[0]), dtype=grad.dtype, device=grad.device)
+ ops.replicate_backward(grad, bins, out)
+ return out, None, None
+
+
+replicate = ReplicateOp.apply
diff --git a/build/torch211-cxx11-cu126-x86_64-linux/ops/round_up.py b/build/torch211-cxx11-cu126-x86_64-linux/ops/round_up.py
new file mode 100644
index 0000000000000000000000000000000000000000..6cf6bc873c9f448c5fa9126ebcfd66e8688002af
--- /dev/null
+++ b/build/torch211-cxx11-cu126-x86_64-linux/ops/round_up.py
@@ -0,0 +1,14 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+
+def round_up(x: torch.Tensor, value: int):
+ assert isinstance(value, int)
+ assert x.dtype == torch.int32
+
+ # TODO(tgale): If this becomes and issue
+ # do this in a custom kernel. We only expect
+ # to use this on arrays of less than 1k elements.
+ return torch.div(x + (value - 1), value, rounding_mode='trunc') * value
diff --git a/build/torch211-cxx11-cu126-x86_64-linux/ops/scatter.py b/build/torch211-cxx11-cu126-x86_64-linux/ops/scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4605d9b46f387761b070352365f223dbfe69d47
--- /dev/null
+++ b/build/torch211-cxx11-cu126-x86_64-linux/ops/scatter.py
@@ -0,0 +1,72 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Optional
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for scatter kernel.
+class ScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ) -> torch.Tensor:
+ maybe_x = [x] if ctx.needs_input_grad[3] else []
+ ctx.save_for_backward(indices, bin_ids, weights, bins, *maybe_x)
+ ctx.top_k = top_k
+ ctx.x_shape = x.shape
+ return kernels.scatter(x, indices, bin_ids, weights, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ saved_tensors = ctx.saved_tensors
+
+ indices, bin_ids, weights, bins = saved_tensors[:4]
+ dgrad = None
+ if ctx.needs_input_grad[0]:
+ dgrad = kernels.gather(
+ grad,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[3]: # need wgrad
+ x = saved_tensors[-1]
+ wgrad = kernels.scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bin_ids,
+ bins,
+ ctx.top_k,
+ )
+ return dgrad, None, None, wgrad, None, None, None
+
+
+def scatter(
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+) -> Optional[torch.Tensor]:
+ return ScatterOp.apply(x, indices, bin_ids, weights, bins, top_k)
diff --git a/build/torch211-cxx11-cu126-x86_64-linux/ops/sort.py b/build/torch211-cxx11-cu126-x86_64-linux/ops/sort.py
new file mode 100644
index 0000000000000000000000000000000000000000..bda3bf64283e39533c2eae3627e76bb2d0262c9f
--- /dev/null
+++ b/build/torch211-cxx11-cu126-x86_64-linux/ops/sort.py
@@ -0,0 +1,38 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Optional, Tuple
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+_BITS_FOR_DTYPE = {
+ torch.int16: 16,
+ torch.int32: 32,
+ torch.int64: 64,
+}
+
+
+# Autograd wrapper for sort kernel.
+# NOTE: Does not support gradients.
+class SortOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, end_bit: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
+ if end_bit is None:
+ end_bit = _BITS_FOR_DTYPE[x.dtype]
+ x_out = torch.empty_like(x)
+ iota_out = torch.empty_like(x)
+ ops.sort(x, end_bit, x_out, iota_out)
+ return (x_out, iota_out)
+
+
+sort = SortOp.apply
diff --git a/build/torch211-cxx11-cu126-x86_64-linux/ops/sort_benchmark.py b/build/torch211-cxx11-cu126-x86_64-linux/ops/sort_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..11043c0824c36372585f1d9f48480c2a6ef32eb6
--- /dev/null
+++ b/build/torch211-cxx11-cu126-x86_64-linux/ops/sort_benchmark.py
@@ -0,0 +1,85 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import numpy as np
+import torch
+# from absl.testing import parameterized
+
+from .. import ops
+
+_SORT_TESTS = (
+ (16384, torch.int32, None),
+ (16384, torch.int32, 2),
+ (16384, torch.int32, 128),
+)
+
+_BASELINE_SORT_TESTS = ((16384,),)
+
+
+def numpy_dtype(dtype):
+ types = {
+ torch.int16: np.int16,
+ torch.int32: np.int32,
+ torch.int64: np.int64,
+ }
+ return types[dtype]
+
+
+def benchmark_function(fn, iterations=10):
+ # Run once to get rid of startup overhead.
+ fn()
+ times = []
+ for _ in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+ start.record()
+ fn()
+ end.record()
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ times = np.array(times)
+ return times.mean(), times.std(), times.max(), times.min()
+
+
+def log_benchmark(arguments, mean_t, std_t):
+ print('=' * 60)
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t))
+ print('=' * 60)
+
+
+# class SortBenchmark(parameterized.TestCase):
+#
+# @parameterized.parameters(*_SORT_TESTS)
+# def testSort(self, n, dtype, max_val):
+# if max_val is None:
+# max_val = np.iinfo(numpy_dtype(dtype)).max
+# end_bit = int(np.ceil(np.log2(max_val)))
+# x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
+#
+# mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.sort(x, end_bit),)
+# arguments = {
+# 'n': n,
+# 'dtype': dtype,
+# 'max_val': max_val,
+# }
+# log_benchmark(arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_BASELINE_SORT_TESTS)
+# def testTorchSort(self, n):
+# x = torch.randint(0, 128, (n,)).cuda().to(torch.int32)
+#
+# mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.sort(x))
+# arguments = {
+# 'n': n,
+# }
+# log_benchmark(arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch211-cxx11-cu126-x86_64-linux/ops/stk_autocast.py b/build/torch211-cxx11-cu126-x86_64-linux/ops/stk_autocast.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a3626e5e0eec51339c95a448bca84be14a2ca93
--- /dev/null
+++ b/build/torch211-cxx11-cu126-x86_64-linux/ops/stk_autocast.py
@@ -0,0 +1,39 @@
+# vendored from
+# https://github.com/stanford-futuredata/stk/blob/736313768ef697ce13a0594a41b2512a0fbc9884/stk/backend/autocast.py
+import functools
+import torch
+
+
+def _is_eligible(x):
+ return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64)
+
+
+def _cast(x, dtype):
+ if isinstance(x, torch.Tensor) and _is_eligible(x):
+ return x.to(dtype)
+ elif isinstance(x, map):
+ return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()}
+ elif isinstance(x, list) or isinstance(x, tuple):
+ return type(x)(map(lambda y: _cast(y, dtype), x))
+ return x
+
+
+def custom_fwd(fwd):
+ """Wrap a custom autograd function that always uses autocast dtype."""
+
+ @functools.wraps(fwd)
+ def decorate_fwd(*args, **kwargs):
+ if torch.is_autocast_enabled():
+ with torch.autocast(device_type="cuda", enabled=False):
+ dtype = torch.get_autocast_gpu_dtype()
+ return fwd(*_cast(args, dtype), **_cast(kwargs, dtype))
+ return fwd(*args, **kwargs)
+ return decorate_fwd
+
+
+def custom_bwd(bwd):
+ @functools.wraps(bwd)
+ def decorate_bwd(*args, **kwargs):
+ with torch.autocast(device_type="cuda", enabled=False):
+ return bwd(*args, **kwargs)
+ return decorate_bwd
\ No newline at end of file
diff --git a/build/torch211-cxx11-cu126-x86_64-linux/ops/sum.py b/build/torch211-cxx11-cu126-x86_64-linux/ops/sum.py
new file mode 100644
index 0000000000000000000000000000000000000000..e00c1aa68e1193f5b72f75a2edc37de8d505facc
--- /dev/null
+++ b/build/torch211-cxx11-cu126-x86_64-linux/ops/sum.py
@@ -0,0 +1,9 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+import torch
+
+
+def sum(x: torch.Tensor, dim: int = 0):
+ if x.shape[dim] == 1:
+ return x.squeeze(dim=dim)
+ return x.sum(dim=dim)
diff --git a/build/torch211-cxx11-cu126-x86_64-linux/ops/topology.py b/build/torch211-cxx11-cu126-x86_64-linux/ops/topology.py
new file mode 100644
index 0000000000000000000000000000000000000000..76a50d3164db20534b099dcb4d8487a7aef25d15
--- /dev/null
+++ b/build/torch211-cxx11-cu126-x86_64-linux/ops/topology.py
@@ -0,0 +1,45 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for topology kernel.
+# NOTE: Does not support gradients.
+class TopologyOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(
+ ctx: Any,
+ padded_bins: torch.Tensor,
+ block_size: int,
+ output_block_rows: int,
+ output_block_columns: int,
+ ):
+ out = torch.empty(
+ output_block_rows * output_block_columns,
+ dtype=torch.int16,
+ device=padded_bins.device,
+ )
+ ops.indices(
+ padded_bins,
+ block_size,
+ output_block_rows,
+ output_block_columns,
+ out,
+ )
+ return out
+
+
+topology = TopologyOp.apply
diff --git a/build/torch211-cxx11-cu126-x86_64-linux/stk/__init__.py b/build/torch211-cxx11-cu126-x86_64-linux/stk/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..73c40c267b1c3f4949e9c957a5d2c9f682dfc1a6
--- /dev/null
+++ b/build/torch211-cxx11-cu126-x86_64-linux/stk/__init__.py
@@ -0,0 +1,7 @@
+# import stk.random
+# import stk.ops
+# from stk.matrix import Matrix
+
+from . import random
+from . import ops
+from .matrix import Matrix
diff --git a/build/torch211-cxx11-cu126-x86_64-linux/stk/backend/__init__.py b/build/torch211-cxx11-cu126-x86_64-linux/stk/backend/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/build/torch211-cxx11-cu126-x86_64-linux/stk/backend/autocast.py b/build/torch211-cxx11-cu126-x86_64-linux/stk/backend/autocast.py
new file mode 100644
index 0000000000000000000000000000000000000000..97f6e919a60f3fd579ed0215031008d14111dc96
--- /dev/null
+++ b/build/torch211-cxx11-cu126-x86_64-linux/stk/backend/autocast.py
@@ -0,0 +1,37 @@
+import functools
+import torch
+
+
+def _is_eligible(x):
+ return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64)
+
+
+def _cast(x, dtype):
+ if isinstance(x, torch.Tensor) and _is_eligible(x):
+ return x.to(dtype)
+ elif isinstance(x, map):
+ return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()}
+ elif isinstance(x, list) or isinstance(x, tuple):
+ return type(x)(map(lambda y: _cast(y, dtype), x))
+ return x
+
+
+def custom_fwd(fwd):
+ """Wrap a custom autograd function that always uses autocast dtype."""
+
+ @functools.wraps(fwd)
+ def decorate_fwd(*args, **kwargs):
+ if torch.is_autocast_enabled():
+ with torch.autocast(device_type="cuda", enabled=False):
+ dtype = torch.get_autocast_gpu_dtype()
+ return fwd(*_cast(args, dtype), **_cast(kwargs, dtype))
+ return fwd(*args, **kwargs)
+ return decorate_fwd
+
+
+def custom_bwd(bwd):
+ @functools.wraps(bwd)
+ def decorate_bwd(*args, **kwargs):
+ with torch.autocast(device_type="cuda", enabled=False):
+ return bwd(*args, **kwargs)
+ return decorate_bwd
diff --git a/build/torch211-cxx11-cu126-x86_64-linux/stk/backend/sputnik.py b/build/torch211-cxx11-cu126-x86_64-linux/stk/backend/sputnik.py
new file mode 100644
index 0000000000000000000000000000000000000000..220c947bc1e932e8c77cc30f4069e9930f1aa962
--- /dev/null
+++ b/build/torch211-cxx11-cu126-x86_64-linux/stk/backend/sputnik.py
@@ -0,0 +1,316 @@
+import torch
+
+from ..backend import triton_kernels as backend
+from ..backend.autocast import custom_bwd, custom_fwd
+
+
+def _standardize_shape(x, transpose):
+ if transpose:
+ return torch.Size((x[1], x[0]))
+ return x
+
+
+def _sparse_transpose(x):
+ return (torch.Size((x[0][1], x[0][0])), ) + x[1:]
+
+
+def _transpose_helper(x, transpose):
+ if isinstance(x, torch.Tensor):
+ return x.t() if transpose else x
+ if transpose:
+ x = _sparse_transpose(x)
+ return x + (transpose,)
+
+
+def _wrap(x):
+ if isinstance(x, torch.Tensor):
+ return (x,)
+ return x
+
+
+def _is_transposed(x):
+ return (not x.is_contiguous() and
+ x.stride()[0] == 1 and
+ x.stride()[1] == x.size()[0])
+
+
+def _call_helper(op, out, a, b, trans_a, trans_b):
+ args = (_wrap(_transpose_helper(a, trans_a)) +
+ _wrap(_transpose_helper(b, trans_b)))
+ if isinstance(out, tuple):
+ args = args + out
+ return op(*args)
+
+
+def _preprocess_inputs(lhs, rhs, dy):
+ if isinstance(lhs, torch.Tensor) and _is_transposed(lhs):
+ lhs = lhs.t()
+ if isinstance(rhs, torch.Tensor) and _is_transposed(rhs):
+ rhs = rhs.t()
+ if (isinstance(dy, torch.Tensor) and
+ not dy.is_contiguous() and
+ not _is_transposed(dy)):
+ dy = dy.contiguous()
+ if isinstance(dy, tuple) and not dy[1].is_contiguous():
+ dy = (dy[0], dy[1].contiguous()) + dy[2:]
+ return lhs, rhs, dy
+
+
+def _postprocess_outputs(x, transpose, grad):
+ if isinstance(x, torch.Tensor) and transpose:
+ return grad.t()
+ return grad
+
+
+def _lhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs):
+ lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy)
+
+ a, b = (rhs, dy) if trans_lhs else (dy, rhs)
+ trans_a = trans_lhs and trans_rhs
+ trans_b = trans_lhs or not trans_rhs
+ out = _call_helper(op, lhs, a, b, trans_a, trans_b)
+ return _postprocess_outputs(lhs, trans_lhs, out)
+
+
+def _rhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs):
+ lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy)
+
+ a, b = (dy, lhs) if trans_rhs else (lhs, dy)
+ trans_a = not trans_lhs or trans_rhs
+ trans_b = trans_lhs and trans_rhs
+ out = _call_helper(op, rhs, a, b, trans_a, trans_b)
+ return _postprocess_outputs(rhs, trans_rhs, out)
+
+
+class DSD(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs):
+ ctx.save_for_backward(data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ rhs)
+ ctx.shape = _standardize_shape(shape, transpose_a)
+ ctx.transpose_a = transpose_a
+
+ out = torch.empty(
+ (shape[0], rhs.size()[1]),
+ dtype=rhs.dtype,
+ device=rhs.device)
+
+ backend.dsd(shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs,
+ out)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs = (ctx.shape,) + saved_tensors[:-1]
+ rhs = saved_tensors[-1]
+ trans_a = ctx.transpose_a
+ trans_b = _is_transposed(rhs)
+
+ ddata = None
+ if ctx.needs_input_grad[1]:
+ ddata = _lhs_gradient(sdd,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ drhs = None
+ if ctx.needs_input_grad[-1]:
+ op = dds if trans_b else dsd
+ drhs = _rhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return None, ddata, None, None, None, None, None, None, None, drhs
+
+
+dsd = DSD.apply
+
+
+class DDS(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b):
+ ctx.save_for_backward(lhs,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t)
+ ctx.shape = _standardize_shape(shape, transpose_b)
+ ctx.transpose_b = transpose_b
+ out = torch.empty((lhs.size()[0], shape[1]),
+ dtype=lhs.dtype,
+ device=lhs.device)
+ backend.dds(lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b,
+ out)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs = saved_tensors[0]
+ rhs = (ctx.shape,) + saved_tensors[1:]
+ trans_a = _is_transposed(lhs)
+ trans_b = ctx.transpose_b
+
+ dlhs = None
+ if ctx.needs_input_grad[0]:
+ op = dsd if trans_a else dds
+ dlhs = _lhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ ddata = None
+ if ctx.needs_input_grad[2]:
+ ddata = _rhs_gradient(sdd,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return dlhs, None, ddata, None, None, None, None, None, None, None
+
+
+dds = DDS.apply
+
+
+class SDD(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ lhs,
+ rhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t):
+ ctx.save_for_backward(
+ lhs,
+ rhs,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t)
+ ctx.shape = shape
+ out = torch.empty(
+ data.shape,
+ dtype=lhs.dtype,
+ device=lhs.device)
+ backend.sdd(lhs,
+ rhs,
+ shape,
+ out,
+ offsets,
+ row_indices,
+ column_indices)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs, rhs = saved_tensors[:2]
+ dy = (ctx.shape, dy) + saved_tensors[2:]
+ trans_a = _is_transposed(lhs)
+ trans_b = _is_transposed(rhs)
+
+ dlhs = None
+ if ctx.needs_input_grad[0]:
+ op = dds if trans_a else dsd
+ dlhs = _lhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ drhs = None
+ if ctx.needs_input_grad[1]:
+ op = dsd if trans_b else dds
+ drhs = _rhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return dlhs, drhs, None, None, None, None, None, None, None, None
+
+
+sdd = SDD.apply
+
+class RowIndices(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, shape, data, offsets, column_indices):
+ out = torch.empty(
+ column_indices.shape,
+ dtype=column_indices.dtype,
+ device=column_indices.device)
+ backend.row_indices(shape, data, offsets, column_indices, out)
+ return out
+
+
+row_indices = RowIndices.apply
diff --git a/build/torch211-cxx11-cu126-x86_64-linux/stk/backend/triton_kernels.py b/build/torch211-cxx11-cu126-x86_64-linux/stk/backend/triton_kernels.py
new file mode 100644
index 0000000000000000000000000000000000000000..c535309f3321249f475367164a558f94a4f8eb86
--- /dev/null
+++ b/build/torch211-cxx11-cu126-x86_64-linux/stk/backend/triton_kernels.py
@@ -0,0 +1,393 @@
+import torch
+import triton
+import triton.language as tl
+from dataclasses import dataclass
+
+@dataclass
+class TritonConfig:
+ BLOCK_M: int = 128
+ BLOCK_N: int = 128
+ BLOCK_K: int = 32
+ BLOCK_SIZE: int = 128
+ NUM_STAGES: int = 4
+ NUM_WARPS: int = 4
+
+def _validate_matmul_dims(M: int, K: int, N: int):
+ error_string = "incompatible dimensions: tensor has dim with length: {}, which must be divisible by {}"
+ assert M % TritonConfig.BLOCK_M == 0, error_string.format(M, TritonConfig.BLOCK_M)
+ assert K % TritonConfig.BLOCK_K == 0, error_string.format(K, TritonConfig.BLOCK_K)
+ assert N % TritonConfig.BLOCK_N == 0, error_string.format(N, TritonConfig.BLOCK_N)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _sdd_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+ # matrix multiplication
+ pid = tl.program_id(0)
+ pid_m = tl.load(row_indices + pid)
+ pid_n = tl.load(column_indices + pid)
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
+ rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
+ rk = tl.arange(0, BLOCK_K)
+ # pointers
+ A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
+ B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ for k in range(0, tl.cdiv(K, BLOCK_K)):
+ a = tl.load(A)
+ b = tl.load(B)
+ acc += tl.dot(a, b)
+ A += BLOCK_K * stride_ak
+ B += BLOCK_K * stride_bk
+ #Store to sparse matrix
+ acc = acc.to(C.dtype.element_ty)
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+ cm = tl.arange(0, BLOCK_M)
+ cn = tl.arange(0, BLOCK_N)
+ C = C + pid * BLOCK_ELEMENTS + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _dsd_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices, offsets,
+ block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+
+ # matrix multiplication
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+
+ num_pid_m = tl.num_programs(0)
+ num_pid_n = tl.num_programs(1)
+ pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M)
+
+ start_inx = tl.load(offsets + pid_m)
+ end_inx = tl.load(offsets + pid_m + 1)
+
+ # pointers to sparse matrix
+ rm = tl.arange(0, BLOCK_M)
+ rak = tl.arange(0, BLOCK_K)
+
+ A += (rm[:, None] * stride_am + rak[None, :] * stride_ak)
+
+ # pointers to dense matrix
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ rbk = tl.arange(0, BLOCK_K)
+ B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn)
+
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K)
+
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+ ak_sub_incr = BLOCK_K * stride_ak
+ bk_sub_incr = BLOCK_K * stride_bk
+ bk_block_incr = BLOCK_SIZE * stride_bk
+
+ for k in range(nsub_blocks * (end_inx - start_inx)):
+ sub_block_inx = k % nsub_blocks
+ block_inx = k // nsub_blocks
+
+ if trans_A:
+ ptr_A = A + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr
+ else:
+ ptr_A = A + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr
+
+ ptr_B = B + tl.load(column_indices + start_inx + block_inx) * bk_block_incr + sub_block_inx * bk_sub_incr
+
+ a = tl.load(ptr_A)
+ b = tl.load(ptr_B)
+ acc += tl.dot(a, b)
+
+ acc = acc.to(C.dtype.element_ty)
+
+ cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+
+ C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _dds_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices, offsets,
+ block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+
+ # matrix multiplication
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+
+ num_pid_m = tl.num_programs(0)
+ num_pid_n = tl.num_programs(1)
+ pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M)
+
+ start_inx = tl.load(offsets + pid_n)
+ end_inx = tl.load(offsets + pid_n + 1)
+
+ # pointers to dense matrix
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ rak = tl.arange(0, BLOCK_K)
+
+ A += (rm[:, None] * stride_am + rak[None, :] * stride_ak)
+
+ # pointers to sparse matrix
+ rn = tl.arange(0, BLOCK_N)
+ rbk = tl.arange(0, BLOCK_K)
+ B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn)
+
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K)
+
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+
+ ak_sub_incr = BLOCK_K * stride_ak
+ ak_block_incr = BLOCK_SIZE * stride_ak
+ bk_sub_incr = BLOCK_K * stride_bk
+
+ for k in range(nsub_blocks * (end_inx - start_inx)):
+ sub_block_inx = k % nsub_blocks
+ block_inx = k // nsub_blocks
+
+ if trans_B:
+ ptr_B = B + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr
+ else:
+ ptr_B = B + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr
+
+ ptr_A = A + tl.load(column_indices + start_inx + block_inx) * ak_block_incr + sub_block_inx * ak_sub_incr
+ a = tl.load(ptr_A)
+ b = tl.load(ptr_B)
+ acc += tl.dot(a, b)
+
+ acc = acc.to(C.dtype.element_ty)
+ cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+def dsd(shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs,
+ out
+ ):
+
+ device = rhs.device
+ trans_A = transpose_a
+ trans_B = False
+
+ if rhs.stride(0) > 1 and rhs.stride(1) > 1:
+ trans_B = True
+
+ # checks constraints
+ assert shape[1] == rhs.shape[0], "incompatible dimensions"
+ M, K = shape
+ _, N = rhs.shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if rhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ stride_am, stride_ak = data.stride(1), data.stride(2)
+ stride_bk, stride_bn = rhs.stride(0), rhs.stride(1)
+ a_column_indices = column_indices
+ a_offsets = offsets
+
+ # launch kernel
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N']))
+
+ if trans_A:
+ stride_am, stride_ak = data.stride(2), data.stride(1)
+ a_column_indices, a_offsets = column_indices_t, offsets_t
+
+ if trans_B:
+ stride_bk, stride_bn = rhs.stride(1), rhs.stride(0)
+
+ _dsd_kernel[grid](
+ data.data, rhs, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(0), out.stride(1),
+ row_indices, a_column_indices, a_offsets,
+ block_offsets_t, trans_A, trans_B,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+ # return out
+
+def dds(lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b,
+ out
+ ):
+
+ device = lhs.device
+ trans_B = transpose_b
+ trans_A = False
+
+ if lhs.stride(0) > 1 and lhs.stride(1) > 1:
+ trans_A = True
+
+ # checks constraints
+ assert lhs.shape[1] == shape[0], "incompatible dimensions"
+ M, K = lhs.shape
+ _, N = shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if lhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ stride_am, stride_ak = lhs.stride(0), lhs.stride(1)
+ stride_bk, stride_bn = data.stride(1), data.stride(2)
+ b_column_indices = column_indices_t
+ b_offsets = offsets_t
+
+ # launch kernel
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N']))
+
+ if trans_A:
+ stride_am, stride_ak = lhs.stride(1), lhs.stride(0)
+ if trans_B:
+ stride_bk, stride_bn = data.stride(2), data.stride(1)
+ b_column_indices, b_offsets = column_indices, offsets
+
+ _dds_kernel[grid](
+ lhs, data, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(0), out.stride(1),
+ row_indices, b_column_indices, b_offsets,
+ block_offsets_t, trans_A, trans_B,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+
+def sdd(lhs,
+ rhs,
+ shape,
+ out,
+ offsets,
+ row_indices,
+ column_indices
+ ):
+
+ device = out.device
+ trans_A = False
+ trans_B = False
+
+ if lhs.stride(0) > 1 and lhs.stride(1) > 1:
+ trans_A = True
+ if rhs.stride(0) > 1 and rhs.stride(1) > 1:
+ trans_B = True
+
+ # checks constraints
+ assert lhs.shape[1] == rhs.shape[0], "incompatible dimensions"
+ M, K = lhs.shape
+ _, N = rhs.shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if out.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ # launch kernel
+ nnz_blocks = len(row_indices)
+ grid = lambda META: (nnz_blocks,)
+
+ stride_am, stride_ak = lhs.stride(0), lhs.stride(1)
+ stride_bk, stride_bn = rhs.stride(0), rhs.stride(1)
+
+ if trans_A:
+ stride_am, stride_ak = lhs.stride(1), lhs.stride(0)
+ if trans_B:
+ stride_bk, stride_bn = rhs.stride(1), rhs.stride(0)
+
+ _sdd_kernel[grid](
+ lhs, rhs, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(1), out.stride(2),
+ row_indices, column_indices,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+
+@triton.jit
+def _row_indices_kernel(offsets, out):
+ pid = tl.program_id(0)
+ row_offset = tl.load(offsets + pid)
+ nnz_blocks = tl.load(offsets + pid + 1) - row_offset
+ for nnz_block in range(nnz_blocks):
+ tl.store(out + row_offset + nnz_block, pid)
+
+def row_indices(
+ shape, data, offsets, column_indices, out
+):
+ block_rows = len(offsets) - 1
+ _row_indices_kernel[(block_rows, )](offsets, out)
diff --git a/build/torch211-cxx11-cu126-x86_64-linux/stk/matrix.py b/build/torch211-cxx11-cu126-x86_64-linux/stk/matrix.py
new file mode 100644
index 0000000000000000000000000000000000000000..80f42263d6aada287adbfa52a61fe950162a9e28
--- /dev/null
+++ b/build/torch211-cxx11-cu126-x86_64-linux/stk/matrix.py
@@ -0,0 +1,329 @@
+import numpy as np
+import torch
+
+# 1. Add heavyweight (data) validation helper.
+# 2. Add construction helpers
+# 3. Make indentation consistent
+# 4. Replace asserts with descriptive errors.
+
+##
+### Validation helpers.
+##
+
+
+def _validate_matrix(shape, data, row_indices, column_indices, offsets):
+ # Data should be [nnz, block_size, block_size]
+ if data.dim() == 1:
+ data = torch.reshape(data, [data.numel(), 1, 1])
+
+ # Blocks should be square.
+ if data.shape[-2] != data.shape[-1]:
+ raise ValueError(
+ "Expected square blocking in data. "
+ f"Got block shape {[data.shape[-2], data.shape[-1]]}")
+
+ # Flatten batch dimensions on data - original shape preserved
+ # in shape argument.
+ block_size = data.shape[-1]
+ data = data.view([-1, block_size, block_size])
+
+ if data.dim() != 3:
+ raise ValueError(
+ "Expected 3D shape for data (nnz, block, block). "
+ f"Got shape {data.dim()}D shape.")
+
+ block_size = data.shape[1]
+ if shape[-2] % block_size != 0 or shape[-1] % block_size != 0:
+ raise ValueError(
+ "Matrix shape must be dividible by blocking. "
+ f"Got shape {shape} with "
+ f"{[block_size, block_size]} blocking.")
+
+ if np.prod(shape) < data.numel():
+ raise ValueError(
+ "Invalid matrix. Number of nonzeros exceeds matrix capacity "
+ f"({data.numel()} v. {np.prod(shape)})")
+
+ if row_indices.dim() != 1:
+ raise ValueError(
+ f"Expected 1D row_indices. Got {row_indices.dim()}D row_indices.")
+
+ if column_indices.dim() != 1:
+ raise ValueError(
+ f"Expected 1D column_indices. Got {column_indices.dim()}D column_indices.")
+
+ if offsets.dim() != 1:
+ raise ValueError(
+ f"Expected 1D offsets. Got {offsets.dim()}D offsets.")
+
+ if row_indices.numel() != data.shape[0]:
+ raise ValueError(
+ "Expected 1 index per nonzero block. "
+ f"Got {row_indices.numel()} row_indices for {data.shape[0]} blocks")
+
+ if column_indices.numel() != data.shape[0]:
+ raise ValueError(
+ "Expected 1 index per nonzero block. "
+ f"Got {column_indices.numel()} column_indices for {data.shape[0]} blocks")
+
+ block_rows = np.prod(shape[:-1]) / block_size
+ if offsets.numel() != block_rows + 1:
+ raise ValueError(
+ "Expected one offset per block row plus one. "
+ f"Got {offsets.numel()} offsets with {block_rows} block rows.")
+
+ is_cuda = (data.is_cuda and
+ row_indices.is_cuda and
+ column_indices.is_cuda and
+ offsets.is_cuda)
+ is_cpu = (not data.is_cuda and
+ not row_indices.is_cuda and
+ not column_indices.is_cuda and
+ not offsets.is_cuda)
+ if not (is_cuda or is_cpu):
+ raise ValueError(
+ "Expected data & meta-data on common device. "
+ f"Got data on {data.device}, row_indices on {row_indices.device} "
+ f"column_indices on {column_indices.device} and "
+ f"offsets on {offsets.device}.")
+
+ if data.dtype != torch.float16:
+ raise ValueError(
+ f"Expected float16 data. Got {data.dtype} data.")
+ if row_indices.dtype != torch.int16:
+ raise ValueError(
+ f"Expected int16 row_indices. Got {row_indices.dtype} row_indices.")
+ if column_indices.dtype != torch.int16:
+ raise ValueError(
+ f"Expected int16 column_indices. Got {column_indices.dtype} column_indices.")
+ if offsets.dtype != torch.int32:
+ raise ValueError(
+ f"Expected int32 offsets. Got {offsets.dtype} offsets.")
+ return data
+
+
+def _transpose(size, data, row_indices, column_indices, offsets):
+ block_columns = size[1] // data.shape[1]
+
+ # Sort row indices by column indices to get the transposed matrix's
+ # column indices.
+ gather_indices = column_indices.argsort()
+ column_indices_t = row_indices.gather(0, gather_indices)
+ block_offsets_t = gather_indices.int()
+
+ # NOTE: Histogram is not implemented for any integer type on CPU. Do
+ # the histogram in 32-bit float, which can exactly represent 16-bit
+ # integers.
+ column_indices_float = column_indices.float()
+
+ zero = torch.zeros((1,), dtype=torch.int32, device=data.device)
+ nnz_per_column = column_indices_float.histc(block_columns, 0, block_columns)
+ nnz_per_column = nnz_per_column.int()
+ offsets_t = torch.cat([zero, nnz_per_column.cumsum(0, dtype=torch.int32)])
+ return column_indices_t, offsets_t, block_offsets_t
+
+
+class Matrix(torch.nn.Module):
+ """A matrix stored in sparse format.
+
+ Underlying format is block compressed sparse row (BCSR).
+
+ TODO(tgale): Make this mirror torch.Tensor API as much as possible.
+ """
+
+ def __init__(self,
+ size,
+ data,
+ row_indices,
+ column_indices,
+ offsets,
+ column_indices_t=None,
+ offsets_t=None,
+ block_offsets_t=None):
+ super().__init__()
+ self._size = size
+ self._data = data
+ self._row_indices = row_indices
+ self._column_indices = column_indices
+ self._offsets = offsets
+
+ # Produce the transpose meta-data if it is not passed in.
+ if ((column_indices_t is None) or (offsets_t is None) or
+ (block_offsets_t is None)):
+ column_indices_t, offsets_t, block_offsets_t = _transpose(
+ size, data, row_indices, column_indices, offsets)
+ self._column_indices_t = column_indices_t
+ self._offsets_t = offsets_t
+ self._block_offsets_t = block_offsets_t
+
+ self._transposed = False
+
+ # Validate that our metadata will not overflow.
+ max_dim = np.iinfo(np.int16).max * self.blocking
+ if column_indices.dtype == torch.int16:
+ if size[0] > max_dim or size[1] > max_dim:
+ raise ValueError(
+ "Sparse matrix with shape {size} exceeds representable "
+ "size with 16-bit indices.")
+
+ def validate(self):
+ _validate_matrix(self._size,
+ self._data,
+ self._row_indices,
+ self._column_indices,
+ self._offsets)
+
+ # TODO(tgale): Add heavyweight data validation.
+
+ def to(self, device):
+ # TODO(tgale): Handle type conversions here. We
+ # need to set the appropriate meta-data type for
+ # the given floating-point type.
+ self._data = self._data.to(device)
+ self._row_indices = self._row_indices.to(device)
+ self._column_indices = self._column_indices.to(device)
+ self._offsets = self._offsets.to(device)
+ self._column_indices_t = self._column_indices_t.to(device)
+ self._offsets_t = self._offsets_t.to(device)
+ self._block_offsets_t = self._block_offsets_t.to(device)
+ return self
+
+ def cuda(self):
+ return self.to(torch.cuda.current_device())
+
+ def clone(self):
+ return Matrix(
+ self.size(),
+ self.data.clone(),
+ self.row_indices.clone(),
+ self.column_indices.clone(),
+ self.offsets.clone(),
+ self.column_indices_t.clone(),
+ self.offsets_t.clone(),
+ self.block_offsets_t.clone())
+
+ def t(self):
+ if self.dim() != 2:
+ raise ValueError(
+ "t() expects a tensor with <= 2 dimensions, "
+ f"but self is {self.dim()}D.")
+ out = Matrix(self.size(),
+ self.data,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+ out._transposed = not self._transposed
+ out._size = torch.Size((self._size[1], self._size[0]))
+ return out
+
+ def contiguous(self):
+ raise ValueError("Not yet implemented.")
+
+ def is_contiguous(self):
+ return not self._transposed
+
+ @property
+ def is_cuda(self):
+ return self._data.is_cuda
+
+ @property
+ def device(self):
+ return self._data.device
+
+ def size(self):
+ return self._size
+
+ @property
+ def shape(self):
+ return self.size()
+
+ def dim(self):
+ return len(self._size)
+
+ @property
+ def data(self):
+ return self._data
+
+ @property
+ def row_indices(self):
+ return self._row_indices
+
+ @property
+ def column_indices(self):
+ return self._column_indices
+
+ @property
+ def offsets(self):
+ return self._offsets
+
+ @property
+ def offsets_t(self):
+ return self._offsets_t
+
+ @property
+ def column_indices_t(self):
+ return self._column_indices_t
+
+ @property
+ def block_offsets_t(self):
+ return self._block_offsets_t
+
+ @property
+ def dtype(self):
+ return self.data.dtype
+
+ @property
+ def nnz(self):
+ return self.data.numel()
+
+ @property
+ def blocking(self):
+ return self.data.shape[1]
+
+ @property
+ def requires_grad(self):
+ return self.data.requires_grad
+
+ def requires_grad_(self, x):
+ self.data.requires_grad_(x)
+ return self
+
+ def view(self, *shape):
+ assert self.is_contiguous()
+ if shape[-1] != self.size()[-1]:
+ raise ValueError(
+ "Can't change view on compressed dimension. "
+ f"{self.size()[-1]} v. {shape[-1]}.")
+ if np.prod(shape) != np.prod(self.size()):
+ raise ValueError(
+ "Mismatch in numel of Matrix and new shape. "
+ f"{np.prod(self.size())} v. {np.prod(shape)}")
+ return Matrix(shape,
+ self.data,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+
+ @property
+ def grad(self):
+ # TODO(tgale): Make sure this mirrors torch.Tensor
+ # behavior in the case where we ask for the gradient
+ # of a non-contiguous tensor.
+ size = self.size()
+ if not self.is_contiguous():
+ size = torch.Size((size[1], size[0]))
+ out = Matrix(size,
+ self.data.grad,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+ return out if self.is_contiguous() else out.t()
diff --git a/build/torch211-cxx11-cu126-x86_64-linux/stk/ops/__init__.py b/build/torch211-cxx11-cu126-x86_64-linux/stk/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc873b236f4cd4036964c016a4036e3ce5ebf1ac
--- /dev/null
+++ b/build/torch211-cxx11-cu126-x86_64-linux/stk/ops/__init__.py
@@ -0,0 +1,3 @@
+from .linear_ops import dds, dsd, sdd
+from .matrix_ops import ones_like, row_indices, sum, to_dense, to_sparse
+from .eltwise_ops import mul
diff --git a/build/torch211-cxx11-cu126-x86_64-linux/stk/ops/eltwise_ops.py b/build/torch211-cxx11-cu126-x86_64-linux/stk/ops/eltwise_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba7d7332320250fd01fa60e60528f19de3e8ed03
--- /dev/null
+++ b/build/torch211-cxx11-cu126-x86_64-linux/stk/ops/eltwise_ops.py
@@ -0,0 +1,28 @@
+from ..matrix import Matrix
+
+def mul(a, b):
+ """Performs element-wise multiplication of matrices a and b.
+
+ It is the user's responsibility to make sure that a and b
+ follow the same matrix topology. This function assumes it is safe
+ to use the topoplogy of a.
+
+ Args:
+ a: stk.Matrix.
+ b: stk.Matrix with a's matrix topology.
+
+ Returns:
+ stk.Matrix where the entries correspond to torch.mul(a, b).
+ """
+ assert isinstance(a, Matrix)
+ assert isinstance(b, Matrix)
+ assert a.size() == b.size()
+
+ return Matrix(a.size(),
+ a.data * b.data,
+ a.row_indices,
+ a.column_indices,
+ a.offsets,
+ a.column_indices_t,
+ a.offsets_t,
+ a.block_offsets_t)
diff --git a/build/torch211-cxx11-cu126-x86_64-linux/stk/ops/eltwise_ops_test.py b/build/torch211-cxx11-cu126-x86_64-linux/stk/ops/eltwise_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..2939372a5c68ac92b47b11015db4f75f4fd60ffa
--- /dev/null
+++ b/build/torch211-cxx11-cu126-x86_64-linux/stk/ops/eltwise_ops_test.py
@@ -0,0 +1,86 @@
+import unittest
+import itertools
+import torch
+# from absl.testing import parameterized
+
+import stk
+from stk.ops.linear_ops_test import allclose, _dense_and_sparse
+
+_MATRIX_SIZES = (
+ (128, 128, 0.0),
+ (256, 256, 0.5),
+ (2048, 1024, 0.8),
+ (512, 128, 0.0),
+ (128, 512, 0.0),
+ (1024, 512, 0.0),
+ (1024, 512, 0.5),
+ (1024, 512, 0.75),
+ (512, 1024, 0.0),
+ (512, 1024, 0.5),
+ (512, 1024, 0.75),
+ (1024, 1024, 0.0),
+ (1024, 1024, 0.5),
+ (1024, 1024, 0.75),
+)
+
+_DTYPE = (
+ torch.float16, torch.bfloat16
+)
+
+def _generate_testcases():
+ testcases = itertools.product(_MATRIX_SIZES, _DTYPE)
+ testcases = [(*size, 128, dtype) for
+ (size, dtype) in testcases]
+ return testcases
+
+_ELTWISE_OP_TESTS = _generate_testcases()
+
+def _dense_and_sparse_like(x, std=0.1):
+ dense_data = torch.randn_like(x.data, device=x.device) * std
+ sparse = stk.Matrix(x.size(),
+ dense_data,
+ x.row_indices,
+ x.column_indices,
+ x.offsets)
+ dense = stk.ops.to_dense(sparse)
+
+ return (dense.requires_grad_(True),
+ sparse.requires_grad_(True))
+
+# @parameterized.parameters(_ELTWISE_OP_TESTS)
+# class EltwiseOpsTest(parameterized.TestCase):
+#
+# def testEltwiseMul(self, m, n, sparsity, blocking, dtype):
+#
+# a_dense, a = _dense_and_sparse(m, n, sparsity, blocking, dtype)
+# b_dense, b = _dense_and_sparse_like(a)
+#
+# out = stk.ops.mul(a, b)
+# expected_out = torch.mul(a_dense, b_dense)
+#
+# # Compute the gradients w.r.t. the inputs.
+# expected_out.sum().backward()
+# stk.ops.sum(out).backward()
+#
+# # Validate the results.
+# out = stk.ops.to_dense(out)
+# self.assertEqual(out.dim(), 2)
+# self.assertEqual(expected_out.size(), out.size())
+# self.assertTrue(allclose(out, expected_out))
+#
+# # LHS gradient.
+# grad = stk.ops.to_dense(a.grad)
+# expected_grad = a_dense.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size(), grad.size())
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# # RHS gradient.
+# grad = stk.ops.to_dense(b.grad)
+# expected_grad = b_dense.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size(), grad.size())
+# self.assertTrue(allclose(grad, expected_grad))
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch211-cxx11-cu126-x86_64-linux/stk/ops/linear_ops.py b/build/torch211-cxx11-cu126-x86_64-linux/stk/ops/linear_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d277c8c07f9e30addc31900a12175c8a1f4d7ad
--- /dev/null
+++ b/build/torch211-cxx11-cu126-x86_64-linux/stk/ops/linear_ops.py
@@ -0,0 +1,59 @@
+import torch
+
+from ..backend import sputnik
+from ..matrix import Matrix
+
+
+def dsd(a, b):
+ assert isinstance(a, Matrix)
+ assert isinstance(b, torch.Tensor)
+ return sputnik.dsd(
+ a.size(),
+ a.data, a.offsets,
+ a.row_indices,
+ a.column_indices,
+ a.offsets_t,
+ a.column_indices_t,
+ a.block_offsets_t,
+ not a.is_contiguous(),
+ b)
+
+
+def dds(a, b):
+ assert isinstance(a, torch.Tensor)
+ assert isinstance(b, Matrix)
+ return sputnik.dds(
+ a,
+ b.size(),
+ b.data, b.offsets,
+ b.row_indices,
+ b.column_indices,
+ b.offsets_t,
+ b.column_indices_t,
+ b.block_offsets_t,
+ not b.is_contiguous())
+
+
+def sdd(a, b, topo):
+ assert isinstance(a, torch.Tensor)
+ assert isinstance(b, torch.Tensor)
+ assert isinstance(topo, Matrix)
+ assert topo.is_contiguous()
+ out = sputnik.sdd(
+ a, b,
+ topo.size(),
+ topo.data,
+ topo.offsets,
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets_t,
+ topo.column_indices_t,
+ topo.block_offsets_t)
+ return Matrix(topo.size(),
+ out,
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets,
+ topo.column_indices_t,
+ topo.offsets_t,
+ topo.block_offsets_t)
diff --git a/build/torch211-cxx11-cu126-x86_64-linux/stk/ops/linear_ops_test.py b/build/torch211-cxx11-cu126-x86_64-linux/stk/ops/linear_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1c24d350df9c1b2346c7da885502cd696c88867
--- /dev/null
+++ b/build/torch211-cxx11-cu126-x86_64-linux/stk/ops/linear_ops_test.py
@@ -0,0 +1,216 @@
+import unittest
+import itertools
+import numpy as np
+import torch
+# from absl.testing import parameterized
+
+import stk
+
+
+def allclose(x, y, pct=0.25):
+ mask = torch.isclose(x, y, rtol=5e-2)
+ pct_diff = (mask.numel() - mask.sum()) / mask.numel() * 100
+ if pct_diff > pct:
+ print("{:.2f}% of values not close.".format(pct_diff))
+ return False
+ return True
+
+
+# An assortment of problems designed to make sure
+# the bindings are operating correctly.
+_MATRIX_SIZES = (
+ (128, 128, 128, 0.0),
+ (256, 256, 256, 0.5),
+ (2048, 1024, 512, 0.8),
+ (512, 128, 128, 0.0),
+ (128, 128, 512, 0.0),
+ (1024, 512, 512, 0.0),
+ (1024, 512, 512, 0.5),
+ (1024, 512, 512, 0.75),
+ (512, 512, 1024, 0.0),
+ (512, 512, 1024, 0.5),
+ (512, 512, 1024, 0.75),
+ (1024, 1024, 1024, 0.0),
+ (1024, 1024, 1024, 0.5),
+ (1024, 1024, 1024, 0.75),
+)
+
+_TRANSPOSE = (
+ (False, False),
+ (False, True),
+ (True, False),
+ (True, True),
+)
+
+_DTYPE = (
+ torch.float16, torch.bfloat16
+)
+
+def _generate_testcases():
+ testcases = itertools.product(_MATRIX_SIZES, _TRANSPOSE, _DTYPE)
+ testcases = [(*size, *trans, 128, dtype) for
+ (size, trans, dtype) in testcases]
+ return testcases
+
+_LINEAR_OP_TESTS = _generate_testcases()
+
+def _dense_and_sparse(rows, cols, sparsity, blocking, dtype, std=0.1):
+ mask = stk.random.dense_mask(rows, cols, sparsity, blocking)
+ dense = (torch.randn(rows, cols) * std * mask).type(dtype)
+ sparse = stk.ops.to_sparse(dense, blocking)
+ cuda_device = torch.device("cuda")
+ return (dense.to(cuda_device).requires_grad_(True),
+ sparse.to(cuda_device).requires_grad_(True))
+
+
+def _dense(rows, cols, dtype, std=0.1):
+ cuda_device = torch.device("cuda")
+ out = (torch.randn(rows, cols) * std).type(dtype)
+ return out.to(cuda_device).requires_grad_(True)
+
+
+def _dense_2x(rows, cols, dtype):
+ a = _dense(rows, cols, dtype)
+ return a, a.detach().requires_grad_(True)
+
+
+def _with_transpose(op, a, b, trans_a, trans_b):
+ a = a.t() if trans_a else a
+ b = b.t() if trans_b else b
+ return op(a, b)
+
+
+def _mmm(a, b, topo):
+ mask = stk.ops.to_dense(stk.ops.ones_like(topo))
+ return torch.mm(a, b) * mask
+
+
+def _sparse_out_with_transpose(op, a, b, topo, trans_a, trans_b):
+ a = a.t() if trans_a else a
+ b = b.t() if trans_b else b
+ return op(a, b, topo)
+
+
+def _mask(x, mask):
+ mask = stk.ops.to_dense(stk.ops.ones_like(mask))
+ return x * mask
+
+
+# @parameterized.parameters(*_LINEAR_OP_TESTS)
+# class LinearOpsTest(parameterized.TestCase):
+#
+# def testLinearOps_Dsd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+# # Construct the operands.
+# a_shape = (k, m) if trans_a else (m, k)
+# a_dense, a = _dense_and_sparse(*a_shape, sparsity, blocking, dtype)
+# b_shape = (n, k) if trans_b else (k, n)
+# b, bcp = _dense_2x(*b_shape, dtype)
+#
+# # Execute the matmul.
+# out = _with_transpose(stk.ops.dsd, a, b, trans_a, trans_b)
+# expected_out = _with_transpose(torch.mm, a_dense, bcp, trans_a, trans_b)
+#
+# # Compute the gradients w.r.t. the inputs.
+# expected_out.sum().backward()
+# out.sum().backward()
+#
+# # Validate the results.
+# self.assertEqual(out.dim(), 2)
+# self.assertEqual(expected_out.size()[0], out.size()[0])
+# self.assertEqual(expected_out.size()[1], out.size()[1])
+# self.assertTrue(allclose(out, expected_out))
+#
+# # LHS gradient.
+# grad = stk.ops.to_dense(a.grad)
+# expected_grad = _mask(a_dense.grad, a.grad)
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# # RHS gradient.
+# grad = b.grad
+# expected_grad = bcp.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# def testLinearOps_Dds(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+# # Construct the operands.
+# a_shape = (k, m) if trans_a else (m, k)
+# a, acp = _dense_2x(*a_shape, dtype)
+# b_shape = (n, k) if trans_b else (k, n)
+# b_dense, b = _dense_and_sparse(*b_shape, sparsity, blocking, dtype)
+#
+# # Execute the matmul.
+# out = _with_transpose(stk.ops.dds, a, b, trans_a, trans_b)
+# expected_out = _with_transpose(torch.mm, acp, b_dense, trans_a, trans_b)
+#
+# # Compute the gradients w.r.t. the inputs.
+# expected_out.sum().backward()
+# out.sum().backward()
+#
+# # Validate the results.
+# self.assertEqual(out.dim(), 2)
+# self.assertEqual(expected_out.size()[0], out.size()[0])
+# self.assertEqual(expected_out.size()[1], out.size()[1])
+# self.assertTrue(allclose(out, expected_out))
+#
+# # LHS gradient.
+# grad = a.grad
+# expected_grad = acp.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# # RHS gradient.
+# grad = stk.ops.to_dense(b.grad)
+# expected_grad = _mask(b_dense.grad, b.grad)
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# def testLinearOps_Sdd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+# # Construct the operands.
+# a_shape = (k, m) if trans_a else (m, k)
+# a, acp = _dense_2x(*a_shape, dtype)
+# b_shape = (n, k) if trans_b else (k, n)
+# b, bcp = _dense_2x(*b_shape, dtype)
+# _, topo = _dense_and_sparse(m, n, sparsity, blocking, dtype)
+#
+# # Execute the matmul.
+# out = _sparse_out_with_transpose(stk.ops.sdd, a, b, topo, trans_a, trans_b)
+# expected_out = _sparse_out_with_transpose(_mmm, acp, bcp, topo, trans_a, trans_b)
+#
+# # Compute the gradients w.r.t. the inputs.
+# expected_out.sum().backward()
+# stk.ops.sum(out).backward()
+#
+# # Validate the results.
+# out = stk.ops.to_dense(out)
+# self.assertEqual(out.dim(), 2)
+# self.assertEqual(expected_out.size()[0], out.size()[0])
+# self.assertEqual(expected_out.size()[1], out.size()[1])
+# self.assertTrue(allclose(out, expected_out))
+#
+# # LHS gradient.
+# grad = a.grad
+# expected_grad = acp.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# # RHS gradient.
+# grad = b.grad
+# expected_grad = bcp.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch211-cxx11-cu126-x86_64-linux/stk/ops/matrix_ops.py b/build/torch211-cxx11-cu126-x86_64-linux/stk/ops/matrix_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..447c72dc73439d84f58c917676cc04e64f13e97d
--- /dev/null
+++ b/build/torch211-cxx11-cu126-x86_64-linux/stk/ops/matrix_ops.py
@@ -0,0 +1,98 @@
+from ..backend import sputnik
+from ..matrix import Matrix
+import torch
+import numpy as np
+
+
+@torch.no_grad()
+def row_indices(shape, data, offsets, column_indices):
+ return sputnik.row_indices(shape, data, offsets, column_indices)
+
+
+# TODO(tgale): Replace this helper with a custom kernel. This operation
+# is much simpler to do than how it's currently implemented.
+@torch.no_grad()
+def _expand_for_blocking(idxs, blocking):
+ # Duplicate for block column dimension.
+ idxs = torch.reshape(idxs, [idxs.size()[0], 1, 2]).repeat(1, blocking, 1)
+
+ # Update the column indices.
+ idxs[:, :, 1] *= blocking
+ idxs[:, :, 1] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking])
+
+ # Duplicate for block row dimension.
+ idxs = torch.reshape(idxs, [idxs.size()[0], 1, blocking, 2])
+ idxs = idxs.repeat(1, blocking, 1, 1)
+
+ # Update the row indices.
+ idxs[:, :, :, 0] *= blocking
+ idxs[:, :, :, 0] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking, 1])
+ idxs = torch.reshape(idxs, [-1, 2])
+ return idxs
+
+
+# TODO(tgale): Add input type checking.
+@torch.no_grad()
+def to_dense(x):
+ assert isinstance(x, Matrix)
+
+ shape = (np.prod(x.shape[:-1]), x.shape[-1])
+ row_idxs = x.row_indices.type(torch.int32)
+ col_idxs = x.column_indices.type(torch.int32)
+ indices = _expand_for_blocking(torch.stack([row_idxs, col_idxs], dim=1), x.blocking)
+ indices = (indices[:, 0] * shape[1] + indices[:, 1]).type(torch.int64)
+
+ out = torch.zeros(shape[0] * shape[1], dtype=x.dtype, device=x.device)
+ out.scatter_(0, indices, x.data.flatten())
+ return out.reshape(x.size())
+
+
+@torch.no_grad()
+def _mask(x, blocking=1):
+ assert x.dim() == 2
+ assert x.size()[0] % blocking == 0
+ assert x.size()[1] % blocking == 0
+ block_rows = x.size()[0] // blocking
+ block_cols = x.size()[1] // blocking
+ x = torch.reshape(x, [block_rows, blocking, block_cols, blocking])
+ x = torch.sum(torch.abs(x), dim=(1, 3))
+ return x != 0
+
+
+# TODO(tgale): Add input type checking.
+@torch.no_grad()
+def to_sparse(x, blocking=1):
+ m = _mask(x, blocking)
+
+ # TODO(tgale): Set to appropriate type for input matrix.
+ row_nnzs = torch.sum(m, dim=1).type(torch.int32)
+ zeros = torch.zeros((1,), dtype=row_nnzs.dtype, device=row_nnzs.device)
+ offsets = torch.cat([zeros, torch.cumsum(row_nnzs, dim=0)])
+ offsets = offsets.type(torch.int32)
+
+ indices = torch.nonzero(m).type(torch.int16)
+ row_indices = indices[:, 0]
+ column_indices = indices[:, 1]
+
+ # Nonzero indices in the dense matrix.
+ nonzero_indices = torch.nonzero(m)
+ nonzero_indices = _expand_for_blocking(nonzero_indices, blocking)
+ nonzero_indices = nonzero_indices[:, 0] * x.size()[1] + nonzero_indices[:, 1]
+
+ # Gather the data and construct the sparse matrix.
+ data = torch.gather(x.flatten(), dim=0, index=nonzero_indices)
+ data = torch.reshape(data, [-1, blocking, blocking])
+ return Matrix(x.size(), data, row_indices, column_indices, offsets)
+
+
+@torch.no_grad()
+def ones_like(x):
+ return Matrix(x.size(),
+ torch.ones_like(x.data),
+ x.row_indices,
+ x.column_indices, x.offsets)
+
+
+def sum(x):
+ assert isinstance(x, Matrix)
+ return x.data.sum()
diff --git a/build/torch211-cxx11-cu126-x86_64-linux/stk/ops/matrix_ops_test.py b/build/torch211-cxx11-cu126-x86_64-linux/stk/ops/matrix_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..d172d921f6f08b0e4fb709207a458b0e1e071bd0
--- /dev/null
+++ b/build/torch211-cxx11-cu126-x86_64-linux/stk/ops/matrix_ops_test.py
@@ -0,0 +1,62 @@
+import unittest
+
+# from absl.testing import parameterized
+import stk
+import torch
+
+
+# @parameterized.parameters(
+# (8, 16, 0.0, 1),
+# (8, 16, 0.5, 1),
+# (8, 16, .95, 1),
+# (16, 8, 0.0, 1),
+# (16, 8, 0.5, 1),
+# (16, 8, .95, 1),
+# (8, 16, 0.0, 8),
+# (8, 16, 0.5, 8),
+# (8, 16, 1.0, 8),
+# (16, 8, 0.0, 8),
+# (16, 8, 0.5, 8),
+# (16, 8, 1.0, 8),
+# (128, 256, 0.5, 16),
+# (256, 128, 0.75, 32),
+# (512, 512, .875, 128))
+# class MatrixOpsTest(parameterized.TestCase):
+#
+# def testMatrixOps_FormatConversion(self, rows, cols, sparsity, blocking):
+# mask = stk.random.dense_mask(rows, cols, sparsity, blocking)
+# x = (torch.randn(rows, cols) * mask).type(torch.float16)
+#
+# # Convert the matrix to sparse format.
+# sparse_x = stk.ops.to_sparse(x, blocking)
+#
+# # Validate the matrix.
+# sparse_x.validate()
+#
+# # Validate the shape.
+# self.assertEqual(sparse_x.dim(), 2)
+# self.assertEqual(sparse_x.size()[0], rows)
+# self.assertEqual(sparse_x.size()[1], cols)
+#
+# # Validate the sparsity.
+# numblocks = rows // blocking * cols // blocking
+# nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+# self.assertEqual(sparse_x.nnz, nnz)
+#
+# # Convert back to dense format.
+# dense_x = stk.ops.to_dense(sparse_x)
+#
+# # Validate the shape.
+# self.assertEqual(dense_x.dim(), 2)
+# self.assertEqual(dense_x.size()[0], rows)
+# self.assertEqual(dense_x.size()[1], cols)
+#
+# # Validate the sparsity
+# self.assertEqual(torch.count_nonzero(dense_x).item(), nnz)
+#
+# # Validate the output.
+# self.assertTrue(torch.all(torch.eq(x, dense_x)))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch211-cxx11-cu126-x86_64-linux/stk/random/__init__.py b/build/torch211-cxx11-cu126-x86_64-linux/stk/random/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2576d1ca27283f77569a9a620c7c99fa68aaf30e
--- /dev/null
+++ b/build/torch211-cxx11-cu126-x86_64-linux/stk/random/__init__.py
@@ -0,0 +1,2 @@
+# from stk.random.random_ops import dense_mask, mask, randn
+from .random_ops import dense_mask, mask, randn
diff --git a/build/torch211-cxx11-cu126-x86_64-linux/stk/random/random_ops.py b/build/torch211-cxx11-cu126-x86_64-linux/stk/random/random_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1b36771e0eb8e7abf46bcb3b136b5fb1d29df93
--- /dev/null
+++ b/build/torch211-cxx11-cu126-x86_64-linux/stk/random/random_ops.py
@@ -0,0 +1,36 @@
+import numpy as np
+import torch
+from ..ops import matrix_ops
+
+
+@torch.no_grad()
+def dense_mask(rows, cols, sparsity, blocking=1):
+ assert sparsity >= 0.0 and sparsity <= 1.0
+ assert rows % blocking == 0 and cols % blocking == 0
+
+ block_rows, block_cols = (rows // blocking, cols // blocking)
+ nnz = round(block_rows * block_cols * (1 - sparsity))
+
+ out = np.ones(block_rows * block_cols)
+ mask = np.random.choice(out.size, out.size - nnz, replace=False)
+ out[mask] = 0.0
+
+ out = np.tile(
+ np.reshape(out, [block_rows, 1, block_cols, 1]),
+ (1, blocking, 1, blocking))
+ out = np.reshape(out, [rows, cols])
+ return torch.from_numpy(out.astype(np.float32))
+
+
+@torch.no_grad()
+def mask(m, n, sparsity, blocking=1):
+ out = dense_mask(m, n, sparsity, blocking).type(torch.float16)
+ return matrix_ops.to_sparse(out, blocking=blocking)
+
+
+@torch.no_grad()
+def randn(shape, sparsity, blocking=1):
+ shape_2d = (np.prod(shape[:-1]), shape[-1])
+ out = mask(*shape_2d, sparsity, blocking)
+ out.data.copy_(torch.randn(*out.data.shape))
+ return out.view(*shape)
diff --git a/build/torch211-cxx11-cu126-x86_64-linux/stk/random/random_ops_test.py b/build/torch211-cxx11-cu126-x86_64-linux/stk/random/random_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..71d716b78b5ec009cbf9ac2dfdf09162a0102e62
--- /dev/null
+++ b/build/torch211-cxx11-cu126-x86_64-linux/stk/random/random_ops_test.py
@@ -0,0 +1,73 @@
+import unittest
+
+# from absl.testing import parameterized
+from . import random
+import torch
+
+
+# @parameterized.parameters(
+# (8, 16, 0.0, 1),
+# (8, 16, 0.5, 1),
+# (8, 16, .95, 1),
+# (16, 8, 0.0, 1),
+# (16, 8, 0.5, 1),
+# (16, 8, .95, 1),
+# (8, 16, 0.0, 8),
+# (8, 16, 0.5, 8),
+# (8, 16, 1.0, 8),
+# (16, 8, 0.0, 8),
+# (16, 8, 0.5, 8),
+# (16, 8, 1.0, 8),
+# (128, 256, 0.5, 16),
+# (256, 128, 0.75, 32),
+# (512, 512, .875, 128))
+# class RandomOpsTest(parameterized.TestCase):
+#
+# def testRandomOps_DenseMask(self, rows, cols, sparsity, blocking):
+# mask = random.dense_mask(
+# rows, cols, sparsity, blocking)
+#
+# # Validate the shape.
+# self.assertEqual(mask.dim(), 2)
+# self.assertEqual(mask.size()[0], rows)
+# self.assertEqual(mask.size()[1], cols)
+#
+# # Validate the sparsity
+# numblocks = rows // blocking * cols // blocking
+# nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+# self.assertEqual(
+# torch.count_nonzero(mask).item(),
+# nnz)
+#
+# # Check values are zero or one.
+# self.assertTrue(
+# torch.all(torch.logical_or(
+# torch.eq(mask, 0),
+# torch.eq(mask, 1))))
+#
+# def testRandomOps_SparseMask(self, rows, cols, sparsity, blocking):
+# mask = random.mask(
+# rows, cols, sparsity, blocking)
+#
+# # Validate the matrix.
+# mask.validate()
+#
+# # Validate the shape.
+# self.assertEqual(mask.dim(), 2)
+# self.assertEqual(mask.size()[0], rows)
+# self.assertEqual(mask.size()[1], cols)
+#
+# # Validate the sparsity.
+# numblocks = rows // blocking * cols // blocking
+# nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+# self.assertEqual(mask.nnz, nnz)
+#
+# # Check values are zero or one.
+# self.assertTrue(
+# torch.all(torch.logical_or(
+# torch.eq(mask.data, 0),
+# torch.eq(mask.data, 1))))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch211-cxx11-cu126-x86_64-linux/xpu_fused_moe.py b/build/torch211-cxx11-cu126-x86_64-linux/xpu_fused_moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..0440b484e6f4073a384dcf3bc562174d2ef71d25
--- /dev/null
+++ b/build/torch211-cxx11-cu126-x86_64-linux/xpu_fused_moe.py
@@ -0,0 +1,672 @@
+# SPDX-License-Identifier: Apache-2.0
+# MegaBlocks XPU Fused MoE Implementation
+import os
+import torch
+
+from ._ops import ops, add_op_namespace_prefix
+
+from torch.library import register_fake
+
+
+def resolve_dtensor(weight: torch.Tensor):
+ """Convert DTensor to local tensor for use with custom ops."""
+ from torch.distributed._tensor import DTensor
+ if isinstance(weight, DTensor):
+ return weight.to_local()
+ return weight
+
+
+# Register fake/meta kernels for torch.compile compatibility
+def _register_xpu_fake_kernels():
+ """Register fake kernels for XPU MoE operations to support torch.compile."""
+
+ def _register_if_available(op_name, fn):
+ if hasattr(ops, op_name):
+ register_fake(add_op_namespace_prefix(op_name))(fn)
+
+ _register_if_available(
+ "cutlass_grouped_gemm_interface",
+ lambda ptr_A, ptr_B, ptr_scales, ptr_bias, ptr_D, expert_first_token_offset, N, K, num_experts, is_B_int4, is_B_mxfp4: ptr_D,
+ )
+
+ _register_if_available(
+ "fused_moe_prologue",
+ lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, ep_rank, ep_size, num_experts_on_rank: None,
+ )
+
+ _register_if_available(
+ "moe_gather",
+ lambda output, moe_output, topk_weights, permuted_row_to_unpermuted_row, unpermuted_row_to_permuted_row, expert_first_token_offset, num_experts: None,
+ )
+
+ _register_if_available(
+ "silu_and_mul",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "mul_and_silu",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_and_mul",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_tanh_and_mul",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_fast",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_new",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_quick",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "swigluoai_and_mul",
+ lambda out, input, alpha=1.702, limit=7.0: None,
+ )
+
+
+# Register fake kernels on module load
+_register_xpu_fake_kernels()
+
+
+# default
+def cutlass_grouped_gemm(input_A, input_B, bias, output, expert_token_count, n,
+ k, num_experts):
+ # expert_token_count_ = torch.tensor(expert_token_count,
+ # dtype=torch.int64,
+ # device=input_A.device)
+ # if bias is not None:
+ # bias = bias.repeat_interleave(expert_token_count_, dim=0).float()
+
+ def exclusive_prefix_sum(arr):
+ prefix = [0]
+ for i, x in enumerate(arr):
+ prefix.append(prefix[-1] + x)
+ return prefix
+
+ expert_offset = torch.tensor(exclusive_prefix_sum(expert_token_count),
+ dtype=torch.int64,
+ device="xpu")
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=None,
+ ptr_bias=bias,
+ ptr_D=output,
+ expert_first_token_offset=expert_offset,
+ N=n,
+ K=k,
+ num_experts=num_experts,
+ is_B_int4=False,
+ is_B_mxfp4=False)
+
+
+def cutlass_grouped_gemm_xe2(input_A, input_B, scales, bias, output,
+ num_rows_per_expert, n, k, num_experts, is_B_int4,
+ is_B_mxfp4):
+ expert_first_token_offset = torch.cat([
+ torch.tensor([0],
+ dtype=num_rows_per_expert.dtype,
+ device=num_rows_per_expert.device),
+ torch.cumsum(num_rows_per_expert, dim=0)
+ ]).to(torch.int64)
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=scales,
+ ptr_bias=bias,
+ ptr_D=output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=n,
+ K=k,
+ num_experts=num_experts,
+ is_B_int4=is_B_int4,
+ is_B_mxfp4=is_B_mxfp4)
+
+
+def ceilDiv(a, b):
+ return (a + b - 1) // b
+
+
+def compute_num_tokens_per_block(num_tokens, num_experts_per_node):
+ for num_tokens_per_block in [32, 64, 128, 256, 512, 1024]:
+ num_blocks_per_seq = ceilDiv(num_tokens, num_tokens_per_block)
+ if num_blocks_per_seq * num_experts_per_node <= num_tokens_per_block:
+ return num_tokens_per_block
+ return 1024
+
+
+def _bytes_to_typed_tensor(byte_tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
+ """Reinterpret a uint8 buffer as a typed tensor by copying bytes.
+
+ This avoids `Tensor.view(dtype)` which can fail under torch.compile
+ constant folding when shape divisibility is not proven.
+ """
+ if byte_tensor.dtype != torch.uint8:
+ raise ValueError("byte_tensor must be uint8")
+ itemsize = torch.empty((), dtype=dtype).element_size()
+ numel = byte_tensor.numel() // itemsize
+ out = torch.empty((numel,), dtype=dtype, device=byte_tensor.device)
+ out.view(torch.uint8).copy_(byte_tensor.contiguous())
+ return out
+
+
+def implement_zp(qweight):
+ # change u4 to s4 to avoid zero point in gemm kernel
+ # only support default zero point now
+ assert qweight.dtype == torch.uint8, "Input tensor must be uint8"
+
+ high_u4 = (qweight >> 4) & 0x0F
+ low_u4 = qweight & 0x0F
+
+ high_s8 = high_u4.to(torch.int8)
+ low_s8 = low_u4.to(torch.int8)
+
+ high_s8 = high_s8 - 8
+ low_s8 = low_s8 - 8
+
+ def pack_compact(a, b):
+
+ def process_number(x):
+ sign = (x < 0).to(torch.uint8)
+ abs_low3 = (x.view(torch.uint8) & 0x7).to(torch.uint8)
+ return (sign << 3) | abs_low3
+
+ packed_a = process_number(a)
+ packed_b = process_number(b)
+
+ return (packed_a << 4) | packed_b
+
+ result = pack_compact(high_s8, low_s8)
+
+ return result
+
+
+def xpu_fused_moe(hidden_states,
+ w13,
+ w13_scales,
+ w13_bias,
+ w2,
+ w2_scales,
+ w2_bias,
+ topk_weights,
+ topk_ids,
+ n_experts_per_token,
+ activation,
+ num_experts,
+ ep_rank=0,
+ ep_size=1,
+ is_fp8=False,
+ is_int4=False,
+ is_mxfp4=False):
+ '''
+ hidden_states: [num_rows, hidden_size]
+ w13: [num_experts, 2*inter_size, hidden_size]
+ w13_scales:
+ None for bf16/fp16
+ or [num_experts] for fp8
+ or [num_experts, 2*inter_size, hidden_size // group_size] for 4bits
+ w13_bias: [num_experts, 2*inter_size] or None
+ w2: [num_experts, hidden_size, inter_size]
+ w2_scales:
+ None for bf16/fp16
+ or [num_experts] for fp8
+ or [num_experts, hidden_size, inter_size // group_size] for 4bits
+ w2_bias: [num_experts, hidden_size] or None
+ topk_weights: [num_rows, topk]
+ topk_ids: [num_rows, topk]
+ n_experts_per_token: int
+ activation: str
+ num_experts: int
+ is_int4: bool
+ is_mxfp4: bool
+ '''
+
+ # Resolve DTensors to local tensors before passing to custom ops
+ hidden_states = resolve_dtensor(hidden_states)
+ w13 = resolve_dtensor(w13)
+ w2 = resolve_dtensor(w2)
+ if w13_scales is not None:
+ w13_scales = resolve_dtensor(w13_scales)
+ if w13_bias is not None:
+ w13_bias = resolve_dtensor(w13_bias)
+ if w2_scales is not None:
+ w2_scales = resolve_dtensor(w2_scales)
+ if w2_bias is not None:
+ w2_bias = resolve_dtensor(w2_bias)
+ topk_weights = resolve_dtensor(topk_weights)
+ topk_ids = resolve_dtensor(topk_ids)
+
+ output = torch.empty_like(hidden_states)
+ num_rows, hidden_size = list(hidden_states.shape)
+
+ dim_last = w13.shape[-1]
+ dim_second_last = w13.shape[-2]
+
+ # w13 is combined gate+up weights, so one dimension is 2*inter_size
+ # Determine which dimension is hidden_size and which is 2*inter_size
+ if dim_second_last == hidden_size:
+ # w13 is [E, hidden_size, 2*inter_size] - standard layout
+ inter_size = dim_last // 2
+ needs_transpose = False
+ else:
+ # w13 is [E, 2*inter_size, hidden_size] - needs transpose
+ inter_size = dim_second_last // 2
+ needs_transpose = True
+
+ assert w13.is_contiguous() and w2.is_contiguous()
+
+ # 4bits support [E, N, K]
+ # other types [E, K, N]
+ if not is_int4 and not is_mxfp4:
+ if not hasattr(w13, 'xpu_fused_moe'):
+ if needs_transpose:
+ w13.data = w13.transpose(-1, -2).contiguous()
+ w2.data = w2.transpose(-1, -2).contiguous()
+ w13.xpu_fused_moe = True
+ w13.inter_size = inter_size
+ else:
+ inter_size = w13.inter_size
+
+ if is_int4 and not hasattr(w13, 'xpu_fused_moe'):
+ w13_tmp = torch.empty_like(w13)
+ w2_tmp = torch.empty_like(w2)
+ for i in range(num_experts):
+ w13_tmp[i] = implement_zp(w13[i])
+ w2_tmp[i] = implement_zp(w2[i])
+ w13_tmp = w13_tmp.contiguous()
+ w2_tmp = w2_tmp.contiguous()
+ w13.data = w13_tmp
+ w2.data = w2_tmp
+ w13.xpu_fused_moe = True
+
+ # TODO: will all integrated in Cpp func. Temporary expose before gemm fusion
+ num_experts_per_node = num_experts
+ experts_per_token = n_experts_per_token
+ num_moe_inputs = n_experts_per_token * num_rows
+ permuted_elems = num_moe_inputs * hidden_size
+ # interbuf_elems = num_moe_inputs * inter_size
+ permuted_row_to_unpermuted_row_size = num_moe_inputs * 4
+ permuted_token_selected_experts_size = num_moe_inputs * 4
+ src_to_dest_map_size = experts_per_token * num_rows * 4
+ expert_first_token_offset_size = (num_experts_per_node + 1) * 8
+ num_tokens_per_block = compute_num_tokens_per_block(
+ num_rows, num_experts_per_node)
+ num_blocks_per_seq = ceilDiv(num_rows, num_tokens_per_block)
+ blocked_expert_counts_size = num_experts_per_node * num_blocks_per_seq * 4
+ blocked_expert_counts_cumsum_size = blocked_expert_counts_size
+ blocked_row_to_unpermuted_row_size = num_experts_per_node * num_rows * 4
+ permuted_data_size = permuted_elems * hidden_states.element_size()
+ permuted_token_final_scales_size = num_moe_inputs * 4
+
+ ws_map = {}
+ map_offset = 0
+
+ def config_ws(name, size):
+ nonlocal map_offset
+ if size % 256 != 0:
+ size += 256 - size % 256
+ ws_map[name] = (size, map_offset)
+ map_offset += size
+
+ config_ws("permuted_row_to_unpermuted_row",
+ permuted_row_to_unpermuted_row_size)
+ config_ws("permuted_token_selected_experts",
+ permuted_token_selected_experts_size)
+ config_ws("unpermuted_row_to_permuted_row", src_to_dest_map_size)
+ config_ws("blocked_expert_counts", blocked_expert_counts_size)
+ config_ws("blocked_expert_counts_cumsum",
+ blocked_expert_counts_cumsum_size)
+ config_ws("blocked_row_to_unpermuted_row",
+ blocked_row_to_unpermuted_row_size)
+ config_ws("expert_first_token_offset", expert_first_token_offset_size)
+ config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
+ config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
+
+ workspace = torch.zeros(map_offset,
+ dtype=torch.uint8,
+ device=hidden_states.device)
+ if topk_ids.dtype == torch.int32:
+ topk_ids = topk_ids.to(torch.int64)
+ ops.fused_moe_prologue(
+ input=hidden_states,
+ token_selected_experts=topk_ids,
+ token_final_scales=topk_weights,
+ workspace=workspace,
+ hidden_size=hidden_size,
+ inter_size=inter_size,
+ ep_rank=ep_rank,
+ ep_size=ep_size,
+ num_experts_on_rank=num_experts_per_node)
+
+ expert_first_token_offset_bytes = workspace[
+ ws_map["expert_first_token_offset"][1]:
+ ws_map["expert_first_token_offset"][1] +
+ expert_first_token_offset_size]
+ unpermuted_row_to_permuted_row_bytes = workspace[
+ ws_map["unpermuted_row_to_permuted_row"][1]:
+ ws_map["unpermuted_row_to_permuted_row"][1] +
+ src_to_dest_map_size]
+ permuted_row_to_unpermuted_row_bytes = workspace[
+ ws_map["permuted_row_to_unpermuted_row"][1]:
+ ws_map["permuted_row_to_unpermuted_row"][1] +
+ permuted_row_to_unpermuted_row_size]
+
+ if torch.compiler.is_compiling():
+ expert_first_token_offset = _bytes_to_typed_tensor(
+ expert_first_token_offset_bytes, torch.int64
+ )
+ unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
+ unpermuted_row_to_permuted_row_bytes, torch.int32
+ )
+ permuted_row_to_unpermuted_row = _bytes_to_typed_tensor(
+ permuted_row_to_unpermuted_row_bytes, torch.int32
+ )
+ else:
+ expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
+ unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
+ permuted_row_to_unpermuted_row = permuted_row_to_unpermuted_row_bytes.view(torch.int32)
+ gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
+ ws_map["overlapped_gemm1_gemm2_inputs"][1] +
+ permuted_data_size].view(hidden_states.dtype).view(
+ num_moe_inputs, hidden_size)
+ # permuted_token_final_scales = workspace[
+ # ws_map["permuted_token_final_scales"][1]:
+ # ws_map["permuted_token_final_scales"][1] +
+ # permuted_token_final_scales_size].view(torch.float)
+ gemm1_output = torch.empty((num_moe_inputs, 2 * inter_size),
+ dtype=hidden_states.dtype,
+ device=hidden_states.device)
+
+ ########### gemm1 ##################
+ input_B = w13
+
+ if not is_fp8 and not is_int4 and not is_mxfp4:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=gemm1_input,
+ ptr_B=input_B,
+ ptr_scales=None,
+ ptr_bias=w13_bias,
+ ptr_D=gemm1_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=2 * inter_size,
+ K=hidden_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+ else:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=gemm1_input,
+ ptr_B=input_B,
+ ptr_scales=w13_scales,
+ ptr_bias=w13_bias,
+ ptr_D=gemm1_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=2 * inter_size,
+ K=hidden_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+
+ # act
+ act_output = torch.empty((num_moe_inputs, inter_size),
+ dtype=gemm1_output.dtype,
+ device=gemm1_output.device)
+ if activation == "silu":
+ ops.silu_and_mul(act_output, gemm1_output)
+ elif activation == "gelu":
+ ops.gelu_and_mul(act_output, gemm1_output)
+ elif activation == "swigluoai":
+ ops.swigluoai_and_mul(act_output, gemm1_output, 1.702, 7.0)
+ else:
+ raise ValueError(f"Unsupported FusedMoe activation: {activation}.")
+
+ ########### gemm2 ##################
+ input_A = act_output.contiguous()
+ input_B = w2
+ gemm2_output = torch.empty((num_moe_inputs, hidden_size),
+ dtype=hidden_states.dtype,
+ device=hidden_states.device)
+ if not is_fp8 and not is_int4 and not is_mxfp4:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=None,
+ ptr_bias=w2_bias,
+ ptr_D=gemm2_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=hidden_size,
+ K=inter_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+ else:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=w2_scales,
+ ptr_bias=w2_bias,
+ ptr_D=gemm2_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=hidden_size,
+ K=inter_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+
+ ops.moe_gather(output, gemm2_output, topk_weights,
+ permuted_row_to_unpermuted_row,
+ unpermuted_row_to_permuted_row,
+ expert_first_token_offset,
+ num_experts_per_node)
+ return output
+
+
+def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor:
+ """Apply jitter to the input tensor for regularization."""
+ low = 1.0 - moe_jitter_eps
+ high = 1.0 + moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return x * (low + noise * (high - low))
+
+
+def compute_top_k(scores: torch.Tensor, moe_top_k: int):
+ """Compute the top-k scores from the logits."""
+ if moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, moe_top_k, dim=-1)
+
+
+def route_tokens_xpu(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: torch.Tensor,
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ training: bool = False,
+) -> tuple:
+ """Route tokens to experts and compute expert weights and indices (XPU version)."""
+ if training and moe_jitter_eps is not None:
+ x = apply_jitter(x, moe_jitter_eps)
+
+ x_flat = x.view(-1, x.shape[-1])
+ logits = torch.nn.functional.linear(x_flat, router_weight, router_bias)
+ expert_weights, expert_indices = compute_top_k(logits, moe_top_k)
+ expert_weights = expert_weights.softmax(dim=-1)
+ if moe_normalize_expert_weights is not None:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+
+ return logits, expert_weights, expert_indices
+
+
+def _get_device_mesh(model):
+ """Extract device_mesh from child's unused pre_hook closure for EP support."""
+ try:
+ hook = next(
+ h
+ for h in model.experts._forward_pre_hooks.values()
+ if "device_mesh" in h.__code__.co_freevars
+ )
+ return hook.__closure__[
+ hook.__code__.co_freevars.index("device_mesh")
+ ].cell_contents
+ except Exception:
+ return None
+
+
+class MegaBlocksMoeMLP(torch.nn.Module):
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> tuple:
+ """
+ Forward pass through the MoE layer.
+
+ Args:
+ x: Input tensor of shape [batch_size, seq_len, hidden_size] or [tokens, hidden_size]
+
+ Returns:
+ Tuple of (output, expert_weights) where:
+ - output: Tensor of same shape as input
+ - expert_weights: Expert weights for each token [tokens, top_k]
+ """
+
+ # Get MoE parameters from the wrapped modules
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+
+ # Get EP (Expert Parallelism) parameters
+ ep_size = 1
+ ep_rank = 0
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = _get_device_mesh(self)
+ if device_mesh is not None:
+ expert_parallel_group = device_mesh.get_group()
+ if expert_parallel_group is not None:
+ import torch.distributed as dist
+ if dist.is_initialized():
+ ep_size = dist.get_world_size(expert_parallel_group)
+ ep_rank = dist.get_rank(expert_parallel_group)
+
+ # Number of experts on this rank
+ num_experts_on_rank = moe_num_experts // ep_size
+
+ # Detect activation type - check for GptOss-style swigluoai activation
+ # GptOssExperts has alpha and limit attributes for swigluoai
+ if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
+ activation = "swigluoai"
+ else:
+ activation = getattr(self.experts, "activation", "silu")
+
+ # Get weight tensors - support different naming conventions
+ if hasattr(self.experts, "gate_up_proj"):
+ w13 = self.experts.gate_up_proj
+ # NOTE: swigluoai_and_mul kernel expects interleaved layout [g0,u0,g1,u1,...]
+ # which matches GptOss's gate_up_proj layout, so no conversion needed.
+
+ elif hasattr(self.experts, "w1"):
+ # Combine w1 and w3 if stored separately
+ w1 = self.experts.w1
+ w3 = getattr(self.experts, "w3", None)
+ if w3 is not None:
+ w13 = torch.cat([w1, w3], dim=-2)
+ else:
+ w13 = w1
+ else:
+ raise AttributeError("experts module must have 'gate_up_proj' or 'w1' attribute")
+
+ if hasattr(self.experts, "down_proj"):
+ w2 = self.experts.down_proj
+ elif hasattr(self.experts, "w2"):
+ w2 = self.experts.w2
+ else:
+ raise AttributeError("experts module must have 'down_proj' or 'w2' attribute")
+
+ # Get optional bias tensors
+ w13_bias = getattr(self.experts, "gate_up_proj_bias", None)
+ w2_bias = getattr(self.experts, "down_proj_bias", None)
+
+ # Get quantization info
+ is_fp8 = getattr(self.experts, "is_fp8", False)
+ is_int4 = getattr(self.experts, "is_int4", False)
+ is_mxfp4 = getattr(self.experts, "is_mxfp4", False)
+
+ w13_scales = getattr(self.experts, "gate_up_proj_scales", None)
+ w2_scales = getattr(self.experts, "down_proj_scales", None)
+
+ # Store original shape
+ in_shape = x.size()
+
+ # Route tokens to experts
+ logits, expert_weights, expert_indices = route_tokens_xpu(
+ x,
+ self.router.weight,
+ getattr(self.router, "bias", None),
+ moe_top_k,
+ moe_num_experts,
+ moe_jitter_eps,
+ moe_normalize_expert_weights,
+ self.training,
+ )
+
+ # Reshape input for fused MoE
+ x_flat = x.view(-1, x.shape[-1])
+
+ # Call XPU fused MoE kernel
+ output = xpu_fused_moe(
+ hidden_states=x_flat,
+ w13=w13,
+ w13_scales=w13_scales,
+ w13_bias=w13_bias,
+ w2=w2,
+ w2_scales=w2_scales,
+ w2_bias=w2_bias,
+ topk_weights=expert_weights.float(),
+ topk_ids=expert_indices,
+ n_experts_per_token=moe_top_k,
+ activation=activation,
+ num_experts=num_experts_on_rank,
+ ep_rank=ep_rank,
+ ep_size=ep_size,
+ is_fp8=is_fp8,
+ is_int4=is_int4,
+ is_mxfp4=is_mxfp4,
+ )
+
+ # All-reduce across EP group to combine partial expert outputs
+ if ep_size > 1 and expert_parallel_group is not None:
+ import torch.distributed as dist
+ dist.all_reduce(output, op=dist.ReduceOp.SUM, group=expert_parallel_group)
+
+ # Restore original shape
+ output = output.view(in_shape)
+
+ return output, expert_weights
+
+
+# Export classes and functions
+__all__ = [
+ "MegaBlocksMoeMLP",
+ "xpu_fused_moe",
+ "cutlass_grouped_gemm",
+ "cutlass_grouped_gemm_xe2",
+]
\ No newline at end of file
diff --git a/build/torch211-cxx11-cu128-aarch64-linux/__init__.py b/build/torch211-cxx11-cu128-aarch64-linux/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..38075732c6d8fa0e1e6ef493145e1aca3851ae6b
--- /dev/null
+++ b/build/torch211-cxx11-cu128-aarch64-linux/__init__.py
@@ -0,0 +1,202 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+from ._ops import ops
+
+from .grouped_gemm import backend as gg_backend
+from .grouped_gemm import ops as gg_ops
+
+
+from ._layers.arguments import Arguments
+from ._layers.dmoe import ParallelDroplessMLP, dMoE
+from ._layers.glu import SparseGLU
+from ._layers.mlp import MLP, SparseMLP
+from ._layers.moe import MoE, ParallelMLP, get_load_balancing_loss
+
+from . import layers
+
+# This section contains the direct kernel exports (not inlcuded in the original code)
+def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
+ """
+ Compute exclusive cumulative sum along the specified dimension.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ result = ops.exclusive_cumsum(x, dim)
+ out.copy_(result)
+ return out
+
+
+def inclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
+ """
+ Compute inclusive cumulative sum along the specified dimension.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ result = ops.inclusive_cumsum(x, dim)
+ out.copy_(result)
+ return out
+
+
+def histogram(x: torch.Tensor, num_bins: int) -> torch.Tensor:
+ """
+ Compute histogram of input tensor values.
+
+ Args:
+ x: Input tensor
+ num_bins: Number of histogram bins
+
+ Returns:
+ Histogram tensor with counts for each bin
+ """
+ return ops.histogram(x, num_bins)
+
+
+def indices(
+ padded_bins: torch.Tensor,
+ block_size: int,
+ output_block_rows: int,
+ output_block_columns: int,
+) -> torch.Tensor:
+ """
+ Construct indices from padded bins for sparse operations.
+
+ Args:
+ padded_bins: Tensor containing bin boundaries
+ block_size: Size of each block
+ output_block_rows: Number of rows in output blocks
+ output_block_columns: Number of columns in output blocks
+
+ Returns:
+ Tensor containing constructed indices
+ """
+ return ops.indices(padded_bins, block_size, output_block_rows, output_block_columns)
+
+
+def replicate_forward(
+ x: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Forward pass of replicate operation - replicate values according to bin sizes.
+
+ Args:
+ x: Input tensor with values to replicate
+ bins: Tensor containing bin sizes
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ return ops.replicate_forward(x, bins, out)
+
+
+def replicate_backward(
+ grad: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Backward pass of replicate operation - reduce gradients back to bins.
+
+ Args:
+ grad: Gradient tensor to reduce
+ bins: Tensor containing bin sizes
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ return ops.replicate_backward(grad, bins, out)
+
+
+def sort(
+ x: torch.Tensor, end_bit: int, x_out: torch.Tensor, iota_out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Radix sort with index tracking.
+
+ Args:
+ x: Input tensor to sort
+ end_bit: Number of bits to consider in sorting
+ x_out: Output tensor for sorted values
+ iota_out: Output tensor for sorted indices
+
+ Returns:
+ The sorted values tensor
+ """
+ return ops.sort(x, end_bit, x_out, iota_out)
+
+
+# Convenience functions for common use cases
+def cumsum(x: torch.Tensor, dim: int = -1, exclusive: bool = False) -> torch.Tensor:
+ """
+ Compute cumulative sum with automatic output allocation.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum (default: last dimension)
+ exclusive: Whether to compute exclusive (True) or inclusive (False) cumsum
+
+ Returns:
+ New tensor containing the cumulative sum
+ """
+ out = torch.empty_like(x)
+ if exclusive:
+ return exclusive_cumsum(x, dim, out)
+ else:
+ return inclusive_cumsum(x, dim, out)
+
+
+def argsort(x: torch.Tensor, end_bit: int = 32) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Sort tensor and return both sorted values and indices.
+
+ Args:
+ x: Input tensor to sort
+ end_bit: Number of bits to consider in sorting
+
+ Returns:
+ Tuple of (sorted_values, sorted_indices)
+ """
+ x_out = torch.empty_like(x)
+ iota_out = torch.empty_like(x)
+ sort(x, end_bit, x_out, iota_out)
+ return x_out, iota_out
+
+
+# Export public API
+__all__ = [
+ "MyReplacementLayer",
+ # Direct kernel exports
+ "exclusive_cumsum",
+ "inclusive_cumsum",
+ "histogram",
+ "indices",
+ "replicate_forward",
+ "replicate_backward",
+ "sort",
+ "cumsum",
+ "argsort",
+ # Original exports
+ "Arguments",
+ "ParallelDroplessMLP",
+ "dMoE",
+ "SparseGLU",
+ "MLP",
+ "SparseMLP",
+ "MoE",
+ "ParallelMLP",
+ "get_load_balancing_loss",
+]
diff --git a/build/torch211-cxx11-cu128-aarch64-linux/_layers/__init__.py b/build/torch211-cxx11-cu128-aarch64-linux/_layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a720e7a2cc4e44636f6e433a2750e945dc38e8b2
--- /dev/null
+++ b/build/torch211-cxx11-cu128-aarch64-linux/_layers/__init__.py
@@ -0,0 +1,10 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# from megablocks.layers.dmoe import dMoE
+from .moe import MoE
+
+__all__ = [
+ 'MoE',
+ # 'dMoE',
+]
diff --git a/build/torch211-cxx11-cu128-aarch64-linux/_layers/activation_fn.py b/build/torch211-cxx11-cu128-aarch64-linux/_layers/activation_fn.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e1d956704840aa4daf7d1d71d24e051567feab9
--- /dev/null
+++ b/build/torch211-cxx11-cu128-aarch64-linux/_layers/activation_fn.py
@@ -0,0 +1,33 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Callable, Union
+
+import torch
+from ..stk import Matrix
+
+
+def act_fn(
+ x: Matrix,
+ function: Callable,
+ return_grad_fn: bool = False,
+ **kwargs,
+) -> Union[tuple[Matrix, Any] | Matrix]:
+ assert isinstance(x, Matrix)
+ with torch.set_grad_enabled(torch.is_grad_enabled() or return_grad_fn):
+ if return_grad_fn:
+ x.data.requires_grad = True
+ out = function(x.data, **kwargs)
+ y = Matrix(
+ x.size(),
+ out,
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
+ if return_grad_fn:
+ return y, out.backward
+ return y
diff --git a/build/torch211-cxx11-cu128-aarch64-linux/_layers/all_to_all.py b/build/torch211-cxx11-cu128-aarch64-linux/_layers/all_to_all.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ac7067bcaa34db1d82b340c43550fe3577aa7a3
--- /dev/null
+++ b/build/torch211-cxx11-cu128-aarch64-linux/_layers/all_to_all.py
@@ -0,0 +1,54 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import torch.distributed as dist
+
+
+class AllToAllOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op):
+ out = torch.empty((sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype)
+
+ ctx.input_shape = x.shape
+ ctx.output_split_sizes = output_split_sizes
+ ctx.input_split_sizes = input_split_sizes
+ ctx.group = group
+ handle = dist.all_to_all_single(
+ out,
+ x,
+ output_split_sizes=output_split_sizes,
+ input_split_sizes=input_split_sizes,
+ group=group,
+ async_op=async_op,
+ )
+ return out, handle
+
+ @staticmethod
+ def backward(ctx, grad, _):
+ if ctx.needs_input_grad[0]:
+ out = torch.empty(
+ ctx.input_shape,
+ device=grad.device,
+ dtype=grad.dtype,
+ )
+ dist.all_to_all_single(
+ out,
+ grad,
+ output_split_sizes=ctx.input_split_sizes,
+ input_split_sizes=ctx.output_split_sizes,
+ group=ctx.group,
+ )
+ return out, None, None, None, None
+ return None, None, None, None, None
+
+
+def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False):
+ return AllToAllOp.apply(
+ x,
+ output_split_sizes,
+ input_split_sizes,
+ group,
+ async_op,
+ )
diff --git a/build/torch211-cxx11-cu128-aarch64-linux/_layers/arguments.py b/build/torch211-cxx11-cu128-aarch64-linux/_layers/arguments.py
new file mode 100644
index 0000000000000000000000000000000000000000..4db9b1bd38bc2e2f421625c124f86b85f45c5ae0
--- /dev/null
+++ b/build/torch211-cxx11-cu128-aarch64-linux/_layers/arguments.py
@@ -0,0 +1,101 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import dataclasses
+from functools import partial
+from typing import Any, Callable, Optional, Union
+
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+
+# import megablocks.grouped_gemm_util as grouped_gemm
+from .. import grouped_gemm_util as grouped_gemm
+
+# Type annotation for in-place Tensor initialization function.
+InitFn = Union[Callable[[torch.Tensor], None], partial[torch.Tensor]]
+
+_ALLOWED_BITWIDTHS = (-1, 4, 8)
+
+DEFAULT_ACTIVATION_FN = partial(F.gelu, approximate='tanh')
+
+
+@dataclasses.dataclass
+class Arguments:
+ # Model arguments.
+ hidden_size: int = 1024
+ ffn_hidden_size: int = 4096
+ num_layers: int = 1
+ bias: bool = True
+ return_bias: bool = True
+ activation_fn: Optional[Callable] = DEFAULT_ACTIVATION_FN
+
+ # MoE arguments.
+ moe_num_experts: int = 1
+ moe_top_k: int = 1
+ moe_capacity_factor: int = 1
+ moe_normalize_expert_weights: Optional[Union[int, float]] = None
+ moe_loss_weight: float = 0.1
+ moe_jitter_eps: Optional[float] = None
+ moe_lbl_in_fp32: bool = False
+
+ # Parallelism arguments.
+ moe_expert_model_parallelism: bool = False
+ expert_parallel_group: Optional[dist.ProcessGroup] = None
+ pipeline_model_parallel_size: int = 1
+ num_layers_per_virtual_pipeline_stage: Optional[int] = None
+
+ # Compute arguments.
+ memory_optimized_mlp: bool = False
+ mlp_type: str = 'mlp'
+ mlp_impl: str = 'sparse'
+
+ # Initialization arguments.
+ fp16: bool = True
+ bf16: bool = False
+ device: Union[int, torch.device] = dataclasses.field(default_factory=torch.cuda.current_device)
+ init_method: InitFn = partial(torch.nn.init.normal_, mean=0.0, std=0.02)
+ output_layer_init_method: InitFn = init_method
+
+ # Benchmarking arguments.
+ uniform_expert_assignment: bool = False
+
+ # shared expert arguments
+ shared_expert: bool = False # enable using shared expert
+ fc_cls: Any = torch.nn.Linear # class of the fully connected layer in shared expert (purpose: to allow using custom FC layer eg te.Linear (for FP8))
+ fc_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict,) # kwargs for custom fc layers
+ remat_act_fn: bool = True # enable act fn to be rematerialized instead of stored
+ shared_expert_hidden_size: Optional[
+ int] = None # hidden size of the shared expert IF we want to set it to something different from hidden_size
+ shared_expert_weighted_sum: bool = False # enable using weighted sum for shared expert output (wieghted by number of experts used)
+
+ # Router Z-loss arguments
+ moe_zloss_weight: float = 0 # 1e-3 is a reasonable value
+ moe_zloss_in_fp32: bool = False
+
+ def __post_init__(self):
+ # Sparse MLP is not supported with triton >=3.2.0
+ # TODO: Remove this once sparse is supported with triton >=3.2.0
+ if self.__getattribute__('mlp_impl') == 'sparse':
+ try:
+ import triton
+ if triton.__version__ >= '3.2.0':
+ raise ValueError(
+ 'Sparse MLP is not supported with triton >=3.2.0. Please use mlp_impl="grouped" instead.',
+ )
+ except ImportError:
+ raise ImportError('Triton is required for sparse MLP implementation')
+
+ if self.__getattribute__('mlp_impl') == 'grouped':
+ grouped_gemm.assert_grouped_gemm_is_available()
+
+ if self.shared_expert_hidden_size is None:
+ self.shared_expert_hidden_size = self.ffn_hidden_size
+
+
+def from_megatron(megatron_args: Any):
+ args = Arguments()
+ for field in dataclasses.fields(args):
+ if hasattr(megatron_args, field.name):
+ setattr(args, field.name, getattr(megatron_args, field.name))
+ return args
diff --git a/build/torch211-cxx11-cu128-aarch64-linux/_layers/common.py b/build/torch211-cxx11-cu128-aarch64-linux/_layers/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d07109702963ba48a3b94ab860807954dfd79c1
--- /dev/null
+++ b/build/torch211-cxx11-cu128-aarch64-linux/_layers/common.py
@@ -0,0 +1,26 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+from .arguments import Arguments
+
+
+def dtype(args: Arguments):
+ if args.fp16:
+ return torch.float16
+ elif args.bf16:
+ return torch.bfloat16
+ return None
+
+
+def cast_if_autocast_enabled(tensor):
+ if torch.is_autocast_enabled():
+ if tensor.device.type == 'cuda':
+ dtype = torch.get_autocast_gpu_dtype()
+ elif tensor.device.type == 'cpu':
+ dtype = torch.get_autocast_cpu_dtype()
+ else:
+ raise NotImplementedError()
+ return tensor.to(dtype=dtype)
+ return tensor
diff --git a/build/torch211-cxx11-cu128-aarch64-linux/_layers/dmlp_registry.py b/build/torch211-cxx11-cu128-aarch64-linux/_layers/dmlp_registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..de2ed047042e438c7190ebb139b6f7f30009734c
--- /dev/null
+++ b/build/torch211-cxx11-cu128-aarch64-linux/_layers/dmlp_registry.py
@@ -0,0 +1,42 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Union
+
+from . import glu, mlp
+from .arguments import Arguments
+
+MlpType = Union[mlp.SparseMLP, glu.SparseGLU]
+
+_REGISTRY = {
+ 'mlp': {
+ 'grouped': mlp.GroupedMLP,
+ 'sparse': mlp.SparseMLP,
+ },
+ 'glu': {
+ 'grouped': glu.GroupedGLU,
+ 'sparse': glu.SparseGLU,
+ },
+}
+
+
+def get(args: Arguments) -> MlpType:
+ """Returns an MLP for use in a dMoE instance.
+
+ Uses the provided arguments to instantiate the appropriate
+ MLP instance. This only contains MLPs for use in dMoEs
+ (ie. only for the dropless versions of MoEs).
+
+ Args:
+ args: propagated Arguments dataclass.
+
+ Returns:
+ An instantiated MLP constructed using the input args.
+ """
+ if args.mlp_type not in _REGISTRY:
+ raise ValueError(f'Unsupported mlp type: {args.mlp_type}')
+
+ if args.mlp_impl not in _REGISTRY[args.mlp_type]:
+ raise ValueError(f'{args.mlp_type} does not support {args.mlp_impl} backend.',)
+
+ return _REGISTRY[args.mlp_type][args.mlp_impl](args)
diff --git a/build/torch211-cxx11-cu128-aarch64-linux/_layers/dmoe.py b/build/torch211-cxx11-cu128-aarch64-linux/_layers/dmoe.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d0375a4df2f27134c4127e60be04f3b45693050
--- /dev/null
+++ b/build/torch211-cxx11-cu128-aarch64-linux/_layers/dmoe.py
@@ -0,0 +1,337 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import numpy as np
+import torch
+
+# try:
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/dmoe.py is needed.',
+# )
+
+# import megablocks.ops as ops
+# # from megablocks.ops import ops
+# from megablocks.layers import common, dmlp_registry, moe, mpu
+# from megablocks.layers.arguments import Arguments
+
+from .. import stk
+from .. import ops
+from . import common, dmlp_registry, moe, mpu
+from .arguments import Arguments
+
+def promote_scalar(x):
+ return x.view(1) if not len(x.size()) else x
+
+
+class ParallelDroplessMLP(moe.ParallelMLP):
+
+ def __init__(self, args: Arguments):
+ super(ParallelDroplessMLP, self).__init__(args)
+ self.hidden_size = args.hidden_size
+ self.ffn_hidden_size = mpu.features_per_rank(args)
+ self.blocking = 128
+ self.mlp = dmlp_registry.get(args)
+
+ # Calculate the number of bits needed to represent the column indices
+ # in the intermediate sparse matrix.
+ max_column_index = ((self.ffn_hidden_size * self.num_experts) // self.blocking)
+ self.transpose_sort_end_bit = max(
+ int(np.ceil(np.log2(max_column_index))),
+ 1,
+ )
+
+ def sparse_transpose(self, size, row_indices, column_indices, offsets):
+ block_columns = size[1] // self.blocking
+
+ # Sort row indices by column indices to get the transposed matrix's
+ # column indices.
+ #
+ # NOTE: Our sort operation uses the same width indices as the input values.
+ # To avoid overflow when we have large activation matrices we cast to
+ # 32-bit before sorting.
+ _, gather_indices = ops.sort(
+ column_indices.int(),
+ self.transpose_sort_end_bit,
+ )
+
+ # There are a constant number of blocks in every row of the sparse matrix.
+ # A blocks offset is:
+ #
+ # row_index * blocks_per_row + column_index % blocks_per_row
+ #
+ # Once we have the block offsets ordered for transposition we can divide
+ # by blocks_per_row to get the transposed column indices.
+ column_indices_t = row_indices.gather(0, gather_indices.long())
+ block_offsets_t = gather_indices.int()
+
+ zero = torch.zeros((1,), dtype=torch.int32, device=row_indices.device)
+ nnz_per_column = ops.histogram(column_indices, block_columns)
+ nnz_per_column = ops.inclusive_cumsum(nnz_per_column, 0)
+ if nnz_per_column.dim() == 0:
+ # This addresses an edge case when ffn_hidden_size is equal to self.blocking.
+ nnz_per_column = nnz_per_column.unsqueeze(0)
+ offsets_t = torch.cat([zero, nnz_per_column])
+ return column_indices_t, offsets_t, block_offsets_t
+
+ def topology(self, x, padded_bins):
+ padded_tokens, _ = x.size()
+ assert padded_tokens % self.blocking == 0
+ if self.ffn_hidden_size % self.blocking != 0:
+ raise ValueError(
+ f'The ffn_hidden_size {self.ffn_hidden_size} must be divisible by ' +
+ f'the block size {self.blocking}. Please update your configuration.',
+ )
+
+ # Offsets for the sparse matrix. All rows have the
+ # same number of nonzero blocks dictated by the
+ # dimensionality of a single expert.
+ block_rows = padded_tokens // self.blocking
+ blocks_per_row = self.ffn_hidden_size // self.blocking
+ offsets = torch.arange(
+ 0,
+ block_rows * blocks_per_row + 1,
+ blocks_per_row,
+ dtype=torch.int32,
+ device=x.device,
+ )
+
+ # Indices for the sparse matrix. The indices for
+ # the intermediate matrix are dynamic depending
+ # on the mapping of tokens to experts.
+ column_indices = ops.topology(
+ padded_bins,
+ self.blocking,
+ block_rows,
+ blocks_per_row,
+ )
+
+ # TODO(tgale): This is unused. Remove the need for this in stk.
+ # For now, use meta init to save the device memory.
+ data = torch.empty(
+ column_indices.numel(),
+ self.blocking,
+ self.blocking,
+ dtype=common.dtype(self.args),
+ device='meta',
+ )
+ shape = (
+ padded_tokens,
+ self.ffn_hidden_size * mpu.experts_per_rank(self.args),
+ )
+ row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
+ column_indices_t, offsets_t, block_offsets_t = self.sparse_transpose(
+ shape,
+ row_indices,
+ column_indices,
+ offsets,
+ )
+ return stk.Matrix(
+ shape,
+ data,
+ row_indices,
+ column_indices,
+ offsets,
+ column_indices_t,
+ offsets_t,
+ block_offsets_t,
+ )
+
+ def indices_and_padded_bins(self, top_experts):
+ # Sort the expert ids to produce the scatter/gather
+ # indices for the permutation.
+ top_experts = top_experts.int()
+ bin_ids, indices = ops.sort(top_experts, self.sort_end_bit)
+
+ # Histogram the expert ids to identify the number of
+ # tokens routed to each expert.
+ tokens_per_expert = ops.histogram(top_experts, self.num_experts)
+
+ # Round the token counts up to the block size used in
+ # the matrix muliplications. Caculate the starting
+ # position of each bin.
+ padded_tokens_per_expert = ops.round_up(
+ tokens_per_expert,
+ self.blocking,
+ )
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ padded_bins = promote_scalar(padded_bins)
+
+ # Calculate the bin bounds for the sorted tokens.
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ bins = promote_scalar(bins)
+ return indices, bin_ids, bins, padded_bins, tokens_per_expert
+
+ def sparse_forward_once(self, x, expert_weights, top_experts):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, padded_bins, tokens_per_expert = (self.indices_and_padded_bins(top_experts))
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.padded_gather(
+ x,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ self.top_k,
+ )
+
+ # Create the sparse matrix topology.
+ with torch.no_grad():
+ topo = self.topology(x, padded_bins)
+
+ # Perform the expert computation.
+ x = self.mlp(x, topo)
+
+ # Un-route the data for the MoE output.
+ x = ops.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ padded_bins,
+ self.top_k,
+ )
+ return x, tokens_per_expert
+
+ # For use in the base-class parallel_forward_once.
+ def sparse_permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy, # unused
+ top_k,
+ ):
+
+ # Round the token counts up to the block size used in the matrix
+ # multiplication. Calculate the starting position of each bin.
+ padded_tokens_per_expert = ops.round_up(
+ tokens_per_expert,
+ self.blocking,
+ )
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ padded_bins = promote_scalar(padded_bins)
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
+
+ # Create the sparse matrix topology.
+ with torch.no_grad():
+ topo = self.topology(x, padded_bins)
+
+ # Perform the expert computation.
+ x = self.mlp(x, topo)
+
+ # Un-route the data for the MoE output.
+ return ops.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ def grouped_forward_once(self, x, expert_weights, top_experts):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ out = self.grouped_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ -1, # unused
+ self.args.moe_top_k,
+ )
+ return out, tokens_per_expert
+
+ def grouped_permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy, # unused
+ top_k,
+ ):
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.gather(x, indices, bin_ids, bins, top_k)
+
+ # Perform the expert computation.
+ x = self.mlp(x, tokens_per_expert)
+
+ # Un-route the data for the MoE output.
+ return ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
+
+ def forward_once(self, x, expert_weights, top_experts):
+ if self.args.mlp_impl == 'sparse':
+ return self.sparse_forward_once(x, expert_weights, top_experts)
+ else:
+ return self.grouped_forward_once(x, expert_weights, top_experts)
+
+ def permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ ):
+ if self.args.mlp_impl == 'sparse':
+ return self.sparse_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ )
+ else:
+ return self.grouped_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ )
+
+
+class dMoE(moe.MoE):
+
+ def _init_experts_mlp(self, args: Arguments):
+ return ParallelDroplessMLP(args)
diff --git a/build/torch211-cxx11-cu128-aarch64-linux/_layers/gelu.py b/build/torch211-cxx11-cu128-aarch64-linux/_layers/gelu.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4c9e6532798615b5c12c96694241a4c18ee8f7b
--- /dev/null
+++ b/build/torch211-cxx11-cu128-aarch64-linux/_layers/gelu.py
@@ -0,0 +1,52 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# try:
+# import stk
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/gelu.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+import torch.nn.functional as F
+
+
+@torch.jit.script
+def _gelu_backward_inplace(g, x):
+ tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
+ ff = (0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out))
+ return g.mul_(ff)
+
+
+def gelu_backward_(grad: stk.Matrix, x: stk.Matrix):
+ # NOTE: The two sparse matrices must have the same topology.
+ if isinstance(grad, stk.Matrix) and isinstance(x, stk.Matrix):
+ return stk.Matrix(
+ x.size(),
+ _gelu_backward_inplace(grad.data, x.data),
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
+ return _gelu_backward_inplace(grad, x)
+
+
+def gelu(x: stk.Matrix):
+ assert isinstance(x, stk.Matrix)
+ return stk.Matrix(
+ x.size(),
+ F.gelu(x.data, approximate='tanh'),
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
diff --git a/build/torch211-cxx11-cu128-aarch64-linux/_layers/glu.py b/build/torch211-cxx11-cu128-aarch64-linux/_layers/glu.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f297a41ff6a1a2a285f5b461951672364b898da
--- /dev/null
+++ b/build/torch211-cxx11-cu128-aarch64-linux/_layers/glu.py
@@ -0,0 +1,244 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# import stk.ops
+# try:
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/glu.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+
+# from megablocks import grouped_gemm_util as gg
+# from megablocks.layers import common, mpu
+# from megablocks.layers.activation_fn import act_fn
+# from megablocks.layers.arguments import Arguments
+# from megablocks.layers.mlp import (
+# SharedMLP,
+# SparseMLP,
+# create_dmoe_expert_weights,
+# resolve_dtensor,
+# )
+
+from .. import grouped_gemm_util as gg
+from . import common, mpu
+from .activation_fn import act_fn
+from .arguments import Arguments
+from .mlp import (
+ SharedMLP,
+ SparseMLP,
+ create_dmoe_expert_weights,
+ resolve_dtensor,
+)
+
+
+class SparseGLU(SparseMLP):
+
+ def __init__(self, args: Arguments):
+ super().__init__(args)
+ self.v1 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ with torch.no_grad():
+ self.v1.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ ),
+ )
+
+ mpu.set_expert_model_parallel_attributes(
+ self.v1,
+ self._should_set_parallelism_attribute,
+ )
+
+ def forward(self, x, topo):
+ if self.args.memory_optimized_mlp:
+ raise NotImplementedError(
+ 'Memory optimized implementation not yet supported with GLU with sparse kernels.',
+ )
+
+ w1, v1, w2 = self.scale_grad(self.w1), self.scale_grad(self.v1,), self.scale_grad(self.w2)
+ w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2)
+
+ # Compute the GLU.
+ x1 = stk.ops.sdd(x, w1.t(), topo)
+ x2 = stk.ops.sdd(x, v1.t(), topo)
+
+ activation_fn_out = act_fn(x1, self.args.activation_fn)
+ x1 = stk.ops.mul(activation_fn_out, x2)
+
+ return stk.ops.dsd(x1, w2)
+
+
+class MemoryOptimizedGroupedGLU(torch.autograd.Function):
+ """GroupedMLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, v1, w2, batch_sizes, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ v1 = v1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], v1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not v1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1', 'v1' and 'w2'.")
+
+ # Layer 0: x @ w1.t().
+ assert gg.backend is not None
+ sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True)
+ v1_out = gg.backend.gmm(x, v1, batch_sizes, trans_b=True)
+
+ # GeLU.
+ activation_fn_out = activation_fn(sdd_out) * v1_out
+
+ # Layer 1: x @ w2.
+ dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, v1, w2, batch_sizes, x, sdd_out, v1_out)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx, ddsd_out):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # Unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, v1, w2 = saved_tensors[:3]
+ batch_sizes = saved_tensors[3]
+ x = saved_tensors[4]
+ sdd_out, v1_out = saved_tensors[5:7]
+
+ # Rematerialize activation_fn output.
+ activation_fn = ctx.activation_fn
+ with torch.set_grad_enabled(True):
+ sdd_out.requires_grad = True
+ v1_out.requires_grad = True
+ activation_fn_out = activation_fn(sdd_out) * v1_out
+ activation_grad_fn = activation_fn_out.backward
+
+ # Compute dw2 with recomputed activation_fn output.
+ assert gg.backend is not None
+ dw2 = gg.backend.gmm(
+ activation_fn_out,
+ ddsd_out,
+ batch_sizes,
+ trans_a=True,
+ )
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ gg.backend.gmm(
+ ddsd_out,
+ w2,
+ batch_sizes,
+ trans_b=True,
+ c=dactivation_fn_out,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out)
+ dsdd_out = sdd_out.grad
+ dv1_out = v1_out.grad
+
+ # Compute dw1.
+ dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True)
+
+ # Compute dv1.
+ dv1 = gg.backend.gmm(dv1_out, x, batch_sizes, trans_a=True)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ dx = ddsd_out
+ gg.backend.gmm(dsdd_out, w1, batch_sizes, c=dx)
+ dx += gg.backend.gmm(dv1_out, v1, batch_sizes)
+ return dx, dw1, dv1, dw2, None, None
+
+
+memory_optimized_grouped_glu = MemoryOptimizedGroupedGLU.apply
+
+
+class GroupedGLU(SparseGLU):
+
+ def forward(self, x, tokens_per_expert):
+ batch_sizes = tokens_per_expert.cpu().to(torch.long)
+ w1, v1, w2 = (
+ self.scale_grad(self.w1),
+ self.scale_grad(self.v1),
+ self.scale_grad(self.w2),
+ )
+ w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2)
+
+ # Re-shape the weights for the grouped GEMMs.
+ ne = mpu.experts_per_rank(self.args)
+ w1 = w1.view(ne, -1, self.args.hidden_size)
+ v1 = v1.view(ne, -1, self.args.hidden_size)
+ w2 = w2.view(ne, -1, self.args.hidden_size)
+
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_grouped_glu(
+ x,
+ w1,
+ v1,
+ w2,
+ batch_sizes,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ assert gg.ops is not None
+ x1 = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
+ x2 = gg.ops.gmm(x, v1, batch_sizes, trans_b=True)
+ x1 = self.args.activation_fn(x1) * x2
+ return gg.ops.gmm(x1, w2, batch_sizes)
+
+
+class SharedGLU(SharedMLP):
+ """GPU for shared expert.
+
+ Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTGLU class
+ """
+
+ def __init__(self, args: Arguments):
+ super().__init__(args)
+ self.gate_proj = args.fc_cls(
+ args.hidden_size,
+ self.args.shared_expert_hidden_size,
+ **self.fc_kwargs,
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))
diff --git a/build/torch211-cxx11-cu128-aarch64-linux/_layers/memory_test.py b/build/torch211-cxx11-cu128-aarch64-linux/_layers/memory_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..74d1166931b712635131985b25a89f4ca23e576d
--- /dev/null
+++ b/build/torch211-cxx11-cu128-aarch64-linux/_layers/memory_test.py
@@ -0,0 +1,103 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import gc
+
+import torch
+import torch.distributed as dist
+
+# from megablocks.layers import arguments, dmoe
+from . import arguments, dmoe
+
+_TESTS = ((8, 2048, 4096, 4096, 32, 4),)
+
+
+def get_tensors():
+ ptrs = set()
+ out = []
+ for obj in gc.get_objects():
+ if torch.is_tensor(obj):
+ if not obj.is_contiguous() or obj.data_ptr() in ptrs:
+ continue
+ out.append(obj)
+ ptrs.add(obj.data_ptr())
+ return out
+
+
+def test_memory(
+ group,
+ batch_size,
+ sequence_length,
+ hidden_size,
+ ffn_hidden_size,
+ num_experts,
+ top_k,
+):
+ args = arguments.Arguments(
+ hidden_size=hidden_size,
+ ffn_hidden_size=ffn_hidden_size,
+ moe_num_experts=num_experts,
+ moe_top_k=top_k,
+ moe_expert_model_parallelism=True,
+ expert_parallel_group=group,
+ fp16=False,
+ bf16=True,
+ device=torch.cuda.current_device(),
+ )
+ layer = dmoe.dMoE(args).cuda()
+
+ x = torch.randn((batch_size, sequence_length, hidden_size),
+ device=torch.cuda.current_device(),
+ dtype=torch.bfloat16).requires_grad_(True)
+ torch.cuda.empty_cache()
+
+ # Run forward + backward.
+ # with torch.autograd.detect_anomaly():
+ out, _ = layer(x)
+ out.mean().backward()
+
+ # Report peak memory.
+ mem = torch.cuda.max_memory_allocated()
+ print('Max Memory Allocated = {:0.0f}MiB'.format(mem / 1e6))
+ print('Max Memory Reserved = {:0.0f}MiB'.format(torch.cuda.max_memory_reserved() / 1e6,),)
+
+ # Calculate weight and gradient memory usage.
+ weight_memory = 2 * (
+ layer.router.layer.weight.numel() + layer.experts.mlp.w1.numel() + layer.experts.mlp.w2.numel()
+ )
+
+ def grad_numel(x):
+ if x.grad is not None:
+ return x.grad.numel()
+ return 0
+
+ grad_memory = 2 * (
+ grad_numel(layer.router.layer.weight) + grad_numel(layer.experts.mlp.w1) + grad_numel(layer.experts.mlp.w2)
+ )
+ weight_memory += grad_memory
+
+ print('Weight Memory Allocated = {:0.0f}MiB'.format(weight_memory / 1e6))
+ print('Activation Memory Allocated = {:0.0f}MiB'.format((mem - weight_memory) / 1e6,),)
+
+ # Manually calculate GPU memory usage from the garbage
+ # collector.
+ gc.collect()
+ total = 0
+ tensors = get_tensors()
+ tensors = sorted(tensors, key=lambda x: -x.numel())
+ for i, t in enumerate(tensors):
+ total += t.numel()
+ print(f'{i}: {t.shape}, {t.numel() * 2}')
+ del tensors
+
+ print('Total Bytes Found = {:0.0f}MiB'.format(total * 2 / 1e6))
+
+
+if __name__ == '__main__':
+ assert dist.is_available()
+ group = dist.init_process_group(backend='nccl')
+ local_rank = dist.get_rank(group)
+ torch.cuda.set_device(local_rank)
+
+ for args in _TESTS:
+ test_memory(group, *args)
diff --git a/build/torch211-cxx11-cu128-aarch64-linux/_layers/mlp.py b/build/torch211-cxx11-cu128-aarch64-linux/_layers/mlp.py
new file mode 100644
index 0000000000000000000000000000000000000000..c99afb9904c24a8b6a83e79059cd1251dbbfd99e
--- /dev/null
+++ b/build/torch211-cxx11-cu128-aarch64-linux/_layers/mlp.py
@@ -0,0 +1,587 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# try:
+# import stk
+# import stk.backend.triton_kernels
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/mlp.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+from packaging import version
+
+# from megablocks import grouped_gemm_util as gg
+# from megablocks.layers import common, gelu, mpu
+# from megablocks.layers.activation_fn import act_fn
+# from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
+
+from .. import grouped_gemm_util as gg
+from . import common, gelu, mpu
+from .activation_fn import act_fn
+from .arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
+
+class ScaleGradient(torch.autograd.Function):
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx: Any, x: torch.Tensor, scale: float):
+ ctx.scale = scale
+ return x
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx: torch.Tensor, grad: torch.Tensor):
+ return grad * ctx.scale, None
+
+
+scale_gradient = ScaleGradient.apply
+
+
+def resolve_dtensor(weight: torch.Tensor):
+ if version.parse(torch.__version__) >= version.parse('2.0.0'):
+ from torch.distributed._tensor import DTensor
+ if isinstance(weight, DTensor):
+ return weight.to_local()
+ return weight
+
+
+def create_moe_expert_weights(
+ args: Arguments,
+ num_experts: int,
+ ffn_hidden_size: int,
+ hidden_size: int,
+ init_method: InitFn,
+):
+ # Create the entire weight matrix such that the sampled weights will
+ # not vary between data parallelism and expert model parallelism for
+ # the same random seed.
+ master_weights = torch.empty(
+ num_experts,
+ ffn_hidden_size,
+ hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ )
+ init_method(master_weights)
+
+ if not args.moe_expert_model_parallelism:
+ return master_weights
+
+ # Calculate the amount of sharding in each dimension.
+ expert_sharding_degree = mpu.expert_sharding_degree(args)
+ hidden_sharding_degree = mpu.hidden_sharding_degree(args)
+
+ # Calculate the experts per rank.
+ #
+ # NOTE: We assign ranks to be expert parallel before going
+ # tensor parallel.
+ rank = mpu.get_expert_parallel_rank(args)
+ expert_rank = rank % expert_sharding_degree
+ num_experts_per_rank = num_experts // expert_sharding_degree
+ start_expert = expert_rank * num_experts_per_rank
+ end_expert = (expert_rank + 1) * num_experts_per_rank
+
+ # Calculate the rows per rank.
+ row_rank = rank // expert_sharding_degree
+ num_rows_per_rank = ffn_hidden_size // hidden_sharding_degree
+ start_row = row_rank * num_rows_per_rank
+ end_row = (row_rank + 1) * num_rows_per_rank
+
+ # Slice the weight matrix to get the chunk for this rank.
+ with torch.no_grad():
+ weights = master_weights[start_expert:end_expert, start_row:end_row]
+ return weights
+
+
+class MLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ # expert_parallel_world_size = mpu.get_expert_parallel_world_size(args)
+ experts_per_rank = mpu.experts_per_rank(args)
+
+ self.w1 = torch.nn.Parameter(
+ torch.empty(
+ experts_per_rank,
+ args.hidden_size,
+ mpu.features_per_rank(args),
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ self.w2 = torch.nn.Parameter(
+ torch.empty(
+ experts_per_rank,
+ mpu.features_per_rank(args),
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w1,
+ args.moe_expert_model_parallelism,
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w2,
+ args.moe_expert_model_parallelism,
+ )
+
+ # Initialize the parameters for the MLP.
+ #
+ # NOTE: It is important that we create the weight tensors prior
+ # to creating the master weights and slicing our the piece for
+ # this rank. If the master weights are created first the PyTorch
+ # caching allocator appears to use the same memory block for these
+ # and the slice which causes large increases in our peak memory
+ # usage.
+ with torch.no_grad():
+ w1 = create_moe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ )
+ self.w1.copy_(w1.transpose(1, 2).contiguous())
+ self.w2.copy_(
+ create_moe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.output_layer_init_method,
+ ),
+ )
+
+ self.gradient_scale = None
+ if self.args.moe_expert_model_parallelism:
+ self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,)
+
+ def scale_grad(self, w):
+ if self.gradient_scale is None:
+ return w
+ return scale_gradient(w, self.gradient_scale)
+
+ def forward(self, x):
+ w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2)
+ w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2)
+ x = torch.bmm(x, w1)
+ x = self.args.activation_fn(x)
+ return torch.bmm(x, w2)
+
+
+def create_dmoe_expert_weights(
+ args: Arguments,
+ num_experts: int,
+ rows: int,
+ columns: int,
+ init_method: InitFn,
+):
+ weights = create_moe_expert_weights(
+ args,
+ num_experts,
+ rows,
+ columns,
+ init_method,
+ )
+ return weights.view([-1, columns])
+
+
+class MemoryOptimizedMLP(torch.autograd.Function):
+ """Sparse MLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, w2, topo, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.")
+
+ topo_tensors = (
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets,
+ topo.column_indices_t,
+ topo.offsets_t,
+ topo.block_offsets_t,
+ )
+
+ # Layer 0: x @ w1.t().
+ sdd_out = stk.ops.sdd(x, w1.t(), topo)
+
+ # GeLU.
+ activation_fn_out = act_fn(sdd_out, activation_fn)
+
+ # Layer 1: x @ w2.
+ dsd_out = stk.ops.dsd(activation_fn_out, w2)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.shape = topo.shape
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.data.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, w2, *topo_tensors, x, sdd_out.data)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx, ddsd_out):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, w2 = saved_tensors[:2]
+ topo_tensors = saved_tensors[2:8]
+ x = saved_tensors[8]
+ sdd_out_data = saved_tensors[9]
+
+ # rematerialize activation function output
+ activation_fn = ctx.activation_fn
+ sdd_out = stk.Matrix(ctx.shape, sdd_out_data, *topo_tensors)
+ activation_fn_out, activation_grad_fn = act_fn(
+ sdd_out,
+ activation_fn,
+ return_grad_fn=True,
+ )
+
+ # Compute dw2 with recomputed activation_fn output.
+ dw2 = stk.ops.dsd(activation_fn_out.t(), ddsd_out)
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ stk.backend.triton_kernels.sdd(
+ ddsd_out,
+ w2.t(),
+ dactivation_fn_out.shape,
+ dactivation_fn_out.data,
+ dactivation_fn_out.offsets,
+ dactivation_fn_out.row_indices,
+ dactivation_fn_out.column_indices,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ if activation_fn is DEFAULT_ACTIVATION_FN:
+ dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out)
+ else:
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out.data)
+ dsdd_out = stk.Matrix(ctx.shape, sdd_out.data.grad, *topo_tensors)
+
+ # Compute dw1.
+ dw1 = stk.ops.dsd(dsdd_out.t(), x)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ stk.backend.triton_kernels.dsd(
+ dsdd_out.shape,
+ dsdd_out.data,
+ dsdd_out.offsets,
+ dsdd_out.row_indices,
+ dsdd_out.column_indices,
+ dsdd_out.offsets_t,
+ dsdd_out.column_indices_t,
+ dsdd_out.block_offsets_t,
+ False,
+ w1,
+ ddsd_out,
+ )
+ dx = ddsd_out
+ return dx, dw1, dw2, None, None
+
+
+memory_optimized_mlp = MemoryOptimizedMLP.apply
+
+
+class SparseMLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ self._num_rows_per_rank = mpu.experts_per_rank(args) * mpu.features_per_rank(args)
+
+ self.w1 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ self.w2 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+
+ # Initialize the parameters for the MLP.
+ #
+ # NOTE: It is important that we create the weight tensors prior
+ # to creating the master weights and slicing our the piece for
+ # this rank. If the master weights are created first the PyTorch
+ # caching allocator appears to use the same memory block for these
+ # and the slice which causes large increases in our peak memory
+ # usage.
+ with torch.no_grad():
+ self.w1.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ ),
+ )
+ self.w2.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.output_layer_init_method,
+ ),
+ )
+
+ self._should_set_parallelism_attribute = args.moe_expert_model_parallelism
+ mpu.set_expert_model_parallel_attributes(
+ self.w1,
+ self._should_set_parallelism_attribute,
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w2,
+ self._should_set_parallelism_attribute,
+ )
+
+ self.gradient_scale = None
+ if self.args.moe_expert_model_parallelism:
+ self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,)
+
+ def scale_grad(self, w):
+ if self.gradient_scale is None:
+ return w
+ return scale_gradient(w, self.gradient_scale)
+
+ def forward(self, x, topo):
+ w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2)
+ w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2)
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_mlp(
+ x,
+ w1,
+ w2,
+ topo,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ x = stk.ops.sdd(x, w1.t(), topo)
+ activation_fn_out = act_fn(x, self.args.activation_fn)
+ return stk.ops.dsd(activation_fn_out, w2)
+
+
+class MemoryOptimizedGroupedMLP(torch.autograd.Function):
+ """GroupedMLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, w2, batch_sizes, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.")
+
+ # Layer 0: x @ w1.t().
+ assert gg.backend is not None
+ sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True)
+
+ # activation_fn
+ activation_fn_out = activation_fn(sdd_out)
+
+ # Layer 1: x @ w2.
+ dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, w2, batch_sizes, x, sdd_out)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx: Any, ddsd_out: torch.Tensor):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # Unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, w2 = saved_tensors[:2]
+ batch_sizes = saved_tensors[2]
+ x = saved_tensors[3]
+ sdd_out = saved_tensors[4]
+
+ # Rematerialize activation_fn output.
+ activation_fn = ctx.activation_fn
+ with torch.set_grad_enabled(True):
+ sdd_out.requires_grad = True
+ activation_fn_out = activation_fn(sdd_out)
+ activation_grad_fn = activation_fn_out.backward
+
+ # Compute dw2 with recomputed activation_fn output.
+ assert gg.backend is not None
+ dw2 = gg.backend.gmm(
+ activation_fn_out,
+ ddsd_out,
+ batch_sizes,
+ trans_a=True,
+ )
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ gg.backend.gmm(
+ ddsd_out,
+ w2,
+ batch_sizes,
+ trans_b=True,
+ c=dactivation_fn_out,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ if activation_fn is DEFAULT_ACTIVATION_FN:
+ dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out)
+ else:
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out)
+ dsdd_out = sdd_out.grad
+
+ # Compute dw1.
+ dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ gg.backend.gmm(dsdd_out, w1, batch_sizes, c=ddsd_out)
+ dx = ddsd_out
+ return dx, dw1, dw2, None, None
+
+
+memory_optimized_grouped_mlp = MemoryOptimizedGroupedMLP.apply
+
+
+class GroupedMLP(SparseMLP):
+
+ def forward(self, x, tokens_per_expert):
+ batch_sizes = tokens_per_expert.cpu().to(torch.long)
+ w1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.w2))
+
+ # Re-shape the weights for the grouped GEMMs.
+ ne = mpu.experts_per_rank(self.args)
+ w1 = resolve_dtensor(w1).view(ne, -1, self.args.hidden_size)
+ w2 = resolve_dtensor(w2).view(ne, -1, self.args.hidden_size)
+
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_grouped_mlp(
+ x,
+ w1,
+ w2,
+ batch_sizes,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ assert gg.ops is not None
+ x = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
+ x = self.args.activation_fn(x)
+ return gg.ops.gmm(x, w2, batch_sizes)
+
+
+class SharedMLP(torch.nn.Module):
+ """MLP for shared expert.
+
+ Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTMLP class
+ """
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ self.fc_kwargs: dict[str, Any] = {
+ 'bias': args.bias,
+ 'device': args.device,
+ }
+ self.fc_kwargs.update(args.fc_kwargs)
+
+ self.up_proj = args.fc_cls(
+ args.hidden_size,
+ args.shared_expert_hidden_size,
+ **self.fc_kwargs,
+ )
+ self.act = args.activation_fn
+ self.down_proj = args.fc_cls(
+ args.shared_expert_hidden_size,
+ args.hidden_size,
+ **self.fc_kwargs,
+ )
+ self.down_proj._is_residual = True # a flag for llm-foundry init
+
+ def add_experts_sharedexpert(
+ self,
+ shared_expert_out: torch.Tensor,
+ expert_out: torch.Tensor,
+ ) -> torch.Tensor:
+ # Helper function to add expert output to shared expert output
+ # with optional weighted sum.
+ if self.args.shared_expert_weighted_sum:
+ # enable using weighted sum for shared expert output
+ # wieghted by number of experts used
+ t_experts = self.args.moe_top_k + 1
+ sh_mlp_out = shared_expert_out / t_experts
+ return sh_mlp_out.add(
+ expert_out,
+ alpha=(self.args.moe_top_k / t_experts),
+ )
+
+ return shared_expert_out + expert_out
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.down_proj(self.act(self.up_proj(x)))
diff --git a/build/torch211-cxx11-cu128-aarch64-linux/_layers/moe.py b/build/torch211-cxx11-cu128-aarch64-linux/_layers/moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0a4aeaacc9c86fc70944e730c53f7a55644e05e
--- /dev/null
+++ b/build/torch211-cxx11-cu128-aarch64-linux/_layers/moe.py
@@ -0,0 +1,507 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Optional, Tuple
+
+import numpy as np
+import torch
+import torch.distributed as dist
+
+# import megablocks.ops as ops
+# from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry
+# from megablocks.layers.all_to_all import all_to_all
+# from megablocks.layers.arguments import Arguments
+
+from ..ops import (
+ sort,
+ histogram,
+ inclusive_cumsum,
+ exclusive_cumsum,
+ binned_gather,
+ binned_scatter,
+ gather,
+ scatter,
+ repeat,
+ replicate,
+)
+
+from . import common, mlp, mpu, router, sharedexpert_registry
+from .arguments import Arguments
+from .all_to_all import all_to_all
+
+_LOAD_BALANCING_LOSS = []
+
+
+def save_load_balancing_loss(loss):
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.append(loss)
+
+
+def get_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ return _LOAD_BALANCING_LOSS
+
+
+def clear_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.clear()
+
+
+def batched_load_balancing_loss(args: Arguments):
+ if args.moe_loss_weight == 0:
+ return 0.0
+
+ # tokens_per_expert[i].shape = (num_experts)
+ # expert_scores[i].shape = (tokens, num_experts)
+ tokens_per_expert, expert_scores = zip(*get_load_balancing_loss())
+ num_layers_per_pipeline_stage = (args.num_layers // args.pipeline_model_parallel_size)
+ if args.num_layers_per_virtual_pipeline_stage is not None:
+ num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage
+
+ if len(tokens_per_expert) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f'Expected {num_layers_per_pipeline_stage} token_per_experts '
+ f'but found {len(tokens_per_expert)}.\nnum_layers = '
+ f'{args.num_layers}\npipeline_model_parallel_size = '
+ f'{args.pipeline_model_parallel_size}\n'
+ 'num_layers_per_virtual_pipeline_stage'
+ f' = {args.num_layers_per_virtual_pipeline_stage}',
+ )
+ if len(expert_scores) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f'Expected {num_layers_per_pipeline_stage} expert_scores '
+ f'but found {len(tokens_per_expert)}.\nnum_layers = '
+ f'{args.num_layers}\npipeline_model_parallel_size = '
+ f'{args.pipeline_model_parallel_size}\n'
+ 'num_layers_per_virtual_pipeline_stage'
+ f' = {args.num_layers_per_virtual_pipeline_stage}',
+ )
+
+ # Verify the shape of the tokens_per_expert and expert_scores tensors.
+ assert all((x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert))
+
+ tokens = expert_scores[0].shape[0]
+ assert all(((x.ndim == 2 and x.shape[1] == args.moe_num_experts and x.shape[0] == tokens) for x in expert_scores))
+
+ # Concatenate the contributions of each layer and convert to
+ # the correct types and formats for the dot product.
+ expert_scores = torch.cat(expert_scores, dim=1)
+ if args.moe_lbl_in_fp32:
+ expert_scores = expert_scores.float()
+ if tokens != 0:
+ expert_scores = expert_scores.mean(dim=0)
+ else:
+ expert_scores = expert_scores.sum(dim=0)
+ tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype)
+
+ expected_values = num_layers_per_pipeline_stage * args.moe_num_experts
+ assert tokens_per_expert.numel() == expected_values
+ assert expert_scores.numel() == expected_values
+
+ # Calculate the total scale across all factors.
+ #
+ # loss_weight * num_experts / (num_layers * tokens * top_k)
+ scale_numerator = (args.moe_num_experts * args.moe_loss_weight)
+ scale_denominator = (args.num_layers * tokens * args.moe_top_k)
+ scale = scale_numerator / scale_denominator
+ return scale * torch.dot(tokens_per_expert, expert_scores)
+
+
+# NOTE: This class defines MoE expert computation, including expert model parallel
+# communication. When using FSDP on top of MegaBlocks this is the module that should
+# be wrapped s.t. the weight all-gathers can be scheduled *before* the expert model
+# parallel all2all.
+class ParallelMLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super(ParallelMLP, self).__init__()
+ self.args = args
+
+ # Calculate the number of experts in total and the number of experts
+ # owned by this rank.
+ # world_size = mpu.get_expert_parallel_world_size(args)
+ self.num_experts = args.moe_num_experts
+ self.top_k = self.args.moe_top_k
+
+ # Calculate the number of bits needed to represent the expert indices
+ # so that we can pass it to radix sort.
+ self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1)
+
+ # Expert MLP.
+ self.mlp = mlp.MLP(args)
+
+ self.bias: Optional[torch.Tensor]
+ if self.args.bias:
+ # Note that the output bias is not parallelized with expert
+ # model parallelism.
+ self.bias = torch.nn.Parameter(
+ torch.empty(
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ torch.nn.init.zeros_(self.bias)
+ else:
+ self.register_parameter('bias', None)
+
+ # Select the forward function for the operating mode.
+ self.forward_fn = (self.parallel_forward_once if args.moe_expert_model_parallelism else self.forward_once)
+
+ def expert_capacity(self, tokens: int) -> int:
+ world_size = mpu.get_expert_parallel_world_size(self.args)
+ tokens_per_expert = (self.top_k * tokens * world_size / self.num_experts)
+ return int(self.args.moe_capacity_factor * tokens_per_expert)
+
+ def load_balancing_loss(self, tokens_per_expert: torch.Tensor, expert_scores: torch.Tensor):
+ """Calculate the load balancing loss contribution."""
+ assert len(expert_scores.size()) == 2
+ tokens, num_experts = expert_scores.size()
+ assert num_experts == self.num_experts
+ assert len(tokens_per_expert.size()) == 1
+ num_experts, = tokens_per_expert.size()
+ assert num_experts == self.num_experts
+ scale = self.num_experts / (tokens * self.top_k)
+ return scale * torch.dot(
+ tokens_per_expert.to(expert_scores.dtype),
+ expert_scores.mean(dim=0),
+ )
+
+ def indices_and_bins(self,
+ top_expert: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ # Sort the expert ids to produce the scatter/gather
+ # indices for the permutation.
+ #
+ # TODO(tgale): Is it worth doing this conversion to 32-bit
+ # prior? Could we place the `torch.max` operation to return
+ # 32-bit expert indices?
+ top_expert = top_expert.int()
+ # output = ops.sort(top_expert, self.sort_end_bit)
+ output = sort(top_expert, self.sort_end_bit)
+ assert output is not None
+ bin_ids, indices = output
+
+ # Histogram the expert ids to identify the number of
+ # tokens routed to each expert.
+ #
+ # TODO(tgale): Does the sorted data produce a more favorable
+ # data distribution for histogram? Or is the op parallelism
+ # worth more?
+ # tokens_per_expert = ops.histogram(top_expert, self.num_experts)
+ tokens_per_expert = histogram(top_expert, self.num_experts)
+
+ # Calculate the bin bounds for the sorted tokens.
+ # bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ bins = inclusive_cumsum(tokens_per_expert, 0)
+ assert bins is not None
+ bins = bins.view(1) if not len(bins.size()) else bins
+
+ assert isinstance(indices, torch.Tensor)
+ assert isinstance(bin_ids, torch.Tensor)
+ assert isinstance(bins, torch.Tensor)
+ assert isinstance(tokens_per_expert, torch.Tensor)
+
+ return indices, bin_ids, bins, tokens_per_expert
+
+ def permute_and_compute(
+ self,
+ x: torch.Tensor,
+ tokens_per_expert: int, # unused
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor, # unused
+ expert_weights: torch.Tensor,
+ bins: torch.Tensor,
+ expert_capacity: int,
+ top_k: int,
+ ):
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ # output = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
+ output = binned_gather(x, indices, bins, expert_capacity, top_k)
+ assert output is not None
+ x = output
+
+ # Perform the expert computation. Note that we don't
+ # use biases for these linear operations.
+ x = self.mlp(x)
+
+ # Un-route the data for the MoE output.
+ # return ops.binned_scatter(x, indices, expert_weights, bins, top_k)
+ return binned_scatter(x, indices, expert_weights, bins, top_k)
+
+
+ def forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ # If expert_capacity is set to zero, set the number of tokens
+ # per expert to the maximum we need to avoid dropping tokens.
+ sl, bs, _ = x.size()
+ expert_capacity = self.expert_capacity(sl * bs)
+ if expert_capacity == 0:
+ expert_capacity = torch.max(tokens_per_expert).item()
+
+ x = self.permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ self.top_k,
+ )
+ return x, tokens_per_expert
+
+ def parallel_forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ # NOTE: This function implements the same computation as forward_once
+ # but with expert model parallelism.
+ #
+ # 1. Permute the tokens locally so that they are grouped by their
+ # expert assignments. This allows us to transfer all of the tokens
+ # for a remote device in one communication primitive.
+ #
+ # 2. Permute the tokens across the expert parallel devices. After
+ # this is completed each device has all of the tokens assigned to
+ # its set of experts in its local HBM.
+ #
+ # 3. Permute the tokens locally so that they are grouped by their
+ # expert assignement. After the distributed permutation the tokens
+ # are grouped by which device they came from. We re-order them
+ # locally to allow for efficient computation.
+ #
+ # After this series of permutations we compute the linear layers
+ # and then repeat these three steps in reverse to produce the final
+ # output.
+ #
+ # Compute the mapping of local tokens to experts.
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ # If we're sharding the experts along the hidden dimension
+ # multiple devices own parts of the same sets of experts.
+ # Replicate the token counts so every device gets the counts.
+ # repeated_tokens_per_expert = ops.repeat(
+ repeated_tokens_per_expert = repeat(
+ tokens_per_expert,
+ (mpu.hidden_sharding_degree(self.args),),
+ )
+
+ # Pass token count information to the device on which the
+ # target expert resides.
+ parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert,)
+ tpe_handle = dist.all_to_all_single(
+ parallel_tokens_per_expert,
+ repeated_tokens_per_expert,
+ group=self.args.expert_parallel_group,
+ async_op=True,
+ )
+
+ # Permute locally and without any padding so that tokens for each
+ # parallel device are stored contiguously.
+ #
+ # This view updates the shape of the tensor from [sl, bs, hs] to
+ # [sl * bs, hs] prior to the permutation.
+ x = x.view(-1, x.shape[-1])
+ # output = ops.gather(x, indices, bin_ids, bins, self.top_k)
+ output = gather(x, indices, bin_ids, bins, self.top_k)
+ assert output is not None
+ x = output
+
+ # Compute the number of tokens that will be received from each
+ # device and permute the input data across the devices.
+ with torch.no_grad():
+ tpe_handle.wait()
+ experts_per_rank = mpu.experts_per_rank(self.args)
+
+ # Reshape to [world_size, num_experts_per_rank].
+ world_size = mpu.get_expert_parallel_world_size(self.args)
+ repeated_tokens_per_expert = (repeated_tokens_per_expert.view(world_size, experts_per_rank))
+ parallel_tokens_per_expert = (parallel_tokens_per_expert.view(world_size, experts_per_rank))
+
+ # TODO(tgale): It might be faster to do this on the GPU and
+ # then communicate the results back to the host.
+ send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1)
+ parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
+ recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1)
+
+ # Convert the send/recv counts to lists.
+ send_counts = send_counts.tolist()
+ recv_counts = recv_counts.tolist()
+ tokens_received = sum(recv_counts)
+
+ # If we're sharding the experts along the hidden dimension
+ # multiple devices own parts of the same sets of experts.
+ # Replicate the token counts so devices that share experts
+ # get all of the tokens assigned to them.
+ #
+ # TODO(tgale): Fuse this into the prior, local permutation.
+ # x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
+ x = repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
+
+ # Start the cross-device permutation asynchronously so we can
+ # overlap communication with computation.
+ parallel_x, parallel_x_handle = all_to_all(
+ x,
+ recv_counts,
+ send_counts,
+ self.args.expert_parallel_group,
+ async_op=True,
+ )
+
+ with torch.no_grad():
+ # After we do the cross-device permutation we have the tokens on the
+ # correct device but not yet grouped by expert because we received
+ # tokens from each device as contiguous chunks. To group the tokens
+ # for expert computation we'll do one more local permutation. The
+ # rest of this torch.no_grad() scope sets up the indices and bins
+ # for this permutation.
+ # replicate_bins = ops.inclusive_cumsum(
+ replicate_bins = inclusive_cumsum(
+ parallel_tokens_per_expert.flatten(),
+ 0,
+ )
+ replicate_bins = (replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins)
+
+ # Construct the expert indices for the permuted tokens.
+ parallel_top_expert = torch.remainder(
+ torch.arange(
+ self.num_experts * mpu.hidden_sharding_degree(self.args),
+ dtype=torch.int32,
+ device=indices.device,
+ ),
+ mpu.experts_per_rank(self.args),
+ )
+ # parallel_top_expert = ops.replicate(
+ parallel_top_expert = replicate(
+ parallel_top_expert.unsqueeze(dim=0),
+ replicate_bins,
+ tokens_received,
+ ).flatten()
+
+ # TODO(tgale): The sort_end_bit here can be reduced.
+ # parallel_bin_ids, parallel_indices = ops.sort(
+ parallel_bin_ids, parallel_indices = sort(
+ parallel_top_expert,
+ self.sort_end_bit,
+ )
+
+ # Calculate the bins boundaries from the token counts.
+ parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+ # parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = (parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins)
+
+ # If expert_capacity is set to zero, set the number of tokens
+ # per expert to the maximum we need to avoid dropping tokens.
+ tokens, _ = x.size()
+ expert_capacity = self.expert_capacity(tokens)
+ if expert_capacity == 0:
+ expert_capacity = torch.max(parallel_tokens_per_expert).item()
+
+ # Locally permute the tokens and perform the expert computation.
+ # Block to make sure that the cross-device permutation is complete.
+ if self.args.mlp_impl == 'grouped':
+ # GroupedMLP requires counts on CPU. We can use the tensor already
+ # moved to CPU for the prior all_to_all, which avoids an extra
+ # device synchronization.
+ parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+ parallel_x_handle.wait()
+ parallel_x = self.permute_and_compute(
+ parallel_x,
+ parallel_tokens_per_expert,
+ parallel_indices,
+ parallel_bin_ids,
+ None, # expert_weights
+ parallel_bins,
+ expert_capacity,
+ top_k=1,
+ )
+
+ # Un-permute the tokens across the devices.
+ x, _ = all_to_all(
+ parallel_x,
+ send_counts,
+ recv_counts,
+ self.args.expert_parallel_group,
+ )
+
+ # Reduce along the hidden sharding to get the final outputs.
+ #
+ # TODO(tgale): Fuse this into the following local permutation.
+ shape = (
+ mpu.hidden_sharding_degree(self.args),
+ -1,
+ self.args.hidden_size,
+ )
+ # x = ops.sum(x.view(shape), dim=0)
+ x = x.view(shape).sum(dim=0)
+
+ # Un-permute locally to setup for the next series of operations.
+ # x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
+ x = scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
+ return x, tokens_per_expert.flatten()
+
+ def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ in_shape = x.size()
+
+ # Compute the experts.
+ x, tokens_per_expert = self.forward_fn(x, expert_weights, top_experts)
+ if self.training and self.args.moe_loss_weight > 0:
+ save_load_balancing_loss((tokens_per_expert, scores))
+ x = x.view(in_shape)
+ if self.bias is not None:
+ if self.args.return_bias:
+ return x, self.bias
+ return x + self.bias
+ return x
+
+
+class MoE(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super(MoE, self).__init__()
+
+ # Token router.
+ self.router = router.LearnedRouter(args)
+
+ # Expert computation helper.
+ self.experts = self._init_experts_mlp(args)
+
+ self.shared_expert = None
+ if args.shared_expert:
+ # SharedExpert computation helper.
+ self.shared_expert = sharedexpert_registry.get(args)
+
+ def _init_experts_mlp(self, args: Arguments):
+ return ParallelMLP(args)
+
+ def forward(self, x: torch.Tensor):
+ # NOTE: If we're going to cast the activations to lower precision
+ # do it before we permute the tokens to save bandwidth.
+ x = common.cast_if_autocast_enabled(x)
+
+ # Compute the expert scores and assignments.
+ scores, expert_weights, top_experts = self.router(x)
+
+ # Compute the experts.
+ out = self.experts(x, scores, expert_weights, top_experts)
+ if self.shared_expert is not None:
+ shared_expert_out = self.shared_expert(x)
+ out = self.shared_expert.add_experts_sharedexpert(
+ shared_expert_out,
+ out,
+ )
+ return out
diff --git a/build/torch211-cxx11-cu128-aarch64-linux/_layers/mpu.py b/build/torch211-cxx11-cu128-aarch64-linux/_layers/mpu.py
new file mode 100644
index 0000000000000000000000000000000000000000..434e143ab42bf3f83406d69e9dd1f72777716e22
--- /dev/null
+++ b/build/torch211-cxx11-cu128-aarch64-linux/_layers/mpu.py
@@ -0,0 +1,94 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Optional
+
+import torch
+import torch.distributed as dist
+
+# from megablocks.layers.arguments import Arguments
+from .arguments import Arguments
+
+
+class MoeParam(torch.Tensor):
+
+ def __init__(self):
+ super().__init__(self)
+ self.expert_model_parallel: bool
+
+
+def is_moe_param(tensor: torch.Tensor) -> bool:
+ return hasattr(tensor, 'expert_model_parallel')
+
+
+def get_expert_parallel_world_size(args: Arguments) -> int:
+ return (dist.get_world_size(args.expert_parallel_group) if args.moe_expert_model_parallelism else 1)
+
+
+def get_expert_parallel_rank(args: Arguments) -> int:
+ return (dist.get_rank(args.expert_parallel_group) if args.moe_expert_model_parallelism else 0)
+
+
+def set_expert_model_parallel_attributes(
+ tensor: torch.Tensor,
+ is_parallel: bool,
+):
+ assert not hasattr(tensor, 'expert_model_parallel')
+ setattr(tensor, 'expert_model_parallel', is_parallel)
+
+
+def param_is_expert_model_parallel(param: MoeParam) -> bool:
+ return (hasattr(param, 'expert_model_parallel') and param.expert_model_parallel)
+
+
+def copy_expert_model_parallel_attributes(
+ destination_tensor: torch.Tensor,
+ source_tensor: torch.Tensor,
+):
+ if hasattr(source_tensor, 'expert_model_parallel'):
+ setattr(
+ destination_tensor,
+ 'expert_model_parallel',
+ getattr(source_tensor, 'expert_model_parallel'),
+ )
+
+
+def synchronized_print(group: Optional[dist.ProcessGroup], *x: torch.Tensor):
+ world_size = dist.get_world_size(group)
+ rank = dist.get_rank(group)
+ for i in range(world_size):
+ dist.barrier(group)
+ if i == rank:
+ print(f'rank = {rank}', *x)
+
+
+# Helpers for expert/tensor sharding.
+def expert_sharding_degree(args: Arguments) -> int:
+ world_size = get_expert_parallel_world_size(args)
+ esd = min(world_size, args.moe_num_experts)
+
+ if (args.moe_num_experts % esd) != 0:
+ raise ValueError(f'Cannot shard {args.moe_num_experts} experts {esd} ways.',)
+ return esd
+
+
+def hidden_sharding_degree(args: Arguments) -> int:
+ world_size = get_expert_parallel_world_size(args)
+ esd = expert_sharding_degree(args)
+ hsd = world_size // esd
+
+ if (args.ffn_hidden_size % hsd) != 0:
+ raise ValueError(f'Cannot shard {args.ffn_hidden_size} features {hsd} ways.',)
+ if (esd * hsd) != world_size:
+ raise ValueError(
+ f"Invalid sharding. 'expert_sharding_degree' ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size}).",
+ )
+ return hsd
+
+
+def experts_per_rank(args: Arguments) -> int:
+ return args.moe_num_experts // expert_sharding_degree(args)
+
+
+def features_per_rank(args: Arguments) -> int:
+ return args.ffn_hidden_size // hidden_sharding_degree(args)
diff --git a/build/torch211-cxx11-cu128-aarch64-linux/_layers/router.py b/build/torch211-cxx11-cu128-aarch64-linux/_layers/router.py
new file mode 100644
index 0000000000000000000000000000000000000000..37cb2782348d62583376f1a183c7ede83601216d
--- /dev/null
+++ b/build/torch211-cxx11-cu128-aarch64-linux/_layers/router.py
@@ -0,0 +1,116 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+
+# from megablocks.layers import common
+# from megablocks.layers.arguments import Arguments
+from . import common
+from .arguments import Arguments
+
+_ROUTER_LOGITS = []
+
+
+def _save_router_logits(logits: torch.Tensor, args: Arguments):
+ if args.moe_zloss_weight == 0:
+ return
+ global _ROUTER_LOGITS
+ _ROUTER_LOGITS.append(logits)
+
+
+def clear_router_zloss():
+ global _ROUTER_LOGITS
+ _ROUTER_LOGITS.clear()
+
+
+def batched_router_zloss(args: Arguments):
+ global _ROUTER_LOGITS
+
+ if args.moe_zloss_weight == 0:
+ import warnings
+ warnings.warn('Call to batched_router_zloss, but moe_zloss_weight=0')
+ return 0
+
+ logits_per_router = _ROUTER_LOGITS
+
+ if args.moe_zloss_in_fp32:
+ logits_per_router = [logits.float() for logits in logits_per_router]
+
+ unscaled_zloss_per_router = torch.stack([
+ torch.logsumexp(logits, dim=1).square().mean() for logits in logits_per_router
+ ])
+
+ return args.moe_zloss_weight * unscaled_zloss_per_router
+
+
+# NOTE: To enable end-to-end benchmarking without convergence we
+# support a flag to force the router to assign tokens uniformly
+# across the experts. We do this with a custom autograd operation
+# so that PyTorch still executes the full set of router operation.
+class _UniformExpertAssignment(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, num_experts: int):
+ out = torch.arange(x.numel(), dtype=x.dtype, device=x.device)
+ out = torch.remainder(out, num_experts)
+ return out.view(x.shape)
+
+
+_uniform_expert_assignment = _UniformExpertAssignment.apply
+
+
+class LearnedRouter(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+
+ # Learned router parameters.
+ #
+ # NOTE: This weight matrix is not parallelized with expert model
+ # parallelism. Each device needs the entire router weight matrix
+ # so that it can route its batch of data correctly.
+ self.layer = torch.nn.Linear(
+ args.hidden_size,
+ args.moe_num_experts,
+ bias=False,
+ dtype=common.dtype(args),
+ device=args.device,
+ )
+ args.init_method(self.layer.weight)
+
+ def jitter(self, x: torch.Tensor):
+ low: float = 1.0 - self.args.moe_jitter_eps
+ high: float = 1.0 + self.args.moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return low + noise * (high - low)
+
+ def _top_k(self, scores: torch.Tensor):
+ if self.args.moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, self.args.moe_top_k, dim=-1)
+
+ def forward(self, x: torch.Tensor):
+ if self.training and self.args.moe_jitter_eps is not None:
+ x = x * self.jitter(x)
+
+ logits = self.layer(x.view(-1, x.shape[-1]))
+ _save_router_logits(logits, self.args)
+ scores = logits.softmax(dim=-1)
+ expert_weights, expert_indices = self._top_k(scores)
+ if self.args.moe_normalize_expert_weights:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=self.args.moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+
+ expert_indices = (
+ _uniform_expert_assignment(
+ expert_indices,
+ self.args.moe_num_experts,
+ ) if self.args.uniform_expert_assignment else expert_indices
+ )
+ return scores, expert_weights, expert_indices
diff --git a/build/torch211-cxx11-cu128-aarch64-linux/_layers/sharedexpert_registry.py b/build/torch211-cxx11-cu128-aarch64-linux/_layers/sharedexpert_registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..5840862f88f370ace5fd49bd0612fc98d186cc49
--- /dev/null
+++ b/build/torch211-cxx11-cu128-aarch64-linux/_layers/sharedexpert_registry.py
@@ -0,0 +1,32 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Union
+
+# from megablocks.layers import glu, mlp
+# from megablocks.layers.arguments import Arguments
+from . import glu, mlp
+from .arguments import Arguments
+
+_REGISTRY = {
+ 'mlp': mlp.SharedMLP,
+ 'glu': glu.SharedGLU,
+}
+
+
+def get(args: Arguments) -> Union[mlp.SharedMLP, glu.SharedGLU]:
+ """Returns an SharedMLP for use in a dMoE instance.
+
+ Uses the provided arguments to instantiate the appropriate
+ SharedMLP instance.
+
+ Args:
+ args: propagated Arguments dataclass.
+
+ Returns:
+ An instantiated SharedMLP constructed using the input args.
+ """
+ if args.mlp_type not in _REGISTRY:
+ raise ValueError(f'Unsupported mlp type: {args.mlp_type}')
+
+ return _REGISTRY[args.mlp_type](args)
diff --git a/build/torch211-cxx11-cu128-aarch64-linux/_megablocks_cuda_7a6bcf4.abi3.so b/build/torch211-cxx11-cu128-aarch64-linux/_megablocks_cuda_7a6bcf4.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..fea7e4a177f87345182ed40bc6ffd6dd007b6ca5
--- /dev/null
+++ b/build/torch211-cxx11-cu128-aarch64-linux/_megablocks_cuda_7a6bcf4.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:91f8b706af9af0569af55e732d0f508af6a31c3bff9268dc3cbd24193c5fee0c
+size 21088232
diff --git a/build/torch211-cxx11-cu128-aarch64-linux/_ops.py b/build/torch211-cxx11-cu128-aarch64-linux/_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..abde001f8cf5f78a02794d6e9a81fd8195e65d77
--- /dev/null
+++ b/build/torch211-cxx11-cu128-aarch64-linux/_ops.py
@@ -0,0 +1,9 @@
+import torch
+from . import _megablocks_cuda_7a6bcf4
+ops = torch.ops._megablocks_cuda_7a6bcf4
+
+def add_op_namespace_prefix(op_name: str):
+ """
+ Prefix op by namespace.
+ """
+ return f"_megablocks_cuda_7a6bcf4::{op_name}"
diff --git a/build/torch211-cxx11-cu128-aarch64-linux/_version.py b/build/torch211-cxx11-cu128-aarch64-linux/_version.py
new file mode 100644
index 0000000000000000000000000000000000000000..c55783177af19bc03654c730c4892df8f8532279
--- /dev/null
+++ b/build/torch211-cxx11-cu128-aarch64-linux/_version.py
@@ -0,0 +1,6 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+"""The MegaBlocks Version."""
+
+__version__ = '0.11.0.dev0'
diff --git a/build/torch211-cxx11-cu128-aarch64-linux/backend/__init__.py b/build/torch211-cxx11-cu128-aarch64-linux/backend/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d4e43e9b7e6a9a1c3bde2df34914643ca5d8332
--- /dev/null
+++ b/build/torch211-cxx11-cu128-aarch64-linux/backend/__init__.py
@@ -0,0 +1,2 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
diff --git a/build/torch211-cxx11-cu128-aarch64-linux/backend/kernels.py b/build/torch211-cxx11-cu128-aarch64-linux/backend/kernels.py
new file mode 100644
index 0000000000000000000000000000000000000000..d20139324883338ddb312e4b05a72056d74491ac
--- /dev/null
+++ b/build/torch211-cxx11-cu128-aarch64-linux/backend/kernels.py
@@ -0,0 +1,557 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import triton
+import triton.language as tl
+
+# Stub triton autotune when testing in a env that does not have CUDA
+# this approach preserves the original code but enables testing without a GPU
+if torch.cuda.is_available() is False:
+ import warnings
+
+ warnings.warn("CUDA is not available. Triton autotuning is disabled.")
+
+ def _no_autotune(*args, **kwargs):
+ def deco(fn):
+ return fn
+ return deco
+
+ triton.autotune = _no_autotune
+
+
+def assert_is_tensor(x, ndim):
+ if x.ndim != ndim:
+ raise ValueError(f'Expected {ndim}-tensor but got {x.ndim}-tensor')
+
+
+def assert_is_matrix(x):
+ assert_is_tensor(x, 2)
+
+
+def assert_is_vector(x):
+ if x.ndim != 1:
+ raise ValueError(f'Expected 1-tensor but got {x.ndim}-tensor')
+
+
+def assert_equal(a, b):
+ if a != b:
+ raise ValueError(f'Expected dimensions to be equal but got {a} and {b}.',)
+
+
+# a: (tokens, hidden_size), real.
+# indices: (tokens * top_k), integer.
+# bin_ids: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+# padded_bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _padded_copy(
+ a,
+ b,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+ A_TO_B: tl.constexpr,
+ SCALE: tl.constexpr,
+):
+ # Our index into array 'a'.
+ index_a = tl.load(indices + tl.program_id(0))
+
+ # One threadblock per row in 'a'. Array 'b' has greater or equal
+ # number of rows since they could be padded.
+ bin_idx = tl.load(bin_ids + tl.program_id(0))
+
+ # Now we know what bin we're assigned to, but we need to know how
+ # many threadblocks were assigned to earlier bins so we can offset
+ # in our bin properly.
+ offset_in_bin = tl.program_id(0)
+ if bin_idx > 0:
+ offset_in_bin -= tl.load(bins + bin_idx - 1)
+
+ # Load the starting index of our bin in array 'b'.
+ index_b = offset_in_bin
+ if bin_idx > 0:
+ index_b += tl.load(padded_bins + bin_idx - 1)
+
+ # Offset the input and output pointers.
+ #
+ # If we're going from A to B, divide the input index to copy
+ # the same input repeatedly. If we're going from B to A we
+ # need to reduce the result. Using atomics is slow, so we
+ # do the reduce step in a second kernel.
+ offset = index_a // TOP_K if A_TO_B else index_a
+ a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS)
+ b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ # Load the scale, if requested.
+ scale = tl.load(weights + index_a) if SCALE else 1
+
+ # Swap the pointers depending on the direction.
+ iptr = a if A_TO_B else b
+ optr = b if A_TO_B else a
+
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ x = tl.load(iptr + offsets, mask=mask)
+ x = x.to(tl.float32) * scale.to(tl.float32)
+
+ tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask)
+
+ offsets += BLOCK_X
+
+
+def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+ assert_equal(bin_ids.shape[0], x.shape[0] * top_k)
+ assert_equal(bins.size(), padded_bins.size())
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ # NOTE: Because of the padding, the output size is dynamic.
+ # We load the final padded bin bound to get the output rows.
+ output_rows = padded_bins[-1].cpu().item()
+ out = torch.zeros((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ x,
+ out,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def gather(x, indices, bin_ids, weights, bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+ assert_equal(bin_ids.shape[0], x.shape[0] * top_k)
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ # NOTE: There is no padding so the output rows equals the
+ # input rows multiplied by top_k.
+ output_rows = x.shape[0] * top_k
+ out = torch.empty((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ x,
+ out,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], bin_ids.shape[0])
+ assert_equal(bins.size(), padded_bins.size())
+
+ if weights is not None:
+ assert_equal(indices.shape[0], weights.shape[0])
+
+ tokens = indices.shape[0] // top_k
+ out = torch.empty((tokens, top_k, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ out,
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=False,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+
+ # Reduce along the top-k dimension, if needed.
+ return out.sum(dim=1) if top_k > 1 else out.view(tokens, x.shape[1])
+
+
+def scatter(x, indices, bin_ids, weights, bins, top_k):
+ return padded_scatter(x, indices, bin_ids, weights, bins, bins, top_k)
+
+
+# x: (tokens, top_k, hidden_size), real
+# grad: (tokens, hidden_size), real.
+# wgrad: (tokens, top_k), real.
+# indices: (tokens * top_k), integer.
+# bin_ids: (tokens * top_k), integer.
+# bins: (num_experts), integer.
+# padded_bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _padded_copy_wgrad(
+ x,
+ grad,
+ wgrad,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+):
+ # Our index into 'tokens * top_k'.
+ index_out = tl.load(indices + tl.program_id(0))
+
+ # One threadblock per row in 'a'. Array 'b' has greater or equal
+ # number of rows since they could be padded.
+ bin_idx = tl.load(bin_ids + tl.program_id(0))
+
+ # Now we know what bin we're assigned to, but we need to know how
+ # many threadblocks were assigned to earlier bins so we can offset
+ # in our bin properly.
+ offset_in_bin = tl.program_id(0)
+ if bin_idx > 0:
+ offset_in_bin -= tl.load(bins + bin_idx - 1)
+
+ # Load the starting index of our bin in array 'x'.
+ index_x = offset_in_bin
+ if bin_idx > 0:
+ index_x += tl.load(padded_bins + bin_idx - 1)
+
+ # Offset the input and output pointers.
+ wgrad += index_out
+ grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS)
+ x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ data = tl.load(x + offsets, mask=mask).to(tl.float32)
+ scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
+ acc += data * scale
+ offsets += BLOCK_X
+
+ # Reduce to get the final result and store.
+ out = tl.sum(acc).to(wgrad.dtype.element_ty)
+ tl.store(wgrad, out)
+
+
+def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_matrix(grad)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], bin_ids.shape[0])
+ assert_equal(bins.size(), padded_bins.size())
+
+ tokens = indices.shape[0] // top_k
+ out = torch.empty((tokens * top_k), dtype=x.dtype, device=x.device)
+ _padded_copy_wgrad[(indices.shape[0],)](
+ x,
+ grad,
+ out,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ TOP_K=top_k,
+ )
+ return out
+
+
+def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k):
+ return padded_scatter_wgrad(x, grad, indices, bin_ids, bins, bins, top_k)
+
+
+# a: (tokens, hidden_size), real.
+# b: (num_experts, expert_capacity, num_columns), real.
+# indices: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _binned_copy(
+ a,
+ b,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+ A_TO_B: tl.constexpr,
+ SCALE: tl.constexpr,
+):
+ # Load our indices into the output.
+ expert_idx = tl.program_id(0)
+ entry_idx = tl.program_id(1)
+
+ # Calculate our offset into the output.
+ index_b = expert_idx * expert_capacity + entry_idx
+
+ # Load the index bounds for our bin and calculate
+ # the number of tokens assigned to our expert.
+ start = 0
+ if expert_idx > 0:
+ start = tl.load(bins + expert_idx - 1)
+ end = tl.load(bins + expert_idx)
+ num_tokens = end - start
+
+ # Calculate our offset into the input. If we don't
+ # have an input exit early.
+ if entry_idx >= num_tokens:
+ return
+ index_a = tl.load(indices + start + entry_idx)
+
+ # Offset the input and output pointers.
+ #
+ # If we're going from A to B, divide the input index to copy
+ # the same input repeatedly. If we're going from B to A we
+ # need to reduce the result. Using atomics is slow, so we
+ # do the reduce step in a second kernel.
+ offset = index_a // TOP_K if A_TO_B else index_a
+ a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS)
+ b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ # Load the scale, if requested.
+ scale = tl.load(weights + index_a) if SCALE else 1
+
+ # Swap the pointers depending on the direction.
+ #
+ # NOTE: We need to zero the output in both directions.
+ iptr = a if A_TO_B else b
+ optr = b if A_TO_B else a
+
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ x = tl.load(iptr + offsets, mask=mask)
+ x = x.to(tl.float32) * scale.to(tl.float32)
+
+ tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask)
+
+ offsets += BLOCK_X
+
+
+def binned_gather(x, indices, weights, bins, expert_capacity, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ num_experts = bins.shape[0]
+ out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
+
+ _binned_copy[(num_experts, expert_capacity)](
+ x,
+ out,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def binned_scatter(x, indices, weights, bins, top_k):
+ # Validate the input shapes.
+ assert_is_tensor(x, 3)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(bins.shape[0], x.shape[0])
+
+ if weights is not None:
+ assert_equal(indices.shape[0], weights.shape[0])
+
+ num_experts, expert_capacity, hidden_size = x.shape
+ tokens = indices.shape[0] // top_k
+ out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
+ _binned_copy[(num_experts, expert_capacity)](
+ out,
+ x,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS=hidden_size,
+ A_TO_B=False,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+
+ # Reduce along the top-k dimension, if needed.
+ return out.sum(dim=1) if top_k > 1 else out.view(tokens, hidden_size)
+
+
+# a: (tokens, hidden_size), real.
+# b: (num_experts, expert_capacity, num_columns), real.
+# indices: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _binned_copy_wgrad(
+ x,
+ grad,
+ wgrad,
+ num_experts,
+ expert_capacity,
+ indices,
+ bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+):
+ # Load our indices into the output.
+ expert_idx = tl.program_id(0)
+ entry_idx = tl.program_id(1)
+
+ # Calculate our offset into the output.
+ index_x = expert_idx * expert_capacity + entry_idx
+
+ # Load the index bounds for our bin and calculate
+ # the number of tokens assigned to our expert.
+ start = 0
+ if expert_idx > 0:
+ start = tl.load(bins + expert_idx - 1)
+ end = tl.load(bins + expert_idx)
+ num_tokens = end - start
+
+ # Calculate our offset into the input. If we don't
+ # have an input exit early.
+ if entry_idx >= num_tokens:
+ return
+ index_out = tl.load(indices + start + entry_idx)
+
+ # Offset the input and output pointers.
+ wgrad += index_out
+ grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS)
+ x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ data = tl.load(x + offsets, mask=mask).to(tl.float32)
+ scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
+ acc += data * scale
+ offsets += BLOCK_X
+
+ # Reduce to get the final result and store.
+ out = tl.sum(acc).to(wgrad.dtype.element_ty)
+ tl.store(wgrad, out)
+
+
+def binned_scatter_wgrad(x, grad, indices, bins, top_k):
+ # Validate the input shapes.
+ assert_is_tensor(x, 3)
+ assert_is_matrix(grad)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(bins.shape[0], x.shape[0])
+
+ num_experts, expert_capacity, hidden_size = x.shape
+ tokens = indices.shape[0] // top_k
+ out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
+ _binned_copy_wgrad[(num_experts, expert_capacity)](
+ x,
+ grad,
+ out,
+ num_experts,
+ expert_capacity,
+ indices,
+ bins,
+ NUM_COLUMNS=hidden_size,
+ TOP_K=top_k,
+ )
+ return out
diff --git a/build/torch211-cxx11-cu128-aarch64-linux/benchmark_util.py b/build/torch211-cxx11-cu128-aarch64-linux/benchmark_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..02612d95e3ead1175a596e2878fa34b5bf85ad6f
--- /dev/null
+++ b/build/torch211-cxx11-cu128-aarch64-linux/benchmark_util.py
@@ -0,0 +1,35 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import numpy as np
+import torch
+
+
+def log_benchmark(name, arguments, time, std):
+ print('=' * 60)
+ print(f'{name} Benchmark')
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean time = {:.3f}ms, std time = {:.3f}ms'.format(time, std))
+ print('=' * 60)
+
+
+def benchmark_function(fn, iterations=100, warmup=10):
+ # Warmup iterations.
+ for _ in range(warmup):
+ fn()
+
+ times = []
+ for i in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+
+ start.record()
+ fn()
+ end.record()
+
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ return np.mean(times), np.std(times)
diff --git a/build/torch211-cxx11-cu128-aarch64-linux/cpu_fused_moe.py b/build/torch211-cxx11-cu128-aarch64-linux/cpu_fused_moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a29e81d6e9c57f64b0a78ac1c0828e45fd9d855
--- /dev/null
+++ b/build/torch211-cxx11-cu128-aarch64-linux/cpu_fused_moe.py
@@ -0,0 +1,311 @@
+# SPDX-License-Identifier: Apache-2.0
+# MegaBlocks CPU Fused MoE Implementation
+#
+# This is a pure Python/PyTorch implementation for CPU.
+# For better performance, consider using the C++ kernel implementation.
+#
+import torch
+import torch.nn.functional as F
+
+
+def swigluoai_activation(gate: torch.Tensor, up: torch.Tensor,
+ alpha: float = 1.702, limit: float = 7.0) -> torch.Tensor:
+ """
+ SwigluOAI activation function used in GptOss models.
+
+ Formula:
+ gate = clamp(gate, max=limit)
+ up = clamp(up, -limit, limit)
+ glu = gate * sigmoid(gate * alpha)
+ output = (up + 1) * glu
+
+ Args:
+ gate: Gate tensor from gate projection
+ up: Up tensor from up projection
+ alpha: Scaling factor for sigmoid (default: 1.702)
+ limit: Clamp limit (default: 7.0)
+
+ Returns:
+ Activated tensor
+ """
+ gate = gate.clamp(max=limit)
+ up = up.clamp(min=-limit, max=limit)
+ glu = gate * torch.sigmoid(gate * alpha)
+ return (up + 1) * glu
+
+
+def silu_and_mul_activation(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor:
+ """
+ SiLU (Swish) activation with element-wise multiplication.
+
+ Formula:
+ output = silu(gate) * up
+
+ Args:
+ gate: Gate tensor
+ up: Up tensor
+
+ Returns:
+ Activated tensor
+ """
+ return F.silu(gate) * up
+
+
+def route_tokens_cpu(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: torch.Tensor | None,
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_normalize_expert_weights: int | None = None,
+) -> tuple:
+ """
+ Route tokens to experts and compute expert weights and indices (CPU version).
+
+ Args:
+ x: Input tensor [batch, seq, hidden] or [tokens, hidden]
+ router_weight: Router weight [num_experts, hidden]
+ router_bias: Router bias [num_experts] or None
+ moe_top_k: Number of experts per token
+ moe_num_experts: Total number of experts
+ moe_normalize_expert_weights: Normalization order or None
+
+ Returns:
+ Tuple of (logits, expert_weights, expert_indices)
+ """
+ x_flat = x.view(-1, x.shape[-1])
+ logits = F.linear(x_flat, router_weight, router_bias)
+
+ if moe_top_k == 1:
+ expert_weights, expert_indices = logits.max(dim=-1, keepdim=True)
+ else:
+ expert_weights, expert_indices = torch.topk(logits, moe_top_k, dim=-1)
+
+ expert_weights = expert_weights.softmax(dim=-1)
+
+ if moe_normalize_expert_weights is not None:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+
+ return logits, expert_weights, expert_indices
+
+
+def cpu_fused_moe(
+ hidden_states: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ w1_bias: torch.Tensor | None = None,
+ w2_bias: torch.Tensor | None = None,
+ activation: str = "silu",
+ alpha: float = 1.702,
+ limit: float = 7.0,
+ is_interleaved: bool = True,
+) -> torch.Tensor:
+ """
+ CPU Fused MoE using PyTorch operations.
+
+ This implementation processes all experts in parallel using batched operations
+ instead of sequential for loops, which is more efficient on CPU.
+
+ Args:
+ hidden_states: [num_tokens, hidden_size]
+ w1: [num_experts, hidden_size, 2*inter_size] - gate_up_proj weights
+ w2: [num_experts, inter_size, hidden_size] - down_proj weights
+ topk_weights: [num_tokens, topk] - routing weights
+ topk_ids: [num_tokens, topk] - expert indices
+ w1_bias: [num_experts, 2*inter_size] or None
+ w2_bias: [num_experts, hidden_size] or None
+ activation: "silu" or "swigluoai"
+ alpha: swigluoai alpha parameter
+ limit: swigluoai limit parameter
+ is_interleaved: whether gate_up is interleaved [g0,u0,g1,u1,...] (True for GptOss)
+
+ Returns:
+ output: [num_tokens, hidden_size]
+ """
+ num_tokens, hidden_size = hidden_states.shape
+ num_experts = w1.shape[0]
+ inter_size = w2.shape[1]
+ topk = topk_weights.shape[1]
+
+ # Initialize output
+ output = torch.zeros_like(hidden_states)
+
+ # Build expert mask: which tokens go to which expert
+ # expert_mask[expert_id] contains indices of (token_idx, topk_pos) pairs
+ for expert_idx in range(num_experts):
+ # Find tokens assigned to this expert
+ # mask shape: [num_tokens, topk], True where topk_ids == expert_idx
+ mask = (topk_ids == expert_idx)
+
+ if not mask.any():
+ continue
+
+ # Get token indices and topk positions
+ token_indices, topk_positions = torch.where(mask)
+
+ if len(token_indices) == 0:
+ continue
+
+ # Gather input tokens for this expert
+ # current_hidden: [num_selected_tokens, hidden_size]
+ current_hidden = hidden_states[token_indices]
+
+ # Get weights for this expert
+ # w1[expert_idx]: [hidden_size, 2*inter_size]
+ # w2[expert_idx]: [inter_size, hidden_size]
+ expert_w1 = w1[expert_idx] # [hidden_size, 2*inter_size]
+ expert_w2 = w2[expert_idx] # [inter_size, hidden_size]
+
+ # First projection: hidden @ w1 -> [num_selected, 2*inter_size]
+ gate_up = current_hidden @ expert_w1
+
+ # Add bias if present
+ if w1_bias is not None:
+ gate_up = gate_up + w1_bias[expert_idx]
+
+ # Split gate and up projections
+ if is_interleaved:
+ # GptOss uses interleaved layout: [g0, u0, g1, u1, ...]
+ gate = gate_up[..., ::2] # [num_selected, inter_size]
+ up = gate_up[..., 1::2] # [num_selected, inter_size]
+ else:
+ # Standard layout: [gate_all, up_all]
+ gate = gate_up[..., :inter_size]
+ up = gate_up[..., inter_size:]
+
+ # Apply activation
+ if activation == "swigluoai":
+ activated = swigluoai_activation(gate, up, alpha, limit)
+ else: # silu
+ activated = silu_and_mul_activation(gate, up)
+
+ # Second projection: activated @ w2 -> [num_selected, hidden_size]
+ expert_out = activated @ expert_w2
+
+ # Add bias if present
+ if w2_bias is not None:
+ expert_out = expert_out + w2_bias[expert_idx]
+
+ # Apply routing weights and accumulate
+ # weights shape: [num_selected]
+ weights = topk_weights[token_indices, topk_positions].unsqueeze(-1)
+ weighted_out = expert_out * weights
+
+ # Accumulate to output
+ output.index_add_(0, token_indices, weighted_out.to(output.dtype))
+
+ return output
+
+
+class MegaBlocksMoeMLP(torch.nn.Module):
+ """
+ CPU MoE MLP module that can be used as a drop-in replacement for
+ the transformers GptOssMLP when using @use_kernel_forward_from_hub.
+ """
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> tuple:
+ """
+ Forward pass through the MoE layer.
+
+ Args:
+ x: Input tensor of shape [batch_size, seq_len, hidden_size] or [tokens, hidden_size]
+
+ Returns:
+ Tuple of (output, expert_weights) where:
+ - output: Tensor of same shape as input
+ - expert_weights: Expert weights for each token [tokens, top_k]
+ """
+ # Get MoE parameters from the wrapped modules
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+
+ # Detect activation type
+ if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
+ activation = "swigluoai"
+ alpha = self.experts.alpha
+ limit = self.experts.limit
+ else:
+ activation = getattr(self.experts, "activation", "silu")
+ alpha = 1.702
+ limit = 7.0
+
+ # Get weight tensors
+ if hasattr(self.experts, "gate_up_proj"):
+ w1 = self.experts.gate_up_proj
+ is_interleaved = True # GptOss uses interleaved layout
+ elif hasattr(self.experts, "w1"):
+ w1 = self.experts.w1
+ w3 = getattr(self.experts, "w3", None)
+ if w3 is not None:
+ w1 = torch.cat([w1, w3], dim=-1)
+ is_interleaved = False
+ else:
+ raise AttributeError("experts module must have 'gate_up_proj' or 'w1' attribute")
+
+ if hasattr(self.experts, "down_proj"):
+ w2 = self.experts.down_proj
+ elif hasattr(self.experts, "w2"):
+ w2 = self.experts.w2
+ else:
+ raise AttributeError("experts module must have 'down_proj' or 'w2' attribute")
+
+ # Get optional bias tensors
+ w1_bias = getattr(self.experts, "gate_up_proj_bias", None)
+ w2_bias = getattr(self.experts, "down_proj_bias", None)
+
+ # Store original shape
+ in_shape = x.size()
+
+ # Route tokens to experts
+ logits, expert_weights, expert_indices = route_tokens_cpu(
+ x,
+ self.router.weight,
+ getattr(self.router, "bias", None),
+ moe_top_k,
+ moe_num_experts,
+ moe_normalize_expert_weights,
+ )
+
+ # Reshape input for fused MoE
+ x_flat = x.view(-1, x.shape[-1])
+
+ # Call CPU fused MoE
+ output = cpu_fused_moe(
+ hidden_states=x_flat,
+ w1=w1,
+ w2=w2,
+ topk_weights=expert_weights,
+ topk_ids=expert_indices,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ activation=activation,
+ alpha=alpha,
+ limit=limit,
+ is_interleaved=is_interleaved,
+ )
+
+ # Restore original shape
+ output = output.view(in_shape)
+
+ return output, expert_weights
+
+
+# Export classes and functions
+__all__ = [
+ "MegaBlocksMoeMLP",
+ "cpu_fused_moe",
+ "route_tokens_cpu",
+ "swigluoai_activation",
+ "silu_and_mul_activation",
+]
diff --git a/build/torch211-cxx11-cu128-aarch64-linux/cpu_moe_cpp.py b/build/torch211-cxx11-cu128-aarch64-linux/cpu_moe_cpp.py
new file mode 100644
index 0000000000000000000000000000000000000000..073ff66d24ce348fbb5ed19c9027fadd3f7a9c61
--- /dev/null
+++ b/build/torch211-cxx11-cu128-aarch64-linux/cpu_moe_cpp.py
@@ -0,0 +1,265 @@
+# SPDX-License-Identifier: Apache-2.0
+# MegaBlocks C++ Optimized CPU MoE
+
+"""
+C++ accelerated MoE with brgemm optimization for Intel AMX.
+Direct replacement for cpu_fused_moe.MegaBlocksMoeMLP with better performance.
+"""
+
+import torch
+from typing import Optional
+from .cpu_fused_moe import route_tokens_cpu
+from ._ops import ops
+
+
+def _to_local_tensor(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
+ """Convert DTensor to local torch.Tensor if needed for custom ops compatibility."""
+ if tensor is None:
+ return None
+ # Check if it's a DTensor by looking for the to_local() method
+ if hasattr(tensor, "to_local"):
+ return tensor.to_local()
+ return tensor
+
+
+def fused_moe_cpp(
+ hidden_states: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ inplace: bool = False,
+ use_int8_w8a8: bool = False,
+ use_fp8_w8a16: bool = False,
+ use_mxfp4: bool = False,
+ w1_scale: Optional[torch.Tensor] = None,
+ w2_scale: Optional[torch.Tensor] = None,
+ block_size: Optional[list] = None,
+ a1_scale: Optional[torch.Tensor] = None,
+ a2_scale: Optional[torch.Tensor] = None,
+ w1_bias: Optional[torch.Tensor] = None,
+ w2_bias: Optional[torch.Tensor] = None,
+ alpha: Optional[float] = None,
+ limit: Optional[float] = None,
+ is_vnni: bool = False,
+) -> torch.Tensor:
+ """
+ C++ Fused MoE with brgemm optimization (sglang compatible interface).
+
+ Uses at::native::cpublas::brgemm for efficient batch GEMM on Intel CPUs.
+ Supports both silu_and_mul (standard SwiGLU) and swigluoai (GptOss) activations.
+
+ Args:
+ hidden_states: Input tensor [M, K]
+ w1: Gate and up projections [E, 2N, K]
+ w2: Down projection [E, K, N]
+ topk_weights: Expert weights [M, topk]
+ topk_ids: Expert indices [M, topk]
+ inplace: Whether to use hidden_states as output
+ use_int8_w8a8: Use int8 quantization
+ use_fp8_w8a16: Use fp8 quantization
+ use_mxfp4: Use mxfp4 quantization
+ w1_scale, w2_scale: Quantization scales
+ block_size: Block size for fp8
+ a1_scale, a2_scale: Activation scales
+ w1_bias, w2_bias: Optional biases
+ alpha: swigluoai alpha parameter (set to enable swiglu)
+ limit: swigluoai limit parameter (set to enable swiglu)
+ is_vnni: Whether w1/w2 are already in VNNI packed format
+ """
+ # MXFP4/FP8 kernels only support bf16, convert if needed
+ orig_dtype = hidden_states.dtype
+ need_convert = ((use_mxfp4 or use_fp8_w8a16) and orig_dtype != torch.bfloat16) or orig_dtype == torch.float32
+ if need_convert:
+ hidden_states = hidden_states.to(torch.bfloat16)
+
+ # bias must match hidden_states dtype
+ if w1_bias is not None:
+ w1_bias = w1_bias.to(hidden_states.dtype)
+ if w2_bias is not None:
+ w2_bias = w2_bias.to(hidden_states.dtype)
+
+ # Convert DTensor to local tensor for custom ops compatibility (TP mode)
+ hidden_states = _to_local_tensor(hidden_states)
+ w1 = _to_local_tensor(w1)
+ w2 = _to_local_tensor(w2)
+ topk_weights = _to_local_tensor(topk_weights)
+ topk_ids = _to_local_tensor(topk_ids)
+ w1_scale = _to_local_tensor(w1_scale)
+ w2_scale = _to_local_tensor(w2_scale)
+ a1_scale = _to_local_tensor(a1_scale)
+ a2_scale = _to_local_tensor(a2_scale)
+ w1_bias = _to_local_tensor(w1_bias)
+ w2_bias = _to_local_tensor(w2_bias)
+
+ output = ops.fused_experts(
+ hidden_states, w1, w2, topk_weights, topk_ids,
+ inplace, use_int8_w8a8, use_fp8_w8a16, use_mxfp4,
+ w1_scale, w2_scale, block_size, a1_scale, a2_scale,
+ w1_bias, w2_bias, alpha, limit, is_vnni
+ )
+
+ # Convert back to original dtype if needed
+ if need_convert:
+ output = output.to(orig_dtype)
+ return output
+
+
+class CPUMegaBlocksMoeMLP(torch.nn.Module):
+ """
+ C++ optimized MoE MLP using brgemm.
+ Drop-in replacement for cpu_fused_moe.MegaBlocksMoeMLP with better performance.
+
+ Usage in transformers:
+ # Will be used via @use_kernel_forward_from_hub decorator
+ """
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> tuple:
+ """
+ Forward pass through the MoE layer using C++ kernel.
+
+ Args:
+ x: Input tensor [batch_size, seq_len, hidden_size]
+
+ Returns:
+ Tuple of (output, expert_weights)
+ """
+ # Optimization for GPT-OSS model
+ if getattr(self, "use_mxfp4", None) is None:
+ self.use_mxfp4 = False
+
+ w1_scale = None
+ w2_scale = None
+
+ if (
+ not getattr(self, "packed_scales", False)
+ and hasattr(self.experts, "gate_up_proj")
+ and getattr(self.experts, "gate_up_proj_precision_config", None) is not None
+ ):
+ # convert scales
+ data_1 = ops.convert_scale_packed(self.experts.gate_up_proj_precision_config.weight_scale.data.transpose(-1, -2).contiguous())
+ data_2 = ops.convert_scale_packed(self.experts.down_proj_precision_config.weight_scale.data.transpose(-1, -2).contiguous())
+ self.experts.gate_up_proj_precision_config.weight_scale.storage.data = data_1
+ self.experts.down_proj_precision_config.weight_scale.storage.data = data_2
+ self.packed_scales = True
+ self.use_mxfp4 = True
+
+ if not getattr(self, "packed_weight", False) and hasattr(
+ self.experts, "gate_up_proj"
+ ):
+ # convert weights
+ data_1 = self.experts.gate_up_proj.data.transpose(-1, -2).contiguous()
+ data_2 = self.experts.down_proj.data.transpose(-1, -2).contiguous()
+ if self.use_mxfp4:
+ self.experts.gate_up_proj.storage.data = ops.convert_weight_packed(data_1)
+ self.experts.down_proj.storage.data = ops.convert_weight_packed(data_2)
+ else:
+ # convert_weight_packed only supports bfloat16, float16, int8, fp8_e4m3 or uint8(mxfp4 or int4).
+ data_1 = data_1.to(torch.bfloat16) if data_1.dtype == torch.float32 else data_1
+ data_2 = data_2.to(torch.bfloat16) if data_2.dtype == torch.float32 else data_2
+ self.experts.gate_up_proj.data = ops.convert_weight_packed(data_1)
+ self.experts.down_proj.data = ops.convert_weight_packed(data_2)
+
+ # C++ kernel does not support float32.
+ dtype = torch.bfloat16 if x.dtype == torch.float32 else x.dtype
+ if getattr(self.experts, "gate_up_proj_bias", None) is not None:
+ self.experts.gate_up_proj_bias.data = self.experts.gate_up_proj_bias.data.to(dtype)
+ if getattr(self.experts, "down_proj_bias", None) is not None:
+ self.experts.down_proj_bias.data = self.experts.down_proj_bias.data.to(dtype)
+
+ self.packed_weight = True
+
+ # Get MoE parameters
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
+
+ # Detect activation type
+ if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
+ activation = "swigluoai"
+ alpha = self.experts.alpha
+ limit = self.experts.limit
+ else:
+ activation = getattr(self.experts, "activation", "silu")
+ alpha = 1.702
+ limit = 7.0
+
+ # Get weight tensors
+ if hasattr(self.experts, "gate_up_proj"):
+ w1 = self.experts.gate_up_proj
+ elif hasattr(self.experts, "w1"):
+ w1 = self.experts.w1
+ w3 = getattr(self.experts, "w3", None)
+ if w3 is not None:
+ w1 = torch.cat([w1, w3], dim=-1)
+ else:
+ raise AttributeError("experts module must have 'gate_up_proj' or 'w1' attribute")
+
+ if hasattr(self.experts, "down_proj"):
+ w2 = self.experts.down_proj
+ elif hasattr(self.experts, "w2"):
+ w2 = self.experts.w2
+ else:
+ raise AttributeError("experts module must have 'down_proj' or 'w2' attribute")
+
+ # Get optional bias tensors
+ w1_bias = getattr(self.experts, "gate_up_proj_bias", None)
+ w2_bias = getattr(self.experts, "down_proj_bias", None)
+ w1_bias = w1_bias if w1_bias is None else w1_bias.data
+ w2_bias = w2_bias if w2_bias is None else w2_bias.data
+
+ if self.use_mxfp4:
+ w1_scale = self.experts.gate_up_proj_precision_config.weight_scale.data
+ w2_scale = self.experts.down_proj_precision_config.weight_scale.data
+
+ # Store original shape
+ in_shape = x.size()
+
+ # Route tokens to experts (Python implementation is fast enough)
+ logits, expert_weights, expert_indices = route_tokens_cpu(
+ x,
+ self.router.weight,
+ getattr(self.router, "bias", None),
+ moe_top_k,
+ moe_num_experts,
+ moe_normalize_expert_weights,
+ )
+
+ # Flatten input
+ x_flat = x.view(-1, x.shape[-1])
+
+ # Determine alpha/limit for swiglu activation
+ use_alpha = alpha if activation == "swigluoai" else None
+ use_limit = limit if activation == "swigluoai" else None
+
+ # Call C++ optimized kernel
+ output = fused_moe_cpp(
+ hidden_states=x_flat,
+ w1=w1.data,
+ w2=w2.data,
+ topk_weights=expert_weights,
+ topk_ids=expert_indices.to(torch.int32),
+ inplace=False,
+ use_int8_w8a8=False,
+ use_fp8_w8a16=False,
+ use_mxfp4=self.use_mxfp4,
+ w1_scale=w1_scale,
+ w2_scale=w2_scale,
+ block_size=None,
+ a1_scale=None,
+ a2_scale=None,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ alpha=use_alpha,
+ limit=use_limit,
+ is_vnni=getattr(self, "packed_weight", False),
+ )
+
+ # Restore original shape
+ output = output.view(in_shape)
+
+ return output, expert_weights
+
+
+__all__ = ["fused_moe_cpp", "MegaBlocksMoeMLP"]
diff --git a/build/torch211-cxx11-cu128-aarch64-linux/grouped_gemm/__init__.py b/build/torch211-cxx11-cu128-aarch64-linux/grouped_gemm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b91c8308f0c24f4c4171b6e4f15b6f76dabf295a
--- /dev/null
+++ b/build/torch211-cxx11-cu128-aarch64-linux/grouped_gemm/__init__.py
@@ -0,0 +1,2 @@
+from . import ops
+from . import backend
diff --git a/build/torch211-cxx11-cu128-aarch64-linux/grouped_gemm/backend.py b/build/torch211-cxx11-cu128-aarch64-linux/grouped_gemm/backend.py
new file mode 100644
index 0000000000000000000000000000000000000000..76037d8039cbfc2f0577275c78e4bc0be762592a
--- /dev/null
+++ b/build/torch211-cxx11-cu128-aarch64-linux/grouped_gemm/backend.py
@@ -0,0 +1,33 @@
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# # TODO(tgale): Wrap this in a try-block with better
+# # error message and instructions for building the
+# # c++ operations.
+# import grouped_gemm_backend as backend
+
+# We import the backend operations from the megablocks package as
+# grouped_gemm is vendored in megablocks in this repository.
+# from ... import _ops as backend
+# from megablocks._ops import ops as backend # type: ignore
+from .._ops import ops as backend # type: ignore
+
+def _allocate_output(a, b, batch_sizes, trans_a, trans_b):
+ assert not (trans_a and trans_b)
+ assert batch_sizes.ndim == 1, "Expected 1d tensor for batch_sizes"
+ assert a.ndim == 2, "Expected 2d tensor for 'a'"
+ assert b.ndim == (2 if trans_a else 3)
+
+ shape = (
+ (batch_sizes.shape[0], a.shape[1], b.shape[1])
+ if trans_a else
+ (a.shape[0], (b.shape[1] if trans_b else b.shape[2]))
+ )
+ return torch.empty(*shape, device=a.device, dtype=a.dtype)
+
+def gmm(a, b, batch_sizes, trans_a=False, trans_b=False, c=None):
+ if c is None:
+ c = _allocate_output(a, b, batch_sizes, trans_a, trans_b)
+ backend.gmm(a, b, c, batch_sizes, trans_a, trans_b)
+ return c
diff --git a/build/torch211-cxx11-cu128-aarch64-linux/grouped_gemm/ops.py b/build/torch211-cxx11-cu128-aarch64-linux/grouped_gemm/ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b30dd14e23837ea3b12334f4e31337ed9ad2b69
--- /dev/null
+++ b/build/torch211-cxx11-cu128-aarch64-linux/grouped_gemm/ops.py
@@ -0,0 +1,33 @@
+from . import backend
+import torch
+
+
+class GroupedGemm(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, a, b, batch_sizes, trans_b):
+ ctx.save_for_backward(a, b, batch_sizes)
+ ctx.trans_b = trans_b
+ return backend.gmm(a, b, batch_sizes, trans_a=False, trans_b=trans_b)
+
+ @staticmethod
+ def backward(ctx, grad):
+ grad = grad.contiguous()
+ a, b, batch_sizes = ctx.saved_tensors
+ trans_b = ctx.trans_b
+
+ agrad = None
+ if ctx.needs_input_grad[0]:
+ agrad = backend.gmm(
+ grad, b, batch_sizes, trans_a=False, trans_b=not trans_b)
+
+ bgrad = None
+ if ctx.needs_input_grad[1]:
+ lhs, rhs = (grad, a) if trans_b else (a, grad)
+ bgrad = backend.gmm(
+ lhs, rhs, batch_sizes, trans_a=True, trans_b=False)
+ return agrad, bgrad, None, None
+
+
+def gmm(a, b, batch_sizes, trans_b=False):
+ return GroupedGemm.apply(a, b, batch_sizes, trans_b)
diff --git a/build/torch211-cxx11-cu128-aarch64-linux/grouped_gemm_util.py b/build/torch211-cxx11-cu128-aarch64-linux/grouped_gemm_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d6d49fc46b0a57ad46e4179df3cc1ac2a24f7ae
--- /dev/null
+++ b/build/torch211-cxx11-cu128-aarch64-linux/grouped_gemm_util.py
@@ -0,0 +1,31 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+import warnings
+
+_grouped_gemm_is_available: bool = False
+try:
+ # import grouped_gemm
+ pass
+ _grouped_gemm_is_available = True
+except ImportError as error:
+ warnings.warn('Grouped GEMM not available.')
+
+
+def grouped_gemm_is_available():
+ return _grouped_gemm_is_available
+
+
+def assert_grouped_gemm_is_available():
+ msg = (
+ 'Grouped GEMM not available. Please run '
+ '`pip install git+https://github.com/tgale96/grouped_gemm@main`.',
+ )
+ assert _grouped_gemm_is_available, msg
+
+
+# backend = grouped_gemm.backend if grouped_gemm_is_available() else None
+# ops = grouped_gemm.ops if grouped_gemm_is_available() else None
+
+
+from .grouped_gemm import backend as ops
+from .grouped_gemm import ops as backend
diff --git a/build/torch211-cxx11-cu128-aarch64-linux/layers.py b/build/torch211-cxx11-cu128-aarch64-linux/layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c91edcd9e2d1a4ef9eac90217ff481f08ab1886
--- /dev/null
+++ b/build/torch211-cxx11-cu128-aarch64-linux/layers.py
@@ -0,0 +1,1232 @@
+import torch
+import torch.distributed as dist
+
+from typing import Optional, Any, TYPE_CHECKING
+
+from . import _layers
+from . import ops
+
+# Conditional import for meta kernel registration
+if TYPE_CHECKING:
+
+ def register_fake(fn):
+ return lambda name: fn
+
+else:
+ try:
+ from torch.library import register_fake
+ except ImportError:
+ try:
+ from torch.library import impl_abstract as register_fake
+ except ImportError:
+ # Fallback for older PyTorch versions
+ def register_fake(op_name):
+ def decorator(fn):
+ return fn
+
+ return decorator
+
+
+# Meta kernel implementations for torch.compile compatibility
+def _install_meta_kernels():
+ """Install meta kernels for existing MegaBlocks operations"""
+
+ # Create wrapper functions that check for compilation and return meta tensors
+
+ # Patch ops.sort
+ if hasattr(ops, "sort"):
+ original_sort = ops.sort
+
+ def sort_with_meta(x, end_bit=None):
+ if torch.compiler.is_compiling():
+ # print("Using meta kernel for sort")
+ # Meta implementation - return tensors with correct shape/dtype/device
+ return torch.empty_like(x), torch.empty_like(x)
+ # print("Using original sort kernel")
+ return original_sort(x, end_bit)
+
+ ops.sort = sort_with_meta
+
+ # Patch ops.histogram
+ if hasattr(ops, "histogram"):
+ original_histogram = ops.histogram
+
+ def histogram_with_meta(x, max_val):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty((max_val,), dtype=torch.int32, device=x.device)
+ return original_histogram(x, max_val)
+
+ ops.histogram = histogram_with_meta
+
+ # Patch ops.inclusive_cumsum
+ if hasattr(ops, "inclusive_cumsum"):
+ original_inclusive_cumsum = ops.inclusive_cumsum
+
+ def inclusive_cumsum_with_meta(x, dim):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty_like(x)
+ return original_inclusive_cumsum(x, dim)
+
+ ops.inclusive_cumsum = inclusive_cumsum_with_meta
+
+ # Patch ops.binned_gather
+ if hasattr(ops, "binned_gather"):
+ original_binned_gather = ops.binned_gather
+
+ def binned_gather_with_meta(x, indices, bins, bin_size, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - output shape based on bin_size
+ if x.dim() >= 2:
+ hidden_size = x.size(-1)
+ return torch.empty(
+ (bin_size, x.size(1), hidden_size),
+ dtype=x.dtype,
+ device=x.device,
+ )
+ else:
+ return torch.empty((bin_size,), dtype=x.dtype, device=x.device)
+ return original_binned_gather(x, indices, bins, bin_size, top_k)
+
+ ops.binned_gather = binned_gather_with_meta
+
+ # Patch ops.binned_scatter
+ if hasattr(ops, "binned_scatter"):
+ original_binned_scatter = ops.binned_scatter
+
+ def binned_scatter_with_meta(x, indices, weights, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - typically reduces to 2D
+ if x.dim() >= 3:
+ return torch.empty(
+ (x.size(1), x.size(2)), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty_like(x)
+ return original_binned_scatter(x, indices, weights, bins, top_k)
+
+ ops.binned_scatter = binned_scatter_with_meta
+
+ # Patch ops.gather
+ if hasattr(ops, "gather"):
+ original_gather = ops.gather
+
+ def gather_with_meta(x, indices, bin_ids, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ if x.dim() >= 2:
+ hidden_size = x.size(-1)
+ return torch.empty(
+ (indices.numel(), hidden_size), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty(indices.shape, dtype=x.dtype, device=x.device)
+ return original_gather(x, indices, bin_ids, bins, top_k)
+
+ ops.gather = gather_with_meta
+
+ # Patch ops.scatter
+ if hasattr(ops, "scatter"):
+ original_scatter = ops.scatter
+
+ def scatter_with_meta(x, indices, bin_ids, weights, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - restore sequence shape
+ seq_len = (
+ indices.size(0) // top_k
+ if indices.numel() > 0 and top_k > 0
+ else x.size(0)
+ )
+ if x.dim() >= 2:
+ return torch.empty(
+ (seq_len, x.size(-1)), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty((seq_len,), dtype=x.dtype, device=x.device)
+ return original_scatter(x, indices, bin_ids, weights, bins, top_k)
+
+ ops.scatter = scatter_with_meta
+
+ # Patch ops.replicate
+ if hasattr(ops, "replicate"):
+ original_replicate = ops.replicate
+
+ def replicate_with_meta(x, bins, num_outputs):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty(
+ (x.shape[0], num_outputs), dtype=x.dtype, device=x.device
+ )
+ return original_replicate(x, bins, num_outputs)
+
+ ops.replicate = replicate_with_meta
+
+ # Patch ops.repeat (if it's a regular function)
+ if hasattr(ops, "repeat"):
+ original_repeat = ops.repeat
+
+ def repeat_with_meta(x, repeats):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ if isinstance(repeats, (tuple, list)):
+ new_shape = list(x.shape)
+ for i, rep in enumerate(repeats):
+ if i < len(new_shape):
+ new_shape[i] *= rep
+ return torch.empty(new_shape, dtype=x.dtype, device=x.device)
+ else:
+ new_shape = [x.size(0) * repeats] + list(x.shape[1:])
+ return torch.empty(new_shape, dtype=x.dtype, device=x.device)
+ return original_repeat(x, repeats)
+
+ ops.repeat = repeat_with_meta
+
+
+# Install meta kernels on import
+try:
+ _install_meta_kernels()
+except Exception as e:
+ # If meta kernel installation fails, continue without them
+ # torch.compile may not work but the library will still function
+ import warnings
+
+ warnings.warn(
+ f"Failed to install meta kernels for torch.compile support: {e}", UserWarning
+ )
+
+
+# Set the expert model parallel attributes on a tensor
+def set_expert_model_parallel_attributes(
+ tensor: torch.Tensor,
+ is_parallel: bool,
+):
+ assert not hasattr(tensor, "expert_model_parallel")
+ setattr(tensor, "expert_model_parallel", is_parallel)
+
+
+# Get the expert model parallel attributes from a tensor
+def expert_sharding_degree(
+ world_size: int,
+ moe_num_experts: int,
+) -> int:
+ esd = min(world_size, moe_num_experts)
+ if (moe_num_experts % esd) != 0:
+ raise ValueError(f"Cannot shard {moe_num_experts} experts {esd} ways.")
+ return esd
+
+
+# Calculate the hidden sharding degree based on world size and expert sharding degree
+def hidden_sharding_degree(
+ world_size: int,
+ moe_num_experts: int,
+ ffn_hidden_size: int,
+) -> int:
+ esd = expert_sharding_degree(world_size, moe_num_experts)
+ hsd = world_size // esd
+ if (ffn_hidden_size % hsd) != 0:
+ raise ValueError(f"Cannot shard {ffn_hidden_size} features {hsd} ways.")
+ if (esd * hsd) != world_size:
+ raise ValueError(
+ f"Invalid sharding. expert_sharding_degree ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size})."
+ )
+ return hsd
+
+
+# Calculate the number of experts per rank based on world size and expert sharding degree
+def experts_per_rank(
+ moe_num_experts: int,
+ world_size: int,
+) -> int:
+ return moe_num_experts // expert_sharding_degree(world_size, moe_num_experts)
+
+
+# Calculate the number of features per rank based on ffn hidden size and hidden sharding degree
+def features_per_rank(
+ ffn_hidden_size: int, world_size: int, moe_num_experts: int
+) -> int:
+ return ffn_hidden_size // hidden_sharding_degree(
+ world_size, moe_num_experts, ffn_hidden_size
+ )
+
+
+# Apply jitter to the input tensor
+def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor:
+ low = 1.0 - moe_jitter_eps
+ high = 1.0 + moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return x * (low + noise * (high - low))
+
+
+# Compute the top-k scores from the logits
+def compute_top_k(scores: torch.Tensor, moe_top_k: int):
+ if moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, moe_top_k, dim=-1)
+
+
+# Route tokens to experts and compute expert weights and indices
+def route_tokens(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: torch.Tensor,
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ if training and moe_jitter_eps is not None:
+ x = apply_jitter(x, moe_jitter_eps)
+
+ x_flat = x.view(-1, x.shape[-1])
+ logits = torch.nn.functional.linear(x_flat, router_weight, router_bias)
+ expert_weights, expert_indices = compute_top_k(logits, moe_top_k)
+ expert_weights = expert_weights.softmax(dim=-1)
+ if moe_normalize_expert_weights is not None:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+ if uniform_expert_assignment:
+ expert_indices = _layers.router._uniform_expert_assignment(
+ expert_indices,
+ moe_num_experts,
+ )
+
+ return logits, expert_weights, expert_indices
+
+
+# Scale the gradient of the weights
+def scale_grad(
+ w: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+) -> torch.Tensor:
+ if gradient_scale is None:
+ return w
+ return _layers.mlp.scale_gradient(w, gradient_scale)
+
+
+# Forward pass for the MLP layer
+def mlp_forward(
+ x: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ limit: float = 7.0,
+):
+ # Scale weights
+ w1 = scale_grad(w1, gradient_scale)
+ w2 = scale_grad(w2, gradient_scale)
+ w1_bias = scale_grad(w1_bias, gradient_scale)
+ w2_bias = scale_grad(w2_bias, gradient_scale)
+
+ # Resolve dtensors
+ w1 = _layers.mlp.resolve_dtensor(w1)
+ w2 = _layers.mlp.resolve_dtensor(w2)
+ w1_bias = _layers.mlp.resolve_dtensor(w1_bias)
+ w2_bias = _layers.mlp.resolve_dtensor(w2_bias)
+
+ # Forward pass
+ gate_up = torch.bmm(x, w1) + w1_bias[..., None, :]
+ gate, up = gate_up[..., ::2], gate_up[..., 1::2]
+ gate = gate.clamp(min=None, max=limit)
+ up = up.clamp(min=-limit, max=limit)
+ glu = gate * torch.sigmoid(gate * alpha)
+ next_states = torch.bmm(((up + 1) * glu), w2)
+ next_states += w2_bias[..., None, :]
+ return next_states
+
+# Shared expert MLP forward pass
+def shared_mlp_forward(
+ x: torch.Tensor,
+ up_proj_weight: torch.Tensor,
+ down_proj_weight: torch.Tensor,
+ up_proj_bias: Optional[torch.Tensor] = None,
+ down_proj_bias: Optional[torch.Tensor] = None,
+ activation_fn: Optional[Any] = None,
+ gradient_scale: Optional[float] = None,
+) -> torch.Tensor:
+ # Default activation function
+ if activation_fn is None:
+ activation_fn = torch.nn.functional.gelu
+
+ # Scale weights
+ up_proj_weight = scale_grad(up_proj_weight, gradient_scale)
+ down_proj_weight = scale_grad(down_proj_weight, gradient_scale)
+ if up_proj_bias is not None:
+ up_proj_bias = scale_grad(up_proj_bias, gradient_scale)
+ if down_proj_bias is not None:
+ down_proj_bias = scale_grad(down_proj_bias, gradient_scale)
+
+ # Resolve dtensors
+ up_proj_weight = _layers.mlp.resolve_dtensor(up_proj_weight)
+ down_proj_weight = _layers.mlp.resolve_dtensor(down_proj_weight)
+ if up_proj_bias is not None:
+ up_proj_bias = _layers.mlp.resolve_dtensor(up_proj_bias)
+ if down_proj_bias is not None:
+ down_proj_bias = _layers.mlp.resolve_dtensor(down_proj_bias)
+
+ # Up projection
+ x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias)
+
+ # Activation
+ x = activation_fn(x)
+
+ # Down projection
+ x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias)
+
+ return x
+
+
+# Combine outputs from shared expert and regular experts
+def combine_expert_shared_outputs(
+ shared_expert_out: torch.Tensor,
+ expert_out: torch.Tensor,
+ shared_expert_weighted_sum: bool = False,
+ moe_top_k: int = 1,
+) -> torch.Tensor:
+ if shared_expert_weighted_sum:
+ # Weighted sum based on number of experts used
+ total_experts = moe_top_k + 1
+ shared_weight = 1.0 / total_experts
+ expert_weight = moe_top_k / total_experts
+ return shared_expert_out * shared_weight + expert_out * expert_weight
+ else:
+ # Simple addition
+ return shared_expert_out + expert_out
+
+
+# Global variable to store load balancing loss
+_LOAD_BALANCING_LOSS = []
+
+
+def save_load_balancing_loss(loss):
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.append(loss)
+
+
+def get_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ return _LOAD_BALANCING_LOSS
+
+
+def clear_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.clear()
+
+
+def batched_load_balancing_loss(args):
+ if args.moe_loss_weight == 0:
+ return 0.0
+
+ tokens_per_expert, expert_scores = zip(*get_load_balancing_loss())
+ num_layers_per_pipeline_stage = args.num_layers // args.pipeline_model_parallel_size
+ if args.num_layers_per_virtual_pipeline_stage is not None:
+ num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage
+
+ if len(tokens_per_expert) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f"Expected {num_layers_per_pipeline_stage} token_per_experts "
+ f"but found {len(tokens_per_expert)}.\nnum_layers = "
+ f"{args.num_layers}\npipeline_model_parallel_size = "
+ f"{args.pipeline_model_parallel_size}\n"
+ "num_layers_per_virtual_pipeline_stage"
+ f" = {args.num_layers_per_virtual_pipeline_stage}",
+ )
+ if len(expert_scores) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f"Expected {num_layers_per_pipeline_stage} expert_scores "
+ f"but found {len(tokens_per_expert)}.\nnum_layers = "
+ f"{args.num_layers}\npipeline_model_parallel_size = "
+ f"{args.pipeline_model_parallel_size}\n"
+ "num_layers_per_virtual_pipeline_stage"
+ f" = {args.num_layers_per_virtual_pipeline_stage}",
+ )
+
+ # Verify the shape of the tokens_per_expert and expert_scores tensors.
+ assert all(
+ (x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert)
+ )
+
+ tokens = expert_scores[0].shape[0]
+ assert all(
+ (
+ (
+ x.ndim == 2
+ and x.shape[1] == args.moe_num_experts
+ and x.shape[0] == tokens
+ )
+ for x in expert_scores
+ )
+ )
+
+ # Concatenate the contributions of each layer and convert to
+ # the correct types and formats for the dot product.
+ expert_scores = torch.cat(expert_scores, dim=1)
+ if args.moe_lbl_in_fp32:
+ expert_scores = expert_scores.float()
+ if tokens != 0:
+ expert_scores = expert_scores.mean(dim=0)
+ else:
+ expert_scores = expert_scores.sum(dim=0)
+ tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype)
+
+ expected_values = num_layers_per_pipeline_stage * args.moe_num_experts
+ assert tokens_per_expert.numel() == expected_values
+ assert expert_scores.numel() == expected_values
+
+ # Calculate the total scale across all factors.
+ #
+ # loss_weight * num_experts / (num_layers * tokens * top_k)
+ scale_numerator = args.moe_num_experts * args.moe_loss_weight
+ scale_denominator = args.num_layers * tokens * args.moe_top_k
+ scale = scale_numerator / scale_denominator
+ return scale * torch.dot(tokens_per_expert, expert_scores)
+
+
+# Calculate the expert capacity based on tokens, top_k, number of experts,
+# expert parallel group, capacity factor, and whether expert model parallelism is used.
+def expert_capacity(
+ tokens: int,
+ top_k: int,
+ num_experts: int,
+ expert_parallel_group: int,
+ moe_capacity_factor: float,
+ moe_expert_model_parallelism: bool,
+) -> int:
+ world_size = (
+ dist.get_world_size(expert_parallel_group)
+ if moe_expert_model_parallelism
+ else 1
+ )
+
+ tokens_per_expert = top_k * tokens * world_size / num_experts
+ return int(moe_capacity_factor * tokens_per_expert)
+
+
+def load_balancing_loss(
+ tokens_per_expert: torch.Tensor,
+ expert_scores: torch.Tensor,
+ top_k: int,
+ num_experts: int,
+):
+ assert len(expert_scores.size()) == 2
+ tokens, num_experts = expert_scores.size()
+ assert num_experts == num_experts
+ assert len(tokens_per_expert.size()) == 1
+ (num_experts,) = tokens_per_expert.size()
+ assert num_experts == num_experts
+ scale = num_experts / (tokens * top_k)
+ return scale * torch.dot(
+ tokens_per_expert.to(expert_scores.dtype),
+ expert_scores.mean(dim=0),
+ )
+
+
+def indices_and_bins(
+ top_expert: torch.Tensor,
+ sort_end_bit: int,
+ num_experts: int,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ top_expert = top_expert.int()
+
+ # Ensure contiguous memory layout
+ top_expert = top_expert.contiguous()
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(top_expert.device):
+ output = ops.sort(top_expert, sort_end_bit)
+ bin_ids, indices = output
+ tokens_per_expert = ops.histogram(top_expert, num_experts)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+
+ bins = bins.view(1) if not len(bins.size()) else bins
+ return indices, bin_ids, bins, tokens_per_expert
+
+
+def expert_capacity_fn(
+ tokens: int,
+ top_k: int,
+ num_experts: int,
+ expert_parallel_group: torch.distributed.ProcessGroup,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+) -> int:
+ world_size = (
+ dist.get_world_size(expert_parallel_group)
+ if moe_expert_model_parallelism
+ else 1
+ )
+ tokens_per_expert = top_k * tokens * world_size / num_experts
+ return int(moe_capacity_factor * tokens_per_expert)
+
+
+def permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ top_k,
+ w1,
+ w2,
+ w1_bias,
+ w2_bias,
+ gradient_scale,
+ alpha,
+):
+ # Route tokens to experts
+ x = x.view(-1, x.shape[-1])
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(x.device):
+ x = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
+
+ # Expert computation
+ x = mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale, alpha)
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(x.device):
+ # Route tokens back
+ out = ops.binned_scatter(x, indices, expert_weights, bins, top_k)
+ return out
+
+
+def forward_once(
+ x: torch.Tensor,
+ expert_weights: torch.Tensor,
+ top_experts: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ top_k: int = 4,
+ num_experts: int = 128,
+ expert_parallel_group: int = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ mlp_impl: Optional[str] = None,
+):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
+ top_experts, sort_end_bit, num_experts
+ )
+
+ # Calculate expert capacity
+ sl, bs, _ = x.size()
+
+ expert_capacity = expert_capacity_fn(
+ sl * bs,
+ top_k,
+ num_experts,
+ expert_parallel_group,
+ moe_capacity_factor,
+ moe_expert_model_parallelism,
+ )
+
+ if expert_capacity == 0:
+ expert_capacity = torch.max(tokens_per_expert).item()
+
+ x = permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ top_k,
+ w1,
+ w2,
+ w1_bias,
+ w2_bias,
+ gradient_scale,
+ alpha,
+ )
+ return x, tokens_per_expert
+
+
+def parallel_forward_once(
+ x: torch.Tensor,
+ expert_weights: torch.Tensor,
+ top_experts: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ top_k: int = 4,
+ num_experts: int = 128,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = True,
+ hidden_size: int = 1152,
+ mlp_impl: Optional[str] = "grouped",
+):
+ # Flatten inputs
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+
+ # TODO: remove debugging var
+ # my_rank = dist.get_rank(expert_parallel_group) if expert_parallel_group else 0
+
+ with torch.no_grad():
+ # Step 1: Local permutation setup
+ indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
+ top_experts, sort_end_bit, num_experts
+ )
+
+ # Calculate sharding parameters
+ world_size = dist.get_world_size(expert_parallel_group)
+ hidden_sharding_deg = hidden_sharding_degree(
+ world_size, num_experts, hidden_size
+ )
+ experts_per_rank_val = experts_per_rank(num_experts, world_size)
+
+ # Replicate token counts for hidden sharding
+ repeated_tokens_per_expert = ops.repeat(
+ tokens_per_expert, (hidden_sharding_deg,)
+ )
+
+ # Exchange token counts across devices
+ parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert)
+
+ # Ensure CUB knows which device to use
+ tpe_handle = dist.all_to_all_single(
+ parallel_tokens_per_expert,
+ repeated_tokens_per_expert,
+ group=expert_parallel_group,
+ async_op=True,
+ )
+
+ # Step 2: Local permutation - group tokens by target device
+ x = x.view(-1, x.shape[-1]) # [sl * bs, hs]
+ x = ops.gather(x, indices, bin_ids, bins, top_k)
+
+ # Step 3: Compute communication counts and exchange tokens
+ with torch.no_grad():
+ tpe_handle.wait()
+
+ # Reshape for per-device calculations
+ repeated_tokens_per_expert = repeated_tokens_per_expert.view(
+ world_size, experts_per_rank_val
+ )
+ parallel_tokens_per_expert = parallel_tokens_per_expert.view(
+ world_size, experts_per_rank_val
+ )
+
+ # Calculate send/recv counts
+ send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1).tolist()
+ # recv_counts = parallel_tokens_per_expert.cpu().sum(dim=-1).tolist()
+ parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
+ recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1).tolist()
+ tokens_received = sum(recv_counts)
+
+ # Replicate for hidden sharding
+ x = ops.repeat(x, (hidden_sharding_deg, 1))
+
+ # Cross-device token exchange
+ parallel_x, parallel_x_handle = _layers.all_to_all.all_to_all(
+ x, recv_counts, send_counts, expert_parallel_group, async_op=True
+ )
+
+ with torch.no_grad():
+ # Step 4: Setup for local expert computation
+ replicate_bins = ops.inclusive_cumsum(parallel_tokens_per_expert.flatten(), 0)
+ replicate_bins = (
+ replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins
+ )
+
+ # Create expert indices for received tokens
+ parallel_top_expert = torch.remainder(
+ torch.arange(
+ num_experts * hidden_sharding_deg,
+ dtype=torch.int32,
+ device=indices.device,
+ ),
+ experts_per_rank_val,
+ )
+ parallel_top_expert = ops.replicate(
+ parallel_top_expert.unsqueeze(dim=0),
+ replicate_bins,
+ tokens_received,
+ ).flatten()
+
+ # Sort tokens by expert assignment
+ parallel_bin_ids, parallel_indices = ops.sort(
+ parallel_top_expert,
+ sort_end_bit,
+ )
+
+ # Calculate bins for local experts
+ parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
+ dim=0, dtype=torch.int
+ )
+ parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = (
+ parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins
+ )
+
+ # Calculate expert capacity
+ expert_capacity = expert_capacity_fn(
+ tokens_received,
+ top_k,
+ experts_per_rank_val,
+ expert_parallel_group,
+ moe_capacity_factor,
+ moe_expert_model_parallelism,
+ )
+ if expert_capacity == 0:
+ expert_capacity = torch.max(parallel_tokens_per_expert).item()
+
+ # Locally permute the tokens and perform the expert computation.
+ # Block to make sure that the cross-device permutation is complete.
+ if mlp_impl == "grouped":
+ # GroupedMLP requires counts on CPU. We can use the tensor already
+ # moved to CPU for the prior all_to_all, which avoids an extra
+ # device synchronization.
+ parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+
+ # Step 5: Expert computation
+ parallel_x_handle.wait()
+
+ parallel_x = permute_and_compute(
+ parallel_x,
+ parallel_tokens_per_expert,
+ parallel_indices,
+ parallel_bin_ids,
+ None, # expert_weights
+ parallel_bins,
+ expert_capacity,
+ top_k=1,
+ w1=w1,
+ w2=w2,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ )
+
+ # Step 6: Reverse communication - send results back
+ x, _ = _layers.all_to_all.all_to_all(
+ parallel_x, send_counts, recv_counts, expert_parallel_group
+ )
+
+ # Step 7: Reduce across hidden sharding dimension
+ shape = (hidden_sharding_deg, -1, hidden_size)
+ x = x.view(shape).sum(dim=0)
+
+ # Step 8: Final local unpermutation
+ x = ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
+
+ return x, tokens_per_expert.flatten()
+
+
+def moe_forward(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: Optional[torch.Tensor],
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+ w1: torch.Tensor = None,
+ w2: torch.Tensor = None,
+ w1_bias: torch.Tensor = None,
+ w2_bias: torch.Tensor = None,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ forward_fn: Any = None,
+ hidden_size: int = None,
+ mlp_impl: str = "grouped",
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+
+ # Route tokens to experts
+ logits, expert_weights, expert_indices = route_tokens(
+ x,
+ router_weight,
+ router_bias,
+ moe_top_k,
+ moe_num_experts,
+ moe_jitter_eps,
+ moe_normalize_expert_weights,
+ uniform_expert_assignment,
+ training,
+ )
+
+ # Create router scores for output
+ router_scores = (
+ torch.zeros_like(logits)
+ .scatter_(1, expert_indices, expert_weights)
+ .transpose(0, 1)
+ )
+
+ in_shape = x.size()
+
+ # Prepare forward function arguments
+ forward_args = {
+ "x": x,
+ "expert_weights": expert_weights,
+ "top_experts": expert_indices,
+ "w1": w1,
+ "w2": w2,
+ "w1_bias": w1_bias,
+ "w2_bias": w2_bias,
+ "gradient_scale": gradient_scale,
+ "alpha": alpha,
+ "sort_end_bit": sort_end_bit,
+ "top_k": moe_top_k,
+ "num_experts": moe_num_experts,
+ "expert_parallel_group": expert_parallel_group,
+ "moe_capacity_factor": moe_capacity_factor,
+ "moe_expert_model_parallelism": moe_expert_model_parallelism,
+ "mlp_impl": mlp_impl,
+ }
+
+ # Add hidden_size for parallel forward
+ if moe_expert_model_parallelism and hidden_size is not None:
+ forward_args["hidden_size"] = hidden_size
+ elif moe_expert_model_parallelism and hidden_size is None:
+ # Infer hidden_size from input shape
+ forward_args["hidden_size"] = x.shape[-1]
+
+ # Compute expert outputs
+ x, tokens_per_expert = forward_fn(**forward_args)
+
+ # Save load balancing loss if needed
+ moe_loss_weight = 0.0 # Can be made configurable
+ if training and moe_loss_weight > 0:
+ save_load_balancing_loss((tokens_per_expert, logits))
+
+ # Restore original shape
+ x = x.view(in_shape)
+
+ return x, expert_weights, router_scores
+
+
+def moe_forward_with_shared_expert(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: Optional[torch.Tensor],
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+ w1: torch.Tensor = None,
+ w2: torch.Tensor = None,
+ w1_bias: torch.Tensor = None,
+ w2_bias: torch.Tensor = None,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ forward_fn: Any = None,
+ hidden_size: int = None,
+ mlp_impl: str = "grouped",
+ # Shared expert parameters
+ shared_up_proj_weight: Optional[torch.Tensor] = None,
+ shared_down_proj_weight: Optional[torch.Tensor] = None,
+ shared_up_proj_bias: Optional[torch.Tensor] = None,
+ shared_down_proj_bias: Optional[torch.Tensor] = None,
+ shared_expert_weighted_sum: bool = False,
+ shared_activation_fn: Optional[Any] = None,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+
+ # First, compute regular MoE forward pass
+ expert_out, expert_weights, router_scores = moe_forward(
+ x=x,
+ router_weight=router_weight,
+ router_bias=router_bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=training,
+ w1=w1,
+ w2=w2,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=moe_expert_model_parallelism,
+ forward_fn=forward_fn,
+ hidden_size=hidden_size,
+ mlp_impl=mlp_impl,
+ )
+
+ # If shared expert weights provided, compute shared expert output
+ if shared_up_proj_weight is not None and shared_down_proj_weight is not None:
+ shared_expert_out = shared_mlp_forward(
+ x=x,
+ up_proj_weight=shared_up_proj_weight,
+ down_proj_weight=shared_down_proj_weight,
+ up_proj_bias=shared_up_proj_bias,
+ down_proj_bias=shared_down_proj_bias,
+ activation_fn=shared_activation_fn,
+ gradient_scale=gradient_scale,
+ )
+
+ # Combine expert outputs
+ combined_out = combine_expert_shared_outputs(
+ shared_expert_out=shared_expert_out,
+ expert_out=expert_out,
+ shared_expert_weighted_sum=shared_expert_weighted_sum,
+ moe_top_k=moe_top_k,
+ )
+
+ return combined_out, expert_weights, router_scores
+
+ # Return regular MoE output if no shared expert
+ return expert_out, expert_weights, router_scores
+
+
+def create_shared_expert_weights(
+ hidden_size: int,
+ shared_expert_hidden_size: int,
+ device: torch.device,
+ dtype: torch.dtype,
+ init_method: Any,
+ output_layer_init_method: Any = None,
+) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
+
+ if output_layer_init_method is None:
+ output_layer_init_method = init_method
+
+ # Create weight tensors
+ up_proj_weight = torch.empty(
+ shared_expert_hidden_size,
+ hidden_size,
+ device=device,
+ dtype=dtype,
+ )
+ down_proj_weight = torch.empty(
+ hidden_size,
+ shared_expert_hidden_size,
+ device=device,
+ dtype=dtype,
+ )
+
+ # Initialize weights
+ init_method(up_proj_weight)
+ output_layer_init_method(down_proj_weight)
+
+ # No bias by default
+ return up_proj_weight, down_proj_weight, None, None
+
+
+# HACK: Extract device_mesh from pre-hook closure - required for transformers integration
+# This exists because device_mesh is trapped in hook closures with no model attribute
+# Fragile - breaks if hook structure changes or Python internals change
+# TODO: Replace with a more robust solution when available
+def get_device_mesh(model):
+ # Extract device_mesh from child's unused pre_hook closure
+ try:
+ # Find the pre-hook that contains 'device_mesh' in its closure
+ hook = next(
+ h
+ for h in model.experts._forward_pre_hooks.values()
+ if "device_mesh" in h.__code__.co_freevars
+ )
+ # Extract the device_mesh from the closure
+ return hook.__closure__[
+ hook.__code__.co_freevars.index("device_mesh")
+ ].cell_contents
+ except Exception:
+ return None
+
+
+class MegaBlocksMoeMLP(torch.nn.Module):
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ gradient_scale = getattr(self.experts, "gradient_scale", None)
+ alpha = getattr(self.experts, "alpha", 1.0)
+ moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+ uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
+
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = get_device_mesh(self)
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
+
+ has_parallel = (
+ expert_parallel_group is not None
+ and dist.is_initialized()
+ and dist.get_world_size(expert_parallel_group) > 1
+ )
+ forward_fn = parallel_forward_once if has_parallel else forward_once
+
+ sort_end_bit = max(
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
+ )
+ mlp_impl = getattr(self, "mlp_impl", "grouped")
+ output, expert_weights_out, *_ = moe_forward(
+ x=x,
+ router_weight=self.router.weight,
+ router_bias=self.router.bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=self.training,
+ w1=self.experts.gate_up_proj,
+ w2=self.experts.down_proj,
+ w1_bias=self.experts.gate_up_proj_bias,
+ w2_bias=self.experts.down_proj_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=has_parallel,
+ forward_fn=forward_fn,
+ hidden_size=self.experts.hidden_size,
+ mlp_impl=mlp_impl,
+ )
+ return output, expert_weights_out
+
+
+# Export main classes
+__all__ = ["MegaBlocksMoeMLP", "MegaBlocksMoeMLPWithSharedExpert"]
+
+
+class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
+
+ def __init__(self):
+ super().__init__()
+ # Shared expert weights will be set by the user
+ self.shared_up_proj_weight = None
+ self.shared_down_proj_weight = None
+ self.shared_up_proj_bias = None
+ self.shared_down_proj_bias = None
+ self.shared_expert_weighted_sum = False
+ self.shared_activation_fn = None
+
+ def set_shared_expert_weights(
+ self,
+ up_proj_weight: torch.Tensor,
+ down_proj_weight: torch.Tensor,
+ up_proj_bias: Optional[torch.Tensor] = None,
+ down_proj_bias: Optional[torch.Tensor] = None,
+ weighted_sum: bool = False,
+ activation_fn: Optional[Any] = None,
+ ):
+ self.shared_up_proj_weight = up_proj_weight
+ self.shared_down_proj_weight = down_proj_weight
+ self.shared_up_proj_bias = up_proj_bias
+ self.shared_down_proj_bias = down_proj_bias
+ self.shared_expert_weighted_sum = weighted_sum
+ self.shared_activation_fn = activation_fn
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ gradient_scale = getattr(self.experts, "gradient_scale", None)
+ alpha = getattr(self.experts, "alpha", 1.0)
+ moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+ uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
+
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = get_device_mesh(self)
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
+
+ has_parallel = (
+ expert_parallel_group is not None
+ and dist.is_initialized()
+ and dist.get_world_size(expert_parallel_group) > 1
+ )
+ forward_fn = parallel_forward_once if has_parallel else forward_once
+
+ sort_end_bit = max(
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
+ )
+ mlp_impl = getattr(self, "mlp_impl", "grouped")
+
+ output, expert_weights_out, *_ = moe_forward_with_shared_expert(
+ x=x,
+ router_weight=self.router.weight,
+ router_bias=self.router.bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=self.training,
+ w1=self.experts.gate_up_proj,
+ w2=self.experts.down_proj,
+ w1_bias=self.experts.gate_up_proj_bias,
+ w2_bias=self.experts.down_proj_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=has_parallel,
+ forward_fn=forward_fn,
+ hidden_size=self.experts.hidden_size,
+ mlp_impl=mlp_impl,
+ # Shared expert parameters
+ shared_up_proj_weight=self.shared_up_proj_weight,
+ shared_down_proj_weight=self.shared_down_proj_weight,
+ shared_up_proj_bias=self.shared_up_proj_bias,
+ shared_down_proj_bias=self.shared_down_proj_bias,
+ shared_expert_weighted_sum=self.shared_expert_weighted_sum,
+ shared_activation_fn=self.shared_activation_fn,
+ )
+ return output, expert_weights_out
+
+
+# Patch for XPU or CPU support
+if hasattr(torch, "xpu") and torch.xpu.is_available():
+ from .xpu_fused_moe import MegaBlocksMoeMLP
+
+from .cpu_moe_cpp import CPUMegaBlocksMoeMLP
diff --git a/build/torch211-cxx11-cu128-aarch64-linux/megablocks/__init__.py b/build/torch211-cxx11-cu128-aarch64-linux/megablocks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23
--- /dev/null
+++ b/build/torch211-cxx11-cu128-aarch64-linux/megablocks/__init__.py
@@ -0,0 +1,26 @@
+import ctypes
+import importlib.util
+import sys
+from pathlib import Path
+from types import ModuleType
+
+
+def _import_from_path(file_path: Path) -> ModuleType:
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
+ # it would also be used for other imports. So, we make a module name that
+ # depends on the path for it to be unique using the hex-encoded hash of
+ # the path.
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
+ module_name = path_hash
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ if spec is None:
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
+ module = importlib.util.module_from_spec(spec)
+ if module is None:
+ raise ImportError(f"Cannot load module {module_name} from spec")
+ sys.modules[module_name] = module
+ spec.loader.exec_module(module) # type: ignore
+ return module
+
+
+globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
diff --git a/build/torch211-cxx11-cu128-aarch64-linux/metadata.json b/build/torch211-cxx11-cu128-aarch64-linux/metadata.json
new file mode 100644
index 0000000000000000000000000000000000000000..e3e4edf582b7ffb515d0ed32e9fc9c89f125c441
--- /dev/null
+++ b/build/torch211-cxx11-cu128-aarch64-linux/metadata.json
@@ -0,0 +1,21 @@
+{
+ "version": 1,
+ "license": "Apache-2.0",
+ "python-depends": [],
+ "backend": {
+ "type": "cuda",
+ "archs": [
+ "10.0",
+ "10.1",
+ "12.0",
+ "7.0",
+ "7.2",
+ "7.5",
+ "8.0",
+ "8.6",
+ "8.7",
+ "8.9",
+ "9.0"
+ ]
+ }
+}
diff --git a/build/torch211-cxx11-cu128-aarch64-linux/ops/__init__.py b/build/torch211-cxx11-cu128-aarch64-linux/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b944080df810d0b0cfc571f3009b0098a651f9b7
--- /dev/null
+++ b/build/torch211-cxx11-cu128-aarch64-linux/ops/__init__.py
@@ -0,0 +1,35 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from .binned_gather import binned_gather
+from .binned_scatter import binned_scatter
+from .cumsum import exclusive_cumsum, inclusive_cumsum
+from .gather import gather
+from .histogram import histogram
+from .padded_gather import padded_gather
+from .padded_scatter import padded_scatter
+from .repeat import repeat
+from .replicate import replicate
+from .round_up import round_up
+from .scatter import scatter
+from .sort import sort
+from .sum import sum
+from .topology import topology
+
+__all__ = [
+ 'binned_gather',
+ 'binned_scatter',
+ 'exclusive_cumsum',
+ 'inclusive_cumsum',
+ 'gather',
+ 'histogram',
+ 'padded_gather',
+ 'padded_scatter',
+ 'repeat',
+ 'replicate',
+ 'round_up',
+ 'scatter',
+ 'sort',
+ 'sum',
+ 'topology',
+]
diff --git a/build/torch211-cxx11-cu128-aarch64-linux/ops/all_to_all_benchmark.py b/build/torch211-cxx11-cu128-aarch64-linux/ops/all_to_all_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c939818edca3345f6344bbc7cef07ffe3cd0181
--- /dev/null
+++ b/build/torch211-cxx11-cu128-aarch64-linux/ops/all_to_all_benchmark.py
@@ -0,0 +1,63 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import torch.distributed as dist
+
+# from megablocks import benchmark_util
+# from megablocks.layers.all_to_all import all_to_all
+
+from .. import benchmark_util
+from .._layers.all_to_all import all_to_all
+
+_ALL_TO_ALL_BENCHMARK = (
+ (8, 1024),
+ (16, 1024),
+ (32, 1024),
+ (64, 1024),
+ (128, 1024),
+ (256, 1024),
+ (512, 1024),
+ (1024, 1024),
+ (2 * 1024, 1024),
+ (4 * 1024, 1024),
+ (8 * 1024, 1024),
+ (16 * 1024, 1024),
+ (32 * 1024, 1024),
+ (64 * 1024, 1024),
+ (128 * 1024, 1024),
+ (256 * 1024, 1024),
+ (512 * 1024, 1024),
+ (1024 * 1024, 1024),
+)
+
+
+def benchmark_all_to_all(group, sl, hs):
+ world_size = dist.get_world_size(group)
+ assert (sl % world_size) == 0
+ send_recv_sizes = [sl // world_size] * world_size
+
+ x = torch.randn((sl, hs)).cuda().half()
+
+ details = {
+ 'world_size': world_size,
+ 'message_size (B)': send_recv_sizes[0] * hs * 2, # 2B elements.
+ }
+
+ def benchmark():
+ return all_to_all(x, send_recv_sizes, send_recv_sizes, group)
+
+ time, std = benchmark_util.benchmark_function(benchmark)
+
+ if dist.get_rank(group) == 0:
+ benchmark_util.log_benchmark('All-To-All', details, time, std)
+
+
+if __name__ == '__main__':
+ assert dist.is_available()
+ group = dist.init_process_group(backend='nccl')
+ local_rank = dist.get_rank(group)
+ torch.cuda.set_device(local_rank)
+
+ for args in _ALL_TO_ALL_BENCHMARK:
+ benchmark_all_to_all(group, *args)
diff --git a/build/torch211-cxx11-cu128-aarch64-linux/ops/binned_gather.py b/build/torch211-cxx11-cu128-aarch64-linux/ops/binned_gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..189a7fa3518d660f29ea32e7a04827164af98d60
--- /dev/null
+++ b/build/torch211-cxx11-cu128-aarch64-linux/ops/binned_gather.py
@@ -0,0 +1,37 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for binned_gather kernel.
+class BinnedGatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bins: torch.Tensor,
+ bin_size: int,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bins)
+ ctx.top_k = top_k
+ return kernels.binned_gather(x, indices, None, bins, bin_size, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ indices, bins = ctx.saved_tensors
+ out = kernels.binned_scatter(grad, indices, None, bins, ctx.top_k)
+ return out, None, None, None, None
+
+
+binned_gather = BinnedGatherOp.apply
diff --git a/build/torch211-cxx11-cu128-aarch64-linux/ops/binned_scatter.py b/build/torch211-cxx11-cu128-aarch64-linux/ops/binned_scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb937c0c106662ce8108c1cb926f8f063b163d3d
--- /dev/null
+++ b/build/torch211-cxx11-cu128-aarch64-linux/ops/binned_scatter.py
@@ -0,0 +1,59 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for binned_scatter kernel.
+class BinnedScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ):
+ assert len(x.size()) == 3
+ ctx.bin_size = x.size(1)
+ ctx.top_k = top_k
+
+ # TODO(tgale): Don't save 'x' for backwards if we don't need to
+ # calculate the gradient w.r.t. 'weights'.
+ ctx.save_for_backward(x, indices, weights, bins)
+ return kernels.binned_scatter(x, indices, weights, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ x, indices, weights, bins = ctx.saved_tensors
+ out = kernels.binned_gather(
+ grad,
+ indices,
+ weights,
+ bins,
+ ctx.bin_size,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[2]:
+ wgrad = kernels.binned_scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bins,
+ ctx.top_k,
+ )
+ return out, None, wgrad, None, None
+
+
+binned_scatter = BinnedScatterOp.apply
diff --git a/build/torch211-cxx11-cu128-aarch64-linux/ops/cumsum.py b/build/torch211-cxx11-cu128-aarch64-linux/ops/cumsum.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2b7572391e20045d335cf7337246e8a9b9f57ef
--- /dev/null
+++ b/build/torch211-cxx11-cu128-aarch64-linux/ops/cumsum.py
@@ -0,0 +1,52 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ # import megablocks_ops as ops # type: ignore
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrappers for cumsum kernels.
+# NOTE: Does not support gradients.
+class ExclusiveCumsumOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, dim: int):
+ if len(x.size()) == 1:
+ x = x.view([1, -1])
+ out = torch.empty_like(x)
+ ops.exclusive_cumsum(x, 1, out)
+ return out.squeeze()
+ out = torch.empty_like(x)
+ ops.exclusive_cumsum(x, dim, out)
+ return out
+
+
+exclusive_cumsum = ExclusiveCumsumOp.apply
+
+
+class InclusiveCumsumOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, dim: int) -> torch.Tensor:
+ if len(x.size()) == 1:
+ x = x.view([1, -1])
+ out = torch.empty_like(x)
+ ops.inclusive_cumsum(x, 1, out)
+ return out.squeeze()
+ out = torch.empty_like(x)
+ ops.inclusive_cumsum(x, dim, out)
+ return out
+
+
+inclusive_cumsum = InclusiveCumsumOp.apply
diff --git a/build/torch211-cxx11-cu128-aarch64-linux/ops/gather.py b/build/torch211-cxx11-cu128-aarch64-linux/ops/gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1f87c1e7bed8d3589dd790805234976e0b05898
--- /dev/null
+++ b/build/torch211-cxx11-cu128-aarch64-linux/ops/gather.py
@@ -0,0 +1,38 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for gather kernel.
+class GatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bin_ids, bins)
+ ctx.top_k = top_k
+ return kernels.gather(x, indices, bin_ids, None, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+
+ indices, bin_ids, bins = ctx.saved_tensors
+ out = kernels.scatter(grad, indices, bin_ids, None, bins, ctx.top_k)
+ return out, None, None, None, None, None
+
+
+gather = GatherOp.apply
diff --git a/build/torch211-cxx11-cu128-aarch64-linux/ops/histogram.py b/build/torch211-cxx11-cu128-aarch64-linux/ops/histogram.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b3f058ec373cbba7555704fb5e4212c3cc75d9d
--- /dev/null
+++ b/build/torch211-cxx11-cu128-aarch64-linux/ops/histogram.py
@@ -0,0 +1,27 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for histogram kernel.
+# NOTE: Does not support gradients.
+class HistogramOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, max_val: float):
+ return ops.histogram(x, max_val)
+
+
+histogram = HistogramOp.apply
diff --git a/build/torch211-cxx11-cu128-aarch64-linux/ops/histogram_benchmark.py b/build/torch211-cxx11-cu128-aarch64-linux/ops/histogram_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0b9c6047567b87a295979498142230d1b0c9411
--- /dev/null
+++ b/build/torch211-cxx11-cu128-aarch64-linux/ops/histogram_benchmark.py
@@ -0,0 +1,78 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import numpy as np
+import torch
+# from absl.testing import parameterized
+
+from .. import ops
+
+_HISTOGRAM_TESTS = (
+ (16384, torch.int32, 2),
+ (16384, torch.int32, 4),
+ (16384, torch.int32, 8),
+ (16384, torch.int32, 16),
+ (16384, torch.int32, 32),
+ (16384, torch.int32, 64),
+ (16384, torch.int32, 128),
+ (16384, torch.int32, 256),
+)
+
+
+def benchmark_function(fn, iterations=10):
+ # Run once to get rid of startup overhead.
+ fn()
+ times = []
+ for _ in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+ start.record()
+ fn()
+ end.record()
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ times = np.array(times)
+ return times.mean(), times.std(), times.max(), times.min()
+
+
+def log_benchmark(arguments, mean_t, std_t):
+ print('=' * 60)
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t))
+ print('=' * 60)
+
+
+# class HistogramBenchmark(parameterized.TestCase):
+#
+# @parameterized.parameters(*_HISTOGRAM_TESTS)
+# def testHistogram(self, n, dtype, max_val):
+# x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
+#
+# mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.histogram(x, max_val),)
+# arguments = {
+# 'n': n,
+# 'dtype': dtype,
+# 'max_val': max_val,
+# }
+# log_benchmark(arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_HISTOGRAM_TESTS)
+# def testTorchHistogram(self, n, dtype, max_val):
+# x = torch.randint(0, 128, (n,)).cuda().to(dtype)
+#
+# mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.histc(x, max_val, 0, max_val - 1),)
+# arguments = {
+# 'n': n,
+# 'dtype': dtype,
+# 'max_val': max_val,
+# }
+# log_benchmark(arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch211-cxx11-cu128-aarch64-linux/ops/matmul_benchmark.py b/build/torch211-cxx11-cu128-aarch64-linux/ops/matmul_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c4b9b8866ffed2eb769b77f2320c82e5491ae0e
--- /dev/null
+++ b/build/torch211-cxx11-cu128-aarch64-linux/ops/matmul_benchmark.py
@@ -0,0 +1,415 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+
+# import stk
+
+# try:
+# import stk
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/ops/matmul_benchmark.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+# from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+
+# Calling tensor.t() calls tensor.transpose(0, 1) which calls
+# torch.as_strided(...). Circumvent this chain to avoid an overhead
+# this adds.
+def transpose_view(x):
+ return torch.as_strided(
+ x,
+ (x.shape[1], x.shape[0]),
+ (x.stride()[1], x.stride()[0]),
+ )
+
+
+_MATMUL_TESTS = (
+ (64 * 1024, 512, 2048, 64),
+ (32 * 1024, 768, 3072, 64),
+ (8 * 1024, 1024, 4096, 64),
+ (4 * 2048, 4096, 4 * 4096, 4),
+)
+
+
+def log_benchmark(name, arguments, time, std, flops):
+ benchmark_util.log_benchmark(name, arguments, time, std)
+ print('flops = {:.2f}B'.format(flops / 1e9))
+ print('throughput = {:.2f}T'.format(flops / 1e9 / time))
+ print('=' * 60)
+
+
+# class MatmulBenchmark(parameterized.TestCase):
+#
+# def build_sparse_matrix(self, x, padded_bins, fhs, ne):
+# blocking = 128
+# padded_tokens, _ = x.size()
+# assert padded_tokens % blocking == 0
+# assert fhs % blocking == 0
+#
+# # Offsets for the sparse matrix. All rows have the
+# # same number of nonzero blocks dictated by the
+# # dimensionality of a single expert.
+# block_rows = padded_tokens // blocking
+# blocks_per_row = fhs // blocking
+# offsets = torch.arange(
+# 0,
+# block_rows * blocks_per_row + 1,
+# blocks_per_row,
+# dtype=torch.int32,
+# device=x.device,
+# )
+#
+# # Indices for the sparse matrix. The indices for
+# # the intermediate matrix are dynamic depending
+# # on the mapping of tokens to experts.
+# column_indices = ops.topology(
+# padded_bins,
+# blocking,
+# block_rows,
+# blocks_per_row,
+# )
+# data = torch.empty(
+# column_indices.numel(),
+# blocking,
+# blocking,
+# dtype=torch.float16,
+# device=x.device,
+# )
+# shape = (padded_tokens, fhs * ne)
+# row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
+# return stk.Matrix(shape, data, row_indices, column_indices, offsets)
+#
+# def build_input_matrix(self, sl, hs, ne):
+# x = torch.randn((sl, hs)).cuda().half()
+#
+# # Assign tokens to experts uniformly.
+# top_expert = torch.arange(0, sl).cuda().int() % ne
+#
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(top_expert, ne)
+# padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+# padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+# out = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, 1)
+# return out, padded_bins
+#
+# def build_weight_matrix(self, ne, hs, fhs):
+# return torch.randn((hs, ne * fhs)).cuda().half()
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_Fwd_SDD_NT(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+# w = transpose_view(w)
+#
+# def benchmark():
+# return stk.ops.sdd(x, w, topo)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0::Fwd::SDD::NT',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_GradX_DSD_NN(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+#
+# def benchmark():
+# return stk.ops.dsd(topo, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0::GradX::DSD::NN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_GradW_DSD_TN(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+# topo = topo.t()
+#
+# def benchmark():
+# return stk.ops.dsd(topo, x)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0::GradW::DSD::TN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_Fwd_DSD_NN(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+#
+# def benchmark():
+# return stk.ops.dsd(x, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::Fwd::DSD::NN',
+# arguments,
+# mean_t,
+# std_t,
+# x.nnz * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_GradX_SDD_NT(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+# out = stk.ops.dsd(x, w)
+# w = transpose_view(w)
+#
+# def benchmark():
+# return stk.ops.sdd(out, w, x)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::GradX::SDD::NT',
+# arguments,
+# mean_t,
+# std_t,
+# x.nnz * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_GradW_DSD_TN(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+# out = stk.ops.dsd(x, w)
+# x = x.t()
+#
+# def benchmark():
+# return stk.ops.dsd(x, out)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::GradW::DSD::TN',
+# arguments,
+# mean_t,
+# std_t,
+# x.nnz * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_Fwd_DDD_NT(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, hs)).cuda().half()
+# w = torch.randn((ne, hs, fhs)).cuda().half()
+#
+# w = w.transpose(1, 2).contiguous()
+# w = w.transpose(1, 2)
+#
+# def benchmark():
+# return torch.bmm(x, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0::Fwd:DDD::NT',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_GradX_DDD_NN(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, hs)).cuda().half()
+# w = torch.randn((ne, hs, fhs)).cuda().half()
+# out = torch.bmm(x, w)
+# w = w.transpose(1, 2).contiguous()
+#
+# def benchmark():
+# return torch.bmm(out, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0:GradX:DDD::NN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_GradW_DDD_TN(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, hs)).cuda().half()
+# w = torch.randn((ne, hs, fhs)).cuda().half()
+# out = torch.bmm(x, w)
+# out = out.transpose(1, 2)
+#
+# def benchmark():
+# return torch.bmm(out, x)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0:GradW:DDD::TN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_Fwd_DDD_NN(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+# w = torch.randn((ne, fhs, hs)).cuda().half()
+#
+# def benchmark():
+# return torch.bmm(x, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::Fwd::DDD::NN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_GradX_DDD_NT(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+# w = torch.randn((ne, fhs, hs)).cuda().half()
+# out = torch.bmm(x, w)
+# w = torch.transpose(w, 1, 2)
+#
+# def benchmark():
+# return torch.bmm(out, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::GradX::DDD::NT',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_GradW_DDD_TN(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+# w = torch.randn((ne, fhs, hs)).cuda().half()
+# out = torch.bmm(x, w)
+# x = torch.transpose(x, 1, 2)
+#
+# def benchmark():
+# return torch.bmm(x, out)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::GradW::DDD::TN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * hs * 2,
+# )
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch211-cxx11-cu128-aarch64-linux/ops/padded_gather.py b/build/torch211-cxx11-cu128-aarch64-linux/ops/padded_gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1cf4047c9494394d2a3884ba8830179013db7ff
--- /dev/null
+++ b/build/torch211-cxx11-cu128-aarch64-linux/ops/padded_gather.py
@@ -0,0 +1,55 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for padded_gather kernel.
+class PaddedGatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bin_ids, bins, padded_bins)
+ ctx.top_k = top_k
+ return kernels.padded_gather(
+ x,
+ indices,
+ bin_ids,
+ None,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+
+ indices, bin_ids, bins, padded_bins = ctx.saved_tensors
+ out = kernels.padded_scatter(
+ grad,
+ indices,
+ bin_ids,
+ None,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+ return out, None, None, None, None, None
+
+
+padded_gather = PaddedGatherOp.apply
diff --git a/build/torch211-cxx11-cu128-aarch64-linux/ops/padded_scatter.py b/build/torch211-cxx11-cu128-aarch64-linux/ops/padded_scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..61e021b81497e472cda5d72bdac557a0ca92d262
--- /dev/null
+++ b/build/torch211-cxx11-cu128-aarch64-linux/ops/padded_scatter.py
@@ -0,0 +1,98 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for padded_scatter kernel.
+class PaddedScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+ ):
+ maybe_x = [x] if ctx.needs_input_grad[3] else []
+ ctx.save_for_backward(
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ *maybe_x,
+ )
+ ctx.top_k = top_k
+ ctx.x_shape = x.shape
+ return kernels.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ saved_tensors = ctx.saved_tensors
+
+ indices, bin_ids, weights, bins, padded_bins = saved_tensors[:5]
+ dgrad = None
+ if ctx.needs_input_grad[0]:
+ dgrad = kernels.padded_gather(
+ grad,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[3]: # need wgrad
+ x = saved_tensors[-1]
+ wgrad = kernels.padded_scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+ return dgrad, None, None, wgrad, None, None, None, None
+
+
+def padded_scatter(
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+):
+ return PaddedScatterOp.apply(
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
diff --git a/build/torch211-cxx11-cu128-aarch64-linux/ops/padded_scatter_benchmark.py b/build/torch211-cxx11-cu128-aarch64-linux/ops/padded_scatter_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..3cbe4735891446b46f93170c64c23fe63632bf93
--- /dev/null
+++ b/build/torch211-cxx11-cu128-aarch64-linux/ops/padded_scatter_benchmark.py
@@ -0,0 +1,66 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import torch
+# from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+_PADDED_SCATTER_BENCHMARK = (
+ # dMoE-Medium, 8-way EMP.
+ (1024 * 16, 1024, 8, 4),
+ # dMoE-Medium, post-all-to-all.
+ (1024 * 16 * 4, 1024, 8, 1),
+)
+
+
+# class PaddedScatterTest(parameterized.TestCase):
+#
+# @parameterized.parameters(*_PADDED_SCATTER_BENCHMARK)
+# def testPaddedScatter(self, sl, hs, ne, top_k):
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+#
+# # Randomly assign tokens to experts.
+# top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(top_expert, ne)
+# padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+# padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+#
+# # Sample weights for the scatter reduce.
+# weights = torch.rand((sl * top_k,)).cuda().half()
+#
+# # Gather the data to prepare for backwards.
+# x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
+#
+# def benchmark():
+# return ops.padded_scatter(
+# x,
+# indices,
+# bin_ids,
+# weights,
+# bins,
+# padded_bins,
+# top_k,
+# )
+#
+# time, std = benchmark_util.benchmark_function(benchmark)
+# benchmark_util.log_benchmark(
+# 'Padded Scatter',
+# {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# 'top_k': top_k,
+# },
+# time,
+# std,
+# )
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch211-cxx11-cu128-aarch64-linux/ops/permute_benchmark.py b/build/torch211-cxx11-cu128-aarch64-linux/ops/permute_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..697abddbb3a2082ec4ddd6d94f89f7faabb34b40
--- /dev/null
+++ b/build/torch211-cxx11-cu128-aarch64-linux/ops/permute_benchmark.py
@@ -0,0 +1,149 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import torch
+# from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+_PERMUTE_TESTS = (
+ (16384, 768, 2),
+ (16384, 768, 4),
+ (16384, 768, 8),
+ (16384, 768, 16),
+ (16384, 768, 32),
+ (16384, 768, 64),
+ (16384, 768, 128),
+ (16384 * 8, 768, 2),
+ (16384 * 8, 768, 4),
+ (16384 * 8, 768, 8),
+ (16384 * 8, 768, 16),
+ (16384 * 8, 768, 32),
+ (16384 * 8, 768, 64),
+ (16384 * 8, 768, 128),
+)
+
+
+# class PermuteBenchmark(parameterized.TestCase):
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testBinnedGather(self, sl, hs, ne):
+# # NOTE: Capacity factor == 1.
+# ec = sl // ne
+#
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+# top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(indices, ne)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+#
+# def benchmark():
+# return ops.binned_gather(x, indices, bins, ec)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('BinnedGather', arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testBinnedScatter(self, sl, hs, ne):
+# # NOTE: Capacity factor == 1.
+# ec = sl // ne
+#
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+# top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(indices, ne)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+# x = ops.binned_gather(x, indices, bins, ec)
+#
+# def benchmark():
+# return ops.binned_scatter(x, indices, bins)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('BinnedScatter', arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testPaddedGather(self, sl, hs, ne):
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+#
+# # Randomly assign tokens to experts.
+# top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(top_expert, ne)
+# padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+# padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+#
+# def benchmark():
+# return ops.padded_gather(x, indices, bin_ids, bins, padded_bins)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('PaddedGather', arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testPaddedScatter(self, sl, hs, ne):
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+#
+# # Randomly assign tokens to experts.
+# top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(top_expert, ne)
+# padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+# padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+# x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins)
+#
+# def benchmark():
+# return ops.padded_scatter(x, indices, bin_ids, bins, padded_bins)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('PaddedScatter', arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testCopy(self, sl, hs, ne):
+# # NOTE: Capacity factor == 1.
+# # ec = sl // ne
+#
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+# y = x.clone()
+#
+# def benchmark():
+# return y.copy_(x)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('Copy', arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch211-cxx11-cu128-aarch64-linux/ops/repeat.py b/build/torch211-cxx11-cu128-aarch64-linux/ops/repeat.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e9e09de5f857d51cd758ab30b2f3a846d6f9275
--- /dev/null
+++ b/build/torch211-cxx11-cu128-aarch64-linux/ops/repeat.py
@@ -0,0 +1,10 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+
+def repeat(x: torch.Tensor, tiling: torch.Size):
+ if all((t == 1 for t in tiling)):
+ return x
+ return x.repeat(*tiling)
diff --git a/build/torch211-cxx11-cu128-aarch64-linux/ops/replicate.py b/build/torch211-cxx11-cu128-aarch64-linux/ops/replicate.py
new file mode 100644
index 0000000000000000000000000000000000000000..26daf0eede330603a4b8ea7167faf1411d07ca93
--- /dev/null
+++ b/build/torch211-cxx11-cu128-aarch64-linux/ops/replicate.py
@@ -0,0 +1,36 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for replicate kernel.
+class ReplicateOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, bins: torch.Tensor, num_outputs: int):
+ ctx.save_for_backward(bins)
+ out = torch.empty((x.shape[0], num_outputs), dtype=x.dtype, device=x.device)
+ ops.replicate_forward(x, bins, out)
+ return out
+
+ @staticmethod
+ def backward(ctx: Any, grad: torch.Tensor):
+ bins, = ctx.saved_tensors
+ out = torch.empty((grad.shape[0], bins.shape[0]), dtype=grad.dtype, device=grad.device)
+ ops.replicate_backward(grad, bins, out)
+ return out, None, None
+
+
+replicate = ReplicateOp.apply
diff --git a/build/torch211-cxx11-cu128-aarch64-linux/ops/round_up.py b/build/torch211-cxx11-cu128-aarch64-linux/ops/round_up.py
new file mode 100644
index 0000000000000000000000000000000000000000..6cf6bc873c9f448c5fa9126ebcfd66e8688002af
--- /dev/null
+++ b/build/torch211-cxx11-cu128-aarch64-linux/ops/round_up.py
@@ -0,0 +1,14 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+
+def round_up(x: torch.Tensor, value: int):
+ assert isinstance(value, int)
+ assert x.dtype == torch.int32
+
+ # TODO(tgale): If this becomes and issue
+ # do this in a custom kernel. We only expect
+ # to use this on arrays of less than 1k elements.
+ return torch.div(x + (value - 1), value, rounding_mode='trunc') * value
diff --git a/build/torch211-cxx11-cu128-aarch64-linux/ops/scatter.py b/build/torch211-cxx11-cu128-aarch64-linux/ops/scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4605d9b46f387761b070352365f223dbfe69d47
--- /dev/null
+++ b/build/torch211-cxx11-cu128-aarch64-linux/ops/scatter.py
@@ -0,0 +1,72 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Optional
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for scatter kernel.
+class ScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ) -> torch.Tensor:
+ maybe_x = [x] if ctx.needs_input_grad[3] else []
+ ctx.save_for_backward(indices, bin_ids, weights, bins, *maybe_x)
+ ctx.top_k = top_k
+ ctx.x_shape = x.shape
+ return kernels.scatter(x, indices, bin_ids, weights, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ saved_tensors = ctx.saved_tensors
+
+ indices, bin_ids, weights, bins = saved_tensors[:4]
+ dgrad = None
+ if ctx.needs_input_grad[0]:
+ dgrad = kernels.gather(
+ grad,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[3]: # need wgrad
+ x = saved_tensors[-1]
+ wgrad = kernels.scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bin_ids,
+ bins,
+ ctx.top_k,
+ )
+ return dgrad, None, None, wgrad, None, None, None
+
+
+def scatter(
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+) -> Optional[torch.Tensor]:
+ return ScatterOp.apply(x, indices, bin_ids, weights, bins, top_k)
diff --git a/build/torch211-cxx11-cu128-aarch64-linux/ops/sort.py b/build/torch211-cxx11-cu128-aarch64-linux/ops/sort.py
new file mode 100644
index 0000000000000000000000000000000000000000..bda3bf64283e39533c2eae3627e76bb2d0262c9f
--- /dev/null
+++ b/build/torch211-cxx11-cu128-aarch64-linux/ops/sort.py
@@ -0,0 +1,38 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Optional, Tuple
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+_BITS_FOR_DTYPE = {
+ torch.int16: 16,
+ torch.int32: 32,
+ torch.int64: 64,
+}
+
+
+# Autograd wrapper for sort kernel.
+# NOTE: Does not support gradients.
+class SortOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, end_bit: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
+ if end_bit is None:
+ end_bit = _BITS_FOR_DTYPE[x.dtype]
+ x_out = torch.empty_like(x)
+ iota_out = torch.empty_like(x)
+ ops.sort(x, end_bit, x_out, iota_out)
+ return (x_out, iota_out)
+
+
+sort = SortOp.apply
diff --git a/build/torch211-cxx11-cu128-aarch64-linux/ops/sort_benchmark.py b/build/torch211-cxx11-cu128-aarch64-linux/ops/sort_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..11043c0824c36372585f1d9f48480c2a6ef32eb6
--- /dev/null
+++ b/build/torch211-cxx11-cu128-aarch64-linux/ops/sort_benchmark.py
@@ -0,0 +1,85 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import numpy as np
+import torch
+# from absl.testing import parameterized
+
+from .. import ops
+
+_SORT_TESTS = (
+ (16384, torch.int32, None),
+ (16384, torch.int32, 2),
+ (16384, torch.int32, 128),
+)
+
+_BASELINE_SORT_TESTS = ((16384,),)
+
+
+def numpy_dtype(dtype):
+ types = {
+ torch.int16: np.int16,
+ torch.int32: np.int32,
+ torch.int64: np.int64,
+ }
+ return types[dtype]
+
+
+def benchmark_function(fn, iterations=10):
+ # Run once to get rid of startup overhead.
+ fn()
+ times = []
+ for _ in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+ start.record()
+ fn()
+ end.record()
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ times = np.array(times)
+ return times.mean(), times.std(), times.max(), times.min()
+
+
+def log_benchmark(arguments, mean_t, std_t):
+ print('=' * 60)
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t))
+ print('=' * 60)
+
+
+# class SortBenchmark(parameterized.TestCase):
+#
+# @parameterized.parameters(*_SORT_TESTS)
+# def testSort(self, n, dtype, max_val):
+# if max_val is None:
+# max_val = np.iinfo(numpy_dtype(dtype)).max
+# end_bit = int(np.ceil(np.log2(max_val)))
+# x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
+#
+# mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.sort(x, end_bit),)
+# arguments = {
+# 'n': n,
+# 'dtype': dtype,
+# 'max_val': max_val,
+# }
+# log_benchmark(arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_BASELINE_SORT_TESTS)
+# def testTorchSort(self, n):
+# x = torch.randint(0, 128, (n,)).cuda().to(torch.int32)
+#
+# mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.sort(x))
+# arguments = {
+# 'n': n,
+# }
+# log_benchmark(arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch211-cxx11-cu128-aarch64-linux/ops/stk_autocast.py b/build/torch211-cxx11-cu128-aarch64-linux/ops/stk_autocast.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a3626e5e0eec51339c95a448bca84be14a2ca93
--- /dev/null
+++ b/build/torch211-cxx11-cu128-aarch64-linux/ops/stk_autocast.py
@@ -0,0 +1,39 @@
+# vendored from
+# https://github.com/stanford-futuredata/stk/blob/736313768ef697ce13a0594a41b2512a0fbc9884/stk/backend/autocast.py
+import functools
+import torch
+
+
+def _is_eligible(x):
+ return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64)
+
+
+def _cast(x, dtype):
+ if isinstance(x, torch.Tensor) and _is_eligible(x):
+ return x.to(dtype)
+ elif isinstance(x, map):
+ return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()}
+ elif isinstance(x, list) or isinstance(x, tuple):
+ return type(x)(map(lambda y: _cast(y, dtype), x))
+ return x
+
+
+def custom_fwd(fwd):
+ """Wrap a custom autograd function that always uses autocast dtype."""
+
+ @functools.wraps(fwd)
+ def decorate_fwd(*args, **kwargs):
+ if torch.is_autocast_enabled():
+ with torch.autocast(device_type="cuda", enabled=False):
+ dtype = torch.get_autocast_gpu_dtype()
+ return fwd(*_cast(args, dtype), **_cast(kwargs, dtype))
+ return fwd(*args, **kwargs)
+ return decorate_fwd
+
+
+def custom_bwd(bwd):
+ @functools.wraps(bwd)
+ def decorate_bwd(*args, **kwargs):
+ with torch.autocast(device_type="cuda", enabled=False):
+ return bwd(*args, **kwargs)
+ return decorate_bwd
\ No newline at end of file
diff --git a/build/torch211-cxx11-cu128-aarch64-linux/ops/sum.py b/build/torch211-cxx11-cu128-aarch64-linux/ops/sum.py
new file mode 100644
index 0000000000000000000000000000000000000000..e00c1aa68e1193f5b72f75a2edc37de8d505facc
--- /dev/null
+++ b/build/torch211-cxx11-cu128-aarch64-linux/ops/sum.py
@@ -0,0 +1,9 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+import torch
+
+
+def sum(x: torch.Tensor, dim: int = 0):
+ if x.shape[dim] == 1:
+ return x.squeeze(dim=dim)
+ return x.sum(dim=dim)
diff --git a/build/torch211-cxx11-cu128-aarch64-linux/ops/topology.py b/build/torch211-cxx11-cu128-aarch64-linux/ops/topology.py
new file mode 100644
index 0000000000000000000000000000000000000000..76a50d3164db20534b099dcb4d8487a7aef25d15
--- /dev/null
+++ b/build/torch211-cxx11-cu128-aarch64-linux/ops/topology.py
@@ -0,0 +1,45 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for topology kernel.
+# NOTE: Does not support gradients.
+class TopologyOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(
+ ctx: Any,
+ padded_bins: torch.Tensor,
+ block_size: int,
+ output_block_rows: int,
+ output_block_columns: int,
+ ):
+ out = torch.empty(
+ output_block_rows * output_block_columns,
+ dtype=torch.int16,
+ device=padded_bins.device,
+ )
+ ops.indices(
+ padded_bins,
+ block_size,
+ output_block_rows,
+ output_block_columns,
+ out,
+ )
+ return out
+
+
+topology = TopologyOp.apply
diff --git a/build/torch211-cxx11-cu128-aarch64-linux/stk/__init__.py b/build/torch211-cxx11-cu128-aarch64-linux/stk/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..73c40c267b1c3f4949e9c957a5d2c9f682dfc1a6
--- /dev/null
+++ b/build/torch211-cxx11-cu128-aarch64-linux/stk/__init__.py
@@ -0,0 +1,7 @@
+# import stk.random
+# import stk.ops
+# from stk.matrix import Matrix
+
+from . import random
+from . import ops
+from .matrix import Matrix
diff --git a/build/torch211-cxx11-cu128-aarch64-linux/stk/backend/__init__.py b/build/torch211-cxx11-cu128-aarch64-linux/stk/backend/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/build/torch211-cxx11-cu128-aarch64-linux/stk/backend/autocast.py b/build/torch211-cxx11-cu128-aarch64-linux/stk/backend/autocast.py
new file mode 100644
index 0000000000000000000000000000000000000000..97f6e919a60f3fd579ed0215031008d14111dc96
--- /dev/null
+++ b/build/torch211-cxx11-cu128-aarch64-linux/stk/backend/autocast.py
@@ -0,0 +1,37 @@
+import functools
+import torch
+
+
+def _is_eligible(x):
+ return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64)
+
+
+def _cast(x, dtype):
+ if isinstance(x, torch.Tensor) and _is_eligible(x):
+ return x.to(dtype)
+ elif isinstance(x, map):
+ return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()}
+ elif isinstance(x, list) or isinstance(x, tuple):
+ return type(x)(map(lambda y: _cast(y, dtype), x))
+ return x
+
+
+def custom_fwd(fwd):
+ """Wrap a custom autograd function that always uses autocast dtype."""
+
+ @functools.wraps(fwd)
+ def decorate_fwd(*args, **kwargs):
+ if torch.is_autocast_enabled():
+ with torch.autocast(device_type="cuda", enabled=False):
+ dtype = torch.get_autocast_gpu_dtype()
+ return fwd(*_cast(args, dtype), **_cast(kwargs, dtype))
+ return fwd(*args, **kwargs)
+ return decorate_fwd
+
+
+def custom_bwd(bwd):
+ @functools.wraps(bwd)
+ def decorate_bwd(*args, **kwargs):
+ with torch.autocast(device_type="cuda", enabled=False):
+ return bwd(*args, **kwargs)
+ return decorate_bwd
diff --git a/build/torch211-cxx11-cu128-aarch64-linux/stk/backend/sputnik.py b/build/torch211-cxx11-cu128-aarch64-linux/stk/backend/sputnik.py
new file mode 100644
index 0000000000000000000000000000000000000000..220c947bc1e932e8c77cc30f4069e9930f1aa962
--- /dev/null
+++ b/build/torch211-cxx11-cu128-aarch64-linux/stk/backend/sputnik.py
@@ -0,0 +1,316 @@
+import torch
+
+from ..backend import triton_kernels as backend
+from ..backend.autocast import custom_bwd, custom_fwd
+
+
+def _standardize_shape(x, transpose):
+ if transpose:
+ return torch.Size((x[1], x[0]))
+ return x
+
+
+def _sparse_transpose(x):
+ return (torch.Size((x[0][1], x[0][0])), ) + x[1:]
+
+
+def _transpose_helper(x, transpose):
+ if isinstance(x, torch.Tensor):
+ return x.t() if transpose else x
+ if transpose:
+ x = _sparse_transpose(x)
+ return x + (transpose,)
+
+
+def _wrap(x):
+ if isinstance(x, torch.Tensor):
+ return (x,)
+ return x
+
+
+def _is_transposed(x):
+ return (not x.is_contiguous() and
+ x.stride()[0] == 1 and
+ x.stride()[1] == x.size()[0])
+
+
+def _call_helper(op, out, a, b, trans_a, trans_b):
+ args = (_wrap(_transpose_helper(a, trans_a)) +
+ _wrap(_transpose_helper(b, trans_b)))
+ if isinstance(out, tuple):
+ args = args + out
+ return op(*args)
+
+
+def _preprocess_inputs(lhs, rhs, dy):
+ if isinstance(lhs, torch.Tensor) and _is_transposed(lhs):
+ lhs = lhs.t()
+ if isinstance(rhs, torch.Tensor) and _is_transposed(rhs):
+ rhs = rhs.t()
+ if (isinstance(dy, torch.Tensor) and
+ not dy.is_contiguous() and
+ not _is_transposed(dy)):
+ dy = dy.contiguous()
+ if isinstance(dy, tuple) and not dy[1].is_contiguous():
+ dy = (dy[0], dy[1].contiguous()) + dy[2:]
+ return lhs, rhs, dy
+
+
+def _postprocess_outputs(x, transpose, grad):
+ if isinstance(x, torch.Tensor) and transpose:
+ return grad.t()
+ return grad
+
+
+def _lhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs):
+ lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy)
+
+ a, b = (rhs, dy) if trans_lhs else (dy, rhs)
+ trans_a = trans_lhs and trans_rhs
+ trans_b = trans_lhs or not trans_rhs
+ out = _call_helper(op, lhs, a, b, trans_a, trans_b)
+ return _postprocess_outputs(lhs, trans_lhs, out)
+
+
+def _rhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs):
+ lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy)
+
+ a, b = (dy, lhs) if trans_rhs else (lhs, dy)
+ trans_a = not trans_lhs or trans_rhs
+ trans_b = trans_lhs and trans_rhs
+ out = _call_helper(op, rhs, a, b, trans_a, trans_b)
+ return _postprocess_outputs(rhs, trans_rhs, out)
+
+
+class DSD(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs):
+ ctx.save_for_backward(data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ rhs)
+ ctx.shape = _standardize_shape(shape, transpose_a)
+ ctx.transpose_a = transpose_a
+
+ out = torch.empty(
+ (shape[0], rhs.size()[1]),
+ dtype=rhs.dtype,
+ device=rhs.device)
+
+ backend.dsd(shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs,
+ out)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs = (ctx.shape,) + saved_tensors[:-1]
+ rhs = saved_tensors[-1]
+ trans_a = ctx.transpose_a
+ trans_b = _is_transposed(rhs)
+
+ ddata = None
+ if ctx.needs_input_grad[1]:
+ ddata = _lhs_gradient(sdd,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ drhs = None
+ if ctx.needs_input_grad[-1]:
+ op = dds if trans_b else dsd
+ drhs = _rhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return None, ddata, None, None, None, None, None, None, None, drhs
+
+
+dsd = DSD.apply
+
+
+class DDS(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b):
+ ctx.save_for_backward(lhs,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t)
+ ctx.shape = _standardize_shape(shape, transpose_b)
+ ctx.transpose_b = transpose_b
+ out = torch.empty((lhs.size()[0], shape[1]),
+ dtype=lhs.dtype,
+ device=lhs.device)
+ backend.dds(lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b,
+ out)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs = saved_tensors[0]
+ rhs = (ctx.shape,) + saved_tensors[1:]
+ trans_a = _is_transposed(lhs)
+ trans_b = ctx.transpose_b
+
+ dlhs = None
+ if ctx.needs_input_grad[0]:
+ op = dsd if trans_a else dds
+ dlhs = _lhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ ddata = None
+ if ctx.needs_input_grad[2]:
+ ddata = _rhs_gradient(sdd,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return dlhs, None, ddata, None, None, None, None, None, None, None
+
+
+dds = DDS.apply
+
+
+class SDD(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ lhs,
+ rhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t):
+ ctx.save_for_backward(
+ lhs,
+ rhs,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t)
+ ctx.shape = shape
+ out = torch.empty(
+ data.shape,
+ dtype=lhs.dtype,
+ device=lhs.device)
+ backend.sdd(lhs,
+ rhs,
+ shape,
+ out,
+ offsets,
+ row_indices,
+ column_indices)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs, rhs = saved_tensors[:2]
+ dy = (ctx.shape, dy) + saved_tensors[2:]
+ trans_a = _is_transposed(lhs)
+ trans_b = _is_transposed(rhs)
+
+ dlhs = None
+ if ctx.needs_input_grad[0]:
+ op = dds if trans_a else dsd
+ dlhs = _lhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ drhs = None
+ if ctx.needs_input_grad[1]:
+ op = dsd if trans_b else dds
+ drhs = _rhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return dlhs, drhs, None, None, None, None, None, None, None, None
+
+
+sdd = SDD.apply
+
+class RowIndices(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, shape, data, offsets, column_indices):
+ out = torch.empty(
+ column_indices.shape,
+ dtype=column_indices.dtype,
+ device=column_indices.device)
+ backend.row_indices(shape, data, offsets, column_indices, out)
+ return out
+
+
+row_indices = RowIndices.apply
diff --git a/build/torch211-cxx11-cu128-aarch64-linux/stk/backend/triton_kernels.py b/build/torch211-cxx11-cu128-aarch64-linux/stk/backend/triton_kernels.py
new file mode 100644
index 0000000000000000000000000000000000000000..c535309f3321249f475367164a558f94a4f8eb86
--- /dev/null
+++ b/build/torch211-cxx11-cu128-aarch64-linux/stk/backend/triton_kernels.py
@@ -0,0 +1,393 @@
+import torch
+import triton
+import triton.language as tl
+from dataclasses import dataclass
+
+@dataclass
+class TritonConfig:
+ BLOCK_M: int = 128
+ BLOCK_N: int = 128
+ BLOCK_K: int = 32
+ BLOCK_SIZE: int = 128
+ NUM_STAGES: int = 4
+ NUM_WARPS: int = 4
+
+def _validate_matmul_dims(M: int, K: int, N: int):
+ error_string = "incompatible dimensions: tensor has dim with length: {}, which must be divisible by {}"
+ assert M % TritonConfig.BLOCK_M == 0, error_string.format(M, TritonConfig.BLOCK_M)
+ assert K % TritonConfig.BLOCK_K == 0, error_string.format(K, TritonConfig.BLOCK_K)
+ assert N % TritonConfig.BLOCK_N == 0, error_string.format(N, TritonConfig.BLOCK_N)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _sdd_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+ # matrix multiplication
+ pid = tl.program_id(0)
+ pid_m = tl.load(row_indices + pid)
+ pid_n = tl.load(column_indices + pid)
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
+ rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
+ rk = tl.arange(0, BLOCK_K)
+ # pointers
+ A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
+ B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ for k in range(0, tl.cdiv(K, BLOCK_K)):
+ a = tl.load(A)
+ b = tl.load(B)
+ acc += tl.dot(a, b)
+ A += BLOCK_K * stride_ak
+ B += BLOCK_K * stride_bk
+ #Store to sparse matrix
+ acc = acc.to(C.dtype.element_ty)
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+ cm = tl.arange(0, BLOCK_M)
+ cn = tl.arange(0, BLOCK_N)
+ C = C + pid * BLOCK_ELEMENTS + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _dsd_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices, offsets,
+ block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+
+ # matrix multiplication
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+
+ num_pid_m = tl.num_programs(0)
+ num_pid_n = tl.num_programs(1)
+ pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M)
+
+ start_inx = tl.load(offsets + pid_m)
+ end_inx = tl.load(offsets + pid_m + 1)
+
+ # pointers to sparse matrix
+ rm = tl.arange(0, BLOCK_M)
+ rak = tl.arange(0, BLOCK_K)
+
+ A += (rm[:, None] * stride_am + rak[None, :] * stride_ak)
+
+ # pointers to dense matrix
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ rbk = tl.arange(0, BLOCK_K)
+ B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn)
+
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K)
+
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+ ak_sub_incr = BLOCK_K * stride_ak
+ bk_sub_incr = BLOCK_K * stride_bk
+ bk_block_incr = BLOCK_SIZE * stride_bk
+
+ for k in range(nsub_blocks * (end_inx - start_inx)):
+ sub_block_inx = k % nsub_blocks
+ block_inx = k // nsub_blocks
+
+ if trans_A:
+ ptr_A = A + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr
+ else:
+ ptr_A = A + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr
+
+ ptr_B = B + tl.load(column_indices + start_inx + block_inx) * bk_block_incr + sub_block_inx * bk_sub_incr
+
+ a = tl.load(ptr_A)
+ b = tl.load(ptr_B)
+ acc += tl.dot(a, b)
+
+ acc = acc.to(C.dtype.element_ty)
+
+ cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+
+ C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _dds_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices, offsets,
+ block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+
+ # matrix multiplication
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+
+ num_pid_m = tl.num_programs(0)
+ num_pid_n = tl.num_programs(1)
+ pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M)
+
+ start_inx = tl.load(offsets + pid_n)
+ end_inx = tl.load(offsets + pid_n + 1)
+
+ # pointers to dense matrix
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ rak = tl.arange(0, BLOCK_K)
+
+ A += (rm[:, None] * stride_am + rak[None, :] * stride_ak)
+
+ # pointers to sparse matrix
+ rn = tl.arange(0, BLOCK_N)
+ rbk = tl.arange(0, BLOCK_K)
+ B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn)
+
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K)
+
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+
+ ak_sub_incr = BLOCK_K * stride_ak
+ ak_block_incr = BLOCK_SIZE * stride_ak
+ bk_sub_incr = BLOCK_K * stride_bk
+
+ for k in range(nsub_blocks * (end_inx - start_inx)):
+ sub_block_inx = k % nsub_blocks
+ block_inx = k // nsub_blocks
+
+ if trans_B:
+ ptr_B = B + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr
+ else:
+ ptr_B = B + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr
+
+ ptr_A = A + tl.load(column_indices + start_inx + block_inx) * ak_block_incr + sub_block_inx * ak_sub_incr
+ a = tl.load(ptr_A)
+ b = tl.load(ptr_B)
+ acc += tl.dot(a, b)
+
+ acc = acc.to(C.dtype.element_ty)
+ cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+def dsd(shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs,
+ out
+ ):
+
+ device = rhs.device
+ trans_A = transpose_a
+ trans_B = False
+
+ if rhs.stride(0) > 1 and rhs.stride(1) > 1:
+ trans_B = True
+
+ # checks constraints
+ assert shape[1] == rhs.shape[0], "incompatible dimensions"
+ M, K = shape
+ _, N = rhs.shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if rhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ stride_am, stride_ak = data.stride(1), data.stride(2)
+ stride_bk, stride_bn = rhs.stride(0), rhs.stride(1)
+ a_column_indices = column_indices
+ a_offsets = offsets
+
+ # launch kernel
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N']))
+
+ if trans_A:
+ stride_am, stride_ak = data.stride(2), data.stride(1)
+ a_column_indices, a_offsets = column_indices_t, offsets_t
+
+ if trans_B:
+ stride_bk, stride_bn = rhs.stride(1), rhs.stride(0)
+
+ _dsd_kernel[grid](
+ data.data, rhs, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(0), out.stride(1),
+ row_indices, a_column_indices, a_offsets,
+ block_offsets_t, trans_A, trans_B,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+ # return out
+
+def dds(lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b,
+ out
+ ):
+
+ device = lhs.device
+ trans_B = transpose_b
+ trans_A = False
+
+ if lhs.stride(0) > 1 and lhs.stride(1) > 1:
+ trans_A = True
+
+ # checks constraints
+ assert lhs.shape[1] == shape[0], "incompatible dimensions"
+ M, K = lhs.shape
+ _, N = shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if lhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ stride_am, stride_ak = lhs.stride(0), lhs.stride(1)
+ stride_bk, stride_bn = data.stride(1), data.stride(2)
+ b_column_indices = column_indices_t
+ b_offsets = offsets_t
+
+ # launch kernel
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N']))
+
+ if trans_A:
+ stride_am, stride_ak = lhs.stride(1), lhs.stride(0)
+ if trans_B:
+ stride_bk, stride_bn = data.stride(2), data.stride(1)
+ b_column_indices, b_offsets = column_indices, offsets
+
+ _dds_kernel[grid](
+ lhs, data, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(0), out.stride(1),
+ row_indices, b_column_indices, b_offsets,
+ block_offsets_t, trans_A, trans_B,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+
+def sdd(lhs,
+ rhs,
+ shape,
+ out,
+ offsets,
+ row_indices,
+ column_indices
+ ):
+
+ device = out.device
+ trans_A = False
+ trans_B = False
+
+ if lhs.stride(0) > 1 and lhs.stride(1) > 1:
+ trans_A = True
+ if rhs.stride(0) > 1 and rhs.stride(1) > 1:
+ trans_B = True
+
+ # checks constraints
+ assert lhs.shape[1] == rhs.shape[0], "incompatible dimensions"
+ M, K = lhs.shape
+ _, N = rhs.shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if out.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ # launch kernel
+ nnz_blocks = len(row_indices)
+ grid = lambda META: (nnz_blocks,)
+
+ stride_am, stride_ak = lhs.stride(0), lhs.stride(1)
+ stride_bk, stride_bn = rhs.stride(0), rhs.stride(1)
+
+ if trans_A:
+ stride_am, stride_ak = lhs.stride(1), lhs.stride(0)
+ if trans_B:
+ stride_bk, stride_bn = rhs.stride(1), rhs.stride(0)
+
+ _sdd_kernel[grid](
+ lhs, rhs, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(1), out.stride(2),
+ row_indices, column_indices,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+
+@triton.jit
+def _row_indices_kernel(offsets, out):
+ pid = tl.program_id(0)
+ row_offset = tl.load(offsets + pid)
+ nnz_blocks = tl.load(offsets + pid + 1) - row_offset
+ for nnz_block in range(nnz_blocks):
+ tl.store(out + row_offset + nnz_block, pid)
+
+def row_indices(
+ shape, data, offsets, column_indices, out
+):
+ block_rows = len(offsets) - 1
+ _row_indices_kernel[(block_rows, )](offsets, out)
diff --git a/build/torch211-cxx11-cu128-aarch64-linux/stk/matrix.py b/build/torch211-cxx11-cu128-aarch64-linux/stk/matrix.py
new file mode 100644
index 0000000000000000000000000000000000000000..80f42263d6aada287adbfa52a61fe950162a9e28
--- /dev/null
+++ b/build/torch211-cxx11-cu128-aarch64-linux/stk/matrix.py
@@ -0,0 +1,329 @@
+import numpy as np
+import torch
+
+# 1. Add heavyweight (data) validation helper.
+# 2. Add construction helpers
+# 3. Make indentation consistent
+# 4. Replace asserts with descriptive errors.
+
+##
+### Validation helpers.
+##
+
+
+def _validate_matrix(shape, data, row_indices, column_indices, offsets):
+ # Data should be [nnz, block_size, block_size]
+ if data.dim() == 1:
+ data = torch.reshape(data, [data.numel(), 1, 1])
+
+ # Blocks should be square.
+ if data.shape[-2] != data.shape[-1]:
+ raise ValueError(
+ "Expected square blocking in data. "
+ f"Got block shape {[data.shape[-2], data.shape[-1]]}")
+
+ # Flatten batch dimensions on data - original shape preserved
+ # in shape argument.
+ block_size = data.shape[-1]
+ data = data.view([-1, block_size, block_size])
+
+ if data.dim() != 3:
+ raise ValueError(
+ "Expected 3D shape for data (nnz, block, block). "
+ f"Got shape {data.dim()}D shape.")
+
+ block_size = data.shape[1]
+ if shape[-2] % block_size != 0 or shape[-1] % block_size != 0:
+ raise ValueError(
+ "Matrix shape must be dividible by blocking. "
+ f"Got shape {shape} with "
+ f"{[block_size, block_size]} blocking.")
+
+ if np.prod(shape) < data.numel():
+ raise ValueError(
+ "Invalid matrix. Number of nonzeros exceeds matrix capacity "
+ f"({data.numel()} v. {np.prod(shape)})")
+
+ if row_indices.dim() != 1:
+ raise ValueError(
+ f"Expected 1D row_indices. Got {row_indices.dim()}D row_indices.")
+
+ if column_indices.dim() != 1:
+ raise ValueError(
+ f"Expected 1D column_indices. Got {column_indices.dim()}D column_indices.")
+
+ if offsets.dim() != 1:
+ raise ValueError(
+ f"Expected 1D offsets. Got {offsets.dim()}D offsets.")
+
+ if row_indices.numel() != data.shape[0]:
+ raise ValueError(
+ "Expected 1 index per nonzero block. "
+ f"Got {row_indices.numel()} row_indices for {data.shape[0]} blocks")
+
+ if column_indices.numel() != data.shape[0]:
+ raise ValueError(
+ "Expected 1 index per nonzero block. "
+ f"Got {column_indices.numel()} column_indices for {data.shape[0]} blocks")
+
+ block_rows = np.prod(shape[:-1]) / block_size
+ if offsets.numel() != block_rows + 1:
+ raise ValueError(
+ "Expected one offset per block row plus one. "
+ f"Got {offsets.numel()} offsets with {block_rows} block rows.")
+
+ is_cuda = (data.is_cuda and
+ row_indices.is_cuda and
+ column_indices.is_cuda and
+ offsets.is_cuda)
+ is_cpu = (not data.is_cuda and
+ not row_indices.is_cuda and
+ not column_indices.is_cuda and
+ not offsets.is_cuda)
+ if not (is_cuda or is_cpu):
+ raise ValueError(
+ "Expected data & meta-data on common device. "
+ f"Got data on {data.device}, row_indices on {row_indices.device} "
+ f"column_indices on {column_indices.device} and "
+ f"offsets on {offsets.device}.")
+
+ if data.dtype != torch.float16:
+ raise ValueError(
+ f"Expected float16 data. Got {data.dtype} data.")
+ if row_indices.dtype != torch.int16:
+ raise ValueError(
+ f"Expected int16 row_indices. Got {row_indices.dtype} row_indices.")
+ if column_indices.dtype != torch.int16:
+ raise ValueError(
+ f"Expected int16 column_indices. Got {column_indices.dtype} column_indices.")
+ if offsets.dtype != torch.int32:
+ raise ValueError(
+ f"Expected int32 offsets. Got {offsets.dtype} offsets.")
+ return data
+
+
+def _transpose(size, data, row_indices, column_indices, offsets):
+ block_columns = size[1] // data.shape[1]
+
+ # Sort row indices by column indices to get the transposed matrix's
+ # column indices.
+ gather_indices = column_indices.argsort()
+ column_indices_t = row_indices.gather(0, gather_indices)
+ block_offsets_t = gather_indices.int()
+
+ # NOTE: Histogram is not implemented for any integer type on CPU. Do
+ # the histogram in 32-bit float, which can exactly represent 16-bit
+ # integers.
+ column_indices_float = column_indices.float()
+
+ zero = torch.zeros((1,), dtype=torch.int32, device=data.device)
+ nnz_per_column = column_indices_float.histc(block_columns, 0, block_columns)
+ nnz_per_column = nnz_per_column.int()
+ offsets_t = torch.cat([zero, nnz_per_column.cumsum(0, dtype=torch.int32)])
+ return column_indices_t, offsets_t, block_offsets_t
+
+
+class Matrix(torch.nn.Module):
+ """A matrix stored in sparse format.
+
+ Underlying format is block compressed sparse row (BCSR).
+
+ TODO(tgale): Make this mirror torch.Tensor API as much as possible.
+ """
+
+ def __init__(self,
+ size,
+ data,
+ row_indices,
+ column_indices,
+ offsets,
+ column_indices_t=None,
+ offsets_t=None,
+ block_offsets_t=None):
+ super().__init__()
+ self._size = size
+ self._data = data
+ self._row_indices = row_indices
+ self._column_indices = column_indices
+ self._offsets = offsets
+
+ # Produce the transpose meta-data if it is not passed in.
+ if ((column_indices_t is None) or (offsets_t is None) or
+ (block_offsets_t is None)):
+ column_indices_t, offsets_t, block_offsets_t = _transpose(
+ size, data, row_indices, column_indices, offsets)
+ self._column_indices_t = column_indices_t
+ self._offsets_t = offsets_t
+ self._block_offsets_t = block_offsets_t
+
+ self._transposed = False
+
+ # Validate that our metadata will not overflow.
+ max_dim = np.iinfo(np.int16).max * self.blocking
+ if column_indices.dtype == torch.int16:
+ if size[0] > max_dim or size[1] > max_dim:
+ raise ValueError(
+ "Sparse matrix with shape {size} exceeds representable "
+ "size with 16-bit indices.")
+
+ def validate(self):
+ _validate_matrix(self._size,
+ self._data,
+ self._row_indices,
+ self._column_indices,
+ self._offsets)
+
+ # TODO(tgale): Add heavyweight data validation.
+
+ def to(self, device):
+ # TODO(tgale): Handle type conversions here. We
+ # need to set the appropriate meta-data type for
+ # the given floating-point type.
+ self._data = self._data.to(device)
+ self._row_indices = self._row_indices.to(device)
+ self._column_indices = self._column_indices.to(device)
+ self._offsets = self._offsets.to(device)
+ self._column_indices_t = self._column_indices_t.to(device)
+ self._offsets_t = self._offsets_t.to(device)
+ self._block_offsets_t = self._block_offsets_t.to(device)
+ return self
+
+ def cuda(self):
+ return self.to(torch.cuda.current_device())
+
+ def clone(self):
+ return Matrix(
+ self.size(),
+ self.data.clone(),
+ self.row_indices.clone(),
+ self.column_indices.clone(),
+ self.offsets.clone(),
+ self.column_indices_t.clone(),
+ self.offsets_t.clone(),
+ self.block_offsets_t.clone())
+
+ def t(self):
+ if self.dim() != 2:
+ raise ValueError(
+ "t() expects a tensor with <= 2 dimensions, "
+ f"but self is {self.dim()}D.")
+ out = Matrix(self.size(),
+ self.data,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+ out._transposed = not self._transposed
+ out._size = torch.Size((self._size[1], self._size[0]))
+ return out
+
+ def contiguous(self):
+ raise ValueError("Not yet implemented.")
+
+ def is_contiguous(self):
+ return not self._transposed
+
+ @property
+ def is_cuda(self):
+ return self._data.is_cuda
+
+ @property
+ def device(self):
+ return self._data.device
+
+ def size(self):
+ return self._size
+
+ @property
+ def shape(self):
+ return self.size()
+
+ def dim(self):
+ return len(self._size)
+
+ @property
+ def data(self):
+ return self._data
+
+ @property
+ def row_indices(self):
+ return self._row_indices
+
+ @property
+ def column_indices(self):
+ return self._column_indices
+
+ @property
+ def offsets(self):
+ return self._offsets
+
+ @property
+ def offsets_t(self):
+ return self._offsets_t
+
+ @property
+ def column_indices_t(self):
+ return self._column_indices_t
+
+ @property
+ def block_offsets_t(self):
+ return self._block_offsets_t
+
+ @property
+ def dtype(self):
+ return self.data.dtype
+
+ @property
+ def nnz(self):
+ return self.data.numel()
+
+ @property
+ def blocking(self):
+ return self.data.shape[1]
+
+ @property
+ def requires_grad(self):
+ return self.data.requires_grad
+
+ def requires_grad_(self, x):
+ self.data.requires_grad_(x)
+ return self
+
+ def view(self, *shape):
+ assert self.is_contiguous()
+ if shape[-1] != self.size()[-1]:
+ raise ValueError(
+ "Can't change view on compressed dimension. "
+ f"{self.size()[-1]} v. {shape[-1]}.")
+ if np.prod(shape) != np.prod(self.size()):
+ raise ValueError(
+ "Mismatch in numel of Matrix and new shape. "
+ f"{np.prod(self.size())} v. {np.prod(shape)}")
+ return Matrix(shape,
+ self.data,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+
+ @property
+ def grad(self):
+ # TODO(tgale): Make sure this mirrors torch.Tensor
+ # behavior in the case where we ask for the gradient
+ # of a non-contiguous tensor.
+ size = self.size()
+ if not self.is_contiguous():
+ size = torch.Size((size[1], size[0]))
+ out = Matrix(size,
+ self.data.grad,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+ return out if self.is_contiguous() else out.t()
diff --git a/build/torch211-cxx11-cu128-aarch64-linux/stk/ops/__init__.py b/build/torch211-cxx11-cu128-aarch64-linux/stk/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc873b236f4cd4036964c016a4036e3ce5ebf1ac
--- /dev/null
+++ b/build/torch211-cxx11-cu128-aarch64-linux/stk/ops/__init__.py
@@ -0,0 +1,3 @@
+from .linear_ops import dds, dsd, sdd
+from .matrix_ops import ones_like, row_indices, sum, to_dense, to_sparse
+from .eltwise_ops import mul
diff --git a/build/torch211-cxx11-cu128-aarch64-linux/stk/ops/eltwise_ops.py b/build/torch211-cxx11-cu128-aarch64-linux/stk/ops/eltwise_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba7d7332320250fd01fa60e60528f19de3e8ed03
--- /dev/null
+++ b/build/torch211-cxx11-cu128-aarch64-linux/stk/ops/eltwise_ops.py
@@ -0,0 +1,28 @@
+from ..matrix import Matrix
+
+def mul(a, b):
+ """Performs element-wise multiplication of matrices a and b.
+
+ It is the user's responsibility to make sure that a and b
+ follow the same matrix topology. This function assumes it is safe
+ to use the topoplogy of a.
+
+ Args:
+ a: stk.Matrix.
+ b: stk.Matrix with a's matrix topology.
+
+ Returns:
+ stk.Matrix where the entries correspond to torch.mul(a, b).
+ """
+ assert isinstance(a, Matrix)
+ assert isinstance(b, Matrix)
+ assert a.size() == b.size()
+
+ return Matrix(a.size(),
+ a.data * b.data,
+ a.row_indices,
+ a.column_indices,
+ a.offsets,
+ a.column_indices_t,
+ a.offsets_t,
+ a.block_offsets_t)
diff --git a/build/torch211-cxx11-cu128-aarch64-linux/stk/ops/eltwise_ops_test.py b/build/torch211-cxx11-cu128-aarch64-linux/stk/ops/eltwise_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..2939372a5c68ac92b47b11015db4f75f4fd60ffa
--- /dev/null
+++ b/build/torch211-cxx11-cu128-aarch64-linux/stk/ops/eltwise_ops_test.py
@@ -0,0 +1,86 @@
+import unittest
+import itertools
+import torch
+# from absl.testing import parameterized
+
+import stk
+from stk.ops.linear_ops_test import allclose, _dense_and_sparse
+
+_MATRIX_SIZES = (
+ (128, 128, 0.0),
+ (256, 256, 0.5),
+ (2048, 1024, 0.8),
+ (512, 128, 0.0),
+ (128, 512, 0.0),
+ (1024, 512, 0.0),
+ (1024, 512, 0.5),
+ (1024, 512, 0.75),
+ (512, 1024, 0.0),
+ (512, 1024, 0.5),
+ (512, 1024, 0.75),
+ (1024, 1024, 0.0),
+ (1024, 1024, 0.5),
+ (1024, 1024, 0.75),
+)
+
+_DTYPE = (
+ torch.float16, torch.bfloat16
+)
+
+def _generate_testcases():
+ testcases = itertools.product(_MATRIX_SIZES, _DTYPE)
+ testcases = [(*size, 128, dtype) for
+ (size, dtype) in testcases]
+ return testcases
+
+_ELTWISE_OP_TESTS = _generate_testcases()
+
+def _dense_and_sparse_like(x, std=0.1):
+ dense_data = torch.randn_like(x.data, device=x.device) * std
+ sparse = stk.Matrix(x.size(),
+ dense_data,
+ x.row_indices,
+ x.column_indices,
+ x.offsets)
+ dense = stk.ops.to_dense(sparse)
+
+ return (dense.requires_grad_(True),
+ sparse.requires_grad_(True))
+
+# @parameterized.parameters(_ELTWISE_OP_TESTS)
+# class EltwiseOpsTest(parameterized.TestCase):
+#
+# def testEltwiseMul(self, m, n, sparsity, blocking, dtype):
+#
+# a_dense, a = _dense_and_sparse(m, n, sparsity, blocking, dtype)
+# b_dense, b = _dense_and_sparse_like(a)
+#
+# out = stk.ops.mul(a, b)
+# expected_out = torch.mul(a_dense, b_dense)
+#
+# # Compute the gradients w.r.t. the inputs.
+# expected_out.sum().backward()
+# stk.ops.sum(out).backward()
+#
+# # Validate the results.
+# out = stk.ops.to_dense(out)
+# self.assertEqual(out.dim(), 2)
+# self.assertEqual(expected_out.size(), out.size())
+# self.assertTrue(allclose(out, expected_out))
+#
+# # LHS gradient.
+# grad = stk.ops.to_dense(a.grad)
+# expected_grad = a_dense.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size(), grad.size())
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# # RHS gradient.
+# grad = stk.ops.to_dense(b.grad)
+# expected_grad = b_dense.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size(), grad.size())
+# self.assertTrue(allclose(grad, expected_grad))
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch211-cxx11-cu128-aarch64-linux/stk/ops/linear_ops.py b/build/torch211-cxx11-cu128-aarch64-linux/stk/ops/linear_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d277c8c07f9e30addc31900a12175c8a1f4d7ad
--- /dev/null
+++ b/build/torch211-cxx11-cu128-aarch64-linux/stk/ops/linear_ops.py
@@ -0,0 +1,59 @@
+import torch
+
+from ..backend import sputnik
+from ..matrix import Matrix
+
+
+def dsd(a, b):
+ assert isinstance(a, Matrix)
+ assert isinstance(b, torch.Tensor)
+ return sputnik.dsd(
+ a.size(),
+ a.data, a.offsets,
+ a.row_indices,
+ a.column_indices,
+ a.offsets_t,
+ a.column_indices_t,
+ a.block_offsets_t,
+ not a.is_contiguous(),
+ b)
+
+
+def dds(a, b):
+ assert isinstance(a, torch.Tensor)
+ assert isinstance(b, Matrix)
+ return sputnik.dds(
+ a,
+ b.size(),
+ b.data, b.offsets,
+ b.row_indices,
+ b.column_indices,
+ b.offsets_t,
+ b.column_indices_t,
+ b.block_offsets_t,
+ not b.is_contiguous())
+
+
+def sdd(a, b, topo):
+ assert isinstance(a, torch.Tensor)
+ assert isinstance(b, torch.Tensor)
+ assert isinstance(topo, Matrix)
+ assert topo.is_contiguous()
+ out = sputnik.sdd(
+ a, b,
+ topo.size(),
+ topo.data,
+ topo.offsets,
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets_t,
+ topo.column_indices_t,
+ topo.block_offsets_t)
+ return Matrix(topo.size(),
+ out,
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets,
+ topo.column_indices_t,
+ topo.offsets_t,
+ topo.block_offsets_t)
diff --git a/build/torch211-cxx11-cu128-aarch64-linux/stk/ops/linear_ops_test.py b/build/torch211-cxx11-cu128-aarch64-linux/stk/ops/linear_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1c24d350df9c1b2346c7da885502cd696c88867
--- /dev/null
+++ b/build/torch211-cxx11-cu128-aarch64-linux/stk/ops/linear_ops_test.py
@@ -0,0 +1,216 @@
+import unittest
+import itertools
+import numpy as np
+import torch
+# from absl.testing import parameterized
+
+import stk
+
+
+def allclose(x, y, pct=0.25):
+ mask = torch.isclose(x, y, rtol=5e-2)
+ pct_diff = (mask.numel() - mask.sum()) / mask.numel() * 100
+ if pct_diff > pct:
+ print("{:.2f}% of values not close.".format(pct_diff))
+ return False
+ return True
+
+
+# An assortment of problems designed to make sure
+# the bindings are operating correctly.
+_MATRIX_SIZES = (
+ (128, 128, 128, 0.0),
+ (256, 256, 256, 0.5),
+ (2048, 1024, 512, 0.8),
+ (512, 128, 128, 0.0),
+ (128, 128, 512, 0.0),
+ (1024, 512, 512, 0.0),
+ (1024, 512, 512, 0.5),
+ (1024, 512, 512, 0.75),
+ (512, 512, 1024, 0.0),
+ (512, 512, 1024, 0.5),
+ (512, 512, 1024, 0.75),
+ (1024, 1024, 1024, 0.0),
+ (1024, 1024, 1024, 0.5),
+ (1024, 1024, 1024, 0.75),
+)
+
+_TRANSPOSE = (
+ (False, False),
+ (False, True),
+ (True, False),
+ (True, True),
+)
+
+_DTYPE = (
+ torch.float16, torch.bfloat16
+)
+
+def _generate_testcases():
+ testcases = itertools.product(_MATRIX_SIZES, _TRANSPOSE, _DTYPE)
+ testcases = [(*size, *trans, 128, dtype) for
+ (size, trans, dtype) in testcases]
+ return testcases
+
+_LINEAR_OP_TESTS = _generate_testcases()
+
+def _dense_and_sparse(rows, cols, sparsity, blocking, dtype, std=0.1):
+ mask = stk.random.dense_mask(rows, cols, sparsity, blocking)
+ dense = (torch.randn(rows, cols) * std * mask).type(dtype)
+ sparse = stk.ops.to_sparse(dense, blocking)
+ cuda_device = torch.device("cuda")
+ return (dense.to(cuda_device).requires_grad_(True),
+ sparse.to(cuda_device).requires_grad_(True))
+
+
+def _dense(rows, cols, dtype, std=0.1):
+ cuda_device = torch.device("cuda")
+ out = (torch.randn(rows, cols) * std).type(dtype)
+ return out.to(cuda_device).requires_grad_(True)
+
+
+def _dense_2x(rows, cols, dtype):
+ a = _dense(rows, cols, dtype)
+ return a, a.detach().requires_grad_(True)
+
+
+def _with_transpose(op, a, b, trans_a, trans_b):
+ a = a.t() if trans_a else a
+ b = b.t() if trans_b else b
+ return op(a, b)
+
+
+def _mmm(a, b, topo):
+ mask = stk.ops.to_dense(stk.ops.ones_like(topo))
+ return torch.mm(a, b) * mask
+
+
+def _sparse_out_with_transpose(op, a, b, topo, trans_a, trans_b):
+ a = a.t() if trans_a else a
+ b = b.t() if trans_b else b
+ return op(a, b, topo)
+
+
+def _mask(x, mask):
+ mask = stk.ops.to_dense(stk.ops.ones_like(mask))
+ return x * mask
+
+
+# @parameterized.parameters(*_LINEAR_OP_TESTS)
+# class LinearOpsTest(parameterized.TestCase):
+#
+# def testLinearOps_Dsd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+# # Construct the operands.
+# a_shape = (k, m) if trans_a else (m, k)
+# a_dense, a = _dense_and_sparse(*a_shape, sparsity, blocking, dtype)
+# b_shape = (n, k) if trans_b else (k, n)
+# b, bcp = _dense_2x(*b_shape, dtype)
+#
+# # Execute the matmul.
+# out = _with_transpose(stk.ops.dsd, a, b, trans_a, trans_b)
+# expected_out = _with_transpose(torch.mm, a_dense, bcp, trans_a, trans_b)
+#
+# # Compute the gradients w.r.t. the inputs.
+# expected_out.sum().backward()
+# out.sum().backward()
+#
+# # Validate the results.
+# self.assertEqual(out.dim(), 2)
+# self.assertEqual(expected_out.size()[0], out.size()[0])
+# self.assertEqual(expected_out.size()[1], out.size()[1])
+# self.assertTrue(allclose(out, expected_out))
+#
+# # LHS gradient.
+# grad = stk.ops.to_dense(a.grad)
+# expected_grad = _mask(a_dense.grad, a.grad)
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# # RHS gradient.
+# grad = b.grad
+# expected_grad = bcp.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# def testLinearOps_Dds(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+# # Construct the operands.
+# a_shape = (k, m) if trans_a else (m, k)
+# a, acp = _dense_2x(*a_shape, dtype)
+# b_shape = (n, k) if trans_b else (k, n)
+# b_dense, b = _dense_and_sparse(*b_shape, sparsity, blocking, dtype)
+#
+# # Execute the matmul.
+# out = _with_transpose(stk.ops.dds, a, b, trans_a, trans_b)
+# expected_out = _with_transpose(torch.mm, acp, b_dense, trans_a, trans_b)
+#
+# # Compute the gradients w.r.t. the inputs.
+# expected_out.sum().backward()
+# out.sum().backward()
+#
+# # Validate the results.
+# self.assertEqual(out.dim(), 2)
+# self.assertEqual(expected_out.size()[0], out.size()[0])
+# self.assertEqual(expected_out.size()[1], out.size()[1])
+# self.assertTrue(allclose(out, expected_out))
+#
+# # LHS gradient.
+# grad = a.grad
+# expected_grad = acp.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# # RHS gradient.
+# grad = stk.ops.to_dense(b.grad)
+# expected_grad = _mask(b_dense.grad, b.grad)
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# def testLinearOps_Sdd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+# # Construct the operands.
+# a_shape = (k, m) if trans_a else (m, k)
+# a, acp = _dense_2x(*a_shape, dtype)
+# b_shape = (n, k) if trans_b else (k, n)
+# b, bcp = _dense_2x(*b_shape, dtype)
+# _, topo = _dense_and_sparse(m, n, sparsity, blocking, dtype)
+#
+# # Execute the matmul.
+# out = _sparse_out_with_transpose(stk.ops.sdd, a, b, topo, trans_a, trans_b)
+# expected_out = _sparse_out_with_transpose(_mmm, acp, bcp, topo, trans_a, trans_b)
+#
+# # Compute the gradients w.r.t. the inputs.
+# expected_out.sum().backward()
+# stk.ops.sum(out).backward()
+#
+# # Validate the results.
+# out = stk.ops.to_dense(out)
+# self.assertEqual(out.dim(), 2)
+# self.assertEqual(expected_out.size()[0], out.size()[0])
+# self.assertEqual(expected_out.size()[1], out.size()[1])
+# self.assertTrue(allclose(out, expected_out))
+#
+# # LHS gradient.
+# grad = a.grad
+# expected_grad = acp.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# # RHS gradient.
+# grad = b.grad
+# expected_grad = bcp.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch211-cxx11-cu128-aarch64-linux/stk/ops/matrix_ops.py b/build/torch211-cxx11-cu128-aarch64-linux/stk/ops/matrix_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..447c72dc73439d84f58c917676cc04e64f13e97d
--- /dev/null
+++ b/build/torch211-cxx11-cu128-aarch64-linux/stk/ops/matrix_ops.py
@@ -0,0 +1,98 @@
+from ..backend import sputnik
+from ..matrix import Matrix
+import torch
+import numpy as np
+
+
+@torch.no_grad()
+def row_indices(shape, data, offsets, column_indices):
+ return sputnik.row_indices(shape, data, offsets, column_indices)
+
+
+# TODO(tgale): Replace this helper with a custom kernel. This operation
+# is much simpler to do than how it's currently implemented.
+@torch.no_grad()
+def _expand_for_blocking(idxs, blocking):
+ # Duplicate for block column dimension.
+ idxs = torch.reshape(idxs, [idxs.size()[0], 1, 2]).repeat(1, blocking, 1)
+
+ # Update the column indices.
+ idxs[:, :, 1] *= blocking
+ idxs[:, :, 1] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking])
+
+ # Duplicate for block row dimension.
+ idxs = torch.reshape(idxs, [idxs.size()[0], 1, blocking, 2])
+ idxs = idxs.repeat(1, blocking, 1, 1)
+
+ # Update the row indices.
+ idxs[:, :, :, 0] *= blocking
+ idxs[:, :, :, 0] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking, 1])
+ idxs = torch.reshape(idxs, [-1, 2])
+ return idxs
+
+
+# TODO(tgale): Add input type checking.
+@torch.no_grad()
+def to_dense(x):
+ assert isinstance(x, Matrix)
+
+ shape = (np.prod(x.shape[:-1]), x.shape[-1])
+ row_idxs = x.row_indices.type(torch.int32)
+ col_idxs = x.column_indices.type(torch.int32)
+ indices = _expand_for_blocking(torch.stack([row_idxs, col_idxs], dim=1), x.blocking)
+ indices = (indices[:, 0] * shape[1] + indices[:, 1]).type(torch.int64)
+
+ out = torch.zeros(shape[0] * shape[1], dtype=x.dtype, device=x.device)
+ out.scatter_(0, indices, x.data.flatten())
+ return out.reshape(x.size())
+
+
+@torch.no_grad()
+def _mask(x, blocking=1):
+ assert x.dim() == 2
+ assert x.size()[0] % blocking == 0
+ assert x.size()[1] % blocking == 0
+ block_rows = x.size()[0] // blocking
+ block_cols = x.size()[1] // blocking
+ x = torch.reshape(x, [block_rows, blocking, block_cols, blocking])
+ x = torch.sum(torch.abs(x), dim=(1, 3))
+ return x != 0
+
+
+# TODO(tgale): Add input type checking.
+@torch.no_grad()
+def to_sparse(x, blocking=1):
+ m = _mask(x, blocking)
+
+ # TODO(tgale): Set to appropriate type for input matrix.
+ row_nnzs = torch.sum(m, dim=1).type(torch.int32)
+ zeros = torch.zeros((1,), dtype=row_nnzs.dtype, device=row_nnzs.device)
+ offsets = torch.cat([zeros, torch.cumsum(row_nnzs, dim=0)])
+ offsets = offsets.type(torch.int32)
+
+ indices = torch.nonzero(m).type(torch.int16)
+ row_indices = indices[:, 0]
+ column_indices = indices[:, 1]
+
+ # Nonzero indices in the dense matrix.
+ nonzero_indices = torch.nonzero(m)
+ nonzero_indices = _expand_for_blocking(nonzero_indices, blocking)
+ nonzero_indices = nonzero_indices[:, 0] * x.size()[1] + nonzero_indices[:, 1]
+
+ # Gather the data and construct the sparse matrix.
+ data = torch.gather(x.flatten(), dim=0, index=nonzero_indices)
+ data = torch.reshape(data, [-1, blocking, blocking])
+ return Matrix(x.size(), data, row_indices, column_indices, offsets)
+
+
+@torch.no_grad()
+def ones_like(x):
+ return Matrix(x.size(),
+ torch.ones_like(x.data),
+ x.row_indices,
+ x.column_indices, x.offsets)
+
+
+def sum(x):
+ assert isinstance(x, Matrix)
+ return x.data.sum()
diff --git a/build/torch211-cxx11-cu128-aarch64-linux/stk/ops/matrix_ops_test.py b/build/torch211-cxx11-cu128-aarch64-linux/stk/ops/matrix_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..d172d921f6f08b0e4fb709207a458b0e1e071bd0
--- /dev/null
+++ b/build/torch211-cxx11-cu128-aarch64-linux/stk/ops/matrix_ops_test.py
@@ -0,0 +1,62 @@
+import unittest
+
+# from absl.testing import parameterized
+import stk
+import torch
+
+
+# @parameterized.parameters(
+# (8, 16, 0.0, 1),
+# (8, 16, 0.5, 1),
+# (8, 16, .95, 1),
+# (16, 8, 0.0, 1),
+# (16, 8, 0.5, 1),
+# (16, 8, .95, 1),
+# (8, 16, 0.0, 8),
+# (8, 16, 0.5, 8),
+# (8, 16, 1.0, 8),
+# (16, 8, 0.0, 8),
+# (16, 8, 0.5, 8),
+# (16, 8, 1.0, 8),
+# (128, 256, 0.5, 16),
+# (256, 128, 0.75, 32),
+# (512, 512, .875, 128))
+# class MatrixOpsTest(parameterized.TestCase):
+#
+# def testMatrixOps_FormatConversion(self, rows, cols, sparsity, blocking):
+# mask = stk.random.dense_mask(rows, cols, sparsity, blocking)
+# x = (torch.randn(rows, cols) * mask).type(torch.float16)
+#
+# # Convert the matrix to sparse format.
+# sparse_x = stk.ops.to_sparse(x, blocking)
+#
+# # Validate the matrix.
+# sparse_x.validate()
+#
+# # Validate the shape.
+# self.assertEqual(sparse_x.dim(), 2)
+# self.assertEqual(sparse_x.size()[0], rows)
+# self.assertEqual(sparse_x.size()[1], cols)
+#
+# # Validate the sparsity.
+# numblocks = rows // blocking * cols // blocking
+# nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+# self.assertEqual(sparse_x.nnz, nnz)
+#
+# # Convert back to dense format.
+# dense_x = stk.ops.to_dense(sparse_x)
+#
+# # Validate the shape.
+# self.assertEqual(dense_x.dim(), 2)
+# self.assertEqual(dense_x.size()[0], rows)
+# self.assertEqual(dense_x.size()[1], cols)
+#
+# # Validate the sparsity
+# self.assertEqual(torch.count_nonzero(dense_x).item(), nnz)
+#
+# # Validate the output.
+# self.assertTrue(torch.all(torch.eq(x, dense_x)))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch211-cxx11-cu128-aarch64-linux/stk/random/__init__.py b/build/torch211-cxx11-cu128-aarch64-linux/stk/random/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2576d1ca27283f77569a9a620c7c99fa68aaf30e
--- /dev/null
+++ b/build/torch211-cxx11-cu128-aarch64-linux/stk/random/__init__.py
@@ -0,0 +1,2 @@
+# from stk.random.random_ops import dense_mask, mask, randn
+from .random_ops import dense_mask, mask, randn
diff --git a/build/torch211-cxx11-cu128-aarch64-linux/stk/random/random_ops.py b/build/torch211-cxx11-cu128-aarch64-linux/stk/random/random_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1b36771e0eb8e7abf46bcb3b136b5fb1d29df93
--- /dev/null
+++ b/build/torch211-cxx11-cu128-aarch64-linux/stk/random/random_ops.py
@@ -0,0 +1,36 @@
+import numpy as np
+import torch
+from ..ops import matrix_ops
+
+
+@torch.no_grad()
+def dense_mask(rows, cols, sparsity, blocking=1):
+ assert sparsity >= 0.0 and sparsity <= 1.0
+ assert rows % blocking == 0 and cols % blocking == 0
+
+ block_rows, block_cols = (rows // blocking, cols // blocking)
+ nnz = round(block_rows * block_cols * (1 - sparsity))
+
+ out = np.ones(block_rows * block_cols)
+ mask = np.random.choice(out.size, out.size - nnz, replace=False)
+ out[mask] = 0.0
+
+ out = np.tile(
+ np.reshape(out, [block_rows, 1, block_cols, 1]),
+ (1, blocking, 1, blocking))
+ out = np.reshape(out, [rows, cols])
+ return torch.from_numpy(out.astype(np.float32))
+
+
+@torch.no_grad()
+def mask(m, n, sparsity, blocking=1):
+ out = dense_mask(m, n, sparsity, blocking).type(torch.float16)
+ return matrix_ops.to_sparse(out, blocking=blocking)
+
+
+@torch.no_grad()
+def randn(shape, sparsity, blocking=1):
+ shape_2d = (np.prod(shape[:-1]), shape[-1])
+ out = mask(*shape_2d, sparsity, blocking)
+ out.data.copy_(torch.randn(*out.data.shape))
+ return out.view(*shape)
diff --git a/build/torch211-cxx11-cu128-aarch64-linux/stk/random/random_ops_test.py b/build/torch211-cxx11-cu128-aarch64-linux/stk/random/random_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..71d716b78b5ec009cbf9ac2dfdf09162a0102e62
--- /dev/null
+++ b/build/torch211-cxx11-cu128-aarch64-linux/stk/random/random_ops_test.py
@@ -0,0 +1,73 @@
+import unittest
+
+# from absl.testing import parameterized
+from . import random
+import torch
+
+
+# @parameterized.parameters(
+# (8, 16, 0.0, 1),
+# (8, 16, 0.5, 1),
+# (8, 16, .95, 1),
+# (16, 8, 0.0, 1),
+# (16, 8, 0.5, 1),
+# (16, 8, .95, 1),
+# (8, 16, 0.0, 8),
+# (8, 16, 0.5, 8),
+# (8, 16, 1.0, 8),
+# (16, 8, 0.0, 8),
+# (16, 8, 0.5, 8),
+# (16, 8, 1.0, 8),
+# (128, 256, 0.5, 16),
+# (256, 128, 0.75, 32),
+# (512, 512, .875, 128))
+# class RandomOpsTest(parameterized.TestCase):
+#
+# def testRandomOps_DenseMask(self, rows, cols, sparsity, blocking):
+# mask = random.dense_mask(
+# rows, cols, sparsity, blocking)
+#
+# # Validate the shape.
+# self.assertEqual(mask.dim(), 2)
+# self.assertEqual(mask.size()[0], rows)
+# self.assertEqual(mask.size()[1], cols)
+#
+# # Validate the sparsity
+# numblocks = rows // blocking * cols // blocking
+# nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+# self.assertEqual(
+# torch.count_nonzero(mask).item(),
+# nnz)
+#
+# # Check values are zero or one.
+# self.assertTrue(
+# torch.all(torch.logical_or(
+# torch.eq(mask, 0),
+# torch.eq(mask, 1))))
+#
+# def testRandomOps_SparseMask(self, rows, cols, sparsity, blocking):
+# mask = random.mask(
+# rows, cols, sparsity, blocking)
+#
+# # Validate the matrix.
+# mask.validate()
+#
+# # Validate the shape.
+# self.assertEqual(mask.dim(), 2)
+# self.assertEqual(mask.size()[0], rows)
+# self.assertEqual(mask.size()[1], cols)
+#
+# # Validate the sparsity.
+# numblocks = rows // blocking * cols // blocking
+# nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+# self.assertEqual(mask.nnz, nnz)
+#
+# # Check values are zero or one.
+# self.assertTrue(
+# torch.all(torch.logical_or(
+# torch.eq(mask.data, 0),
+# torch.eq(mask.data, 1))))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch211-cxx11-cu128-aarch64-linux/xpu_fused_moe.py b/build/torch211-cxx11-cu128-aarch64-linux/xpu_fused_moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..0440b484e6f4073a384dcf3bc562174d2ef71d25
--- /dev/null
+++ b/build/torch211-cxx11-cu128-aarch64-linux/xpu_fused_moe.py
@@ -0,0 +1,672 @@
+# SPDX-License-Identifier: Apache-2.0
+# MegaBlocks XPU Fused MoE Implementation
+import os
+import torch
+
+from ._ops import ops, add_op_namespace_prefix
+
+from torch.library import register_fake
+
+
+def resolve_dtensor(weight: torch.Tensor):
+ """Convert DTensor to local tensor for use with custom ops."""
+ from torch.distributed._tensor import DTensor
+ if isinstance(weight, DTensor):
+ return weight.to_local()
+ return weight
+
+
+# Register fake/meta kernels for torch.compile compatibility
+def _register_xpu_fake_kernels():
+ """Register fake kernels for XPU MoE operations to support torch.compile."""
+
+ def _register_if_available(op_name, fn):
+ if hasattr(ops, op_name):
+ register_fake(add_op_namespace_prefix(op_name))(fn)
+
+ _register_if_available(
+ "cutlass_grouped_gemm_interface",
+ lambda ptr_A, ptr_B, ptr_scales, ptr_bias, ptr_D, expert_first_token_offset, N, K, num_experts, is_B_int4, is_B_mxfp4: ptr_D,
+ )
+
+ _register_if_available(
+ "fused_moe_prologue",
+ lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, ep_rank, ep_size, num_experts_on_rank: None,
+ )
+
+ _register_if_available(
+ "moe_gather",
+ lambda output, moe_output, topk_weights, permuted_row_to_unpermuted_row, unpermuted_row_to_permuted_row, expert_first_token_offset, num_experts: None,
+ )
+
+ _register_if_available(
+ "silu_and_mul",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "mul_and_silu",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_and_mul",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_tanh_and_mul",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_fast",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_new",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_quick",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "swigluoai_and_mul",
+ lambda out, input, alpha=1.702, limit=7.0: None,
+ )
+
+
+# Register fake kernels on module load
+_register_xpu_fake_kernels()
+
+
+# default
+def cutlass_grouped_gemm(input_A, input_B, bias, output, expert_token_count, n,
+ k, num_experts):
+ # expert_token_count_ = torch.tensor(expert_token_count,
+ # dtype=torch.int64,
+ # device=input_A.device)
+ # if bias is not None:
+ # bias = bias.repeat_interleave(expert_token_count_, dim=0).float()
+
+ def exclusive_prefix_sum(arr):
+ prefix = [0]
+ for i, x in enumerate(arr):
+ prefix.append(prefix[-1] + x)
+ return prefix
+
+ expert_offset = torch.tensor(exclusive_prefix_sum(expert_token_count),
+ dtype=torch.int64,
+ device="xpu")
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=None,
+ ptr_bias=bias,
+ ptr_D=output,
+ expert_first_token_offset=expert_offset,
+ N=n,
+ K=k,
+ num_experts=num_experts,
+ is_B_int4=False,
+ is_B_mxfp4=False)
+
+
+def cutlass_grouped_gemm_xe2(input_A, input_B, scales, bias, output,
+ num_rows_per_expert, n, k, num_experts, is_B_int4,
+ is_B_mxfp4):
+ expert_first_token_offset = torch.cat([
+ torch.tensor([0],
+ dtype=num_rows_per_expert.dtype,
+ device=num_rows_per_expert.device),
+ torch.cumsum(num_rows_per_expert, dim=0)
+ ]).to(torch.int64)
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=scales,
+ ptr_bias=bias,
+ ptr_D=output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=n,
+ K=k,
+ num_experts=num_experts,
+ is_B_int4=is_B_int4,
+ is_B_mxfp4=is_B_mxfp4)
+
+
+def ceilDiv(a, b):
+ return (a + b - 1) // b
+
+
+def compute_num_tokens_per_block(num_tokens, num_experts_per_node):
+ for num_tokens_per_block in [32, 64, 128, 256, 512, 1024]:
+ num_blocks_per_seq = ceilDiv(num_tokens, num_tokens_per_block)
+ if num_blocks_per_seq * num_experts_per_node <= num_tokens_per_block:
+ return num_tokens_per_block
+ return 1024
+
+
+def _bytes_to_typed_tensor(byte_tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
+ """Reinterpret a uint8 buffer as a typed tensor by copying bytes.
+
+ This avoids `Tensor.view(dtype)` which can fail under torch.compile
+ constant folding when shape divisibility is not proven.
+ """
+ if byte_tensor.dtype != torch.uint8:
+ raise ValueError("byte_tensor must be uint8")
+ itemsize = torch.empty((), dtype=dtype).element_size()
+ numel = byte_tensor.numel() // itemsize
+ out = torch.empty((numel,), dtype=dtype, device=byte_tensor.device)
+ out.view(torch.uint8).copy_(byte_tensor.contiguous())
+ return out
+
+
+def implement_zp(qweight):
+ # change u4 to s4 to avoid zero point in gemm kernel
+ # only support default zero point now
+ assert qweight.dtype == torch.uint8, "Input tensor must be uint8"
+
+ high_u4 = (qweight >> 4) & 0x0F
+ low_u4 = qweight & 0x0F
+
+ high_s8 = high_u4.to(torch.int8)
+ low_s8 = low_u4.to(torch.int8)
+
+ high_s8 = high_s8 - 8
+ low_s8 = low_s8 - 8
+
+ def pack_compact(a, b):
+
+ def process_number(x):
+ sign = (x < 0).to(torch.uint8)
+ abs_low3 = (x.view(torch.uint8) & 0x7).to(torch.uint8)
+ return (sign << 3) | abs_low3
+
+ packed_a = process_number(a)
+ packed_b = process_number(b)
+
+ return (packed_a << 4) | packed_b
+
+ result = pack_compact(high_s8, low_s8)
+
+ return result
+
+
+def xpu_fused_moe(hidden_states,
+ w13,
+ w13_scales,
+ w13_bias,
+ w2,
+ w2_scales,
+ w2_bias,
+ topk_weights,
+ topk_ids,
+ n_experts_per_token,
+ activation,
+ num_experts,
+ ep_rank=0,
+ ep_size=1,
+ is_fp8=False,
+ is_int4=False,
+ is_mxfp4=False):
+ '''
+ hidden_states: [num_rows, hidden_size]
+ w13: [num_experts, 2*inter_size, hidden_size]
+ w13_scales:
+ None for bf16/fp16
+ or [num_experts] for fp8
+ or [num_experts, 2*inter_size, hidden_size // group_size] for 4bits
+ w13_bias: [num_experts, 2*inter_size] or None
+ w2: [num_experts, hidden_size, inter_size]
+ w2_scales:
+ None for bf16/fp16
+ or [num_experts] for fp8
+ or [num_experts, hidden_size, inter_size // group_size] for 4bits
+ w2_bias: [num_experts, hidden_size] or None
+ topk_weights: [num_rows, topk]
+ topk_ids: [num_rows, topk]
+ n_experts_per_token: int
+ activation: str
+ num_experts: int
+ is_int4: bool
+ is_mxfp4: bool
+ '''
+
+ # Resolve DTensors to local tensors before passing to custom ops
+ hidden_states = resolve_dtensor(hidden_states)
+ w13 = resolve_dtensor(w13)
+ w2 = resolve_dtensor(w2)
+ if w13_scales is not None:
+ w13_scales = resolve_dtensor(w13_scales)
+ if w13_bias is not None:
+ w13_bias = resolve_dtensor(w13_bias)
+ if w2_scales is not None:
+ w2_scales = resolve_dtensor(w2_scales)
+ if w2_bias is not None:
+ w2_bias = resolve_dtensor(w2_bias)
+ topk_weights = resolve_dtensor(topk_weights)
+ topk_ids = resolve_dtensor(topk_ids)
+
+ output = torch.empty_like(hidden_states)
+ num_rows, hidden_size = list(hidden_states.shape)
+
+ dim_last = w13.shape[-1]
+ dim_second_last = w13.shape[-2]
+
+ # w13 is combined gate+up weights, so one dimension is 2*inter_size
+ # Determine which dimension is hidden_size and which is 2*inter_size
+ if dim_second_last == hidden_size:
+ # w13 is [E, hidden_size, 2*inter_size] - standard layout
+ inter_size = dim_last // 2
+ needs_transpose = False
+ else:
+ # w13 is [E, 2*inter_size, hidden_size] - needs transpose
+ inter_size = dim_second_last // 2
+ needs_transpose = True
+
+ assert w13.is_contiguous() and w2.is_contiguous()
+
+ # 4bits support [E, N, K]
+ # other types [E, K, N]
+ if not is_int4 and not is_mxfp4:
+ if not hasattr(w13, 'xpu_fused_moe'):
+ if needs_transpose:
+ w13.data = w13.transpose(-1, -2).contiguous()
+ w2.data = w2.transpose(-1, -2).contiguous()
+ w13.xpu_fused_moe = True
+ w13.inter_size = inter_size
+ else:
+ inter_size = w13.inter_size
+
+ if is_int4 and not hasattr(w13, 'xpu_fused_moe'):
+ w13_tmp = torch.empty_like(w13)
+ w2_tmp = torch.empty_like(w2)
+ for i in range(num_experts):
+ w13_tmp[i] = implement_zp(w13[i])
+ w2_tmp[i] = implement_zp(w2[i])
+ w13_tmp = w13_tmp.contiguous()
+ w2_tmp = w2_tmp.contiguous()
+ w13.data = w13_tmp
+ w2.data = w2_tmp
+ w13.xpu_fused_moe = True
+
+ # TODO: will all integrated in Cpp func. Temporary expose before gemm fusion
+ num_experts_per_node = num_experts
+ experts_per_token = n_experts_per_token
+ num_moe_inputs = n_experts_per_token * num_rows
+ permuted_elems = num_moe_inputs * hidden_size
+ # interbuf_elems = num_moe_inputs * inter_size
+ permuted_row_to_unpermuted_row_size = num_moe_inputs * 4
+ permuted_token_selected_experts_size = num_moe_inputs * 4
+ src_to_dest_map_size = experts_per_token * num_rows * 4
+ expert_first_token_offset_size = (num_experts_per_node + 1) * 8
+ num_tokens_per_block = compute_num_tokens_per_block(
+ num_rows, num_experts_per_node)
+ num_blocks_per_seq = ceilDiv(num_rows, num_tokens_per_block)
+ blocked_expert_counts_size = num_experts_per_node * num_blocks_per_seq * 4
+ blocked_expert_counts_cumsum_size = blocked_expert_counts_size
+ blocked_row_to_unpermuted_row_size = num_experts_per_node * num_rows * 4
+ permuted_data_size = permuted_elems * hidden_states.element_size()
+ permuted_token_final_scales_size = num_moe_inputs * 4
+
+ ws_map = {}
+ map_offset = 0
+
+ def config_ws(name, size):
+ nonlocal map_offset
+ if size % 256 != 0:
+ size += 256 - size % 256
+ ws_map[name] = (size, map_offset)
+ map_offset += size
+
+ config_ws("permuted_row_to_unpermuted_row",
+ permuted_row_to_unpermuted_row_size)
+ config_ws("permuted_token_selected_experts",
+ permuted_token_selected_experts_size)
+ config_ws("unpermuted_row_to_permuted_row", src_to_dest_map_size)
+ config_ws("blocked_expert_counts", blocked_expert_counts_size)
+ config_ws("blocked_expert_counts_cumsum",
+ blocked_expert_counts_cumsum_size)
+ config_ws("blocked_row_to_unpermuted_row",
+ blocked_row_to_unpermuted_row_size)
+ config_ws("expert_first_token_offset", expert_first_token_offset_size)
+ config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
+ config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
+
+ workspace = torch.zeros(map_offset,
+ dtype=torch.uint8,
+ device=hidden_states.device)
+ if topk_ids.dtype == torch.int32:
+ topk_ids = topk_ids.to(torch.int64)
+ ops.fused_moe_prologue(
+ input=hidden_states,
+ token_selected_experts=topk_ids,
+ token_final_scales=topk_weights,
+ workspace=workspace,
+ hidden_size=hidden_size,
+ inter_size=inter_size,
+ ep_rank=ep_rank,
+ ep_size=ep_size,
+ num_experts_on_rank=num_experts_per_node)
+
+ expert_first_token_offset_bytes = workspace[
+ ws_map["expert_first_token_offset"][1]:
+ ws_map["expert_first_token_offset"][1] +
+ expert_first_token_offset_size]
+ unpermuted_row_to_permuted_row_bytes = workspace[
+ ws_map["unpermuted_row_to_permuted_row"][1]:
+ ws_map["unpermuted_row_to_permuted_row"][1] +
+ src_to_dest_map_size]
+ permuted_row_to_unpermuted_row_bytes = workspace[
+ ws_map["permuted_row_to_unpermuted_row"][1]:
+ ws_map["permuted_row_to_unpermuted_row"][1] +
+ permuted_row_to_unpermuted_row_size]
+
+ if torch.compiler.is_compiling():
+ expert_first_token_offset = _bytes_to_typed_tensor(
+ expert_first_token_offset_bytes, torch.int64
+ )
+ unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
+ unpermuted_row_to_permuted_row_bytes, torch.int32
+ )
+ permuted_row_to_unpermuted_row = _bytes_to_typed_tensor(
+ permuted_row_to_unpermuted_row_bytes, torch.int32
+ )
+ else:
+ expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
+ unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
+ permuted_row_to_unpermuted_row = permuted_row_to_unpermuted_row_bytes.view(torch.int32)
+ gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
+ ws_map["overlapped_gemm1_gemm2_inputs"][1] +
+ permuted_data_size].view(hidden_states.dtype).view(
+ num_moe_inputs, hidden_size)
+ # permuted_token_final_scales = workspace[
+ # ws_map["permuted_token_final_scales"][1]:
+ # ws_map["permuted_token_final_scales"][1] +
+ # permuted_token_final_scales_size].view(torch.float)
+ gemm1_output = torch.empty((num_moe_inputs, 2 * inter_size),
+ dtype=hidden_states.dtype,
+ device=hidden_states.device)
+
+ ########### gemm1 ##################
+ input_B = w13
+
+ if not is_fp8 and not is_int4 and not is_mxfp4:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=gemm1_input,
+ ptr_B=input_B,
+ ptr_scales=None,
+ ptr_bias=w13_bias,
+ ptr_D=gemm1_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=2 * inter_size,
+ K=hidden_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+ else:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=gemm1_input,
+ ptr_B=input_B,
+ ptr_scales=w13_scales,
+ ptr_bias=w13_bias,
+ ptr_D=gemm1_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=2 * inter_size,
+ K=hidden_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+
+ # act
+ act_output = torch.empty((num_moe_inputs, inter_size),
+ dtype=gemm1_output.dtype,
+ device=gemm1_output.device)
+ if activation == "silu":
+ ops.silu_and_mul(act_output, gemm1_output)
+ elif activation == "gelu":
+ ops.gelu_and_mul(act_output, gemm1_output)
+ elif activation == "swigluoai":
+ ops.swigluoai_and_mul(act_output, gemm1_output, 1.702, 7.0)
+ else:
+ raise ValueError(f"Unsupported FusedMoe activation: {activation}.")
+
+ ########### gemm2 ##################
+ input_A = act_output.contiguous()
+ input_B = w2
+ gemm2_output = torch.empty((num_moe_inputs, hidden_size),
+ dtype=hidden_states.dtype,
+ device=hidden_states.device)
+ if not is_fp8 and not is_int4 and not is_mxfp4:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=None,
+ ptr_bias=w2_bias,
+ ptr_D=gemm2_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=hidden_size,
+ K=inter_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+ else:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=w2_scales,
+ ptr_bias=w2_bias,
+ ptr_D=gemm2_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=hidden_size,
+ K=inter_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+
+ ops.moe_gather(output, gemm2_output, topk_weights,
+ permuted_row_to_unpermuted_row,
+ unpermuted_row_to_permuted_row,
+ expert_first_token_offset,
+ num_experts_per_node)
+ return output
+
+
+def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor:
+ """Apply jitter to the input tensor for regularization."""
+ low = 1.0 - moe_jitter_eps
+ high = 1.0 + moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return x * (low + noise * (high - low))
+
+
+def compute_top_k(scores: torch.Tensor, moe_top_k: int):
+ """Compute the top-k scores from the logits."""
+ if moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, moe_top_k, dim=-1)
+
+
+def route_tokens_xpu(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: torch.Tensor,
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ training: bool = False,
+) -> tuple:
+ """Route tokens to experts and compute expert weights and indices (XPU version)."""
+ if training and moe_jitter_eps is not None:
+ x = apply_jitter(x, moe_jitter_eps)
+
+ x_flat = x.view(-1, x.shape[-1])
+ logits = torch.nn.functional.linear(x_flat, router_weight, router_bias)
+ expert_weights, expert_indices = compute_top_k(logits, moe_top_k)
+ expert_weights = expert_weights.softmax(dim=-1)
+ if moe_normalize_expert_weights is not None:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+
+ return logits, expert_weights, expert_indices
+
+
+def _get_device_mesh(model):
+ """Extract device_mesh from child's unused pre_hook closure for EP support."""
+ try:
+ hook = next(
+ h
+ for h in model.experts._forward_pre_hooks.values()
+ if "device_mesh" in h.__code__.co_freevars
+ )
+ return hook.__closure__[
+ hook.__code__.co_freevars.index("device_mesh")
+ ].cell_contents
+ except Exception:
+ return None
+
+
+class MegaBlocksMoeMLP(torch.nn.Module):
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> tuple:
+ """
+ Forward pass through the MoE layer.
+
+ Args:
+ x: Input tensor of shape [batch_size, seq_len, hidden_size] or [tokens, hidden_size]
+
+ Returns:
+ Tuple of (output, expert_weights) where:
+ - output: Tensor of same shape as input
+ - expert_weights: Expert weights for each token [tokens, top_k]
+ """
+
+ # Get MoE parameters from the wrapped modules
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+
+ # Get EP (Expert Parallelism) parameters
+ ep_size = 1
+ ep_rank = 0
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = _get_device_mesh(self)
+ if device_mesh is not None:
+ expert_parallel_group = device_mesh.get_group()
+ if expert_parallel_group is not None:
+ import torch.distributed as dist
+ if dist.is_initialized():
+ ep_size = dist.get_world_size(expert_parallel_group)
+ ep_rank = dist.get_rank(expert_parallel_group)
+
+ # Number of experts on this rank
+ num_experts_on_rank = moe_num_experts // ep_size
+
+ # Detect activation type - check for GptOss-style swigluoai activation
+ # GptOssExperts has alpha and limit attributes for swigluoai
+ if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
+ activation = "swigluoai"
+ else:
+ activation = getattr(self.experts, "activation", "silu")
+
+ # Get weight tensors - support different naming conventions
+ if hasattr(self.experts, "gate_up_proj"):
+ w13 = self.experts.gate_up_proj
+ # NOTE: swigluoai_and_mul kernel expects interleaved layout [g0,u0,g1,u1,...]
+ # which matches GptOss's gate_up_proj layout, so no conversion needed.
+
+ elif hasattr(self.experts, "w1"):
+ # Combine w1 and w3 if stored separately
+ w1 = self.experts.w1
+ w3 = getattr(self.experts, "w3", None)
+ if w3 is not None:
+ w13 = torch.cat([w1, w3], dim=-2)
+ else:
+ w13 = w1
+ else:
+ raise AttributeError("experts module must have 'gate_up_proj' or 'w1' attribute")
+
+ if hasattr(self.experts, "down_proj"):
+ w2 = self.experts.down_proj
+ elif hasattr(self.experts, "w2"):
+ w2 = self.experts.w2
+ else:
+ raise AttributeError("experts module must have 'down_proj' or 'w2' attribute")
+
+ # Get optional bias tensors
+ w13_bias = getattr(self.experts, "gate_up_proj_bias", None)
+ w2_bias = getattr(self.experts, "down_proj_bias", None)
+
+ # Get quantization info
+ is_fp8 = getattr(self.experts, "is_fp8", False)
+ is_int4 = getattr(self.experts, "is_int4", False)
+ is_mxfp4 = getattr(self.experts, "is_mxfp4", False)
+
+ w13_scales = getattr(self.experts, "gate_up_proj_scales", None)
+ w2_scales = getattr(self.experts, "down_proj_scales", None)
+
+ # Store original shape
+ in_shape = x.size()
+
+ # Route tokens to experts
+ logits, expert_weights, expert_indices = route_tokens_xpu(
+ x,
+ self.router.weight,
+ getattr(self.router, "bias", None),
+ moe_top_k,
+ moe_num_experts,
+ moe_jitter_eps,
+ moe_normalize_expert_weights,
+ self.training,
+ )
+
+ # Reshape input for fused MoE
+ x_flat = x.view(-1, x.shape[-1])
+
+ # Call XPU fused MoE kernel
+ output = xpu_fused_moe(
+ hidden_states=x_flat,
+ w13=w13,
+ w13_scales=w13_scales,
+ w13_bias=w13_bias,
+ w2=w2,
+ w2_scales=w2_scales,
+ w2_bias=w2_bias,
+ topk_weights=expert_weights.float(),
+ topk_ids=expert_indices,
+ n_experts_per_token=moe_top_k,
+ activation=activation,
+ num_experts=num_experts_on_rank,
+ ep_rank=ep_rank,
+ ep_size=ep_size,
+ is_fp8=is_fp8,
+ is_int4=is_int4,
+ is_mxfp4=is_mxfp4,
+ )
+
+ # All-reduce across EP group to combine partial expert outputs
+ if ep_size > 1 and expert_parallel_group is not None:
+ import torch.distributed as dist
+ dist.all_reduce(output, op=dist.ReduceOp.SUM, group=expert_parallel_group)
+
+ # Restore original shape
+ output = output.view(in_shape)
+
+ return output, expert_weights
+
+
+# Export classes and functions
+__all__ = [
+ "MegaBlocksMoeMLP",
+ "xpu_fused_moe",
+ "cutlass_grouped_gemm",
+ "cutlass_grouped_gemm_xe2",
+]
\ No newline at end of file
diff --git a/build/torch211-cxx11-cu128-x86_64-linux/__init__.py b/build/torch211-cxx11-cu128-x86_64-linux/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..38075732c6d8fa0e1e6ef493145e1aca3851ae6b
--- /dev/null
+++ b/build/torch211-cxx11-cu128-x86_64-linux/__init__.py
@@ -0,0 +1,202 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+from ._ops import ops
+
+from .grouped_gemm import backend as gg_backend
+from .grouped_gemm import ops as gg_ops
+
+
+from ._layers.arguments import Arguments
+from ._layers.dmoe import ParallelDroplessMLP, dMoE
+from ._layers.glu import SparseGLU
+from ._layers.mlp import MLP, SparseMLP
+from ._layers.moe import MoE, ParallelMLP, get_load_balancing_loss
+
+from . import layers
+
+# This section contains the direct kernel exports (not inlcuded in the original code)
+def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
+ """
+ Compute exclusive cumulative sum along the specified dimension.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ result = ops.exclusive_cumsum(x, dim)
+ out.copy_(result)
+ return out
+
+
+def inclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
+ """
+ Compute inclusive cumulative sum along the specified dimension.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ result = ops.inclusive_cumsum(x, dim)
+ out.copy_(result)
+ return out
+
+
+def histogram(x: torch.Tensor, num_bins: int) -> torch.Tensor:
+ """
+ Compute histogram of input tensor values.
+
+ Args:
+ x: Input tensor
+ num_bins: Number of histogram bins
+
+ Returns:
+ Histogram tensor with counts for each bin
+ """
+ return ops.histogram(x, num_bins)
+
+
+def indices(
+ padded_bins: torch.Tensor,
+ block_size: int,
+ output_block_rows: int,
+ output_block_columns: int,
+) -> torch.Tensor:
+ """
+ Construct indices from padded bins for sparse operations.
+
+ Args:
+ padded_bins: Tensor containing bin boundaries
+ block_size: Size of each block
+ output_block_rows: Number of rows in output blocks
+ output_block_columns: Number of columns in output blocks
+
+ Returns:
+ Tensor containing constructed indices
+ """
+ return ops.indices(padded_bins, block_size, output_block_rows, output_block_columns)
+
+
+def replicate_forward(
+ x: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Forward pass of replicate operation - replicate values according to bin sizes.
+
+ Args:
+ x: Input tensor with values to replicate
+ bins: Tensor containing bin sizes
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ return ops.replicate_forward(x, bins, out)
+
+
+def replicate_backward(
+ grad: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Backward pass of replicate operation - reduce gradients back to bins.
+
+ Args:
+ grad: Gradient tensor to reduce
+ bins: Tensor containing bin sizes
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ return ops.replicate_backward(grad, bins, out)
+
+
+def sort(
+ x: torch.Tensor, end_bit: int, x_out: torch.Tensor, iota_out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Radix sort with index tracking.
+
+ Args:
+ x: Input tensor to sort
+ end_bit: Number of bits to consider in sorting
+ x_out: Output tensor for sorted values
+ iota_out: Output tensor for sorted indices
+
+ Returns:
+ The sorted values tensor
+ """
+ return ops.sort(x, end_bit, x_out, iota_out)
+
+
+# Convenience functions for common use cases
+def cumsum(x: torch.Tensor, dim: int = -1, exclusive: bool = False) -> torch.Tensor:
+ """
+ Compute cumulative sum with automatic output allocation.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum (default: last dimension)
+ exclusive: Whether to compute exclusive (True) or inclusive (False) cumsum
+
+ Returns:
+ New tensor containing the cumulative sum
+ """
+ out = torch.empty_like(x)
+ if exclusive:
+ return exclusive_cumsum(x, dim, out)
+ else:
+ return inclusive_cumsum(x, dim, out)
+
+
+def argsort(x: torch.Tensor, end_bit: int = 32) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Sort tensor and return both sorted values and indices.
+
+ Args:
+ x: Input tensor to sort
+ end_bit: Number of bits to consider in sorting
+
+ Returns:
+ Tuple of (sorted_values, sorted_indices)
+ """
+ x_out = torch.empty_like(x)
+ iota_out = torch.empty_like(x)
+ sort(x, end_bit, x_out, iota_out)
+ return x_out, iota_out
+
+
+# Export public API
+__all__ = [
+ "MyReplacementLayer",
+ # Direct kernel exports
+ "exclusive_cumsum",
+ "inclusive_cumsum",
+ "histogram",
+ "indices",
+ "replicate_forward",
+ "replicate_backward",
+ "sort",
+ "cumsum",
+ "argsort",
+ # Original exports
+ "Arguments",
+ "ParallelDroplessMLP",
+ "dMoE",
+ "SparseGLU",
+ "MLP",
+ "SparseMLP",
+ "MoE",
+ "ParallelMLP",
+ "get_load_balancing_loss",
+]
diff --git a/build/torch211-cxx11-cu128-x86_64-linux/_layers/__init__.py b/build/torch211-cxx11-cu128-x86_64-linux/_layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a720e7a2cc4e44636f6e433a2750e945dc38e8b2
--- /dev/null
+++ b/build/torch211-cxx11-cu128-x86_64-linux/_layers/__init__.py
@@ -0,0 +1,10 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# from megablocks.layers.dmoe import dMoE
+from .moe import MoE
+
+__all__ = [
+ 'MoE',
+ # 'dMoE',
+]
diff --git a/build/torch211-cxx11-cu128-x86_64-linux/_layers/activation_fn.py b/build/torch211-cxx11-cu128-x86_64-linux/_layers/activation_fn.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e1d956704840aa4daf7d1d71d24e051567feab9
--- /dev/null
+++ b/build/torch211-cxx11-cu128-x86_64-linux/_layers/activation_fn.py
@@ -0,0 +1,33 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Callable, Union
+
+import torch
+from ..stk import Matrix
+
+
+def act_fn(
+ x: Matrix,
+ function: Callable,
+ return_grad_fn: bool = False,
+ **kwargs,
+) -> Union[tuple[Matrix, Any] | Matrix]:
+ assert isinstance(x, Matrix)
+ with torch.set_grad_enabled(torch.is_grad_enabled() or return_grad_fn):
+ if return_grad_fn:
+ x.data.requires_grad = True
+ out = function(x.data, **kwargs)
+ y = Matrix(
+ x.size(),
+ out,
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
+ if return_grad_fn:
+ return y, out.backward
+ return y
diff --git a/build/torch211-cxx11-cu128-x86_64-linux/_layers/all_to_all.py b/build/torch211-cxx11-cu128-x86_64-linux/_layers/all_to_all.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ac7067bcaa34db1d82b340c43550fe3577aa7a3
--- /dev/null
+++ b/build/torch211-cxx11-cu128-x86_64-linux/_layers/all_to_all.py
@@ -0,0 +1,54 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import torch.distributed as dist
+
+
+class AllToAllOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op):
+ out = torch.empty((sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype)
+
+ ctx.input_shape = x.shape
+ ctx.output_split_sizes = output_split_sizes
+ ctx.input_split_sizes = input_split_sizes
+ ctx.group = group
+ handle = dist.all_to_all_single(
+ out,
+ x,
+ output_split_sizes=output_split_sizes,
+ input_split_sizes=input_split_sizes,
+ group=group,
+ async_op=async_op,
+ )
+ return out, handle
+
+ @staticmethod
+ def backward(ctx, grad, _):
+ if ctx.needs_input_grad[0]:
+ out = torch.empty(
+ ctx.input_shape,
+ device=grad.device,
+ dtype=grad.dtype,
+ )
+ dist.all_to_all_single(
+ out,
+ grad,
+ output_split_sizes=ctx.input_split_sizes,
+ input_split_sizes=ctx.output_split_sizes,
+ group=ctx.group,
+ )
+ return out, None, None, None, None
+ return None, None, None, None, None
+
+
+def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False):
+ return AllToAllOp.apply(
+ x,
+ output_split_sizes,
+ input_split_sizes,
+ group,
+ async_op,
+ )
diff --git a/build/torch211-cxx11-cu128-x86_64-linux/_layers/arguments.py b/build/torch211-cxx11-cu128-x86_64-linux/_layers/arguments.py
new file mode 100644
index 0000000000000000000000000000000000000000..4db9b1bd38bc2e2f421625c124f86b85f45c5ae0
--- /dev/null
+++ b/build/torch211-cxx11-cu128-x86_64-linux/_layers/arguments.py
@@ -0,0 +1,101 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import dataclasses
+from functools import partial
+from typing import Any, Callable, Optional, Union
+
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+
+# import megablocks.grouped_gemm_util as grouped_gemm
+from .. import grouped_gemm_util as grouped_gemm
+
+# Type annotation for in-place Tensor initialization function.
+InitFn = Union[Callable[[torch.Tensor], None], partial[torch.Tensor]]
+
+_ALLOWED_BITWIDTHS = (-1, 4, 8)
+
+DEFAULT_ACTIVATION_FN = partial(F.gelu, approximate='tanh')
+
+
+@dataclasses.dataclass
+class Arguments:
+ # Model arguments.
+ hidden_size: int = 1024
+ ffn_hidden_size: int = 4096
+ num_layers: int = 1
+ bias: bool = True
+ return_bias: bool = True
+ activation_fn: Optional[Callable] = DEFAULT_ACTIVATION_FN
+
+ # MoE arguments.
+ moe_num_experts: int = 1
+ moe_top_k: int = 1
+ moe_capacity_factor: int = 1
+ moe_normalize_expert_weights: Optional[Union[int, float]] = None
+ moe_loss_weight: float = 0.1
+ moe_jitter_eps: Optional[float] = None
+ moe_lbl_in_fp32: bool = False
+
+ # Parallelism arguments.
+ moe_expert_model_parallelism: bool = False
+ expert_parallel_group: Optional[dist.ProcessGroup] = None
+ pipeline_model_parallel_size: int = 1
+ num_layers_per_virtual_pipeline_stage: Optional[int] = None
+
+ # Compute arguments.
+ memory_optimized_mlp: bool = False
+ mlp_type: str = 'mlp'
+ mlp_impl: str = 'sparse'
+
+ # Initialization arguments.
+ fp16: bool = True
+ bf16: bool = False
+ device: Union[int, torch.device] = dataclasses.field(default_factory=torch.cuda.current_device)
+ init_method: InitFn = partial(torch.nn.init.normal_, mean=0.0, std=0.02)
+ output_layer_init_method: InitFn = init_method
+
+ # Benchmarking arguments.
+ uniform_expert_assignment: bool = False
+
+ # shared expert arguments
+ shared_expert: bool = False # enable using shared expert
+ fc_cls: Any = torch.nn.Linear # class of the fully connected layer in shared expert (purpose: to allow using custom FC layer eg te.Linear (for FP8))
+ fc_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict,) # kwargs for custom fc layers
+ remat_act_fn: bool = True # enable act fn to be rematerialized instead of stored
+ shared_expert_hidden_size: Optional[
+ int] = None # hidden size of the shared expert IF we want to set it to something different from hidden_size
+ shared_expert_weighted_sum: bool = False # enable using weighted sum for shared expert output (wieghted by number of experts used)
+
+ # Router Z-loss arguments
+ moe_zloss_weight: float = 0 # 1e-3 is a reasonable value
+ moe_zloss_in_fp32: bool = False
+
+ def __post_init__(self):
+ # Sparse MLP is not supported with triton >=3.2.0
+ # TODO: Remove this once sparse is supported with triton >=3.2.0
+ if self.__getattribute__('mlp_impl') == 'sparse':
+ try:
+ import triton
+ if triton.__version__ >= '3.2.0':
+ raise ValueError(
+ 'Sparse MLP is not supported with triton >=3.2.0. Please use mlp_impl="grouped" instead.',
+ )
+ except ImportError:
+ raise ImportError('Triton is required for sparse MLP implementation')
+
+ if self.__getattribute__('mlp_impl') == 'grouped':
+ grouped_gemm.assert_grouped_gemm_is_available()
+
+ if self.shared_expert_hidden_size is None:
+ self.shared_expert_hidden_size = self.ffn_hidden_size
+
+
+def from_megatron(megatron_args: Any):
+ args = Arguments()
+ for field in dataclasses.fields(args):
+ if hasattr(megatron_args, field.name):
+ setattr(args, field.name, getattr(megatron_args, field.name))
+ return args
diff --git a/build/torch211-cxx11-cu128-x86_64-linux/_layers/common.py b/build/torch211-cxx11-cu128-x86_64-linux/_layers/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d07109702963ba48a3b94ab860807954dfd79c1
--- /dev/null
+++ b/build/torch211-cxx11-cu128-x86_64-linux/_layers/common.py
@@ -0,0 +1,26 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+from .arguments import Arguments
+
+
+def dtype(args: Arguments):
+ if args.fp16:
+ return torch.float16
+ elif args.bf16:
+ return torch.bfloat16
+ return None
+
+
+def cast_if_autocast_enabled(tensor):
+ if torch.is_autocast_enabled():
+ if tensor.device.type == 'cuda':
+ dtype = torch.get_autocast_gpu_dtype()
+ elif tensor.device.type == 'cpu':
+ dtype = torch.get_autocast_cpu_dtype()
+ else:
+ raise NotImplementedError()
+ return tensor.to(dtype=dtype)
+ return tensor
diff --git a/build/torch211-cxx11-cu128-x86_64-linux/_layers/dmlp_registry.py b/build/torch211-cxx11-cu128-x86_64-linux/_layers/dmlp_registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..de2ed047042e438c7190ebb139b6f7f30009734c
--- /dev/null
+++ b/build/torch211-cxx11-cu128-x86_64-linux/_layers/dmlp_registry.py
@@ -0,0 +1,42 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Union
+
+from . import glu, mlp
+from .arguments import Arguments
+
+MlpType = Union[mlp.SparseMLP, glu.SparseGLU]
+
+_REGISTRY = {
+ 'mlp': {
+ 'grouped': mlp.GroupedMLP,
+ 'sparse': mlp.SparseMLP,
+ },
+ 'glu': {
+ 'grouped': glu.GroupedGLU,
+ 'sparse': glu.SparseGLU,
+ },
+}
+
+
+def get(args: Arguments) -> MlpType:
+ """Returns an MLP for use in a dMoE instance.
+
+ Uses the provided arguments to instantiate the appropriate
+ MLP instance. This only contains MLPs for use in dMoEs
+ (ie. only for the dropless versions of MoEs).
+
+ Args:
+ args: propagated Arguments dataclass.
+
+ Returns:
+ An instantiated MLP constructed using the input args.
+ """
+ if args.mlp_type not in _REGISTRY:
+ raise ValueError(f'Unsupported mlp type: {args.mlp_type}')
+
+ if args.mlp_impl not in _REGISTRY[args.mlp_type]:
+ raise ValueError(f'{args.mlp_type} does not support {args.mlp_impl} backend.',)
+
+ return _REGISTRY[args.mlp_type][args.mlp_impl](args)
diff --git a/build/torch211-cxx11-cu128-x86_64-linux/_layers/dmoe.py b/build/torch211-cxx11-cu128-x86_64-linux/_layers/dmoe.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d0375a4df2f27134c4127e60be04f3b45693050
--- /dev/null
+++ b/build/torch211-cxx11-cu128-x86_64-linux/_layers/dmoe.py
@@ -0,0 +1,337 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import numpy as np
+import torch
+
+# try:
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/dmoe.py is needed.',
+# )
+
+# import megablocks.ops as ops
+# # from megablocks.ops import ops
+# from megablocks.layers import common, dmlp_registry, moe, mpu
+# from megablocks.layers.arguments import Arguments
+
+from .. import stk
+from .. import ops
+from . import common, dmlp_registry, moe, mpu
+from .arguments import Arguments
+
+def promote_scalar(x):
+ return x.view(1) if not len(x.size()) else x
+
+
+class ParallelDroplessMLP(moe.ParallelMLP):
+
+ def __init__(self, args: Arguments):
+ super(ParallelDroplessMLP, self).__init__(args)
+ self.hidden_size = args.hidden_size
+ self.ffn_hidden_size = mpu.features_per_rank(args)
+ self.blocking = 128
+ self.mlp = dmlp_registry.get(args)
+
+ # Calculate the number of bits needed to represent the column indices
+ # in the intermediate sparse matrix.
+ max_column_index = ((self.ffn_hidden_size * self.num_experts) // self.blocking)
+ self.transpose_sort_end_bit = max(
+ int(np.ceil(np.log2(max_column_index))),
+ 1,
+ )
+
+ def sparse_transpose(self, size, row_indices, column_indices, offsets):
+ block_columns = size[1] // self.blocking
+
+ # Sort row indices by column indices to get the transposed matrix's
+ # column indices.
+ #
+ # NOTE: Our sort operation uses the same width indices as the input values.
+ # To avoid overflow when we have large activation matrices we cast to
+ # 32-bit before sorting.
+ _, gather_indices = ops.sort(
+ column_indices.int(),
+ self.transpose_sort_end_bit,
+ )
+
+ # There are a constant number of blocks in every row of the sparse matrix.
+ # A blocks offset is:
+ #
+ # row_index * blocks_per_row + column_index % blocks_per_row
+ #
+ # Once we have the block offsets ordered for transposition we can divide
+ # by blocks_per_row to get the transposed column indices.
+ column_indices_t = row_indices.gather(0, gather_indices.long())
+ block_offsets_t = gather_indices.int()
+
+ zero = torch.zeros((1,), dtype=torch.int32, device=row_indices.device)
+ nnz_per_column = ops.histogram(column_indices, block_columns)
+ nnz_per_column = ops.inclusive_cumsum(nnz_per_column, 0)
+ if nnz_per_column.dim() == 0:
+ # This addresses an edge case when ffn_hidden_size is equal to self.blocking.
+ nnz_per_column = nnz_per_column.unsqueeze(0)
+ offsets_t = torch.cat([zero, nnz_per_column])
+ return column_indices_t, offsets_t, block_offsets_t
+
+ def topology(self, x, padded_bins):
+ padded_tokens, _ = x.size()
+ assert padded_tokens % self.blocking == 0
+ if self.ffn_hidden_size % self.blocking != 0:
+ raise ValueError(
+ f'The ffn_hidden_size {self.ffn_hidden_size} must be divisible by ' +
+ f'the block size {self.blocking}. Please update your configuration.',
+ )
+
+ # Offsets for the sparse matrix. All rows have the
+ # same number of nonzero blocks dictated by the
+ # dimensionality of a single expert.
+ block_rows = padded_tokens // self.blocking
+ blocks_per_row = self.ffn_hidden_size // self.blocking
+ offsets = torch.arange(
+ 0,
+ block_rows * blocks_per_row + 1,
+ blocks_per_row,
+ dtype=torch.int32,
+ device=x.device,
+ )
+
+ # Indices for the sparse matrix. The indices for
+ # the intermediate matrix are dynamic depending
+ # on the mapping of tokens to experts.
+ column_indices = ops.topology(
+ padded_bins,
+ self.blocking,
+ block_rows,
+ blocks_per_row,
+ )
+
+ # TODO(tgale): This is unused. Remove the need for this in stk.
+ # For now, use meta init to save the device memory.
+ data = torch.empty(
+ column_indices.numel(),
+ self.blocking,
+ self.blocking,
+ dtype=common.dtype(self.args),
+ device='meta',
+ )
+ shape = (
+ padded_tokens,
+ self.ffn_hidden_size * mpu.experts_per_rank(self.args),
+ )
+ row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
+ column_indices_t, offsets_t, block_offsets_t = self.sparse_transpose(
+ shape,
+ row_indices,
+ column_indices,
+ offsets,
+ )
+ return stk.Matrix(
+ shape,
+ data,
+ row_indices,
+ column_indices,
+ offsets,
+ column_indices_t,
+ offsets_t,
+ block_offsets_t,
+ )
+
+ def indices_and_padded_bins(self, top_experts):
+ # Sort the expert ids to produce the scatter/gather
+ # indices for the permutation.
+ top_experts = top_experts.int()
+ bin_ids, indices = ops.sort(top_experts, self.sort_end_bit)
+
+ # Histogram the expert ids to identify the number of
+ # tokens routed to each expert.
+ tokens_per_expert = ops.histogram(top_experts, self.num_experts)
+
+ # Round the token counts up to the block size used in
+ # the matrix muliplications. Caculate the starting
+ # position of each bin.
+ padded_tokens_per_expert = ops.round_up(
+ tokens_per_expert,
+ self.blocking,
+ )
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ padded_bins = promote_scalar(padded_bins)
+
+ # Calculate the bin bounds for the sorted tokens.
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ bins = promote_scalar(bins)
+ return indices, bin_ids, bins, padded_bins, tokens_per_expert
+
+ def sparse_forward_once(self, x, expert_weights, top_experts):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, padded_bins, tokens_per_expert = (self.indices_and_padded_bins(top_experts))
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.padded_gather(
+ x,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ self.top_k,
+ )
+
+ # Create the sparse matrix topology.
+ with torch.no_grad():
+ topo = self.topology(x, padded_bins)
+
+ # Perform the expert computation.
+ x = self.mlp(x, topo)
+
+ # Un-route the data for the MoE output.
+ x = ops.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ padded_bins,
+ self.top_k,
+ )
+ return x, tokens_per_expert
+
+ # For use in the base-class parallel_forward_once.
+ def sparse_permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy, # unused
+ top_k,
+ ):
+
+ # Round the token counts up to the block size used in the matrix
+ # multiplication. Calculate the starting position of each bin.
+ padded_tokens_per_expert = ops.round_up(
+ tokens_per_expert,
+ self.blocking,
+ )
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ padded_bins = promote_scalar(padded_bins)
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
+
+ # Create the sparse matrix topology.
+ with torch.no_grad():
+ topo = self.topology(x, padded_bins)
+
+ # Perform the expert computation.
+ x = self.mlp(x, topo)
+
+ # Un-route the data for the MoE output.
+ return ops.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ def grouped_forward_once(self, x, expert_weights, top_experts):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ out = self.grouped_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ -1, # unused
+ self.args.moe_top_k,
+ )
+ return out, tokens_per_expert
+
+ def grouped_permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy, # unused
+ top_k,
+ ):
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.gather(x, indices, bin_ids, bins, top_k)
+
+ # Perform the expert computation.
+ x = self.mlp(x, tokens_per_expert)
+
+ # Un-route the data for the MoE output.
+ return ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
+
+ def forward_once(self, x, expert_weights, top_experts):
+ if self.args.mlp_impl == 'sparse':
+ return self.sparse_forward_once(x, expert_weights, top_experts)
+ else:
+ return self.grouped_forward_once(x, expert_weights, top_experts)
+
+ def permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ ):
+ if self.args.mlp_impl == 'sparse':
+ return self.sparse_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ )
+ else:
+ return self.grouped_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ )
+
+
+class dMoE(moe.MoE):
+
+ def _init_experts_mlp(self, args: Arguments):
+ return ParallelDroplessMLP(args)
diff --git a/build/torch211-cxx11-cu128-x86_64-linux/_layers/gelu.py b/build/torch211-cxx11-cu128-x86_64-linux/_layers/gelu.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4c9e6532798615b5c12c96694241a4c18ee8f7b
--- /dev/null
+++ b/build/torch211-cxx11-cu128-x86_64-linux/_layers/gelu.py
@@ -0,0 +1,52 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# try:
+# import stk
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/gelu.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+import torch.nn.functional as F
+
+
+@torch.jit.script
+def _gelu_backward_inplace(g, x):
+ tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
+ ff = (0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out))
+ return g.mul_(ff)
+
+
+def gelu_backward_(grad: stk.Matrix, x: stk.Matrix):
+ # NOTE: The two sparse matrices must have the same topology.
+ if isinstance(grad, stk.Matrix) and isinstance(x, stk.Matrix):
+ return stk.Matrix(
+ x.size(),
+ _gelu_backward_inplace(grad.data, x.data),
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
+ return _gelu_backward_inplace(grad, x)
+
+
+def gelu(x: stk.Matrix):
+ assert isinstance(x, stk.Matrix)
+ return stk.Matrix(
+ x.size(),
+ F.gelu(x.data, approximate='tanh'),
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
diff --git a/build/torch211-cxx11-cu128-x86_64-linux/_layers/glu.py b/build/torch211-cxx11-cu128-x86_64-linux/_layers/glu.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f297a41ff6a1a2a285f5b461951672364b898da
--- /dev/null
+++ b/build/torch211-cxx11-cu128-x86_64-linux/_layers/glu.py
@@ -0,0 +1,244 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# import stk.ops
+# try:
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/glu.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+
+# from megablocks import grouped_gemm_util as gg
+# from megablocks.layers import common, mpu
+# from megablocks.layers.activation_fn import act_fn
+# from megablocks.layers.arguments import Arguments
+# from megablocks.layers.mlp import (
+# SharedMLP,
+# SparseMLP,
+# create_dmoe_expert_weights,
+# resolve_dtensor,
+# )
+
+from .. import grouped_gemm_util as gg
+from . import common, mpu
+from .activation_fn import act_fn
+from .arguments import Arguments
+from .mlp import (
+ SharedMLP,
+ SparseMLP,
+ create_dmoe_expert_weights,
+ resolve_dtensor,
+)
+
+
+class SparseGLU(SparseMLP):
+
+ def __init__(self, args: Arguments):
+ super().__init__(args)
+ self.v1 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ with torch.no_grad():
+ self.v1.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ ),
+ )
+
+ mpu.set_expert_model_parallel_attributes(
+ self.v1,
+ self._should_set_parallelism_attribute,
+ )
+
+ def forward(self, x, topo):
+ if self.args.memory_optimized_mlp:
+ raise NotImplementedError(
+ 'Memory optimized implementation not yet supported with GLU with sparse kernels.',
+ )
+
+ w1, v1, w2 = self.scale_grad(self.w1), self.scale_grad(self.v1,), self.scale_grad(self.w2)
+ w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2)
+
+ # Compute the GLU.
+ x1 = stk.ops.sdd(x, w1.t(), topo)
+ x2 = stk.ops.sdd(x, v1.t(), topo)
+
+ activation_fn_out = act_fn(x1, self.args.activation_fn)
+ x1 = stk.ops.mul(activation_fn_out, x2)
+
+ return stk.ops.dsd(x1, w2)
+
+
+class MemoryOptimizedGroupedGLU(torch.autograd.Function):
+ """GroupedMLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, v1, w2, batch_sizes, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ v1 = v1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], v1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not v1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1', 'v1' and 'w2'.")
+
+ # Layer 0: x @ w1.t().
+ assert gg.backend is not None
+ sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True)
+ v1_out = gg.backend.gmm(x, v1, batch_sizes, trans_b=True)
+
+ # GeLU.
+ activation_fn_out = activation_fn(sdd_out) * v1_out
+
+ # Layer 1: x @ w2.
+ dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, v1, w2, batch_sizes, x, sdd_out, v1_out)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx, ddsd_out):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # Unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, v1, w2 = saved_tensors[:3]
+ batch_sizes = saved_tensors[3]
+ x = saved_tensors[4]
+ sdd_out, v1_out = saved_tensors[5:7]
+
+ # Rematerialize activation_fn output.
+ activation_fn = ctx.activation_fn
+ with torch.set_grad_enabled(True):
+ sdd_out.requires_grad = True
+ v1_out.requires_grad = True
+ activation_fn_out = activation_fn(sdd_out) * v1_out
+ activation_grad_fn = activation_fn_out.backward
+
+ # Compute dw2 with recomputed activation_fn output.
+ assert gg.backend is not None
+ dw2 = gg.backend.gmm(
+ activation_fn_out,
+ ddsd_out,
+ batch_sizes,
+ trans_a=True,
+ )
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ gg.backend.gmm(
+ ddsd_out,
+ w2,
+ batch_sizes,
+ trans_b=True,
+ c=dactivation_fn_out,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out)
+ dsdd_out = sdd_out.grad
+ dv1_out = v1_out.grad
+
+ # Compute dw1.
+ dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True)
+
+ # Compute dv1.
+ dv1 = gg.backend.gmm(dv1_out, x, batch_sizes, trans_a=True)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ dx = ddsd_out
+ gg.backend.gmm(dsdd_out, w1, batch_sizes, c=dx)
+ dx += gg.backend.gmm(dv1_out, v1, batch_sizes)
+ return dx, dw1, dv1, dw2, None, None
+
+
+memory_optimized_grouped_glu = MemoryOptimizedGroupedGLU.apply
+
+
+class GroupedGLU(SparseGLU):
+
+ def forward(self, x, tokens_per_expert):
+ batch_sizes = tokens_per_expert.cpu().to(torch.long)
+ w1, v1, w2 = (
+ self.scale_grad(self.w1),
+ self.scale_grad(self.v1),
+ self.scale_grad(self.w2),
+ )
+ w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2)
+
+ # Re-shape the weights for the grouped GEMMs.
+ ne = mpu.experts_per_rank(self.args)
+ w1 = w1.view(ne, -1, self.args.hidden_size)
+ v1 = v1.view(ne, -1, self.args.hidden_size)
+ w2 = w2.view(ne, -1, self.args.hidden_size)
+
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_grouped_glu(
+ x,
+ w1,
+ v1,
+ w2,
+ batch_sizes,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ assert gg.ops is not None
+ x1 = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
+ x2 = gg.ops.gmm(x, v1, batch_sizes, trans_b=True)
+ x1 = self.args.activation_fn(x1) * x2
+ return gg.ops.gmm(x1, w2, batch_sizes)
+
+
+class SharedGLU(SharedMLP):
+ """GPU for shared expert.
+
+ Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTGLU class
+ """
+
+ def __init__(self, args: Arguments):
+ super().__init__(args)
+ self.gate_proj = args.fc_cls(
+ args.hidden_size,
+ self.args.shared_expert_hidden_size,
+ **self.fc_kwargs,
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))
diff --git a/build/torch211-cxx11-cu128-x86_64-linux/_layers/memory_test.py b/build/torch211-cxx11-cu128-x86_64-linux/_layers/memory_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..74d1166931b712635131985b25a89f4ca23e576d
--- /dev/null
+++ b/build/torch211-cxx11-cu128-x86_64-linux/_layers/memory_test.py
@@ -0,0 +1,103 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import gc
+
+import torch
+import torch.distributed as dist
+
+# from megablocks.layers import arguments, dmoe
+from . import arguments, dmoe
+
+_TESTS = ((8, 2048, 4096, 4096, 32, 4),)
+
+
+def get_tensors():
+ ptrs = set()
+ out = []
+ for obj in gc.get_objects():
+ if torch.is_tensor(obj):
+ if not obj.is_contiguous() or obj.data_ptr() in ptrs:
+ continue
+ out.append(obj)
+ ptrs.add(obj.data_ptr())
+ return out
+
+
+def test_memory(
+ group,
+ batch_size,
+ sequence_length,
+ hidden_size,
+ ffn_hidden_size,
+ num_experts,
+ top_k,
+):
+ args = arguments.Arguments(
+ hidden_size=hidden_size,
+ ffn_hidden_size=ffn_hidden_size,
+ moe_num_experts=num_experts,
+ moe_top_k=top_k,
+ moe_expert_model_parallelism=True,
+ expert_parallel_group=group,
+ fp16=False,
+ bf16=True,
+ device=torch.cuda.current_device(),
+ )
+ layer = dmoe.dMoE(args).cuda()
+
+ x = torch.randn((batch_size, sequence_length, hidden_size),
+ device=torch.cuda.current_device(),
+ dtype=torch.bfloat16).requires_grad_(True)
+ torch.cuda.empty_cache()
+
+ # Run forward + backward.
+ # with torch.autograd.detect_anomaly():
+ out, _ = layer(x)
+ out.mean().backward()
+
+ # Report peak memory.
+ mem = torch.cuda.max_memory_allocated()
+ print('Max Memory Allocated = {:0.0f}MiB'.format(mem / 1e6))
+ print('Max Memory Reserved = {:0.0f}MiB'.format(torch.cuda.max_memory_reserved() / 1e6,),)
+
+ # Calculate weight and gradient memory usage.
+ weight_memory = 2 * (
+ layer.router.layer.weight.numel() + layer.experts.mlp.w1.numel() + layer.experts.mlp.w2.numel()
+ )
+
+ def grad_numel(x):
+ if x.grad is not None:
+ return x.grad.numel()
+ return 0
+
+ grad_memory = 2 * (
+ grad_numel(layer.router.layer.weight) + grad_numel(layer.experts.mlp.w1) + grad_numel(layer.experts.mlp.w2)
+ )
+ weight_memory += grad_memory
+
+ print('Weight Memory Allocated = {:0.0f}MiB'.format(weight_memory / 1e6))
+ print('Activation Memory Allocated = {:0.0f}MiB'.format((mem - weight_memory) / 1e6,),)
+
+ # Manually calculate GPU memory usage from the garbage
+ # collector.
+ gc.collect()
+ total = 0
+ tensors = get_tensors()
+ tensors = sorted(tensors, key=lambda x: -x.numel())
+ for i, t in enumerate(tensors):
+ total += t.numel()
+ print(f'{i}: {t.shape}, {t.numel() * 2}')
+ del tensors
+
+ print('Total Bytes Found = {:0.0f}MiB'.format(total * 2 / 1e6))
+
+
+if __name__ == '__main__':
+ assert dist.is_available()
+ group = dist.init_process_group(backend='nccl')
+ local_rank = dist.get_rank(group)
+ torch.cuda.set_device(local_rank)
+
+ for args in _TESTS:
+ test_memory(group, *args)
diff --git a/build/torch211-cxx11-cu128-x86_64-linux/_layers/mlp.py b/build/torch211-cxx11-cu128-x86_64-linux/_layers/mlp.py
new file mode 100644
index 0000000000000000000000000000000000000000..c99afb9904c24a8b6a83e79059cd1251dbbfd99e
--- /dev/null
+++ b/build/torch211-cxx11-cu128-x86_64-linux/_layers/mlp.py
@@ -0,0 +1,587 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# try:
+# import stk
+# import stk.backend.triton_kernels
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/mlp.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+from packaging import version
+
+# from megablocks import grouped_gemm_util as gg
+# from megablocks.layers import common, gelu, mpu
+# from megablocks.layers.activation_fn import act_fn
+# from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
+
+from .. import grouped_gemm_util as gg
+from . import common, gelu, mpu
+from .activation_fn import act_fn
+from .arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
+
+class ScaleGradient(torch.autograd.Function):
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx: Any, x: torch.Tensor, scale: float):
+ ctx.scale = scale
+ return x
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx: torch.Tensor, grad: torch.Tensor):
+ return grad * ctx.scale, None
+
+
+scale_gradient = ScaleGradient.apply
+
+
+def resolve_dtensor(weight: torch.Tensor):
+ if version.parse(torch.__version__) >= version.parse('2.0.0'):
+ from torch.distributed._tensor import DTensor
+ if isinstance(weight, DTensor):
+ return weight.to_local()
+ return weight
+
+
+def create_moe_expert_weights(
+ args: Arguments,
+ num_experts: int,
+ ffn_hidden_size: int,
+ hidden_size: int,
+ init_method: InitFn,
+):
+ # Create the entire weight matrix such that the sampled weights will
+ # not vary between data parallelism and expert model parallelism for
+ # the same random seed.
+ master_weights = torch.empty(
+ num_experts,
+ ffn_hidden_size,
+ hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ )
+ init_method(master_weights)
+
+ if not args.moe_expert_model_parallelism:
+ return master_weights
+
+ # Calculate the amount of sharding in each dimension.
+ expert_sharding_degree = mpu.expert_sharding_degree(args)
+ hidden_sharding_degree = mpu.hidden_sharding_degree(args)
+
+ # Calculate the experts per rank.
+ #
+ # NOTE: We assign ranks to be expert parallel before going
+ # tensor parallel.
+ rank = mpu.get_expert_parallel_rank(args)
+ expert_rank = rank % expert_sharding_degree
+ num_experts_per_rank = num_experts // expert_sharding_degree
+ start_expert = expert_rank * num_experts_per_rank
+ end_expert = (expert_rank + 1) * num_experts_per_rank
+
+ # Calculate the rows per rank.
+ row_rank = rank // expert_sharding_degree
+ num_rows_per_rank = ffn_hidden_size // hidden_sharding_degree
+ start_row = row_rank * num_rows_per_rank
+ end_row = (row_rank + 1) * num_rows_per_rank
+
+ # Slice the weight matrix to get the chunk for this rank.
+ with torch.no_grad():
+ weights = master_weights[start_expert:end_expert, start_row:end_row]
+ return weights
+
+
+class MLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ # expert_parallel_world_size = mpu.get_expert_parallel_world_size(args)
+ experts_per_rank = mpu.experts_per_rank(args)
+
+ self.w1 = torch.nn.Parameter(
+ torch.empty(
+ experts_per_rank,
+ args.hidden_size,
+ mpu.features_per_rank(args),
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ self.w2 = torch.nn.Parameter(
+ torch.empty(
+ experts_per_rank,
+ mpu.features_per_rank(args),
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w1,
+ args.moe_expert_model_parallelism,
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w2,
+ args.moe_expert_model_parallelism,
+ )
+
+ # Initialize the parameters for the MLP.
+ #
+ # NOTE: It is important that we create the weight tensors prior
+ # to creating the master weights and slicing our the piece for
+ # this rank. If the master weights are created first the PyTorch
+ # caching allocator appears to use the same memory block for these
+ # and the slice which causes large increases in our peak memory
+ # usage.
+ with torch.no_grad():
+ w1 = create_moe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ )
+ self.w1.copy_(w1.transpose(1, 2).contiguous())
+ self.w2.copy_(
+ create_moe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.output_layer_init_method,
+ ),
+ )
+
+ self.gradient_scale = None
+ if self.args.moe_expert_model_parallelism:
+ self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,)
+
+ def scale_grad(self, w):
+ if self.gradient_scale is None:
+ return w
+ return scale_gradient(w, self.gradient_scale)
+
+ def forward(self, x):
+ w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2)
+ w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2)
+ x = torch.bmm(x, w1)
+ x = self.args.activation_fn(x)
+ return torch.bmm(x, w2)
+
+
+def create_dmoe_expert_weights(
+ args: Arguments,
+ num_experts: int,
+ rows: int,
+ columns: int,
+ init_method: InitFn,
+):
+ weights = create_moe_expert_weights(
+ args,
+ num_experts,
+ rows,
+ columns,
+ init_method,
+ )
+ return weights.view([-1, columns])
+
+
+class MemoryOptimizedMLP(torch.autograd.Function):
+ """Sparse MLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, w2, topo, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.")
+
+ topo_tensors = (
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets,
+ topo.column_indices_t,
+ topo.offsets_t,
+ topo.block_offsets_t,
+ )
+
+ # Layer 0: x @ w1.t().
+ sdd_out = stk.ops.sdd(x, w1.t(), topo)
+
+ # GeLU.
+ activation_fn_out = act_fn(sdd_out, activation_fn)
+
+ # Layer 1: x @ w2.
+ dsd_out = stk.ops.dsd(activation_fn_out, w2)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.shape = topo.shape
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.data.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, w2, *topo_tensors, x, sdd_out.data)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx, ddsd_out):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, w2 = saved_tensors[:2]
+ topo_tensors = saved_tensors[2:8]
+ x = saved_tensors[8]
+ sdd_out_data = saved_tensors[9]
+
+ # rematerialize activation function output
+ activation_fn = ctx.activation_fn
+ sdd_out = stk.Matrix(ctx.shape, sdd_out_data, *topo_tensors)
+ activation_fn_out, activation_grad_fn = act_fn(
+ sdd_out,
+ activation_fn,
+ return_grad_fn=True,
+ )
+
+ # Compute dw2 with recomputed activation_fn output.
+ dw2 = stk.ops.dsd(activation_fn_out.t(), ddsd_out)
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ stk.backend.triton_kernels.sdd(
+ ddsd_out,
+ w2.t(),
+ dactivation_fn_out.shape,
+ dactivation_fn_out.data,
+ dactivation_fn_out.offsets,
+ dactivation_fn_out.row_indices,
+ dactivation_fn_out.column_indices,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ if activation_fn is DEFAULT_ACTIVATION_FN:
+ dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out)
+ else:
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out.data)
+ dsdd_out = stk.Matrix(ctx.shape, sdd_out.data.grad, *topo_tensors)
+
+ # Compute dw1.
+ dw1 = stk.ops.dsd(dsdd_out.t(), x)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ stk.backend.triton_kernels.dsd(
+ dsdd_out.shape,
+ dsdd_out.data,
+ dsdd_out.offsets,
+ dsdd_out.row_indices,
+ dsdd_out.column_indices,
+ dsdd_out.offsets_t,
+ dsdd_out.column_indices_t,
+ dsdd_out.block_offsets_t,
+ False,
+ w1,
+ ddsd_out,
+ )
+ dx = ddsd_out
+ return dx, dw1, dw2, None, None
+
+
+memory_optimized_mlp = MemoryOptimizedMLP.apply
+
+
+class SparseMLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ self._num_rows_per_rank = mpu.experts_per_rank(args) * mpu.features_per_rank(args)
+
+ self.w1 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ self.w2 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+
+ # Initialize the parameters for the MLP.
+ #
+ # NOTE: It is important that we create the weight tensors prior
+ # to creating the master weights and slicing our the piece for
+ # this rank. If the master weights are created first the PyTorch
+ # caching allocator appears to use the same memory block for these
+ # and the slice which causes large increases in our peak memory
+ # usage.
+ with torch.no_grad():
+ self.w1.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ ),
+ )
+ self.w2.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.output_layer_init_method,
+ ),
+ )
+
+ self._should_set_parallelism_attribute = args.moe_expert_model_parallelism
+ mpu.set_expert_model_parallel_attributes(
+ self.w1,
+ self._should_set_parallelism_attribute,
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w2,
+ self._should_set_parallelism_attribute,
+ )
+
+ self.gradient_scale = None
+ if self.args.moe_expert_model_parallelism:
+ self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,)
+
+ def scale_grad(self, w):
+ if self.gradient_scale is None:
+ return w
+ return scale_gradient(w, self.gradient_scale)
+
+ def forward(self, x, topo):
+ w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2)
+ w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2)
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_mlp(
+ x,
+ w1,
+ w2,
+ topo,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ x = stk.ops.sdd(x, w1.t(), topo)
+ activation_fn_out = act_fn(x, self.args.activation_fn)
+ return stk.ops.dsd(activation_fn_out, w2)
+
+
+class MemoryOptimizedGroupedMLP(torch.autograd.Function):
+ """GroupedMLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, w2, batch_sizes, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.")
+
+ # Layer 0: x @ w1.t().
+ assert gg.backend is not None
+ sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True)
+
+ # activation_fn
+ activation_fn_out = activation_fn(sdd_out)
+
+ # Layer 1: x @ w2.
+ dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, w2, batch_sizes, x, sdd_out)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx: Any, ddsd_out: torch.Tensor):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # Unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, w2 = saved_tensors[:2]
+ batch_sizes = saved_tensors[2]
+ x = saved_tensors[3]
+ sdd_out = saved_tensors[4]
+
+ # Rematerialize activation_fn output.
+ activation_fn = ctx.activation_fn
+ with torch.set_grad_enabled(True):
+ sdd_out.requires_grad = True
+ activation_fn_out = activation_fn(sdd_out)
+ activation_grad_fn = activation_fn_out.backward
+
+ # Compute dw2 with recomputed activation_fn output.
+ assert gg.backend is not None
+ dw2 = gg.backend.gmm(
+ activation_fn_out,
+ ddsd_out,
+ batch_sizes,
+ trans_a=True,
+ )
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ gg.backend.gmm(
+ ddsd_out,
+ w2,
+ batch_sizes,
+ trans_b=True,
+ c=dactivation_fn_out,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ if activation_fn is DEFAULT_ACTIVATION_FN:
+ dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out)
+ else:
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out)
+ dsdd_out = sdd_out.grad
+
+ # Compute dw1.
+ dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ gg.backend.gmm(dsdd_out, w1, batch_sizes, c=ddsd_out)
+ dx = ddsd_out
+ return dx, dw1, dw2, None, None
+
+
+memory_optimized_grouped_mlp = MemoryOptimizedGroupedMLP.apply
+
+
+class GroupedMLP(SparseMLP):
+
+ def forward(self, x, tokens_per_expert):
+ batch_sizes = tokens_per_expert.cpu().to(torch.long)
+ w1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.w2))
+
+ # Re-shape the weights for the grouped GEMMs.
+ ne = mpu.experts_per_rank(self.args)
+ w1 = resolve_dtensor(w1).view(ne, -1, self.args.hidden_size)
+ w2 = resolve_dtensor(w2).view(ne, -1, self.args.hidden_size)
+
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_grouped_mlp(
+ x,
+ w1,
+ w2,
+ batch_sizes,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ assert gg.ops is not None
+ x = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
+ x = self.args.activation_fn(x)
+ return gg.ops.gmm(x, w2, batch_sizes)
+
+
+class SharedMLP(torch.nn.Module):
+ """MLP for shared expert.
+
+ Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTMLP class
+ """
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ self.fc_kwargs: dict[str, Any] = {
+ 'bias': args.bias,
+ 'device': args.device,
+ }
+ self.fc_kwargs.update(args.fc_kwargs)
+
+ self.up_proj = args.fc_cls(
+ args.hidden_size,
+ args.shared_expert_hidden_size,
+ **self.fc_kwargs,
+ )
+ self.act = args.activation_fn
+ self.down_proj = args.fc_cls(
+ args.shared_expert_hidden_size,
+ args.hidden_size,
+ **self.fc_kwargs,
+ )
+ self.down_proj._is_residual = True # a flag for llm-foundry init
+
+ def add_experts_sharedexpert(
+ self,
+ shared_expert_out: torch.Tensor,
+ expert_out: torch.Tensor,
+ ) -> torch.Tensor:
+ # Helper function to add expert output to shared expert output
+ # with optional weighted sum.
+ if self.args.shared_expert_weighted_sum:
+ # enable using weighted sum for shared expert output
+ # wieghted by number of experts used
+ t_experts = self.args.moe_top_k + 1
+ sh_mlp_out = shared_expert_out / t_experts
+ return sh_mlp_out.add(
+ expert_out,
+ alpha=(self.args.moe_top_k / t_experts),
+ )
+
+ return shared_expert_out + expert_out
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.down_proj(self.act(self.up_proj(x)))
diff --git a/build/torch211-cxx11-cu128-x86_64-linux/_layers/moe.py b/build/torch211-cxx11-cu128-x86_64-linux/_layers/moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0a4aeaacc9c86fc70944e730c53f7a55644e05e
--- /dev/null
+++ b/build/torch211-cxx11-cu128-x86_64-linux/_layers/moe.py
@@ -0,0 +1,507 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Optional, Tuple
+
+import numpy as np
+import torch
+import torch.distributed as dist
+
+# import megablocks.ops as ops
+# from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry
+# from megablocks.layers.all_to_all import all_to_all
+# from megablocks.layers.arguments import Arguments
+
+from ..ops import (
+ sort,
+ histogram,
+ inclusive_cumsum,
+ exclusive_cumsum,
+ binned_gather,
+ binned_scatter,
+ gather,
+ scatter,
+ repeat,
+ replicate,
+)
+
+from . import common, mlp, mpu, router, sharedexpert_registry
+from .arguments import Arguments
+from .all_to_all import all_to_all
+
+_LOAD_BALANCING_LOSS = []
+
+
+def save_load_balancing_loss(loss):
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.append(loss)
+
+
+def get_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ return _LOAD_BALANCING_LOSS
+
+
+def clear_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.clear()
+
+
+def batched_load_balancing_loss(args: Arguments):
+ if args.moe_loss_weight == 0:
+ return 0.0
+
+ # tokens_per_expert[i].shape = (num_experts)
+ # expert_scores[i].shape = (tokens, num_experts)
+ tokens_per_expert, expert_scores = zip(*get_load_balancing_loss())
+ num_layers_per_pipeline_stage = (args.num_layers // args.pipeline_model_parallel_size)
+ if args.num_layers_per_virtual_pipeline_stage is not None:
+ num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage
+
+ if len(tokens_per_expert) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f'Expected {num_layers_per_pipeline_stage} token_per_experts '
+ f'but found {len(tokens_per_expert)}.\nnum_layers = '
+ f'{args.num_layers}\npipeline_model_parallel_size = '
+ f'{args.pipeline_model_parallel_size}\n'
+ 'num_layers_per_virtual_pipeline_stage'
+ f' = {args.num_layers_per_virtual_pipeline_stage}',
+ )
+ if len(expert_scores) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f'Expected {num_layers_per_pipeline_stage} expert_scores '
+ f'but found {len(tokens_per_expert)}.\nnum_layers = '
+ f'{args.num_layers}\npipeline_model_parallel_size = '
+ f'{args.pipeline_model_parallel_size}\n'
+ 'num_layers_per_virtual_pipeline_stage'
+ f' = {args.num_layers_per_virtual_pipeline_stage}',
+ )
+
+ # Verify the shape of the tokens_per_expert and expert_scores tensors.
+ assert all((x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert))
+
+ tokens = expert_scores[0].shape[0]
+ assert all(((x.ndim == 2 and x.shape[1] == args.moe_num_experts and x.shape[0] == tokens) for x in expert_scores))
+
+ # Concatenate the contributions of each layer and convert to
+ # the correct types and formats for the dot product.
+ expert_scores = torch.cat(expert_scores, dim=1)
+ if args.moe_lbl_in_fp32:
+ expert_scores = expert_scores.float()
+ if tokens != 0:
+ expert_scores = expert_scores.mean(dim=0)
+ else:
+ expert_scores = expert_scores.sum(dim=0)
+ tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype)
+
+ expected_values = num_layers_per_pipeline_stage * args.moe_num_experts
+ assert tokens_per_expert.numel() == expected_values
+ assert expert_scores.numel() == expected_values
+
+ # Calculate the total scale across all factors.
+ #
+ # loss_weight * num_experts / (num_layers * tokens * top_k)
+ scale_numerator = (args.moe_num_experts * args.moe_loss_weight)
+ scale_denominator = (args.num_layers * tokens * args.moe_top_k)
+ scale = scale_numerator / scale_denominator
+ return scale * torch.dot(tokens_per_expert, expert_scores)
+
+
+# NOTE: This class defines MoE expert computation, including expert model parallel
+# communication. When using FSDP on top of MegaBlocks this is the module that should
+# be wrapped s.t. the weight all-gathers can be scheduled *before* the expert model
+# parallel all2all.
+class ParallelMLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super(ParallelMLP, self).__init__()
+ self.args = args
+
+ # Calculate the number of experts in total and the number of experts
+ # owned by this rank.
+ # world_size = mpu.get_expert_parallel_world_size(args)
+ self.num_experts = args.moe_num_experts
+ self.top_k = self.args.moe_top_k
+
+ # Calculate the number of bits needed to represent the expert indices
+ # so that we can pass it to radix sort.
+ self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1)
+
+ # Expert MLP.
+ self.mlp = mlp.MLP(args)
+
+ self.bias: Optional[torch.Tensor]
+ if self.args.bias:
+ # Note that the output bias is not parallelized with expert
+ # model parallelism.
+ self.bias = torch.nn.Parameter(
+ torch.empty(
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ torch.nn.init.zeros_(self.bias)
+ else:
+ self.register_parameter('bias', None)
+
+ # Select the forward function for the operating mode.
+ self.forward_fn = (self.parallel_forward_once if args.moe_expert_model_parallelism else self.forward_once)
+
+ def expert_capacity(self, tokens: int) -> int:
+ world_size = mpu.get_expert_parallel_world_size(self.args)
+ tokens_per_expert = (self.top_k * tokens * world_size / self.num_experts)
+ return int(self.args.moe_capacity_factor * tokens_per_expert)
+
+ def load_balancing_loss(self, tokens_per_expert: torch.Tensor, expert_scores: torch.Tensor):
+ """Calculate the load balancing loss contribution."""
+ assert len(expert_scores.size()) == 2
+ tokens, num_experts = expert_scores.size()
+ assert num_experts == self.num_experts
+ assert len(tokens_per_expert.size()) == 1
+ num_experts, = tokens_per_expert.size()
+ assert num_experts == self.num_experts
+ scale = self.num_experts / (tokens * self.top_k)
+ return scale * torch.dot(
+ tokens_per_expert.to(expert_scores.dtype),
+ expert_scores.mean(dim=0),
+ )
+
+ def indices_and_bins(self,
+ top_expert: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ # Sort the expert ids to produce the scatter/gather
+ # indices for the permutation.
+ #
+ # TODO(tgale): Is it worth doing this conversion to 32-bit
+ # prior? Could we place the `torch.max` operation to return
+ # 32-bit expert indices?
+ top_expert = top_expert.int()
+ # output = ops.sort(top_expert, self.sort_end_bit)
+ output = sort(top_expert, self.sort_end_bit)
+ assert output is not None
+ bin_ids, indices = output
+
+ # Histogram the expert ids to identify the number of
+ # tokens routed to each expert.
+ #
+ # TODO(tgale): Does the sorted data produce a more favorable
+ # data distribution for histogram? Or is the op parallelism
+ # worth more?
+ # tokens_per_expert = ops.histogram(top_expert, self.num_experts)
+ tokens_per_expert = histogram(top_expert, self.num_experts)
+
+ # Calculate the bin bounds for the sorted tokens.
+ # bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ bins = inclusive_cumsum(tokens_per_expert, 0)
+ assert bins is not None
+ bins = bins.view(1) if not len(bins.size()) else bins
+
+ assert isinstance(indices, torch.Tensor)
+ assert isinstance(bin_ids, torch.Tensor)
+ assert isinstance(bins, torch.Tensor)
+ assert isinstance(tokens_per_expert, torch.Tensor)
+
+ return indices, bin_ids, bins, tokens_per_expert
+
+ def permute_and_compute(
+ self,
+ x: torch.Tensor,
+ tokens_per_expert: int, # unused
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor, # unused
+ expert_weights: torch.Tensor,
+ bins: torch.Tensor,
+ expert_capacity: int,
+ top_k: int,
+ ):
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ # output = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
+ output = binned_gather(x, indices, bins, expert_capacity, top_k)
+ assert output is not None
+ x = output
+
+ # Perform the expert computation. Note that we don't
+ # use biases for these linear operations.
+ x = self.mlp(x)
+
+ # Un-route the data for the MoE output.
+ # return ops.binned_scatter(x, indices, expert_weights, bins, top_k)
+ return binned_scatter(x, indices, expert_weights, bins, top_k)
+
+
+ def forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ # If expert_capacity is set to zero, set the number of tokens
+ # per expert to the maximum we need to avoid dropping tokens.
+ sl, bs, _ = x.size()
+ expert_capacity = self.expert_capacity(sl * bs)
+ if expert_capacity == 0:
+ expert_capacity = torch.max(tokens_per_expert).item()
+
+ x = self.permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ self.top_k,
+ )
+ return x, tokens_per_expert
+
+ def parallel_forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ # NOTE: This function implements the same computation as forward_once
+ # but with expert model parallelism.
+ #
+ # 1. Permute the tokens locally so that they are grouped by their
+ # expert assignments. This allows us to transfer all of the tokens
+ # for a remote device in one communication primitive.
+ #
+ # 2. Permute the tokens across the expert parallel devices. After
+ # this is completed each device has all of the tokens assigned to
+ # its set of experts in its local HBM.
+ #
+ # 3. Permute the tokens locally so that they are grouped by their
+ # expert assignement. After the distributed permutation the tokens
+ # are grouped by which device they came from. We re-order them
+ # locally to allow for efficient computation.
+ #
+ # After this series of permutations we compute the linear layers
+ # and then repeat these three steps in reverse to produce the final
+ # output.
+ #
+ # Compute the mapping of local tokens to experts.
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ # If we're sharding the experts along the hidden dimension
+ # multiple devices own parts of the same sets of experts.
+ # Replicate the token counts so every device gets the counts.
+ # repeated_tokens_per_expert = ops.repeat(
+ repeated_tokens_per_expert = repeat(
+ tokens_per_expert,
+ (mpu.hidden_sharding_degree(self.args),),
+ )
+
+ # Pass token count information to the device on which the
+ # target expert resides.
+ parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert,)
+ tpe_handle = dist.all_to_all_single(
+ parallel_tokens_per_expert,
+ repeated_tokens_per_expert,
+ group=self.args.expert_parallel_group,
+ async_op=True,
+ )
+
+ # Permute locally and without any padding so that tokens for each
+ # parallel device are stored contiguously.
+ #
+ # This view updates the shape of the tensor from [sl, bs, hs] to
+ # [sl * bs, hs] prior to the permutation.
+ x = x.view(-1, x.shape[-1])
+ # output = ops.gather(x, indices, bin_ids, bins, self.top_k)
+ output = gather(x, indices, bin_ids, bins, self.top_k)
+ assert output is not None
+ x = output
+
+ # Compute the number of tokens that will be received from each
+ # device and permute the input data across the devices.
+ with torch.no_grad():
+ tpe_handle.wait()
+ experts_per_rank = mpu.experts_per_rank(self.args)
+
+ # Reshape to [world_size, num_experts_per_rank].
+ world_size = mpu.get_expert_parallel_world_size(self.args)
+ repeated_tokens_per_expert = (repeated_tokens_per_expert.view(world_size, experts_per_rank))
+ parallel_tokens_per_expert = (parallel_tokens_per_expert.view(world_size, experts_per_rank))
+
+ # TODO(tgale): It might be faster to do this on the GPU and
+ # then communicate the results back to the host.
+ send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1)
+ parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
+ recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1)
+
+ # Convert the send/recv counts to lists.
+ send_counts = send_counts.tolist()
+ recv_counts = recv_counts.tolist()
+ tokens_received = sum(recv_counts)
+
+ # If we're sharding the experts along the hidden dimension
+ # multiple devices own parts of the same sets of experts.
+ # Replicate the token counts so devices that share experts
+ # get all of the tokens assigned to them.
+ #
+ # TODO(tgale): Fuse this into the prior, local permutation.
+ # x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
+ x = repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
+
+ # Start the cross-device permutation asynchronously so we can
+ # overlap communication with computation.
+ parallel_x, parallel_x_handle = all_to_all(
+ x,
+ recv_counts,
+ send_counts,
+ self.args.expert_parallel_group,
+ async_op=True,
+ )
+
+ with torch.no_grad():
+ # After we do the cross-device permutation we have the tokens on the
+ # correct device but not yet grouped by expert because we received
+ # tokens from each device as contiguous chunks. To group the tokens
+ # for expert computation we'll do one more local permutation. The
+ # rest of this torch.no_grad() scope sets up the indices and bins
+ # for this permutation.
+ # replicate_bins = ops.inclusive_cumsum(
+ replicate_bins = inclusive_cumsum(
+ parallel_tokens_per_expert.flatten(),
+ 0,
+ )
+ replicate_bins = (replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins)
+
+ # Construct the expert indices for the permuted tokens.
+ parallel_top_expert = torch.remainder(
+ torch.arange(
+ self.num_experts * mpu.hidden_sharding_degree(self.args),
+ dtype=torch.int32,
+ device=indices.device,
+ ),
+ mpu.experts_per_rank(self.args),
+ )
+ # parallel_top_expert = ops.replicate(
+ parallel_top_expert = replicate(
+ parallel_top_expert.unsqueeze(dim=0),
+ replicate_bins,
+ tokens_received,
+ ).flatten()
+
+ # TODO(tgale): The sort_end_bit here can be reduced.
+ # parallel_bin_ids, parallel_indices = ops.sort(
+ parallel_bin_ids, parallel_indices = sort(
+ parallel_top_expert,
+ self.sort_end_bit,
+ )
+
+ # Calculate the bins boundaries from the token counts.
+ parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+ # parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = (parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins)
+
+ # If expert_capacity is set to zero, set the number of tokens
+ # per expert to the maximum we need to avoid dropping tokens.
+ tokens, _ = x.size()
+ expert_capacity = self.expert_capacity(tokens)
+ if expert_capacity == 0:
+ expert_capacity = torch.max(parallel_tokens_per_expert).item()
+
+ # Locally permute the tokens and perform the expert computation.
+ # Block to make sure that the cross-device permutation is complete.
+ if self.args.mlp_impl == 'grouped':
+ # GroupedMLP requires counts on CPU. We can use the tensor already
+ # moved to CPU for the prior all_to_all, which avoids an extra
+ # device synchronization.
+ parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+ parallel_x_handle.wait()
+ parallel_x = self.permute_and_compute(
+ parallel_x,
+ parallel_tokens_per_expert,
+ parallel_indices,
+ parallel_bin_ids,
+ None, # expert_weights
+ parallel_bins,
+ expert_capacity,
+ top_k=1,
+ )
+
+ # Un-permute the tokens across the devices.
+ x, _ = all_to_all(
+ parallel_x,
+ send_counts,
+ recv_counts,
+ self.args.expert_parallel_group,
+ )
+
+ # Reduce along the hidden sharding to get the final outputs.
+ #
+ # TODO(tgale): Fuse this into the following local permutation.
+ shape = (
+ mpu.hidden_sharding_degree(self.args),
+ -1,
+ self.args.hidden_size,
+ )
+ # x = ops.sum(x.view(shape), dim=0)
+ x = x.view(shape).sum(dim=0)
+
+ # Un-permute locally to setup for the next series of operations.
+ # x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
+ x = scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
+ return x, tokens_per_expert.flatten()
+
+ def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ in_shape = x.size()
+
+ # Compute the experts.
+ x, tokens_per_expert = self.forward_fn(x, expert_weights, top_experts)
+ if self.training and self.args.moe_loss_weight > 0:
+ save_load_balancing_loss((tokens_per_expert, scores))
+ x = x.view(in_shape)
+ if self.bias is not None:
+ if self.args.return_bias:
+ return x, self.bias
+ return x + self.bias
+ return x
+
+
+class MoE(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super(MoE, self).__init__()
+
+ # Token router.
+ self.router = router.LearnedRouter(args)
+
+ # Expert computation helper.
+ self.experts = self._init_experts_mlp(args)
+
+ self.shared_expert = None
+ if args.shared_expert:
+ # SharedExpert computation helper.
+ self.shared_expert = sharedexpert_registry.get(args)
+
+ def _init_experts_mlp(self, args: Arguments):
+ return ParallelMLP(args)
+
+ def forward(self, x: torch.Tensor):
+ # NOTE: If we're going to cast the activations to lower precision
+ # do it before we permute the tokens to save bandwidth.
+ x = common.cast_if_autocast_enabled(x)
+
+ # Compute the expert scores and assignments.
+ scores, expert_weights, top_experts = self.router(x)
+
+ # Compute the experts.
+ out = self.experts(x, scores, expert_weights, top_experts)
+ if self.shared_expert is not None:
+ shared_expert_out = self.shared_expert(x)
+ out = self.shared_expert.add_experts_sharedexpert(
+ shared_expert_out,
+ out,
+ )
+ return out
diff --git a/build/torch211-cxx11-cu128-x86_64-linux/_layers/mpu.py b/build/torch211-cxx11-cu128-x86_64-linux/_layers/mpu.py
new file mode 100644
index 0000000000000000000000000000000000000000..434e143ab42bf3f83406d69e9dd1f72777716e22
--- /dev/null
+++ b/build/torch211-cxx11-cu128-x86_64-linux/_layers/mpu.py
@@ -0,0 +1,94 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Optional
+
+import torch
+import torch.distributed as dist
+
+# from megablocks.layers.arguments import Arguments
+from .arguments import Arguments
+
+
+class MoeParam(torch.Tensor):
+
+ def __init__(self):
+ super().__init__(self)
+ self.expert_model_parallel: bool
+
+
+def is_moe_param(tensor: torch.Tensor) -> bool:
+ return hasattr(tensor, 'expert_model_parallel')
+
+
+def get_expert_parallel_world_size(args: Arguments) -> int:
+ return (dist.get_world_size(args.expert_parallel_group) if args.moe_expert_model_parallelism else 1)
+
+
+def get_expert_parallel_rank(args: Arguments) -> int:
+ return (dist.get_rank(args.expert_parallel_group) if args.moe_expert_model_parallelism else 0)
+
+
+def set_expert_model_parallel_attributes(
+ tensor: torch.Tensor,
+ is_parallel: bool,
+):
+ assert not hasattr(tensor, 'expert_model_parallel')
+ setattr(tensor, 'expert_model_parallel', is_parallel)
+
+
+def param_is_expert_model_parallel(param: MoeParam) -> bool:
+ return (hasattr(param, 'expert_model_parallel') and param.expert_model_parallel)
+
+
+def copy_expert_model_parallel_attributes(
+ destination_tensor: torch.Tensor,
+ source_tensor: torch.Tensor,
+):
+ if hasattr(source_tensor, 'expert_model_parallel'):
+ setattr(
+ destination_tensor,
+ 'expert_model_parallel',
+ getattr(source_tensor, 'expert_model_parallel'),
+ )
+
+
+def synchronized_print(group: Optional[dist.ProcessGroup], *x: torch.Tensor):
+ world_size = dist.get_world_size(group)
+ rank = dist.get_rank(group)
+ for i in range(world_size):
+ dist.barrier(group)
+ if i == rank:
+ print(f'rank = {rank}', *x)
+
+
+# Helpers for expert/tensor sharding.
+def expert_sharding_degree(args: Arguments) -> int:
+ world_size = get_expert_parallel_world_size(args)
+ esd = min(world_size, args.moe_num_experts)
+
+ if (args.moe_num_experts % esd) != 0:
+ raise ValueError(f'Cannot shard {args.moe_num_experts} experts {esd} ways.',)
+ return esd
+
+
+def hidden_sharding_degree(args: Arguments) -> int:
+ world_size = get_expert_parallel_world_size(args)
+ esd = expert_sharding_degree(args)
+ hsd = world_size // esd
+
+ if (args.ffn_hidden_size % hsd) != 0:
+ raise ValueError(f'Cannot shard {args.ffn_hidden_size} features {hsd} ways.',)
+ if (esd * hsd) != world_size:
+ raise ValueError(
+ f"Invalid sharding. 'expert_sharding_degree' ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size}).",
+ )
+ return hsd
+
+
+def experts_per_rank(args: Arguments) -> int:
+ return args.moe_num_experts // expert_sharding_degree(args)
+
+
+def features_per_rank(args: Arguments) -> int:
+ return args.ffn_hidden_size // hidden_sharding_degree(args)
diff --git a/build/torch211-cxx11-cu128-x86_64-linux/_layers/router.py b/build/torch211-cxx11-cu128-x86_64-linux/_layers/router.py
new file mode 100644
index 0000000000000000000000000000000000000000..37cb2782348d62583376f1a183c7ede83601216d
--- /dev/null
+++ b/build/torch211-cxx11-cu128-x86_64-linux/_layers/router.py
@@ -0,0 +1,116 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+
+# from megablocks.layers import common
+# from megablocks.layers.arguments import Arguments
+from . import common
+from .arguments import Arguments
+
+_ROUTER_LOGITS = []
+
+
+def _save_router_logits(logits: torch.Tensor, args: Arguments):
+ if args.moe_zloss_weight == 0:
+ return
+ global _ROUTER_LOGITS
+ _ROUTER_LOGITS.append(logits)
+
+
+def clear_router_zloss():
+ global _ROUTER_LOGITS
+ _ROUTER_LOGITS.clear()
+
+
+def batched_router_zloss(args: Arguments):
+ global _ROUTER_LOGITS
+
+ if args.moe_zloss_weight == 0:
+ import warnings
+ warnings.warn('Call to batched_router_zloss, but moe_zloss_weight=0')
+ return 0
+
+ logits_per_router = _ROUTER_LOGITS
+
+ if args.moe_zloss_in_fp32:
+ logits_per_router = [logits.float() for logits in logits_per_router]
+
+ unscaled_zloss_per_router = torch.stack([
+ torch.logsumexp(logits, dim=1).square().mean() for logits in logits_per_router
+ ])
+
+ return args.moe_zloss_weight * unscaled_zloss_per_router
+
+
+# NOTE: To enable end-to-end benchmarking without convergence we
+# support a flag to force the router to assign tokens uniformly
+# across the experts. We do this with a custom autograd operation
+# so that PyTorch still executes the full set of router operation.
+class _UniformExpertAssignment(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, num_experts: int):
+ out = torch.arange(x.numel(), dtype=x.dtype, device=x.device)
+ out = torch.remainder(out, num_experts)
+ return out.view(x.shape)
+
+
+_uniform_expert_assignment = _UniformExpertAssignment.apply
+
+
+class LearnedRouter(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+
+ # Learned router parameters.
+ #
+ # NOTE: This weight matrix is not parallelized with expert model
+ # parallelism. Each device needs the entire router weight matrix
+ # so that it can route its batch of data correctly.
+ self.layer = torch.nn.Linear(
+ args.hidden_size,
+ args.moe_num_experts,
+ bias=False,
+ dtype=common.dtype(args),
+ device=args.device,
+ )
+ args.init_method(self.layer.weight)
+
+ def jitter(self, x: torch.Tensor):
+ low: float = 1.0 - self.args.moe_jitter_eps
+ high: float = 1.0 + self.args.moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return low + noise * (high - low)
+
+ def _top_k(self, scores: torch.Tensor):
+ if self.args.moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, self.args.moe_top_k, dim=-1)
+
+ def forward(self, x: torch.Tensor):
+ if self.training and self.args.moe_jitter_eps is not None:
+ x = x * self.jitter(x)
+
+ logits = self.layer(x.view(-1, x.shape[-1]))
+ _save_router_logits(logits, self.args)
+ scores = logits.softmax(dim=-1)
+ expert_weights, expert_indices = self._top_k(scores)
+ if self.args.moe_normalize_expert_weights:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=self.args.moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+
+ expert_indices = (
+ _uniform_expert_assignment(
+ expert_indices,
+ self.args.moe_num_experts,
+ ) if self.args.uniform_expert_assignment else expert_indices
+ )
+ return scores, expert_weights, expert_indices
diff --git a/build/torch211-cxx11-cu128-x86_64-linux/_layers/sharedexpert_registry.py b/build/torch211-cxx11-cu128-x86_64-linux/_layers/sharedexpert_registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..5840862f88f370ace5fd49bd0612fc98d186cc49
--- /dev/null
+++ b/build/torch211-cxx11-cu128-x86_64-linux/_layers/sharedexpert_registry.py
@@ -0,0 +1,32 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Union
+
+# from megablocks.layers import glu, mlp
+# from megablocks.layers.arguments import Arguments
+from . import glu, mlp
+from .arguments import Arguments
+
+_REGISTRY = {
+ 'mlp': mlp.SharedMLP,
+ 'glu': glu.SharedGLU,
+}
+
+
+def get(args: Arguments) -> Union[mlp.SharedMLP, glu.SharedGLU]:
+ """Returns an SharedMLP for use in a dMoE instance.
+
+ Uses the provided arguments to instantiate the appropriate
+ SharedMLP instance.
+
+ Args:
+ args: propagated Arguments dataclass.
+
+ Returns:
+ An instantiated SharedMLP constructed using the input args.
+ """
+ if args.mlp_type not in _REGISTRY:
+ raise ValueError(f'Unsupported mlp type: {args.mlp_type}')
+
+ return _REGISTRY[args.mlp_type](args)
diff --git a/build/torch211-cxx11-cu128-x86_64-linux/_megablocks_cuda_7a6bcf4.abi3.so b/build/torch211-cxx11-cu128-x86_64-linux/_megablocks_cuda_7a6bcf4.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..767c0720dd56d80d3e809cbe78db66791cefbc43
--- /dev/null
+++ b/build/torch211-cxx11-cu128-x86_64-linux/_megablocks_cuda_7a6bcf4.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:adea81c43411e3594ad56b695b6913a6b03ccffa516d582f4cf6a6dba57bab04
+size 21009984
diff --git a/build/torch211-cxx11-cu128-x86_64-linux/_ops.py b/build/torch211-cxx11-cu128-x86_64-linux/_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..abde001f8cf5f78a02794d6e9a81fd8195e65d77
--- /dev/null
+++ b/build/torch211-cxx11-cu128-x86_64-linux/_ops.py
@@ -0,0 +1,9 @@
+import torch
+from . import _megablocks_cuda_7a6bcf4
+ops = torch.ops._megablocks_cuda_7a6bcf4
+
+def add_op_namespace_prefix(op_name: str):
+ """
+ Prefix op by namespace.
+ """
+ return f"_megablocks_cuda_7a6bcf4::{op_name}"
diff --git a/build/torch211-cxx11-cu128-x86_64-linux/_version.py b/build/torch211-cxx11-cu128-x86_64-linux/_version.py
new file mode 100644
index 0000000000000000000000000000000000000000..c55783177af19bc03654c730c4892df8f8532279
--- /dev/null
+++ b/build/torch211-cxx11-cu128-x86_64-linux/_version.py
@@ -0,0 +1,6 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+"""The MegaBlocks Version."""
+
+__version__ = '0.11.0.dev0'
diff --git a/build/torch211-cxx11-cu128-x86_64-linux/backend/__init__.py b/build/torch211-cxx11-cu128-x86_64-linux/backend/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d4e43e9b7e6a9a1c3bde2df34914643ca5d8332
--- /dev/null
+++ b/build/torch211-cxx11-cu128-x86_64-linux/backend/__init__.py
@@ -0,0 +1,2 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
diff --git a/build/torch211-cxx11-cu128-x86_64-linux/backend/kernels.py b/build/torch211-cxx11-cu128-x86_64-linux/backend/kernels.py
new file mode 100644
index 0000000000000000000000000000000000000000..d20139324883338ddb312e4b05a72056d74491ac
--- /dev/null
+++ b/build/torch211-cxx11-cu128-x86_64-linux/backend/kernels.py
@@ -0,0 +1,557 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import triton
+import triton.language as tl
+
+# Stub triton autotune when testing in a env that does not have CUDA
+# this approach preserves the original code but enables testing without a GPU
+if torch.cuda.is_available() is False:
+ import warnings
+
+ warnings.warn("CUDA is not available. Triton autotuning is disabled.")
+
+ def _no_autotune(*args, **kwargs):
+ def deco(fn):
+ return fn
+ return deco
+
+ triton.autotune = _no_autotune
+
+
+def assert_is_tensor(x, ndim):
+ if x.ndim != ndim:
+ raise ValueError(f'Expected {ndim}-tensor but got {x.ndim}-tensor')
+
+
+def assert_is_matrix(x):
+ assert_is_tensor(x, 2)
+
+
+def assert_is_vector(x):
+ if x.ndim != 1:
+ raise ValueError(f'Expected 1-tensor but got {x.ndim}-tensor')
+
+
+def assert_equal(a, b):
+ if a != b:
+ raise ValueError(f'Expected dimensions to be equal but got {a} and {b}.',)
+
+
+# a: (tokens, hidden_size), real.
+# indices: (tokens * top_k), integer.
+# bin_ids: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+# padded_bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _padded_copy(
+ a,
+ b,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+ A_TO_B: tl.constexpr,
+ SCALE: tl.constexpr,
+):
+ # Our index into array 'a'.
+ index_a = tl.load(indices + tl.program_id(0))
+
+ # One threadblock per row in 'a'. Array 'b' has greater or equal
+ # number of rows since they could be padded.
+ bin_idx = tl.load(bin_ids + tl.program_id(0))
+
+ # Now we know what bin we're assigned to, but we need to know how
+ # many threadblocks were assigned to earlier bins so we can offset
+ # in our bin properly.
+ offset_in_bin = tl.program_id(0)
+ if bin_idx > 0:
+ offset_in_bin -= tl.load(bins + bin_idx - 1)
+
+ # Load the starting index of our bin in array 'b'.
+ index_b = offset_in_bin
+ if bin_idx > 0:
+ index_b += tl.load(padded_bins + bin_idx - 1)
+
+ # Offset the input and output pointers.
+ #
+ # If we're going from A to B, divide the input index to copy
+ # the same input repeatedly. If we're going from B to A we
+ # need to reduce the result. Using atomics is slow, so we
+ # do the reduce step in a second kernel.
+ offset = index_a // TOP_K if A_TO_B else index_a
+ a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS)
+ b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ # Load the scale, if requested.
+ scale = tl.load(weights + index_a) if SCALE else 1
+
+ # Swap the pointers depending on the direction.
+ iptr = a if A_TO_B else b
+ optr = b if A_TO_B else a
+
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ x = tl.load(iptr + offsets, mask=mask)
+ x = x.to(tl.float32) * scale.to(tl.float32)
+
+ tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask)
+
+ offsets += BLOCK_X
+
+
+def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+ assert_equal(bin_ids.shape[0], x.shape[0] * top_k)
+ assert_equal(bins.size(), padded_bins.size())
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ # NOTE: Because of the padding, the output size is dynamic.
+ # We load the final padded bin bound to get the output rows.
+ output_rows = padded_bins[-1].cpu().item()
+ out = torch.zeros((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ x,
+ out,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def gather(x, indices, bin_ids, weights, bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+ assert_equal(bin_ids.shape[0], x.shape[0] * top_k)
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ # NOTE: There is no padding so the output rows equals the
+ # input rows multiplied by top_k.
+ output_rows = x.shape[0] * top_k
+ out = torch.empty((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ x,
+ out,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], bin_ids.shape[0])
+ assert_equal(bins.size(), padded_bins.size())
+
+ if weights is not None:
+ assert_equal(indices.shape[0], weights.shape[0])
+
+ tokens = indices.shape[0] // top_k
+ out = torch.empty((tokens, top_k, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ out,
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=False,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+
+ # Reduce along the top-k dimension, if needed.
+ return out.sum(dim=1) if top_k > 1 else out.view(tokens, x.shape[1])
+
+
+def scatter(x, indices, bin_ids, weights, bins, top_k):
+ return padded_scatter(x, indices, bin_ids, weights, bins, bins, top_k)
+
+
+# x: (tokens, top_k, hidden_size), real
+# grad: (tokens, hidden_size), real.
+# wgrad: (tokens, top_k), real.
+# indices: (tokens * top_k), integer.
+# bin_ids: (tokens * top_k), integer.
+# bins: (num_experts), integer.
+# padded_bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _padded_copy_wgrad(
+ x,
+ grad,
+ wgrad,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+):
+ # Our index into 'tokens * top_k'.
+ index_out = tl.load(indices + tl.program_id(0))
+
+ # One threadblock per row in 'a'. Array 'b' has greater or equal
+ # number of rows since they could be padded.
+ bin_idx = tl.load(bin_ids + tl.program_id(0))
+
+ # Now we know what bin we're assigned to, but we need to know how
+ # many threadblocks were assigned to earlier bins so we can offset
+ # in our bin properly.
+ offset_in_bin = tl.program_id(0)
+ if bin_idx > 0:
+ offset_in_bin -= tl.load(bins + bin_idx - 1)
+
+ # Load the starting index of our bin in array 'x'.
+ index_x = offset_in_bin
+ if bin_idx > 0:
+ index_x += tl.load(padded_bins + bin_idx - 1)
+
+ # Offset the input and output pointers.
+ wgrad += index_out
+ grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS)
+ x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ data = tl.load(x + offsets, mask=mask).to(tl.float32)
+ scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
+ acc += data * scale
+ offsets += BLOCK_X
+
+ # Reduce to get the final result and store.
+ out = tl.sum(acc).to(wgrad.dtype.element_ty)
+ tl.store(wgrad, out)
+
+
+def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_matrix(grad)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], bin_ids.shape[0])
+ assert_equal(bins.size(), padded_bins.size())
+
+ tokens = indices.shape[0] // top_k
+ out = torch.empty((tokens * top_k), dtype=x.dtype, device=x.device)
+ _padded_copy_wgrad[(indices.shape[0],)](
+ x,
+ grad,
+ out,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ TOP_K=top_k,
+ )
+ return out
+
+
+def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k):
+ return padded_scatter_wgrad(x, grad, indices, bin_ids, bins, bins, top_k)
+
+
+# a: (tokens, hidden_size), real.
+# b: (num_experts, expert_capacity, num_columns), real.
+# indices: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _binned_copy(
+ a,
+ b,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+ A_TO_B: tl.constexpr,
+ SCALE: tl.constexpr,
+):
+ # Load our indices into the output.
+ expert_idx = tl.program_id(0)
+ entry_idx = tl.program_id(1)
+
+ # Calculate our offset into the output.
+ index_b = expert_idx * expert_capacity + entry_idx
+
+ # Load the index bounds for our bin and calculate
+ # the number of tokens assigned to our expert.
+ start = 0
+ if expert_idx > 0:
+ start = tl.load(bins + expert_idx - 1)
+ end = tl.load(bins + expert_idx)
+ num_tokens = end - start
+
+ # Calculate our offset into the input. If we don't
+ # have an input exit early.
+ if entry_idx >= num_tokens:
+ return
+ index_a = tl.load(indices + start + entry_idx)
+
+ # Offset the input and output pointers.
+ #
+ # If we're going from A to B, divide the input index to copy
+ # the same input repeatedly. If we're going from B to A we
+ # need to reduce the result. Using atomics is slow, so we
+ # do the reduce step in a second kernel.
+ offset = index_a // TOP_K if A_TO_B else index_a
+ a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS)
+ b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ # Load the scale, if requested.
+ scale = tl.load(weights + index_a) if SCALE else 1
+
+ # Swap the pointers depending on the direction.
+ #
+ # NOTE: We need to zero the output in both directions.
+ iptr = a if A_TO_B else b
+ optr = b if A_TO_B else a
+
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ x = tl.load(iptr + offsets, mask=mask)
+ x = x.to(tl.float32) * scale.to(tl.float32)
+
+ tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask)
+
+ offsets += BLOCK_X
+
+
+def binned_gather(x, indices, weights, bins, expert_capacity, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ num_experts = bins.shape[0]
+ out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
+
+ _binned_copy[(num_experts, expert_capacity)](
+ x,
+ out,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def binned_scatter(x, indices, weights, bins, top_k):
+ # Validate the input shapes.
+ assert_is_tensor(x, 3)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(bins.shape[0], x.shape[0])
+
+ if weights is not None:
+ assert_equal(indices.shape[0], weights.shape[0])
+
+ num_experts, expert_capacity, hidden_size = x.shape
+ tokens = indices.shape[0] // top_k
+ out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
+ _binned_copy[(num_experts, expert_capacity)](
+ out,
+ x,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS=hidden_size,
+ A_TO_B=False,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+
+ # Reduce along the top-k dimension, if needed.
+ return out.sum(dim=1) if top_k > 1 else out.view(tokens, hidden_size)
+
+
+# a: (tokens, hidden_size), real.
+# b: (num_experts, expert_capacity, num_columns), real.
+# indices: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _binned_copy_wgrad(
+ x,
+ grad,
+ wgrad,
+ num_experts,
+ expert_capacity,
+ indices,
+ bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+):
+ # Load our indices into the output.
+ expert_idx = tl.program_id(0)
+ entry_idx = tl.program_id(1)
+
+ # Calculate our offset into the output.
+ index_x = expert_idx * expert_capacity + entry_idx
+
+ # Load the index bounds for our bin and calculate
+ # the number of tokens assigned to our expert.
+ start = 0
+ if expert_idx > 0:
+ start = tl.load(bins + expert_idx - 1)
+ end = tl.load(bins + expert_idx)
+ num_tokens = end - start
+
+ # Calculate our offset into the input. If we don't
+ # have an input exit early.
+ if entry_idx >= num_tokens:
+ return
+ index_out = tl.load(indices + start + entry_idx)
+
+ # Offset the input and output pointers.
+ wgrad += index_out
+ grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS)
+ x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ data = tl.load(x + offsets, mask=mask).to(tl.float32)
+ scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
+ acc += data * scale
+ offsets += BLOCK_X
+
+ # Reduce to get the final result and store.
+ out = tl.sum(acc).to(wgrad.dtype.element_ty)
+ tl.store(wgrad, out)
+
+
+def binned_scatter_wgrad(x, grad, indices, bins, top_k):
+ # Validate the input shapes.
+ assert_is_tensor(x, 3)
+ assert_is_matrix(grad)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(bins.shape[0], x.shape[0])
+
+ num_experts, expert_capacity, hidden_size = x.shape
+ tokens = indices.shape[0] // top_k
+ out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
+ _binned_copy_wgrad[(num_experts, expert_capacity)](
+ x,
+ grad,
+ out,
+ num_experts,
+ expert_capacity,
+ indices,
+ bins,
+ NUM_COLUMNS=hidden_size,
+ TOP_K=top_k,
+ )
+ return out
diff --git a/build/torch211-cxx11-cu128-x86_64-linux/benchmark_util.py b/build/torch211-cxx11-cu128-x86_64-linux/benchmark_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..02612d95e3ead1175a596e2878fa34b5bf85ad6f
--- /dev/null
+++ b/build/torch211-cxx11-cu128-x86_64-linux/benchmark_util.py
@@ -0,0 +1,35 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import numpy as np
+import torch
+
+
+def log_benchmark(name, arguments, time, std):
+ print('=' * 60)
+ print(f'{name} Benchmark')
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean time = {:.3f}ms, std time = {:.3f}ms'.format(time, std))
+ print('=' * 60)
+
+
+def benchmark_function(fn, iterations=100, warmup=10):
+ # Warmup iterations.
+ for _ in range(warmup):
+ fn()
+
+ times = []
+ for i in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+
+ start.record()
+ fn()
+ end.record()
+
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ return np.mean(times), np.std(times)
diff --git a/build/torch211-cxx11-cu128-x86_64-linux/cpu_fused_moe.py b/build/torch211-cxx11-cu128-x86_64-linux/cpu_fused_moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a29e81d6e9c57f64b0a78ac1c0828e45fd9d855
--- /dev/null
+++ b/build/torch211-cxx11-cu128-x86_64-linux/cpu_fused_moe.py
@@ -0,0 +1,311 @@
+# SPDX-License-Identifier: Apache-2.0
+# MegaBlocks CPU Fused MoE Implementation
+#
+# This is a pure Python/PyTorch implementation for CPU.
+# For better performance, consider using the C++ kernel implementation.
+#
+import torch
+import torch.nn.functional as F
+
+
+def swigluoai_activation(gate: torch.Tensor, up: torch.Tensor,
+ alpha: float = 1.702, limit: float = 7.0) -> torch.Tensor:
+ """
+ SwigluOAI activation function used in GptOss models.
+
+ Formula:
+ gate = clamp(gate, max=limit)
+ up = clamp(up, -limit, limit)
+ glu = gate * sigmoid(gate * alpha)
+ output = (up + 1) * glu
+
+ Args:
+ gate: Gate tensor from gate projection
+ up: Up tensor from up projection
+ alpha: Scaling factor for sigmoid (default: 1.702)
+ limit: Clamp limit (default: 7.0)
+
+ Returns:
+ Activated tensor
+ """
+ gate = gate.clamp(max=limit)
+ up = up.clamp(min=-limit, max=limit)
+ glu = gate * torch.sigmoid(gate * alpha)
+ return (up + 1) * glu
+
+
+def silu_and_mul_activation(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor:
+ """
+ SiLU (Swish) activation with element-wise multiplication.
+
+ Formula:
+ output = silu(gate) * up
+
+ Args:
+ gate: Gate tensor
+ up: Up tensor
+
+ Returns:
+ Activated tensor
+ """
+ return F.silu(gate) * up
+
+
+def route_tokens_cpu(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: torch.Tensor | None,
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_normalize_expert_weights: int | None = None,
+) -> tuple:
+ """
+ Route tokens to experts and compute expert weights and indices (CPU version).
+
+ Args:
+ x: Input tensor [batch, seq, hidden] or [tokens, hidden]
+ router_weight: Router weight [num_experts, hidden]
+ router_bias: Router bias [num_experts] or None
+ moe_top_k: Number of experts per token
+ moe_num_experts: Total number of experts
+ moe_normalize_expert_weights: Normalization order or None
+
+ Returns:
+ Tuple of (logits, expert_weights, expert_indices)
+ """
+ x_flat = x.view(-1, x.shape[-1])
+ logits = F.linear(x_flat, router_weight, router_bias)
+
+ if moe_top_k == 1:
+ expert_weights, expert_indices = logits.max(dim=-1, keepdim=True)
+ else:
+ expert_weights, expert_indices = torch.topk(logits, moe_top_k, dim=-1)
+
+ expert_weights = expert_weights.softmax(dim=-1)
+
+ if moe_normalize_expert_weights is not None:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+
+ return logits, expert_weights, expert_indices
+
+
+def cpu_fused_moe(
+ hidden_states: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ w1_bias: torch.Tensor | None = None,
+ w2_bias: torch.Tensor | None = None,
+ activation: str = "silu",
+ alpha: float = 1.702,
+ limit: float = 7.0,
+ is_interleaved: bool = True,
+) -> torch.Tensor:
+ """
+ CPU Fused MoE using PyTorch operations.
+
+ This implementation processes all experts in parallel using batched operations
+ instead of sequential for loops, which is more efficient on CPU.
+
+ Args:
+ hidden_states: [num_tokens, hidden_size]
+ w1: [num_experts, hidden_size, 2*inter_size] - gate_up_proj weights
+ w2: [num_experts, inter_size, hidden_size] - down_proj weights
+ topk_weights: [num_tokens, topk] - routing weights
+ topk_ids: [num_tokens, topk] - expert indices
+ w1_bias: [num_experts, 2*inter_size] or None
+ w2_bias: [num_experts, hidden_size] or None
+ activation: "silu" or "swigluoai"
+ alpha: swigluoai alpha parameter
+ limit: swigluoai limit parameter
+ is_interleaved: whether gate_up is interleaved [g0,u0,g1,u1,...] (True for GptOss)
+
+ Returns:
+ output: [num_tokens, hidden_size]
+ """
+ num_tokens, hidden_size = hidden_states.shape
+ num_experts = w1.shape[0]
+ inter_size = w2.shape[1]
+ topk = topk_weights.shape[1]
+
+ # Initialize output
+ output = torch.zeros_like(hidden_states)
+
+ # Build expert mask: which tokens go to which expert
+ # expert_mask[expert_id] contains indices of (token_idx, topk_pos) pairs
+ for expert_idx in range(num_experts):
+ # Find tokens assigned to this expert
+ # mask shape: [num_tokens, topk], True where topk_ids == expert_idx
+ mask = (topk_ids == expert_idx)
+
+ if not mask.any():
+ continue
+
+ # Get token indices and topk positions
+ token_indices, topk_positions = torch.where(mask)
+
+ if len(token_indices) == 0:
+ continue
+
+ # Gather input tokens for this expert
+ # current_hidden: [num_selected_tokens, hidden_size]
+ current_hidden = hidden_states[token_indices]
+
+ # Get weights for this expert
+ # w1[expert_idx]: [hidden_size, 2*inter_size]
+ # w2[expert_idx]: [inter_size, hidden_size]
+ expert_w1 = w1[expert_idx] # [hidden_size, 2*inter_size]
+ expert_w2 = w2[expert_idx] # [inter_size, hidden_size]
+
+ # First projection: hidden @ w1 -> [num_selected, 2*inter_size]
+ gate_up = current_hidden @ expert_w1
+
+ # Add bias if present
+ if w1_bias is not None:
+ gate_up = gate_up + w1_bias[expert_idx]
+
+ # Split gate and up projections
+ if is_interleaved:
+ # GptOss uses interleaved layout: [g0, u0, g1, u1, ...]
+ gate = gate_up[..., ::2] # [num_selected, inter_size]
+ up = gate_up[..., 1::2] # [num_selected, inter_size]
+ else:
+ # Standard layout: [gate_all, up_all]
+ gate = gate_up[..., :inter_size]
+ up = gate_up[..., inter_size:]
+
+ # Apply activation
+ if activation == "swigluoai":
+ activated = swigluoai_activation(gate, up, alpha, limit)
+ else: # silu
+ activated = silu_and_mul_activation(gate, up)
+
+ # Second projection: activated @ w2 -> [num_selected, hidden_size]
+ expert_out = activated @ expert_w2
+
+ # Add bias if present
+ if w2_bias is not None:
+ expert_out = expert_out + w2_bias[expert_idx]
+
+ # Apply routing weights and accumulate
+ # weights shape: [num_selected]
+ weights = topk_weights[token_indices, topk_positions].unsqueeze(-1)
+ weighted_out = expert_out * weights
+
+ # Accumulate to output
+ output.index_add_(0, token_indices, weighted_out.to(output.dtype))
+
+ return output
+
+
+class MegaBlocksMoeMLP(torch.nn.Module):
+ """
+ CPU MoE MLP module that can be used as a drop-in replacement for
+ the transformers GptOssMLP when using @use_kernel_forward_from_hub.
+ """
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> tuple:
+ """
+ Forward pass through the MoE layer.
+
+ Args:
+ x: Input tensor of shape [batch_size, seq_len, hidden_size] or [tokens, hidden_size]
+
+ Returns:
+ Tuple of (output, expert_weights) where:
+ - output: Tensor of same shape as input
+ - expert_weights: Expert weights for each token [tokens, top_k]
+ """
+ # Get MoE parameters from the wrapped modules
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+
+ # Detect activation type
+ if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
+ activation = "swigluoai"
+ alpha = self.experts.alpha
+ limit = self.experts.limit
+ else:
+ activation = getattr(self.experts, "activation", "silu")
+ alpha = 1.702
+ limit = 7.0
+
+ # Get weight tensors
+ if hasattr(self.experts, "gate_up_proj"):
+ w1 = self.experts.gate_up_proj
+ is_interleaved = True # GptOss uses interleaved layout
+ elif hasattr(self.experts, "w1"):
+ w1 = self.experts.w1
+ w3 = getattr(self.experts, "w3", None)
+ if w3 is not None:
+ w1 = torch.cat([w1, w3], dim=-1)
+ is_interleaved = False
+ else:
+ raise AttributeError("experts module must have 'gate_up_proj' or 'w1' attribute")
+
+ if hasattr(self.experts, "down_proj"):
+ w2 = self.experts.down_proj
+ elif hasattr(self.experts, "w2"):
+ w2 = self.experts.w2
+ else:
+ raise AttributeError("experts module must have 'down_proj' or 'w2' attribute")
+
+ # Get optional bias tensors
+ w1_bias = getattr(self.experts, "gate_up_proj_bias", None)
+ w2_bias = getattr(self.experts, "down_proj_bias", None)
+
+ # Store original shape
+ in_shape = x.size()
+
+ # Route tokens to experts
+ logits, expert_weights, expert_indices = route_tokens_cpu(
+ x,
+ self.router.weight,
+ getattr(self.router, "bias", None),
+ moe_top_k,
+ moe_num_experts,
+ moe_normalize_expert_weights,
+ )
+
+ # Reshape input for fused MoE
+ x_flat = x.view(-1, x.shape[-1])
+
+ # Call CPU fused MoE
+ output = cpu_fused_moe(
+ hidden_states=x_flat,
+ w1=w1,
+ w2=w2,
+ topk_weights=expert_weights,
+ topk_ids=expert_indices,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ activation=activation,
+ alpha=alpha,
+ limit=limit,
+ is_interleaved=is_interleaved,
+ )
+
+ # Restore original shape
+ output = output.view(in_shape)
+
+ return output, expert_weights
+
+
+# Export classes and functions
+__all__ = [
+ "MegaBlocksMoeMLP",
+ "cpu_fused_moe",
+ "route_tokens_cpu",
+ "swigluoai_activation",
+ "silu_and_mul_activation",
+]
diff --git a/build/torch211-cxx11-cu128-x86_64-linux/cpu_moe_cpp.py b/build/torch211-cxx11-cu128-x86_64-linux/cpu_moe_cpp.py
new file mode 100644
index 0000000000000000000000000000000000000000..073ff66d24ce348fbb5ed19c9027fadd3f7a9c61
--- /dev/null
+++ b/build/torch211-cxx11-cu128-x86_64-linux/cpu_moe_cpp.py
@@ -0,0 +1,265 @@
+# SPDX-License-Identifier: Apache-2.0
+# MegaBlocks C++ Optimized CPU MoE
+
+"""
+C++ accelerated MoE with brgemm optimization for Intel AMX.
+Direct replacement for cpu_fused_moe.MegaBlocksMoeMLP with better performance.
+"""
+
+import torch
+from typing import Optional
+from .cpu_fused_moe import route_tokens_cpu
+from ._ops import ops
+
+
+def _to_local_tensor(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
+ """Convert DTensor to local torch.Tensor if needed for custom ops compatibility."""
+ if tensor is None:
+ return None
+ # Check if it's a DTensor by looking for the to_local() method
+ if hasattr(tensor, "to_local"):
+ return tensor.to_local()
+ return tensor
+
+
+def fused_moe_cpp(
+ hidden_states: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ inplace: bool = False,
+ use_int8_w8a8: bool = False,
+ use_fp8_w8a16: bool = False,
+ use_mxfp4: bool = False,
+ w1_scale: Optional[torch.Tensor] = None,
+ w2_scale: Optional[torch.Tensor] = None,
+ block_size: Optional[list] = None,
+ a1_scale: Optional[torch.Tensor] = None,
+ a2_scale: Optional[torch.Tensor] = None,
+ w1_bias: Optional[torch.Tensor] = None,
+ w2_bias: Optional[torch.Tensor] = None,
+ alpha: Optional[float] = None,
+ limit: Optional[float] = None,
+ is_vnni: bool = False,
+) -> torch.Tensor:
+ """
+ C++ Fused MoE with brgemm optimization (sglang compatible interface).
+
+ Uses at::native::cpublas::brgemm for efficient batch GEMM on Intel CPUs.
+ Supports both silu_and_mul (standard SwiGLU) and swigluoai (GptOss) activations.
+
+ Args:
+ hidden_states: Input tensor [M, K]
+ w1: Gate and up projections [E, 2N, K]
+ w2: Down projection [E, K, N]
+ topk_weights: Expert weights [M, topk]
+ topk_ids: Expert indices [M, topk]
+ inplace: Whether to use hidden_states as output
+ use_int8_w8a8: Use int8 quantization
+ use_fp8_w8a16: Use fp8 quantization
+ use_mxfp4: Use mxfp4 quantization
+ w1_scale, w2_scale: Quantization scales
+ block_size: Block size for fp8
+ a1_scale, a2_scale: Activation scales
+ w1_bias, w2_bias: Optional biases
+ alpha: swigluoai alpha parameter (set to enable swiglu)
+ limit: swigluoai limit parameter (set to enable swiglu)
+ is_vnni: Whether w1/w2 are already in VNNI packed format
+ """
+ # MXFP4/FP8 kernels only support bf16, convert if needed
+ orig_dtype = hidden_states.dtype
+ need_convert = ((use_mxfp4 or use_fp8_w8a16) and orig_dtype != torch.bfloat16) or orig_dtype == torch.float32
+ if need_convert:
+ hidden_states = hidden_states.to(torch.bfloat16)
+
+ # bias must match hidden_states dtype
+ if w1_bias is not None:
+ w1_bias = w1_bias.to(hidden_states.dtype)
+ if w2_bias is not None:
+ w2_bias = w2_bias.to(hidden_states.dtype)
+
+ # Convert DTensor to local tensor for custom ops compatibility (TP mode)
+ hidden_states = _to_local_tensor(hidden_states)
+ w1 = _to_local_tensor(w1)
+ w2 = _to_local_tensor(w2)
+ topk_weights = _to_local_tensor(topk_weights)
+ topk_ids = _to_local_tensor(topk_ids)
+ w1_scale = _to_local_tensor(w1_scale)
+ w2_scale = _to_local_tensor(w2_scale)
+ a1_scale = _to_local_tensor(a1_scale)
+ a2_scale = _to_local_tensor(a2_scale)
+ w1_bias = _to_local_tensor(w1_bias)
+ w2_bias = _to_local_tensor(w2_bias)
+
+ output = ops.fused_experts(
+ hidden_states, w1, w2, topk_weights, topk_ids,
+ inplace, use_int8_w8a8, use_fp8_w8a16, use_mxfp4,
+ w1_scale, w2_scale, block_size, a1_scale, a2_scale,
+ w1_bias, w2_bias, alpha, limit, is_vnni
+ )
+
+ # Convert back to original dtype if needed
+ if need_convert:
+ output = output.to(orig_dtype)
+ return output
+
+
+class CPUMegaBlocksMoeMLP(torch.nn.Module):
+ """
+ C++ optimized MoE MLP using brgemm.
+ Drop-in replacement for cpu_fused_moe.MegaBlocksMoeMLP with better performance.
+
+ Usage in transformers:
+ # Will be used via @use_kernel_forward_from_hub decorator
+ """
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> tuple:
+ """
+ Forward pass through the MoE layer using C++ kernel.
+
+ Args:
+ x: Input tensor [batch_size, seq_len, hidden_size]
+
+ Returns:
+ Tuple of (output, expert_weights)
+ """
+ # Optimization for GPT-OSS model
+ if getattr(self, "use_mxfp4", None) is None:
+ self.use_mxfp4 = False
+
+ w1_scale = None
+ w2_scale = None
+
+ if (
+ not getattr(self, "packed_scales", False)
+ and hasattr(self.experts, "gate_up_proj")
+ and getattr(self.experts, "gate_up_proj_precision_config", None) is not None
+ ):
+ # convert scales
+ data_1 = ops.convert_scale_packed(self.experts.gate_up_proj_precision_config.weight_scale.data.transpose(-1, -2).contiguous())
+ data_2 = ops.convert_scale_packed(self.experts.down_proj_precision_config.weight_scale.data.transpose(-1, -2).contiguous())
+ self.experts.gate_up_proj_precision_config.weight_scale.storage.data = data_1
+ self.experts.down_proj_precision_config.weight_scale.storage.data = data_2
+ self.packed_scales = True
+ self.use_mxfp4 = True
+
+ if not getattr(self, "packed_weight", False) and hasattr(
+ self.experts, "gate_up_proj"
+ ):
+ # convert weights
+ data_1 = self.experts.gate_up_proj.data.transpose(-1, -2).contiguous()
+ data_2 = self.experts.down_proj.data.transpose(-1, -2).contiguous()
+ if self.use_mxfp4:
+ self.experts.gate_up_proj.storage.data = ops.convert_weight_packed(data_1)
+ self.experts.down_proj.storage.data = ops.convert_weight_packed(data_2)
+ else:
+ # convert_weight_packed only supports bfloat16, float16, int8, fp8_e4m3 or uint8(mxfp4 or int4).
+ data_1 = data_1.to(torch.bfloat16) if data_1.dtype == torch.float32 else data_1
+ data_2 = data_2.to(torch.bfloat16) if data_2.dtype == torch.float32 else data_2
+ self.experts.gate_up_proj.data = ops.convert_weight_packed(data_1)
+ self.experts.down_proj.data = ops.convert_weight_packed(data_2)
+
+ # C++ kernel does not support float32.
+ dtype = torch.bfloat16 if x.dtype == torch.float32 else x.dtype
+ if getattr(self.experts, "gate_up_proj_bias", None) is not None:
+ self.experts.gate_up_proj_bias.data = self.experts.gate_up_proj_bias.data.to(dtype)
+ if getattr(self.experts, "down_proj_bias", None) is not None:
+ self.experts.down_proj_bias.data = self.experts.down_proj_bias.data.to(dtype)
+
+ self.packed_weight = True
+
+ # Get MoE parameters
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
+
+ # Detect activation type
+ if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
+ activation = "swigluoai"
+ alpha = self.experts.alpha
+ limit = self.experts.limit
+ else:
+ activation = getattr(self.experts, "activation", "silu")
+ alpha = 1.702
+ limit = 7.0
+
+ # Get weight tensors
+ if hasattr(self.experts, "gate_up_proj"):
+ w1 = self.experts.gate_up_proj
+ elif hasattr(self.experts, "w1"):
+ w1 = self.experts.w1
+ w3 = getattr(self.experts, "w3", None)
+ if w3 is not None:
+ w1 = torch.cat([w1, w3], dim=-1)
+ else:
+ raise AttributeError("experts module must have 'gate_up_proj' or 'w1' attribute")
+
+ if hasattr(self.experts, "down_proj"):
+ w2 = self.experts.down_proj
+ elif hasattr(self.experts, "w2"):
+ w2 = self.experts.w2
+ else:
+ raise AttributeError("experts module must have 'down_proj' or 'w2' attribute")
+
+ # Get optional bias tensors
+ w1_bias = getattr(self.experts, "gate_up_proj_bias", None)
+ w2_bias = getattr(self.experts, "down_proj_bias", None)
+ w1_bias = w1_bias if w1_bias is None else w1_bias.data
+ w2_bias = w2_bias if w2_bias is None else w2_bias.data
+
+ if self.use_mxfp4:
+ w1_scale = self.experts.gate_up_proj_precision_config.weight_scale.data
+ w2_scale = self.experts.down_proj_precision_config.weight_scale.data
+
+ # Store original shape
+ in_shape = x.size()
+
+ # Route tokens to experts (Python implementation is fast enough)
+ logits, expert_weights, expert_indices = route_tokens_cpu(
+ x,
+ self.router.weight,
+ getattr(self.router, "bias", None),
+ moe_top_k,
+ moe_num_experts,
+ moe_normalize_expert_weights,
+ )
+
+ # Flatten input
+ x_flat = x.view(-1, x.shape[-1])
+
+ # Determine alpha/limit for swiglu activation
+ use_alpha = alpha if activation == "swigluoai" else None
+ use_limit = limit if activation == "swigluoai" else None
+
+ # Call C++ optimized kernel
+ output = fused_moe_cpp(
+ hidden_states=x_flat,
+ w1=w1.data,
+ w2=w2.data,
+ topk_weights=expert_weights,
+ topk_ids=expert_indices.to(torch.int32),
+ inplace=False,
+ use_int8_w8a8=False,
+ use_fp8_w8a16=False,
+ use_mxfp4=self.use_mxfp4,
+ w1_scale=w1_scale,
+ w2_scale=w2_scale,
+ block_size=None,
+ a1_scale=None,
+ a2_scale=None,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ alpha=use_alpha,
+ limit=use_limit,
+ is_vnni=getattr(self, "packed_weight", False),
+ )
+
+ # Restore original shape
+ output = output.view(in_shape)
+
+ return output, expert_weights
+
+
+__all__ = ["fused_moe_cpp", "MegaBlocksMoeMLP"]
diff --git a/build/torch211-cxx11-cu128-x86_64-linux/grouped_gemm/__init__.py b/build/torch211-cxx11-cu128-x86_64-linux/grouped_gemm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b91c8308f0c24f4c4171b6e4f15b6f76dabf295a
--- /dev/null
+++ b/build/torch211-cxx11-cu128-x86_64-linux/grouped_gemm/__init__.py
@@ -0,0 +1,2 @@
+from . import ops
+from . import backend
diff --git a/build/torch211-cxx11-cu128-x86_64-linux/grouped_gemm/backend.py b/build/torch211-cxx11-cu128-x86_64-linux/grouped_gemm/backend.py
new file mode 100644
index 0000000000000000000000000000000000000000..76037d8039cbfc2f0577275c78e4bc0be762592a
--- /dev/null
+++ b/build/torch211-cxx11-cu128-x86_64-linux/grouped_gemm/backend.py
@@ -0,0 +1,33 @@
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# # TODO(tgale): Wrap this in a try-block with better
+# # error message and instructions for building the
+# # c++ operations.
+# import grouped_gemm_backend as backend
+
+# We import the backend operations from the megablocks package as
+# grouped_gemm is vendored in megablocks in this repository.
+# from ... import _ops as backend
+# from megablocks._ops import ops as backend # type: ignore
+from .._ops import ops as backend # type: ignore
+
+def _allocate_output(a, b, batch_sizes, trans_a, trans_b):
+ assert not (trans_a and trans_b)
+ assert batch_sizes.ndim == 1, "Expected 1d tensor for batch_sizes"
+ assert a.ndim == 2, "Expected 2d tensor for 'a'"
+ assert b.ndim == (2 if trans_a else 3)
+
+ shape = (
+ (batch_sizes.shape[0], a.shape[1], b.shape[1])
+ if trans_a else
+ (a.shape[0], (b.shape[1] if trans_b else b.shape[2]))
+ )
+ return torch.empty(*shape, device=a.device, dtype=a.dtype)
+
+def gmm(a, b, batch_sizes, trans_a=False, trans_b=False, c=None):
+ if c is None:
+ c = _allocate_output(a, b, batch_sizes, trans_a, trans_b)
+ backend.gmm(a, b, c, batch_sizes, trans_a, trans_b)
+ return c
diff --git a/build/torch211-cxx11-cu128-x86_64-linux/grouped_gemm/ops.py b/build/torch211-cxx11-cu128-x86_64-linux/grouped_gemm/ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b30dd14e23837ea3b12334f4e31337ed9ad2b69
--- /dev/null
+++ b/build/torch211-cxx11-cu128-x86_64-linux/grouped_gemm/ops.py
@@ -0,0 +1,33 @@
+from . import backend
+import torch
+
+
+class GroupedGemm(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, a, b, batch_sizes, trans_b):
+ ctx.save_for_backward(a, b, batch_sizes)
+ ctx.trans_b = trans_b
+ return backend.gmm(a, b, batch_sizes, trans_a=False, trans_b=trans_b)
+
+ @staticmethod
+ def backward(ctx, grad):
+ grad = grad.contiguous()
+ a, b, batch_sizes = ctx.saved_tensors
+ trans_b = ctx.trans_b
+
+ agrad = None
+ if ctx.needs_input_grad[0]:
+ agrad = backend.gmm(
+ grad, b, batch_sizes, trans_a=False, trans_b=not trans_b)
+
+ bgrad = None
+ if ctx.needs_input_grad[1]:
+ lhs, rhs = (grad, a) if trans_b else (a, grad)
+ bgrad = backend.gmm(
+ lhs, rhs, batch_sizes, trans_a=True, trans_b=False)
+ return agrad, bgrad, None, None
+
+
+def gmm(a, b, batch_sizes, trans_b=False):
+ return GroupedGemm.apply(a, b, batch_sizes, trans_b)
diff --git a/build/torch211-cxx11-cu128-x86_64-linux/grouped_gemm_util.py b/build/torch211-cxx11-cu128-x86_64-linux/grouped_gemm_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d6d49fc46b0a57ad46e4179df3cc1ac2a24f7ae
--- /dev/null
+++ b/build/torch211-cxx11-cu128-x86_64-linux/grouped_gemm_util.py
@@ -0,0 +1,31 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+import warnings
+
+_grouped_gemm_is_available: bool = False
+try:
+ # import grouped_gemm
+ pass
+ _grouped_gemm_is_available = True
+except ImportError as error:
+ warnings.warn('Grouped GEMM not available.')
+
+
+def grouped_gemm_is_available():
+ return _grouped_gemm_is_available
+
+
+def assert_grouped_gemm_is_available():
+ msg = (
+ 'Grouped GEMM not available. Please run '
+ '`pip install git+https://github.com/tgale96/grouped_gemm@main`.',
+ )
+ assert _grouped_gemm_is_available, msg
+
+
+# backend = grouped_gemm.backend if grouped_gemm_is_available() else None
+# ops = grouped_gemm.ops if grouped_gemm_is_available() else None
+
+
+from .grouped_gemm import backend as ops
+from .grouped_gemm import ops as backend
diff --git a/build/torch211-cxx11-cu128-x86_64-linux/layers.py b/build/torch211-cxx11-cu128-x86_64-linux/layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c91edcd9e2d1a4ef9eac90217ff481f08ab1886
--- /dev/null
+++ b/build/torch211-cxx11-cu128-x86_64-linux/layers.py
@@ -0,0 +1,1232 @@
+import torch
+import torch.distributed as dist
+
+from typing import Optional, Any, TYPE_CHECKING
+
+from . import _layers
+from . import ops
+
+# Conditional import for meta kernel registration
+if TYPE_CHECKING:
+
+ def register_fake(fn):
+ return lambda name: fn
+
+else:
+ try:
+ from torch.library import register_fake
+ except ImportError:
+ try:
+ from torch.library import impl_abstract as register_fake
+ except ImportError:
+ # Fallback for older PyTorch versions
+ def register_fake(op_name):
+ def decorator(fn):
+ return fn
+
+ return decorator
+
+
+# Meta kernel implementations for torch.compile compatibility
+def _install_meta_kernels():
+ """Install meta kernels for existing MegaBlocks operations"""
+
+ # Create wrapper functions that check for compilation and return meta tensors
+
+ # Patch ops.sort
+ if hasattr(ops, "sort"):
+ original_sort = ops.sort
+
+ def sort_with_meta(x, end_bit=None):
+ if torch.compiler.is_compiling():
+ # print("Using meta kernel for sort")
+ # Meta implementation - return tensors with correct shape/dtype/device
+ return torch.empty_like(x), torch.empty_like(x)
+ # print("Using original sort kernel")
+ return original_sort(x, end_bit)
+
+ ops.sort = sort_with_meta
+
+ # Patch ops.histogram
+ if hasattr(ops, "histogram"):
+ original_histogram = ops.histogram
+
+ def histogram_with_meta(x, max_val):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty((max_val,), dtype=torch.int32, device=x.device)
+ return original_histogram(x, max_val)
+
+ ops.histogram = histogram_with_meta
+
+ # Patch ops.inclusive_cumsum
+ if hasattr(ops, "inclusive_cumsum"):
+ original_inclusive_cumsum = ops.inclusive_cumsum
+
+ def inclusive_cumsum_with_meta(x, dim):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty_like(x)
+ return original_inclusive_cumsum(x, dim)
+
+ ops.inclusive_cumsum = inclusive_cumsum_with_meta
+
+ # Patch ops.binned_gather
+ if hasattr(ops, "binned_gather"):
+ original_binned_gather = ops.binned_gather
+
+ def binned_gather_with_meta(x, indices, bins, bin_size, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - output shape based on bin_size
+ if x.dim() >= 2:
+ hidden_size = x.size(-1)
+ return torch.empty(
+ (bin_size, x.size(1), hidden_size),
+ dtype=x.dtype,
+ device=x.device,
+ )
+ else:
+ return torch.empty((bin_size,), dtype=x.dtype, device=x.device)
+ return original_binned_gather(x, indices, bins, bin_size, top_k)
+
+ ops.binned_gather = binned_gather_with_meta
+
+ # Patch ops.binned_scatter
+ if hasattr(ops, "binned_scatter"):
+ original_binned_scatter = ops.binned_scatter
+
+ def binned_scatter_with_meta(x, indices, weights, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - typically reduces to 2D
+ if x.dim() >= 3:
+ return torch.empty(
+ (x.size(1), x.size(2)), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty_like(x)
+ return original_binned_scatter(x, indices, weights, bins, top_k)
+
+ ops.binned_scatter = binned_scatter_with_meta
+
+ # Patch ops.gather
+ if hasattr(ops, "gather"):
+ original_gather = ops.gather
+
+ def gather_with_meta(x, indices, bin_ids, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ if x.dim() >= 2:
+ hidden_size = x.size(-1)
+ return torch.empty(
+ (indices.numel(), hidden_size), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty(indices.shape, dtype=x.dtype, device=x.device)
+ return original_gather(x, indices, bin_ids, bins, top_k)
+
+ ops.gather = gather_with_meta
+
+ # Patch ops.scatter
+ if hasattr(ops, "scatter"):
+ original_scatter = ops.scatter
+
+ def scatter_with_meta(x, indices, bin_ids, weights, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - restore sequence shape
+ seq_len = (
+ indices.size(0) // top_k
+ if indices.numel() > 0 and top_k > 0
+ else x.size(0)
+ )
+ if x.dim() >= 2:
+ return torch.empty(
+ (seq_len, x.size(-1)), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty((seq_len,), dtype=x.dtype, device=x.device)
+ return original_scatter(x, indices, bin_ids, weights, bins, top_k)
+
+ ops.scatter = scatter_with_meta
+
+ # Patch ops.replicate
+ if hasattr(ops, "replicate"):
+ original_replicate = ops.replicate
+
+ def replicate_with_meta(x, bins, num_outputs):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty(
+ (x.shape[0], num_outputs), dtype=x.dtype, device=x.device
+ )
+ return original_replicate(x, bins, num_outputs)
+
+ ops.replicate = replicate_with_meta
+
+ # Patch ops.repeat (if it's a regular function)
+ if hasattr(ops, "repeat"):
+ original_repeat = ops.repeat
+
+ def repeat_with_meta(x, repeats):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ if isinstance(repeats, (tuple, list)):
+ new_shape = list(x.shape)
+ for i, rep in enumerate(repeats):
+ if i < len(new_shape):
+ new_shape[i] *= rep
+ return torch.empty(new_shape, dtype=x.dtype, device=x.device)
+ else:
+ new_shape = [x.size(0) * repeats] + list(x.shape[1:])
+ return torch.empty(new_shape, dtype=x.dtype, device=x.device)
+ return original_repeat(x, repeats)
+
+ ops.repeat = repeat_with_meta
+
+
+# Install meta kernels on import
+try:
+ _install_meta_kernels()
+except Exception as e:
+ # If meta kernel installation fails, continue without them
+ # torch.compile may not work but the library will still function
+ import warnings
+
+ warnings.warn(
+ f"Failed to install meta kernels for torch.compile support: {e}", UserWarning
+ )
+
+
+# Set the expert model parallel attributes on a tensor
+def set_expert_model_parallel_attributes(
+ tensor: torch.Tensor,
+ is_parallel: bool,
+):
+ assert not hasattr(tensor, "expert_model_parallel")
+ setattr(tensor, "expert_model_parallel", is_parallel)
+
+
+# Get the expert model parallel attributes from a tensor
+def expert_sharding_degree(
+ world_size: int,
+ moe_num_experts: int,
+) -> int:
+ esd = min(world_size, moe_num_experts)
+ if (moe_num_experts % esd) != 0:
+ raise ValueError(f"Cannot shard {moe_num_experts} experts {esd} ways.")
+ return esd
+
+
+# Calculate the hidden sharding degree based on world size and expert sharding degree
+def hidden_sharding_degree(
+ world_size: int,
+ moe_num_experts: int,
+ ffn_hidden_size: int,
+) -> int:
+ esd = expert_sharding_degree(world_size, moe_num_experts)
+ hsd = world_size // esd
+ if (ffn_hidden_size % hsd) != 0:
+ raise ValueError(f"Cannot shard {ffn_hidden_size} features {hsd} ways.")
+ if (esd * hsd) != world_size:
+ raise ValueError(
+ f"Invalid sharding. expert_sharding_degree ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size})."
+ )
+ return hsd
+
+
+# Calculate the number of experts per rank based on world size and expert sharding degree
+def experts_per_rank(
+ moe_num_experts: int,
+ world_size: int,
+) -> int:
+ return moe_num_experts // expert_sharding_degree(world_size, moe_num_experts)
+
+
+# Calculate the number of features per rank based on ffn hidden size and hidden sharding degree
+def features_per_rank(
+ ffn_hidden_size: int, world_size: int, moe_num_experts: int
+) -> int:
+ return ffn_hidden_size // hidden_sharding_degree(
+ world_size, moe_num_experts, ffn_hidden_size
+ )
+
+
+# Apply jitter to the input tensor
+def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor:
+ low = 1.0 - moe_jitter_eps
+ high = 1.0 + moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return x * (low + noise * (high - low))
+
+
+# Compute the top-k scores from the logits
+def compute_top_k(scores: torch.Tensor, moe_top_k: int):
+ if moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, moe_top_k, dim=-1)
+
+
+# Route tokens to experts and compute expert weights and indices
+def route_tokens(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: torch.Tensor,
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ if training and moe_jitter_eps is not None:
+ x = apply_jitter(x, moe_jitter_eps)
+
+ x_flat = x.view(-1, x.shape[-1])
+ logits = torch.nn.functional.linear(x_flat, router_weight, router_bias)
+ expert_weights, expert_indices = compute_top_k(logits, moe_top_k)
+ expert_weights = expert_weights.softmax(dim=-1)
+ if moe_normalize_expert_weights is not None:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+ if uniform_expert_assignment:
+ expert_indices = _layers.router._uniform_expert_assignment(
+ expert_indices,
+ moe_num_experts,
+ )
+
+ return logits, expert_weights, expert_indices
+
+
+# Scale the gradient of the weights
+def scale_grad(
+ w: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+) -> torch.Tensor:
+ if gradient_scale is None:
+ return w
+ return _layers.mlp.scale_gradient(w, gradient_scale)
+
+
+# Forward pass for the MLP layer
+def mlp_forward(
+ x: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ limit: float = 7.0,
+):
+ # Scale weights
+ w1 = scale_grad(w1, gradient_scale)
+ w2 = scale_grad(w2, gradient_scale)
+ w1_bias = scale_grad(w1_bias, gradient_scale)
+ w2_bias = scale_grad(w2_bias, gradient_scale)
+
+ # Resolve dtensors
+ w1 = _layers.mlp.resolve_dtensor(w1)
+ w2 = _layers.mlp.resolve_dtensor(w2)
+ w1_bias = _layers.mlp.resolve_dtensor(w1_bias)
+ w2_bias = _layers.mlp.resolve_dtensor(w2_bias)
+
+ # Forward pass
+ gate_up = torch.bmm(x, w1) + w1_bias[..., None, :]
+ gate, up = gate_up[..., ::2], gate_up[..., 1::2]
+ gate = gate.clamp(min=None, max=limit)
+ up = up.clamp(min=-limit, max=limit)
+ glu = gate * torch.sigmoid(gate * alpha)
+ next_states = torch.bmm(((up + 1) * glu), w2)
+ next_states += w2_bias[..., None, :]
+ return next_states
+
+# Shared expert MLP forward pass
+def shared_mlp_forward(
+ x: torch.Tensor,
+ up_proj_weight: torch.Tensor,
+ down_proj_weight: torch.Tensor,
+ up_proj_bias: Optional[torch.Tensor] = None,
+ down_proj_bias: Optional[torch.Tensor] = None,
+ activation_fn: Optional[Any] = None,
+ gradient_scale: Optional[float] = None,
+) -> torch.Tensor:
+ # Default activation function
+ if activation_fn is None:
+ activation_fn = torch.nn.functional.gelu
+
+ # Scale weights
+ up_proj_weight = scale_grad(up_proj_weight, gradient_scale)
+ down_proj_weight = scale_grad(down_proj_weight, gradient_scale)
+ if up_proj_bias is not None:
+ up_proj_bias = scale_grad(up_proj_bias, gradient_scale)
+ if down_proj_bias is not None:
+ down_proj_bias = scale_grad(down_proj_bias, gradient_scale)
+
+ # Resolve dtensors
+ up_proj_weight = _layers.mlp.resolve_dtensor(up_proj_weight)
+ down_proj_weight = _layers.mlp.resolve_dtensor(down_proj_weight)
+ if up_proj_bias is not None:
+ up_proj_bias = _layers.mlp.resolve_dtensor(up_proj_bias)
+ if down_proj_bias is not None:
+ down_proj_bias = _layers.mlp.resolve_dtensor(down_proj_bias)
+
+ # Up projection
+ x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias)
+
+ # Activation
+ x = activation_fn(x)
+
+ # Down projection
+ x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias)
+
+ return x
+
+
+# Combine outputs from shared expert and regular experts
+def combine_expert_shared_outputs(
+ shared_expert_out: torch.Tensor,
+ expert_out: torch.Tensor,
+ shared_expert_weighted_sum: bool = False,
+ moe_top_k: int = 1,
+) -> torch.Tensor:
+ if shared_expert_weighted_sum:
+ # Weighted sum based on number of experts used
+ total_experts = moe_top_k + 1
+ shared_weight = 1.0 / total_experts
+ expert_weight = moe_top_k / total_experts
+ return shared_expert_out * shared_weight + expert_out * expert_weight
+ else:
+ # Simple addition
+ return shared_expert_out + expert_out
+
+
+# Global variable to store load balancing loss
+_LOAD_BALANCING_LOSS = []
+
+
+def save_load_balancing_loss(loss):
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.append(loss)
+
+
+def get_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ return _LOAD_BALANCING_LOSS
+
+
+def clear_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.clear()
+
+
+def batched_load_balancing_loss(args):
+ if args.moe_loss_weight == 0:
+ return 0.0
+
+ tokens_per_expert, expert_scores = zip(*get_load_balancing_loss())
+ num_layers_per_pipeline_stage = args.num_layers // args.pipeline_model_parallel_size
+ if args.num_layers_per_virtual_pipeline_stage is not None:
+ num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage
+
+ if len(tokens_per_expert) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f"Expected {num_layers_per_pipeline_stage} token_per_experts "
+ f"but found {len(tokens_per_expert)}.\nnum_layers = "
+ f"{args.num_layers}\npipeline_model_parallel_size = "
+ f"{args.pipeline_model_parallel_size}\n"
+ "num_layers_per_virtual_pipeline_stage"
+ f" = {args.num_layers_per_virtual_pipeline_stage}",
+ )
+ if len(expert_scores) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f"Expected {num_layers_per_pipeline_stage} expert_scores "
+ f"but found {len(tokens_per_expert)}.\nnum_layers = "
+ f"{args.num_layers}\npipeline_model_parallel_size = "
+ f"{args.pipeline_model_parallel_size}\n"
+ "num_layers_per_virtual_pipeline_stage"
+ f" = {args.num_layers_per_virtual_pipeline_stage}",
+ )
+
+ # Verify the shape of the tokens_per_expert and expert_scores tensors.
+ assert all(
+ (x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert)
+ )
+
+ tokens = expert_scores[0].shape[0]
+ assert all(
+ (
+ (
+ x.ndim == 2
+ and x.shape[1] == args.moe_num_experts
+ and x.shape[0] == tokens
+ )
+ for x in expert_scores
+ )
+ )
+
+ # Concatenate the contributions of each layer and convert to
+ # the correct types and formats for the dot product.
+ expert_scores = torch.cat(expert_scores, dim=1)
+ if args.moe_lbl_in_fp32:
+ expert_scores = expert_scores.float()
+ if tokens != 0:
+ expert_scores = expert_scores.mean(dim=0)
+ else:
+ expert_scores = expert_scores.sum(dim=0)
+ tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype)
+
+ expected_values = num_layers_per_pipeline_stage * args.moe_num_experts
+ assert tokens_per_expert.numel() == expected_values
+ assert expert_scores.numel() == expected_values
+
+ # Calculate the total scale across all factors.
+ #
+ # loss_weight * num_experts / (num_layers * tokens * top_k)
+ scale_numerator = args.moe_num_experts * args.moe_loss_weight
+ scale_denominator = args.num_layers * tokens * args.moe_top_k
+ scale = scale_numerator / scale_denominator
+ return scale * torch.dot(tokens_per_expert, expert_scores)
+
+
+# Calculate the expert capacity based on tokens, top_k, number of experts,
+# expert parallel group, capacity factor, and whether expert model parallelism is used.
+def expert_capacity(
+ tokens: int,
+ top_k: int,
+ num_experts: int,
+ expert_parallel_group: int,
+ moe_capacity_factor: float,
+ moe_expert_model_parallelism: bool,
+) -> int:
+ world_size = (
+ dist.get_world_size(expert_parallel_group)
+ if moe_expert_model_parallelism
+ else 1
+ )
+
+ tokens_per_expert = top_k * tokens * world_size / num_experts
+ return int(moe_capacity_factor * tokens_per_expert)
+
+
+def load_balancing_loss(
+ tokens_per_expert: torch.Tensor,
+ expert_scores: torch.Tensor,
+ top_k: int,
+ num_experts: int,
+):
+ assert len(expert_scores.size()) == 2
+ tokens, num_experts = expert_scores.size()
+ assert num_experts == num_experts
+ assert len(tokens_per_expert.size()) == 1
+ (num_experts,) = tokens_per_expert.size()
+ assert num_experts == num_experts
+ scale = num_experts / (tokens * top_k)
+ return scale * torch.dot(
+ tokens_per_expert.to(expert_scores.dtype),
+ expert_scores.mean(dim=0),
+ )
+
+
+def indices_and_bins(
+ top_expert: torch.Tensor,
+ sort_end_bit: int,
+ num_experts: int,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ top_expert = top_expert.int()
+
+ # Ensure contiguous memory layout
+ top_expert = top_expert.contiguous()
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(top_expert.device):
+ output = ops.sort(top_expert, sort_end_bit)
+ bin_ids, indices = output
+ tokens_per_expert = ops.histogram(top_expert, num_experts)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+
+ bins = bins.view(1) if not len(bins.size()) else bins
+ return indices, bin_ids, bins, tokens_per_expert
+
+
+def expert_capacity_fn(
+ tokens: int,
+ top_k: int,
+ num_experts: int,
+ expert_parallel_group: torch.distributed.ProcessGroup,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+) -> int:
+ world_size = (
+ dist.get_world_size(expert_parallel_group)
+ if moe_expert_model_parallelism
+ else 1
+ )
+ tokens_per_expert = top_k * tokens * world_size / num_experts
+ return int(moe_capacity_factor * tokens_per_expert)
+
+
+def permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ top_k,
+ w1,
+ w2,
+ w1_bias,
+ w2_bias,
+ gradient_scale,
+ alpha,
+):
+ # Route tokens to experts
+ x = x.view(-1, x.shape[-1])
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(x.device):
+ x = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
+
+ # Expert computation
+ x = mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale, alpha)
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(x.device):
+ # Route tokens back
+ out = ops.binned_scatter(x, indices, expert_weights, bins, top_k)
+ return out
+
+
+def forward_once(
+ x: torch.Tensor,
+ expert_weights: torch.Tensor,
+ top_experts: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ top_k: int = 4,
+ num_experts: int = 128,
+ expert_parallel_group: int = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ mlp_impl: Optional[str] = None,
+):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
+ top_experts, sort_end_bit, num_experts
+ )
+
+ # Calculate expert capacity
+ sl, bs, _ = x.size()
+
+ expert_capacity = expert_capacity_fn(
+ sl * bs,
+ top_k,
+ num_experts,
+ expert_parallel_group,
+ moe_capacity_factor,
+ moe_expert_model_parallelism,
+ )
+
+ if expert_capacity == 0:
+ expert_capacity = torch.max(tokens_per_expert).item()
+
+ x = permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ top_k,
+ w1,
+ w2,
+ w1_bias,
+ w2_bias,
+ gradient_scale,
+ alpha,
+ )
+ return x, tokens_per_expert
+
+
+def parallel_forward_once(
+ x: torch.Tensor,
+ expert_weights: torch.Tensor,
+ top_experts: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ top_k: int = 4,
+ num_experts: int = 128,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = True,
+ hidden_size: int = 1152,
+ mlp_impl: Optional[str] = "grouped",
+):
+ # Flatten inputs
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+
+ # TODO: remove debugging var
+ # my_rank = dist.get_rank(expert_parallel_group) if expert_parallel_group else 0
+
+ with torch.no_grad():
+ # Step 1: Local permutation setup
+ indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
+ top_experts, sort_end_bit, num_experts
+ )
+
+ # Calculate sharding parameters
+ world_size = dist.get_world_size(expert_parallel_group)
+ hidden_sharding_deg = hidden_sharding_degree(
+ world_size, num_experts, hidden_size
+ )
+ experts_per_rank_val = experts_per_rank(num_experts, world_size)
+
+ # Replicate token counts for hidden sharding
+ repeated_tokens_per_expert = ops.repeat(
+ tokens_per_expert, (hidden_sharding_deg,)
+ )
+
+ # Exchange token counts across devices
+ parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert)
+
+ # Ensure CUB knows which device to use
+ tpe_handle = dist.all_to_all_single(
+ parallel_tokens_per_expert,
+ repeated_tokens_per_expert,
+ group=expert_parallel_group,
+ async_op=True,
+ )
+
+ # Step 2: Local permutation - group tokens by target device
+ x = x.view(-1, x.shape[-1]) # [sl * bs, hs]
+ x = ops.gather(x, indices, bin_ids, bins, top_k)
+
+ # Step 3: Compute communication counts and exchange tokens
+ with torch.no_grad():
+ tpe_handle.wait()
+
+ # Reshape for per-device calculations
+ repeated_tokens_per_expert = repeated_tokens_per_expert.view(
+ world_size, experts_per_rank_val
+ )
+ parallel_tokens_per_expert = parallel_tokens_per_expert.view(
+ world_size, experts_per_rank_val
+ )
+
+ # Calculate send/recv counts
+ send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1).tolist()
+ # recv_counts = parallel_tokens_per_expert.cpu().sum(dim=-1).tolist()
+ parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
+ recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1).tolist()
+ tokens_received = sum(recv_counts)
+
+ # Replicate for hidden sharding
+ x = ops.repeat(x, (hidden_sharding_deg, 1))
+
+ # Cross-device token exchange
+ parallel_x, parallel_x_handle = _layers.all_to_all.all_to_all(
+ x, recv_counts, send_counts, expert_parallel_group, async_op=True
+ )
+
+ with torch.no_grad():
+ # Step 4: Setup for local expert computation
+ replicate_bins = ops.inclusive_cumsum(parallel_tokens_per_expert.flatten(), 0)
+ replicate_bins = (
+ replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins
+ )
+
+ # Create expert indices for received tokens
+ parallel_top_expert = torch.remainder(
+ torch.arange(
+ num_experts * hidden_sharding_deg,
+ dtype=torch.int32,
+ device=indices.device,
+ ),
+ experts_per_rank_val,
+ )
+ parallel_top_expert = ops.replicate(
+ parallel_top_expert.unsqueeze(dim=0),
+ replicate_bins,
+ tokens_received,
+ ).flatten()
+
+ # Sort tokens by expert assignment
+ parallel_bin_ids, parallel_indices = ops.sort(
+ parallel_top_expert,
+ sort_end_bit,
+ )
+
+ # Calculate bins for local experts
+ parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
+ dim=0, dtype=torch.int
+ )
+ parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = (
+ parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins
+ )
+
+ # Calculate expert capacity
+ expert_capacity = expert_capacity_fn(
+ tokens_received,
+ top_k,
+ experts_per_rank_val,
+ expert_parallel_group,
+ moe_capacity_factor,
+ moe_expert_model_parallelism,
+ )
+ if expert_capacity == 0:
+ expert_capacity = torch.max(parallel_tokens_per_expert).item()
+
+ # Locally permute the tokens and perform the expert computation.
+ # Block to make sure that the cross-device permutation is complete.
+ if mlp_impl == "grouped":
+ # GroupedMLP requires counts on CPU. We can use the tensor already
+ # moved to CPU for the prior all_to_all, which avoids an extra
+ # device synchronization.
+ parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+
+ # Step 5: Expert computation
+ parallel_x_handle.wait()
+
+ parallel_x = permute_and_compute(
+ parallel_x,
+ parallel_tokens_per_expert,
+ parallel_indices,
+ parallel_bin_ids,
+ None, # expert_weights
+ parallel_bins,
+ expert_capacity,
+ top_k=1,
+ w1=w1,
+ w2=w2,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ )
+
+ # Step 6: Reverse communication - send results back
+ x, _ = _layers.all_to_all.all_to_all(
+ parallel_x, send_counts, recv_counts, expert_parallel_group
+ )
+
+ # Step 7: Reduce across hidden sharding dimension
+ shape = (hidden_sharding_deg, -1, hidden_size)
+ x = x.view(shape).sum(dim=0)
+
+ # Step 8: Final local unpermutation
+ x = ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
+
+ return x, tokens_per_expert.flatten()
+
+
+def moe_forward(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: Optional[torch.Tensor],
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+ w1: torch.Tensor = None,
+ w2: torch.Tensor = None,
+ w1_bias: torch.Tensor = None,
+ w2_bias: torch.Tensor = None,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ forward_fn: Any = None,
+ hidden_size: int = None,
+ mlp_impl: str = "grouped",
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+
+ # Route tokens to experts
+ logits, expert_weights, expert_indices = route_tokens(
+ x,
+ router_weight,
+ router_bias,
+ moe_top_k,
+ moe_num_experts,
+ moe_jitter_eps,
+ moe_normalize_expert_weights,
+ uniform_expert_assignment,
+ training,
+ )
+
+ # Create router scores for output
+ router_scores = (
+ torch.zeros_like(logits)
+ .scatter_(1, expert_indices, expert_weights)
+ .transpose(0, 1)
+ )
+
+ in_shape = x.size()
+
+ # Prepare forward function arguments
+ forward_args = {
+ "x": x,
+ "expert_weights": expert_weights,
+ "top_experts": expert_indices,
+ "w1": w1,
+ "w2": w2,
+ "w1_bias": w1_bias,
+ "w2_bias": w2_bias,
+ "gradient_scale": gradient_scale,
+ "alpha": alpha,
+ "sort_end_bit": sort_end_bit,
+ "top_k": moe_top_k,
+ "num_experts": moe_num_experts,
+ "expert_parallel_group": expert_parallel_group,
+ "moe_capacity_factor": moe_capacity_factor,
+ "moe_expert_model_parallelism": moe_expert_model_parallelism,
+ "mlp_impl": mlp_impl,
+ }
+
+ # Add hidden_size for parallel forward
+ if moe_expert_model_parallelism and hidden_size is not None:
+ forward_args["hidden_size"] = hidden_size
+ elif moe_expert_model_parallelism and hidden_size is None:
+ # Infer hidden_size from input shape
+ forward_args["hidden_size"] = x.shape[-1]
+
+ # Compute expert outputs
+ x, tokens_per_expert = forward_fn(**forward_args)
+
+ # Save load balancing loss if needed
+ moe_loss_weight = 0.0 # Can be made configurable
+ if training and moe_loss_weight > 0:
+ save_load_balancing_loss((tokens_per_expert, logits))
+
+ # Restore original shape
+ x = x.view(in_shape)
+
+ return x, expert_weights, router_scores
+
+
+def moe_forward_with_shared_expert(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: Optional[torch.Tensor],
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+ w1: torch.Tensor = None,
+ w2: torch.Tensor = None,
+ w1_bias: torch.Tensor = None,
+ w2_bias: torch.Tensor = None,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ forward_fn: Any = None,
+ hidden_size: int = None,
+ mlp_impl: str = "grouped",
+ # Shared expert parameters
+ shared_up_proj_weight: Optional[torch.Tensor] = None,
+ shared_down_proj_weight: Optional[torch.Tensor] = None,
+ shared_up_proj_bias: Optional[torch.Tensor] = None,
+ shared_down_proj_bias: Optional[torch.Tensor] = None,
+ shared_expert_weighted_sum: bool = False,
+ shared_activation_fn: Optional[Any] = None,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+
+ # First, compute regular MoE forward pass
+ expert_out, expert_weights, router_scores = moe_forward(
+ x=x,
+ router_weight=router_weight,
+ router_bias=router_bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=training,
+ w1=w1,
+ w2=w2,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=moe_expert_model_parallelism,
+ forward_fn=forward_fn,
+ hidden_size=hidden_size,
+ mlp_impl=mlp_impl,
+ )
+
+ # If shared expert weights provided, compute shared expert output
+ if shared_up_proj_weight is not None and shared_down_proj_weight is not None:
+ shared_expert_out = shared_mlp_forward(
+ x=x,
+ up_proj_weight=shared_up_proj_weight,
+ down_proj_weight=shared_down_proj_weight,
+ up_proj_bias=shared_up_proj_bias,
+ down_proj_bias=shared_down_proj_bias,
+ activation_fn=shared_activation_fn,
+ gradient_scale=gradient_scale,
+ )
+
+ # Combine expert outputs
+ combined_out = combine_expert_shared_outputs(
+ shared_expert_out=shared_expert_out,
+ expert_out=expert_out,
+ shared_expert_weighted_sum=shared_expert_weighted_sum,
+ moe_top_k=moe_top_k,
+ )
+
+ return combined_out, expert_weights, router_scores
+
+ # Return regular MoE output if no shared expert
+ return expert_out, expert_weights, router_scores
+
+
+def create_shared_expert_weights(
+ hidden_size: int,
+ shared_expert_hidden_size: int,
+ device: torch.device,
+ dtype: torch.dtype,
+ init_method: Any,
+ output_layer_init_method: Any = None,
+) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
+
+ if output_layer_init_method is None:
+ output_layer_init_method = init_method
+
+ # Create weight tensors
+ up_proj_weight = torch.empty(
+ shared_expert_hidden_size,
+ hidden_size,
+ device=device,
+ dtype=dtype,
+ )
+ down_proj_weight = torch.empty(
+ hidden_size,
+ shared_expert_hidden_size,
+ device=device,
+ dtype=dtype,
+ )
+
+ # Initialize weights
+ init_method(up_proj_weight)
+ output_layer_init_method(down_proj_weight)
+
+ # No bias by default
+ return up_proj_weight, down_proj_weight, None, None
+
+
+# HACK: Extract device_mesh from pre-hook closure - required for transformers integration
+# This exists because device_mesh is trapped in hook closures with no model attribute
+# Fragile - breaks if hook structure changes or Python internals change
+# TODO: Replace with a more robust solution when available
+def get_device_mesh(model):
+ # Extract device_mesh from child's unused pre_hook closure
+ try:
+ # Find the pre-hook that contains 'device_mesh' in its closure
+ hook = next(
+ h
+ for h in model.experts._forward_pre_hooks.values()
+ if "device_mesh" in h.__code__.co_freevars
+ )
+ # Extract the device_mesh from the closure
+ return hook.__closure__[
+ hook.__code__.co_freevars.index("device_mesh")
+ ].cell_contents
+ except Exception:
+ return None
+
+
+class MegaBlocksMoeMLP(torch.nn.Module):
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ gradient_scale = getattr(self.experts, "gradient_scale", None)
+ alpha = getattr(self.experts, "alpha", 1.0)
+ moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+ uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
+
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = get_device_mesh(self)
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
+
+ has_parallel = (
+ expert_parallel_group is not None
+ and dist.is_initialized()
+ and dist.get_world_size(expert_parallel_group) > 1
+ )
+ forward_fn = parallel_forward_once if has_parallel else forward_once
+
+ sort_end_bit = max(
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
+ )
+ mlp_impl = getattr(self, "mlp_impl", "grouped")
+ output, expert_weights_out, *_ = moe_forward(
+ x=x,
+ router_weight=self.router.weight,
+ router_bias=self.router.bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=self.training,
+ w1=self.experts.gate_up_proj,
+ w2=self.experts.down_proj,
+ w1_bias=self.experts.gate_up_proj_bias,
+ w2_bias=self.experts.down_proj_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=has_parallel,
+ forward_fn=forward_fn,
+ hidden_size=self.experts.hidden_size,
+ mlp_impl=mlp_impl,
+ )
+ return output, expert_weights_out
+
+
+# Export main classes
+__all__ = ["MegaBlocksMoeMLP", "MegaBlocksMoeMLPWithSharedExpert"]
+
+
+class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
+
+ def __init__(self):
+ super().__init__()
+ # Shared expert weights will be set by the user
+ self.shared_up_proj_weight = None
+ self.shared_down_proj_weight = None
+ self.shared_up_proj_bias = None
+ self.shared_down_proj_bias = None
+ self.shared_expert_weighted_sum = False
+ self.shared_activation_fn = None
+
+ def set_shared_expert_weights(
+ self,
+ up_proj_weight: torch.Tensor,
+ down_proj_weight: torch.Tensor,
+ up_proj_bias: Optional[torch.Tensor] = None,
+ down_proj_bias: Optional[torch.Tensor] = None,
+ weighted_sum: bool = False,
+ activation_fn: Optional[Any] = None,
+ ):
+ self.shared_up_proj_weight = up_proj_weight
+ self.shared_down_proj_weight = down_proj_weight
+ self.shared_up_proj_bias = up_proj_bias
+ self.shared_down_proj_bias = down_proj_bias
+ self.shared_expert_weighted_sum = weighted_sum
+ self.shared_activation_fn = activation_fn
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ gradient_scale = getattr(self.experts, "gradient_scale", None)
+ alpha = getattr(self.experts, "alpha", 1.0)
+ moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+ uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
+
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = get_device_mesh(self)
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
+
+ has_parallel = (
+ expert_parallel_group is not None
+ and dist.is_initialized()
+ and dist.get_world_size(expert_parallel_group) > 1
+ )
+ forward_fn = parallel_forward_once if has_parallel else forward_once
+
+ sort_end_bit = max(
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
+ )
+ mlp_impl = getattr(self, "mlp_impl", "grouped")
+
+ output, expert_weights_out, *_ = moe_forward_with_shared_expert(
+ x=x,
+ router_weight=self.router.weight,
+ router_bias=self.router.bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=self.training,
+ w1=self.experts.gate_up_proj,
+ w2=self.experts.down_proj,
+ w1_bias=self.experts.gate_up_proj_bias,
+ w2_bias=self.experts.down_proj_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=has_parallel,
+ forward_fn=forward_fn,
+ hidden_size=self.experts.hidden_size,
+ mlp_impl=mlp_impl,
+ # Shared expert parameters
+ shared_up_proj_weight=self.shared_up_proj_weight,
+ shared_down_proj_weight=self.shared_down_proj_weight,
+ shared_up_proj_bias=self.shared_up_proj_bias,
+ shared_down_proj_bias=self.shared_down_proj_bias,
+ shared_expert_weighted_sum=self.shared_expert_weighted_sum,
+ shared_activation_fn=self.shared_activation_fn,
+ )
+ return output, expert_weights_out
+
+
+# Patch for XPU or CPU support
+if hasattr(torch, "xpu") and torch.xpu.is_available():
+ from .xpu_fused_moe import MegaBlocksMoeMLP
+
+from .cpu_moe_cpp import CPUMegaBlocksMoeMLP
diff --git a/build/torch211-cxx11-cu128-x86_64-linux/megablocks/__init__.py b/build/torch211-cxx11-cu128-x86_64-linux/megablocks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23
--- /dev/null
+++ b/build/torch211-cxx11-cu128-x86_64-linux/megablocks/__init__.py
@@ -0,0 +1,26 @@
+import ctypes
+import importlib.util
+import sys
+from pathlib import Path
+from types import ModuleType
+
+
+def _import_from_path(file_path: Path) -> ModuleType:
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
+ # it would also be used for other imports. So, we make a module name that
+ # depends on the path for it to be unique using the hex-encoded hash of
+ # the path.
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
+ module_name = path_hash
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ if spec is None:
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
+ module = importlib.util.module_from_spec(spec)
+ if module is None:
+ raise ImportError(f"Cannot load module {module_name} from spec")
+ sys.modules[module_name] = module
+ spec.loader.exec_module(module) # type: ignore
+ return module
+
+
+globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
diff --git a/build/torch211-cxx11-cu128-x86_64-linux/metadata.json b/build/torch211-cxx11-cu128-x86_64-linux/metadata.json
new file mode 100644
index 0000000000000000000000000000000000000000..e3e4edf582b7ffb515d0ed32e9fc9c89f125c441
--- /dev/null
+++ b/build/torch211-cxx11-cu128-x86_64-linux/metadata.json
@@ -0,0 +1,21 @@
+{
+ "version": 1,
+ "license": "Apache-2.0",
+ "python-depends": [],
+ "backend": {
+ "type": "cuda",
+ "archs": [
+ "10.0",
+ "10.1",
+ "12.0",
+ "7.0",
+ "7.2",
+ "7.5",
+ "8.0",
+ "8.6",
+ "8.7",
+ "8.9",
+ "9.0"
+ ]
+ }
+}
diff --git a/build/torch211-cxx11-cu128-x86_64-linux/ops/__init__.py b/build/torch211-cxx11-cu128-x86_64-linux/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b944080df810d0b0cfc571f3009b0098a651f9b7
--- /dev/null
+++ b/build/torch211-cxx11-cu128-x86_64-linux/ops/__init__.py
@@ -0,0 +1,35 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from .binned_gather import binned_gather
+from .binned_scatter import binned_scatter
+from .cumsum import exclusive_cumsum, inclusive_cumsum
+from .gather import gather
+from .histogram import histogram
+from .padded_gather import padded_gather
+from .padded_scatter import padded_scatter
+from .repeat import repeat
+from .replicate import replicate
+from .round_up import round_up
+from .scatter import scatter
+from .sort import sort
+from .sum import sum
+from .topology import topology
+
+__all__ = [
+ 'binned_gather',
+ 'binned_scatter',
+ 'exclusive_cumsum',
+ 'inclusive_cumsum',
+ 'gather',
+ 'histogram',
+ 'padded_gather',
+ 'padded_scatter',
+ 'repeat',
+ 'replicate',
+ 'round_up',
+ 'scatter',
+ 'sort',
+ 'sum',
+ 'topology',
+]
diff --git a/build/torch211-cxx11-cu128-x86_64-linux/ops/all_to_all_benchmark.py b/build/torch211-cxx11-cu128-x86_64-linux/ops/all_to_all_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c939818edca3345f6344bbc7cef07ffe3cd0181
--- /dev/null
+++ b/build/torch211-cxx11-cu128-x86_64-linux/ops/all_to_all_benchmark.py
@@ -0,0 +1,63 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import torch.distributed as dist
+
+# from megablocks import benchmark_util
+# from megablocks.layers.all_to_all import all_to_all
+
+from .. import benchmark_util
+from .._layers.all_to_all import all_to_all
+
+_ALL_TO_ALL_BENCHMARK = (
+ (8, 1024),
+ (16, 1024),
+ (32, 1024),
+ (64, 1024),
+ (128, 1024),
+ (256, 1024),
+ (512, 1024),
+ (1024, 1024),
+ (2 * 1024, 1024),
+ (4 * 1024, 1024),
+ (8 * 1024, 1024),
+ (16 * 1024, 1024),
+ (32 * 1024, 1024),
+ (64 * 1024, 1024),
+ (128 * 1024, 1024),
+ (256 * 1024, 1024),
+ (512 * 1024, 1024),
+ (1024 * 1024, 1024),
+)
+
+
+def benchmark_all_to_all(group, sl, hs):
+ world_size = dist.get_world_size(group)
+ assert (sl % world_size) == 0
+ send_recv_sizes = [sl // world_size] * world_size
+
+ x = torch.randn((sl, hs)).cuda().half()
+
+ details = {
+ 'world_size': world_size,
+ 'message_size (B)': send_recv_sizes[0] * hs * 2, # 2B elements.
+ }
+
+ def benchmark():
+ return all_to_all(x, send_recv_sizes, send_recv_sizes, group)
+
+ time, std = benchmark_util.benchmark_function(benchmark)
+
+ if dist.get_rank(group) == 0:
+ benchmark_util.log_benchmark('All-To-All', details, time, std)
+
+
+if __name__ == '__main__':
+ assert dist.is_available()
+ group = dist.init_process_group(backend='nccl')
+ local_rank = dist.get_rank(group)
+ torch.cuda.set_device(local_rank)
+
+ for args in _ALL_TO_ALL_BENCHMARK:
+ benchmark_all_to_all(group, *args)
diff --git a/build/torch211-cxx11-cu128-x86_64-linux/ops/binned_gather.py b/build/torch211-cxx11-cu128-x86_64-linux/ops/binned_gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..189a7fa3518d660f29ea32e7a04827164af98d60
--- /dev/null
+++ b/build/torch211-cxx11-cu128-x86_64-linux/ops/binned_gather.py
@@ -0,0 +1,37 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for binned_gather kernel.
+class BinnedGatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bins: torch.Tensor,
+ bin_size: int,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bins)
+ ctx.top_k = top_k
+ return kernels.binned_gather(x, indices, None, bins, bin_size, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ indices, bins = ctx.saved_tensors
+ out = kernels.binned_scatter(grad, indices, None, bins, ctx.top_k)
+ return out, None, None, None, None
+
+
+binned_gather = BinnedGatherOp.apply
diff --git a/build/torch211-cxx11-cu128-x86_64-linux/ops/binned_scatter.py b/build/torch211-cxx11-cu128-x86_64-linux/ops/binned_scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb937c0c106662ce8108c1cb926f8f063b163d3d
--- /dev/null
+++ b/build/torch211-cxx11-cu128-x86_64-linux/ops/binned_scatter.py
@@ -0,0 +1,59 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for binned_scatter kernel.
+class BinnedScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ):
+ assert len(x.size()) == 3
+ ctx.bin_size = x.size(1)
+ ctx.top_k = top_k
+
+ # TODO(tgale): Don't save 'x' for backwards if we don't need to
+ # calculate the gradient w.r.t. 'weights'.
+ ctx.save_for_backward(x, indices, weights, bins)
+ return kernels.binned_scatter(x, indices, weights, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ x, indices, weights, bins = ctx.saved_tensors
+ out = kernels.binned_gather(
+ grad,
+ indices,
+ weights,
+ bins,
+ ctx.bin_size,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[2]:
+ wgrad = kernels.binned_scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bins,
+ ctx.top_k,
+ )
+ return out, None, wgrad, None, None
+
+
+binned_scatter = BinnedScatterOp.apply
diff --git a/build/torch211-cxx11-cu128-x86_64-linux/ops/cumsum.py b/build/torch211-cxx11-cu128-x86_64-linux/ops/cumsum.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2b7572391e20045d335cf7337246e8a9b9f57ef
--- /dev/null
+++ b/build/torch211-cxx11-cu128-x86_64-linux/ops/cumsum.py
@@ -0,0 +1,52 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ # import megablocks_ops as ops # type: ignore
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrappers for cumsum kernels.
+# NOTE: Does not support gradients.
+class ExclusiveCumsumOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, dim: int):
+ if len(x.size()) == 1:
+ x = x.view([1, -1])
+ out = torch.empty_like(x)
+ ops.exclusive_cumsum(x, 1, out)
+ return out.squeeze()
+ out = torch.empty_like(x)
+ ops.exclusive_cumsum(x, dim, out)
+ return out
+
+
+exclusive_cumsum = ExclusiveCumsumOp.apply
+
+
+class InclusiveCumsumOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, dim: int) -> torch.Tensor:
+ if len(x.size()) == 1:
+ x = x.view([1, -1])
+ out = torch.empty_like(x)
+ ops.inclusive_cumsum(x, 1, out)
+ return out.squeeze()
+ out = torch.empty_like(x)
+ ops.inclusive_cumsum(x, dim, out)
+ return out
+
+
+inclusive_cumsum = InclusiveCumsumOp.apply
diff --git a/build/torch211-cxx11-cu128-x86_64-linux/ops/gather.py b/build/torch211-cxx11-cu128-x86_64-linux/ops/gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1f87c1e7bed8d3589dd790805234976e0b05898
--- /dev/null
+++ b/build/torch211-cxx11-cu128-x86_64-linux/ops/gather.py
@@ -0,0 +1,38 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for gather kernel.
+class GatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bin_ids, bins)
+ ctx.top_k = top_k
+ return kernels.gather(x, indices, bin_ids, None, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+
+ indices, bin_ids, bins = ctx.saved_tensors
+ out = kernels.scatter(grad, indices, bin_ids, None, bins, ctx.top_k)
+ return out, None, None, None, None, None
+
+
+gather = GatherOp.apply
diff --git a/build/torch211-cxx11-cu128-x86_64-linux/ops/histogram.py b/build/torch211-cxx11-cu128-x86_64-linux/ops/histogram.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b3f058ec373cbba7555704fb5e4212c3cc75d9d
--- /dev/null
+++ b/build/torch211-cxx11-cu128-x86_64-linux/ops/histogram.py
@@ -0,0 +1,27 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for histogram kernel.
+# NOTE: Does not support gradients.
+class HistogramOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, max_val: float):
+ return ops.histogram(x, max_val)
+
+
+histogram = HistogramOp.apply
diff --git a/build/torch211-cxx11-cu128-x86_64-linux/ops/histogram_benchmark.py b/build/torch211-cxx11-cu128-x86_64-linux/ops/histogram_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0b9c6047567b87a295979498142230d1b0c9411
--- /dev/null
+++ b/build/torch211-cxx11-cu128-x86_64-linux/ops/histogram_benchmark.py
@@ -0,0 +1,78 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import numpy as np
+import torch
+# from absl.testing import parameterized
+
+from .. import ops
+
+_HISTOGRAM_TESTS = (
+ (16384, torch.int32, 2),
+ (16384, torch.int32, 4),
+ (16384, torch.int32, 8),
+ (16384, torch.int32, 16),
+ (16384, torch.int32, 32),
+ (16384, torch.int32, 64),
+ (16384, torch.int32, 128),
+ (16384, torch.int32, 256),
+)
+
+
+def benchmark_function(fn, iterations=10):
+ # Run once to get rid of startup overhead.
+ fn()
+ times = []
+ for _ in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+ start.record()
+ fn()
+ end.record()
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ times = np.array(times)
+ return times.mean(), times.std(), times.max(), times.min()
+
+
+def log_benchmark(arguments, mean_t, std_t):
+ print('=' * 60)
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t))
+ print('=' * 60)
+
+
+# class HistogramBenchmark(parameterized.TestCase):
+#
+# @parameterized.parameters(*_HISTOGRAM_TESTS)
+# def testHistogram(self, n, dtype, max_val):
+# x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
+#
+# mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.histogram(x, max_val),)
+# arguments = {
+# 'n': n,
+# 'dtype': dtype,
+# 'max_val': max_val,
+# }
+# log_benchmark(arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_HISTOGRAM_TESTS)
+# def testTorchHistogram(self, n, dtype, max_val):
+# x = torch.randint(0, 128, (n,)).cuda().to(dtype)
+#
+# mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.histc(x, max_val, 0, max_val - 1),)
+# arguments = {
+# 'n': n,
+# 'dtype': dtype,
+# 'max_val': max_val,
+# }
+# log_benchmark(arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch211-cxx11-cu128-x86_64-linux/ops/matmul_benchmark.py b/build/torch211-cxx11-cu128-x86_64-linux/ops/matmul_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c4b9b8866ffed2eb769b77f2320c82e5491ae0e
--- /dev/null
+++ b/build/torch211-cxx11-cu128-x86_64-linux/ops/matmul_benchmark.py
@@ -0,0 +1,415 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+
+# import stk
+
+# try:
+# import stk
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/ops/matmul_benchmark.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+# from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+
+# Calling tensor.t() calls tensor.transpose(0, 1) which calls
+# torch.as_strided(...). Circumvent this chain to avoid an overhead
+# this adds.
+def transpose_view(x):
+ return torch.as_strided(
+ x,
+ (x.shape[1], x.shape[0]),
+ (x.stride()[1], x.stride()[0]),
+ )
+
+
+_MATMUL_TESTS = (
+ (64 * 1024, 512, 2048, 64),
+ (32 * 1024, 768, 3072, 64),
+ (8 * 1024, 1024, 4096, 64),
+ (4 * 2048, 4096, 4 * 4096, 4),
+)
+
+
+def log_benchmark(name, arguments, time, std, flops):
+ benchmark_util.log_benchmark(name, arguments, time, std)
+ print('flops = {:.2f}B'.format(flops / 1e9))
+ print('throughput = {:.2f}T'.format(flops / 1e9 / time))
+ print('=' * 60)
+
+
+# class MatmulBenchmark(parameterized.TestCase):
+#
+# def build_sparse_matrix(self, x, padded_bins, fhs, ne):
+# blocking = 128
+# padded_tokens, _ = x.size()
+# assert padded_tokens % blocking == 0
+# assert fhs % blocking == 0
+#
+# # Offsets for the sparse matrix. All rows have the
+# # same number of nonzero blocks dictated by the
+# # dimensionality of a single expert.
+# block_rows = padded_tokens // blocking
+# blocks_per_row = fhs // blocking
+# offsets = torch.arange(
+# 0,
+# block_rows * blocks_per_row + 1,
+# blocks_per_row,
+# dtype=torch.int32,
+# device=x.device,
+# )
+#
+# # Indices for the sparse matrix. The indices for
+# # the intermediate matrix are dynamic depending
+# # on the mapping of tokens to experts.
+# column_indices = ops.topology(
+# padded_bins,
+# blocking,
+# block_rows,
+# blocks_per_row,
+# )
+# data = torch.empty(
+# column_indices.numel(),
+# blocking,
+# blocking,
+# dtype=torch.float16,
+# device=x.device,
+# )
+# shape = (padded_tokens, fhs * ne)
+# row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
+# return stk.Matrix(shape, data, row_indices, column_indices, offsets)
+#
+# def build_input_matrix(self, sl, hs, ne):
+# x = torch.randn((sl, hs)).cuda().half()
+#
+# # Assign tokens to experts uniformly.
+# top_expert = torch.arange(0, sl).cuda().int() % ne
+#
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(top_expert, ne)
+# padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+# padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+# out = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, 1)
+# return out, padded_bins
+#
+# def build_weight_matrix(self, ne, hs, fhs):
+# return torch.randn((hs, ne * fhs)).cuda().half()
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_Fwd_SDD_NT(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+# w = transpose_view(w)
+#
+# def benchmark():
+# return stk.ops.sdd(x, w, topo)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0::Fwd::SDD::NT',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_GradX_DSD_NN(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+#
+# def benchmark():
+# return stk.ops.dsd(topo, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0::GradX::DSD::NN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_GradW_DSD_TN(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+# topo = topo.t()
+#
+# def benchmark():
+# return stk.ops.dsd(topo, x)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0::GradW::DSD::TN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_Fwd_DSD_NN(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+#
+# def benchmark():
+# return stk.ops.dsd(x, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::Fwd::DSD::NN',
+# arguments,
+# mean_t,
+# std_t,
+# x.nnz * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_GradX_SDD_NT(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+# out = stk.ops.dsd(x, w)
+# w = transpose_view(w)
+#
+# def benchmark():
+# return stk.ops.sdd(out, w, x)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::GradX::SDD::NT',
+# arguments,
+# mean_t,
+# std_t,
+# x.nnz * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_GradW_DSD_TN(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+# out = stk.ops.dsd(x, w)
+# x = x.t()
+#
+# def benchmark():
+# return stk.ops.dsd(x, out)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::GradW::DSD::TN',
+# arguments,
+# mean_t,
+# std_t,
+# x.nnz * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_Fwd_DDD_NT(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, hs)).cuda().half()
+# w = torch.randn((ne, hs, fhs)).cuda().half()
+#
+# w = w.transpose(1, 2).contiguous()
+# w = w.transpose(1, 2)
+#
+# def benchmark():
+# return torch.bmm(x, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0::Fwd:DDD::NT',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_GradX_DDD_NN(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, hs)).cuda().half()
+# w = torch.randn((ne, hs, fhs)).cuda().half()
+# out = torch.bmm(x, w)
+# w = w.transpose(1, 2).contiguous()
+#
+# def benchmark():
+# return torch.bmm(out, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0:GradX:DDD::NN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_GradW_DDD_TN(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, hs)).cuda().half()
+# w = torch.randn((ne, hs, fhs)).cuda().half()
+# out = torch.bmm(x, w)
+# out = out.transpose(1, 2)
+#
+# def benchmark():
+# return torch.bmm(out, x)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0:GradW:DDD::TN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_Fwd_DDD_NN(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+# w = torch.randn((ne, fhs, hs)).cuda().half()
+#
+# def benchmark():
+# return torch.bmm(x, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::Fwd::DDD::NN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_GradX_DDD_NT(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+# w = torch.randn((ne, fhs, hs)).cuda().half()
+# out = torch.bmm(x, w)
+# w = torch.transpose(w, 1, 2)
+#
+# def benchmark():
+# return torch.bmm(out, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::GradX::DDD::NT',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_GradW_DDD_TN(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+# w = torch.randn((ne, fhs, hs)).cuda().half()
+# out = torch.bmm(x, w)
+# x = torch.transpose(x, 1, 2)
+#
+# def benchmark():
+# return torch.bmm(x, out)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::GradW::DDD::TN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * hs * 2,
+# )
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch211-cxx11-cu128-x86_64-linux/ops/padded_gather.py b/build/torch211-cxx11-cu128-x86_64-linux/ops/padded_gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1cf4047c9494394d2a3884ba8830179013db7ff
--- /dev/null
+++ b/build/torch211-cxx11-cu128-x86_64-linux/ops/padded_gather.py
@@ -0,0 +1,55 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for padded_gather kernel.
+class PaddedGatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bin_ids, bins, padded_bins)
+ ctx.top_k = top_k
+ return kernels.padded_gather(
+ x,
+ indices,
+ bin_ids,
+ None,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+
+ indices, bin_ids, bins, padded_bins = ctx.saved_tensors
+ out = kernels.padded_scatter(
+ grad,
+ indices,
+ bin_ids,
+ None,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+ return out, None, None, None, None, None
+
+
+padded_gather = PaddedGatherOp.apply
diff --git a/build/torch211-cxx11-cu128-x86_64-linux/ops/padded_scatter.py b/build/torch211-cxx11-cu128-x86_64-linux/ops/padded_scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..61e021b81497e472cda5d72bdac557a0ca92d262
--- /dev/null
+++ b/build/torch211-cxx11-cu128-x86_64-linux/ops/padded_scatter.py
@@ -0,0 +1,98 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for padded_scatter kernel.
+class PaddedScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+ ):
+ maybe_x = [x] if ctx.needs_input_grad[3] else []
+ ctx.save_for_backward(
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ *maybe_x,
+ )
+ ctx.top_k = top_k
+ ctx.x_shape = x.shape
+ return kernels.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ saved_tensors = ctx.saved_tensors
+
+ indices, bin_ids, weights, bins, padded_bins = saved_tensors[:5]
+ dgrad = None
+ if ctx.needs_input_grad[0]:
+ dgrad = kernels.padded_gather(
+ grad,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[3]: # need wgrad
+ x = saved_tensors[-1]
+ wgrad = kernels.padded_scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+ return dgrad, None, None, wgrad, None, None, None, None
+
+
+def padded_scatter(
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+):
+ return PaddedScatterOp.apply(
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
diff --git a/build/torch211-cxx11-cu128-x86_64-linux/ops/padded_scatter_benchmark.py b/build/torch211-cxx11-cu128-x86_64-linux/ops/padded_scatter_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..3cbe4735891446b46f93170c64c23fe63632bf93
--- /dev/null
+++ b/build/torch211-cxx11-cu128-x86_64-linux/ops/padded_scatter_benchmark.py
@@ -0,0 +1,66 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import torch
+# from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+_PADDED_SCATTER_BENCHMARK = (
+ # dMoE-Medium, 8-way EMP.
+ (1024 * 16, 1024, 8, 4),
+ # dMoE-Medium, post-all-to-all.
+ (1024 * 16 * 4, 1024, 8, 1),
+)
+
+
+# class PaddedScatterTest(parameterized.TestCase):
+#
+# @parameterized.parameters(*_PADDED_SCATTER_BENCHMARK)
+# def testPaddedScatter(self, sl, hs, ne, top_k):
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+#
+# # Randomly assign tokens to experts.
+# top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(top_expert, ne)
+# padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+# padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+#
+# # Sample weights for the scatter reduce.
+# weights = torch.rand((sl * top_k,)).cuda().half()
+#
+# # Gather the data to prepare for backwards.
+# x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
+#
+# def benchmark():
+# return ops.padded_scatter(
+# x,
+# indices,
+# bin_ids,
+# weights,
+# bins,
+# padded_bins,
+# top_k,
+# )
+#
+# time, std = benchmark_util.benchmark_function(benchmark)
+# benchmark_util.log_benchmark(
+# 'Padded Scatter',
+# {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# 'top_k': top_k,
+# },
+# time,
+# std,
+# )
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch211-cxx11-cu128-x86_64-linux/ops/permute_benchmark.py b/build/torch211-cxx11-cu128-x86_64-linux/ops/permute_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..697abddbb3a2082ec4ddd6d94f89f7faabb34b40
--- /dev/null
+++ b/build/torch211-cxx11-cu128-x86_64-linux/ops/permute_benchmark.py
@@ -0,0 +1,149 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import torch
+# from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+_PERMUTE_TESTS = (
+ (16384, 768, 2),
+ (16384, 768, 4),
+ (16384, 768, 8),
+ (16384, 768, 16),
+ (16384, 768, 32),
+ (16384, 768, 64),
+ (16384, 768, 128),
+ (16384 * 8, 768, 2),
+ (16384 * 8, 768, 4),
+ (16384 * 8, 768, 8),
+ (16384 * 8, 768, 16),
+ (16384 * 8, 768, 32),
+ (16384 * 8, 768, 64),
+ (16384 * 8, 768, 128),
+)
+
+
+# class PermuteBenchmark(parameterized.TestCase):
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testBinnedGather(self, sl, hs, ne):
+# # NOTE: Capacity factor == 1.
+# ec = sl // ne
+#
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+# top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(indices, ne)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+#
+# def benchmark():
+# return ops.binned_gather(x, indices, bins, ec)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('BinnedGather', arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testBinnedScatter(self, sl, hs, ne):
+# # NOTE: Capacity factor == 1.
+# ec = sl // ne
+#
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+# top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(indices, ne)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+# x = ops.binned_gather(x, indices, bins, ec)
+#
+# def benchmark():
+# return ops.binned_scatter(x, indices, bins)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('BinnedScatter', arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testPaddedGather(self, sl, hs, ne):
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+#
+# # Randomly assign tokens to experts.
+# top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(top_expert, ne)
+# padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+# padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+#
+# def benchmark():
+# return ops.padded_gather(x, indices, bin_ids, bins, padded_bins)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('PaddedGather', arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testPaddedScatter(self, sl, hs, ne):
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+#
+# # Randomly assign tokens to experts.
+# top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(top_expert, ne)
+# padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+# padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+# x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins)
+#
+# def benchmark():
+# return ops.padded_scatter(x, indices, bin_ids, bins, padded_bins)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('PaddedScatter', arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testCopy(self, sl, hs, ne):
+# # NOTE: Capacity factor == 1.
+# # ec = sl // ne
+#
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+# y = x.clone()
+#
+# def benchmark():
+# return y.copy_(x)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('Copy', arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch211-cxx11-cu128-x86_64-linux/ops/repeat.py b/build/torch211-cxx11-cu128-x86_64-linux/ops/repeat.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e9e09de5f857d51cd758ab30b2f3a846d6f9275
--- /dev/null
+++ b/build/torch211-cxx11-cu128-x86_64-linux/ops/repeat.py
@@ -0,0 +1,10 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+
+def repeat(x: torch.Tensor, tiling: torch.Size):
+ if all((t == 1 for t in tiling)):
+ return x
+ return x.repeat(*tiling)
diff --git a/build/torch211-cxx11-cu128-x86_64-linux/ops/replicate.py b/build/torch211-cxx11-cu128-x86_64-linux/ops/replicate.py
new file mode 100644
index 0000000000000000000000000000000000000000..26daf0eede330603a4b8ea7167faf1411d07ca93
--- /dev/null
+++ b/build/torch211-cxx11-cu128-x86_64-linux/ops/replicate.py
@@ -0,0 +1,36 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for replicate kernel.
+class ReplicateOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, bins: torch.Tensor, num_outputs: int):
+ ctx.save_for_backward(bins)
+ out = torch.empty((x.shape[0], num_outputs), dtype=x.dtype, device=x.device)
+ ops.replicate_forward(x, bins, out)
+ return out
+
+ @staticmethod
+ def backward(ctx: Any, grad: torch.Tensor):
+ bins, = ctx.saved_tensors
+ out = torch.empty((grad.shape[0], bins.shape[0]), dtype=grad.dtype, device=grad.device)
+ ops.replicate_backward(grad, bins, out)
+ return out, None, None
+
+
+replicate = ReplicateOp.apply
diff --git a/build/torch211-cxx11-cu128-x86_64-linux/ops/round_up.py b/build/torch211-cxx11-cu128-x86_64-linux/ops/round_up.py
new file mode 100644
index 0000000000000000000000000000000000000000..6cf6bc873c9f448c5fa9126ebcfd66e8688002af
--- /dev/null
+++ b/build/torch211-cxx11-cu128-x86_64-linux/ops/round_up.py
@@ -0,0 +1,14 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+
+def round_up(x: torch.Tensor, value: int):
+ assert isinstance(value, int)
+ assert x.dtype == torch.int32
+
+ # TODO(tgale): If this becomes and issue
+ # do this in a custom kernel. We only expect
+ # to use this on arrays of less than 1k elements.
+ return torch.div(x + (value - 1), value, rounding_mode='trunc') * value
diff --git a/build/torch211-cxx11-cu128-x86_64-linux/ops/scatter.py b/build/torch211-cxx11-cu128-x86_64-linux/ops/scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4605d9b46f387761b070352365f223dbfe69d47
--- /dev/null
+++ b/build/torch211-cxx11-cu128-x86_64-linux/ops/scatter.py
@@ -0,0 +1,72 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Optional
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for scatter kernel.
+class ScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ) -> torch.Tensor:
+ maybe_x = [x] if ctx.needs_input_grad[3] else []
+ ctx.save_for_backward(indices, bin_ids, weights, bins, *maybe_x)
+ ctx.top_k = top_k
+ ctx.x_shape = x.shape
+ return kernels.scatter(x, indices, bin_ids, weights, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ saved_tensors = ctx.saved_tensors
+
+ indices, bin_ids, weights, bins = saved_tensors[:4]
+ dgrad = None
+ if ctx.needs_input_grad[0]:
+ dgrad = kernels.gather(
+ grad,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[3]: # need wgrad
+ x = saved_tensors[-1]
+ wgrad = kernels.scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bin_ids,
+ bins,
+ ctx.top_k,
+ )
+ return dgrad, None, None, wgrad, None, None, None
+
+
+def scatter(
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+) -> Optional[torch.Tensor]:
+ return ScatterOp.apply(x, indices, bin_ids, weights, bins, top_k)
diff --git a/build/torch211-cxx11-cu128-x86_64-linux/ops/sort.py b/build/torch211-cxx11-cu128-x86_64-linux/ops/sort.py
new file mode 100644
index 0000000000000000000000000000000000000000..bda3bf64283e39533c2eae3627e76bb2d0262c9f
--- /dev/null
+++ b/build/torch211-cxx11-cu128-x86_64-linux/ops/sort.py
@@ -0,0 +1,38 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Optional, Tuple
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+_BITS_FOR_DTYPE = {
+ torch.int16: 16,
+ torch.int32: 32,
+ torch.int64: 64,
+}
+
+
+# Autograd wrapper for sort kernel.
+# NOTE: Does not support gradients.
+class SortOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, end_bit: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
+ if end_bit is None:
+ end_bit = _BITS_FOR_DTYPE[x.dtype]
+ x_out = torch.empty_like(x)
+ iota_out = torch.empty_like(x)
+ ops.sort(x, end_bit, x_out, iota_out)
+ return (x_out, iota_out)
+
+
+sort = SortOp.apply
diff --git a/build/torch211-cxx11-cu128-x86_64-linux/ops/sort_benchmark.py b/build/torch211-cxx11-cu128-x86_64-linux/ops/sort_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..11043c0824c36372585f1d9f48480c2a6ef32eb6
--- /dev/null
+++ b/build/torch211-cxx11-cu128-x86_64-linux/ops/sort_benchmark.py
@@ -0,0 +1,85 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import numpy as np
+import torch
+# from absl.testing import parameterized
+
+from .. import ops
+
+_SORT_TESTS = (
+ (16384, torch.int32, None),
+ (16384, torch.int32, 2),
+ (16384, torch.int32, 128),
+)
+
+_BASELINE_SORT_TESTS = ((16384,),)
+
+
+def numpy_dtype(dtype):
+ types = {
+ torch.int16: np.int16,
+ torch.int32: np.int32,
+ torch.int64: np.int64,
+ }
+ return types[dtype]
+
+
+def benchmark_function(fn, iterations=10):
+ # Run once to get rid of startup overhead.
+ fn()
+ times = []
+ for _ in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+ start.record()
+ fn()
+ end.record()
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ times = np.array(times)
+ return times.mean(), times.std(), times.max(), times.min()
+
+
+def log_benchmark(arguments, mean_t, std_t):
+ print('=' * 60)
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t))
+ print('=' * 60)
+
+
+# class SortBenchmark(parameterized.TestCase):
+#
+# @parameterized.parameters(*_SORT_TESTS)
+# def testSort(self, n, dtype, max_val):
+# if max_val is None:
+# max_val = np.iinfo(numpy_dtype(dtype)).max
+# end_bit = int(np.ceil(np.log2(max_val)))
+# x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
+#
+# mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.sort(x, end_bit),)
+# arguments = {
+# 'n': n,
+# 'dtype': dtype,
+# 'max_val': max_val,
+# }
+# log_benchmark(arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_BASELINE_SORT_TESTS)
+# def testTorchSort(self, n):
+# x = torch.randint(0, 128, (n,)).cuda().to(torch.int32)
+#
+# mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.sort(x))
+# arguments = {
+# 'n': n,
+# }
+# log_benchmark(arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch211-cxx11-cu128-x86_64-linux/ops/stk_autocast.py b/build/torch211-cxx11-cu128-x86_64-linux/ops/stk_autocast.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a3626e5e0eec51339c95a448bca84be14a2ca93
--- /dev/null
+++ b/build/torch211-cxx11-cu128-x86_64-linux/ops/stk_autocast.py
@@ -0,0 +1,39 @@
+# vendored from
+# https://github.com/stanford-futuredata/stk/blob/736313768ef697ce13a0594a41b2512a0fbc9884/stk/backend/autocast.py
+import functools
+import torch
+
+
+def _is_eligible(x):
+ return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64)
+
+
+def _cast(x, dtype):
+ if isinstance(x, torch.Tensor) and _is_eligible(x):
+ return x.to(dtype)
+ elif isinstance(x, map):
+ return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()}
+ elif isinstance(x, list) or isinstance(x, tuple):
+ return type(x)(map(lambda y: _cast(y, dtype), x))
+ return x
+
+
+def custom_fwd(fwd):
+ """Wrap a custom autograd function that always uses autocast dtype."""
+
+ @functools.wraps(fwd)
+ def decorate_fwd(*args, **kwargs):
+ if torch.is_autocast_enabled():
+ with torch.autocast(device_type="cuda", enabled=False):
+ dtype = torch.get_autocast_gpu_dtype()
+ return fwd(*_cast(args, dtype), **_cast(kwargs, dtype))
+ return fwd(*args, **kwargs)
+ return decorate_fwd
+
+
+def custom_bwd(bwd):
+ @functools.wraps(bwd)
+ def decorate_bwd(*args, **kwargs):
+ with torch.autocast(device_type="cuda", enabled=False):
+ return bwd(*args, **kwargs)
+ return decorate_bwd
\ No newline at end of file
diff --git a/build/torch211-cxx11-cu128-x86_64-linux/ops/sum.py b/build/torch211-cxx11-cu128-x86_64-linux/ops/sum.py
new file mode 100644
index 0000000000000000000000000000000000000000..e00c1aa68e1193f5b72f75a2edc37de8d505facc
--- /dev/null
+++ b/build/torch211-cxx11-cu128-x86_64-linux/ops/sum.py
@@ -0,0 +1,9 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+import torch
+
+
+def sum(x: torch.Tensor, dim: int = 0):
+ if x.shape[dim] == 1:
+ return x.squeeze(dim=dim)
+ return x.sum(dim=dim)
diff --git a/build/torch211-cxx11-cu128-x86_64-linux/ops/topology.py b/build/torch211-cxx11-cu128-x86_64-linux/ops/topology.py
new file mode 100644
index 0000000000000000000000000000000000000000..76a50d3164db20534b099dcb4d8487a7aef25d15
--- /dev/null
+++ b/build/torch211-cxx11-cu128-x86_64-linux/ops/topology.py
@@ -0,0 +1,45 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for topology kernel.
+# NOTE: Does not support gradients.
+class TopologyOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(
+ ctx: Any,
+ padded_bins: torch.Tensor,
+ block_size: int,
+ output_block_rows: int,
+ output_block_columns: int,
+ ):
+ out = torch.empty(
+ output_block_rows * output_block_columns,
+ dtype=torch.int16,
+ device=padded_bins.device,
+ )
+ ops.indices(
+ padded_bins,
+ block_size,
+ output_block_rows,
+ output_block_columns,
+ out,
+ )
+ return out
+
+
+topology = TopologyOp.apply
diff --git a/build/torch211-cxx11-cu128-x86_64-linux/stk/__init__.py b/build/torch211-cxx11-cu128-x86_64-linux/stk/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..73c40c267b1c3f4949e9c957a5d2c9f682dfc1a6
--- /dev/null
+++ b/build/torch211-cxx11-cu128-x86_64-linux/stk/__init__.py
@@ -0,0 +1,7 @@
+# import stk.random
+# import stk.ops
+# from stk.matrix import Matrix
+
+from . import random
+from . import ops
+from .matrix import Matrix
diff --git a/build/torch211-cxx11-cu128-x86_64-linux/stk/backend/__init__.py b/build/torch211-cxx11-cu128-x86_64-linux/stk/backend/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/build/torch211-cxx11-cu128-x86_64-linux/stk/backend/autocast.py b/build/torch211-cxx11-cu128-x86_64-linux/stk/backend/autocast.py
new file mode 100644
index 0000000000000000000000000000000000000000..97f6e919a60f3fd579ed0215031008d14111dc96
--- /dev/null
+++ b/build/torch211-cxx11-cu128-x86_64-linux/stk/backend/autocast.py
@@ -0,0 +1,37 @@
+import functools
+import torch
+
+
+def _is_eligible(x):
+ return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64)
+
+
+def _cast(x, dtype):
+ if isinstance(x, torch.Tensor) and _is_eligible(x):
+ return x.to(dtype)
+ elif isinstance(x, map):
+ return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()}
+ elif isinstance(x, list) or isinstance(x, tuple):
+ return type(x)(map(lambda y: _cast(y, dtype), x))
+ return x
+
+
+def custom_fwd(fwd):
+ """Wrap a custom autograd function that always uses autocast dtype."""
+
+ @functools.wraps(fwd)
+ def decorate_fwd(*args, **kwargs):
+ if torch.is_autocast_enabled():
+ with torch.autocast(device_type="cuda", enabled=False):
+ dtype = torch.get_autocast_gpu_dtype()
+ return fwd(*_cast(args, dtype), **_cast(kwargs, dtype))
+ return fwd(*args, **kwargs)
+ return decorate_fwd
+
+
+def custom_bwd(bwd):
+ @functools.wraps(bwd)
+ def decorate_bwd(*args, **kwargs):
+ with torch.autocast(device_type="cuda", enabled=False):
+ return bwd(*args, **kwargs)
+ return decorate_bwd
diff --git a/build/torch211-cxx11-cu128-x86_64-linux/stk/backend/sputnik.py b/build/torch211-cxx11-cu128-x86_64-linux/stk/backend/sputnik.py
new file mode 100644
index 0000000000000000000000000000000000000000..220c947bc1e932e8c77cc30f4069e9930f1aa962
--- /dev/null
+++ b/build/torch211-cxx11-cu128-x86_64-linux/stk/backend/sputnik.py
@@ -0,0 +1,316 @@
+import torch
+
+from ..backend import triton_kernels as backend
+from ..backend.autocast import custom_bwd, custom_fwd
+
+
+def _standardize_shape(x, transpose):
+ if transpose:
+ return torch.Size((x[1], x[0]))
+ return x
+
+
+def _sparse_transpose(x):
+ return (torch.Size((x[0][1], x[0][0])), ) + x[1:]
+
+
+def _transpose_helper(x, transpose):
+ if isinstance(x, torch.Tensor):
+ return x.t() if transpose else x
+ if transpose:
+ x = _sparse_transpose(x)
+ return x + (transpose,)
+
+
+def _wrap(x):
+ if isinstance(x, torch.Tensor):
+ return (x,)
+ return x
+
+
+def _is_transposed(x):
+ return (not x.is_contiguous() and
+ x.stride()[0] == 1 and
+ x.stride()[1] == x.size()[0])
+
+
+def _call_helper(op, out, a, b, trans_a, trans_b):
+ args = (_wrap(_transpose_helper(a, trans_a)) +
+ _wrap(_transpose_helper(b, trans_b)))
+ if isinstance(out, tuple):
+ args = args + out
+ return op(*args)
+
+
+def _preprocess_inputs(lhs, rhs, dy):
+ if isinstance(lhs, torch.Tensor) and _is_transposed(lhs):
+ lhs = lhs.t()
+ if isinstance(rhs, torch.Tensor) and _is_transposed(rhs):
+ rhs = rhs.t()
+ if (isinstance(dy, torch.Tensor) and
+ not dy.is_contiguous() and
+ not _is_transposed(dy)):
+ dy = dy.contiguous()
+ if isinstance(dy, tuple) and not dy[1].is_contiguous():
+ dy = (dy[0], dy[1].contiguous()) + dy[2:]
+ return lhs, rhs, dy
+
+
+def _postprocess_outputs(x, transpose, grad):
+ if isinstance(x, torch.Tensor) and transpose:
+ return grad.t()
+ return grad
+
+
+def _lhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs):
+ lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy)
+
+ a, b = (rhs, dy) if trans_lhs else (dy, rhs)
+ trans_a = trans_lhs and trans_rhs
+ trans_b = trans_lhs or not trans_rhs
+ out = _call_helper(op, lhs, a, b, trans_a, trans_b)
+ return _postprocess_outputs(lhs, trans_lhs, out)
+
+
+def _rhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs):
+ lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy)
+
+ a, b = (dy, lhs) if trans_rhs else (lhs, dy)
+ trans_a = not trans_lhs or trans_rhs
+ trans_b = trans_lhs and trans_rhs
+ out = _call_helper(op, rhs, a, b, trans_a, trans_b)
+ return _postprocess_outputs(rhs, trans_rhs, out)
+
+
+class DSD(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs):
+ ctx.save_for_backward(data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ rhs)
+ ctx.shape = _standardize_shape(shape, transpose_a)
+ ctx.transpose_a = transpose_a
+
+ out = torch.empty(
+ (shape[0], rhs.size()[1]),
+ dtype=rhs.dtype,
+ device=rhs.device)
+
+ backend.dsd(shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs,
+ out)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs = (ctx.shape,) + saved_tensors[:-1]
+ rhs = saved_tensors[-1]
+ trans_a = ctx.transpose_a
+ trans_b = _is_transposed(rhs)
+
+ ddata = None
+ if ctx.needs_input_grad[1]:
+ ddata = _lhs_gradient(sdd,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ drhs = None
+ if ctx.needs_input_grad[-1]:
+ op = dds if trans_b else dsd
+ drhs = _rhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return None, ddata, None, None, None, None, None, None, None, drhs
+
+
+dsd = DSD.apply
+
+
+class DDS(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b):
+ ctx.save_for_backward(lhs,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t)
+ ctx.shape = _standardize_shape(shape, transpose_b)
+ ctx.transpose_b = transpose_b
+ out = torch.empty((lhs.size()[0], shape[1]),
+ dtype=lhs.dtype,
+ device=lhs.device)
+ backend.dds(lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b,
+ out)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs = saved_tensors[0]
+ rhs = (ctx.shape,) + saved_tensors[1:]
+ trans_a = _is_transposed(lhs)
+ trans_b = ctx.transpose_b
+
+ dlhs = None
+ if ctx.needs_input_grad[0]:
+ op = dsd if trans_a else dds
+ dlhs = _lhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ ddata = None
+ if ctx.needs_input_grad[2]:
+ ddata = _rhs_gradient(sdd,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return dlhs, None, ddata, None, None, None, None, None, None, None
+
+
+dds = DDS.apply
+
+
+class SDD(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ lhs,
+ rhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t):
+ ctx.save_for_backward(
+ lhs,
+ rhs,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t)
+ ctx.shape = shape
+ out = torch.empty(
+ data.shape,
+ dtype=lhs.dtype,
+ device=lhs.device)
+ backend.sdd(lhs,
+ rhs,
+ shape,
+ out,
+ offsets,
+ row_indices,
+ column_indices)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs, rhs = saved_tensors[:2]
+ dy = (ctx.shape, dy) + saved_tensors[2:]
+ trans_a = _is_transposed(lhs)
+ trans_b = _is_transposed(rhs)
+
+ dlhs = None
+ if ctx.needs_input_grad[0]:
+ op = dds if trans_a else dsd
+ dlhs = _lhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ drhs = None
+ if ctx.needs_input_grad[1]:
+ op = dsd if trans_b else dds
+ drhs = _rhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return dlhs, drhs, None, None, None, None, None, None, None, None
+
+
+sdd = SDD.apply
+
+class RowIndices(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, shape, data, offsets, column_indices):
+ out = torch.empty(
+ column_indices.shape,
+ dtype=column_indices.dtype,
+ device=column_indices.device)
+ backend.row_indices(shape, data, offsets, column_indices, out)
+ return out
+
+
+row_indices = RowIndices.apply
diff --git a/build/torch211-cxx11-cu128-x86_64-linux/stk/backend/triton_kernels.py b/build/torch211-cxx11-cu128-x86_64-linux/stk/backend/triton_kernels.py
new file mode 100644
index 0000000000000000000000000000000000000000..c535309f3321249f475367164a558f94a4f8eb86
--- /dev/null
+++ b/build/torch211-cxx11-cu128-x86_64-linux/stk/backend/triton_kernels.py
@@ -0,0 +1,393 @@
+import torch
+import triton
+import triton.language as tl
+from dataclasses import dataclass
+
+@dataclass
+class TritonConfig:
+ BLOCK_M: int = 128
+ BLOCK_N: int = 128
+ BLOCK_K: int = 32
+ BLOCK_SIZE: int = 128
+ NUM_STAGES: int = 4
+ NUM_WARPS: int = 4
+
+def _validate_matmul_dims(M: int, K: int, N: int):
+ error_string = "incompatible dimensions: tensor has dim with length: {}, which must be divisible by {}"
+ assert M % TritonConfig.BLOCK_M == 0, error_string.format(M, TritonConfig.BLOCK_M)
+ assert K % TritonConfig.BLOCK_K == 0, error_string.format(K, TritonConfig.BLOCK_K)
+ assert N % TritonConfig.BLOCK_N == 0, error_string.format(N, TritonConfig.BLOCK_N)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _sdd_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+ # matrix multiplication
+ pid = tl.program_id(0)
+ pid_m = tl.load(row_indices + pid)
+ pid_n = tl.load(column_indices + pid)
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
+ rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
+ rk = tl.arange(0, BLOCK_K)
+ # pointers
+ A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
+ B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ for k in range(0, tl.cdiv(K, BLOCK_K)):
+ a = tl.load(A)
+ b = tl.load(B)
+ acc += tl.dot(a, b)
+ A += BLOCK_K * stride_ak
+ B += BLOCK_K * stride_bk
+ #Store to sparse matrix
+ acc = acc.to(C.dtype.element_ty)
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+ cm = tl.arange(0, BLOCK_M)
+ cn = tl.arange(0, BLOCK_N)
+ C = C + pid * BLOCK_ELEMENTS + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _dsd_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices, offsets,
+ block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+
+ # matrix multiplication
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+
+ num_pid_m = tl.num_programs(0)
+ num_pid_n = tl.num_programs(1)
+ pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M)
+
+ start_inx = tl.load(offsets + pid_m)
+ end_inx = tl.load(offsets + pid_m + 1)
+
+ # pointers to sparse matrix
+ rm = tl.arange(0, BLOCK_M)
+ rak = tl.arange(0, BLOCK_K)
+
+ A += (rm[:, None] * stride_am + rak[None, :] * stride_ak)
+
+ # pointers to dense matrix
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ rbk = tl.arange(0, BLOCK_K)
+ B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn)
+
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K)
+
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+ ak_sub_incr = BLOCK_K * stride_ak
+ bk_sub_incr = BLOCK_K * stride_bk
+ bk_block_incr = BLOCK_SIZE * stride_bk
+
+ for k in range(nsub_blocks * (end_inx - start_inx)):
+ sub_block_inx = k % nsub_blocks
+ block_inx = k // nsub_blocks
+
+ if trans_A:
+ ptr_A = A + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr
+ else:
+ ptr_A = A + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr
+
+ ptr_B = B + tl.load(column_indices + start_inx + block_inx) * bk_block_incr + sub_block_inx * bk_sub_incr
+
+ a = tl.load(ptr_A)
+ b = tl.load(ptr_B)
+ acc += tl.dot(a, b)
+
+ acc = acc.to(C.dtype.element_ty)
+
+ cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+
+ C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _dds_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices, offsets,
+ block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+
+ # matrix multiplication
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+
+ num_pid_m = tl.num_programs(0)
+ num_pid_n = tl.num_programs(1)
+ pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M)
+
+ start_inx = tl.load(offsets + pid_n)
+ end_inx = tl.load(offsets + pid_n + 1)
+
+ # pointers to dense matrix
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ rak = tl.arange(0, BLOCK_K)
+
+ A += (rm[:, None] * stride_am + rak[None, :] * stride_ak)
+
+ # pointers to sparse matrix
+ rn = tl.arange(0, BLOCK_N)
+ rbk = tl.arange(0, BLOCK_K)
+ B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn)
+
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K)
+
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+
+ ak_sub_incr = BLOCK_K * stride_ak
+ ak_block_incr = BLOCK_SIZE * stride_ak
+ bk_sub_incr = BLOCK_K * stride_bk
+
+ for k in range(nsub_blocks * (end_inx - start_inx)):
+ sub_block_inx = k % nsub_blocks
+ block_inx = k // nsub_blocks
+
+ if trans_B:
+ ptr_B = B + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr
+ else:
+ ptr_B = B + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr
+
+ ptr_A = A + tl.load(column_indices + start_inx + block_inx) * ak_block_incr + sub_block_inx * ak_sub_incr
+ a = tl.load(ptr_A)
+ b = tl.load(ptr_B)
+ acc += tl.dot(a, b)
+
+ acc = acc.to(C.dtype.element_ty)
+ cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+def dsd(shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs,
+ out
+ ):
+
+ device = rhs.device
+ trans_A = transpose_a
+ trans_B = False
+
+ if rhs.stride(0) > 1 and rhs.stride(1) > 1:
+ trans_B = True
+
+ # checks constraints
+ assert shape[1] == rhs.shape[0], "incompatible dimensions"
+ M, K = shape
+ _, N = rhs.shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if rhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ stride_am, stride_ak = data.stride(1), data.stride(2)
+ stride_bk, stride_bn = rhs.stride(0), rhs.stride(1)
+ a_column_indices = column_indices
+ a_offsets = offsets
+
+ # launch kernel
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N']))
+
+ if trans_A:
+ stride_am, stride_ak = data.stride(2), data.stride(1)
+ a_column_indices, a_offsets = column_indices_t, offsets_t
+
+ if trans_B:
+ stride_bk, stride_bn = rhs.stride(1), rhs.stride(0)
+
+ _dsd_kernel[grid](
+ data.data, rhs, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(0), out.stride(1),
+ row_indices, a_column_indices, a_offsets,
+ block_offsets_t, trans_A, trans_B,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+ # return out
+
+def dds(lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b,
+ out
+ ):
+
+ device = lhs.device
+ trans_B = transpose_b
+ trans_A = False
+
+ if lhs.stride(0) > 1 and lhs.stride(1) > 1:
+ trans_A = True
+
+ # checks constraints
+ assert lhs.shape[1] == shape[0], "incompatible dimensions"
+ M, K = lhs.shape
+ _, N = shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if lhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ stride_am, stride_ak = lhs.stride(0), lhs.stride(1)
+ stride_bk, stride_bn = data.stride(1), data.stride(2)
+ b_column_indices = column_indices_t
+ b_offsets = offsets_t
+
+ # launch kernel
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N']))
+
+ if trans_A:
+ stride_am, stride_ak = lhs.stride(1), lhs.stride(0)
+ if trans_B:
+ stride_bk, stride_bn = data.stride(2), data.stride(1)
+ b_column_indices, b_offsets = column_indices, offsets
+
+ _dds_kernel[grid](
+ lhs, data, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(0), out.stride(1),
+ row_indices, b_column_indices, b_offsets,
+ block_offsets_t, trans_A, trans_B,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+
+def sdd(lhs,
+ rhs,
+ shape,
+ out,
+ offsets,
+ row_indices,
+ column_indices
+ ):
+
+ device = out.device
+ trans_A = False
+ trans_B = False
+
+ if lhs.stride(0) > 1 and lhs.stride(1) > 1:
+ trans_A = True
+ if rhs.stride(0) > 1 and rhs.stride(1) > 1:
+ trans_B = True
+
+ # checks constraints
+ assert lhs.shape[1] == rhs.shape[0], "incompatible dimensions"
+ M, K = lhs.shape
+ _, N = rhs.shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if out.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ # launch kernel
+ nnz_blocks = len(row_indices)
+ grid = lambda META: (nnz_blocks,)
+
+ stride_am, stride_ak = lhs.stride(0), lhs.stride(1)
+ stride_bk, stride_bn = rhs.stride(0), rhs.stride(1)
+
+ if trans_A:
+ stride_am, stride_ak = lhs.stride(1), lhs.stride(0)
+ if trans_B:
+ stride_bk, stride_bn = rhs.stride(1), rhs.stride(0)
+
+ _sdd_kernel[grid](
+ lhs, rhs, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(1), out.stride(2),
+ row_indices, column_indices,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+
+@triton.jit
+def _row_indices_kernel(offsets, out):
+ pid = tl.program_id(0)
+ row_offset = tl.load(offsets + pid)
+ nnz_blocks = tl.load(offsets + pid + 1) - row_offset
+ for nnz_block in range(nnz_blocks):
+ tl.store(out + row_offset + nnz_block, pid)
+
+def row_indices(
+ shape, data, offsets, column_indices, out
+):
+ block_rows = len(offsets) - 1
+ _row_indices_kernel[(block_rows, )](offsets, out)
diff --git a/build/torch211-cxx11-cu128-x86_64-linux/stk/matrix.py b/build/torch211-cxx11-cu128-x86_64-linux/stk/matrix.py
new file mode 100644
index 0000000000000000000000000000000000000000..80f42263d6aada287adbfa52a61fe950162a9e28
--- /dev/null
+++ b/build/torch211-cxx11-cu128-x86_64-linux/stk/matrix.py
@@ -0,0 +1,329 @@
+import numpy as np
+import torch
+
+# 1. Add heavyweight (data) validation helper.
+# 2. Add construction helpers
+# 3. Make indentation consistent
+# 4. Replace asserts with descriptive errors.
+
+##
+### Validation helpers.
+##
+
+
+def _validate_matrix(shape, data, row_indices, column_indices, offsets):
+ # Data should be [nnz, block_size, block_size]
+ if data.dim() == 1:
+ data = torch.reshape(data, [data.numel(), 1, 1])
+
+ # Blocks should be square.
+ if data.shape[-2] != data.shape[-1]:
+ raise ValueError(
+ "Expected square blocking in data. "
+ f"Got block shape {[data.shape[-2], data.shape[-1]]}")
+
+ # Flatten batch dimensions on data - original shape preserved
+ # in shape argument.
+ block_size = data.shape[-1]
+ data = data.view([-1, block_size, block_size])
+
+ if data.dim() != 3:
+ raise ValueError(
+ "Expected 3D shape for data (nnz, block, block). "
+ f"Got shape {data.dim()}D shape.")
+
+ block_size = data.shape[1]
+ if shape[-2] % block_size != 0 or shape[-1] % block_size != 0:
+ raise ValueError(
+ "Matrix shape must be dividible by blocking. "
+ f"Got shape {shape} with "
+ f"{[block_size, block_size]} blocking.")
+
+ if np.prod(shape) < data.numel():
+ raise ValueError(
+ "Invalid matrix. Number of nonzeros exceeds matrix capacity "
+ f"({data.numel()} v. {np.prod(shape)})")
+
+ if row_indices.dim() != 1:
+ raise ValueError(
+ f"Expected 1D row_indices. Got {row_indices.dim()}D row_indices.")
+
+ if column_indices.dim() != 1:
+ raise ValueError(
+ f"Expected 1D column_indices. Got {column_indices.dim()}D column_indices.")
+
+ if offsets.dim() != 1:
+ raise ValueError(
+ f"Expected 1D offsets. Got {offsets.dim()}D offsets.")
+
+ if row_indices.numel() != data.shape[0]:
+ raise ValueError(
+ "Expected 1 index per nonzero block. "
+ f"Got {row_indices.numel()} row_indices for {data.shape[0]} blocks")
+
+ if column_indices.numel() != data.shape[0]:
+ raise ValueError(
+ "Expected 1 index per nonzero block. "
+ f"Got {column_indices.numel()} column_indices for {data.shape[0]} blocks")
+
+ block_rows = np.prod(shape[:-1]) / block_size
+ if offsets.numel() != block_rows + 1:
+ raise ValueError(
+ "Expected one offset per block row plus one. "
+ f"Got {offsets.numel()} offsets with {block_rows} block rows.")
+
+ is_cuda = (data.is_cuda and
+ row_indices.is_cuda and
+ column_indices.is_cuda and
+ offsets.is_cuda)
+ is_cpu = (not data.is_cuda and
+ not row_indices.is_cuda and
+ not column_indices.is_cuda and
+ not offsets.is_cuda)
+ if not (is_cuda or is_cpu):
+ raise ValueError(
+ "Expected data & meta-data on common device. "
+ f"Got data on {data.device}, row_indices on {row_indices.device} "
+ f"column_indices on {column_indices.device} and "
+ f"offsets on {offsets.device}.")
+
+ if data.dtype != torch.float16:
+ raise ValueError(
+ f"Expected float16 data. Got {data.dtype} data.")
+ if row_indices.dtype != torch.int16:
+ raise ValueError(
+ f"Expected int16 row_indices. Got {row_indices.dtype} row_indices.")
+ if column_indices.dtype != torch.int16:
+ raise ValueError(
+ f"Expected int16 column_indices. Got {column_indices.dtype} column_indices.")
+ if offsets.dtype != torch.int32:
+ raise ValueError(
+ f"Expected int32 offsets. Got {offsets.dtype} offsets.")
+ return data
+
+
+def _transpose(size, data, row_indices, column_indices, offsets):
+ block_columns = size[1] // data.shape[1]
+
+ # Sort row indices by column indices to get the transposed matrix's
+ # column indices.
+ gather_indices = column_indices.argsort()
+ column_indices_t = row_indices.gather(0, gather_indices)
+ block_offsets_t = gather_indices.int()
+
+ # NOTE: Histogram is not implemented for any integer type on CPU. Do
+ # the histogram in 32-bit float, which can exactly represent 16-bit
+ # integers.
+ column_indices_float = column_indices.float()
+
+ zero = torch.zeros((1,), dtype=torch.int32, device=data.device)
+ nnz_per_column = column_indices_float.histc(block_columns, 0, block_columns)
+ nnz_per_column = nnz_per_column.int()
+ offsets_t = torch.cat([zero, nnz_per_column.cumsum(0, dtype=torch.int32)])
+ return column_indices_t, offsets_t, block_offsets_t
+
+
+class Matrix(torch.nn.Module):
+ """A matrix stored in sparse format.
+
+ Underlying format is block compressed sparse row (BCSR).
+
+ TODO(tgale): Make this mirror torch.Tensor API as much as possible.
+ """
+
+ def __init__(self,
+ size,
+ data,
+ row_indices,
+ column_indices,
+ offsets,
+ column_indices_t=None,
+ offsets_t=None,
+ block_offsets_t=None):
+ super().__init__()
+ self._size = size
+ self._data = data
+ self._row_indices = row_indices
+ self._column_indices = column_indices
+ self._offsets = offsets
+
+ # Produce the transpose meta-data if it is not passed in.
+ if ((column_indices_t is None) or (offsets_t is None) or
+ (block_offsets_t is None)):
+ column_indices_t, offsets_t, block_offsets_t = _transpose(
+ size, data, row_indices, column_indices, offsets)
+ self._column_indices_t = column_indices_t
+ self._offsets_t = offsets_t
+ self._block_offsets_t = block_offsets_t
+
+ self._transposed = False
+
+ # Validate that our metadata will not overflow.
+ max_dim = np.iinfo(np.int16).max * self.blocking
+ if column_indices.dtype == torch.int16:
+ if size[0] > max_dim or size[1] > max_dim:
+ raise ValueError(
+ "Sparse matrix with shape {size} exceeds representable "
+ "size with 16-bit indices.")
+
+ def validate(self):
+ _validate_matrix(self._size,
+ self._data,
+ self._row_indices,
+ self._column_indices,
+ self._offsets)
+
+ # TODO(tgale): Add heavyweight data validation.
+
+ def to(self, device):
+ # TODO(tgale): Handle type conversions here. We
+ # need to set the appropriate meta-data type for
+ # the given floating-point type.
+ self._data = self._data.to(device)
+ self._row_indices = self._row_indices.to(device)
+ self._column_indices = self._column_indices.to(device)
+ self._offsets = self._offsets.to(device)
+ self._column_indices_t = self._column_indices_t.to(device)
+ self._offsets_t = self._offsets_t.to(device)
+ self._block_offsets_t = self._block_offsets_t.to(device)
+ return self
+
+ def cuda(self):
+ return self.to(torch.cuda.current_device())
+
+ def clone(self):
+ return Matrix(
+ self.size(),
+ self.data.clone(),
+ self.row_indices.clone(),
+ self.column_indices.clone(),
+ self.offsets.clone(),
+ self.column_indices_t.clone(),
+ self.offsets_t.clone(),
+ self.block_offsets_t.clone())
+
+ def t(self):
+ if self.dim() != 2:
+ raise ValueError(
+ "t() expects a tensor with <= 2 dimensions, "
+ f"but self is {self.dim()}D.")
+ out = Matrix(self.size(),
+ self.data,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+ out._transposed = not self._transposed
+ out._size = torch.Size((self._size[1], self._size[0]))
+ return out
+
+ def contiguous(self):
+ raise ValueError("Not yet implemented.")
+
+ def is_contiguous(self):
+ return not self._transposed
+
+ @property
+ def is_cuda(self):
+ return self._data.is_cuda
+
+ @property
+ def device(self):
+ return self._data.device
+
+ def size(self):
+ return self._size
+
+ @property
+ def shape(self):
+ return self.size()
+
+ def dim(self):
+ return len(self._size)
+
+ @property
+ def data(self):
+ return self._data
+
+ @property
+ def row_indices(self):
+ return self._row_indices
+
+ @property
+ def column_indices(self):
+ return self._column_indices
+
+ @property
+ def offsets(self):
+ return self._offsets
+
+ @property
+ def offsets_t(self):
+ return self._offsets_t
+
+ @property
+ def column_indices_t(self):
+ return self._column_indices_t
+
+ @property
+ def block_offsets_t(self):
+ return self._block_offsets_t
+
+ @property
+ def dtype(self):
+ return self.data.dtype
+
+ @property
+ def nnz(self):
+ return self.data.numel()
+
+ @property
+ def blocking(self):
+ return self.data.shape[1]
+
+ @property
+ def requires_grad(self):
+ return self.data.requires_grad
+
+ def requires_grad_(self, x):
+ self.data.requires_grad_(x)
+ return self
+
+ def view(self, *shape):
+ assert self.is_contiguous()
+ if shape[-1] != self.size()[-1]:
+ raise ValueError(
+ "Can't change view on compressed dimension. "
+ f"{self.size()[-1]} v. {shape[-1]}.")
+ if np.prod(shape) != np.prod(self.size()):
+ raise ValueError(
+ "Mismatch in numel of Matrix and new shape. "
+ f"{np.prod(self.size())} v. {np.prod(shape)}")
+ return Matrix(shape,
+ self.data,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+
+ @property
+ def grad(self):
+ # TODO(tgale): Make sure this mirrors torch.Tensor
+ # behavior in the case where we ask for the gradient
+ # of a non-contiguous tensor.
+ size = self.size()
+ if not self.is_contiguous():
+ size = torch.Size((size[1], size[0]))
+ out = Matrix(size,
+ self.data.grad,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+ return out if self.is_contiguous() else out.t()
diff --git a/build/torch211-cxx11-cu128-x86_64-linux/stk/ops/__init__.py b/build/torch211-cxx11-cu128-x86_64-linux/stk/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc873b236f4cd4036964c016a4036e3ce5ebf1ac
--- /dev/null
+++ b/build/torch211-cxx11-cu128-x86_64-linux/stk/ops/__init__.py
@@ -0,0 +1,3 @@
+from .linear_ops import dds, dsd, sdd
+from .matrix_ops import ones_like, row_indices, sum, to_dense, to_sparse
+from .eltwise_ops import mul
diff --git a/build/torch211-cxx11-cu128-x86_64-linux/stk/ops/eltwise_ops.py b/build/torch211-cxx11-cu128-x86_64-linux/stk/ops/eltwise_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba7d7332320250fd01fa60e60528f19de3e8ed03
--- /dev/null
+++ b/build/torch211-cxx11-cu128-x86_64-linux/stk/ops/eltwise_ops.py
@@ -0,0 +1,28 @@
+from ..matrix import Matrix
+
+def mul(a, b):
+ """Performs element-wise multiplication of matrices a and b.
+
+ It is the user's responsibility to make sure that a and b
+ follow the same matrix topology. This function assumes it is safe
+ to use the topoplogy of a.
+
+ Args:
+ a: stk.Matrix.
+ b: stk.Matrix with a's matrix topology.
+
+ Returns:
+ stk.Matrix where the entries correspond to torch.mul(a, b).
+ """
+ assert isinstance(a, Matrix)
+ assert isinstance(b, Matrix)
+ assert a.size() == b.size()
+
+ return Matrix(a.size(),
+ a.data * b.data,
+ a.row_indices,
+ a.column_indices,
+ a.offsets,
+ a.column_indices_t,
+ a.offsets_t,
+ a.block_offsets_t)
diff --git a/build/torch211-cxx11-cu128-x86_64-linux/stk/ops/eltwise_ops_test.py b/build/torch211-cxx11-cu128-x86_64-linux/stk/ops/eltwise_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..2939372a5c68ac92b47b11015db4f75f4fd60ffa
--- /dev/null
+++ b/build/torch211-cxx11-cu128-x86_64-linux/stk/ops/eltwise_ops_test.py
@@ -0,0 +1,86 @@
+import unittest
+import itertools
+import torch
+# from absl.testing import parameterized
+
+import stk
+from stk.ops.linear_ops_test import allclose, _dense_and_sparse
+
+_MATRIX_SIZES = (
+ (128, 128, 0.0),
+ (256, 256, 0.5),
+ (2048, 1024, 0.8),
+ (512, 128, 0.0),
+ (128, 512, 0.0),
+ (1024, 512, 0.0),
+ (1024, 512, 0.5),
+ (1024, 512, 0.75),
+ (512, 1024, 0.0),
+ (512, 1024, 0.5),
+ (512, 1024, 0.75),
+ (1024, 1024, 0.0),
+ (1024, 1024, 0.5),
+ (1024, 1024, 0.75),
+)
+
+_DTYPE = (
+ torch.float16, torch.bfloat16
+)
+
+def _generate_testcases():
+ testcases = itertools.product(_MATRIX_SIZES, _DTYPE)
+ testcases = [(*size, 128, dtype) for
+ (size, dtype) in testcases]
+ return testcases
+
+_ELTWISE_OP_TESTS = _generate_testcases()
+
+def _dense_and_sparse_like(x, std=0.1):
+ dense_data = torch.randn_like(x.data, device=x.device) * std
+ sparse = stk.Matrix(x.size(),
+ dense_data,
+ x.row_indices,
+ x.column_indices,
+ x.offsets)
+ dense = stk.ops.to_dense(sparse)
+
+ return (dense.requires_grad_(True),
+ sparse.requires_grad_(True))
+
+# @parameterized.parameters(_ELTWISE_OP_TESTS)
+# class EltwiseOpsTest(parameterized.TestCase):
+#
+# def testEltwiseMul(self, m, n, sparsity, blocking, dtype):
+#
+# a_dense, a = _dense_and_sparse(m, n, sparsity, blocking, dtype)
+# b_dense, b = _dense_and_sparse_like(a)
+#
+# out = stk.ops.mul(a, b)
+# expected_out = torch.mul(a_dense, b_dense)
+#
+# # Compute the gradients w.r.t. the inputs.
+# expected_out.sum().backward()
+# stk.ops.sum(out).backward()
+#
+# # Validate the results.
+# out = stk.ops.to_dense(out)
+# self.assertEqual(out.dim(), 2)
+# self.assertEqual(expected_out.size(), out.size())
+# self.assertTrue(allclose(out, expected_out))
+#
+# # LHS gradient.
+# grad = stk.ops.to_dense(a.grad)
+# expected_grad = a_dense.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size(), grad.size())
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# # RHS gradient.
+# grad = stk.ops.to_dense(b.grad)
+# expected_grad = b_dense.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size(), grad.size())
+# self.assertTrue(allclose(grad, expected_grad))
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch211-cxx11-cu128-x86_64-linux/stk/ops/linear_ops.py b/build/torch211-cxx11-cu128-x86_64-linux/stk/ops/linear_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d277c8c07f9e30addc31900a12175c8a1f4d7ad
--- /dev/null
+++ b/build/torch211-cxx11-cu128-x86_64-linux/stk/ops/linear_ops.py
@@ -0,0 +1,59 @@
+import torch
+
+from ..backend import sputnik
+from ..matrix import Matrix
+
+
+def dsd(a, b):
+ assert isinstance(a, Matrix)
+ assert isinstance(b, torch.Tensor)
+ return sputnik.dsd(
+ a.size(),
+ a.data, a.offsets,
+ a.row_indices,
+ a.column_indices,
+ a.offsets_t,
+ a.column_indices_t,
+ a.block_offsets_t,
+ not a.is_contiguous(),
+ b)
+
+
+def dds(a, b):
+ assert isinstance(a, torch.Tensor)
+ assert isinstance(b, Matrix)
+ return sputnik.dds(
+ a,
+ b.size(),
+ b.data, b.offsets,
+ b.row_indices,
+ b.column_indices,
+ b.offsets_t,
+ b.column_indices_t,
+ b.block_offsets_t,
+ not b.is_contiguous())
+
+
+def sdd(a, b, topo):
+ assert isinstance(a, torch.Tensor)
+ assert isinstance(b, torch.Tensor)
+ assert isinstance(topo, Matrix)
+ assert topo.is_contiguous()
+ out = sputnik.sdd(
+ a, b,
+ topo.size(),
+ topo.data,
+ topo.offsets,
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets_t,
+ topo.column_indices_t,
+ topo.block_offsets_t)
+ return Matrix(topo.size(),
+ out,
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets,
+ topo.column_indices_t,
+ topo.offsets_t,
+ topo.block_offsets_t)
diff --git a/build/torch211-cxx11-cu128-x86_64-linux/stk/ops/linear_ops_test.py b/build/torch211-cxx11-cu128-x86_64-linux/stk/ops/linear_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1c24d350df9c1b2346c7da885502cd696c88867
--- /dev/null
+++ b/build/torch211-cxx11-cu128-x86_64-linux/stk/ops/linear_ops_test.py
@@ -0,0 +1,216 @@
+import unittest
+import itertools
+import numpy as np
+import torch
+# from absl.testing import parameterized
+
+import stk
+
+
+def allclose(x, y, pct=0.25):
+ mask = torch.isclose(x, y, rtol=5e-2)
+ pct_diff = (mask.numel() - mask.sum()) / mask.numel() * 100
+ if pct_diff > pct:
+ print("{:.2f}% of values not close.".format(pct_diff))
+ return False
+ return True
+
+
+# An assortment of problems designed to make sure
+# the bindings are operating correctly.
+_MATRIX_SIZES = (
+ (128, 128, 128, 0.0),
+ (256, 256, 256, 0.5),
+ (2048, 1024, 512, 0.8),
+ (512, 128, 128, 0.0),
+ (128, 128, 512, 0.0),
+ (1024, 512, 512, 0.0),
+ (1024, 512, 512, 0.5),
+ (1024, 512, 512, 0.75),
+ (512, 512, 1024, 0.0),
+ (512, 512, 1024, 0.5),
+ (512, 512, 1024, 0.75),
+ (1024, 1024, 1024, 0.0),
+ (1024, 1024, 1024, 0.5),
+ (1024, 1024, 1024, 0.75),
+)
+
+_TRANSPOSE = (
+ (False, False),
+ (False, True),
+ (True, False),
+ (True, True),
+)
+
+_DTYPE = (
+ torch.float16, torch.bfloat16
+)
+
+def _generate_testcases():
+ testcases = itertools.product(_MATRIX_SIZES, _TRANSPOSE, _DTYPE)
+ testcases = [(*size, *trans, 128, dtype) for
+ (size, trans, dtype) in testcases]
+ return testcases
+
+_LINEAR_OP_TESTS = _generate_testcases()
+
+def _dense_and_sparse(rows, cols, sparsity, blocking, dtype, std=0.1):
+ mask = stk.random.dense_mask(rows, cols, sparsity, blocking)
+ dense = (torch.randn(rows, cols) * std * mask).type(dtype)
+ sparse = stk.ops.to_sparse(dense, blocking)
+ cuda_device = torch.device("cuda")
+ return (dense.to(cuda_device).requires_grad_(True),
+ sparse.to(cuda_device).requires_grad_(True))
+
+
+def _dense(rows, cols, dtype, std=0.1):
+ cuda_device = torch.device("cuda")
+ out = (torch.randn(rows, cols) * std).type(dtype)
+ return out.to(cuda_device).requires_grad_(True)
+
+
+def _dense_2x(rows, cols, dtype):
+ a = _dense(rows, cols, dtype)
+ return a, a.detach().requires_grad_(True)
+
+
+def _with_transpose(op, a, b, trans_a, trans_b):
+ a = a.t() if trans_a else a
+ b = b.t() if trans_b else b
+ return op(a, b)
+
+
+def _mmm(a, b, topo):
+ mask = stk.ops.to_dense(stk.ops.ones_like(topo))
+ return torch.mm(a, b) * mask
+
+
+def _sparse_out_with_transpose(op, a, b, topo, trans_a, trans_b):
+ a = a.t() if trans_a else a
+ b = b.t() if trans_b else b
+ return op(a, b, topo)
+
+
+def _mask(x, mask):
+ mask = stk.ops.to_dense(stk.ops.ones_like(mask))
+ return x * mask
+
+
+# @parameterized.parameters(*_LINEAR_OP_TESTS)
+# class LinearOpsTest(parameterized.TestCase):
+#
+# def testLinearOps_Dsd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+# # Construct the operands.
+# a_shape = (k, m) if trans_a else (m, k)
+# a_dense, a = _dense_and_sparse(*a_shape, sparsity, blocking, dtype)
+# b_shape = (n, k) if trans_b else (k, n)
+# b, bcp = _dense_2x(*b_shape, dtype)
+#
+# # Execute the matmul.
+# out = _with_transpose(stk.ops.dsd, a, b, trans_a, trans_b)
+# expected_out = _with_transpose(torch.mm, a_dense, bcp, trans_a, trans_b)
+#
+# # Compute the gradients w.r.t. the inputs.
+# expected_out.sum().backward()
+# out.sum().backward()
+#
+# # Validate the results.
+# self.assertEqual(out.dim(), 2)
+# self.assertEqual(expected_out.size()[0], out.size()[0])
+# self.assertEqual(expected_out.size()[1], out.size()[1])
+# self.assertTrue(allclose(out, expected_out))
+#
+# # LHS gradient.
+# grad = stk.ops.to_dense(a.grad)
+# expected_grad = _mask(a_dense.grad, a.grad)
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# # RHS gradient.
+# grad = b.grad
+# expected_grad = bcp.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# def testLinearOps_Dds(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+# # Construct the operands.
+# a_shape = (k, m) if trans_a else (m, k)
+# a, acp = _dense_2x(*a_shape, dtype)
+# b_shape = (n, k) if trans_b else (k, n)
+# b_dense, b = _dense_and_sparse(*b_shape, sparsity, blocking, dtype)
+#
+# # Execute the matmul.
+# out = _with_transpose(stk.ops.dds, a, b, trans_a, trans_b)
+# expected_out = _with_transpose(torch.mm, acp, b_dense, trans_a, trans_b)
+#
+# # Compute the gradients w.r.t. the inputs.
+# expected_out.sum().backward()
+# out.sum().backward()
+#
+# # Validate the results.
+# self.assertEqual(out.dim(), 2)
+# self.assertEqual(expected_out.size()[0], out.size()[0])
+# self.assertEqual(expected_out.size()[1], out.size()[1])
+# self.assertTrue(allclose(out, expected_out))
+#
+# # LHS gradient.
+# grad = a.grad
+# expected_grad = acp.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# # RHS gradient.
+# grad = stk.ops.to_dense(b.grad)
+# expected_grad = _mask(b_dense.grad, b.grad)
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# def testLinearOps_Sdd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+# # Construct the operands.
+# a_shape = (k, m) if trans_a else (m, k)
+# a, acp = _dense_2x(*a_shape, dtype)
+# b_shape = (n, k) if trans_b else (k, n)
+# b, bcp = _dense_2x(*b_shape, dtype)
+# _, topo = _dense_and_sparse(m, n, sparsity, blocking, dtype)
+#
+# # Execute the matmul.
+# out = _sparse_out_with_transpose(stk.ops.sdd, a, b, topo, trans_a, trans_b)
+# expected_out = _sparse_out_with_transpose(_mmm, acp, bcp, topo, trans_a, trans_b)
+#
+# # Compute the gradients w.r.t. the inputs.
+# expected_out.sum().backward()
+# stk.ops.sum(out).backward()
+#
+# # Validate the results.
+# out = stk.ops.to_dense(out)
+# self.assertEqual(out.dim(), 2)
+# self.assertEqual(expected_out.size()[0], out.size()[0])
+# self.assertEqual(expected_out.size()[1], out.size()[1])
+# self.assertTrue(allclose(out, expected_out))
+#
+# # LHS gradient.
+# grad = a.grad
+# expected_grad = acp.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# # RHS gradient.
+# grad = b.grad
+# expected_grad = bcp.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch211-cxx11-cu128-x86_64-linux/stk/ops/matrix_ops.py b/build/torch211-cxx11-cu128-x86_64-linux/stk/ops/matrix_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..447c72dc73439d84f58c917676cc04e64f13e97d
--- /dev/null
+++ b/build/torch211-cxx11-cu128-x86_64-linux/stk/ops/matrix_ops.py
@@ -0,0 +1,98 @@
+from ..backend import sputnik
+from ..matrix import Matrix
+import torch
+import numpy as np
+
+
+@torch.no_grad()
+def row_indices(shape, data, offsets, column_indices):
+ return sputnik.row_indices(shape, data, offsets, column_indices)
+
+
+# TODO(tgale): Replace this helper with a custom kernel. This operation
+# is much simpler to do than how it's currently implemented.
+@torch.no_grad()
+def _expand_for_blocking(idxs, blocking):
+ # Duplicate for block column dimension.
+ idxs = torch.reshape(idxs, [idxs.size()[0], 1, 2]).repeat(1, blocking, 1)
+
+ # Update the column indices.
+ idxs[:, :, 1] *= blocking
+ idxs[:, :, 1] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking])
+
+ # Duplicate for block row dimension.
+ idxs = torch.reshape(idxs, [idxs.size()[0], 1, blocking, 2])
+ idxs = idxs.repeat(1, blocking, 1, 1)
+
+ # Update the row indices.
+ idxs[:, :, :, 0] *= blocking
+ idxs[:, :, :, 0] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking, 1])
+ idxs = torch.reshape(idxs, [-1, 2])
+ return idxs
+
+
+# TODO(tgale): Add input type checking.
+@torch.no_grad()
+def to_dense(x):
+ assert isinstance(x, Matrix)
+
+ shape = (np.prod(x.shape[:-1]), x.shape[-1])
+ row_idxs = x.row_indices.type(torch.int32)
+ col_idxs = x.column_indices.type(torch.int32)
+ indices = _expand_for_blocking(torch.stack([row_idxs, col_idxs], dim=1), x.blocking)
+ indices = (indices[:, 0] * shape[1] + indices[:, 1]).type(torch.int64)
+
+ out = torch.zeros(shape[0] * shape[1], dtype=x.dtype, device=x.device)
+ out.scatter_(0, indices, x.data.flatten())
+ return out.reshape(x.size())
+
+
+@torch.no_grad()
+def _mask(x, blocking=1):
+ assert x.dim() == 2
+ assert x.size()[0] % blocking == 0
+ assert x.size()[1] % blocking == 0
+ block_rows = x.size()[0] // blocking
+ block_cols = x.size()[1] // blocking
+ x = torch.reshape(x, [block_rows, blocking, block_cols, blocking])
+ x = torch.sum(torch.abs(x), dim=(1, 3))
+ return x != 0
+
+
+# TODO(tgale): Add input type checking.
+@torch.no_grad()
+def to_sparse(x, blocking=1):
+ m = _mask(x, blocking)
+
+ # TODO(tgale): Set to appropriate type for input matrix.
+ row_nnzs = torch.sum(m, dim=1).type(torch.int32)
+ zeros = torch.zeros((1,), dtype=row_nnzs.dtype, device=row_nnzs.device)
+ offsets = torch.cat([zeros, torch.cumsum(row_nnzs, dim=0)])
+ offsets = offsets.type(torch.int32)
+
+ indices = torch.nonzero(m).type(torch.int16)
+ row_indices = indices[:, 0]
+ column_indices = indices[:, 1]
+
+ # Nonzero indices in the dense matrix.
+ nonzero_indices = torch.nonzero(m)
+ nonzero_indices = _expand_for_blocking(nonzero_indices, blocking)
+ nonzero_indices = nonzero_indices[:, 0] * x.size()[1] + nonzero_indices[:, 1]
+
+ # Gather the data and construct the sparse matrix.
+ data = torch.gather(x.flatten(), dim=0, index=nonzero_indices)
+ data = torch.reshape(data, [-1, blocking, blocking])
+ return Matrix(x.size(), data, row_indices, column_indices, offsets)
+
+
+@torch.no_grad()
+def ones_like(x):
+ return Matrix(x.size(),
+ torch.ones_like(x.data),
+ x.row_indices,
+ x.column_indices, x.offsets)
+
+
+def sum(x):
+ assert isinstance(x, Matrix)
+ return x.data.sum()
diff --git a/build/torch211-cxx11-cu128-x86_64-linux/stk/ops/matrix_ops_test.py b/build/torch211-cxx11-cu128-x86_64-linux/stk/ops/matrix_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..d172d921f6f08b0e4fb709207a458b0e1e071bd0
--- /dev/null
+++ b/build/torch211-cxx11-cu128-x86_64-linux/stk/ops/matrix_ops_test.py
@@ -0,0 +1,62 @@
+import unittest
+
+# from absl.testing import parameterized
+import stk
+import torch
+
+
+# @parameterized.parameters(
+# (8, 16, 0.0, 1),
+# (8, 16, 0.5, 1),
+# (8, 16, .95, 1),
+# (16, 8, 0.0, 1),
+# (16, 8, 0.5, 1),
+# (16, 8, .95, 1),
+# (8, 16, 0.0, 8),
+# (8, 16, 0.5, 8),
+# (8, 16, 1.0, 8),
+# (16, 8, 0.0, 8),
+# (16, 8, 0.5, 8),
+# (16, 8, 1.0, 8),
+# (128, 256, 0.5, 16),
+# (256, 128, 0.75, 32),
+# (512, 512, .875, 128))
+# class MatrixOpsTest(parameterized.TestCase):
+#
+# def testMatrixOps_FormatConversion(self, rows, cols, sparsity, blocking):
+# mask = stk.random.dense_mask(rows, cols, sparsity, blocking)
+# x = (torch.randn(rows, cols) * mask).type(torch.float16)
+#
+# # Convert the matrix to sparse format.
+# sparse_x = stk.ops.to_sparse(x, blocking)
+#
+# # Validate the matrix.
+# sparse_x.validate()
+#
+# # Validate the shape.
+# self.assertEqual(sparse_x.dim(), 2)
+# self.assertEqual(sparse_x.size()[0], rows)
+# self.assertEqual(sparse_x.size()[1], cols)
+#
+# # Validate the sparsity.
+# numblocks = rows // blocking * cols // blocking
+# nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+# self.assertEqual(sparse_x.nnz, nnz)
+#
+# # Convert back to dense format.
+# dense_x = stk.ops.to_dense(sparse_x)
+#
+# # Validate the shape.
+# self.assertEqual(dense_x.dim(), 2)
+# self.assertEqual(dense_x.size()[0], rows)
+# self.assertEqual(dense_x.size()[1], cols)
+#
+# # Validate the sparsity
+# self.assertEqual(torch.count_nonzero(dense_x).item(), nnz)
+#
+# # Validate the output.
+# self.assertTrue(torch.all(torch.eq(x, dense_x)))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch211-cxx11-cu128-x86_64-linux/stk/random/__init__.py b/build/torch211-cxx11-cu128-x86_64-linux/stk/random/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2576d1ca27283f77569a9a620c7c99fa68aaf30e
--- /dev/null
+++ b/build/torch211-cxx11-cu128-x86_64-linux/stk/random/__init__.py
@@ -0,0 +1,2 @@
+# from stk.random.random_ops import dense_mask, mask, randn
+from .random_ops import dense_mask, mask, randn
diff --git a/build/torch211-cxx11-cu128-x86_64-linux/stk/random/random_ops.py b/build/torch211-cxx11-cu128-x86_64-linux/stk/random/random_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1b36771e0eb8e7abf46bcb3b136b5fb1d29df93
--- /dev/null
+++ b/build/torch211-cxx11-cu128-x86_64-linux/stk/random/random_ops.py
@@ -0,0 +1,36 @@
+import numpy as np
+import torch
+from ..ops import matrix_ops
+
+
+@torch.no_grad()
+def dense_mask(rows, cols, sparsity, blocking=1):
+ assert sparsity >= 0.0 and sparsity <= 1.0
+ assert rows % blocking == 0 and cols % blocking == 0
+
+ block_rows, block_cols = (rows // blocking, cols // blocking)
+ nnz = round(block_rows * block_cols * (1 - sparsity))
+
+ out = np.ones(block_rows * block_cols)
+ mask = np.random.choice(out.size, out.size - nnz, replace=False)
+ out[mask] = 0.0
+
+ out = np.tile(
+ np.reshape(out, [block_rows, 1, block_cols, 1]),
+ (1, blocking, 1, blocking))
+ out = np.reshape(out, [rows, cols])
+ return torch.from_numpy(out.astype(np.float32))
+
+
+@torch.no_grad()
+def mask(m, n, sparsity, blocking=1):
+ out = dense_mask(m, n, sparsity, blocking).type(torch.float16)
+ return matrix_ops.to_sparse(out, blocking=blocking)
+
+
+@torch.no_grad()
+def randn(shape, sparsity, blocking=1):
+ shape_2d = (np.prod(shape[:-1]), shape[-1])
+ out = mask(*shape_2d, sparsity, blocking)
+ out.data.copy_(torch.randn(*out.data.shape))
+ return out.view(*shape)
diff --git a/build/torch211-cxx11-cu128-x86_64-linux/stk/random/random_ops_test.py b/build/torch211-cxx11-cu128-x86_64-linux/stk/random/random_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..71d716b78b5ec009cbf9ac2dfdf09162a0102e62
--- /dev/null
+++ b/build/torch211-cxx11-cu128-x86_64-linux/stk/random/random_ops_test.py
@@ -0,0 +1,73 @@
+import unittest
+
+# from absl.testing import parameterized
+from . import random
+import torch
+
+
+# @parameterized.parameters(
+# (8, 16, 0.0, 1),
+# (8, 16, 0.5, 1),
+# (8, 16, .95, 1),
+# (16, 8, 0.0, 1),
+# (16, 8, 0.5, 1),
+# (16, 8, .95, 1),
+# (8, 16, 0.0, 8),
+# (8, 16, 0.5, 8),
+# (8, 16, 1.0, 8),
+# (16, 8, 0.0, 8),
+# (16, 8, 0.5, 8),
+# (16, 8, 1.0, 8),
+# (128, 256, 0.5, 16),
+# (256, 128, 0.75, 32),
+# (512, 512, .875, 128))
+# class RandomOpsTest(parameterized.TestCase):
+#
+# def testRandomOps_DenseMask(self, rows, cols, sparsity, blocking):
+# mask = random.dense_mask(
+# rows, cols, sparsity, blocking)
+#
+# # Validate the shape.
+# self.assertEqual(mask.dim(), 2)
+# self.assertEqual(mask.size()[0], rows)
+# self.assertEqual(mask.size()[1], cols)
+#
+# # Validate the sparsity
+# numblocks = rows // blocking * cols // blocking
+# nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+# self.assertEqual(
+# torch.count_nonzero(mask).item(),
+# nnz)
+#
+# # Check values are zero or one.
+# self.assertTrue(
+# torch.all(torch.logical_or(
+# torch.eq(mask, 0),
+# torch.eq(mask, 1))))
+#
+# def testRandomOps_SparseMask(self, rows, cols, sparsity, blocking):
+# mask = random.mask(
+# rows, cols, sparsity, blocking)
+#
+# # Validate the matrix.
+# mask.validate()
+#
+# # Validate the shape.
+# self.assertEqual(mask.dim(), 2)
+# self.assertEqual(mask.size()[0], rows)
+# self.assertEqual(mask.size()[1], cols)
+#
+# # Validate the sparsity.
+# numblocks = rows // blocking * cols // blocking
+# nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+# self.assertEqual(mask.nnz, nnz)
+#
+# # Check values are zero or one.
+# self.assertTrue(
+# torch.all(torch.logical_or(
+# torch.eq(mask.data, 0),
+# torch.eq(mask.data, 1))))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch211-cxx11-cu128-x86_64-linux/xpu_fused_moe.py b/build/torch211-cxx11-cu128-x86_64-linux/xpu_fused_moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..0440b484e6f4073a384dcf3bc562174d2ef71d25
--- /dev/null
+++ b/build/torch211-cxx11-cu128-x86_64-linux/xpu_fused_moe.py
@@ -0,0 +1,672 @@
+# SPDX-License-Identifier: Apache-2.0
+# MegaBlocks XPU Fused MoE Implementation
+import os
+import torch
+
+from ._ops import ops, add_op_namespace_prefix
+
+from torch.library import register_fake
+
+
+def resolve_dtensor(weight: torch.Tensor):
+ """Convert DTensor to local tensor for use with custom ops."""
+ from torch.distributed._tensor import DTensor
+ if isinstance(weight, DTensor):
+ return weight.to_local()
+ return weight
+
+
+# Register fake/meta kernels for torch.compile compatibility
+def _register_xpu_fake_kernels():
+ """Register fake kernels for XPU MoE operations to support torch.compile."""
+
+ def _register_if_available(op_name, fn):
+ if hasattr(ops, op_name):
+ register_fake(add_op_namespace_prefix(op_name))(fn)
+
+ _register_if_available(
+ "cutlass_grouped_gemm_interface",
+ lambda ptr_A, ptr_B, ptr_scales, ptr_bias, ptr_D, expert_first_token_offset, N, K, num_experts, is_B_int4, is_B_mxfp4: ptr_D,
+ )
+
+ _register_if_available(
+ "fused_moe_prologue",
+ lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, ep_rank, ep_size, num_experts_on_rank: None,
+ )
+
+ _register_if_available(
+ "moe_gather",
+ lambda output, moe_output, topk_weights, permuted_row_to_unpermuted_row, unpermuted_row_to_permuted_row, expert_first_token_offset, num_experts: None,
+ )
+
+ _register_if_available(
+ "silu_and_mul",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "mul_and_silu",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_and_mul",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_tanh_and_mul",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_fast",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_new",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_quick",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "swigluoai_and_mul",
+ lambda out, input, alpha=1.702, limit=7.0: None,
+ )
+
+
+# Register fake kernels on module load
+_register_xpu_fake_kernels()
+
+
+# default
+def cutlass_grouped_gemm(input_A, input_B, bias, output, expert_token_count, n,
+ k, num_experts):
+ # expert_token_count_ = torch.tensor(expert_token_count,
+ # dtype=torch.int64,
+ # device=input_A.device)
+ # if bias is not None:
+ # bias = bias.repeat_interleave(expert_token_count_, dim=0).float()
+
+ def exclusive_prefix_sum(arr):
+ prefix = [0]
+ for i, x in enumerate(arr):
+ prefix.append(prefix[-1] + x)
+ return prefix
+
+ expert_offset = torch.tensor(exclusive_prefix_sum(expert_token_count),
+ dtype=torch.int64,
+ device="xpu")
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=None,
+ ptr_bias=bias,
+ ptr_D=output,
+ expert_first_token_offset=expert_offset,
+ N=n,
+ K=k,
+ num_experts=num_experts,
+ is_B_int4=False,
+ is_B_mxfp4=False)
+
+
+def cutlass_grouped_gemm_xe2(input_A, input_B, scales, bias, output,
+ num_rows_per_expert, n, k, num_experts, is_B_int4,
+ is_B_mxfp4):
+ expert_first_token_offset = torch.cat([
+ torch.tensor([0],
+ dtype=num_rows_per_expert.dtype,
+ device=num_rows_per_expert.device),
+ torch.cumsum(num_rows_per_expert, dim=0)
+ ]).to(torch.int64)
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=scales,
+ ptr_bias=bias,
+ ptr_D=output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=n,
+ K=k,
+ num_experts=num_experts,
+ is_B_int4=is_B_int4,
+ is_B_mxfp4=is_B_mxfp4)
+
+
+def ceilDiv(a, b):
+ return (a + b - 1) // b
+
+
+def compute_num_tokens_per_block(num_tokens, num_experts_per_node):
+ for num_tokens_per_block in [32, 64, 128, 256, 512, 1024]:
+ num_blocks_per_seq = ceilDiv(num_tokens, num_tokens_per_block)
+ if num_blocks_per_seq * num_experts_per_node <= num_tokens_per_block:
+ return num_tokens_per_block
+ return 1024
+
+
+def _bytes_to_typed_tensor(byte_tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
+ """Reinterpret a uint8 buffer as a typed tensor by copying bytes.
+
+ This avoids `Tensor.view(dtype)` which can fail under torch.compile
+ constant folding when shape divisibility is not proven.
+ """
+ if byte_tensor.dtype != torch.uint8:
+ raise ValueError("byte_tensor must be uint8")
+ itemsize = torch.empty((), dtype=dtype).element_size()
+ numel = byte_tensor.numel() // itemsize
+ out = torch.empty((numel,), dtype=dtype, device=byte_tensor.device)
+ out.view(torch.uint8).copy_(byte_tensor.contiguous())
+ return out
+
+
+def implement_zp(qweight):
+ # change u4 to s4 to avoid zero point in gemm kernel
+ # only support default zero point now
+ assert qweight.dtype == torch.uint8, "Input tensor must be uint8"
+
+ high_u4 = (qweight >> 4) & 0x0F
+ low_u4 = qweight & 0x0F
+
+ high_s8 = high_u4.to(torch.int8)
+ low_s8 = low_u4.to(torch.int8)
+
+ high_s8 = high_s8 - 8
+ low_s8 = low_s8 - 8
+
+ def pack_compact(a, b):
+
+ def process_number(x):
+ sign = (x < 0).to(torch.uint8)
+ abs_low3 = (x.view(torch.uint8) & 0x7).to(torch.uint8)
+ return (sign << 3) | abs_low3
+
+ packed_a = process_number(a)
+ packed_b = process_number(b)
+
+ return (packed_a << 4) | packed_b
+
+ result = pack_compact(high_s8, low_s8)
+
+ return result
+
+
+def xpu_fused_moe(hidden_states,
+ w13,
+ w13_scales,
+ w13_bias,
+ w2,
+ w2_scales,
+ w2_bias,
+ topk_weights,
+ topk_ids,
+ n_experts_per_token,
+ activation,
+ num_experts,
+ ep_rank=0,
+ ep_size=1,
+ is_fp8=False,
+ is_int4=False,
+ is_mxfp4=False):
+ '''
+ hidden_states: [num_rows, hidden_size]
+ w13: [num_experts, 2*inter_size, hidden_size]
+ w13_scales:
+ None for bf16/fp16
+ or [num_experts] for fp8
+ or [num_experts, 2*inter_size, hidden_size // group_size] for 4bits
+ w13_bias: [num_experts, 2*inter_size] or None
+ w2: [num_experts, hidden_size, inter_size]
+ w2_scales:
+ None for bf16/fp16
+ or [num_experts] for fp8
+ or [num_experts, hidden_size, inter_size // group_size] for 4bits
+ w2_bias: [num_experts, hidden_size] or None
+ topk_weights: [num_rows, topk]
+ topk_ids: [num_rows, topk]
+ n_experts_per_token: int
+ activation: str
+ num_experts: int
+ is_int4: bool
+ is_mxfp4: bool
+ '''
+
+ # Resolve DTensors to local tensors before passing to custom ops
+ hidden_states = resolve_dtensor(hidden_states)
+ w13 = resolve_dtensor(w13)
+ w2 = resolve_dtensor(w2)
+ if w13_scales is not None:
+ w13_scales = resolve_dtensor(w13_scales)
+ if w13_bias is not None:
+ w13_bias = resolve_dtensor(w13_bias)
+ if w2_scales is not None:
+ w2_scales = resolve_dtensor(w2_scales)
+ if w2_bias is not None:
+ w2_bias = resolve_dtensor(w2_bias)
+ topk_weights = resolve_dtensor(topk_weights)
+ topk_ids = resolve_dtensor(topk_ids)
+
+ output = torch.empty_like(hidden_states)
+ num_rows, hidden_size = list(hidden_states.shape)
+
+ dim_last = w13.shape[-1]
+ dim_second_last = w13.shape[-2]
+
+ # w13 is combined gate+up weights, so one dimension is 2*inter_size
+ # Determine which dimension is hidden_size and which is 2*inter_size
+ if dim_second_last == hidden_size:
+ # w13 is [E, hidden_size, 2*inter_size] - standard layout
+ inter_size = dim_last // 2
+ needs_transpose = False
+ else:
+ # w13 is [E, 2*inter_size, hidden_size] - needs transpose
+ inter_size = dim_second_last // 2
+ needs_transpose = True
+
+ assert w13.is_contiguous() and w2.is_contiguous()
+
+ # 4bits support [E, N, K]
+ # other types [E, K, N]
+ if not is_int4 and not is_mxfp4:
+ if not hasattr(w13, 'xpu_fused_moe'):
+ if needs_transpose:
+ w13.data = w13.transpose(-1, -2).contiguous()
+ w2.data = w2.transpose(-1, -2).contiguous()
+ w13.xpu_fused_moe = True
+ w13.inter_size = inter_size
+ else:
+ inter_size = w13.inter_size
+
+ if is_int4 and not hasattr(w13, 'xpu_fused_moe'):
+ w13_tmp = torch.empty_like(w13)
+ w2_tmp = torch.empty_like(w2)
+ for i in range(num_experts):
+ w13_tmp[i] = implement_zp(w13[i])
+ w2_tmp[i] = implement_zp(w2[i])
+ w13_tmp = w13_tmp.contiguous()
+ w2_tmp = w2_tmp.contiguous()
+ w13.data = w13_tmp
+ w2.data = w2_tmp
+ w13.xpu_fused_moe = True
+
+ # TODO: will all integrated in Cpp func. Temporary expose before gemm fusion
+ num_experts_per_node = num_experts
+ experts_per_token = n_experts_per_token
+ num_moe_inputs = n_experts_per_token * num_rows
+ permuted_elems = num_moe_inputs * hidden_size
+ # interbuf_elems = num_moe_inputs * inter_size
+ permuted_row_to_unpermuted_row_size = num_moe_inputs * 4
+ permuted_token_selected_experts_size = num_moe_inputs * 4
+ src_to_dest_map_size = experts_per_token * num_rows * 4
+ expert_first_token_offset_size = (num_experts_per_node + 1) * 8
+ num_tokens_per_block = compute_num_tokens_per_block(
+ num_rows, num_experts_per_node)
+ num_blocks_per_seq = ceilDiv(num_rows, num_tokens_per_block)
+ blocked_expert_counts_size = num_experts_per_node * num_blocks_per_seq * 4
+ blocked_expert_counts_cumsum_size = blocked_expert_counts_size
+ blocked_row_to_unpermuted_row_size = num_experts_per_node * num_rows * 4
+ permuted_data_size = permuted_elems * hidden_states.element_size()
+ permuted_token_final_scales_size = num_moe_inputs * 4
+
+ ws_map = {}
+ map_offset = 0
+
+ def config_ws(name, size):
+ nonlocal map_offset
+ if size % 256 != 0:
+ size += 256 - size % 256
+ ws_map[name] = (size, map_offset)
+ map_offset += size
+
+ config_ws("permuted_row_to_unpermuted_row",
+ permuted_row_to_unpermuted_row_size)
+ config_ws("permuted_token_selected_experts",
+ permuted_token_selected_experts_size)
+ config_ws("unpermuted_row_to_permuted_row", src_to_dest_map_size)
+ config_ws("blocked_expert_counts", blocked_expert_counts_size)
+ config_ws("blocked_expert_counts_cumsum",
+ blocked_expert_counts_cumsum_size)
+ config_ws("blocked_row_to_unpermuted_row",
+ blocked_row_to_unpermuted_row_size)
+ config_ws("expert_first_token_offset", expert_first_token_offset_size)
+ config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
+ config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
+
+ workspace = torch.zeros(map_offset,
+ dtype=torch.uint8,
+ device=hidden_states.device)
+ if topk_ids.dtype == torch.int32:
+ topk_ids = topk_ids.to(torch.int64)
+ ops.fused_moe_prologue(
+ input=hidden_states,
+ token_selected_experts=topk_ids,
+ token_final_scales=topk_weights,
+ workspace=workspace,
+ hidden_size=hidden_size,
+ inter_size=inter_size,
+ ep_rank=ep_rank,
+ ep_size=ep_size,
+ num_experts_on_rank=num_experts_per_node)
+
+ expert_first_token_offset_bytes = workspace[
+ ws_map["expert_first_token_offset"][1]:
+ ws_map["expert_first_token_offset"][1] +
+ expert_first_token_offset_size]
+ unpermuted_row_to_permuted_row_bytes = workspace[
+ ws_map["unpermuted_row_to_permuted_row"][1]:
+ ws_map["unpermuted_row_to_permuted_row"][1] +
+ src_to_dest_map_size]
+ permuted_row_to_unpermuted_row_bytes = workspace[
+ ws_map["permuted_row_to_unpermuted_row"][1]:
+ ws_map["permuted_row_to_unpermuted_row"][1] +
+ permuted_row_to_unpermuted_row_size]
+
+ if torch.compiler.is_compiling():
+ expert_first_token_offset = _bytes_to_typed_tensor(
+ expert_first_token_offset_bytes, torch.int64
+ )
+ unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
+ unpermuted_row_to_permuted_row_bytes, torch.int32
+ )
+ permuted_row_to_unpermuted_row = _bytes_to_typed_tensor(
+ permuted_row_to_unpermuted_row_bytes, torch.int32
+ )
+ else:
+ expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
+ unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
+ permuted_row_to_unpermuted_row = permuted_row_to_unpermuted_row_bytes.view(torch.int32)
+ gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
+ ws_map["overlapped_gemm1_gemm2_inputs"][1] +
+ permuted_data_size].view(hidden_states.dtype).view(
+ num_moe_inputs, hidden_size)
+ # permuted_token_final_scales = workspace[
+ # ws_map["permuted_token_final_scales"][1]:
+ # ws_map["permuted_token_final_scales"][1] +
+ # permuted_token_final_scales_size].view(torch.float)
+ gemm1_output = torch.empty((num_moe_inputs, 2 * inter_size),
+ dtype=hidden_states.dtype,
+ device=hidden_states.device)
+
+ ########### gemm1 ##################
+ input_B = w13
+
+ if not is_fp8 and not is_int4 and not is_mxfp4:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=gemm1_input,
+ ptr_B=input_B,
+ ptr_scales=None,
+ ptr_bias=w13_bias,
+ ptr_D=gemm1_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=2 * inter_size,
+ K=hidden_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+ else:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=gemm1_input,
+ ptr_B=input_B,
+ ptr_scales=w13_scales,
+ ptr_bias=w13_bias,
+ ptr_D=gemm1_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=2 * inter_size,
+ K=hidden_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+
+ # act
+ act_output = torch.empty((num_moe_inputs, inter_size),
+ dtype=gemm1_output.dtype,
+ device=gemm1_output.device)
+ if activation == "silu":
+ ops.silu_and_mul(act_output, gemm1_output)
+ elif activation == "gelu":
+ ops.gelu_and_mul(act_output, gemm1_output)
+ elif activation == "swigluoai":
+ ops.swigluoai_and_mul(act_output, gemm1_output, 1.702, 7.0)
+ else:
+ raise ValueError(f"Unsupported FusedMoe activation: {activation}.")
+
+ ########### gemm2 ##################
+ input_A = act_output.contiguous()
+ input_B = w2
+ gemm2_output = torch.empty((num_moe_inputs, hidden_size),
+ dtype=hidden_states.dtype,
+ device=hidden_states.device)
+ if not is_fp8 and not is_int4 and not is_mxfp4:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=None,
+ ptr_bias=w2_bias,
+ ptr_D=gemm2_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=hidden_size,
+ K=inter_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+ else:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=w2_scales,
+ ptr_bias=w2_bias,
+ ptr_D=gemm2_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=hidden_size,
+ K=inter_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+
+ ops.moe_gather(output, gemm2_output, topk_weights,
+ permuted_row_to_unpermuted_row,
+ unpermuted_row_to_permuted_row,
+ expert_first_token_offset,
+ num_experts_per_node)
+ return output
+
+
+def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor:
+ """Apply jitter to the input tensor for regularization."""
+ low = 1.0 - moe_jitter_eps
+ high = 1.0 + moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return x * (low + noise * (high - low))
+
+
+def compute_top_k(scores: torch.Tensor, moe_top_k: int):
+ """Compute the top-k scores from the logits."""
+ if moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, moe_top_k, dim=-1)
+
+
+def route_tokens_xpu(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: torch.Tensor,
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ training: bool = False,
+) -> tuple:
+ """Route tokens to experts and compute expert weights and indices (XPU version)."""
+ if training and moe_jitter_eps is not None:
+ x = apply_jitter(x, moe_jitter_eps)
+
+ x_flat = x.view(-1, x.shape[-1])
+ logits = torch.nn.functional.linear(x_flat, router_weight, router_bias)
+ expert_weights, expert_indices = compute_top_k(logits, moe_top_k)
+ expert_weights = expert_weights.softmax(dim=-1)
+ if moe_normalize_expert_weights is not None:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+
+ return logits, expert_weights, expert_indices
+
+
+def _get_device_mesh(model):
+ """Extract device_mesh from child's unused pre_hook closure for EP support."""
+ try:
+ hook = next(
+ h
+ for h in model.experts._forward_pre_hooks.values()
+ if "device_mesh" in h.__code__.co_freevars
+ )
+ return hook.__closure__[
+ hook.__code__.co_freevars.index("device_mesh")
+ ].cell_contents
+ except Exception:
+ return None
+
+
+class MegaBlocksMoeMLP(torch.nn.Module):
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> tuple:
+ """
+ Forward pass through the MoE layer.
+
+ Args:
+ x: Input tensor of shape [batch_size, seq_len, hidden_size] or [tokens, hidden_size]
+
+ Returns:
+ Tuple of (output, expert_weights) where:
+ - output: Tensor of same shape as input
+ - expert_weights: Expert weights for each token [tokens, top_k]
+ """
+
+ # Get MoE parameters from the wrapped modules
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+
+ # Get EP (Expert Parallelism) parameters
+ ep_size = 1
+ ep_rank = 0
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = _get_device_mesh(self)
+ if device_mesh is not None:
+ expert_parallel_group = device_mesh.get_group()
+ if expert_parallel_group is not None:
+ import torch.distributed as dist
+ if dist.is_initialized():
+ ep_size = dist.get_world_size(expert_parallel_group)
+ ep_rank = dist.get_rank(expert_parallel_group)
+
+ # Number of experts on this rank
+ num_experts_on_rank = moe_num_experts // ep_size
+
+ # Detect activation type - check for GptOss-style swigluoai activation
+ # GptOssExperts has alpha and limit attributes for swigluoai
+ if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
+ activation = "swigluoai"
+ else:
+ activation = getattr(self.experts, "activation", "silu")
+
+ # Get weight tensors - support different naming conventions
+ if hasattr(self.experts, "gate_up_proj"):
+ w13 = self.experts.gate_up_proj
+ # NOTE: swigluoai_and_mul kernel expects interleaved layout [g0,u0,g1,u1,...]
+ # which matches GptOss's gate_up_proj layout, so no conversion needed.
+
+ elif hasattr(self.experts, "w1"):
+ # Combine w1 and w3 if stored separately
+ w1 = self.experts.w1
+ w3 = getattr(self.experts, "w3", None)
+ if w3 is not None:
+ w13 = torch.cat([w1, w3], dim=-2)
+ else:
+ w13 = w1
+ else:
+ raise AttributeError("experts module must have 'gate_up_proj' or 'w1' attribute")
+
+ if hasattr(self.experts, "down_proj"):
+ w2 = self.experts.down_proj
+ elif hasattr(self.experts, "w2"):
+ w2 = self.experts.w2
+ else:
+ raise AttributeError("experts module must have 'down_proj' or 'w2' attribute")
+
+ # Get optional bias tensors
+ w13_bias = getattr(self.experts, "gate_up_proj_bias", None)
+ w2_bias = getattr(self.experts, "down_proj_bias", None)
+
+ # Get quantization info
+ is_fp8 = getattr(self.experts, "is_fp8", False)
+ is_int4 = getattr(self.experts, "is_int4", False)
+ is_mxfp4 = getattr(self.experts, "is_mxfp4", False)
+
+ w13_scales = getattr(self.experts, "gate_up_proj_scales", None)
+ w2_scales = getattr(self.experts, "down_proj_scales", None)
+
+ # Store original shape
+ in_shape = x.size()
+
+ # Route tokens to experts
+ logits, expert_weights, expert_indices = route_tokens_xpu(
+ x,
+ self.router.weight,
+ getattr(self.router, "bias", None),
+ moe_top_k,
+ moe_num_experts,
+ moe_jitter_eps,
+ moe_normalize_expert_weights,
+ self.training,
+ )
+
+ # Reshape input for fused MoE
+ x_flat = x.view(-1, x.shape[-1])
+
+ # Call XPU fused MoE kernel
+ output = xpu_fused_moe(
+ hidden_states=x_flat,
+ w13=w13,
+ w13_scales=w13_scales,
+ w13_bias=w13_bias,
+ w2=w2,
+ w2_scales=w2_scales,
+ w2_bias=w2_bias,
+ topk_weights=expert_weights.float(),
+ topk_ids=expert_indices,
+ n_experts_per_token=moe_top_k,
+ activation=activation,
+ num_experts=num_experts_on_rank,
+ ep_rank=ep_rank,
+ ep_size=ep_size,
+ is_fp8=is_fp8,
+ is_int4=is_int4,
+ is_mxfp4=is_mxfp4,
+ )
+
+ # All-reduce across EP group to combine partial expert outputs
+ if ep_size > 1 and expert_parallel_group is not None:
+ import torch.distributed as dist
+ dist.all_reduce(output, op=dist.ReduceOp.SUM, group=expert_parallel_group)
+
+ # Restore original shape
+ output = output.view(in_shape)
+
+ return output, expert_weights
+
+
+# Export classes and functions
+__all__ = [
+ "MegaBlocksMoeMLP",
+ "xpu_fused_moe",
+ "cutlass_grouped_gemm",
+ "cutlass_grouped_gemm_xe2",
+]
\ No newline at end of file
diff --git a/build/torch211-cxx11-cu130-aarch64-linux/__init__.py b/build/torch211-cxx11-cu130-aarch64-linux/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..38075732c6d8fa0e1e6ef493145e1aca3851ae6b
--- /dev/null
+++ b/build/torch211-cxx11-cu130-aarch64-linux/__init__.py
@@ -0,0 +1,202 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+from ._ops import ops
+
+from .grouped_gemm import backend as gg_backend
+from .grouped_gemm import ops as gg_ops
+
+
+from ._layers.arguments import Arguments
+from ._layers.dmoe import ParallelDroplessMLP, dMoE
+from ._layers.glu import SparseGLU
+from ._layers.mlp import MLP, SparseMLP
+from ._layers.moe import MoE, ParallelMLP, get_load_balancing_loss
+
+from . import layers
+
+# This section contains the direct kernel exports (not inlcuded in the original code)
+def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
+ """
+ Compute exclusive cumulative sum along the specified dimension.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ result = ops.exclusive_cumsum(x, dim)
+ out.copy_(result)
+ return out
+
+
+def inclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
+ """
+ Compute inclusive cumulative sum along the specified dimension.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ result = ops.inclusive_cumsum(x, dim)
+ out.copy_(result)
+ return out
+
+
+def histogram(x: torch.Tensor, num_bins: int) -> torch.Tensor:
+ """
+ Compute histogram of input tensor values.
+
+ Args:
+ x: Input tensor
+ num_bins: Number of histogram bins
+
+ Returns:
+ Histogram tensor with counts for each bin
+ """
+ return ops.histogram(x, num_bins)
+
+
+def indices(
+ padded_bins: torch.Tensor,
+ block_size: int,
+ output_block_rows: int,
+ output_block_columns: int,
+) -> torch.Tensor:
+ """
+ Construct indices from padded bins for sparse operations.
+
+ Args:
+ padded_bins: Tensor containing bin boundaries
+ block_size: Size of each block
+ output_block_rows: Number of rows in output blocks
+ output_block_columns: Number of columns in output blocks
+
+ Returns:
+ Tensor containing constructed indices
+ """
+ return ops.indices(padded_bins, block_size, output_block_rows, output_block_columns)
+
+
+def replicate_forward(
+ x: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Forward pass of replicate operation - replicate values according to bin sizes.
+
+ Args:
+ x: Input tensor with values to replicate
+ bins: Tensor containing bin sizes
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ return ops.replicate_forward(x, bins, out)
+
+
+def replicate_backward(
+ grad: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Backward pass of replicate operation - reduce gradients back to bins.
+
+ Args:
+ grad: Gradient tensor to reduce
+ bins: Tensor containing bin sizes
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ return ops.replicate_backward(grad, bins, out)
+
+
+def sort(
+ x: torch.Tensor, end_bit: int, x_out: torch.Tensor, iota_out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Radix sort with index tracking.
+
+ Args:
+ x: Input tensor to sort
+ end_bit: Number of bits to consider in sorting
+ x_out: Output tensor for sorted values
+ iota_out: Output tensor for sorted indices
+
+ Returns:
+ The sorted values tensor
+ """
+ return ops.sort(x, end_bit, x_out, iota_out)
+
+
+# Convenience functions for common use cases
+def cumsum(x: torch.Tensor, dim: int = -1, exclusive: bool = False) -> torch.Tensor:
+ """
+ Compute cumulative sum with automatic output allocation.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum (default: last dimension)
+ exclusive: Whether to compute exclusive (True) or inclusive (False) cumsum
+
+ Returns:
+ New tensor containing the cumulative sum
+ """
+ out = torch.empty_like(x)
+ if exclusive:
+ return exclusive_cumsum(x, dim, out)
+ else:
+ return inclusive_cumsum(x, dim, out)
+
+
+def argsort(x: torch.Tensor, end_bit: int = 32) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Sort tensor and return both sorted values and indices.
+
+ Args:
+ x: Input tensor to sort
+ end_bit: Number of bits to consider in sorting
+
+ Returns:
+ Tuple of (sorted_values, sorted_indices)
+ """
+ x_out = torch.empty_like(x)
+ iota_out = torch.empty_like(x)
+ sort(x, end_bit, x_out, iota_out)
+ return x_out, iota_out
+
+
+# Export public API
+__all__ = [
+ "MyReplacementLayer",
+ # Direct kernel exports
+ "exclusive_cumsum",
+ "inclusive_cumsum",
+ "histogram",
+ "indices",
+ "replicate_forward",
+ "replicate_backward",
+ "sort",
+ "cumsum",
+ "argsort",
+ # Original exports
+ "Arguments",
+ "ParallelDroplessMLP",
+ "dMoE",
+ "SparseGLU",
+ "MLP",
+ "SparseMLP",
+ "MoE",
+ "ParallelMLP",
+ "get_load_balancing_loss",
+]
diff --git a/build/torch211-cxx11-cu130-aarch64-linux/_layers/__init__.py b/build/torch211-cxx11-cu130-aarch64-linux/_layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a720e7a2cc4e44636f6e433a2750e945dc38e8b2
--- /dev/null
+++ b/build/torch211-cxx11-cu130-aarch64-linux/_layers/__init__.py
@@ -0,0 +1,10 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# from megablocks.layers.dmoe import dMoE
+from .moe import MoE
+
+__all__ = [
+ 'MoE',
+ # 'dMoE',
+]
diff --git a/build/torch211-cxx11-cu130-aarch64-linux/_layers/activation_fn.py b/build/torch211-cxx11-cu130-aarch64-linux/_layers/activation_fn.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e1d956704840aa4daf7d1d71d24e051567feab9
--- /dev/null
+++ b/build/torch211-cxx11-cu130-aarch64-linux/_layers/activation_fn.py
@@ -0,0 +1,33 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Callable, Union
+
+import torch
+from ..stk import Matrix
+
+
+def act_fn(
+ x: Matrix,
+ function: Callable,
+ return_grad_fn: bool = False,
+ **kwargs,
+) -> Union[tuple[Matrix, Any] | Matrix]:
+ assert isinstance(x, Matrix)
+ with torch.set_grad_enabled(torch.is_grad_enabled() or return_grad_fn):
+ if return_grad_fn:
+ x.data.requires_grad = True
+ out = function(x.data, **kwargs)
+ y = Matrix(
+ x.size(),
+ out,
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
+ if return_grad_fn:
+ return y, out.backward
+ return y
diff --git a/build/torch211-cxx11-cu130-aarch64-linux/_layers/all_to_all.py b/build/torch211-cxx11-cu130-aarch64-linux/_layers/all_to_all.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ac7067bcaa34db1d82b340c43550fe3577aa7a3
--- /dev/null
+++ b/build/torch211-cxx11-cu130-aarch64-linux/_layers/all_to_all.py
@@ -0,0 +1,54 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import torch.distributed as dist
+
+
+class AllToAllOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op):
+ out = torch.empty((sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype)
+
+ ctx.input_shape = x.shape
+ ctx.output_split_sizes = output_split_sizes
+ ctx.input_split_sizes = input_split_sizes
+ ctx.group = group
+ handle = dist.all_to_all_single(
+ out,
+ x,
+ output_split_sizes=output_split_sizes,
+ input_split_sizes=input_split_sizes,
+ group=group,
+ async_op=async_op,
+ )
+ return out, handle
+
+ @staticmethod
+ def backward(ctx, grad, _):
+ if ctx.needs_input_grad[0]:
+ out = torch.empty(
+ ctx.input_shape,
+ device=grad.device,
+ dtype=grad.dtype,
+ )
+ dist.all_to_all_single(
+ out,
+ grad,
+ output_split_sizes=ctx.input_split_sizes,
+ input_split_sizes=ctx.output_split_sizes,
+ group=ctx.group,
+ )
+ return out, None, None, None, None
+ return None, None, None, None, None
+
+
+def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False):
+ return AllToAllOp.apply(
+ x,
+ output_split_sizes,
+ input_split_sizes,
+ group,
+ async_op,
+ )
diff --git a/build/torch211-cxx11-cu130-aarch64-linux/_layers/arguments.py b/build/torch211-cxx11-cu130-aarch64-linux/_layers/arguments.py
new file mode 100644
index 0000000000000000000000000000000000000000..4db9b1bd38bc2e2f421625c124f86b85f45c5ae0
--- /dev/null
+++ b/build/torch211-cxx11-cu130-aarch64-linux/_layers/arguments.py
@@ -0,0 +1,101 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import dataclasses
+from functools import partial
+from typing import Any, Callable, Optional, Union
+
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+
+# import megablocks.grouped_gemm_util as grouped_gemm
+from .. import grouped_gemm_util as grouped_gemm
+
+# Type annotation for in-place Tensor initialization function.
+InitFn = Union[Callable[[torch.Tensor], None], partial[torch.Tensor]]
+
+_ALLOWED_BITWIDTHS = (-1, 4, 8)
+
+DEFAULT_ACTIVATION_FN = partial(F.gelu, approximate='tanh')
+
+
+@dataclasses.dataclass
+class Arguments:
+ # Model arguments.
+ hidden_size: int = 1024
+ ffn_hidden_size: int = 4096
+ num_layers: int = 1
+ bias: bool = True
+ return_bias: bool = True
+ activation_fn: Optional[Callable] = DEFAULT_ACTIVATION_FN
+
+ # MoE arguments.
+ moe_num_experts: int = 1
+ moe_top_k: int = 1
+ moe_capacity_factor: int = 1
+ moe_normalize_expert_weights: Optional[Union[int, float]] = None
+ moe_loss_weight: float = 0.1
+ moe_jitter_eps: Optional[float] = None
+ moe_lbl_in_fp32: bool = False
+
+ # Parallelism arguments.
+ moe_expert_model_parallelism: bool = False
+ expert_parallel_group: Optional[dist.ProcessGroup] = None
+ pipeline_model_parallel_size: int = 1
+ num_layers_per_virtual_pipeline_stage: Optional[int] = None
+
+ # Compute arguments.
+ memory_optimized_mlp: bool = False
+ mlp_type: str = 'mlp'
+ mlp_impl: str = 'sparse'
+
+ # Initialization arguments.
+ fp16: bool = True
+ bf16: bool = False
+ device: Union[int, torch.device] = dataclasses.field(default_factory=torch.cuda.current_device)
+ init_method: InitFn = partial(torch.nn.init.normal_, mean=0.0, std=0.02)
+ output_layer_init_method: InitFn = init_method
+
+ # Benchmarking arguments.
+ uniform_expert_assignment: bool = False
+
+ # shared expert arguments
+ shared_expert: bool = False # enable using shared expert
+ fc_cls: Any = torch.nn.Linear # class of the fully connected layer in shared expert (purpose: to allow using custom FC layer eg te.Linear (for FP8))
+ fc_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict,) # kwargs for custom fc layers
+ remat_act_fn: bool = True # enable act fn to be rematerialized instead of stored
+ shared_expert_hidden_size: Optional[
+ int] = None # hidden size of the shared expert IF we want to set it to something different from hidden_size
+ shared_expert_weighted_sum: bool = False # enable using weighted sum for shared expert output (wieghted by number of experts used)
+
+ # Router Z-loss arguments
+ moe_zloss_weight: float = 0 # 1e-3 is a reasonable value
+ moe_zloss_in_fp32: bool = False
+
+ def __post_init__(self):
+ # Sparse MLP is not supported with triton >=3.2.0
+ # TODO: Remove this once sparse is supported with triton >=3.2.0
+ if self.__getattribute__('mlp_impl') == 'sparse':
+ try:
+ import triton
+ if triton.__version__ >= '3.2.0':
+ raise ValueError(
+ 'Sparse MLP is not supported with triton >=3.2.0. Please use mlp_impl="grouped" instead.',
+ )
+ except ImportError:
+ raise ImportError('Triton is required for sparse MLP implementation')
+
+ if self.__getattribute__('mlp_impl') == 'grouped':
+ grouped_gemm.assert_grouped_gemm_is_available()
+
+ if self.shared_expert_hidden_size is None:
+ self.shared_expert_hidden_size = self.ffn_hidden_size
+
+
+def from_megatron(megatron_args: Any):
+ args = Arguments()
+ for field in dataclasses.fields(args):
+ if hasattr(megatron_args, field.name):
+ setattr(args, field.name, getattr(megatron_args, field.name))
+ return args
diff --git a/build/torch211-cxx11-cu130-aarch64-linux/_layers/common.py b/build/torch211-cxx11-cu130-aarch64-linux/_layers/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d07109702963ba48a3b94ab860807954dfd79c1
--- /dev/null
+++ b/build/torch211-cxx11-cu130-aarch64-linux/_layers/common.py
@@ -0,0 +1,26 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+from .arguments import Arguments
+
+
+def dtype(args: Arguments):
+ if args.fp16:
+ return torch.float16
+ elif args.bf16:
+ return torch.bfloat16
+ return None
+
+
+def cast_if_autocast_enabled(tensor):
+ if torch.is_autocast_enabled():
+ if tensor.device.type == 'cuda':
+ dtype = torch.get_autocast_gpu_dtype()
+ elif tensor.device.type == 'cpu':
+ dtype = torch.get_autocast_cpu_dtype()
+ else:
+ raise NotImplementedError()
+ return tensor.to(dtype=dtype)
+ return tensor
diff --git a/build/torch211-cxx11-cu130-aarch64-linux/_layers/dmlp_registry.py b/build/torch211-cxx11-cu130-aarch64-linux/_layers/dmlp_registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..de2ed047042e438c7190ebb139b6f7f30009734c
--- /dev/null
+++ b/build/torch211-cxx11-cu130-aarch64-linux/_layers/dmlp_registry.py
@@ -0,0 +1,42 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Union
+
+from . import glu, mlp
+from .arguments import Arguments
+
+MlpType = Union[mlp.SparseMLP, glu.SparseGLU]
+
+_REGISTRY = {
+ 'mlp': {
+ 'grouped': mlp.GroupedMLP,
+ 'sparse': mlp.SparseMLP,
+ },
+ 'glu': {
+ 'grouped': glu.GroupedGLU,
+ 'sparse': glu.SparseGLU,
+ },
+}
+
+
+def get(args: Arguments) -> MlpType:
+ """Returns an MLP for use in a dMoE instance.
+
+ Uses the provided arguments to instantiate the appropriate
+ MLP instance. This only contains MLPs for use in dMoEs
+ (ie. only for the dropless versions of MoEs).
+
+ Args:
+ args: propagated Arguments dataclass.
+
+ Returns:
+ An instantiated MLP constructed using the input args.
+ """
+ if args.mlp_type not in _REGISTRY:
+ raise ValueError(f'Unsupported mlp type: {args.mlp_type}')
+
+ if args.mlp_impl not in _REGISTRY[args.mlp_type]:
+ raise ValueError(f'{args.mlp_type} does not support {args.mlp_impl} backend.',)
+
+ return _REGISTRY[args.mlp_type][args.mlp_impl](args)
diff --git a/build/torch211-cxx11-cu130-aarch64-linux/_layers/dmoe.py b/build/torch211-cxx11-cu130-aarch64-linux/_layers/dmoe.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d0375a4df2f27134c4127e60be04f3b45693050
--- /dev/null
+++ b/build/torch211-cxx11-cu130-aarch64-linux/_layers/dmoe.py
@@ -0,0 +1,337 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import numpy as np
+import torch
+
+# try:
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/dmoe.py is needed.',
+# )
+
+# import megablocks.ops as ops
+# # from megablocks.ops import ops
+# from megablocks.layers import common, dmlp_registry, moe, mpu
+# from megablocks.layers.arguments import Arguments
+
+from .. import stk
+from .. import ops
+from . import common, dmlp_registry, moe, mpu
+from .arguments import Arguments
+
+def promote_scalar(x):
+ return x.view(1) if not len(x.size()) else x
+
+
+class ParallelDroplessMLP(moe.ParallelMLP):
+
+ def __init__(self, args: Arguments):
+ super(ParallelDroplessMLP, self).__init__(args)
+ self.hidden_size = args.hidden_size
+ self.ffn_hidden_size = mpu.features_per_rank(args)
+ self.blocking = 128
+ self.mlp = dmlp_registry.get(args)
+
+ # Calculate the number of bits needed to represent the column indices
+ # in the intermediate sparse matrix.
+ max_column_index = ((self.ffn_hidden_size * self.num_experts) // self.blocking)
+ self.transpose_sort_end_bit = max(
+ int(np.ceil(np.log2(max_column_index))),
+ 1,
+ )
+
+ def sparse_transpose(self, size, row_indices, column_indices, offsets):
+ block_columns = size[1] // self.blocking
+
+ # Sort row indices by column indices to get the transposed matrix's
+ # column indices.
+ #
+ # NOTE: Our sort operation uses the same width indices as the input values.
+ # To avoid overflow when we have large activation matrices we cast to
+ # 32-bit before sorting.
+ _, gather_indices = ops.sort(
+ column_indices.int(),
+ self.transpose_sort_end_bit,
+ )
+
+ # There are a constant number of blocks in every row of the sparse matrix.
+ # A blocks offset is:
+ #
+ # row_index * blocks_per_row + column_index % blocks_per_row
+ #
+ # Once we have the block offsets ordered for transposition we can divide
+ # by blocks_per_row to get the transposed column indices.
+ column_indices_t = row_indices.gather(0, gather_indices.long())
+ block_offsets_t = gather_indices.int()
+
+ zero = torch.zeros((1,), dtype=torch.int32, device=row_indices.device)
+ nnz_per_column = ops.histogram(column_indices, block_columns)
+ nnz_per_column = ops.inclusive_cumsum(nnz_per_column, 0)
+ if nnz_per_column.dim() == 0:
+ # This addresses an edge case when ffn_hidden_size is equal to self.blocking.
+ nnz_per_column = nnz_per_column.unsqueeze(0)
+ offsets_t = torch.cat([zero, nnz_per_column])
+ return column_indices_t, offsets_t, block_offsets_t
+
+ def topology(self, x, padded_bins):
+ padded_tokens, _ = x.size()
+ assert padded_tokens % self.blocking == 0
+ if self.ffn_hidden_size % self.blocking != 0:
+ raise ValueError(
+ f'The ffn_hidden_size {self.ffn_hidden_size} must be divisible by ' +
+ f'the block size {self.blocking}. Please update your configuration.',
+ )
+
+ # Offsets for the sparse matrix. All rows have the
+ # same number of nonzero blocks dictated by the
+ # dimensionality of a single expert.
+ block_rows = padded_tokens // self.blocking
+ blocks_per_row = self.ffn_hidden_size // self.blocking
+ offsets = torch.arange(
+ 0,
+ block_rows * blocks_per_row + 1,
+ blocks_per_row,
+ dtype=torch.int32,
+ device=x.device,
+ )
+
+ # Indices for the sparse matrix. The indices for
+ # the intermediate matrix are dynamic depending
+ # on the mapping of tokens to experts.
+ column_indices = ops.topology(
+ padded_bins,
+ self.blocking,
+ block_rows,
+ blocks_per_row,
+ )
+
+ # TODO(tgale): This is unused. Remove the need for this in stk.
+ # For now, use meta init to save the device memory.
+ data = torch.empty(
+ column_indices.numel(),
+ self.blocking,
+ self.blocking,
+ dtype=common.dtype(self.args),
+ device='meta',
+ )
+ shape = (
+ padded_tokens,
+ self.ffn_hidden_size * mpu.experts_per_rank(self.args),
+ )
+ row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
+ column_indices_t, offsets_t, block_offsets_t = self.sparse_transpose(
+ shape,
+ row_indices,
+ column_indices,
+ offsets,
+ )
+ return stk.Matrix(
+ shape,
+ data,
+ row_indices,
+ column_indices,
+ offsets,
+ column_indices_t,
+ offsets_t,
+ block_offsets_t,
+ )
+
+ def indices_and_padded_bins(self, top_experts):
+ # Sort the expert ids to produce the scatter/gather
+ # indices for the permutation.
+ top_experts = top_experts.int()
+ bin_ids, indices = ops.sort(top_experts, self.sort_end_bit)
+
+ # Histogram the expert ids to identify the number of
+ # tokens routed to each expert.
+ tokens_per_expert = ops.histogram(top_experts, self.num_experts)
+
+ # Round the token counts up to the block size used in
+ # the matrix muliplications. Caculate the starting
+ # position of each bin.
+ padded_tokens_per_expert = ops.round_up(
+ tokens_per_expert,
+ self.blocking,
+ )
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ padded_bins = promote_scalar(padded_bins)
+
+ # Calculate the bin bounds for the sorted tokens.
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ bins = promote_scalar(bins)
+ return indices, bin_ids, bins, padded_bins, tokens_per_expert
+
+ def sparse_forward_once(self, x, expert_weights, top_experts):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, padded_bins, tokens_per_expert = (self.indices_and_padded_bins(top_experts))
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.padded_gather(
+ x,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ self.top_k,
+ )
+
+ # Create the sparse matrix topology.
+ with torch.no_grad():
+ topo = self.topology(x, padded_bins)
+
+ # Perform the expert computation.
+ x = self.mlp(x, topo)
+
+ # Un-route the data for the MoE output.
+ x = ops.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ padded_bins,
+ self.top_k,
+ )
+ return x, tokens_per_expert
+
+ # For use in the base-class parallel_forward_once.
+ def sparse_permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy, # unused
+ top_k,
+ ):
+
+ # Round the token counts up to the block size used in the matrix
+ # multiplication. Calculate the starting position of each bin.
+ padded_tokens_per_expert = ops.round_up(
+ tokens_per_expert,
+ self.blocking,
+ )
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ padded_bins = promote_scalar(padded_bins)
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
+
+ # Create the sparse matrix topology.
+ with torch.no_grad():
+ topo = self.topology(x, padded_bins)
+
+ # Perform the expert computation.
+ x = self.mlp(x, topo)
+
+ # Un-route the data for the MoE output.
+ return ops.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ def grouped_forward_once(self, x, expert_weights, top_experts):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ out = self.grouped_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ -1, # unused
+ self.args.moe_top_k,
+ )
+ return out, tokens_per_expert
+
+ def grouped_permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy, # unused
+ top_k,
+ ):
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.gather(x, indices, bin_ids, bins, top_k)
+
+ # Perform the expert computation.
+ x = self.mlp(x, tokens_per_expert)
+
+ # Un-route the data for the MoE output.
+ return ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
+
+ def forward_once(self, x, expert_weights, top_experts):
+ if self.args.mlp_impl == 'sparse':
+ return self.sparse_forward_once(x, expert_weights, top_experts)
+ else:
+ return self.grouped_forward_once(x, expert_weights, top_experts)
+
+ def permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ ):
+ if self.args.mlp_impl == 'sparse':
+ return self.sparse_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ )
+ else:
+ return self.grouped_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ )
+
+
+class dMoE(moe.MoE):
+
+ def _init_experts_mlp(self, args: Arguments):
+ return ParallelDroplessMLP(args)
diff --git a/build/torch211-cxx11-cu130-aarch64-linux/_layers/gelu.py b/build/torch211-cxx11-cu130-aarch64-linux/_layers/gelu.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4c9e6532798615b5c12c96694241a4c18ee8f7b
--- /dev/null
+++ b/build/torch211-cxx11-cu130-aarch64-linux/_layers/gelu.py
@@ -0,0 +1,52 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# try:
+# import stk
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/gelu.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+import torch.nn.functional as F
+
+
+@torch.jit.script
+def _gelu_backward_inplace(g, x):
+ tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
+ ff = (0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out))
+ return g.mul_(ff)
+
+
+def gelu_backward_(grad: stk.Matrix, x: stk.Matrix):
+ # NOTE: The two sparse matrices must have the same topology.
+ if isinstance(grad, stk.Matrix) and isinstance(x, stk.Matrix):
+ return stk.Matrix(
+ x.size(),
+ _gelu_backward_inplace(grad.data, x.data),
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
+ return _gelu_backward_inplace(grad, x)
+
+
+def gelu(x: stk.Matrix):
+ assert isinstance(x, stk.Matrix)
+ return stk.Matrix(
+ x.size(),
+ F.gelu(x.data, approximate='tanh'),
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
diff --git a/build/torch211-cxx11-cu130-aarch64-linux/_layers/glu.py b/build/torch211-cxx11-cu130-aarch64-linux/_layers/glu.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f297a41ff6a1a2a285f5b461951672364b898da
--- /dev/null
+++ b/build/torch211-cxx11-cu130-aarch64-linux/_layers/glu.py
@@ -0,0 +1,244 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# import stk.ops
+# try:
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/glu.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+
+# from megablocks import grouped_gemm_util as gg
+# from megablocks.layers import common, mpu
+# from megablocks.layers.activation_fn import act_fn
+# from megablocks.layers.arguments import Arguments
+# from megablocks.layers.mlp import (
+# SharedMLP,
+# SparseMLP,
+# create_dmoe_expert_weights,
+# resolve_dtensor,
+# )
+
+from .. import grouped_gemm_util as gg
+from . import common, mpu
+from .activation_fn import act_fn
+from .arguments import Arguments
+from .mlp import (
+ SharedMLP,
+ SparseMLP,
+ create_dmoe_expert_weights,
+ resolve_dtensor,
+)
+
+
+class SparseGLU(SparseMLP):
+
+ def __init__(self, args: Arguments):
+ super().__init__(args)
+ self.v1 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ with torch.no_grad():
+ self.v1.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ ),
+ )
+
+ mpu.set_expert_model_parallel_attributes(
+ self.v1,
+ self._should_set_parallelism_attribute,
+ )
+
+ def forward(self, x, topo):
+ if self.args.memory_optimized_mlp:
+ raise NotImplementedError(
+ 'Memory optimized implementation not yet supported with GLU with sparse kernels.',
+ )
+
+ w1, v1, w2 = self.scale_grad(self.w1), self.scale_grad(self.v1,), self.scale_grad(self.w2)
+ w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2)
+
+ # Compute the GLU.
+ x1 = stk.ops.sdd(x, w1.t(), topo)
+ x2 = stk.ops.sdd(x, v1.t(), topo)
+
+ activation_fn_out = act_fn(x1, self.args.activation_fn)
+ x1 = stk.ops.mul(activation_fn_out, x2)
+
+ return stk.ops.dsd(x1, w2)
+
+
+class MemoryOptimizedGroupedGLU(torch.autograd.Function):
+ """GroupedMLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, v1, w2, batch_sizes, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ v1 = v1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], v1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not v1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1', 'v1' and 'w2'.")
+
+ # Layer 0: x @ w1.t().
+ assert gg.backend is not None
+ sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True)
+ v1_out = gg.backend.gmm(x, v1, batch_sizes, trans_b=True)
+
+ # GeLU.
+ activation_fn_out = activation_fn(sdd_out) * v1_out
+
+ # Layer 1: x @ w2.
+ dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, v1, w2, batch_sizes, x, sdd_out, v1_out)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx, ddsd_out):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # Unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, v1, w2 = saved_tensors[:3]
+ batch_sizes = saved_tensors[3]
+ x = saved_tensors[4]
+ sdd_out, v1_out = saved_tensors[5:7]
+
+ # Rematerialize activation_fn output.
+ activation_fn = ctx.activation_fn
+ with torch.set_grad_enabled(True):
+ sdd_out.requires_grad = True
+ v1_out.requires_grad = True
+ activation_fn_out = activation_fn(sdd_out) * v1_out
+ activation_grad_fn = activation_fn_out.backward
+
+ # Compute dw2 with recomputed activation_fn output.
+ assert gg.backend is not None
+ dw2 = gg.backend.gmm(
+ activation_fn_out,
+ ddsd_out,
+ batch_sizes,
+ trans_a=True,
+ )
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ gg.backend.gmm(
+ ddsd_out,
+ w2,
+ batch_sizes,
+ trans_b=True,
+ c=dactivation_fn_out,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out)
+ dsdd_out = sdd_out.grad
+ dv1_out = v1_out.grad
+
+ # Compute dw1.
+ dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True)
+
+ # Compute dv1.
+ dv1 = gg.backend.gmm(dv1_out, x, batch_sizes, trans_a=True)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ dx = ddsd_out
+ gg.backend.gmm(dsdd_out, w1, batch_sizes, c=dx)
+ dx += gg.backend.gmm(dv1_out, v1, batch_sizes)
+ return dx, dw1, dv1, dw2, None, None
+
+
+memory_optimized_grouped_glu = MemoryOptimizedGroupedGLU.apply
+
+
+class GroupedGLU(SparseGLU):
+
+ def forward(self, x, tokens_per_expert):
+ batch_sizes = tokens_per_expert.cpu().to(torch.long)
+ w1, v1, w2 = (
+ self.scale_grad(self.w1),
+ self.scale_grad(self.v1),
+ self.scale_grad(self.w2),
+ )
+ w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2)
+
+ # Re-shape the weights for the grouped GEMMs.
+ ne = mpu.experts_per_rank(self.args)
+ w1 = w1.view(ne, -1, self.args.hidden_size)
+ v1 = v1.view(ne, -1, self.args.hidden_size)
+ w2 = w2.view(ne, -1, self.args.hidden_size)
+
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_grouped_glu(
+ x,
+ w1,
+ v1,
+ w2,
+ batch_sizes,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ assert gg.ops is not None
+ x1 = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
+ x2 = gg.ops.gmm(x, v1, batch_sizes, trans_b=True)
+ x1 = self.args.activation_fn(x1) * x2
+ return gg.ops.gmm(x1, w2, batch_sizes)
+
+
+class SharedGLU(SharedMLP):
+ """GPU for shared expert.
+
+ Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTGLU class
+ """
+
+ def __init__(self, args: Arguments):
+ super().__init__(args)
+ self.gate_proj = args.fc_cls(
+ args.hidden_size,
+ self.args.shared_expert_hidden_size,
+ **self.fc_kwargs,
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))
diff --git a/build/torch211-cxx11-cu130-aarch64-linux/_layers/memory_test.py b/build/torch211-cxx11-cu130-aarch64-linux/_layers/memory_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..74d1166931b712635131985b25a89f4ca23e576d
--- /dev/null
+++ b/build/torch211-cxx11-cu130-aarch64-linux/_layers/memory_test.py
@@ -0,0 +1,103 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import gc
+
+import torch
+import torch.distributed as dist
+
+# from megablocks.layers import arguments, dmoe
+from . import arguments, dmoe
+
+_TESTS = ((8, 2048, 4096, 4096, 32, 4),)
+
+
+def get_tensors():
+ ptrs = set()
+ out = []
+ for obj in gc.get_objects():
+ if torch.is_tensor(obj):
+ if not obj.is_contiguous() or obj.data_ptr() in ptrs:
+ continue
+ out.append(obj)
+ ptrs.add(obj.data_ptr())
+ return out
+
+
+def test_memory(
+ group,
+ batch_size,
+ sequence_length,
+ hidden_size,
+ ffn_hidden_size,
+ num_experts,
+ top_k,
+):
+ args = arguments.Arguments(
+ hidden_size=hidden_size,
+ ffn_hidden_size=ffn_hidden_size,
+ moe_num_experts=num_experts,
+ moe_top_k=top_k,
+ moe_expert_model_parallelism=True,
+ expert_parallel_group=group,
+ fp16=False,
+ bf16=True,
+ device=torch.cuda.current_device(),
+ )
+ layer = dmoe.dMoE(args).cuda()
+
+ x = torch.randn((batch_size, sequence_length, hidden_size),
+ device=torch.cuda.current_device(),
+ dtype=torch.bfloat16).requires_grad_(True)
+ torch.cuda.empty_cache()
+
+ # Run forward + backward.
+ # with torch.autograd.detect_anomaly():
+ out, _ = layer(x)
+ out.mean().backward()
+
+ # Report peak memory.
+ mem = torch.cuda.max_memory_allocated()
+ print('Max Memory Allocated = {:0.0f}MiB'.format(mem / 1e6))
+ print('Max Memory Reserved = {:0.0f}MiB'.format(torch.cuda.max_memory_reserved() / 1e6,),)
+
+ # Calculate weight and gradient memory usage.
+ weight_memory = 2 * (
+ layer.router.layer.weight.numel() + layer.experts.mlp.w1.numel() + layer.experts.mlp.w2.numel()
+ )
+
+ def grad_numel(x):
+ if x.grad is not None:
+ return x.grad.numel()
+ return 0
+
+ grad_memory = 2 * (
+ grad_numel(layer.router.layer.weight) + grad_numel(layer.experts.mlp.w1) + grad_numel(layer.experts.mlp.w2)
+ )
+ weight_memory += grad_memory
+
+ print('Weight Memory Allocated = {:0.0f}MiB'.format(weight_memory / 1e6))
+ print('Activation Memory Allocated = {:0.0f}MiB'.format((mem - weight_memory) / 1e6,),)
+
+ # Manually calculate GPU memory usage from the garbage
+ # collector.
+ gc.collect()
+ total = 0
+ tensors = get_tensors()
+ tensors = sorted(tensors, key=lambda x: -x.numel())
+ for i, t in enumerate(tensors):
+ total += t.numel()
+ print(f'{i}: {t.shape}, {t.numel() * 2}')
+ del tensors
+
+ print('Total Bytes Found = {:0.0f}MiB'.format(total * 2 / 1e6))
+
+
+if __name__ == '__main__':
+ assert dist.is_available()
+ group = dist.init_process_group(backend='nccl')
+ local_rank = dist.get_rank(group)
+ torch.cuda.set_device(local_rank)
+
+ for args in _TESTS:
+ test_memory(group, *args)
diff --git a/build/torch211-cxx11-cu130-aarch64-linux/_layers/mlp.py b/build/torch211-cxx11-cu130-aarch64-linux/_layers/mlp.py
new file mode 100644
index 0000000000000000000000000000000000000000..c99afb9904c24a8b6a83e79059cd1251dbbfd99e
--- /dev/null
+++ b/build/torch211-cxx11-cu130-aarch64-linux/_layers/mlp.py
@@ -0,0 +1,587 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# try:
+# import stk
+# import stk.backend.triton_kernels
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/mlp.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+from packaging import version
+
+# from megablocks import grouped_gemm_util as gg
+# from megablocks.layers import common, gelu, mpu
+# from megablocks.layers.activation_fn import act_fn
+# from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
+
+from .. import grouped_gemm_util as gg
+from . import common, gelu, mpu
+from .activation_fn import act_fn
+from .arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
+
+class ScaleGradient(torch.autograd.Function):
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx: Any, x: torch.Tensor, scale: float):
+ ctx.scale = scale
+ return x
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx: torch.Tensor, grad: torch.Tensor):
+ return grad * ctx.scale, None
+
+
+scale_gradient = ScaleGradient.apply
+
+
+def resolve_dtensor(weight: torch.Tensor):
+ if version.parse(torch.__version__) >= version.parse('2.0.0'):
+ from torch.distributed._tensor import DTensor
+ if isinstance(weight, DTensor):
+ return weight.to_local()
+ return weight
+
+
+def create_moe_expert_weights(
+ args: Arguments,
+ num_experts: int,
+ ffn_hidden_size: int,
+ hidden_size: int,
+ init_method: InitFn,
+):
+ # Create the entire weight matrix such that the sampled weights will
+ # not vary between data parallelism and expert model parallelism for
+ # the same random seed.
+ master_weights = torch.empty(
+ num_experts,
+ ffn_hidden_size,
+ hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ )
+ init_method(master_weights)
+
+ if not args.moe_expert_model_parallelism:
+ return master_weights
+
+ # Calculate the amount of sharding in each dimension.
+ expert_sharding_degree = mpu.expert_sharding_degree(args)
+ hidden_sharding_degree = mpu.hidden_sharding_degree(args)
+
+ # Calculate the experts per rank.
+ #
+ # NOTE: We assign ranks to be expert parallel before going
+ # tensor parallel.
+ rank = mpu.get_expert_parallel_rank(args)
+ expert_rank = rank % expert_sharding_degree
+ num_experts_per_rank = num_experts // expert_sharding_degree
+ start_expert = expert_rank * num_experts_per_rank
+ end_expert = (expert_rank + 1) * num_experts_per_rank
+
+ # Calculate the rows per rank.
+ row_rank = rank // expert_sharding_degree
+ num_rows_per_rank = ffn_hidden_size // hidden_sharding_degree
+ start_row = row_rank * num_rows_per_rank
+ end_row = (row_rank + 1) * num_rows_per_rank
+
+ # Slice the weight matrix to get the chunk for this rank.
+ with torch.no_grad():
+ weights = master_weights[start_expert:end_expert, start_row:end_row]
+ return weights
+
+
+class MLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ # expert_parallel_world_size = mpu.get_expert_parallel_world_size(args)
+ experts_per_rank = mpu.experts_per_rank(args)
+
+ self.w1 = torch.nn.Parameter(
+ torch.empty(
+ experts_per_rank,
+ args.hidden_size,
+ mpu.features_per_rank(args),
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ self.w2 = torch.nn.Parameter(
+ torch.empty(
+ experts_per_rank,
+ mpu.features_per_rank(args),
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w1,
+ args.moe_expert_model_parallelism,
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w2,
+ args.moe_expert_model_parallelism,
+ )
+
+ # Initialize the parameters for the MLP.
+ #
+ # NOTE: It is important that we create the weight tensors prior
+ # to creating the master weights and slicing our the piece for
+ # this rank. If the master weights are created first the PyTorch
+ # caching allocator appears to use the same memory block for these
+ # and the slice which causes large increases in our peak memory
+ # usage.
+ with torch.no_grad():
+ w1 = create_moe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ )
+ self.w1.copy_(w1.transpose(1, 2).contiguous())
+ self.w2.copy_(
+ create_moe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.output_layer_init_method,
+ ),
+ )
+
+ self.gradient_scale = None
+ if self.args.moe_expert_model_parallelism:
+ self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,)
+
+ def scale_grad(self, w):
+ if self.gradient_scale is None:
+ return w
+ return scale_gradient(w, self.gradient_scale)
+
+ def forward(self, x):
+ w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2)
+ w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2)
+ x = torch.bmm(x, w1)
+ x = self.args.activation_fn(x)
+ return torch.bmm(x, w2)
+
+
+def create_dmoe_expert_weights(
+ args: Arguments,
+ num_experts: int,
+ rows: int,
+ columns: int,
+ init_method: InitFn,
+):
+ weights = create_moe_expert_weights(
+ args,
+ num_experts,
+ rows,
+ columns,
+ init_method,
+ )
+ return weights.view([-1, columns])
+
+
+class MemoryOptimizedMLP(torch.autograd.Function):
+ """Sparse MLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, w2, topo, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.")
+
+ topo_tensors = (
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets,
+ topo.column_indices_t,
+ topo.offsets_t,
+ topo.block_offsets_t,
+ )
+
+ # Layer 0: x @ w1.t().
+ sdd_out = stk.ops.sdd(x, w1.t(), topo)
+
+ # GeLU.
+ activation_fn_out = act_fn(sdd_out, activation_fn)
+
+ # Layer 1: x @ w2.
+ dsd_out = stk.ops.dsd(activation_fn_out, w2)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.shape = topo.shape
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.data.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, w2, *topo_tensors, x, sdd_out.data)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx, ddsd_out):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, w2 = saved_tensors[:2]
+ topo_tensors = saved_tensors[2:8]
+ x = saved_tensors[8]
+ sdd_out_data = saved_tensors[9]
+
+ # rematerialize activation function output
+ activation_fn = ctx.activation_fn
+ sdd_out = stk.Matrix(ctx.shape, sdd_out_data, *topo_tensors)
+ activation_fn_out, activation_grad_fn = act_fn(
+ sdd_out,
+ activation_fn,
+ return_grad_fn=True,
+ )
+
+ # Compute dw2 with recomputed activation_fn output.
+ dw2 = stk.ops.dsd(activation_fn_out.t(), ddsd_out)
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ stk.backend.triton_kernels.sdd(
+ ddsd_out,
+ w2.t(),
+ dactivation_fn_out.shape,
+ dactivation_fn_out.data,
+ dactivation_fn_out.offsets,
+ dactivation_fn_out.row_indices,
+ dactivation_fn_out.column_indices,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ if activation_fn is DEFAULT_ACTIVATION_FN:
+ dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out)
+ else:
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out.data)
+ dsdd_out = stk.Matrix(ctx.shape, sdd_out.data.grad, *topo_tensors)
+
+ # Compute dw1.
+ dw1 = stk.ops.dsd(dsdd_out.t(), x)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ stk.backend.triton_kernels.dsd(
+ dsdd_out.shape,
+ dsdd_out.data,
+ dsdd_out.offsets,
+ dsdd_out.row_indices,
+ dsdd_out.column_indices,
+ dsdd_out.offsets_t,
+ dsdd_out.column_indices_t,
+ dsdd_out.block_offsets_t,
+ False,
+ w1,
+ ddsd_out,
+ )
+ dx = ddsd_out
+ return dx, dw1, dw2, None, None
+
+
+memory_optimized_mlp = MemoryOptimizedMLP.apply
+
+
+class SparseMLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ self._num_rows_per_rank = mpu.experts_per_rank(args) * mpu.features_per_rank(args)
+
+ self.w1 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ self.w2 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+
+ # Initialize the parameters for the MLP.
+ #
+ # NOTE: It is important that we create the weight tensors prior
+ # to creating the master weights and slicing our the piece for
+ # this rank. If the master weights are created first the PyTorch
+ # caching allocator appears to use the same memory block for these
+ # and the slice which causes large increases in our peak memory
+ # usage.
+ with torch.no_grad():
+ self.w1.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ ),
+ )
+ self.w2.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.output_layer_init_method,
+ ),
+ )
+
+ self._should_set_parallelism_attribute = args.moe_expert_model_parallelism
+ mpu.set_expert_model_parallel_attributes(
+ self.w1,
+ self._should_set_parallelism_attribute,
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w2,
+ self._should_set_parallelism_attribute,
+ )
+
+ self.gradient_scale = None
+ if self.args.moe_expert_model_parallelism:
+ self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,)
+
+ def scale_grad(self, w):
+ if self.gradient_scale is None:
+ return w
+ return scale_gradient(w, self.gradient_scale)
+
+ def forward(self, x, topo):
+ w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2)
+ w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2)
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_mlp(
+ x,
+ w1,
+ w2,
+ topo,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ x = stk.ops.sdd(x, w1.t(), topo)
+ activation_fn_out = act_fn(x, self.args.activation_fn)
+ return stk.ops.dsd(activation_fn_out, w2)
+
+
+class MemoryOptimizedGroupedMLP(torch.autograd.Function):
+ """GroupedMLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, w2, batch_sizes, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.")
+
+ # Layer 0: x @ w1.t().
+ assert gg.backend is not None
+ sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True)
+
+ # activation_fn
+ activation_fn_out = activation_fn(sdd_out)
+
+ # Layer 1: x @ w2.
+ dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, w2, batch_sizes, x, sdd_out)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx: Any, ddsd_out: torch.Tensor):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # Unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, w2 = saved_tensors[:2]
+ batch_sizes = saved_tensors[2]
+ x = saved_tensors[3]
+ sdd_out = saved_tensors[4]
+
+ # Rematerialize activation_fn output.
+ activation_fn = ctx.activation_fn
+ with torch.set_grad_enabled(True):
+ sdd_out.requires_grad = True
+ activation_fn_out = activation_fn(sdd_out)
+ activation_grad_fn = activation_fn_out.backward
+
+ # Compute dw2 with recomputed activation_fn output.
+ assert gg.backend is not None
+ dw2 = gg.backend.gmm(
+ activation_fn_out,
+ ddsd_out,
+ batch_sizes,
+ trans_a=True,
+ )
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ gg.backend.gmm(
+ ddsd_out,
+ w2,
+ batch_sizes,
+ trans_b=True,
+ c=dactivation_fn_out,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ if activation_fn is DEFAULT_ACTIVATION_FN:
+ dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out)
+ else:
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out)
+ dsdd_out = sdd_out.grad
+
+ # Compute dw1.
+ dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ gg.backend.gmm(dsdd_out, w1, batch_sizes, c=ddsd_out)
+ dx = ddsd_out
+ return dx, dw1, dw2, None, None
+
+
+memory_optimized_grouped_mlp = MemoryOptimizedGroupedMLP.apply
+
+
+class GroupedMLP(SparseMLP):
+
+ def forward(self, x, tokens_per_expert):
+ batch_sizes = tokens_per_expert.cpu().to(torch.long)
+ w1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.w2))
+
+ # Re-shape the weights for the grouped GEMMs.
+ ne = mpu.experts_per_rank(self.args)
+ w1 = resolve_dtensor(w1).view(ne, -1, self.args.hidden_size)
+ w2 = resolve_dtensor(w2).view(ne, -1, self.args.hidden_size)
+
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_grouped_mlp(
+ x,
+ w1,
+ w2,
+ batch_sizes,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ assert gg.ops is not None
+ x = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
+ x = self.args.activation_fn(x)
+ return gg.ops.gmm(x, w2, batch_sizes)
+
+
+class SharedMLP(torch.nn.Module):
+ """MLP for shared expert.
+
+ Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTMLP class
+ """
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ self.fc_kwargs: dict[str, Any] = {
+ 'bias': args.bias,
+ 'device': args.device,
+ }
+ self.fc_kwargs.update(args.fc_kwargs)
+
+ self.up_proj = args.fc_cls(
+ args.hidden_size,
+ args.shared_expert_hidden_size,
+ **self.fc_kwargs,
+ )
+ self.act = args.activation_fn
+ self.down_proj = args.fc_cls(
+ args.shared_expert_hidden_size,
+ args.hidden_size,
+ **self.fc_kwargs,
+ )
+ self.down_proj._is_residual = True # a flag for llm-foundry init
+
+ def add_experts_sharedexpert(
+ self,
+ shared_expert_out: torch.Tensor,
+ expert_out: torch.Tensor,
+ ) -> torch.Tensor:
+ # Helper function to add expert output to shared expert output
+ # with optional weighted sum.
+ if self.args.shared_expert_weighted_sum:
+ # enable using weighted sum for shared expert output
+ # wieghted by number of experts used
+ t_experts = self.args.moe_top_k + 1
+ sh_mlp_out = shared_expert_out / t_experts
+ return sh_mlp_out.add(
+ expert_out,
+ alpha=(self.args.moe_top_k / t_experts),
+ )
+
+ return shared_expert_out + expert_out
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.down_proj(self.act(self.up_proj(x)))
diff --git a/build/torch211-cxx11-cu130-aarch64-linux/_layers/moe.py b/build/torch211-cxx11-cu130-aarch64-linux/_layers/moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0a4aeaacc9c86fc70944e730c53f7a55644e05e
--- /dev/null
+++ b/build/torch211-cxx11-cu130-aarch64-linux/_layers/moe.py
@@ -0,0 +1,507 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Optional, Tuple
+
+import numpy as np
+import torch
+import torch.distributed as dist
+
+# import megablocks.ops as ops
+# from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry
+# from megablocks.layers.all_to_all import all_to_all
+# from megablocks.layers.arguments import Arguments
+
+from ..ops import (
+ sort,
+ histogram,
+ inclusive_cumsum,
+ exclusive_cumsum,
+ binned_gather,
+ binned_scatter,
+ gather,
+ scatter,
+ repeat,
+ replicate,
+)
+
+from . import common, mlp, mpu, router, sharedexpert_registry
+from .arguments import Arguments
+from .all_to_all import all_to_all
+
+_LOAD_BALANCING_LOSS = []
+
+
+def save_load_balancing_loss(loss):
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.append(loss)
+
+
+def get_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ return _LOAD_BALANCING_LOSS
+
+
+def clear_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.clear()
+
+
+def batched_load_balancing_loss(args: Arguments):
+ if args.moe_loss_weight == 0:
+ return 0.0
+
+ # tokens_per_expert[i].shape = (num_experts)
+ # expert_scores[i].shape = (tokens, num_experts)
+ tokens_per_expert, expert_scores = zip(*get_load_balancing_loss())
+ num_layers_per_pipeline_stage = (args.num_layers // args.pipeline_model_parallel_size)
+ if args.num_layers_per_virtual_pipeline_stage is not None:
+ num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage
+
+ if len(tokens_per_expert) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f'Expected {num_layers_per_pipeline_stage} token_per_experts '
+ f'but found {len(tokens_per_expert)}.\nnum_layers = '
+ f'{args.num_layers}\npipeline_model_parallel_size = '
+ f'{args.pipeline_model_parallel_size}\n'
+ 'num_layers_per_virtual_pipeline_stage'
+ f' = {args.num_layers_per_virtual_pipeline_stage}',
+ )
+ if len(expert_scores) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f'Expected {num_layers_per_pipeline_stage} expert_scores '
+ f'but found {len(tokens_per_expert)}.\nnum_layers = '
+ f'{args.num_layers}\npipeline_model_parallel_size = '
+ f'{args.pipeline_model_parallel_size}\n'
+ 'num_layers_per_virtual_pipeline_stage'
+ f' = {args.num_layers_per_virtual_pipeline_stage}',
+ )
+
+ # Verify the shape of the tokens_per_expert and expert_scores tensors.
+ assert all((x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert))
+
+ tokens = expert_scores[0].shape[0]
+ assert all(((x.ndim == 2 and x.shape[1] == args.moe_num_experts and x.shape[0] == tokens) for x in expert_scores))
+
+ # Concatenate the contributions of each layer and convert to
+ # the correct types and formats for the dot product.
+ expert_scores = torch.cat(expert_scores, dim=1)
+ if args.moe_lbl_in_fp32:
+ expert_scores = expert_scores.float()
+ if tokens != 0:
+ expert_scores = expert_scores.mean(dim=0)
+ else:
+ expert_scores = expert_scores.sum(dim=0)
+ tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype)
+
+ expected_values = num_layers_per_pipeline_stage * args.moe_num_experts
+ assert tokens_per_expert.numel() == expected_values
+ assert expert_scores.numel() == expected_values
+
+ # Calculate the total scale across all factors.
+ #
+ # loss_weight * num_experts / (num_layers * tokens * top_k)
+ scale_numerator = (args.moe_num_experts * args.moe_loss_weight)
+ scale_denominator = (args.num_layers * tokens * args.moe_top_k)
+ scale = scale_numerator / scale_denominator
+ return scale * torch.dot(tokens_per_expert, expert_scores)
+
+
+# NOTE: This class defines MoE expert computation, including expert model parallel
+# communication. When using FSDP on top of MegaBlocks this is the module that should
+# be wrapped s.t. the weight all-gathers can be scheduled *before* the expert model
+# parallel all2all.
+class ParallelMLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super(ParallelMLP, self).__init__()
+ self.args = args
+
+ # Calculate the number of experts in total and the number of experts
+ # owned by this rank.
+ # world_size = mpu.get_expert_parallel_world_size(args)
+ self.num_experts = args.moe_num_experts
+ self.top_k = self.args.moe_top_k
+
+ # Calculate the number of bits needed to represent the expert indices
+ # so that we can pass it to radix sort.
+ self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1)
+
+ # Expert MLP.
+ self.mlp = mlp.MLP(args)
+
+ self.bias: Optional[torch.Tensor]
+ if self.args.bias:
+ # Note that the output bias is not parallelized with expert
+ # model parallelism.
+ self.bias = torch.nn.Parameter(
+ torch.empty(
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ torch.nn.init.zeros_(self.bias)
+ else:
+ self.register_parameter('bias', None)
+
+ # Select the forward function for the operating mode.
+ self.forward_fn = (self.parallel_forward_once if args.moe_expert_model_parallelism else self.forward_once)
+
+ def expert_capacity(self, tokens: int) -> int:
+ world_size = mpu.get_expert_parallel_world_size(self.args)
+ tokens_per_expert = (self.top_k * tokens * world_size / self.num_experts)
+ return int(self.args.moe_capacity_factor * tokens_per_expert)
+
+ def load_balancing_loss(self, tokens_per_expert: torch.Tensor, expert_scores: torch.Tensor):
+ """Calculate the load balancing loss contribution."""
+ assert len(expert_scores.size()) == 2
+ tokens, num_experts = expert_scores.size()
+ assert num_experts == self.num_experts
+ assert len(tokens_per_expert.size()) == 1
+ num_experts, = tokens_per_expert.size()
+ assert num_experts == self.num_experts
+ scale = self.num_experts / (tokens * self.top_k)
+ return scale * torch.dot(
+ tokens_per_expert.to(expert_scores.dtype),
+ expert_scores.mean(dim=0),
+ )
+
+ def indices_and_bins(self,
+ top_expert: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ # Sort the expert ids to produce the scatter/gather
+ # indices for the permutation.
+ #
+ # TODO(tgale): Is it worth doing this conversion to 32-bit
+ # prior? Could we place the `torch.max` operation to return
+ # 32-bit expert indices?
+ top_expert = top_expert.int()
+ # output = ops.sort(top_expert, self.sort_end_bit)
+ output = sort(top_expert, self.sort_end_bit)
+ assert output is not None
+ bin_ids, indices = output
+
+ # Histogram the expert ids to identify the number of
+ # tokens routed to each expert.
+ #
+ # TODO(tgale): Does the sorted data produce a more favorable
+ # data distribution for histogram? Or is the op parallelism
+ # worth more?
+ # tokens_per_expert = ops.histogram(top_expert, self.num_experts)
+ tokens_per_expert = histogram(top_expert, self.num_experts)
+
+ # Calculate the bin bounds for the sorted tokens.
+ # bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ bins = inclusive_cumsum(tokens_per_expert, 0)
+ assert bins is not None
+ bins = bins.view(1) if not len(bins.size()) else bins
+
+ assert isinstance(indices, torch.Tensor)
+ assert isinstance(bin_ids, torch.Tensor)
+ assert isinstance(bins, torch.Tensor)
+ assert isinstance(tokens_per_expert, torch.Tensor)
+
+ return indices, bin_ids, bins, tokens_per_expert
+
+ def permute_and_compute(
+ self,
+ x: torch.Tensor,
+ tokens_per_expert: int, # unused
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor, # unused
+ expert_weights: torch.Tensor,
+ bins: torch.Tensor,
+ expert_capacity: int,
+ top_k: int,
+ ):
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ # output = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
+ output = binned_gather(x, indices, bins, expert_capacity, top_k)
+ assert output is not None
+ x = output
+
+ # Perform the expert computation. Note that we don't
+ # use biases for these linear operations.
+ x = self.mlp(x)
+
+ # Un-route the data for the MoE output.
+ # return ops.binned_scatter(x, indices, expert_weights, bins, top_k)
+ return binned_scatter(x, indices, expert_weights, bins, top_k)
+
+
+ def forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ # If expert_capacity is set to zero, set the number of tokens
+ # per expert to the maximum we need to avoid dropping tokens.
+ sl, bs, _ = x.size()
+ expert_capacity = self.expert_capacity(sl * bs)
+ if expert_capacity == 0:
+ expert_capacity = torch.max(tokens_per_expert).item()
+
+ x = self.permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ self.top_k,
+ )
+ return x, tokens_per_expert
+
+ def parallel_forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ # NOTE: This function implements the same computation as forward_once
+ # but with expert model parallelism.
+ #
+ # 1. Permute the tokens locally so that they are grouped by their
+ # expert assignments. This allows us to transfer all of the tokens
+ # for a remote device in one communication primitive.
+ #
+ # 2. Permute the tokens across the expert parallel devices. After
+ # this is completed each device has all of the tokens assigned to
+ # its set of experts in its local HBM.
+ #
+ # 3. Permute the tokens locally so that they are grouped by their
+ # expert assignement. After the distributed permutation the tokens
+ # are grouped by which device they came from. We re-order them
+ # locally to allow for efficient computation.
+ #
+ # After this series of permutations we compute the linear layers
+ # and then repeat these three steps in reverse to produce the final
+ # output.
+ #
+ # Compute the mapping of local tokens to experts.
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ # If we're sharding the experts along the hidden dimension
+ # multiple devices own parts of the same sets of experts.
+ # Replicate the token counts so every device gets the counts.
+ # repeated_tokens_per_expert = ops.repeat(
+ repeated_tokens_per_expert = repeat(
+ tokens_per_expert,
+ (mpu.hidden_sharding_degree(self.args),),
+ )
+
+ # Pass token count information to the device on which the
+ # target expert resides.
+ parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert,)
+ tpe_handle = dist.all_to_all_single(
+ parallel_tokens_per_expert,
+ repeated_tokens_per_expert,
+ group=self.args.expert_parallel_group,
+ async_op=True,
+ )
+
+ # Permute locally and without any padding so that tokens for each
+ # parallel device are stored contiguously.
+ #
+ # This view updates the shape of the tensor from [sl, bs, hs] to
+ # [sl * bs, hs] prior to the permutation.
+ x = x.view(-1, x.shape[-1])
+ # output = ops.gather(x, indices, bin_ids, bins, self.top_k)
+ output = gather(x, indices, bin_ids, bins, self.top_k)
+ assert output is not None
+ x = output
+
+ # Compute the number of tokens that will be received from each
+ # device and permute the input data across the devices.
+ with torch.no_grad():
+ tpe_handle.wait()
+ experts_per_rank = mpu.experts_per_rank(self.args)
+
+ # Reshape to [world_size, num_experts_per_rank].
+ world_size = mpu.get_expert_parallel_world_size(self.args)
+ repeated_tokens_per_expert = (repeated_tokens_per_expert.view(world_size, experts_per_rank))
+ parallel_tokens_per_expert = (parallel_tokens_per_expert.view(world_size, experts_per_rank))
+
+ # TODO(tgale): It might be faster to do this on the GPU and
+ # then communicate the results back to the host.
+ send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1)
+ parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
+ recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1)
+
+ # Convert the send/recv counts to lists.
+ send_counts = send_counts.tolist()
+ recv_counts = recv_counts.tolist()
+ tokens_received = sum(recv_counts)
+
+ # If we're sharding the experts along the hidden dimension
+ # multiple devices own parts of the same sets of experts.
+ # Replicate the token counts so devices that share experts
+ # get all of the tokens assigned to them.
+ #
+ # TODO(tgale): Fuse this into the prior, local permutation.
+ # x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
+ x = repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
+
+ # Start the cross-device permutation asynchronously so we can
+ # overlap communication with computation.
+ parallel_x, parallel_x_handle = all_to_all(
+ x,
+ recv_counts,
+ send_counts,
+ self.args.expert_parallel_group,
+ async_op=True,
+ )
+
+ with torch.no_grad():
+ # After we do the cross-device permutation we have the tokens on the
+ # correct device but not yet grouped by expert because we received
+ # tokens from each device as contiguous chunks. To group the tokens
+ # for expert computation we'll do one more local permutation. The
+ # rest of this torch.no_grad() scope sets up the indices and bins
+ # for this permutation.
+ # replicate_bins = ops.inclusive_cumsum(
+ replicate_bins = inclusive_cumsum(
+ parallel_tokens_per_expert.flatten(),
+ 0,
+ )
+ replicate_bins = (replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins)
+
+ # Construct the expert indices for the permuted tokens.
+ parallel_top_expert = torch.remainder(
+ torch.arange(
+ self.num_experts * mpu.hidden_sharding_degree(self.args),
+ dtype=torch.int32,
+ device=indices.device,
+ ),
+ mpu.experts_per_rank(self.args),
+ )
+ # parallel_top_expert = ops.replicate(
+ parallel_top_expert = replicate(
+ parallel_top_expert.unsqueeze(dim=0),
+ replicate_bins,
+ tokens_received,
+ ).flatten()
+
+ # TODO(tgale): The sort_end_bit here can be reduced.
+ # parallel_bin_ids, parallel_indices = ops.sort(
+ parallel_bin_ids, parallel_indices = sort(
+ parallel_top_expert,
+ self.sort_end_bit,
+ )
+
+ # Calculate the bins boundaries from the token counts.
+ parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+ # parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = (parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins)
+
+ # If expert_capacity is set to zero, set the number of tokens
+ # per expert to the maximum we need to avoid dropping tokens.
+ tokens, _ = x.size()
+ expert_capacity = self.expert_capacity(tokens)
+ if expert_capacity == 0:
+ expert_capacity = torch.max(parallel_tokens_per_expert).item()
+
+ # Locally permute the tokens and perform the expert computation.
+ # Block to make sure that the cross-device permutation is complete.
+ if self.args.mlp_impl == 'grouped':
+ # GroupedMLP requires counts on CPU. We can use the tensor already
+ # moved to CPU for the prior all_to_all, which avoids an extra
+ # device synchronization.
+ parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+ parallel_x_handle.wait()
+ parallel_x = self.permute_and_compute(
+ parallel_x,
+ parallel_tokens_per_expert,
+ parallel_indices,
+ parallel_bin_ids,
+ None, # expert_weights
+ parallel_bins,
+ expert_capacity,
+ top_k=1,
+ )
+
+ # Un-permute the tokens across the devices.
+ x, _ = all_to_all(
+ parallel_x,
+ send_counts,
+ recv_counts,
+ self.args.expert_parallel_group,
+ )
+
+ # Reduce along the hidden sharding to get the final outputs.
+ #
+ # TODO(tgale): Fuse this into the following local permutation.
+ shape = (
+ mpu.hidden_sharding_degree(self.args),
+ -1,
+ self.args.hidden_size,
+ )
+ # x = ops.sum(x.view(shape), dim=0)
+ x = x.view(shape).sum(dim=0)
+
+ # Un-permute locally to setup for the next series of operations.
+ # x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
+ x = scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
+ return x, tokens_per_expert.flatten()
+
+ def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ in_shape = x.size()
+
+ # Compute the experts.
+ x, tokens_per_expert = self.forward_fn(x, expert_weights, top_experts)
+ if self.training and self.args.moe_loss_weight > 0:
+ save_load_balancing_loss((tokens_per_expert, scores))
+ x = x.view(in_shape)
+ if self.bias is not None:
+ if self.args.return_bias:
+ return x, self.bias
+ return x + self.bias
+ return x
+
+
+class MoE(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super(MoE, self).__init__()
+
+ # Token router.
+ self.router = router.LearnedRouter(args)
+
+ # Expert computation helper.
+ self.experts = self._init_experts_mlp(args)
+
+ self.shared_expert = None
+ if args.shared_expert:
+ # SharedExpert computation helper.
+ self.shared_expert = sharedexpert_registry.get(args)
+
+ def _init_experts_mlp(self, args: Arguments):
+ return ParallelMLP(args)
+
+ def forward(self, x: torch.Tensor):
+ # NOTE: If we're going to cast the activations to lower precision
+ # do it before we permute the tokens to save bandwidth.
+ x = common.cast_if_autocast_enabled(x)
+
+ # Compute the expert scores and assignments.
+ scores, expert_weights, top_experts = self.router(x)
+
+ # Compute the experts.
+ out = self.experts(x, scores, expert_weights, top_experts)
+ if self.shared_expert is not None:
+ shared_expert_out = self.shared_expert(x)
+ out = self.shared_expert.add_experts_sharedexpert(
+ shared_expert_out,
+ out,
+ )
+ return out
diff --git a/build/torch211-cxx11-cu130-aarch64-linux/_layers/mpu.py b/build/torch211-cxx11-cu130-aarch64-linux/_layers/mpu.py
new file mode 100644
index 0000000000000000000000000000000000000000..434e143ab42bf3f83406d69e9dd1f72777716e22
--- /dev/null
+++ b/build/torch211-cxx11-cu130-aarch64-linux/_layers/mpu.py
@@ -0,0 +1,94 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Optional
+
+import torch
+import torch.distributed as dist
+
+# from megablocks.layers.arguments import Arguments
+from .arguments import Arguments
+
+
+class MoeParam(torch.Tensor):
+
+ def __init__(self):
+ super().__init__(self)
+ self.expert_model_parallel: bool
+
+
+def is_moe_param(tensor: torch.Tensor) -> bool:
+ return hasattr(tensor, 'expert_model_parallel')
+
+
+def get_expert_parallel_world_size(args: Arguments) -> int:
+ return (dist.get_world_size(args.expert_parallel_group) if args.moe_expert_model_parallelism else 1)
+
+
+def get_expert_parallel_rank(args: Arguments) -> int:
+ return (dist.get_rank(args.expert_parallel_group) if args.moe_expert_model_parallelism else 0)
+
+
+def set_expert_model_parallel_attributes(
+ tensor: torch.Tensor,
+ is_parallel: bool,
+):
+ assert not hasattr(tensor, 'expert_model_parallel')
+ setattr(tensor, 'expert_model_parallel', is_parallel)
+
+
+def param_is_expert_model_parallel(param: MoeParam) -> bool:
+ return (hasattr(param, 'expert_model_parallel') and param.expert_model_parallel)
+
+
+def copy_expert_model_parallel_attributes(
+ destination_tensor: torch.Tensor,
+ source_tensor: torch.Tensor,
+):
+ if hasattr(source_tensor, 'expert_model_parallel'):
+ setattr(
+ destination_tensor,
+ 'expert_model_parallel',
+ getattr(source_tensor, 'expert_model_parallel'),
+ )
+
+
+def synchronized_print(group: Optional[dist.ProcessGroup], *x: torch.Tensor):
+ world_size = dist.get_world_size(group)
+ rank = dist.get_rank(group)
+ for i in range(world_size):
+ dist.barrier(group)
+ if i == rank:
+ print(f'rank = {rank}', *x)
+
+
+# Helpers for expert/tensor sharding.
+def expert_sharding_degree(args: Arguments) -> int:
+ world_size = get_expert_parallel_world_size(args)
+ esd = min(world_size, args.moe_num_experts)
+
+ if (args.moe_num_experts % esd) != 0:
+ raise ValueError(f'Cannot shard {args.moe_num_experts} experts {esd} ways.',)
+ return esd
+
+
+def hidden_sharding_degree(args: Arguments) -> int:
+ world_size = get_expert_parallel_world_size(args)
+ esd = expert_sharding_degree(args)
+ hsd = world_size // esd
+
+ if (args.ffn_hidden_size % hsd) != 0:
+ raise ValueError(f'Cannot shard {args.ffn_hidden_size} features {hsd} ways.',)
+ if (esd * hsd) != world_size:
+ raise ValueError(
+ f"Invalid sharding. 'expert_sharding_degree' ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size}).",
+ )
+ return hsd
+
+
+def experts_per_rank(args: Arguments) -> int:
+ return args.moe_num_experts // expert_sharding_degree(args)
+
+
+def features_per_rank(args: Arguments) -> int:
+ return args.ffn_hidden_size // hidden_sharding_degree(args)
diff --git a/build/torch211-cxx11-cu130-aarch64-linux/_layers/router.py b/build/torch211-cxx11-cu130-aarch64-linux/_layers/router.py
new file mode 100644
index 0000000000000000000000000000000000000000..37cb2782348d62583376f1a183c7ede83601216d
--- /dev/null
+++ b/build/torch211-cxx11-cu130-aarch64-linux/_layers/router.py
@@ -0,0 +1,116 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+
+# from megablocks.layers import common
+# from megablocks.layers.arguments import Arguments
+from . import common
+from .arguments import Arguments
+
+_ROUTER_LOGITS = []
+
+
+def _save_router_logits(logits: torch.Tensor, args: Arguments):
+ if args.moe_zloss_weight == 0:
+ return
+ global _ROUTER_LOGITS
+ _ROUTER_LOGITS.append(logits)
+
+
+def clear_router_zloss():
+ global _ROUTER_LOGITS
+ _ROUTER_LOGITS.clear()
+
+
+def batched_router_zloss(args: Arguments):
+ global _ROUTER_LOGITS
+
+ if args.moe_zloss_weight == 0:
+ import warnings
+ warnings.warn('Call to batched_router_zloss, but moe_zloss_weight=0')
+ return 0
+
+ logits_per_router = _ROUTER_LOGITS
+
+ if args.moe_zloss_in_fp32:
+ logits_per_router = [logits.float() for logits in logits_per_router]
+
+ unscaled_zloss_per_router = torch.stack([
+ torch.logsumexp(logits, dim=1).square().mean() for logits in logits_per_router
+ ])
+
+ return args.moe_zloss_weight * unscaled_zloss_per_router
+
+
+# NOTE: To enable end-to-end benchmarking without convergence we
+# support a flag to force the router to assign tokens uniformly
+# across the experts. We do this with a custom autograd operation
+# so that PyTorch still executes the full set of router operation.
+class _UniformExpertAssignment(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, num_experts: int):
+ out = torch.arange(x.numel(), dtype=x.dtype, device=x.device)
+ out = torch.remainder(out, num_experts)
+ return out.view(x.shape)
+
+
+_uniform_expert_assignment = _UniformExpertAssignment.apply
+
+
+class LearnedRouter(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+
+ # Learned router parameters.
+ #
+ # NOTE: This weight matrix is not parallelized with expert model
+ # parallelism. Each device needs the entire router weight matrix
+ # so that it can route its batch of data correctly.
+ self.layer = torch.nn.Linear(
+ args.hidden_size,
+ args.moe_num_experts,
+ bias=False,
+ dtype=common.dtype(args),
+ device=args.device,
+ )
+ args.init_method(self.layer.weight)
+
+ def jitter(self, x: torch.Tensor):
+ low: float = 1.0 - self.args.moe_jitter_eps
+ high: float = 1.0 + self.args.moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return low + noise * (high - low)
+
+ def _top_k(self, scores: torch.Tensor):
+ if self.args.moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, self.args.moe_top_k, dim=-1)
+
+ def forward(self, x: torch.Tensor):
+ if self.training and self.args.moe_jitter_eps is not None:
+ x = x * self.jitter(x)
+
+ logits = self.layer(x.view(-1, x.shape[-1]))
+ _save_router_logits(logits, self.args)
+ scores = logits.softmax(dim=-1)
+ expert_weights, expert_indices = self._top_k(scores)
+ if self.args.moe_normalize_expert_weights:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=self.args.moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+
+ expert_indices = (
+ _uniform_expert_assignment(
+ expert_indices,
+ self.args.moe_num_experts,
+ ) if self.args.uniform_expert_assignment else expert_indices
+ )
+ return scores, expert_weights, expert_indices
diff --git a/build/torch211-cxx11-cu130-aarch64-linux/_layers/sharedexpert_registry.py b/build/torch211-cxx11-cu130-aarch64-linux/_layers/sharedexpert_registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..5840862f88f370ace5fd49bd0612fc98d186cc49
--- /dev/null
+++ b/build/torch211-cxx11-cu130-aarch64-linux/_layers/sharedexpert_registry.py
@@ -0,0 +1,32 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Union
+
+# from megablocks.layers import glu, mlp
+# from megablocks.layers.arguments import Arguments
+from . import glu, mlp
+from .arguments import Arguments
+
+_REGISTRY = {
+ 'mlp': mlp.SharedMLP,
+ 'glu': glu.SharedGLU,
+}
+
+
+def get(args: Arguments) -> Union[mlp.SharedMLP, glu.SharedGLU]:
+ """Returns an SharedMLP for use in a dMoE instance.
+
+ Uses the provided arguments to instantiate the appropriate
+ SharedMLP instance.
+
+ Args:
+ args: propagated Arguments dataclass.
+
+ Returns:
+ An instantiated SharedMLP constructed using the input args.
+ """
+ if args.mlp_type not in _REGISTRY:
+ raise ValueError(f'Unsupported mlp type: {args.mlp_type}')
+
+ return _REGISTRY[args.mlp_type](args)
diff --git a/build/torch211-cxx11-cu130-aarch64-linux/_megablocks_cuda_7a6bcf4.abi3.so b/build/torch211-cxx11-cu130-aarch64-linux/_megablocks_cuda_7a6bcf4.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..49f9a4a2738530bc50ce8497de9d5206075a5f2e
--- /dev/null
+++ b/build/torch211-cxx11-cu130-aarch64-linux/_megablocks_cuda_7a6bcf4.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0926657a5bf049020d315e3281ca24c455d318a0ec8d9afc14665a79f8c2f19c
+size 12073200
diff --git a/build/torch211-cxx11-cu130-aarch64-linux/_ops.py b/build/torch211-cxx11-cu130-aarch64-linux/_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..abde001f8cf5f78a02794d6e9a81fd8195e65d77
--- /dev/null
+++ b/build/torch211-cxx11-cu130-aarch64-linux/_ops.py
@@ -0,0 +1,9 @@
+import torch
+from . import _megablocks_cuda_7a6bcf4
+ops = torch.ops._megablocks_cuda_7a6bcf4
+
+def add_op_namespace_prefix(op_name: str):
+ """
+ Prefix op by namespace.
+ """
+ return f"_megablocks_cuda_7a6bcf4::{op_name}"
diff --git a/build/torch211-cxx11-cu130-aarch64-linux/_version.py b/build/torch211-cxx11-cu130-aarch64-linux/_version.py
new file mode 100644
index 0000000000000000000000000000000000000000..c55783177af19bc03654c730c4892df8f8532279
--- /dev/null
+++ b/build/torch211-cxx11-cu130-aarch64-linux/_version.py
@@ -0,0 +1,6 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+"""The MegaBlocks Version."""
+
+__version__ = '0.11.0.dev0'
diff --git a/build/torch211-cxx11-cu130-aarch64-linux/backend/__init__.py b/build/torch211-cxx11-cu130-aarch64-linux/backend/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d4e43e9b7e6a9a1c3bde2df34914643ca5d8332
--- /dev/null
+++ b/build/torch211-cxx11-cu130-aarch64-linux/backend/__init__.py
@@ -0,0 +1,2 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
diff --git a/build/torch211-cxx11-cu130-aarch64-linux/backend/kernels.py b/build/torch211-cxx11-cu130-aarch64-linux/backend/kernels.py
new file mode 100644
index 0000000000000000000000000000000000000000..d20139324883338ddb312e4b05a72056d74491ac
--- /dev/null
+++ b/build/torch211-cxx11-cu130-aarch64-linux/backend/kernels.py
@@ -0,0 +1,557 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import triton
+import triton.language as tl
+
+# Stub triton autotune when testing in a env that does not have CUDA
+# this approach preserves the original code but enables testing without a GPU
+if torch.cuda.is_available() is False:
+ import warnings
+
+ warnings.warn("CUDA is not available. Triton autotuning is disabled.")
+
+ def _no_autotune(*args, **kwargs):
+ def deco(fn):
+ return fn
+ return deco
+
+ triton.autotune = _no_autotune
+
+
+def assert_is_tensor(x, ndim):
+ if x.ndim != ndim:
+ raise ValueError(f'Expected {ndim}-tensor but got {x.ndim}-tensor')
+
+
+def assert_is_matrix(x):
+ assert_is_tensor(x, 2)
+
+
+def assert_is_vector(x):
+ if x.ndim != 1:
+ raise ValueError(f'Expected 1-tensor but got {x.ndim}-tensor')
+
+
+def assert_equal(a, b):
+ if a != b:
+ raise ValueError(f'Expected dimensions to be equal but got {a} and {b}.',)
+
+
+# a: (tokens, hidden_size), real.
+# indices: (tokens * top_k), integer.
+# bin_ids: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+# padded_bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _padded_copy(
+ a,
+ b,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+ A_TO_B: tl.constexpr,
+ SCALE: tl.constexpr,
+):
+ # Our index into array 'a'.
+ index_a = tl.load(indices + tl.program_id(0))
+
+ # One threadblock per row in 'a'. Array 'b' has greater or equal
+ # number of rows since they could be padded.
+ bin_idx = tl.load(bin_ids + tl.program_id(0))
+
+ # Now we know what bin we're assigned to, but we need to know how
+ # many threadblocks were assigned to earlier bins so we can offset
+ # in our bin properly.
+ offset_in_bin = tl.program_id(0)
+ if bin_idx > 0:
+ offset_in_bin -= tl.load(bins + bin_idx - 1)
+
+ # Load the starting index of our bin in array 'b'.
+ index_b = offset_in_bin
+ if bin_idx > 0:
+ index_b += tl.load(padded_bins + bin_idx - 1)
+
+ # Offset the input and output pointers.
+ #
+ # If we're going from A to B, divide the input index to copy
+ # the same input repeatedly. If we're going from B to A we
+ # need to reduce the result. Using atomics is slow, so we
+ # do the reduce step in a second kernel.
+ offset = index_a // TOP_K if A_TO_B else index_a
+ a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS)
+ b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ # Load the scale, if requested.
+ scale = tl.load(weights + index_a) if SCALE else 1
+
+ # Swap the pointers depending on the direction.
+ iptr = a if A_TO_B else b
+ optr = b if A_TO_B else a
+
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ x = tl.load(iptr + offsets, mask=mask)
+ x = x.to(tl.float32) * scale.to(tl.float32)
+
+ tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask)
+
+ offsets += BLOCK_X
+
+
+def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+ assert_equal(bin_ids.shape[0], x.shape[0] * top_k)
+ assert_equal(bins.size(), padded_bins.size())
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ # NOTE: Because of the padding, the output size is dynamic.
+ # We load the final padded bin bound to get the output rows.
+ output_rows = padded_bins[-1].cpu().item()
+ out = torch.zeros((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ x,
+ out,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def gather(x, indices, bin_ids, weights, bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+ assert_equal(bin_ids.shape[0], x.shape[0] * top_k)
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ # NOTE: There is no padding so the output rows equals the
+ # input rows multiplied by top_k.
+ output_rows = x.shape[0] * top_k
+ out = torch.empty((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ x,
+ out,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], bin_ids.shape[0])
+ assert_equal(bins.size(), padded_bins.size())
+
+ if weights is not None:
+ assert_equal(indices.shape[0], weights.shape[0])
+
+ tokens = indices.shape[0] // top_k
+ out = torch.empty((tokens, top_k, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ out,
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=False,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+
+ # Reduce along the top-k dimension, if needed.
+ return out.sum(dim=1) if top_k > 1 else out.view(tokens, x.shape[1])
+
+
+def scatter(x, indices, bin_ids, weights, bins, top_k):
+ return padded_scatter(x, indices, bin_ids, weights, bins, bins, top_k)
+
+
+# x: (tokens, top_k, hidden_size), real
+# grad: (tokens, hidden_size), real.
+# wgrad: (tokens, top_k), real.
+# indices: (tokens * top_k), integer.
+# bin_ids: (tokens * top_k), integer.
+# bins: (num_experts), integer.
+# padded_bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _padded_copy_wgrad(
+ x,
+ grad,
+ wgrad,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+):
+ # Our index into 'tokens * top_k'.
+ index_out = tl.load(indices + tl.program_id(0))
+
+ # One threadblock per row in 'a'. Array 'b' has greater or equal
+ # number of rows since they could be padded.
+ bin_idx = tl.load(bin_ids + tl.program_id(0))
+
+ # Now we know what bin we're assigned to, but we need to know how
+ # many threadblocks were assigned to earlier bins so we can offset
+ # in our bin properly.
+ offset_in_bin = tl.program_id(0)
+ if bin_idx > 0:
+ offset_in_bin -= tl.load(bins + bin_idx - 1)
+
+ # Load the starting index of our bin in array 'x'.
+ index_x = offset_in_bin
+ if bin_idx > 0:
+ index_x += tl.load(padded_bins + bin_idx - 1)
+
+ # Offset the input and output pointers.
+ wgrad += index_out
+ grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS)
+ x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ data = tl.load(x + offsets, mask=mask).to(tl.float32)
+ scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
+ acc += data * scale
+ offsets += BLOCK_X
+
+ # Reduce to get the final result and store.
+ out = tl.sum(acc).to(wgrad.dtype.element_ty)
+ tl.store(wgrad, out)
+
+
+def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_matrix(grad)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], bin_ids.shape[0])
+ assert_equal(bins.size(), padded_bins.size())
+
+ tokens = indices.shape[0] // top_k
+ out = torch.empty((tokens * top_k), dtype=x.dtype, device=x.device)
+ _padded_copy_wgrad[(indices.shape[0],)](
+ x,
+ grad,
+ out,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ TOP_K=top_k,
+ )
+ return out
+
+
+def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k):
+ return padded_scatter_wgrad(x, grad, indices, bin_ids, bins, bins, top_k)
+
+
+# a: (tokens, hidden_size), real.
+# b: (num_experts, expert_capacity, num_columns), real.
+# indices: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _binned_copy(
+ a,
+ b,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+ A_TO_B: tl.constexpr,
+ SCALE: tl.constexpr,
+):
+ # Load our indices into the output.
+ expert_idx = tl.program_id(0)
+ entry_idx = tl.program_id(1)
+
+ # Calculate our offset into the output.
+ index_b = expert_idx * expert_capacity + entry_idx
+
+ # Load the index bounds for our bin and calculate
+ # the number of tokens assigned to our expert.
+ start = 0
+ if expert_idx > 0:
+ start = tl.load(bins + expert_idx - 1)
+ end = tl.load(bins + expert_idx)
+ num_tokens = end - start
+
+ # Calculate our offset into the input. If we don't
+ # have an input exit early.
+ if entry_idx >= num_tokens:
+ return
+ index_a = tl.load(indices + start + entry_idx)
+
+ # Offset the input and output pointers.
+ #
+ # If we're going from A to B, divide the input index to copy
+ # the same input repeatedly. If we're going from B to A we
+ # need to reduce the result. Using atomics is slow, so we
+ # do the reduce step in a second kernel.
+ offset = index_a // TOP_K if A_TO_B else index_a
+ a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS)
+ b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ # Load the scale, if requested.
+ scale = tl.load(weights + index_a) if SCALE else 1
+
+ # Swap the pointers depending on the direction.
+ #
+ # NOTE: We need to zero the output in both directions.
+ iptr = a if A_TO_B else b
+ optr = b if A_TO_B else a
+
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ x = tl.load(iptr + offsets, mask=mask)
+ x = x.to(tl.float32) * scale.to(tl.float32)
+
+ tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask)
+
+ offsets += BLOCK_X
+
+
+def binned_gather(x, indices, weights, bins, expert_capacity, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ num_experts = bins.shape[0]
+ out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
+
+ _binned_copy[(num_experts, expert_capacity)](
+ x,
+ out,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def binned_scatter(x, indices, weights, bins, top_k):
+ # Validate the input shapes.
+ assert_is_tensor(x, 3)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(bins.shape[0], x.shape[0])
+
+ if weights is not None:
+ assert_equal(indices.shape[0], weights.shape[0])
+
+ num_experts, expert_capacity, hidden_size = x.shape
+ tokens = indices.shape[0] // top_k
+ out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
+ _binned_copy[(num_experts, expert_capacity)](
+ out,
+ x,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS=hidden_size,
+ A_TO_B=False,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+
+ # Reduce along the top-k dimension, if needed.
+ return out.sum(dim=1) if top_k > 1 else out.view(tokens, hidden_size)
+
+
+# a: (tokens, hidden_size), real.
+# b: (num_experts, expert_capacity, num_columns), real.
+# indices: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _binned_copy_wgrad(
+ x,
+ grad,
+ wgrad,
+ num_experts,
+ expert_capacity,
+ indices,
+ bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+):
+ # Load our indices into the output.
+ expert_idx = tl.program_id(0)
+ entry_idx = tl.program_id(1)
+
+ # Calculate our offset into the output.
+ index_x = expert_idx * expert_capacity + entry_idx
+
+ # Load the index bounds for our bin and calculate
+ # the number of tokens assigned to our expert.
+ start = 0
+ if expert_idx > 0:
+ start = tl.load(bins + expert_idx - 1)
+ end = tl.load(bins + expert_idx)
+ num_tokens = end - start
+
+ # Calculate our offset into the input. If we don't
+ # have an input exit early.
+ if entry_idx >= num_tokens:
+ return
+ index_out = tl.load(indices + start + entry_idx)
+
+ # Offset the input and output pointers.
+ wgrad += index_out
+ grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS)
+ x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ data = tl.load(x + offsets, mask=mask).to(tl.float32)
+ scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
+ acc += data * scale
+ offsets += BLOCK_X
+
+ # Reduce to get the final result and store.
+ out = tl.sum(acc).to(wgrad.dtype.element_ty)
+ tl.store(wgrad, out)
+
+
+def binned_scatter_wgrad(x, grad, indices, bins, top_k):
+ # Validate the input shapes.
+ assert_is_tensor(x, 3)
+ assert_is_matrix(grad)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(bins.shape[0], x.shape[0])
+
+ num_experts, expert_capacity, hidden_size = x.shape
+ tokens = indices.shape[0] // top_k
+ out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
+ _binned_copy_wgrad[(num_experts, expert_capacity)](
+ x,
+ grad,
+ out,
+ num_experts,
+ expert_capacity,
+ indices,
+ bins,
+ NUM_COLUMNS=hidden_size,
+ TOP_K=top_k,
+ )
+ return out
diff --git a/build/torch211-cxx11-cu130-aarch64-linux/benchmark_util.py b/build/torch211-cxx11-cu130-aarch64-linux/benchmark_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..02612d95e3ead1175a596e2878fa34b5bf85ad6f
--- /dev/null
+++ b/build/torch211-cxx11-cu130-aarch64-linux/benchmark_util.py
@@ -0,0 +1,35 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import numpy as np
+import torch
+
+
+def log_benchmark(name, arguments, time, std):
+ print('=' * 60)
+ print(f'{name} Benchmark')
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean time = {:.3f}ms, std time = {:.3f}ms'.format(time, std))
+ print('=' * 60)
+
+
+def benchmark_function(fn, iterations=100, warmup=10):
+ # Warmup iterations.
+ for _ in range(warmup):
+ fn()
+
+ times = []
+ for i in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+
+ start.record()
+ fn()
+ end.record()
+
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ return np.mean(times), np.std(times)
diff --git a/build/torch211-cxx11-cu130-aarch64-linux/cpu_fused_moe.py b/build/torch211-cxx11-cu130-aarch64-linux/cpu_fused_moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a29e81d6e9c57f64b0a78ac1c0828e45fd9d855
--- /dev/null
+++ b/build/torch211-cxx11-cu130-aarch64-linux/cpu_fused_moe.py
@@ -0,0 +1,311 @@
+# SPDX-License-Identifier: Apache-2.0
+# MegaBlocks CPU Fused MoE Implementation
+#
+# This is a pure Python/PyTorch implementation for CPU.
+# For better performance, consider using the C++ kernel implementation.
+#
+import torch
+import torch.nn.functional as F
+
+
+def swigluoai_activation(gate: torch.Tensor, up: torch.Tensor,
+ alpha: float = 1.702, limit: float = 7.0) -> torch.Tensor:
+ """
+ SwigluOAI activation function used in GptOss models.
+
+ Formula:
+ gate = clamp(gate, max=limit)
+ up = clamp(up, -limit, limit)
+ glu = gate * sigmoid(gate * alpha)
+ output = (up + 1) * glu
+
+ Args:
+ gate: Gate tensor from gate projection
+ up: Up tensor from up projection
+ alpha: Scaling factor for sigmoid (default: 1.702)
+ limit: Clamp limit (default: 7.0)
+
+ Returns:
+ Activated tensor
+ """
+ gate = gate.clamp(max=limit)
+ up = up.clamp(min=-limit, max=limit)
+ glu = gate * torch.sigmoid(gate * alpha)
+ return (up + 1) * glu
+
+
+def silu_and_mul_activation(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor:
+ """
+ SiLU (Swish) activation with element-wise multiplication.
+
+ Formula:
+ output = silu(gate) * up
+
+ Args:
+ gate: Gate tensor
+ up: Up tensor
+
+ Returns:
+ Activated tensor
+ """
+ return F.silu(gate) * up
+
+
+def route_tokens_cpu(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: torch.Tensor | None,
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_normalize_expert_weights: int | None = None,
+) -> tuple:
+ """
+ Route tokens to experts and compute expert weights and indices (CPU version).
+
+ Args:
+ x: Input tensor [batch, seq, hidden] or [tokens, hidden]
+ router_weight: Router weight [num_experts, hidden]
+ router_bias: Router bias [num_experts] or None
+ moe_top_k: Number of experts per token
+ moe_num_experts: Total number of experts
+ moe_normalize_expert_weights: Normalization order or None
+
+ Returns:
+ Tuple of (logits, expert_weights, expert_indices)
+ """
+ x_flat = x.view(-1, x.shape[-1])
+ logits = F.linear(x_flat, router_weight, router_bias)
+
+ if moe_top_k == 1:
+ expert_weights, expert_indices = logits.max(dim=-1, keepdim=True)
+ else:
+ expert_weights, expert_indices = torch.topk(logits, moe_top_k, dim=-1)
+
+ expert_weights = expert_weights.softmax(dim=-1)
+
+ if moe_normalize_expert_weights is not None:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+
+ return logits, expert_weights, expert_indices
+
+
+def cpu_fused_moe(
+ hidden_states: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ w1_bias: torch.Tensor | None = None,
+ w2_bias: torch.Tensor | None = None,
+ activation: str = "silu",
+ alpha: float = 1.702,
+ limit: float = 7.0,
+ is_interleaved: bool = True,
+) -> torch.Tensor:
+ """
+ CPU Fused MoE using PyTorch operations.
+
+ This implementation processes all experts in parallel using batched operations
+ instead of sequential for loops, which is more efficient on CPU.
+
+ Args:
+ hidden_states: [num_tokens, hidden_size]
+ w1: [num_experts, hidden_size, 2*inter_size] - gate_up_proj weights
+ w2: [num_experts, inter_size, hidden_size] - down_proj weights
+ topk_weights: [num_tokens, topk] - routing weights
+ topk_ids: [num_tokens, topk] - expert indices
+ w1_bias: [num_experts, 2*inter_size] or None
+ w2_bias: [num_experts, hidden_size] or None
+ activation: "silu" or "swigluoai"
+ alpha: swigluoai alpha parameter
+ limit: swigluoai limit parameter
+ is_interleaved: whether gate_up is interleaved [g0,u0,g1,u1,...] (True for GptOss)
+
+ Returns:
+ output: [num_tokens, hidden_size]
+ """
+ num_tokens, hidden_size = hidden_states.shape
+ num_experts = w1.shape[0]
+ inter_size = w2.shape[1]
+ topk = topk_weights.shape[1]
+
+ # Initialize output
+ output = torch.zeros_like(hidden_states)
+
+ # Build expert mask: which tokens go to which expert
+ # expert_mask[expert_id] contains indices of (token_idx, topk_pos) pairs
+ for expert_idx in range(num_experts):
+ # Find tokens assigned to this expert
+ # mask shape: [num_tokens, topk], True where topk_ids == expert_idx
+ mask = (topk_ids == expert_idx)
+
+ if not mask.any():
+ continue
+
+ # Get token indices and topk positions
+ token_indices, topk_positions = torch.where(mask)
+
+ if len(token_indices) == 0:
+ continue
+
+ # Gather input tokens for this expert
+ # current_hidden: [num_selected_tokens, hidden_size]
+ current_hidden = hidden_states[token_indices]
+
+ # Get weights for this expert
+ # w1[expert_idx]: [hidden_size, 2*inter_size]
+ # w2[expert_idx]: [inter_size, hidden_size]
+ expert_w1 = w1[expert_idx] # [hidden_size, 2*inter_size]
+ expert_w2 = w2[expert_idx] # [inter_size, hidden_size]
+
+ # First projection: hidden @ w1 -> [num_selected, 2*inter_size]
+ gate_up = current_hidden @ expert_w1
+
+ # Add bias if present
+ if w1_bias is not None:
+ gate_up = gate_up + w1_bias[expert_idx]
+
+ # Split gate and up projections
+ if is_interleaved:
+ # GptOss uses interleaved layout: [g0, u0, g1, u1, ...]
+ gate = gate_up[..., ::2] # [num_selected, inter_size]
+ up = gate_up[..., 1::2] # [num_selected, inter_size]
+ else:
+ # Standard layout: [gate_all, up_all]
+ gate = gate_up[..., :inter_size]
+ up = gate_up[..., inter_size:]
+
+ # Apply activation
+ if activation == "swigluoai":
+ activated = swigluoai_activation(gate, up, alpha, limit)
+ else: # silu
+ activated = silu_and_mul_activation(gate, up)
+
+ # Second projection: activated @ w2 -> [num_selected, hidden_size]
+ expert_out = activated @ expert_w2
+
+ # Add bias if present
+ if w2_bias is not None:
+ expert_out = expert_out + w2_bias[expert_idx]
+
+ # Apply routing weights and accumulate
+ # weights shape: [num_selected]
+ weights = topk_weights[token_indices, topk_positions].unsqueeze(-1)
+ weighted_out = expert_out * weights
+
+ # Accumulate to output
+ output.index_add_(0, token_indices, weighted_out.to(output.dtype))
+
+ return output
+
+
+class MegaBlocksMoeMLP(torch.nn.Module):
+ """
+ CPU MoE MLP module that can be used as a drop-in replacement for
+ the transformers GptOssMLP when using @use_kernel_forward_from_hub.
+ """
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> tuple:
+ """
+ Forward pass through the MoE layer.
+
+ Args:
+ x: Input tensor of shape [batch_size, seq_len, hidden_size] or [tokens, hidden_size]
+
+ Returns:
+ Tuple of (output, expert_weights) where:
+ - output: Tensor of same shape as input
+ - expert_weights: Expert weights for each token [tokens, top_k]
+ """
+ # Get MoE parameters from the wrapped modules
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+
+ # Detect activation type
+ if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
+ activation = "swigluoai"
+ alpha = self.experts.alpha
+ limit = self.experts.limit
+ else:
+ activation = getattr(self.experts, "activation", "silu")
+ alpha = 1.702
+ limit = 7.0
+
+ # Get weight tensors
+ if hasattr(self.experts, "gate_up_proj"):
+ w1 = self.experts.gate_up_proj
+ is_interleaved = True # GptOss uses interleaved layout
+ elif hasattr(self.experts, "w1"):
+ w1 = self.experts.w1
+ w3 = getattr(self.experts, "w3", None)
+ if w3 is not None:
+ w1 = torch.cat([w1, w3], dim=-1)
+ is_interleaved = False
+ else:
+ raise AttributeError("experts module must have 'gate_up_proj' or 'w1' attribute")
+
+ if hasattr(self.experts, "down_proj"):
+ w2 = self.experts.down_proj
+ elif hasattr(self.experts, "w2"):
+ w2 = self.experts.w2
+ else:
+ raise AttributeError("experts module must have 'down_proj' or 'w2' attribute")
+
+ # Get optional bias tensors
+ w1_bias = getattr(self.experts, "gate_up_proj_bias", None)
+ w2_bias = getattr(self.experts, "down_proj_bias", None)
+
+ # Store original shape
+ in_shape = x.size()
+
+ # Route tokens to experts
+ logits, expert_weights, expert_indices = route_tokens_cpu(
+ x,
+ self.router.weight,
+ getattr(self.router, "bias", None),
+ moe_top_k,
+ moe_num_experts,
+ moe_normalize_expert_weights,
+ )
+
+ # Reshape input for fused MoE
+ x_flat = x.view(-1, x.shape[-1])
+
+ # Call CPU fused MoE
+ output = cpu_fused_moe(
+ hidden_states=x_flat,
+ w1=w1,
+ w2=w2,
+ topk_weights=expert_weights,
+ topk_ids=expert_indices,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ activation=activation,
+ alpha=alpha,
+ limit=limit,
+ is_interleaved=is_interleaved,
+ )
+
+ # Restore original shape
+ output = output.view(in_shape)
+
+ return output, expert_weights
+
+
+# Export classes and functions
+__all__ = [
+ "MegaBlocksMoeMLP",
+ "cpu_fused_moe",
+ "route_tokens_cpu",
+ "swigluoai_activation",
+ "silu_and_mul_activation",
+]
diff --git a/build/torch211-cxx11-cu130-aarch64-linux/cpu_moe_cpp.py b/build/torch211-cxx11-cu130-aarch64-linux/cpu_moe_cpp.py
new file mode 100644
index 0000000000000000000000000000000000000000..073ff66d24ce348fbb5ed19c9027fadd3f7a9c61
--- /dev/null
+++ b/build/torch211-cxx11-cu130-aarch64-linux/cpu_moe_cpp.py
@@ -0,0 +1,265 @@
+# SPDX-License-Identifier: Apache-2.0
+# MegaBlocks C++ Optimized CPU MoE
+
+"""
+C++ accelerated MoE with brgemm optimization for Intel AMX.
+Direct replacement for cpu_fused_moe.MegaBlocksMoeMLP with better performance.
+"""
+
+import torch
+from typing import Optional
+from .cpu_fused_moe import route_tokens_cpu
+from ._ops import ops
+
+
+def _to_local_tensor(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
+ """Convert DTensor to local torch.Tensor if needed for custom ops compatibility."""
+ if tensor is None:
+ return None
+ # Check if it's a DTensor by looking for the to_local() method
+ if hasattr(tensor, "to_local"):
+ return tensor.to_local()
+ return tensor
+
+
+def fused_moe_cpp(
+ hidden_states: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ inplace: bool = False,
+ use_int8_w8a8: bool = False,
+ use_fp8_w8a16: bool = False,
+ use_mxfp4: bool = False,
+ w1_scale: Optional[torch.Tensor] = None,
+ w2_scale: Optional[torch.Tensor] = None,
+ block_size: Optional[list] = None,
+ a1_scale: Optional[torch.Tensor] = None,
+ a2_scale: Optional[torch.Tensor] = None,
+ w1_bias: Optional[torch.Tensor] = None,
+ w2_bias: Optional[torch.Tensor] = None,
+ alpha: Optional[float] = None,
+ limit: Optional[float] = None,
+ is_vnni: bool = False,
+) -> torch.Tensor:
+ """
+ C++ Fused MoE with brgemm optimization (sglang compatible interface).
+
+ Uses at::native::cpublas::brgemm for efficient batch GEMM on Intel CPUs.
+ Supports both silu_and_mul (standard SwiGLU) and swigluoai (GptOss) activations.
+
+ Args:
+ hidden_states: Input tensor [M, K]
+ w1: Gate and up projections [E, 2N, K]
+ w2: Down projection [E, K, N]
+ topk_weights: Expert weights [M, topk]
+ topk_ids: Expert indices [M, topk]
+ inplace: Whether to use hidden_states as output
+ use_int8_w8a8: Use int8 quantization
+ use_fp8_w8a16: Use fp8 quantization
+ use_mxfp4: Use mxfp4 quantization
+ w1_scale, w2_scale: Quantization scales
+ block_size: Block size for fp8
+ a1_scale, a2_scale: Activation scales
+ w1_bias, w2_bias: Optional biases
+ alpha: swigluoai alpha parameter (set to enable swiglu)
+ limit: swigluoai limit parameter (set to enable swiglu)
+ is_vnni: Whether w1/w2 are already in VNNI packed format
+ """
+ # MXFP4/FP8 kernels only support bf16, convert if needed
+ orig_dtype = hidden_states.dtype
+ need_convert = ((use_mxfp4 or use_fp8_w8a16) and orig_dtype != torch.bfloat16) or orig_dtype == torch.float32
+ if need_convert:
+ hidden_states = hidden_states.to(torch.bfloat16)
+
+ # bias must match hidden_states dtype
+ if w1_bias is not None:
+ w1_bias = w1_bias.to(hidden_states.dtype)
+ if w2_bias is not None:
+ w2_bias = w2_bias.to(hidden_states.dtype)
+
+ # Convert DTensor to local tensor for custom ops compatibility (TP mode)
+ hidden_states = _to_local_tensor(hidden_states)
+ w1 = _to_local_tensor(w1)
+ w2 = _to_local_tensor(w2)
+ topk_weights = _to_local_tensor(topk_weights)
+ topk_ids = _to_local_tensor(topk_ids)
+ w1_scale = _to_local_tensor(w1_scale)
+ w2_scale = _to_local_tensor(w2_scale)
+ a1_scale = _to_local_tensor(a1_scale)
+ a2_scale = _to_local_tensor(a2_scale)
+ w1_bias = _to_local_tensor(w1_bias)
+ w2_bias = _to_local_tensor(w2_bias)
+
+ output = ops.fused_experts(
+ hidden_states, w1, w2, topk_weights, topk_ids,
+ inplace, use_int8_w8a8, use_fp8_w8a16, use_mxfp4,
+ w1_scale, w2_scale, block_size, a1_scale, a2_scale,
+ w1_bias, w2_bias, alpha, limit, is_vnni
+ )
+
+ # Convert back to original dtype if needed
+ if need_convert:
+ output = output.to(orig_dtype)
+ return output
+
+
+class CPUMegaBlocksMoeMLP(torch.nn.Module):
+ """
+ C++ optimized MoE MLP using brgemm.
+ Drop-in replacement for cpu_fused_moe.MegaBlocksMoeMLP with better performance.
+
+ Usage in transformers:
+ # Will be used via @use_kernel_forward_from_hub decorator
+ """
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> tuple:
+ """
+ Forward pass through the MoE layer using C++ kernel.
+
+ Args:
+ x: Input tensor [batch_size, seq_len, hidden_size]
+
+ Returns:
+ Tuple of (output, expert_weights)
+ """
+ # Optimization for GPT-OSS model
+ if getattr(self, "use_mxfp4", None) is None:
+ self.use_mxfp4 = False
+
+ w1_scale = None
+ w2_scale = None
+
+ if (
+ not getattr(self, "packed_scales", False)
+ and hasattr(self.experts, "gate_up_proj")
+ and getattr(self.experts, "gate_up_proj_precision_config", None) is not None
+ ):
+ # convert scales
+ data_1 = ops.convert_scale_packed(self.experts.gate_up_proj_precision_config.weight_scale.data.transpose(-1, -2).contiguous())
+ data_2 = ops.convert_scale_packed(self.experts.down_proj_precision_config.weight_scale.data.transpose(-1, -2).contiguous())
+ self.experts.gate_up_proj_precision_config.weight_scale.storage.data = data_1
+ self.experts.down_proj_precision_config.weight_scale.storage.data = data_2
+ self.packed_scales = True
+ self.use_mxfp4 = True
+
+ if not getattr(self, "packed_weight", False) and hasattr(
+ self.experts, "gate_up_proj"
+ ):
+ # convert weights
+ data_1 = self.experts.gate_up_proj.data.transpose(-1, -2).contiguous()
+ data_2 = self.experts.down_proj.data.transpose(-1, -2).contiguous()
+ if self.use_mxfp4:
+ self.experts.gate_up_proj.storage.data = ops.convert_weight_packed(data_1)
+ self.experts.down_proj.storage.data = ops.convert_weight_packed(data_2)
+ else:
+ # convert_weight_packed only supports bfloat16, float16, int8, fp8_e4m3 or uint8(mxfp4 or int4).
+ data_1 = data_1.to(torch.bfloat16) if data_1.dtype == torch.float32 else data_1
+ data_2 = data_2.to(torch.bfloat16) if data_2.dtype == torch.float32 else data_2
+ self.experts.gate_up_proj.data = ops.convert_weight_packed(data_1)
+ self.experts.down_proj.data = ops.convert_weight_packed(data_2)
+
+ # C++ kernel does not support float32.
+ dtype = torch.bfloat16 if x.dtype == torch.float32 else x.dtype
+ if getattr(self.experts, "gate_up_proj_bias", None) is not None:
+ self.experts.gate_up_proj_bias.data = self.experts.gate_up_proj_bias.data.to(dtype)
+ if getattr(self.experts, "down_proj_bias", None) is not None:
+ self.experts.down_proj_bias.data = self.experts.down_proj_bias.data.to(dtype)
+
+ self.packed_weight = True
+
+ # Get MoE parameters
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
+
+ # Detect activation type
+ if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
+ activation = "swigluoai"
+ alpha = self.experts.alpha
+ limit = self.experts.limit
+ else:
+ activation = getattr(self.experts, "activation", "silu")
+ alpha = 1.702
+ limit = 7.0
+
+ # Get weight tensors
+ if hasattr(self.experts, "gate_up_proj"):
+ w1 = self.experts.gate_up_proj
+ elif hasattr(self.experts, "w1"):
+ w1 = self.experts.w1
+ w3 = getattr(self.experts, "w3", None)
+ if w3 is not None:
+ w1 = torch.cat([w1, w3], dim=-1)
+ else:
+ raise AttributeError("experts module must have 'gate_up_proj' or 'w1' attribute")
+
+ if hasattr(self.experts, "down_proj"):
+ w2 = self.experts.down_proj
+ elif hasattr(self.experts, "w2"):
+ w2 = self.experts.w2
+ else:
+ raise AttributeError("experts module must have 'down_proj' or 'w2' attribute")
+
+ # Get optional bias tensors
+ w1_bias = getattr(self.experts, "gate_up_proj_bias", None)
+ w2_bias = getattr(self.experts, "down_proj_bias", None)
+ w1_bias = w1_bias if w1_bias is None else w1_bias.data
+ w2_bias = w2_bias if w2_bias is None else w2_bias.data
+
+ if self.use_mxfp4:
+ w1_scale = self.experts.gate_up_proj_precision_config.weight_scale.data
+ w2_scale = self.experts.down_proj_precision_config.weight_scale.data
+
+ # Store original shape
+ in_shape = x.size()
+
+ # Route tokens to experts (Python implementation is fast enough)
+ logits, expert_weights, expert_indices = route_tokens_cpu(
+ x,
+ self.router.weight,
+ getattr(self.router, "bias", None),
+ moe_top_k,
+ moe_num_experts,
+ moe_normalize_expert_weights,
+ )
+
+ # Flatten input
+ x_flat = x.view(-1, x.shape[-1])
+
+ # Determine alpha/limit for swiglu activation
+ use_alpha = alpha if activation == "swigluoai" else None
+ use_limit = limit if activation == "swigluoai" else None
+
+ # Call C++ optimized kernel
+ output = fused_moe_cpp(
+ hidden_states=x_flat,
+ w1=w1.data,
+ w2=w2.data,
+ topk_weights=expert_weights,
+ topk_ids=expert_indices.to(torch.int32),
+ inplace=False,
+ use_int8_w8a8=False,
+ use_fp8_w8a16=False,
+ use_mxfp4=self.use_mxfp4,
+ w1_scale=w1_scale,
+ w2_scale=w2_scale,
+ block_size=None,
+ a1_scale=None,
+ a2_scale=None,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ alpha=use_alpha,
+ limit=use_limit,
+ is_vnni=getattr(self, "packed_weight", False),
+ )
+
+ # Restore original shape
+ output = output.view(in_shape)
+
+ return output, expert_weights
+
+
+__all__ = ["fused_moe_cpp", "MegaBlocksMoeMLP"]
diff --git a/build/torch211-cxx11-cu130-aarch64-linux/grouped_gemm/__init__.py b/build/torch211-cxx11-cu130-aarch64-linux/grouped_gemm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b91c8308f0c24f4c4171b6e4f15b6f76dabf295a
--- /dev/null
+++ b/build/torch211-cxx11-cu130-aarch64-linux/grouped_gemm/__init__.py
@@ -0,0 +1,2 @@
+from . import ops
+from . import backend
diff --git a/build/torch211-cxx11-cu130-aarch64-linux/grouped_gemm/backend.py b/build/torch211-cxx11-cu130-aarch64-linux/grouped_gemm/backend.py
new file mode 100644
index 0000000000000000000000000000000000000000..76037d8039cbfc2f0577275c78e4bc0be762592a
--- /dev/null
+++ b/build/torch211-cxx11-cu130-aarch64-linux/grouped_gemm/backend.py
@@ -0,0 +1,33 @@
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# # TODO(tgale): Wrap this in a try-block with better
+# # error message and instructions for building the
+# # c++ operations.
+# import grouped_gemm_backend as backend
+
+# We import the backend operations from the megablocks package as
+# grouped_gemm is vendored in megablocks in this repository.
+# from ... import _ops as backend
+# from megablocks._ops import ops as backend # type: ignore
+from .._ops import ops as backend # type: ignore
+
+def _allocate_output(a, b, batch_sizes, trans_a, trans_b):
+ assert not (trans_a and trans_b)
+ assert batch_sizes.ndim == 1, "Expected 1d tensor for batch_sizes"
+ assert a.ndim == 2, "Expected 2d tensor for 'a'"
+ assert b.ndim == (2 if trans_a else 3)
+
+ shape = (
+ (batch_sizes.shape[0], a.shape[1], b.shape[1])
+ if trans_a else
+ (a.shape[0], (b.shape[1] if trans_b else b.shape[2]))
+ )
+ return torch.empty(*shape, device=a.device, dtype=a.dtype)
+
+def gmm(a, b, batch_sizes, trans_a=False, trans_b=False, c=None):
+ if c is None:
+ c = _allocate_output(a, b, batch_sizes, trans_a, trans_b)
+ backend.gmm(a, b, c, batch_sizes, trans_a, trans_b)
+ return c
diff --git a/build/torch211-cxx11-cu130-aarch64-linux/grouped_gemm/ops.py b/build/torch211-cxx11-cu130-aarch64-linux/grouped_gemm/ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b30dd14e23837ea3b12334f4e31337ed9ad2b69
--- /dev/null
+++ b/build/torch211-cxx11-cu130-aarch64-linux/grouped_gemm/ops.py
@@ -0,0 +1,33 @@
+from . import backend
+import torch
+
+
+class GroupedGemm(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, a, b, batch_sizes, trans_b):
+ ctx.save_for_backward(a, b, batch_sizes)
+ ctx.trans_b = trans_b
+ return backend.gmm(a, b, batch_sizes, trans_a=False, trans_b=trans_b)
+
+ @staticmethod
+ def backward(ctx, grad):
+ grad = grad.contiguous()
+ a, b, batch_sizes = ctx.saved_tensors
+ trans_b = ctx.trans_b
+
+ agrad = None
+ if ctx.needs_input_grad[0]:
+ agrad = backend.gmm(
+ grad, b, batch_sizes, trans_a=False, trans_b=not trans_b)
+
+ bgrad = None
+ if ctx.needs_input_grad[1]:
+ lhs, rhs = (grad, a) if trans_b else (a, grad)
+ bgrad = backend.gmm(
+ lhs, rhs, batch_sizes, trans_a=True, trans_b=False)
+ return agrad, bgrad, None, None
+
+
+def gmm(a, b, batch_sizes, trans_b=False):
+ return GroupedGemm.apply(a, b, batch_sizes, trans_b)
diff --git a/build/torch211-cxx11-cu130-aarch64-linux/grouped_gemm_util.py b/build/torch211-cxx11-cu130-aarch64-linux/grouped_gemm_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d6d49fc46b0a57ad46e4179df3cc1ac2a24f7ae
--- /dev/null
+++ b/build/torch211-cxx11-cu130-aarch64-linux/grouped_gemm_util.py
@@ -0,0 +1,31 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+import warnings
+
+_grouped_gemm_is_available: bool = False
+try:
+ # import grouped_gemm
+ pass
+ _grouped_gemm_is_available = True
+except ImportError as error:
+ warnings.warn('Grouped GEMM not available.')
+
+
+def grouped_gemm_is_available():
+ return _grouped_gemm_is_available
+
+
+def assert_grouped_gemm_is_available():
+ msg = (
+ 'Grouped GEMM not available. Please run '
+ '`pip install git+https://github.com/tgale96/grouped_gemm@main`.',
+ )
+ assert _grouped_gemm_is_available, msg
+
+
+# backend = grouped_gemm.backend if grouped_gemm_is_available() else None
+# ops = grouped_gemm.ops if grouped_gemm_is_available() else None
+
+
+from .grouped_gemm import backend as ops
+from .grouped_gemm import ops as backend
diff --git a/build/torch211-cxx11-cu130-aarch64-linux/layers.py b/build/torch211-cxx11-cu130-aarch64-linux/layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c91edcd9e2d1a4ef9eac90217ff481f08ab1886
--- /dev/null
+++ b/build/torch211-cxx11-cu130-aarch64-linux/layers.py
@@ -0,0 +1,1232 @@
+import torch
+import torch.distributed as dist
+
+from typing import Optional, Any, TYPE_CHECKING
+
+from . import _layers
+from . import ops
+
+# Conditional import for meta kernel registration
+if TYPE_CHECKING:
+
+ def register_fake(fn):
+ return lambda name: fn
+
+else:
+ try:
+ from torch.library import register_fake
+ except ImportError:
+ try:
+ from torch.library import impl_abstract as register_fake
+ except ImportError:
+ # Fallback for older PyTorch versions
+ def register_fake(op_name):
+ def decorator(fn):
+ return fn
+
+ return decorator
+
+
+# Meta kernel implementations for torch.compile compatibility
+def _install_meta_kernels():
+ """Install meta kernels for existing MegaBlocks operations"""
+
+ # Create wrapper functions that check for compilation and return meta tensors
+
+ # Patch ops.sort
+ if hasattr(ops, "sort"):
+ original_sort = ops.sort
+
+ def sort_with_meta(x, end_bit=None):
+ if torch.compiler.is_compiling():
+ # print("Using meta kernel for sort")
+ # Meta implementation - return tensors with correct shape/dtype/device
+ return torch.empty_like(x), torch.empty_like(x)
+ # print("Using original sort kernel")
+ return original_sort(x, end_bit)
+
+ ops.sort = sort_with_meta
+
+ # Patch ops.histogram
+ if hasattr(ops, "histogram"):
+ original_histogram = ops.histogram
+
+ def histogram_with_meta(x, max_val):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty((max_val,), dtype=torch.int32, device=x.device)
+ return original_histogram(x, max_val)
+
+ ops.histogram = histogram_with_meta
+
+ # Patch ops.inclusive_cumsum
+ if hasattr(ops, "inclusive_cumsum"):
+ original_inclusive_cumsum = ops.inclusive_cumsum
+
+ def inclusive_cumsum_with_meta(x, dim):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty_like(x)
+ return original_inclusive_cumsum(x, dim)
+
+ ops.inclusive_cumsum = inclusive_cumsum_with_meta
+
+ # Patch ops.binned_gather
+ if hasattr(ops, "binned_gather"):
+ original_binned_gather = ops.binned_gather
+
+ def binned_gather_with_meta(x, indices, bins, bin_size, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - output shape based on bin_size
+ if x.dim() >= 2:
+ hidden_size = x.size(-1)
+ return torch.empty(
+ (bin_size, x.size(1), hidden_size),
+ dtype=x.dtype,
+ device=x.device,
+ )
+ else:
+ return torch.empty((bin_size,), dtype=x.dtype, device=x.device)
+ return original_binned_gather(x, indices, bins, bin_size, top_k)
+
+ ops.binned_gather = binned_gather_with_meta
+
+ # Patch ops.binned_scatter
+ if hasattr(ops, "binned_scatter"):
+ original_binned_scatter = ops.binned_scatter
+
+ def binned_scatter_with_meta(x, indices, weights, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - typically reduces to 2D
+ if x.dim() >= 3:
+ return torch.empty(
+ (x.size(1), x.size(2)), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty_like(x)
+ return original_binned_scatter(x, indices, weights, bins, top_k)
+
+ ops.binned_scatter = binned_scatter_with_meta
+
+ # Patch ops.gather
+ if hasattr(ops, "gather"):
+ original_gather = ops.gather
+
+ def gather_with_meta(x, indices, bin_ids, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ if x.dim() >= 2:
+ hidden_size = x.size(-1)
+ return torch.empty(
+ (indices.numel(), hidden_size), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty(indices.shape, dtype=x.dtype, device=x.device)
+ return original_gather(x, indices, bin_ids, bins, top_k)
+
+ ops.gather = gather_with_meta
+
+ # Patch ops.scatter
+ if hasattr(ops, "scatter"):
+ original_scatter = ops.scatter
+
+ def scatter_with_meta(x, indices, bin_ids, weights, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - restore sequence shape
+ seq_len = (
+ indices.size(0) // top_k
+ if indices.numel() > 0 and top_k > 0
+ else x.size(0)
+ )
+ if x.dim() >= 2:
+ return torch.empty(
+ (seq_len, x.size(-1)), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty((seq_len,), dtype=x.dtype, device=x.device)
+ return original_scatter(x, indices, bin_ids, weights, bins, top_k)
+
+ ops.scatter = scatter_with_meta
+
+ # Patch ops.replicate
+ if hasattr(ops, "replicate"):
+ original_replicate = ops.replicate
+
+ def replicate_with_meta(x, bins, num_outputs):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty(
+ (x.shape[0], num_outputs), dtype=x.dtype, device=x.device
+ )
+ return original_replicate(x, bins, num_outputs)
+
+ ops.replicate = replicate_with_meta
+
+ # Patch ops.repeat (if it's a regular function)
+ if hasattr(ops, "repeat"):
+ original_repeat = ops.repeat
+
+ def repeat_with_meta(x, repeats):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ if isinstance(repeats, (tuple, list)):
+ new_shape = list(x.shape)
+ for i, rep in enumerate(repeats):
+ if i < len(new_shape):
+ new_shape[i] *= rep
+ return torch.empty(new_shape, dtype=x.dtype, device=x.device)
+ else:
+ new_shape = [x.size(0) * repeats] + list(x.shape[1:])
+ return torch.empty(new_shape, dtype=x.dtype, device=x.device)
+ return original_repeat(x, repeats)
+
+ ops.repeat = repeat_with_meta
+
+
+# Install meta kernels on import
+try:
+ _install_meta_kernels()
+except Exception as e:
+ # If meta kernel installation fails, continue without them
+ # torch.compile may not work but the library will still function
+ import warnings
+
+ warnings.warn(
+ f"Failed to install meta kernels for torch.compile support: {e}", UserWarning
+ )
+
+
+# Set the expert model parallel attributes on a tensor
+def set_expert_model_parallel_attributes(
+ tensor: torch.Tensor,
+ is_parallel: bool,
+):
+ assert not hasattr(tensor, "expert_model_parallel")
+ setattr(tensor, "expert_model_parallel", is_parallel)
+
+
+# Get the expert model parallel attributes from a tensor
+def expert_sharding_degree(
+ world_size: int,
+ moe_num_experts: int,
+) -> int:
+ esd = min(world_size, moe_num_experts)
+ if (moe_num_experts % esd) != 0:
+ raise ValueError(f"Cannot shard {moe_num_experts} experts {esd} ways.")
+ return esd
+
+
+# Calculate the hidden sharding degree based on world size and expert sharding degree
+def hidden_sharding_degree(
+ world_size: int,
+ moe_num_experts: int,
+ ffn_hidden_size: int,
+) -> int:
+ esd = expert_sharding_degree(world_size, moe_num_experts)
+ hsd = world_size // esd
+ if (ffn_hidden_size % hsd) != 0:
+ raise ValueError(f"Cannot shard {ffn_hidden_size} features {hsd} ways.")
+ if (esd * hsd) != world_size:
+ raise ValueError(
+ f"Invalid sharding. expert_sharding_degree ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size})."
+ )
+ return hsd
+
+
+# Calculate the number of experts per rank based on world size and expert sharding degree
+def experts_per_rank(
+ moe_num_experts: int,
+ world_size: int,
+) -> int:
+ return moe_num_experts // expert_sharding_degree(world_size, moe_num_experts)
+
+
+# Calculate the number of features per rank based on ffn hidden size and hidden sharding degree
+def features_per_rank(
+ ffn_hidden_size: int, world_size: int, moe_num_experts: int
+) -> int:
+ return ffn_hidden_size // hidden_sharding_degree(
+ world_size, moe_num_experts, ffn_hidden_size
+ )
+
+
+# Apply jitter to the input tensor
+def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor:
+ low = 1.0 - moe_jitter_eps
+ high = 1.0 + moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return x * (low + noise * (high - low))
+
+
+# Compute the top-k scores from the logits
+def compute_top_k(scores: torch.Tensor, moe_top_k: int):
+ if moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, moe_top_k, dim=-1)
+
+
+# Route tokens to experts and compute expert weights and indices
+def route_tokens(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: torch.Tensor,
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ if training and moe_jitter_eps is not None:
+ x = apply_jitter(x, moe_jitter_eps)
+
+ x_flat = x.view(-1, x.shape[-1])
+ logits = torch.nn.functional.linear(x_flat, router_weight, router_bias)
+ expert_weights, expert_indices = compute_top_k(logits, moe_top_k)
+ expert_weights = expert_weights.softmax(dim=-1)
+ if moe_normalize_expert_weights is not None:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+ if uniform_expert_assignment:
+ expert_indices = _layers.router._uniform_expert_assignment(
+ expert_indices,
+ moe_num_experts,
+ )
+
+ return logits, expert_weights, expert_indices
+
+
+# Scale the gradient of the weights
+def scale_grad(
+ w: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+) -> torch.Tensor:
+ if gradient_scale is None:
+ return w
+ return _layers.mlp.scale_gradient(w, gradient_scale)
+
+
+# Forward pass for the MLP layer
+def mlp_forward(
+ x: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ limit: float = 7.0,
+):
+ # Scale weights
+ w1 = scale_grad(w1, gradient_scale)
+ w2 = scale_grad(w2, gradient_scale)
+ w1_bias = scale_grad(w1_bias, gradient_scale)
+ w2_bias = scale_grad(w2_bias, gradient_scale)
+
+ # Resolve dtensors
+ w1 = _layers.mlp.resolve_dtensor(w1)
+ w2 = _layers.mlp.resolve_dtensor(w2)
+ w1_bias = _layers.mlp.resolve_dtensor(w1_bias)
+ w2_bias = _layers.mlp.resolve_dtensor(w2_bias)
+
+ # Forward pass
+ gate_up = torch.bmm(x, w1) + w1_bias[..., None, :]
+ gate, up = gate_up[..., ::2], gate_up[..., 1::2]
+ gate = gate.clamp(min=None, max=limit)
+ up = up.clamp(min=-limit, max=limit)
+ glu = gate * torch.sigmoid(gate * alpha)
+ next_states = torch.bmm(((up + 1) * glu), w2)
+ next_states += w2_bias[..., None, :]
+ return next_states
+
+# Shared expert MLP forward pass
+def shared_mlp_forward(
+ x: torch.Tensor,
+ up_proj_weight: torch.Tensor,
+ down_proj_weight: torch.Tensor,
+ up_proj_bias: Optional[torch.Tensor] = None,
+ down_proj_bias: Optional[torch.Tensor] = None,
+ activation_fn: Optional[Any] = None,
+ gradient_scale: Optional[float] = None,
+) -> torch.Tensor:
+ # Default activation function
+ if activation_fn is None:
+ activation_fn = torch.nn.functional.gelu
+
+ # Scale weights
+ up_proj_weight = scale_grad(up_proj_weight, gradient_scale)
+ down_proj_weight = scale_grad(down_proj_weight, gradient_scale)
+ if up_proj_bias is not None:
+ up_proj_bias = scale_grad(up_proj_bias, gradient_scale)
+ if down_proj_bias is not None:
+ down_proj_bias = scale_grad(down_proj_bias, gradient_scale)
+
+ # Resolve dtensors
+ up_proj_weight = _layers.mlp.resolve_dtensor(up_proj_weight)
+ down_proj_weight = _layers.mlp.resolve_dtensor(down_proj_weight)
+ if up_proj_bias is not None:
+ up_proj_bias = _layers.mlp.resolve_dtensor(up_proj_bias)
+ if down_proj_bias is not None:
+ down_proj_bias = _layers.mlp.resolve_dtensor(down_proj_bias)
+
+ # Up projection
+ x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias)
+
+ # Activation
+ x = activation_fn(x)
+
+ # Down projection
+ x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias)
+
+ return x
+
+
+# Combine outputs from shared expert and regular experts
+def combine_expert_shared_outputs(
+ shared_expert_out: torch.Tensor,
+ expert_out: torch.Tensor,
+ shared_expert_weighted_sum: bool = False,
+ moe_top_k: int = 1,
+) -> torch.Tensor:
+ if shared_expert_weighted_sum:
+ # Weighted sum based on number of experts used
+ total_experts = moe_top_k + 1
+ shared_weight = 1.0 / total_experts
+ expert_weight = moe_top_k / total_experts
+ return shared_expert_out * shared_weight + expert_out * expert_weight
+ else:
+ # Simple addition
+ return shared_expert_out + expert_out
+
+
+# Global variable to store load balancing loss
+_LOAD_BALANCING_LOSS = []
+
+
+def save_load_balancing_loss(loss):
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.append(loss)
+
+
+def get_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ return _LOAD_BALANCING_LOSS
+
+
+def clear_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.clear()
+
+
+def batched_load_balancing_loss(args):
+ if args.moe_loss_weight == 0:
+ return 0.0
+
+ tokens_per_expert, expert_scores = zip(*get_load_balancing_loss())
+ num_layers_per_pipeline_stage = args.num_layers // args.pipeline_model_parallel_size
+ if args.num_layers_per_virtual_pipeline_stage is not None:
+ num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage
+
+ if len(tokens_per_expert) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f"Expected {num_layers_per_pipeline_stage} token_per_experts "
+ f"but found {len(tokens_per_expert)}.\nnum_layers = "
+ f"{args.num_layers}\npipeline_model_parallel_size = "
+ f"{args.pipeline_model_parallel_size}\n"
+ "num_layers_per_virtual_pipeline_stage"
+ f" = {args.num_layers_per_virtual_pipeline_stage}",
+ )
+ if len(expert_scores) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f"Expected {num_layers_per_pipeline_stage} expert_scores "
+ f"but found {len(tokens_per_expert)}.\nnum_layers = "
+ f"{args.num_layers}\npipeline_model_parallel_size = "
+ f"{args.pipeline_model_parallel_size}\n"
+ "num_layers_per_virtual_pipeline_stage"
+ f" = {args.num_layers_per_virtual_pipeline_stage}",
+ )
+
+ # Verify the shape of the tokens_per_expert and expert_scores tensors.
+ assert all(
+ (x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert)
+ )
+
+ tokens = expert_scores[0].shape[0]
+ assert all(
+ (
+ (
+ x.ndim == 2
+ and x.shape[1] == args.moe_num_experts
+ and x.shape[0] == tokens
+ )
+ for x in expert_scores
+ )
+ )
+
+ # Concatenate the contributions of each layer and convert to
+ # the correct types and formats for the dot product.
+ expert_scores = torch.cat(expert_scores, dim=1)
+ if args.moe_lbl_in_fp32:
+ expert_scores = expert_scores.float()
+ if tokens != 0:
+ expert_scores = expert_scores.mean(dim=0)
+ else:
+ expert_scores = expert_scores.sum(dim=0)
+ tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype)
+
+ expected_values = num_layers_per_pipeline_stage * args.moe_num_experts
+ assert tokens_per_expert.numel() == expected_values
+ assert expert_scores.numel() == expected_values
+
+ # Calculate the total scale across all factors.
+ #
+ # loss_weight * num_experts / (num_layers * tokens * top_k)
+ scale_numerator = args.moe_num_experts * args.moe_loss_weight
+ scale_denominator = args.num_layers * tokens * args.moe_top_k
+ scale = scale_numerator / scale_denominator
+ return scale * torch.dot(tokens_per_expert, expert_scores)
+
+
+# Calculate the expert capacity based on tokens, top_k, number of experts,
+# expert parallel group, capacity factor, and whether expert model parallelism is used.
+def expert_capacity(
+ tokens: int,
+ top_k: int,
+ num_experts: int,
+ expert_parallel_group: int,
+ moe_capacity_factor: float,
+ moe_expert_model_parallelism: bool,
+) -> int:
+ world_size = (
+ dist.get_world_size(expert_parallel_group)
+ if moe_expert_model_parallelism
+ else 1
+ )
+
+ tokens_per_expert = top_k * tokens * world_size / num_experts
+ return int(moe_capacity_factor * tokens_per_expert)
+
+
+def load_balancing_loss(
+ tokens_per_expert: torch.Tensor,
+ expert_scores: torch.Tensor,
+ top_k: int,
+ num_experts: int,
+):
+ assert len(expert_scores.size()) == 2
+ tokens, num_experts = expert_scores.size()
+ assert num_experts == num_experts
+ assert len(tokens_per_expert.size()) == 1
+ (num_experts,) = tokens_per_expert.size()
+ assert num_experts == num_experts
+ scale = num_experts / (tokens * top_k)
+ return scale * torch.dot(
+ tokens_per_expert.to(expert_scores.dtype),
+ expert_scores.mean(dim=0),
+ )
+
+
+def indices_and_bins(
+ top_expert: torch.Tensor,
+ sort_end_bit: int,
+ num_experts: int,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ top_expert = top_expert.int()
+
+ # Ensure contiguous memory layout
+ top_expert = top_expert.contiguous()
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(top_expert.device):
+ output = ops.sort(top_expert, sort_end_bit)
+ bin_ids, indices = output
+ tokens_per_expert = ops.histogram(top_expert, num_experts)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+
+ bins = bins.view(1) if not len(bins.size()) else bins
+ return indices, bin_ids, bins, tokens_per_expert
+
+
+def expert_capacity_fn(
+ tokens: int,
+ top_k: int,
+ num_experts: int,
+ expert_parallel_group: torch.distributed.ProcessGroup,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+) -> int:
+ world_size = (
+ dist.get_world_size(expert_parallel_group)
+ if moe_expert_model_parallelism
+ else 1
+ )
+ tokens_per_expert = top_k * tokens * world_size / num_experts
+ return int(moe_capacity_factor * tokens_per_expert)
+
+
+def permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ top_k,
+ w1,
+ w2,
+ w1_bias,
+ w2_bias,
+ gradient_scale,
+ alpha,
+):
+ # Route tokens to experts
+ x = x.view(-1, x.shape[-1])
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(x.device):
+ x = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
+
+ # Expert computation
+ x = mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale, alpha)
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(x.device):
+ # Route tokens back
+ out = ops.binned_scatter(x, indices, expert_weights, bins, top_k)
+ return out
+
+
+def forward_once(
+ x: torch.Tensor,
+ expert_weights: torch.Tensor,
+ top_experts: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ top_k: int = 4,
+ num_experts: int = 128,
+ expert_parallel_group: int = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ mlp_impl: Optional[str] = None,
+):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
+ top_experts, sort_end_bit, num_experts
+ )
+
+ # Calculate expert capacity
+ sl, bs, _ = x.size()
+
+ expert_capacity = expert_capacity_fn(
+ sl * bs,
+ top_k,
+ num_experts,
+ expert_parallel_group,
+ moe_capacity_factor,
+ moe_expert_model_parallelism,
+ )
+
+ if expert_capacity == 0:
+ expert_capacity = torch.max(tokens_per_expert).item()
+
+ x = permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ top_k,
+ w1,
+ w2,
+ w1_bias,
+ w2_bias,
+ gradient_scale,
+ alpha,
+ )
+ return x, tokens_per_expert
+
+
+def parallel_forward_once(
+ x: torch.Tensor,
+ expert_weights: torch.Tensor,
+ top_experts: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ top_k: int = 4,
+ num_experts: int = 128,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = True,
+ hidden_size: int = 1152,
+ mlp_impl: Optional[str] = "grouped",
+):
+ # Flatten inputs
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+
+ # TODO: remove debugging var
+ # my_rank = dist.get_rank(expert_parallel_group) if expert_parallel_group else 0
+
+ with torch.no_grad():
+ # Step 1: Local permutation setup
+ indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
+ top_experts, sort_end_bit, num_experts
+ )
+
+ # Calculate sharding parameters
+ world_size = dist.get_world_size(expert_parallel_group)
+ hidden_sharding_deg = hidden_sharding_degree(
+ world_size, num_experts, hidden_size
+ )
+ experts_per_rank_val = experts_per_rank(num_experts, world_size)
+
+ # Replicate token counts for hidden sharding
+ repeated_tokens_per_expert = ops.repeat(
+ tokens_per_expert, (hidden_sharding_deg,)
+ )
+
+ # Exchange token counts across devices
+ parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert)
+
+ # Ensure CUB knows which device to use
+ tpe_handle = dist.all_to_all_single(
+ parallel_tokens_per_expert,
+ repeated_tokens_per_expert,
+ group=expert_parallel_group,
+ async_op=True,
+ )
+
+ # Step 2: Local permutation - group tokens by target device
+ x = x.view(-1, x.shape[-1]) # [sl * bs, hs]
+ x = ops.gather(x, indices, bin_ids, bins, top_k)
+
+ # Step 3: Compute communication counts and exchange tokens
+ with torch.no_grad():
+ tpe_handle.wait()
+
+ # Reshape for per-device calculations
+ repeated_tokens_per_expert = repeated_tokens_per_expert.view(
+ world_size, experts_per_rank_val
+ )
+ parallel_tokens_per_expert = parallel_tokens_per_expert.view(
+ world_size, experts_per_rank_val
+ )
+
+ # Calculate send/recv counts
+ send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1).tolist()
+ # recv_counts = parallel_tokens_per_expert.cpu().sum(dim=-1).tolist()
+ parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
+ recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1).tolist()
+ tokens_received = sum(recv_counts)
+
+ # Replicate for hidden sharding
+ x = ops.repeat(x, (hidden_sharding_deg, 1))
+
+ # Cross-device token exchange
+ parallel_x, parallel_x_handle = _layers.all_to_all.all_to_all(
+ x, recv_counts, send_counts, expert_parallel_group, async_op=True
+ )
+
+ with torch.no_grad():
+ # Step 4: Setup for local expert computation
+ replicate_bins = ops.inclusive_cumsum(parallel_tokens_per_expert.flatten(), 0)
+ replicate_bins = (
+ replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins
+ )
+
+ # Create expert indices for received tokens
+ parallel_top_expert = torch.remainder(
+ torch.arange(
+ num_experts * hidden_sharding_deg,
+ dtype=torch.int32,
+ device=indices.device,
+ ),
+ experts_per_rank_val,
+ )
+ parallel_top_expert = ops.replicate(
+ parallel_top_expert.unsqueeze(dim=0),
+ replicate_bins,
+ tokens_received,
+ ).flatten()
+
+ # Sort tokens by expert assignment
+ parallel_bin_ids, parallel_indices = ops.sort(
+ parallel_top_expert,
+ sort_end_bit,
+ )
+
+ # Calculate bins for local experts
+ parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
+ dim=0, dtype=torch.int
+ )
+ parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = (
+ parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins
+ )
+
+ # Calculate expert capacity
+ expert_capacity = expert_capacity_fn(
+ tokens_received,
+ top_k,
+ experts_per_rank_val,
+ expert_parallel_group,
+ moe_capacity_factor,
+ moe_expert_model_parallelism,
+ )
+ if expert_capacity == 0:
+ expert_capacity = torch.max(parallel_tokens_per_expert).item()
+
+ # Locally permute the tokens and perform the expert computation.
+ # Block to make sure that the cross-device permutation is complete.
+ if mlp_impl == "grouped":
+ # GroupedMLP requires counts on CPU. We can use the tensor already
+ # moved to CPU for the prior all_to_all, which avoids an extra
+ # device synchronization.
+ parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+
+ # Step 5: Expert computation
+ parallel_x_handle.wait()
+
+ parallel_x = permute_and_compute(
+ parallel_x,
+ parallel_tokens_per_expert,
+ parallel_indices,
+ parallel_bin_ids,
+ None, # expert_weights
+ parallel_bins,
+ expert_capacity,
+ top_k=1,
+ w1=w1,
+ w2=w2,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ )
+
+ # Step 6: Reverse communication - send results back
+ x, _ = _layers.all_to_all.all_to_all(
+ parallel_x, send_counts, recv_counts, expert_parallel_group
+ )
+
+ # Step 7: Reduce across hidden sharding dimension
+ shape = (hidden_sharding_deg, -1, hidden_size)
+ x = x.view(shape).sum(dim=0)
+
+ # Step 8: Final local unpermutation
+ x = ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
+
+ return x, tokens_per_expert.flatten()
+
+
+def moe_forward(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: Optional[torch.Tensor],
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+ w1: torch.Tensor = None,
+ w2: torch.Tensor = None,
+ w1_bias: torch.Tensor = None,
+ w2_bias: torch.Tensor = None,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ forward_fn: Any = None,
+ hidden_size: int = None,
+ mlp_impl: str = "grouped",
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+
+ # Route tokens to experts
+ logits, expert_weights, expert_indices = route_tokens(
+ x,
+ router_weight,
+ router_bias,
+ moe_top_k,
+ moe_num_experts,
+ moe_jitter_eps,
+ moe_normalize_expert_weights,
+ uniform_expert_assignment,
+ training,
+ )
+
+ # Create router scores for output
+ router_scores = (
+ torch.zeros_like(logits)
+ .scatter_(1, expert_indices, expert_weights)
+ .transpose(0, 1)
+ )
+
+ in_shape = x.size()
+
+ # Prepare forward function arguments
+ forward_args = {
+ "x": x,
+ "expert_weights": expert_weights,
+ "top_experts": expert_indices,
+ "w1": w1,
+ "w2": w2,
+ "w1_bias": w1_bias,
+ "w2_bias": w2_bias,
+ "gradient_scale": gradient_scale,
+ "alpha": alpha,
+ "sort_end_bit": sort_end_bit,
+ "top_k": moe_top_k,
+ "num_experts": moe_num_experts,
+ "expert_parallel_group": expert_parallel_group,
+ "moe_capacity_factor": moe_capacity_factor,
+ "moe_expert_model_parallelism": moe_expert_model_parallelism,
+ "mlp_impl": mlp_impl,
+ }
+
+ # Add hidden_size for parallel forward
+ if moe_expert_model_parallelism and hidden_size is not None:
+ forward_args["hidden_size"] = hidden_size
+ elif moe_expert_model_parallelism and hidden_size is None:
+ # Infer hidden_size from input shape
+ forward_args["hidden_size"] = x.shape[-1]
+
+ # Compute expert outputs
+ x, tokens_per_expert = forward_fn(**forward_args)
+
+ # Save load balancing loss if needed
+ moe_loss_weight = 0.0 # Can be made configurable
+ if training and moe_loss_weight > 0:
+ save_load_balancing_loss((tokens_per_expert, logits))
+
+ # Restore original shape
+ x = x.view(in_shape)
+
+ return x, expert_weights, router_scores
+
+
+def moe_forward_with_shared_expert(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: Optional[torch.Tensor],
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+ w1: torch.Tensor = None,
+ w2: torch.Tensor = None,
+ w1_bias: torch.Tensor = None,
+ w2_bias: torch.Tensor = None,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ forward_fn: Any = None,
+ hidden_size: int = None,
+ mlp_impl: str = "grouped",
+ # Shared expert parameters
+ shared_up_proj_weight: Optional[torch.Tensor] = None,
+ shared_down_proj_weight: Optional[torch.Tensor] = None,
+ shared_up_proj_bias: Optional[torch.Tensor] = None,
+ shared_down_proj_bias: Optional[torch.Tensor] = None,
+ shared_expert_weighted_sum: bool = False,
+ shared_activation_fn: Optional[Any] = None,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+
+ # First, compute regular MoE forward pass
+ expert_out, expert_weights, router_scores = moe_forward(
+ x=x,
+ router_weight=router_weight,
+ router_bias=router_bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=training,
+ w1=w1,
+ w2=w2,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=moe_expert_model_parallelism,
+ forward_fn=forward_fn,
+ hidden_size=hidden_size,
+ mlp_impl=mlp_impl,
+ )
+
+ # If shared expert weights provided, compute shared expert output
+ if shared_up_proj_weight is not None and shared_down_proj_weight is not None:
+ shared_expert_out = shared_mlp_forward(
+ x=x,
+ up_proj_weight=shared_up_proj_weight,
+ down_proj_weight=shared_down_proj_weight,
+ up_proj_bias=shared_up_proj_bias,
+ down_proj_bias=shared_down_proj_bias,
+ activation_fn=shared_activation_fn,
+ gradient_scale=gradient_scale,
+ )
+
+ # Combine expert outputs
+ combined_out = combine_expert_shared_outputs(
+ shared_expert_out=shared_expert_out,
+ expert_out=expert_out,
+ shared_expert_weighted_sum=shared_expert_weighted_sum,
+ moe_top_k=moe_top_k,
+ )
+
+ return combined_out, expert_weights, router_scores
+
+ # Return regular MoE output if no shared expert
+ return expert_out, expert_weights, router_scores
+
+
+def create_shared_expert_weights(
+ hidden_size: int,
+ shared_expert_hidden_size: int,
+ device: torch.device,
+ dtype: torch.dtype,
+ init_method: Any,
+ output_layer_init_method: Any = None,
+) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
+
+ if output_layer_init_method is None:
+ output_layer_init_method = init_method
+
+ # Create weight tensors
+ up_proj_weight = torch.empty(
+ shared_expert_hidden_size,
+ hidden_size,
+ device=device,
+ dtype=dtype,
+ )
+ down_proj_weight = torch.empty(
+ hidden_size,
+ shared_expert_hidden_size,
+ device=device,
+ dtype=dtype,
+ )
+
+ # Initialize weights
+ init_method(up_proj_weight)
+ output_layer_init_method(down_proj_weight)
+
+ # No bias by default
+ return up_proj_weight, down_proj_weight, None, None
+
+
+# HACK: Extract device_mesh from pre-hook closure - required for transformers integration
+# This exists because device_mesh is trapped in hook closures with no model attribute
+# Fragile - breaks if hook structure changes or Python internals change
+# TODO: Replace with a more robust solution when available
+def get_device_mesh(model):
+ # Extract device_mesh from child's unused pre_hook closure
+ try:
+ # Find the pre-hook that contains 'device_mesh' in its closure
+ hook = next(
+ h
+ for h in model.experts._forward_pre_hooks.values()
+ if "device_mesh" in h.__code__.co_freevars
+ )
+ # Extract the device_mesh from the closure
+ return hook.__closure__[
+ hook.__code__.co_freevars.index("device_mesh")
+ ].cell_contents
+ except Exception:
+ return None
+
+
+class MegaBlocksMoeMLP(torch.nn.Module):
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ gradient_scale = getattr(self.experts, "gradient_scale", None)
+ alpha = getattr(self.experts, "alpha", 1.0)
+ moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+ uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
+
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = get_device_mesh(self)
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
+
+ has_parallel = (
+ expert_parallel_group is not None
+ and dist.is_initialized()
+ and dist.get_world_size(expert_parallel_group) > 1
+ )
+ forward_fn = parallel_forward_once if has_parallel else forward_once
+
+ sort_end_bit = max(
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
+ )
+ mlp_impl = getattr(self, "mlp_impl", "grouped")
+ output, expert_weights_out, *_ = moe_forward(
+ x=x,
+ router_weight=self.router.weight,
+ router_bias=self.router.bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=self.training,
+ w1=self.experts.gate_up_proj,
+ w2=self.experts.down_proj,
+ w1_bias=self.experts.gate_up_proj_bias,
+ w2_bias=self.experts.down_proj_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=has_parallel,
+ forward_fn=forward_fn,
+ hidden_size=self.experts.hidden_size,
+ mlp_impl=mlp_impl,
+ )
+ return output, expert_weights_out
+
+
+# Export main classes
+__all__ = ["MegaBlocksMoeMLP", "MegaBlocksMoeMLPWithSharedExpert"]
+
+
+class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
+
+ def __init__(self):
+ super().__init__()
+ # Shared expert weights will be set by the user
+ self.shared_up_proj_weight = None
+ self.shared_down_proj_weight = None
+ self.shared_up_proj_bias = None
+ self.shared_down_proj_bias = None
+ self.shared_expert_weighted_sum = False
+ self.shared_activation_fn = None
+
+ def set_shared_expert_weights(
+ self,
+ up_proj_weight: torch.Tensor,
+ down_proj_weight: torch.Tensor,
+ up_proj_bias: Optional[torch.Tensor] = None,
+ down_proj_bias: Optional[torch.Tensor] = None,
+ weighted_sum: bool = False,
+ activation_fn: Optional[Any] = None,
+ ):
+ self.shared_up_proj_weight = up_proj_weight
+ self.shared_down_proj_weight = down_proj_weight
+ self.shared_up_proj_bias = up_proj_bias
+ self.shared_down_proj_bias = down_proj_bias
+ self.shared_expert_weighted_sum = weighted_sum
+ self.shared_activation_fn = activation_fn
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ gradient_scale = getattr(self.experts, "gradient_scale", None)
+ alpha = getattr(self.experts, "alpha", 1.0)
+ moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+ uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
+
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = get_device_mesh(self)
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
+
+ has_parallel = (
+ expert_parallel_group is not None
+ and dist.is_initialized()
+ and dist.get_world_size(expert_parallel_group) > 1
+ )
+ forward_fn = parallel_forward_once if has_parallel else forward_once
+
+ sort_end_bit = max(
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
+ )
+ mlp_impl = getattr(self, "mlp_impl", "grouped")
+
+ output, expert_weights_out, *_ = moe_forward_with_shared_expert(
+ x=x,
+ router_weight=self.router.weight,
+ router_bias=self.router.bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=self.training,
+ w1=self.experts.gate_up_proj,
+ w2=self.experts.down_proj,
+ w1_bias=self.experts.gate_up_proj_bias,
+ w2_bias=self.experts.down_proj_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=has_parallel,
+ forward_fn=forward_fn,
+ hidden_size=self.experts.hidden_size,
+ mlp_impl=mlp_impl,
+ # Shared expert parameters
+ shared_up_proj_weight=self.shared_up_proj_weight,
+ shared_down_proj_weight=self.shared_down_proj_weight,
+ shared_up_proj_bias=self.shared_up_proj_bias,
+ shared_down_proj_bias=self.shared_down_proj_bias,
+ shared_expert_weighted_sum=self.shared_expert_weighted_sum,
+ shared_activation_fn=self.shared_activation_fn,
+ )
+ return output, expert_weights_out
+
+
+# Patch for XPU or CPU support
+if hasattr(torch, "xpu") and torch.xpu.is_available():
+ from .xpu_fused_moe import MegaBlocksMoeMLP
+
+from .cpu_moe_cpp import CPUMegaBlocksMoeMLP
diff --git a/build/torch211-cxx11-cu130-aarch64-linux/megablocks/__init__.py b/build/torch211-cxx11-cu130-aarch64-linux/megablocks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23
--- /dev/null
+++ b/build/torch211-cxx11-cu130-aarch64-linux/megablocks/__init__.py
@@ -0,0 +1,26 @@
+import ctypes
+import importlib.util
+import sys
+from pathlib import Path
+from types import ModuleType
+
+
+def _import_from_path(file_path: Path) -> ModuleType:
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
+ # it would also be used for other imports. So, we make a module name that
+ # depends on the path for it to be unique using the hex-encoded hash of
+ # the path.
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
+ module_name = path_hash
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ if spec is None:
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
+ module = importlib.util.module_from_spec(spec)
+ if module is None:
+ raise ImportError(f"Cannot load module {module_name} from spec")
+ sys.modules[module_name] = module
+ spec.loader.exec_module(module) # type: ignore
+ return module
+
+
+globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
diff --git a/build/torch211-cxx11-cu130-aarch64-linux/metadata.json b/build/torch211-cxx11-cu130-aarch64-linux/metadata.json
new file mode 100644
index 0000000000000000000000000000000000000000..a9813b81c6c98110d265c184f2016d728202289b
--- /dev/null
+++ b/build/torch211-cxx11-cu130-aarch64-linux/metadata.json
@@ -0,0 +1,18 @@
+{
+ "version": 1,
+ "license": "Apache-2.0",
+ "python-depends": [],
+ "backend": {
+ "type": "cuda",
+ "archs": [
+ "10.0",
+ "12.0",
+ "7.5",
+ "8.0",
+ "8.6",
+ "8.7",
+ "8.9",
+ "9.0"
+ ]
+ }
+}
diff --git a/build/torch211-cxx11-cu130-aarch64-linux/ops/__init__.py b/build/torch211-cxx11-cu130-aarch64-linux/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b944080df810d0b0cfc571f3009b0098a651f9b7
--- /dev/null
+++ b/build/torch211-cxx11-cu130-aarch64-linux/ops/__init__.py
@@ -0,0 +1,35 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from .binned_gather import binned_gather
+from .binned_scatter import binned_scatter
+from .cumsum import exclusive_cumsum, inclusive_cumsum
+from .gather import gather
+from .histogram import histogram
+from .padded_gather import padded_gather
+from .padded_scatter import padded_scatter
+from .repeat import repeat
+from .replicate import replicate
+from .round_up import round_up
+from .scatter import scatter
+from .sort import sort
+from .sum import sum
+from .topology import topology
+
+__all__ = [
+ 'binned_gather',
+ 'binned_scatter',
+ 'exclusive_cumsum',
+ 'inclusive_cumsum',
+ 'gather',
+ 'histogram',
+ 'padded_gather',
+ 'padded_scatter',
+ 'repeat',
+ 'replicate',
+ 'round_up',
+ 'scatter',
+ 'sort',
+ 'sum',
+ 'topology',
+]
diff --git a/build/torch211-cxx11-cu130-aarch64-linux/ops/all_to_all_benchmark.py b/build/torch211-cxx11-cu130-aarch64-linux/ops/all_to_all_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c939818edca3345f6344bbc7cef07ffe3cd0181
--- /dev/null
+++ b/build/torch211-cxx11-cu130-aarch64-linux/ops/all_to_all_benchmark.py
@@ -0,0 +1,63 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import torch.distributed as dist
+
+# from megablocks import benchmark_util
+# from megablocks.layers.all_to_all import all_to_all
+
+from .. import benchmark_util
+from .._layers.all_to_all import all_to_all
+
+_ALL_TO_ALL_BENCHMARK = (
+ (8, 1024),
+ (16, 1024),
+ (32, 1024),
+ (64, 1024),
+ (128, 1024),
+ (256, 1024),
+ (512, 1024),
+ (1024, 1024),
+ (2 * 1024, 1024),
+ (4 * 1024, 1024),
+ (8 * 1024, 1024),
+ (16 * 1024, 1024),
+ (32 * 1024, 1024),
+ (64 * 1024, 1024),
+ (128 * 1024, 1024),
+ (256 * 1024, 1024),
+ (512 * 1024, 1024),
+ (1024 * 1024, 1024),
+)
+
+
+def benchmark_all_to_all(group, sl, hs):
+ world_size = dist.get_world_size(group)
+ assert (sl % world_size) == 0
+ send_recv_sizes = [sl // world_size] * world_size
+
+ x = torch.randn((sl, hs)).cuda().half()
+
+ details = {
+ 'world_size': world_size,
+ 'message_size (B)': send_recv_sizes[0] * hs * 2, # 2B elements.
+ }
+
+ def benchmark():
+ return all_to_all(x, send_recv_sizes, send_recv_sizes, group)
+
+ time, std = benchmark_util.benchmark_function(benchmark)
+
+ if dist.get_rank(group) == 0:
+ benchmark_util.log_benchmark('All-To-All', details, time, std)
+
+
+if __name__ == '__main__':
+ assert dist.is_available()
+ group = dist.init_process_group(backend='nccl')
+ local_rank = dist.get_rank(group)
+ torch.cuda.set_device(local_rank)
+
+ for args in _ALL_TO_ALL_BENCHMARK:
+ benchmark_all_to_all(group, *args)
diff --git a/build/torch211-cxx11-cu130-aarch64-linux/ops/binned_gather.py b/build/torch211-cxx11-cu130-aarch64-linux/ops/binned_gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..189a7fa3518d660f29ea32e7a04827164af98d60
--- /dev/null
+++ b/build/torch211-cxx11-cu130-aarch64-linux/ops/binned_gather.py
@@ -0,0 +1,37 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for binned_gather kernel.
+class BinnedGatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bins: torch.Tensor,
+ bin_size: int,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bins)
+ ctx.top_k = top_k
+ return kernels.binned_gather(x, indices, None, bins, bin_size, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ indices, bins = ctx.saved_tensors
+ out = kernels.binned_scatter(grad, indices, None, bins, ctx.top_k)
+ return out, None, None, None, None
+
+
+binned_gather = BinnedGatherOp.apply
diff --git a/build/torch211-cxx11-cu130-aarch64-linux/ops/binned_scatter.py b/build/torch211-cxx11-cu130-aarch64-linux/ops/binned_scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb937c0c106662ce8108c1cb926f8f063b163d3d
--- /dev/null
+++ b/build/torch211-cxx11-cu130-aarch64-linux/ops/binned_scatter.py
@@ -0,0 +1,59 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for binned_scatter kernel.
+class BinnedScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ):
+ assert len(x.size()) == 3
+ ctx.bin_size = x.size(1)
+ ctx.top_k = top_k
+
+ # TODO(tgale): Don't save 'x' for backwards if we don't need to
+ # calculate the gradient w.r.t. 'weights'.
+ ctx.save_for_backward(x, indices, weights, bins)
+ return kernels.binned_scatter(x, indices, weights, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ x, indices, weights, bins = ctx.saved_tensors
+ out = kernels.binned_gather(
+ grad,
+ indices,
+ weights,
+ bins,
+ ctx.bin_size,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[2]:
+ wgrad = kernels.binned_scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bins,
+ ctx.top_k,
+ )
+ return out, None, wgrad, None, None
+
+
+binned_scatter = BinnedScatterOp.apply
diff --git a/build/torch211-cxx11-cu130-aarch64-linux/ops/cumsum.py b/build/torch211-cxx11-cu130-aarch64-linux/ops/cumsum.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2b7572391e20045d335cf7337246e8a9b9f57ef
--- /dev/null
+++ b/build/torch211-cxx11-cu130-aarch64-linux/ops/cumsum.py
@@ -0,0 +1,52 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ # import megablocks_ops as ops # type: ignore
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrappers for cumsum kernels.
+# NOTE: Does not support gradients.
+class ExclusiveCumsumOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, dim: int):
+ if len(x.size()) == 1:
+ x = x.view([1, -1])
+ out = torch.empty_like(x)
+ ops.exclusive_cumsum(x, 1, out)
+ return out.squeeze()
+ out = torch.empty_like(x)
+ ops.exclusive_cumsum(x, dim, out)
+ return out
+
+
+exclusive_cumsum = ExclusiveCumsumOp.apply
+
+
+class InclusiveCumsumOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, dim: int) -> torch.Tensor:
+ if len(x.size()) == 1:
+ x = x.view([1, -1])
+ out = torch.empty_like(x)
+ ops.inclusive_cumsum(x, 1, out)
+ return out.squeeze()
+ out = torch.empty_like(x)
+ ops.inclusive_cumsum(x, dim, out)
+ return out
+
+
+inclusive_cumsum = InclusiveCumsumOp.apply
diff --git a/build/torch211-cxx11-cu130-aarch64-linux/ops/gather.py b/build/torch211-cxx11-cu130-aarch64-linux/ops/gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1f87c1e7bed8d3589dd790805234976e0b05898
--- /dev/null
+++ b/build/torch211-cxx11-cu130-aarch64-linux/ops/gather.py
@@ -0,0 +1,38 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for gather kernel.
+class GatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bin_ids, bins)
+ ctx.top_k = top_k
+ return kernels.gather(x, indices, bin_ids, None, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+
+ indices, bin_ids, bins = ctx.saved_tensors
+ out = kernels.scatter(grad, indices, bin_ids, None, bins, ctx.top_k)
+ return out, None, None, None, None, None
+
+
+gather = GatherOp.apply
diff --git a/build/torch211-cxx11-cu130-aarch64-linux/ops/histogram.py b/build/torch211-cxx11-cu130-aarch64-linux/ops/histogram.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b3f058ec373cbba7555704fb5e4212c3cc75d9d
--- /dev/null
+++ b/build/torch211-cxx11-cu130-aarch64-linux/ops/histogram.py
@@ -0,0 +1,27 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for histogram kernel.
+# NOTE: Does not support gradients.
+class HistogramOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, max_val: float):
+ return ops.histogram(x, max_val)
+
+
+histogram = HistogramOp.apply
diff --git a/build/torch211-cxx11-cu130-aarch64-linux/ops/histogram_benchmark.py b/build/torch211-cxx11-cu130-aarch64-linux/ops/histogram_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0b9c6047567b87a295979498142230d1b0c9411
--- /dev/null
+++ b/build/torch211-cxx11-cu130-aarch64-linux/ops/histogram_benchmark.py
@@ -0,0 +1,78 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import numpy as np
+import torch
+# from absl.testing import parameterized
+
+from .. import ops
+
+_HISTOGRAM_TESTS = (
+ (16384, torch.int32, 2),
+ (16384, torch.int32, 4),
+ (16384, torch.int32, 8),
+ (16384, torch.int32, 16),
+ (16384, torch.int32, 32),
+ (16384, torch.int32, 64),
+ (16384, torch.int32, 128),
+ (16384, torch.int32, 256),
+)
+
+
+def benchmark_function(fn, iterations=10):
+ # Run once to get rid of startup overhead.
+ fn()
+ times = []
+ for _ in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+ start.record()
+ fn()
+ end.record()
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ times = np.array(times)
+ return times.mean(), times.std(), times.max(), times.min()
+
+
+def log_benchmark(arguments, mean_t, std_t):
+ print('=' * 60)
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t))
+ print('=' * 60)
+
+
+# class HistogramBenchmark(parameterized.TestCase):
+#
+# @parameterized.parameters(*_HISTOGRAM_TESTS)
+# def testHistogram(self, n, dtype, max_val):
+# x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
+#
+# mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.histogram(x, max_val),)
+# arguments = {
+# 'n': n,
+# 'dtype': dtype,
+# 'max_val': max_val,
+# }
+# log_benchmark(arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_HISTOGRAM_TESTS)
+# def testTorchHistogram(self, n, dtype, max_val):
+# x = torch.randint(0, 128, (n,)).cuda().to(dtype)
+#
+# mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.histc(x, max_val, 0, max_val - 1),)
+# arguments = {
+# 'n': n,
+# 'dtype': dtype,
+# 'max_val': max_val,
+# }
+# log_benchmark(arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch211-cxx11-cu130-aarch64-linux/ops/matmul_benchmark.py b/build/torch211-cxx11-cu130-aarch64-linux/ops/matmul_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c4b9b8866ffed2eb769b77f2320c82e5491ae0e
--- /dev/null
+++ b/build/torch211-cxx11-cu130-aarch64-linux/ops/matmul_benchmark.py
@@ -0,0 +1,415 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+
+# import stk
+
+# try:
+# import stk
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/ops/matmul_benchmark.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+# from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+
+# Calling tensor.t() calls tensor.transpose(0, 1) which calls
+# torch.as_strided(...). Circumvent this chain to avoid an overhead
+# this adds.
+def transpose_view(x):
+ return torch.as_strided(
+ x,
+ (x.shape[1], x.shape[0]),
+ (x.stride()[1], x.stride()[0]),
+ )
+
+
+_MATMUL_TESTS = (
+ (64 * 1024, 512, 2048, 64),
+ (32 * 1024, 768, 3072, 64),
+ (8 * 1024, 1024, 4096, 64),
+ (4 * 2048, 4096, 4 * 4096, 4),
+)
+
+
+def log_benchmark(name, arguments, time, std, flops):
+ benchmark_util.log_benchmark(name, arguments, time, std)
+ print('flops = {:.2f}B'.format(flops / 1e9))
+ print('throughput = {:.2f}T'.format(flops / 1e9 / time))
+ print('=' * 60)
+
+
+# class MatmulBenchmark(parameterized.TestCase):
+#
+# def build_sparse_matrix(self, x, padded_bins, fhs, ne):
+# blocking = 128
+# padded_tokens, _ = x.size()
+# assert padded_tokens % blocking == 0
+# assert fhs % blocking == 0
+#
+# # Offsets for the sparse matrix. All rows have the
+# # same number of nonzero blocks dictated by the
+# # dimensionality of a single expert.
+# block_rows = padded_tokens // blocking
+# blocks_per_row = fhs // blocking
+# offsets = torch.arange(
+# 0,
+# block_rows * blocks_per_row + 1,
+# blocks_per_row,
+# dtype=torch.int32,
+# device=x.device,
+# )
+#
+# # Indices for the sparse matrix. The indices for
+# # the intermediate matrix are dynamic depending
+# # on the mapping of tokens to experts.
+# column_indices = ops.topology(
+# padded_bins,
+# blocking,
+# block_rows,
+# blocks_per_row,
+# )
+# data = torch.empty(
+# column_indices.numel(),
+# blocking,
+# blocking,
+# dtype=torch.float16,
+# device=x.device,
+# )
+# shape = (padded_tokens, fhs * ne)
+# row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
+# return stk.Matrix(shape, data, row_indices, column_indices, offsets)
+#
+# def build_input_matrix(self, sl, hs, ne):
+# x = torch.randn((sl, hs)).cuda().half()
+#
+# # Assign tokens to experts uniformly.
+# top_expert = torch.arange(0, sl).cuda().int() % ne
+#
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(top_expert, ne)
+# padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+# padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+# out = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, 1)
+# return out, padded_bins
+#
+# def build_weight_matrix(self, ne, hs, fhs):
+# return torch.randn((hs, ne * fhs)).cuda().half()
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_Fwd_SDD_NT(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+# w = transpose_view(w)
+#
+# def benchmark():
+# return stk.ops.sdd(x, w, topo)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0::Fwd::SDD::NT',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_GradX_DSD_NN(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+#
+# def benchmark():
+# return stk.ops.dsd(topo, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0::GradX::DSD::NN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_GradW_DSD_TN(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+# topo = topo.t()
+#
+# def benchmark():
+# return stk.ops.dsd(topo, x)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0::GradW::DSD::TN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_Fwd_DSD_NN(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+#
+# def benchmark():
+# return stk.ops.dsd(x, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::Fwd::DSD::NN',
+# arguments,
+# mean_t,
+# std_t,
+# x.nnz * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_GradX_SDD_NT(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+# out = stk.ops.dsd(x, w)
+# w = transpose_view(w)
+#
+# def benchmark():
+# return stk.ops.sdd(out, w, x)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::GradX::SDD::NT',
+# arguments,
+# mean_t,
+# std_t,
+# x.nnz * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_GradW_DSD_TN(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+# out = stk.ops.dsd(x, w)
+# x = x.t()
+#
+# def benchmark():
+# return stk.ops.dsd(x, out)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::GradW::DSD::TN',
+# arguments,
+# mean_t,
+# std_t,
+# x.nnz * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_Fwd_DDD_NT(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, hs)).cuda().half()
+# w = torch.randn((ne, hs, fhs)).cuda().half()
+#
+# w = w.transpose(1, 2).contiguous()
+# w = w.transpose(1, 2)
+#
+# def benchmark():
+# return torch.bmm(x, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0::Fwd:DDD::NT',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_GradX_DDD_NN(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, hs)).cuda().half()
+# w = torch.randn((ne, hs, fhs)).cuda().half()
+# out = torch.bmm(x, w)
+# w = w.transpose(1, 2).contiguous()
+#
+# def benchmark():
+# return torch.bmm(out, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0:GradX:DDD::NN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_GradW_DDD_TN(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, hs)).cuda().half()
+# w = torch.randn((ne, hs, fhs)).cuda().half()
+# out = torch.bmm(x, w)
+# out = out.transpose(1, 2)
+#
+# def benchmark():
+# return torch.bmm(out, x)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0:GradW:DDD::TN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_Fwd_DDD_NN(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+# w = torch.randn((ne, fhs, hs)).cuda().half()
+#
+# def benchmark():
+# return torch.bmm(x, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::Fwd::DDD::NN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_GradX_DDD_NT(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+# w = torch.randn((ne, fhs, hs)).cuda().half()
+# out = torch.bmm(x, w)
+# w = torch.transpose(w, 1, 2)
+#
+# def benchmark():
+# return torch.bmm(out, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::GradX::DDD::NT',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_GradW_DDD_TN(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+# w = torch.randn((ne, fhs, hs)).cuda().half()
+# out = torch.bmm(x, w)
+# x = torch.transpose(x, 1, 2)
+#
+# def benchmark():
+# return torch.bmm(x, out)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::GradW::DDD::TN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * hs * 2,
+# )
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch211-cxx11-cu130-aarch64-linux/ops/padded_gather.py b/build/torch211-cxx11-cu130-aarch64-linux/ops/padded_gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1cf4047c9494394d2a3884ba8830179013db7ff
--- /dev/null
+++ b/build/torch211-cxx11-cu130-aarch64-linux/ops/padded_gather.py
@@ -0,0 +1,55 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for padded_gather kernel.
+class PaddedGatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bin_ids, bins, padded_bins)
+ ctx.top_k = top_k
+ return kernels.padded_gather(
+ x,
+ indices,
+ bin_ids,
+ None,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+
+ indices, bin_ids, bins, padded_bins = ctx.saved_tensors
+ out = kernels.padded_scatter(
+ grad,
+ indices,
+ bin_ids,
+ None,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+ return out, None, None, None, None, None
+
+
+padded_gather = PaddedGatherOp.apply
diff --git a/build/torch211-cxx11-cu130-aarch64-linux/ops/padded_scatter.py b/build/torch211-cxx11-cu130-aarch64-linux/ops/padded_scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..61e021b81497e472cda5d72bdac557a0ca92d262
--- /dev/null
+++ b/build/torch211-cxx11-cu130-aarch64-linux/ops/padded_scatter.py
@@ -0,0 +1,98 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for padded_scatter kernel.
+class PaddedScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+ ):
+ maybe_x = [x] if ctx.needs_input_grad[3] else []
+ ctx.save_for_backward(
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ *maybe_x,
+ )
+ ctx.top_k = top_k
+ ctx.x_shape = x.shape
+ return kernels.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ saved_tensors = ctx.saved_tensors
+
+ indices, bin_ids, weights, bins, padded_bins = saved_tensors[:5]
+ dgrad = None
+ if ctx.needs_input_grad[0]:
+ dgrad = kernels.padded_gather(
+ grad,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[3]: # need wgrad
+ x = saved_tensors[-1]
+ wgrad = kernels.padded_scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+ return dgrad, None, None, wgrad, None, None, None, None
+
+
+def padded_scatter(
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+):
+ return PaddedScatterOp.apply(
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
diff --git a/build/torch211-cxx11-cu130-aarch64-linux/ops/padded_scatter_benchmark.py b/build/torch211-cxx11-cu130-aarch64-linux/ops/padded_scatter_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..3cbe4735891446b46f93170c64c23fe63632bf93
--- /dev/null
+++ b/build/torch211-cxx11-cu130-aarch64-linux/ops/padded_scatter_benchmark.py
@@ -0,0 +1,66 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import torch
+# from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+_PADDED_SCATTER_BENCHMARK = (
+ # dMoE-Medium, 8-way EMP.
+ (1024 * 16, 1024, 8, 4),
+ # dMoE-Medium, post-all-to-all.
+ (1024 * 16 * 4, 1024, 8, 1),
+)
+
+
+# class PaddedScatterTest(parameterized.TestCase):
+#
+# @parameterized.parameters(*_PADDED_SCATTER_BENCHMARK)
+# def testPaddedScatter(self, sl, hs, ne, top_k):
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+#
+# # Randomly assign tokens to experts.
+# top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(top_expert, ne)
+# padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+# padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+#
+# # Sample weights for the scatter reduce.
+# weights = torch.rand((sl * top_k,)).cuda().half()
+#
+# # Gather the data to prepare for backwards.
+# x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
+#
+# def benchmark():
+# return ops.padded_scatter(
+# x,
+# indices,
+# bin_ids,
+# weights,
+# bins,
+# padded_bins,
+# top_k,
+# )
+#
+# time, std = benchmark_util.benchmark_function(benchmark)
+# benchmark_util.log_benchmark(
+# 'Padded Scatter',
+# {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# 'top_k': top_k,
+# },
+# time,
+# std,
+# )
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch211-cxx11-cu130-aarch64-linux/ops/permute_benchmark.py b/build/torch211-cxx11-cu130-aarch64-linux/ops/permute_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..697abddbb3a2082ec4ddd6d94f89f7faabb34b40
--- /dev/null
+++ b/build/torch211-cxx11-cu130-aarch64-linux/ops/permute_benchmark.py
@@ -0,0 +1,149 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import torch
+# from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+_PERMUTE_TESTS = (
+ (16384, 768, 2),
+ (16384, 768, 4),
+ (16384, 768, 8),
+ (16384, 768, 16),
+ (16384, 768, 32),
+ (16384, 768, 64),
+ (16384, 768, 128),
+ (16384 * 8, 768, 2),
+ (16384 * 8, 768, 4),
+ (16384 * 8, 768, 8),
+ (16384 * 8, 768, 16),
+ (16384 * 8, 768, 32),
+ (16384 * 8, 768, 64),
+ (16384 * 8, 768, 128),
+)
+
+
+# class PermuteBenchmark(parameterized.TestCase):
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testBinnedGather(self, sl, hs, ne):
+# # NOTE: Capacity factor == 1.
+# ec = sl // ne
+#
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+# top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(indices, ne)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+#
+# def benchmark():
+# return ops.binned_gather(x, indices, bins, ec)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('BinnedGather', arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testBinnedScatter(self, sl, hs, ne):
+# # NOTE: Capacity factor == 1.
+# ec = sl // ne
+#
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+# top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(indices, ne)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+# x = ops.binned_gather(x, indices, bins, ec)
+#
+# def benchmark():
+# return ops.binned_scatter(x, indices, bins)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('BinnedScatter', arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testPaddedGather(self, sl, hs, ne):
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+#
+# # Randomly assign tokens to experts.
+# top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(top_expert, ne)
+# padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+# padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+#
+# def benchmark():
+# return ops.padded_gather(x, indices, bin_ids, bins, padded_bins)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('PaddedGather', arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testPaddedScatter(self, sl, hs, ne):
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+#
+# # Randomly assign tokens to experts.
+# top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(top_expert, ne)
+# padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+# padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+# x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins)
+#
+# def benchmark():
+# return ops.padded_scatter(x, indices, bin_ids, bins, padded_bins)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('PaddedScatter', arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testCopy(self, sl, hs, ne):
+# # NOTE: Capacity factor == 1.
+# # ec = sl // ne
+#
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+# y = x.clone()
+#
+# def benchmark():
+# return y.copy_(x)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('Copy', arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch211-cxx11-cu130-aarch64-linux/ops/repeat.py b/build/torch211-cxx11-cu130-aarch64-linux/ops/repeat.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e9e09de5f857d51cd758ab30b2f3a846d6f9275
--- /dev/null
+++ b/build/torch211-cxx11-cu130-aarch64-linux/ops/repeat.py
@@ -0,0 +1,10 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+
+def repeat(x: torch.Tensor, tiling: torch.Size):
+ if all((t == 1 for t in tiling)):
+ return x
+ return x.repeat(*tiling)
diff --git a/build/torch211-cxx11-cu130-aarch64-linux/ops/replicate.py b/build/torch211-cxx11-cu130-aarch64-linux/ops/replicate.py
new file mode 100644
index 0000000000000000000000000000000000000000..26daf0eede330603a4b8ea7167faf1411d07ca93
--- /dev/null
+++ b/build/torch211-cxx11-cu130-aarch64-linux/ops/replicate.py
@@ -0,0 +1,36 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for replicate kernel.
+class ReplicateOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, bins: torch.Tensor, num_outputs: int):
+ ctx.save_for_backward(bins)
+ out = torch.empty((x.shape[0], num_outputs), dtype=x.dtype, device=x.device)
+ ops.replicate_forward(x, bins, out)
+ return out
+
+ @staticmethod
+ def backward(ctx: Any, grad: torch.Tensor):
+ bins, = ctx.saved_tensors
+ out = torch.empty((grad.shape[0], bins.shape[0]), dtype=grad.dtype, device=grad.device)
+ ops.replicate_backward(grad, bins, out)
+ return out, None, None
+
+
+replicate = ReplicateOp.apply
diff --git a/build/torch211-cxx11-cu130-aarch64-linux/ops/round_up.py b/build/torch211-cxx11-cu130-aarch64-linux/ops/round_up.py
new file mode 100644
index 0000000000000000000000000000000000000000..6cf6bc873c9f448c5fa9126ebcfd66e8688002af
--- /dev/null
+++ b/build/torch211-cxx11-cu130-aarch64-linux/ops/round_up.py
@@ -0,0 +1,14 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+
+def round_up(x: torch.Tensor, value: int):
+ assert isinstance(value, int)
+ assert x.dtype == torch.int32
+
+ # TODO(tgale): If this becomes and issue
+ # do this in a custom kernel. We only expect
+ # to use this on arrays of less than 1k elements.
+ return torch.div(x + (value - 1), value, rounding_mode='trunc') * value
diff --git a/build/torch211-cxx11-cu130-aarch64-linux/ops/scatter.py b/build/torch211-cxx11-cu130-aarch64-linux/ops/scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4605d9b46f387761b070352365f223dbfe69d47
--- /dev/null
+++ b/build/torch211-cxx11-cu130-aarch64-linux/ops/scatter.py
@@ -0,0 +1,72 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Optional
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for scatter kernel.
+class ScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ) -> torch.Tensor:
+ maybe_x = [x] if ctx.needs_input_grad[3] else []
+ ctx.save_for_backward(indices, bin_ids, weights, bins, *maybe_x)
+ ctx.top_k = top_k
+ ctx.x_shape = x.shape
+ return kernels.scatter(x, indices, bin_ids, weights, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ saved_tensors = ctx.saved_tensors
+
+ indices, bin_ids, weights, bins = saved_tensors[:4]
+ dgrad = None
+ if ctx.needs_input_grad[0]:
+ dgrad = kernels.gather(
+ grad,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[3]: # need wgrad
+ x = saved_tensors[-1]
+ wgrad = kernels.scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bin_ids,
+ bins,
+ ctx.top_k,
+ )
+ return dgrad, None, None, wgrad, None, None, None
+
+
+def scatter(
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+) -> Optional[torch.Tensor]:
+ return ScatterOp.apply(x, indices, bin_ids, weights, bins, top_k)
diff --git a/build/torch211-cxx11-cu130-aarch64-linux/ops/sort.py b/build/torch211-cxx11-cu130-aarch64-linux/ops/sort.py
new file mode 100644
index 0000000000000000000000000000000000000000..bda3bf64283e39533c2eae3627e76bb2d0262c9f
--- /dev/null
+++ b/build/torch211-cxx11-cu130-aarch64-linux/ops/sort.py
@@ -0,0 +1,38 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Optional, Tuple
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+_BITS_FOR_DTYPE = {
+ torch.int16: 16,
+ torch.int32: 32,
+ torch.int64: 64,
+}
+
+
+# Autograd wrapper for sort kernel.
+# NOTE: Does not support gradients.
+class SortOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, end_bit: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
+ if end_bit is None:
+ end_bit = _BITS_FOR_DTYPE[x.dtype]
+ x_out = torch.empty_like(x)
+ iota_out = torch.empty_like(x)
+ ops.sort(x, end_bit, x_out, iota_out)
+ return (x_out, iota_out)
+
+
+sort = SortOp.apply
diff --git a/build/torch211-cxx11-cu130-aarch64-linux/ops/sort_benchmark.py b/build/torch211-cxx11-cu130-aarch64-linux/ops/sort_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..11043c0824c36372585f1d9f48480c2a6ef32eb6
--- /dev/null
+++ b/build/torch211-cxx11-cu130-aarch64-linux/ops/sort_benchmark.py
@@ -0,0 +1,85 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import numpy as np
+import torch
+# from absl.testing import parameterized
+
+from .. import ops
+
+_SORT_TESTS = (
+ (16384, torch.int32, None),
+ (16384, torch.int32, 2),
+ (16384, torch.int32, 128),
+)
+
+_BASELINE_SORT_TESTS = ((16384,),)
+
+
+def numpy_dtype(dtype):
+ types = {
+ torch.int16: np.int16,
+ torch.int32: np.int32,
+ torch.int64: np.int64,
+ }
+ return types[dtype]
+
+
+def benchmark_function(fn, iterations=10):
+ # Run once to get rid of startup overhead.
+ fn()
+ times = []
+ for _ in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+ start.record()
+ fn()
+ end.record()
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ times = np.array(times)
+ return times.mean(), times.std(), times.max(), times.min()
+
+
+def log_benchmark(arguments, mean_t, std_t):
+ print('=' * 60)
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t))
+ print('=' * 60)
+
+
+# class SortBenchmark(parameterized.TestCase):
+#
+# @parameterized.parameters(*_SORT_TESTS)
+# def testSort(self, n, dtype, max_val):
+# if max_val is None:
+# max_val = np.iinfo(numpy_dtype(dtype)).max
+# end_bit = int(np.ceil(np.log2(max_val)))
+# x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
+#
+# mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.sort(x, end_bit),)
+# arguments = {
+# 'n': n,
+# 'dtype': dtype,
+# 'max_val': max_val,
+# }
+# log_benchmark(arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_BASELINE_SORT_TESTS)
+# def testTorchSort(self, n):
+# x = torch.randint(0, 128, (n,)).cuda().to(torch.int32)
+#
+# mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.sort(x))
+# arguments = {
+# 'n': n,
+# }
+# log_benchmark(arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch211-cxx11-cu130-aarch64-linux/ops/stk_autocast.py b/build/torch211-cxx11-cu130-aarch64-linux/ops/stk_autocast.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a3626e5e0eec51339c95a448bca84be14a2ca93
--- /dev/null
+++ b/build/torch211-cxx11-cu130-aarch64-linux/ops/stk_autocast.py
@@ -0,0 +1,39 @@
+# vendored from
+# https://github.com/stanford-futuredata/stk/blob/736313768ef697ce13a0594a41b2512a0fbc9884/stk/backend/autocast.py
+import functools
+import torch
+
+
+def _is_eligible(x):
+ return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64)
+
+
+def _cast(x, dtype):
+ if isinstance(x, torch.Tensor) and _is_eligible(x):
+ return x.to(dtype)
+ elif isinstance(x, map):
+ return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()}
+ elif isinstance(x, list) or isinstance(x, tuple):
+ return type(x)(map(lambda y: _cast(y, dtype), x))
+ return x
+
+
+def custom_fwd(fwd):
+ """Wrap a custom autograd function that always uses autocast dtype."""
+
+ @functools.wraps(fwd)
+ def decorate_fwd(*args, **kwargs):
+ if torch.is_autocast_enabled():
+ with torch.autocast(device_type="cuda", enabled=False):
+ dtype = torch.get_autocast_gpu_dtype()
+ return fwd(*_cast(args, dtype), **_cast(kwargs, dtype))
+ return fwd(*args, **kwargs)
+ return decorate_fwd
+
+
+def custom_bwd(bwd):
+ @functools.wraps(bwd)
+ def decorate_bwd(*args, **kwargs):
+ with torch.autocast(device_type="cuda", enabled=False):
+ return bwd(*args, **kwargs)
+ return decorate_bwd
\ No newline at end of file
diff --git a/build/torch211-cxx11-cu130-aarch64-linux/ops/sum.py b/build/torch211-cxx11-cu130-aarch64-linux/ops/sum.py
new file mode 100644
index 0000000000000000000000000000000000000000..e00c1aa68e1193f5b72f75a2edc37de8d505facc
--- /dev/null
+++ b/build/torch211-cxx11-cu130-aarch64-linux/ops/sum.py
@@ -0,0 +1,9 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+import torch
+
+
+def sum(x: torch.Tensor, dim: int = 0):
+ if x.shape[dim] == 1:
+ return x.squeeze(dim=dim)
+ return x.sum(dim=dim)
diff --git a/build/torch211-cxx11-cu130-aarch64-linux/ops/topology.py b/build/torch211-cxx11-cu130-aarch64-linux/ops/topology.py
new file mode 100644
index 0000000000000000000000000000000000000000..76a50d3164db20534b099dcb4d8487a7aef25d15
--- /dev/null
+++ b/build/torch211-cxx11-cu130-aarch64-linux/ops/topology.py
@@ -0,0 +1,45 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for topology kernel.
+# NOTE: Does not support gradients.
+class TopologyOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(
+ ctx: Any,
+ padded_bins: torch.Tensor,
+ block_size: int,
+ output_block_rows: int,
+ output_block_columns: int,
+ ):
+ out = torch.empty(
+ output_block_rows * output_block_columns,
+ dtype=torch.int16,
+ device=padded_bins.device,
+ )
+ ops.indices(
+ padded_bins,
+ block_size,
+ output_block_rows,
+ output_block_columns,
+ out,
+ )
+ return out
+
+
+topology = TopologyOp.apply
diff --git a/build/torch211-cxx11-cu130-aarch64-linux/stk/__init__.py b/build/torch211-cxx11-cu130-aarch64-linux/stk/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..73c40c267b1c3f4949e9c957a5d2c9f682dfc1a6
--- /dev/null
+++ b/build/torch211-cxx11-cu130-aarch64-linux/stk/__init__.py
@@ -0,0 +1,7 @@
+# import stk.random
+# import stk.ops
+# from stk.matrix import Matrix
+
+from . import random
+from . import ops
+from .matrix import Matrix
diff --git a/build/torch211-cxx11-cu130-aarch64-linux/stk/backend/__init__.py b/build/torch211-cxx11-cu130-aarch64-linux/stk/backend/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/build/torch211-cxx11-cu130-aarch64-linux/stk/backend/autocast.py b/build/torch211-cxx11-cu130-aarch64-linux/stk/backend/autocast.py
new file mode 100644
index 0000000000000000000000000000000000000000..97f6e919a60f3fd579ed0215031008d14111dc96
--- /dev/null
+++ b/build/torch211-cxx11-cu130-aarch64-linux/stk/backend/autocast.py
@@ -0,0 +1,37 @@
+import functools
+import torch
+
+
+def _is_eligible(x):
+ return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64)
+
+
+def _cast(x, dtype):
+ if isinstance(x, torch.Tensor) and _is_eligible(x):
+ return x.to(dtype)
+ elif isinstance(x, map):
+ return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()}
+ elif isinstance(x, list) or isinstance(x, tuple):
+ return type(x)(map(lambda y: _cast(y, dtype), x))
+ return x
+
+
+def custom_fwd(fwd):
+ """Wrap a custom autograd function that always uses autocast dtype."""
+
+ @functools.wraps(fwd)
+ def decorate_fwd(*args, **kwargs):
+ if torch.is_autocast_enabled():
+ with torch.autocast(device_type="cuda", enabled=False):
+ dtype = torch.get_autocast_gpu_dtype()
+ return fwd(*_cast(args, dtype), **_cast(kwargs, dtype))
+ return fwd(*args, **kwargs)
+ return decorate_fwd
+
+
+def custom_bwd(bwd):
+ @functools.wraps(bwd)
+ def decorate_bwd(*args, **kwargs):
+ with torch.autocast(device_type="cuda", enabled=False):
+ return bwd(*args, **kwargs)
+ return decorate_bwd
diff --git a/build/torch211-cxx11-cu130-aarch64-linux/stk/backend/sputnik.py b/build/torch211-cxx11-cu130-aarch64-linux/stk/backend/sputnik.py
new file mode 100644
index 0000000000000000000000000000000000000000..220c947bc1e932e8c77cc30f4069e9930f1aa962
--- /dev/null
+++ b/build/torch211-cxx11-cu130-aarch64-linux/stk/backend/sputnik.py
@@ -0,0 +1,316 @@
+import torch
+
+from ..backend import triton_kernels as backend
+from ..backend.autocast import custom_bwd, custom_fwd
+
+
+def _standardize_shape(x, transpose):
+ if transpose:
+ return torch.Size((x[1], x[0]))
+ return x
+
+
+def _sparse_transpose(x):
+ return (torch.Size((x[0][1], x[0][0])), ) + x[1:]
+
+
+def _transpose_helper(x, transpose):
+ if isinstance(x, torch.Tensor):
+ return x.t() if transpose else x
+ if transpose:
+ x = _sparse_transpose(x)
+ return x + (transpose,)
+
+
+def _wrap(x):
+ if isinstance(x, torch.Tensor):
+ return (x,)
+ return x
+
+
+def _is_transposed(x):
+ return (not x.is_contiguous() and
+ x.stride()[0] == 1 and
+ x.stride()[1] == x.size()[0])
+
+
+def _call_helper(op, out, a, b, trans_a, trans_b):
+ args = (_wrap(_transpose_helper(a, trans_a)) +
+ _wrap(_transpose_helper(b, trans_b)))
+ if isinstance(out, tuple):
+ args = args + out
+ return op(*args)
+
+
+def _preprocess_inputs(lhs, rhs, dy):
+ if isinstance(lhs, torch.Tensor) and _is_transposed(lhs):
+ lhs = lhs.t()
+ if isinstance(rhs, torch.Tensor) and _is_transposed(rhs):
+ rhs = rhs.t()
+ if (isinstance(dy, torch.Tensor) and
+ not dy.is_contiguous() and
+ not _is_transposed(dy)):
+ dy = dy.contiguous()
+ if isinstance(dy, tuple) and not dy[1].is_contiguous():
+ dy = (dy[0], dy[1].contiguous()) + dy[2:]
+ return lhs, rhs, dy
+
+
+def _postprocess_outputs(x, transpose, grad):
+ if isinstance(x, torch.Tensor) and transpose:
+ return grad.t()
+ return grad
+
+
+def _lhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs):
+ lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy)
+
+ a, b = (rhs, dy) if trans_lhs else (dy, rhs)
+ trans_a = trans_lhs and trans_rhs
+ trans_b = trans_lhs or not trans_rhs
+ out = _call_helper(op, lhs, a, b, trans_a, trans_b)
+ return _postprocess_outputs(lhs, trans_lhs, out)
+
+
+def _rhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs):
+ lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy)
+
+ a, b = (dy, lhs) if trans_rhs else (lhs, dy)
+ trans_a = not trans_lhs or trans_rhs
+ trans_b = trans_lhs and trans_rhs
+ out = _call_helper(op, rhs, a, b, trans_a, trans_b)
+ return _postprocess_outputs(rhs, trans_rhs, out)
+
+
+class DSD(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs):
+ ctx.save_for_backward(data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ rhs)
+ ctx.shape = _standardize_shape(shape, transpose_a)
+ ctx.transpose_a = transpose_a
+
+ out = torch.empty(
+ (shape[0], rhs.size()[1]),
+ dtype=rhs.dtype,
+ device=rhs.device)
+
+ backend.dsd(shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs,
+ out)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs = (ctx.shape,) + saved_tensors[:-1]
+ rhs = saved_tensors[-1]
+ trans_a = ctx.transpose_a
+ trans_b = _is_transposed(rhs)
+
+ ddata = None
+ if ctx.needs_input_grad[1]:
+ ddata = _lhs_gradient(sdd,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ drhs = None
+ if ctx.needs_input_grad[-1]:
+ op = dds if trans_b else dsd
+ drhs = _rhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return None, ddata, None, None, None, None, None, None, None, drhs
+
+
+dsd = DSD.apply
+
+
+class DDS(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b):
+ ctx.save_for_backward(lhs,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t)
+ ctx.shape = _standardize_shape(shape, transpose_b)
+ ctx.transpose_b = transpose_b
+ out = torch.empty((lhs.size()[0], shape[1]),
+ dtype=lhs.dtype,
+ device=lhs.device)
+ backend.dds(lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b,
+ out)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs = saved_tensors[0]
+ rhs = (ctx.shape,) + saved_tensors[1:]
+ trans_a = _is_transposed(lhs)
+ trans_b = ctx.transpose_b
+
+ dlhs = None
+ if ctx.needs_input_grad[0]:
+ op = dsd if trans_a else dds
+ dlhs = _lhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ ddata = None
+ if ctx.needs_input_grad[2]:
+ ddata = _rhs_gradient(sdd,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return dlhs, None, ddata, None, None, None, None, None, None, None
+
+
+dds = DDS.apply
+
+
+class SDD(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ lhs,
+ rhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t):
+ ctx.save_for_backward(
+ lhs,
+ rhs,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t)
+ ctx.shape = shape
+ out = torch.empty(
+ data.shape,
+ dtype=lhs.dtype,
+ device=lhs.device)
+ backend.sdd(lhs,
+ rhs,
+ shape,
+ out,
+ offsets,
+ row_indices,
+ column_indices)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs, rhs = saved_tensors[:2]
+ dy = (ctx.shape, dy) + saved_tensors[2:]
+ trans_a = _is_transposed(lhs)
+ trans_b = _is_transposed(rhs)
+
+ dlhs = None
+ if ctx.needs_input_grad[0]:
+ op = dds if trans_a else dsd
+ dlhs = _lhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ drhs = None
+ if ctx.needs_input_grad[1]:
+ op = dsd if trans_b else dds
+ drhs = _rhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return dlhs, drhs, None, None, None, None, None, None, None, None
+
+
+sdd = SDD.apply
+
+class RowIndices(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, shape, data, offsets, column_indices):
+ out = torch.empty(
+ column_indices.shape,
+ dtype=column_indices.dtype,
+ device=column_indices.device)
+ backend.row_indices(shape, data, offsets, column_indices, out)
+ return out
+
+
+row_indices = RowIndices.apply
diff --git a/build/torch211-cxx11-cu130-aarch64-linux/stk/backend/triton_kernels.py b/build/torch211-cxx11-cu130-aarch64-linux/stk/backend/triton_kernels.py
new file mode 100644
index 0000000000000000000000000000000000000000..c535309f3321249f475367164a558f94a4f8eb86
--- /dev/null
+++ b/build/torch211-cxx11-cu130-aarch64-linux/stk/backend/triton_kernels.py
@@ -0,0 +1,393 @@
+import torch
+import triton
+import triton.language as tl
+from dataclasses import dataclass
+
+@dataclass
+class TritonConfig:
+ BLOCK_M: int = 128
+ BLOCK_N: int = 128
+ BLOCK_K: int = 32
+ BLOCK_SIZE: int = 128
+ NUM_STAGES: int = 4
+ NUM_WARPS: int = 4
+
+def _validate_matmul_dims(M: int, K: int, N: int):
+ error_string = "incompatible dimensions: tensor has dim with length: {}, which must be divisible by {}"
+ assert M % TritonConfig.BLOCK_M == 0, error_string.format(M, TritonConfig.BLOCK_M)
+ assert K % TritonConfig.BLOCK_K == 0, error_string.format(K, TritonConfig.BLOCK_K)
+ assert N % TritonConfig.BLOCK_N == 0, error_string.format(N, TritonConfig.BLOCK_N)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _sdd_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+ # matrix multiplication
+ pid = tl.program_id(0)
+ pid_m = tl.load(row_indices + pid)
+ pid_n = tl.load(column_indices + pid)
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
+ rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
+ rk = tl.arange(0, BLOCK_K)
+ # pointers
+ A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
+ B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ for k in range(0, tl.cdiv(K, BLOCK_K)):
+ a = tl.load(A)
+ b = tl.load(B)
+ acc += tl.dot(a, b)
+ A += BLOCK_K * stride_ak
+ B += BLOCK_K * stride_bk
+ #Store to sparse matrix
+ acc = acc.to(C.dtype.element_ty)
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+ cm = tl.arange(0, BLOCK_M)
+ cn = tl.arange(0, BLOCK_N)
+ C = C + pid * BLOCK_ELEMENTS + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _dsd_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices, offsets,
+ block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+
+ # matrix multiplication
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+
+ num_pid_m = tl.num_programs(0)
+ num_pid_n = tl.num_programs(1)
+ pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M)
+
+ start_inx = tl.load(offsets + pid_m)
+ end_inx = tl.load(offsets + pid_m + 1)
+
+ # pointers to sparse matrix
+ rm = tl.arange(0, BLOCK_M)
+ rak = tl.arange(0, BLOCK_K)
+
+ A += (rm[:, None] * stride_am + rak[None, :] * stride_ak)
+
+ # pointers to dense matrix
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ rbk = tl.arange(0, BLOCK_K)
+ B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn)
+
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K)
+
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+ ak_sub_incr = BLOCK_K * stride_ak
+ bk_sub_incr = BLOCK_K * stride_bk
+ bk_block_incr = BLOCK_SIZE * stride_bk
+
+ for k in range(nsub_blocks * (end_inx - start_inx)):
+ sub_block_inx = k % nsub_blocks
+ block_inx = k // nsub_blocks
+
+ if trans_A:
+ ptr_A = A + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr
+ else:
+ ptr_A = A + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr
+
+ ptr_B = B + tl.load(column_indices + start_inx + block_inx) * bk_block_incr + sub_block_inx * bk_sub_incr
+
+ a = tl.load(ptr_A)
+ b = tl.load(ptr_B)
+ acc += tl.dot(a, b)
+
+ acc = acc.to(C.dtype.element_ty)
+
+ cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+
+ C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _dds_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices, offsets,
+ block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+
+ # matrix multiplication
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+
+ num_pid_m = tl.num_programs(0)
+ num_pid_n = tl.num_programs(1)
+ pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M)
+
+ start_inx = tl.load(offsets + pid_n)
+ end_inx = tl.load(offsets + pid_n + 1)
+
+ # pointers to dense matrix
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ rak = tl.arange(0, BLOCK_K)
+
+ A += (rm[:, None] * stride_am + rak[None, :] * stride_ak)
+
+ # pointers to sparse matrix
+ rn = tl.arange(0, BLOCK_N)
+ rbk = tl.arange(0, BLOCK_K)
+ B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn)
+
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K)
+
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+
+ ak_sub_incr = BLOCK_K * stride_ak
+ ak_block_incr = BLOCK_SIZE * stride_ak
+ bk_sub_incr = BLOCK_K * stride_bk
+
+ for k in range(nsub_blocks * (end_inx - start_inx)):
+ sub_block_inx = k % nsub_blocks
+ block_inx = k // nsub_blocks
+
+ if trans_B:
+ ptr_B = B + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr
+ else:
+ ptr_B = B + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr
+
+ ptr_A = A + tl.load(column_indices + start_inx + block_inx) * ak_block_incr + sub_block_inx * ak_sub_incr
+ a = tl.load(ptr_A)
+ b = tl.load(ptr_B)
+ acc += tl.dot(a, b)
+
+ acc = acc.to(C.dtype.element_ty)
+ cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+def dsd(shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs,
+ out
+ ):
+
+ device = rhs.device
+ trans_A = transpose_a
+ trans_B = False
+
+ if rhs.stride(0) > 1 and rhs.stride(1) > 1:
+ trans_B = True
+
+ # checks constraints
+ assert shape[1] == rhs.shape[0], "incompatible dimensions"
+ M, K = shape
+ _, N = rhs.shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if rhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ stride_am, stride_ak = data.stride(1), data.stride(2)
+ stride_bk, stride_bn = rhs.stride(0), rhs.stride(1)
+ a_column_indices = column_indices
+ a_offsets = offsets
+
+ # launch kernel
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N']))
+
+ if trans_A:
+ stride_am, stride_ak = data.stride(2), data.stride(1)
+ a_column_indices, a_offsets = column_indices_t, offsets_t
+
+ if trans_B:
+ stride_bk, stride_bn = rhs.stride(1), rhs.stride(0)
+
+ _dsd_kernel[grid](
+ data.data, rhs, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(0), out.stride(1),
+ row_indices, a_column_indices, a_offsets,
+ block_offsets_t, trans_A, trans_B,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+ # return out
+
+def dds(lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b,
+ out
+ ):
+
+ device = lhs.device
+ trans_B = transpose_b
+ trans_A = False
+
+ if lhs.stride(0) > 1 and lhs.stride(1) > 1:
+ trans_A = True
+
+ # checks constraints
+ assert lhs.shape[1] == shape[0], "incompatible dimensions"
+ M, K = lhs.shape
+ _, N = shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if lhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ stride_am, stride_ak = lhs.stride(0), lhs.stride(1)
+ stride_bk, stride_bn = data.stride(1), data.stride(2)
+ b_column_indices = column_indices_t
+ b_offsets = offsets_t
+
+ # launch kernel
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N']))
+
+ if trans_A:
+ stride_am, stride_ak = lhs.stride(1), lhs.stride(0)
+ if trans_B:
+ stride_bk, stride_bn = data.stride(2), data.stride(1)
+ b_column_indices, b_offsets = column_indices, offsets
+
+ _dds_kernel[grid](
+ lhs, data, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(0), out.stride(1),
+ row_indices, b_column_indices, b_offsets,
+ block_offsets_t, trans_A, trans_B,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+
+def sdd(lhs,
+ rhs,
+ shape,
+ out,
+ offsets,
+ row_indices,
+ column_indices
+ ):
+
+ device = out.device
+ trans_A = False
+ trans_B = False
+
+ if lhs.stride(0) > 1 and lhs.stride(1) > 1:
+ trans_A = True
+ if rhs.stride(0) > 1 and rhs.stride(1) > 1:
+ trans_B = True
+
+ # checks constraints
+ assert lhs.shape[1] == rhs.shape[0], "incompatible dimensions"
+ M, K = lhs.shape
+ _, N = rhs.shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if out.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ # launch kernel
+ nnz_blocks = len(row_indices)
+ grid = lambda META: (nnz_blocks,)
+
+ stride_am, stride_ak = lhs.stride(0), lhs.stride(1)
+ stride_bk, stride_bn = rhs.stride(0), rhs.stride(1)
+
+ if trans_A:
+ stride_am, stride_ak = lhs.stride(1), lhs.stride(0)
+ if trans_B:
+ stride_bk, stride_bn = rhs.stride(1), rhs.stride(0)
+
+ _sdd_kernel[grid](
+ lhs, rhs, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(1), out.stride(2),
+ row_indices, column_indices,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+
+@triton.jit
+def _row_indices_kernel(offsets, out):
+ pid = tl.program_id(0)
+ row_offset = tl.load(offsets + pid)
+ nnz_blocks = tl.load(offsets + pid + 1) - row_offset
+ for nnz_block in range(nnz_blocks):
+ tl.store(out + row_offset + nnz_block, pid)
+
+def row_indices(
+ shape, data, offsets, column_indices, out
+):
+ block_rows = len(offsets) - 1
+ _row_indices_kernel[(block_rows, )](offsets, out)
diff --git a/build/torch211-cxx11-cu130-aarch64-linux/stk/matrix.py b/build/torch211-cxx11-cu130-aarch64-linux/stk/matrix.py
new file mode 100644
index 0000000000000000000000000000000000000000..80f42263d6aada287adbfa52a61fe950162a9e28
--- /dev/null
+++ b/build/torch211-cxx11-cu130-aarch64-linux/stk/matrix.py
@@ -0,0 +1,329 @@
+import numpy as np
+import torch
+
+# 1. Add heavyweight (data) validation helper.
+# 2. Add construction helpers
+# 3. Make indentation consistent
+# 4. Replace asserts with descriptive errors.
+
+##
+### Validation helpers.
+##
+
+
+def _validate_matrix(shape, data, row_indices, column_indices, offsets):
+ # Data should be [nnz, block_size, block_size]
+ if data.dim() == 1:
+ data = torch.reshape(data, [data.numel(), 1, 1])
+
+ # Blocks should be square.
+ if data.shape[-2] != data.shape[-1]:
+ raise ValueError(
+ "Expected square blocking in data. "
+ f"Got block shape {[data.shape[-2], data.shape[-1]]}")
+
+ # Flatten batch dimensions on data - original shape preserved
+ # in shape argument.
+ block_size = data.shape[-1]
+ data = data.view([-1, block_size, block_size])
+
+ if data.dim() != 3:
+ raise ValueError(
+ "Expected 3D shape for data (nnz, block, block). "
+ f"Got shape {data.dim()}D shape.")
+
+ block_size = data.shape[1]
+ if shape[-2] % block_size != 0 or shape[-1] % block_size != 0:
+ raise ValueError(
+ "Matrix shape must be dividible by blocking. "
+ f"Got shape {shape} with "
+ f"{[block_size, block_size]} blocking.")
+
+ if np.prod(shape) < data.numel():
+ raise ValueError(
+ "Invalid matrix. Number of nonzeros exceeds matrix capacity "
+ f"({data.numel()} v. {np.prod(shape)})")
+
+ if row_indices.dim() != 1:
+ raise ValueError(
+ f"Expected 1D row_indices. Got {row_indices.dim()}D row_indices.")
+
+ if column_indices.dim() != 1:
+ raise ValueError(
+ f"Expected 1D column_indices. Got {column_indices.dim()}D column_indices.")
+
+ if offsets.dim() != 1:
+ raise ValueError(
+ f"Expected 1D offsets. Got {offsets.dim()}D offsets.")
+
+ if row_indices.numel() != data.shape[0]:
+ raise ValueError(
+ "Expected 1 index per nonzero block. "
+ f"Got {row_indices.numel()} row_indices for {data.shape[0]} blocks")
+
+ if column_indices.numel() != data.shape[0]:
+ raise ValueError(
+ "Expected 1 index per nonzero block. "
+ f"Got {column_indices.numel()} column_indices for {data.shape[0]} blocks")
+
+ block_rows = np.prod(shape[:-1]) / block_size
+ if offsets.numel() != block_rows + 1:
+ raise ValueError(
+ "Expected one offset per block row plus one. "
+ f"Got {offsets.numel()} offsets with {block_rows} block rows.")
+
+ is_cuda = (data.is_cuda and
+ row_indices.is_cuda and
+ column_indices.is_cuda and
+ offsets.is_cuda)
+ is_cpu = (not data.is_cuda and
+ not row_indices.is_cuda and
+ not column_indices.is_cuda and
+ not offsets.is_cuda)
+ if not (is_cuda or is_cpu):
+ raise ValueError(
+ "Expected data & meta-data on common device. "
+ f"Got data on {data.device}, row_indices on {row_indices.device} "
+ f"column_indices on {column_indices.device} and "
+ f"offsets on {offsets.device}.")
+
+ if data.dtype != torch.float16:
+ raise ValueError(
+ f"Expected float16 data. Got {data.dtype} data.")
+ if row_indices.dtype != torch.int16:
+ raise ValueError(
+ f"Expected int16 row_indices. Got {row_indices.dtype} row_indices.")
+ if column_indices.dtype != torch.int16:
+ raise ValueError(
+ f"Expected int16 column_indices. Got {column_indices.dtype} column_indices.")
+ if offsets.dtype != torch.int32:
+ raise ValueError(
+ f"Expected int32 offsets. Got {offsets.dtype} offsets.")
+ return data
+
+
+def _transpose(size, data, row_indices, column_indices, offsets):
+ block_columns = size[1] // data.shape[1]
+
+ # Sort row indices by column indices to get the transposed matrix's
+ # column indices.
+ gather_indices = column_indices.argsort()
+ column_indices_t = row_indices.gather(0, gather_indices)
+ block_offsets_t = gather_indices.int()
+
+ # NOTE: Histogram is not implemented for any integer type on CPU. Do
+ # the histogram in 32-bit float, which can exactly represent 16-bit
+ # integers.
+ column_indices_float = column_indices.float()
+
+ zero = torch.zeros((1,), dtype=torch.int32, device=data.device)
+ nnz_per_column = column_indices_float.histc(block_columns, 0, block_columns)
+ nnz_per_column = nnz_per_column.int()
+ offsets_t = torch.cat([zero, nnz_per_column.cumsum(0, dtype=torch.int32)])
+ return column_indices_t, offsets_t, block_offsets_t
+
+
+class Matrix(torch.nn.Module):
+ """A matrix stored in sparse format.
+
+ Underlying format is block compressed sparse row (BCSR).
+
+ TODO(tgale): Make this mirror torch.Tensor API as much as possible.
+ """
+
+ def __init__(self,
+ size,
+ data,
+ row_indices,
+ column_indices,
+ offsets,
+ column_indices_t=None,
+ offsets_t=None,
+ block_offsets_t=None):
+ super().__init__()
+ self._size = size
+ self._data = data
+ self._row_indices = row_indices
+ self._column_indices = column_indices
+ self._offsets = offsets
+
+ # Produce the transpose meta-data if it is not passed in.
+ if ((column_indices_t is None) or (offsets_t is None) or
+ (block_offsets_t is None)):
+ column_indices_t, offsets_t, block_offsets_t = _transpose(
+ size, data, row_indices, column_indices, offsets)
+ self._column_indices_t = column_indices_t
+ self._offsets_t = offsets_t
+ self._block_offsets_t = block_offsets_t
+
+ self._transposed = False
+
+ # Validate that our metadata will not overflow.
+ max_dim = np.iinfo(np.int16).max * self.blocking
+ if column_indices.dtype == torch.int16:
+ if size[0] > max_dim or size[1] > max_dim:
+ raise ValueError(
+ "Sparse matrix with shape {size} exceeds representable "
+ "size with 16-bit indices.")
+
+ def validate(self):
+ _validate_matrix(self._size,
+ self._data,
+ self._row_indices,
+ self._column_indices,
+ self._offsets)
+
+ # TODO(tgale): Add heavyweight data validation.
+
+ def to(self, device):
+ # TODO(tgale): Handle type conversions here. We
+ # need to set the appropriate meta-data type for
+ # the given floating-point type.
+ self._data = self._data.to(device)
+ self._row_indices = self._row_indices.to(device)
+ self._column_indices = self._column_indices.to(device)
+ self._offsets = self._offsets.to(device)
+ self._column_indices_t = self._column_indices_t.to(device)
+ self._offsets_t = self._offsets_t.to(device)
+ self._block_offsets_t = self._block_offsets_t.to(device)
+ return self
+
+ def cuda(self):
+ return self.to(torch.cuda.current_device())
+
+ def clone(self):
+ return Matrix(
+ self.size(),
+ self.data.clone(),
+ self.row_indices.clone(),
+ self.column_indices.clone(),
+ self.offsets.clone(),
+ self.column_indices_t.clone(),
+ self.offsets_t.clone(),
+ self.block_offsets_t.clone())
+
+ def t(self):
+ if self.dim() != 2:
+ raise ValueError(
+ "t() expects a tensor with <= 2 dimensions, "
+ f"but self is {self.dim()}D.")
+ out = Matrix(self.size(),
+ self.data,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+ out._transposed = not self._transposed
+ out._size = torch.Size((self._size[1], self._size[0]))
+ return out
+
+ def contiguous(self):
+ raise ValueError("Not yet implemented.")
+
+ def is_contiguous(self):
+ return not self._transposed
+
+ @property
+ def is_cuda(self):
+ return self._data.is_cuda
+
+ @property
+ def device(self):
+ return self._data.device
+
+ def size(self):
+ return self._size
+
+ @property
+ def shape(self):
+ return self.size()
+
+ def dim(self):
+ return len(self._size)
+
+ @property
+ def data(self):
+ return self._data
+
+ @property
+ def row_indices(self):
+ return self._row_indices
+
+ @property
+ def column_indices(self):
+ return self._column_indices
+
+ @property
+ def offsets(self):
+ return self._offsets
+
+ @property
+ def offsets_t(self):
+ return self._offsets_t
+
+ @property
+ def column_indices_t(self):
+ return self._column_indices_t
+
+ @property
+ def block_offsets_t(self):
+ return self._block_offsets_t
+
+ @property
+ def dtype(self):
+ return self.data.dtype
+
+ @property
+ def nnz(self):
+ return self.data.numel()
+
+ @property
+ def blocking(self):
+ return self.data.shape[1]
+
+ @property
+ def requires_grad(self):
+ return self.data.requires_grad
+
+ def requires_grad_(self, x):
+ self.data.requires_grad_(x)
+ return self
+
+ def view(self, *shape):
+ assert self.is_contiguous()
+ if shape[-1] != self.size()[-1]:
+ raise ValueError(
+ "Can't change view on compressed dimension. "
+ f"{self.size()[-1]} v. {shape[-1]}.")
+ if np.prod(shape) != np.prod(self.size()):
+ raise ValueError(
+ "Mismatch in numel of Matrix and new shape. "
+ f"{np.prod(self.size())} v. {np.prod(shape)}")
+ return Matrix(shape,
+ self.data,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+
+ @property
+ def grad(self):
+ # TODO(tgale): Make sure this mirrors torch.Tensor
+ # behavior in the case where we ask for the gradient
+ # of a non-contiguous tensor.
+ size = self.size()
+ if not self.is_contiguous():
+ size = torch.Size((size[1], size[0]))
+ out = Matrix(size,
+ self.data.grad,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+ return out if self.is_contiguous() else out.t()
diff --git a/build/torch211-cxx11-cu130-aarch64-linux/stk/ops/__init__.py b/build/torch211-cxx11-cu130-aarch64-linux/stk/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc873b236f4cd4036964c016a4036e3ce5ebf1ac
--- /dev/null
+++ b/build/torch211-cxx11-cu130-aarch64-linux/stk/ops/__init__.py
@@ -0,0 +1,3 @@
+from .linear_ops import dds, dsd, sdd
+from .matrix_ops import ones_like, row_indices, sum, to_dense, to_sparse
+from .eltwise_ops import mul
diff --git a/build/torch211-cxx11-cu130-aarch64-linux/stk/ops/eltwise_ops.py b/build/torch211-cxx11-cu130-aarch64-linux/stk/ops/eltwise_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba7d7332320250fd01fa60e60528f19de3e8ed03
--- /dev/null
+++ b/build/torch211-cxx11-cu130-aarch64-linux/stk/ops/eltwise_ops.py
@@ -0,0 +1,28 @@
+from ..matrix import Matrix
+
+def mul(a, b):
+ """Performs element-wise multiplication of matrices a and b.
+
+ It is the user's responsibility to make sure that a and b
+ follow the same matrix topology. This function assumes it is safe
+ to use the topoplogy of a.
+
+ Args:
+ a: stk.Matrix.
+ b: stk.Matrix with a's matrix topology.
+
+ Returns:
+ stk.Matrix where the entries correspond to torch.mul(a, b).
+ """
+ assert isinstance(a, Matrix)
+ assert isinstance(b, Matrix)
+ assert a.size() == b.size()
+
+ return Matrix(a.size(),
+ a.data * b.data,
+ a.row_indices,
+ a.column_indices,
+ a.offsets,
+ a.column_indices_t,
+ a.offsets_t,
+ a.block_offsets_t)
diff --git a/build/torch211-cxx11-cu130-aarch64-linux/stk/ops/eltwise_ops_test.py b/build/torch211-cxx11-cu130-aarch64-linux/stk/ops/eltwise_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..2939372a5c68ac92b47b11015db4f75f4fd60ffa
--- /dev/null
+++ b/build/torch211-cxx11-cu130-aarch64-linux/stk/ops/eltwise_ops_test.py
@@ -0,0 +1,86 @@
+import unittest
+import itertools
+import torch
+# from absl.testing import parameterized
+
+import stk
+from stk.ops.linear_ops_test import allclose, _dense_and_sparse
+
+_MATRIX_SIZES = (
+ (128, 128, 0.0),
+ (256, 256, 0.5),
+ (2048, 1024, 0.8),
+ (512, 128, 0.0),
+ (128, 512, 0.0),
+ (1024, 512, 0.0),
+ (1024, 512, 0.5),
+ (1024, 512, 0.75),
+ (512, 1024, 0.0),
+ (512, 1024, 0.5),
+ (512, 1024, 0.75),
+ (1024, 1024, 0.0),
+ (1024, 1024, 0.5),
+ (1024, 1024, 0.75),
+)
+
+_DTYPE = (
+ torch.float16, torch.bfloat16
+)
+
+def _generate_testcases():
+ testcases = itertools.product(_MATRIX_SIZES, _DTYPE)
+ testcases = [(*size, 128, dtype) for
+ (size, dtype) in testcases]
+ return testcases
+
+_ELTWISE_OP_TESTS = _generate_testcases()
+
+def _dense_and_sparse_like(x, std=0.1):
+ dense_data = torch.randn_like(x.data, device=x.device) * std
+ sparse = stk.Matrix(x.size(),
+ dense_data,
+ x.row_indices,
+ x.column_indices,
+ x.offsets)
+ dense = stk.ops.to_dense(sparse)
+
+ return (dense.requires_grad_(True),
+ sparse.requires_grad_(True))
+
+# @parameterized.parameters(_ELTWISE_OP_TESTS)
+# class EltwiseOpsTest(parameterized.TestCase):
+#
+# def testEltwiseMul(self, m, n, sparsity, blocking, dtype):
+#
+# a_dense, a = _dense_and_sparse(m, n, sparsity, blocking, dtype)
+# b_dense, b = _dense_and_sparse_like(a)
+#
+# out = stk.ops.mul(a, b)
+# expected_out = torch.mul(a_dense, b_dense)
+#
+# # Compute the gradients w.r.t. the inputs.
+# expected_out.sum().backward()
+# stk.ops.sum(out).backward()
+#
+# # Validate the results.
+# out = stk.ops.to_dense(out)
+# self.assertEqual(out.dim(), 2)
+# self.assertEqual(expected_out.size(), out.size())
+# self.assertTrue(allclose(out, expected_out))
+#
+# # LHS gradient.
+# grad = stk.ops.to_dense(a.grad)
+# expected_grad = a_dense.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size(), grad.size())
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# # RHS gradient.
+# grad = stk.ops.to_dense(b.grad)
+# expected_grad = b_dense.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size(), grad.size())
+# self.assertTrue(allclose(grad, expected_grad))
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch211-cxx11-cu130-aarch64-linux/stk/ops/linear_ops.py b/build/torch211-cxx11-cu130-aarch64-linux/stk/ops/linear_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d277c8c07f9e30addc31900a12175c8a1f4d7ad
--- /dev/null
+++ b/build/torch211-cxx11-cu130-aarch64-linux/stk/ops/linear_ops.py
@@ -0,0 +1,59 @@
+import torch
+
+from ..backend import sputnik
+from ..matrix import Matrix
+
+
+def dsd(a, b):
+ assert isinstance(a, Matrix)
+ assert isinstance(b, torch.Tensor)
+ return sputnik.dsd(
+ a.size(),
+ a.data, a.offsets,
+ a.row_indices,
+ a.column_indices,
+ a.offsets_t,
+ a.column_indices_t,
+ a.block_offsets_t,
+ not a.is_contiguous(),
+ b)
+
+
+def dds(a, b):
+ assert isinstance(a, torch.Tensor)
+ assert isinstance(b, Matrix)
+ return sputnik.dds(
+ a,
+ b.size(),
+ b.data, b.offsets,
+ b.row_indices,
+ b.column_indices,
+ b.offsets_t,
+ b.column_indices_t,
+ b.block_offsets_t,
+ not b.is_contiguous())
+
+
+def sdd(a, b, topo):
+ assert isinstance(a, torch.Tensor)
+ assert isinstance(b, torch.Tensor)
+ assert isinstance(topo, Matrix)
+ assert topo.is_contiguous()
+ out = sputnik.sdd(
+ a, b,
+ topo.size(),
+ topo.data,
+ topo.offsets,
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets_t,
+ topo.column_indices_t,
+ topo.block_offsets_t)
+ return Matrix(topo.size(),
+ out,
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets,
+ topo.column_indices_t,
+ topo.offsets_t,
+ topo.block_offsets_t)
diff --git a/build/torch211-cxx11-cu130-aarch64-linux/stk/ops/linear_ops_test.py b/build/torch211-cxx11-cu130-aarch64-linux/stk/ops/linear_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1c24d350df9c1b2346c7da885502cd696c88867
--- /dev/null
+++ b/build/torch211-cxx11-cu130-aarch64-linux/stk/ops/linear_ops_test.py
@@ -0,0 +1,216 @@
+import unittest
+import itertools
+import numpy as np
+import torch
+# from absl.testing import parameterized
+
+import stk
+
+
+def allclose(x, y, pct=0.25):
+ mask = torch.isclose(x, y, rtol=5e-2)
+ pct_diff = (mask.numel() - mask.sum()) / mask.numel() * 100
+ if pct_diff > pct:
+ print("{:.2f}% of values not close.".format(pct_diff))
+ return False
+ return True
+
+
+# An assortment of problems designed to make sure
+# the bindings are operating correctly.
+_MATRIX_SIZES = (
+ (128, 128, 128, 0.0),
+ (256, 256, 256, 0.5),
+ (2048, 1024, 512, 0.8),
+ (512, 128, 128, 0.0),
+ (128, 128, 512, 0.0),
+ (1024, 512, 512, 0.0),
+ (1024, 512, 512, 0.5),
+ (1024, 512, 512, 0.75),
+ (512, 512, 1024, 0.0),
+ (512, 512, 1024, 0.5),
+ (512, 512, 1024, 0.75),
+ (1024, 1024, 1024, 0.0),
+ (1024, 1024, 1024, 0.5),
+ (1024, 1024, 1024, 0.75),
+)
+
+_TRANSPOSE = (
+ (False, False),
+ (False, True),
+ (True, False),
+ (True, True),
+)
+
+_DTYPE = (
+ torch.float16, torch.bfloat16
+)
+
+def _generate_testcases():
+ testcases = itertools.product(_MATRIX_SIZES, _TRANSPOSE, _DTYPE)
+ testcases = [(*size, *trans, 128, dtype) for
+ (size, trans, dtype) in testcases]
+ return testcases
+
+_LINEAR_OP_TESTS = _generate_testcases()
+
+def _dense_and_sparse(rows, cols, sparsity, blocking, dtype, std=0.1):
+ mask = stk.random.dense_mask(rows, cols, sparsity, blocking)
+ dense = (torch.randn(rows, cols) * std * mask).type(dtype)
+ sparse = stk.ops.to_sparse(dense, blocking)
+ cuda_device = torch.device("cuda")
+ return (dense.to(cuda_device).requires_grad_(True),
+ sparse.to(cuda_device).requires_grad_(True))
+
+
+def _dense(rows, cols, dtype, std=0.1):
+ cuda_device = torch.device("cuda")
+ out = (torch.randn(rows, cols) * std).type(dtype)
+ return out.to(cuda_device).requires_grad_(True)
+
+
+def _dense_2x(rows, cols, dtype):
+ a = _dense(rows, cols, dtype)
+ return a, a.detach().requires_grad_(True)
+
+
+def _with_transpose(op, a, b, trans_a, trans_b):
+ a = a.t() if trans_a else a
+ b = b.t() if trans_b else b
+ return op(a, b)
+
+
+def _mmm(a, b, topo):
+ mask = stk.ops.to_dense(stk.ops.ones_like(topo))
+ return torch.mm(a, b) * mask
+
+
+def _sparse_out_with_transpose(op, a, b, topo, trans_a, trans_b):
+ a = a.t() if trans_a else a
+ b = b.t() if trans_b else b
+ return op(a, b, topo)
+
+
+def _mask(x, mask):
+ mask = stk.ops.to_dense(stk.ops.ones_like(mask))
+ return x * mask
+
+
+# @parameterized.parameters(*_LINEAR_OP_TESTS)
+# class LinearOpsTest(parameterized.TestCase):
+#
+# def testLinearOps_Dsd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+# # Construct the operands.
+# a_shape = (k, m) if trans_a else (m, k)
+# a_dense, a = _dense_and_sparse(*a_shape, sparsity, blocking, dtype)
+# b_shape = (n, k) if trans_b else (k, n)
+# b, bcp = _dense_2x(*b_shape, dtype)
+#
+# # Execute the matmul.
+# out = _with_transpose(stk.ops.dsd, a, b, trans_a, trans_b)
+# expected_out = _with_transpose(torch.mm, a_dense, bcp, trans_a, trans_b)
+#
+# # Compute the gradients w.r.t. the inputs.
+# expected_out.sum().backward()
+# out.sum().backward()
+#
+# # Validate the results.
+# self.assertEqual(out.dim(), 2)
+# self.assertEqual(expected_out.size()[0], out.size()[0])
+# self.assertEqual(expected_out.size()[1], out.size()[1])
+# self.assertTrue(allclose(out, expected_out))
+#
+# # LHS gradient.
+# grad = stk.ops.to_dense(a.grad)
+# expected_grad = _mask(a_dense.grad, a.grad)
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# # RHS gradient.
+# grad = b.grad
+# expected_grad = bcp.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# def testLinearOps_Dds(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+# # Construct the operands.
+# a_shape = (k, m) if trans_a else (m, k)
+# a, acp = _dense_2x(*a_shape, dtype)
+# b_shape = (n, k) if trans_b else (k, n)
+# b_dense, b = _dense_and_sparse(*b_shape, sparsity, blocking, dtype)
+#
+# # Execute the matmul.
+# out = _with_transpose(stk.ops.dds, a, b, trans_a, trans_b)
+# expected_out = _with_transpose(torch.mm, acp, b_dense, trans_a, trans_b)
+#
+# # Compute the gradients w.r.t. the inputs.
+# expected_out.sum().backward()
+# out.sum().backward()
+#
+# # Validate the results.
+# self.assertEqual(out.dim(), 2)
+# self.assertEqual(expected_out.size()[0], out.size()[0])
+# self.assertEqual(expected_out.size()[1], out.size()[1])
+# self.assertTrue(allclose(out, expected_out))
+#
+# # LHS gradient.
+# grad = a.grad
+# expected_grad = acp.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# # RHS gradient.
+# grad = stk.ops.to_dense(b.grad)
+# expected_grad = _mask(b_dense.grad, b.grad)
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# def testLinearOps_Sdd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+# # Construct the operands.
+# a_shape = (k, m) if trans_a else (m, k)
+# a, acp = _dense_2x(*a_shape, dtype)
+# b_shape = (n, k) if trans_b else (k, n)
+# b, bcp = _dense_2x(*b_shape, dtype)
+# _, topo = _dense_and_sparse(m, n, sparsity, blocking, dtype)
+#
+# # Execute the matmul.
+# out = _sparse_out_with_transpose(stk.ops.sdd, a, b, topo, trans_a, trans_b)
+# expected_out = _sparse_out_with_transpose(_mmm, acp, bcp, topo, trans_a, trans_b)
+#
+# # Compute the gradients w.r.t. the inputs.
+# expected_out.sum().backward()
+# stk.ops.sum(out).backward()
+#
+# # Validate the results.
+# out = stk.ops.to_dense(out)
+# self.assertEqual(out.dim(), 2)
+# self.assertEqual(expected_out.size()[0], out.size()[0])
+# self.assertEqual(expected_out.size()[1], out.size()[1])
+# self.assertTrue(allclose(out, expected_out))
+#
+# # LHS gradient.
+# grad = a.grad
+# expected_grad = acp.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# # RHS gradient.
+# grad = b.grad
+# expected_grad = bcp.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch211-cxx11-cu130-aarch64-linux/stk/ops/matrix_ops.py b/build/torch211-cxx11-cu130-aarch64-linux/stk/ops/matrix_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..447c72dc73439d84f58c917676cc04e64f13e97d
--- /dev/null
+++ b/build/torch211-cxx11-cu130-aarch64-linux/stk/ops/matrix_ops.py
@@ -0,0 +1,98 @@
+from ..backend import sputnik
+from ..matrix import Matrix
+import torch
+import numpy as np
+
+
+@torch.no_grad()
+def row_indices(shape, data, offsets, column_indices):
+ return sputnik.row_indices(shape, data, offsets, column_indices)
+
+
+# TODO(tgale): Replace this helper with a custom kernel. This operation
+# is much simpler to do than how it's currently implemented.
+@torch.no_grad()
+def _expand_for_blocking(idxs, blocking):
+ # Duplicate for block column dimension.
+ idxs = torch.reshape(idxs, [idxs.size()[0], 1, 2]).repeat(1, blocking, 1)
+
+ # Update the column indices.
+ idxs[:, :, 1] *= blocking
+ idxs[:, :, 1] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking])
+
+ # Duplicate for block row dimension.
+ idxs = torch.reshape(idxs, [idxs.size()[0], 1, blocking, 2])
+ idxs = idxs.repeat(1, blocking, 1, 1)
+
+ # Update the row indices.
+ idxs[:, :, :, 0] *= blocking
+ idxs[:, :, :, 0] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking, 1])
+ idxs = torch.reshape(idxs, [-1, 2])
+ return idxs
+
+
+# TODO(tgale): Add input type checking.
+@torch.no_grad()
+def to_dense(x):
+ assert isinstance(x, Matrix)
+
+ shape = (np.prod(x.shape[:-1]), x.shape[-1])
+ row_idxs = x.row_indices.type(torch.int32)
+ col_idxs = x.column_indices.type(torch.int32)
+ indices = _expand_for_blocking(torch.stack([row_idxs, col_idxs], dim=1), x.blocking)
+ indices = (indices[:, 0] * shape[1] + indices[:, 1]).type(torch.int64)
+
+ out = torch.zeros(shape[0] * shape[1], dtype=x.dtype, device=x.device)
+ out.scatter_(0, indices, x.data.flatten())
+ return out.reshape(x.size())
+
+
+@torch.no_grad()
+def _mask(x, blocking=1):
+ assert x.dim() == 2
+ assert x.size()[0] % blocking == 0
+ assert x.size()[1] % blocking == 0
+ block_rows = x.size()[0] // blocking
+ block_cols = x.size()[1] // blocking
+ x = torch.reshape(x, [block_rows, blocking, block_cols, blocking])
+ x = torch.sum(torch.abs(x), dim=(1, 3))
+ return x != 0
+
+
+# TODO(tgale): Add input type checking.
+@torch.no_grad()
+def to_sparse(x, blocking=1):
+ m = _mask(x, blocking)
+
+ # TODO(tgale): Set to appropriate type for input matrix.
+ row_nnzs = torch.sum(m, dim=1).type(torch.int32)
+ zeros = torch.zeros((1,), dtype=row_nnzs.dtype, device=row_nnzs.device)
+ offsets = torch.cat([zeros, torch.cumsum(row_nnzs, dim=0)])
+ offsets = offsets.type(torch.int32)
+
+ indices = torch.nonzero(m).type(torch.int16)
+ row_indices = indices[:, 0]
+ column_indices = indices[:, 1]
+
+ # Nonzero indices in the dense matrix.
+ nonzero_indices = torch.nonzero(m)
+ nonzero_indices = _expand_for_blocking(nonzero_indices, blocking)
+ nonzero_indices = nonzero_indices[:, 0] * x.size()[1] + nonzero_indices[:, 1]
+
+ # Gather the data and construct the sparse matrix.
+ data = torch.gather(x.flatten(), dim=0, index=nonzero_indices)
+ data = torch.reshape(data, [-1, blocking, blocking])
+ return Matrix(x.size(), data, row_indices, column_indices, offsets)
+
+
+@torch.no_grad()
+def ones_like(x):
+ return Matrix(x.size(),
+ torch.ones_like(x.data),
+ x.row_indices,
+ x.column_indices, x.offsets)
+
+
+def sum(x):
+ assert isinstance(x, Matrix)
+ return x.data.sum()
diff --git a/build/torch211-cxx11-cu130-aarch64-linux/stk/ops/matrix_ops_test.py b/build/torch211-cxx11-cu130-aarch64-linux/stk/ops/matrix_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..d172d921f6f08b0e4fb709207a458b0e1e071bd0
--- /dev/null
+++ b/build/torch211-cxx11-cu130-aarch64-linux/stk/ops/matrix_ops_test.py
@@ -0,0 +1,62 @@
+import unittest
+
+# from absl.testing import parameterized
+import stk
+import torch
+
+
+# @parameterized.parameters(
+# (8, 16, 0.0, 1),
+# (8, 16, 0.5, 1),
+# (8, 16, .95, 1),
+# (16, 8, 0.0, 1),
+# (16, 8, 0.5, 1),
+# (16, 8, .95, 1),
+# (8, 16, 0.0, 8),
+# (8, 16, 0.5, 8),
+# (8, 16, 1.0, 8),
+# (16, 8, 0.0, 8),
+# (16, 8, 0.5, 8),
+# (16, 8, 1.0, 8),
+# (128, 256, 0.5, 16),
+# (256, 128, 0.75, 32),
+# (512, 512, .875, 128))
+# class MatrixOpsTest(parameterized.TestCase):
+#
+# def testMatrixOps_FormatConversion(self, rows, cols, sparsity, blocking):
+# mask = stk.random.dense_mask(rows, cols, sparsity, blocking)
+# x = (torch.randn(rows, cols) * mask).type(torch.float16)
+#
+# # Convert the matrix to sparse format.
+# sparse_x = stk.ops.to_sparse(x, blocking)
+#
+# # Validate the matrix.
+# sparse_x.validate()
+#
+# # Validate the shape.
+# self.assertEqual(sparse_x.dim(), 2)
+# self.assertEqual(sparse_x.size()[0], rows)
+# self.assertEqual(sparse_x.size()[1], cols)
+#
+# # Validate the sparsity.
+# numblocks = rows // blocking * cols // blocking
+# nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+# self.assertEqual(sparse_x.nnz, nnz)
+#
+# # Convert back to dense format.
+# dense_x = stk.ops.to_dense(sparse_x)
+#
+# # Validate the shape.
+# self.assertEqual(dense_x.dim(), 2)
+# self.assertEqual(dense_x.size()[0], rows)
+# self.assertEqual(dense_x.size()[1], cols)
+#
+# # Validate the sparsity
+# self.assertEqual(torch.count_nonzero(dense_x).item(), nnz)
+#
+# # Validate the output.
+# self.assertTrue(torch.all(torch.eq(x, dense_x)))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch211-cxx11-cu130-aarch64-linux/stk/random/__init__.py b/build/torch211-cxx11-cu130-aarch64-linux/stk/random/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2576d1ca27283f77569a9a620c7c99fa68aaf30e
--- /dev/null
+++ b/build/torch211-cxx11-cu130-aarch64-linux/stk/random/__init__.py
@@ -0,0 +1,2 @@
+# from stk.random.random_ops import dense_mask, mask, randn
+from .random_ops import dense_mask, mask, randn
diff --git a/build/torch211-cxx11-cu130-aarch64-linux/stk/random/random_ops.py b/build/torch211-cxx11-cu130-aarch64-linux/stk/random/random_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1b36771e0eb8e7abf46bcb3b136b5fb1d29df93
--- /dev/null
+++ b/build/torch211-cxx11-cu130-aarch64-linux/stk/random/random_ops.py
@@ -0,0 +1,36 @@
+import numpy as np
+import torch
+from ..ops import matrix_ops
+
+
+@torch.no_grad()
+def dense_mask(rows, cols, sparsity, blocking=1):
+ assert sparsity >= 0.0 and sparsity <= 1.0
+ assert rows % blocking == 0 and cols % blocking == 0
+
+ block_rows, block_cols = (rows // blocking, cols // blocking)
+ nnz = round(block_rows * block_cols * (1 - sparsity))
+
+ out = np.ones(block_rows * block_cols)
+ mask = np.random.choice(out.size, out.size - nnz, replace=False)
+ out[mask] = 0.0
+
+ out = np.tile(
+ np.reshape(out, [block_rows, 1, block_cols, 1]),
+ (1, blocking, 1, blocking))
+ out = np.reshape(out, [rows, cols])
+ return torch.from_numpy(out.astype(np.float32))
+
+
+@torch.no_grad()
+def mask(m, n, sparsity, blocking=1):
+ out = dense_mask(m, n, sparsity, blocking).type(torch.float16)
+ return matrix_ops.to_sparse(out, blocking=blocking)
+
+
+@torch.no_grad()
+def randn(shape, sparsity, blocking=1):
+ shape_2d = (np.prod(shape[:-1]), shape[-1])
+ out = mask(*shape_2d, sparsity, blocking)
+ out.data.copy_(torch.randn(*out.data.shape))
+ return out.view(*shape)
diff --git a/build/torch211-cxx11-cu130-aarch64-linux/stk/random/random_ops_test.py b/build/torch211-cxx11-cu130-aarch64-linux/stk/random/random_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..71d716b78b5ec009cbf9ac2dfdf09162a0102e62
--- /dev/null
+++ b/build/torch211-cxx11-cu130-aarch64-linux/stk/random/random_ops_test.py
@@ -0,0 +1,73 @@
+import unittest
+
+# from absl.testing import parameterized
+from . import random
+import torch
+
+
+# @parameterized.parameters(
+# (8, 16, 0.0, 1),
+# (8, 16, 0.5, 1),
+# (8, 16, .95, 1),
+# (16, 8, 0.0, 1),
+# (16, 8, 0.5, 1),
+# (16, 8, .95, 1),
+# (8, 16, 0.0, 8),
+# (8, 16, 0.5, 8),
+# (8, 16, 1.0, 8),
+# (16, 8, 0.0, 8),
+# (16, 8, 0.5, 8),
+# (16, 8, 1.0, 8),
+# (128, 256, 0.5, 16),
+# (256, 128, 0.75, 32),
+# (512, 512, .875, 128))
+# class RandomOpsTest(parameterized.TestCase):
+#
+# def testRandomOps_DenseMask(self, rows, cols, sparsity, blocking):
+# mask = random.dense_mask(
+# rows, cols, sparsity, blocking)
+#
+# # Validate the shape.
+# self.assertEqual(mask.dim(), 2)
+# self.assertEqual(mask.size()[0], rows)
+# self.assertEqual(mask.size()[1], cols)
+#
+# # Validate the sparsity
+# numblocks = rows // blocking * cols // blocking
+# nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+# self.assertEqual(
+# torch.count_nonzero(mask).item(),
+# nnz)
+#
+# # Check values are zero or one.
+# self.assertTrue(
+# torch.all(torch.logical_or(
+# torch.eq(mask, 0),
+# torch.eq(mask, 1))))
+#
+# def testRandomOps_SparseMask(self, rows, cols, sparsity, blocking):
+# mask = random.mask(
+# rows, cols, sparsity, blocking)
+#
+# # Validate the matrix.
+# mask.validate()
+#
+# # Validate the shape.
+# self.assertEqual(mask.dim(), 2)
+# self.assertEqual(mask.size()[0], rows)
+# self.assertEqual(mask.size()[1], cols)
+#
+# # Validate the sparsity.
+# numblocks = rows // blocking * cols // blocking
+# nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+# self.assertEqual(mask.nnz, nnz)
+#
+# # Check values are zero or one.
+# self.assertTrue(
+# torch.all(torch.logical_or(
+# torch.eq(mask.data, 0),
+# torch.eq(mask.data, 1))))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch211-cxx11-cu130-aarch64-linux/xpu_fused_moe.py b/build/torch211-cxx11-cu130-aarch64-linux/xpu_fused_moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..0440b484e6f4073a384dcf3bc562174d2ef71d25
--- /dev/null
+++ b/build/torch211-cxx11-cu130-aarch64-linux/xpu_fused_moe.py
@@ -0,0 +1,672 @@
+# SPDX-License-Identifier: Apache-2.0
+# MegaBlocks XPU Fused MoE Implementation
+import os
+import torch
+
+from ._ops import ops, add_op_namespace_prefix
+
+from torch.library import register_fake
+
+
+def resolve_dtensor(weight: torch.Tensor):
+ """Convert DTensor to local tensor for use with custom ops."""
+ from torch.distributed._tensor import DTensor
+ if isinstance(weight, DTensor):
+ return weight.to_local()
+ return weight
+
+
+# Register fake/meta kernels for torch.compile compatibility
+def _register_xpu_fake_kernels():
+ """Register fake kernels for XPU MoE operations to support torch.compile."""
+
+ def _register_if_available(op_name, fn):
+ if hasattr(ops, op_name):
+ register_fake(add_op_namespace_prefix(op_name))(fn)
+
+ _register_if_available(
+ "cutlass_grouped_gemm_interface",
+ lambda ptr_A, ptr_B, ptr_scales, ptr_bias, ptr_D, expert_first_token_offset, N, K, num_experts, is_B_int4, is_B_mxfp4: ptr_D,
+ )
+
+ _register_if_available(
+ "fused_moe_prologue",
+ lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, ep_rank, ep_size, num_experts_on_rank: None,
+ )
+
+ _register_if_available(
+ "moe_gather",
+ lambda output, moe_output, topk_weights, permuted_row_to_unpermuted_row, unpermuted_row_to_permuted_row, expert_first_token_offset, num_experts: None,
+ )
+
+ _register_if_available(
+ "silu_and_mul",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "mul_and_silu",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_and_mul",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_tanh_and_mul",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_fast",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_new",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_quick",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "swigluoai_and_mul",
+ lambda out, input, alpha=1.702, limit=7.0: None,
+ )
+
+
+# Register fake kernels on module load
+_register_xpu_fake_kernels()
+
+
+# default
+def cutlass_grouped_gemm(input_A, input_B, bias, output, expert_token_count, n,
+ k, num_experts):
+ # expert_token_count_ = torch.tensor(expert_token_count,
+ # dtype=torch.int64,
+ # device=input_A.device)
+ # if bias is not None:
+ # bias = bias.repeat_interleave(expert_token_count_, dim=0).float()
+
+ def exclusive_prefix_sum(arr):
+ prefix = [0]
+ for i, x in enumerate(arr):
+ prefix.append(prefix[-1] + x)
+ return prefix
+
+ expert_offset = torch.tensor(exclusive_prefix_sum(expert_token_count),
+ dtype=torch.int64,
+ device="xpu")
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=None,
+ ptr_bias=bias,
+ ptr_D=output,
+ expert_first_token_offset=expert_offset,
+ N=n,
+ K=k,
+ num_experts=num_experts,
+ is_B_int4=False,
+ is_B_mxfp4=False)
+
+
+def cutlass_grouped_gemm_xe2(input_A, input_B, scales, bias, output,
+ num_rows_per_expert, n, k, num_experts, is_B_int4,
+ is_B_mxfp4):
+ expert_first_token_offset = torch.cat([
+ torch.tensor([0],
+ dtype=num_rows_per_expert.dtype,
+ device=num_rows_per_expert.device),
+ torch.cumsum(num_rows_per_expert, dim=0)
+ ]).to(torch.int64)
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=scales,
+ ptr_bias=bias,
+ ptr_D=output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=n,
+ K=k,
+ num_experts=num_experts,
+ is_B_int4=is_B_int4,
+ is_B_mxfp4=is_B_mxfp4)
+
+
+def ceilDiv(a, b):
+ return (a + b - 1) // b
+
+
+def compute_num_tokens_per_block(num_tokens, num_experts_per_node):
+ for num_tokens_per_block in [32, 64, 128, 256, 512, 1024]:
+ num_blocks_per_seq = ceilDiv(num_tokens, num_tokens_per_block)
+ if num_blocks_per_seq * num_experts_per_node <= num_tokens_per_block:
+ return num_tokens_per_block
+ return 1024
+
+
+def _bytes_to_typed_tensor(byte_tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
+ """Reinterpret a uint8 buffer as a typed tensor by copying bytes.
+
+ This avoids `Tensor.view(dtype)` which can fail under torch.compile
+ constant folding when shape divisibility is not proven.
+ """
+ if byte_tensor.dtype != torch.uint8:
+ raise ValueError("byte_tensor must be uint8")
+ itemsize = torch.empty((), dtype=dtype).element_size()
+ numel = byte_tensor.numel() // itemsize
+ out = torch.empty((numel,), dtype=dtype, device=byte_tensor.device)
+ out.view(torch.uint8).copy_(byte_tensor.contiguous())
+ return out
+
+
+def implement_zp(qweight):
+ # change u4 to s4 to avoid zero point in gemm kernel
+ # only support default zero point now
+ assert qweight.dtype == torch.uint8, "Input tensor must be uint8"
+
+ high_u4 = (qweight >> 4) & 0x0F
+ low_u4 = qweight & 0x0F
+
+ high_s8 = high_u4.to(torch.int8)
+ low_s8 = low_u4.to(torch.int8)
+
+ high_s8 = high_s8 - 8
+ low_s8 = low_s8 - 8
+
+ def pack_compact(a, b):
+
+ def process_number(x):
+ sign = (x < 0).to(torch.uint8)
+ abs_low3 = (x.view(torch.uint8) & 0x7).to(torch.uint8)
+ return (sign << 3) | abs_low3
+
+ packed_a = process_number(a)
+ packed_b = process_number(b)
+
+ return (packed_a << 4) | packed_b
+
+ result = pack_compact(high_s8, low_s8)
+
+ return result
+
+
+def xpu_fused_moe(hidden_states,
+ w13,
+ w13_scales,
+ w13_bias,
+ w2,
+ w2_scales,
+ w2_bias,
+ topk_weights,
+ topk_ids,
+ n_experts_per_token,
+ activation,
+ num_experts,
+ ep_rank=0,
+ ep_size=1,
+ is_fp8=False,
+ is_int4=False,
+ is_mxfp4=False):
+ '''
+ hidden_states: [num_rows, hidden_size]
+ w13: [num_experts, 2*inter_size, hidden_size]
+ w13_scales:
+ None for bf16/fp16
+ or [num_experts] for fp8
+ or [num_experts, 2*inter_size, hidden_size // group_size] for 4bits
+ w13_bias: [num_experts, 2*inter_size] or None
+ w2: [num_experts, hidden_size, inter_size]
+ w2_scales:
+ None for bf16/fp16
+ or [num_experts] for fp8
+ or [num_experts, hidden_size, inter_size // group_size] for 4bits
+ w2_bias: [num_experts, hidden_size] or None
+ topk_weights: [num_rows, topk]
+ topk_ids: [num_rows, topk]
+ n_experts_per_token: int
+ activation: str
+ num_experts: int
+ is_int4: bool
+ is_mxfp4: bool
+ '''
+
+ # Resolve DTensors to local tensors before passing to custom ops
+ hidden_states = resolve_dtensor(hidden_states)
+ w13 = resolve_dtensor(w13)
+ w2 = resolve_dtensor(w2)
+ if w13_scales is not None:
+ w13_scales = resolve_dtensor(w13_scales)
+ if w13_bias is not None:
+ w13_bias = resolve_dtensor(w13_bias)
+ if w2_scales is not None:
+ w2_scales = resolve_dtensor(w2_scales)
+ if w2_bias is not None:
+ w2_bias = resolve_dtensor(w2_bias)
+ topk_weights = resolve_dtensor(topk_weights)
+ topk_ids = resolve_dtensor(topk_ids)
+
+ output = torch.empty_like(hidden_states)
+ num_rows, hidden_size = list(hidden_states.shape)
+
+ dim_last = w13.shape[-1]
+ dim_second_last = w13.shape[-2]
+
+ # w13 is combined gate+up weights, so one dimension is 2*inter_size
+ # Determine which dimension is hidden_size and which is 2*inter_size
+ if dim_second_last == hidden_size:
+ # w13 is [E, hidden_size, 2*inter_size] - standard layout
+ inter_size = dim_last // 2
+ needs_transpose = False
+ else:
+ # w13 is [E, 2*inter_size, hidden_size] - needs transpose
+ inter_size = dim_second_last // 2
+ needs_transpose = True
+
+ assert w13.is_contiguous() and w2.is_contiguous()
+
+ # 4bits support [E, N, K]
+ # other types [E, K, N]
+ if not is_int4 and not is_mxfp4:
+ if not hasattr(w13, 'xpu_fused_moe'):
+ if needs_transpose:
+ w13.data = w13.transpose(-1, -2).contiguous()
+ w2.data = w2.transpose(-1, -2).contiguous()
+ w13.xpu_fused_moe = True
+ w13.inter_size = inter_size
+ else:
+ inter_size = w13.inter_size
+
+ if is_int4 and not hasattr(w13, 'xpu_fused_moe'):
+ w13_tmp = torch.empty_like(w13)
+ w2_tmp = torch.empty_like(w2)
+ for i in range(num_experts):
+ w13_tmp[i] = implement_zp(w13[i])
+ w2_tmp[i] = implement_zp(w2[i])
+ w13_tmp = w13_tmp.contiguous()
+ w2_tmp = w2_tmp.contiguous()
+ w13.data = w13_tmp
+ w2.data = w2_tmp
+ w13.xpu_fused_moe = True
+
+ # TODO: will all integrated in Cpp func. Temporary expose before gemm fusion
+ num_experts_per_node = num_experts
+ experts_per_token = n_experts_per_token
+ num_moe_inputs = n_experts_per_token * num_rows
+ permuted_elems = num_moe_inputs * hidden_size
+ # interbuf_elems = num_moe_inputs * inter_size
+ permuted_row_to_unpermuted_row_size = num_moe_inputs * 4
+ permuted_token_selected_experts_size = num_moe_inputs * 4
+ src_to_dest_map_size = experts_per_token * num_rows * 4
+ expert_first_token_offset_size = (num_experts_per_node + 1) * 8
+ num_tokens_per_block = compute_num_tokens_per_block(
+ num_rows, num_experts_per_node)
+ num_blocks_per_seq = ceilDiv(num_rows, num_tokens_per_block)
+ blocked_expert_counts_size = num_experts_per_node * num_blocks_per_seq * 4
+ blocked_expert_counts_cumsum_size = blocked_expert_counts_size
+ blocked_row_to_unpermuted_row_size = num_experts_per_node * num_rows * 4
+ permuted_data_size = permuted_elems * hidden_states.element_size()
+ permuted_token_final_scales_size = num_moe_inputs * 4
+
+ ws_map = {}
+ map_offset = 0
+
+ def config_ws(name, size):
+ nonlocal map_offset
+ if size % 256 != 0:
+ size += 256 - size % 256
+ ws_map[name] = (size, map_offset)
+ map_offset += size
+
+ config_ws("permuted_row_to_unpermuted_row",
+ permuted_row_to_unpermuted_row_size)
+ config_ws("permuted_token_selected_experts",
+ permuted_token_selected_experts_size)
+ config_ws("unpermuted_row_to_permuted_row", src_to_dest_map_size)
+ config_ws("blocked_expert_counts", blocked_expert_counts_size)
+ config_ws("blocked_expert_counts_cumsum",
+ blocked_expert_counts_cumsum_size)
+ config_ws("blocked_row_to_unpermuted_row",
+ blocked_row_to_unpermuted_row_size)
+ config_ws("expert_first_token_offset", expert_first_token_offset_size)
+ config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
+ config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
+
+ workspace = torch.zeros(map_offset,
+ dtype=torch.uint8,
+ device=hidden_states.device)
+ if topk_ids.dtype == torch.int32:
+ topk_ids = topk_ids.to(torch.int64)
+ ops.fused_moe_prologue(
+ input=hidden_states,
+ token_selected_experts=topk_ids,
+ token_final_scales=topk_weights,
+ workspace=workspace,
+ hidden_size=hidden_size,
+ inter_size=inter_size,
+ ep_rank=ep_rank,
+ ep_size=ep_size,
+ num_experts_on_rank=num_experts_per_node)
+
+ expert_first_token_offset_bytes = workspace[
+ ws_map["expert_first_token_offset"][1]:
+ ws_map["expert_first_token_offset"][1] +
+ expert_first_token_offset_size]
+ unpermuted_row_to_permuted_row_bytes = workspace[
+ ws_map["unpermuted_row_to_permuted_row"][1]:
+ ws_map["unpermuted_row_to_permuted_row"][1] +
+ src_to_dest_map_size]
+ permuted_row_to_unpermuted_row_bytes = workspace[
+ ws_map["permuted_row_to_unpermuted_row"][1]:
+ ws_map["permuted_row_to_unpermuted_row"][1] +
+ permuted_row_to_unpermuted_row_size]
+
+ if torch.compiler.is_compiling():
+ expert_first_token_offset = _bytes_to_typed_tensor(
+ expert_first_token_offset_bytes, torch.int64
+ )
+ unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
+ unpermuted_row_to_permuted_row_bytes, torch.int32
+ )
+ permuted_row_to_unpermuted_row = _bytes_to_typed_tensor(
+ permuted_row_to_unpermuted_row_bytes, torch.int32
+ )
+ else:
+ expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
+ unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
+ permuted_row_to_unpermuted_row = permuted_row_to_unpermuted_row_bytes.view(torch.int32)
+ gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
+ ws_map["overlapped_gemm1_gemm2_inputs"][1] +
+ permuted_data_size].view(hidden_states.dtype).view(
+ num_moe_inputs, hidden_size)
+ # permuted_token_final_scales = workspace[
+ # ws_map["permuted_token_final_scales"][1]:
+ # ws_map["permuted_token_final_scales"][1] +
+ # permuted_token_final_scales_size].view(torch.float)
+ gemm1_output = torch.empty((num_moe_inputs, 2 * inter_size),
+ dtype=hidden_states.dtype,
+ device=hidden_states.device)
+
+ ########### gemm1 ##################
+ input_B = w13
+
+ if not is_fp8 and not is_int4 and not is_mxfp4:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=gemm1_input,
+ ptr_B=input_B,
+ ptr_scales=None,
+ ptr_bias=w13_bias,
+ ptr_D=gemm1_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=2 * inter_size,
+ K=hidden_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+ else:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=gemm1_input,
+ ptr_B=input_B,
+ ptr_scales=w13_scales,
+ ptr_bias=w13_bias,
+ ptr_D=gemm1_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=2 * inter_size,
+ K=hidden_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+
+ # act
+ act_output = torch.empty((num_moe_inputs, inter_size),
+ dtype=gemm1_output.dtype,
+ device=gemm1_output.device)
+ if activation == "silu":
+ ops.silu_and_mul(act_output, gemm1_output)
+ elif activation == "gelu":
+ ops.gelu_and_mul(act_output, gemm1_output)
+ elif activation == "swigluoai":
+ ops.swigluoai_and_mul(act_output, gemm1_output, 1.702, 7.0)
+ else:
+ raise ValueError(f"Unsupported FusedMoe activation: {activation}.")
+
+ ########### gemm2 ##################
+ input_A = act_output.contiguous()
+ input_B = w2
+ gemm2_output = torch.empty((num_moe_inputs, hidden_size),
+ dtype=hidden_states.dtype,
+ device=hidden_states.device)
+ if not is_fp8 and not is_int4 and not is_mxfp4:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=None,
+ ptr_bias=w2_bias,
+ ptr_D=gemm2_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=hidden_size,
+ K=inter_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+ else:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=w2_scales,
+ ptr_bias=w2_bias,
+ ptr_D=gemm2_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=hidden_size,
+ K=inter_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+
+ ops.moe_gather(output, gemm2_output, topk_weights,
+ permuted_row_to_unpermuted_row,
+ unpermuted_row_to_permuted_row,
+ expert_first_token_offset,
+ num_experts_per_node)
+ return output
+
+
+def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor:
+ """Apply jitter to the input tensor for regularization."""
+ low = 1.0 - moe_jitter_eps
+ high = 1.0 + moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return x * (low + noise * (high - low))
+
+
+def compute_top_k(scores: torch.Tensor, moe_top_k: int):
+ """Compute the top-k scores from the logits."""
+ if moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, moe_top_k, dim=-1)
+
+
+def route_tokens_xpu(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: torch.Tensor,
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ training: bool = False,
+) -> tuple:
+ """Route tokens to experts and compute expert weights and indices (XPU version)."""
+ if training and moe_jitter_eps is not None:
+ x = apply_jitter(x, moe_jitter_eps)
+
+ x_flat = x.view(-1, x.shape[-1])
+ logits = torch.nn.functional.linear(x_flat, router_weight, router_bias)
+ expert_weights, expert_indices = compute_top_k(logits, moe_top_k)
+ expert_weights = expert_weights.softmax(dim=-1)
+ if moe_normalize_expert_weights is not None:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+
+ return logits, expert_weights, expert_indices
+
+
+def _get_device_mesh(model):
+ """Extract device_mesh from child's unused pre_hook closure for EP support."""
+ try:
+ hook = next(
+ h
+ for h in model.experts._forward_pre_hooks.values()
+ if "device_mesh" in h.__code__.co_freevars
+ )
+ return hook.__closure__[
+ hook.__code__.co_freevars.index("device_mesh")
+ ].cell_contents
+ except Exception:
+ return None
+
+
+class MegaBlocksMoeMLP(torch.nn.Module):
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> tuple:
+ """
+ Forward pass through the MoE layer.
+
+ Args:
+ x: Input tensor of shape [batch_size, seq_len, hidden_size] or [tokens, hidden_size]
+
+ Returns:
+ Tuple of (output, expert_weights) where:
+ - output: Tensor of same shape as input
+ - expert_weights: Expert weights for each token [tokens, top_k]
+ """
+
+ # Get MoE parameters from the wrapped modules
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+
+ # Get EP (Expert Parallelism) parameters
+ ep_size = 1
+ ep_rank = 0
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = _get_device_mesh(self)
+ if device_mesh is not None:
+ expert_parallel_group = device_mesh.get_group()
+ if expert_parallel_group is not None:
+ import torch.distributed as dist
+ if dist.is_initialized():
+ ep_size = dist.get_world_size(expert_parallel_group)
+ ep_rank = dist.get_rank(expert_parallel_group)
+
+ # Number of experts on this rank
+ num_experts_on_rank = moe_num_experts // ep_size
+
+ # Detect activation type - check for GptOss-style swigluoai activation
+ # GptOssExperts has alpha and limit attributes for swigluoai
+ if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
+ activation = "swigluoai"
+ else:
+ activation = getattr(self.experts, "activation", "silu")
+
+ # Get weight tensors - support different naming conventions
+ if hasattr(self.experts, "gate_up_proj"):
+ w13 = self.experts.gate_up_proj
+ # NOTE: swigluoai_and_mul kernel expects interleaved layout [g0,u0,g1,u1,...]
+ # which matches GptOss's gate_up_proj layout, so no conversion needed.
+
+ elif hasattr(self.experts, "w1"):
+ # Combine w1 and w3 if stored separately
+ w1 = self.experts.w1
+ w3 = getattr(self.experts, "w3", None)
+ if w3 is not None:
+ w13 = torch.cat([w1, w3], dim=-2)
+ else:
+ w13 = w1
+ else:
+ raise AttributeError("experts module must have 'gate_up_proj' or 'w1' attribute")
+
+ if hasattr(self.experts, "down_proj"):
+ w2 = self.experts.down_proj
+ elif hasattr(self.experts, "w2"):
+ w2 = self.experts.w2
+ else:
+ raise AttributeError("experts module must have 'down_proj' or 'w2' attribute")
+
+ # Get optional bias tensors
+ w13_bias = getattr(self.experts, "gate_up_proj_bias", None)
+ w2_bias = getattr(self.experts, "down_proj_bias", None)
+
+ # Get quantization info
+ is_fp8 = getattr(self.experts, "is_fp8", False)
+ is_int4 = getattr(self.experts, "is_int4", False)
+ is_mxfp4 = getattr(self.experts, "is_mxfp4", False)
+
+ w13_scales = getattr(self.experts, "gate_up_proj_scales", None)
+ w2_scales = getattr(self.experts, "down_proj_scales", None)
+
+ # Store original shape
+ in_shape = x.size()
+
+ # Route tokens to experts
+ logits, expert_weights, expert_indices = route_tokens_xpu(
+ x,
+ self.router.weight,
+ getattr(self.router, "bias", None),
+ moe_top_k,
+ moe_num_experts,
+ moe_jitter_eps,
+ moe_normalize_expert_weights,
+ self.training,
+ )
+
+ # Reshape input for fused MoE
+ x_flat = x.view(-1, x.shape[-1])
+
+ # Call XPU fused MoE kernel
+ output = xpu_fused_moe(
+ hidden_states=x_flat,
+ w13=w13,
+ w13_scales=w13_scales,
+ w13_bias=w13_bias,
+ w2=w2,
+ w2_scales=w2_scales,
+ w2_bias=w2_bias,
+ topk_weights=expert_weights.float(),
+ topk_ids=expert_indices,
+ n_experts_per_token=moe_top_k,
+ activation=activation,
+ num_experts=num_experts_on_rank,
+ ep_rank=ep_rank,
+ ep_size=ep_size,
+ is_fp8=is_fp8,
+ is_int4=is_int4,
+ is_mxfp4=is_mxfp4,
+ )
+
+ # All-reduce across EP group to combine partial expert outputs
+ if ep_size > 1 and expert_parallel_group is not None:
+ import torch.distributed as dist
+ dist.all_reduce(output, op=dist.ReduceOp.SUM, group=expert_parallel_group)
+
+ # Restore original shape
+ output = output.view(in_shape)
+
+ return output, expert_weights
+
+
+# Export classes and functions
+__all__ = [
+ "MegaBlocksMoeMLP",
+ "xpu_fused_moe",
+ "cutlass_grouped_gemm",
+ "cutlass_grouped_gemm_xe2",
+]
\ No newline at end of file
diff --git a/build/torch211-cxx11-cu130-x86_64-linux/__init__.py b/build/torch211-cxx11-cu130-x86_64-linux/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..38075732c6d8fa0e1e6ef493145e1aca3851ae6b
--- /dev/null
+++ b/build/torch211-cxx11-cu130-x86_64-linux/__init__.py
@@ -0,0 +1,202 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+from ._ops import ops
+
+from .grouped_gemm import backend as gg_backend
+from .grouped_gemm import ops as gg_ops
+
+
+from ._layers.arguments import Arguments
+from ._layers.dmoe import ParallelDroplessMLP, dMoE
+from ._layers.glu import SparseGLU
+from ._layers.mlp import MLP, SparseMLP
+from ._layers.moe import MoE, ParallelMLP, get_load_balancing_loss
+
+from . import layers
+
+# This section contains the direct kernel exports (not inlcuded in the original code)
+def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
+ """
+ Compute exclusive cumulative sum along the specified dimension.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ result = ops.exclusive_cumsum(x, dim)
+ out.copy_(result)
+ return out
+
+
+def inclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
+ """
+ Compute inclusive cumulative sum along the specified dimension.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ result = ops.inclusive_cumsum(x, dim)
+ out.copy_(result)
+ return out
+
+
+def histogram(x: torch.Tensor, num_bins: int) -> torch.Tensor:
+ """
+ Compute histogram of input tensor values.
+
+ Args:
+ x: Input tensor
+ num_bins: Number of histogram bins
+
+ Returns:
+ Histogram tensor with counts for each bin
+ """
+ return ops.histogram(x, num_bins)
+
+
+def indices(
+ padded_bins: torch.Tensor,
+ block_size: int,
+ output_block_rows: int,
+ output_block_columns: int,
+) -> torch.Tensor:
+ """
+ Construct indices from padded bins for sparse operations.
+
+ Args:
+ padded_bins: Tensor containing bin boundaries
+ block_size: Size of each block
+ output_block_rows: Number of rows in output blocks
+ output_block_columns: Number of columns in output blocks
+
+ Returns:
+ Tensor containing constructed indices
+ """
+ return ops.indices(padded_bins, block_size, output_block_rows, output_block_columns)
+
+
+def replicate_forward(
+ x: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Forward pass of replicate operation - replicate values according to bin sizes.
+
+ Args:
+ x: Input tensor with values to replicate
+ bins: Tensor containing bin sizes
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ return ops.replicate_forward(x, bins, out)
+
+
+def replicate_backward(
+ grad: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Backward pass of replicate operation - reduce gradients back to bins.
+
+ Args:
+ grad: Gradient tensor to reduce
+ bins: Tensor containing bin sizes
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ return ops.replicate_backward(grad, bins, out)
+
+
+def sort(
+ x: torch.Tensor, end_bit: int, x_out: torch.Tensor, iota_out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Radix sort with index tracking.
+
+ Args:
+ x: Input tensor to sort
+ end_bit: Number of bits to consider in sorting
+ x_out: Output tensor for sorted values
+ iota_out: Output tensor for sorted indices
+
+ Returns:
+ The sorted values tensor
+ """
+ return ops.sort(x, end_bit, x_out, iota_out)
+
+
+# Convenience functions for common use cases
+def cumsum(x: torch.Tensor, dim: int = -1, exclusive: bool = False) -> torch.Tensor:
+ """
+ Compute cumulative sum with automatic output allocation.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum (default: last dimension)
+ exclusive: Whether to compute exclusive (True) or inclusive (False) cumsum
+
+ Returns:
+ New tensor containing the cumulative sum
+ """
+ out = torch.empty_like(x)
+ if exclusive:
+ return exclusive_cumsum(x, dim, out)
+ else:
+ return inclusive_cumsum(x, dim, out)
+
+
+def argsort(x: torch.Tensor, end_bit: int = 32) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Sort tensor and return both sorted values and indices.
+
+ Args:
+ x: Input tensor to sort
+ end_bit: Number of bits to consider in sorting
+
+ Returns:
+ Tuple of (sorted_values, sorted_indices)
+ """
+ x_out = torch.empty_like(x)
+ iota_out = torch.empty_like(x)
+ sort(x, end_bit, x_out, iota_out)
+ return x_out, iota_out
+
+
+# Export public API
+__all__ = [
+ "MyReplacementLayer",
+ # Direct kernel exports
+ "exclusive_cumsum",
+ "inclusive_cumsum",
+ "histogram",
+ "indices",
+ "replicate_forward",
+ "replicate_backward",
+ "sort",
+ "cumsum",
+ "argsort",
+ # Original exports
+ "Arguments",
+ "ParallelDroplessMLP",
+ "dMoE",
+ "SparseGLU",
+ "MLP",
+ "SparseMLP",
+ "MoE",
+ "ParallelMLP",
+ "get_load_balancing_loss",
+]
diff --git a/build/torch211-cxx11-cu130-x86_64-linux/_layers/__init__.py b/build/torch211-cxx11-cu130-x86_64-linux/_layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a720e7a2cc4e44636f6e433a2750e945dc38e8b2
--- /dev/null
+++ b/build/torch211-cxx11-cu130-x86_64-linux/_layers/__init__.py
@@ -0,0 +1,10 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# from megablocks.layers.dmoe import dMoE
+from .moe import MoE
+
+__all__ = [
+ 'MoE',
+ # 'dMoE',
+]
diff --git a/build/torch211-cxx11-cu130-x86_64-linux/_layers/activation_fn.py b/build/torch211-cxx11-cu130-x86_64-linux/_layers/activation_fn.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e1d956704840aa4daf7d1d71d24e051567feab9
--- /dev/null
+++ b/build/torch211-cxx11-cu130-x86_64-linux/_layers/activation_fn.py
@@ -0,0 +1,33 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Callable, Union
+
+import torch
+from ..stk import Matrix
+
+
+def act_fn(
+ x: Matrix,
+ function: Callable,
+ return_grad_fn: bool = False,
+ **kwargs,
+) -> Union[tuple[Matrix, Any] | Matrix]:
+ assert isinstance(x, Matrix)
+ with torch.set_grad_enabled(torch.is_grad_enabled() or return_grad_fn):
+ if return_grad_fn:
+ x.data.requires_grad = True
+ out = function(x.data, **kwargs)
+ y = Matrix(
+ x.size(),
+ out,
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
+ if return_grad_fn:
+ return y, out.backward
+ return y
diff --git a/build/torch211-cxx11-cu130-x86_64-linux/_layers/all_to_all.py b/build/torch211-cxx11-cu130-x86_64-linux/_layers/all_to_all.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ac7067bcaa34db1d82b340c43550fe3577aa7a3
--- /dev/null
+++ b/build/torch211-cxx11-cu130-x86_64-linux/_layers/all_to_all.py
@@ -0,0 +1,54 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import torch.distributed as dist
+
+
+class AllToAllOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op):
+ out = torch.empty((sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype)
+
+ ctx.input_shape = x.shape
+ ctx.output_split_sizes = output_split_sizes
+ ctx.input_split_sizes = input_split_sizes
+ ctx.group = group
+ handle = dist.all_to_all_single(
+ out,
+ x,
+ output_split_sizes=output_split_sizes,
+ input_split_sizes=input_split_sizes,
+ group=group,
+ async_op=async_op,
+ )
+ return out, handle
+
+ @staticmethod
+ def backward(ctx, grad, _):
+ if ctx.needs_input_grad[0]:
+ out = torch.empty(
+ ctx.input_shape,
+ device=grad.device,
+ dtype=grad.dtype,
+ )
+ dist.all_to_all_single(
+ out,
+ grad,
+ output_split_sizes=ctx.input_split_sizes,
+ input_split_sizes=ctx.output_split_sizes,
+ group=ctx.group,
+ )
+ return out, None, None, None, None
+ return None, None, None, None, None
+
+
+def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False):
+ return AllToAllOp.apply(
+ x,
+ output_split_sizes,
+ input_split_sizes,
+ group,
+ async_op,
+ )
diff --git a/build/torch211-cxx11-cu130-x86_64-linux/_layers/arguments.py b/build/torch211-cxx11-cu130-x86_64-linux/_layers/arguments.py
new file mode 100644
index 0000000000000000000000000000000000000000..4db9b1bd38bc2e2f421625c124f86b85f45c5ae0
--- /dev/null
+++ b/build/torch211-cxx11-cu130-x86_64-linux/_layers/arguments.py
@@ -0,0 +1,101 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import dataclasses
+from functools import partial
+from typing import Any, Callable, Optional, Union
+
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+
+# import megablocks.grouped_gemm_util as grouped_gemm
+from .. import grouped_gemm_util as grouped_gemm
+
+# Type annotation for in-place Tensor initialization function.
+InitFn = Union[Callable[[torch.Tensor], None], partial[torch.Tensor]]
+
+_ALLOWED_BITWIDTHS = (-1, 4, 8)
+
+DEFAULT_ACTIVATION_FN = partial(F.gelu, approximate='tanh')
+
+
+@dataclasses.dataclass
+class Arguments:
+ # Model arguments.
+ hidden_size: int = 1024
+ ffn_hidden_size: int = 4096
+ num_layers: int = 1
+ bias: bool = True
+ return_bias: bool = True
+ activation_fn: Optional[Callable] = DEFAULT_ACTIVATION_FN
+
+ # MoE arguments.
+ moe_num_experts: int = 1
+ moe_top_k: int = 1
+ moe_capacity_factor: int = 1
+ moe_normalize_expert_weights: Optional[Union[int, float]] = None
+ moe_loss_weight: float = 0.1
+ moe_jitter_eps: Optional[float] = None
+ moe_lbl_in_fp32: bool = False
+
+ # Parallelism arguments.
+ moe_expert_model_parallelism: bool = False
+ expert_parallel_group: Optional[dist.ProcessGroup] = None
+ pipeline_model_parallel_size: int = 1
+ num_layers_per_virtual_pipeline_stage: Optional[int] = None
+
+ # Compute arguments.
+ memory_optimized_mlp: bool = False
+ mlp_type: str = 'mlp'
+ mlp_impl: str = 'sparse'
+
+ # Initialization arguments.
+ fp16: bool = True
+ bf16: bool = False
+ device: Union[int, torch.device] = dataclasses.field(default_factory=torch.cuda.current_device)
+ init_method: InitFn = partial(torch.nn.init.normal_, mean=0.0, std=0.02)
+ output_layer_init_method: InitFn = init_method
+
+ # Benchmarking arguments.
+ uniform_expert_assignment: bool = False
+
+ # shared expert arguments
+ shared_expert: bool = False # enable using shared expert
+ fc_cls: Any = torch.nn.Linear # class of the fully connected layer in shared expert (purpose: to allow using custom FC layer eg te.Linear (for FP8))
+ fc_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict,) # kwargs for custom fc layers
+ remat_act_fn: bool = True # enable act fn to be rematerialized instead of stored
+ shared_expert_hidden_size: Optional[
+ int] = None # hidden size of the shared expert IF we want to set it to something different from hidden_size
+ shared_expert_weighted_sum: bool = False # enable using weighted sum for shared expert output (wieghted by number of experts used)
+
+ # Router Z-loss arguments
+ moe_zloss_weight: float = 0 # 1e-3 is a reasonable value
+ moe_zloss_in_fp32: bool = False
+
+ def __post_init__(self):
+ # Sparse MLP is not supported with triton >=3.2.0
+ # TODO: Remove this once sparse is supported with triton >=3.2.0
+ if self.__getattribute__('mlp_impl') == 'sparse':
+ try:
+ import triton
+ if triton.__version__ >= '3.2.0':
+ raise ValueError(
+ 'Sparse MLP is not supported with triton >=3.2.0. Please use mlp_impl="grouped" instead.',
+ )
+ except ImportError:
+ raise ImportError('Triton is required for sparse MLP implementation')
+
+ if self.__getattribute__('mlp_impl') == 'grouped':
+ grouped_gemm.assert_grouped_gemm_is_available()
+
+ if self.shared_expert_hidden_size is None:
+ self.shared_expert_hidden_size = self.ffn_hidden_size
+
+
+def from_megatron(megatron_args: Any):
+ args = Arguments()
+ for field in dataclasses.fields(args):
+ if hasattr(megatron_args, field.name):
+ setattr(args, field.name, getattr(megatron_args, field.name))
+ return args
diff --git a/build/torch211-cxx11-cu130-x86_64-linux/_layers/common.py b/build/torch211-cxx11-cu130-x86_64-linux/_layers/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d07109702963ba48a3b94ab860807954dfd79c1
--- /dev/null
+++ b/build/torch211-cxx11-cu130-x86_64-linux/_layers/common.py
@@ -0,0 +1,26 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+from .arguments import Arguments
+
+
+def dtype(args: Arguments):
+ if args.fp16:
+ return torch.float16
+ elif args.bf16:
+ return torch.bfloat16
+ return None
+
+
+def cast_if_autocast_enabled(tensor):
+ if torch.is_autocast_enabled():
+ if tensor.device.type == 'cuda':
+ dtype = torch.get_autocast_gpu_dtype()
+ elif tensor.device.type == 'cpu':
+ dtype = torch.get_autocast_cpu_dtype()
+ else:
+ raise NotImplementedError()
+ return tensor.to(dtype=dtype)
+ return tensor
diff --git a/build/torch211-cxx11-cu130-x86_64-linux/_layers/dmlp_registry.py b/build/torch211-cxx11-cu130-x86_64-linux/_layers/dmlp_registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..de2ed047042e438c7190ebb139b6f7f30009734c
--- /dev/null
+++ b/build/torch211-cxx11-cu130-x86_64-linux/_layers/dmlp_registry.py
@@ -0,0 +1,42 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Union
+
+from . import glu, mlp
+from .arguments import Arguments
+
+MlpType = Union[mlp.SparseMLP, glu.SparseGLU]
+
+_REGISTRY = {
+ 'mlp': {
+ 'grouped': mlp.GroupedMLP,
+ 'sparse': mlp.SparseMLP,
+ },
+ 'glu': {
+ 'grouped': glu.GroupedGLU,
+ 'sparse': glu.SparseGLU,
+ },
+}
+
+
+def get(args: Arguments) -> MlpType:
+ """Returns an MLP for use in a dMoE instance.
+
+ Uses the provided arguments to instantiate the appropriate
+ MLP instance. This only contains MLPs for use in dMoEs
+ (ie. only for the dropless versions of MoEs).
+
+ Args:
+ args: propagated Arguments dataclass.
+
+ Returns:
+ An instantiated MLP constructed using the input args.
+ """
+ if args.mlp_type not in _REGISTRY:
+ raise ValueError(f'Unsupported mlp type: {args.mlp_type}')
+
+ if args.mlp_impl not in _REGISTRY[args.mlp_type]:
+ raise ValueError(f'{args.mlp_type} does not support {args.mlp_impl} backend.',)
+
+ return _REGISTRY[args.mlp_type][args.mlp_impl](args)
diff --git a/build/torch211-cxx11-cu130-x86_64-linux/_layers/dmoe.py b/build/torch211-cxx11-cu130-x86_64-linux/_layers/dmoe.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d0375a4df2f27134c4127e60be04f3b45693050
--- /dev/null
+++ b/build/torch211-cxx11-cu130-x86_64-linux/_layers/dmoe.py
@@ -0,0 +1,337 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import numpy as np
+import torch
+
+# try:
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/dmoe.py is needed.',
+# )
+
+# import megablocks.ops as ops
+# # from megablocks.ops import ops
+# from megablocks.layers import common, dmlp_registry, moe, mpu
+# from megablocks.layers.arguments import Arguments
+
+from .. import stk
+from .. import ops
+from . import common, dmlp_registry, moe, mpu
+from .arguments import Arguments
+
+def promote_scalar(x):
+ return x.view(1) if not len(x.size()) else x
+
+
+class ParallelDroplessMLP(moe.ParallelMLP):
+
+ def __init__(self, args: Arguments):
+ super(ParallelDroplessMLP, self).__init__(args)
+ self.hidden_size = args.hidden_size
+ self.ffn_hidden_size = mpu.features_per_rank(args)
+ self.blocking = 128
+ self.mlp = dmlp_registry.get(args)
+
+ # Calculate the number of bits needed to represent the column indices
+ # in the intermediate sparse matrix.
+ max_column_index = ((self.ffn_hidden_size * self.num_experts) // self.blocking)
+ self.transpose_sort_end_bit = max(
+ int(np.ceil(np.log2(max_column_index))),
+ 1,
+ )
+
+ def sparse_transpose(self, size, row_indices, column_indices, offsets):
+ block_columns = size[1] // self.blocking
+
+ # Sort row indices by column indices to get the transposed matrix's
+ # column indices.
+ #
+ # NOTE: Our sort operation uses the same width indices as the input values.
+ # To avoid overflow when we have large activation matrices we cast to
+ # 32-bit before sorting.
+ _, gather_indices = ops.sort(
+ column_indices.int(),
+ self.transpose_sort_end_bit,
+ )
+
+ # There are a constant number of blocks in every row of the sparse matrix.
+ # A blocks offset is:
+ #
+ # row_index * blocks_per_row + column_index % blocks_per_row
+ #
+ # Once we have the block offsets ordered for transposition we can divide
+ # by blocks_per_row to get the transposed column indices.
+ column_indices_t = row_indices.gather(0, gather_indices.long())
+ block_offsets_t = gather_indices.int()
+
+ zero = torch.zeros((1,), dtype=torch.int32, device=row_indices.device)
+ nnz_per_column = ops.histogram(column_indices, block_columns)
+ nnz_per_column = ops.inclusive_cumsum(nnz_per_column, 0)
+ if nnz_per_column.dim() == 0:
+ # This addresses an edge case when ffn_hidden_size is equal to self.blocking.
+ nnz_per_column = nnz_per_column.unsqueeze(0)
+ offsets_t = torch.cat([zero, nnz_per_column])
+ return column_indices_t, offsets_t, block_offsets_t
+
+ def topology(self, x, padded_bins):
+ padded_tokens, _ = x.size()
+ assert padded_tokens % self.blocking == 0
+ if self.ffn_hidden_size % self.blocking != 0:
+ raise ValueError(
+ f'The ffn_hidden_size {self.ffn_hidden_size} must be divisible by ' +
+ f'the block size {self.blocking}. Please update your configuration.',
+ )
+
+ # Offsets for the sparse matrix. All rows have the
+ # same number of nonzero blocks dictated by the
+ # dimensionality of a single expert.
+ block_rows = padded_tokens // self.blocking
+ blocks_per_row = self.ffn_hidden_size // self.blocking
+ offsets = torch.arange(
+ 0,
+ block_rows * blocks_per_row + 1,
+ blocks_per_row,
+ dtype=torch.int32,
+ device=x.device,
+ )
+
+ # Indices for the sparse matrix. The indices for
+ # the intermediate matrix are dynamic depending
+ # on the mapping of tokens to experts.
+ column_indices = ops.topology(
+ padded_bins,
+ self.blocking,
+ block_rows,
+ blocks_per_row,
+ )
+
+ # TODO(tgale): This is unused. Remove the need for this in stk.
+ # For now, use meta init to save the device memory.
+ data = torch.empty(
+ column_indices.numel(),
+ self.blocking,
+ self.blocking,
+ dtype=common.dtype(self.args),
+ device='meta',
+ )
+ shape = (
+ padded_tokens,
+ self.ffn_hidden_size * mpu.experts_per_rank(self.args),
+ )
+ row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
+ column_indices_t, offsets_t, block_offsets_t = self.sparse_transpose(
+ shape,
+ row_indices,
+ column_indices,
+ offsets,
+ )
+ return stk.Matrix(
+ shape,
+ data,
+ row_indices,
+ column_indices,
+ offsets,
+ column_indices_t,
+ offsets_t,
+ block_offsets_t,
+ )
+
+ def indices_and_padded_bins(self, top_experts):
+ # Sort the expert ids to produce the scatter/gather
+ # indices for the permutation.
+ top_experts = top_experts.int()
+ bin_ids, indices = ops.sort(top_experts, self.sort_end_bit)
+
+ # Histogram the expert ids to identify the number of
+ # tokens routed to each expert.
+ tokens_per_expert = ops.histogram(top_experts, self.num_experts)
+
+ # Round the token counts up to the block size used in
+ # the matrix muliplications. Caculate the starting
+ # position of each bin.
+ padded_tokens_per_expert = ops.round_up(
+ tokens_per_expert,
+ self.blocking,
+ )
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ padded_bins = promote_scalar(padded_bins)
+
+ # Calculate the bin bounds for the sorted tokens.
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ bins = promote_scalar(bins)
+ return indices, bin_ids, bins, padded_bins, tokens_per_expert
+
+ def sparse_forward_once(self, x, expert_weights, top_experts):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, padded_bins, tokens_per_expert = (self.indices_and_padded_bins(top_experts))
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.padded_gather(
+ x,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ self.top_k,
+ )
+
+ # Create the sparse matrix topology.
+ with torch.no_grad():
+ topo = self.topology(x, padded_bins)
+
+ # Perform the expert computation.
+ x = self.mlp(x, topo)
+
+ # Un-route the data for the MoE output.
+ x = ops.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ padded_bins,
+ self.top_k,
+ )
+ return x, tokens_per_expert
+
+ # For use in the base-class parallel_forward_once.
+ def sparse_permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy, # unused
+ top_k,
+ ):
+
+ # Round the token counts up to the block size used in the matrix
+ # multiplication. Calculate the starting position of each bin.
+ padded_tokens_per_expert = ops.round_up(
+ tokens_per_expert,
+ self.blocking,
+ )
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ padded_bins = promote_scalar(padded_bins)
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
+
+ # Create the sparse matrix topology.
+ with torch.no_grad():
+ topo = self.topology(x, padded_bins)
+
+ # Perform the expert computation.
+ x = self.mlp(x, topo)
+
+ # Un-route the data for the MoE output.
+ return ops.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ def grouped_forward_once(self, x, expert_weights, top_experts):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ out = self.grouped_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ -1, # unused
+ self.args.moe_top_k,
+ )
+ return out, tokens_per_expert
+
+ def grouped_permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy, # unused
+ top_k,
+ ):
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.gather(x, indices, bin_ids, bins, top_k)
+
+ # Perform the expert computation.
+ x = self.mlp(x, tokens_per_expert)
+
+ # Un-route the data for the MoE output.
+ return ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
+
+ def forward_once(self, x, expert_weights, top_experts):
+ if self.args.mlp_impl == 'sparse':
+ return self.sparse_forward_once(x, expert_weights, top_experts)
+ else:
+ return self.grouped_forward_once(x, expert_weights, top_experts)
+
+ def permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ ):
+ if self.args.mlp_impl == 'sparse':
+ return self.sparse_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ )
+ else:
+ return self.grouped_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ )
+
+
+class dMoE(moe.MoE):
+
+ def _init_experts_mlp(self, args: Arguments):
+ return ParallelDroplessMLP(args)
diff --git a/build/torch211-cxx11-cu130-x86_64-linux/_layers/gelu.py b/build/torch211-cxx11-cu130-x86_64-linux/_layers/gelu.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4c9e6532798615b5c12c96694241a4c18ee8f7b
--- /dev/null
+++ b/build/torch211-cxx11-cu130-x86_64-linux/_layers/gelu.py
@@ -0,0 +1,52 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# try:
+# import stk
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/gelu.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+import torch.nn.functional as F
+
+
+@torch.jit.script
+def _gelu_backward_inplace(g, x):
+ tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
+ ff = (0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out))
+ return g.mul_(ff)
+
+
+def gelu_backward_(grad: stk.Matrix, x: stk.Matrix):
+ # NOTE: The two sparse matrices must have the same topology.
+ if isinstance(grad, stk.Matrix) and isinstance(x, stk.Matrix):
+ return stk.Matrix(
+ x.size(),
+ _gelu_backward_inplace(grad.data, x.data),
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
+ return _gelu_backward_inplace(grad, x)
+
+
+def gelu(x: stk.Matrix):
+ assert isinstance(x, stk.Matrix)
+ return stk.Matrix(
+ x.size(),
+ F.gelu(x.data, approximate='tanh'),
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
diff --git a/build/torch211-cxx11-cu130-x86_64-linux/_layers/glu.py b/build/torch211-cxx11-cu130-x86_64-linux/_layers/glu.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f297a41ff6a1a2a285f5b461951672364b898da
--- /dev/null
+++ b/build/torch211-cxx11-cu130-x86_64-linux/_layers/glu.py
@@ -0,0 +1,244 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# import stk.ops
+# try:
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/glu.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+
+# from megablocks import grouped_gemm_util as gg
+# from megablocks.layers import common, mpu
+# from megablocks.layers.activation_fn import act_fn
+# from megablocks.layers.arguments import Arguments
+# from megablocks.layers.mlp import (
+# SharedMLP,
+# SparseMLP,
+# create_dmoe_expert_weights,
+# resolve_dtensor,
+# )
+
+from .. import grouped_gemm_util as gg
+from . import common, mpu
+from .activation_fn import act_fn
+from .arguments import Arguments
+from .mlp import (
+ SharedMLP,
+ SparseMLP,
+ create_dmoe_expert_weights,
+ resolve_dtensor,
+)
+
+
+class SparseGLU(SparseMLP):
+
+ def __init__(self, args: Arguments):
+ super().__init__(args)
+ self.v1 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ with torch.no_grad():
+ self.v1.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ ),
+ )
+
+ mpu.set_expert_model_parallel_attributes(
+ self.v1,
+ self._should_set_parallelism_attribute,
+ )
+
+ def forward(self, x, topo):
+ if self.args.memory_optimized_mlp:
+ raise NotImplementedError(
+ 'Memory optimized implementation not yet supported with GLU with sparse kernels.',
+ )
+
+ w1, v1, w2 = self.scale_grad(self.w1), self.scale_grad(self.v1,), self.scale_grad(self.w2)
+ w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2)
+
+ # Compute the GLU.
+ x1 = stk.ops.sdd(x, w1.t(), topo)
+ x2 = stk.ops.sdd(x, v1.t(), topo)
+
+ activation_fn_out = act_fn(x1, self.args.activation_fn)
+ x1 = stk.ops.mul(activation_fn_out, x2)
+
+ return stk.ops.dsd(x1, w2)
+
+
+class MemoryOptimizedGroupedGLU(torch.autograd.Function):
+ """GroupedMLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, v1, w2, batch_sizes, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ v1 = v1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], v1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not v1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1', 'v1' and 'w2'.")
+
+ # Layer 0: x @ w1.t().
+ assert gg.backend is not None
+ sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True)
+ v1_out = gg.backend.gmm(x, v1, batch_sizes, trans_b=True)
+
+ # GeLU.
+ activation_fn_out = activation_fn(sdd_out) * v1_out
+
+ # Layer 1: x @ w2.
+ dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, v1, w2, batch_sizes, x, sdd_out, v1_out)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx, ddsd_out):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # Unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, v1, w2 = saved_tensors[:3]
+ batch_sizes = saved_tensors[3]
+ x = saved_tensors[4]
+ sdd_out, v1_out = saved_tensors[5:7]
+
+ # Rematerialize activation_fn output.
+ activation_fn = ctx.activation_fn
+ with torch.set_grad_enabled(True):
+ sdd_out.requires_grad = True
+ v1_out.requires_grad = True
+ activation_fn_out = activation_fn(sdd_out) * v1_out
+ activation_grad_fn = activation_fn_out.backward
+
+ # Compute dw2 with recomputed activation_fn output.
+ assert gg.backend is not None
+ dw2 = gg.backend.gmm(
+ activation_fn_out,
+ ddsd_out,
+ batch_sizes,
+ trans_a=True,
+ )
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ gg.backend.gmm(
+ ddsd_out,
+ w2,
+ batch_sizes,
+ trans_b=True,
+ c=dactivation_fn_out,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out)
+ dsdd_out = sdd_out.grad
+ dv1_out = v1_out.grad
+
+ # Compute dw1.
+ dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True)
+
+ # Compute dv1.
+ dv1 = gg.backend.gmm(dv1_out, x, batch_sizes, trans_a=True)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ dx = ddsd_out
+ gg.backend.gmm(dsdd_out, w1, batch_sizes, c=dx)
+ dx += gg.backend.gmm(dv1_out, v1, batch_sizes)
+ return dx, dw1, dv1, dw2, None, None
+
+
+memory_optimized_grouped_glu = MemoryOptimizedGroupedGLU.apply
+
+
+class GroupedGLU(SparseGLU):
+
+ def forward(self, x, tokens_per_expert):
+ batch_sizes = tokens_per_expert.cpu().to(torch.long)
+ w1, v1, w2 = (
+ self.scale_grad(self.w1),
+ self.scale_grad(self.v1),
+ self.scale_grad(self.w2),
+ )
+ w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2)
+
+ # Re-shape the weights for the grouped GEMMs.
+ ne = mpu.experts_per_rank(self.args)
+ w1 = w1.view(ne, -1, self.args.hidden_size)
+ v1 = v1.view(ne, -1, self.args.hidden_size)
+ w2 = w2.view(ne, -1, self.args.hidden_size)
+
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_grouped_glu(
+ x,
+ w1,
+ v1,
+ w2,
+ batch_sizes,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ assert gg.ops is not None
+ x1 = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
+ x2 = gg.ops.gmm(x, v1, batch_sizes, trans_b=True)
+ x1 = self.args.activation_fn(x1) * x2
+ return gg.ops.gmm(x1, w2, batch_sizes)
+
+
+class SharedGLU(SharedMLP):
+ """GPU for shared expert.
+
+ Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTGLU class
+ """
+
+ def __init__(self, args: Arguments):
+ super().__init__(args)
+ self.gate_proj = args.fc_cls(
+ args.hidden_size,
+ self.args.shared_expert_hidden_size,
+ **self.fc_kwargs,
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))
diff --git a/build/torch211-cxx11-cu130-x86_64-linux/_layers/memory_test.py b/build/torch211-cxx11-cu130-x86_64-linux/_layers/memory_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..74d1166931b712635131985b25a89f4ca23e576d
--- /dev/null
+++ b/build/torch211-cxx11-cu130-x86_64-linux/_layers/memory_test.py
@@ -0,0 +1,103 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import gc
+
+import torch
+import torch.distributed as dist
+
+# from megablocks.layers import arguments, dmoe
+from . import arguments, dmoe
+
+_TESTS = ((8, 2048, 4096, 4096, 32, 4),)
+
+
+def get_tensors():
+ ptrs = set()
+ out = []
+ for obj in gc.get_objects():
+ if torch.is_tensor(obj):
+ if not obj.is_contiguous() or obj.data_ptr() in ptrs:
+ continue
+ out.append(obj)
+ ptrs.add(obj.data_ptr())
+ return out
+
+
+def test_memory(
+ group,
+ batch_size,
+ sequence_length,
+ hidden_size,
+ ffn_hidden_size,
+ num_experts,
+ top_k,
+):
+ args = arguments.Arguments(
+ hidden_size=hidden_size,
+ ffn_hidden_size=ffn_hidden_size,
+ moe_num_experts=num_experts,
+ moe_top_k=top_k,
+ moe_expert_model_parallelism=True,
+ expert_parallel_group=group,
+ fp16=False,
+ bf16=True,
+ device=torch.cuda.current_device(),
+ )
+ layer = dmoe.dMoE(args).cuda()
+
+ x = torch.randn((batch_size, sequence_length, hidden_size),
+ device=torch.cuda.current_device(),
+ dtype=torch.bfloat16).requires_grad_(True)
+ torch.cuda.empty_cache()
+
+ # Run forward + backward.
+ # with torch.autograd.detect_anomaly():
+ out, _ = layer(x)
+ out.mean().backward()
+
+ # Report peak memory.
+ mem = torch.cuda.max_memory_allocated()
+ print('Max Memory Allocated = {:0.0f}MiB'.format(mem / 1e6))
+ print('Max Memory Reserved = {:0.0f}MiB'.format(torch.cuda.max_memory_reserved() / 1e6,),)
+
+ # Calculate weight and gradient memory usage.
+ weight_memory = 2 * (
+ layer.router.layer.weight.numel() + layer.experts.mlp.w1.numel() + layer.experts.mlp.w2.numel()
+ )
+
+ def grad_numel(x):
+ if x.grad is not None:
+ return x.grad.numel()
+ return 0
+
+ grad_memory = 2 * (
+ grad_numel(layer.router.layer.weight) + grad_numel(layer.experts.mlp.w1) + grad_numel(layer.experts.mlp.w2)
+ )
+ weight_memory += grad_memory
+
+ print('Weight Memory Allocated = {:0.0f}MiB'.format(weight_memory / 1e6))
+ print('Activation Memory Allocated = {:0.0f}MiB'.format((mem - weight_memory) / 1e6,),)
+
+ # Manually calculate GPU memory usage from the garbage
+ # collector.
+ gc.collect()
+ total = 0
+ tensors = get_tensors()
+ tensors = sorted(tensors, key=lambda x: -x.numel())
+ for i, t in enumerate(tensors):
+ total += t.numel()
+ print(f'{i}: {t.shape}, {t.numel() * 2}')
+ del tensors
+
+ print('Total Bytes Found = {:0.0f}MiB'.format(total * 2 / 1e6))
+
+
+if __name__ == '__main__':
+ assert dist.is_available()
+ group = dist.init_process_group(backend='nccl')
+ local_rank = dist.get_rank(group)
+ torch.cuda.set_device(local_rank)
+
+ for args in _TESTS:
+ test_memory(group, *args)
diff --git a/build/torch211-cxx11-cu130-x86_64-linux/_layers/mlp.py b/build/torch211-cxx11-cu130-x86_64-linux/_layers/mlp.py
new file mode 100644
index 0000000000000000000000000000000000000000..c99afb9904c24a8b6a83e79059cd1251dbbfd99e
--- /dev/null
+++ b/build/torch211-cxx11-cu130-x86_64-linux/_layers/mlp.py
@@ -0,0 +1,587 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# try:
+# import stk
+# import stk.backend.triton_kernels
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/mlp.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+from packaging import version
+
+# from megablocks import grouped_gemm_util as gg
+# from megablocks.layers import common, gelu, mpu
+# from megablocks.layers.activation_fn import act_fn
+# from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
+
+from .. import grouped_gemm_util as gg
+from . import common, gelu, mpu
+from .activation_fn import act_fn
+from .arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
+
+class ScaleGradient(torch.autograd.Function):
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx: Any, x: torch.Tensor, scale: float):
+ ctx.scale = scale
+ return x
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx: torch.Tensor, grad: torch.Tensor):
+ return grad * ctx.scale, None
+
+
+scale_gradient = ScaleGradient.apply
+
+
+def resolve_dtensor(weight: torch.Tensor):
+ if version.parse(torch.__version__) >= version.parse('2.0.0'):
+ from torch.distributed._tensor import DTensor
+ if isinstance(weight, DTensor):
+ return weight.to_local()
+ return weight
+
+
+def create_moe_expert_weights(
+ args: Arguments,
+ num_experts: int,
+ ffn_hidden_size: int,
+ hidden_size: int,
+ init_method: InitFn,
+):
+ # Create the entire weight matrix such that the sampled weights will
+ # not vary between data parallelism and expert model parallelism for
+ # the same random seed.
+ master_weights = torch.empty(
+ num_experts,
+ ffn_hidden_size,
+ hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ )
+ init_method(master_weights)
+
+ if not args.moe_expert_model_parallelism:
+ return master_weights
+
+ # Calculate the amount of sharding in each dimension.
+ expert_sharding_degree = mpu.expert_sharding_degree(args)
+ hidden_sharding_degree = mpu.hidden_sharding_degree(args)
+
+ # Calculate the experts per rank.
+ #
+ # NOTE: We assign ranks to be expert parallel before going
+ # tensor parallel.
+ rank = mpu.get_expert_parallel_rank(args)
+ expert_rank = rank % expert_sharding_degree
+ num_experts_per_rank = num_experts // expert_sharding_degree
+ start_expert = expert_rank * num_experts_per_rank
+ end_expert = (expert_rank + 1) * num_experts_per_rank
+
+ # Calculate the rows per rank.
+ row_rank = rank // expert_sharding_degree
+ num_rows_per_rank = ffn_hidden_size // hidden_sharding_degree
+ start_row = row_rank * num_rows_per_rank
+ end_row = (row_rank + 1) * num_rows_per_rank
+
+ # Slice the weight matrix to get the chunk for this rank.
+ with torch.no_grad():
+ weights = master_weights[start_expert:end_expert, start_row:end_row]
+ return weights
+
+
+class MLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ # expert_parallel_world_size = mpu.get_expert_parallel_world_size(args)
+ experts_per_rank = mpu.experts_per_rank(args)
+
+ self.w1 = torch.nn.Parameter(
+ torch.empty(
+ experts_per_rank,
+ args.hidden_size,
+ mpu.features_per_rank(args),
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ self.w2 = torch.nn.Parameter(
+ torch.empty(
+ experts_per_rank,
+ mpu.features_per_rank(args),
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w1,
+ args.moe_expert_model_parallelism,
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w2,
+ args.moe_expert_model_parallelism,
+ )
+
+ # Initialize the parameters for the MLP.
+ #
+ # NOTE: It is important that we create the weight tensors prior
+ # to creating the master weights and slicing our the piece for
+ # this rank. If the master weights are created first the PyTorch
+ # caching allocator appears to use the same memory block for these
+ # and the slice which causes large increases in our peak memory
+ # usage.
+ with torch.no_grad():
+ w1 = create_moe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ )
+ self.w1.copy_(w1.transpose(1, 2).contiguous())
+ self.w2.copy_(
+ create_moe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.output_layer_init_method,
+ ),
+ )
+
+ self.gradient_scale = None
+ if self.args.moe_expert_model_parallelism:
+ self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,)
+
+ def scale_grad(self, w):
+ if self.gradient_scale is None:
+ return w
+ return scale_gradient(w, self.gradient_scale)
+
+ def forward(self, x):
+ w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2)
+ w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2)
+ x = torch.bmm(x, w1)
+ x = self.args.activation_fn(x)
+ return torch.bmm(x, w2)
+
+
+def create_dmoe_expert_weights(
+ args: Arguments,
+ num_experts: int,
+ rows: int,
+ columns: int,
+ init_method: InitFn,
+):
+ weights = create_moe_expert_weights(
+ args,
+ num_experts,
+ rows,
+ columns,
+ init_method,
+ )
+ return weights.view([-1, columns])
+
+
+class MemoryOptimizedMLP(torch.autograd.Function):
+ """Sparse MLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, w2, topo, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.")
+
+ topo_tensors = (
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets,
+ topo.column_indices_t,
+ topo.offsets_t,
+ topo.block_offsets_t,
+ )
+
+ # Layer 0: x @ w1.t().
+ sdd_out = stk.ops.sdd(x, w1.t(), topo)
+
+ # GeLU.
+ activation_fn_out = act_fn(sdd_out, activation_fn)
+
+ # Layer 1: x @ w2.
+ dsd_out = stk.ops.dsd(activation_fn_out, w2)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.shape = topo.shape
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.data.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, w2, *topo_tensors, x, sdd_out.data)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx, ddsd_out):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, w2 = saved_tensors[:2]
+ topo_tensors = saved_tensors[2:8]
+ x = saved_tensors[8]
+ sdd_out_data = saved_tensors[9]
+
+ # rematerialize activation function output
+ activation_fn = ctx.activation_fn
+ sdd_out = stk.Matrix(ctx.shape, sdd_out_data, *topo_tensors)
+ activation_fn_out, activation_grad_fn = act_fn(
+ sdd_out,
+ activation_fn,
+ return_grad_fn=True,
+ )
+
+ # Compute dw2 with recomputed activation_fn output.
+ dw2 = stk.ops.dsd(activation_fn_out.t(), ddsd_out)
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ stk.backend.triton_kernels.sdd(
+ ddsd_out,
+ w2.t(),
+ dactivation_fn_out.shape,
+ dactivation_fn_out.data,
+ dactivation_fn_out.offsets,
+ dactivation_fn_out.row_indices,
+ dactivation_fn_out.column_indices,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ if activation_fn is DEFAULT_ACTIVATION_FN:
+ dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out)
+ else:
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out.data)
+ dsdd_out = stk.Matrix(ctx.shape, sdd_out.data.grad, *topo_tensors)
+
+ # Compute dw1.
+ dw1 = stk.ops.dsd(dsdd_out.t(), x)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ stk.backend.triton_kernels.dsd(
+ dsdd_out.shape,
+ dsdd_out.data,
+ dsdd_out.offsets,
+ dsdd_out.row_indices,
+ dsdd_out.column_indices,
+ dsdd_out.offsets_t,
+ dsdd_out.column_indices_t,
+ dsdd_out.block_offsets_t,
+ False,
+ w1,
+ ddsd_out,
+ )
+ dx = ddsd_out
+ return dx, dw1, dw2, None, None
+
+
+memory_optimized_mlp = MemoryOptimizedMLP.apply
+
+
+class SparseMLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ self._num_rows_per_rank = mpu.experts_per_rank(args) * mpu.features_per_rank(args)
+
+ self.w1 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ self.w2 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+
+ # Initialize the parameters for the MLP.
+ #
+ # NOTE: It is important that we create the weight tensors prior
+ # to creating the master weights and slicing our the piece for
+ # this rank. If the master weights are created first the PyTorch
+ # caching allocator appears to use the same memory block for these
+ # and the slice which causes large increases in our peak memory
+ # usage.
+ with torch.no_grad():
+ self.w1.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ ),
+ )
+ self.w2.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.output_layer_init_method,
+ ),
+ )
+
+ self._should_set_parallelism_attribute = args.moe_expert_model_parallelism
+ mpu.set_expert_model_parallel_attributes(
+ self.w1,
+ self._should_set_parallelism_attribute,
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w2,
+ self._should_set_parallelism_attribute,
+ )
+
+ self.gradient_scale = None
+ if self.args.moe_expert_model_parallelism:
+ self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,)
+
+ def scale_grad(self, w):
+ if self.gradient_scale is None:
+ return w
+ return scale_gradient(w, self.gradient_scale)
+
+ def forward(self, x, topo):
+ w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2)
+ w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2)
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_mlp(
+ x,
+ w1,
+ w2,
+ topo,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ x = stk.ops.sdd(x, w1.t(), topo)
+ activation_fn_out = act_fn(x, self.args.activation_fn)
+ return stk.ops.dsd(activation_fn_out, w2)
+
+
+class MemoryOptimizedGroupedMLP(torch.autograd.Function):
+ """GroupedMLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, w2, batch_sizes, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.")
+
+ # Layer 0: x @ w1.t().
+ assert gg.backend is not None
+ sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True)
+
+ # activation_fn
+ activation_fn_out = activation_fn(sdd_out)
+
+ # Layer 1: x @ w2.
+ dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, w2, batch_sizes, x, sdd_out)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx: Any, ddsd_out: torch.Tensor):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # Unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, w2 = saved_tensors[:2]
+ batch_sizes = saved_tensors[2]
+ x = saved_tensors[3]
+ sdd_out = saved_tensors[4]
+
+ # Rematerialize activation_fn output.
+ activation_fn = ctx.activation_fn
+ with torch.set_grad_enabled(True):
+ sdd_out.requires_grad = True
+ activation_fn_out = activation_fn(sdd_out)
+ activation_grad_fn = activation_fn_out.backward
+
+ # Compute dw2 with recomputed activation_fn output.
+ assert gg.backend is not None
+ dw2 = gg.backend.gmm(
+ activation_fn_out,
+ ddsd_out,
+ batch_sizes,
+ trans_a=True,
+ )
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ gg.backend.gmm(
+ ddsd_out,
+ w2,
+ batch_sizes,
+ trans_b=True,
+ c=dactivation_fn_out,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ if activation_fn is DEFAULT_ACTIVATION_FN:
+ dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out)
+ else:
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out)
+ dsdd_out = sdd_out.grad
+
+ # Compute dw1.
+ dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ gg.backend.gmm(dsdd_out, w1, batch_sizes, c=ddsd_out)
+ dx = ddsd_out
+ return dx, dw1, dw2, None, None
+
+
+memory_optimized_grouped_mlp = MemoryOptimizedGroupedMLP.apply
+
+
+class GroupedMLP(SparseMLP):
+
+ def forward(self, x, tokens_per_expert):
+ batch_sizes = tokens_per_expert.cpu().to(torch.long)
+ w1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.w2))
+
+ # Re-shape the weights for the grouped GEMMs.
+ ne = mpu.experts_per_rank(self.args)
+ w1 = resolve_dtensor(w1).view(ne, -1, self.args.hidden_size)
+ w2 = resolve_dtensor(w2).view(ne, -1, self.args.hidden_size)
+
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_grouped_mlp(
+ x,
+ w1,
+ w2,
+ batch_sizes,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ assert gg.ops is not None
+ x = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
+ x = self.args.activation_fn(x)
+ return gg.ops.gmm(x, w2, batch_sizes)
+
+
+class SharedMLP(torch.nn.Module):
+ """MLP for shared expert.
+
+ Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTMLP class
+ """
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ self.fc_kwargs: dict[str, Any] = {
+ 'bias': args.bias,
+ 'device': args.device,
+ }
+ self.fc_kwargs.update(args.fc_kwargs)
+
+ self.up_proj = args.fc_cls(
+ args.hidden_size,
+ args.shared_expert_hidden_size,
+ **self.fc_kwargs,
+ )
+ self.act = args.activation_fn
+ self.down_proj = args.fc_cls(
+ args.shared_expert_hidden_size,
+ args.hidden_size,
+ **self.fc_kwargs,
+ )
+ self.down_proj._is_residual = True # a flag for llm-foundry init
+
+ def add_experts_sharedexpert(
+ self,
+ shared_expert_out: torch.Tensor,
+ expert_out: torch.Tensor,
+ ) -> torch.Tensor:
+ # Helper function to add expert output to shared expert output
+ # with optional weighted sum.
+ if self.args.shared_expert_weighted_sum:
+ # enable using weighted sum for shared expert output
+ # wieghted by number of experts used
+ t_experts = self.args.moe_top_k + 1
+ sh_mlp_out = shared_expert_out / t_experts
+ return sh_mlp_out.add(
+ expert_out,
+ alpha=(self.args.moe_top_k / t_experts),
+ )
+
+ return shared_expert_out + expert_out
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.down_proj(self.act(self.up_proj(x)))
diff --git a/build/torch211-cxx11-cu130-x86_64-linux/_layers/moe.py b/build/torch211-cxx11-cu130-x86_64-linux/_layers/moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0a4aeaacc9c86fc70944e730c53f7a55644e05e
--- /dev/null
+++ b/build/torch211-cxx11-cu130-x86_64-linux/_layers/moe.py
@@ -0,0 +1,507 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Optional, Tuple
+
+import numpy as np
+import torch
+import torch.distributed as dist
+
+# import megablocks.ops as ops
+# from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry
+# from megablocks.layers.all_to_all import all_to_all
+# from megablocks.layers.arguments import Arguments
+
+from ..ops import (
+ sort,
+ histogram,
+ inclusive_cumsum,
+ exclusive_cumsum,
+ binned_gather,
+ binned_scatter,
+ gather,
+ scatter,
+ repeat,
+ replicate,
+)
+
+from . import common, mlp, mpu, router, sharedexpert_registry
+from .arguments import Arguments
+from .all_to_all import all_to_all
+
+_LOAD_BALANCING_LOSS = []
+
+
+def save_load_balancing_loss(loss):
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.append(loss)
+
+
+def get_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ return _LOAD_BALANCING_LOSS
+
+
+def clear_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.clear()
+
+
+def batched_load_balancing_loss(args: Arguments):
+ if args.moe_loss_weight == 0:
+ return 0.0
+
+ # tokens_per_expert[i].shape = (num_experts)
+ # expert_scores[i].shape = (tokens, num_experts)
+ tokens_per_expert, expert_scores = zip(*get_load_balancing_loss())
+ num_layers_per_pipeline_stage = (args.num_layers // args.pipeline_model_parallel_size)
+ if args.num_layers_per_virtual_pipeline_stage is not None:
+ num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage
+
+ if len(tokens_per_expert) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f'Expected {num_layers_per_pipeline_stage} token_per_experts '
+ f'but found {len(tokens_per_expert)}.\nnum_layers = '
+ f'{args.num_layers}\npipeline_model_parallel_size = '
+ f'{args.pipeline_model_parallel_size}\n'
+ 'num_layers_per_virtual_pipeline_stage'
+ f' = {args.num_layers_per_virtual_pipeline_stage}',
+ )
+ if len(expert_scores) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f'Expected {num_layers_per_pipeline_stage} expert_scores '
+ f'but found {len(tokens_per_expert)}.\nnum_layers = '
+ f'{args.num_layers}\npipeline_model_parallel_size = '
+ f'{args.pipeline_model_parallel_size}\n'
+ 'num_layers_per_virtual_pipeline_stage'
+ f' = {args.num_layers_per_virtual_pipeline_stage}',
+ )
+
+ # Verify the shape of the tokens_per_expert and expert_scores tensors.
+ assert all((x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert))
+
+ tokens = expert_scores[0].shape[0]
+ assert all(((x.ndim == 2 and x.shape[1] == args.moe_num_experts and x.shape[0] == tokens) for x in expert_scores))
+
+ # Concatenate the contributions of each layer and convert to
+ # the correct types and formats for the dot product.
+ expert_scores = torch.cat(expert_scores, dim=1)
+ if args.moe_lbl_in_fp32:
+ expert_scores = expert_scores.float()
+ if tokens != 0:
+ expert_scores = expert_scores.mean(dim=0)
+ else:
+ expert_scores = expert_scores.sum(dim=0)
+ tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype)
+
+ expected_values = num_layers_per_pipeline_stage * args.moe_num_experts
+ assert tokens_per_expert.numel() == expected_values
+ assert expert_scores.numel() == expected_values
+
+ # Calculate the total scale across all factors.
+ #
+ # loss_weight * num_experts / (num_layers * tokens * top_k)
+ scale_numerator = (args.moe_num_experts * args.moe_loss_weight)
+ scale_denominator = (args.num_layers * tokens * args.moe_top_k)
+ scale = scale_numerator / scale_denominator
+ return scale * torch.dot(tokens_per_expert, expert_scores)
+
+
+# NOTE: This class defines MoE expert computation, including expert model parallel
+# communication. When using FSDP on top of MegaBlocks this is the module that should
+# be wrapped s.t. the weight all-gathers can be scheduled *before* the expert model
+# parallel all2all.
+class ParallelMLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super(ParallelMLP, self).__init__()
+ self.args = args
+
+ # Calculate the number of experts in total and the number of experts
+ # owned by this rank.
+ # world_size = mpu.get_expert_parallel_world_size(args)
+ self.num_experts = args.moe_num_experts
+ self.top_k = self.args.moe_top_k
+
+ # Calculate the number of bits needed to represent the expert indices
+ # so that we can pass it to radix sort.
+ self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1)
+
+ # Expert MLP.
+ self.mlp = mlp.MLP(args)
+
+ self.bias: Optional[torch.Tensor]
+ if self.args.bias:
+ # Note that the output bias is not parallelized with expert
+ # model parallelism.
+ self.bias = torch.nn.Parameter(
+ torch.empty(
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ torch.nn.init.zeros_(self.bias)
+ else:
+ self.register_parameter('bias', None)
+
+ # Select the forward function for the operating mode.
+ self.forward_fn = (self.parallel_forward_once if args.moe_expert_model_parallelism else self.forward_once)
+
+ def expert_capacity(self, tokens: int) -> int:
+ world_size = mpu.get_expert_parallel_world_size(self.args)
+ tokens_per_expert = (self.top_k * tokens * world_size / self.num_experts)
+ return int(self.args.moe_capacity_factor * tokens_per_expert)
+
+ def load_balancing_loss(self, tokens_per_expert: torch.Tensor, expert_scores: torch.Tensor):
+ """Calculate the load balancing loss contribution."""
+ assert len(expert_scores.size()) == 2
+ tokens, num_experts = expert_scores.size()
+ assert num_experts == self.num_experts
+ assert len(tokens_per_expert.size()) == 1
+ num_experts, = tokens_per_expert.size()
+ assert num_experts == self.num_experts
+ scale = self.num_experts / (tokens * self.top_k)
+ return scale * torch.dot(
+ tokens_per_expert.to(expert_scores.dtype),
+ expert_scores.mean(dim=0),
+ )
+
+ def indices_and_bins(self,
+ top_expert: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ # Sort the expert ids to produce the scatter/gather
+ # indices for the permutation.
+ #
+ # TODO(tgale): Is it worth doing this conversion to 32-bit
+ # prior? Could we place the `torch.max` operation to return
+ # 32-bit expert indices?
+ top_expert = top_expert.int()
+ # output = ops.sort(top_expert, self.sort_end_bit)
+ output = sort(top_expert, self.sort_end_bit)
+ assert output is not None
+ bin_ids, indices = output
+
+ # Histogram the expert ids to identify the number of
+ # tokens routed to each expert.
+ #
+ # TODO(tgale): Does the sorted data produce a more favorable
+ # data distribution for histogram? Or is the op parallelism
+ # worth more?
+ # tokens_per_expert = ops.histogram(top_expert, self.num_experts)
+ tokens_per_expert = histogram(top_expert, self.num_experts)
+
+ # Calculate the bin bounds for the sorted tokens.
+ # bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ bins = inclusive_cumsum(tokens_per_expert, 0)
+ assert bins is not None
+ bins = bins.view(1) if not len(bins.size()) else bins
+
+ assert isinstance(indices, torch.Tensor)
+ assert isinstance(bin_ids, torch.Tensor)
+ assert isinstance(bins, torch.Tensor)
+ assert isinstance(tokens_per_expert, torch.Tensor)
+
+ return indices, bin_ids, bins, tokens_per_expert
+
+ def permute_and_compute(
+ self,
+ x: torch.Tensor,
+ tokens_per_expert: int, # unused
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor, # unused
+ expert_weights: torch.Tensor,
+ bins: torch.Tensor,
+ expert_capacity: int,
+ top_k: int,
+ ):
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ # output = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
+ output = binned_gather(x, indices, bins, expert_capacity, top_k)
+ assert output is not None
+ x = output
+
+ # Perform the expert computation. Note that we don't
+ # use biases for these linear operations.
+ x = self.mlp(x)
+
+ # Un-route the data for the MoE output.
+ # return ops.binned_scatter(x, indices, expert_weights, bins, top_k)
+ return binned_scatter(x, indices, expert_weights, bins, top_k)
+
+
+ def forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ # If expert_capacity is set to zero, set the number of tokens
+ # per expert to the maximum we need to avoid dropping tokens.
+ sl, bs, _ = x.size()
+ expert_capacity = self.expert_capacity(sl * bs)
+ if expert_capacity == 0:
+ expert_capacity = torch.max(tokens_per_expert).item()
+
+ x = self.permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ self.top_k,
+ )
+ return x, tokens_per_expert
+
+ def parallel_forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ # NOTE: This function implements the same computation as forward_once
+ # but with expert model parallelism.
+ #
+ # 1. Permute the tokens locally so that they are grouped by their
+ # expert assignments. This allows us to transfer all of the tokens
+ # for a remote device in one communication primitive.
+ #
+ # 2. Permute the tokens across the expert parallel devices. After
+ # this is completed each device has all of the tokens assigned to
+ # its set of experts in its local HBM.
+ #
+ # 3. Permute the tokens locally so that they are grouped by their
+ # expert assignement. After the distributed permutation the tokens
+ # are grouped by which device they came from. We re-order them
+ # locally to allow for efficient computation.
+ #
+ # After this series of permutations we compute the linear layers
+ # and then repeat these three steps in reverse to produce the final
+ # output.
+ #
+ # Compute the mapping of local tokens to experts.
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ # If we're sharding the experts along the hidden dimension
+ # multiple devices own parts of the same sets of experts.
+ # Replicate the token counts so every device gets the counts.
+ # repeated_tokens_per_expert = ops.repeat(
+ repeated_tokens_per_expert = repeat(
+ tokens_per_expert,
+ (mpu.hidden_sharding_degree(self.args),),
+ )
+
+ # Pass token count information to the device on which the
+ # target expert resides.
+ parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert,)
+ tpe_handle = dist.all_to_all_single(
+ parallel_tokens_per_expert,
+ repeated_tokens_per_expert,
+ group=self.args.expert_parallel_group,
+ async_op=True,
+ )
+
+ # Permute locally and without any padding so that tokens for each
+ # parallel device are stored contiguously.
+ #
+ # This view updates the shape of the tensor from [sl, bs, hs] to
+ # [sl * bs, hs] prior to the permutation.
+ x = x.view(-1, x.shape[-1])
+ # output = ops.gather(x, indices, bin_ids, bins, self.top_k)
+ output = gather(x, indices, bin_ids, bins, self.top_k)
+ assert output is not None
+ x = output
+
+ # Compute the number of tokens that will be received from each
+ # device and permute the input data across the devices.
+ with torch.no_grad():
+ tpe_handle.wait()
+ experts_per_rank = mpu.experts_per_rank(self.args)
+
+ # Reshape to [world_size, num_experts_per_rank].
+ world_size = mpu.get_expert_parallel_world_size(self.args)
+ repeated_tokens_per_expert = (repeated_tokens_per_expert.view(world_size, experts_per_rank))
+ parallel_tokens_per_expert = (parallel_tokens_per_expert.view(world_size, experts_per_rank))
+
+ # TODO(tgale): It might be faster to do this on the GPU and
+ # then communicate the results back to the host.
+ send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1)
+ parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
+ recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1)
+
+ # Convert the send/recv counts to lists.
+ send_counts = send_counts.tolist()
+ recv_counts = recv_counts.tolist()
+ tokens_received = sum(recv_counts)
+
+ # If we're sharding the experts along the hidden dimension
+ # multiple devices own parts of the same sets of experts.
+ # Replicate the token counts so devices that share experts
+ # get all of the tokens assigned to them.
+ #
+ # TODO(tgale): Fuse this into the prior, local permutation.
+ # x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
+ x = repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
+
+ # Start the cross-device permutation asynchronously so we can
+ # overlap communication with computation.
+ parallel_x, parallel_x_handle = all_to_all(
+ x,
+ recv_counts,
+ send_counts,
+ self.args.expert_parallel_group,
+ async_op=True,
+ )
+
+ with torch.no_grad():
+ # After we do the cross-device permutation we have the tokens on the
+ # correct device but not yet grouped by expert because we received
+ # tokens from each device as contiguous chunks. To group the tokens
+ # for expert computation we'll do one more local permutation. The
+ # rest of this torch.no_grad() scope sets up the indices and bins
+ # for this permutation.
+ # replicate_bins = ops.inclusive_cumsum(
+ replicate_bins = inclusive_cumsum(
+ parallel_tokens_per_expert.flatten(),
+ 0,
+ )
+ replicate_bins = (replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins)
+
+ # Construct the expert indices for the permuted tokens.
+ parallel_top_expert = torch.remainder(
+ torch.arange(
+ self.num_experts * mpu.hidden_sharding_degree(self.args),
+ dtype=torch.int32,
+ device=indices.device,
+ ),
+ mpu.experts_per_rank(self.args),
+ )
+ # parallel_top_expert = ops.replicate(
+ parallel_top_expert = replicate(
+ parallel_top_expert.unsqueeze(dim=0),
+ replicate_bins,
+ tokens_received,
+ ).flatten()
+
+ # TODO(tgale): The sort_end_bit here can be reduced.
+ # parallel_bin_ids, parallel_indices = ops.sort(
+ parallel_bin_ids, parallel_indices = sort(
+ parallel_top_expert,
+ self.sort_end_bit,
+ )
+
+ # Calculate the bins boundaries from the token counts.
+ parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+ # parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = (parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins)
+
+ # If expert_capacity is set to zero, set the number of tokens
+ # per expert to the maximum we need to avoid dropping tokens.
+ tokens, _ = x.size()
+ expert_capacity = self.expert_capacity(tokens)
+ if expert_capacity == 0:
+ expert_capacity = torch.max(parallel_tokens_per_expert).item()
+
+ # Locally permute the tokens and perform the expert computation.
+ # Block to make sure that the cross-device permutation is complete.
+ if self.args.mlp_impl == 'grouped':
+ # GroupedMLP requires counts on CPU. We can use the tensor already
+ # moved to CPU for the prior all_to_all, which avoids an extra
+ # device synchronization.
+ parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+ parallel_x_handle.wait()
+ parallel_x = self.permute_and_compute(
+ parallel_x,
+ parallel_tokens_per_expert,
+ parallel_indices,
+ parallel_bin_ids,
+ None, # expert_weights
+ parallel_bins,
+ expert_capacity,
+ top_k=1,
+ )
+
+ # Un-permute the tokens across the devices.
+ x, _ = all_to_all(
+ parallel_x,
+ send_counts,
+ recv_counts,
+ self.args.expert_parallel_group,
+ )
+
+ # Reduce along the hidden sharding to get the final outputs.
+ #
+ # TODO(tgale): Fuse this into the following local permutation.
+ shape = (
+ mpu.hidden_sharding_degree(self.args),
+ -1,
+ self.args.hidden_size,
+ )
+ # x = ops.sum(x.view(shape), dim=0)
+ x = x.view(shape).sum(dim=0)
+
+ # Un-permute locally to setup for the next series of operations.
+ # x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
+ x = scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
+ return x, tokens_per_expert.flatten()
+
+ def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ in_shape = x.size()
+
+ # Compute the experts.
+ x, tokens_per_expert = self.forward_fn(x, expert_weights, top_experts)
+ if self.training and self.args.moe_loss_weight > 0:
+ save_load_balancing_loss((tokens_per_expert, scores))
+ x = x.view(in_shape)
+ if self.bias is not None:
+ if self.args.return_bias:
+ return x, self.bias
+ return x + self.bias
+ return x
+
+
+class MoE(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super(MoE, self).__init__()
+
+ # Token router.
+ self.router = router.LearnedRouter(args)
+
+ # Expert computation helper.
+ self.experts = self._init_experts_mlp(args)
+
+ self.shared_expert = None
+ if args.shared_expert:
+ # SharedExpert computation helper.
+ self.shared_expert = sharedexpert_registry.get(args)
+
+ def _init_experts_mlp(self, args: Arguments):
+ return ParallelMLP(args)
+
+ def forward(self, x: torch.Tensor):
+ # NOTE: If we're going to cast the activations to lower precision
+ # do it before we permute the tokens to save bandwidth.
+ x = common.cast_if_autocast_enabled(x)
+
+ # Compute the expert scores and assignments.
+ scores, expert_weights, top_experts = self.router(x)
+
+ # Compute the experts.
+ out = self.experts(x, scores, expert_weights, top_experts)
+ if self.shared_expert is not None:
+ shared_expert_out = self.shared_expert(x)
+ out = self.shared_expert.add_experts_sharedexpert(
+ shared_expert_out,
+ out,
+ )
+ return out
diff --git a/build/torch211-cxx11-cu130-x86_64-linux/_layers/mpu.py b/build/torch211-cxx11-cu130-x86_64-linux/_layers/mpu.py
new file mode 100644
index 0000000000000000000000000000000000000000..434e143ab42bf3f83406d69e9dd1f72777716e22
--- /dev/null
+++ b/build/torch211-cxx11-cu130-x86_64-linux/_layers/mpu.py
@@ -0,0 +1,94 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Optional
+
+import torch
+import torch.distributed as dist
+
+# from megablocks.layers.arguments import Arguments
+from .arguments import Arguments
+
+
+class MoeParam(torch.Tensor):
+
+ def __init__(self):
+ super().__init__(self)
+ self.expert_model_parallel: bool
+
+
+def is_moe_param(tensor: torch.Tensor) -> bool:
+ return hasattr(tensor, 'expert_model_parallel')
+
+
+def get_expert_parallel_world_size(args: Arguments) -> int:
+ return (dist.get_world_size(args.expert_parallel_group) if args.moe_expert_model_parallelism else 1)
+
+
+def get_expert_parallel_rank(args: Arguments) -> int:
+ return (dist.get_rank(args.expert_parallel_group) if args.moe_expert_model_parallelism else 0)
+
+
+def set_expert_model_parallel_attributes(
+ tensor: torch.Tensor,
+ is_parallel: bool,
+):
+ assert not hasattr(tensor, 'expert_model_parallel')
+ setattr(tensor, 'expert_model_parallel', is_parallel)
+
+
+def param_is_expert_model_parallel(param: MoeParam) -> bool:
+ return (hasattr(param, 'expert_model_parallel') and param.expert_model_parallel)
+
+
+def copy_expert_model_parallel_attributes(
+ destination_tensor: torch.Tensor,
+ source_tensor: torch.Tensor,
+):
+ if hasattr(source_tensor, 'expert_model_parallel'):
+ setattr(
+ destination_tensor,
+ 'expert_model_parallel',
+ getattr(source_tensor, 'expert_model_parallel'),
+ )
+
+
+def synchronized_print(group: Optional[dist.ProcessGroup], *x: torch.Tensor):
+ world_size = dist.get_world_size(group)
+ rank = dist.get_rank(group)
+ for i in range(world_size):
+ dist.barrier(group)
+ if i == rank:
+ print(f'rank = {rank}', *x)
+
+
+# Helpers for expert/tensor sharding.
+def expert_sharding_degree(args: Arguments) -> int:
+ world_size = get_expert_parallel_world_size(args)
+ esd = min(world_size, args.moe_num_experts)
+
+ if (args.moe_num_experts % esd) != 0:
+ raise ValueError(f'Cannot shard {args.moe_num_experts} experts {esd} ways.',)
+ return esd
+
+
+def hidden_sharding_degree(args: Arguments) -> int:
+ world_size = get_expert_parallel_world_size(args)
+ esd = expert_sharding_degree(args)
+ hsd = world_size // esd
+
+ if (args.ffn_hidden_size % hsd) != 0:
+ raise ValueError(f'Cannot shard {args.ffn_hidden_size} features {hsd} ways.',)
+ if (esd * hsd) != world_size:
+ raise ValueError(
+ f"Invalid sharding. 'expert_sharding_degree' ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size}).",
+ )
+ return hsd
+
+
+def experts_per_rank(args: Arguments) -> int:
+ return args.moe_num_experts // expert_sharding_degree(args)
+
+
+def features_per_rank(args: Arguments) -> int:
+ return args.ffn_hidden_size // hidden_sharding_degree(args)
diff --git a/build/torch211-cxx11-cu130-x86_64-linux/_layers/router.py b/build/torch211-cxx11-cu130-x86_64-linux/_layers/router.py
new file mode 100644
index 0000000000000000000000000000000000000000..37cb2782348d62583376f1a183c7ede83601216d
--- /dev/null
+++ b/build/torch211-cxx11-cu130-x86_64-linux/_layers/router.py
@@ -0,0 +1,116 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+
+# from megablocks.layers import common
+# from megablocks.layers.arguments import Arguments
+from . import common
+from .arguments import Arguments
+
+_ROUTER_LOGITS = []
+
+
+def _save_router_logits(logits: torch.Tensor, args: Arguments):
+ if args.moe_zloss_weight == 0:
+ return
+ global _ROUTER_LOGITS
+ _ROUTER_LOGITS.append(logits)
+
+
+def clear_router_zloss():
+ global _ROUTER_LOGITS
+ _ROUTER_LOGITS.clear()
+
+
+def batched_router_zloss(args: Arguments):
+ global _ROUTER_LOGITS
+
+ if args.moe_zloss_weight == 0:
+ import warnings
+ warnings.warn('Call to batched_router_zloss, but moe_zloss_weight=0')
+ return 0
+
+ logits_per_router = _ROUTER_LOGITS
+
+ if args.moe_zloss_in_fp32:
+ logits_per_router = [logits.float() for logits in logits_per_router]
+
+ unscaled_zloss_per_router = torch.stack([
+ torch.logsumexp(logits, dim=1).square().mean() for logits in logits_per_router
+ ])
+
+ return args.moe_zloss_weight * unscaled_zloss_per_router
+
+
+# NOTE: To enable end-to-end benchmarking without convergence we
+# support a flag to force the router to assign tokens uniformly
+# across the experts. We do this with a custom autograd operation
+# so that PyTorch still executes the full set of router operation.
+class _UniformExpertAssignment(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, num_experts: int):
+ out = torch.arange(x.numel(), dtype=x.dtype, device=x.device)
+ out = torch.remainder(out, num_experts)
+ return out.view(x.shape)
+
+
+_uniform_expert_assignment = _UniformExpertAssignment.apply
+
+
+class LearnedRouter(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+
+ # Learned router parameters.
+ #
+ # NOTE: This weight matrix is not parallelized with expert model
+ # parallelism. Each device needs the entire router weight matrix
+ # so that it can route its batch of data correctly.
+ self.layer = torch.nn.Linear(
+ args.hidden_size,
+ args.moe_num_experts,
+ bias=False,
+ dtype=common.dtype(args),
+ device=args.device,
+ )
+ args.init_method(self.layer.weight)
+
+ def jitter(self, x: torch.Tensor):
+ low: float = 1.0 - self.args.moe_jitter_eps
+ high: float = 1.0 + self.args.moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return low + noise * (high - low)
+
+ def _top_k(self, scores: torch.Tensor):
+ if self.args.moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, self.args.moe_top_k, dim=-1)
+
+ def forward(self, x: torch.Tensor):
+ if self.training and self.args.moe_jitter_eps is not None:
+ x = x * self.jitter(x)
+
+ logits = self.layer(x.view(-1, x.shape[-1]))
+ _save_router_logits(logits, self.args)
+ scores = logits.softmax(dim=-1)
+ expert_weights, expert_indices = self._top_k(scores)
+ if self.args.moe_normalize_expert_weights:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=self.args.moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+
+ expert_indices = (
+ _uniform_expert_assignment(
+ expert_indices,
+ self.args.moe_num_experts,
+ ) if self.args.uniform_expert_assignment else expert_indices
+ )
+ return scores, expert_weights, expert_indices
diff --git a/build/torch211-cxx11-cu130-x86_64-linux/_layers/sharedexpert_registry.py b/build/torch211-cxx11-cu130-x86_64-linux/_layers/sharedexpert_registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..5840862f88f370ace5fd49bd0612fc98d186cc49
--- /dev/null
+++ b/build/torch211-cxx11-cu130-x86_64-linux/_layers/sharedexpert_registry.py
@@ -0,0 +1,32 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Union
+
+# from megablocks.layers import glu, mlp
+# from megablocks.layers.arguments import Arguments
+from . import glu, mlp
+from .arguments import Arguments
+
+_REGISTRY = {
+ 'mlp': mlp.SharedMLP,
+ 'glu': glu.SharedGLU,
+}
+
+
+def get(args: Arguments) -> Union[mlp.SharedMLP, glu.SharedGLU]:
+ """Returns an SharedMLP for use in a dMoE instance.
+
+ Uses the provided arguments to instantiate the appropriate
+ SharedMLP instance.
+
+ Args:
+ args: propagated Arguments dataclass.
+
+ Returns:
+ An instantiated SharedMLP constructed using the input args.
+ """
+ if args.mlp_type not in _REGISTRY:
+ raise ValueError(f'Unsupported mlp type: {args.mlp_type}')
+
+ return _REGISTRY[args.mlp_type](args)
diff --git a/build/torch211-cxx11-cu130-x86_64-linux/_megablocks_cuda_7a6bcf4.abi3.so b/build/torch211-cxx11-cu130-x86_64-linux/_megablocks_cuda_7a6bcf4.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..e07c8ab8c5e122896ef80f10303311422b02dc06
--- /dev/null
+++ b/build/torch211-cxx11-cu130-x86_64-linux/_megablocks_cuda_7a6bcf4.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3f548de7e56f7b35bddd555b88836ff77d731dfa6d71c52c2198a54607dba186
+size 12041592
diff --git a/build/torch211-cxx11-cu130-x86_64-linux/_ops.py b/build/torch211-cxx11-cu130-x86_64-linux/_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..abde001f8cf5f78a02794d6e9a81fd8195e65d77
--- /dev/null
+++ b/build/torch211-cxx11-cu130-x86_64-linux/_ops.py
@@ -0,0 +1,9 @@
+import torch
+from . import _megablocks_cuda_7a6bcf4
+ops = torch.ops._megablocks_cuda_7a6bcf4
+
+def add_op_namespace_prefix(op_name: str):
+ """
+ Prefix op by namespace.
+ """
+ return f"_megablocks_cuda_7a6bcf4::{op_name}"
diff --git a/build/torch211-cxx11-cu130-x86_64-linux/_version.py b/build/torch211-cxx11-cu130-x86_64-linux/_version.py
new file mode 100644
index 0000000000000000000000000000000000000000..c55783177af19bc03654c730c4892df8f8532279
--- /dev/null
+++ b/build/torch211-cxx11-cu130-x86_64-linux/_version.py
@@ -0,0 +1,6 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+"""The MegaBlocks Version."""
+
+__version__ = '0.11.0.dev0'
diff --git a/build/torch211-cxx11-cu130-x86_64-linux/backend/__init__.py b/build/torch211-cxx11-cu130-x86_64-linux/backend/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d4e43e9b7e6a9a1c3bde2df34914643ca5d8332
--- /dev/null
+++ b/build/torch211-cxx11-cu130-x86_64-linux/backend/__init__.py
@@ -0,0 +1,2 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
diff --git a/build/torch211-cxx11-cu130-x86_64-linux/backend/kernels.py b/build/torch211-cxx11-cu130-x86_64-linux/backend/kernels.py
new file mode 100644
index 0000000000000000000000000000000000000000..d20139324883338ddb312e4b05a72056d74491ac
--- /dev/null
+++ b/build/torch211-cxx11-cu130-x86_64-linux/backend/kernels.py
@@ -0,0 +1,557 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import triton
+import triton.language as tl
+
+# Stub triton autotune when testing in a env that does not have CUDA
+# this approach preserves the original code but enables testing without a GPU
+if torch.cuda.is_available() is False:
+ import warnings
+
+ warnings.warn("CUDA is not available. Triton autotuning is disabled.")
+
+ def _no_autotune(*args, **kwargs):
+ def deco(fn):
+ return fn
+ return deco
+
+ triton.autotune = _no_autotune
+
+
+def assert_is_tensor(x, ndim):
+ if x.ndim != ndim:
+ raise ValueError(f'Expected {ndim}-tensor but got {x.ndim}-tensor')
+
+
+def assert_is_matrix(x):
+ assert_is_tensor(x, 2)
+
+
+def assert_is_vector(x):
+ if x.ndim != 1:
+ raise ValueError(f'Expected 1-tensor but got {x.ndim}-tensor')
+
+
+def assert_equal(a, b):
+ if a != b:
+ raise ValueError(f'Expected dimensions to be equal but got {a} and {b}.',)
+
+
+# a: (tokens, hidden_size), real.
+# indices: (tokens * top_k), integer.
+# bin_ids: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+# padded_bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _padded_copy(
+ a,
+ b,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+ A_TO_B: tl.constexpr,
+ SCALE: tl.constexpr,
+):
+ # Our index into array 'a'.
+ index_a = tl.load(indices + tl.program_id(0))
+
+ # One threadblock per row in 'a'. Array 'b' has greater or equal
+ # number of rows since they could be padded.
+ bin_idx = tl.load(bin_ids + tl.program_id(0))
+
+ # Now we know what bin we're assigned to, but we need to know how
+ # many threadblocks were assigned to earlier bins so we can offset
+ # in our bin properly.
+ offset_in_bin = tl.program_id(0)
+ if bin_idx > 0:
+ offset_in_bin -= tl.load(bins + bin_idx - 1)
+
+ # Load the starting index of our bin in array 'b'.
+ index_b = offset_in_bin
+ if bin_idx > 0:
+ index_b += tl.load(padded_bins + bin_idx - 1)
+
+ # Offset the input and output pointers.
+ #
+ # If we're going from A to B, divide the input index to copy
+ # the same input repeatedly. If we're going from B to A we
+ # need to reduce the result. Using atomics is slow, so we
+ # do the reduce step in a second kernel.
+ offset = index_a // TOP_K if A_TO_B else index_a
+ a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS)
+ b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ # Load the scale, if requested.
+ scale = tl.load(weights + index_a) if SCALE else 1
+
+ # Swap the pointers depending on the direction.
+ iptr = a if A_TO_B else b
+ optr = b if A_TO_B else a
+
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ x = tl.load(iptr + offsets, mask=mask)
+ x = x.to(tl.float32) * scale.to(tl.float32)
+
+ tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask)
+
+ offsets += BLOCK_X
+
+
+def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+ assert_equal(bin_ids.shape[0], x.shape[0] * top_k)
+ assert_equal(bins.size(), padded_bins.size())
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ # NOTE: Because of the padding, the output size is dynamic.
+ # We load the final padded bin bound to get the output rows.
+ output_rows = padded_bins[-1].cpu().item()
+ out = torch.zeros((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ x,
+ out,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def gather(x, indices, bin_ids, weights, bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+ assert_equal(bin_ids.shape[0], x.shape[0] * top_k)
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ # NOTE: There is no padding so the output rows equals the
+ # input rows multiplied by top_k.
+ output_rows = x.shape[0] * top_k
+ out = torch.empty((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ x,
+ out,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], bin_ids.shape[0])
+ assert_equal(bins.size(), padded_bins.size())
+
+ if weights is not None:
+ assert_equal(indices.shape[0], weights.shape[0])
+
+ tokens = indices.shape[0] // top_k
+ out = torch.empty((tokens, top_k, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ out,
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=False,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+
+ # Reduce along the top-k dimension, if needed.
+ return out.sum(dim=1) if top_k > 1 else out.view(tokens, x.shape[1])
+
+
+def scatter(x, indices, bin_ids, weights, bins, top_k):
+ return padded_scatter(x, indices, bin_ids, weights, bins, bins, top_k)
+
+
+# x: (tokens, top_k, hidden_size), real
+# grad: (tokens, hidden_size), real.
+# wgrad: (tokens, top_k), real.
+# indices: (tokens * top_k), integer.
+# bin_ids: (tokens * top_k), integer.
+# bins: (num_experts), integer.
+# padded_bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _padded_copy_wgrad(
+ x,
+ grad,
+ wgrad,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+):
+ # Our index into 'tokens * top_k'.
+ index_out = tl.load(indices + tl.program_id(0))
+
+ # One threadblock per row in 'a'. Array 'b' has greater or equal
+ # number of rows since they could be padded.
+ bin_idx = tl.load(bin_ids + tl.program_id(0))
+
+ # Now we know what bin we're assigned to, but we need to know how
+ # many threadblocks were assigned to earlier bins so we can offset
+ # in our bin properly.
+ offset_in_bin = tl.program_id(0)
+ if bin_idx > 0:
+ offset_in_bin -= tl.load(bins + bin_idx - 1)
+
+ # Load the starting index of our bin in array 'x'.
+ index_x = offset_in_bin
+ if bin_idx > 0:
+ index_x += tl.load(padded_bins + bin_idx - 1)
+
+ # Offset the input and output pointers.
+ wgrad += index_out
+ grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS)
+ x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ data = tl.load(x + offsets, mask=mask).to(tl.float32)
+ scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
+ acc += data * scale
+ offsets += BLOCK_X
+
+ # Reduce to get the final result and store.
+ out = tl.sum(acc).to(wgrad.dtype.element_ty)
+ tl.store(wgrad, out)
+
+
+def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_matrix(grad)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], bin_ids.shape[0])
+ assert_equal(bins.size(), padded_bins.size())
+
+ tokens = indices.shape[0] // top_k
+ out = torch.empty((tokens * top_k), dtype=x.dtype, device=x.device)
+ _padded_copy_wgrad[(indices.shape[0],)](
+ x,
+ grad,
+ out,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ TOP_K=top_k,
+ )
+ return out
+
+
+def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k):
+ return padded_scatter_wgrad(x, grad, indices, bin_ids, bins, bins, top_k)
+
+
+# a: (tokens, hidden_size), real.
+# b: (num_experts, expert_capacity, num_columns), real.
+# indices: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _binned_copy(
+ a,
+ b,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+ A_TO_B: tl.constexpr,
+ SCALE: tl.constexpr,
+):
+ # Load our indices into the output.
+ expert_idx = tl.program_id(0)
+ entry_idx = tl.program_id(1)
+
+ # Calculate our offset into the output.
+ index_b = expert_idx * expert_capacity + entry_idx
+
+ # Load the index bounds for our bin and calculate
+ # the number of tokens assigned to our expert.
+ start = 0
+ if expert_idx > 0:
+ start = tl.load(bins + expert_idx - 1)
+ end = tl.load(bins + expert_idx)
+ num_tokens = end - start
+
+ # Calculate our offset into the input. If we don't
+ # have an input exit early.
+ if entry_idx >= num_tokens:
+ return
+ index_a = tl.load(indices + start + entry_idx)
+
+ # Offset the input and output pointers.
+ #
+ # If we're going from A to B, divide the input index to copy
+ # the same input repeatedly. If we're going from B to A we
+ # need to reduce the result. Using atomics is slow, so we
+ # do the reduce step in a second kernel.
+ offset = index_a // TOP_K if A_TO_B else index_a
+ a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS)
+ b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ # Load the scale, if requested.
+ scale = tl.load(weights + index_a) if SCALE else 1
+
+ # Swap the pointers depending on the direction.
+ #
+ # NOTE: We need to zero the output in both directions.
+ iptr = a if A_TO_B else b
+ optr = b if A_TO_B else a
+
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ x = tl.load(iptr + offsets, mask=mask)
+ x = x.to(tl.float32) * scale.to(tl.float32)
+
+ tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask)
+
+ offsets += BLOCK_X
+
+
+def binned_gather(x, indices, weights, bins, expert_capacity, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ num_experts = bins.shape[0]
+ out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
+
+ _binned_copy[(num_experts, expert_capacity)](
+ x,
+ out,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def binned_scatter(x, indices, weights, bins, top_k):
+ # Validate the input shapes.
+ assert_is_tensor(x, 3)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(bins.shape[0], x.shape[0])
+
+ if weights is not None:
+ assert_equal(indices.shape[0], weights.shape[0])
+
+ num_experts, expert_capacity, hidden_size = x.shape
+ tokens = indices.shape[0] // top_k
+ out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
+ _binned_copy[(num_experts, expert_capacity)](
+ out,
+ x,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS=hidden_size,
+ A_TO_B=False,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+
+ # Reduce along the top-k dimension, if needed.
+ return out.sum(dim=1) if top_k > 1 else out.view(tokens, hidden_size)
+
+
+# a: (tokens, hidden_size), real.
+# b: (num_experts, expert_capacity, num_columns), real.
+# indices: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _binned_copy_wgrad(
+ x,
+ grad,
+ wgrad,
+ num_experts,
+ expert_capacity,
+ indices,
+ bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+):
+ # Load our indices into the output.
+ expert_idx = tl.program_id(0)
+ entry_idx = tl.program_id(1)
+
+ # Calculate our offset into the output.
+ index_x = expert_idx * expert_capacity + entry_idx
+
+ # Load the index bounds for our bin and calculate
+ # the number of tokens assigned to our expert.
+ start = 0
+ if expert_idx > 0:
+ start = tl.load(bins + expert_idx - 1)
+ end = tl.load(bins + expert_idx)
+ num_tokens = end - start
+
+ # Calculate our offset into the input. If we don't
+ # have an input exit early.
+ if entry_idx >= num_tokens:
+ return
+ index_out = tl.load(indices + start + entry_idx)
+
+ # Offset the input and output pointers.
+ wgrad += index_out
+ grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS)
+ x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ data = tl.load(x + offsets, mask=mask).to(tl.float32)
+ scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
+ acc += data * scale
+ offsets += BLOCK_X
+
+ # Reduce to get the final result and store.
+ out = tl.sum(acc).to(wgrad.dtype.element_ty)
+ tl.store(wgrad, out)
+
+
+def binned_scatter_wgrad(x, grad, indices, bins, top_k):
+ # Validate the input shapes.
+ assert_is_tensor(x, 3)
+ assert_is_matrix(grad)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(bins.shape[0], x.shape[0])
+
+ num_experts, expert_capacity, hidden_size = x.shape
+ tokens = indices.shape[0] // top_k
+ out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
+ _binned_copy_wgrad[(num_experts, expert_capacity)](
+ x,
+ grad,
+ out,
+ num_experts,
+ expert_capacity,
+ indices,
+ bins,
+ NUM_COLUMNS=hidden_size,
+ TOP_K=top_k,
+ )
+ return out
diff --git a/build/torch211-cxx11-cu130-x86_64-linux/benchmark_util.py b/build/torch211-cxx11-cu130-x86_64-linux/benchmark_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..02612d95e3ead1175a596e2878fa34b5bf85ad6f
--- /dev/null
+++ b/build/torch211-cxx11-cu130-x86_64-linux/benchmark_util.py
@@ -0,0 +1,35 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import numpy as np
+import torch
+
+
+def log_benchmark(name, arguments, time, std):
+ print('=' * 60)
+ print(f'{name} Benchmark')
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean time = {:.3f}ms, std time = {:.3f}ms'.format(time, std))
+ print('=' * 60)
+
+
+def benchmark_function(fn, iterations=100, warmup=10):
+ # Warmup iterations.
+ for _ in range(warmup):
+ fn()
+
+ times = []
+ for i in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+
+ start.record()
+ fn()
+ end.record()
+
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ return np.mean(times), np.std(times)
diff --git a/build/torch211-cxx11-cu130-x86_64-linux/cpu_fused_moe.py b/build/torch211-cxx11-cu130-x86_64-linux/cpu_fused_moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a29e81d6e9c57f64b0a78ac1c0828e45fd9d855
--- /dev/null
+++ b/build/torch211-cxx11-cu130-x86_64-linux/cpu_fused_moe.py
@@ -0,0 +1,311 @@
+# SPDX-License-Identifier: Apache-2.0
+# MegaBlocks CPU Fused MoE Implementation
+#
+# This is a pure Python/PyTorch implementation for CPU.
+# For better performance, consider using the C++ kernel implementation.
+#
+import torch
+import torch.nn.functional as F
+
+
+def swigluoai_activation(gate: torch.Tensor, up: torch.Tensor,
+ alpha: float = 1.702, limit: float = 7.0) -> torch.Tensor:
+ """
+ SwigluOAI activation function used in GptOss models.
+
+ Formula:
+ gate = clamp(gate, max=limit)
+ up = clamp(up, -limit, limit)
+ glu = gate * sigmoid(gate * alpha)
+ output = (up + 1) * glu
+
+ Args:
+ gate: Gate tensor from gate projection
+ up: Up tensor from up projection
+ alpha: Scaling factor for sigmoid (default: 1.702)
+ limit: Clamp limit (default: 7.0)
+
+ Returns:
+ Activated tensor
+ """
+ gate = gate.clamp(max=limit)
+ up = up.clamp(min=-limit, max=limit)
+ glu = gate * torch.sigmoid(gate * alpha)
+ return (up + 1) * glu
+
+
+def silu_and_mul_activation(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor:
+ """
+ SiLU (Swish) activation with element-wise multiplication.
+
+ Formula:
+ output = silu(gate) * up
+
+ Args:
+ gate: Gate tensor
+ up: Up tensor
+
+ Returns:
+ Activated tensor
+ """
+ return F.silu(gate) * up
+
+
+def route_tokens_cpu(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: torch.Tensor | None,
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_normalize_expert_weights: int | None = None,
+) -> tuple:
+ """
+ Route tokens to experts and compute expert weights and indices (CPU version).
+
+ Args:
+ x: Input tensor [batch, seq, hidden] or [tokens, hidden]
+ router_weight: Router weight [num_experts, hidden]
+ router_bias: Router bias [num_experts] or None
+ moe_top_k: Number of experts per token
+ moe_num_experts: Total number of experts
+ moe_normalize_expert_weights: Normalization order or None
+
+ Returns:
+ Tuple of (logits, expert_weights, expert_indices)
+ """
+ x_flat = x.view(-1, x.shape[-1])
+ logits = F.linear(x_flat, router_weight, router_bias)
+
+ if moe_top_k == 1:
+ expert_weights, expert_indices = logits.max(dim=-1, keepdim=True)
+ else:
+ expert_weights, expert_indices = torch.topk(logits, moe_top_k, dim=-1)
+
+ expert_weights = expert_weights.softmax(dim=-1)
+
+ if moe_normalize_expert_weights is not None:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+
+ return logits, expert_weights, expert_indices
+
+
+def cpu_fused_moe(
+ hidden_states: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ w1_bias: torch.Tensor | None = None,
+ w2_bias: torch.Tensor | None = None,
+ activation: str = "silu",
+ alpha: float = 1.702,
+ limit: float = 7.0,
+ is_interleaved: bool = True,
+) -> torch.Tensor:
+ """
+ CPU Fused MoE using PyTorch operations.
+
+ This implementation processes all experts in parallel using batched operations
+ instead of sequential for loops, which is more efficient on CPU.
+
+ Args:
+ hidden_states: [num_tokens, hidden_size]
+ w1: [num_experts, hidden_size, 2*inter_size] - gate_up_proj weights
+ w2: [num_experts, inter_size, hidden_size] - down_proj weights
+ topk_weights: [num_tokens, topk] - routing weights
+ topk_ids: [num_tokens, topk] - expert indices
+ w1_bias: [num_experts, 2*inter_size] or None
+ w2_bias: [num_experts, hidden_size] or None
+ activation: "silu" or "swigluoai"
+ alpha: swigluoai alpha parameter
+ limit: swigluoai limit parameter
+ is_interleaved: whether gate_up is interleaved [g0,u0,g1,u1,...] (True for GptOss)
+
+ Returns:
+ output: [num_tokens, hidden_size]
+ """
+ num_tokens, hidden_size = hidden_states.shape
+ num_experts = w1.shape[0]
+ inter_size = w2.shape[1]
+ topk = topk_weights.shape[1]
+
+ # Initialize output
+ output = torch.zeros_like(hidden_states)
+
+ # Build expert mask: which tokens go to which expert
+ # expert_mask[expert_id] contains indices of (token_idx, topk_pos) pairs
+ for expert_idx in range(num_experts):
+ # Find tokens assigned to this expert
+ # mask shape: [num_tokens, topk], True where topk_ids == expert_idx
+ mask = (topk_ids == expert_idx)
+
+ if not mask.any():
+ continue
+
+ # Get token indices and topk positions
+ token_indices, topk_positions = torch.where(mask)
+
+ if len(token_indices) == 0:
+ continue
+
+ # Gather input tokens for this expert
+ # current_hidden: [num_selected_tokens, hidden_size]
+ current_hidden = hidden_states[token_indices]
+
+ # Get weights for this expert
+ # w1[expert_idx]: [hidden_size, 2*inter_size]
+ # w2[expert_idx]: [inter_size, hidden_size]
+ expert_w1 = w1[expert_idx] # [hidden_size, 2*inter_size]
+ expert_w2 = w2[expert_idx] # [inter_size, hidden_size]
+
+ # First projection: hidden @ w1 -> [num_selected, 2*inter_size]
+ gate_up = current_hidden @ expert_w1
+
+ # Add bias if present
+ if w1_bias is not None:
+ gate_up = gate_up + w1_bias[expert_idx]
+
+ # Split gate and up projections
+ if is_interleaved:
+ # GptOss uses interleaved layout: [g0, u0, g1, u1, ...]
+ gate = gate_up[..., ::2] # [num_selected, inter_size]
+ up = gate_up[..., 1::2] # [num_selected, inter_size]
+ else:
+ # Standard layout: [gate_all, up_all]
+ gate = gate_up[..., :inter_size]
+ up = gate_up[..., inter_size:]
+
+ # Apply activation
+ if activation == "swigluoai":
+ activated = swigluoai_activation(gate, up, alpha, limit)
+ else: # silu
+ activated = silu_and_mul_activation(gate, up)
+
+ # Second projection: activated @ w2 -> [num_selected, hidden_size]
+ expert_out = activated @ expert_w2
+
+ # Add bias if present
+ if w2_bias is not None:
+ expert_out = expert_out + w2_bias[expert_idx]
+
+ # Apply routing weights and accumulate
+ # weights shape: [num_selected]
+ weights = topk_weights[token_indices, topk_positions].unsqueeze(-1)
+ weighted_out = expert_out * weights
+
+ # Accumulate to output
+ output.index_add_(0, token_indices, weighted_out.to(output.dtype))
+
+ return output
+
+
+class MegaBlocksMoeMLP(torch.nn.Module):
+ """
+ CPU MoE MLP module that can be used as a drop-in replacement for
+ the transformers GptOssMLP when using @use_kernel_forward_from_hub.
+ """
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> tuple:
+ """
+ Forward pass through the MoE layer.
+
+ Args:
+ x: Input tensor of shape [batch_size, seq_len, hidden_size] or [tokens, hidden_size]
+
+ Returns:
+ Tuple of (output, expert_weights) where:
+ - output: Tensor of same shape as input
+ - expert_weights: Expert weights for each token [tokens, top_k]
+ """
+ # Get MoE parameters from the wrapped modules
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+
+ # Detect activation type
+ if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
+ activation = "swigluoai"
+ alpha = self.experts.alpha
+ limit = self.experts.limit
+ else:
+ activation = getattr(self.experts, "activation", "silu")
+ alpha = 1.702
+ limit = 7.0
+
+ # Get weight tensors
+ if hasattr(self.experts, "gate_up_proj"):
+ w1 = self.experts.gate_up_proj
+ is_interleaved = True # GptOss uses interleaved layout
+ elif hasattr(self.experts, "w1"):
+ w1 = self.experts.w1
+ w3 = getattr(self.experts, "w3", None)
+ if w3 is not None:
+ w1 = torch.cat([w1, w3], dim=-1)
+ is_interleaved = False
+ else:
+ raise AttributeError("experts module must have 'gate_up_proj' or 'w1' attribute")
+
+ if hasattr(self.experts, "down_proj"):
+ w2 = self.experts.down_proj
+ elif hasattr(self.experts, "w2"):
+ w2 = self.experts.w2
+ else:
+ raise AttributeError("experts module must have 'down_proj' or 'w2' attribute")
+
+ # Get optional bias tensors
+ w1_bias = getattr(self.experts, "gate_up_proj_bias", None)
+ w2_bias = getattr(self.experts, "down_proj_bias", None)
+
+ # Store original shape
+ in_shape = x.size()
+
+ # Route tokens to experts
+ logits, expert_weights, expert_indices = route_tokens_cpu(
+ x,
+ self.router.weight,
+ getattr(self.router, "bias", None),
+ moe_top_k,
+ moe_num_experts,
+ moe_normalize_expert_weights,
+ )
+
+ # Reshape input for fused MoE
+ x_flat = x.view(-1, x.shape[-1])
+
+ # Call CPU fused MoE
+ output = cpu_fused_moe(
+ hidden_states=x_flat,
+ w1=w1,
+ w2=w2,
+ topk_weights=expert_weights,
+ topk_ids=expert_indices,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ activation=activation,
+ alpha=alpha,
+ limit=limit,
+ is_interleaved=is_interleaved,
+ )
+
+ # Restore original shape
+ output = output.view(in_shape)
+
+ return output, expert_weights
+
+
+# Export classes and functions
+__all__ = [
+ "MegaBlocksMoeMLP",
+ "cpu_fused_moe",
+ "route_tokens_cpu",
+ "swigluoai_activation",
+ "silu_and_mul_activation",
+]
diff --git a/build/torch211-cxx11-cu130-x86_64-linux/cpu_moe_cpp.py b/build/torch211-cxx11-cu130-x86_64-linux/cpu_moe_cpp.py
new file mode 100644
index 0000000000000000000000000000000000000000..073ff66d24ce348fbb5ed19c9027fadd3f7a9c61
--- /dev/null
+++ b/build/torch211-cxx11-cu130-x86_64-linux/cpu_moe_cpp.py
@@ -0,0 +1,265 @@
+# SPDX-License-Identifier: Apache-2.0
+# MegaBlocks C++ Optimized CPU MoE
+
+"""
+C++ accelerated MoE with brgemm optimization for Intel AMX.
+Direct replacement for cpu_fused_moe.MegaBlocksMoeMLP with better performance.
+"""
+
+import torch
+from typing import Optional
+from .cpu_fused_moe import route_tokens_cpu
+from ._ops import ops
+
+
+def _to_local_tensor(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
+ """Convert DTensor to local torch.Tensor if needed for custom ops compatibility."""
+ if tensor is None:
+ return None
+ # Check if it's a DTensor by looking for the to_local() method
+ if hasattr(tensor, "to_local"):
+ return tensor.to_local()
+ return tensor
+
+
+def fused_moe_cpp(
+ hidden_states: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ inplace: bool = False,
+ use_int8_w8a8: bool = False,
+ use_fp8_w8a16: bool = False,
+ use_mxfp4: bool = False,
+ w1_scale: Optional[torch.Tensor] = None,
+ w2_scale: Optional[torch.Tensor] = None,
+ block_size: Optional[list] = None,
+ a1_scale: Optional[torch.Tensor] = None,
+ a2_scale: Optional[torch.Tensor] = None,
+ w1_bias: Optional[torch.Tensor] = None,
+ w2_bias: Optional[torch.Tensor] = None,
+ alpha: Optional[float] = None,
+ limit: Optional[float] = None,
+ is_vnni: bool = False,
+) -> torch.Tensor:
+ """
+ C++ Fused MoE with brgemm optimization (sglang compatible interface).
+
+ Uses at::native::cpublas::brgemm for efficient batch GEMM on Intel CPUs.
+ Supports both silu_and_mul (standard SwiGLU) and swigluoai (GptOss) activations.
+
+ Args:
+ hidden_states: Input tensor [M, K]
+ w1: Gate and up projections [E, 2N, K]
+ w2: Down projection [E, K, N]
+ topk_weights: Expert weights [M, topk]
+ topk_ids: Expert indices [M, topk]
+ inplace: Whether to use hidden_states as output
+ use_int8_w8a8: Use int8 quantization
+ use_fp8_w8a16: Use fp8 quantization
+ use_mxfp4: Use mxfp4 quantization
+ w1_scale, w2_scale: Quantization scales
+ block_size: Block size for fp8
+ a1_scale, a2_scale: Activation scales
+ w1_bias, w2_bias: Optional biases
+ alpha: swigluoai alpha parameter (set to enable swiglu)
+ limit: swigluoai limit parameter (set to enable swiglu)
+ is_vnni: Whether w1/w2 are already in VNNI packed format
+ """
+ # MXFP4/FP8 kernels only support bf16, convert if needed
+ orig_dtype = hidden_states.dtype
+ need_convert = ((use_mxfp4 or use_fp8_w8a16) and orig_dtype != torch.bfloat16) or orig_dtype == torch.float32
+ if need_convert:
+ hidden_states = hidden_states.to(torch.bfloat16)
+
+ # bias must match hidden_states dtype
+ if w1_bias is not None:
+ w1_bias = w1_bias.to(hidden_states.dtype)
+ if w2_bias is not None:
+ w2_bias = w2_bias.to(hidden_states.dtype)
+
+ # Convert DTensor to local tensor for custom ops compatibility (TP mode)
+ hidden_states = _to_local_tensor(hidden_states)
+ w1 = _to_local_tensor(w1)
+ w2 = _to_local_tensor(w2)
+ topk_weights = _to_local_tensor(topk_weights)
+ topk_ids = _to_local_tensor(topk_ids)
+ w1_scale = _to_local_tensor(w1_scale)
+ w2_scale = _to_local_tensor(w2_scale)
+ a1_scale = _to_local_tensor(a1_scale)
+ a2_scale = _to_local_tensor(a2_scale)
+ w1_bias = _to_local_tensor(w1_bias)
+ w2_bias = _to_local_tensor(w2_bias)
+
+ output = ops.fused_experts(
+ hidden_states, w1, w2, topk_weights, topk_ids,
+ inplace, use_int8_w8a8, use_fp8_w8a16, use_mxfp4,
+ w1_scale, w2_scale, block_size, a1_scale, a2_scale,
+ w1_bias, w2_bias, alpha, limit, is_vnni
+ )
+
+ # Convert back to original dtype if needed
+ if need_convert:
+ output = output.to(orig_dtype)
+ return output
+
+
+class CPUMegaBlocksMoeMLP(torch.nn.Module):
+ """
+ C++ optimized MoE MLP using brgemm.
+ Drop-in replacement for cpu_fused_moe.MegaBlocksMoeMLP with better performance.
+
+ Usage in transformers:
+ # Will be used via @use_kernel_forward_from_hub decorator
+ """
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> tuple:
+ """
+ Forward pass through the MoE layer using C++ kernel.
+
+ Args:
+ x: Input tensor [batch_size, seq_len, hidden_size]
+
+ Returns:
+ Tuple of (output, expert_weights)
+ """
+ # Optimization for GPT-OSS model
+ if getattr(self, "use_mxfp4", None) is None:
+ self.use_mxfp4 = False
+
+ w1_scale = None
+ w2_scale = None
+
+ if (
+ not getattr(self, "packed_scales", False)
+ and hasattr(self.experts, "gate_up_proj")
+ and getattr(self.experts, "gate_up_proj_precision_config", None) is not None
+ ):
+ # convert scales
+ data_1 = ops.convert_scale_packed(self.experts.gate_up_proj_precision_config.weight_scale.data.transpose(-1, -2).contiguous())
+ data_2 = ops.convert_scale_packed(self.experts.down_proj_precision_config.weight_scale.data.transpose(-1, -2).contiguous())
+ self.experts.gate_up_proj_precision_config.weight_scale.storage.data = data_1
+ self.experts.down_proj_precision_config.weight_scale.storage.data = data_2
+ self.packed_scales = True
+ self.use_mxfp4 = True
+
+ if not getattr(self, "packed_weight", False) and hasattr(
+ self.experts, "gate_up_proj"
+ ):
+ # convert weights
+ data_1 = self.experts.gate_up_proj.data.transpose(-1, -2).contiguous()
+ data_2 = self.experts.down_proj.data.transpose(-1, -2).contiguous()
+ if self.use_mxfp4:
+ self.experts.gate_up_proj.storage.data = ops.convert_weight_packed(data_1)
+ self.experts.down_proj.storage.data = ops.convert_weight_packed(data_2)
+ else:
+ # convert_weight_packed only supports bfloat16, float16, int8, fp8_e4m3 or uint8(mxfp4 or int4).
+ data_1 = data_1.to(torch.bfloat16) if data_1.dtype == torch.float32 else data_1
+ data_2 = data_2.to(torch.bfloat16) if data_2.dtype == torch.float32 else data_2
+ self.experts.gate_up_proj.data = ops.convert_weight_packed(data_1)
+ self.experts.down_proj.data = ops.convert_weight_packed(data_2)
+
+ # C++ kernel does not support float32.
+ dtype = torch.bfloat16 if x.dtype == torch.float32 else x.dtype
+ if getattr(self.experts, "gate_up_proj_bias", None) is not None:
+ self.experts.gate_up_proj_bias.data = self.experts.gate_up_proj_bias.data.to(dtype)
+ if getattr(self.experts, "down_proj_bias", None) is not None:
+ self.experts.down_proj_bias.data = self.experts.down_proj_bias.data.to(dtype)
+
+ self.packed_weight = True
+
+ # Get MoE parameters
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
+
+ # Detect activation type
+ if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
+ activation = "swigluoai"
+ alpha = self.experts.alpha
+ limit = self.experts.limit
+ else:
+ activation = getattr(self.experts, "activation", "silu")
+ alpha = 1.702
+ limit = 7.0
+
+ # Get weight tensors
+ if hasattr(self.experts, "gate_up_proj"):
+ w1 = self.experts.gate_up_proj
+ elif hasattr(self.experts, "w1"):
+ w1 = self.experts.w1
+ w3 = getattr(self.experts, "w3", None)
+ if w3 is not None:
+ w1 = torch.cat([w1, w3], dim=-1)
+ else:
+ raise AttributeError("experts module must have 'gate_up_proj' or 'w1' attribute")
+
+ if hasattr(self.experts, "down_proj"):
+ w2 = self.experts.down_proj
+ elif hasattr(self.experts, "w2"):
+ w2 = self.experts.w2
+ else:
+ raise AttributeError("experts module must have 'down_proj' or 'w2' attribute")
+
+ # Get optional bias tensors
+ w1_bias = getattr(self.experts, "gate_up_proj_bias", None)
+ w2_bias = getattr(self.experts, "down_proj_bias", None)
+ w1_bias = w1_bias if w1_bias is None else w1_bias.data
+ w2_bias = w2_bias if w2_bias is None else w2_bias.data
+
+ if self.use_mxfp4:
+ w1_scale = self.experts.gate_up_proj_precision_config.weight_scale.data
+ w2_scale = self.experts.down_proj_precision_config.weight_scale.data
+
+ # Store original shape
+ in_shape = x.size()
+
+ # Route tokens to experts (Python implementation is fast enough)
+ logits, expert_weights, expert_indices = route_tokens_cpu(
+ x,
+ self.router.weight,
+ getattr(self.router, "bias", None),
+ moe_top_k,
+ moe_num_experts,
+ moe_normalize_expert_weights,
+ )
+
+ # Flatten input
+ x_flat = x.view(-1, x.shape[-1])
+
+ # Determine alpha/limit for swiglu activation
+ use_alpha = alpha if activation == "swigluoai" else None
+ use_limit = limit if activation == "swigluoai" else None
+
+ # Call C++ optimized kernel
+ output = fused_moe_cpp(
+ hidden_states=x_flat,
+ w1=w1.data,
+ w2=w2.data,
+ topk_weights=expert_weights,
+ topk_ids=expert_indices.to(torch.int32),
+ inplace=False,
+ use_int8_w8a8=False,
+ use_fp8_w8a16=False,
+ use_mxfp4=self.use_mxfp4,
+ w1_scale=w1_scale,
+ w2_scale=w2_scale,
+ block_size=None,
+ a1_scale=None,
+ a2_scale=None,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ alpha=use_alpha,
+ limit=use_limit,
+ is_vnni=getattr(self, "packed_weight", False),
+ )
+
+ # Restore original shape
+ output = output.view(in_shape)
+
+ return output, expert_weights
+
+
+__all__ = ["fused_moe_cpp", "MegaBlocksMoeMLP"]
diff --git a/build/torch211-cxx11-cu130-x86_64-linux/grouped_gemm/__init__.py b/build/torch211-cxx11-cu130-x86_64-linux/grouped_gemm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b91c8308f0c24f4c4171b6e4f15b6f76dabf295a
--- /dev/null
+++ b/build/torch211-cxx11-cu130-x86_64-linux/grouped_gemm/__init__.py
@@ -0,0 +1,2 @@
+from . import ops
+from . import backend
diff --git a/build/torch211-cxx11-cu130-x86_64-linux/grouped_gemm/backend.py b/build/torch211-cxx11-cu130-x86_64-linux/grouped_gemm/backend.py
new file mode 100644
index 0000000000000000000000000000000000000000..76037d8039cbfc2f0577275c78e4bc0be762592a
--- /dev/null
+++ b/build/torch211-cxx11-cu130-x86_64-linux/grouped_gemm/backend.py
@@ -0,0 +1,33 @@
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# # TODO(tgale): Wrap this in a try-block with better
+# # error message and instructions for building the
+# # c++ operations.
+# import grouped_gemm_backend as backend
+
+# We import the backend operations from the megablocks package as
+# grouped_gemm is vendored in megablocks in this repository.
+# from ... import _ops as backend
+# from megablocks._ops import ops as backend # type: ignore
+from .._ops import ops as backend # type: ignore
+
+def _allocate_output(a, b, batch_sizes, trans_a, trans_b):
+ assert not (trans_a and trans_b)
+ assert batch_sizes.ndim == 1, "Expected 1d tensor for batch_sizes"
+ assert a.ndim == 2, "Expected 2d tensor for 'a'"
+ assert b.ndim == (2 if trans_a else 3)
+
+ shape = (
+ (batch_sizes.shape[0], a.shape[1], b.shape[1])
+ if trans_a else
+ (a.shape[0], (b.shape[1] if trans_b else b.shape[2]))
+ )
+ return torch.empty(*shape, device=a.device, dtype=a.dtype)
+
+def gmm(a, b, batch_sizes, trans_a=False, trans_b=False, c=None):
+ if c is None:
+ c = _allocate_output(a, b, batch_sizes, trans_a, trans_b)
+ backend.gmm(a, b, c, batch_sizes, trans_a, trans_b)
+ return c
diff --git a/build/torch211-cxx11-cu130-x86_64-linux/grouped_gemm/ops.py b/build/torch211-cxx11-cu130-x86_64-linux/grouped_gemm/ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b30dd14e23837ea3b12334f4e31337ed9ad2b69
--- /dev/null
+++ b/build/torch211-cxx11-cu130-x86_64-linux/grouped_gemm/ops.py
@@ -0,0 +1,33 @@
+from . import backend
+import torch
+
+
+class GroupedGemm(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, a, b, batch_sizes, trans_b):
+ ctx.save_for_backward(a, b, batch_sizes)
+ ctx.trans_b = trans_b
+ return backend.gmm(a, b, batch_sizes, trans_a=False, trans_b=trans_b)
+
+ @staticmethod
+ def backward(ctx, grad):
+ grad = grad.contiguous()
+ a, b, batch_sizes = ctx.saved_tensors
+ trans_b = ctx.trans_b
+
+ agrad = None
+ if ctx.needs_input_grad[0]:
+ agrad = backend.gmm(
+ grad, b, batch_sizes, trans_a=False, trans_b=not trans_b)
+
+ bgrad = None
+ if ctx.needs_input_grad[1]:
+ lhs, rhs = (grad, a) if trans_b else (a, grad)
+ bgrad = backend.gmm(
+ lhs, rhs, batch_sizes, trans_a=True, trans_b=False)
+ return agrad, bgrad, None, None
+
+
+def gmm(a, b, batch_sizes, trans_b=False):
+ return GroupedGemm.apply(a, b, batch_sizes, trans_b)
diff --git a/build/torch211-cxx11-cu130-x86_64-linux/grouped_gemm_util.py b/build/torch211-cxx11-cu130-x86_64-linux/grouped_gemm_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d6d49fc46b0a57ad46e4179df3cc1ac2a24f7ae
--- /dev/null
+++ b/build/torch211-cxx11-cu130-x86_64-linux/grouped_gemm_util.py
@@ -0,0 +1,31 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+import warnings
+
+_grouped_gemm_is_available: bool = False
+try:
+ # import grouped_gemm
+ pass
+ _grouped_gemm_is_available = True
+except ImportError as error:
+ warnings.warn('Grouped GEMM not available.')
+
+
+def grouped_gemm_is_available():
+ return _grouped_gemm_is_available
+
+
+def assert_grouped_gemm_is_available():
+ msg = (
+ 'Grouped GEMM not available. Please run '
+ '`pip install git+https://github.com/tgale96/grouped_gemm@main`.',
+ )
+ assert _grouped_gemm_is_available, msg
+
+
+# backend = grouped_gemm.backend if grouped_gemm_is_available() else None
+# ops = grouped_gemm.ops if grouped_gemm_is_available() else None
+
+
+from .grouped_gemm import backend as ops
+from .grouped_gemm import ops as backend
diff --git a/build/torch211-cxx11-cu130-x86_64-linux/layers.py b/build/torch211-cxx11-cu130-x86_64-linux/layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c91edcd9e2d1a4ef9eac90217ff481f08ab1886
--- /dev/null
+++ b/build/torch211-cxx11-cu130-x86_64-linux/layers.py
@@ -0,0 +1,1232 @@
+import torch
+import torch.distributed as dist
+
+from typing import Optional, Any, TYPE_CHECKING
+
+from . import _layers
+from . import ops
+
+# Conditional import for meta kernel registration
+if TYPE_CHECKING:
+
+ def register_fake(fn):
+ return lambda name: fn
+
+else:
+ try:
+ from torch.library import register_fake
+ except ImportError:
+ try:
+ from torch.library import impl_abstract as register_fake
+ except ImportError:
+ # Fallback for older PyTorch versions
+ def register_fake(op_name):
+ def decorator(fn):
+ return fn
+
+ return decorator
+
+
+# Meta kernel implementations for torch.compile compatibility
+def _install_meta_kernels():
+ """Install meta kernels for existing MegaBlocks operations"""
+
+ # Create wrapper functions that check for compilation and return meta tensors
+
+ # Patch ops.sort
+ if hasattr(ops, "sort"):
+ original_sort = ops.sort
+
+ def sort_with_meta(x, end_bit=None):
+ if torch.compiler.is_compiling():
+ # print("Using meta kernel for sort")
+ # Meta implementation - return tensors with correct shape/dtype/device
+ return torch.empty_like(x), torch.empty_like(x)
+ # print("Using original sort kernel")
+ return original_sort(x, end_bit)
+
+ ops.sort = sort_with_meta
+
+ # Patch ops.histogram
+ if hasattr(ops, "histogram"):
+ original_histogram = ops.histogram
+
+ def histogram_with_meta(x, max_val):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty((max_val,), dtype=torch.int32, device=x.device)
+ return original_histogram(x, max_val)
+
+ ops.histogram = histogram_with_meta
+
+ # Patch ops.inclusive_cumsum
+ if hasattr(ops, "inclusive_cumsum"):
+ original_inclusive_cumsum = ops.inclusive_cumsum
+
+ def inclusive_cumsum_with_meta(x, dim):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty_like(x)
+ return original_inclusive_cumsum(x, dim)
+
+ ops.inclusive_cumsum = inclusive_cumsum_with_meta
+
+ # Patch ops.binned_gather
+ if hasattr(ops, "binned_gather"):
+ original_binned_gather = ops.binned_gather
+
+ def binned_gather_with_meta(x, indices, bins, bin_size, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - output shape based on bin_size
+ if x.dim() >= 2:
+ hidden_size = x.size(-1)
+ return torch.empty(
+ (bin_size, x.size(1), hidden_size),
+ dtype=x.dtype,
+ device=x.device,
+ )
+ else:
+ return torch.empty((bin_size,), dtype=x.dtype, device=x.device)
+ return original_binned_gather(x, indices, bins, bin_size, top_k)
+
+ ops.binned_gather = binned_gather_with_meta
+
+ # Patch ops.binned_scatter
+ if hasattr(ops, "binned_scatter"):
+ original_binned_scatter = ops.binned_scatter
+
+ def binned_scatter_with_meta(x, indices, weights, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - typically reduces to 2D
+ if x.dim() >= 3:
+ return torch.empty(
+ (x.size(1), x.size(2)), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty_like(x)
+ return original_binned_scatter(x, indices, weights, bins, top_k)
+
+ ops.binned_scatter = binned_scatter_with_meta
+
+ # Patch ops.gather
+ if hasattr(ops, "gather"):
+ original_gather = ops.gather
+
+ def gather_with_meta(x, indices, bin_ids, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ if x.dim() >= 2:
+ hidden_size = x.size(-1)
+ return torch.empty(
+ (indices.numel(), hidden_size), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty(indices.shape, dtype=x.dtype, device=x.device)
+ return original_gather(x, indices, bin_ids, bins, top_k)
+
+ ops.gather = gather_with_meta
+
+ # Patch ops.scatter
+ if hasattr(ops, "scatter"):
+ original_scatter = ops.scatter
+
+ def scatter_with_meta(x, indices, bin_ids, weights, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - restore sequence shape
+ seq_len = (
+ indices.size(0) // top_k
+ if indices.numel() > 0 and top_k > 0
+ else x.size(0)
+ )
+ if x.dim() >= 2:
+ return torch.empty(
+ (seq_len, x.size(-1)), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty((seq_len,), dtype=x.dtype, device=x.device)
+ return original_scatter(x, indices, bin_ids, weights, bins, top_k)
+
+ ops.scatter = scatter_with_meta
+
+ # Patch ops.replicate
+ if hasattr(ops, "replicate"):
+ original_replicate = ops.replicate
+
+ def replicate_with_meta(x, bins, num_outputs):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty(
+ (x.shape[0], num_outputs), dtype=x.dtype, device=x.device
+ )
+ return original_replicate(x, bins, num_outputs)
+
+ ops.replicate = replicate_with_meta
+
+ # Patch ops.repeat (if it's a regular function)
+ if hasattr(ops, "repeat"):
+ original_repeat = ops.repeat
+
+ def repeat_with_meta(x, repeats):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ if isinstance(repeats, (tuple, list)):
+ new_shape = list(x.shape)
+ for i, rep in enumerate(repeats):
+ if i < len(new_shape):
+ new_shape[i] *= rep
+ return torch.empty(new_shape, dtype=x.dtype, device=x.device)
+ else:
+ new_shape = [x.size(0) * repeats] + list(x.shape[1:])
+ return torch.empty(new_shape, dtype=x.dtype, device=x.device)
+ return original_repeat(x, repeats)
+
+ ops.repeat = repeat_with_meta
+
+
+# Install meta kernels on import
+try:
+ _install_meta_kernels()
+except Exception as e:
+ # If meta kernel installation fails, continue without them
+ # torch.compile may not work but the library will still function
+ import warnings
+
+ warnings.warn(
+ f"Failed to install meta kernels for torch.compile support: {e}", UserWarning
+ )
+
+
+# Set the expert model parallel attributes on a tensor
+def set_expert_model_parallel_attributes(
+ tensor: torch.Tensor,
+ is_parallel: bool,
+):
+ assert not hasattr(tensor, "expert_model_parallel")
+ setattr(tensor, "expert_model_parallel", is_parallel)
+
+
+# Get the expert model parallel attributes from a tensor
+def expert_sharding_degree(
+ world_size: int,
+ moe_num_experts: int,
+) -> int:
+ esd = min(world_size, moe_num_experts)
+ if (moe_num_experts % esd) != 0:
+ raise ValueError(f"Cannot shard {moe_num_experts} experts {esd} ways.")
+ return esd
+
+
+# Calculate the hidden sharding degree based on world size and expert sharding degree
+def hidden_sharding_degree(
+ world_size: int,
+ moe_num_experts: int,
+ ffn_hidden_size: int,
+) -> int:
+ esd = expert_sharding_degree(world_size, moe_num_experts)
+ hsd = world_size // esd
+ if (ffn_hidden_size % hsd) != 0:
+ raise ValueError(f"Cannot shard {ffn_hidden_size} features {hsd} ways.")
+ if (esd * hsd) != world_size:
+ raise ValueError(
+ f"Invalid sharding. expert_sharding_degree ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size})."
+ )
+ return hsd
+
+
+# Calculate the number of experts per rank based on world size and expert sharding degree
+def experts_per_rank(
+ moe_num_experts: int,
+ world_size: int,
+) -> int:
+ return moe_num_experts // expert_sharding_degree(world_size, moe_num_experts)
+
+
+# Calculate the number of features per rank based on ffn hidden size and hidden sharding degree
+def features_per_rank(
+ ffn_hidden_size: int, world_size: int, moe_num_experts: int
+) -> int:
+ return ffn_hidden_size // hidden_sharding_degree(
+ world_size, moe_num_experts, ffn_hidden_size
+ )
+
+
+# Apply jitter to the input tensor
+def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor:
+ low = 1.0 - moe_jitter_eps
+ high = 1.0 + moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return x * (low + noise * (high - low))
+
+
+# Compute the top-k scores from the logits
+def compute_top_k(scores: torch.Tensor, moe_top_k: int):
+ if moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, moe_top_k, dim=-1)
+
+
+# Route tokens to experts and compute expert weights and indices
+def route_tokens(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: torch.Tensor,
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ if training and moe_jitter_eps is not None:
+ x = apply_jitter(x, moe_jitter_eps)
+
+ x_flat = x.view(-1, x.shape[-1])
+ logits = torch.nn.functional.linear(x_flat, router_weight, router_bias)
+ expert_weights, expert_indices = compute_top_k(logits, moe_top_k)
+ expert_weights = expert_weights.softmax(dim=-1)
+ if moe_normalize_expert_weights is not None:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+ if uniform_expert_assignment:
+ expert_indices = _layers.router._uniform_expert_assignment(
+ expert_indices,
+ moe_num_experts,
+ )
+
+ return logits, expert_weights, expert_indices
+
+
+# Scale the gradient of the weights
+def scale_grad(
+ w: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+) -> torch.Tensor:
+ if gradient_scale is None:
+ return w
+ return _layers.mlp.scale_gradient(w, gradient_scale)
+
+
+# Forward pass for the MLP layer
+def mlp_forward(
+ x: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ limit: float = 7.0,
+):
+ # Scale weights
+ w1 = scale_grad(w1, gradient_scale)
+ w2 = scale_grad(w2, gradient_scale)
+ w1_bias = scale_grad(w1_bias, gradient_scale)
+ w2_bias = scale_grad(w2_bias, gradient_scale)
+
+ # Resolve dtensors
+ w1 = _layers.mlp.resolve_dtensor(w1)
+ w2 = _layers.mlp.resolve_dtensor(w2)
+ w1_bias = _layers.mlp.resolve_dtensor(w1_bias)
+ w2_bias = _layers.mlp.resolve_dtensor(w2_bias)
+
+ # Forward pass
+ gate_up = torch.bmm(x, w1) + w1_bias[..., None, :]
+ gate, up = gate_up[..., ::2], gate_up[..., 1::2]
+ gate = gate.clamp(min=None, max=limit)
+ up = up.clamp(min=-limit, max=limit)
+ glu = gate * torch.sigmoid(gate * alpha)
+ next_states = torch.bmm(((up + 1) * glu), w2)
+ next_states += w2_bias[..., None, :]
+ return next_states
+
+# Shared expert MLP forward pass
+def shared_mlp_forward(
+ x: torch.Tensor,
+ up_proj_weight: torch.Tensor,
+ down_proj_weight: torch.Tensor,
+ up_proj_bias: Optional[torch.Tensor] = None,
+ down_proj_bias: Optional[torch.Tensor] = None,
+ activation_fn: Optional[Any] = None,
+ gradient_scale: Optional[float] = None,
+) -> torch.Tensor:
+ # Default activation function
+ if activation_fn is None:
+ activation_fn = torch.nn.functional.gelu
+
+ # Scale weights
+ up_proj_weight = scale_grad(up_proj_weight, gradient_scale)
+ down_proj_weight = scale_grad(down_proj_weight, gradient_scale)
+ if up_proj_bias is not None:
+ up_proj_bias = scale_grad(up_proj_bias, gradient_scale)
+ if down_proj_bias is not None:
+ down_proj_bias = scale_grad(down_proj_bias, gradient_scale)
+
+ # Resolve dtensors
+ up_proj_weight = _layers.mlp.resolve_dtensor(up_proj_weight)
+ down_proj_weight = _layers.mlp.resolve_dtensor(down_proj_weight)
+ if up_proj_bias is not None:
+ up_proj_bias = _layers.mlp.resolve_dtensor(up_proj_bias)
+ if down_proj_bias is not None:
+ down_proj_bias = _layers.mlp.resolve_dtensor(down_proj_bias)
+
+ # Up projection
+ x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias)
+
+ # Activation
+ x = activation_fn(x)
+
+ # Down projection
+ x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias)
+
+ return x
+
+
+# Combine outputs from shared expert and regular experts
+def combine_expert_shared_outputs(
+ shared_expert_out: torch.Tensor,
+ expert_out: torch.Tensor,
+ shared_expert_weighted_sum: bool = False,
+ moe_top_k: int = 1,
+) -> torch.Tensor:
+ if shared_expert_weighted_sum:
+ # Weighted sum based on number of experts used
+ total_experts = moe_top_k + 1
+ shared_weight = 1.0 / total_experts
+ expert_weight = moe_top_k / total_experts
+ return shared_expert_out * shared_weight + expert_out * expert_weight
+ else:
+ # Simple addition
+ return shared_expert_out + expert_out
+
+
+# Global variable to store load balancing loss
+_LOAD_BALANCING_LOSS = []
+
+
+def save_load_balancing_loss(loss):
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.append(loss)
+
+
+def get_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ return _LOAD_BALANCING_LOSS
+
+
+def clear_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.clear()
+
+
+def batched_load_balancing_loss(args):
+ if args.moe_loss_weight == 0:
+ return 0.0
+
+ tokens_per_expert, expert_scores = zip(*get_load_balancing_loss())
+ num_layers_per_pipeline_stage = args.num_layers // args.pipeline_model_parallel_size
+ if args.num_layers_per_virtual_pipeline_stage is not None:
+ num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage
+
+ if len(tokens_per_expert) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f"Expected {num_layers_per_pipeline_stage} token_per_experts "
+ f"but found {len(tokens_per_expert)}.\nnum_layers = "
+ f"{args.num_layers}\npipeline_model_parallel_size = "
+ f"{args.pipeline_model_parallel_size}\n"
+ "num_layers_per_virtual_pipeline_stage"
+ f" = {args.num_layers_per_virtual_pipeline_stage}",
+ )
+ if len(expert_scores) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f"Expected {num_layers_per_pipeline_stage} expert_scores "
+ f"but found {len(tokens_per_expert)}.\nnum_layers = "
+ f"{args.num_layers}\npipeline_model_parallel_size = "
+ f"{args.pipeline_model_parallel_size}\n"
+ "num_layers_per_virtual_pipeline_stage"
+ f" = {args.num_layers_per_virtual_pipeline_stage}",
+ )
+
+ # Verify the shape of the tokens_per_expert and expert_scores tensors.
+ assert all(
+ (x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert)
+ )
+
+ tokens = expert_scores[0].shape[0]
+ assert all(
+ (
+ (
+ x.ndim == 2
+ and x.shape[1] == args.moe_num_experts
+ and x.shape[0] == tokens
+ )
+ for x in expert_scores
+ )
+ )
+
+ # Concatenate the contributions of each layer and convert to
+ # the correct types and formats for the dot product.
+ expert_scores = torch.cat(expert_scores, dim=1)
+ if args.moe_lbl_in_fp32:
+ expert_scores = expert_scores.float()
+ if tokens != 0:
+ expert_scores = expert_scores.mean(dim=0)
+ else:
+ expert_scores = expert_scores.sum(dim=0)
+ tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype)
+
+ expected_values = num_layers_per_pipeline_stage * args.moe_num_experts
+ assert tokens_per_expert.numel() == expected_values
+ assert expert_scores.numel() == expected_values
+
+ # Calculate the total scale across all factors.
+ #
+ # loss_weight * num_experts / (num_layers * tokens * top_k)
+ scale_numerator = args.moe_num_experts * args.moe_loss_weight
+ scale_denominator = args.num_layers * tokens * args.moe_top_k
+ scale = scale_numerator / scale_denominator
+ return scale * torch.dot(tokens_per_expert, expert_scores)
+
+
+# Calculate the expert capacity based on tokens, top_k, number of experts,
+# expert parallel group, capacity factor, and whether expert model parallelism is used.
+def expert_capacity(
+ tokens: int,
+ top_k: int,
+ num_experts: int,
+ expert_parallel_group: int,
+ moe_capacity_factor: float,
+ moe_expert_model_parallelism: bool,
+) -> int:
+ world_size = (
+ dist.get_world_size(expert_parallel_group)
+ if moe_expert_model_parallelism
+ else 1
+ )
+
+ tokens_per_expert = top_k * tokens * world_size / num_experts
+ return int(moe_capacity_factor * tokens_per_expert)
+
+
+def load_balancing_loss(
+ tokens_per_expert: torch.Tensor,
+ expert_scores: torch.Tensor,
+ top_k: int,
+ num_experts: int,
+):
+ assert len(expert_scores.size()) == 2
+ tokens, num_experts = expert_scores.size()
+ assert num_experts == num_experts
+ assert len(tokens_per_expert.size()) == 1
+ (num_experts,) = tokens_per_expert.size()
+ assert num_experts == num_experts
+ scale = num_experts / (tokens * top_k)
+ return scale * torch.dot(
+ tokens_per_expert.to(expert_scores.dtype),
+ expert_scores.mean(dim=0),
+ )
+
+
+def indices_and_bins(
+ top_expert: torch.Tensor,
+ sort_end_bit: int,
+ num_experts: int,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ top_expert = top_expert.int()
+
+ # Ensure contiguous memory layout
+ top_expert = top_expert.contiguous()
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(top_expert.device):
+ output = ops.sort(top_expert, sort_end_bit)
+ bin_ids, indices = output
+ tokens_per_expert = ops.histogram(top_expert, num_experts)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+
+ bins = bins.view(1) if not len(bins.size()) else bins
+ return indices, bin_ids, bins, tokens_per_expert
+
+
+def expert_capacity_fn(
+ tokens: int,
+ top_k: int,
+ num_experts: int,
+ expert_parallel_group: torch.distributed.ProcessGroup,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+) -> int:
+ world_size = (
+ dist.get_world_size(expert_parallel_group)
+ if moe_expert_model_parallelism
+ else 1
+ )
+ tokens_per_expert = top_k * tokens * world_size / num_experts
+ return int(moe_capacity_factor * tokens_per_expert)
+
+
+def permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ top_k,
+ w1,
+ w2,
+ w1_bias,
+ w2_bias,
+ gradient_scale,
+ alpha,
+):
+ # Route tokens to experts
+ x = x.view(-1, x.shape[-1])
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(x.device):
+ x = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
+
+ # Expert computation
+ x = mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale, alpha)
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(x.device):
+ # Route tokens back
+ out = ops.binned_scatter(x, indices, expert_weights, bins, top_k)
+ return out
+
+
+def forward_once(
+ x: torch.Tensor,
+ expert_weights: torch.Tensor,
+ top_experts: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ top_k: int = 4,
+ num_experts: int = 128,
+ expert_parallel_group: int = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ mlp_impl: Optional[str] = None,
+):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
+ top_experts, sort_end_bit, num_experts
+ )
+
+ # Calculate expert capacity
+ sl, bs, _ = x.size()
+
+ expert_capacity = expert_capacity_fn(
+ sl * bs,
+ top_k,
+ num_experts,
+ expert_parallel_group,
+ moe_capacity_factor,
+ moe_expert_model_parallelism,
+ )
+
+ if expert_capacity == 0:
+ expert_capacity = torch.max(tokens_per_expert).item()
+
+ x = permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ top_k,
+ w1,
+ w2,
+ w1_bias,
+ w2_bias,
+ gradient_scale,
+ alpha,
+ )
+ return x, tokens_per_expert
+
+
+def parallel_forward_once(
+ x: torch.Tensor,
+ expert_weights: torch.Tensor,
+ top_experts: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ top_k: int = 4,
+ num_experts: int = 128,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = True,
+ hidden_size: int = 1152,
+ mlp_impl: Optional[str] = "grouped",
+):
+ # Flatten inputs
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+
+ # TODO: remove debugging var
+ # my_rank = dist.get_rank(expert_parallel_group) if expert_parallel_group else 0
+
+ with torch.no_grad():
+ # Step 1: Local permutation setup
+ indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
+ top_experts, sort_end_bit, num_experts
+ )
+
+ # Calculate sharding parameters
+ world_size = dist.get_world_size(expert_parallel_group)
+ hidden_sharding_deg = hidden_sharding_degree(
+ world_size, num_experts, hidden_size
+ )
+ experts_per_rank_val = experts_per_rank(num_experts, world_size)
+
+ # Replicate token counts for hidden sharding
+ repeated_tokens_per_expert = ops.repeat(
+ tokens_per_expert, (hidden_sharding_deg,)
+ )
+
+ # Exchange token counts across devices
+ parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert)
+
+ # Ensure CUB knows which device to use
+ tpe_handle = dist.all_to_all_single(
+ parallel_tokens_per_expert,
+ repeated_tokens_per_expert,
+ group=expert_parallel_group,
+ async_op=True,
+ )
+
+ # Step 2: Local permutation - group tokens by target device
+ x = x.view(-1, x.shape[-1]) # [sl * bs, hs]
+ x = ops.gather(x, indices, bin_ids, bins, top_k)
+
+ # Step 3: Compute communication counts and exchange tokens
+ with torch.no_grad():
+ tpe_handle.wait()
+
+ # Reshape for per-device calculations
+ repeated_tokens_per_expert = repeated_tokens_per_expert.view(
+ world_size, experts_per_rank_val
+ )
+ parallel_tokens_per_expert = parallel_tokens_per_expert.view(
+ world_size, experts_per_rank_val
+ )
+
+ # Calculate send/recv counts
+ send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1).tolist()
+ # recv_counts = parallel_tokens_per_expert.cpu().sum(dim=-1).tolist()
+ parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
+ recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1).tolist()
+ tokens_received = sum(recv_counts)
+
+ # Replicate for hidden sharding
+ x = ops.repeat(x, (hidden_sharding_deg, 1))
+
+ # Cross-device token exchange
+ parallel_x, parallel_x_handle = _layers.all_to_all.all_to_all(
+ x, recv_counts, send_counts, expert_parallel_group, async_op=True
+ )
+
+ with torch.no_grad():
+ # Step 4: Setup for local expert computation
+ replicate_bins = ops.inclusive_cumsum(parallel_tokens_per_expert.flatten(), 0)
+ replicate_bins = (
+ replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins
+ )
+
+ # Create expert indices for received tokens
+ parallel_top_expert = torch.remainder(
+ torch.arange(
+ num_experts * hidden_sharding_deg,
+ dtype=torch.int32,
+ device=indices.device,
+ ),
+ experts_per_rank_val,
+ )
+ parallel_top_expert = ops.replicate(
+ parallel_top_expert.unsqueeze(dim=0),
+ replicate_bins,
+ tokens_received,
+ ).flatten()
+
+ # Sort tokens by expert assignment
+ parallel_bin_ids, parallel_indices = ops.sort(
+ parallel_top_expert,
+ sort_end_bit,
+ )
+
+ # Calculate bins for local experts
+ parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
+ dim=0, dtype=torch.int
+ )
+ parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = (
+ parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins
+ )
+
+ # Calculate expert capacity
+ expert_capacity = expert_capacity_fn(
+ tokens_received,
+ top_k,
+ experts_per_rank_val,
+ expert_parallel_group,
+ moe_capacity_factor,
+ moe_expert_model_parallelism,
+ )
+ if expert_capacity == 0:
+ expert_capacity = torch.max(parallel_tokens_per_expert).item()
+
+ # Locally permute the tokens and perform the expert computation.
+ # Block to make sure that the cross-device permutation is complete.
+ if mlp_impl == "grouped":
+ # GroupedMLP requires counts on CPU. We can use the tensor already
+ # moved to CPU for the prior all_to_all, which avoids an extra
+ # device synchronization.
+ parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+
+ # Step 5: Expert computation
+ parallel_x_handle.wait()
+
+ parallel_x = permute_and_compute(
+ parallel_x,
+ parallel_tokens_per_expert,
+ parallel_indices,
+ parallel_bin_ids,
+ None, # expert_weights
+ parallel_bins,
+ expert_capacity,
+ top_k=1,
+ w1=w1,
+ w2=w2,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ )
+
+ # Step 6: Reverse communication - send results back
+ x, _ = _layers.all_to_all.all_to_all(
+ parallel_x, send_counts, recv_counts, expert_parallel_group
+ )
+
+ # Step 7: Reduce across hidden sharding dimension
+ shape = (hidden_sharding_deg, -1, hidden_size)
+ x = x.view(shape).sum(dim=0)
+
+ # Step 8: Final local unpermutation
+ x = ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
+
+ return x, tokens_per_expert.flatten()
+
+
+def moe_forward(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: Optional[torch.Tensor],
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+ w1: torch.Tensor = None,
+ w2: torch.Tensor = None,
+ w1_bias: torch.Tensor = None,
+ w2_bias: torch.Tensor = None,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ forward_fn: Any = None,
+ hidden_size: int = None,
+ mlp_impl: str = "grouped",
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+
+ # Route tokens to experts
+ logits, expert_weights, expert_indices = route_tokens(
+ x,
+ router_weight,
+ router_bias,
+ moe_top_k,
+ moe_num_experts,
+ moe_jitter_eps,
+ moe_normalize_expert_weights,
+ uniform_expert_assignment,
+ training,
+ )
+
+ # Create router scores for output
+ router_scores = (
+ torch.zeros_like(logits)
+ .scatter_(1, expert_indices, expert_weights)
+ .transpose(0, 1)
+ )
+
+ in_shape = x.size()
+
+ # Prepare forward function arguments
+ forward_args = {
+ "x": x,
+ "expert_weights": expert_weights,
+ "top_experts": expert_indices,
+ "w1": w1,
+ "w2": w2,
+ "w1_bias": w1_bias,
+ "w2_bias": w2_bias,
+ "gradient_scale": gradient_scale,
+ "alpha": alpha,
+ "sort_end_bit": sort_end_bit,
+ "top_k": moe_top_k,
+ "num_experts": moe_num_experts,
+ "expert_parallel_group": expert_parallel_group,
+ "moe_capacity_factor": moe_capacity_factor,
+ "moe_expert_model_parallelism": moe_expert_model_parallelism,
+ "mlp_impl": mlp_impl,
+ }
+
+ # Add hidden_size for parallel forward
+ if moe_expert_model_parallelism and hidden_size is not None:
+ forward_args["hidden_size"] = hidden_size
+ elif moe_expert_model_parallelism and hidden_size is None:
+ # Infer hidden_size from input shape
+ forward_args["hidden_size"] = x.shape[-1]
+
+ # Compute expert outputs
+ x, tokens_per_expert = forward_fn(**forward_args)
+
+ # Save load balancing loss if needed
+ moe_loss_weight = 0.0 # Can be made configurable
+ if training and moe_loss_weight > 0:
+ save_load_balancing_loss((tokens_per_expert, logits))
+
+ # Restore original shape
+ x = x.view(in_shape)
+
+ return x, expert_weights, router_scores
+
+
+def moe_forward_with_shared_expert(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: Optional[torch.Tensor],
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+ w1: torch.Tensor = None,
+ w2: torch.Tensor = None,
+ w1_bias: torch.Tensor = None,
+ w2_bias: torch.Tensor = None,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ forward_fn: Any = None,
+ hidden_size: int = None,
+ mlp_impl: str = "grouped",
+ # Shared expert parameters
+ shared_up_proj_weight: Optional[torch.Tensor] = None,
+ shared_down_proj_weight: Optional[torch.Tensor] = None,
+ shared_up_proj_bias: Optional[torch.Tensor] = None,
+ shared_down_proj_bias: Optional[torch.Tensor] = None,
+ shared_expert_weighted_sum: bool = False,
+ shared_activation_fn: Optional[Any] = None,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+
+ # First, compute regular MoE forward pass
+ expert_out, expert_weights, router_scores = moe_forward(
+ x=x,
+ router_weight=router_weight,
+ router_bias=router_bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=training,
+ w1=w1,
+ w2=w2,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=moe_expert_model_parallelism,
+ forward_fn=forward_fn,
+ hidden_size=hidden_size,
+ mlp_impl=mlp_impl,
+ )
+
+ # If shared expert weights provided, compute shared expert output
+ if shared_up_proj_weight is not None and shared_down_proj_weight is not None:
+ shared_expert_out = shared_mlp_forward(
+ x=x,
+ up_proj_weight=shared_up_proj_weight,
+ down_proj_weight=shared_down_proj_weight,
+ up_proj_bias=shared_up_proj_bias,
+ down_proj_bias=shared_down_proj_bias,
+ activation_fn=shared_activation_fn,
+ gradient_scale=gradient_scale,
+ )
+
+ # Combine expert outputs
+ combined_out = combine_expert_shared_outputs(
+ shared_expert_out=shared_expert_out,
+ expert_out=expert_out,
+ shared_expert_weighted_sum=shared_expert_weighted_sum,
+ moe_top_k=moe_top_k,
+ )
+
+ return combined_out, expert_weights, router_scores
+
+ # Return regular MoE output if no shared expert
+ return expert_out, expert_weights, router_scores
+
+
+def create_shared_expert_weights(
+ hidden_size: int,
+ shared_expert_hidden_size: int,
+ device: torch.device,
+ dtype: torch.dtype,
+ init_method: Any,
+ output_layer_init_method: Any = None,
+) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
+
+ if output_layer_init_method is None:
+ output_layer_init_method = init_method
+
+ # Create weight tensors
+ up_proj_weight = torch.empty(
+ shared_expert_hidden_size,
+ hidden_size,
+ device=device,
+ dtype=dtype,
+ )
+ down_proj_weight = torch.empty(
+ hidden_size,
+ shared_expert_hidden_size,
+ device=device,
+ dtype=dtype,
+ )
+
+ # Initialize weights
+ init_method(up_proj_weight)
+ output_layer_init_method(down_proj_weight)
+
+ # No bias by default
+ return up_proj_weight, down_proj_weight, None, None
+
+
+# HACK: Extract device_mesh from pre-hook closure - required for transformers integration
+# This exists because device_mesh is trapped in hook closures with no model attribute
+# Fragile - breaks if hook structure changes or Python internals change
+# TODO: Replace with a more robust solution when available
+def get_device_mesh(model):
+ # Extract device_mesh from child's unused pre_hook closure
+ try:
+ # Find the pre-hook that contains 'device_mesh' in its closure
+ hook = next(
+ h
+ for h in model.experts._forward_pre_hooks.values()
+ if "device_mesh" in h.__code__.co_freevars
+ )
+ # Extract the device_mesh from the closure
+ return hook.__closure__[
+ hook.__code__.co_freevars.index("device_mesh")
+ ].cell_contents
+ except Exception:
+ return None
+
+
+class MegaBlocksMoeMLP(torch.nn.Module):
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ gradient_scale = getattr(self.experts, "gradient_scale", None)
+ alpha = getattr(self.experts, "alpha", 1.0)
+ moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+ uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
+
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = get_device_mesh(self)
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
+
+ has_parallel = (
+ expert_parallel_group is not None
+ and dist.is_initialized()
+ and dist.get_world_size(expert_parallel_group) > 1
+ )
+ forward_fn = parallel_forward_once if has_parallel else forward_once
+
+ sort_end_bit = max(
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
+ )
+ mlp_impl = getattr(self, "mlp_impl", "grouped")
+ output, expert_weights_out, *_ = moe_forward(
+ x=x,
+ router_weight=self.router.weight,
+ router_bias=self.router.bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=self.training,
+ w1=self.experts.gate_up_proj,
+ w2=self.experts.down_proj,
+ w1_bias=self.experts.gate_up_proj_bias,
+ w2_bias=self.experts.down_proj_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=has_parallel,
+ forward_fn=forward_fn,
+ hidden_size=self.experts.hidden_size,
+ mlp_impl=mlp_impl,
+ )
+ return output, expert_weights_out
+
+
+# Export main classes
+__all__ = ["MegaBlocksMoeMLP", "MegaBlocksMoeMLPWithSharedExpert"]
+
+
+class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
+
+ def __init__(self):
+ super().__init__()
+ # Shared expert weights will be set by the user
+ self.shared_up_proj_weight = None
+ self.shared_down_proj_weight = None
+ self.shared_up_proj_bias = None
+ self.shared_down_proj_bias = None
+ self.shared_expert_weighted_sum = False
+ self.shared_activation_fn = None
+
+ def set_shared_expert_weights(
+ self,
+ up_proj_weight: torch.Tensor,
+ down_proj_weight: torch.Tensor,
+ up_proj_bias: Optional[torch.Tensor] = None,
+ down_proj_bias: Optional[torch.Tensor] = None,
+ weighted_sum: bool = False,
+ activation_fn: Optional[Any] = None,
+ ):
+ self.shared_up_proj_weight = up_proj_weight
+ self.shared_down_proj_weight = down_proj_weight
+ self.shared_up_proj_bias = up_proj_bias
+ self.shared_down_proj_bias = down_proj_bias
+ self.shared_expert_weighted_sum = weighted_sum
+ self.shared_activation_fn = activation_fn
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ gradient_scale = getattr(self.experts, "gradient_scale", None)
+ alpha = getattr(self.experts, "alpha", 1.0)
+ moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+ uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
+
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = get_device_mesh(self)
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
+
+ has_parallel = (
+ expert_parallel_group is not None
+ and dist.is_initialized()
+ and dist.get_world_size(expert_parallel_group) > 1
+ )
+ forward_fn = parallel_forward_once if has_parallel else forward_once
+
+ sort_end_bit = max(
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
+ )
+ mlp_impl = getattr(self, "mlp_impl", "grouped")
+
+ output, expert_weights_out, *_ = moe_forward_with_shared_expert(
+ x=x,
+ router_weight=self.router.weight,
+ router_bias=self.router.bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=self.training,
+ w1=self.experts.gate_up_proj,
+ w2=self.experts.down_proj,
+ w1_bias=self.experts.gate_up_proj_bias,
+ w2_bias=self.experts.down_proj_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=has_parallel,
+ forward_fn=forward_fn,
+ hidden_size=self.experts.hidden_size,
+ mlp_impl=mlp_impl,
+ # Shared expert parameters
+ shared_up_proj_weight=self.shared_up_proj_weight,
+ shared_down_proj_weight=self.shared_down_proj_weight,
+ shared_up_proj_bias=self.shared_up_proj_bias,
+ shared_down_proj_bias=self.shared_down_proj_bias,
+ shared_expert_weighted_sum=self.shared_expert_weighted_sum,
+ shared_activation_fn=self.shared_activation_fn,
+ )
+ return output, expert_weights_out
+
+
+# Patch for XPU or CPU support
+if hasattr(torch, "xpu") and torch.xpu.is_available():
+ from .xpu_fused_moe import MegaBlocksMoeMLP
+
+from .cpu_moe_cpp import CPUMegaBlocksMoeMLP
diff --git a/build/torch211-cxx11-cu130-x86_64-linux/megablocks/__init__.py b/build/torch211-cxx11-cu130-x86_64-linux/megablocks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23
--- /dev/null
+++ b/build/torch211-cxx11-cu130-x86_64-linux/megablocks/__init__.py
@@ -0,0 +1,26 @@
+import ctypes
+import importlib.util
+import sys
+from pathlib import Path
+from types import ModuleType
+
+
+def _import_from_path(file_path: Path) -> ModuleType:
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
+ # it would also be used for other imports. So, we make a module name that
+ # depends on the path for it to be unique using the hex-encoded hash of
+ # the path.
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
+ module_name = path_hash
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ if spec is None:
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
+ module = importlib.util.module_from_spec(spec)
+ if module is None:
+ raise ImportError(f"Cannot load module {module_name} from spec")
+ sys.modules[module_name] = module
+ spec.loader.exec_module(module) # type: ignore
+ return module
+
+
+globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
diff --git a/build/torch211-cxx11-cu130-x86_64-linux/metadata.json b/build/torch211-cxx11-cu130-x86_64-linux/metadata.json
new file mode 100644
index 0000000000000000000000000000000000000000..a9813b81c6c98110d265c184f2016d728202289b
--- /dev/null
+++ b/build/torch211-cxx11-cu130-x86_64-linux/metadata.json
@@ -0,0 +1,18 @@
+{
+ "version": 1,
+ "license": "Apache-2.0",
+ "python-depends": [],
+ "backend": {
+ "type": "cuda",
+ "archs": [
+ "10.0",
+ "12.0",
+ "7.5",
+ "8.0",
+ "8.6",
+ "8.7",
+ "8.9",
+ "9.0"
+ ]
+ }
+}
diff --git a/build/torch211-cxx11-cu130-x86_64-linux/ops/__init__.py b/build/torch211-cxx11-cu130-x86_64-linux/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b944080df810d0b0cfc571f3009b0098a651f9b7
--- /dev/null
+++ b/build/torch211-cxx11-cu130-x86_64-linux/ops/__init__.py
@@ -0,0 +1,35 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from .binned_gather import binned_gather
+from .binned_scatter import binned_scatter
+from .cumsum import exclusive_cumsum, inclusive_cumsum
+from .gather import gather
+from .histogram import histogram
+from .padded_gather import padded_gather
+from .padded_scatter import padded_scatter
+from .repeat import repeat
+from .replicate import replicate
+from .round_up import round_up
+from .scatter import scatter
+from .sort import sort
+from .sum import sum
+from .topology import topology
+
+__all__ = [
+ 'binned_gather',
+ 'binned_scatter',
+ 'exclusive_cumsum',
+ 'inclusive_cumsum',
+ 'gather',
+ 'histogram',
+ 'padded_gather',
+ 'padded_scatter',
+ 'repeat',
+ 'replicate',
+ 'round_up',
+ 'scatter',
+ 'sort',
+ 'sum',
+ 'topology',
+]
diff --git a/build/torch211-cxx11-cu130-x86_64-linux/ops/all_to_all_benchmark.py b/build/torch211-cxx11-cu130-x86_64-linux/ops/all_to_all_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c939818edca3345f6344bbc7cef07ffe3cd0181
--- /dev/null
+++ b/build/torch211-cxx11-cu130-x86_64-linux/ops/all_to_all_benchmark.py
@@ -0,0 +1,63 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import torch.distributed as dist
+
+# from megablocks import benchmark_util
+# from megablocks.layers.all_to_all import all_to_all
+
+from .. import benchmark_util
+from .._layers.all_to_all import all_to_all
+
+_ALL_TO_ALL_BENCHMARK = (
+ (8, 1024),
+ (16, 1024),
+ (32, 1024),
+ (64, 1024),
+ (128, 1024),
+ (256, 1024),
+ (512, 1024),
+ (1024, 1024),
+ (2 * 1024, 1024),
+ (4 * 1024, 1024),
+ (8 * 1024, 1024),
+ (16 * 1024, 1024),
+ (32 * 1024, 1024),
+ (64 * 1024, 1024),
+ (128 * 1024, 1024),
+ (256 * 1024, 1024),
+ (512 * 1024, 1024),
+ (1024 * 1024, 1024),
+)
+
+
+def benchmark_all_to_all(group, sl, hs):
+ world_size = dist.get_world_size(group)
+ assert (sl % world_size) == 0
+ send_recv_sizes = [sl // world_size] * world_size
+
+ x = torch.randn((sl, hs)).cuda().half()
+
+ details = {
+ 'world_size': world_size,
+ 'message_size (B)': send_recv_sizes[0] * hs * 2, # 2B elements.
+ }
+
+ def benchmark():
+ return all_to_all(x, send_recv_sizes, send_recv_sizes, group)
+
+ time, std = benchmark_util.benchmark_function(benchmark)
+
+ if dist.get_rank(group) == 0:
+ benchmark_util.log_benchmark('All-To-All', details, time, std)
+
+
+if __name__ == '__main__':
+ assert dist.is_available()
+ group = dist.init_process_group(backend='nccl')
+ local_rank = dist.get_rank(group)
+ torch.cuda.set_device(local_rank)
+
+ for args in _ALL_TO_ALL_BENCHMARK:
+ benchmark_all_to_all(group, *args)
diff --git a/build/torch211-cxx11-cu130-x86_64-linux/ops/binned_gather.py b/build/torch211-cxx11-cu130-x86_64-linux/ops/binned_gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..189a7fa3518d660f29ea32e7a04827164af98d60
--- /dev/null
+++ b/build/torch211-cxx11-cu130-x86_64-linux/ops/binned_gather.py
@@ -0,0 +1,37 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for binned_gather kernel.
+class BinnedGatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bins: torch.Tensor,
+ bin_size: int,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bins)
+ ctx.top_k = top_k
+ return kernels.binned_gather(x, indices, None, bins, bin_size, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ indices, bins = ctx.saved_tensors
+ out = kernels.binned_scatter(grad, indices, None, bins, ctx.top_k)
+ return out, None, None, None, None
+
+
+binned_gather = BinnedGatherOp.apply
diff --git a/build/torch211-cxx11-cu130-x86_64-linux/ops/binned_scatter.py b/build/torch211-cxx11-cu130-x86_64-linux/ops/binned_scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb937c0c106662ce8108c1cb926f8f063b163d3d
--- /dev/null
+++ b/build/torch211-cxx11-cu130-x86_64-linux/ops/binned_scatter.py
@@ -0,0 +1,59 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for binned_scatter kernel.
+class BinnedScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ):
+ assert len(x.size()) == 3
+ ctx.bin_size = x.size(1)
+ ctx.top_k = top_k
+
+ # TODO(tgale): Don't save 'x' for backwards if we don't need to
+ # calculate the gradient w.r.t. 'weights'.
+ ctx.save_for_backward(x, indices, weights, bins)
+ return kernels.binned_scatter(x, indices, weights, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ x, indices, weights, bins = ctx.saved_tensors
+ out = kernels.binned_gather(
+ grad,
+ indices,
+ weights,
+ bins,
+ ctx.bin_size,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[2]:
+ wgrad = kernels.binned_scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bins,
+ ctx.top_k,
+ )
+ return out, None, wgrad, None, None
+
+
+binned_scatter = BinnedScatterOp.apply
diff --git a/build/torch211-cxx11-cu130-x86_64-linux/ops/cumsum.py b/build/torch211-cxx11-cu130-x86_64-linux/ops/cumsum.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2b7572391e20045d335cf7337246e8a9b9f57ef
--- /dev/null
+++ b/build/torch211-cxx11-cu130-x86_64-linux/ops/cumsum.py
@@ -0,0 +1,52 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ # import megablocks_ops as ops # type: ignore
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrappers for cumsum kernels.
+# NOTE: Does not support gradients.
+class ExclusiveCumsumOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, dim: int):
+ if len(x.size()) == 1:
+ x = x.view([1, -1])
+ out = torch.empty_like(x)
+ ops.exclusive_cumsum(x, 1, out)
+ return out.squeeze()
+ out = torch.empty_like(x)
+ ops.exclusive_cumsum(x, dim, out)
+ return out
+
+
+exclusive_cumsum = ExclusiveCumsumOp.apply
+
+
+class InclusiveCumsumOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, dim: int) -> torch.Tensor:
+ if len(x.size()) == 1:
+ x = x.view([1, -1])
+ out = torch.empty_like(x)
+ ops.inclusive_cumsum(x, 1, out)
+ return out.squeeze()
+ out = torch.empty_like(x)
+ ops.inclusive_cumsum(x, dim, out)
+ return out
+
+
+inclusive_cumsum = InclusiveCumsumOp.apply
diff --git a/build/torch211-cxx11-cu130-x86_64-linux/ops/gather.py b/build/torch211-cxx11-cu130-x86_64-linux/ops/gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1f87c1e7bed8d3589dd790805234976e0b05898
--- /dev/null
+++ b/build/torch211-cxx11-cu130-x86_64-linux/ops/gather.py
@@ -0,0 +1,38 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for gather kernel.
+class GatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bin_ids, bins)
+ ctx.top_k = top_k
+ return kernels.gather(x, indices, bin_ids, None, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+
+ indices, bin_ids, bins = ctx.saved_tensors
+ out = kernels.scatter(grad, indices, bin_ids, None, bins, ctx.top_k)
+ return out, None, None, None, None, None
+
+
+gather = GatherOp.apply
diff --git a/build/torch211-cxx11-cu130-x86_64-linux/ops/histogram.py b/build/torch211-cxx11-cu130-x86_64-linux/ops/histogram.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b3f058ec373cbba7555704fb5e4212c3cc75d9d
--- /dev/null
+++ b/build/torch211-cxx11-cu130-x86_64-linux/ops/histogram.py
@@ -0,0 +1,27 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for histogram kernel.
+# NOTE: Does not support gradients.
+class HistogramOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, max_val: float):
+ return ops.histogram(x, max_val)
+
+
+histogram = HistogramOp.apply
diff --git a/build/torch211-cxx11-cu130-x86_64-linux/ops/histogram_benchmark.py b/build/torch211-cxx11-cu130-x86_64-linux/ops/histogram_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0b9c6047567b87a295979498142230d1b0c9411
--- /dev/null
+++ b/build/torch211-cxx11-cu130-x86_64-linux/ops/histogram_benchmark.py
@@ -0,0 +1,78 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import numpy as np
+import torch
+# from absl.testing import parameterized
+
+from .. import ops
+
+_HISTOGRAM_TESTS = (
+ (16384, torch.int32, 2),
+ (16384, torch.int32, 4),
+ (16384, torch.int32, 8),
+ (16384, torch.int32, 16),
+ (16384, torch.int32, 32),
+ (16384, torch.int32, 64),
+ (16384, torch.int32, 128),
+ (16384, torch.int32, 256),
+)
+
+
+def benchmark_function(fn, iterations=10):
+ # Run once to get rid of startup overhead.
+ fn()
+ times = []
+ for _ in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+ start.record()
+ fn()
+ end.record()
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ times = np.array(times)
+ return times.mean(), times.std(), times.max(), times.min()
+
+
+def log_benchmark(arguments, mean_t, std_t):
+ print('=' * 60)
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t))
+ print('=' * 60)
+
+
+# class HistogramBenchmark(parameterized.TestCase):
+#
+# @parameterized.parameters(*_HISTOGRAM_TESTS)
+# def testHistogram(self, n, dtype, max_val):
+# x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
+#
+# mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.histogram(x, max_val),)
+# arguments = {
+# 'n': n,
+# 'dtype': dtype,
+# 'max_val': max_val,
+# }
+# log_benchmark(arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_HISTOGRAM_TESTS)
+# def testTorchHistogram(self, n, dtype, max_val):
+# x = torch.randint(0, 128, (n,)).cuda().to(dtype)
+#
+# mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.histc(x, max_val, 0, max_val - 1),)
+# arguments = {
+# 'n': n,
+# 'dtype': dtype,
+# 'max_val': max_val,
+# }
+# log_benchmark(arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch211-cxx11-cu130-x86_64-linux/ops/matmul_benchmark.py b/build/torch211-cxx11-cu130-x86_64-linux/ops/matmul_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c4b9b8866ffed2eb769b77f2320c82e5491ae0e
--- /dev/null
+++ b/build/torch211-cxx11-cu130-x86_64-linux/ops/matmul_benchmark.py
@@ -0,0 +1,415 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+
+# import stk
+
+# try:
+# import stk
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/ops/matmul_benchmark.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+# from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+
+# Calling tensor.t() calls tensor.transpose(0, 1) which calls
+# torch.as_strided(...). Circumvent this chain to avoid an overhead
+# this adds.
+def transpose_view(x):
+ return torch.as_strided(
+ x,
+ (x.shape[1], x.shape[0]),
+ (x.stride()[1], x.stride()[0]),
+ )
+
+
+_MATMUL_TESTS = (
+ (64 * 1024, 512, 2048, 64),
+ (32 * 1024, 768, 3072, 64),
+ (8 * 1024, 1024, 4096, 64),
+ (4 * 2048, 4096, 4 * 4096, 4),
+)
+
+
+def log_benchmark(name, arguments, time, std, flops):
+ benchmark_util.log_benchmark(name, arguments, time, std)
+ print('flops = {:.2f}B'.format(flops / 1e9))
+ print('throughput = {:.2f}T'.format(flops / 1e9 / time))
+ print('=' * 60)
+
+
+# class MatmulBenchmark(parameterized.TestCase):
+#
+# def build_sparse_matrix(self, x, padded_bins, fhs, ne):
+# blocking = 128
+# padded_tokens, _ = x.size()
+# assert padded_tokens % blocking == 0
+# assert fhs % blocking == 0
+#
+# # Offsets for the sparse matrix. All rows have the
+# # same number of nonzero blocks dictated by the
+# # dimensionality of a single expert.
+# block_rows = padded_tokens // blocking
+# blocks_per_row = fhs // blocking
+# offsets = torch.arange(
+# 0,
+# block_rows * blocks_per_row + 1,
+# blocks_per_row,
+# dtype=torch.int32,
+# device=x.device,
+# )
+#
+# # Indices for the sparse matrix. The indices for
+# # the intermediate matrix are dynamic depending
+# # on the mapping of tokens to experts.
+# column_indices = ops.topology(
+# padded_bins,
+# blocking,
+# block_rows,
+# blocks_per_row,
+# )
+# data = torch.empty(
+# column_indices.numel(),
+# blocking,
+# blocking,
+# dtype=torch.float16,
+# device=x.device,
+# )
+# shape = (padded_tokens, fhs * ne)
+# row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
+# return stk.Matrix(shape, data, row_indices, column_indices, offsets)
+#
+# def build_input_matrix(self, sl, hs, ne):
+# x = torch.randn((sl, hs)).cuda().half()
+#
+# # Assign tokens to experts uniformly.
+# top_expert = torch.arange(0, sl).cuda().int() % ne
+#
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(top_expert, ne)
+# padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+# padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+# out = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, 1)
+# return out, padded_bins
+#
+# def build_weight_matrix(self, ne, hs, fhs):
+# return torch.randn((hs, ne * fhs)).cuda().half()
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_Fwd_SDD_NT(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+# w = transpose_view(w)
+#
+# def benchmark():
+# return stk.ops.sdd(x, w, topo)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0::Fwd::SDD::NT',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_GradX_DSD_NN(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+#
+# def benchmark():
+# return stk.ops.dsd(topo, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0::GradX::DSD::NN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_GradW_DSD_TN(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+# topo = topo.t()
+#
+# def benchmark():
+# return stk.ops.dsd(topo, x)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0::GradW::DSD::TN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_Fwd_DSD_NN(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+#
+# def benchmark():
+# return stk.ops.dsd(x, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::Fwd::DSD::NN',
+# arguments,
+# mean_t,
+# std_t,
+# x.nnz * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_GradX_SDD_NT(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+# out = stk.ops.dsd(x, w)
+# w = transpose_view(w)
+#
+# def benchmark():
+# return stk.ops.sdd(out, w, x)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::GradX::SDD::NT',
+# arguments,
+# mean_t,
+# std_t,
+# x.nnz * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_GradW_DSD_TN(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+# out = stk.ops.dsd(x, w)
+# x = x.t()
+#
+# def benchmark():
+# return stk.ops.dsd(x, out)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::GradW::DSD::TN',
+# arguments,
+# mean_t,
+# std_t,
+# x.nnz * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_Fwd_DDD_NT(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, hs)).cuda().half()
+# w = torch.randn((ne, hs, fhs)).cuda().half()
+#
+# w = w.transpose(1, 2).contiguous()
+# w = w.transpose(1, 2)
+#
+# def benchmark():
+# return torch.bmm(x, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0::Fwd:DDD::NT',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_GradX_DDD_NN(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, hs)).cuda().half()
+# w = torch.randn((ne, hs, fhs)).cuda().half()
+# out = torch.bmm(x, w)
+# w = w.transpose(1, 2).contiguous()
+#
+# def benchmark():
+# return torch.bmm(out, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0:GradX:DDD::NN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_GradW_DDD_TN(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, hs)).cuda().half()
+# w = torch.randn((ne, hs, fhs)).cuda().half()
+# out = torch.bmm(x, w)
+# out = out.transpose(1, 2)
+#
+# def benchmark():
+# return torch.bmm(out, x)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0:GradW:DDD::TN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_Fwd_DDD_NN(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+# w = torch.randn((ne, fhs, hs)).cuda().half()
+#
+# def benchmark():
+# return torch.bmm(x, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::Fwd::DDD::NN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_GradX_DDD_NT(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+# w = torch.randn((ne, fhs, hs)).cuda().half()
+# out = torch.bmm(x, w)
+# w = torch.transpose(w, 1, 2)
+#
+# def benchmark():
+# return torch.bmm(out, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::GradX::DDD::NT',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_GradW_DDD_TN(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+# w = torch.randn((ne, fhs, hs)).cuda().half()
+# out = torch.bmm(x, w)
+# x = torch.transpose(x, 1, 2)
+#
+# def benchmark():
+# return torch.bmm(x, out)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::GradW::DDD::TN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * hs * 2,
+# )
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch211-cxx11-cu130-x86_64-linux/ops/padded_gather.py b/build/torch211-cxx11-cu130-x86_64-linux/ops/padded_gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1cf4047c9494394d2a3884ba8830179013db7ff
--- /dev/null
+++ b/build/torch211-cxx11-cu130-x86_64-linux/ops/padded_gather.py
@@ -0,0 +1,55 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for padded_gather kernel.
+class PaddedGatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bin_ids, bins, padded_bins)
+ ctx.top_k = top_k
+ return kernels.padded_gather(
+ x,
+ indices,
+ bin_ids,
+ None,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+
+ indices, bin_ids, bins, padded_bins = ctx.saved_tensors
+ out = kernels.padded_scatter(
+ grad,
+ indices,
+ bin_ids,
+ None,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+ return out, None, None, None, None, None
+
+
+padded_gather = PaddedGatherOp.apply
diff --git a/build/torch211-cxx11-cu130-x86_64-linux/ops/padded_scatter.py b/build/torch211-cxx11-cu130-x86_64-linux/ops/padded_scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..61e021b81497e472cda5d72bdac557a0ca92d262
--- /dev/null
+++ b/build/torch211-cxx11-cu130-x86_64-linux/ops/padded_scatter.py
@@ -0,0 +1,98 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for padded_scatter kernel.
+class PaddedScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+ ):
+ maybe_x = [x] if ctx.needs_input_grad[3] else []
+ ctx.save_for_backward(
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ *maybe_x,
+ )
+ ctx.top_k = top_k
+ ctx.x_shape = x.shape
+ return kernels.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ saved_tensors = ctx.saved_tensors
+
+ indices, bin_ids, weights, bins, padded_bins = saved_tensors[:5]
+ dgrad = None
+ if ctx.needs_input_grad[0]:
+ dgrad = kernels.padded_gather(
+ grad,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[3]: # need wgrad
+ x = saved_tensors[-1]
+ wgrad = kernels.padded_scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+ return dgrad, None, None, wgrad, None, None, None, None
+
+
+def padded_scatter(
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+):
+ return PaddedScatterOp.apply(
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
diff --git a/build/torch211-cxx11-cu130-x86_64-linux/ops/padded_scatter_benchmark.py b/build/torch211-cxx11-cu130-x86_64-linux/ops/padded_scatter_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..3cbe4735891446b46f93170c64c23fe63632bf93
--- /dev/null
+++ b/build/torch211-cxx11-cu130-x86_64-linux/ops/padded_scatter_benchmark.py
@@ -0,0 +1,66 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import torch
+# from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+_PADDED_SCATTER_BENCHMARK = (
+ # dMoE-Medium, 8-way EMP.
+ (1024 * 16, 1024, 8, 4),
+ # dMoE-Medium, post-all-to-all.
+ (1024 * 16 * 4, 1024, 8, 1),
+)
+
+
+# class PaddedScatterTest(parameterized.TestCase):
+#
+# @parameterized.parameters(*_PADDED_SCATTER_BENCHMARK)
+# def testPaddedScatter(self, sl, hs, ne, top_k):
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+#
+# # Randomly assign tokens to experts.
+# top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(top_expert, ne)
+# padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+# padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+#
+# # Sample weights for the scatter reduce.
+# weights = torch.rand((sl * top_k,)).cuda().half()
+#
+# # Gather the data to prepare for backwards.
+# x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
+#
+# def benchmark():
+# return ops.padded_scatter(
+# x,
+# indices,
+# bin_ids,
+# weights,
+# bins,
+# padded_bins,
+# top_k,
+# )
+#
+# time, std = benchmark_util.benchmark_function(benchmark)
+# benchmark_util.log_benchmark(
+# 'Padded Scatter',
+# {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# 'top_k': top_k,
+# },
+# time,
+# std,
+# )
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch211-cxx11-cu130-x86_64-linux/ops/permute_benchmark.py b/build/torch211-cxx11-cu130-x86_64-linux/ops/permute_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..697abddbb3a2082ec4ddd6d94f89f7faabb34b40
--- /dev/null
+++ b/build/torch211-cxx11-cu130-x86_64-linux/ops/permute_benchmark.py
@@ -0,0 +1,149 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import torch
+# from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+_PERMUTE_TESTS = (
+ (16384, 768, 2),
+ (16384, 768, 4),
+ (16384, 768, 8),
+ (16384, 768, 16),
+ (16384, 768, 32),
+ (16384, 768, 64),
+ (16384, 768, 128),
+ (16384 * 8, 768, 2),
+ (16384 * 8, 768, 4),
+ (16384 * 8, 768, 8),
+ (16384 * 8, 768, 16),
+ (16384 * 8, 768, 32),
+ (16384 * 8, 768, 64),
+ (16384 * 8, 768, 128),
+)
+
+
+# class PermuteBenchmark(parameterized.TestCase):
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testBinnedGather(self, sl, hs, ne):
+# # NOTE: Capacity factor == 1.
+# ec = sl // ne
+#
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+# top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(indices, ne)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+#
+# def benchmark():
+# return ops.binned_gather(x, indices, bins, ec)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('BinnedGather', arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testBinnedScatter(self, sl, hs, ne):
+# # NOTE: Capacity factor == 1.
+# ec = sl // ne
+#
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+# top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(indices, ne)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+# x = ops.binned_gather(x, indices, bins, ec)
+#
+# def benchmark():
+# return ops.binned_scatter(x, indices, bins)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('BinnedScatter', arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testPaddedGather(self, sl, hs, ne):
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+#
+# # Randomly assign tokens to experts.
+# top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(top_expert, ne)
+# padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+# padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+#
+# def benchmark():
+# return ops.padded_gather(x, indices, bin_ids, bins, padded_bins)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('PaddedGather', arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testPaddedScatter(self, sl, hs, ne):
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+#
+# # Randomly assign tokens to experts.
+# top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(top_expert, ne)
+# padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+# padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+# x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins)
+#
+# def benchmark():
+# return ops.padded_scatter(x, indices, bin_ids, bins, padded_bins)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('PaddedScatter', arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testCopy(self, sl, hs, ne):
+# # NOTE: Capacity factor == 1.
+# # ec = sl // ne
+#
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+# y = x.clone()
+#
+# def benchmark():
+# return y.copy_(x)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('Copy', arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch211-cxx11-cu130-x86_64-linux/ops/repeat.py b/build/torch211-cxx11-cu130-x86_64-linux/ops/repeat.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e9e09de5f857d51cd758ab30b2f3a846d6f9275
--- /dev/null
+++ b/build/torch211-cxx11-cu130-x86_64-linux/ops/repeat.py
@@ -0,0 +1,10 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+
+def repeat(x: torch.Tensor, tiling: torch.Size):
+ if all((t == 1 for t in tiling)):
+ return x
+ return x.repeat(*tiling)
diff --git a/build/torch211-cxx11-cu130-x86_64-linux/ops/replicate.py b/build/torch211-cxx11-cu130-x86_64-linux/ops/replicate.py
new file mode 100644
index 0000000000000000000000000000000000000000..26daf0eede330603a4b8ea7167faf1411d07ca93
--- /dev/null
+++ b/build/torch211-cxx11-cu130-x86_64-linux/ops/replicate.py
@@ -0,0 +1,36 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for replicate kernel.
+class ReplicateOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, bins: torch.Tensor, num_outputs: int):
+ ctx.save_for_backward(bins)
+ out = torch.empty((x.shape[0], num_outputs), dtype=x.dtype, device=x.device)
+ ops.replicate_forward(x, bins, out)
+ return out
+
+ @staticmethod
+ def backward(ctx: Any, grad: torch.Tensor):
+ bins, = ctx.saved_tensors
+ out = torch.empty((grad.shape[0], bins.shape[0]), dtype=grad.dtype, device=grad.device)
+ ops.replicate_backward(grad, bins, out)
+ return out, None, None
+
+
+replicate = ReplicateOp.apply
diff --git a/build/torch211-cxx11-cu130-x86_64-linux/ops/round_up.py b/build/torch211-cxx11-cu130-x86_64-linux/ops/round_up.py
new file mode 100644
index 0000000000000000000000000000000000000000..6cf6bc873c9f448c5fa9126ebcfd66e8688002af
--- /dev/null
+++ b/build/torch211-cxx11-cu130-x86_64-linux/ops/round_up.py
@@ -0,0 +1,14 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+
+def round_up(x: torch.Tensor, value: int):
+ assert isinstance(value, int)
+ assert x.dtype == torch.int32
+
+ # TODO(tgale): If this becomes and issue
+ # do this in a custom kernel. We only expect
+ # to use this on arrays of less than 1k elements.
+ return torch.div(x + (value - 1), value, rounding_mode='trunc') * value
diff --git a/build/torch211-cxx11-cu130-x86_64-linux/ops/scatter.py b/build/torch211-cxx11-cu130-x86_64-linux/ops/scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4605d9b46f387761b070352365f223dbfe69d47
--- /dev/null
+++ b/build/torch211-cxx11-cu130-x86_64-linux/ops/scatter.py
@@ -0,0 +1,72 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Optional
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for scatter kernel.
+class ScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ) -> torch.Tensor:
+ maybe_x = [x] if ctx.needs_input_grad[3] else []
+ ctx.save_for_backward(indices, bin_ids, weights, bins, *maybe_x)
+ ctx.top_k = top_k
+ ctx.x_shape = x.shape
+ return kernels.scatter(x, indices, bin_ids, weights, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ saved_tensors = ctx.saved_tensors
+
+ indices, bin_ids, weights, bins = saved_tensors[:4]
+ dgrad = None
+ if ctx.needs_input_grad[0]:
+ dgrad = kernels.gather(
+ grad,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[3]: # need wgrad
+ x = saved_tensors[-1]
+ wgrad = kernels.scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bin_ids,
+ bins,
+ ctx.top_k,
+ )
+ return dgrad, None, None, wgrad, None, None, None
+
+
+def scatter(
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+) -> Optional[torch.Tensor]:
+ return ScatterOp.apply(x, indices, bin_ids, weights, bins, top_k)
diff --git a/build/torch211-cxx11-cu130-x86_64-linux/ops/sort.py b/build/torch211-cxx11-cu130-x86_64-linux/ops/sort.py
new file mode 100644
index 0000000000000000000000000000000000000000..bda3bf64283e39533c2eae3627e76bb2d0262c9f
--- /dev/null
+++ b/build/torch211-cxx11-cu130-x86_64-linux/ops/sort.py
@@ -0,0 +1,38 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Optional, Tuple
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+_BITS_FOR_DTYPE = {
+ torch.int16: 16,
+ torch.int32: 32,
+ torch.int64: 64,
+}
+
+
+# Autograd wrapper for sort kernel.
+# NOTE: Does not support gradients.
+class SortOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, end_bit: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
+ if end_bit is None:
+ end_bit = _BITS_FOR_DTYPE[x.dtype]
+ x_out = torch.empty_like(x)
+ iota_out = torch.empty_like(x)
+ ops.sort(x, end_bit, x_out, iota_out)
+ return (x_out, iota_out)
+
+
+sort = SortOp.apply
diff --git a/build/torch211-cxx11-cu130-x86_64-linux/ops/sort_benchmark.py b/build/torch211-cxx11-cu130-x86_64-linux/ops/sort_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..11043c0824c36372585f1d9f48480c2a6ef32eb6
--- /dev/null
+++ b/build/torch211-cxx11-cu130-x86_64-linux/ops/sort_benchmark.py
@@ -0,0 +1,85 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import numpy as np
+import torch
+# from absl.testing import parameterized
+
+from .. import ops
+
+_SORT_TESTS = (
+ (16384, torch.int32, None),
+ (16384, torch.int32, 2),
+ (16384, torch.int32, 128),
+)
+
+_BASELINE_SORT_TESTS = ((16384,),)
+
+
+def numpy_dtype(dtype):
+ types = {
+ torch.int16: np.int16,
+ torch.int32: np.int32,
+ torch.int64: np.int64,
+ }
+ return types[dtype]
+
+
+def benchmark_function(fn, iterations=10):
+ # Run once to get rid of startup overhead.
+ fn()
+ times = []
+ for _ in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+ start.record()
+ fn()
+ end.record()
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ times = np.array(times)
+ return times.mean(), times.std(), times.max(), times.min()
+
+
+def log_benchmark(arguments, mean_t, std_t):
+ print('=' * 60)
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t))
+ print('=' * 60)
+
+
+# class SortBenchmark(parameterized.TestCase):
+#
+# @parameterized.parameters(*_SORT_TESTS)
+# def testSort(self, n, dtype, max_val):
+# if max_val is None:
+# max_val = np.iinfo(numpy_dtype(dtype)).max
+# end_bit = int(np.ceil(np.log2(max_val)))
+# x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
+#
+# mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.sort(x, end_bit),)
+# arguments = {
+# 'n': n,
+# 'dtype': dtype,
+# 'max_val': max_val,
+# }
+# log_benchmark(arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_BASELINE_SORT_TESTS)
+# def testTorchSort(self, n):
+# x = torch.randint(0, 128, (n,)).cuda().to(torch.int32)
+#
+# mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.sort(x))
+# arguments = {
+# 'n': n,
+# }
+# log_benchmark(arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch211-cxx11-cu130-x86_64-linux/ops/stk_autocast.py b/build/torch211-cxx11-cu130-x86_64-linux/ops/stk_autocast.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a3626e5e0eec51339c95a448bca84be14a2ca93
--- /dev/null
+++ b/build/torch211-cxx11-cu130-x86_64-linux/ops/stk_autocast.py
@@ -0,0 +1,39 @@
+# vendored from
+# https://github.com/stanford-futuredata/stk/blob/736313768ef697ce13a0594a41b2512a0fbc9884/stk/backend/autocast.py
+import functools
+import torch
+
+
+def _is_eligible(x):
+ return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64)
+
+
+def _cast(x, dtype):
+ if isinstance(x, torch.Tensor) and _is_eligible(x):
+ return x.to(dtype)
+ elif isinstance(x, map):
+ return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()}
+ elif isinstance(x, list) or isinstance(x, tuple):
+ return type(x)(map(lambda y: _cast(y, dtype), x))
+ return x
+
+
+def custom_fwd(fwd):
+ """Wrap a custom autograd function that always uses autocast dtype."""
+
+ @functools.wraps(fwd)
+ def decorate_fwd(*args, **kwargs):
+ if torch.is_autocast_enabled():
+ with torch.autocast(device_type="cuda", enabled=False):
+ dtype = torch.get_autocast_gpu_dtype()
+ return fwd(*_cast(args, dtype), **_cast(kwargs, dtype))
+ return fwd(*args, **kwargs)
+ return decorate_fwd
+
+
+def custom_bwd(bwd):
+ @functools.wraps(bwd)
+ def decorate_bwd(*args, **kwargs):
+ with torch.autocast(device_type="cuda", enabled=False):
+ return bwd(*args, **kwargs)
+ return decorate_bwd
\ No newline at end of file
diff --git a/build/torch211-cxx11-cu130-x86_64-linux/ops/sum.py b/build/torch211-cxx11-cu130-x86_64-linux/ops/sum.py
new file mode 100644
index 0000000000000000000000000000000000000000..e00c1aa68e1193f5b72f75a2edc37de8d505facc
--- /dev/null
+++ b/build/torch211-cxx11-cu130-x86_64-linux/ops/sum.py
@@ -0,0 +1,9 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+import torch
+
+
+def sum(x: torch.Tensor, dim: int = 0):
+ if x.shape[dim] == 1:
+ return x.squeeze(dim=dim)
+ return x.sum(dim=dim)
diff --git a/build/torch211-cxx11-cu130-x86_64-linux/ops/topology.py b/build/torch211-cxx11-cu130-x86_64-linux/ops/topology.py
new file mode 100644
index 0000000000000000000000000000000000000000..76a50d3164db20534b099dcb4d8487a7aef25d15
--- /dev/null
+++ b/build/torch211-cxx11-cu130-x86_64-linux/ops/topology.py
@@ -0,0 +1,45 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for topology kernel.
+# NOTE: Does not support gradients.
+class TopologyOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(
+ ctx: Any,
+ padded_bins: torch.Tensor,
+ block_size: int,
+ output_block_rows: int,
+ output_block_columns: int,
+ ):
+ out = torch.empty(
+ output_block_rows * output_block_columns,
+ dtype=torch.int16,
+ device=padded_bins.device,
+ )
+ ops.indices(
+ padded_bins,
+ block_size,
+ output_block_rows,
+ output_block_columns,
+ out,
+ )
+ return out
+
+
+topology = TopologyOp.apply
diff --git a/build/torch211-cxx11-cu130-x86_64-linux/stk/__init__.py b/build/torch211-cxx11-cu130-x86_64-linux/stk/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..73c40c267b1c3f4949e9c957a5d2c9f682dfc1a6
--- /dev/null
+++ b/build/torch211-cxx11-cu130-x86_64-linux/stk/__init__.py
@@ -0,0 +1,7 @@
+# import stk.random
+# import stk.ops
+# from stk.matrix import Matrix
+
+from . import random
+from . import ops
+from .matrix import Matrix
diff --git a/build/torch211-cxx11-cu130-x86_64-linux/stk/backend/__init__.py b/build/torch211-cxx11-cu130-x86_64-linux/stk/backend/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/build/torch211-cxx11-cu130-x86_64-linux/stk/backend/autocast.py b/build/torch211-cxx11-cu130-x86_64-linux/stk/backend/autocast.py
new file mode 100644
index 0000000000000000000000000000000000000000..97f6e919a60f3fd579ed0215031008d14111dc96
--- /dev/null
+++ b/build/torch211-cxx11-cu130-x86_64-linux/stk/backend/autocast.py
@@ -0,0 +1,37 @@
+import functools
+import torch
+
+
+def _is_eligible(x):
+ return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64)
+
+
+def _cast(x, dtype):
+ if isinstance(x, torch.Tensor) and _is_eligible(x):
+ return x.to(dtype)
+ elif isinstance(x, map):
+ return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()}
+ elif isinstance(x, list) or isinstance(x, tuple):
+ return type(x)(map(lambda y: _cast(y, dtype), x))
+ return x
+
+
+def custom_fwd(fwd):
+ """Wrap a custom autograd function that always uses autocast dtype."""
+
+ @functools.wraps(fwd)
+ def decorate_fwd(*args, **kwargs):
+ if torch.is_autocast_enabled():
+ with torch.autocast(device_type="cuda", enabled=False):
+ dtype = torch.get_autocast_gpu_dtype()
+ return fwd(*_cast(args, dtype), **_cast(kwargs, dtype))
+ return fwd(*args, **kwargs)
+ return decorate_fwd
+
+
+def custom_bwd(bwd):
+ @functools.wraps(bwd)
+ def decorate_bwd(*args, **kwargs):
+ with torch.autocast(device_type="cuda", enabled=False):
+ return bwd(*args, **kwargs)
+ return decorate_bwd
diff --git a/build/torch211-cxx11-cu130-x86_64-linux/stk/backend/sputnik.py b/build/torch211-cxx11-cu130-x86_64-linux/stk/backend/sputnik.py
new file mode 100644
index 0000000000000000000000000000000000000000..220c947bc1e932e8c77cc30f4069e9930f1aa962
--- /dev/null
+++ b/build/torch211-cxx11-cu130-x86_64-linux/stk/backend/sputnik.py
@@ -0,0 +1,316 @@
+import torch
+
+from ..backend import triton_kernels as backend
+from ..backend.autocast import custom_bwd, custom_fwd
+
+
+def _standardize_shape(x, transpose):
+ if transpose:
+ return torch.Size((x[1], x[0]))
+ return x
+
+
+def _sparse_transpose(x):
+ return (torch.Size((x[0][1], x[0][0])), ) + x[1:]
+
+
+def _transpose_helper(x, transpose):
+ if isinstance(x, torch.Tensor):
+ return x.t() if transpose else x
+ if transpose:
+ x = _sparse_transpose(x)
+ return x + (transpose,)
+
+
+def _wrap(x):
+ if isinstance(x, torch.Tensor):
+ return (x,)
+ return x
+
+
+def _is_transposed(x):
+ return (not x.is_contiguous() and
+ x.stride()[0] == 1 and
+ x.stride()[1] == x.size()[0])
+
+
+def _call_helper(op, out, a, b, trans_a, trans_b):
+ args = (_wrap(_transpose_helper(a, trans_a)) +
+ _wrap(_transpose_helper(b, trans_b)))
+ if isinstance(out, tuple):
+ args = args + out
+ return op(*args)
+
+
+def _preprocess_inputs(lhs, rhs, dy):
+ if isinstance(lhs, torch.Tensor) and _is_transposed(lhs):
+ lhs = lhs.t()
+ if isinstance(rhs, torch.Tensor) and _is_transposed(rhs):
+ rhs = rhs.t()
+ if (isinstance(dy, torch.Tensor) and
+ not dy.is_contiguous() and
+ not _is_transposed(dy)):
+ dy = dy.contiguous()
+ if isinstance(dy, tuple) and not dy[1].is_contiguous():
+ dy = (dy[0], dy[1].contiguous()) + dy[2:]
+ return lhs, rhs, dy
+
+
+def _postprocess_outputs(x, transpose, grad):
+ if isinstance(x, torch.Tensor) and transpose:
+ return grad.t()
+ return grad
+
+
+def _lhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs):
+ lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy)
+
+ a, b = (rhs, dy) if trans_lhs else (dy, rhs)
+ trans_a = trans_lhs and trans_rhs
+ trans_b = trans_lhs or not trans_rhs
+ out = _call_helper(op, lhs, a, b, trans_a, trans_b)
+ return _postprocess_outputs(lhs, trans_lhs, out)
+
+
+def _rhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs):
+ lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy)
+
+ a, b = (dy, lhs) if trans_rhs else (lhs, dy)
+ trans_a = not trans_lhs or trans_rhs
+ trans_b = trans_lhs and trans_rhs
+ out = _call_helper(op, rhs, a, b, trans_a, trans_b)
+ return _postprocess_outputs(rhs, trans_rhs, out)
+
+
+class DSD(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs):
+ ctx.save_for_backward(data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ rhs)
+ ctx.shape = _standardize_shape(shape, transpose_a)
+ ctx.transpose_a = transpose_a
+
+ out = torch.empty(
+ (shape[0], rhs.size()[1]),
+ dtype=rhs.dtype,
+ device=rhs.device)
+
+ backend.dsd(shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs,
+ out)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs = (ctx.shape,) + saved_tensors[:-1]
+ rhs = saved_tensors[-1]
+ trans_a = ctx.transpose_a
+ trans_b = _is_transposed(rhs)
+
+ ddata = None
+ if ctx.needs_input_grad[1]:
+ ddata = _lhs_gradient(sdd,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ drhs = None
+ if ctx.needs_input_grad[-1]:
+ op = dds if trans_b else dsd
+ drhs = _rhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return None, ddata, None, None, None, None, None, None, None, drhs
+
+
+dsd = DSD.apply
+
+
+class DDS(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b):
+ ctx.save_for_backward(lhs,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t)
+ ctx.shape = _standardize_shape(shape, transpose_b)
+ ctx.transpose_b = transpose_b
+ out = torch.empty((lhs.size()[0], shape[1]),
+ dtype=lhs.dtype,
+ device=lhs.device)
+ backend.dds(lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b,
+ out)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs = saved_tensors[0]
+ rhs = (ctx.shape,) + saved_tensors[1:]
+ trans_a = _is_transposed(lhs)
+ trans_b = ctx.transpose_b
+
+ dlhs = None
+ if ctx.needs_input_grad[0]:
+ op = dsd if trans_a else dds
+ dlhs = _lhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ ddata = None
+ if ctx.needs_input_grad[2]:
+ ddata = _rhs_gradient(sdd,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return dlhs, None, ddata, None, None, None, None, None, None, None
+
+
+dds = DDS.apply
+
+
+class SDD(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ lhs,
+ rhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t):
+ ctx.save_for_backward(
+ lhs,
+ rhs,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t)
+ ctx.shape = shape
+ out = torch.empty(
+ data.shape,
+ dtype=lhs.dtype,
+ device=lhs.device)
+ backend.sdd(lhs,
+ rhs,
+ shape,
+ out,
+ offsets,
+ row_indices,
+ column_indices)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs, rhs = saved_tensors[:2]
+ dy = (ctx.shape, dy) + saved_tensors[2:]
+ trans_a = _is_transposed(lhs)
+ trans_b = _is_transposed(rhs)
+
+ dlhs = None
+ if ctx.needs_input_grad[0]:
+ op = dds if trans_a else dsd
+ dlhs = _lhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ drhs = None
+ if ctx.needs_input_grad[1]:
+ op = dsd if trans_b else dds
+ drhs = _rhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return dlhs, drhs, None, None, None, None, None, None, None, None
+
+
+sdd = SDD.apply
+
+class RowIndices(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, shape, data, offsets, column_indices):
+ out = torch.empty(
+ column_indices.shape,
+ dtype=column_indices.dtype,
+ device=column_indices.device)
+ backend.row_indices(shape, data, offsets, column_indices, out)
+ return out
+
+
+row_indices = RowIndices.apply
diff --git a/build/torch211-cxx11-cu130-x86_64-linux/stk/backend/triton_kernels.py b/build/torch211-cxx11-cu130-x86_64-linux/stk/backend/triton_kernels.py
new file mode 100644
index 0000000000000000000000000000000000000000..c535309f3321249f475367164a558f94a4f8eb86
--- /dev/null
+++ b/build/torch211-cxx11-cu130-x86_64-linux/stk/backend/triton_kernels.py
@@ -0,0 +1,393 @@
+import torch
+import triton
+import triton.language as tl
+from dataclasses import dataclass
+
+@dataclass
+class TritonConfig:
+ BLOCK_M: int = 128
+ BLOCK_N: int = 128
+ BLOCK_K: int = 32
+ BLOCK_SIZE: int = 128
+ NUM_STAGES: int = 4
+ NUM_WARPS: int = 4
+
+def _validate_matmul_dims(M: int, K: int, N: int):
+ error_string = "incompatible dimensions: tensor has dim with length: {}, which must be divisible by {}"
+ assert M % TritonConfig.BLOCK_M == 0, error_string.format(M, TritonConfig.BLOCK_M)
+ assert K % TritonConfig.BLOCK_K == 0, error_string.format(K, TritonConfig.BLOCK_K)
+ assert N % TritonConfig.BLOCK_N == 0, error_string.format(N, TritonConfig.BLOCK_N)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _sdd_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+ # matrix multiplication
+ pid = tl.program_id(0)
+ pid_m = tl.load(row_indices + pid)
+ pid_n = tl.load(column_indices + pid)
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
+ rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
+ rk = tl.arange(0, BLOCK_K)
+ # pointers
+ A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
+ B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ for k in range(0, tl.cdiv(K, BLOCK_K)):
+ a = tl.load(A)
+ b = tl.load(B)
+ acc += tl.dot(a, b)
+ A += BLOCK_K * stride_ak
+ B += BLOCK_K * stride_bk
+ #Store to sparse matrix
+ acc = acc.to(C.dtype.element_ty)
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+ cm = tl.arange(0, BLOCK_M)
+ cn = tl.arange(0, BLOCK_N)
+ C = C + pid * BLOCK_ELEMENTS + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _dsd_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices, offsets,
+ block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+
+ # matrix multiplication
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+
+ num_pid_m = tl.num_programs(0)
+ num_pid_n = tl.num_programs(1)
+ pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M)
+
+ start_inx = tl.load(offsets + pid_m)
+ end_inx = tl.load(offsets + pid_m + 1)
+
+ # pointers to sparse matrix
+ rm = tl.arange(0, BLOCK_M)
+ rak = tl.arange(0, BLOCK_K)
+
+ A += (rm[:, None] * stride_am + rak[None, :] * stride_ak)
+
+ # pointers to dense matrix
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ rbk = tl.arange(0, BLOCK_K)
+ B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn)
+
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K)
+
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+ ak_sub_incr = BLOCK_K * stride_ak
+ bk_sub_incr = BLOCK_K * stride_bk
+ bk_block_incr = BLOCK_SIZE * stride_bk
+
+ for k in range(nsub_blocks * (end_inx - start_inx)):
+ sub_block_inx = k % nsub_blocks
+ block_inx = k // nsub_blocks
+
+ if trans_A:
+ ptr_A = A + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr
+ else:
+ ptr_A = A + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr
+
+ ptr_B = B + tl.load(column_indices + start_inx + block_inx) * bk_block_incr + sub_block_inx * bk_sub_incr
+
+ a = tl.load(ptr_A)
+ b = tl.load(ptr_B)
+ acc += tl.dot(a, b)
+
+ acc = acc.to(C.dtype.element_ty)
+
+ cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+
+ C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _dds_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices, offsets,
+ block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+
+ # matrix multiplication
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+
+ num_pid_m = tl.num_programs(0)
+ num_pid_n = tl.num_programs(1)
+ pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M)
+
+ start_inx = tl.load(offsets + pid_n)
+ end_inx = tl.load(offsets + pid_n + 1)
+
+ # pointers to dense matrix
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ rak = tl.arange(0, BLOCK_K)
+
+ A += (rm[:, None] * stride_am + rak[None, :] * stride_ak)
+
+ # pointers to sparse matrix
+ rn = tl.arange(0, BLOCK_N)
+ rbk = tl.arange(0, BLOCK_K)
+ B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn)
+
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K)
+
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+
+ ak_sub_incr = BLOCK_K * stride_ak
+ ak_block_incr = BLOCK_SIZE * stride_ak
+ bk_sub_incr = BLOCK_K * stride_bk
+
+ for k in range(nsub_blocks * (end_inx - start_inx)):
+ sub_block_inx = k % nsub_blocks
+ block_inx = k // nsub_blocks
+
+ if trans_B:
+ ptr_B = B + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr
+ else:
+ ptr_B = B + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr
+
+ ptr_A = A + tl.load(column_indices + start_inx + block_inx) * ak_block_incr + sub_block_inx * ak_sub_incr
+ a = tl.load(ptr_A)
+ b = tl.load(ptr_B)
+ acc += tl.dot(a, b)
+
+ acc = acc.to(C.dtype.element_ty)
+ cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+def dsd(shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs,
+ out
+ ):
+
+ device = rhs.device
+ trans_A = transpose_a
+ trans_B = False
+
+ if rhs.stride(0) > 1 and rhs.stride(1) > 1:
+ trans_B = True
+
+ # checks constraints
+ assert shape[1] == rhs.shape[0], "incompatible dimensions"
+ M, K = shape
+ _, N = rhs.shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if rhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ stride_am, stride_ak = data.stride(1), data.stride(2)
+ stride_bk, stride_bn = rhs.stride(0), rhs.stride(1)
+ a_column_indices = column_indices
+ a_offsets = offsets
+
+ # launch kernel
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N']))
+
+ if trans_A:
+ stride_am, stride_ak = data.stride(2), data.stride(1)
+ a_column_indices, a_offsets = column_indices_t, offsets_t
+
+ if trans_B:
+ stride_bk, stride_bn = rhs.stride(1), rhs.stride(0)
+
+ _dsd_kernel[grid](
+ data.data, rhs, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(0), out.stride(1),
+ row_indices, a_column_indices, a_offsets,
+ block_offsets_t, trans_A, trans_B,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+ # return out
+
+def dds(lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b,
+ out
+ ):
+
+ device = lhs.device
+ trans_B = transpose_b
+ trans_A = False
+
+ if lhs.stride(0) > 1 and lhs.stride(1) > 1:
+ trans_A = True
+
+ # checks constraints
+ assert lhs.shape[1] == shape[0], "incompatible dimensions"
+ M, K = lhs.shape
+ _, N = shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if lhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ stride_am, stride_ak = lhs.stride(0), lhs.stride(1)
+ stride_bk, stride_bn = data.stride(1), data.stride(2)
+ b_column_indices = column_indices_t
+ b_offsets = offsets_t
+
+ # launch kernel
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N']))
+
+ if trans_A:
+ stride_am, stride_ak = lhs.stride(1), lhs.stride(0)
+ if trans_B:
+ stride_bk, stride_bn = data.stride(2), data.stride(1)
+ b_column_indices, b_offsets = column_indices, offsets
+
+ _dds_kernel[grid](
+ lhs, data, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(0), out.stride(1),
+ row_indices, b_column_indices, b_offsets,
+ block_offsets_t, trans_A, trans_B,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+
+def sdd(lhs,
+ rhs,
+ shape,
+ out,
+ offsets,
+ row_indices,
+ column_indices
+ ):
+
+ device = out.device
+ trans_A = False
+ trans_B = False
+
+ if lhs.stride(0) > 1 and lhs.stride(1) > 1:
+ trans_A = True
+ if rhs.stride(0) > 1 and rhs.stride(1) > 1:
+ trans_B = True
+
+ # checks constraints
+ assert lhs.shape[1] == rhs.shape[0], "incompatible dimensions"
+ M, K = lhs.shape
+ _, N = rhs.shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if out.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ # launch kernel
+ nnz_blocks = len(row_indices)
+ grid = lambda META: (nnz_blocks,)
+
+ stride_am, stride_ak = lhs.stride(0), lhs.stride(1)
+ stride_bk, stride_bn = rhs.stride(0), rhs.stride(1)
+
+ if trans_A:
+ stride_am, stride_ak = lhs.stride(1), lhs.stride(0)
+ if trans_B:
+ stride_bk, stride_bn = rhs.stride(1), rhs.stride(0)
+
+ _sdd_kernel[grid](
+ lhs, rhs, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(1), out.stride(2),
+ row_indices, column_indices,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+
+@triton.jit
+def _row_indices_kernel(offsets, out):
+ pid = tl.program_id(0)
+ row_offset = tl.load(offsets + pid)
+ nnz_blocks = tl.load(offsets + pid + 1) - row_offset
+ for nnz_block in range(nnz_blocks):
+ tl.store(out + row_offset + nnz_block, pid)
+
+def row_indices(
+ shape, data, offsets, column_indices, out
+):
+ block_rows = len(offsets) - 1
+ _row_indices_kernel[(block_rows, )](offsets, out)
diff --git a/build/torch211-cxx11-cu130-x86_64-linux/stk/matrix.py b/build/torch211-cxx11-cu130-x86_64-linux/stk/matrix.py
new file mode 100644
index 0000000000000000000000000000000000000000..80f42263d6aada287adbfa52a61fe950162a9e28
--- /dev/null
+++ b/build/torch211-cxx11-cu130-x86_64-linux/stk/matrix.py
@@ -0,0 +1,329 @@
+import numpy as np
+import torch
+
+# 1. Add heavyweight (data) validation helper.
+# 2. Add construction helpers
+# 3. Make indentation consistent
+# 4. Replace asserts with descriptive errors.
+
+##
+### Validation helpers.
+##
+
+
+def _validate_matrix(shape, data, row_indices, column_indices, offsets):
+ # Data should be [nnz, block_size, block_size]
+ if data.dim() == 1:
+ data = torch.reshape(data, [data.numel(), 1, 1])
+
+ # Blocks should be square.
+ if data.shape[-2] != data.shape[-1]:
+ raise ValueError(
+ "Expected square blocking in data. "
+ f"Got block shape {[data.shape[-2], data.shape[-1]]}")
+
+ # Flatten batch dimensions on data - original shape preserved
+ # in shape argument.
+ block_size = data.shape[-1]
+ data = data.view([-1, block_size, block_size])
+
+ if data.dim() != 3:
+ raise ValueError(
+ "Expected 3D shape for data (nnz, block, block). "
+ f"Got shape {data.dim()}D shape.")
+
+ block_size = data.shape[1]
+ if shape[-2] % block_size != 0 or shape[-1] % block_size != 0:
+ raise ValueError(
+ "Matrix shape must be dividible by blocking. "
+ f"Got shape {shape} with "
+ f"{[block_size, block_size]} blocking.")
+
+ if np.prod(shape) < data.numel():
+ raise ValueError(
+ "Invalid matrix. Number of nonzeros exceeds matrix capacity "
+ f"({data.numel()} v. {np.prod(shape)})")
+
+ if row_indices.dim() != 1:
+ raise ValueError(
+ f"Expected 1D row_indices. Got {row_indices.dim()}D row_indices.")
+
+ if column_indices.dim() != 1:
+ raise ValueError(
+ f"Expected 1D column_indices. Got {column_indices.dim()}D column_indices.")
+
+ if offsets.dim() != 1:
+ raise ValueError(
+ f"Expected 1D offsets. Got {offsets.dim()}D offsets.")
+
+ if row_indices.numel() != data.shape[0]:
+ raise ValueError(
+ "Expected 1 index per nonzero block. "
+ f"Got {row_indices.numel()} row_indices for {data.shape[0]} blocks")
+
+ if column_indices.numel() != data.shape[0]:
+ raise ValueError(
+ "Expected 1 index per nonzero block. "
+ f"Got {column_indices.numel()} column_indices for {data.shape[0]} blocks")
+
+ block_rows = np.prod(shape[:-1]) / block_size
+ if offsets.numel() != block_rows + 1:
+ raise ValueError(
+ "Expected one offset per block row plus one. "
+ f"Got {offsets.numel()} offsets with {block_rows} block rows.")
+
+ is_cuda = (data.is_cuda and
+ row_indices.is_cuda and
+ column_indices.is_cuda and
+ offsets.is_cuda)
+ is_cpu = (not data.is_cuda and
+ not row_indices.is_cuda and
+ not column_indices.is_cuda and
+ not offsets.is_cuda)
+ if not (is_cuda or is_cpu):
+ raise ValueError(
+ "Expected data & meta-data on common device. "
+ f"Got data on {data.device}, row_indices on {row_indices.device} "
+ f"column_indices on {column_indices.device} and "
+ f"offsets on {offsets.device}.")
+
+ if data.dtype != torch.float16:
+ raise ValueError(
+ f"Expected float16 data. Got {data.dtype} data.")
+ if row_indices.dtype != torch.int16:
+ raise ValueError(
+ f"Expected int16 row_indices. Got {row_indices.dtype} row_indices.")
+ if column_indices.dtype != torch.int16:
+ raise ValueError(
+ f"Expected int16 column_indices. Got {column_indices.dtype} column_indices.")
+ if offsets.dtype != torch.int32:
+ raise ValueError(
+ f"Expected int32 offsets. Got {offsets.dtype} offsets.")
+ return data
+
+
+def _transpose(size, data, row_indices, column_indices, offsets):
+ block_columns = size[1] // data.shape[1]
+
+ # Sort row indices by column indices to get the transposed matrix's
+ # column indices.
+ gather_indices = column_indices.argsort()
+ column_indices_t = row_indices.gather(0, gather_indices)
+ block_offsets_t = gather_indices.int()
+
+ # NOTE: Histogram is not implemented for any integer type on CPU. Do
+ # the histogram in 32-bit float, which can exactly represent 16-bit
+ # integers.
+ column_indices_float = column_indices.float()
+
+ zero = torch.zeros((1,), dtype=torch.int32, device=data.device)
+ nnz_per_column = column_indices_float.histc(block_columns, 0, block_columns)
+ nnz_per_column = nnz_per_column.int()
+ offsets_t = torch.cat([zero, nnz_per_column.cumsum(0, dtype=torch.int32)])
+ return column_indices_t, offsets_t, block_offsets_t
+
+
+class Matrix(torch.nn.Module):
+ """A matrix stored in sparse format.
+
+ Underlying format is block compressed sparse row (BCSR).
+
+ TODO(tgale): Make this mirror torch.Tensor API as much as possible.
+ """
+
+ def __init__(self,
+ size,
+ data,
+ row_indices,
+ column_indices,
+ offsets,
+ column_indices_t=None,
+ offsets_t=None,
+ block_offsets_t=None):
+ super().__init__()
+ self._size = size
+ self._data = data
+ self._row_indices = row_indices
+ self._column_indices = column_indices
+ self._offsets = offsets
+
+ # Produce the transpose meta-data if it is not passed in.
+ if ((column_indices_t is None) or (offsets_t is None) or
+ (block_offsets_t is None)):
+ column_indices_t, offsets_t, block_offsets_t = _transpose(
+ size, data, row_indices, column_indices, offsets)
+ self._column_indices_t = column_indices_t
+ self._offsets_t = offsets_t
+ self._block_offsets_t = block_offsets_t
+
+ self._transposed = False
+
+ # Validate that our metadata will not overflow.
+ max_dim = np.iinfo(np.int16).max * self.blocking
+ if column_indices.dtype == torch.int16:
+ if size[0] > max_dim or size[1] > max_dim:
+ raise ValueError(
+ "Sparse matrix with shape {size} exceeds representable "
+ "size with 16-bit indices.")
+
+ def validate(self):
+ _validate_matrix(self._size,
+ self._data,
+ self._row_indices,
+ self._column_indices,
+ self._offsets)
+
+ # TODO(tgale): Add heavyweight data validation.
+
+ def to(self, device):
+ # TODO(tgale): Handle type conversions here. We
+ # need to set the appropriate meta-data type for
+ # the given floating-point type.
+ self._data = self._data.to(device)
+ self._row_indices = self._row_indices.to(device)
+ self._column_indices = self._column_indices.to(device)
+ self._offsets = self._offsets.to(device)
+ self._column_indices_t = self._column_indices_t.to(device)
+ self._offsets_t = self._offsets_t.to(device)
+ self._block_offsets_t = self._block_offsets_t.to(device)
+ return self
+
+ def cuda(self):
+ return self.to(torch.cuda.current_device())
+
+ def clone(self):
+ return Matrix(
+ self.size(),
+ self.data.clone(),
+ self.row_indices.clone(),
+ self.column_indices.clone(),
+ self.offsets.clone(),
+ self.column_indices_t.clone(),
+ self.offsets_t.clone(),
+ self.block_offsets_t.clone())
+
+ def t(self):
+ if self.dim() != 2:
+ raise ValueError(
+ "t() expects a tensor with <= 2 dimensions, "
+ f"but self is {self.dim()}D.")
+ out = Matrix(self.size(),
+ self.data,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+ out._transposed = not self._transposed
+ out._size = torch.Size((self._size[1], self._size[0]))
+ return out
+
+ def contiguous(self):
+ raise ValueError("Not yet implemented.")
+
+ def is_contiguous(self):
+ return not self._transposed
+
+ @property
+ def is_cuda(self):
+ return self._data.is_cuda
+
+ @property
+ def device(self):
+ return self._data.device
+
+ def size(self):
+ return self._size
+
+ @property
+ def shape(self):
+ return self.size()
+
+ def dim(self):
+ return len(self._size)
+
+ @property
+ def data(self):
+ return self._data
+
+ @property
+ def row_indices(self):
+ return self._row_indices
+
+ @property
+ def column_indices(self):
+ return self._column_indices
+
+ @property
+ def offsets(self):
+ return self._offsets
+
+ @property
+ def offsets_t(self):
+ return self._offsets_t
+
+ @property
+ def column_indices_t(self):
+ return self._column_indices_t
+
+ @property
+ def block_offsets_t(self):
+ return self._block_offsets_t
+
+ @property
+ def dtype(self):
+ return self.data.dtype
+
+ @property
+ def nnz(self):
+ return self.data.numel()
+
+ @property
+ def blocking(self):
+ return self.data.shape[1]
+
+ @property
+ def requires_grad(self):
+ return self.data.requires_grad
+
+ def requires_grad_(self, x):
+ self.data.requires_grad_(x)
+ return self
+
+ def view(self, *shape):
+ assert self.is_contiguous()
+ if shape[-1] != self.size()[-1]:
+ raise ValueError(
+ "Can't change view on compressed dimension. "
+ f"{self.size()[-1]} v. {shape[-1]}.")
+ if np.prod(shape) != np.prod(self.size()):
+ raise ValueError(
+ "Mismatch in numel of Matrix and new shape. "
+ f"{np.prod(self.size())} v. {np.prod(shape)}")
+ return Matrix(shape,
+ self.data,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+
+ @property
+ def grad(self):
+ # TODO(tgale): Make sure this mirrors torch.Tensor
+ # behavior in the case where we ask for the gradient
+ # of a non-contiguous tensor.
+ size = self.size()
+ if not self.is_contiguous():
+ size = torch.Size((size[1], size[0]))
+ out = Matrix(size,
+ self.data.grad,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+ return out if self.is_contiguous() else out.t()
diff --git a/build/torch211-cxx11-cu130-x86_64-linux/stk/ops/__init__.py b/build/torch211-cxx11-cu130-x86_64-linux/stk/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc873b236f4cd4036964c016a4036e3ce5ebf1ac
--- /dev/null
+++ b/build/torch211-cxx11-cu130-x86_64-linux/stk/ops/__init__.py
@@ -0,0 +1,3 @@
+from .linear_ops import dds, dsd, sdd
+from .matrix_ops import ones_like, row_indices, sum, to_dense, to_sparse
+from .eltwise_ops import mul
diff --git a/build/torch211-cxx11-cu130-x86_64-linux/stk/ops/eltwise_ops.py b/build/torch211-cxx11-cu130-x86_64-linux/stk/ops/eltwise_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba7d7332320250fd01fa60e60528f19de3e8ed03
--- /dev/null
+++ b/build/torch211-cxx11-cu130-x86_64-linux/stk/ops/eltwise_ops.py
@@ -0,0 +1,28 @@
+from ..matrix import Matrix
+
+def mul(a, b):
+ """Performs element-wise multiplication of matrices a and b.
+
+ It is the user's responsibility to make sure that a and b
+ follow the same matrix topology. This function assumes it is safe
+ to use the topoplogy of a.
+
+ Args:
+ a: stk.Matrix.
+ b: stk.Matrix with a's matrix topology.
+
+ Returns:
+ stk.Matrix where the entries correspond to torch.mul(a, b).
+ """
+ assert isinstance(a, Matrix)
+ assert isinstance(b, Matrix)
+ assert a.size() == b.size()
+
+ return Matrix(a.size(),
+ a.data * b.data,
+ a.row_indices,
+ a.column_indices,
+ a.offsets,
+ a.column_indices_t,
+ a.offsets_t,
+ a.block_offsets_t)
diff --git a/build/torch211-cxx11-cu130-x86_64-linux/stk/ops/eltwise_ops_test.py b/build/torch211-cxx11-cu130-x86_64-linux/stk/ops/eltwise_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..2939372a5c68ac92b47b11015db4f75f4fd60ffa
--- /dev/null
+++ b/build/torch211-cxx11-cu130-x86_64-linux/stk/ops/eltwise_ops_test.py
@@ -0,0 +1,86 @@
+import unittest
+import itertools
+import torch
+# from absl.testing import parameterized
+
+import stk
+from stk.ops.linear_ops_test import allclose, _dense_and_sparse
+
+_MATRIX_SIZES = (
+ (128, 128, 0.0),
+ (256, 256, 0.5),
+ (2048, 1024, 0.8),
+ (512, 128, 0.0),
+ (128, 512, 0.0),
+ (1024, 512, 0.0),
+ (1024, 512, 0.5),
+ (1024, 512, 0.75),
+ (512, 1024, 0.0),
+ (512, 1024, 0.5),
+ (512, 1024, 0.75),
+ (1024, 1024, 0.0),
+ (1024, 1024, 0.5),
+ (1024, 1024, 0.75),
+)
+
+_DTYPE = (
+ torch.float16, torch.bfloat16
+)
+
+def _generate_testcases():
+ testcases = itertools.product(_MATRIX_SIZES, _DTYPE)
+ testcases = [(*size, 128, dtype) for
+ (size, dtype) in testcases]
+ return testcases
+
+_ELTWISE_OP_TESTS = _generate_testcases()
+
+def _dense_and_sparse_like(x, std=0.1):
+ dense_data = torch.randn_like(x.data, device=x.device) * std
+ sparse = stk.Matrix(x.size(),
+ dense_data,
+ x.row_indices,
+ x.column_indices,
+ x.offsets)
+ dense = stk.ops.to_dense(sparse)
+
+ return (dense.requires_grad_(True),
+ sparse.requires_grad_(True))
+
+# @parameterized.parameters(_ELTWISE_OP_TESTS)
+# class EltwiseOpsTest(parameterized.TestCase):
+#
+# def testEltwiseMul(self, m, n, sparsity, blocking, dtype):
+#
+# a_dense, a = _dense_and_sparse(m, n, sparsity, blocking, dtype)
+# b_dense, b = _dense_and_sparse_like(a)
+#
+# out = stk.ops.mul(a, b)
+# expected_out = torch.mul(a_dense, b_dense)
+#
+# # Compute the gradients w.r.t. the inputs.
+# expected_out.sum().backward()
+# stk.ops.sum(out).backward()
+#
+# # Validate the results.
+# out = stk.ops.to_dense(out)
+# self.assertEqual(out.dim(), 2)
+# self.assertEqual(expected_out.size(), out.size())
+# self.assertTrue(allclose(out, expected_out))
+#
+# # LHS gradient.
+# grad = stk.ops.to_dense(a.grad)
+# expected_grad = a_dense.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size(), grad.size())
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# # RHS gradient.
+# grad = stk.ops.to_dense(b.grad)
+# expected_grad = b_dense.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size(), grad.size())
+# self.assertTrue(allclose(grad, expected_grad))
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch211-cxx11-cu130-x86_64-linux/stk/ops/linear_ops.py b/build/torch211-cxx11-cu130-x86_64-linux/stk/ops/linear_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d277c8c07f9e30addc31900a12175c8a1f4d7ad
--- /dev/null
+++ b/build/torch211-cxx11-cu130-x86_64-linux/stk/ops/linear_ops.py
@@ -0,0 +1,59 @@
+import torch
+
+from ..backend import sputnik
+from ..matrix import Matrix
+
+
+def dsd(a, b):
+ assert isinstance(a, Matrix)
+ assert isinstance(b, torch.Tensor)
+ return sputnik.dsd(
+ a.size(),
+ a.data, a.offsets,
+ a.row_indices,
+ a.column_indices,
+ a.offsets_t,
+ a.column_indices_t,
+ a.block_offsets_t,
+ not a.is_contiguous(),
+ b)
+
+
+def dds(a, b):
+ assert isinstance(a, torch.Tensor)
+ assert isinstance(b, Matrix)
+ return sputnik.dds(
+ a,
+ b.size(),
+ b.data, b.offsets,
+ b.row_indices,
+ b.column_indices,
+ b.offsets_t,
+ b.column_indices_t,
+ b.block_offsets_t,
+ not b.is_contiguous())
+
+
+def sdd(a, b, topo):
+ assert isinstance(a, torch.Tensor)
+ assert isinstance(b, torch.Tensor)
+ assert isinstance(topo, Matrix)
+ assert topo.is_contiguous()
+ out = sputnik.sdd(
+ a, b,
+ topo.size(),
+ topo.data,
+ topo.offsets,
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets_t,
+ topo.column_indices_t,
+ topo.block_offsets_t)
+ return Matrix(topo.size(),
+ out,
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets,
+ topo.column_indices_t,
+ topo.offsets_t,
+ topo.block_offsets_t)
diff --git a/build/torch211-cxx11-cu130-x86_64-linux/stk/ops/linear_ops_test.py b/build/torch211-cxx11-cu130-x86_64-linux/stk/ops/linear_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1c24d350df9c1b2346c7da885502cd696c88867
--- /dev/null
+++ b/build/torch211-cxx11-cu130-x86_64-linux/stk/ops/linear_ops_test.py
@@ -0,0 +1,216 @@
+import unittest
+import itertools
+import numpy as np
+import torch
+# from absl.testing import parameterized
+
+import stk
+
+
+def allclose(x, y, pct=0.25):
+ mask = torch.isclose(x, y, rtol=5e-2)
+ pct_diff = (mask.numel() - mask.sum()) / mask.numel() * 100
+ if pct_diff > pct:
+ print("{:.2f}% of values not close.".format(pct_diff))
+ return False
+ return True
+
+
+# An assortment of problems designed to make sure
+# the bindings are operating correctly.
+_MATRIX_SIZES = (
+ (128, 128, 128, 0.0),
+ (256, 256, 256, 0.5),
+ (2048, 1024, 512, 0.8),
+ (512, 128, 128, 0.0),
+ (128, 128, 512, 0.0),
+ (1024, 512, 512, 0.0),
+ (1024, 512, 512, 0.5),
+ (1024, 512, 512, 0.75),
+ (512, 512, 1024, 0.0),
+ (512, 512, 1024, 0.5),
+ (512, 512, 1024, 0.75),
+ (1024, 1024, 1024, 0.0),
+ (1024, 1024, 1024, 0.5),
+ (1024, 1024, 1024, 0.75),
+)
+
+_TRANSPOSE = (
+ (False, False),
+ (False, True),
+ (True, False),
+ (True, True),
+)
+
+_DTYPE = (
+ torch.float16, torch.bfloat16
+)
+
+def _generate_testcases():
+ testcases = itertools.product(_MATRIX_SIZES, _TRANSPOSE, _DTYPE)
+ testcases = [(*size, *trans, 128, dtype) for
+ (size, trans, dtype) in testcases]
+ return testcases
+
+_LINEAR_OP_TESTS = _generate_testcases()
+
+def _dense_and_sparse(rows, cols, sparsity, blocking, dtype, std=0.1):
+ mask = stk.random.dense_mask(rows, cols, sparsity, blocking)
+ dense = (torch.randn(rows, cols) * std * mask).type(dtype)
+ sparse = stk.ops.to_sparse(dense, blocking)
+ cuda_device = torch.device("cuda")
+ return (dense.to(cuda_device).requires_grad_(True),
+ sparse.to(cuda_device).requires_grad_(True))
+
+
+def _dense(rows, cols, dtype, std=0.1):
+ cuda_device = torch.device("cuda")
+ out = (torch.randn(rows, cols) * std).type(dtype)
+ return out.to(cuda_device).requires_grad_(True)
+
+
+def _dense_2x(rows, cols, dtype):
+ a = _dense(rows, cols, dtype)
+ return a, a.detach().requires_grad_(True)
+
+
+def _with_transpose(op, a, b, trans_a, trans_b):
+ a = a.t() if trans_a else a
+ b = b.t() if trans_b else b
+ return op(a, b)
+
+
+def _mmm(a, b, topo):
+ mask = stk.ops.to_dense(stk.ops.ones_like(topo))
+ return torch.mm(a, b) * mask
+
+
+def _sparse_out_with_transpose(op, a, b, topo, trans_a, trans_b):
+ a = a.t() if trans_a else a
+ b = b.t() if trans_b else b
+ return op(a, b, topo)
+
+
+def _mask(x, mask):
+ mask = stk.ops.to_dense(stk.ops.ones_like(mask))
+ return x * mask
+
+
+# @parameterized.parameters(*_LINEAR_OP_TESTS)
+# class LinearOpsTest(parameterized.TestCase):
+#
+# def testLinearOps_Dsd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+# # Construct the operands.
+# a_shape = (k, m) if trans_a else (m, k)
+# a_dense, a = _dense_and_sparse(*a_shape, sparsity, blocking, dtype)
+# b_shape = (n, k) if trans_b else (k, n)
+# b, bcp = _dense_2x(*b_shape, dtype)
+#
+# # Execute the matmul.
+# out = _with_transpose(stk.ops.dsd, a, b, trans_a, trans_b)
+# expected_out = _with_transpose(torch.mm, a_dense, bcp, trans_a, trans_b)
+#
+# # Compute the gradients w.r.t. the inputs.
+# expected_out.sum().backward()
+# out.sum().backward()
+#
+# # Validate the results.
+# self.assertEqual(out.dim(), 2)
+# self.assertEqual(expected_out.size()[0], out.size()[0])
+# self.assertEqual(expected_out.size()[1], out.size()[1])
+# self.assertTrue(allclose(out, expected_out))
+#
+# # LHS gradient.
+# grad = stk.ops.to_dense(a.grad)
+# expected_grad = _mask(a_dense.grad, a.grad)
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# # RHS gradient.
+# grad = b.grad
+# expected_grad = bcp.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# def testLinearOps_Dds(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+# # Construct the operands.
+# a_shape = (k, m) if trans_a else (m, k)
+# a, acp = _dense_2x(*a_shape, dtype)
+# b_shape = (n, k) if trans_b else (k, n)
+# b_dense, b = _dense_and_sparse(*b_shape, sparsity, blocking, dtype)
+#
+# # Execute the matmul.
+# out = _with_transpose(stk.ops.dds, a, b, trans_a, trans_b)
+# expected_out = _with_transpose(torch.mm, acp, b_dense, trans_a, trans_b)
+#
+# # Compute the gradients w.r.t. the inputs.
+# expected_out.sum().backward()
+# out.sum().backward()
+#
+# # Validate the results.
+# self.assertEqual(out.dim(), 2)
+# self.assertEqual(expected_out.size()[0], out.size()[0])
+# self.assertEqual(expected_out.size()[1], out.size()[1])
+# self.assertTrue(allclose(out, expected_out))
+#
+# # LHS gradient.
+# grad = a.grad
+# expected_grad = acp.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# # RHS gradient.
+# grad = stk.ops.to_dense(b.grad)
+# expected_grad = _mask(b_dense.grad, b.grad)
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# def testLinearOps_Sdd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+# # Construct the operands.
+# a_shape = (k, m) if trans_a else (m, k)
+# a, acp = _dense_2x(*a_shape, dtype)
+# b_shape = (n, k) if trans_b else (k, n)
+# b, bcp = _dense_2x(*b_shape, dtype)
+# _, topo = _dense_and_sparse(m, n, sparsity, blocking, dtype)
+#
+# # Execute the matmul.
+# out = _sparse_out_with_transpose(stk.ops.sdd, a, b, topo, trans_a, trans_b)
+# expected_out = _sparse_out_with_transpose(_mmm, acp, bcp, topo, trans_a, trans_b)
+#
+# # Compute the gradients w.r.t. the inputs.
+# expected_out.sum().backward()
+# stk.ops.sum(out).backward()
+#
+# # Validate the results.
+# out = stk.ops.to_dense(out)
+# self.assertEqual(out.dim(), 2)
+# self.assertEqual(expected_out.size()[0], out.size()[0])
+# self.assertEqual(expected_out.size()[1], out.size()[1])
+# self.assertTrue(allclose(out, expected_out))
+#
+# # LHS gradient.
+# grad = a.grad
+# expected_grad = acp.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# # RHS gradient.
+# grad = b.grad
+# expected_grad = bcp.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch211-cxx11-cu130-x86_64-linux/stk/ops/matrix_ops.py b/build/torch211-cxx11-cu130-x86_64-linux/stk/ops/matrix_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..447c72dc73439d84f58c917676cc04e64f13e97d
--- /dev/null
+++ b/build/torch211-cxx11-cu130-x86_64-linux/stk/ops/matrix_ops.py
@@ -0,0 +1,98 @@
+from ..backend import sputnik
+from ..matrix import Matrix
+import torch
+import numpy as np
+
+
+@torch.no_grad()
+def row_indices(shape, data, offsets, column_indices):
+ return sputnik.row_indices(shape, data, offsets, column_indices)
+
+
+# TODO(tgale): Replace this helper with a custom kernel. This operation
+# is much simpler to do than how it's currently implemented.
+@torch.no_grad()
+def _expand_for_blocking(idxs, blocking):
+ # Duplicate for block column dimension.
+ idxs = torch.reshape(idxs, [idxs.size()[0], 1, 2]).repeat(1, blocking, 1)
+
+ # Update the column indices.
+ idxs[:, :, 1] *= blocking
+ idxs[:, :, 1] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking])
+
+ # Duplicate for block row dimension.
+ idxs = torch.reshape(idxs, [idxs.size()[0], 1, blocking, 2])
+ idxs = idxs.repeat(1, blocking, 1, 1)
+
+ # Update the row indices.
+ idxs[:, :, :, 0] *= blocking
+ idxs[:, :, :, 0] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking, 1])
+ idxs = torch.reshape(idxs, [-1, 2])
+ return idxs
+
+
+# TODO(tgale): Add input type checking.
+@torch.no_grad()
+def to_dense(x):
+ assert isinstance(x, Matrix)
+
+ shape = (np.prod(x.shape[:-1]), x.shape[-1])
+ row_idxs = x.row_indices.type(torch.int32)
+ col_idxs = x.column_indices.type(torch.int32)
+ indices = _expand_for_blocking(torch.stack([row_idxs, col_idxs], dim=1), x.blocking)
+ indices = (indices[:, 0] * shape[1] + indices[:, 1]).type(torch.int64)
+
+ out = torch.zeros(shape[0] * shape[1], dtype=x.dtype, device=x.device)
+ out.scatter_(0, indices, x.data.flatten())
+ return out.reshape(x.size())
+
+
+@torch.no_grad()
+def _mask(x, blocking=1):
+ assert x.dim() == 2
+ assert x.size()[0] % blocking == 0
+ assert x.size()[1] % blocking == 0
+ block_rows = x.size()[0] // blocking
+ block_cols = x.size()[1] // blocking
+ x = torch.reshape(x, [block_rows, blocking, block_cols, blocking])
+ x = torch.sum(torch.abs(x), dim=(1, 3))
+ return x != 0
+
+
+# TODO(tgale): Add input type checking.
+@torch.no_grad()
+def to_sparse(x, blocking=1):
+ m = _mask(x, blocking)
+
+ # TODO(tgale): Set to appropriate type for input matrix.
+ row_nnzs = torch.sum(m, dim=1).type(torch.int32)
+ zeros = torch.zeros((1,), dtype=row_nnzs.dtype, device=row_nnzs.device)
+ offsets = torch.cat([zeros, torch.cumsum(row_nnzs, dim=0)])
+ offsets = offsets.type(torch.int32)
+
+ indices = torch.nonzero(m).type(torch.int16)
+ row_indices = indices[:, 0]
+ column_indices = indices[:, 1]
+
+ # Nonzero indices in the dense matrix.
+ nonzero_indices = torch.nonzero(m)
+ nonzero_indices = _expand_for_blocking(nonzero_indices, blocking)
+ nonzero_indices = nonzero_indices[:, 0] * x.size()[1] + nonzero_indices[:, 1]
+
+ # Gather the data and construct the sparse matrix.
+ data = torch.gather(x.flatten(), dim=0, index=nonzero_indices)
+ data = torch.reshape(data, [-1, blocking, blocking])
+ return Matrix(x.size(), data, row_indices, column_indices, offsets)
+
+
+@torch.no_grad()
+def ones_like(x):
+ return Matrix(x.size(),
+ torch.ones_like(x.data),
+ x.row_indices,
+ x.column_indices, x.offsets)
+
+
+def sum(x):
+ assert isinstance(x, Matrix)
+ return x.data.sum()
diff --git a/build/torch211-cxx11-cu130-x86_64-linux/stk/ops/matrix_ops_test.py b/build/torch211-cxx11-cu130-x86_64-linux/stk/ops/matrix_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..d172d921f6f08b0e4fb709207a458b0e1e071bd0
--- /dev/null
+++ b/build/torch211-cxx11-cu130-x86_64-linux/stk/ops/matrix_ops_test.py
@@ -0,0 +1,62 @@
+import unittest
+
+# from absl.testing import parameterized
+import stk
+import torch
+
+
+# @parameterized.parameters(
+# (8, 16, 0.0, 1),
+# (8, 16, 0.5, 1),
+# (8, 16, .95, 1),
+# (16, 8, 0.0, 1),
+# (16, 8, 0.5, 1),
+# (16, 8, .95, 1),
+# (8, 16, 0.0, 8),
+# (8, 16, 0.5, 8),
+# (8, 16, 1.0, 8),
+# (16, 8, 0.0, 8),
+# (16, 8, 0.5, 8),
+# (16, 8, 1.0, 8),
+# (128, 256, 0.5, 16),
+# (256, 128, 0.75, 32),
+# (512, 512, .875, 128))
+# class MatrixOpsTest(parameterized.TestCase):
+#
+# def testMatrixOps_FormatConversion(self, rows, cols, sparsity, blocking):
+# mask = stk.random.dense_mask(rows, cols, sparsity, blocking)
+# x = (torch.randn(rows, cols) * mask).type(torch.float16)
+#
+# # Convert the matrix to sparse format.
+# sparse_x = stk.ops.to_sparse(x, blocking)
+#
+# # Validate the matrix.
+# sparse_x.validate()
+#
+# # Validate the shape.
+# self.assertEqual(sparse_x.dim(), 2)
+# self.assertEqual(sparse_x.size()[0], rows)
+# self.assertEqual(sparse_x.size()[1], cols)
+#
+# # Validate the sparsity.
+# numblocks = rows // blocking * cols // blocking
+# nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+# self.assertEqual(sparse_x.nnz, nnz)
+#
+# # Convert back to dense format.
+# dense_x = stk.ops.to_dense(sparse_x)
+#
+# # Validate the shape.
+# self.assertEqual(dense_x.dim(), 2)
+# self.assertEqual(dense_x.size()[0], rows)
+# self.assertEqual(dense_x.size()[1], cols)
+#
+# # Validate the sparsity
+# self.assertEqual(torch.count_nonzero(dense_x).item(), nnz)
+#
+# # Validate the output.
+# self.assertTrue(torch.all(torch.eq(x, dense_x)))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch211-cxx11-cu130-x86_64-linux/stk/random/__init__.py b/build/torch211-cxx11-cu130-x86_64-linux/stk/random/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2576d1ca27283f77569a9a620c7c99fa68aaf30e
--- /dev/null
+++ b/build/torch211-cxx11-cu130-x86_64-linux/stk/random/__init__.py
@@ -0,0 +1,2 @@
+# from stk.random.random_ops import dense_mask, mask, randn
+from .random_ops import dense_mask, mask, randn
diff --git a/build/torch211-cxx11-cu130-x86_64-linux/stk/random/random_ops.py b/build/torch211-cxx11-cu130-x86_64-linux/stk/random/random_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1b36771e0eb8e7abf46bcb3b136b5fb1d29df93
--- /dev/null
+++ b/build/torch211-cxx11-cu130-x86_64-linux/stk/random/random_ops.py
@@ -0,0 +1,36 @@
+import numpy as np
+import torch
+from ..ops import matrix_ops
+
+
+@torch.no_grad()
+def dense_mask(rows, cols, sparsity, blocking=1):
+ assert sparsity >= 0.0 and sparsity <= 1.0
+ assert rows % blocking == 0 and cols % blocking == 0
+
+ block_rows, block_cols = (rows // blocking, cols // blocking)
+ nnz = round(block_rows * block_cols * (1 - sparsity))
+
+ out = np.ones(block_rows * block_cols)
+ mask = np.random.choice(out.size, out.size - nnz, replace=False)
+ out[mask] = 0.0
+
+ out = np.tile(
+ np.reshape(out, [block_rows, 1, block_cols, 1]),
+ (1, blocking, 1, blocking))
+ out = np.reshape(out, [rows, cols])
+ return torch.from_numpy(out.astype(np.float32))
+
+
+@torch.no_grad()
+def mask(m, n, sparsity, blocking=1):
+ out = dense_mask(m, n, sparsity, blocking).type(torch.float16)
+ return matrix_ops.to_sparse(out, blocking=blocking)
+
+
+@torch.no_grad()
+def randn(shape, sparsity, blocking=1):
+ shape_2d = (np.prod(shape[:-1]), shape[-1])
+ out = mask(*shape_2d, sparsity, blocking)
+ out.data.copy_(torch.randn(*out.data.shape))
+ return out.view(*shape)
diff --git a/build/torch211-cxx11-cu130-x86_64-linux/stk/random/random_ops_test.py b/build/torch211-cxx11-cu130-x86_64-linux/stk/random/random_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..71d716b78b5ec009cbf9ac2dfdf09162a0102e62
--- /dev/null
+++ b/build/torch211-cxx11-cu130-x86_64-linux/stk/random/random_ops_test.py
@@ -0,0 +1,73 @@
+import unittest
+
+# from absl.testing import parameterized
+from . import random
+import torch
+
+
+# @parameterized.parameters(
+# (8, 16, 0.0, 1),
+# (8, 16, 0.5, 1),
+# (8, 16, .95, 1),
+# (16, 8, 0.0, 1),
+# (16, 8, 0.5, 1),
+# (16, 8, .95, 1),
+# (8, 16, 0.0, 8),
+# (8, 16, 0.5, 8),
+# (8, 16, 1.0, 8),
+# (16, 8, 0.0, 8),
+# (16, 8, 0.5, 8),
+# (16, 8, 1.0, 8),
+# (128, 256, 0.5, 16),
+# (256, 128, 0.75, 32),
+# (512, 512, .875, 128))
+# class RandomOpsTest(parameterized.TestCase):
+#
+# def testRandomOps_DenseMask(self, rows, cols, sparsity, blocking):
+# mask = random.dense_mask(
+# rows, cols, sparsity, blocking)
+#
+# # Validate the shape.
+# self.assertEqual(mask.dim(), 2)
+# self.assertEqual(mask.size()[0], rows)
+# self.assertEqual(mask.size()[1], cols)
+#
+# # Validate the sparsity
+# numblocks = rows // blocking * cols // blocking
+# nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+# self.assertEqual(
+# torch.count_nonzero(mask).item(),
+# nnz)
+#
+# # Check values are zero or one.
+# self.assertTrue(
+# torch.all(torch.logical_or(
+# torch.eq(mask, 0),
+# torch.eq(mask, 1))))
+#
+# def testRandomOps_SparseMask(self, rows, cols, sparsity, blocking):
+# mask = random.mask(
+# rows, cols, sparsity, blocking)
+#
+# # Validate the matrix.
+# mask.validate()
+#
+# # Validate the shape.
+# self.assertEqual(mask.dim(), 2)
+# self.assertEqual(mask.size()[0], rows)
+# self.assertEqual(mask.size()[1], cols)
+#
+# # Validate the sparsity.
+# numblocks = rows // blocking * cols // blocking
+# nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+# self.assertEqual(mask.nnz, nnz)
+#
+# # Check values are zero or one.
+# self.assertTrue(
+# torch.all(torch.logical_or(
+# torch.eq(mask.data, 0),
+# torch.eq(mask.data, 1))))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch211-cxx11-cu130-x86_64-linux/xpu_fused_moe.py b/build/torch211-cxx11-cu130-x86_64-linux/xpu_fused_moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..0440b484e6f4073a384dcf3bc562174d2ef71d25
--- /dev/null
+++ b/build/torch211-cxx11-cu130-x86_64-linux/xpu_fused_moe.py
@@ -0,0 +1,672 @@
+# SPDX-License-Identifier: Apache-2.0
+# MegaBlocks XPU Fused MoE Implementation
+import os
+import torch
+
+from ._ops import ops, add_op_namespace_prefix
+
+from torch.library import register_fake
+
+
+def resolve_dtensor(weight: torch.Tensor):
+ """Convert DTensor to local tensor for use with custom ops."""
+ from torch.distributed._tensor import DTensor
+ if isinstance(weight, DTensor):
+ return weight.to_local()
+ return weight
+
+
+# Register fake/meta kernels for torch.compile compatibility
+def _register_xpu_fake_kernels():
+ """Register fake kernels for XPU MoE operations to support torch.compile."""
+
+ def _register_if_available(op_name, fn):
+ if hasattr(ops, op_name):
+ register_fake(add_op_namespace_prefix(op_name))(fn)
+
+ _register_if_available(
+ "cutlass_grouped_gemm_interface",
+ lambda ptr_A, ptr_B, ptr_scales, ptr_bias, ptr_D, expert_first_token_offset, N, K, num_experts, is_B_int4, is_B_mxfp4: ptr_D,
+ )
+
+ _register_if_available(
+ "fused_moe_prologue",
+ lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, ep_rank, ep_size, num_experts_on_rank: None,
+ )
+
+ _register_if_available(
+ "moe_gather",
+ lambda output, moe_output, topk_weights, permuted_row_to_unpermuted_row, unpermuted_row_to_permuted_row, expert_first_token_offset, num_experts: None,
+ )
+
+ _register_if_available(
+ "silu_and_mul",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "mul_and_silu",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_and_mul",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_tanh_and_mul",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_fast",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_new",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_quick",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "swigluoai_and_mul",
+ lambda out, input, alpha=1.702, limit=7.0: None,
+ )
+
+
+# Register fake kernels on module load
+_register_xpu_fake_kernels()
+
+
+# default
+def cutlass_grouped_gemm(input_A, input_B, bias, output, expert_token_count, n,
+ k, num_experts):
+ # expert_token_count_ = torch.tensor(expert_token_count,
+ # dtype=torch.int64,
+ # device=input_A.device)
+ # if bias is not None:
+ # bias = bias.repeat_interleave(expert_token_count_, dim=0).float()
+
+ def exclusive_prefix_sum(arr):
+ prefix = [0]
+ for i, x in enumerate(arr):
+ prefix.append(prefix[-1] + x)
+ return prefix
+
+ expert_offset = torch.tensor(exclusive_prefix_sum(expert_token_count),
+ dtype=torch.int64,
+ device="xpu")
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=None,
+ ptr_bias=bias,
+ ptr_D=output,
+ expert_first_token_offset=expert_offset,
+ N=n,
+ K=k,
+ num_experts=num_experts,
+ is_B_int4=False,
+ is_B_mxfp4=False)
+
+
+def cutlass_grouped_gemm_xe2(input_A, input_B, scales, bias, output,
+ num_rows_per_expert, n, k, num_experts, is_B_int4,
+ is_B_mxfp4):
+ expert_first_token_offset = torch.cat([
+ torch.tensor([0],
+ dtype=num_rows_per_expert.dtype,
+ device=num_rows_per_expert.device),
+ torch.cumsum(num_rows_per_expert, dim=0)
+ ]).to(torch.int64)
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=scales,
+ ptr_bias=bias,
+ ptr_D=output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=n,
+ K=k,
+ num_experts=num_experts,
+ is_B_int4=is_B_int4,
+ is_B_mxfp4=is_B_mxfp4)
+
+
+def ceilDiv(a, b):
+ return (a + b - 1) // b
+
+
+def compute_num_tokens_per_block(num_tokens, num_experts_per_node):
+ for num_tokens_per_block in [32, 64, 128, 256, 512, 1024]:
+ num_blocks_per_seq = ceilDiv(num_tokens, num_tokens_per_block)
+ if num_blocks_per_seq * num_experts_per_node <= num_tokens_per_block:
+ return num_tokens_per_block
+ return 1024
+
+
+def _bytes_to_typed_tensor(byte_tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
+ """Reinterpret a uint8 buffer as a typed tensor by copying bytes.
+
+ This avoids `Tensor.view(dtype)` which can fail under torch.compile
+ constant folding when shape divisibility is not proven.
+ """
+ if byte_tensor.dtype != torch.uint8:
+ raise ValueError("byte_tensor must be uint8")
+ itemsize = torch.empty((), dtype=dtype).element_size()
+ numel = byte_tensor.numel() // itemsize
+ out = torch.empty((numel,), dtype=dtype, device=byte_tensor.device)
+ out.view(torch.uint8).copy_(byte_tensor.contiguous())
+ return out
+
+
+def implement_zp(qweight):
+ # change u4 to s4 to avoid zero point in gemm kernel
+ # only support default zero point now
+ assert qweight.dtype == torch.uint8, "Input tensor must be uint8"
+
+ high_u4 = (qweight >> 4) & 0x0F
+ low_u4 = qweight & 0x0F
+
+ high_s8 = high_u4.to(torch.int8)
+ low_s8 = low_u4.to(torch.int8)
+
+ high_s8 = high_s8 - 8
+ low_s8 = low_s8 - 8
+
+ def pack_compact(a, b):
+
+ def process_number(x):
+ sign = (x < 0).to(torch.uint8)
+ abs_low3 = (x.view(torch.uint8) & 0x7).to(torch.uint8)
+ return (sign << 3) | abs_low3
+
+ packed_a = process_number(a)
+ packed_b = process_number(b)
+
+ return (packed_a << 4) | packed_b
+
+ result = pack_compact(high_s8, low_s8)
+
+ return result
+
+
+def xpu_fused_moe(hidden_states,
+ w13,
+ w13_scales,
+ w13_bias,
+ w2,
+ w2_scales,
+ w2_bias,
+ topk_weights,
+ topk_ids,
+ n_experts_per_token,
+ activation,
+ num_experts,
+ ep_rank=0,
+ ep_size=1,
+ is_fp8=False,
+ is_int4=False,
+ is_mxfp4=False):
+ '''
+ hidden_states: [num_rows, hidden_size]
+ w13: [num_experts, 2*inter_size, hidden_size]
+ w13_scales:
+ None for bf16/fp16
+ or [num_experts] for fp8
+ or [num_experts, 2*inter_size, hidden_size // group_size] for 4bits
+ w13_bias: [num_experts, 2*inter_size] or None
+ w2: [num_experts, hidden_size, inter_size]
+ w2_scales:
+ None for bf16/fp16
+ or [num_experts] for fp8
+ or [num_experts, hidden_size, inter_size // group_size] for 4bits
+ w2_bias: [num_experts, hidden_size] or None
+ topk_weights: [num_rows, topk]
+ topk_ids: [num_rows, topk]
+ n_experts_per_token: int
+ activation: str
+ num_experts: int
+ is_int4: bool
+ is_mxfp4: bool
+ '''
+
+ # Resolve DTensors to local tensors before passing to custom ops
+ hidden_states = resolve_dtensor(hidden_states)
+ w13 = resolve_dtensor(w13)
+ w2 = resolve_dtensor(w2)
+ if w13_scales is not None:
+ w13_scales = resolve_dtensor(w13_scales)
+ if w13_bias is not None:
+ w13_bias = resolve_dtensor(w13_bias)
+ if w2_scales is not None:
+ w2_scales = resolve_dtensor(w2_scales)
+ if w2_bias is not None:
+ w2_bias = resolve_dtensor(w2_bias)
+ topk_weights = resolve_dtensor(topk_weights)
+ topk_ids = resolve_dtensor(topk_ids)
+
+ output = torch.empty_like(hidden_states)
+ num_rows, hidden_size = list(hidden_states.shape)
+
+ dim_last = w13.shape[-1]
+ dim_second_last = w13.shape[-2]
+
+ # w13 is combined gate+up weights, so one dimension is 2*inter_size
+ # Determine which dimension is hidden_size and which is 2*inter_size
+ if dim_second_last == hidden_size:
+ # w13 is [E, hidden_size, 2*inter_size] - standard layout
+ inter_size = dim_last // 2
+ needs_transpose = False
+ else:
+ # w13 is [E, 2*inter_size, hidden_size] - needs transpose
+ inter_size = dim_second_last // 2
+ needs_transpose = True
+
+ assert w13.is_contiguous() and w2.is_contiguous()
+
+ # 4bits support [E, N, K]
+ # other types [E, K, N]
+ if not is_int4 and not is_mxfp4:
+ if not hasattr(w13, 'xpu_fused_moe'):
+ if needs_transpose:
+ w13.data = w13.transpose(-1, -2).contiguous()
+ w2.data = w2.transpose(-1, -2).contiguous()
+ w13.xpu_fused_moe = True
+ w13.inter_size = inter_size
+ else:
+ inter_size = w13.inter_size
+
+ if is_int4 and not hasattr(w13, 'xpu_fused_moe'):
+ w13_tmp = torch.empty_like(w13)
+ w2_tmp = torch.empty_like(w2)
+ for i in range(num_experts):
+ w13_tmp[i] = implement_zp(w13[i])
+ w2_tmp[i] = implement_zp(w2[i])
+ w13_tmp = w13_tmp.contiguous()
+ w2_tmp = w2_tmp.contiguous()
+ w13.data = w13_tmp
+ w2.data = w2_tmp
+ w13.xpu_fused_moe = True
+
+ # TODO: will all integrated in Cpp func. Temporary expose before gemm fusion
+ num_experts_per_node = num_experts
+ experts_per_token = n_experts_per_token
+ num_moe_inputs = n_experts_per_token * num_rows
+ permuted_elems = num_moe_inputs * hidden_size
+ # interbuf_elems = num_moe_inputs * inter_size
+ permuted_row_to_unpermuted_row_size = num_moe_inputs * 4
+ permuted_token_selected_experts_size = num_moe_inputs * 4
+ src_to_dest_map_size = experts_per_token * num_rows * 4
+ expert_first_token_offset_size = (num_experts_per_node + 1) * 8
+ num_tokens_per_block = compute_num_tokens_per_block(
+ num_rows, num_experts_per_node)
+ num_blocks_per_seq = ceilDiv(num_rows, num_tokens_per_block)
+ blocked_expert_counts_size = num_experts_per_node * num_blocks_per_seq * 4
+ blocked_expert_counts_cumsum_size = blocked_expert_counts_size
+ blocked_row_to_unpermuted_row_size = num_experts_per_node * num_rows * 4
+ permuted_data_size = permuted_elems * hidden_states.element_size()
+ permuted_token_final_scales_size = num_moe_inputs * 4
+
+ ws_map = {}
+ map_offset = 0
+
+ def config_ws(name, size):
+ nonlocal map_offset
+ if size % 256 != 0:
+ size += 256 - size % 256
+ ws_map[name] = (size, map_offset)
+ map_offset += size
+
+ config_ws("permuted_row_to_unpermuted_row",
+ permuted_row_to_unpermuted_row_size)
+ config_ws("permuted_token_selected_experts",
+ permuted_token_selected_experts_size)
+ config_ws("unpermuted_row_to_permuted_row", src_to_dest_map_size)
+ config_ws("blocked_expert_counts", blocked_expert_counts_size)
+ config_ws("blocked_expert_counts_cumsum",
+ blocked_expert_counts_cumsum_size)
+ config_ws("blocked_row_to_unpermuted_row",
+ blocked_row_to_unpermuted_row_size)
+ config_ws("expert_first_token_offset", expert_first_token_offset_size)
+ config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
+ config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
+
+ workspace = torch.zeros(map_offset,
+ dtype=torch.uint8,
+ device=hidden_states.device)
+ if topk_ids.dtype == torch.int32:
+ topk_ids = topk_ids.to(torch.int64)
+ ops.fused_moe_prologue(
+ input=hidden_states,
+ token_selected_experts=topk_ids,
+ token_final_scales=topk_weights,
+ workspace=workspace,
+ hidden_size=hidden_size,
+ inter_size=inter_size,
+ ep_rank=ep_rank,
+ ep_size=ep_size,
+ num_experts_on_rank=num_experts_per_node)
+
+ expert_first_token_offset_bytes = workspace[
+ ws_map["expert_first_token_offset"][1]:
+ ws_map["expert_first_token_offset"][1] +
+ expert_first_token_offset_size]
+ unpermuted_row_to_permuted_row_bytes = workspace[
+ ws_map["unpermuted_row_to_permuted_row"][1]:
+ ws_map["unpermuted_row_to_permuted_row"][1] +
+ src_to_dest_map_size]
+ permuted_row_to_unpermuted_row_bytes = workspace[
+ ws_map["permuted_row_to_unpermuted_row"][1]:
+ ws_map["permuted_row_to_unpermuted_row"][1] +
+ permuted_row_to_unpermuted_row_size]
+
+ if torch.compiler.is_compiling():
+ expert_first_token_offset = _bytes_to_typed_tensor(
+ expert_first_token_offset_bytes, torch.int64
+ )
+ unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
+ unpermuted_row_to_permuted_row_bytes, torch.int32
+ )
+ permuted_row_to_unpermuted_row = _bytes_to_typed_tensor(
+ permuted_row_to_unpermuted_row_bytes, torch.int32
+ )
+ else:
+ expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
+ unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
+ permuted_row_to_unpermuted_row = permuted_row_to_unpermuted_row_bytes.view(torch.int32)
+ gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
+ ws_map["overlapped_gemm1_gemm2_inputs"][1] +
+ permuted_data_size].view(hidden_states.dtype).view(
+ num_moe_inputs, hidden_size)
+ # permuted_token_final_scales = workspace[
+ # ws_map["permuted_token_final_scales"][1]:
+ # ws_map["permuted_token_final_scales"][1] +
+ # permuted_token_final_scales_size].view(torch.float)
+ gemm1_output = torch.empty((num_moe_inputs, 2 * inter_size),
+ dtype=hidden_states.dtype,
+ device=hidden_states.device)
+
+ ########### gemm1 ##################
+ input_B = w13
+
+ if not is_fp8 and not is_int4 and not is_mxfp4:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=gemm1_input,
+ ptr_B=input_B,
+ ptr_scales=None,
+ ptr_bias=w13_bias,
+ ptr_D=gemm1_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=2 * inter_size,
+ K=hidden_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+ else:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=gemm1_input,
+ ptr_B=input_B,
+ ptr_scales=w13_scales,
+ ptr_bias=w13_bias,
+ ptr_D=gemm1_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=2 * inter_size,
+ K=hidden_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+
+ # act
+ act_output = torch.empty((num_moe_inputs, inter_size),
+ dtype=gemm1_output.dtype,
+ device=gemm1_output.device)
+ if activation == "silu":
+ ops.silu_and_mul(act_output, gemm1_output)
+ elif activation == "gelu":
+ ops.gelu_and_mul(act_output, gemm1_output)
+ elif activation == "swigluoai":
+ ops.swigluoai_and_mul(act_output, gemm1_output, 1.702, 7.0)
+ else:
+ raise ValueError(f"Unsupported FusedMoe activation: {activation}.")
+
+ ########### gemm2 ##################
+ input_A = act_output.contiguous()
+ input_B = w2
+ gemm2_output = torch.empty((num_moe_inputs, hidden_size),
+ dtype=hidden_states.dtype,
+ device=hidden_states.device)
+ if not is_fp8 and not is_int4 and not is_mxfp4:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=None,
+ ptr_bias=w2_bias,
+ ptr_D=gemm2_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=hidden_size,
+ K=inter_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+ else:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=w2_scales,
+ ptr_bias=w2_bias,
+ ptr_D=gemm2_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=hidden_size,
+ K=inter_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+
+ ops.moe_gather(output, gemm2_output, topk_weights,
+ permuted_row_to_unpermuted_row,
+ unpermuted_row_to_permuted_row,
+ expert_first_token_offset,
+ num_experts_per_node)
+ return output
+
+
+def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor:
+ """Apply jitter to the input tensor for regularization."""
+ low = 1.0 - moe_jitter_eps
+ high = 1.0 + moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return x * (low + noise * (high - low))
+
+
+def compute_top_k(scores: torch.Tensor, moe_top_k: int):
+ """Compute the top-k scores from the logits."""
+ if moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, moe_top_k, dim=-1)
+
+
+def route_tokens_xpu(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: torch.Tensor,
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ training: bool = False,
+) -> tuple:
+ """Route tokens to experts and compute expert weights and indices (XPU version)."""
+ if training and moe_jitter_eps is not None:
+ x = apply_jitter(x, moe_jitter_eps)
+
+ x_flat = x.view(-1, x.shape[-1])
+ logits = torch.nn.functional.linear(x_flat, router_weight, router_bias)
+ expert_weights, expert_indices = compute_top_k(logits, moe_top_k)
+ expert_weights = expert_weights.softmax(dim=-1)
+ if moe_normalize_expert_weights is not None:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+
+ return logits, expert_weights, expert_indices
+
+
+def _get_device_mesh(model):
+ """Extract device_mesh from child's unused pre_hook closure for EP support."""
+ try:
+ hook = next(
+ h
+ for h in model.experts._forward_pre_hooks.values()
+ if "device_mesh" in h.__code__.co_freevars
+ )
+ return hook.__closure__[
+ hook.__code__.co_freevars.index("device_mesh")
+ ].cell_contents
+ except Exception:
+ return None
+
+
+class MegaBlocksMoeMLP(torch.nn.Module):
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> tuple:
+ """
+ Forward pass through the MoE layer.
+
+ Args:
+ x: Input tensor of shape [batch_size, seq_len, hidden_size] or [tokens, hidden_size]
+
+ Returns:
+ Tuple of (output, expert_weights) where:
+ - output: Tensor of same shape as input
+ - expert_weights: Expert weights for each token [tokens, top_k]
+ """
+
+ # Get MoE parameters from the wrapped modules
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+
+ # Get EP (Expert Parallelism) parameters
+ ep_size = 1
+ ep_rank = 0
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = _get_device_mesh(self)
+ if device_mesh is not None:
+ expert_parallel_group = device_mesh.get_group()
+ if expert_parallel_group is not None:
+ import torch.distributed as dist
+ if dist.is_initialized():
+ ep_size = dist.get_world_size(expert_parallel_group)
+ ep_rank = dist.get_rank(expert_parallel_group)
+
+ # Number of experts on this rank
+ num_experts_on_rank = moe_num_experts // ep_size
+
+ # Detect activation type - check for GptOss-style swigluoai activation
+ # GptOssExperts has alpha and limit attributes for swigluoai
+ if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
+ activation = "swigluoai"
+ else:
+ activation = getattr(self.experts, "activation", "silu")
+
+ # Get weight tensors - support different naming conventions
+ if hasattr(self.experts, "gate_up_proj"):
+ w13 = self.experts.gate_up_proj
+ # NOTE: swigluoai_and_mul kernel expects interleaved layout [g0,u0,g1,u1,...]
+ # which matches GptOss's gate_up_proj layout, so no conversion needed.
+
+ elif hasattr(self.experts, "w1"):
+ # Combine w1 and w3 if stored separately
+ w1 = self.experts.w1
+ w3 = getattr(self.experts, "w3", None)
+ if w3 is not None:
+ w13 = torch.cat([w1, w3], dim=-2)
+ else:
+ w13 = w1
+ else:
+ raise AttributeError("experts module must have 'gate_up_proj' or 'w1' attribute")
+
+ if hasattr(self.experts, "down_proj"):
+ w2 = self.experts.down_proj
+ elif hasattr(self.experts, "w2"):
+ w2 = self.experts.w2
+ else:
+ raise AttributeError("experts module must have 'down_proj' or 'w2' attribute")
+
+ # Get optional bias tensors
+ w13_bias = getattr(self.experts, "gate_up_proj_bias", None)
+ w2_bias = getattr(self.experts, "down_proj_bias", None)
+
+ # Get quantization info
+ is_fp8 = getattr(self.experts, "is_fp8", False)
+ is_int4 = getattr(self.experts, "is_int4", False)
+ is_mxfp4 = getattr(self.experts, "is_mxfp4", False)
+
+ w13_scales = getattr(self.experts, "gate_up_proj_scales", None)
+ w2_scales = getattr(self.experts, "down_proj_scales", None)
+
+ # Store original shape
+ in_shape = x.size()
+
+ # Route tokens to experts
+ logits, expert_weights, expert_indices = route_tokens_xpu(
+ x,
+ self.router.weight,
+ getattr(self.router, "bias", None),
+ moe_top_k,
+ moe_num_experts,
+ moe_jitter_eps,
+ moe_normalize_expert_weights,
+ self.training,
+ )
+
+ # Reshape input for fused MoE
+ x_flat = x.view(-1, x.shape[-1])
+
+ # Call XPU fused MoE kernel
+ output = xpu_fused_moe(
+ hidden_states=x_flat,
+ w13=w13,
+ w13_scales=w13_scales,
+ w13_bias=w13_bias,
+ w2=w2,
+ w2_scales=w2_scales,
+ w2_bias=w2_bias,
+ topk_weights=expert_weights.float(),
+ topk_ids=expert_indices,
+ n_experts_per_token=moe_top_k,
+ activation=activation,
+ num_experts=num_experts_on_rank,
+ ep_rank=ep_rank,
+ ep_size=ep_size,
+ is_fp8=is_fp8,
+ is_int4=is_int4,
+ is_mxfp4=is_mxfp4,
+ )
+
+ # All-reduce across EP group to combine partial expert outputs
+ if ep_size > 1 and expert_parallel_group is not None:
+ import torch.distributed as dist
+ dist.all_reduce(output, op=dist.ReduceOp.SUM, group=expert_parallel_group)
+
+ # Restore original shape
+ output = output.view(in_shape)
+
+ return output, expert_weights
+
+
+# Export classes and functions
+__all__ = [
+ "MegaBlocksMoeMLP",
+ "xpu_fused_moe",
+ "cutlass_grouped_gemm",
+ "cutlass_grouped_gemm_xe2",
+]
\ No newline at end of file
diff --git a/build/torch211-cxx11-xpu20253-x86_64-linux/__init__.py b/build/torch211-cxx11-xpu20253-x86_64-linux/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..38075732c6d8fa0e1e6ef493145e1aca3851ae6b
--- /dev/null
+++ b/build/torch211-cxx11-xpu20253-x86_64-linux/__init__.py
@@ -0,0 +1,202 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+from ._ops import ops
+
+from .grouped_gemm import backend as gg_backend
+from .grouped_gemm import ops as gg_ops
+
+
+from ._layers.arguments import Arguments
+from ._layers.dmoe import ParallelDroplessMLP, dMoE
+from ._layers.glu import SparseGLU
+from ._layers.mlp import MLP, SparseMLP
+from ._layers.moe import MoE, ParallelMLP, get_load_balancing_loss
+
+from . import layers
+
+# This section contains the direct kernel exports (not inlcuded in the original code)
+def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
+ """
+ Compute exclusive cumulative sum along the specified dimension.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ result = ops.exclusive_cumsum(x, dim)
+ out.copy_(result)
+ return out
+
+
+def inclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
+ """
+ Compute inclusive cumulative sum along the specified dimension.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ result = ops.inclusive_cumsum(x, dim)
+ out.copy_(result)
+ return out
+
+
+def histogram(x: torch.Tensor, num_bins: int) -> torch.Tensor:
+ """
+ Compute histogram of input tensor values.
+
+ Args:
+ x: Input tensor
+ num_bins: Number of histogram bins
+
+ Returns:
+ Histogram tensor with counts for each bin
+ """
+ return ops.histogram(x, num_bins)
+
+
+def indices(
+ padded_bins: torch.Tensor,
+ block_size: int,
+ output_block_rows: int,
+ output_block_columns: int,
+) -> torch.Tensor:
+ """
+ Construct indices from padded bins for sparse operations.
+
+ Args:
+ padded_bins: Tensor containing bin boundaries
+ block_size: Size of each block
+ output_block_rows: Number of rows in output blocks
+ output_block_columns: Number of columns in output blocks
+
+ Returns:
+ Tensor containing constructed indices
+ """
+ return ops.indices(padded_bins, block_size, output_block_rows, output_block_columns)
+
+
+def replicate_forward(
+ x: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Forward pass of replicate operation - replicate values according to bin sizes.
+
+ Args:
+ x: Input tensor with values to replicate
+ bins: Tensor containing bin sizes
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ return ops.replicate_forward(x, bins, out)
+
+
+def replicate_backward(
+ grad: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Backward pass of replicate operation - reduce gradients back to bins.
+
+ Args:
+ grad: Gradient tensor to reduce
+ bins: Tensor containing bin sizes
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ return ops.replicate_backward(grad, bins, out)
+
+
+def sort(
+ x: torch.Tensor, end_bit: int, x_out: torch.Tensor, iota_out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Radix sort with index tracking.
+
+ Args:
+ x: Input tensor to sort
+ end_bit: Number of bits to consider in sorting
+ x_out: Output tensor for sorted values
+ iota_out: Output tensor for sorted indices
+
+ Returns:
+ The sorted values tensor
+ """
+ return ops.sort(x, end_bit, x_out, iota_out)
+
+
+# Convenience functions for common use cases
+def cumsum(x: torch.Tensor, dim: int = -1, exclusive: bool = False) -> torch.Tensor:
+ """
+ Compute cumulative sum with automatic output allocation.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum (default: last dimension)
+ exclusive: Whether to compute exclusive (True) or inclusive (False) cumsum
+
+ Returns:
+ New tensor containing the cumulative sum
+ """
+ out = torch.empty_like(x)
+ if exclusive:
+ return exclusive_cumsum(x, dim, out)
+ else:
+ return inclusive_cumsum(x, dim, out)
+
+
+def argsort(x: torch.Tensor, end_bit: int = 32) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Sort tensor and return both sorted values and indices.
+
+ Args:
+ x: Input tensor to sort
+ end_bit: Number of bits to consider in sorting
+
+ Returns:
+ Tuple of (sorted_values, sorted_indices)
+ """
+ x_out = torch.empty_like(x)
+ iota_out = torch.empty_like(x)
+ sort(x, end_bit, x_out, iota_out)
+ return x_out, iota_out
+
+
+# Export public API
+__all__ = [
+ "MyReplacementLayer",
+ # Direct kernel exports
+ "exclusive_cumsum",
+ "inclusive_cumsum",
+ "histogram",
+ "indices",
+ "replicate_forward",
+ "replicate_backward",
+ "sort",
+ "cumsum",
+ "argsort",
+ # Original exports
+ "Arguments",
+ "ParallelDroplessMLP",
+ "dMoE",
+ "SparseGLU",
+ "MLP",
+ "SparseMLP",
+ "MoE",
+ "ParallelMLP",
+ "get_load_balancing_loss",
+]
diff --git a/build/torch211-cxx11-xpu20253-x86_64-linux/_layers/__init__.py b/build/torch211-cxx11-xpu20253-x86_64-linux/_layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a720e7a2cc4e44636f6e433a2750e945dc38e8b2
--- /dev/null
+++ b/build/torch211-cxx11-xpu20253-x86_64-linux/_layers/__init__.py
@@ -0,0 +1,10 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# from megablocks.layers.dmoe import dMoE
+from .moe import MoE
+
+__all__ = [
+ 'MoE',
+ # 'dMoE',
+]
diff --git a/build/torch211-cxx11-xpu20253-x86_64-linux/_layers/activation_fn.py b/build/torch211-cxx11-xpu20253-x86_64-linux/_layers/activation_fn.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e1d956704840aa4daf7d1d71d24e051567feab9
--- /dev/null
+++ b/build/torch211-cxx11-xpu20253-x86_64-linux/_layers/activation_fn.py
@@ -0,0 +1,33 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Callable, Union
+
+import torch
+from ..stk import Matrix
+
+
+def act_fn(
+ x: Matrix,
+ function: Callable,
+ return_grad_fn: bool = False,
+ **kwargs,
+) -> Union[tuple[Matrix, Any] | Matrix]:
+ assert isinstance(x, Matrix)
+ with torch.set_grad_enabled(torch.is_grad_enabled() or return_grad_fn):
+ if return_grad_fn:
+ x.data.requires_grad = True
+ out = function(x.data, **kwargs)
+ y = Matrix(
+ x.size(),
+ out,
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
+ if return_grad_fn:
+ return y, out.backward
+ return y
diff --git a/build/torch211-cxx11-xpu20253-x86_64-linux/_layers/all_to_all.py b/build/torch211-cxx11-xpu20253-x86_64-linux/_layers/all_to_all.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ac7067bcaa34db1d82b340c43550fe3577aa7a3
--- /dev/null
+++ b/build/torch211-cxx11-xpu20253-x86_64-linux/_layers/all_to_all.py
@@ -0,0 +1,54 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import torch.distributed as dist
+
+
+class AllToAllOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op):
+ out = torch.empty((sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype)
+
+ ctx.input_shape = x.shape
+ ctx.output_split_sizes = output_split_sizes
+ ctx.input_split_sizes = input_split_sizes
+ ctx.group = group
+ handle = dist.all_to_all_single(
+ out,
+ x,
+ output_split_sizes=output_split_sizes,
+ input_split_sizes=input_split_sizes,
+ group=group,
+ async_op=async_op,
+ )
+ return out, handle
+
+ @staticmethod
+ def backward(ctx, grad, _):
+ if ctx.needs_input_grad[0]:
+ out = torch.empty(
+ ctx.input_shape,
+ device=grad.device,
+ dtype=grad.dtype,
+ )
+ dist.all_to_all_single(
+ out,
+ grad,
+ output_split_sizes=ctx.input_split_sizes,
+ input_split_sizes=ctx.output_split_sizes,
+ group=ctx.group,
+ )
+ return out, None, None, None, None
+ return None, None, None, None, None
+
+
+def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False):
+ return AllToAllOp.apply(
+ x,
+ output_split_sizes,
+ input_split_sizes,
+ group,
+ async_op,
+ )
diff --git a/build/torch211-cxx11-xpu20253-x86_64-linux/_layers/arguments.py b/build/torch211-cxx11-xpu20253-x86_64-linux/_layers/arguments.py
new file mode 100644
index 0000000000000000000000000000000000000000..4db9b1bd38bc2e2f421625c124f86b85f45c5ae0
--- /dev/null
+++ b/build/torch211-cxx11-xpu20253-x86_64-linux/_layers/arguments.py
@@ -0,0 +1,101 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import dataclasses
+from functools import partial
+from typing import Any, Callable, Optional, Union
+
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+
+# import megablocks.grouped_gemm_util as grouped_gemm
+from .. import grouped_gemm_util as grouped_gemm
+
+# Type annotation for in-place Tensor initialization function.
+InitFn = Union[Callable[[torch.Tensor], None], partial[torch.Tensor]]
+
+_ALLOWED_BITWIDTHS = (-1, 4, 8)
+
+DEFAULT_ACTIVATION_FN = partial(F.gelu, approximate='tanh')
+
+
+@dataclasses.dataclass
+class Arguments:
+ # Model arguments.
+ hidden_size: int = 1024
+ ffn_hidden_size: int = 4096
+ num_layers: int = 1
+ bias: bool = True
+ return_bias: bool = True
+ activation_fn: Optional[Callable] = DEFAULT_ACTIVATION_FN
+
+ # MoE arguments.
+ moe_num_experts: int = 1
+ moe_top_k: int = 1
+ moe_capacity_factor: int = 1
+ moe_normalize_expert_weights: Optional[Union[int, float]] = None
+ moe_loss_weight: float = 0.1
+ moe_jitter_eps: Optional[float] = None
+ moe_lbl_in_fp32: bool = False
+
+ # Parallelism arguments.
+ moe_expert_model_parallelism: bool = False
+ expert_parallel_group: Optional[dist.ProcessGroup] = None
+ pipeline_model_parallel_size: int = 1
+ num_layers_per_virtual_pipeline_stage: Optional[int] = None
+
+ # Compute arguments.
+ memory_optimized_mlp: bool = False
+ mlp_type: str = 'mlp'
+ mlp_impl: str = 'sparse'
+
+ # Initialization arguments.
+ fp16: bool = True
+ bf16: bool = False
+ device: Union[int, torch.device] = dataclasses.field(default_factory=torch.cuda.current_device)
+ init_method: InitFn = partial(torch.nn.init.normal_, mean=0.0, std=0.02)
+ output_layer_init_method: InitFn = init_method
+
+ # Benchmarking arguments.
+ uniform_expert_assignment: bool = False
+
+ # shared expert arguments
+ shared_expert: bool = False # enable using shared expert
+ fc_cls: Any = torch.nn.Linear # class of the fully connected layer in shared expert (purpose: to allow using custom FC layer eg te.Linear (for FP8))
+ fc_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict,) # kwargs for custom fc layers
+ remat_act_fn: bool = True # enable act fn to be rematerialized instead of stored
+ shared_expert_hidden_size: Optional[
+ int] = None # hidden size of the shared expert IF we want to set it to something different from hidden_size
+ shared_expert_weighted_sum: bool = False # enable using weighted sum for shared expert output (wieghted by number of experts used)
+
+ # Router Z-loss arguments
+ moe_zloss_weight: float = 0 # 1e-3 is a reasonable value
+ moe_zloss_in_fp32: bool = False
+
+ def __post_init__(self):
+ # Sparse MLP is not supported with triton >=3.2.0
+ # TODO: Remove this once sparse is supported with triton >=3.2.0
+ if self.__getattribute__('mlp_impl') == 'sparse':
+ try:
+ import triton
+ if triton.__version__ >= '3.2.0':
+ raise ValueError(
+ 'Sparse MLP is not supported with triton >=3.2.0. Please use mlp_impl="grouped" instead.',
+ )
+ except ImportError:
+ raise ImportError('Triton is required for sparse MLP implementation')
+
+ if self.__getattribute__('mlp_impl') == 'grouped':
+ grouped_gemm.assert_grouped_gemm_is_available()
+
+ if self.shared_expert_hidden_size is None:
+ self.shared_expert_hidden_size = self.ffn_hidden_size
+
+
+def from_megatron(megatron_args: Any):
+ args = Arguments()
+ for field in dataclasses.fields(args):
+ if hasattr(megatron_args, field.name):
+ setattr(args, field.name, getattr(megatron_args, field.name))
+ return args
diff --git a/build/torch211-cxx11-xpu20253-x86_64-linux/_layers/common.py b/build/torch211-cxx11-xpu20253-x86_64-linux/_layers/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d07109702963ba48a3b94ab860807954dfd79c1
--- /dev/null
+++ b/build/torch211-cxx11-xpu20253-x86_64-linux/_layers/common.py
@@ -0,0 +1,26 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+from .arguments import Arguments
+
+
+def dtype(args: Arguments):
+ if args.fp16:
+ return torch.float16
+ elif args.bf16:
+ return torch.bfloat16
+ return None
+
+
+def cast_if_autocast_enabled(tensor):
+ if torch.is_autocast_enabled():
+ if tensor.device.type == 'cuda':
+ dtype = torch.get_autocast_gpu_dtype()
+ elif tensor.device.type == 'cpu':
+ dtype = torch.get_autocast_cpu_dtype()
+ else:
+ raise NotImplementedError()
+ return tensor.to(dtype=dtype)
+ return tensor
diff --git a/build/torch211-cxx11-xpu20253-x86_64-linux/_layers/dmlp_registry.py b/build/torch211-cxx11-xpu20253-x86_64-linux/_layers/dmlp_registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..de2ed047042e438c7190ebb139b6f7f30009734c
--- /dev/null
+++ b/build/torch211-cxx11-xpu20253-x86_64-linux/_layers/dmlp_registry.py
@@ -0,0 +1,42 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Union
+
+from . import glu, mlp
+from .arguments import Arguments
+
+MlpType = Union[mlp.SparseMLP, glu.SparseGLU]
+
+_REGISTRY = {
+ 'mlp': {
+ 'grouped': mlp.GroupedMLP,
+ 'sparse': mlp.SparseMLP,
+ },
+ 'glu': {
+ 'grouped': glu.GroupedGLU,
+ 'sparse': glu.SparseGLU,
+ },
+}
+
+
+def get(args: Arguments) -> MlpType:
+ """Returns an MLP for use in a dMoE instance.
+
+ Uses the provided arguments to instantiate the appropriate
+ MLP instance. This only contains MLPs for use in dMoEs
+ (ie. only for the dropless versions of MoEs).
+
+ Args:
+ args: propagated Arguments dataclass.
+
+ Returns:
+ An instantiated MLP constructed using the input args.
+ """
+ if args.mlp_type not in _REGISTRY:
+ raise ValueError(f'Unsupported mlp type: {args.mlp_type}')
+
+ if args.mlp_impl not in _REGISTRY[args.mlp_type]:
+ raise ValueError(f'{args.mlp_type} does not support {args.mlp_impl} backend.',)
+
+ return _REGISTRY[args.mlp_type][args.mlp_impl](args)
diff --git a/build/torch211-cxx11-xpu20253-x86_64-linux/_layers/dmoe.py b/build/torch211-cxx11-xpu20253-x86_64-linux/_layers/dmoe.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d0375a4df2f27134c4127e60be04f3b45693050
--- /dev/null
+++ b/build/torch211-cxx11-xpu20253-x86_64-linux/_layers/dmoe.py
@@ -0,0 +1,337 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import numpy as np
+import torch
+
+# try:
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/dmoe.py is needed.',
+# )
+
+# import megablocks.ops as ops
+# # from megablocks.ops import ops
+# from megablocks.layers import common, dmlp_registry, moe, mpu
+# from megablocks.layers.arguments import Arguments
+
+from .. import stk
+from .. import ops
+from . import common, dmlp_registry, moe, mpu
+from .arguments import Arguments
+
+def promote_scalar(x):
+ return x.view(1) if not len(x.size()) else x
+
+
+class ParallelDroplessMLP(moe.ParallelMLP):
+
+ def __init__(self, args: Arguments):
+ super(ParallelDroplessMLP, self).__init__(args)
+ self.hidden_size = args.hidden_size
+ self.ffn_hidden_size = mpu.features_per_rank(args)
+ self.blocking = 128
+ self.mlp = dmlp_registry.get(args)
+
+ # Calculate the number of bits needed to represent the column indices
+ # in the intermediate sparse matrix.
+ max_column_index = ((self.ffn_hidden_size * self.num_experts) // self.blocking)
+ self.transpose_sort_end_bit = max(
+ int(np.ceil(np.log2(max_column_index))),
+ 1,
+ )
+
+ def sparse_transpose(self, size, row_indices, column_indices, offsets):
+ block_columns = size[1] // self.blocking
+
+ # Sort row indices by column indices to get the transposed matrix's
+ # column indices.
+ #
+ # NOTE: Our sort operation uses the same width indices as the input values.
+ # To avoid overflow when we have large activation matrices we cast to
+ # 32-bit before sorting.
+ _, gather_indices = ops.sort(
+ column_indices.int(),
+ self.transpose_sort_end_bit,
+ )
+
+ # There are a constant number of blocks in every row of the sparse matrix.
+ # A blocks offset is:
+ #
+ # row_index * blocks_per_row + column_index % blocks_per_row
+ #
+ # Once we have the block offsets ordered for transposition we can divide
+ # by blocks_per_row to get the transposed column indices.
+ column_indices_t = row_indices.gather(0, gather_indices.long())
+ block_offsets_t = gather_indices.int()
+
+ zero = torch.zeros((1,), dtype=torch.int32, device=row_indices.device)
+ nnz_per_column = ops.histogram(column_indices, block_columns)
+ nnz_per_column = ops.inclusive_cumsum(nnz_per_column, 0)
+ if nnz_per_column.dim() == 0:
+ # This addresses an edge case when ffn_hidden_size is equal to self.blocking.
+ nnz_per_column = nnz_per_column.unsqueeze(0)
+ offsets_t = torch.cat([zero, nnz_per_column])
+ return column_indices_t, offsets_t, block_offsets_t
+
+ def topology(self, x, padded_bins):
+ padded_tokens, _ = x.size()
+ assert padded_tokens % self.blocking == 0
+ if self.ffn_hidden_size % self.blocking != 0:
+ raise ValueError(
+ f'The ffn_hidden_size {self.ffn_hidden_size} must be divisible by ' +
+ f'the block size {self.blocking}. Please update your configuration.',
+ )
+
+ # Offsets for the sparse matrix. All rows have the
+ # same number of nonzero blocks dictated by the
+ # dimensionality of a single expert.
+ block_rows = padded_tokens // self.blocking
+ blocks_per_row = self.ffn_hidden_size // self.blocking
+ offsets = torch.arange(
+ 0,
+ block_rows * blocks_per_row + 1,
+ blocks_per_row,
+ dtype=torch.int32,
+ device=x.device,
+ )
+
+ # Indices for the sparse matrix. The indices for
+ # the intermediate matrix are dynamic depending
+ # on the mapping of tokens to experts.
+ column_indices = ops.topology(
+ padded_bins,
+ self.blocking,
+ block_rows,
+ blocks_per_row,
+ )
+
+ # TODO(tgale): This is unused. Remove the need for this in stk.
+ # For now, use meta init to save the device memory.
+ data = torch.empty(
+ column_indices.numel(),
+ self.blocking,
+ self.blocking,
+ dtype=common.dtype(self.args),
+ device='meta',
+ )
+ shape = (
+ padded_tokens,
+ self.ffn_hidden_size * mpu.experts_per_rank(self.args),
+ )
+ row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
+ column_indices_t, offsets_t, block_offsets_t = self.sparse_transpose(
+ shape,
+ row_indices,
+ column_indices,
+ offsets,
+ )
+ return stk.Matrix(
+ shape,
+ data,
+ row_indices,
+ column_indices,
+ offsets,
+ column_indices_t,
+ offsets_t,
+ block_offsets_t,
+ )
+
+ def indices_and_padded_bins(self, top_experts):
+ # Sort the expert ids to produce the scatter/gather
+ # indices for the permutation.
+ top_experts = top_experts.int()
+ bin_ids, indices = ops.sort(top_experts, self.sort_end_bit)
+
+ # Histogram the expert ids to identify the number of
+ # tokens routed to each expert.
+ tokens_per_expert = ops.histogram(top_experts, self.num_experts)
+
+ # Round the token counts up to the block size used in
+ # the matrix muliplications. Caculate the starting
+ # position of each bin.
+ padded_tokens_per_expert = ops.round_up(
+ tokens_per_expert,
+ self.blocking,
+ )
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ padded_bins = promote_scalar(padded_bins)
+
+ # Calculate the bin bounds for the sorted tokens.
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ bins = promote_scalar(bins)
+ return indices, bin_ids, bins, padded_bins, tokens_per_expert
+
+ def sparse_forward_once(self, x, expert_weights, top_experts):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, padded_bins, tokens_per_expert = (self.indices_and_padded_bins(top_experts))
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.padded_gather(
+ x,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ self.top_k,
+ )
+
+ # Create the sparse matrix topology.
+ with torch.no_grad():
+ topo = self.topology(x, padded_bins)
+
+ # Perform the expert computation.
+ x = self.mlp(x, topo)
+
+ # Un-route the data for the MoE output.
+ x = ops.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ padded_bins,
+ self.top_k,
+ )
+ return x, tokens_per_expert
+
+ # For use in the base-class parallel_forward_once.
+ def sparse_permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy, # unused
+ top_k,
+ ):
+
+ # Round the token counts up to the block size used in the matrix
+ # multiplication. Calculate the starting position of each bin.
+ padded_tokens_per_expert = ops.round_up(
+ tokens_per_expert,
+ self.blocking,
+ )
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ padded_bins = promote_scalar(padded_bins)
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
+
+ # Create the sparse matrix topology.
+ with torch.no_grad():
+ topo = self.topology(x, padded_bins)
+
+ # Perform the expert computation.
+ x = self.mlp(x, topo)
+
+ # Un-route the data for the MoE output.
+ return ops.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ def grouped_forward_once(self, x, expert_weights, top_experts):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ out = self.grouped_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ -1, # unused
+ self.args.moe_top_k,
+ )
+ return out, tokens_per_expert
+
+ def grouped_permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy, # unused
+ top_k,
+ ):
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.gather(x, indices, bin_ids, bins, top_k)
+
+ # Perform the expert computation.
+ x = self.mlp(x, tokens_per_expert)
+
+ # Un-route the data for the MoE output.
+ return ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
+
+ def forward_once(self, x, expert_weights, top_experts):
+ if self.args.mlp_impl == 'sparse':
+ return self.sparse_forward_once(x, expert_weights, top_experts)
+ else:
+ return self.grouped_forward_once(x, expert_weights, top_experts)
+
+ def permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ ):
+ if self.args.mlp_impl == 'sparse':
+ return self.sparse_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ )
+ else:
+ return self.grouped_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ )
+
+
+class dMoE(moe.MoE):
+
+ def _init_experts_mlp(self, args: Arguments):
+ return ParallelDroplessMLP(args)
diff --git a/build/torch211-cxx11-xpu20253-x86_64-linux/_layers/gelu.py b/build/torch211-cxx11-xpu20253-x86_64-linux/_layers/gelu.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4c9e6532798615b5c12c96694241a4c18ee8f7b
--- /dev/null
+++ b/build/torch211-cxx11-xpu20253-x86_64-linux/_layers/gelu.py
@@ -0,0 +1,52 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# try:
+# import stk
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/gelu.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+import torch.nn.functional as F
+
+
+@torch.jit.script
+def _gelu_backward_inplace(g, x):
+ tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
+ ff = (0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out))
+ return g.mul_(ff)
+
+
+def gelu_backward_(grad: stk.Matrix, x: stk.Matrix):
+ # NOTE: The two sparse matrices must have the same topology.
+ if isinstance(grad, stk.Matrix) and isinstance(x, stk.Matrix):
+ return stk.Matrix(
+ x.size(),
+ _gelu_backward_inplace(grad.data, x.data),
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
+ return _gelu_backward_inplace(grad, x)
+
+
+def gelu(x: stk.Matrix):
+ assert isinstance(x, stk.Matrix)
+ return stk.Matrix(
+ x.size(),
+ F.gelu(x.data, approximate='tanh'),
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
diff --git a/build/torch211-cxx11-xpu20253-x86_64-linux/_layers/glu.py b/build/torch211-cxx11-xpu20253-x86_64-linux/_layers/glu.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f297a41ff6a1a2a285f5b461951672364b898da
--- /dev/null
+++ b/build/torch211-cxx11-xpu20253-x86_64-linux/_layers/glu.py
@@ -0,0 +1,244 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# import stk.ops
+# try:
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/glu.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+
+# from megablocks import grouped_gemm_util as gg
+# from megablocks.layers import common, mpu
+# from megablocks.layers.activation_fn import act_fn
+# from megablocks.layers.arguments import Arguments
+# from megablocks.layers.mlp import (
+# SharedMLP,
+# SparseMLP,
+# create_dmoe_expert_weights,
+# resolve_dtensor,
+# )
+
+from .. import grouped_gemm_util as gg
+from . import common, mpu
+from .activation_fn import act_fn
+from .arguments import Arguments
+from .mlp import (
+ SharedMLP,
+ SparseMLP,
+ create_dmoe_expert_weights,
+ resolve_dtensor,
+)
+
+
+class SparseGLU(SparseMLP):
+
+ def __init__(self, args: Arguments):
+ super().__init__(args)
+ self.v1 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ with torch.no_grad():
+ self.v1.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ ),
+ )
+
+ mpu.set_expert_model_parallel_attributes(
+ self.v1,
+ self._should_set_parallelism_attribute,
+ )
+
+ def forward(self, x, topo):
+ if self.args.memory_optimized_mlp:
+ raise NotImplementedError(
+ 'Memory optimized implementation not yet supported with GLU with sparse kernels.',
+ )
+
+ w1, v1, w2 = self.scale_grad(self.w1), self.scale_grad(self.v1,), self.scale_grad(self.w2)
+ w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2)
+
+ # Compute the GLU.
+ x1 = stk.ops.sdd(x, w1.t(), topo)
+ x2 = stk.ops.sdd(x, v1.t(), topo)
+
+ activation_fn_out = act_fn(x1, self.args.activation_fn)
+ x1 = stk.ops.mul(activation_fn_out, x2)
+
+ return stk.ops.dsd(x1, w2)
+
+
+class MemoryOptimizedGroupedGLU(torch.autograd.Function):
+ """GroupedMLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, v1, w2, batch_sizes, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ v1 = v1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], v1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not v1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1', 'v1' and 'w2'.")
+
+ # Layer 0: x @ w1.t().
+ assert gg.backend is not None
+ sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True)
+ v1_out = gg.backend.gmm(x, v1, batch_sizes, trans_b=True)
+
+ # GeLU.
+ activation_fn_out = activation_fn(sdd_out) * v1_out
+
+ # Layer 1: x @ w2.
+ dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, v1, w2, batch_sizes, x, sdd_out, v1_out)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx, ddsd_out):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # Unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, v1, w2 = saved_tensors[:3]
+ batch_sizes = saved_tensors[3]
+ x = saved_tensors[4]
+ sdd_out, v1_out = saved_tensors[5:7]
+
+ # Rematerialize activation_fn output.
+ activation_fn = ctx.activation_fn
+ with torch.set_grad_enabled(True):
+ sdd_out.requires_grad = True
+ v1_out.requires_grad = True
+ activation_fn_out = activation_fn(sdd_out) * v1_out
+ activation_grad_fn = activation_fn_out.backward
+
+ # Compute dw2 with recomputed activation_fn output.
+ assert gg.backend is not None
+ dw2 = gg.backend.gmm(
+ activation_fn_out,
+ ddsd_out,
+ batch_sizes,
+ trans_a=True,
+ )
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ gg.backend.gmm(
+ ddsd_out,
+ w2,
+ batch_sizes,
+ trans_b=True,
+ c=dactivation_fn_out,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out)
+ dsdd_out = sdd_out.grad
+ dv1_out = v1_out.grad
+
+ # Compute dw1.
+ dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True)
+
+ # Compute dv1.
+ dv1 = gg.backend.gmm(dv1_out, x, batch_sizes, trans_a=True)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ dx = ddsd_out
+ gg.backend.gmm(dsdd_out, w1, batch_sizes, c=dx)
+ dx += gg.backend.gmm(dv1_out, v1, batch_sizes)
+ return dx, dw1, dv1, dw2, None, None
+
+
+memory_optimized_grouped_glu = MemoryOptimizedGroupedGLU.apply
+
+
+class GroupedGLU(SparseGLU):
+
+ def forward(self, x, tokens_per_expert):
+ batch_sizes = tokens_per_expert.cpu().to(torch.long)
+ w1, v1, w2 = (
+ self.scale_grad(self.w1),
+ self.scale_grad(self.v1),
+ self.scale_grad(self.w2),
+ )
+ w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2)
+
+ # Re-shape the weights for the grouped GEMMs.
+ ne = mpu.experts_per_rank(self.args)
+ w1 = w1.view(ne, -1, self.args.hidden_size)
+ v1 = v1.view(ne, -1, self.args.hidden_size)
+ w2 = w2.view(ne, -1, self.args.hidden_size)
+
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_grouped_glu(
+ x,
+ w1,
+ v1,
+ w2,
+ batch_sizes,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ assert gg.ops is not None
+ x1 = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
+ x2 = gg.ops.gmm(x, v1, batch_sizes, trans_b=True)
+ x1 = self.args.activation_fn(x1) * x2
+ return gg.ops.gmm(x1, w2, batch_sizes)
+
+
+class SharedGLU(SharedMLP):
+ """GPU for shared expert.
+
+ Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTGLU class
+ """
+
+ def __init__(self, args: Arguments):
+ super().__init__(args)
+ self.gate_proj = args.fc_cls(
+ args.hidden_size,
+ self.args.shared_expert_hidden_size,
+ **self.fc_kwargs,
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))
diff --git a/build/torch211-cxx11-xpu20253-x86_64-linux/_layers/memory_test.py b/build/torch211-cxx11-xpu20253-x86_64-linux/_layers/memory_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..74d1166931b712635131985b25a89f4ca23e576d
--- /dev/null
+++ b/build/torch211-cxx11-xpu20253-x86_64-linux/_layers/memory_test.py
@@ -0,0 +1,103 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import gc
+
+import torch
+import torch.distributed as dist
+
+# from megablocks.layers import arguments, dmoe
+from . import arguments, dmoe
+
+_TESTS = ((8, 2048, 4096, 4096, 32, 4),)
+
+
+def get_tensors():
+ ptrs = set()
+ out = []
+ for obj in gc.get_objects():
+ if torch.is_tensor(obj):
+ if not obj.is_contiguous() or obj.data_ptr() in ptrs:
+ continue
+ out.append(obj)
+ ptrs.add(obj.data_ptr())
+ return out
+
+
+def test_memory(
+ group,
+ batch_size,
+ sequence_length,
+ hidden_size,
+ ffn_hidden_size,
+ num_experts,
+ top_k,
+):
+ args = arguments.Arguments(
+ hidden_size=hidden_size,
+ ffn_hidden_size=ffn_hidden_size,
+ moe_num_experts=num_experts,
+ moe_top_k=top_k,
+ moe_expert_model_parallelism=True,
+ expert_parallel_group=group,
+ fp16=False,
+ bf16=True,
+ device=torch.cuda.current_device(),
+ )
+ layer = dmoe.dMoE(args).cuda()
+
+ x = torch.randn((batch_size, sequence_length, hidden_size),
+ device=torch.cuda.current_device(),
+ dtype=torch.bfloat16).requires_grad_(True)
+ torch.cuda.empty_cache()
+
+ # Run forward + backward.
+ # with torch.autograd.detect_anomaly():
+ out, _ = layer(x)
+ out.mean().backward()
+
+ # Report peak memory.
+ mem = torch.cuda.max_memory_allocated()
+ print('Max Memory Allocated = {:0.0f}MiB'.format(mem / 1e6))
+ print('Max Memory Reserved = {:0.0f}MiB'.format(torch.cuda.max_memory_reserved() / 1e6,),)
+
+ # Calculate weight and gradient memory usage.
+ weight_memory = 2 * (
+ layer.router.layer.weight.numel() + layer.experts.mlp.w1.numel() + layer.experts.mlp.w2.numel()
+ )
+
+ def grad_numel(x):
+ if x.grad is not None:
+ return x.grad.numel()
+ return 0
+
+ grad_memory = 2 * (
+ grad_numel(layer.router.layer.weight) + grad_numel(layer.experts.mlp.w1) + grad_numel(layer.experts.mlp.w2)
+ )
+ weight_memory += grad_memory
+
+ print('Weight Memory Allocated = {:0.0f}MiB'.format(weight_memory / 1e6))
+ print('Activation Memory Allocated = {:0.0f}MiB'.format((mem - weight_memory) / 1e6,),)
+
+ # Manually calculate GPU memory usage from the garbage
+ # collector.
+ gc.collect()
+ total = 0
+ tensors = get_tensors()
+ tensors = sorted(tensors, key=lambda x: -x.numel())
+ for i, t in enumerate(tensors):
+ total += t.numel()
+ print(f'{i}: {t.shape}, {t.numel() * 2}')
+ del tensors
+
+ print('Total Bytes Found = {:0.0f}MiB'.format(total * 2 / 1e6))
+
+
+if __name__ == '__main__':
+ assert dist.is_available()
+ group = dist.init_process_group(backend='nccl')
+ local_rank = dist.get_rank(group)
+ torch.cuda.set_device(local_rank)
+
+ for args in _TESTS:
+ test_memory(group, *args)
diff --git a/build/torch211-cxx11-xpu20253-x86_64-linux/_layers/mlp.py b/build/torch211-cxx11-xpu20253-x86_64-linux/_layers/mlp.py
new file mode 100644
index 0000000000000000000000000000000000000000..c99afb9904c24a8b6a83e79059cd1251dbbfd99e
--- /dev/null
+++ b/build/torch211-cxx11-xpu20253-x86_64-linux/_layers/mlp.py
@@ -0,0 +1,587 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# try:
+# import stk
+# import stk.backend.triton_kernels
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/mlp.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+from packaging import version
+
+# from megablocks import grouped_gemm_util as gg
+# from megablocks.layers import common, gelu, mpu
+# from megablocks.layers.activation_fn import act_fn
+# from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
+
+from .. import grouped_gemm_util as gg
+from . import common, gelu, mpu
+from .activation_fn import act_fn
+from .arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
+
+class ScaleGradient(torch.autograd.Function):
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx: Any, x: torch.Tensor, scale: float):
+ ctx.scale = scale
+ return x
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx: torch.Tensor, grad: torch.Tensor):
+ return grad * ctx.scale, None
+
+
+scale_gradient = ScaleGradient.apply
+
+
+def resolve_dtensor(weight: torch.Tensor):
+ if version.parse(torch.__version__) >= version.parse('2.0.0'):
+ from torch.distributed._tensor import DTensor
+ if isinstance(weight, DTensor):
+ return weight.to_local()
+ return weight
+
+
+def create_moe_expert_weights(
+ args: Arguments,
+ num_experts: int,
+ ffn_hidden_size: int,
+ hidden_size: int,
+ init_method: InitFn,
+):
+ # Create the entire weight matrix such that the sampled weights will
+ # not vary between data parallelism and expert model parallelism for
+ # the same random seed.
+ master_weights = torch.empty(
+ num_experts,
+ ffn_hidden_size,
+ hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ )
+ init_method(master_weights)
+
+ if not args.moe_expert_model_parallelism:
+ return master_weights
+
+ # Calculate the amount of sharding in each dimension.
+ expert_sharding_degree = mpu.expert_sharding_degree(args)
+ hidden_sharding_degree = mpu.hidden_sharding_degree(args)
+
+ # Calculate the experts per rank.
+ #
+ # NOTE: We assign ranks to be expert parallel before going
+ # tensor parallel.
+ rank = mpu.get_expert_parallel_rank(args)
+ expert_rank = rank % expert_sharding_degree
+ num_experts_per_rank = num_experts // expert_sharding_degree
+ start_expert = expert_rank * num_experts_per_rank
+ end_expert = (expert_rank + 1) * num_experts_per_rank
+
+ # Calculate the rows per rank.
+ row_rank = rank // expert_sharding_degree
+ num_rows_per_rank = ffn_hidden_size // hidden_sharding_degree
+ start_row = row_rank * num_rows_per_rank
+ end_row = (row_rank + 1) * num_rows_per_rank
+
+ # Slice the weight matrix to get the chunk for this rank.
+ with torch.no_grad():
+ weights = master_weights[start_expert:end_expert, start_row:end_row]
+ return weights
+
+
+class MLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ # expert_parallel_world_size = mpu.get_expert_parallel_world_size(args)
+ experts_per_rank = mpu.experts_per_rank(args)
+
+ self.w1 = torch.nn.Parameter(
+ torch.empty(
+ experts_per_rank,
+ args.hidden_size,
+ mpu.features_per_rank(args),
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ self.w2 = torch.nn.Parameter(
+ torch.empty(
+ experts_per_rank,
+ mpu.features_per_rank(args),
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w1,
+ args.moe_expert_model_parallelism,
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w2,
+ args.moe_expert_model_parallelism,
+ )
+
+ # Initialize the parameters for the MLP.
+ #
+ # NOTE: It is important that we create the weight tensors prior
+ # to creating the master weights and slicing our the piece for
+ # this rank. If the master weights are created first the PyTorch
+ # caching allocator appears to use the same memory block for these
+ # and the slice which causes large increases in our peak memory
+ # usage.
+ with torch.no_grad():
+ w1 = create_moe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ )
+ self.w1.copy_(w1.transpose(1, 2).contiguous())
+ self.w2.copy_(
+ create_moe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.output_layer_init_method,
+ ),
+ )
+
+ self.gradient_scale = None
+ if self.args.moe_expert_model_parallelism:
+ self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,)
+
+ def scale_grad(self, w):
+ if self.gradient_scale is None:
+ return w
+ return scale_gradient(w, self.gradient_scale)
+
+ def forward(self, x):
+ w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2)
+ w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2)
+ x = torch.bmm(x, w1)
+ x = self.args.activation_fn(x)
+ return torch.bmm(x, w2)
+
+
+def create_dmoe_expert_weights(
+ args: Arguments,
+ num_experts: int,
+ rows: int,
+ columns: int,
+ init_method: InitFn,
+):
+ weights = create_moe_expert_weights(
+ args,
+ num_experts,
+ rows,
+ columns,
+ init_method,
+ )
+ return weights.view([-1, columns])
+
+
+class MemoryOptimizedMLP(torch.autograd.Function):
+ """Sparse MLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, w2, topo, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.")
+
+ topo_tensors = (
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets,
+ topo.column_indices_t,
+ topo.offsets_t,
+ topo.block_offsets_t,
+ )
+
+ # Layer 0: x @ w1.t().
+ sdd_out = stk.ops.sdd(x, w1.t(), topo)
+
+ # GeLU.
+ activation_fn_out = act_fn(sdd_out, activation_fn)
+
+ # Layer 1: x @ w2.
+ dsd_out = stk.ops.dsd(activation_fn_out, w2)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.shape = topo.shape
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.data.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, w2, *topo_tensors, x, sdd_out.data)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx, ddsd_out):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, w2 = saved_tensors[:2]
+ topo_tensors = saved_tensors[2:8]
+ x = saved_tensors[8]
+ sdd_out_data = saved_tensors[9]
+
+ # rematerialize activation function output
+ activation_fn = ctx.activation_fn
+ sdd_out = stk.Matrix(ctx.shape, sdd_out_data, *topo_tensors)
+ activation_fn_out, activation_grad_fn = act_fn(
+ sdd_out,
+ activation_fn,
+ return_grad_fn=True,
+ )
+
+ # Compute dw2 with recomputed activation_fn output.
+ dw2 = stk.ops.dsd(activation_fn_out.t(), ddsd_out)
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ stk.backend.triton_kernels.sdd(
+ ddsd_out,
+ w2.t(),
+ dactivation_fn_out.shape,
+ dactivation_fn_out.data,
+ dactivation_fn_out.offsets,
+ dactivation_fn_out.row_indices,
+ dactivation_fn_out.column_indices,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ if activation_fn is DEFAULT_ACTIVATION_FN:
+ dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out)
+ else:
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out.data)
+ dsdd_out = stk.Matrix(ctx.shape, sdd_out.data.grad, *topo_tensors)
+
+ # Compute dw1.
+ dw1 = stk.ops.dsd(dsdd_out.t(), x)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ stk.backend.triton_kernels.dsd(
+ dsdd_out.shape,
+ dsdd_out.data,
+ dsdd_out.offsets,
+ dsdd_out.row_indices,
+ dsdd_out.column_indices,
+ dsdd_out.offsets_t,
+ dsdd_out.column_indices_t,
+ dsdd_out.block_offsets_t,
+ False,
+ w1,
+ ddsd_out,
+ )
+ dx = ddsd_out
+ return dx, dw1, dw2, None, None
+
+
+memory_optimized_mlp = MemoryOptimizedMLP.apply
+
+
+class SparseMLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ self._num_rows_per_rank = mpu.experts_per_rank(args) * mpu.features_per_rank(args)
+
+ self.w1 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ self.w2 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+
+ # Initialize the parameters for the MLP.
+ #
+ # NOTE: It is important that we create the weight tensors prior
+ # to creating the master weights and slicing our the piece for
+ # this rank. If the master weights are created first the PyTorch
+ # caching allocator appears to use the same memory block for these
+ # and the slice which causes large increases in our peak memory
+ # usage.
+ with torch.no_grad():
+ self.w1.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ ),
+ )
+ self.w2.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.output_layer_init_method,
+ ),
+ )
+
+ self._should_set_parallelism_attribute = args.moe_expert_model_parallelism
+ mpu.set_expert_model_parallel_attributes(
+ self.w1,
+ self._should_set_parallelism_attribute,
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w2,
+ self._should_set_parallelism_attribute,
+ )
+
+ self.gradient_scale = None
+ if self.args.moe_expert_model_parallelism:
+ self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,)
+
+ def scale_grad(self, w):
+ if self.gradient_scale is None:
+ return w
+ return scale_gradient(w, self.gradient_scale)
+
+ def forward(self, x, topo):
+ w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2)
+ w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2)
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_mlp(
+ x,
+ w1,
+ w2,
+ topo,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ x = stk.ops.sdd(x, w1.t(), topo)
+ activation_fn_out = act_fn(x, self.args.activation_fn)
+ return stk.ops.dsd(activation_fn_out, w2)
+
+
+class MemoryOptimizedGroupedMLP(torch.autograd.Function):
+ """GroupedMLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, w2, batch_sizes, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.")
+
+ # Layer 0: x @ w1.t().
+ assert gg.backend is not None
+ sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True)
+
+ # activation_fn
+ activation_fn_out = activation_fn(sdd_out)
+
+ # Layer 1: x @ w2.
+ dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, w2, batch_sizes, x, sdd_out)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx: Any, ddsd_out: torch.Tensor):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # Unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, w2 = saved_tensors[:2]
+ batch_sizes = saved_tensors[2]
+ x = saved_tensors[3]
+ sdd_out = saved_tensors[4]
+
+ # Rematerialize activation_fn output.
+ activation_fn = ctx.activation_fn
+ with torch.set_grad_enabled(True):
+ sdd_out.requires_grad = True
+ activation_fn_out = activation_fn(sdd_out)
+ activation_grad_fn = activation_fn_out.backward
+
+ # Compute dw2 with recomputed activation_fn output.
+ assert gg.backend is not None
+ dw2 = gg.backend.gmm(
+ activation_fn_out,
+ ddsd_out,
+ batch_sizes,
+ trans_a=True,
+ )
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ gg.backend.gmm(
+ ddsd_out,
+ w2,
+ batch_sizes,
+ trans_b=True,
+ c=dactivation_fn_out,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ if activation_fn is DEFAULT_ACTIVATION_FN:
+ dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out)
+ else:
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out)
+ dsdd_out = sdd_out.grad
+
+ # Compute dw1.
+ dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ gg.backend.gmm(dsdd_out, w1, batch_sizes, c=ddsd_out)
+ dx = ddsd_out
+ return dx, dw1, dw2, None, None
+
+
+memory_optimized_grouped_mlp = MemoryOptimizedGroupedMLP.apply
+
+
+class GroupedMLP(SparseMLP):
+
+ def forward(self, x, tokens_per_expert):
+ batch_sizes = tokens_per_expert.cpu().to(torch.long)
+ w1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.w2))
+
+ # Re-shape the weights for the grouped GEMMs.
+ ne = mpu.experts_per_rank(self.args)
+ w1 = resolve_dtensor(w1).view(ne, -1, self.args.hidden_size)
+ w2 = resolve_dtensor(w2).view(ne, -1, self.args.hidden_size)
+
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_grouped_mlp(
+ x,
+ w1,
+ w2,
+ batch_sizes,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ assert gg.ops is not None
+ x = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
+ x = self.args.activation_fn(x)
+ return gg.ops.gmm(x, w2, batch_sizes)
+
+
+class SharedMLP(torch.nn.Module):
+ """MLP for shared expert.
+
+ Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTMLP class
+ """
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ self.fc_kwargs: dict[str, Any] = {
+ 'bias': args.bias,
+ 'device': args.device,
+ }
+ self.fc_kwargs.update(args.fc_kwargs)
+
+ self.up_proj = args.fc_cls(
+ args.hidden_size,
+ args.shared_expert_hidden_size,
+ **self.fc_kwargs,
+ )
+ self.act = args.activation_fn
+ self.down_proj = args.fc_cls(
+ args.shared_expert_hidden_size,
+ args.hidden_size,
+ **self.fc_kwargs,
+ )
+ self.down_proj._is_residual = True # a flag for llm-foundry init
+
+ def add_experts_sharedexpert(
+ self,
+ shared_expert_out: torch.Tensor,
+ expert_out: torch.Tensor,
+ ) -> torch.Tensor:
+ # Helper function to add expert output to shared expert output
+ # with optional weighted sum.
+ if self.args.shared_expert_weighted_sum:
+ # enable using weighted sum for shared expert output
+ # wieghted by number of experts used
+ t_experts = self.args.moe_top_k + 1
+ sh_mlp_out = shared_expert_out / t_experts
+ return sh_mlp_out.add(
+ expert_out,
+ alpha=(self.args.moe_top_k / t_experts),
+ )
+
+ return shared_expert_out + expert_out
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.down_proj(self.act(self.up_proj(x)))
diff --git a/build/torch211-cxx11-xpu20253-x86_64-linux/_layers/moe.py b/build/torch211-cxx11-xpu20253-x86_64-linux/_layers/moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0a4aeaacc9c86fc70944e730c53f7a55644e05e
--- /dev/null
+++ b/build/torch211-cxx11-xpu20253-x86_64-linux/_layers/moe.py
@@ -0,0 +1,507 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Optional, Tuple
+
+import numpy as np
+import torch
+import torch.distributed as dist
+
+# import megablocks.ops as ops
+# from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry
+# from megablocks.layers.all_to_all import all_to_all
+# from megablocks.layers.arguments import Arguments
+
+from ..ops import (
+ sort,
+ histogram,
+ inclusive_cumsum,
+ exclusive_cumsum,
+ binned_gather,
+ binned_scatter,
+ gather,
+ scatter,
+ repeat,
+ replicate,
+)
+
+from . import common, mlp, mpu, router, sharedexpert_registry
+from .arguments import Arguments
+from .all_to_all import all_to_all
+
+_LOAD_BALANCING_LOSS = []
+
+
+def save_load_balancing_loss(loss):
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.append(loss)
+
+
+def get_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ return _LOAD_BALANCING_LOSS
+
+
+def clear_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.clear()
+
+
+def batched_load_balancing_loss(args: Arguments):
+ if args.moe_loss_weight == 0:
+ return 0.0
+
+ # tokens_per_expert[i].shape = (num_experts)
+ # expert_scores[i].shape = (tokens, num_experts)
+ tokens_per_expert, expert_scores = zip(*get_load_balancing_loss())
+ num_layers_per_pipeline_stage = (args.num_layers // args.pipeline_model_parallel_size)
+ if args.num_layers_per_virtual_pipeline_stage is not None:
+ num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage
+
+ if len(tokens_per_expert) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f'Expected {num_layers_per_pipeline_stage} token_per_experts '
+ f'but found {len(tokens_per_expert)}.\nnum_layers = '
+ f'{args.num_layers}\npipeline_model_parallel_size = '
+ f'{args.pipeline_model_parallel_size}\n'
+ 'num_layers_per_virtual_pipeline_stage'
+ f' = {args.num_layers_per_virtual_pipeline_stage}',
+ )
+ if len(expert_scores) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f'Expected {num_layers_per_pipeline_stage} expert_scores '
+ f'but found {len(tokens_per_expert)}.\nnum_layers = '
+ f'{args.num_layers}\npipeline_model_parallel_size = '
+ f'{args.pipeline_model_parallel_size}\n'
+ 'num_layers_per_virtual_pipeline_stage'
+ f' = {args.num_layers_per_virtual_pipeline_stage}',
+ )
+
+ # Verify the shape of the tokens_per_expert and expert_scores tensors.
+ assert all((x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert))
+
+ tokens = expert_scores[0].shape[0]
+ assert all(((x.ndim == 2 and x.shape[1] == args.moe_num_experts and x.shape[0] == tokens) for x in expert_scores))
+
+ # Concatenate the contributions of each layer and convert to
+ # the correct types and formats for the dot product.
+ expert_scores = torch.cat(expert_scores, dim=1)
+ if args.moe_lbl_in_fp32:
+ expert_scores = expert_scores.float()
+ if tokens != 0:
+ expert_scores = expert_scores.mean(dim=0)
+ else:
+ expert_scores = expert_scores.sum(dim=0)
+ tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype)
+
+ expected_values = num_layers_per_pipeline_stage * args.moe_num_experts
+ assert tokens_per_expert.numel() == expected_values
+ assert expert_scores.numel() == expected_values
+
+ # Calculate the total scale across all factors.
+ #
+ # loss_weight * num_experts / (num_layers * tokens * top_k)
+ scale_numerator = (args.moe_num_experts * args.moe_loss_weight)
+ scale_denominator = (args.num_layers * tokens * args.moe_top_k)
+ scale = scale_numerator / scale_denominator
+ return scale * torch.dot(tokens_per_expert, expert_scores)
+
+
+# NOTE: This class defines MoE expert computation, including expert model parallel
+# communication. When using FSDP on top of MegaBlocks this is the module that should
+# be wrapped s.t. the weight all-gathers can be scheduled *before* the expert model
+# parallel all2all.
+class ParallelMLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super(ParallelMLP, self).__init__()
+ self.args = args
+
+ # Calculate the number of experts in total and the number of experts
+ # owned by this rank.
+ # world_size = mpu.get_expert_parallel_world_size(args)
+ self.num_experts = args.moe_num_experts
+ self.top_k = self.args.moe_top_k
+
+ # Calculate the number of bits needed to represent the expert indices
+ # so that we can pass it to radix sort.
+ self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1)
+
+ # Expert MLP.
+ self.mlp = mlp.MLP(args)
+
+ self.bias: Optional[torch.Tensor]
+ if self.args.bias:
+ # Note that the output bias is not parallelized with expert
+ # model parallelism.
+ self.bias = torch.nn.Parameter(
+ torch.empty(
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ torch.nn.init.zeros_(self.bias)
+ else:
+ self.register_parameter('bias', None)
+
+ # Select the forward function for the operating mode.
+ self.forward_fn = (self.parallel_forward_once if args.moe_expert_model_parallelism else self.forward_once)
+
+ def expert_capacity(self, tokens: int) -> int:
+ world_size = mpu.get_expert_parallel_world_size(self.args)
+ tokens_per_expert = (self.top_k * tokens * world_size / self.num_experts)
+ return int(self.args.moe_capacity_factor * tokens_per_expert)
+
+ def load_balancing_loss(self, tokens_per_expert: torch.Tensor, expert_scores: torch.Tensor):
+ """Calculate the load balancing loss contribution."""
+ assert len(expert_scores.size()) == 2
+ tokens, num_experts = expert_scores.size()
+ assert num_experts == self.num_experts
+ assert len(tokens_per_expert.size()) == 1
+ num_experts, = tokens_per_expert.size()
+ assert num_experts == self.num_experts
+ scale = self.num_experts / (tokens * self.top_k)
+ return scale * torch.dot(
+ tokens_per_expert.to(expert_scores.dtype),
+ expert_scores.mean(dim=0),
+ )
+
+ def indices_and_bins(self,
+ top_expert: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ # Sort the expert ids to produce the scatter/gather
+ # indices for the permutation.
+ #
+ # TODO(tgale): Is it worth doing this conversion to 32-bit
+ # prior? Could we place the `torch.max` operation to return
+ # 32-bit expert indices?
+ top_expert = top_expert.int()
+ # output = ops.sort(top_expert, self.sort_end_bit)
+ output = sort(top_expert, self.sort_end_bit)
+ assert output is not None
+ bin_ids, indices = output
+
+ # Histogram the expert ids to identify the number of
+ # tokens routed to each expert.
+ #
+ # TODO(tgale): Does the sorted data produce a more favorable
+ # data distribution for histogram? Or is the op parallelism
+ # worth more?
+ # tokens_per_expert = ops.histogram(top_expert, self.num_experts)
+ tokens_per_expert = histogram(top_expert, self.num_experts)
+
+ # Calculate the bin bounds for the sorted tokens.
+ # bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ bins = inclusive_cumsum(tokens_per_expert, 0)
+ assert bins is not None
+ bins = bins.view(1) if not len(bins.size()) else bins
+
+ assert isinstance(indices, torch.Tensor)
+ assert isinstance(bin_ids, torch.Tensor)
+ assert isinstance(bins, torch.Tensor)
+ assert isinstance(tokens_per_expert, torch.Tensor)
+
+ return indices, bin_ids, bins, tokens_per_expert
+
+ def permute_and_compute(
+ self,
+ x: torch.Tensor,
+ tokens_per_expert: int, # unused
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor, # unused
+ expert_weights: torch.Tensor,
+ bins: torch.Tensor,
+ expert_capacity: int,
+ top_k: int,
+ ):
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ # output = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
+ output = binned_gather(x, indices, bins, expert_capacity, top_k)
+ assert output is not None
+ x = output
+
+ # Perform the expert computation. Note that we don't
+ # use biases for these linear operations.
+ x = self.mlp(x)
+
+ # Un-route the data for the MoE output.
+ # return ops.binned_scatter(x, indices, expert_weights, bins, top_k)
+ return binned_scatter(x, indices, expert_weights, bins, top_k)
+
+
+ def forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ # If expert_capacity is set to zero, set the number of tokens
+ # per expert to the maximum we need to avoid dropping tokens.
+ sl, bs, _ = x.size()
+ expert_capacity = self.expert_capacity(sl * bs)
+ if expert_capacity == 0:
+ expert_capacity = torch.max(tokens_per_expert).item()
+
+ x = self.permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ self.top_k,
+ )
+ return x, tokens_per_expert
+
+ def parallel_forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ # NOTE: This function implements the same computation as forward_once
+ # but with expert model parallelism.
+ #
+ # 1. Permute the tokens locally so that they are grouped by their
+ # expert assignments. This allows us to transfer all of the tokens
+ # for a remote device in one communication primitive.
+ #
+ # 2. Permute the tokens across the expert parallel devices. After
+ # this is completed each device has all of the tokens assigned to
+ # its set of experts in its local HBM.
+ #
+ # 3. Permute the tokens locally so that they are grouped by their
+ # expert assignement. After the distributed permutation the tokens
+ # are grouped by which device they came from. We re-order them
+ # locally to allow for efficient computation.
+ #
+ # After this series of permutations we compute the linear layers
+ # and then repeat these three steps in reverse to produce the final
+ # output.
+ #
+ # Compute the mapping of local tokens to experts.
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ # If we're sharding the experts along the hidden dimension
+ # multiple devices own parts of the same sets of experts.
+ # Replicate the token counts so every device gets the counts.
+ # repeated_tokens_per_expert = ops.repeat(
+ repeated_tokens_per_expert = repeat(
+ tokens_per_expert,
+ (mpu.hidden_sharding_degree(self.args),),
+ )
+
+ # Pass token count information to the device on which the
+ # target expert resides.
+ parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert,)
+ tpe_handle = dist.all_to_all_single(
+ parallel_tokens_per_expert,
+ repeated_tokens_per_expert,
+ group=self.args.expert_parallel_group,
+ async_op=True,
+ )
+
+ # Permute locally and without any padding so that tokens for each
+ # parallel device are stored contiguously.
+ #
+ # This view updates the shape of the tensor from [sl, bs, hs] to
+ # [sl * bs, hs] prior to the permutation.
+ x = x.view(-1, x.shape[-1])
+ # output = ops.gather(x, indices, bin_ids, bins, self.top_k)
+ output = gather(x, indices, bin_ids, bins, self.top_k)
+ assert output is not None
+ x = output
+
+ # Compute the number of tokens that will be received from each
+ # device and permute the input data across the devices.
+ with torch.no_grad():
+ tpe_handle.wait()
+ experts_per_rank = mpu.experts_per_rank(self.args)
+
+ # Reshape to [world_size, num_experts_per_rank].
+ world_size = mpu.get_expert_parallel_world_size(self.args)
+ repeated_tokens_per_expert = (repeated_tokens_per_expert.view(world_size, experts_per_rank))
+ parallel_tokens_per_expert = (parallel_tokens_per_expert.view(world_size, experts_per_rank))
+
+ # TODO(tgale): It might be faster to do this on the GPU and
+ # then communicate the results back to the host.
+ send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1)
+ parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
+ recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1)
+
+ # Convert the send/recv counts to lists.
+ send_counts = send_counts.tolist()
+ recv_counts = recv_counts.tolist()
+ tokens_received = sum(recv_counts)
+
+ # If we're sharding the experts along the hidden dimension
+ # multiple devices own parts of the same sets of experts.
+ # Replicate the token counts so devices that share experts
+ # get all of the tokens assigned to them.
+ #
+ # TODO(tgale): Fuse this into the prior, local permutation.
+ # x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
+ x = repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
+
+ # Start the cross-device permutation asynchronously so we can
+ # overlap communication with computation.
+ parallel_x, parallel_x_handle = all_to_all(
+ x,
+ recv_counts,
+ send_counts,
+ self.args.expert_parallel_group,
+ async_op=True,
+ )
+
+ with torch.no_grad():
+ # After we do the cross-device permutation we have the tokens on the
+ # correct device but not yet grouped by expert because we received
+ # tokens from each device as contiguous chunks. To group the tokens
+ # for expert computation we'll do one more local permutation. The
+ # rest of this torch.no_grad() scope sets up the indices and bins
+ # for this permutation.
+ # replicate_bins = ops.inclusive_cumsum(
+ replicate_bins = inclusive_cumsum(
+ parallel_tokens_per_expert.flatten(),
+ 0,
+ )
+ replicate_bins = (replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins)
+
+ # Construct the expert indices for the permuted tokens.
+ parallel_top_expert = torch.remainder(
+ torch.arange(
+ self.num_experts * mpu.hidden_sharding_degree(self.args),
+ dtype=torch.int32,
+ device=indices.device,
+ ),
+ mpu.experts_per_rank(self.args),
+ )
+ # parallel_top_expert = ops.replicate(
+ parallel_top_expert = replicate(
+ parallel_top_expert.unsqueeze(dim=0),
+ replicate_bins,
+ tokens_received,
+ ).flatten()
+
+ # TODO(tgale): The sort_end_bit here can be reduced.
+ # parallel_bin_ids, parallel_indices = ops.sort(
+ parallel_bin_ids, parallel_indices = sort(
+ parallel_top_expert,
+ self.sort_end_bit,
+ )
+
+ # Calculate the bins boundaries from the token counts.
+ parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+ # parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = (parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins)
+
+ # If expert_capacity is set to zero, set the number of tokens
+ # per expert to the maximum we need to avoid dropping tokens.
+ tokens, _ = x.size()
+ expert_capacity = self.expert_capacity(tokens)
+ if expert_capacity == 0:
+ expert_capacity = torch.max(parallel_tokens_per_expert).item()
+
+ # Locally permute the tokens and perform the expert computation.
+ # Block to make sure that the cross-device permutation is complete.
+ if self.args.mlp_impl == 'grouped':
+ # GroupedMLP requires counts on CPU. We can use the tensor already
+ # moved to CPU for the prior all_to_all, which avoids an extra
+ # device synchronization.
+ parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+ parallel_x_handle.wait()
+ parallel_x = self.permute_and_compute(
+ parallel_x,
+ parallel_tokens_per_expert,
+ parallel_indices,
+ parallel_bin_ids,
+ None, # expert_weights
+ parallel_bins,
+ expert_capacity,
+ top_k=1,
+ )
+
+ # Un-permute the tokens across the devices.
+ x, _ = all_to_all(
+ parallel_x,
+ send_counts,
+ recv_counts,
+ self.args.expert_parallel_group,
+ )
+
+ # Reduce along the hidden sharding to get the final outputs.
+ #
+ # TODO(tgale): Fuse this into the following local permutation.
+ shape = (
+ mpu.hidden_sharding_degree(self.args),
+ -1,
+ self.args.hidden_size,
+ )
+ # x = ops.sum(x.view(shape), dim=0)
+ x = x.view(shape).sum(dim=0)
+
+ # Un-permute locally to setup for the next series of operations.
+ # x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
+ x = scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
+ return x, tokens_per_expert.flatten()
+
+ def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ in_shape = x.size()
+
+ # Compute the experts.
+ x, tokens_per_expert = self.forward_fn(x, expert_weights, top_experts)
+ if self.training and self.args.moe_loss_weight > 0:
+ save_load_balancing_loss((tokens_per_expert, scores))
+ x = x.view(in_shape)
+ if self.bias is not None:
+ if self.args.return_bias:
+ return x, self.bias
+ return x + self.bias
+ return x
+
+
+class MoE(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super(MoE, self).__init__()
+
+ # Token router.
+ self.router = router.LearnedRouter(args)
+
+ # Expert computation helper.
+ self.experts = self._init_experts_mlp(args)
+
+ self.shared_expert = None
+ if args.shared_expert:
+ # SharedExpert computation helper.
+ self.shared_expert = sharedexpert_registry.get(args)
+
+ def _init_experts_mlp(self, args: Arguments):
+ return ParallelMLP(args)
+
+ def forward(self, x: torch.Tensor):
+ # NOTE: If we're going to cast the activations to lower precision
+ # do it before we permute the tokens to save bandwidth.
+ x = common.cast_if_autocast_enabled(x)
+
+ # Compute the expert scores and assignments.
+ scores, expert_weights, top_experts = self.router(x)
+
+ # Compute the experts.
+ out = self.experts(x, scores, expert_weights, top_experts)
+ if self.shared_expert is not None:
+ shared_expert_out = self.shared_expert(x)
+ out = self.shared_expert.add_experts_sharedexpert(
+ shared_expert_out,
+ out,
+ )
+ return out
diff --git a/build/torch211-cxx11-xpu20253-x86_64-linux/_layers/mpu.py b/build/torch211-cxx11-xpu20253-x86_64-linux/_layers/mpu.py
new file mode 100644
index 0000000000000000000000000000000000000000..434e143ab42bf3f83406d69e9dd1f72777716e22
--- /dev/null
+++ b/build/torch211-cxx11-xpu20253-x86_64-linux/_layers/mpu.py
@@ -0,0 +1,94 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Optional
+
+import torch
+import torch.distributed as dist
+
+# from megablocks.layers.arguments import Arguments
+from .arguments import Arguments
+
+
+class MoeParam(torch.Tensor):
+
+ def __init__(self):
+ super().__init__(self)
+ self.expert_model_parallel: bool
+
+
+def is_moe_param(tensor: torch.Tensor) -> bool:
+ return hasattr(tensor, 'expert_model_parallel')
+
+
+def get_expert_parallel_world_size(args: Arguments) -> int:
+ return (dist.get_world_size(args.expert_parallel_group) if args.moe_expert_model_parallelism else 1)
+
+
+def get_expert_parallel_rank(args: Arguments) -> int:
+ return (dist.get_rank(args.expert_parallel_group) if args.moe_expert_model_parallelism else 0)
+
+
+def set_expert_model_parallel_attributes(
+ tensor: torch.Tensor,
+ is_parallel: bool,
+):
+ assert not hasattr(tensor, 'expert_model_parallel')
+ setattr(tensor, 'expert_model_parallel', is_parallel)
+
+
+def param_is_expert_model_parallel(param: MoeParam) -> bool:
+ return (hasattr(param, 'expert_model_parallel') and param.expert_model_parallel)
+
+
+def copy_expert_model_parallel_attributes(
+ destination_tensor: torch.Tensor,
+ source_tensor: torch.Tensor,
+):
+ if hasattr(source_tensor, 'expert_model_parallel'):
+ setattr(
+ destination_tensor,
+ 'expert_model_parallel',
+ getattr(source_tensor, 'expert_model_parallel'),
+ )
+
+
+def synchronized_print(group: Optional[dist.ProcessGroup], *x: torch.Tensor):
+ world_size = dist.get_world_size(group)
+ rank = dist.get_rank(group)
+ for i in range(world_size):
+ dist.barrier(group)
+ if i == rank:
+ print(f'rank = {rank}', *x)
+
+
+# Helpers for expert/tensor sharding.
+def expert_sharding_degree(args: Arguments) -> int:
+ world_size = get_expert_parallel_world_size(args)
+ esd = min(world_size, args.moe_num_experts)
+
+ if (args.moe_num_experts % esd) != 0:
+ raise ValueError(f'Cannot shard {args.moe_num_experts} experts {esd} ways.',)
+ return esd
+
+
+def hidden_sharding_degree(args: Arguments) -> int:
+ world_size = get_expert_parallel_world_size(args)
+ esd = expert_sharding_degree(args)
+ hsd = world_size // esd
+
+ if (args.ffn_hidden_size % hsd) != 0:
+ raise ValueError(f'Cannot shard {args.ffn_hidden_size} features {hsd} ways.',)
+ if (esd * hsd) != world_size:
+ raise ValueError(
+ f"Invalid sharding. 'expert_sharding_degree' ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size}).",
+ )
+ return hsd
+
+
+def experts_per_rank(args: Arguments) -> int:
+ return args.moe_num_experts // expert_sharding_degree(args)
+
+
+def features_per_rank(args: Arguments) -> int:
+ return args.ffn_hidden_size // hidden_sharding_degree(args)
diff --git a/build/torch211-cxx11-xpu20253-x86_64-linux/_layers/router.py b/build/torch211-cxx11-xpu20253-x86_64-linux/_layers/router.py
new file mode 100644
index 0000000000000000000000000000000000000000..37cb2782348d62583376f1a183c7ede83601216d
--- /dev/null
+++ b/build/torch211-cxx11-xpu20253-x86_64-linux/_layers/router.py
@@ -0,0 +1,116 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+
+# from megablocks.layers import common
+# from megablocks.layers.arguments import Arguments
+from . import common
+from .arguments import Arguments
+
+_ROUTER_LOGITS = []
+
+
+def _save_router_logits(logits: torch.Tensor, args: Arguments):
+ if args.moe_zloss_weight == 0:
+ return
+ global _ROUTER_LOGITS
+ _ROUTER_LOGITS.append(logits)
+
+
+def clear_router_zloss():
+ global _ROUTER_LOGITS
+ _ROUTER_LOGITS.clear()
+
+
+def batched_router_zloss(args: Arguments):
+ global _ROUTER_LOGITS
+
+ if args.moe_zloss_weight == 0:
+ import warnings
+ warnings.warn('Call to batched_router_zloss, but moe_zloss_weight=0')
+ return 0
+
+ logits_per_router = _ROUTER_LOGITS
+
+ if args.moe_zloss_in_fp32:
+ logits_per_router = [logits.float() for logits in logits_per_router]
+
+ unscaled_zloss_per_router = torch.stack([
+ torch.logsumexp(logits, dim=1).square().mean() for logits in logits_per_router
+ ])
+
+ return args.moe_zloss_weight * unscaled_zloss_per_router
+
+
+# NOTE: To enable end-to-end benchmarking without convergence we
+# support a flag to force the router to assign tokens uniformly
+# across the experts. We do this with a custom autograd operation
+# so that PyTorch still executes the full set of router operation.
+class _UniformExpertAssignment(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, num_experts: int):
+ out = torch.arange(x.numel(), dtype=x.dtype, device=x.device)
+ out = torch.remainder(out, num_experts)
+ return out.view(x.shape)
+
+
+_uniform_expert_assignment = _UniformExpertAssignment.apply
+
+
+class LearnedRouter(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+
+ # Learned router parameters.
+ #
+ # NOTE: This weight matrix is not parallelized with expert model
+ # parallelism. Each device needs the entire router weight matrix
+ # so that it can route its batch of data correctly.
+ self.layer = torch.nn.Linear(
+ args.hidden_size,
+ args.moe_num_experts,
+ bias=False,
+ dtype=common.dtype(args),
+ device=args.device,
+ )
+ args.init_method(self.layer.weight)
+
+ def jitter(self, x: torch.Tensor):
+ low: float = 1.0 - self.args.moe_jitter_eps
+ high: float = 1.0 + self.args.moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return low + noise * (high - low)
+
+ def _top_k(self, scores: torch.Tensor):
+ if self.args.moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, self.args.moe_top_k, dim=-1)
+
+ def forward(self, x: torch.Tensor):
+ if self.training and self.args.moe_jitter_eps is not None:
+ x = x * self.jitter(x)
+
+ logits = self.layer(x.view(-1, x.shape[-1]))
+ _save_router_logits(logits, self.args)
+ scores = logits.softmax(dim=-1)
+ expert_weights, expert_indices = self._top_k(scores)
+ if self.args.moe_normalize_expert_weights:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=self.args.moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+
+ expert_indices = (
+ _uniform_expert_assignment(
+ expert_indices,
+ self.args.moe_num_experts,
+ ) if self.args.uniform_expert_assignment else expert_indices
+ )
+ return scores, expert_weights, expert_indices
diff --git a/build/torch211-cxx11-xpu20253-x86_64-linux/_layers/sharedexpert_registry.py b/build/torch211-cxx11-xpu20253-x86_64-linux/_layers/sharedexpert_registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..5840862f88f370ace5fd49bd0612fc98d186cc49
--- /dev/null
+++ b/build/torch211-cxx11-xpu20253-x86_64-linux/_layers/sharedexpert_registry.py
@@ -0,0 +1,32 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Union
+
+# from megablocks.layers import glu, mlp
+# from megablocks.layers.arguments import Arguments
+from . import glu, mlp
+from .arguments import Arguments
+
+_REGISTRY = {
+ 'mlp': mlp.SharedMLP,
+ 'glu': glu.SharedGLU,
+}
+
+
+def get(args: Arguments) -> Union[mlp.SharedMLP, glu.SharedGLU]:
+ """Returns an SharedMLP for use in a dMoE instance.
+
+ Uses the provided arguments to instantiate the appropriate
+ SharedMLP instance.
+
+ Args:
+ args: propagated Arguments dataclass.
+
+ Returns:
+ An instantiated SharedMLP constructed using the input args.
+ """
+ if args.mlp_type not in _REGISTRY:
+ raise ValueError(f'Unsupported mlp type: {args.mlp_type}')
+
+ return _REGISTRY[args.mlp_type](args)
diff --git a/build/torch211-cxx11-xpu20253-x86_64-linux/_megablocks_xpu_7a6bcf4.abi3.so b/build/torch211-cxx11-xpu20253-x86_64-linux/_megablocks_xpu_7a6bcf4.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..ee486c8abce8a03e1a612e789037d4c3a4793807
--- /dev/null
+++ b/build/torch211-cxx11-xpu20253-x86_64-linux/_megablocks_xpu_7a6bcf4.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3a732142f2d8813f0cbfc6fd912e421b57707789a7a35b8063b141b52182dfc5
+size 5381792
diff --git a/build/torch211-cxx11-xpu20253-x86_64-linux/_ops.py b/build/torch211-cxx11-xpu20253-x86_64-linux/_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c8dd6eeccd632df5e23111e5dd5221d3e1fcb47
--- /dev/null
+++ b/build/torch211-cxx11-xpu20253-x86_64-linux/_ops.py
@@ -0,0 +1,9 @@
+import torch
+from . import _megablocks_xpu_7a6bcf4
+ops = torch.ops._megablocks_xpu_7a6bcf4
+
+def add_op_namespace_prefix(op_name: str):
+ """
+ Prefix op by namespace.
+ """
+ return f"_megablocks_xpu_7a6bcf4::{op_name}"
diff --git a/build/torch211-cxx11-xpu20253-x86_64-linux/_version.py b/build/torch211-cxx11-xpu20253-x86_64-linux/_version.py
new file mode 100644
index 0000000000000000000000000000000000000000..c55783177af19bc03654c730c4892df8f8532279
--- /dev/null
+++ b/build/torch211-cxx11-xpu20253-x86_64-linux/_version.py
@@ -0,0 +1,6 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+"""The MegaBlocks Version."""
+
+__version__ = '0.11.0.dev0'
diff --git a/build/torch211-cxx11-xpu20253-x86_64-linux/backend/__init__.py b/build/torch211-cxx11-xpu20253-x86_64-linux/backend/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d4e43e9b7e6a9a1c3bde2df34914643ca5d8332
--- /dev/null
+++ b/build/torch211-cxx11-xpu20253-x86_64-linux/backend/__init__.py
@@ -0,0 +1,2 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
diff --git a/build/torch211-cxx11-xpu20253-x86_64-linux/backend/kernels.py b/build/torch211-cxx11-xpu20253-x86_64-linux/backend/kernels.py
new file mode 100644
index 0000000000000000000000000000000000000000..d20139324883338ddb312e4b05a72056d74491ac
--- /dev/null
+++ b/build/torch211-cxx11-xpu20253-x86_64-linux/backend/kernels.py
@@ -0,0 +1,557 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import triton
+import triton.language as tl
+
+# Stub triton autotune when testing in a env that does not have CUDA
+# this approach preserves the original code but enables testing without a GPU
+if torch.cuda.is_available() is False:
+ import warnings
+
+ warnings.warn("CUDA is not available. Triton autotuning is disabled.")
+
+ def _no_autotune(*args, **kwargs):
+ def deco(fn):
+ return fn
+ return deco
+
+ triton.autotune = _no_autotune
+
+
+def assert_is_tensor(x, ndim):
+ if x.ndim != ndim:
+ raise ValueError(f'Expected {ndim}-tensor but got {x.ndim}-tensor')
+
+
+def assert_is_matrix(x):
+ assert_is_tensor(x, 2)
+
+
+def assert_is_vector(x):
+ if x.ndim != 1:
+ raise ValueError(f'Expected 1-tensor but got {x.ndim}-tensor')
+
+
+def assert_equal(a, b):
+ if a != b:
+ raise ValueError(f'Expected dimensions to be equal but got {a} and {b}.',)
+
+
+# a: (tokens, hidden_size), real.
+# indices: (tokens * top_k), integer.
+# bin_ids: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+# padded_bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _padded_copy(
+ a,
+ b,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+ A_TO_B: tl.constexpr,
+ SCALE: tl.constexpr,
+):
+ # Our index into array 'a'.
+ index_a = tl.load(indices + tl.program_id(0))
+
+ # One threadblock per row in 'a'. Array 'b' has greater or equal
+ # number of rows since they could be padded.
+ bin_idx = tl.load(bin_ids + tl.program_id(0))
+
+ # Now we know what bin we're assigned to, but we need to know how
+ # many threadblocks were assigned to earlier bins so we can offset
+ # in our bin properly.
+ offset_in_bin = tl.program_id(0)
+ if bin_idx > 0:
+ offset_in_bin -= tl.load(bins + bin_idx - 1)
+
+ # Load the starting index of our bin in array 'b'.
+ index_b = offset_in_bin
+ if bin_idx > 0:
+ index_b += tl.load(padded_bins + bin_idx - 1)
+
+ # Offset the input and output pointers.
+ #
+ # If we're going from A to B, divide the input index to copy
+ # the same input repeatedly. If we're going from B to A we
+ # need to reduce the result. Using atomics is slow, so we
+ # do the reduce step in a second kernel.
+ offset = index_a // TOP_K if A_TO_B else index_a
+ a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS)
+ b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ # Load the scale, if requested.
+ scale = tl.load(weights + index_a) if SCALE else 1
+
+ # Swap the pointers depending on the direction.
+ iptr = a if A_TO_B else b
+ optr = b if A_TO_B else a
+
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ x = tl.load(iptr + offsets, mask=mask)
+ x = x.to(tl.float32) * scale.to(tl.float32)
+
+ tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask)
+
+ offsets += BLOCK_X
+
+
+def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+ assert_equal(bin_ids.shape[0], x.shape[0] * top_k)
+ assert_equal(bins.size(), padded_bins.size())
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ # NOTE: Because of the padding, the output size is dynamic.
+ # We load the final padded bin bound to get the output rows.
+ output_rows = padded_bins[-1].cpu().item()
+ out = torch.zeros((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ x,
+ out,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def gather(x, indices, bin_ids, weights, bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+ assert_equal(bin_ids.shape[0], x.shape[0] * top_k)
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ # NOTE: There is no padding so the output rows equals the
+ # input rows multiplied by top_k.
+ output_rows = x.shape[0] * top_k
+ out = torch.empty((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ x,
+ out,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], bin_ids.shape[0])
+ assert_equal(bins.size(), padded_bins.size())
+
+ if weights is not None:
+ assert_equal(indices.shape[0], weights.shape[0])
+
+ tokens = indices.shape[0] // top_k
+ out = torch.empty((tokens, top_k, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ out,
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=False,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+
+ # Reduce along the top-k dimension, if needed.
+ return out.sum(dim=1) if top_k > 1 else out.view(tokens, x.shape[1])
+
+
+def scatter(x, indices, bin_ids, weights, bins, top_k):
+ return padded_scatter(x, indices, bin_ids, weights, bins, bins, top_k)
+
+
+# x: (tokens, top_k, hidden_size), real
+# grad: (tokens, hidden_size), real.
+# wgrad: (tokens, top_k), real.
+# indices: (tokens * top_k), integer.
+# bin_ids: (tokens * top_k), integer.
+# bins: (num_experts), integer.
+# padded_bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _padded_copy_wgrad(
+ x,
+ grad,
+ wgrad,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+):
+ # Our index into 'tokens * top_k'.
+ index_out = tl.load(indices + tl.program_id(0))
+
+ # One threadblock per row in 'a'. Array 'b' has greater or equal
+ # number of rows since they could be padded.
+ bin_idx = tl.load(bin_ids + tl.program_id(0))
+
+ # Now we know what bin we're assigned to, but we need to know how
+ # many threadblocks were assigned to earlier bins so we can offset
+ # in our bin properly.
+ offset_in_bin = tl.program_id(0)
+ if bin_idx > 0:
+ offset_in_bin -= tl.load(bins + bin_idx - 1)
+
+ # Load the starting index of our bin in array 'x'.
+ index_x = offset_in_bin
+ if bin_idx > 0:
+ index_x += tl.load(padded_bins + bin_idx - 1)
+
+ # Offset the input and output pointers.
+ wgrad += index_out
+ grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS)
+ x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ data = tl.load(x + offsets, mask=mask).to(tl.float32)
+ scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
+ acc += data * scale
+ offsets += BLOCK_X
+
+ # Reduce to get the final result and store.
+ out = tl.sum(acc).to(wgrad.dtype.element_ty)
+ tl.store(wgrad, out)
+
+
+def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_matrix(grad)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], bin_ids.shape[0])
+ assert_equal(bins.size(), padded_bins.size())
+
+ tokens = indices.shape[0] // top_k
+ out = torch.empty((tokens * top_k), dtype=x.dtype, device=x.device)
+ _padded_copy_wgrad[(indices.shape[0],)](
+ x,
+ grad,
+ out,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ TOP_K=top_k,
+ )
+ return out
+
+
+def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k):
+ return padded_scatter_wgrad(x, grad, indices, bin_ids, bins, bins, top_k)
+
+
+# a: (tokens, hidden_size), real.
+# b: (num_experts, expert_capacity, num_columns), real.
+# indices: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _binned_copy(
+ a,
+ b,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+ A_TO_B: tl.constexpr,
+ SCALE: tl.constexpr,
+):
+ # Load our indices into the output.
+ expert_idx = tl.program_id(0)
+ entry_idx = tl.program_id(1)
+
+ # Calculate our offset into the output.
+ index_b = expert_idx * expert_capacity + entry_idx
+
+ # Load the index bounds for our bin and calculate
+ # the number of tokens assigned to our expert.
+ start = 0
+ if expert_idx > 0:
+ start = tl.load(bins + expert_idx - 1)
+ end = tl.load(bins + expert_idx)
+ num_tokens = end - start
+
+ # Calculate our offset into the input. If we don't
+ # have an input exit early.
+ if entry_idx >= num_tokens:
+ return
+ index_a = tl.load(indices + start + entry_idx)
+
+ # Offset the input and output pointers.
+ #
+ # If we're going from A to B, divide the input index to copy
+ # the same input repeatedly. If we're going from B to A we
+ # need to reduce the result. Using atomics is slow, so we
+ # do the reduce step in a second kernel.
+ offset = index_a // TOP_K if A_TO_B else index_a
+ a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS)
+ b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ # Load the scale, if requested.
+ scale = tl.load(weights + index_a) if SCALE else 1
+
+ # Swap the pointers depending on the direction.
+ #
+ # NOTE: We need to zero the output in both directions.
+ iptr = a if A_TO_B else b
+ optr = b if A_TO_B else a
+
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ x = tl.load(iptr + offsets, mask=mask)
+ x = x.to(tl.float32) * scale.to(tl.float32)
+
+ tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask)
+
+ offsets += BLOCK_X
+
+
+def binned_gather(x, indices, weights, bins, expert_capacity, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ num_experts = bins.shape[0]
+ out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
+
+ _binned_copy[(num_experts, expert_capacity)](
+ x,
+ out,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def binned_scatter(x, indices, weights, bins, top_k):
+ # Validate the input shapes.
+ assert_is_tensor(x, 3)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(bins.shape[0], x.shape[0])
+
+ if weights is not None:
+ assert_equal(indices.shape[0], weights.shape[0])
+
+ num_experts, expert_capacity, hidden_size = x.shape
+ tokens = indices.shape[0] // top_k
+ out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
+ _binned_copy[(num_experts, expert_capacity)](
+ out,
+ x,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS=hidden_size,
+ A_TO_B=False,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+
+ # Reduce along the top-k dimension, if needed.
+ return out.sum(dim=1) if top_k > 1 else out.view(tokens, hidden_size)
+
+
+# a: (tokens, hidden_size), real.
+# b: (num_experts, expert_capacity, num_columns), real.
+# indices: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _binned_copy_wgrad(
+ x,
+ grad,
+ wgrad,
+ num_experts,
+ expert_capacity,
+ indices,
+ bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+):
+ # Load our indices into the output.
+ expert_idx = tl.program_id(0)
+ entry_idx = tl.program_id(1)
+
+ # Calculate our offset into the output.
+ index_x = expert_idx * expert_capacity + entry_idx
+
+ # Load the index bounds for our bin and calculate
+ # the number of tokens assigned to our expert.
+ start = 0
+ if expert_idx > 0:
+ start = tl.load(bins + expert_idx - 1)
+ end = tl.load(bins + expert_idx)
+ num_tokens = end - start
+
+ # Calculate our offset into the input. If we don't
+ # have an input exit early.
+ if entry_idx >= num_tokens:
+ return
+ index_out = tl.load(indices + start + entry_idx)
+
+ # Offset the input and output pointers.
+ wgrad += index_out
+ grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS)
+ x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ data = tl.load(x + offsets, mask=mask).to(tl.float32)
+ scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
+ acc += data * scale
+ offsets += BLOCK_X
+
+ # Reduce to get the final result and store.
+ out = tl.sum(acc).to(wgrad.dtype.element_ty)
+ tl.store(wgrad, out)
+
+
+def binned_scatter_wgrad(x, grad, indices, bins, top_k):
+ # Validate the input shapes.
+ assert_is_tensor(x, 3)
+ assert_is_matrix(grad)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(bins.shape[0], x.shape[0])
+
+ num_experts, expert_capacity, hidden_size = x.shape
+ tokens = indices.shape[0] // top_k
+ out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
+ _binned_copy_wgrad[(num_experts, expert_capacity)](
+ x,
+ grad,
+ out,
+ num_experts,
+ expert_capacity,
+ indices,
+ bins,
+ NUM_COLUMNS=hidden_size,
+ TOP_K=top_k,
+ )
+ return out
diff --git a/build/torch211-cxx11-xpu20253-x86_64-linux/benchmark_util.py b/build/torch211-cxx11-xpu20253-x86_64-linux/benchmark_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..02612d95e3ead1175a596e2878fa34b5bf85ad6f
--- /dev/null
+++ b/build/torch211-cxx11-xpu20253-x86_64-linux/benchmark_util.py
@@ -0,0 +1,35 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import numpy as np
+import torch
+
+
+def log_benchmark(name, arguments, time, std):
+ print('=' * 60)
+ print(f'{name} Benchmark')
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean time = {:.3f}ms, std time = {:.3f}ms'.format(time, std))
+ print('=' * 60)
+
+
+def benchmark_function(fn, iterations=100, warmup=10):
+ # Warmup iterations.
+ for _ in range(warmup):
+ fn()
+
+ times = []
+ for i in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+
+ start.record()
+ fn()
+ end.record()
+
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ return np.mean(times), np.std(times)
diff --git a/build/torch211-cxx11-xpu20253-x86_64-linux/cpu_fused_moe.py b/build/torch211-cxx11-xpu20253-x86_64-linux/cpu_fused_moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a29e81d6e9c57f64b0a78ac1c0828e45fd9d855
--- /dev/null
+++ b/build/torch211-cxx11-xpu20253-x86_64-linux/cpu_fused_moe.py
@@ -0,0 +1,311 @@
+# SPDX-License-Identifier: Apache-2.0
+# MegaBlocks CPU Fused MoE Implementation
+#
+# This is a pure Python/PyTorch implementation for CPU.
+# For better performance, consider using the C++ kernel implementation.
+#
+import torch
+import torch.nn.functional as F
+
+
+def swigluoai_activation(gate: torch.Tensor, up: torch.Tensor,
+ alpha: float = 1.702, limit: float = 7.0) -> torch.Tensor:
+ """
+ SwigluOAI activation function used in GptOss models.
+
+ Formula:
+ gate = clamp(gate, max=limit)
+ up = clamp(up, -limit, limit)
+ glu = gate * sigmoid(gate * alpha)
+ output = (up + 1) * glu
+
+ Args:
+ gate: Gate tensor from gate projection
+ up: Up tensor from up projection
+ alpha: Scaling factor for sigmoid (default: 1.702)
+ limit: Clamp limit (default: 7.0)
+
+ Returns:
+ Activated tensor
+ """
+ gate = gate.clamp(max=limit)
+ up = up.clamp(min=-limit, max=limit)
+ glu = gate * torch.sigmoid(gate * alpha)
+ return (up + 1) * glu
+
+
+def silu_and_mul_activation(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor:
+ """
+ SiLU (Swish) activation with element-wise multiplication.
+
+ Formula:
+ output = silu(gate) * up
+
+ Args:
+ gate: Gate tensor
+ up: Up tensor
+
+ Returns:
+ Activated tensor
+ """
+ return F.silu(gate) * up
+
+
+def route_tokens_cpu(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: torch.Tensor | None,
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_normalize_expert_weights: int | None = None,
+) -> tuple:
+ """
+ Route tokens to experts and compute expert weights and indices (CPU version).
+
+ Args:
+ x: Input tensor [batch, seq, hidden] or [tokens, hidden]
+ router_weight: Router weight [num_experts, hidden]
+ router_bias: Router bias [num_experts] or None
+ moe_top_k: Number of experts per token
+ moe_num_experts: Total number of experts
+ moe_normalize_expert_weights: Normalization order or None
+
+ Returns:
+ Tuple of (logits, expert_weights, expert_indices)
+ """
+ x_flat = x.view(-1, x.shape[-1])
+ logits = F.linear(x_flat, router_weight, router_bias)
+
+ if moe_top_k == 1:
+ expert_weights, expert_indices = logits.max(dim=-1, keepdim=True)
+ else:
+ expert_weights, expert_indices = torch.topk(logits, moe_top_k, dim=-1)
+
+ expert_weights = expert_weights.softmax(dim=-1)
+
+ if moe_normalize_expert_weights is not None:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+
+ return logits, expert_weights, expert_indices
+
+
+def cpu_fused_moe(
+ hidden_states: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ w1_bias: torch.Tensor | None = None,
+ w2_bias: torch.Tensor | None = None,
+ activation: str = "silu",
+ alpha: float = 1.702,
+ limit: float = 7.0,
+ is_interleaved: bool = True,
+) -> torch.Tensor:
+ """
+ CPU Fused MoE using PyTorch operations.
+
+ This implementation processes all experts in parallel using batched operations
+ instead of sequential for loops, which is more efficient on CPU.
+
+ Args:
+ hidden_states: [num_tokens, hidden_size]
+ w1: [num_experts, hidden_size, 2*inter_size] - gate_up_proj weights
+ w2: [num_experts, inter_size, hidden_size] - down_proj weights
+ topk_weights: [num_tokens, topk] - routing weights
+ topk_ids: [num_tokens, topk] - expert indices
+ w1_bias: [num_experts, 2*inter_size] or None
+ w2_bias: [num_experts, hidden_size] or None
+ activation: "silu" or "swigluoai"
+ alpha: swigluoai alpha parameter
+ limit: swigluoai limit parameter
+ is_interleaved: whether gate_up is interleaved [g0,u0,g1,u1,...] (True for GptOss)
+
+ Returns:
+ output: [num_tokens, hidden_size]
+ """
+ num_tokens, hidden_size = hidden_states.shape
+ num_experts = w1.shape[0]
+ inter_size = w2.shape[1]
+ topk = topk_weights.shape[1]
+
+ # Initialize output
+ output = torch.zeros_like(hidden_states)
+
+ # Build expert mask: which tokens go to which expert
+ # expert_mask[expert_id] contains indices of (token_idx, topk_pos) pairs
+ for expert_idx in range(num_experts):
+ # Find tokens assigned to this expert
+ # mask shape: [num_tokens, topk], True where topk_ids == expert_idx
+ mask = (topk_ids == expert_idx)
+
+ if not mask.any():
+ continue
+
+ # Get token indices and topk positions
+ token_indices, topk_positions = torch.where(mask)
+
+ if len(token_indices) == 0:
+ continue
+
+ # Gather input tokens for this expert
+ # current_hidden: [num_selected_tokens, hidden_size]
+ current_hidden = hidden_states[token_indices]
+
+ # Get weights for this expert
+ # w1[expert_idx]: [hidden_size, 2*inter_size]
+ # w2[expert_idx]: [inter_size, hidden_size]
+ expert_w1 = w1[expert_idx] # [hidden_size, 2*inter_size]
+ expert_w2 = w2[expert_idx] # [inter_size, hidden_size]
+
+ # First projection: hidden @ w1 -> [num_selected, 2*inter_size]
+ gate_up = current_hidden @ expert_w1
+
+ # Add bias if present
+ if w1_bias is not None:
+ gate_up = gate_up + w1_bias[expert_idx]
+
+ # Split gate and up projections
+ if is_interleaved:
+ # GptOss uses interleaved layout: [g0, u0, g1, u1, ...]
+ gate = gate_up[..., ::2] # [num_selected, inter_size]
+ up = gate_up[..., 1::2] # [num_selected, inter_size]
+ else:
+ # Standard layout: [gate_all, up_all]
+ gate = gate_up[..., :inter_size]
+ up = gate_up[..., inter_size:]
+
+ # Apply activation
+ if activation == "swigluoai":
+ activated = swigluoai_activation(gate, up, alpha, limit)
+ else: # silu
+ activated = silu_and_mul_activation(gate, up)
+
+ # Second projection: activated @ w2 -> [num_selected, hidden_size]
+ expert_out = activated @ expert_w2
+
+ # Add bias if present
+ if w2_bias is not None:
+ expert_out = expert_out + w2_bias[expert_idx]
+
+ # Apply routing weights and accumulate
+ # weights shape: [num_selected]
+ weights = topk_weights[token_indices, topk_positions].unsqueeze(-1)
+ weighted_out = expert_out * weights
+
+ # Accumulate to output
+ output.index_add_(0, token_indices, weighted_out.to(output.dtype))
+
+ return output
+
+
+class MegaBlocksMoeMLP(torch.nn.Module):
+ """
+ CPU MoE MLP module that can be used as a drop-in replacement for
+ the transformers GptOssMLP when using @use_kernel_forward_from_hub.
+ """
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> tuple:
+ """
+ Forward pass through the MoE layer.
+
+ Args:
+ x: Input tensor of shape [batch_size, seq_len, hidden_size] or [tokens, hidden_size]
+
+ Returns:
+ Tuple of (output, expert_weights) where:
+ - output: Tensor of same shape as input
+ - expert_weights: Expert weights for each token [tokens, top_k]
+ """
+ # Get MoE parameters from the wrapped modules
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+
+ # Detect activation type
+ if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
+ activation = "swigluoai"
+ alpha = self.experts.alpha
+ limit = self.experts.limit
+ else:
+ activation = getattr(self.experts, "activation", "silu")
+ alpha = 1.702
+ limit = 7.0
+
+ # Get weight tensors
+ if hasattr(self.experts, "gate_up_proj"):
+ w1 = self.experts.gate_up_proj
+ is_interleaved = True # GptOss uses interleaved layout
+ elif hasattr(self.experts, "w1"):
+ w1 = self.experts.w1
+ w3 = getattr(self.experts, "w3", None)
+ if w3 is not None:
+ w1 = torch.cat([w1, w3], dim=-1)
+ is_interleaved = False
+ else:
+ raise AttributeError("experts module must have 'gate_up_proj' or 'w1' attribute")
+
+ if hasattr(self.experts, "down_proj"):
+ w2 = self.experts.down_proj
+ elif hasattr(self.experts, "w2"):
+ w2 = self.experts.w2
+ else:
+ raise AttributeError("experts module must have 'down_proj' or 'w2' attribute")
+
+ # Get optional bias tensors
+ w1_bias = getattr(self.experts, "gate_up_proj_bias", None)
+ w2_bias = getattr(self.experts, "down_proj_bias", None)
+
+ # Store original shape
+ in_shape = x.size()
+
+ # Route tokens to experts
+ logits, expert_weights, expert_indices = route_tokens_cpu(
+ x,
+ self.router.weight,
+ getattr(self.router, "bias", None),
+ moe_top_k,
+ moe_num_experts,
+ moe_normalize_expert_weights,
+ )
+
+ # Reshape input for fused MoE
+ x_flat = x.view(-1, x.shape[-1])
+
+ # Call CPU fused MoE
+ output = cpu_fused_moe(
+ hidden_states=x_flat,
+ w1=w1,
+ w2=w2,
+ topk_weights=expert_weights,
+ topk_ids=expert_indices,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ activation=activation,
+ alpha=alpha,
+ limit=limit,
+ is_interleaved=is_interleaved,
+ )
+
+ # Restore original shape
+ output = output.view(in_shape)
+
+ return output, expert_weights
+
+
+# Export classes and functions
+__all__ = [
+ "MegaBlocksMoeMLP",
+ "cpu_fused_moe",
+ "route_tokens_cpu",
+ "swigluoai_activation",
+ "silu_and_mul_activation",
+]
diff --git a/build/torch211-cxx11-xpu20253-x86_64-linux/cpu_moe_cpp.py b/build/torch211-cxx11-xpu20253-x86_64-linux/cpu_moe_cpp.py
new file mode 100644
index 0000000000000000000000000000000000000000..073ff66d24ce348fbb5ed19c9027fadd3f7a9c61
--- /dev/null
+++ b/build/torch211-cxx11-xpu20253-x86_64-linux/cpu_moe_cpp.py
@@ -0,0 +1,265 @@
+# SPDX-License-Identifier: Apache-2.0
+# MegaBlocks C++ Optimized CPU MoE
+
+"""
+C++ accelerated MoE with brgemm optimization for Intel AMX.
+Direct replacement for cpu_fused_moe.MegaBlocksMoeMLP with better performance.
+"""
+
+import torch
+from typing import Optional
+from .cpu_fused_moe import route_tokens_cpu
+from ._ops import ops
+
+
+def _to_local_tensor(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
+ """Convert DTensor to local torch.Tensor if needed for custom ops compatibility."""
+ if tensor is None:
+ return None
+ # Check if it's a DTensor by looking for the to_local() method
+ if hasattr(tensor, "to_local"):
+ return tensor.to_local()
+ return tensor
+
+
+def fused_moe_cpp(
+ hidden_states: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ inplace: bool = False,
+ use_int8_w8a8: bool = False,
+ use_fp8_w8a16: bool = False,
+ use_mxfp4: bool = False,
+ w1_scale: Optional[torch.Tensor] = None,
+ w2_scale: Optional[torch.Tensor] = None,
+ block_size: Optional[list] = None,
+ a1_scale: Optional[torch.Tensor] = None,
+ a2_scale: Optional[torch.Tensor] = None,
+ w1_bias: Optional[torch.Tensor] = None,
+ w2_bias: Optional[torch.Tensor] = None,
+ alpha: Optional[float] = None,
+ limit: Optional[float] = None,
+ is_vnni: bool = False,
+) -> torch.Tensor:
+ """
+ C++ Fused MoE with brgemm optimization (sglang compatible interface).
+
+ Uses at::native::cpublas::brgemm for efficient batch GEMM on Intel CPUs.
+ Supports both silu_and_mul (standard SwiGLU) and swigluoai (GptOss) activations.
+
+ Args:
+ hidden_states: Input tensor [M, K]
+ w1: Gate and up projections [E, 2N, K]
+ w2: Down projection [E, K, N]
+ topk_weights: Expert weights [M, topk]
+ topk_ids: Expert indices [M, topk]
+ inplace: Whether to use hidden_states as output
+ use_int8_w8a8: Use int8 quantization
+ use_fp8_w8a16: Use fp8 quantization
+ use_mxfp4: Use mxfp4 quantization
+ w1_scale, w2_scale: Quantization scales
+ block_size: Block size for fp8
+ a1_scale, a2_scale: Activation scales
+ w1_bias, w2_bias: Optional biases
+ alpha: swigluoai alpha parameter (set to enable swiglu)
+ limit: swigluoai limit parameter (set to enable swiglu)
+ is_vnni: Whether w1/w2 are already in VNNI packed format
+ """
+ # MXFP4/FP8 kernels only support bf16, convert if needed
+ orig_dtype = hidden_states.dtype
+ need_convert = ((use_mxfp4 or use_fp8_w8a16) and orig_dtype != torch.bfloat16) or orig_dtype == torch.float32
+ if need_convert:
+ hidden_states = hidden_states.to(torch.bfloat16)
+
+ # bias must match hidden_states dtype
+ if w1_bias is not None:
+ w1_bias = w1_bias.to(hidden_states.dtype)
+ if w2_bias is not None:
+ w2_bias = w2_bias.to(hidden_states.dtype)
+
+ # Convert DTensor to local tensor for custom ops compatibility (TP mode)
+ hidden_states = _to_local_tensor(hidden_states)
+ w1 = _to_local_tensor(w1)
+ w2 = _to_local_tensor(w2)
+ topk_weights = _to_local_tensor(topk_weights)
+ topk_ids = _to_local_tensor(topk_ids)
+ w1_scale = _to_local_tensor(w1_scale)
+ w2_scale = _to_local_tensor(w2_scale)
+ a1_scale = _to_local_tensor(a1_scale)
+ a2_scale = _to_local_tensor(a2_scale)
+ w1_bias = _to_local_tensor(w1_bias)
+ w2_bias = _to_local_tensor(w2_bias)
+
+ output = ops.fused_experts(
+ hidden_states, w1, w2, topk_weights, topk_ids,
+ inplace, use_int8_w8a8, use_fp8_w8a16, use_mxfp4,
+ w1_scale, w2_scale, block_size, a1_scale, a2_scale,
+ w1_bias, w2_bias, alpha, limit, is_vnni
+ )
+
+ # Convert back to original dtype if needed
+ if need_convert:
+ output = output.to(orig_dtype)
+ return output
+
+
+class CPUMegaBlocksMoeMLP(torch.nn.Module):
+ """
+ C++ optimized MoE MLP using brgemm.
+ Drop-in replacement for cpu_fused_moe.MegaBlocksMoeMLP with better performance.
+
+ Usage in transformers:
+ # Will be used via @use_kernel_forward_from_hub decorator
+ """
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> tuple:
+ """
+ Forward pass through the MoE layer using C++ kernel.
+
+ Args:
+ x: Input tensor [batch_size, seq_len, hidden_size]
+
+ Returns:
+ Tuple of (output, expert_weights)
+ """
+ # Optimization for GPT-OSS model
+ if getattr(self, "use_mxfp4", None) is None:
+ self.use_mxfp4 = False
+
+ w1_scale = None
+ w2_scale = None
+
+ if (
+ not getattr(self, "packed_scales", False)
+ and hasattr(self.experts, "gate_up_proj")
+ and getattr(self.experts, "gate_up_proj_precision_config", None) is not None
+ ):
+ # convert scales
+ data_1 = ops.convert_scale_packed(self.experts.gate_up_proj_precision_config.weight_scale.data.transpose(-1, -2).contiguous())
+ data_2 = ops.convert_scale_packed(self.experts.down_proj_precision_config.weight_scale.data.transpose(-1, -2).contiguous())
+ self.experts.gate_up_proj_precision_config.weight_scale.storage.data = data_1
+ self.experts.down_proj_precision_config.weight_scale.storage.data = data_2
+ self.packed_scales = True
+ self.use_mxfp4 = True
+
+ if not getattr(self, "packed_weight", False) and hasattr(
+ self.experts, "gate_up_proj"
+ ):
+ # convert weights
+ data_1 = self.experts.gate_up_proj.data.transpose(-1, -2).contiguous()
+ data_2 = self.experts.down_proj.data.transpose(-1, -2).contiguous()
+ if self.use_mxfp4:
+ self.experts.gate_up_proj.storage.data = ops.convert_weight_packed(data_1)
+ self.experts.down_proj.storage.data = ops.convert_weight_packed(data_2)
+ else:
+ # convert_weight_packed only supports bfloat16, float16, int8, fp8_e4m3 or uint8(mxfp4 or int4).
+ data_1 = data_1.to(torch.bfloat16) if data_1.dtype == torch.float32 else data_1
+ data_2 = data_2.to(torch.bfloat16) if data_2.dtype == torch.float32 else data_2
+ self.experts.gate_up_proj.data = ops.convert_weight_packed(data_1)
+ self.experts.down_proj.data = ops.convert_weight_packed(data_2)
+
+ # C++ kernel does not support float32.
+ dtype = torch.bfloat16 if x.dtype == torch.float32 else x.dtype
+ if getattr(self.experts, "gate_up_proj_bias", None) is not None:
+ self.experts.gate_up_proj_bias.data = self.experts.gate_up_proj_bias.data.to(dtype)
+ if getattr(self.experts, "down_proj_bias", None) is not None:
+ self.experts.down_proj_bias.data = self.experts.down_proj_bias.data.to(dtype)
+
+ self.packed_weight = True
+
+ # Get MoE parameters
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
+
+ # Detect activation type
+ if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
+ activation = "swigluoai"
+ alpha = self.experts.alpha
+ limit = self.experts.limit
+ else:
+ activation = getattr(self.experts, "activation", "silu")
+ alpha = 1.702
+ limit = 7.0
+
+ # Get weight tensors
+ if hasattr(self.experts, "gate_up_proj"):
+ w1 = self.experts.gate_up_proj
+ elif hasattr(self.experts, "w1"):
+ w1 = self.experts.w1
+ w3 = getattr(self.experts, "w3", None)
+ if w3 is not None:
+ w1 = torch.cat([w1, w3], dim=-1)
+ else:
+ raise AttributeError("experts module must have 'gate_up_proj' or 'w1' attribute")
+
+ if hasattr(self.experts, "down_proj"):
+ w2 = self.experts.down_proj
+ elif hasattr(self.experts, "w2"):
+ w2 = self.experts.w2
+ else:
+ raise AttributeError("experts module must have 'down_proj' or 'w2' attribute")
+
+ # Get optional bias tensors
+ w1_bias = getattr(self.experts, "gate_up_proj_bias", None)
+ w2_bias = getattr(self.experts, "down_proj_bias", None)
+ w1_bias = w1_bias if w1_bias is None else w1_bias.data
+ w2_bias = w2_bias if w2_bias is None else w2_bias.data
+
+ if self.use_mxfp4:
+ w1_scale = self.experts.gate_up_proj_precision_config.weight_scale.data
+ w2_scale = self.experts.down_proj_precision_config.weight_scale.data
+
+ # Store original shape
+ in_shape = x.size()
+
+ # Route tokens to experts (Python implementation is fast enough)
+ logits, expert_weights, expert_indices = route_tokens_cpu(
+ x,
+ self.router.weight,
+ getattr(self.router, "bias", None),
+ moe_top_k,
+ moe_num_experts,
+ moe_normalize_expert_weights,
+ )
+
+ # Flatten input
+ x_flat = x.view(-1, x.shape[-1])
+
+ # Determine alpha/limit for swiglu activation
+ use_alpha = alpha if activation == "swigluoai" else None
+ use_limit = limit if activation == "swigluoai" else None
+
+ # Call C++ optimized kernel
+ output = fused_moe_cpp(
+ hidden_states=x_flat,
+ w1=w1.data,
+ w2=w2.data,
+ topk_weights=expert_weights,
+ topk_ids=expert_indices.to(torch.int32),
+ inplace=False,
+ use_int8_w8a8=False,
+ use_fp8_w8a16=False,
+ use_mxfp4=self.use_mxfp4,
+ w1_scale=w1_scale,
+ w2_scale=w2_scale,
+ block_size=None,
+ a1_scale=None,
+ a2_scale=None,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ alpha=use_alpha,
+ limit=use_limit,
+ is_vnni=getattr(self, "packed_weight", False),
+ )
+
+ # Restore original shape
+ output = output.view(in_shape)
+
+ return output, expert_weights
+
+
+__all__ = ["fused_moe_cpp", "MegaBlocksMoeMLP"]
diff --git a/build/torch211-cxx11-xpu20253-x86_64-linux/grouped_gemm/__init__.py b/build/torch211-cxx11-xpu20253-x86_64-linux/grouped_gemm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b91c8308f0c24f4c4171b6e4f15b6f76dabf295a
--- /dev/null
+++ b/build/torch211-cxx11-xpu20253-x86_64-linux/grouped_gemm/__init__.py
@@ -0,0 +1,2 @@
+from . import ops
+from . import backend
diff --git a/build/torch211-cxx11-xpu20253-x86_64-linux/grouped_gemm/backend.py b/build/torch211-cxx11-xpu20253-x86_64-linux/grouped_gemm/backend.py
new file mode 100644
index 0000000000000000000000000000000000000000..76037d8039cbfc2f0577275c78e4bc0be762592a
--- /dev/null
+++ b/build/torch211-cxx11-xpu20253-x86_64-linux/grouped_gemm/backend.py
@@ -0,0 +1,33 @@
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# # TODO(tgale): Wrap this in a try-block with better
+# # error message and instructions for building the
+# # c++ operations.
+# import grouped_gemm_backend as backend
+
+# We import the backend operations from the megablocks package as
+# grouped_gemm is vendored in megablocks in this repository.
+# from ... import _ops as backend
+# from megablocks._ops import ops as backend # type: ignore
+from .._ops import ops as backend # type: ignore
+
+def _allocate_output(a, b, batch_sizes, trans_a, trans_b):
+ assert not (trans_a and trans_b)
+ assert batch_sizes.ndim == 1, "Expected 1d tensor for batch_sizes"
+ assert a.ndim == 2, "Expected 2d tensor for 'a'"
+ assert b.ndim == (2 if trans_a else 3)
+
+ shape = (
+ (batch_sizes.shape[0], a.shape[1], b.shape[1])
+ if trans_a else
+ (a.shape[0], (b.shape[1] if trans_b else b.shape[2]))
+ )
+ return torch.empty(*shape, device=a.device, dtype=a.dtype)
+
+def gmm(a, b, batch_sizes, trans_a=False, trans_b=False, c=None):
+ if c is None:
+ c = _allocate_output(a, b, batch_sizes, trans_a, trans_b)
+ backend.gmm(a, b, c, batch_sizes, trans_a, trans_b)
+ return c
diff --git a/build/torch211-cxx11-xpu20253-x86_64-linux/grouped_gemm/ops.py b/build/torch211-cxx11-xpu20253-x86_64-linux/grouped_gemm/ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b30dd14e23837ea3b12334f4e31337ed9ad2b69
--- /dev/null
+++ b/build/torch211-cxx11-xpu20253-x86_64-linux/grouped_gemm/ops.py
@@ -0,0 +1,33 @@
+from . import backend
+import torch
+
+
+class GroupedGemm(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, a, b, batch_sizes, trans_b):
+ ctx.save_for_backward(a, b, batch_sizes)
+ ctx.trans_b = trans_b
+ return backend.gmm(a, b, batch_sizes, trans_a=False, trans_b=trans_b)
+
+ @staticmethod
+ def backward(ctx, grad):
+ grad = grad.contiguous()
+ a, b, batch_sizes = ctx.saved_tensors
+ trans_b = ctx.trans_b
+
+ agrad = None
+ if ctx.needs_input_grad[0]:
+ agrad = backend.gmm(
+ grad, b, batch_sizes, trans_a=False, trans_b=not trans_b)
+
+ bgrad = None
+ if ctx.needs_input_grad[1]:
+ lhs, rhs = (grad, a) if trans_b else (a, grad)
+ bgrad = backend.gmm(
+ lhs, rhs, batch_sizes, trans_a=True, trans_b=False)
+ return agrad, bgrad, None, None
+
+
+def gmm(a, b, batch_sizes, trans_b=False):
+ return GroupedGemm.apply(a, b, batch_sizes, trans_b)
diff --git a/build/torch211-cxx11-xpu20253-x86_64-linux/grouped_gemm_util.py b/build/torch211-cxx11-xpu20253-x86_64-linux/grouped_gemm_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d6d49fc46b0a57ad46e4179df3cc1ac2a24f7ae
--- /dev/null
+++ b/build/torch211-cxx11-xpu20253-x86_64-linux/grouped_gemm_util.py
@@ -0,0 +1,31 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+import warnings
+
+_grouped_gemm_is_available: bool = False
+try:
+ # import grouped_gemm
+ pass
+ _grouped_gemm_is_available = True
+except ImportError as error:
+ warnings.warn('Grouped GEMM not available.')
+
+
+def grouped_gemm_is_available():
+ return _grouped_gemm_is_available
+
+
+def assert_grouped_gemm_is_available():
+ msg = (
+ 'Grouped GEMM not available. Please run '
+ '`pip install git+https://github.com/tgale96/grouped_gemm@main`.',
+ )
+ assert _grouped_gemm_is_available, msg
+
+
+# backend = grouped_gemm.backend if grouped_gemm_is_available() else None
+# ops = grouped_gemm.ops if grouped_gemm_is_available() else None
+
+
+from .grouped_gemm import backend as ops
+from .grouped_gemm import ops as backend
diff --git a/build/torch211-cxx11-xpu20253-x86_64-linux/layers.py b/build/torch211-cxx11-xpu20253-x86_64-linux/layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c91edcd9e2d1a4ef9eac90217ff481f08ab1886
--- /dev/null
+++ b/build/torch211-cxx11-xpu20253-x86_64-linux/layers.py
@@ -0,0 +1,1232 @@
+import torch
+import torch.distributed as dist
+
+from typing import Optional, Any, TYPE_CHECKING
+
+from . import _layers
+from . import ops
+
+# Conditional import for meta kernel registration
+if TYPE_CHECKING:
+
+ def register_fake(fn):
+ return lambda name: fn
+
+else:
+ try:
+ from torch.library import register_fake
+ except ImportError:
+ try:
+ from torch.library import impl_abstract as register_fake
+ except ImportError:
+ # Fallback for older PyTorch versions
+ def register_fake(op_name):
+ def decorator(fn):
+ return fn
+
+ return decorator
+
+
+# Meta kernel implementations for torch.compile compatibility
+def _install_meta_kernels():
+ """Install meta kernels for existing MegaBlocks operations"""
+
+ # Create wrapper functions that check for compilation and return meta tensors
+
+ # Patch ops.sort
+ if hasattr(ops, "sort"):
+ original_sort = ops.sort
+
+ def sort_with_meta(x, end_bit=None):
+ if torch.compiler.is_compiling():
+ # print("Using meta kernel for sort")
+ # Meta implementation - return tensors with correct shape/dtype/device
+ return torch.empty_like(x), torch.empty_like(x)
+ # print("Using original sort kernel")
+ return original_sort(x, end_bit)
+
+ ops.sort = sort_with_meta
+
+ # Patch ops.histogram
+ if hasattr(ops, "histogram"):
+ original_histogram = ops.histogram
+
+ def histogram_with_meta(x, max_val):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty((max_val,), dtype=torch.int32, device=x.device)
+ return original_histogram(x, max_val)
+
+ ops.histogram = histogram_with_meta
+
+ # Patch ops.inclusive_cumsum
+ if hasattr(ops, "inclusive_cumsum"):
+ original_inclusive_cumsum = ops.inclusive_cumsum
+
+ def inclusive_cumsum_with_meta(x, dim):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty_like(x)
+ return original_inclusive_cumsum(x, dim)
+
+ ops.inclusive_cumsum = inclusive_cumsum_with_meta
+
+ # Patch ops.binned_gather
+ if hasattr(ops, "binned_gather"):
+ original_binned_gather = ops.binned_gather
+
+ def binned_gather_with_meta(x, indices, bins, bin_size, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - output shape based on bin_size
+ if x.dim() >= 2:
+ hidden_size = x.size(-1)
+ return torch.empty(
+ (bin_size, x.size(1), hidden_size),
+ dtype=x.dtype,
+ device=x.device,
+ )
+ else:
+ return torch.empty((bin_size,), dtype=x.dtype, device=x.device)
+ return original_binned_gather(x, indices, bins, bin_size, top_k)
+
+ ops.binned_gather = binned_gather_with_meta
+
+ # Patch ops.binned_scatter
+ if hasattr(ops, "binned_scatter"):
+ original_binned_scatter = ops.binned_scatter
+
+ def binned_scatter_with_meta(x, indices, weights, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - typically reduces to 2D
+ if x.dim() >= 3:
+ return torch.empty(
+ (x.size(1), x.size(2)), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty_like(x)
+ return original_binned_scatter(x, indices, weights, bins, top_k)
+
+ ops.binned_scatter = binned_scatter_with_meta
+
+ # Patch ops.gather
+ if hasattr(ops, "gather"):
+ original_gather = ops.gather
+
+ def gather_with_meta(x, indices, bin_ids, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ if x.dim() >= 2:
+ hidden_size = x.size(-1)
+ return torch.empty(
+ (indices.numel(), hidden_size), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty(indices.shape, dtype=x.dtype, device=x.device)
+ return original_gather(x, indices, bin_ids, bins, top_k)
+
+ ops.gather = gather_with_meta
+
+ # Patch ops.scatter
+ if hasattr(ops, "scatter"):
+ original_scatter = ops.scatter
+
+ def scatter_with_meta(x, indices, bin_ids, weights, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - restore sequence shape
+ seq_len = (
+ indices.size(0) // top_k
+ if indices.numel() > 0 and top_k > 0
+ else x.size(0)
+ )
+ if x.dim() >= 2:
+ return torch.empty(
+ (seq_len, x.size(-1)), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty((seq_len,), dtype=x.dtype, device=x.device)
+ return original_scatter(x, indices, bin_ids, weights, bins, top_k)
+
+ ops.scatter = scatter_with_meta
+
+ # Patch ops.replicate
+ if hasattr(ops, "replicate"):
+ original_replicate = ops.replicate
+
+ def replicate_with_meta(x, bins, num_outputs):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty(
+ (x.shape[0], num_outputs), dtype=x.dtype, device=x.device
+ )
+ return original_replicate(x, bins, num_outputs)
+
+ ops.replicate = replicate_with_meta
+
+ # Patch ops.repeat (if it's a regular function)
+ if hasattr(ops, "repeat"):
+ original_repeat = ops.repeat
+
+ def repeat_with_meta(x, repeats):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ if isinstance(repeats, (tuple, list)):
+ new_shape = list(x.shape)
+ for i, rep in enumerate(repeats):
+ if i < len(new_shape):
+ new_shape[i] *= rep
+ return torch.empty(new_shape, dtype=x.dtype, device=x.device)
+ else:
+ new_shape = [x.size(0) * repeats] + list(x.shape[1:])
+ return torch.empty(new_shape, dtype=x.dtype, device=x.device)
+ return original_repeat(x, repeats)
+
+ ops.repeat = repeat_with_meta
+
+
+# Install meta kernels on import
+try:
+ _install_meta_kernels()
+except Exception as e:
+ # If meta kernel installation fails, continue without them
+ # torch.compile may not work but the library will still function
+ import warnings
+
+ warnings.warn(
+ f"Failed to install meta kernels for torch.compile support: {e}", UserWarning
+ )
+
+
+# Set the expert model parallel attributes on a tensor
+def set_expert_model_parallel_attributes(
+ tensor: torch.Tensor,
+ is_parallel: bool,
+):
+ assert not hasattr(tensor, "expert_model_parallel")
+ setattr(tensor, "expert_model_parallel", is_parallel)
+
+
+# Get the expert model parallel attributes from a tensor
+def expert_sharding_degree(
+ world_size: int,
+ moe_num_experts: int,
+) -> int:
+ esd = min(world_size, moe_num_experts)
+ if (moe_num_experts % esd) != 0:
+ raise ValueError(f"Cannot shard {moe_num_experts} experts {esd} ways.")
+ return esd
+
+
+# Calculate the hidden sharding degree based on world size and expert sharding degree
+def hidden_sharding_degree(
+ world_size: int,
+ moe_num_experts: int,
+ ffn_hidden_size: int,
+) -> int:
+ esd = expert_sharding_degree(world_size, moe_num_experts)
+ hsd = world_size // esd
+ if (ffn_hidden_size % hsd) != 0:
+ raise ValueError(f"Cannot shard {ffn_hidden_size} features {hsd} ways.")
+ if (esd * hsd) != world_size:
+ raise ValueError(
+ f"Invalid sharding. expert_sharding_degree ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size})."
+ )
+ return hsd
+
+
+# Calculate the number of experts per rank based on world size and expert sharding degree
+def experts_per_rank(
+ moe_num_experts: int,
+ world_size: int,
+) -> int:
+ return moe_num_experts // expert_sharding_degree(world_size, moe_num_experts)
+
+
+# Calculate the number of features per rank based on ffn hidden size and hidden sharding degree
+def features_per_rank(
+ ffn_hidden_size: int, world_size: int, moe_num_experts: int
+) -> int:
+ return ffn_hidden_size // hidden_sharding_degree(
+ world_size, moe_num_experts, ffn_hidden_size
+ )
+
+
+# Apply jitter to the input tensor
+def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor:
+ low = 1.0 - moe_jitter_eps
+ high = 1.0 + moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return x * (low + noise * (high - low))
+
+
+# Compute the top-k scores from the logits
+def compute_top_k(scores: torch.Tensor, moe_top_k: int):
+ if moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, moe_top_k, dim=-1)
+
+
+# Route tokens to experts and compute expert weights and indices
+def route_tokens(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: torch.Tensor,
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ if training and moe_jitter_eps is not None:
+ x = apply_jitter(x, moe_jitter_eps)
+
+ x_flat = x.view(-1, x.shape[-1])
+ logits = torch.nn.functional.linear(x_flat, router_weight, router_bias)
+ expert_weights, expert_indices = compute_top_k(logits, moe_top_k)
+ expert_weights = expert_weights.softmax(dim=-1)
+ if moe_normalize_expert_weights is not None:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+ if uniform_expert_assignment:
+ expert_indices = _layers.router._uniform_expert_assignment(
+ expert_indices,
+ moe_num_experts,
+ )
+
+ return logits, expert_weights, expert_indices
+
+
+# Scale the gradient of the weights
+def scale_grad(
+ w: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+) -> torch.Tensor:
+ if gradient_scale is None:
+ return w
+ return _layers.mlp.scale_gradient(w, gradient_scale)
+
+
+# Forward pass for the MLP layer
+def mlp_forward(
+ x: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ limit: float = 7.0,
+):
+ # Scale weights
+ w1 = scale_grad(w1, gradient_scale)
+ w2 = scale_grad(w2, gradient_scale)
+ w1_bias = scale_grad(w1_bias, gradient_scale)
+ w2_bias = scale_grad(w2_bias, gradient_scale)
+
+ # Resolve dtensors
+ w1 = _layers.mlp.resolve_dtensor(w1)
+ w2 = _layers.mlp.resolve_dtensor(w2)
+ w1_bias = _layers.mlp.resolve_dtensor(w1_bias)
+ w2_bias = _layers.mlp.resolve_dtensor(w2_bias)
+
+ # Forward pass
+ gate_up = torch.bmm(x, w1) + w1_bias[..., None, :]
+ gate, up = gate_up[..., ::2], gate_up[..., 1::2]
+ gate = gate.clamp(min=None, max=limit)
+ up = up.clamp(min=-limit, max=limit)
+ glu = gate * torch.sigmoid(gate * alpha)
+ next_states = torch.bmm(((up + 1) * glu), w2)
+ next_states += w2_bias[..., None, :]
+ return next_states
+
+# Shared expert MLP forward pass
+def shared_mlp_forward(
+ x: torch.Tensor,
+ up_proj_weight: torch.Tensor,
+ down_proj_weight: torch.Tensor,
+ up_proj_bias: Optional[torch.Tensor] = None,
+ down_proj_bias: Optional[torch.Tensor] = None,
+ activation_fn: Optional[Any] = None,
+ gradient_scale: Optional[float] = None,
+) -> torch.Tensor:
+ # Default activation function
+ if activation_fn is None:
+ activation_fn = torch.nn.functional.gelu
+
+ # Scale weights
+ up_proj_weight = scale_grad(up_proj_weight, gradient_scale)
+ down_proj_weight = scale_grad(down_proj_weight, gradient_scale)
+ if up_proj_bias is not None:
+ up_proj_bias = scale_grad(up_proj_bias, gradient_scale)
+ if down_proj_bias is not None:
+ down_proj_bias = scale_grad(down_proj_bias, gradient_scale)
+
+ # Resolve dtensors
+ up_proj_weight = _layers.mlp.resolve_dtensor(up_proj_weight)
+ down_proj_weight = _layers.mlp.resolve_dtensor(down_proj_weight)
+ if up_proj_bias is not None:
+ up_proj_bias = _layers.mlp.resolve_dtensor(up_proj_bias)
+ if down_proj_bias is not None:
+ down_proj_bias = _layers.mlp.resolve_dtensor(down_proj_bias)
+
+ # Up projection
+ x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias)
+
+ # Activation
+ x = activation_fn(x)
+
+ # Down projection
+ x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias)
+
+ return x
+
+
+# Combine outputs from shared expert and regular experts
+def combine_expert_shared_outputs(
+ shared_expert_out: torch.Tensor,
+ expert_out: torch.Tensor,
+ shared_expert_weighted_sum: bool = False,
+ moe_top_k: int = 1,
+) -> torch.Tensor:
+ if shared_expert_weighted_sum:
+ # Weighted sum based on number of experts used
+ total_experts = moe_top_k + 1
+ shared_weight = 1.0 / total_experts
+ expert_weight = moe_top_k / total_experts
+ return shared_expert_out * shared_weight + expert_out * expert_weight
+ else:
+ # Simple addition
+ return shared_expert_out + expert_out
+
+
+# Global variable to store load balancing loss
+_LOAD_BALANCING_LOSS = []
+
+
+def save_load_balancing_loss(loss):
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.append(loss)
+
+
+def get_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ return _LOAD_BALANCING_LOSS
+
+
+def clear_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.clear()
+
+
+def batched_load_balancing_loss(args):
+ if args.moe_loss_weight == 0:
+ return 0.0
+
+ tokens_per_expert, expert_scores = zip(*get_load_balancing_loss())
+ num_layers_per_pipeline_stage = args.num_layers // args.pipeline_model_parallel_size
+ if args.num_layers_per_virtual_pipeline_stage is not None:
+ num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage
+
+ if len(tokens_per_expert) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f"Expected {num_layers_per_pipeline_stage} token_per_experts "
+ f"but found {len(tokens_per_expert)}.\nnum_layers = "
+ f"{args.num_layers}\npipeline_model_parallel_size = "
+ f"{args.pipeline_model_parallel_size}\n"
+ "num_layers_per_virtual_pipeline_stage"
+ f" = {args.num_layers_per_virtual_pipeline_stage}",
+ )
+ if len(expert_scores) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f"Expected {num_layers_per_pipeline_stage} expert_scores "
+ f"but found {len(tokens_per_expert)}.\nnum_layers = "
+ f"{args.num_layers}\npipeline_model_parallel_size = "
+ f"{args.pipeline_model_parallel_size}\n"
+ "num_layers_per_virtual_pipeline_stage"
+ f" = {args.num_layers_per_virtual_pipeline_stage}",
+ )
+
+ # Verify the shape of the tokens_per_expert and expert_scores tensors.
+ assert all(
+ (x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert)
+ )
+
+ tokens = expert_scores[0].shape[0]
+ assert all(
+ (
+ (
+ x.ndim == 2
+ and x.shape[1] == args.moe_num_experts
+ and x.shape[0] == tokens
+ )
+ for x in expert_scores
+ )
+ )
+
+ # Concatenate the contributions of each layer and convert to
+ # the correct types and formats for the dot product.
+ expert_scores = torch.cat(expert_scores, dim=1)
+ if args.moe_lbl_in_fp32:
+ expert_scores = expert_scores.float()
+ if tokens != 0:
+ expert_scores = expert_scores.mean(dim=0)
+ else:
+ expert_scores = expert_scores.sum(dim=0)
+ tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype)
+
+ expected_values = num_layers_per_pipeline_stage * args.moe_num_experts
+ assert tokens_per_expert.numel() == expected_values
+ assert expert_scores.numel() == expected_values
+
+ # Calculate the total scale across all factors.
+ #
+ # loss_weight * num_experts / (num_layers * tokens * top_k)
+ scale_numerator = args.moe_num_experts * args.moe_loss_weight
+ scale_denominator = args.num_layers * tokens * args.moe_top_k
+ scale = scale_numerator / scale_denominator
+ return scale * torch.dot(tokens_per_expert, expert_scores)
+
+
+# Calculate the expert capacity based on tokens, top_k, number of experts,
+# expert parallel group, capacity factor, and whether expert model parallelism is used.
+def expert_capacity(
+ tokens: int,
+ top_k: int,
+ num_experts: int,
+ expert_parallel_group: int,
+ moe_capacity_factor: float,
+ moe_expert_model_parallelism: bool,
+) -> int:
+ world_size = (
+ dist.get_world_size(expert_parallel_group)
+ if moe_expert_model_parallelism
+ else 1
+ )
+
+ tokens_per_expert = top_k * tokens * world_size / num_experts
+ return int(moe_capacity_factor * tokens_per_expert)
+
+
+def load_balancing_loss(
+ tokens_per_expert: torch.Tensor,
+ expert_scores: torch.Tensor,
+ top_k: int,
+ num_experts: int,
+):
+ assert len(expert_scores.size()) == 2
+ tokens, num_experts = expert_scores.size()
+ assert num_experts == num_experts
+ assert len(tokens_per_expert.size()) == 1
+ (num_experts,) = tokens_per_expert.size()
+ assert num_experts == num_experts
+ scale = num_experts / (tokens * top_k)
+ return scale * torch.dot(
+ tokens_per_expert.to(expert_scores.dtype),
+ expert_scores.mean(dim=0),
+ )
+
+
+def indices_and_bins(
+ top_expert: torch.Tensor,
+ sort_end_bit: int,
+ num_experts: int,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ top_expert = top_expert.int()
+
+ # Ensure contiguous memory layout
+ top_expert = top_expert.contiguous()
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(top_expert.device):
+ output = ops.sort(top_expert, sort_end_bit)
+ bin_ids, indices = output
+ tokens_per_expert = ops.histogram(top_expert, num_experts)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+
+ bins = bins.view(1) if not len(bins.size()) else bins
+ return indices, bin_ids, bins, tokens_per_expert
+
+
+def expert_capacity_fn(
+ tokens: int,
+ top_k: int,
+ num_experts: int,
+ expert_parallel_group: torch.distributed.ProcessGroup,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+) -> int:
+ world_size = (
+ dist.get_world_size(expert_parallel_group)
+ if moe_expert_model_parallelism
+ else 1
+ )
+ tokens_per_expert = top_k * tokens * world_size / num_experts
+ return int(moe_capacity_factor * tokens_per_expert)
+
+
+def permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ top_k,
+ w1,
+ w2,
+ w1_bias,
+ w2_bias,
+ gradient_scale,
+ alpha,
+):
+ # Route tokens to experts
+ x = x.view(-1, x.shape[-1])
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(x.device):
+ x = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
+
+ # Expert computation
+ x = mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale, alpha)
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(x.device):
+ # Route tokens back
+ out = ops.binned_scatter(x, indices, expert_weights, bins, top_k)
+ return out
+
+
+def forward_once(
+ x: torch.Tensor,
+ expert_weights: torch.Tensor,
+ top_experts: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ top_k: int = 4,
+ num_experts: int = 128,
+ expert_parallel_group: int = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ mlp_impl: Optional[str] = None,
+):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
+ top_experts, sort_end_bit, num_experts
+ )
+
+ # Calculate expert capacity
+ sl, bs, _ = x.size()
+
+ expert_capacity = expert_capacity_fn(
+ sl * bs,
+ top_k,
+ num_experts,
+ expert_parallel_group,
+ moe_capacity_factor,
+ moe_expert_model_parallelism,
+ )
+
+ if expert_capacity == 0:
+ expert_capacity = torch.max(tokens_per_expert).item()
+
+ x = permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ top_k,
+ w1,
+ w2,
+ w1_bias,
+ w2_bias,
+ gradient_scale,
+ alpha,
+ )
+ return x, tokens_per_expert
+
+
+def parallel_forward_once(
+ x: torch.Tensor,
+ expert_weights: torch.Tensor,
+ top_experts: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ top_k: int = 4,
+ num_experts: int = 128,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = True,
+ hidden_size: int = 1152,
+ mlp_impl: Optional[str] = "grouped",
+):
+ # Flatten inputs
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+
+ # TODO: remove debugging var
+ # my_rank = dist.get_rank(expert_parallel_group) if expert_parallel_group else 0
+
+ with torch.no_grad():
+ # Step 1: Local permutation setup
+ indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
+ top_experts, sort_end_bit, num_experts
+ )
+
+ # Calculate sharding parameters
+ world_size = dist.get_world_size(expert_parallel_group)
+ hidden_sharding_deg = hidden_sharding_degree(
+ world_size, num_experts, hidden_size
+ )
+ experts_per_rank_val = experts_per_rank(num_experts, world_size)
+
+ # Replicate token counts for hidden sharding
+ repeated_tokens_per_expert = ops.repeat(
+ tokens_per_expert, (hidden_sharding_deg,)
+ )
+
+ # Exchange token counts across devices
+ parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert)
+
+ # Ensure CUB knows which device to use
+ tpe_handle = dist.all_to_all_single(
+ parallel_tokens_per_expert,
+ repeated_tokens_per_expert,
+ group=expert_parallel_group,
+ async_op=True,
+ )
+
+ # Step 2: Local permutation - group tokens by target device
+ x = x.view(-1, x.shape[-1]) # [sl * bs, hs]
+ x = ops.gather(x, indices, bin_ids, bins, top_k)
+
+ # Step 3: Compute communication counts and exchange tokens
+ with torch.no_grad():
+ tpe_handle.wait()
+
+ # Reshape for per-device calculations
+ repeated_tokens_per_expert = repeated_tokens_per_expert.view(
+ world_size, experts_per_rank_val
+ )
+ parallel_tokens_per_expert = parallel_tokens_per_expert.view(
+ world_size, experts_per_rank_val
+ )
+
+ # Calculate send/recv counts
+ send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1).tolist()
+ # recv_counts = parallel_tokens_per_expert.cpu().sum(dim=-1).tolist()
+ parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
+ recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1).tolist()
+ tokens_received = sum(recv_counts)
+
+ # Replicate for hidden sharding
+ x = ops.repeat(x, (hidden_sharding_deg, 1))
+
+ # Cross-device token exchange
+ parallel_x, parallel_x_handle = _layers.all_to_all.all_to_all(
+ x, recv_counts, send_counts, expert_parallel_group, async_op=True
+ )
+
+ with torch.no_grad():
+ # Step 4: Setup for local expert computation
+ replicate_bins = ops.inclusive_cumsum(parallel_tokens_per_expert.flatten(), 0)
+ replicate_bins = (
+ replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins
+ )
+
+ # Create expert indices for received tokens
+ parallel_top_expert = torch.remainder(
+ torch.arange(
+ num_experts * hidden_sharding_deg,
+ dtype=torch.int32,
+ device=indices.device,
+ ),
+ experts_per_rank_val,
+ )
+ parallel_top_expert = ops.replicate(
+ parallel_top_expert.unsqueeze(dim=0),
+ replicate_bins,
+ tokens_received,
+ ).flatten()
+
+ # Sort tokens by expert assignment
+ parallel_bin_ids, parallel_indices = ops.sort(
+ parallel_top_expert,
+ sort_end_bit,
+ )
+
+ # Calculate bins for local experts
+ parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
+ dim=0, dtype=torch.int
+ )
+ parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = (
+ parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins
+ )
+
+ # Calculate expert capacity
+ expert_capacity = expert_capacity_fn(
+ tokens_received,
+ top_k,
+ experts_per_rank_val,
+ expert_parallel_group,
+ moe_capacity_factor,
+ moe_expert_model_parallelism,
+ )
+ if expert_capacity == 0:
+ expert_capacity = torch.max(parallel_tokens_per_expert).item()
+
+ # Locally permute the tokens and perform the expert computation.
+ # Block to make sure that the cross-device permutation is complete.
+ if mlp_impl == "grouped":
+ # GroupedMLP requires counts on CPU. We can use the tensor already
+ # moved to CPU for the prior all_to_all, which avoids an extra
+ # device synchronization.
+ parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+
+ # Step 5: Expert computation
+ parallel_x_handle.wait()
+
+ parallel_x = permute_and_compute(
+ parallel_x,
+ parallel_tokens_per_expert,
+ parallel_indices,
+ parallel_bin_ids,
+ None, # expert_weights
+ parallel_bins,
+ expert_capacity,
+ top_k=1,
+ w1=w1,
+ w2=w2,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ )
+
+ # Step 6: Reverse communication - send results back
+ x, _ = _layers.all_to_all.all_to_all(
+ parallel_x, send_counts, recv_counts, expert_parallel_group
+ )
+
+ # Step 7: Reduce across hidden sharding dimension
+ shape = (hidden_sharding_deg, -1, hidden_size)
+ x = x.view(shape).sum(dim=0)
+
+ # Step 8: Final local unpermutation
+ x = ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
+
+ return x, tokens_per_expert.flatten()
+
+
+def moe_forward(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: Optional[torch.Tensor],
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+ w1: torch.Tensor = None,
+ w2: torch.Tensor = None,
+ w1_bias: torch.Tensor = None,
+ w2_bias: torch.Tensor = None,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ forward_fn: Any = None,
+ hidden_size: int = None,
+ mlp_impl: str = "grouped",
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+
+ # Route tokens to experts
+ logits, expert_weights, expert_indices = route_tokens(
+ x,
+ router_weight,
+ router_bias,
+ moe_top_k,
+ moe_num_experts,
+ moe_jitter_eps,
+ moe_normalize_expert_weights,
+ uniform_expert_assignment,
+ training,
+ )
+
+ # Create router scores for output
+ router_scores = (
+ torch.zeros_like(logits)
+ .scatter_(1, expert_indices, expert_weights)
+ .transpose(0, 1)
+ )
+
+ in_shape = x.size()
+
+ # Prepare forward function arguments
+ forward_args = {
+ "x": x,
+ "expert_weights": expert_weights,
+ "top_experts": expert_indices,
+ "w1": w1,
+ "w2": w2,
+ "w1_bias": w1_bias,
+ "w2_bias": w2_bias,
+ "gradient_scale": gradient_scale,
+ "alpha": alpha,
+ "sort_end_bit": sort_end_bit,
+ "top_k": moe_top_k,
+ "num_experts": moe_num_experts,
+ "expert_parallel_group": expert_parallel_group,
+ "moe_capacity_factor": moe_capacity_factor,
+ "moe_expert_model_parallelism": moe_expert_model_parallelism,
+ "mlp_impl": mlp_impl,
+ }
+
+ # Add hidden_size for parallel forward
+ if moe_expert_model_parallelism and hidden_size is not None:
+ forward_args["hidden_size"] = hidden_size
+ elif moe_expert_model_parallelism and hidden_size is None:
+ # Infer hidden_size from input shape
+ forward_args["hidden_size"] = x.shape[-1]
+
+ # Compute expert outputs
+ x, tokens_per_expert = forward_fn(**forward_args)
+
+ # Save load balancing loss if needed
+ moe_loss_weight = 0.0 # Can be made configurable
+ if training and moe_loss_weight > 0:
+ save_load_balancing_loss((tokens_per_expert, logits))
+
+ # Restore original shape
+ x = x.view(in_shape)
+
+ return x, expert_weights, router_scores
+
+
+def moe_forward_with_shared_expert(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: Optional[torch.Tensor],
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+ w1: torch.Tensor = None,
+ w2: torch.Tensor = None,
+ w1_bias: torch.Tensor = None,
+ w2_bias: torch.Tensor = None,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ forward_fn: Any = None,
+ hidden_size: int = None,
+ mlp_impl: str = "grouped",
+ # Shared expert parameters
+ shared_up_proj_weight: Optional[torch.Tensor] = None,
+ shared_down_proj_weight: Optional[torch.Tensor] = None,
+ shared_up_proj_bias: Optional[torch.Tensor] = None,
+ shared_down_proj_bias: Optional[torch.Tensor] = None,
+ shared_expert_weighted_sum: bool = False,
+ shared_activation_fn: Optional[Any] = None,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+
+ # First, compute regular MoE forward pass
+ expert_out, expert_weights, router_scores = moe_forward(
+ x=x,
+ router_weight=router_weight,
+ router_bias=router_bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=training,
+ w1=w1,
+ w2=w2,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=moe_expert_model_parallelism,
+ forward_fn=forward_fn,
+ hidden_size=hidden_size,
+ mlp_impl=mlp_impl,
+ )
+
+ # If shared expert weights provided, compute shared expert output
+ if shared_up_proj_weight is not None and shared_down_proj_weight is not None:
+ shared_expert_out = shared_mlp_forward(
+ x=x,
+ up_proj_weight=shared_up_proj_weight,
+ down_proj_weight=shared_down_proj_weight,
+ up_proj_bias=shared_up_proj_bias,
+ down_proj_bias=shared_down_proj_bias,
+ activation_fn=shared_activation_fn,
+ gradient_scale=gradient_scale,
+ )
+
+ # Combine expert outputs
+ combined_out = combine_expert_shared_outputs(
+ shared_expert_out=shared_expert_out,
+ expert_out=expert_out,
+ shared_expert_weighted_sum=shared_expert_weighted_sum,
+ moe_top_k=moe_top_k,
+ )
+
+ return combined_out, expert_weights, router_scores
+
+ # Return regular MoE output if no shared expert
+ return expert_out, expert_weights, router_scores
+
+
+def create_shared_expert_weights(
+ hidden_size: int,
+ shared_expert_hidden_size: int,
+ device: torch.device,
+ dtype: torch.dtype,
+ init_method: Any,
+ output_layer_init_method: Any = None,
+) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
+
+ if output_layer_init_method is None:
+ output_layer_init_method = init_method
+
+ # Create weight tensors
+ up_proj_weight = torch.empty(
+ shared_expert_hidden_size,
+ hidden_size,
+ device=device,
+ dtype=dtype,
+ )
+ down_proj_weight = torch.empty(
+ hidden_size,
+ shared_expert_hidden_size,
+ device=device,
+ dtype=dtype,
+ )
+
+ # Initialize weights
+ init_method(up_proj_weight)
+ output_layer_init_method(down_proj_weight)
+
+ # No bias by default
+ return up_proj_weight, down_proj_weight, None, None
+
+
+# HACK: Extract device_mesh from pre-hook closure - required for transformers integration
+# This exists because device_mesh is trapped in hook closures with no model attribute
+# Fragile - breaks if hook structure changes or Python internals change
+# TODO: Replace with a more robust solution when available
+def get_device_mesh(model):
+ # Extract device_mesh from child's unused pre_hook closure
+ try:
+ # Find the pre-hook that contains 'device_mesh' in its closure
+ hook = next(
+ h
+ for h in model.experts._forward_pre_hooks.values()
+ if "device_mesh" in h.__code__.co_freevars
+ )
+ # Extract the device_mesh from the closure
+ return hook.__closure__[
+ hook.__code__.co_freevars.index("device_mesh")
+ ].cell_contents
+ except Exception:
+ return None
+
+
+class MegaBlocksMoeMLP(torch.nn.Module):
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ gradient_scale = getattr(self.experts, "gradient_scale", None)
+ alpha = getattr(self.experts, "alpha", 1.0)
+ moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+ uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
+
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = get_device_mesh(self)
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
+
+ has_parallel = (
+ expert_parallel_group is not None
+ and dist.is_initialized()
+ and dist.get_world_size(expert_parallel_group) > 1
+ )
+ forward_fn = parallel_forward_once if has_parallel else forward_once
+
+ sort_end_bit = max(
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
+ )
+ mlp_impl = getattr(self, "mlp_impl", "grouped")
+ output, expert_weights_out, *_ = moe_forward(
+ x=x,
+ router_weight=self.router.weight,
+ router_bias=self.router.bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=self.training,
+ w1=self.experts.gate_up_proj,
+ w2=self.experts.down_proj,
+ w1_bias=self.experts.gate_up_proj_bias,
+ w2_bias=self.experts.down_proj_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=has_parallel,
+ forward_fn=forward_fn,
+ hidden_size=self.experts.hidden_size,
+ mlp_impl=mlp_impl,
+ )
+ return output, expert_weights_out
+
+
+# Export main classes
+__all__ = ["MegaBlocksMoeMLP", "MegaBlocksMoeMLPWithSharedExpert"]
+
+
+class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
+
+ def __init__(self):
+ super().__init__()
+ # Shared expert weights will be set by the user
+ self.shared_up_proj_weight = None
+ self.shared_down_proj_weight = None
+ self.shared_up_proj_bias = None
+ self.shared_down_proj_bias = None
+ self.shared_expert_weighted_sum = False
+ self.shared_activation_fn = None
+
+ def set_shared_expert_weights(
+ self,
+ up_proj_weight: torch.Tensor,
+ down_proj_weight: torch.Tensor,
+ up_proj_bias: Optional[torch.Tensor] = None,
+ down_proj_bias: Optional[torch.Tensor] = None,
+ weighted_sum: bool = False,
+ activation_fn: Optional[Any] = None,
+ ):
+ self.shared_up_proj_weight = up_proj_weight
+ self.shared_down_proj_weight = down_proj_weight
+ self.shared_up_proj_bias = up_proj_bias
+ self.shared_down_proj_bias = down_proj_bias
+ self.shared_expert_weighted_sum = weighted_sum
+ self.shared_activation_fn = activation_fn
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ gradient_scale = getattr(self.experts, "gradient_scale", None)
+ alpha = getattr(self.experts, "alpha", 1.0)
+ moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+ uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
+
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = get_device_mesh(self)
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
+
+ has_parallel = (
+ expert_parallel_group is not None
+ and dist.is_initialized()
+ and dist.get_world_size(expert_parallel_group) > 1
+ )
+ forward_fn = parallel_forward_once if has_parallel else forward_once
+
+ sort_end_bit = max(
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
+ )
+ mlp_impl = getattr(self, "mlp_impl", "grouped")
+
+ output, expert_weights_out, *_ = moe_forward_with_shared_expert(
+ x=x,
+ router_weight=self.router.weight,
+ router_bias=self.router.bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=self.training,
+ w1=self.experts.gate_up_proj,
+ w2=self.experts.down_proj,
+ w1_bias=self.experts.gate_up_proj_bias,
+ w2_bias=self.experts.down_proj_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=has_parallel,
+ forward_fn=forward_fn,
+ hidden_size=self.experts.hidden_size,
+ mlp_impl=mlp_impl,
+ # Shared expert parameters
+ shared_up_proj_weight=self.shared_up_proj_weight,
+ shared_down_proj_weight=self.shared_down_proj_weight,
+ shared_up_proj_bias=self.shared_up_proj_bias,
+ shared_down_proj_bias=self.shared_down_proj_bias,
+ shared_expert_weighted_sum=self.shared_expert_weighted_sum,
+ shared_activation_fn=self.shared_activation_fn,
+ )
+ return output, expert_weights_out
+
+
+# Patch for XPU or CPU support
+if hasattr(torch, "xpu") and torch.xpu.is_available():
+ from .xpu_fused_moe import MegaBlocksMoeMLP
+
+from .cpu_moe_cpp import CPUMegaBlocksMoeMLP
diff --git a/build/torch211-cxx11-xpu20253-x86_64-linux/megablocks/__init__.py b/build/torch211-cxx11-xpu20253-x86_64-linux/megablocks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23
--- /dev/null
+++ b/build/torch211-cxx11-xpu20253-x86_64-linux/megablocks/__init__.py
@@ -0,0 +1,26 @@
+import ctypes
+import importlib.util
+import sys
+from pathlib import Path
+from types import ModuleType
+
+
+def _import_from_path(file_path: Path) -> ModuleType:
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
+ # it would also be used for other imports. So, we make a module name that
+ # depends on the path for it to be unique using the hex-encoded hash of
+ # the path.
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
+ module_name = path_hash
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ if spec is None:
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
+ module = importlib.util.module_from_spec(spec)
+ if module is None:
+ raise ImportError(f"Cannot load module {module_name} from spec")
+ sys.modules[module_name] = module
+ spec.loader.exec_module(module) # type: ignore
+ return module
+
+
+globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
diff --git a/build/torch211-cxx11-xpu20253-x86_64-linux/metadata.json b/build/torch211-cxx11-xpu20253-x86_64-linux/metadata.json
new file mode 100644
index 0000000000000000000000000000000000000000..b911d0a2549a35a1c65ab7e77d32e5aac23cd6ac
--- /dev/null
+++ b/build/torch211-cxx11-xpu20253-x86_64-linux/metadata.json
@@ -0,0 +1,8 @@
+{
+ "version": 1,
+ "license": "Apache-2.0",
+ "python-depends": [],
+ "backend": {
+ "type": "xpu"
+ }
+}
\ No newline at end of file
diff --git a/build/torch211-cxx11-xpu20253-x86_64-linux/ops/__init__.py b/build/torch211-cxx11-xpu20253-x86_64-linux/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b944080df810d0b0cfc571f3009b0098a651f9b7
--- /dev/null
+++ b/build/torch211-cxx11-xpu20253-x86_64-linux/ops/__init__.py
@@ -0,0 +1,35 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from .binned_gather import binned_gather
+from .binned_scatter import binned_scatter
+from .cumsum import exclusive_cumsum, inclusive_cumsum
+from .gather import gather
+from .histogram import histogram
+from .padded_gather import padded_gather
+from .padded_scatter import padded_scatter
+from .repeat import repeat
+from .replicate import replicate
+from .round_up import round_up
+from .scatter import scatter
+from .sort import sort
+from .sum import sum
+from .topology import topology
+
+__all__ = [
+ 'binned_gather',
+ 'binned_scatter',
+ 'exclusive_cumsum',
+ 'inclusive_cumsum',
+ 'gather',
+ 'histogram',
+ 'padded_gather',
+ 'padded_scatter',
+ 'repeat',
+ 'replicate',
+ 'round_up',
+ 'scatter',
+ 'sort',
+ 'sum',
+ 'topology',
+]
diff --git a/build/torch211-cxx11-xpu20253-x86_64-linux/ops/all_to_all_benchmark.py b/build/torch211-cxx11-xpu20253-x86_64-linux/ops/all_to_all_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c939818edca3345f6344bbc7cef07ffe3cd0181
--- /dev/null
+++ b/build/torch211-cxx11-xpu20253-x86_64-linux/ops/all_to_all_benchmark.py
@@ -0,0 +1,63 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import torch.distributed as dist
+
+# from megablocks import benchmark_util
+# from megablocks.layers.all_to_all import all_to_all
+
+from .. import benchmark_util
+from .._layers.all_to_all import all_to_all
+
+_ALL_TO_ALL_BENCHMARK = (
+ (8, 1024),
+ (16, 1024),
+ (32, 1024),
+ (64, 1024),
+ (128, 1024),
+ (256, 1024),
+ (512, 1024),
+ (1024, 1024),
+ (2 * 1024, 1024),
+ (4 * 1024, 1024),
+ (8 * 1024, 1024),
+ (16 * 1024, 1024),
+ (32 * 1024, 1024),
+ (64 * 1024, 1024),
+ (128 * 1024, 1024),
+ (256 * 1024, 1024),
+ (512 * 1024, 1024),
+ (1024 * 1024, 1024),
+)
+
+
+def benchmark_all_to_all(group, sl, hs):
+ world_size = dist.get_world_size(group)
+ assert (sl % world_size) == 0
+ send_recv_sizes = [sl // world_size] * world_size
+
+ x = torch.randn((sl, hs)).cuda().half()
+
+ details = {
+ 'world_size': world_size,
+ 'message_size (B)': send_recv_sizes[0] * hs * 2, # 2B elements.
+ }
+
+ def benchmark():
+ return all_to_all(x, send_recv_sizes, send_recv_sizes, group)
+
+ time, std = benchmark_util.benchmark_function(benchmark)
+
+ if dist.get_rank(group) == 0:
+ benchmark_util.log_benchmark('All-To-All', details, time, std)
+
+
+if __name__ == '__main__':
+ assert dist.is_available()
+ group = dist.init_process_group(backend='nccl')
+ local_rank = dist.get_rank(group)
+ torch.cuda.set_device(local_rank)
+
+ for args in _ALL_TO_ALL_BENCHMARK:
+ benchmark_all_to_all(group, *args)
diff --git a/build/torch211-cxx11-xpu20253-x86_64-linux/ops/binned_gather.py b/build/torch211-cxx11-xpu20253-x86_64-linux/ops/binned_gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..189a7fa3518d660f29ea32e7a04827164af98d60
--- /dev/null
+++ b/build/torch211-cxx11-xpu20253-x86_64-linux/ops/binned_gather.py
@@ -0,0 +1,37 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for binned_gather kernel.
+class BinnedGatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bins: torch.Tensor,
+ bin_size: int,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bins)
+ ctx.top_k = top_k
+ return kernels.binned_gather(x, indices, None, bins, bin_size, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ indices, bins = ctx.saved_tensors
+ out = kernels.binned_scatter(grad, indices, None, bins, ctx.top_k)
+ return out, None, None, None, None
+
+
+binned_gather = BinnedGatherOp.apply
diff --git a/build/torch211-cxx11-xpu20253-x86_64-linux/ops/binned_scatter.py b/build/torch211-cxx11-xpu20253-x86_64-linux/ops/binned_scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb937c0c106662ce8108c1cb926f8f063b163d3d
--- /dev/null
+++ b/build/torch211-cxx11-xpu20253-x86_64-linux/ops/binned_scatter.py
@@ -0,0 +1,59 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for binned_scatter kernel.
+class BinnedScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ):
+ assert len(x.size()) == 3
+ ctx.bin_size = x.size(1)
+ ctx.top_k = top_k
+
+ # TODO(tgale): Don't save 'x' for backwards if we don't need to
+ # calculate the gradient w.r.t. 'weights'.
+ ctx.save_for_backward(x, indices, weights, bins)
+ return kernels.binned_scatter(x, indices, weights, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ x, indices, weights, bins = ctx.saved_tensors
+ out = kernels.binned_gather(
+ grad,
+ indices,
+ weights,
+ bins,
+ ctx.bin_size,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[2]:
+ wgrad = kernels.binned_scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bins,
+ ctx.top_k,
+ )
+ return out, None, wgrad, None, None
+
+
+binned_scatter = BinnedScatterOp.apply
diff --git a/build/torch211-cxx11-xpu20253-x86_64-linux/ops/cumsum.py b/build/torch211-cxx11-xpu20253-x86_64-linux/ops/cumsum.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2b7572391e20045d335cf7337246e8a9b9f57ef
--- /dev/null
+++ b/build/torch211-cxx11-xpu20253-x86_64-linux/ops/cumsum.py
@@ -0,0 +1,52 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ # import megablocks_ops as ops # type: ignore
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrappers for cumsum kernels.
+# NOTE: Does not support gradients.
+class ExclusiveCumsumOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, dim: int):
+ if len(x.size()) == 1:
+ x = x.view([1, -1])
+ out = torch.empty_like(x)
+ ops.exclusive_cumsum(x, 1, out)
+ return out.squeeze()
+ out = torch.empty_like(x)
+ ops.exclusive_cumsum(x, dim, out)
+ return out
+
+
+exclusive_cumsum = ExclusiveCumsumOp.apply
+
+
+class InclusiveCumsumOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, dim: int) -> torch.Tensor:
+ if len(x.size()) == 1:
+ x = x.view([1, -1])
+ out = torch.empty_like(x)
+ ops.inclusive_cumsum(x, 1, out)
+ return out.squeeze()
+ out = torch.empty_like(x)
+ ops.inclusive_cumsum(x, dim, out)
+ return out
+
+
+inclusive_cumsum = InclusiveCumsumOp.apply
diff --git a/build/torch211-cxx11-xpu20253-x86_64-linux/ops/gather.py b/build/torch211-cxx11-xpu20253-x86_64-linux/ops/gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1f87c1e7bed8d3589dd790805234976e0b05898
--- /dev/null
+++ b/build/torch211-cxx11-xpu20253-x86_64-linux/ops/gather.py
@@ -0,0 +1,38 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for gather kernel.
+class GatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bin_ids, bins)
+ ctx.top_k = top_k
+ return kernels.gather(x, indices, bin_ids, None, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+
+ indices, bin_ids, bins = ctx.saved_tensors
+ out = kernels.scatter(grad, indices, bin_ids, None, bins, ctx.top_k)
+ return out, None, None, None, None, None
+
+
+gather = GatherOp.apply
diff --git a/build/torch211-cxx11-xpu20253-x86_64-linux/ops/histogram.py b/build/torch211-cxx11-xpu20253-x86_64-linux/ops/histogram.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b3f058ec373cbba7555704fb5e4212c3cc75d9d
--- /dev/null
+++ b/build/torch211-cxx11-xpu20253-x86_64-linux/ops/histogram.py
@@ -0,0 +1,27 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for histogram kernel.
+# NOTE: Does not support gradients.
+class HistogramOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, max_val: float):
+ return ops.histogram(x, max_val)
+
+
+histogram = HistogramOp.apply
diff --git a/build/torch211-cxx11-xpu20253-x86_64-linux/ops/histogram_benchmark.py b/build/torch211-cxx11-xpu20253-x86_64-linux/ops/histogram_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0b9c6047567b87a295979498142230d1b0c9411
--- /dev/null
+++ b/build/torch211-cxx11-xpu20253-x86_64-linux/ops/histogram_benchmark.py
@@ -0,0 +1,78 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import numpy as np
+import torch
+# from absl.testing import parameterized
+
+from .. import ops
+
+_HISTOGRAM_TESTS = (
+ (16384, torch.int32, 2),
+ (16384, torch.int32, 4),
+ (16384, torch.int32, 8),
+ (16384, torch.int32, 16),
+ (16384, torch.int32, 32),
+ (16384, torch.int32, 64),
+ (16384, torch.int32, 128),
+ (16384, torch.int32, 256),
+)
+
+
+def benchmark_function(fn, iterations=10):
+ # Run once to get rid of startup overhead.
+ fn()
+ times = []
+ for _ in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+ start.record()
+ fn()
+ end.record()
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ times = np.array(times)
+ return times.mean(), times.std(), times.max(), times.min()
+
+
+def log_benchmark(arguments, mean_t, std_t):
+ print('=' * 60)
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t))
+ print('=' * 60)
+
+
+# class HistogramBenchmark(parameterized.TestCase):
+#
+# @parameterized.parameters(*_HISTOGRAM_TESTS)
+# def testHistogram(self, n, dtype, max_val):
+# x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
+#
+# mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.histogram(x, max_val),)
+# arguments = {
+# 'n': n,
+# 'dtype': dtype,
+# 'max_val': max_val,
+# }
+# log_benchmark(arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_HISTOGRAM_TESTS)
+# def testTorchHistogram(self, n, dtype, max_val):
+# x = torch.randint(0, 128, (n,)).cuda().to(dtype)
+#
+# mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.histc(x, max_val, 0, max_val - 1),)
+# arguments = {
+# 'n': n,
+# 'dtype': dtype,
+# 'max_val': max_val,
+# }
+# log_benchmark(arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch211-cxx11-xpu20253-x86_64-linux/ops/matmul_benchmark.py b/build/torch211-cxx11-xpu20253-x86_64-linux/ops/matmul_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c4b9b8866ffed2eb769b77f2320c82e5491ae0e
--- /dev/null
+++ b/build/torch211-cxx11-xpu20253-x86_64-linux/ops/matmul_benchmark.py
@@ -0,0 +1,415 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+
+# import stk
+
+# try:
+# import stk
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/ops/matmul_benchmark.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+# from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+
+# Calling tensor.t() calls tensor.transpose(0, 1) which calls
+# torch.as_strided(...). Circumvent this chain to avoid an overhead
+# this adds.
+def transpose_view(x):
+ return torch.as_strided(
+ x,
+ (x.shape[1], x.shape[0]),
+ (x.stride()[1], x.stride()[0]),
+ )
+
+
+_MATMUL_TESTS = (
+ (64 * 1024, 512, 2048, 64),
+ (32 * 1024, 768, 3072, 64),
+ (8 * 1024, 1024, 4096, 64),
+ (4 * 2048, 4096, 4 * 4096, 4),
+)
+
+
+def log_benchmark(name, arguments, time, std, flops):
+ benchmark_util.log_benchmark(name, arguments, time, std)
+ print('flops = {:.2f}B'.format(flops / 1e9))
+ print('throughput = {:.2f}T'.format(flops / 1e9 / time))
+ print('=' * 60)
+
+
+# class MatmulBenchmark(parameterized.TestCase):
+#
+# def build_sparse_matrix(self, x, padded_bins, fhs, ne):
+# blocking = 128
+# padded_tokens, _ = x.size()
+# assert padded_tokens % blocking == 0
+# assert fhs % blocking == 0
+#
+# # Offsets for the sparse matrix. All rows have the
+# # same number of nonzero blocks dictated by the
+# # dimensionality of a single expert.
+# block_rows = padded_tokens // blocking
+# blocks_per_row = fhs // blocking
+# offsets = torch.arange(
+# 0,
+# block_rows * blocks_per_row + 1,
+# blocks_per_row,
+# dtype=torch.int32,
+# device=x.device,
+# )
+#
+# # Indices for the sparse matrix. The indices for
+# # the intermediate matrix are dynamic depending
+# # on the mapping of tokens to experts.
+# column_indices = ops.topology(
+# padded_bins,
+# blocking,
+# block_rows,
+# blocks_per_row,
+# )
+# data = torch.empty(
+# column_indices.numel(),
+# blocking,
+# blocking,
+# dtype=torch.float16,
+# device=x.device,
+# )
+# shape = (padded_tokens, fhs * ne)
+# row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
+# return stk.Matrix(shape, data, row_indices, column_indices, offsets)
+#
+# def build_input_matrix(self, sl, hs, ne):
+# x = torch.randn((sl, hs)).cuda().half()
+#
+# # Assign tokens to experts uniformly.
+# top_expert = torch.arange(0, sl).cuda().int() % ne
+#
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(top_expert, ne)
+# padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+# padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+# out = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, 1)
+# return out, padded_bins
+#
+# def build_weight_matrix(self, ne, hs, fhs):
+# return torch.randn((hs, ne * fhs)).cuda().half()
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_Fwd_SDD_NT(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+# w = transpose_view(w)
+#
+# def benchmark():
+# return stk.ops.sdd(x, w, topo)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0::Fwd::SDD::NT',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_GradX_DSD_NN(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+#
+# def benchmark():
+# return stk.ops.dsd(topo, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0::GradX::DSD::NN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_GradW_DSD_TN(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+# topo = topo.t()
+#
+# def benchmark():
+# return stk.ops.dsd(topo, x)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0::GradW::DSD::TN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_Fwd_DSD_NN(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+#
+# def benchmark():
+# return stk.ops.dsd(x, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::Fwd::DSD::NN',
+# arguments,
+# mean_t,
+# std_t,
+# x.nnz * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_GradX_SDD_NT(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+# out = stk.ops.dsd(x, w)
+# w = transpose_view(w)
+#
+# def benchmark():
+# return stk.ops.sdd(out, w, x)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::GradX::SDD::NT',
+# arguments,
+# mean_t,
+# std_t,
+# x.nnz * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_GradW_DSD_TN(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+# out = stk.ops.dsd(x, w)
+# x = x.t()
+#
+# def benchmark():
+# return stk.ops.dsd(x, out)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::GradW::DSD::TN',
+# arguments,
+# mean_t,
+# std_t,
+# x.nnz * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_Fwd_DDD_NT(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, hs)).cuda().half()
+# w = torch.randn((ne, hs, fhs)).cuda().half()
+#
+# w = w.transpose(1, 2).contiguous()
+# w = w.transpose(1, 2)
+#
+# def benchmark():
+# return torch.bmm(x, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0::Fwd:DDD::NT',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_GradX_DDD_NN(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, hs)).cuda().half()
+# w = torch.randn((ne, hs, fhs)).cuda().half()
+# out = torch.bmm(x, w)
+# w = w.transpose(1, 2).contiguous()
+#
+# def benchmark():
+# return torch.bmm(out, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0:GradX:DDD::NN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_GradW_DDD_TN(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, hs)).cuda().half()
+# w = torch.randn((ne, hs, fhs)).cuda().half()
+# out = torch.bmm(x, w)
+# out = out.transpose(1, 2)
+#
+# def benchmark():
+# return torch.bmm(out, x)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0:GradW:DDD::TN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_Fwd_DDD_NN(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+# w = torch.randn((ne, fhs, hs)).cuda().half()
+#
+# def benchmark():
+# return torch.bmm(x, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::Fwd::DDD::NN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_GradX_DDD_NT(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+# w = torch.randn((ne, fhs, hs)).cuda().half()
+# out = torch.bmm(x, w)
+# w = torch.transpose(w, 1, 2)
+#
+# def benchmark():
+# return torch.bmm(out, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::GradX::DDD::NT',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_GradW_DDD_TN(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+# w = torch.randn((ne, fhs, hs)).cuda().half()
+# out = torch.bmm(x, w)
+# x = torch.transpose(x, 1, 2)
+#
+# def benchmark():
+# return torch.bmm(x, out)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::GradW::DDD::TN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * hs * 2,
+# )
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch211-cxx11-xpu20253-x86_64-linux/ops/padded_gather.py b/build/torch211-cxx11-xpu20253-x86_64-linux/ops/padded_gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1cf4047c9494394d2a3884ba8830179013db7ff
--- /dev/null
+++ b/build/torch211-cxx11-xpu20253-x86_64-linux/ops/padded_gather.py
@@ -0,0 +1,55 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for padded_gather kernel.
+class PaddedGatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bin_ids, bins, padded_bins)
+ ctx.top_k = top_k
+ return kernels.padded_gather(
+ x,
+ indices,
+ bin_ids,
+ None,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+
+ indices, bin_ids, bins, padded_bins = ctx.saved_tensors
+ out = kernels.padded_scatter(
+ grad,
+ indices,
+ bin_ids,
+ None,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+ return out, None, None, None, None, None
+
+
+padded_gather = PaddedGatherOp.apply
diff --git a/build/torch211-cxx11-xpu20253-x86_64-linux/ops/padded_scatter.py b/build/torch211-cxx11-xpu20253-x86_64-linux/ops/padded_scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..61e021b81497e472cda5d72bdac557a0ca92d262
--- /dev/null
+++ b/build/torch211-cxx11-xpu20253-x86_64-linux/ops/padded_scatter.py
@@ -0,0 +1,98 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for padded_scatter kernel.
+class PaddedScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+ ):
+ maybe_x = [x] if ctx.needs_input_grad[3] else []
+ ctx.save_for_backward(
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ *maybe_x,
+ )
+ ctx.top_k = top_k
+ ctx.x_shape = x.shape
+ return kernels.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ saved_tensors = ctx.saved_tensors
+
+ indices, bin_ids, weights, bins, padded_bins = saved_tensors[:5]
+ dgrad = None
+ if ctx.needs_input_grad[0]:
+ dgrad = kernels.padded_gather(
+ grad,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[3]: # need wgrad
+ x = saved_tensors[-1]
+ wgrad = kernels.padded_scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+ return dgrad, None, None, wgrad, None, None, None, None
+
+
+def padded_scatter(
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+):
+ return PaddedScatterOp.apply(
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
diff --git a/build/torch211-cxx11-xpu20253-x86_64-linux/ops/padded_scatter_benchmark.py b/build/torch211-cxx11-xpu20253-x86_64-linux/ops/padded_scatter_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..3cbe4735891446b46f93170c64c23fe63632bf93
--- /dev/null
+++ b/build/torch211-cxx11-xpu20253-x86_64-linux/ops/padded_scatter_benchmark.py
@@ -0,0 +1,66 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import torch
+# from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+_PADDED_SCATTER_BENCHMARK = (
+ # dMoE-Medium, 8-way EMP.
+ (1024 * 16, 1024, 8, 4),
+ # dMoE-Medium, post-all-to-all.
+ (1024 * 16 * 4, 1024, 8, 1),
+)
+
+
+# class PaddedScatterTest(parameterized.TestCase):
+#
+# @parameterized.parameters(*_PADDED_SCATTER_BENCHMARK)
+# def testPaddedScatter(self, sl, hs, ne, top_k):
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+#
+# # Randomly assign tokens to experts.
+# top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(top_expert, ne)
+# padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+# padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+#
+# # Sample weights for the scatter reduce.
+# weights = torch.rand((sl * top_k,)).cuda().half()
+#
+# # Gather the data to prepare for backwards.
+# x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
+#
+# def benchmark():
+# return ops.padded_scatter(
+# x,
+# indices,
+# bin_ids,
+# weights,
+# bins,
+# padded_bins,
+# top_k,
+# )
+#
+# time, std = benchmark_util.benchmark_function(benchmark)
+# benchmark_util.log_benchmark(
+# 'Padded Scatter',
+# {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# 'top_k': top_k,
+# },
+# time,
+# std,
+# )
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch211-cxx11-xpu20253-x86_64-linux/ops/permute_benchmark.py b/build/torch211-cxx11-xpu20253-x86_64-linux/ops/permute_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..697abddbb3a2082ec4ddd6d94f89f7faabb34b40
--- /dev/null
+++ b/build/torch211-cxx11-xpu20253-x86_64-linux/ops/permute_benchmark.py
@@ -0,0 +1,149 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import torch
+# from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+_PERMUTE_TESTS = (
+ (16384, 768, 2),
+ (16384, 768, 4),
+ (16384, 768, 8),
+ (16384, 768, 16),
+ (16384, 768, 32),
+ (16384, 768, 64),
+ (16384, 768, 128),
+ (16384 * 8, 768, 2),
+ (16384 * 8, 768, 4),
+ (16384 * 8, 768, 8),
+ (16384 * 8, 768, 16),
+ (16384 * 8, 768, 32),
+ (16384 * 8, 768, 64),
+ (16384 * 8, 768, 128),
+)
+
+
+# class PermuteBenchmark(parameterized.TestCase):
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testBinnedGather(self, sl, hs, ne):
+# # NOTE: Capacity factor == 1.
+# ec = sl // ne
+#
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+# top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(indices, ne)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+#
+# def benchmark():
+# return ops.binned_gather(x, indices, bins, ec)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('BinnedGather', arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testBinnedScatter(self, sl, hs, ne):
+# # NOTE: Capacity factor == 1.
+# ec = sl // ne
+#
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+# top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(indices, ne)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+# x = ops.binned_gather(x, indices, bins, ec)
+#
+# def benchmark():
+# return ops.binned_scatter(x, indices, bins)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('BinnedScatter', arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testPaddedGather(self, sl, hs, ne):
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+#
+# # Randomly assign tokens to experts.
+# top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(top_expert, ne)
+# padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+# padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+#
+# def benchmark():
+# return ops.padded_gather(x, indices, bin_ids, bins, padded_bins)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('PaddedGather', arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testPaddedScatter(self, sl, hs, ne):
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+#
+# # Randomly assign tokens to experts.
+# top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(top_expert, ne)
+# padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+# padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+# x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins)
+#
+# def benchmark():
+# return ops.padded_scatter(x, indices, bin_ids, bins, padded_bins)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('PaddedScatter', arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testCopy(self, sl, hs, ne):
+# # NOTE: Capacity factor == 1.
+# # ec = sl // ne
+#
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+# y = x.clone()
+#
+# def benchmark():
+# return y.copy_(x)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('Copy', arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch211-cxx11-xpu20253-x86_64-linux/ops/repeat.py b/build/torch211-cxx11-xpu20253-x86_64-linux/ops/repeat.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e9e09de5f857d51cd758ab30b2f3a846d6f9275
--- /dev/null
+++ b/build/torch211-cxx11-xpu20253-x86_64-linux/ops/repeat.py
@@ -0,0 +1,10 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+
+def repeat(x: torch.Tensor, tiling: torch.Size):
+ if all((t == 1 for t in tiling)):
+ return x
+ return x.repeat(*tiling)
diff --git a/build/torch211-cxx11-xpu20253-x86_64-linux/ops/replicate.py b/build/torch211-cxx11-xpu20253-x86_64-linux/ops/replicate.py
new file mode 100644
index 0000000000000000000000000000000000000000..26daf0eede330603a4b8ea7167faf1411d07ca93
--- /dev/null
+++ b/build/torch211-cxx11-xpu20253-x86_64-linux/ops/replicate.py
@@ -0,0 +1,36 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for replicate kernel.
+class ReplicateOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, bins: torch.Tensor, num_outputs: int):
+ ctx.save_for_backward(bins)
+ out = torch.empty((x.shape[0], num_outputs), dtype=x.dtype, device=x.device)
+ ops.replicate_forward(x, bins, out)
+ return out
+
+ @staticmethod
+ def backward(ctx: Any, grad: torch.Tensor):
+ bins, = ctx.saved_tensors
+ out = torch.empty((grad.shape[0], bins.shape[0]), dtype=grad.dtype, device=grad.device)
+ ops.replicate_backward(grad, bins, out)
+ return out, None, None
+
+
+replicate = ReplicateOp.apply
diff --git a/build/torch211-cxx11-xpu20253-x86_64-linux/ops/round_up.py b/build/torch211-cxx11-xpu20253-x86_64-linux/ops/round_up.py
new file mode 100644
index 0000000000000000000000000000000000000000..6cf6bc873c9f448c5fa9126ebcfd66e8688002af
--- /dev/null
+++ b/build/torch211-cxx11-xpu20253-x86_64-linux/ops/round_up.py
@@ -0,0 +1,14 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+
+def round_up(x: torch.Tensor, value: int):
+ assert isinstance(value, int)
+ assert x.dtype == torch.int32
+
+ # TODO(tgale): If this becomes and issue
+ # do this in a custom kernel. We only expect
+ # to use this on arrays of less than 1k elements.
+ return torch.div(x + (value - 1), value, rounding_mode='trunc') * value
diff --git a/build/torch211-cxx11-xpu20253-x86_64-linux/ops/scatter.py b/build/torch211-cxx11-xpu20253-x86_64-linux/ops/scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4605d9b46f387761b070352365f223dbfe69d47
--- /dev/null
+++ b/build/torch211-cxx11-xpu20253-x86_64-linux/ops/scatter.py
@@ -0,0 +1,72 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Optional
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for scatter kernel.
+class ScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ) -> torch.Tensor:
+ maybe_x = [x] if ctx.needs_input_grad[3] else []
+ ctx.save_for_backward(indices, bin_ids, weights, bins, *maybe_x)
+ ctx.top_k = top_k
+ ctx.x_shape = x.shape
+ return kernels.scatter(x, indices, bin_ids, weights, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ saved_tensors = ctx.saved_tensors
+
+ indices, bin_ids, weights, bins = saved_tensors[:4]
+ dgrad = None
+ if ctx.needs_input_grad[0]:
+ dgrad = kernels.gather(
+ grad,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[3]: # need wgrad
+ x = saved_tensors[-1]
+ wgrad = kernels.scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bin_ids,
+ bins,
+ ctx.top_k,
+ )
+ return dgrad, None, None, wgrad, None, None, None
+
+
+def scatter(
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+) -> Optional[torch.Tensor]:
+ return ScatterOp.apply(x, indices, bin_ids, weights, bins, top_k)
diff --git a/build/torch211-cxx11-xpu20253-x86_64-linux/ops/sort.py b/build/torch211-cxx11-xpu20253-x86_64-linux/ops/sort.py
new file mode 100644
index 0000000000000000000000000000000000000000..bda3bf64283e39533c2eae3627e76bb2d0262c9f
--- /dev/null
+++ b/build/torch211-cxx11-xpu20253-x86_64-linux/ops/sort.py
@@ -0,0 +1,38 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Optional, Tuple
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+_BITS_FOR_DTYPE = {
+ torch.int16: 16,
+ torch.int32: 32,
+ torch.int64: 64,
+}
+
+
+# Autograd wrapper for sort kernel.
+# NOTE: Does not support gradients.
+class SortOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, end_bit: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
+ if end_bit is None:
+ end_bit = _BITS_FOR_DTYPE[x.dtype]
+ x_out = torch.empty_like(x)
+ iota_out = torch.empty_like(x)
+ ops.sort(x, end_bit, x_out, iota_out)
+ return (x_out, iota_out)
+
+
+sort = SortOp.apply
diff --git a/build/torch211-cxx11-xpu20253-x86_64-linux/ops/sort_benchmark.py b/build/torch211-cxx11-xpu20253-x86_64-linux/ops/sort_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..11043c0824c36372585f1d9f48480c2a6ef32eb6
--- /dev/null
+++ b/build/torch211-cxx11-xpu20253-x86_64-linux/ops/sort_benchmark.py
@@ -0,0 +1,85 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import numpy as np
+import torch
+# from absl.testing import parameterized
+
+from .. import ops
+
+_SORT_TESTS = (
+ (16384, torch.int32, None),
+ (16384, torch.int32, 2),
+ (16384, torch.int32, 128),
+)
+
+_BASELINE_SORT_TESTS = ((16384,),)
+
+
+def numpy_dtype(dtype):
+ types = {
+ torch.int16: np.int16,
+ torch.int32: np.int32,
+ torch.int64: np.int64,
+ }
+ return types[dtype]
+
+
+def benchmark_function(fn, iterations=10):
+ # Run once to get rid of startup overhead.
+ fn()
+ times = []
+ for _ in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+ start.record()
+ fn()
+ end.record()
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ times = np.array(times)
+ return times.mean(), times.std(), times.max(), times.min()
+
+
+def log_benchmark(arguments, mean_t, std_t):
+ print('=' * 60)
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t))
+ print('=' * 60)
+
+
+# class SortBenchmark(parameterized.TestCase):
+#
+# @parameterized.parameters(*_SORT_TESTS)
+# def testSort(self, n, dtype, max_val):
+# if max_val is None:
+# max_val = np.iinfo(numpy_dtype(dtype)).max
+# end_bit = int(np.ceil(np.log2(max_val)))
+# x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
+#
+# mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.sort(x, end_bit),)
+# arguments = {
+# 'n': n,
+# 'dtype': dtype,
+# 'max_val': max_val,
+# }
+# log_benchmark(arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_BASELINE_SORT_TESTS)
+# def testTorchSort(self, n):
+# x = torch.randint(0, 128, (n,)).cuda().to(torch.int32)
+#
+# mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.sort(x))
+# arguments = {
+# 'n': n,
+# }
+# log_benchmark(arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch211-cxx11-xpu20253-x86_64-linux/ops/stk_autocast.py b/build/torch211-cxx11-xpu20253-x86_64-linux/ops/stk_autocast.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a3626e5e0eec51339c95a448bca84be14a2ca93
--- /dev/null
+++ b/build/torch211-cxx11-xpu20253-x86_64-linux/ops/stk_autocast.py
@@ -0,0 +1,39 @@
+# vendored from
+# https://github.com/stanford-futuredata/stk/blob/736313768ef697ce13a0594a41b2512a0fbc9884/stk/backend/autocast.py
+import functools
+import torch
+
+
+def _is_eligible(x):
+ return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64)
+
+
+def _cast(x, dtype):
+ if isinstance(x, torch.Tensor) and _is_eligible(x):
+ return x.to(dtype)
+ elif isinstance(x, map):
+ return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()}
+ elif isinstance(x, list) or isinstance(x, tuple):
+ return type(x)(map(lambda y: _cast(y, dtype), x))
+ return x
+
+
+def custom_fwd(fwd):
+ """Wrap a custom autograd function that always uses autocast dtype."""
+
+ @functools.wraps(fwd)
+ def decorate_fwd(*args, **kwargs):
+ if torch.is_autocast_enabled():
+ with torch.autocast(device_type="cuda", enabled=False):
+ dtype = torch.get_autocast_gpu_dtype()
+ return fwd(*_cast(args, dtype), **_cast(kwargs, dtype))
+ return fwd(*args, **kwargs)
+ return decorate_fwd
+
+
+def custom_bwd(bwd):
+ @functools.wraps(bwd)
+ def decorate_bwd(*args, **kwargs):
+ with torch.autocast(device_type="cuda", enabled=False):
+ return bwd(*args, **kwargs)
+ return decorate_bwd
\ No newline at end of file
diff --git a/build/torch211-cxx11-xpu20253-x86_64-linux/ops/sum.py b/build/torch211-cxx11-xpu20253-x86_64-linux/ops/sum.py
new file mode 100644
index 0000000000000000000000000000000000000000..e00c1aa68e1193f5b72f75a2edc37de8d505facc
--- /dev/null
+++ b/build/torch211-cxx11-xpu20253-x86_64-linux/ops/sum.py
@@ -0,0 +1,9 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+import torch
+
+
+def sum(x: torch.Tensor, dim: int = 0):
+ if x.shape[dim] == 1:
+ return x.squeeze(dim=dim)
+ return x.sum(dim=dim)
diff --git a/build/torch211-cxx11-xpu20253-x86_64-linux/ops/topology.py b/build/torch211-cxx11-xpu20253-x86_64-linux/ops/topology.py
new file mode 100644
index 0000000000000000000000000000000000000000..76a50d3164db20534b099dcb4d8487a7aef25d15
--- /dev/null
+++ b/build/torch211-cxx11-xpu20253-x86_64-linux/ops/topology.py
@@ -0,0 +1,45 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for topology kernel.
+# NOTE: Does not support gradients.
+class TopologyOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(
+ ctx: Any,
+ padded_bins: torch.Tensor,
+ block_size: int,
+ output_block_rows: int,
+ output_block_columns: int,
+ ):
+ out = torch.empty(
+ output_block_rows * output_block_columns,
+ dtype=torch.int16,
+ device=padded_bins.device,
+ )
+ ops.indices(
+ padded_bins,
+ block_size,
+ output_block_rows,
+ output_block_columns,
+ out,
+ )
+ return out
+
+
+topology = TopologyOp.apply
diff --git a/build/torch211-cxx11-xpu20253-x86_64-linux/stk/__init__.py b/build/torch211-cxx11-xpu20253-x86_64-linux/stk/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..73c40c267b1c3f4949e9c957a5d2c9f682dfc1a6
--- /dev/null
+++ b/build/torch211-cxx11-xpu20253-x86_64-linux/stk/__init__.py
@@ -0,0 +1,7 @@
+# import stk.random
+# import stk.ops
+# from stk.matrix import Matrix
+
+from . import random
+from . import ops
+from .matrix import Matrix
diff --git a/build/torch211-cxx11-xpu20253-x86_64-linux/stk/backend/__init__.py b/build/torch211-cxx11-xpu20253-x86_64-linux/stk/backend/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/build/torch211-cxx11-xpu20253-x86_64-linux/stk/backend/autocast.py b/build/torch211-cxx11-xpu20253-x86_64-linux/stk/backend/autocast.py
new file mode 100644
index 0000000000000000000000000000000000000000..97f6e919a60f3fd579ed0215031008d14111dc96
--- /dev/null
+++ b/build/torch211-cxx11-xpu20253-x86_64-linux/stk/backend/autocast.py
@@ -0,0 +1,37 @@
+import functools
+import torch
+
+
+def _is_eligible(x):
+ return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64)
+
+
+def _cast(x, dtype):
+ if isinstance(x, torch.Tensor) and _is_eligible(x):
+ return x.to(dtype)
+ elif isinstance(x, map):
+ return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()}
+ elif isinstance(x, list) or isinstance(x, tuple):
+ return type(x)(map(lambda y: _cast(y, dtype), x))
+ return x
+
+
+def custom_fwd(fwd):
+ """Wrap a custom autograd function that always uses autocast dtype."""
+
+ @functools.wraps(fwd)
+ def decorate_fwd(*args, **kwargs):
+ if torch.is_autocast_enabled():
+ with torch.autocast(device_type="cuda", enabled=False):
+ dtype = torch.get_autocast_gpu_dtype()
+ return fwd(*_cast(args, dtype), **_cast(kwargs, dtype))
+ return fwd(*args, **kwargs)
+ return decorate_fwd
+
+
+def custom_bwd(bwd):
+ @functools.wraps(bwd)
+ def decorate_bwd(*args, **kwargs):
+ with torch.autocast(device_type="cuda", enabled=False):
+ return bwd(*args, **kwargs)
+ return decorate_bwd
diff --git a/build/torch211-cxx11-xpu20253-x86_64-linux/stk/backend/sputnik.py b/build/torch211-cxx11-xpu20253-x86_64-linux/stk/backend/sputnik.py
new file mode 100644
index 0000000000000000000000000000000000000000..220c947bc1e932e8c77cc30f4069e9930f1aa962
--- /dev/null
+++ b/build/torch211-cxx11-xpu20253-x86_64-linux/stk/backend/sputnik.py
@@ -0,0 +1,316 @@
+import torch
+
+from ..backend import triton_kernels as backend
+from ..backend.autocast import custom_bwd, custom_fwd
+
+
+def _standardize_shape(x, transpose):
+ if transpose:
+ return torch.Size((x[1], x[0]))
+ return x
+
+
+def _sparse_transpose(x):
+ return (torch.Size((x[0][1], x[0][0])), ) + x[1:]
+
+
+def _transpose_helper(x, transpose):
+ if isinstance(x, torch.Tensor):
+ return x.t() if transpose else x
+ if transpose:
+ x = _sparse_transpose(x)
+ return x + (transpose,)
+
+
+def _wrap(x):
+ if isinstance(x, torch.Tensor):
+ return (x,)
+ return x
+
+
+def _is_transposed(x):
+ return (not x.is_contiguous() and
+ x.stride()[0] == 1 and
+ x.stride()[1] == x.size()[0])
+
+
+def _call_helper(op, out, a, b, trans_a, trans_b):
+ args = (_wrap(_transpose_helper(a, trans_a)) +
+ _wrap(_transpose_helper(b, trans_b)))
+ if isinstance(out, tuple):
+ args = args + out
+ return op(*args)
+
+
+def _preprocess_inputs(lhs, rhs, dy):
+ if isinstance(lhs, torch.Tensor) and _is_transposed(lhs):
+ lhs = lhs.t()
+ if isinstance(rhs, torch.Tensor) and _is_transposed(rhs):
+ rhs = rhs.t()
+ if (isinstance(dy, torch.Tensor) and
+ not dy.is_contiguous() and
+ not _is_transposed(dy)):
+ dy = dy.contiguous()
+ if isinstance(dy, tuple) and not dy[1].is_contiguous():
+ dy = (dy[0], dy[1].contiguous()) + dy[2:]
+ return lhs, rhs, dy
+
+
+def _postprocess_outputs(x, transpose, grad):
+ if isinstance(x, torch.Tensor) and transpose:
+ return grad.t()
+ return grad
+
+
+def _lhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs):
+ lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy)
+
+ a, b = (rhs, dy) if trans_lhs else (dy, rhs)
+ trans_a = trans_lhs and trans_rhs
+ trans_b = trans_lhs or not trans_rhs
+ out = _call_helper(op, lhs, a, b, trans_a, trans_b)
+ return _postprocess_outputs(lhs, trans_lhs, out)
+
+
+def _rhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs):
+ lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy)
+
+ a, b = (dy, lhs) if trans_rhs else (lhs, dy)
+ trans_a = not trans_lhs or trans_rhs
+ trans_b = trans_lhs and trans_rhs
+ out = _call_helper(op, rhs, a, b, trans_a, trans_b)
+ return _postprocess_outputs(rhs, trans_rhs, out)
+
+
+class DSD(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs):
+ ctx.save_for_backward(data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ rhs)
+ ctx.shape = _standardize_shape(shape, transpose_a)
+ ctx.transpose_a = transpose_a
+
+ out = torch.empty(
+ (shape[0], rhs.size()[1]),
+ dtype=rhs.dtype,
+ device=rhs.device)
+
+ backend.dsd(shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs,
+ out)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs = (ctx.shape,) + saved_tensors[:-1]
+ rhs = saved_tensors[-1]
+ trans_a = ctx.transpose_a
+ trans_b = _is_transposed(rhs)
+
+ ddata = None
+ if ctx.needs_input_grad[1]:
+ ddata = _lhs_gradient(sdd,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ drhs = None
+ if ctx.needs_input_grad[-1]:
+ op = dds if trans_b else dsd
+ drhs = _rhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return None, ddata, None, None, None, None, None, None, None, drhs
+
+
+dsd = DSD.apply
+
+
+class DDS(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b):
+ ctx.save_for_backward(lhs,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t)
+ ctx.shape = _standardize_shape(shape, transpose_b)
+ ctx.transpose_b = transpose_b
+ out = torch.empty((lhs.size()[0], shape[1]),
+ dtype=lhs.dtype,
+ device=lhs.device)
+ backend.dds(lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b,
+ out)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs = saved_tensors[0]
+ rhs = (ctx.shape,) + saved_tensors[1:]
+ trans_a = _is_transposed(lhs)
+ trans_b = ctx.transpose_b
+
+ dlhs = None
+ if ctx.needs_input_grad[0]:
+ op = dsd if trans_a else dds
+ dlhs = _lhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ ddata = None
+ if ctx.needs_input_grad[2]:
+ ddata = _rhs_gradient(sdd,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return dlhs, None, ddata, None, None, None, None, None, None, None
+
+
+dds = DDS.apply
+
+
+class SDD(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ lhs,
+ rhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t):
+ ctx.save_for_backward(
+ lhs,
+ rhs,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t)
+ ctx.shape = shape
+ out = torch.empty(
+ data.shape,
+ dtype=lhs.dtype,
+ device=lhs.device)
+ backend.sdd(lhs,
+ rhs,
+ shape,
+ out,
+ offsets,
+ row_indices,
+ column_indices)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs, rhs = saved_tensors[:2]
+ dy = (ctx.shape, dy) + saved_tensors[2:]
+ trans_a = _is_transposed(lhs)
+ trans_b = _is_transposed(rhs)
+
+ dlhs = None
+ if ctx.needs_input_grad[0]:
+ op = dds if trans_a else dsd
+ dlhs = _lhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ drhs = None
+ if ctx.needs_input_grad[1]:
+ op = dsd if trans_b else dds
+ drhs = _rhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return dlhs, drhs, None, None, None, None, None, None, None, None
+
+
+sdd = SDD.apply
+
+class RowIndices(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, shape, data, offsets, column_indices):
+ out = torch.empty(
+ column_indices.shape,
+ dtype=column_indices.dtype,
+ device=column_indices.device)
+ backend.row_indices(shape, data, offsets, column_indices, out)
+ return out
+
+
+row_indices = RowIndices.apply
diff --git a/build/torch211-cxx11-xpu20253-x86_64-linux/stk/backend/triton_kernels.py b/build/torch211-cxx11-xpu20253-x86_64-linux/stk/backend/triton_kernels.py
new file mode 100644
index 0000000000000000000000000000000000000000..c535309f3321249f475367164a558f94a4f8eb86
--- /dev/null
+++ b/build/torch211-cxx11-xpu20253-x86_64-linux/stk/backend/triton_kernels.py
@@ -0,0 +1,393 @@
+import torch
+import triton
+import triton.language as tl
+from dataclasses import dataclass
+
+@dataclass
+class TritonConfig:
+ BLOCK_M: int = 128
+ BLOCK_N: int = 128
+ BLOCK_K: int = 32
+ BLOCK_SIZE: int = 128
+ NUM_STAGES: int = 4
+ NUM_WARPS: int = 4
+
+def _validate_matmul_dims(M: int, K: int, N: int):
+ error_string = "incompatible dimensions: tensor has dim with length: {}, which must be divisible by {}"
+ assert M % TritonConfig.BLOCK_M == 0, error_string.format(M, TritonConfig.BLOCK_M)
+ assert K % TritonConfig.BLOCK_K == 0, error_string.format(K, TritonConfig.BLOCK_K)
+ assert N % TritonConfig.BLOCK_N == 0, error_string.format(N, TritonConfig.BLOCK_N)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _sdd_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+ # matrix multiplication
+ pid = tl.program_id(0)
+ pid_m = tl.load(row_indices + pid)
+ pid_n = tl.load(column_indices + pid)
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
+ rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
+ rk = tl.arange(0, BLOCK_K)
+ # pointers
+ A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
+ B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ for k in range(0, tl.cdiv(K, BLOCK_K)):
+ a = tl.load(A)
+ b = tl.load(B)
+ acc += tl.dot(a, b)
+ A += BLOCK_K * stride_ak
+ B += BLOCK_K * stride_bk
+ #Store to sparse matrix
+ acc = acc.to(C.dtype.element_ty)
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+ cm = tl.arange(0, BLOCK_M)
+ cn = tl.arange(0, BLOCK_N)
+ C = C + pid * BLOCK_ELEMENTS + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _dsd_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices, offsets,
+ block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+
+ # matrix multiplication
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+
+ num_pid_m = tl.num_programs(0)
+ num_pid_n = tl.num_programs(1)
+ pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M)
+
+ start_inx = tl.load(offsets + pid_m)
+ end_inx = tl.load(offsets + pid_m + 1)
+
+ # pointers to sparse matrix
+ rm = tl.arange(0, BLOCK_M)
+ rak = tl.arange(0, BLOCK_K)
+
+ A += (rm[:, None] * stride_am + rak[None, :] * stride_ak)
+
+ # pointers to dense matrix
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ rbk = tl.arange(0, BLOCK_K)
+ B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn)
+
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K)
+
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+ ak_sub_incr = BLOCK_K * stride_ak
+ bk_sub_incr = BLOCK_K * stride_bk
+ bk_block_incr = BLOCK_SIZE * stride_bk
+
+ for k in range(nsub_blocks * (end_inx - start_inx)):
+ sub_block_inx = k % nsub_blocks
+ block_inx = k // nsub_blocks
+
+ if trans_A:
+ ptr_A = A + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr
+ else:
+ ptr_A = A + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr
+
+ ptr_B = B + tl.load(column_indices + start_inx + block_inx) * bk_block_incr + sub_block_inx * bk_sub_incr
+
+ a = tl.load(ptr_A)
+ b = tl.load(ptr_B)
+ acc += tl.dot(a, b)
+
+ acc = acc.to(C.dtype.element_ty)
+
+ cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+
+ C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _dds_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices, offsets,
+ block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+
+ # matrix multiplication
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+
+ num_pid_m = tl.num_programs(0)
+ num_pid_n = tl.num_programs(1)
+ pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M)
+
+ start_inx = tl.load(offsets + pid_n)
+ end_inx = tl.load(offsets + pid_n + 1)
+
+ # pointers to dense matrix
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ rak = tl.arange(0, BLOCK_K)
+
+ A += (rm[:, None] * stride_am + rak[None, :] * stride_ak)
+
+ # pointers to sparse matrix
+ rn = tl.arange(0, BLOCK_N)
+ rbk = tl.arange(0, BLOCK_K)
+ B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn)
+
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K)
+
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+
+ ak_sub_incr = BLOCK_K * stride_ak
+ ak_block_incr = BLOCK_SIZE * stride_ak
+ bk_sub_incr = BLOCK_K * stride_bk
+
+ for k in range(nsub_blocks * (end_inx - start_inx)):
+ sub_block_inx = k % nsub_blocks
+ block_inx = k // nsub_blocks
+
+ if trans_B:
+ ptr_B = B + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr
+ else:
+ ptr_B = B + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr
+
+ ptr_A = A + tl.load(column_indices + start_inx + block_inx) * ak_block_incr + sub_block_inx * ak_sub_incr
+ a = tl.load(ptr_A)
+ b = tl.load(ptr_B)
+ acc += tl.dot(a, b)
+
+ acc = acc.to(C.dtype.element_ty)
+ cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+def dsd(shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs,
+ out
+ ):
+
+ device = rhs.device
+ trans_A = transpose_a
+ trans_B = False
+
+ if rhs.stride(0) > 1 and rhs.stride(1) > 1:
+ trans_B = True
+
+ # checks constraints
+ assert shape[1] == rhs.shape[0], "incompatible dimensions"
+ M, K = shape
+ _, N = rhs.shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if rhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ stride_am, stride_ak = data.stride(1), data.stride(2)
+ stride_bk, stride_bn = rhs.stride(0), rhs.stride(1)
+ a_column_indices = column_indices
+ a_offsets = offsets
+
+ # launch kernel
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N']))
+
+ if trans_A:
+ stride_am, stride_ak = data.stride(2), data.stride(1)
+ a_column_indices, a_offsets = column_indices_t, offsets_t
+
+ if trans_B:
+ stride_bk, stride_bn = rhs.stride(1), rhs.stride(0)
+
+ _dsd_kernel[grid](
+ data.data, rhs, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(0), out.stride(1),
+ row_indices, a_column_indices, a_offsets,
+ block_offsets_t, trans_A, trans_B,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+ # return out
+
+def dds(lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b,
+ out
+ ):
+
+ device = lhs.device
+ trans_B = transpose_b
+ trans_A = False
+
+ if lhs.stride(0) > 1 and lhs.stride(1) > 1:
+ trans_A = True
+
+ # checks constraints
+ assert lhs.shape[1] == shape[0], "incompatible dimensions"
+ M, K = lhs.shape
+ _, N = shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if lhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ stride_am, stride_ak = lhs.stride(0), lhs.stride(1)
+ stride_bk, stride_bn = data.stride(1), data.stride(2)
+ b_column_indices = column_indices_t
+ b_offsets = offsets_t
+
+ # launch kernel
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N']))
+
+ if trans_A:
+ stride_am, stride_ak = lhs.stride(1), lhs.stride(0)
+ if trans_B:
+ stride_bk, stride_bn = data.stride(2), data.stride(1)
+ b_column_indices, b_offsets = column_indices, offsets
+
+ _dds_kernel[grid](
+ lhs, data, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(0), out.stride(1),
+ row_indices, b_column_indices, b_offsets,
+ block_offsets_t, trans_A, trans_B,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+
+def sdd(lhs,
+ rhs,
+ shape,
+ out,
+ offsets,
+ row_indices,
+ column_indices
+ ):
+
+ device = out.device
+ trans_A = False
+ trans_B = False
+
+ if lhs.stride(0) > 1 and lhs.stride(1) > 1:
+ trans_A = True
+ if rhs.stride(0) > 1 and rhs.stride(1) > 1:
+ trans_B = True
+
+ # checks constraints
+ assert lhs.shape[1] == rhs.shape[0], "incompatible dimensions"
+ M, K = lhs.shape
+ _, N = rhs.shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if out.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ # launch kernel
+ nnz_blocks = len(row_indices)
+ grid = lambda META: (nnz_blocks,)
+
+ stride_am, stride_ak = lhs.stride(0), lhs.stride(1)
+ stride_bk, stride_bn = rhs.stride(0), rhs.stride(1)
+
+ if trans_A:
+ stride_am, stride_ak = lhs.stride(1), lhs.stride(0)
+ if trans_B:
+ stride_bk, stride_bn = rhs.stride(1), rhs.stride(0)
+
+ _sdd_kernel[grid](
+ lhs, rhs, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(1), out.stride(2),
+ row_indices, column_indices,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+
+@triton.jit
+def _row_indices_kernel(offsets, out):
+ pid = tl.program_id(0)
+ row_offset = tl.load(offsets + pid)
+ nnz_blocks = tl.load(offsets + pid + 1) - row_offset
+ for nnz_block in range(nnz_blocks):
+ tl.store(out + row_offset + nnz_block, pid)
+
+def row_indices(
+ shape, data, offsets, column_indices, out
+):
+ block_rows = len(offsets) - 1
+ _row_indices_kernel[(block_rows, )](offsets, out)
diff --git a/build/torch211-cxx11-xpu20253-x86_64-linux/stk/matrix.py b/build/torch211-cxx11-xpu20253-x86_64-linux/stk/matrix.py
new file mode 100644
index 0000000000000000000000000000000000000000..80f42263d6aada287adbfa52a61fe950162a9e28
--- /dev/null
+++ b/build/torch211-cxx11-xpu20253-x86_64-linux/stk/matrix.py
@@ -0,0 +1,329 @@
+import numpy as np
+import torch
+
+# 1. Add heavyweight (data) validation helper.
+# 2. Add construction helpers
+# 3. Make indentation consistent
+# 4. Replace asserts with descriptive errors.
+
+##
+### Validation helpers.
+##
+
+
+def _validate_matrix(shape, data, row_indices, column_indices, offsets):
+ # Data should be [nnz, block_size, block_size]
+ if data.dim() == 1:
+ data = torch.reshape(data, [data.numel(), 1, 1])
+
+ # Blocks should be square.
+ if data.shape[-2] != data.shape[-1]:
+ raise ValueError(
+ "Expected square blocking in data. "
+ f"Got block shape {[data.shape[-2], data.shape[-1]]}")
+
+ # Flatten batch dimensions on data - original shape preserved
+ # in shape argument.
+ block_size = data.shape[-1]
+ data = data.view([-1, block_size, block_size])
+
+ if data.dim() != 3:
+ raise ValueError(
+ "Expected 3D shape for data (nnz, block, block). "
+ f"Got shape {data.dim()}D shape.")
+
+ block_size = data.shape[1]
+ if shape[-2] % block_size != 0 or shape[-1] % block_size != 0:
+ raise ValueError(
+ "Matrix shape must be dividible by blocking. "
+ f"Got shape {shape} with "
+ f"{[block_size, block_size]} blocking.")
+
+ if np.prod(shape) < data.numel():
+ raise ValueError(
+ "Invalid matrix. Number of nonzeros exceeds matrix capacity "
+ f"({data.numel()} v. {np.prod(shape)})")
+
+ if row_indices.dim() != 1:
+ raise ValueError(
+ f"Expected 1D row_indices. Got {row_indices.dim()}D row_indices.")
+
+ if column_indices.dim() != 1:
+ raise ValueError(
+ f"Expected 1D column_indices. Got {column_indices.dim()}D column_indices.")
+
+ if offsets.dim() != 1:
+ raise ValueError(
+ f"Expected 1D offsets. Got {offsets.dim()}D offsets.")
+
+ if row_indices.numel() != data.shape[0]:
+ raise ValueError(
+ "Expected 1 index per nonzero block. "
+ f"Got {row_indices.numel()} row_indices for {data.shape[0]} blocks")
+
+ if column_indices.numel() != data.shape[0]:
+ raise ValueError(
+ "Expected 1 index per nonzero block. "
+ f"Got {column_indices.numel()} column_indices for {data.shape[0]} blocks")
+
+ block_rows = np.prod(shape[:-1]) / block_size
+ if offsets.numel() != block_rows + 1:
+ raise ValueError(
+ "Expected one offset per block row plus one. "
+ f"Got {offsets.numel()} offsets with {block_rows} block rows.")
+
+ is_cuda = (data.is_cuda and
+ row_indices.is_cuda and
+ column_indices.is_cuda and
+ offsets.is_cuda)
+ is_cpu = (not data.is_cuda and
+ not row_indices.is_cuda and
+ not column_indices.is_cuda and
+ not offsets.is_cuda)
+ if not (is_cuda or is_cpu):
+ raise ValueError(
+ "Expected data & meta-data on common device. "
+ f"Got data on {data.device}, row_indices on {row_indices.device} "
+ f"column_indices on {column_indices.device} and "
+ f"offsets on {offsets.device}.")
+
+ if data.dtype != torch.float16:
+ raise ValueError(
+ f"Expected float16 data. Got {data.dtype} data.")
+ if row_indices.dtype != torch.int16:
+ raise ValueError(
+ f"Expected int16 row_indices. Got {row_indices.dtype} row_indices.")
+ if column_indices.dtype != torch.int16:
+ raise ValueError(
+ f"Expected int16 column_indices. Got {column_indices.dtype} column_indices.")
+ if offsets.dtype != torch.int32:
+ raise ValueError(
+ f"Expected int32 offsets. Got {offsets.dtype} offsets.")
+ return data
+
+
+def _transpose(size, data, row_indices, column_indices, offsets):
+ block_columns = size[1] // data.shape[1]
+
+ # Sort row indices by column indices to get the transposed matrix's
+ # column indices.
+ gather_indices = column_indices.argsort()
+ column_indices_t = row_indices.gather(0, gather_indices)
+ block_offsets_t = gather_indices.int()
+
+ # NOTE: Histogram is not implemented for any integer type on CPU. Do
+ # the histogram in 32-bit float, which can exactly represent 16-bit
+ # integers.
+ column_indices_float = column_indices.float()
+
+ zero = torch.zeros((1,), dtype=torch.int32, device=data.device)
+ nnz_per_column = column_indices_float.histc(block_columns, 0, block_columns)
+ nnz_per_column = nnz_per_column.int()
+ offsets_t = torch.cat([zero, nnz_per_column.cumsum(0, dtype=torch.int32)])
+ return column_indices_t, offsets_t, block_offsets_t
+
+
+class Matrix(torch.nn.Module):
+ """A matrix stored in sparse format.
+
+ Underlying format is block compressed sparse row (BCSR).
+
+ TODO(tgale): Make this mirror torch.Tensor API as much as possible.
+ """
+
+ def __init__(self,
+ size,
+ data,
+ row_indices,
+ column_indices,
+ offsets,
+ column_indices_t=None,
+ offsets_t=None,
+ block_offsets_t=None):
+ super().__init__()
+ self._size = size
+ self._data = data
+ self._row_indices = row_indices
+ self._column_indices = column_indices
+ self._offsets = offsets
+
+ # Produce the transpose meta-data if it is not passed in.
+ if ((column_indices_t is None) or (offsets_t is None) or
+ (block_offsets_t is None)):
+ column_indices_t, offsets_t, block_offsets_t = _transpose(
+ size, data, row_indices, column_indices, offsets)
+ self._column_indices_t = column_indices_t
+ self._offsets_t = offsets_t
+ self._block_offsets_t = block_offsets_t
+
+ self._transposed = False
+
+ # Validate that our metadata will not overflow.
+ max_dim = np.iinfo(np.int16).max * self.blocking
+ if column_indices.dtype == torch.int16:
+ if size[0] > max_dim or size[1] > max_dim:
+ raise ValueError(
+ "Sparse matrix with shape {size} exceeds representable "
+ "size with 16-bit indices.")
+
+ def validate(self):
+ _validate_matrix(self._size,
+ self._data,
+ self._row_indices,
+ self._column_indices,
+ self._offsets)
+
+ # TODO(tgale): Add heavyweight data validation.
+
+ def to(self, device):
+ # TODO(tgale): Handle type conversions here. We
+ # need to set the appropriate meta-data type for
+ # the given floating-point type.
+ self._data = self._data.to(device)
+ self._row_indices = self._row_indices.to(device)
+ self._column_indices = self._column_indices.to(device)
+ self._offsets = self._offsets.to(device)
+ self._column_indices_t = self._column_indices_t.to(device)
+ self._offsets_t = self._offsets_t.to(device)
+ self._block_offsets_t = self._block_offsets_t.to(device)
+ return self
+
+ def cuda(self):
+ return self.to(torch.cuda.current_device())
+
+ def clone(self):
+ return Matrix(
+ self.size(),
+ self.data.clone(),
+ self.row_indices.clone(),
+ self.column_indices.clone(),
+ self.offsets.clone(),
+ self.column_indices_t.clone(),
+ self.offsets_t.clone(),
+ self.block_offsets_t.clone())
+
+ def t(self):
+ if self.dim() != 2:
+ raise ValueError(
+ "t() expects a tensor with <= 2 dimensions, "
+ f"but self is {self.dim()}D.")
+ out = Matrix(self.size(),
+ self.data,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+ out._transposed = not self._transposed
+ out._size = torch.Size((self._size[1], self._size[0]))
+ return out
+
+ def contiguous(self):
+ raise ValueError("Not yet implemented.")
+
+ def is_contiguous(self):
+ return not self._transposed
+
+ @property
+ def is_cuda(self):
+ return self._data.is_cuda
+
+ @property
+ def device(self):
+ return self._data.device
+
+ def size(self):
+ return self._size
+
+ @property
+ def shape(self):
+ return self.size()
+
+ def dim(self):
+ return len(self._size)
+
+ @property
+ def data(self):
+ return self._data
+
+ @property
+ def row_indices(self):
+ return self._row_indices
+
+ @property
+ def column_indices(self):
+ return self._column_indices
+
+ @property
+ def offsets(self):
+ return self._offsets
+
+ @property
+ def offsets_t(self):
+ return self._offsets_t
+
+ @property
+ def column_indices_t(self):
+ return self._column_indices_t
+
+ @property
+ def block_offsets_t(self):
+ return self._block_offsets_t
+
+ @property
+ def dtype(self):
+ return self.data.dtype
+
+ @property
+ def nnz(self):
+ return self.data.numel()
+
+ @property
+ def blocking(self):
+ return self.data.shape[1]
+
+ @property
+ def requires_grad(self):
+ return self.data.requires_grad
+
+ def requires_grad_(self, x):
+ self.data.requires_grad_(x)
+ return self
+
+ def view(self, *shape):
+ assert self.is_contiguous()
+ if shape[-1] != self.size()[-1]:
+ raise ValueError(
+ "Can't change view on compressed dimension. "
+ f"{self.size()[-1]} v. {shape[-1]}.")
+ if np.prod(shape) != np.prod(self.size()):
+ raise ValueError(
+ "Mismatch in numel of Matrix and new shape. "
+ f"{np.prod(self.size())} v. {np.prod(shape)}")
+ return Matrix(shape,
+ self.data,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+
+ @property
+ def grad(self):
+ # TODO(tgale): Make sure this mirrors torch.Tensor
+ # behavior in the case where we ask for the gradient
+ # of a non-contiguous tensor.
+ size = self.size()
+ if not self.is_contiguous():
+ size = torch.Size((size[1], size[0]))
+ out = Matrix(size,
+ self.data.grad,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+ return out if self.is_contiguous() else out.t()
diff --git a/build/torch211-cxx11-xpu20253-x86_64-linux/stk/ops/__init__.py b/build/torch211-cxx11-xpu20253-x86_64-linux/stk/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc873b236f4cd4036964c016a4036e3ce5ebf1ac
--- /dev/null
+++ b/build/torch211-cxx11-xpu20253-x86_64-linux/stk/ops/__init__.py
@@ -0,0 +1,3 @@
+from .linear_ops import dds, dsd, sdd
+from .matrix_ops import ones_like, row_indices, sum, to_dense, to_sparse
+from .eltwise_ops import mul
diff --git a/build/torch211-cxx11-xpu20253-x86_64-linux/stk/ops/eltwise_ops.py b/build/torch211-cxx11-xpu20253-x86_64-linux/stk/ops/eltwise_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba7d7332320250fd01fa60e60528f19de3e8ed03
--- /dev/null
+++ b/build/torch211-cxx11-xpu20253-x86_64-linux/stk/ops/eltwise_ops.py
@@ -0,0 +1,28 @@
+from ..matrix import Matrix
+
+def mul(a, b):
+ """Performs element-wise multiplication of matrices a and b.
+
+ It is the user's responsibility to make sure that a and b
+ follow the same matrix topology. This function assumes it is safe
+ to use the topoplogy of a.
+
+ Args:
+ a: stk.Matrix.
+ b: stk.Matrix with a's matrix topology.
+
+ Returns:
+ stk.Matrix where the entries correspond to torch.mul(a, b).
+ """
+ assert isinstance(a, Matrix)
+ assert isinstance(b, Matrix)
+ assert a.size() == b.size()
+
+ return Matrix(a.size(),
+ a.data * b.data,
+ a.row_indices,
+ a.column_indices,
+ a.offsets,
+ a.column_indices_t,
+ a.offsets_t,
+ a.block_offsets_t)
diff --git a/build/torch211-cxx11-xpu20253-x86_64-linux/stk/ops/eltwise_ops_test.py b/build/torch211-cxx11-xpu20253-x86_64-linux/stk/ops/eltwise_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..2939372a5c68ac92b47b11015db4f75f4fd60ffa
--- /dev/null
+++ b/build/torch211-cxx11-xpu20253-x86_64-linux/stk/ops/eltwise_ops_test.py
@@ -0,0 +1,86 @@
+import unittest
+import itertools
+import torch
+# from absl.testing import parameterized
+
+import stk
+from stk.ops.linear_ops_test import allclose, _dense_and_sparse
+
+_MATRIX_SIZES = (
+ (128, 128, 0.0),
+ (256, 256, 0.5),
+ (2048, 1024, 0.8),
+ (512, 128, 0.0),
+ (128, 512, 0.0),
+ (1024, 512, 0.0),
+ (1024, 512, 0.5),
+ (1024, 512, 0.75),
+ (512, 1024, 0.0),
+ (512, 1024, 0.5),
+ (512, 1024, 0.75),
+ (1024, 1024, 0.0),
+ (1024, 1024, 0.5),
+ (1024, 1024, 0.75),
+)
+
+_DTYPE = (
+ torch.float16, torch.bfloat16
+)
+
+def _generate_testcases():
+ testcases = itertools.product(_MATRIX_SIZES, _DTYPE)
+ testcases = [(*size, 128, dtype) for
+ (size, dtype) in testcases]
+ return testcases
+
+_ELTWISE_OP_TESTS = _generate_testcases()
+
+def _dense_and_sparse_like(x, std=0.1):
+ dense_data = torch.randn_like(x.data, device=x.device) * std
+ sparse = stk.Matrix(x.size(),
+ dense_data,
+ x.row_indices,
+ x.column_indices,
+ x.offsets)
+ dense = stk.ops.to_dense(sparse)
+
+ return (dense.requires_grad_(True),
+ sparse.requires_grad_(True))
+
+# @parameterized.parameters(_ELTWISE_OP_TESTS)
+# class EltwiseOpsTest(parameterized.TestCase):
+#
+# def testEltwiseMul(self, m, n, sparsity, blocking, dtype):
+#
+# a_dense, a = _dense_and_sparse(m, n, sparsity, blocking, dtype)
+# b_dense, b = _dense_and_sparse_like(a)
+#
+# out = stk.ops.mul(a, b)
+# expected_out = torch.mul(a_dense, b_dense)
+#
+# # Compute the gradients w.r.t. the inputs.
+# expected_out.sum().backward()
+# stk.ops.sum(out).backward()
+#
+# # Validate the results.
+# out = stk.ops.to_dense(out)
+# self.assertEqual(out.dim(), 2)
+# self.assertEqual(expected_out.size(), out.size())
+# self.assertTrue(allclose(out, expected_out))
+#
+# # LHS gradient.
+# grad = stk.ops.to_dense(a.grad)
+# expected_grad = a_dense.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size(), grad.size())
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# # RHS gradient.
+# grad = stk.ops.to_dense(b.grad)
+# expected_grad = b_dense.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size(), grad.size())
+# self.assertTrue(allclose(grad, expected_grad))
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch211-cxx11-xpu20253-x86_64-linux/stk/ops/linear_ops.py b/build/torch211-cxx11-xpu20253-x86_64-linux/stk/ops/linear_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d277c8c07f9e30addc31900a12175c8a1f4d7ad
--- /dev/null
+++ b/build/torch211-cxx11-xpu20253-x86_64-linux/stk/ops/linear_ops.py
@@ -0,0 +1,59 @@
+import torch
+
+from ..backend import sputnik
+from ..matrix import Matrix
+
+
+def dsd(a, b):
+ assert isinstance(a, Matrix)
+ assert isinstance(b, torch.Tensor)
+ return sputnik.dsd(
+ a.size(),
+ a.data, a.offsets,
+ a.row_indices,
+ a.column_indices,
+ a.offsets_t,
+ a.column_indices_t,
+ a.block_offsets_t,
+ not a.is_contiguous(),
+ b)
+
+
+def dds(a, b):
+ assert isinstance(a, torch.Tensor)
+ assert isinstance(b, Matrix)
+ return sputnik.dds(
+ a,
+ b.size(),
+ b.data, b.offsets,
+ b.row_indices,
+ b.column_indices,
+ b.offsets_t,
+ b.column_indices_t,
+ b.block_offsets_t,
+ not b.is_contiguous())
+
+
+def sdd(a, b, topo):
+ assert isinstance(a, torch.Tensor)
+ assert isinstance(b, torch.Tensor)
+ assert isinstance(topo, Matrix)
+ assert topo.is_contiguous()
+ out = sputnik.sdd(
+ a, b,
+ topo.size(),
+ topo.data,
+ topo.offsets,
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets_t,
+ topo.column_indices_t,
+ topo.block_offsets_t)
+ return Matrix(topo.size(),
+ out,
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets,
+ topo.column_indices_t,
+ topo.offsets_t,
+ topo.block_offsets_t)
diff --git a/build/torch211-cxx11-xpu20253-x86_64-linux/stk/ops/linear_ops_test.py b/build/torch211-cxx11-xpu20253-x86_64-linux/stk/ops/linear_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1c24d350df9c1b2346c7da885502cd696c88867
--- /dev/null
+++ b/build/torch211-cxx11-xpu20253-x86_64-linux/stk/ops/linear_ops_test.py
@@ -0,0 +1,216 @@
+import unittest
+import itertools
+import numpy as np
+import torch
+# from absl.testing import parameterized
+
+import stk
+
+
+def allclose(x, y, pct=0.25):
+ mask = torch.isclose(x, y, rtol=5e-2)
+ pct_diff = (mask.numel() - mask.sum()) / mask.numel() * 100
+ if pct_diff > pct:
+ print("{:.2f}% of values not close.".format(pct_diff))
+ return False
+ return True
+
+
+# An assortment of problems designed to make sure
+# the bindings are operating correctly.
+_MATRIX_SIZES = (
+ (128, 128, 128, 0.0),
+ (256, 256, 256, 0.5),
+ (2048, 1024, 512, 0.8),
+ (512, 128, 128, 0.0),
+ (128, 128, 512, 0.0),
+ (1024, 512, 512, 0.0),
+ (1024, 512, 512, 0.5),
+ (1024, 512, 512, 0.75),
+ (512, 512, 1024, 0.0),
+ (512, 512, 1024, 0.5),
+ (512, 512, 1024, 0.75),
+ (1024, 1024, 1024, 0.0),
+ (1024, 1024, 1024, 0.5),
+ (1024, 1024, 1024, 0.75),
+)
+
+_TRANSPOSE = (
+ (False, False),
+ (False, True),
+ (True, False),
+ (True, True),
+)
+
+_DTYPE = (
+ torch.float16, torch.bfloat16
+)
+
+def _generate_testcases():
+ testcases = itertools.product(_MATRIX_SIZES, _TRANSPOSE, _DTYPE)
+ testcases = [(*size, *trans, 128, dtype) for
+ (size, trans, dtype) in testcases]
+ return testcases
+
+_LINEAR_OP_TESTS = _generate_testcases()
+
+def _dense_and_sparse(rows, cols, sparsity, blocking, dtype, std=0.1):
+ mask = stk.random.dense_mask(rows, cols, sparsity, blocking)
+ dense = (torch.randn(rows, cols) * std * mask).type(dtype)
+ sparse = stk.ops.to_sparse(dense, blocking)
+ cuda_device = torch.device("cuda")
+ return (dense.to(cuda_device).requires_grad_(True),
+ sparse.to(cuda_device).requires_grad_(True))
+
+
+def _dense(rows, cols, dtype, std=0.1):
+ cuda_device = torch.device("cuda")
+ out = (torch.randn(rows, cols) * std).type(dtype)
+ return out.to(cuda_device).requires_grad_(True)
+
+
+def _dense_2x(rows, cols, dtype):
+ a = _dense(rows, cols, dtype)
+ return a, a.detach().requires_grad_(True)
+
+
+def _with_transpose(op, a, b, trans_a, trans_b):
+ a = a.t() if trans_a else a
+ b = b.t() if trans_b else b
+ return op(a, b)
+
+
+def _mmm(a, b, topo):
+ mask = stk.ops.to_dense(stk.ops.ones_like(topo))
+ return torch.mm(a, b) * mask
+
+
+def _sparse_out_with_transpose(op, a, b, topo, trans_a, trans_b):
+ a = a.t() if trans_a else a
+ b = b.t() if trans_b else b
+ return op(a, b, topo)
+
+
+def _mask(x, mask):
+ mask = stk.ops.to_dense(stk.ops.ones_like(mask))
+ return x * mask
+
+
+# @parameterized.parameters(*_LINEAR_OP_TESTS)
+# class LinearOpsTest(parameterized.TestCase):
+#
+# def testLinearOps_Dsd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+# # Construct the operands.
+# a_shape = (k, m) if trans_a else (m, k)
+# a_dense, a = _dense_and_sparse(*a_shape, sparsity, blocking, dtype)
+# b_shape = (n, k) if trans_b else (k, n)
+# b, bcp = _dense_2x(*b_shape, dtype)
+#
+# # Execute the matmul.
+# out = _with_transpose(stk.ops.dsd, a, b, trans_a, trans_b)
+# expected_out = _with_transpose(torch.mm, a_dense, bcp, trans_a, trans_b)
+#
+# # Compute the gradients w.r.t. the inputs.
+# expected_out.sum().backward()
+# out.sum().backward()
+#
+# # Validate the results.
+# self.assertEqual(out.dim(), 2)
+# self.assertEqual(expected_out.size()[0], out.size()[0])
+# self.assertEqual(expected_out.size()[1], out.size()[1])
+# self.assertTrue(allclose(out, expected_out))
+#
+# # LHS gradient.
+# grad = stk.ops.to_dense(a.grad)
+# expected_grad = _mask(a_dense.grad, a.grad)
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# # RHS gradient.
+# grad = b.grad
+# expected_grad = bcp.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# def testLinearOps_Dds(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+# # Construct the operands.
+# a_shape = (k, m) if trans_a else (m, k)
+# a, acp = _dense_2x(*a_shape, dtype)
+# b_shape = (n, k) if trans_b else (k, n)
+# b_dense, b = _dense_and_sparse(*b_shape, sparsity, blocking, dtype)
+#
+# # Execute the matmul.
+# out = _with_transpose(stk.ops.dds, a, b, trans_a, trans_b)
+# expected_out = _with_transpose(torch.mm, acp, b_dense, trans_a, trans_b)
+#
+# # Compute the gradients w.r.t. the inputs.
+# expected_out.sum().backward()
+# out.sum().backward()
+#
+# # Validate the results.
+# self.assertEqual(out.dim(), 2)
+# self.assertEqual(expected_out.size()[0], out.size()[0])
+# self.assertEqual(expected_out.size()[1], out.size()[1])
+# self.assertTrue(allclose(out, expected_out))
+#
+# # LHS gradient.
+# grad = a.grad
+# expected_grad = acp.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# # RHS gradient.
+# grad = stk.ops.to_dense(b.grad)
+# expected_grad = _mask(b_dense.grad, b.grad)
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# def testLinearOps_Sdd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+# # Construct the operands.
+# a_shape = (k, m) if trans_a else (m, k)
+# a, acp = _dense_2x(*a_shape, dtype)
+# b_shape = (n, k) if trans_b else (k, n)
+# b, bcp = _dense_2x(*b_shape, dtype)
+# _, topo = _dense_and_sparse(m, n, sparsity, blocking, dtype)
+#
+# # Execute the matmul.
+# out = _sparse_out_with_transpose(stk.ops.sdd, a, b, topo, trans_a, trans_b)
+# expected_out = _sparse_out_with_transpose(_mmm, acp, bcp, topo, trans_a, trans_b)
+#
+# # Compute the gradients w.r.t. the inputs.
+# expected_out.sum().backward()
+# stk.ops.sum(out).backward()
+#
+# # Validate the results.
+# out = stk.ops.to_dense(out)
+# self.assertEqual(out.dim(), 2)
+# self.assertEqual(expected_out.size()[0], out.size()[0])
+# self.assertEqual(expected_out.size()[1], out.size()[1])
+# self.assertTrue(allclose(out, expected_out))
+#
+# # LHS gradient.
+# grad = a.grad
+# expected_grad = acp.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# # RHS gradient.
+# grad = b.grad
+# expected_grad = bcp.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch211-cxx11-xpu20253-x86_64-linux/stk/ops/matrix_ops.py b/build/torch211-cxx11-xpu20253-x86_64-linux/stk/ops/matrix_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..447c72dc73439d84f58c917676cc04e64f13e97d
--- /dev/null
+++ b/build/torch211-cxx11-xpu20253-x86_64-linux/stk/ops/matrix_ops.py
@@ -0,0 +1,98 @@
+from ..backend import sputnik
+from ..matrix import Matrix
+import torch
+import numpy as np
+
+
+@torch.no_grad()
+def row_indices(shape, data, offsets, column_indices):
+ return sputnik.row_indices(shape, data, offsets, column_indices)
+
+
+# TODO(tgale): Replace this helper with a custom kernel. This operation
+# is much simpler to do than how it's currently implemented.
+@torch.no_grad()
+def _expand_for_blocking(idxs, blocking):
+ # Duplicate for block column dimension.
+ idxs = torch.reshape(idxs, [idxs.size()[0], 1, 2]).repeat(1, blocking, 1)
+
+ # Update the column indices.
+ idxs[:, :, 1] *= blocking
+ idxs[:, :, 1] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking])
+
+ # Duplicate for block row dimension.
+ idxs = torch.reshape(idxs, [idxs.size()[0], 1, blocking, 2])
+ idxs = idxs.repeat(1, blocking, 1, 1)
+
+ # Update the row indices.
+ idxs[:, :, :, 0] *= blocking
+ idxs[:, :, :, 0] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking, 1])
+ idxs = torch.reshape(idxs, [-1, 2])
+ return idxs
+
+
+# TODO(tgale): Add input type checking.
+@torch.no_grad()
+def to_dense(x):
+ assert isinstance(x, Matrix)
+
+ shape = (np.prod(x.shape[:-1]), x.shape[-1])
+ row_idxs = x.row_indices.type(torch.int32)
+ col_idxs = x.column_indices.type(torch.int32)
+ indices = _expand_for_blocking(torch.stack([row_idxs, col_idxs], dim=1), x.blocking)
+ indices = (indices[:, 0] * shape[1] + indices[:, 1]).type(torch.int64)
+
+ out = torch.zeros(shape[0] * shape[1], dtype=x.dtype, device=x.device)
+ out.scatter_(0, indices, x.data.flatten())
+ return out.reshape(x.size())
+
+
+@torch.no_grad()
+def _mask(x, blocking=1):
+ assert x.dim() == 2
+ assert x.size()[0] % blocking == 0
+ assert x.size()[1] % blocking == 0
+ block_rows = x.size()[0] // blocking
+ block_cols = x.size()[1] // blocking
+ x = torch.reshape(x, [block_rows, blocking, block_cols, blocking])
+ x = torch.sum(torch.abs(x), dim=(1, 3))
+ return x != 0
+
+
+# TODO(tgale): Add input type checking.
+@torch.no_grad()
+def to_sparse(x, blocking=1):
+ m = _mask(x, blocking)
+
+ # TODO(tgale): Set to appropriate type for input matrix.
+ row_nnzs = torch.sum(m, dim=1).type(torch.int32)
+ zeros = torch.zeros((1,), dtype=row_nnzs.dtype, device=row_nnzs.device)
+ offsets = torch.cat([zeros, torch.cumsum(row_nnzs, dim=0)])
+ offsets = offsets.type(torch.int32)
+
+ indices = torch.nonzero(m).type(torch.int16)
+ row_indices = indices[:, 0]
+ column_indices = indices[:, 1]
+
+ # Nonzero indices in the dense matrix.
+ nonzero_indices = torch.nonzero(m)
+ nonzero_indices = _expand_for_blocking(nonzero_indices, blocking)
+ nonzero_indices = nonzero_indices[:, 0] * x.size()[1] + nonzero_indices[:, 1]
+
+ # Gather the data and construct the sparse matrix.
+ data = torch.gather(x.flatten(), dim=0, index=nonzero_indices)
+ data = torch.reshape(data, [-1, blocking, blocking])
+ return Matrix(x.size(), data, row_indices, column_indices, offsets)
+
+
+@torch.no_grad()
+def ones_like(x):
+ return Matrix(x.size(),
+ torch.ones_like(x.data),
+ x.row_indices,
+ x.column_indices, x.offsets)
+
+
+def sum(x):
+ assert isinstance(x, Matrix)
+ return x.data.sum()
diff --git a/build/torch211-cxx11-xpu20253-x86_64-linux/stk/ops/matrix_ops_test.py b/build/torch211-cxx11-xpu20253-x86_64-linux/stk/ops/matrix_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..d172d921f6f08b0e4fb709207a458b0e1e071bd0
--- /dev/null
+++ b/build/torch211-cxx11-xpu20253-x86_64-linux/stk/ops/matrix_ops_test.py
@@ -0,0 +1,62 @@
+import unittest
+
+# from absl.testing import parameterized
+import stk
+import torch
+
+
+# @parameterized.parameters(
+# (8, 16, 0.0, 1),
+# (8, 16, 0.5, 1),
+# (8, 16, .95, 1),
+# (16, 8, 0.0, 1),
+# (16, 8, 0.5, 1),
+# (16, 8, .95, 1),
+# (8, 16, 0.0, 8),
+# (8, 16, 0.5, 8),
+# (8, 16, 1.0, 8),
+# (16, 8, 0.0, 8),
+# (16, 8, 0.5, 8),
+# (16, 8, 1.0, 8),
+# (128, 256, 0.5, 16),
+# (256, 128, 0.75, 32),
+# (512, 512, .875, 128))
+# class MatrixOpsTest(parameterized.TestCase):
+#
+# def testMatrixOps_FormatConversion(self, rows, cols, sparsity, blocking):
+# mask = stk.random.dense_mask(rows, cols, sparsity, blocking)
+# x = (torch.randn(rows, cols) * mask).type(torch.float16)
+#
+# # Convert the matrix to sparse format.
+# sparse_x = stk.ops.to_sparse(x, blocking)
+#
+# # Validate the matrix.
+# sparse_x.validate()
+#
+# # Validate the shape.
+# self.assertEqual(sparse_x.dim(), 2)
+# self.assertEqual(sparse_x.size()[0], rows)
+# self.assertEqual(sparse_x.size()[1], cols)
+#
+# # Validate the sparsity.
+# numblocks = rows // blocking * cols // blocking
+# nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+# self.assertEqual(sparse_x.nnz, nnz)
+#
+# # Convert back to dense format.
+# dense_x = stk.ops.to_dense(sparse_x)
+#
+# # Validate the shape.
+# self.assertEqual(dense_x.dim(), 2)
+# self.assertEqual(dense_x.size()[0], rows)
+# self.assertEqual(dense_x.size()[1], cols)
+#
+# # Validate the sparsity
+# self.assertEqual(torch.count_nonzero(dense_x).item(), nnz)
+#
+# # Validate the output.
+# self.assertTrue(torch.all(torch.eq(x, dense_x)))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch211-cxx11-xpu20253-x86_64-linux/stk/random/__init__.py b/build/torch211-cxx11-xpu20253-x86_64-linux/stk/random/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2576d1ca27283f77569a9a620c7c99fa68aaf30e
--- /dev/null
+++ b/build/torch211-cxx11-xpu20253-x86_64-linux/stk/random/__init__.py
@@ -0,0 +1,2 @@
+# from stk.random.random_ops import dense_mask, mask, randn
+from .random_ops import dense_mask, mask, randn
diff --git a/build/torch211-cxx11-xpu20253-x86_64-linux/stk/random/random_ops.py b/build/torch211-cxx11-xpu20253-x86_64-linux/stk/random/random_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1b36771e0eb8e7abf46bcb3b136b5fb1d29df93
--- /dev/null
+++ b/build/torch211-cxx11-xpu20253-x86_64-linux/stk/random/random_ops.py
@@ -0,0 +1,36 @@
+import numpy as np
+import torch
+from ..ops import matrix_ops
+
+
+@torch.no_grad()
+def dense_mask(rows, cols, sparsity, blocking=1):
+ assert sparsity >= 0.0 and sparsity <= 1.0
+ assert rows % blocking == 0 and cols % blocking == 0
+
+ block_rows, block_cols = (rows // blocking, cols // blocking)
+ nnz = round(block_rows * block_cols * (1 - sparsity))
+
+ out = np.ones(block_rows * block_cols)
+ mask = np.random.choice(out.size, out.size - nnz, replace=False)
+ out[mask] = 0.0
+
+ out = np.tile(
+ np.reshape(out, [block_rows, 1, block_cols, 1]),
+ (1, blocking, 1, blocking))
+ out = np.reshape(out, [rows, cols])
+ return torch.from_numpy(out.astype(np.float32))
+
+
+@torch.no_grad()
+def mask(m, n, sparsity, blocking=1):
+ out = dense_mask(m, n, sparsity, blocking).type(torch.float16)
+ return matrix_ops.to_sparse(out, blocking=blocking)
+
+
+@torch.no_grad()
+def randn(shape, sparsity, blocking=1):
+ shape_2d = (np.prod(shape[:-1]), shape[-1])
+ out = mask(*shape_2d, sparsity, blocking)
+ out.data.copy_(torch.randn(*out.data.shape))
+ return out.view(*shape)
diff --git a/build/torch211-cxx11-xpu20253-x86_64-linux/stk/random/random_ops_test.py b/build/torch211-cxx11-xpu20253-x86_64-linux/stk/random/random_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..71d716b78b5ec009cbf9ac2dfdf09162a0102e62
--- /dev/null
+++ b/build/torch211-cxx11-xpu20253-x86_64-linux/stk/random/random_ops_test.py
@@ -0,0 +1,73 @@
+import unittest
+
+# from absl.testing import parameterized
+from . import random
+import torch
+
+
+# @parameterized.parameters(
+# (8, 16, 0.0, 1),
+# (8, 16, 0.5, 1),
+# (8, 16, .95, 1),
+# (16, 8, 0.0, 1),
+# (16, 8, 0.5, 1),
+# (16, 8, .95, 1),
+# (8, 16, 0.0, 8),
+# (8, 16, 0.5, 8),
+# (8, 16, 1.0, 8),
+# (16, 8, 0.0, 8),
+# (16, 8, 0.5, 8),
+# (16, 8, 1.0, 8),
+# (128, 256, 0.5, 16),
+# (256, 128, 0.75, 32),
+# (512, 512, .875, 128))
+# class RandomOpsTest(parameterized.TestCase):
+#
+# def testRandomOps_DenseMask(self, rows, cols, sparsity, blocking):
+# mask = random.dense_mask(
+# rows, cols, sparsity, blocking)
+#
+# # Validate the shape.
+# self.assertEqual(mask.dim(), 2)
+# self.assertEqual(mask.size()[0], rows)
+# self.assertEqual(mask.size()[1], cols)
+#
+# # Validate the sparsity
+# numblocks = rows // blocking * cols // blocking
+# nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+# self.assertEqual(
+# torch.count_nonzero(mask).item(),
+# nnz)
+#
+# # Check values are zero or one.
+# self.assertTrue(
+# torch.all(torch.logical_or(
+# torch.eq(mask, 0),
+# torch.eq(mask, 1))))
+#
+# def testRandomOps_SparseMask(self, rows, cols, sparsity, blocking):
+# mask = random.mask(
+# rows, cols, sparsity, blocking)
+#
+# # Validate the matrix.
+# mask.validate()
+#
+# # Validate the shape.
+# self.assertEqual(mask.dim(), 2)
+# self.assertEqual(mask.size()[0], rows)
+# self.assertEqual(mask.size()[1], cols)
+#
+# # Validate the sparsity.
+# numblocks = rows // blocking * cols // blocking
+# nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+# self.assertEqual(mask.nnz, nnz)
+#
+# # Check values are zero or one.
+# self.assertTrue(
+# torch.all(torch.logical_or(
+# torch.eq(mask.data, 0),
+# torch.eq(mask.data, 1))))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch211-cxx11-xpu20253-x86_64-linux/xpu_fused_moe.py b/build/torch211-cxx11-xpu20253-x86_64-linux/xpu_fused_moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..0440b484e6f4073a384dcf3bc562174d2ef71d25
--- /dev/null
+++ b/build/torch211-cxx11-xpu20253-x86_64-linux/xpu_fused_moe.py
@@ -0,0 +1,672 @@
+# SPDX-License-Identifier: Apache-2.0
+# MegaBlocks XPU Fused MoE Implementation
+import os
+import torch
+
+from ._ops import ops, add_op_namespace_prefix
+
+from torch.library import register_fake
+
+
+def resolve_dtensor(weight: torch.Tensor):
+ """Convert DTensor to local tensor for use with custom ops."""
+ from torch.distributed._tensor import DTensor
+ if isinstance(weight, DTensor):
+ return weight.to_local()
+ return weight
+
+
+# Register fake/meta kernels for torch.compile compatibility
+def _register_xpu_fake_kernels():
+ """Register fake kernels for XPU MoE operations to support torch.compile."""
+
+ def _register_if_available(op_name, fn):
+ if hasattr(ops, op_name):
+ register_fake(add_op_namespace_prefix(op_name))(fn)
+
+ _register_if_available(
+ "cutlass_grouped_gemm_interface",
+ lambda ptr_A, ptr_B, ptr_scales, ptr_bias, ptr_D, expert_first_token_offset, N, K, num_experts, is_B_int4, is_B_mxfp4: ptr_D,
+ )
+
+ _register_if_available(
+ "fused_moe_prologue",
+ lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, ep_rank, ep_size, num_experts_on_rank: None,
+ )
+
+ _register_if_available(
+ "moe_gather",
+ lambda output, moe_output, topk_weights, permuted_row_to_unpermuted_row, unpermuted_row_to_permuted_row, expert_first_token_offset, num_experts: None,
+ )
+
+ _register_if_available(
+ "silu_and_mul",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "mul_and_silu",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_and_mul",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_tanh_and_mul",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_fast",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_new",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_quick",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "swigluoai_and_mul",
+ lambda out, input, alpha=1.702, limit=7.0: None,
+ )
+
+
+# Register fake kernels on module load
+_register_xpu_fake_kernels()
+
+
+# default
+def cutlass_grouped_gemm(input_A, input_B, bias, output, expert_token_count, n,
+ k, num_experts):
+ # expert_token_count_ = torch.tensor(expert_token_count,
+ # dtype=torch.int64,
+ # device=input_A.device)
+ # if bias is not None:
+ # bias = bias.repeat_interleave(expert_token_count_, dim=0).float()
+
+ def exclusive_prefix_sum(arr):
+ prefix = [0]
+ for i, x in enumerate(arr):
+ prefix.append(prefix[-1] + x)
+ return prefix
+
+ expert_offset = torch.tensor(exclusive_prefix_sum(expert_token_count),
+ dtype=torch.int64,
+ device="xpu")
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=None,
+ ptr_bias=bias,
+ ptr_D=output,
+ expert_first_token_offset=expert_offset,
+ N=n,
+ K=k,
+ num_experts=num_experts,
+ is_B_int4=False,
+ is_B_mxfp4=False)
+
+
+def cutlass_grouped_gemm_xe2(input_A, input_B, scales, bias, output,
+ num_rows_per_expert, n, k, num_experts, is_B_int4,
+ is_B_mxfp4):
+ expert_first_token_offset = torch.cat([
+ torch.tensor([0],
+ dtype=num_rows_per_expert.dtype,
+ device=num_rows_per_expert.device),
+ torch.cumsum(num_rows_per_expert, dim=0)
+ ]).to(torch.int64)
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=scales,
+ ptr_bias=bias,
+ ptr_D=output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=n,
+ K=k,
+ num_experts=num_experts,
+ is_B_int4=is_B_int4,
+ is_B_mxfp4=is_B_mxfp4)
+
+
+def ceilDiv(a, b):
+ return (a + b - 1) // b
+
+
+def compute_num_tokens_per_block(num_tokens, num_experts_per_node):
+ for num_tokens_per_block in [32, 64, 128, 256, 512, 1024]:
+ num_blocks_per_seq = ceilDiv(num_tokens, num_tokens_per_block)
+ if num_blocks_per_seq * num_experts_per_node <= num_tokens_per_block:
+ return num_tokens_per_block
+ return 1024
+
+
+def _bytes_to_typed_tensor(byte_tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
+ """Reinterpret a uint8 buffer as a typed tensor by copying bytes.
+
+ This avoids `Tensor.view(dtype)` which can fail under torch.compile
+ constant folding when shape divisibility is not proven.
+ """
+ if byte_tensor.dtype != torch.uint8:
+ raise ValueError("byte_tensor must be uint8")
+ itemsize = torch.empty((), dtype=dtype).element_size()
+ numel = byte_tensor.numel() // itemsize
+ out = torch.empty((numel,), dtype=dtype, device=byte_tensor.device)
+ out.view(torch.uint8).copy_(byte_tensor.contiguous())
+ return out
+
+
+def implement_zp(qweight):
+ # change u4 to s4 to avoid zero point in gemm kernel
+ # only support default zero point now
+ assert qweight.dtype == torch.uint8, "Input tensor must be uint8"
+
+ high_u4 = (qweight >> 4) & 0x0F
+ low_u4 = qweight & 0x0F
+
+ high_s8 = high_u4.to(torch.int8)
+ low_s8 = low_u4.to(torch.int8)
+
+ high_s8 = high_s8 - 8
+ low_s8 = low_s8 - 8
+
+ def pack_compact(a, b):
+
+ def process_number(x):
+ sign = (x < 0).to(torch.uint8)
+ abs_low3 = (x.view(torch.uint8) & 0x7).to(torch.uint8)
+ return (sign << 3) | abs_low3
+
+ packed_a = process_number(a)
+ packed_b = process_number(b)
+
+ return (packed_a << 4) | packed_b
+
+ result = pack_compact(high_s8, low_s8)
+
+ return result
+
+
+def xpu_fused_moe(hidden_states,
+ w13,
+ w13_scales,
+ w13_bias,
+ w2,
+ w2_scales,
+ w2_bias,
+ topk_weights,
+ topk_ids,
+ n_experts_per_token,
+ activation,
+ num_experts,
+ ep_rank=0,
+ ep_size=1,
+ is_fp8=False,
+ is_int4=False,
+ is_mxfp4=False):
+ '''
+ hidden_states: [num_rows, hidden_size]
+ w13: [num_experts, 2*inter_size, hidden_size]
+ w13_scales:
+ None for bf16/fp16
+ or [num_experts] for fp8
+ or [num_experts, 2*inter_size, hidden_size // group_size] for 4bits
+ w13_bias: [num_experts, 2*inter_size] or None
+ w2: [num_experts, hidden_size, inter_size]
+ w2_scales:
+ None for bf16/fp16
+ or [num_experts] for fp8
+ or [num_experts, hidden_size, inter_size // group_size] for 4bits
+ w2_bias: [num_experts, hidden_size] or None
+ topk_weights: [num_rows, topk]
+ topk_ids: [num_rows, topk]
+ n_experts_per_token: int
+ activation: str
+ num_experts: int
+ is_int4: bool
+ is_mxfp4: bool
+ '''
+
+ # Resolve DTensors to local tensors before passing to custom ops
+ hidden_states = resolve_dtensor(hidden_states)
+ w13 = resolve_dtensor(w13)
+ w2 = resolve_dtensor(w2)
+ if w13_scales is not None:
+ w13_scales = resolve_dtensor(w13_scales)
+ if w13_bias is not None:
+ w13_bias = resolve_dtensor(w13_bias)
+ if w2_scales is not None:
+ w2_scales = resolve_dtensor(w2_scales)
+ if w2_bias is not None:
+ w2_bias = resolve_dtensor(w2_bias)
+ topk_weights = resolve_dtensor(topk_weights)
+ topk_ids = resolve_dtensor(topk_ids)
+
+ output = torch.empty_like(hidden_states)
+ num_rows, hidden_size = list(hidden_states.shape)
+
+ dim_last = w13.shape[-1]
+ dim_second_last = w13.shape[-2]
+
+ # w13 is combined gate+up weights, so one dimension is 2*inter_size
+ # Determine which dimension is hidden_size and which is 2*inter_size
+ if dim_second_last == hidden_size:
+ # w13 is [E, hidden_size, 2*inter_size] - standard layout
+ inter_size = dim_last // 2
+ needs_transpose = False
+ else:
+ # w13 is [E, 2*inter_size, hidden_size] - needs transpose
+ inter_size = dim_second_last // 2
+ needs_transpose = True
+
+ assert w13.is_contiguous() and w2.is_contiguous()
+
+ # 4bits support [E, N, K]
+ # other types [E, K, N]
+ if not is_int4 and not is_mxfp4:
+ if not hasattr(w13, 'xpu_fused_moe'):
+ if needs_transpose:
+ w13.data = w13.transpose(-1, -2).contiguous()
+ w2.data = w2.transpose(-1, -2).contiguous()
+ w13.xpu_fused_moe = True
+ w13.inter_size = inter_size
+ else:
+ inter_size = w13.inter_size
+
+ if is_int4 and not hasattr(w13, 'xpu_fused_moe'):
+ w13_tmp = torch.empty_like(w13)
+ w2_tmp = torch.empty_like(w2)
+ for i in range(num_experts):
+ w13_tmp[i] = implement_zp(w13[i])
+ w2_tmp[i] = implement_zp(w2[i])
+ w13_tmp = w13_tmp.contiguous()
+ w2_tmp = w2_tmp.contiguous()
+ w13.data = w13_tmp
+ w2.data = w2_tmp
+ w13.xpu_fused_moe = True
+
+ # TODO: will all integrated in Cpp func. Temporary expose before gemm fusion
+ num_experts_per_node = num_experts
+ experts_per_token = n_experts_per_token
+ num_moe_inputs = n_experts_per_token * num_rows
+ permuted_elems = num_moe_inputs * hidden_size
+ # interbuf_elems = num_moe_inputs * inter_size
+ permuted_row_to_unpermuted_row_size = num_moe_inputs * 4
+ permuted_token_selected_experts_size = num_moe_inputs * 4
+ src_to_dest_map_size = experts_per_token * num_rows * 4
+ expert_first_token_offset_size = (num_experts_per_node + 1) * 8
+ num_tokens_per_block = compute_num_tokens_per_block(
+ num_rows, num_experts_per_node)
+ num_blocks_per_seq = ceilDiv(num_rows, num_tokens_per_block)
+ blocked_expert_counts_size = num_experts_per_node * num_blocks_per_seq * 4
+ blocked_expert_counts_cumsum_size = blocked_expert_counts_size
+ blocked_row_to_unpermuted_row_size = num_experts_per_node * num_rows * 4
+ permuted_data_size = permuted_elems * hidden_states.element_size()
+ permuted_token_final_scales_size = num_moe_inputs * 4
+
+ ws_map = {}
+ map_offset = 0
+
+ def config_ws(name, size):
+ nonlocal map_offset
+ if size % 256 != 0:
+ size += 256 - size % 256
+ ws_map[name] = (size, map_offset)
+ map_offset += size
+
+ config_ws("permuted_row_to_unpermuted_row",
+ permuted_row_to_unpermuted_row_size)
+ config_ws("permuted_token_selected_experts",
+ permuted_token_selected_experts_size)
+ config_ws("unpermuted_row_to_permuted_row", src_to_dest_map_size)
+ config_ws("blocked_expert_counts", blocked_expert_counts_size)
+ config_ws("blocked_expert_counts_cumsum",
+ blocked_expert_counts_cumsum_size)
+ config_ws("blocked_row_to_unpermuted_row",
+ blocked_row_to_unpermuted_row_size)
+ config_ws("expert_first_token_offset", expert_first_token_offset_size)
+ config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
+ config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
+
+ workspace = torch.zeros(map_offset,
+ dtype=torch.uint8,
+ device=hidden_states.device)
+ if topk_ids.dtype == torch.int32:
+ topk_ids = topk_ids.to(torch.int64)
+ ops.fused_moe_prologue(
+ input=hidden_states,
+ token_selected_experts=topk_ids,
+ token_final_scales=topk_weights,
+ workspace=workspace,
+ hidden_size=hidden_size,
+ inter_size=inter_size,
+ ep_rank=ep_rank,
+ ep_size=ep_size,
+ num_experts_on_rank=num_experts_per_node)
+
+ expert_first_token_offset_bytes = workspace[
+ ws_map["expert_first_token_offset"][1]:
+ ws_map["expert_first_token_offset"][1] +
+ expert_first_token_offset_size]
+ unpermuted_row_to_permuted_row_bytes = workspace[
+ ws_map["unpermuted_row_to_permuted_row"][1]:
+ ws_map["unpermuted_row_to_permuted_row"][1] +
+ src_to_dest_map_size]
+ permuted_row_to_unpermuted_row_bytes = workspace[
+ ws_map["permuted_row_to_unpermuted_row"][1]:
+ ws_map["permuted_row_to_unpermuted_row"][1] +
+ permuted_row_to_unpermuted_row_size]
+
+ if torch.compiler.is_compiling():
+ expert_first_token_offset = _bytes_to_typed_tensor(
+ expert_first_token_offset_bytes, torch.int64
+ )
+ unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
+ unpermuted_row_to_permuted_row_bytes, torch.int32
+ )
+ permuted_row_to_unpermuted_row = _bytes_to_typed_tensor(
+ permuted_row_to_unpermuted_row_bytes, torch.int32
+ )
+ else:
+ expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
+ unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
+ permuted_row_to_unpermuted_row = permuted_row_to_unpermuted_row_bytes.view(torch.int32)
+ gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
+ ws_map["overlapped_gemm1_gemm2_inputs"][1] +
+ permuted_data_size].view(hidden_states.dtype).view(
+ num_moe_inputs, hidden_size)
+ # permuted_token_final_scales = workspace[
+ # ws_map["permuted_token_final_scales"][1]:
+ # ws_map["permuted_token_final_scales"][1] +
+ # permuted_token_final_scales_size].view(torch.float)
+ gemm1_output = torch.empty((num_moe_inputs, 2 * inter_size),
+ dtype=hidden_states.dtype,
+ device=hidden_states.device)
+
+ ########### gemm1 ##################
+ input_B = w13
+
+ if not is_fp8 and not is_int4 and not is_mxfp4:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=gemm1_input,
+ ptr_B=input_B,
+ ptr_scales=None,
+ ptr_bias=w13_bias,
+ ptr_D=gemm1_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=2 * inter_size,
+ K=hidden_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+ else:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=gemm1_input,
+ ptr_B=input_B,
+ ptr_scales=w13_scales,
+ ptr_bias=w13_bias,
+ ptr_D=gemm1_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=2 * inter_size,
+ K=hidden_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+
+ # act
+ act_output = torch.empty((num_moe_inputs, inter_size),
+ dtype=gemm1_output.dtype,
+ device=gemm1_output.device)
+ if activation == "silu":
+ ops.silu_and_mul(act_output, gemm1_output)
+ elif activation == "gelu":
+ ops.gelu_and_mul(act_output, gemm1_output)
+ elif activation == "swigluoai":
+ ops.swigluoai_and_mul(act_output, gemm1_output, 1.702, 7.0)
+ else:
+ raise ValueError(f"Unsupported FusedMoe activation: {activation}.")
+
+ ########### gemm2 ##################
+ input_A = act_output.contiguous()
+ input_B = w2
+ gemm2_output = torch.empty((num_moe_inputs, hidden_size),
+ dtype=hidden_states.dtype,
+ device=hidden_states.device)
+ if not is_fp8 and not is_int4 and not is_mxfp4:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=None,
+ ptr_bias=w2_bias,
+ ptr_D=gemm2_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=hidden_size,
+ K=inter_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+ else:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=w2_scales,
+ ptr_bias=w2_bias,
+ ptr_D=gemm2_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=hidden_size,
+ K=inter_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+
+ ops.moe_gather(output, gemm2_output, topk_weights,
+ permuted_row_to_unpermuted_row,
+ unpermuted_row_to_permuted_row,
+ expert_first_token_offset,
+ num_experts_per_node)
+ return output
+
+
+def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor:
+ """Apply jitter to the input tensor for regularization."""
+ low = 1.0 - moe_jitter_eps
+ high = 1.0 + moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return x * (low + noise * (high - low))
+
+
+def compute_top_k(scores: torch.Tensor, moe_top_k: int):
+ """Compute the top-k scores from the logits."""
+ if moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, moe_top_k, dim=-1)
+
+
+def route_tokens_xpu(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: torch.Tensor,
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ training: bool = False,
+) -> tuple:
+ """Route tokens to experts and compute expert weights and indices (XPU version)."""
+ if training and moe_jitter_eps is not None:
+ x = apply_jitter(x, moe_jitter_eps)
+
+ x_flat = x.view(-1, x.shape[-1])
+ logits = torch.nn.functional.linear(x_flat, router_weight, router_bias)
+ expert_weights, expert_indices = compute_top_k(logits, moe_top_k)
+ expert_weights = expert_weights.softmax(dim=-1)
+ if moe_normalize_expert_weights is not None:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+
+ return logits, expert_weights, expert_indices
+
+
+def _get_device_mesh(model):
+ """Extract device_mesh from child's unused pre_hook closure for EP support."""
+ try:
+ hook = next(
+ h
+ for h in model.experts._forward_pre_hooks.values()
+ if "device_mesh" in h.__code__.co_freevars
+ )
+ return hook.__closure__[
+ hook.__code__.co_freevars.index("device_mesh")
+ ].cell_contents
+ except Exception:
+ return None
+
+
+class MegaBlocksMoeMLP(torch.nn.Module):
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> tuple:
+ """
+ Forward pass through the MoE layer.
+
+ Args:
+ x: Input tensor of shape [batch_size, seq_len, hidden_size] or [tokens, hidden_size]
+
+ Returns:
+ Tuple of (output, expert_weights) where:
+ - output: Tensor of same shape as input
+ - expert_weights: Expert weights for each token [tokens, top_k]
+ """
+
+ # Get MoE parameters from the wrapped modules
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+
+ # Get EP (Expert Parallelism) parameters
+ ep_size = 1
+ ep_rank = 0
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = _get_device_mesh(self)
+ if device_mesh is not None:
+ expert_parallel_group = device_mesh.get_group()
+ if expert_parallel_group is not None:
+ import torch.distributed as dist
+ if dist.is_initialized():
+ ep_size = dist.get_world_size(expert_parallel_group)
+ ep_rank = dist.get_rank(expert_parallel_group)
+
+ # Number of experts on this rank
+ num_experts_on_rank = moe_num_experts // ep_size
+
+ # Detect activation type - check for GptOss-style swigluoai activation
+ # GptOssExperts has alpha and limit attributes for swigluoai
+ if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
+ activation = "swigluoai"
+ else:
+ activation = getattr(self.experts, "activation", "silu")
+
+ # Get weight tensors - support different naming conventions
+ if hasattr(self.experts, "gate_up_proj"):
+ w13 = self.experts.gate_up_proj
+ # NOTE: swigluoai_and_mul kernel expects interleaved layout [g0,u0,g1,u1,...]
+ # which matches GptOss's gate_up_proj layout, so no conversion needed.
+
+ elif hasattr(self.experts, "w1"):
+ # Combine w1 and w3 if stored separately
+ w1 = self.experts.w1
+ w3 = getattr(self.experts, "w3", None)
+ if w3 is not None:
+ w13 = torch.cat([w1, w3], dim=-2)
+ else:
+ w13 = w1
+ else:
+ raise AttributeError("experts module must have 'gate_up_proj' or 'w1' attribute")
+
+ if hasattr(self.experts, "down_proj"):
+ w2 = self.experts.down_proj
+ elif hasattr(self.experts, "w2"):
+ w2 = self.experts.w2
+ else:
+ raise AttributeError("experts module must have 'down_proj' or 'w2' attribute")
+
+ # Get optional bias tensors
+ w13_bias = getattr(self.experts, "gate_up_proj_bias", None)
+ w2_bias = getattr(self.experts, "down_proj_bias", None)
+
+ # Get quantization info
+ is_fp8 = getattr(self.experts, "is_fp8", False)
+ is_int4 = getattr(self.experts, "is_int4", False)
+ is_mxfp4 = getattr(self.experts, "is_mxfp4", False)
+
+ w13_scales = getattr(self.experts, "gate_up_proj_scales", None)
+ w2_scales = getattr(self.experts, "down_proj_scales", None)
+
+ # Store original shape
+ in_shape = x.size()
+
+ # Route tokens to experts
+ logits, expert_weights, expert_indices = route_tokens_xpu(
+ x,
+ self.router.weight,
+ getattr(self.router, "bias", None),
+ moe_top_k,
+ moe_num_experts,
+ moe_jitter_eps,
+ moe_normalize_expert_weights,
+ self.training,
+ )
+
+ # Reshape input for fused MoE
+ x_flat = x.view(-1, x.shape[-1])
+
+ # Call XPU fused MoE kernel
+ output = xpu_fused_moe(
+ hidden_states=x_flat,
+ w13=w13,
+ w13_scales=w13_scales,
+ w13_bias=w13_bias,
+ w2=w2,
+ w2_scales=w2_scales,
+ w2_bias=w2_bias,
+ topk_weights=expert_weights.float(),
+ topk_ids=expert_indices,
+ n_experts_per_token=moe_top_k,
+ activation=activation,
+ num_experts=num_experts_on_rank,
+ ep_rank=ep_rank,
+ ep_size=ep_size,
+ is_fp8=is_fp8,
+ is_int4=is_int4,
+ is_mxfp4=is_mxfp4,
+ )
+
+ # All-reduce across EP group to combine partial expert outputs
+ if ep_size > 1 and expert_parallel_group is not None:
+ import torch.distributed as dist
+ dist.all_reduce(output, op=dist.ReduceOp.SUM, group=expert_parallel_group)
+
+ # Restore original shape
+ output = output.view(in_shape)
+
+ return output, expert_weights
+
+
+# Export classes and functions
+__all__ = [
+ "MegaBlocksMoeMLP",
+ "xpu_fused_moe",
+ "cutlass_grouped_gemm",
+ "cutlass_grouped_gemm_xe2",
+]
\ No newline at end of file
diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/__init__.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..38075732c6d8fa0e1e6ef493145e1aca3851ae6b
--- /dev/null
+++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/__init__.py
@@ -0,0 +1,202 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+from ._ops import ops
+
+from .grouped_gemm import backend as gg_backend
+from .grouped_gemm import ops as gg_ops
+
+
+from ._layers.arguments import Arguments
+from ._layers.dmoe import ParallelDroplessMLP, dMoE
+from ._layers.glu import SparseGLU
+from ._layers.mlp import MLP, SparseMLP
+from ._layers.moe import MoE, ParallelMLP, get_load_balancing_loss
+
+from . import layers
+
+# This section contains the direct kernel exports (not inlcuded in the original code)
+def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
+ """
+ Compute exclusive cumulative sum along the specified dimension.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ result = ops.exclusive_cumsum(x, dim)
+ out.copy_(result)
+ return out
+
+
+def inclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
+ """
+ Compute inclusive cumulative sum along the specified dimension.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ result = ops.inclusive_cumsum(x, dim)
+ out.copy_(result)
+ return out
+
+
+def histogram(x: torch.Tensor, num_bins: int) -> torch.Tensor:
+ """
+ Compute histogram of input tensor values.
+
+ Args:
+ x: Input tensor
+ num_bins: Number of histogram bins
+
+ Returns:
+ Histogram tensor with counts for each bin
+ """
+ return ops.histogram(x, num_bins)
+
+
+def indices(
+ padded_bins: torch.Tensor,
+ block_size: int,
+ output_block_rows: int,
+ output_block_columns: int,
+) -> torch.Tensor:
+ """
+ Construct indices from padded bins for sparse operations.
+
+ Args:
+ padded_bins: Tensor containing bin boundaries
+ block_size: Size of each block
+ output_block_rows: Number of rows in output blocks
+ output_block_columns: Number of columns in output blocks
+
+ Returns:
+ Tensor containing constructed indices
+ """
+ return ops.indices(padded_bins, block_size, output_block_rows, output_block_columns)
+
+
+def replicate_forward(
+ x: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Forward pass of replicate operation - replicate values according to bin sizes.
+
+ Args:
+ x: Input tensor with values to replicate
+ bins: Tensor containing bin sizes
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ return ops.replicate_forward(x, bins, out)
+
+
+def replicate_backward(
+ grad: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Backward pass of replicate operation - reduce gradients back to bins.
+
+ Args:
+ grad: Gradient tensor to reduce
+ bins: Tensor containing bin sizes
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ return ops.replicate_backward(grad, bins, out)
+
+
+def sort(
+ x: torch.Tensor, end_bit: int, x_out: torch.Tensor, iota_out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Radix sort with index tracking.
+
+ Args:
+ x: Input tensor to sort
+ end_bit: Number of bits to consider in sorting
+ x_out: Output tensor for sorted values
+ iota_out: Output tensor for sorted indices
+
+ Returns:
+ The sorted values tensor
+ """
+ return ops.sort(x, end_bit, x_out, iota_out)
+
+
+# Convenience functions for common use cases
+def cumsum(x: torch.Tensor, dim: int = -1, exclusive: bool = False) -> torch.Tensor:
+ """
+ Compute cumulative sum with automatic output allocation.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum (default: last dimension)
+ exclusive: Whether to compute exclusive (True) or inclusive (False) cumsum
+
+ Returns:
+ New tensor containing the cumulative sum
+ """
+ out = torch.empty_like(x)
+ if exclusive:
+ return exclusive_cumsum(x, dim, out)
+ else:
+ return inclusive_cumsum(x, dim, out)
+
+
+def argsort(x: torch.Tensor, end_bit: int = 32) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Sort tensor and return both sorted values and indices.
+
+ Args:
+ x: Input tensor to sort
+ end_bit: Number of bits to consider in sorting
+
+ Returns:
+ Tuple of (sorted_values, sorted_indices)
+ """
+ x_out = torch.empty_like(x)
+ iota_out = torch.empty_like(x)
+ sort(x, end_bit, x_out, iota_out)
+ return x_out, iota_out
+
+
+# Export public API
+__all__ = [
+ "MyReplacementLayer",
+ # Direct kernel exports
+ "exclusive_cumsum",
+ "inclusive_cumsum",
+ "histogram",
+ "indices",
+ "replicate_forward",
+ "replicate_backward",
+ "sort",
+ "cumsum",
+ "argsort",
+ # Original exports
+ "Arguments",
+ "ParallelDroplessMLP",
+ "dMoE",
+ "SparseGLU",
+ "MLP",
+ "SparseMLP",
+ "MoE",
+ "ParallelMLP",
+ "get_load_balancing_loss",
+]
diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/__init__.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a720e7a2cc4e44636f6e433a2750e945dc38e8b2
--- /dev/null
+++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/__init__.py
@@ -0,0 +1,10 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# from megablocks.layers.dmoe import dMoE
+from .moe import MoE
+
+__all__ = [
+ 'MoE',
+ # 'dMoE',
+]
diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/activation_fn.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/activation_fn.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e1d956704840aa4daf7d1d71d24e051567feab9
--- /dev/null
+++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/activation_fn.py
@@ -0,0 +1,33 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Callable, Union
+
+import torch
+from ..stk import Matrix
+
+
+def act_fn(
+ x: Matrix,
+ function: Callable,
+ return_grad_fn: bool = False,
+ **kwargs,
+) -> Union[tuple[Matrix, Any] | Matrix]:
+ assert isinstance(x, Matrix)
+ with torch.set_grad_enabled(torch.is_grad_enabled() or return_grad_fn):
+ if return_grad_fn:
+ x.data.requires_grad = True
+ out = function(x.data, **kwargs)
+ y = Matrix(
+ x.size(),
+ out,
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
+ if return_grad_fn:
+ return y, out.backward
+ return y
diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/all_to_all.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/all_to_all.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ac7067bcaa34db1d82b340c43550fe3577aa7a3
--- /dev/null
+++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/all_to_all.py
@@ -0,0 +1,54 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import torch.distributed as dist
+
+
+class AllToAllOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op):
+ out = torch.empty((sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype)
+
+ ctx.input_shape = x.shape
+ ctx.output_split_sizes = output_split_sizes
+ ctx.input_split_sizes = input_split_sizes
+ ctx.group = group
+ handle = dist.all_to_all_single(
+ out,
+ x,
+ output_split_sizes=output_split_sizes,
+ input_split_sizes=input_split_sizes,
+ group=group,
+ async_op=async_op,
+ )
+ return out, handle
+
+ @staticmethod
+ def backward(ctx, grad, _):
+ if ctx.needs_input_grad[0]:
+ out = torch.empty(
+ ctx.input_shape,
+ device=grad.device,
+ dtype=grad.dtype,
+ )
+ dist.all_to_all_single(
+ out,
+ grad,
+ output_split_sizes=ctx.input_split_sizes,
+ input_split_sizes=ctx.output_split_sizes,
+ group=ctx.group,
+ )
+ return out, None, None, None, None
+ return None, None, None, None, None
+
+
+def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False):
+ return AllToAllOp.apply(
+ x,
+ output_split_sizes,
+ input_split_sizes,
+ group,
+ async_op,
+ )
diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/arguments.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/arguments.py
new file mode 100644
index 0000000000000000000000000000000000000000..4db9b1bd38bc2e2f421625c124f86b85f45c5ae0
--- /dev/null
+++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/arguments.py
@@ -0,0 +1,101 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import dataclasses
+from functools import partial
+from typing import Any, Callable, Optional, Union
+
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+
+# import megablocks.grouped_gemm_util as grouped_gemm
+from .. import grouped_gemm_util as grouped_gemm
+
+# Type annotation for in-place Tensor initialization function.
+InitFn = Union[Callable[[torch.Tensor], None], partial[torch.Tensor]]
+
+_ALLOWED_BITWIDTHS = (-1, 4, 8)
+
+DEFAULT_ACTIVATION_FN = partial(F.gelu, approximate='tanh')
+
+
+@dataclasses.dataclass
+class Arguments:
+ # Model arguments.
+ hidden_size: int = 1024
+ ffn_hidden_size: int = 4096
+ num_layers: int = 1
+ bias: bool = True
+ return_bias: bool = True
+ activation_fn: Optional[Callable] = DEFAULT_ACTIVATION_FN
+
+ # MoE arguments.
+ moe_num_experts: int = 1
+ moe_top_k: int = 1
+ moe_capacity_factor: int = 1
+ moe_normalize_expert_weights: Optional[Union[int, float]] = None
+ moe_loss_weight: float = 0.1
+ moe_jitter_eps: Optional[float] = None
+ moe_lbl_in_fp32: bool = False
+
+ # Parallelism arguments.
+ moe_expert_model_parallelism: bool = False
+ expert_parallel_group: Optional[dist.ProcessGroup] = None
+ pipeline_model_parallel_size: int = 1
+ num_layers_per_virtual_pipeline_stage: Optional[int] = None
+
+ # Compute arguments.
+ memory_optimized_mlp: bool = False
+ mlp_type: str = 'mlp'
+ mlp_impl: str = 'sparse'
+
+ # Initialization arguments.
+ fp16: bool = True
+ bf16: bool = False
+ device: Union[int, torch.device] = dataclasses.field(default_factory=torch.cuda.current_device)
+ init_method: InitFn = partial(torch.nn.init.normal_, mean=0.0, std=0.02)
+ output_layer_init_method: InitFn = init_method
+
+ # Benchmarking arguments.
+ uniform_expert_assignment: bool = False
+
+ # shared expert arguments
+ shared_expert: bool = False # enable using shared expert
+ fc_cls: Any = torch.nn.Linear # class of the fully connected layer in shared expert (purpose: to allow using custom FC layer eg te.Linear (for FP8))
+ fc_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict,) # kwargs for custom fc layers
+ remat_act_fn: bool = True # enable act fn to be rematerialized instead of stored
+ shared_expert_hidden_size: Optional[
+ int] = None # hidden size of the shared expert IF we want to set it to something different from hidden_size
+ shared_expert_weighted_sum: bool = False # enable using weighted sum for shared expert output (wieghted by number of experts used)
+
+ # Router Z-loss arguments
+ moe_zloss_weight: float = 0 # 1e-3 is a reasonable value
+ moe_zloss_in_fp32: bool = False
+
+ def __post_init__(self):
+ # Sparse MLP is not supported with triton >=3.2.0
+ # TODO: Remove this once sparse is supported with triton >=3.2.0
+ if self.__getattribute__('mlp_impl') == 'sparse':
+ try:
+ import triton
+ if triton.__version__ >= '3.2.0':
+ raise ValueError(
+ 'Sparse MLP is not supported with triton >=3.2.0. Please use mlp_impl="grouped" instead.',
+ )
+ except ImportError:
+ raise ImportError('Triton is required for sparse MLP implementation')
+
+ if self.__getattribute__('mlp_impl') == 'grouped':
+ grouped_gemm.assert_grouped_gemm_is_available()
+
+ if self.shared_expert_hidden_size is None:
+ self.shared_expert_hidden_size = self.ffn_hidden_size
+
+
+def from_megatron(megatron_args: Any):
+ args = Arguments()
+ for field in dataclasses.fields(args):
+ if hasattr(megatron_args, field.name):
+ setattr(args, field.name, getattr(megatron_args, field.name))
+ return args
diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/common.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d07109702963ba48a3b94ab860807954dfd79c1
--- /dev/null
+++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/common.py
@@ -0,0 +1,26 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+from .arguments import Arguments
+
+
+def dtype(args: Arguments):
+ if args.fp16:
+ return torch.float16
+ elif args.bf16:
+ return torch.bfloat16
+ return None
+
+
+def cast_if_autocast_enabled(tensor):
+ if torch.is_autocast_enabled():
+ if tensor.device.type == 'cuda':
+ dtype = torch.get_autocast_gpu_dtype()
+ elif tensor.device.type == 'cpu':
+ dtype = torch.get_autocast_cpu_dtype()
+ else:
+ raise NotImplementedError()
+ return tensor.to(dtype=dtype)
+ return tensor
diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/dmlp_registry.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/dmlp_registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..de2ed047042e438c7190ebb139b6f7f30009734c
--- /dev/null
+++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/dmlp_registry.py
@@ -0,0 +1,42 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Union
+
+from . import glu, mlp
+from .arguments import Arguments
+
+MlpType = Union[mlp.SparseMLP, glu.SparseGLU]
+
+_REGISTRY = {
+ 'mlp': {
+ 'grouped': mlp.GroupedMLP,
+ 'sparse': mlp.SparseMLP,
+ },
+ 'glu': {
+ 'grouped': glu.GroupedGLU,
+ 'sparse': glu.SparseGLU,
+ },
+}
+
+
+def get(args: Arguments) -> MlpType:
+ """Returns an MLP for use in a dMoE instance.
+
+ Uses the provided arguments to instantiate the appropriate
+ MLP instance. This only contains MLPs for use in dMoEs
+ (ie. only for the dropless versions of MoEs).
+
+ Args:
+ args: propagated Arguments dataclass.
+
+ Returns:
+ An instantiated MLP constructed using the input args.
+ """
+ if args.mlp_type not in _REGISTRY:
+ raise ValueError(f'Unsupported mlp type: {args.mlp_type}')
+
+ if args.mlp_impl not in _REGISTRY[args.mlp_type]:
+ raise ValueError(f'{args.mlp_type} does not support {args.mlp_impl} backend.',)
+
+ return _REGISTRY[args.mlp_type][args.mlp_impl](args)
diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/dmoe.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/dmoe.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d0375a4df2f27134c4127e60be04f3b45693050
--- /dev/null
+++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/dmoe.py
@@ -0,0 +1,337 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import numpy as np
+import torch
+
+# try:
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/dmoe.py is needed.',
+# )
+
+# import megablocks.ops as ops
+# # from megablocks.ops import ops
+# from megablocks.layers import common, dmlp_registry, moe, mpu
+# from megablocks.layers.arguments import Arguments
+
+from .. import stk
+from .. import ops
+from . import common, dmlp_registry, moe, mpu
+from .arguments import Arguments
+
+def promote_scalar(x):
+ return x.view(1) if not len(x.size()) else x
+
+
+class ParallelDroplessMLP(moe.ParallelMLP):
+
+ def __init__(self, args: Arguments):
+ super(ParallelDroplessMLP, self).__init__(args)
+ self.hidden_size = args.hidden_size
+ self.ffn_hidden_size = mpu.features_per_rank(args)
+ self.blocking = 128
+ self.mlp = dmlp_registry.get(args)
+
+ # Calculate the number of bits needed to represent the column indices
+ # in the intermediate sparse matrix.
+ max_column_index = ((self.ffn_hidden_size * self.num_experts) // self.blocking)
+ self.transpose_sort_end_bit = max(
+ int(np.ceil(np.log2(max_column_index))),
+ 1,
+ )
+
+ def sparse_transpose(self, size, row_indices, column_indices, offsets):
+ block_columns = size[1] // self.blocking
+
+ # Sort row indices by column indices to get the transposed matrix's
+ # column indices.
+ #
+ # NOTE: Our sort operation uses the same width indices as the input values.
+ # To avoid overflow when we have large activation matrices we cast to
+ # 32-bit before sorting.
+ _, gather_indices = ops.sort(
+ column_indices.int(),
+ self.transpose_sort_end_bit,
+ )
+
+ # There are a constant number of blocks in every row of the sparse matrix.
+ # A blocks offset is:
+ #
+ # row_index * blocks_per_row + column_index % blocks_per_row
+ #
+ # Once we have the block offsets ordered for transposition we can divide
+ # by blocks_per_row to get the transposed column indices.
+ column_indices_t = row_indices.gather(0, gather_indices.long())
+ block_offsets_t = gather_indices.int()
+
+ zero = torch.zeros((1,), dtype=torch.int32, device=row_indices.device)
+ nnz_per_column = ops.histogram(column_indices, block_columns)
+ nnz_per_column = ops.inclusive_cumsum(nnz_per_column, 0)
+ if nnz_per_column.dim() == 0:
+ # This addresses an edge case when ffn_hidden_size is equal to self.blocking.
+ nnz_per_column = nnz_per_column.unsqueeze(0)
+ offsets_t = torch.cat([zero, nnz_per_column])
+ return column_indices_t, offsets_t, block_offsets_t
+
+ def topology(self, x, padded_bins):
+ padded_tokens, _ = x.size()
+ assert padded_tokens % self.blocking == 0
+ if self.ffn_hidden_size % self.blocking != 0:
+ raise ValueError(
+ f'The ffn_hidden_size {self.ffn_hidden_size} must be divisible by ' +
+ f'the block size {self.blocking}. Please update your configuration.',
+ )
+
+ # Offsets for the sparse matrix. All rows have the
+ # same number of nonzero blocks dictated by the
+ # dimensionality of a single expert.
+ block_rows = padded_tokens // self.blocking
+ blocks_per_row = self.ffn_hidden_size // self.blocking
+ offsets = torch.arange(
+ 0,
+ block_rows * blocks_per_row + 1,
+ blocks_per_row,
+ dtype=torch.int32,
+ device=x.device,
+ )
+
+ # Indices for the sparse matrix. The indices for
+ # the intermediate matrix are dynamic depending
+ # on the mapping of tokens to experts.
+ column_indices = ops.topology(
+ padded_bins,
+ self.blocking,
+ block_rows,
+ blocks_per_row,
+ )
+
+ # TODO(tgale): This is unused. Remove the need for this in stk.
+ # For now, use meta init to save the device memory.
+ data = torch.empty(
+ column_indices.numel(),
+ self.blocking,
+ self.blocking,
+ dtype=common.dtype(self.args),
+ device='meta',
+ )
+ shape = (
+ padded_tokens,
+ self.ffn_hidden_size * mpu.experts_per_rank(self.args),
+ )
+ row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
+ column_indices_t, offsets_t, block_offsets_t = self.sparse_transpose(
+ shape,
+ row_indices,
+ column_indices,
+ offsets,
+ )
+ return stk.Matrix(
+ shape,
+ data,
+ row_indices,
+ column_indices,
+ offsets,
+ column_indices_t,
+ offsets_t,
+ block_offsets_t,
+ )
+
+ def indices_and_padded_bins(self, top_experts):
+ # Sort the expert ids to produce the scatter/gather
+ # indices for the permutation.
+ top_experts = top_experts.int()
+ bin_ids, indices = ops.sort(top_experts, self.sort_end_bit)
+
+ # Histogram the expert ids to identify the number of
+ # tokens routed to each expert.
+ tokens_per_expert = ops.histogram(top_experts, self.num_experts)
+
+ # Round the token counts up to the block size used in
+ # the matrix muliplications. Caculate the starting
+ # position of each bin.
+ padded_tokens_per_expert = ops.round_up(
+ tokens_per_expert,
+ self.blocking,
+ )
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ padded_bins = promote_scalar(padded_bins)
+
+ # Calculate the bin bounds for the sorted tokens.
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ bins = promote_scalar(bins)
+ return indices, bin_ids, bins, padded_bins, tokens_per_expert
+
+ def sparse_forward_once(self, x, expert_weights, top_experts):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, padded_bins, tokens_per_expert = (self.indices_and_padded_bins(top_experts))
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.padded_gather(
+ x,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ self.top_k,
+ )
+
+ # Create the sparse matrix topology.
+ with torch.no_grad():
+ topo = self.topology(x, padded_bins)
+
+ # Perform the expert computation.
+ x = self.mlp(x, topo)
+
+ # Un-route the data for the MoE output.
+ x = ops.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ padded_bins,
+ self.top_k,
+ )
+ return x, tokens_per_expert
+
+ # For use in the base-class parallel_forward_once.
+ def sparse_permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy, # unused
+ top_k,
+ ):
+
+ # Round the token counts up to the block size used in the matrix
+ # multiplication. Calculate the starting position of each bin.
+ padded_tokens_per_expert = ops.round_up(
+ tokens_per_expert,
+ self.blocking,
+ )
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ padded_bins = promote_scalar(padded_bins)
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
+
+ # Create the sparse matrix topology.
+ with torch.no_grad():
+ topo = self.topology(x, padded_bins)
+
+ # Perform the expert computation.
+ x = self.mlp(x, topo)
+
+ # Un-route the data for the MoE output.
+ return ops.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ def grouped_forward_once(self, x, expert_weights, top_experts):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ out = self.grouped_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ -1, # unused
+ self.args.moe_top_k,
+ )
+ return out, tokens_per_expert
+
+ def grouped_permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy, # unused
+ top_k,
+ ):
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.gather(x, indices, bin_ids, bins, top_k)
+
+ # Perform the expert computation.
+ x = self.mlp(x, tokens_per_expert)
+
+ # Un-route the data for the MoE output.
+ return ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
+
+ def forward_once(self, x, expert_weights, top_experts):
+ if self.args.mlp_impl == 'sparse':
+ return self.sparse_forward_once(x, expert_weights, top_experts)
+ else:
+ return self.grouped_forward_once(x, expert_weights, top_experts)
+
+ def permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ ):
+ if self.args.mlp_impl == 'sparse':
+ return self.sparse_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ )
+ else:
+ return self.grouped_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ )
+
+
+class dMoE(moe.MoE):
+
+ def _init_experts_mlp(self, args: Arguments):
+ return ParallelDroplessMLP(args)
diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/gelu.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/gelu.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4c9e6532798615b5c12c96694241a4c18ee8f7b
--- /dev/null
+++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/gelu.py
@@ -0,0 +1,52 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# try:
+# import stk
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/gelu.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+import torch.nn.functional as F
+
+
+@torch.jit.script
+def _gelu_backward_inplace(g, x):
+ tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
+ ff = (0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out))
+ return g.mul_(ff)
+
+
+def gelu_backward_(grad: stk.Matrix, x: stk.Matrix):
+ # NOTE: The two sparse matrices must have the same topology.
+ if isinstance(grad, stk.Matrix) and isinstance(x, stk.Matrix):
+ return stk.Matrix(
+ x.size(),
+ _gelu_backward_inplace(grad.data, x.data),
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
+ return _gelu_backward_inplace(grad, x)
+
+
+def gelu(x: stk.Matrix):
+ assert isinstance(x, stk.Matrix)
+ return stk.Matrix(
+ x.size(),
+ F.gelu(x.data, approximate='tanh'),
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/glu.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/glu.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f297a41ff6a1a2a285f5b461951672364b898da
--- /dev/null
+++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/glu.py
@@ -0,0 +1,244 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# import stk.ops
+# try:
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/glu.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+
+# from megablocks import grouped_gemm_util as gg
+# from megablocks.layers import common, mpu
+# from megablocks.layers.activation_fn import act_fn
+# from megablocks.layers.arguments import Arguments
+# from megablocks.layers.mlp import (
+# SharedMLP,
+# SparseMLP,
+# create_dmoe_expert_weights,
+# resolve_dtensor,
+# )
+
+from .. import grouped_gemm_util as gg
+from . import common, mpu
+from .activation_fn import act_fn
+from .arguments import Arguments
+from .mlp import (
+ SharedMLP,
+ SparseMLP,
+ create_dmoe_expert_weights,
+ resolve_dtensor,
+)
+
+
+class SparseGLU(SparseMLP):
+
+ def __init__(self, args: Arguments):
+ super().__init__(args)
+ self.v1 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ with torch.no_grad():
+ self.v1.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ ),
+ )
+
+ mpu.set_expert_model_parallel_attributes(
+ self.v1,
+ self._should_set_parallelism_attribute,
+ )
+
+ def forward(self, x, topo):
+ if self.args.memory_optimized_mlp:
+ raise NotImplementedError(
+ 'Memory optimized implementation not yet supported with GLU with sparse kernels.',
+ )
+
+ w1, v1, w2 = self.scale_grad(self.w1), self.scale_grad(self.v1,), self.scale_grad(self.w2)
+ w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2)
+
+ # Compute the GLU.
+ x1 = stk.ops.sdd(x, w1.t(), topo)
+ x2 = stk.ops.sdd(x, v1.t(), topo)
+
+ activation_fn_out = act_fn(x1, self.args.activation_fn)
+ x1 = stk.ops.mul(activation_fn_out, x2)
+
+ return stk.ops.dsd(x1, w2)
+
+
+class MemoryOptimizedGroupedGLU(torch.autograd.Function):
+ """GroupedMLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, v1, w2, batch_sizes, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ v1 = v1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], v1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not v1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1', 'v1' and 'w2'.")
+
+ # Layer 0: x @ w1.t().
+ assert gg.backend is not None
+ sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True)
+ v1_out = gg.backend.gmm(x, v1, batch_sizes, trans_b=True)
+
+ # GeLU.
+ activation_fn_out = activation_fn(sdd_out) * v1_out
+
+ # Layer 1: x @ w2.
+ dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, v1, w2, batch_sizes, x, sdd_out, v1_out)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx, ddsd_out):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # Unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, v1, w2 = saved_tensors[:3]
+ batch_sizes = saved_tensors[3]
+ x = saved_tensors[4]
+ sdd_out, v1_out = saved_tensors[5:7]
+
+ # Rematerialize activation_fn output.
+ activation_fn = ctx.activation_fn
+ with torch.set_grad_enabled(True):
+ sdd_out.requires_grad = True
+ v1_out.requires_grad = True
+ activation_fn_out = activation_fn(sdd_out) * v1_out
+ activation_grad_fn = activation_fn_out.backward
+
+ # Compute dw2 with recomputed activation_fn output.
+ assert gg.backend is not None
+ dw2 = gg.backend.gmm(
+ activation_fn_out,
+ ddsd_out,
+ batch_sizes,
+ trans_a=True,
+ )
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ gg.backend.gmm(
+ ddsd_out,
+ w2,
+ batch_sizes,
+ trans_b=True,
+ c=dactivation_fn_out,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out)
+ dsdd_out = sdd_out.grad
+ dv1_out = v1_out.grad
+
+ # Compute dw1.
+ dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True)
+
+ # Compute dv1.
+ dv1 = gg.backend.gmm(dv1_out, x, batch_sizes, trans_a=True)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ dx = ddsd_out
+ gg.backend.gmm(dsdd_out, w1, batch_sizes, c=dx)
+ dx += gg.backend.gmm(dv1_out, v1, batch_sizes)
+ return dx, dw1, dv1, dw2, None, None
+
+
+memory_optimized_grouped_glu = MemoryOptimizedGroupedGLU.apply
+
+
+class GroupedGLU(SparseGLU):
+
+ def forward(self, x, tokens_per_expert):
+ batch_sizes = tokens_per_expert.cpu().to(torch.long)
+ w1, v1, w2 = (
+ self.scale_grad(self.w1),
+ self.scale_grad(self.v1),
+ self.scale_grad(self.w2),
+ )
+ w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2)
+
+ # Re-shape the weights for the grouped GEMMs.
+ ne = mpu.experts_per_rank(self.args)
+ w1 = w1.view(ne, -1, self.args.hidden_size)
+ v1 = v1.view(ne, -1, self.args.hidden_size)
+ w2 = w2.view(ne, -1, self.args.hidden_size)
+
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_grouped_glu(
+ x,
+ w1,
+ v1,
+ w2,
+ batch_sizes,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ assert gg.ops is not None
+ x1 = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
+ x2 = gg.ops.gmm(x, v1, batch_sizes, trans_b=True)
+ x1 = self.args.activation_fn(x1) * x2
+ return gg.ops.gmm(x1, w2, batch_sizes)
+
+
+class SharedGLU(SharedMLP):
+ """GPU for shared expert.
+
+ Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTGLU class
+ """
+
+ def __init__(self, args: Arguments):
+ super().__init__(args)
+ self.gate_proj = args.fc_cls(
+ args.hidden_size,
+ self.args.shared_expert_hidden_size,
+ **self.fc_kwargs,
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))
diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/memory_test.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/memory_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..74d1166931b712635131985b25a89f4ca23e576d
--- /dev/null
+++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/memory_test.py
@@ -0,0 +1,103 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import gc
+
+import torch
+import torch.distributed as dist
+
+# from megablocks.layers import arguments, dmoe
+from . import arguments, dmoe
+
+_TESTS = ((8, 2048, 4096, 4096, 32, 4),)
+
+
+def get_tensors():
+ ptrs = set()
+ out = []
+ for obj in gc.get_objects():
+ if torch.is_tensor(obj):
+ if not obj.is_contiguous() or obj.data_ptr() in ptrs:
+ continue
+ out.append(obj)
+ ptrs.add(obj.data_ptr())
+ return out
+
+
+def test_memory(
+ group,
+ batch_size,
+ sequence_length,
+ hidden_size,
+ ffn_hidden_size,
+ num_experts,
+ top_k,
+):
+ args = arguments.Arguments(
+ hidden_size=hidden_size,
+ ffn_hidden_size=ffn_hidden_size,
+ moe_num_experts=num_experts,
+ moe_top_k=top_k,
+ moe_expert_model_parallelism=True,
+ expert_parallel_group=group,
+ fp16=False,
+ bf16=True,
+ device=torch.cuda.current_device(),
+ )
+ layer = dmoe.dMoE(args).cuda()
+
+ x = torch.randn((batch_size, sequence_length, hidden_size),
+ device=torch.cuda.current_device(),
+ dtype=torch.bfloat16).requires_grad_(True)
+ torch.cuda.empty_cache()
+
+ # Run forward + backward.
+ # with torch.autograd.detect_anomaly():
+ out, _ = layer(x)
+ out.mean().backward()
+
+ # Report peak memory.
+ mem = torch.cuda.max_memory_allocated()
+ print('Max Memory Allocated = {:0.0f}MiB'.format(mem / 1e6))
+ print('Max Memory Reserved = {:0.0f}MiB'.format(torch.cuda.max_memory_reserved() / 1e6,),)
+
+ # Calculate weight and gradient memory usage.
+ weight_memory = 2 * (
+ layer.router.layer.weight.numel() + layer.experts.mlp.w1.numel() + layer.experts.mlp.w2.numel()
+ )
+
+ def grad_numel(x):
+ if x.grad is not None:
+ return x.grad.numel()
+ return 0
+
+ grad_memory = 2 * (
+ grad_numel(layer.router.layer.weight) + grad_numel(layer.experts.mlp.w1) + grad_numel(layer.experts.mlp.w2)
+ )
+ weight_memory += grad_memory
+
+ print('Weight Memory Allocated = {:0.0f}MiB'.format(weight_memory / 1e6))
+ print('Activation Memory Allocated = {:0.0f}MiB'.format((mem - weight_memory) / 1e6,),)
+
+ # Manually calculate GPU memory usage from the garbage
+ # collector.
+ gc.collect()
+ total = 0
+ tensors = get_tensors()
+ tensors = sorted(tensors, key=lambda x: -x.numel())
+ for i, t in enumerate(tensors):
+ total += t.numel()
+ print(f'{i}: {t.shape}, {t.numel() * 2}')
+ del tensors
+
+ print('Total Bytes Found = {:0.0f}MiB'.format(total * 2 / 1e6))
+
+
+if __name__ == '__main__':
+ assert dist.is_available()
+ group = dist.init_process_group(backend='nccl')
+ local_rank = dist.get_rank(group)
+ torch.cuda.set_device(local_rank)
+
+ for args in _TESTS:
+ test_memory(group, *args)
diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/mlp.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/mlp.py
new file mode 100644
index 0000000000000000000000000000000000000000..c99afb9904c24a8b6a83e79059cd1251dbbfd99e
--- /dev/null
+++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/mlp.py
@@ -0,0 +1,587 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# try:
+# import stk
+# import stk.backend.triton_kernels
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/mlp.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+from packaging import version
+
+# from megablocks import grouped_gemm_util as gg
+# from megablocks.layers import common, gelu, mpu
+# from megablocks.layers.activation_fn import act_fn
+# from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
+
+from .. import grouped_gemm_util as gg
+from . import common, gelu, mpu
+from .activation_fn import act_fn
+from .arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
+
+class ScaleGradient(torch.autograd.Function):
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx: Any, x: torch.Tensor, scale: float):
+ ctx.scale = scale
+ return x
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx: torch.Tensor, grad: torch.Tensor):
+ return grad * ctx.scale, None
+
+
+scale_gradient = ScaleGradient.apply
+
+
+def resolve_dtensor(weight: torch.Tensor):
+ if version.parse(torch.__version__) >= version.parse('2.0.0'):
+ from torch.distributed._tensor import DTensor
+ if isinstance(weight, DTensor):
+ return weight.to_local()
+ return weight
+
+
+def create_moe_expert_weights(
+ args: Arguments,
+ num_experts: int,
+ ffn_hidden_size: int,
+ hidden_size: int,
+ init_method: InitFn,
+):
+ # Create the entire weight matrix such that the sampled weights will
+ # not vary between data parallelism and expert model parallelism for
+ # the same random seed.
+ master_weights = torch.empty(
+ num_experts,
+ ffn_hidden_size,
+ hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ )
+ init_method(master_weights)
+
+ if not args.moe_expert_model_parallelism:
+ return master_weights
+
+ # Calculate the amount of sharding in each dimension.
+ expert_sharding_degree = mpu.expert_sharding_degree(args)
+ hidden_sharding_degree = mpu.hidden_sharding_degree(args)
+
+ # Calculate the experts per rank.
+ #
+ # NOTE: We assign ranks to be expert parallel before going
+ # tensor parallel.
+ rank = mpu.get_expert_parallel_rank(args)
+ expert_rank = rank % expert_sharding_degree
+ num_experts_per_rank = num_experts // expert_sharding_degree
+ start_expert = expert_rank * num_experts_per_rank
+ end_expert = (expert_rank + 1) * num_experts_per_rank
+
+ # Calculate the rows per rank.
+ row_rank = rank // expert_sharding_degree
+ num_rows_per_rank = ffn_hidden_size // hidden_sharding_degree
+ start_row = row_rank * num_rows_per_rank
+ end_row = (row_rank + 1) * num_rows_per_rank
+
+ # Slice the weight matrix to get the chunk for this rank.
+ with torch.no_grad():
+ weights = master_weights[start_expert:end_expert, start_row:end_row]
+ return weights
+
+
+class MLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ # expert_parallel_world_size = mpu.get_expert_parallel_world_size(args)
+ experts_per_rank = mpu.experts_per_rank(args)
+
+ self.w1 = torch.nn.Parameter(
+ torch.empty(
+ experts_per_rank,
+ args.hidden_size,
+ mpu.features_per_rank(args),
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ self.w2 = torch.nn.Parameter(
+ torch.empty(
+ experts_per_rank,
+ mpu.features_per_rank(args),
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w1,
+ args.moe_expert_model_parallelism,
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w2,
+ args.moe_expert_model_parallelism,
+ )
+
+ # Initialize the parameters for the MLP.
+ #
+ # NOTE: It is important that we create the weight tensors prior
+ # to creating the master weights and slicing our the piece for
+ # this rank. If the master weights are created first the PyTorch
+ # caching allocator appears to use the same memory block for these
+ # and the slice which causes large increases in our peak memory
+ # usage.
+ with torch.no_grad():
+ w1 = create_moe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ )
+ self.w1.copy_(w1.transpose(1, 2).contiguous())
+ self.w2.copy_(
+ create_moe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.output_layer_init_method,
+ ),
+ )
+
+ self.gradient_scale = None
+ if self.args.moe_expert_model_parallelism:
+ self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,)
+
+ def scale_grad(self, w):
+ if self.gradient_scale is None:
+ return w
+ return scale_gradient(w, self.gradient_scale)
+
+ def forward(self, x):
+ w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2)
+ w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2)
+ x = torch.bmm(x, w1)
+ x = self.args.activation_fn(x)
+ return torch.bmm(x, w2)
+
+
+def create_dmoe_expert_weights(
+ args: Arguments,
+ num_experts: int,
+ rows: int,
+ columns: int,
+ init_method: InitFn,
+):
+ weights = create_moe_expert_weights(
+ args,
+ num_experts,
+ rows,
+ columns,
+ init_method,
+ )
+ return weights.view([-1, columns])
+
+
+class MemoryOptimizedMLP(torch.autograd.Function):
+ """Sparse MLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, w2, topo, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.")
+
+ topo_tensors = (
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets,
+ topo.column_indices_t,
+ topo.offsets_t,
+ topo.block_offsets_t,
+ )
+
+ # Layer 0: x @ w1.t().
+ sdd_out = stk.ops.sdd(x, w1.t(), topo)
+
+ # GeLU.
+ activation_fn_out = act_fn(sdd_out, activation_fn)
+
+ # Layer 1: x @ w2.
+ dsd_out = stk.ops.dsd(activation_fn_out, w2)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.shape = topo.shape
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.data.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, w2, *topo_tensors, x, sdd_out.data)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx, ddsd_out):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, w2 = saved_tensors[:2]
+ topo_tensors = saved_tensors[2:8]
+ x = saved_tensors[8]
+ sdd_out_data = saved_tensors[9]
+
+ # rematerialize activation function output
+ activation_fn = ctx.activation_fn
+ sdd_out = stk.Matrix(ctx.shape, sdd_out_data, *topo_tensors)
+ activation_fn_out, activation_grad_fn = act_fn(
+ sdd_out,
+ activation_fn,
+ return_grad_fn=True,
+ )
+
+ # Compute dw2 with recomputed activation_fn output.
+ dw2 = stk.ops.dsd(activation_fn_out.t(), ddsd_out)
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ stk.backend.triton_kernels.sdd(
+ ddsd_out,
+ w2.t(),
+ dactivation_fn_out.shape,
+ dactivation_fn_out.data,
+ dactivation_fn_out.offsets,
+ dactivation_fn_out.row_indices,
+ dactivation_fn_out.column_indices,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ if activation_fn is DEFAULT_ACTIVATION_FN:
+ dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out)
+ else:
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out.data)
+ dsdd_out = stk.Matrix(ctx.shape, sdd_out.data.grad, *topo_tensors)
+
+ # Compute dw1.
+ dw1 = stk.ops.dsd(dsdd_out.t(), x)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ stk.backend.triton_kernels.dsd(
+ dsdd_out.shape,
+ dsdd_out.data,
+ dsdd_out.offsets,
+ dsdd_out.row_indices,
+ dsdd_out.column_indices,
+ dsdd_out.offsets_t,
+ dsdd_out.column_indices_t,
+ dsdd_out.block_offsets_t,
+ False,
+ w1,
+ ddsd_out,
+ )
+ dx = ddsd_out
+ return dx, dw1, dw2, None, None
+
+
+memory_optimized_mlp = MemoryOptimizedMLP.apply
+
+
+class SparseMLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ self._num_rows_per_rank = mpu.experts_per_rank(args) * mpu.features_per_rank(args)
+
+ self.w1 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ self.w2 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+
+ # Initialize the parameters for the MLP.
+ #
+ # NOTE: It is important that we create the weight tensors prior
+ # to creating the master weights and slicing our the piece for
+ # this rank. If the master weights are created first the PyTorch
+ # caching allocator appears to use the same memory block for these
+ # and the slice which causes large increases in our peak memory
+ # usage.
+ with torch.no_grad():
+ self.w1.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ ),
+ )
+ self.w2.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.output_layer_init_method,
+ ),
+ )
+
+ self._should_set_parallelism_attribute = args.moe_expert_model_parallelism
+ mpu.set_expert_model_parallel_attributes(
+ self.w1,
+ self._should_set_parallelism_attribute,
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w2,
+ self._should_set_parallelism_attribute,
+ )
+
+ self.gradient_scale = None
+ if self.args.moe_expert_model_parallelism:
+ self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,)
+
+ def scale_grad(self, w):
+ if self.gradient_scale is None:
+ return w
+ return scale_gradient(w, self.gradient_scale)
+
+ def forward(self, x, topo):
+ w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2)
+ w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2)
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_mlp(
+ x,
+ w1,
+ w2,
+ topo,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ x = stk.ops.sdd(x, w1.t(), topo)
+ activation_fn_out = act_fn(x, self.args.activation_fn)
+ return stk.ops.dsd(activation_fn_out, w2)
+
+
+class MemoryOptimizedGroupedMLP(torch.autograd.Function):
+ """GroupedMLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, w2, batch_sizes, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.")
+
+ # Layer 0: x @ w1.t().
+ assert gg.backend is not None
+ sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True)
+
+ # activation_fn
+ activation_fn_out = activation_fn(sdd_out)
+
+ # Layer 1: x @ w2.
+ dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, w2, batch_sizes, x, sdd_out)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx: Any, ddsd_out: torch.Tensor):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # Unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, w2 = saved_tensors[:2]
+ batch_sizes = saved_tensors[2]
+ x = saved_tensors[3]
+ sdd_out = saved_tensors[4]
+
+ # Rematerialize activation_fn output.
+ activation_fn = ctx.activation_fn
+ with torch.set_grad_enabled(True):
+ sdd_out.requires_grad = True
+ activation_fn_out = activation_fn(sdd_out)
+ activation_grad_fn = activation_fn_out.backward
+
+ # Compute dw2 with recomputed activation_fn output.
+ assert gg.backend is not None
+ dw2 = gg.backend.gmm(
+ activation_fn_out,
+ ddsd_out,
+ batch_sizes,
+ trans_a=True,
+ )
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ gg.backend.gmm(
+ ddsd_out,
+ w2,
+ batch_sizes,
+ trans_b=True,
+ c=dactivation_fn_out,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ if activation_fn is DEFAULT_ACTIVATION_FN:
+ dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out)
+ else:
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out)
+ dsdd_out = sdd_out.grad
+
+ # Compute dw1.
+ dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ gg.backend.gmm(dsdd_out, w1, batch_sizes, c=ddsd_out)
+ dx = ddsd_out
+ return dx, dw1, dw2, None, None
+
+
+memory_optimized_grouped_mlp = MemoryOptimizedGroupedMLP.apply
+
+
+class GroupedMLP(SparseMLP):
+
+ def forward(self, x, tokens_per_expert):
+ batch_sizes = tokens_per_expert.cpu().to(torch.long)
+ w1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.w2))
+
+ # Re-shape the weights for the grouped GEMMs.
+ ne = mpu.experts_per_rank(self.args)
+ w1 = resolve_dtensor(w1).view(ne, -1, self.args.hidden_size)
+ w2 = resolve_dtensor(w2).view(ne, -1, self.args.hidden_size)
+
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_grouped_mlp(
+ x,
+ w1,
+ w2,
+ batch_sizes,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ assert gg.ops is not None
+ x = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
+ x = self.args.activation_fn(x)
+ return gg.ops.gmm(x, w2, batch_sizes)
+
+
+class SharedMLP(torch.nn.Module):
+ """MLP for shared expert.
+
+ Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTMLP class
+ """
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ self.fc_kwargs: dict[str, Any] = {
+ 'bias': args.bias,
+ 'device': args.device,
+ }
+ self.fc_kwargs.update(args.fc_kwargs)
+
+ self.up_proj = args.fc_cls(
+ args.hidden_size,
+ args.shared_expert_hidden_size,
+ **self.fc_kwargs,
+ )
+ self.act = args.activation_fn
+ self.down_proj = args.fc_cls(
+ args.shared_expert_hidden_size,
+ args.hidden_size,
+ **self.fc_kwargs,
+ )
+ self.down_proj._is_residual = True # a flag for llm-foundry init
+
+ def add_experts_sharedexpert(
+ self,
+ shared_expert_out: torch.Tensor,
+ expert_out: torch.Tensor,
+ ) -> torch.Tensor:
+ # Helper function to add expert output to shared expert output
+ # with optional weighted sum.
+ if self.args.shared_expert_weighted_sum:
+ # enable using weighted sum for shared expert output
+ # wieghted by number of experts used
+ t_experts = self.args.moe_top_k + 1
+ sh_mlp_out = shared_expert_out / t_experts
+ return sh_mlp_out.add(
+ expert_out,
+ alpha=(self.args.moe_top_k / t_experts),
+ )
+
+ return shared_expert_out + expert_out
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.down_proj(self.act(self.up_proj(x)))
diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/moe.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0a4aeaacc9c86fc70944e730c53f7a55644e05e
--- /dev/null
+++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/moe.py
@@ -0,0 +1,507 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Optional, Tuple
+
+import numpy as np
+import torch
+import torch.distributed as dist
+
+# import megablocks.ops as ops
+# from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry
+# from megablocks.layers.all_to_all import all_to_all
+# from megablocks.layers.arguments import Arguments
+
+from ..ops import (
+ sort,
+ histogram,
+ inclusive_cumsum,
+ exclusive_cumsum,
+ binned_gather,
+ binned_scatter,
+ gather,
+ scatter,
+ repeat,
+ replicate,
+)
+
+from . import common, mlp, mpu, router, sharedexpert_registry
+from .arguments import Arguments
+from .all_to_all import all_to_all
+
+_LOAD_BALANCING_LOSS = []
+
+
+def save_load_balancing_loss(loss):
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.append(loss)
+
+
+def get_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ return _LOAD_BALANCING_LOSS
+
+
+def clear_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.clear()
+
+
+def batched_load_balancing_loss(args: Arguments):
+ if args.moe_loss_weight == 0:
+ return 0.0
+
+ # tokens_per_expert[i].shape = (num_experts)
+ # expert_scores[i].shape = (tokens, num_experts)
+ tokens_per_expert, expert_scores = zip(*get_load_balancing_loss())
+ num_layers_per_pipeline_stage = (args.num_layers // args.pipeline_model_parallel_size)
+ if args.num_layers_per_virtual_pipeline_stage is not None:
+ num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage
+
+ if len(tokens_per_expert) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f'Expected {num_layers_per_pipeline_stage} token_per_experts '
+ f'but found {len(tokens_per_expert)}.\nnum_layers = '
+ f'{args.num_layers}\npipeline_model_parallel_size = '
+ f'{args.pipeline_model_parallel_size}\n'
+ 'num_layers_per_virtual_pipeline_stage'
+ f' = {args.num_layers_per_virtual_pipeline_stage}',
+ )
+ if len(expert_scores) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f'Expected {num_layers_per_pipeline_stage} expert_scores '
+ f'but found {len(tokens_per_expert)}.\nnum_layers = '
+ f'{args.num_layers}\npipeline_model_parallel_size = '
+ f'{args.pipeline_model_parallel_size}\n'
+ 'num_layers_per_virtual_pipeline_stage'
+ f' = {args.num_layers_per_virtual_pipeline_stage}',
+ )
+
+ # Verify the shape of the tokens_per_expert and expert_scores tensors.
+ assert all((x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert))
+
+ tokens = expert_scores[0].shape[0]
+ assert all(((x.ndim == 2 and x.shape[1] == args.moe_num_experts and x.shape[0] == tokens) for x in expert_scores))
+
+ # Concatenate the contributions of each layer and convert to
+ # the correct types and formats for the dot product.
+ expert_scores = torch.cat(expert_scores, dim=1)
+ if args.moe_lbl_in_fp32:
+ expert_scores = expert_scores.float()
+ if tokens != 0:
+ expert_scores = expert_scores.mean(dim=0)
+ else:
+ expert_scores = expert_scores.sum(dim=0)
+ tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype)
+
+ expected_values = num_layers_per_pipeline_stage * args.moe_num_experts
+ assert tokens_per_expert.numel() == expected_values
+ assert expert_scores.numel() == expected_values
+
+ # Calculate the total scale across all factors.
+ #
+ # loss_weight * num_experts / (num_layers * tokens * top_k)
+ scale_numerator = (args.moe_num_experts * args.moe_loss_weight)
+ scale_denominator = (args.num_layers * tokens * args.moe_top_k)
+ scale = scale_numerator / scale_denominator
+ return scale * torch.dot(tokens_per_expert, expert_scores)
+
+
+# NOTE: This class defines MoE expert computation, including expert model parallel
+# communication. When using FSDP on top of MegaBlocks this is the module that should
+# be wrapped s.t. the weight all-gathers can be scheduled *before* the expert model
+# parallel all2all.
+class ParallelMLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super(ParallelMLP, self).__init__()
+ self.args = args
+
+ # Calculate the number of experts in total and the number of experts
+ # owned by this rank.
+ # world_size = mpu.get_expert_parallel_world_size(args)
+ self.num_experts = args.moe_num_experts
+ self.top_k = self.args.moe_top_k
+
+ # Calculate the number of bits needed to represent the expert indices
+ # so that we can pass it to radix sort.
+ self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1)
+
+ # Expert MLP.
+ self.mlp = mlp.MLP(args)
+
+ self.bias: Optional[torch.Tensor]
+ if self.args.bias:
+ # Note that the output bias is not parallelized with expert
+ # model parallelism.
+ self.bias = torch.nn.Parameter(
+ torch.empty(
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ torch.nn.init.zeros_(self.bias)
+ else:
+ self.register_parameter('bias', None)
+
+ # Select the forward function for the operating mode.
+ self.forward_fn = (self.parallel_forward_once if args.moe_expert_model_parallelism else self.forward_once)
+
+ def expert_capacity(self, tokens: int) -> int:
+ world_size = mpu.get_expert_parallel_world_size(self.args)
+ tokens_per_expert = (self.top_k * tokens * world_size / self.num_experts)
+ return int(self.args.moe_capacity_factor * tokens_per_expert)
+
+ def load_balancing_loss(self, tokens_per_expert: torch.Tensor, expert_scores: torch.Tensor):
+ """Calculate the load balancing loss contribution."""
+ assert len(expert_scores.size()) == 2
+ tokens, num_experts = expert_scores.size()
+ assert num_experts == self.num_experts
+ assert len(tokens_per_expert.size()) == 1
+ num_experts, = tokens_per_expert.size()
+ assert num_experts == self.num_experts
+ scale = self.num_experts / (tokens * self.top_k)
+ return scale * torch.dot(
+ tokens_per_expert.to(expert_scores.dtype),
+ expert_scores.mean(dim=0),
+ )
+
+ def indices_and_bins(self,
+ top_expert: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ # Sort the expert ids to produce the scatter/gather
+ # indices for the permutation.
+ #
+ # TODO(tgale): Is it worth doing this conversion to 32-bit
+ # prior? Could we place the `torch.max` operation to return
+ # 32-bit expert indices?
+ top_expert = top_expert.int()
+ # output = ops.sort(top_expert, self.sort_end_bit)
+ output = sort(top_expert, self.sort_end_bit)
+ assert output is not None
+ bin_ids, indices = output
+
+ # Histogram the expert ids to identify the number of
+ # tokens routed to each expert.
+ #
+ # TODO(tgale): Does the sorted data produce a more favorable
+ # data distribution for histogram? Or is the op parallelism
+ # worth more?
+ # tokens_per_expert = ops.histogram(top_expert, self.num_experts)
+ tokens_per_expert = histogram(top_expert, self.num_experts)
+
+ # Calculate the bin bounds for the sorted tokens.
+ # bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ bins = inclusive_cumsum(tokens_per_expert, 0)
+ assert bins is not None
+ bins = bins.view(1) if not len(bins.size()) else bins
+
+ assert isinstance(indices, torch.Tensor)
+ assert isinstance(bin_ids, torch.Tensor)
+ assert isinstance(bins, torch.Tensor)
+ assert isinstance(tokens_per_expert, torch.Tensor)
+
+ return indices, bin_ids, bins, tokens_per_expert
+
+ def permute_and_compute(
+ self,
+ x: torch.Tensor,
+ tokens_per_expert: int, # unused
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor, # unused
+ expert_weights: torch.Tensor,
+ bins: torch.Tensor,
+ expert_capacity: int,
+ top_k: int,
+ ):
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ # output = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
+ output = binned_gather(x, indices, bins, expert_capacity, top_k)
+ assert output is not None
+ x = output
+
+ # Perform the expert computation. Note that we don't
+ # use biases for these linear operations.
+ x = self.mlp(x)
+
+ # Un-route the data for the MoE output.
+ # return ops.binned_scatter(x, indices, expert_weights, bins, top_k)
+ return binned_scatter(x, indices, expert_weights, bins, top_k)
+
+
+ def forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ # If expert_capacity is set to zero, set the number of tokens
+ # per expert to the maximum we need to avoid dropping tokens.
+ sl, bs, _ = x.size()
+ expert_capacity = self.expert_capacity(sl * bs)
+ if expert_capacity == 0:
+ expert_capacity = torch.max(tokens_per_expert).item()
+
+ x = self.permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ self.top_k,
+ )
+ return x, tokens_per_expert
+
+ def parallel_forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ # NOTE: This function implements the same computation as forward_once
+ # but with expert model parallelism.
+ #
+ # 1. Permute the tokens locally so that they are grouped by their
+ # expert assignments. This allows us to transfer all of the tokens
+ # for a remote device in one communication primitive.
+ #
+ # 2. Permute the tokens across the expert parallel devices. After
+ # this is completed each device has all of the tokens assigned to
+ # its set of experts in its local HBM.
+ #
+ # 3. Permute the tokens locally so that they are grouped by their
+ # expert assignement. After the distributed permutation the tokens
+ # are grouped by which device they came from. We re-order them
+ # locally to allow for efficient computation.
+ #
+ # After this series of permutations we compute the linear layers
+ # and then repeat these three steps in reverse to produce the final
+ # output.
+ #
+ # Compute the mapping of local tokens to experts.
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ # If we're sharding the experts along the hidden dimension
+ # multiple devices own parts of the same sets of experts.
+ # Replicate the token counts so every device gets the counts.
+ # repeated_tokens_per_expert = ops.repeat(
+ repeated_tokens_per_expert = repeat(
+ tokens_per_expert,
+ (mpu.hidden_sharding_degree(self.args),),
+ )
+
+ # Pass token count information to the device on which the
+ # target expert resides.
+ parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert,)
+ tpe_handle = dist.all_to_all_single(
+ parallel_tokens_per_expert,
+ repeated_tokens_per_expert,
+ group=self.args.expert_parallel_group,
+ async_op=True,
+ )
+
+ # Permute locally and without any padding so that tokens for each
+ # parallel device are stored contiguously.
+ #
+ # This view updates the shape of the tensor from [sl, bs, hs] to
+ # [sl * bs, hs] prior to the permutation.
+ x = x.view(-1, x.shape[-1])
+ # output = ops.gather(x, indices, bin_ids, bins, self.top_k)
+ output = gather(x, indices, bin_ids, bins, self.top_k)
+ assert output is not None
+ x = output
+
+ # Compute the number of tokens that will be received from each
+ # device and permute the input data across the devices.
+ with torch.no_grad():
+ tpe_handle.wait()
+ experts_per_rank = mpu.experts_per_rank(self.args)
+
+ # Reshape to [world_size, num_experts_per_rank].
+ world_size = mpu.get_expert_parallel_world_size(self.args)
+ repeated_tokens_per_expert = (repeated_tokens_per_expert.view(world_size, experts_per_rank))
+ parallel_tokens_per_expert = (parallel_tokens_per_expert.view(world_size, experts_per_rank))
+
+ # TODO(tgale): It might be faster to do this on the GPU and
+ # then communicate the results back to the host.
+ send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1)
+ parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
+ recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1)
+
+ # Convert the send/recv counts to lists.
+ send_counts = send_counts.tolist()
+ recv_counts = recv_counts.tolist()
+ tokens_received = sum(recv_counts)
+
+ # If we're sharding the experts along the hidden dimension
+ # multiple devices own parts of the same sets of experts.
+ # Replicate the token counts so devices that share experts
+ # get all of the tokens assigned to them.
+ #
+ # TODO(tgale): Fuse this into the prior, local permutation.
+ # x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
+ x = repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
+
+ # Start the cross-device permutation asynchronously so we can
+ # overlap communication with computation.
+ parallel_x, parallel_x_handle = all_to_all(
+ x,
+ recv_counts,
+ send_counts,
+ self.args.expert_parallel_group,
+ async_op=True,
+ )
+
+ with torch.no_grad():
+ # After we do the cross-device permutation we have the tokens on the
+ # correct device but not yet grouped by expert because we received
+ # tokens from each device as contiguous chunks. To group the tokens
+ # for expert computation we'll do one more local permutation. The
+ # rest of this torch.no_grad() scope sets up the indices and bins
+ # for this permutation.
+ # replicate_bins = ops.inclusive_cumsum(
+ replicate_bins = inclusive_cumsum(
+ parallel_tokens_per_expert.flatten(),
+ 0,
+ )
+ replicate_bins = (replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins)
+
+ # Construct the expert indices for the permuted tokens.
+ parallel_top_expert = torch.remainder(
+ torch.arange(
+ self.num_experts * mpu.hidden_sharding_degree(self.args),
+ dtype=torch.int32,
+ device=indices.device,
+ ),
+ mpu.experts_per_rank(self.args),
+ )
+ # parallel_top_expert = ops.replicate(
+ parallel_top_expert = replicate(
+ parallel_top_expert.unsqueeze(dim=0),
+ replicate_bins,
+ tokens_received,
+ ).flatten()
+
+ # TODO(tgale): The sort_end_bit here can be reduced.
+ # parallel_bin_ids, parallel_indices = ops.sort(
+ parallel_bin_ids, parallel_indices = sort(
+ parallel_top_expert,
+ self.sort_end_bit,
+ )
+
+ # Calculate the bins boundaries from the token counts.
+ parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+ # parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = (parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins)
+
+ # If expert_capacity is set to zero, set the number of tokens
+ # per expert to the maximum we need to avoid dropping tokens.
+ tokens, _ = x.size()
+ expert_capacity = self.expert_capacity(tokens)
+ if expert_capacity == 0:
+ expert_capacity = torch.max(parallel_tokens_per_expert).item()
+
+ # Locally permute the tokens and perform the expert computation.
+ # Block to make sure that the cross-device permutation is complete.
+ if self.args.mlp_impl == 'grouped':
+ # GroupedMLP requires counts on CPU. We can use the tensor already
+ # moved to CPU for the prior all_to_all, which avoids an extra
+ # device synchronization.
+ parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+ parallel_x_handle.wait()
+ parallel_x = self.permute_and_compute(
+ parallel_x,
+ parallel_tokens_per_expert,
+ parallel_indices,
+ parallel_bin_ids,
+ None, # expert_weights
+ parallel_bins,
+ expert_capacity,
+ top_k=1,
+ )
+
+ # Un-permute the tokens across the devices.
+ x, _ = all_to_all(
+ parallel_x,
+ send_counts,
+ recv_counts,
+ self.args.expert_parallel_group,
+ )
+
+ # Reduce along the hidden sharding to get the final outputs.
+ #
+ # TODO(tgale): Fuse this into the following local permutation.
+ shape = (
+ mpu.hidden_sharding_degree(self.args),
+ -1,
+ self.args.hidden_size,
+ )
+ # x = ops.sum(x.view(shape), dim=0)
+ x = x.view(shape).sum(dim=0)
+
+ # Un-permute locally to setup for the next series of operations.
+ # x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
+ x = scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
+ return x, tokens_per_expert.flatten()
+
+ def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ in_shape = x.size()
+
+ # Compute the experts.
+ x, tokens_per_expert = self.forward_fn(x, expert_weights, top_experts)
+ if self.training and self.args.moe_loss_weight > 0:
+ save_load_balancing_loss((tokens_per_expert, scores))
+ x = x.view(in_shape)
+ if self.bias is not None:
+ if self.args.return_bias:
+ return x, self.bias
+ return x + self.bias
+ return x
+
+
+class MoE(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super(MoE, self).__init__()
+
+ # Token router.
+ self.router = router.LearnedRouter(args)
+
+ # Expert computation helper.
+ self.experts = self._init_experts_mlp(args)
+
+ self.shared_expert = None
+ if args.shared_expert:
+ # SharedExpert computation helper.
+ self.shared_expert = sharedexpert_registry.get(args)
+
+ def _init_experts_mlp(self, args: Arguments):
+ return ParallelMLP(args)
+
+ def forward(self, x: torch.Tensor):
+ # NOTE: If we're going to cast the activations to lower precision
+ # do it before we permute the tokens to save bandwidth.
+ x = common.cast_if_autocast_enabled(x)
+
+ # Compute the expert scores and assignments.
+ scores, expert_weights, top_experts = self.router(x)
+
+ # Compute the experts.
+ out = self.experts(x, scores, expert_weights, top_experts)
+ if self.shared_expert is not None:
+ shared_expert_out = self.shared_expert(x)
+ out = self.shared_expert.add_experts_sharedexpert(
+ shared_expert_out,
+ out,
+ )
+ return out
diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/mpu.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/mpu.py
new file mode 100644
index 0000000000000000000000000000000000000000..434e143ab42bf3f83406d69e9dd1f72777716e22
--- /dev/null
+++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/mpu.py
@@ -0,0 +1,94 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Optional
+
+import torch
+import torch.distributed as dist
+
+# from megablocks.layers.arguments import Arguments
+from .arguments import Arguments
+
+
+class MoeParam(torch.Tensor):
+
+ def __init__(self):
+ super().__init__(self)
+ self.expert_model_parallel: bool
+
+
+def is_moe_param(tensor: torch.Tensor) -> bool:
+ return hasattr(tensor, 'expert_model_parallel')
+
+
+def get_expert_parallel_world_size(args: Arguments) -> int:
+ return (dist.get_world_size(args.expert_parallel_group) if args.moe_expert_model_parallelism else 1)
+
+
+def get_expert_parallel_rank(args: Arguments) -> int:
+ return (dist.get_rank(args.expert_parallel_group) if args.moe_expert_model_parallelism else 0)
+
+
+def set_expert_model_parallel_attributes(
+ tensor: torch.Tensor,
+ is_parallel: bool,
+):
+ assert not hasattr(tensor, 'expert_model_parallel')
+ setattr(tensor, 'expert_model_parallel', is_parallel)
+
+
+def param_is_expert_model_parallel(param: MoeParam) -> bool:
+ return (hasattr(param, 'expert_model_parallel') and param.expert_model_parallel)
+
+
+def copy_expert_model_parallel_attributes(
+ destination_tensor: torch.Tensor,
+ source_tensor: torch.Tensor,
+):
+ if hasattr(source_tensor, 'expert_model_parallel'):
+ setattr(
+ destination_tensor,
+ 'expert_model_parallel',
+ getattr(source_tensor, 'expert_model_parallel'),
+ )
+
+
+def synchronized_print(group: Optional[dist.ProcessGroup], *x: torch.Tensor):
+ world_size = dist.get_world_size(group)
+ rank = dist.get_rank(group)
+ for i in range(world_size):
+ dist.barrier(group)
+ if i == rank:
+ print(f'rank = {rank}', *x)
+
+
+# Helpers for expert/tensor sharding.
+def expert_sharding_degree(args: Arguments) -> int:
+ world_size = get_expert_parallel_world_size(args)
+ esd = min(world_size, args.moe_num_experts)
+
+ if (args.moe_num_experts % esd) != 0:
+ raise ValueError(f'Cannot shard {args.moe_num_experts} experts {esd} ways.',)
+ return esd
+
+
+def hidden_sharding_degree(args: Arguments) -> int:
+ world_size = get_expert_parallel_world_size(args)
+ esd = expert_sharding_degree(args)
+ hsd = world_size // esd
+
+ if (args.ffn_hidden_size % hsd) != 0:
+ raise ValueError(f'Cannot shard {args.ffn_hidden_size} features {hsd} ways.',)
+ if (esd * hsd) != world_size:
+ raise ValueError(
+ f"Invalid sharding. 'expert_sharding_degree' ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size}).",
+ )
+ return hsd
+
+
+def experts_per_rank(args: Arguments) -> int:
+ return args.moe_num_experts // expert_sharding_degree(args)
+
+
+def features_per_rank(args: Arguments) -> int:
+ return args.ffn_hidden_size // hidden_sharding_degree(args)
diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/router.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/router.py
new file mode 100644
index 0000000000000000000000000000000000000000..37cb2782348d62583376f1a183c7ede83601216d
--- /dev/null
+++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/router.py
@@ -0,0 +1,116 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+
+# from megablocks.layers import common
+# from megablocks.layers.arguments import Arguments
+from . import common
+from .arguments import Arguments
+
+_ROUTER_LOGITS = []
+
+
+def _save_router_logits(logits: torch.Tensor, args: Arguments):
+ if args.moe_zloss_weight == 0:
+ return
+ global _ROUTER_LOGITS
+ _ROUTER_LOGITS.append(logits)
+
+
+def clear_router_zloss():
+ global _ROUTER_LOGITS
+ _ROUTER_LOGITS.clear()
+
+
+def batched_router_zloss(args: Arguments):
+ global _ROUTER_LOGITS
+
+ if args.moe_zloss_weight == 0:
+ import warnings
+ warnings.warn('Call to batched_router_zloss, but moe_zloss_weight=0')
+ return 0
+
+ logits_per_router = _ROUTER_LOGITS
+
+ if args.moe_zloss_in_fp32:
+ logits_per_router = [logits.float() for logits in logits_per_router]
+
+ unscaled_zloss_per_router = torch.stack([
+ torch.logsumexp(logits, dim=1).square().mean() for logits in logits_per_router
+ ])
+
+ return args.moe_zloss_weight * unscaled_zloss_per_router
+
+
+# NOTE: To enable end-to-end benchmarking without convergence we
+# support a flag to force the router to assign tokens uniformly
+# across the experts. We do this with a custom autograd operation
+# so that PyTorch still executes the full set of router operation.
+class _UniformExpertAssignment(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, num_experts: int):
+ out = torch.arange(x.numel(), dtype=x.dtype, device=x.device)
+ out = torch.remainder(out, num_experts)
+ return out.view(x.shape)
+
+
+_uniform_expert_assignment = _UniformExpertAssignment.apply
+
+
+class LearnedRouter(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+
+ # Learned router parameters.
+ #
+ # NOTE: This weight matrix is not parallelized with expert model
+ # parallelism. Each device needs the entire router weight matrix
+ # so that it can route its batch of data correctly.
+ self.layer = torch.nn.Linear(
+ args.hidden_size,
+ args.moe_num_experts,
+ bias=False,
+ dtype=common.dtype(args),
+ device=args.device,
+ )
+ args.init_method(self.layer.weight)
+
+ def jitter(self, x: torch.Tensor):
+ low: float = 1.0 - self.args.moe_jitter_eps
+ high: float = 1.0 + self.args.moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return low + noise * (high - low)
+
+ def _top_k(self, scores: torch.Tensor):
+ if self.args.moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, self.args.moe_top_k, dim=-1)
+
+ def forward(self, x: torch.Tensor):
+ if self.training and self.args.moe_jitter_eps is not None:
+ x = x * self.jitter(x)
+
+ logits = self.layer(x.view(-1, x.shape[-1]))
+ _save_router_logits(logits, self.args)
+ scores = logits.softmax(dim=-1)
+ expert_weights, expert_indices = self._top_k(scores)
+ if self.args.moe_normalize_expert_weights:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=self.args.moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+
+ expert_indices = (
+ _uniform_expert_assignment(
+ expert_indices,
+ self.args.moe_num_experts,
+ ) if self.args.uniform_expert_assignment else expert_indices
+ )
+ return scores, expert_weights, expert_indices
diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/sharedexpert_registry.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/sharedexpert_registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..5840862f88f370ace5fd49bd0612fc98d186cc49
--- /dev/null
+++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/sharedexpert_registry.py
@@ -0,0 +1,32 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Union
+
+# from megablocks.layers import glu, mlp
+# from megablocks.layers.arguments import Arguments
+from . import glu, mlp
+from .arguments import Arguments
+
+_REGISTRY = {
+ 'mlp': mlp.SharedMLP,
+ 'glu': glu.SharedGLU,
+}
+
+
+def get(args: Arguments) -> Union[mlp.SharedMLP, glu.SharedGLU]:
+ """Returns an SharedMLP for use in a dMoE instance.
+
+ Uses the provided arguments to instantiate the appropriate
+ SharedMLP instance.
+
+ Args:
+ args: propagated Arguments dataclass.
+
+ Returns:
+ An instantiated SharedMLP constructed using the input args.
+ """
+ if args.mlp_type not in _REGISTRY:
+ raise ValueError(f'Unsupported mlp type: {args.mlp_type}')
+
+ return _REGISTRY[args.mlp_type](args)
diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_megablocks_885c7a2.abi3.so b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_megablocks_885c7a2.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..0d80074bae15f25f2ac4a90a2f5511cb5d01309c
--- /dev/null
+++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_megablocks_885c7a2.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7ee1601097f38f9ba908bad9f2844b50f1ffdd52379ec9548c7873b34ee00271
+size 10509584
diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_ops.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ec290dd41dd30ed4551035db04f6c85ee1a0fe0
--- /dev/null
+++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_ops.py
@@ -0,0 +1,9 @@
+import torch
+from . import _megablocks_885c7a2
+ops = torch.ops._megablocks_885c7a2
+
+def add_op_namespace_prefix(op_name: str):
+ """
+ Prefix op by namespace.
+ """
+ return f"_megablocks_885c7a2::{op_name}"
\ No newline at end of file
diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_version.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_version.py
new file mode 100644
index 0000000000000000000000000000000000000000..c55783177af19bc03654c730c4892df8f8532279
--- /dev/null
+++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_version.py
@@ -0,0 +1,6 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+"""The MegaBlocks Version."""
+
+__version__ = '0.11.0.dev0'
diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/backend/__init__.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/backend/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d4e43e9b7e6a9a1c3bde2df34914643ca5d8332
--- /dev/null
+++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/backend/__init__.py
@@ -0,0 +1,2 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/backend/kernels.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/backend/kernels.py
new file mode 100644
index 0000000000000000000000000000000000000000..d20139324883338ddb312e4b05a72056d74491ac
--- /dev/null
+++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/backend/kernels.py
@@ -0,0 +1,557 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import triton
+import triton.language as tl
+
+# Stub triton autotune when testing in a env that does not have CUDA
+# this approach preserves the original code but enables testing without a GPU
+if torch.cuda.is_available() is False:
+ import warnings
+
+ warnings.warn("CUDA is not available. Triton autotuning is disabled.")
+
+ def _no_autotune(*args, **kwargs):
+ def deco(fn):
+ return fn
+ return deco
+
+ triton.autotune = _no_autotune
+
+
+def assert_is_tensor(x, ndim):
+ if x.ndim != ndim:
+ raise ValueError(f'Expected {ndim}-tensor but got {x.ndim}-tensor')
+
+
+def assert_is_matrix(x):
+ assert_is_tensor(x, 2)
+
+
+def assert_is_vector(x):
+ if x.ndim != 1:
+ raise ValueError(f'Expected 1-tensor but got {x.ndim}-tensor')
+
+
+def assert_equal(a, b):
+ if a != b:
+ raise ValueError(f'Expected dimensions to be equal but got {a} and {b}.',)
+
+
+# a: (tokens, hidden_size), real.
+# indices: (tokens * top_k), integer.
+# bin_ids: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+# padded_bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _padded_copy(
+ a,
+ b,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+ A_TO_B: tl.constexpr,
+ SCALE: tl.constexpr,
+):
+ # Our index into array 'a'.
+ index_a = tl.load(indices + tl.program_id(0))
+
+ # One threadblock per row in 'a'. Array 'b' has greater or equal
+ # number of rows since they could be padded.
+ bin_idx = tl.load(bin_ids + tl.program_id(0))
+
+ # Now we know what bin we're assigned to, but we need to know how
+ # many threadblocks were assigned to earlier bins so we can offset
+ # in our bin properly.
+ offset_in_bin = tl.program_id(0)
+ if bin_idx > 0:
+ offset_in_bin -= tl.load(bins + bin_idx - 1)
+
+ # Load the starting index of our bin in array 'b'.
+ index_b = offset_in_bin
+ if bin_idx > 0:
+ index_b += tl.load(padded_bins + bin_idx - 1)
+
+ # Offset the input and output pointers.
+ #
+ # If we're going from A to B, divide the input index to copy
+ # the same input repeatedly. If we're going from B to A we
+ # need to reduce the result. Using atomics is slow, so we
+ # do the reduce step in a second kernel.
+ offset = index_a // TOP_K if A_TO_B else index_a
+ a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS)
+ b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ # Load the scale, if requested.
+ scale = tl.load(weights + index_a) if SCALE else 1
+
+ # Swap the pointers depending on the direction.
+ iptr = a if A_TO_B else b
+ optr = b if A_TO_B else a
+
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ x = tl.load(iptr + offsets, mask=mask)
+ x = x.to(tl.float32) * scale.to(tl.float32)
+
+ tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask)
+
+ offsets += BLOCK_X
+
+
+def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+ assert_equal(bin_ids.shape[0], x.shape[0] * top_k)
+ assert_equal(bins.size(), padded_bins.size())
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ # NOTE: Because of the padding, the output size is dynamic.
+ # We load the final padded bin bound to get the output rows.
+ output_rows = padded_bins[-1].cpu().item()
+ out = torch.zeros((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ x,
+ out,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def gather(x, indices, bin_ids, weights, bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+ assert_equal(bin_ids.shape[0], x.shape[0] * top_k)
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ # NOTE: There is no padding so the output rows equals the
+ # input rows multiplied by top_k.
+ output_rows = x.shape[0] * top_k
+ out = torch.empty((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ x,
+ out,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], bin_ids.shape[0])
+ assert_equal(bins.size(), padded_bins.size())
+
+ if weights is not None:
+ assert_equal(indices.shape[0], weights.shape[0])
+
+ tokens = indices.shape[0] // top_k
+ out = torch.empty((tokens, top_k, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ out,
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=False,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+
+ # Reduce along the top-k dimension, if needed.
+ return out.sum(dim=1) if top_k > 1 else out.view(tokens, x.shape[1])
+
+
+def scatter(x, indices, bin_ids, weights, bins, top_k):
+ return padded_scatter(x, indices, bin_ids, weights, bins, bins, top_k)
+
+
+# x: (tokens, top_k, hidden_size), real
+# grad: (tokens, hidden_size), real.
+# wgrad: (tokens, top_k), real.
+# indices: (tokens * top_k), integer.
+# bin_ids: (tokens * top_k), integer.
+# bins: (num_experts), integer.
+# padded_bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _padded_copy_wgrad(
+ x,
+ grad,
+ wgrad,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+):
+ # Our index into 'tokens * top_k'.
+ index_out = tl.load(indices + tl.program_id(0))
+
+ # One threadblock per row in 'a'. Array 'b' has greater or equal
+ # number of rows since they could be padded.
+ bin_idx = tl.load(bin_ids + tl.program_id(0))
+
+ # Now we know what bin we're assigned to, but we need to know how
+ # many threadblocks were assigned to earlier bins so we can offset
+ # in our bin properly.
+ offset_in_bin = tl.program_id(0)
+ if bin_idx > 0:
+ offset_in_bin -= tl.load(bins + bin_idx - 1)
+
+ # Load the starting index of our bin in array 'x'.
+ index_x = offset_in_bin
+ if bin_idx > 0:
+ index_x += tl.load(padded_bins + bin_idx - 1)
+
+ # Offset the input and output pointers.
+ wgrad += index_out
+ grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS)
+ x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ data = tl.load(x + offsets, mask=mask).to(tl.float32)
+ scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
+ acc += data * scale
+ offsets += BLOCK_X
+
+ # Reduce to get the final result and store.
+ out = tl.sum(acc).to(wgrad.dtype.element_ty)
+ tl.store(wgrad, out)
+
+
+def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_matrix(grad)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], bin_ids.shape[0])
+ assert_equal(bins.size(), padded_bins.size())
+
+ tokens = indices.shape[0] // top_k
+ out = torch.empty((tokens * top_k), dtype=x.dtype, device=x.device)
+ _padded_copy_wgrad[(indices.shape[0],)](
+ x,
+ grad,
+ out,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ TOP_K=top_k,
+ )
+ return out
+
+
+def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k):
+ return padded_scatter_wgrad(x, grad, indices, bin_ids, bins, bins, top_k)
+
+
+# a: (tokens, hidden_size), real.
+# b: (num_experts, expert_capacity, num_columns), real.
+# indices: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _binned_copy(
+ a,
+ b,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+ A_TO_B: tl.constexpr,
+ SCALE: tl.constexpr,
+):
+ # Load our indices into the output.
+ expert_idx = tl.program_id(0)
+ entry_idx = tl.program_id(1)
+
+ # Calculate our offset into the output.
+ index_b = expert_idx * expert_capacity + entry_idx
+
+ # Load the index bounds for our bin and calculate
+ # the number of tokens assigned to our expert.
+ start = 0
+ if expert_idx > 0:
+ start = tl.load(bins + expert_idx - 1)
+ end = tl.load(bins + expert_idx)
+ num_tokens = end - start
+
+ # Calculate our offset into the input. If we don't
+ # have an input exit early.
+ if entry_idx >= num_tokens:
+ return
+ index_a = tl.load(indices + start + entry_idx)
+
+ # Offset the input and output pointers.
+ #
+ # If we're going from A to B, divide the input index to copy
+ # the same input repeatedly. If we're going from B to A we
+ # need to reduce the result. Using atomics is slow, so we
+ # do the reduce step in a second kernel.
+ offset = index_a // TOP_K if A_TO_B else index_a
+ a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS)
+ b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ # Load the scale, if requested.
+ scale = tl.load(weights + index_a) if SCALE else 1
+
+ # Swap the pointers depending on the direction.
+ #
+ # NOTE: We need to zero the output in both directions.
+ iptr = a if A_TO_B else b
+ optr = b if A_TO_B else a
+
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ x = tl.load(iptr + offsets, mask=mask)
+ x = x.to(tl.float32) * scale.to(tl.float32)
+
+ tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask)
+
+ offsets += BLOCK_X
+
+
+def binned_gather(x, indices, weights, bins, expert_capacity, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ num_experts = bins.shape[0]
+ out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
+
+ _binned_copy[(num_experts, expert_capacity)](
+ x,
+ out,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def binned_scatter(x, indices, weights, bins, top_k):
+ # Validate the input shapes.
+ assert_is_tensor(x, 3)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(bins.shape[0], x.shape[0])
+
+ if weights is not None:
+ assert_equal(indices.shape[0], weights.shape[0])
+
+ num_experts, expert_capacity, hidden_size = x.shape
+ tokens = indices.shape[0] // top_k
+ out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
+ _binned_copy[(num_experts, expert_capacity)](
+ out,
+ x,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS=hidden_size,
+ A_TO_B=False,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+
+ # Reduce along the top-k dimension, if needed.
+ return out.sum(dim=1) if top_k > 1 else out.view(tokens, hidden_size)
+
+
+# a: (tokens, hidden_size), real.
+# b: (num_experts, expert_capacity, num_columns), real.
+# indices: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _binned_copy_wgrad(
+ x,
+ grad,
+ wgrad,
+ num_experts,
+ expert_capacity,
+ indices,
+ bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+):
+ # Load our indices into the output.
+ expert_idx = tl.program_id(0)
+ entry_idx = tl.program_id(1)
+
+ # Calculate our offset into the output.
+ index_x = expert_idx * expert_capacity + entry_idx
+
+ # Load the index bounds for our bin and calculate
+ # the number of tokens assigned to our expert.
+ start = 0
+ if expert_idx > 0:
+ start = tl.load(bins + expert_idx - 1)
+ end = tl.load(bins + expert_idx)
+ num_tokens = end - start
+
+ # Calculate our offset into the input. If we don't
+ # have an input exit early.
+ if entry_idx >= num_tokens:
+ return
+ index_out = tl.load(indices + start + entry_idx)
+
+ # Offset the input and output pointers.
+ wgrad += index_out
+ grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS)
+ x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ data = tl.load(x + offsets, mask=mask).to(tl.float32)
+ scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
+ acc += data * scale
+ offsets += BLOCK_X
+
+ # Reduce to get the final result and store.
+ out = tl.sum(acc).to(wgrad.dtype.element_ty)
+ tl.store(wgrad, out)
+
+
+def binned_scatter_wgrad(x, grad, indices, bins, top_k):
+ # Validate the input shapes.
+ assert_is_tensor(x, 3)
+ assert_is_matrix(grad)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(bins.shape[0], x.shape[0])
+
+ num_experts, expert_capacity, hidden_size = x.shape
+ tokens = indices.shape[0] // top_k
+ out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
+ _binned_copy_wgrad[(num_experts, expert_capacity)](
+ x,
+ grad,
+ out,
+ num_experts,
+ expert_capacity,
+ indices,
+ bins,
+ NUM_COLUMNS=hidden_size,
+ TOP_K=top_k,
+ )
+ return out
diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/benchmark_util.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/benchmark_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..02612d95e3ead1175a596e2878fa34b5bf85ad6f
--- /dev/null
+++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/benchmark_util.py
@@ -0,0 +1,35 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import numpy as np
+import torch
+
+
+def log_benchmark(name, arguments, time, std):
+ print('=' * 60)
+ print(f'{name} Benchmark')
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean time = {:.3f}ms, std time = {:.3f}ms'.format(time, std))
+ print('=' * 60)
+
+
+def benchmark_function(fn, iterations=100, warmup=10):
+ # Warmup iterations.
+ for _ in range(warmup):
+ fn()
+
+ times = []
+ for i in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+
+ start.record()
+ fn()
+ end.record()
+
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ return np.mean(times), np.std(times)
diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm/__init__.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b91c8308f0c24f4c4171b6e4f15b6f76dabf295a
--- /dev/null
+++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm/__init__.py
@@ -0,0 +1,2 @@
+from . import ops
+from . import backend
diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm/backend.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm/backend.py
new file mode 100644
index 0000000000000000000000000000000000000000..76037d8039cbfc2f0577275c78e4bc0be762592a
--- /dev/null
+++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm/backend.py
@@ -0,0 +1,33 @@
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# # TODO(tgale): Wrap this in a try-block with better
+# # error message and instructions for building the
+# # c++ operations.
+# import grouped_gemm_backend as backend
+
+# We import the backend operations from the megablocks package as
+# grouped_gemm is vendored in megablocks in this repository.
+# from ... import _ops as backend
+# from megablocks._ops import ops as backend # type: ignore
+from .._ops import ops as backend # type: ignore
+
+def _allocate_output(a, b, batch_sizes, trans_a, trans_b):
+ assert not (trans_a and trans_b)
+ assert batch_sizes.ndim == 1, "Expected 1d tensor for batch_sizes"
+ assert a.ndim == 2, "Expected 2d tensor for 'a'"
+ assert b.ndim == (2 if trans_a else 3)
+
+ shape = (
+ (batch_sizes.shape[0], a.shape[1], b.shape[1])
+ if trans_a else
+ (a.shape[0], (b.shape[1] if trans_b else b.shape[2]))
+ )
+ return torch.empty(*shape, device=a.device, dtype=a.dtype)
+
+def gmm(a, b, batch_sizes, trans_a=False, trans_b=False, c=None):
+ if c is None:
+ c = _allocate_output(a, b, batch_sizes, trans_a, trans_b)
+ backend.gmm(a, b, c, batch_sizes, trans_a, trans_b)
+ return c
diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm/ops.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm/ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b30dd14e23837ea3b12334f4e31337ed9ad2b69
--- /dev/null
+++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm/ops.py
@@ -0,0 +1,33 @@
+from . import backend
+import torch
+
+
+class GroupedGemm(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, a, b, batch_sizes, trans_b):
+ ctx.save_for_backward(a, b, batch_sizes)
+ ctx.trans_b = trans_b
+ return backend.gmm(a, b, batch_sizes, trans_a=False, trans_b=trans_b)
+
+ @staticmethod
+ def backward(ctx, grad):
+ grad = grad.contiguous()
+ a, b, batch_sizes = ctx.saved_tensors
+ trans_b = ctx.trans_b
+
+ agrad = None
+ if ctx.needs_input_grad[0]:
+ agrad = backend.gmm(
+ grad, b, batch_sizes, trans_a=False, trans_b=not trans_b)
+
+ bgrad = None
+ if ctx.needs_input_grad[1]:
+ lhs, rhs = (grad, a) if trans_b else (a, grad)
+ bgrad = backend.gmm(
+ lhs, rhs, batch_sizes, trans_a=True, trans_b=False)
+ return agrad, bgrad, None, None
+
+
+def gmm(a, b, batch_sizes, trans_b=False):
+ return GroupedGemm.apply(a, b, batch_sizes, trans_b)
diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm_util.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d6d49fc46b0a57ad46e4179df3cc1ac2a24f7ae
--- /dev/null
+++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm_util.py
@@ -0,0 +1,31 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+import warnings
+
+_grouped_gemm_is_available: bool = False
+try:
+ # import grouped_gemm
+ pass
+ _grouped_gemm_is_available = True
+except ImportError as error:
+ warnings.warn('Grouped GEMM not available.')
+
+
+def grouped_gemm_is_available():
+ return _grouped_gemm_is_available
+
+
+def assert_grouped_gemm_is_available():
+ msg = (
+ 'Grouped GEMM not available. Please run '
+ '`pip install git+https://github.com/tgale96/grouped_gemm@main`.',
+ )
+ assert _grouped_gemm_is_available, msg
+
+
+# backend = grouped_gemm.backend if grouped_gemm_is_available() else None
+# ops = grouped_gemm.ops if grouped_gemm_is_available() else None
+
+
+from .grouped_gemm import backend as ops
+from .grouped_gemm import ops as backend
diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b835ac5f6929edb8b547f373212388f34be3868
--- /dev/null
+++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers.py
@@ -0,0 +1,1225 @@
+import torch
+import torch.distributed as dist
+
+from typing import Optional, Any, TYPE_CHECKING
+
+from . import _layers
+from . import ops
+
+# Conditional import for meta kernel registration
+if TYPE_CHECKING:
+
+ def register_fake(fn):
+ return lambda name: fn
+
+else:
+ try:
+ from torch.library import register_fake
+ except ImportError:
+ try:
+ from torch.library import impl_abstract as register_fake
+ except ImportError:
+ # Fallback for older PyTorch versions
+ def register_fake(op_name):
+ def decorator(fn):
+ return fn
+
+ return decorator
+
+
+# Meta kernel implementations for torch.compile compatibility
+def _install_meta_kernels():
+ """Install meta kernels for existing MegaBlocks operations"""
+
+ # Create wrapper functions that check for compilation and return meta tensors
+
+ # Patch ops.sort
+ if hasattr(ops, "sort"):
+ original_sort = ops.sort
+
+ def sort_with_meta(x, end_bit=None):
+ if torch.compiler.is_compiling():
+ print("Using meta kernel for sort")
+ # Meta implementation - return tensors with correct shape/dtype/device
+ return torch.empty_like(x), torch.empty_like(x)
+ # print("Using original sort kernel")
+ return original_sort(x, end_bit)
+
+ ops.sort = sort_with_meta
+
+ # Patch ops.histogram
+ if hasattr(ops, "histogram"):
+ original_histogram = ops.histogram
+
+ def histogram_with_meta(x, max_val):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty((max_val,), dtype=torch.int32, device=x.device)
+ return original_histogram(x, max_val)
+
+ ops.histogram = histogram_with_meta
+
+ # Patch ops.inclusive_cumsum
+ if hasattr(ops, "inclusive_cumsum"):
+ original_inclusive_cumsum = ops.inclusive_cumsum
+
+ def inclusive_cumsum_with_meta(x, dim):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty_like(x)
+ return original_inclusive_cumsum(x, dim)
+
+ ops.inclusive_cumsum = inclusive_cumsum_with_meta
+
+ # Patch ops.binned_gather
+ if hasattr(ops, "binned_gather"):
+ original_binned_gather = ops.binned_gather
+
+ def binned_gather_with_meta(x, indices, bins, bin_size, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - output shape based on bin_size
+ if x.dim() >= 2:
+ hidden_size = x.size(-1)
+ return torch.empty(
+ (bin_size, x.size(1), hidden_size),
+ dtype=x.dtype,
+ device=x.device,
+ )
+ else:
+ return torch.empty((bin_size,), dtype=x.dtype, device=x.device)
+ return original_binned_gather(x, indices, bins, bin_size, top_k)
+
+ ops.binned_gather = binned_gather_with_meta
+
+ # Patch ops.binned_scatter
+ if hasattr(ops, "binned_scatter"):
+ original_binned_scatter = ops.binned_scatter
+
+ def binned_scatter_with_meta(x, indices, weights, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - typically reduces to 2D
+ if x.dim() >= 3:
+ return torch.empty(
+ (x.size(1), x.size(2)), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty_like(x)
+ return original_binned_scatter(x, indices, weights, bins, top_k)
+
+ ops.binned_scatter = binned_scatter_with_meta
+
+ # Patch ops.gather
+ if hasattr(ops, "gather"):
+ original_gather = ops.gather
+
+ def gather_with_meta(x, indices, bin_ids, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ if x.dim() >= 2:
+ hidden_size = x.size(-1)
+ return torch.empty(
+ (indices.numel(), hidden_size), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty(indices.shape, dtype=x.dtype, device=x.device)
+ return original_gather(x, indices, bin_ids, bins, top_k)
+
+ ops.gather = gather_with_meta
+
+ # Patch ops.scatter
+ if hasattr(ops, "scatter"):
+ original_scatter = ops.scatter
+
+ def scatter_with_meta(x, indices, bin_ids, weights, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - restore sequence shape
+ seq_len = (
+ indices.size(0) // top_k
+ if indices.numel() > 0 and top_k > 0
+ else x.size(0)
+ )
+ if x.dim() >= 2:
+ return torch.empty(
+ (seq_len, x.size(-1)), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty((seq_len,), dtype=x.dtype, device=x.device)
+ return original_scatter(x, indices, bin_ids, weights, bins, top_k)
+
+ ops.scatter = scatter_with_meta
+
+ # Patch ops.replicate
+ if hasattr(ops, "replicate"):
+ original_replicate = ops.replicate
+
+ def replicate_with_meta(x, bins, num_outputs):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty(
+ (x.shape[0], num_outputs), dtype=x.dtype, device=x.device
+ )
+ return original_replicate(x, bins, num_outputs)
+
+ ops.replicate = replicate_with_meta
+
+ # Patch ops.repeat (if it's a regular function)
+ if hasattr(ops, "repeat"):
+ original_repeat = ops.repeat
+
+ def repeat_with_meta(x, repeats):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ if isinstance(repeats, (tuple, list)):
+ new_shape = list(x.shape)
+ for i, rep in enumerate(repeats):
+ if i < len(new_shape):
+ new_shape[i] *= rep
+ return torch.empty(new_shape, dtype=x.dtype, device=x.device)
+ else:
+ new_shape = [x.size(0) * repeats] + list(x.shape[1:])
+ return torch.empty(new_shape, dtype=x.dtype, device=x.device)
+ return original_repeat(x, repeats)
+
+ ops.repeat = repeat_with_meta
+
+
+# Install meta kernels on import
+try:
+ _install_meta_kernels()
+except Exception as e:
+ # If meta kernel installation fails, continue without them
+ # torch.compile may not work but the library will still function
+ import warnings
+
+ warnings.warn(
+ f"Failed to install meta kernels for torch.compile support: {e}", UserWarning
+ )
+
+
+# Set the expert model parallel attributes on a tensor
+def set_expert_model_parallel_attributes(
+ tensor: torch.Tensor,
+ is_parallel: bool,
+):
+ assert not hasattr(tensor, "expert_model_parallel")
+ setattr(tensor, "expert_model_parallel", is_parallel)
+
+
+# Get the expert model parallel attributes from a tensor
+def expert_sharding_degree(
+ world_size: int,
+ moe_num_experts: int,
+) -> int:
+ esd = min(world_size, moe_num_experts)
+ if (moe_num_experts % esd) != 0:
+ raise ValueError(f"Cannot shard {moe_num_experts} experts {esd} ways.")
+ return esd
+
+
+# Calculate the hidden sharding degree based on world size and expert sharding degree
+def hidden_sharding_degree(
+ world_size: int,
+ moe_num_experts: int,
+ ffn_hidden_size: int,
+) -> int:
+ esd = expert_sharding_degree(world_size, moe_num_experts)
+ hsd = world_size // esd
+ if (ffn_hidden_size % hsd) != 0:
+ raise ValueError(f"Cannot shard {ffn_hidden_size} features {hsd} ways.")
+ if (esd * hsd) != world_size:
+ raise ValueError(
+ f"Invalid sharding. expert_sharding_degree ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size})."
+ )
+ return hsd
+
+
+# Calculate the number of experts per rank based on world size and expert sharding degree
+def experts_per_rank(
+ moe_num_experts: int,
+ world_size: int,
+) -> int:
+ return moe_num_experts // expert_sharding_degree(world_size, moe_num_experts)
+
+
+# Calculate the number of features per rank based on ffn hidden size and hidden sharding degree
+def features_per_rank(
+ ffn_hidden_size: int, world_size: int, moe_num_experts: int
+) -> int:
+ return ffn_hidden_size // hidden_sharding_degree(
+ world_size, moe_num_experts, ffn_hidden_size
+ )
+
+
+# Apply jitter to the input tensor
+def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor:
+ low = 1.0 - moe_jitter_eps
+ high = 1.0 + moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return x * (low + noise * (high - low))
+
+
+# Compute the top-k scores from the logits
+def compute_top_k(scores: torch.Tensor, moe_top_k: int):
+ if moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, moe_top_k, dim=-1)
+
+
+# Route tokens to experts and compute expert weights and indices
+def route_tokens(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: torch.Tensor,
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ if training and moe_jitter_eps is not None:
+ x = apply_jitter(x, moe_jitter_eps)
+
+ x_flat = x.view(-1, x.shape[-1])
+ logits = torch.nn.functional.linear(x_flat, router_weight, router_bias)
+ expert_weights, expert_indices = compute_top_k(logits, moe_top_k)
+ expert_weights = expert_weights.softmax(dim=-1)
+ if moe_normalize_expert_weights is not None:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+ if uniform_expert_assignment:
+ expert_indices = _layers.router._uniform_expert_assignment(
+ expert_indices,
+ moe_num_experts,
+ )
+
+ return logits, expert_weights, expert_indices
+
+
+# Scale the gradient of the weights
+def scale_grad(
+ w: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+) -> torch.Tensor:
+ if gradient_scale is None:
+ return w
+ return _layers.mlp.scale_gradient(w, gradient_scale)
+
+
+# Forward pass for the MLP layer
+def mlp_forward(
+ x: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ limit: float = 7.0,
+):
+ # Scale weights
+ w1 = scale_grad(w1, gradient_scale)
+ w2 = scale_grad(w2, gradient_scale)
+ w1_bias = scale_grad(w1_bias, gradient_scale)
+ w2_bias = scale_grad(w2_bias, gradient_scale)
+
+ # Resolve dtensors
+ w1 = _layers.mlp.resolve_dtensor(w1)
+ w2 = _layers.mlp.resolve_dtensor(w2)
+ w1_bias = _layers.mlp.resolve_dtensor(w1_bias)
+ w2_bias = _layers.mlp.resolve_dtensor(w2_bias)
+
+ # Forward pass
+ gate_up = torch.bmm(x, w1) + w1_bias[..., None, :]
+ gate, up = gate_up[..., ::2], gate_up[..., 1::2]
+ gate = gate.clamp(min=None, max=limit)
+ up = up.clamp(min=-limit, max=limit)
+ glu = gate * torch.sigmoid(gate * alpha)
+ next_states = torch.bmm(((up + 1) * glu), w2)
+ next_states += w2_bias[..., None, :]
+ return next_states
+
+# Shared expert MLP forward pass
+def shared_mlp_forward(
+ x: torch.Tensor,
+ up_proj_weight: torch.Tensor,
+ down_proj_weight: torch.Tensor,
+ up_proj_bias: Optional[torch.Tensor] = None,
+ down_proj_bias: Optional[torch.Tensor] = None,
+ activation_fn: Optional[Any] = None,
+ gradient_scale: Optional[float] = None,
+) -> torch.Tensor:
+ # Default activation function
+ if activation_fn is None:
+ activation_fn = torch.nn.functional.gelu
+
+ # Scale weights
+ up_proj_weight = scale_grad(up_proj_weight, gradient_scale)
+ down_proj_weight = scale_grad(down_proj_weight, gradient_scale)
+ if up_proj_bias is not None:
+ up_proj_bias = scale_grad(up_proj_bias, gradient_scale)
+ if down_proj_bias is not None:
+ down_proj_bias = scale_grad(down_proj_bias, gradient_scale)
+
+ # Resolve dtensors
+ up_proj_weight = _layers.mlp.resolve_dtensor(up_proj_weight)
+ down_proj_weight = _layers.mlp.resolve_dtensor(down_proj_weight)
+ if up_proj_bias is not None:
+ up_proj_bias = _layers.mlp.resolve_dtensor(up_proj_bias)
+ if down_proj_bias is not None:
+ down_proj_bias = _layers.mlp.resolve_dtensor(down_proj_bias)
+
+ # Up projection
+ x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias)
+
+ # Activation
+ x = activation_fn(x)
+
+ # Down projection
+ x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias)
+
+ return x
+
+
+# Combine outputs from shared expert and regular experts
+def combine_expert_shared_outputs(
+ shared_expert_out: torch.Tensor,
+ expert_out: torch.Tensor,
+ shared_expert_weighted_sum: bool = False,
+ moe_top_k: int = 1,
+) -> torch.Tensor:
+ if shared_expert_weighted_sum:
+ # Weighted sum based on number of experts used
+ total_experts = moe_top_k + 1
+ shared_weight = 1.0 / total_experts
+ expert_weight = moe_top_k / total_experts
+ return shared_expert_out * shared_weight + expert_out * expert_weight
+ else:
+ # Simple addition
+ return shared_expert_out + expert_out
+
+
+# Global variable to store load balancing loss
+_LOAD_BALANCING_LOSS = []
+
+
+def save_load_balancing_loss(loss):
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.append(loss)
+
+
+def get_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ return _LOAD_BALANCING_LOSS
+
+
+def clear_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.clear()
+
+
+def batched_load_balancing_loss(args):
+ if args.moe_loss_weight == 0:
+ return 0.0
+
+ tokens_per_expert, expert_scores = zip(*get_load_balancing_loss())
+ num_layers_per_pipeline_stage = args.num_layers // args.pipeline_model_parallel_size
+ if args.num_layers_per_virtual_pipeline_stage is not None:
+ num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage
+
+ if len(tokens_per_expert) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f"Expected {num_layers_per_pipeline_stage} token_per_experts "
+ f"but found {len(tokens_per_expert)}.\nnum_layers = "
+ f"{args.num_layers}\npipeline_model_parallel_size = "
+ f"{args.pipeline_model_parallel_size}\n"
+ "num_layers_per_virtual_pipeline_stage"
+ f" = {args.num_layers_per_virtual_pipeline_stage}",
+ )
+ if len(expert_scores) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f"Expected {num_layers_per_pipeline_stage} expert_scores "
+ f"but found {len(tokens_per_expert)}.\nnum_layers = "
+ f"{args.num_layers}\npipeline_model_parallel_size = "
+ f"{args.pipeline_model_parallel_size}\n"
+ "num_layers_per_virtual_pipeline_stage"
+ f" = {args.num_layers_per_virtual_pipeline_stage}",
+ )
+
+ # Verify the shape of the tokens_per_expert and expert_scores tensors.
+ assert all(
+ (x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert)
+ )
+
+ tokens = expert_scores[0].shape[0]
+ assert all(
+ (
+ (
+ x.ndim == 2
+ and x.shape[1] == args.moe_num_experts
+ and x.shape[0] == tokens
+ )
+ for x in expert_scores
+ )
+ )
+
+ # Concatenate the contributions of each layer and convert to
+ # the correct types and formats for the dot product.
+ expert_scores = torch.cat(expert_scores, dim=1)
+ if args.moe_lbl_in_fp32:
+ expert_scores = expert_scores.float()
+ if tokens != 0:
+ expert_scores = expert_scores.mean(dim=0)
+ else:
+ expert_scores = expert_scores.sum(dim=0)
+ tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype)
+
+ expected_values = num_layers_per_pipeline_stage * args.moe_num_experts
+ assert tokens_per_expert.numel() == expected_values
+ assert expert_scores.numel() == expected_values
+
+ # Calculate the total scale across all factors.
+ #
+ # loss_weight * num_experts / (num_layers * tokens * top_k)
+ scale_numerator = args.moe_num_experts * args.moe_loss_weight
+ scale_denominator = args.num_layers * tokens * args.moe_top_k
+ scale = scale_numerator / scale_denominator
+ return scale * torch.dot(tokens_per_expert, expert_scores)
+
+
+# Calculate the expert capacity based on tokens, top_k, number of experts,
+# expert parallel group, capacity factor, and whether expert model parallelism is used.
+def expert_capacity(
+ tokens: int,
+ top_k: int,
+ num_experts: int,
+ expert_parallel_group: int,
+ moe_capacity_factor: float,
+ moe_expert_model_parallelism: bool,
+) -> int:
+ world_size = (
+ dist.get_world_size(expert_parallel_group)
+ if moe_expert_model_parallelism
+ else 1
+ )
+
+ tokens_per_expert = top_k * tokens * world_size / num_experts
+ return int(moe_capacity_factor * tokens_per_expert)
+
+
+def load_balancing_loss(
+ tokens_per_expert: torch.Tensor,
+ expert_scores: torch.Tensor,
+ top_k: int,
+ num_experts: int,
+):
+ assert len(expert_scores.size()) == 2
+ tokens, num_experts = expert_scores.size()
+ assert num_experts == num_experts
+ assert len(tokens_per_expert.size()) == 1
+ (num_experts,) = tokens_per_expert.size()
+ assert num_experts == num_experts
+ scale = num_experts / (tokens * top_k)
+ return scale * torch.dot(
+ tokens_per_expert.to(expert_scores.dtype),
+ expert_scores.mean(dim=0),
+ )
+
+
+def indices_and_bins(
+ top_expert: torch.Tensor,
+ sort_end_bit: int,
+ num_experts: int,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ top_expert = top_expert.int()
+
+ # Ensure contiguous memory layout
+ top_expert = top_expert.contiguous()
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(top_expert.device):
+ output = ops.sort(top_expert, sort_end_bit)
+ bin_ids, indices = output
+ tokens_per_expert = ops.histogram(top_expert, num_experts)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+
+ bins = bins.view(1) if not len(bins.size()) else bins
+ return indices, bin_ids, bins, tokens_per_expert
+
+
+def expert_capacity_fn(
+ tokens: int,
+ top_k: int,
+ num_experts: int,
+ expert_parallel_group: torch.distributed.ProcessGroup,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+) -> int:
+ world_size = (
+ dist.get_world_size(expert_parallel_group)
+ if moe_expert_model_parallelism
+ else 1
+ )
+ tokens_per_expert = top_k * tokens * world_size / num_experts
+ return int(moe_capacity_factor * tokens_per_expert)
+
+
+def permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ top_k,
+ w1,
+ w2,
+ w1_bias,
+ w2_bias,
+ gradient_scale,
+ alpha,
+):
+ # Route tokens to experts
+ x = x.view(-1, x.shape[-1])
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(x.device):
+ x = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
+
+ # Expert computation
+ x = mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale, alpha)
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(x.device):
+ # Route tokens back
+ out = ops.binned_scatter(x, indices, expert_weights, bins, top_k)
+ return out
+
+
+def forward_once(
+ x: torch.Tensor,
+ expert_weights: torch.Tensor,
+ top_experts: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ top_k: int = 4,
+ num_experts: int = 128,
+ expert_parallel_group: int = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ mlp_impl: Optional[str] = None,
+):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
+ top_experts, sort_end_bit, num_experts
+ )
+
+ # Calculate expert capacity
+ sl, bs, _ = x.size()
+
+ expert_capacity = expert_capacity_fn(
+ sl * bs,
+ top_k,
+ num_experts,
+ expert_parallel_group,
+ moe_capacity_factor,
+ moe_expert_model_parallelism,
+ )
+
+ if expert_capacity == 0:
+ expert_capacity = torch.max(tokens_per_expert).item()
+
+ x = permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ top_k,
+ w1,
+ w2,
+ w1_bias,
+ w2_bias,
+ gradient_scale,
+ alpha,
+ )
+ return x, tokens_per_expert
+
+
+def parallel_forward_once(
+ x: torch.Tensor,
+ expert_weights: torch.Tensor,
+ top_experts: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ top_k: int = 4,
+ num_experts: int = 128,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = True,
+ hidden_size: int = 1152,
+ mlp_impl: Optional[str] = "grouped",
+):
+ # Flatten inputs
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+
+ # TODO: remove debugging var
+ # my_rank = dist.get_rank(expert_parallel_group) if expert_parallel_group else 0
+
+ with torch.no_grad():
+ # Step 1: Local permutation setup
+ indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
+ top_experts, sort_end_bit, num_experts
+ )
+
+ # Calculate sharding parameters
+ world_size = dist.get_world_size(expert_parallel_group)
+ hidden_sharding_deg = hidden_sharding_degree(
+ world_size, num_experts, hidden_size
+ )
+ experts_per_rank_val = experts_per_rank(num_experts, world_size)
+
+ # Replicate token counts for hidden sharding
+ repeated_tokens_per_expert = ops.repeat(
+ tokens_per_expert, (hidden_sharding_deg,)
+ )
+
+ # Exchange token counts across devices
+ parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert)
+
+ # Ensure CUB knows which device to use
+ tpe_handle = dist.all_to_all_single(
+ parallel_tokens_per_expert,
+ repeated_tokens_per_expert,
+ group=expert_parallel_group,
+ async_op=True,
+ )
+
+ # Step 2: Local permutation - group tokens by target device
+ x = x.view(-1, x.shape[-1]) # [sl * bs, hs]
+ x = ops.gather(x, indices, bin_ids, bins, top_k)
+
+ # Step 3: Compute communication counts and exchange tokens
+ with torch.no_grad():
+ tpe_handle.wait()
+
+ # Reshape for per-device calculations
+ repeated_tokens_per_expert = repeated_tokens_per_expert.view(
+ world_size, experts_per_rank_val
+ )
+ parallel_tokens_per_expert = parallel_tokens_per_expert.view(
+ world_size, experts_per_rank_val
+ )
+
+ # Calculate send/recv counts
+ send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1).tolist()
+ # recv_counts = parallel_tokens_per_expert.cpu().sum(dim=-1).tolist()
+ parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
+ recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1).tolist()
+ tokens_received = sum(recv_counts)
+
+ # Replicate for hidden sharding
+ x = ops.repeat(x, (hidden_sharding_deg, 1))
+
+ # Cross-device token exchange
+ parallel_x, parallel_x_handle = _layers.all_to_all.all_to_all(
+ x, recv_counts, send_counts, expert_parallel_group, async_op=True
+ )
+
+ with torch.no_grad():
+ # Step 4: Setup for local expert computation
+ replicate_bins = ops.inclusive_cumsum(parallel_tokens_per_expert.flatten(), 0)
+ replicate_bins = (
+ replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins
+ )
+
+ # Create expert indices for received tokens
+ parallel_top_expert = torch.remainder(
+ torch.arange(
+ num_experts * hidden_sharding_deg,
+ dtype=torch.int32,
+ device=indices.device,
+ ),
+ experts_per_rank_val,
+ )
+ parallel_top_expert = ops.replicate(
+ parallel_top_expert.unsqueeze(dim=0),
+ replicate_bins,
+ tokens_received,
+ ).flatten()
+
+ # Sort tokens by expert assignment
+ parallel_bin_ids, parallel_indices = ops.sort(
+ parallel_top_expert,
+ sort_end_bit,
+ )
+
+ # Calculate bins for local experts
+ parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
+ dim=0, dtype=torch.int
+ )
+ parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = (
+ parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins
+ )
+
+ # Calculate expert capacity
+ expert_capacity = expert_capacity_fn(
+ tokens_received,
+ top_k,
+ experts_per_rank_val,
+ expert_parallel_group,
+ moe_capacity_factor,
+ moe_expert_model_parallelism,
+ )
+ if expert_capacity == 0:
+ expert_capacity = torch.max(parallel_tokens_per_expert).item()
+
+ # Locally permute the tokens and perform the expert computation.
+ # Block to make sure that the cross-device permutation is complete.
+ if mlp_impl == "grouped":
+ # GroupedMLP requires counts on CPU. We can use the tensor already
+ # moved to CPU for the prior all_to_all, which avoids an extra
+ # device synchronization.
+ parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+
+ # Step 5: Expert computation
+ parallel_x_handle.wait()
+
+ parallel_x = permute_and_compute(
+ parallel_x,
+ parallel_tokens_per_expert,
+ parallel_indices,
+ parallel_bin_ids,
+ None, # expert_weights
+ parallel_bins,
+ expert_capacity,
+ top_k=1,
+ w1=w1,
+ w2=w2,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ )
+
+ # Step 6: Reverse communication - send results back
+ x, _ = _layers.all_to_all.all_to_all(
+ parallel_x, send_counts, recv_counts, expert_parallel_group
+ )
+
+ # Step 7: Reduce across hidden sharding dimension
+ shape = (hidden_sharding_deg, -1, hidden_size)
+ x = x.view(shape).sum(dim=0)
+
+ # Step 8: Final local unpermutation
+ x = ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
+
+ return x, tokens_per_expert.flatten()
+
+
+def moe_forward(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: Optional[torch.Tensor],
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+ w1: torch.Tensor = None,
+ w2: torch.Tensor = None,
+ w1_bias: torch.Tensor = None,
+ w2_bias: torch.Tensor = None,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ forward_fn: Any = None,
+ hidden_size: int = None,
+ mlp_impl: str = "grouped",
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+
+ # Route tokens to experts
+ logits, expert_weights, expert_indices = route_tokens(
+ x,
+ router_weight,
+ router_bias,
+ moe_top_k,
+ moe_num_experts,
+ moe_jitter_eps,
+ moe_normalize_expert_weights,
+ uniform_expert_assignment,
+ training,
+ )
+
+ # Create router scores for output
+ router_scores = (
+ torch.zeros_like(logits)
+ .scatter_(1, expert_indices, expert_weights)
+ .transpose(0, 1)
+ )
+
+ in_shape = x.size()
+
+ # Prepare forward function arguments
+ forward_args = {
+ "x": x,
+ "expert_weights": expert_weights,
+ "top_experts": expert_indices,
+ "w1": w1,
+ "w2": w2,
+ "w1_bias": w1_bias,
+ "w2_bias": w2_bias,
+ "gradient_scale": gradient_scale,
+ "alpha": alpha,
+ "sort_end_bit": sort_end_bit,
+ "top_k": moe_top_k,
+ "num_experts": moe_num_experts,
+ "expert_parallel_group": expert_parallel_group,
+ "moe_capacity_factor": moe_capacity_factor,
+ "moe_expert_model_parallelism": moe_expert_model_parallelism,
+ "mlp_impl": mlp_impl,
+ }
+
+ # Add hidden_size for parallel forward
+ if moe_expert_model_parallelism and hidden_size is not None:
+ forward_args["hidden_size"] = hidden_size
+ elif moe_expert_model_parallelism and hidden_size is None:
+ # Infer hidden_size from input shape
+ forward_args["hidden_size"] = x.shape[-1]
+
+ # Compute expert outputs
+ x, tokens_per_expert = forward_fn(**forward_args)
+
+ # Save load balancing loss if needed
+ moe_loss_weight = 0.0 # Can be made configurable
+ if training and moe_loss_weight > 0:
+ save_load_balancing_loss((tokens_per_expert, logits))
+
+ # Restore original shape
+ x = x.view(in_shape)
+
+ return x, expert_weights, router_scores
+
+
+def moe_forward_with_shared_expert(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: Optional[torch.Tensor],
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+ w1: torch.Tensor = None,
+ w2: torch.Tensor = None,
+ w1_bias: torch.Tensor = None,
+ w2_bias: torch.Tensor = None,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ forward_fn: Any = None,
+ hidden_size: int = None,
+ mlp_impl: str = "grouped",
+ # Shared expert parameters
+ shared_up_proj_weight: Optional[torch.Tensor] = None,
+ shared_down_proj_weight: Optional[torch.Tensor] = None,
+ shared_up_proj_bias: Optional[torch.Tensor] = None,
+ shared_down_proj_bias: Optional[torch.Tensor] = None,
+ shared_expert_weighted_sum: bool = False,
+ shared_activation_fn: Optional[Any] = None,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+
+ # First, compute regular MoE forward pass
+ expert_out, expert_weights, router_scores = moe_forward(
+ x=x,
+ router_weight=router_weight,
+ router_bias=router_bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=training,
+ w1=w1,
+ w2=w2,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=moe_expert_model_parallelism,
+ forward_fn=forward_fn,
+ hidden_size=hidden_size,
+ mlp_impl=mlp_impl,
+ )
+
+ # If shared expert weights provided, compute shared expert output
+ if shared_up_proj_weight is not None and shared_down_proj_weight is not None:
+ shared_expert_out = shared_mlp_forward(
+ x=x,
+ up_proj_weight=shared_up_proj_weight,
+ down_proj_weight=shared_down_proj_weight,
+ up_proj_bias=shared_up_proj_bias,
+ down_proj_bias=shared_down_proj_bias,
+ activation_fn=shared_activation_fn,
+ gradient_scale=gradient_scale,
+ )
+
+ # Combine expert outputs
+ combined_out = combine_expert_shared_outputs(
+ shared_expert_out=shared_expert_out,
+ expert_out=expert_out,
+ shared_expert_weighted_sum=shared_expert_weighted_sum,
+ moe_top_k=moe_top_k,
+ )
+
+ return combined_out, expert_weights, router_scores
+
+ # Return regular MoE output if no shared expert
+ return expert_out, expert_weights, router_scores
+
+
+def create_shared_expert_weights(
+ hidden_size: int,
+ shared_expert_hidden_size: int,
+ device: torch.device,
+ dtype: torch.dtype,
+ init_method: Any,
+ output_layer_init_method: Any = None,
+) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
+
+ if output_layer_init_method is None:
+ output_layer_init_method = init_method
+
+ # Create weight tensors
+ up_proj_weight = torch.empty(
+ shared_expert_hidden_size,
+ hidden_size,
+ device=device,
+ dtype=dtype,
+ )
+ down_proj_weight = torch.empty(
+ hidden_size,
+ shared_expert_hidden_size,
+ device=device,
+ dtype=dtype,
+ )
+
+ # Initialize weights
+ init_method(up_proj_weight)
+ output_layer_init_method(down_proj_weight)
+
+ # No bias by default
+ return up_proj_weight, down_proj_weight, None, None
+
+
+# HACK: Extract device_mesh from pre-hook closure - required for transformers integration
+# This exists because device_mesh is trapped in hook closures with no model attribute
+# Fragile - breaks if hook structure changes or Python internals change
+# TODO: Replace with a more robust solution when available
+def get_device_mesh(model):
+ # Extract device_mesh from child's unused pre_hook closure
+ try:
+ # Find the pre-hook that contains 'device_mesh' in its closure
+ hook = next(
+ h
+ for h in model.experts._forward_pre_hooks.values()
+ if "device_mesh" in h.__code__.co_freevars
+ )
+ # Extract the device_mesh from the closure
+ return hook.__closure__[
+ hook.__code__.co_freevars.index("device_mesh")
+ ].cell_contents
+ except Exception:
+ return None
+
+
+class MegaBlocksMoeMLP(torch.nn.Module):
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ gradient_scale = getattr(self.experts, "gradient_scale", None)
+ alpha = getattr(self.experts, "alpha", 1.0)
+ moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+ uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
+
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = get_device_mesh(self)
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
+
+ has_parallel = (
+ expert_parallel_group is not None
+ and dist.is_initialized()
+ and dist.get_world_size(expert_parallel_group) > 1
+ )
+ forward_fn = parallel_forward_once if has_parallel else forward_once
+
+ sort_end_bit = max(
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
+ )
+ mlp_impl = getattr(self, "mlp_impl", "grouped")
+ output, expert_weights_out, *_ = moe_forward(
+ x=x,
+ router_weight=self.router.weight,
+ router_bias=self.router.bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=self.training,
+ w1=self.experts.gate_up_proj,
+ w2=self.experts.down_proj,
+ w1_bias=self.experts.gate_up_proj_bias,
+ w2_bias=self.experts.down_proj_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=has_parallel,
+ forward_fn=forward_fn,
+ hidden_size=self.experts.hidden_size,
+ mlp_impl=mlp_impl,
+ )
+ return output, expert_weights_out
+
+
+# Export main classes
+__all__ = ["MegaBlocksMoeMLP", "MegaBlocksMoeMLPWithSharedExpert"]
+
+
+class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
+
+ def __init__(self):
+ super().__init__()
+ # Shared expert weights will be set by the user
+ self.shared_up_proj_weight = None
+ self.shared_down_proj_weight = None
+ self.shared_up_proj_bias = None
+ self.shared_down_proj_bias = None
+ self.shared_expert_weighted_sum = False
+ self.shared_activation_fn = None
+
+ def set_shared_expert_weights(
+ self,
+ up_proj_weight: torch.Tensor,
+ down_proj_weight: torch.Tensor,
+ up_proj_bias: Optional[torch.Tensor] = None,
+ down_proj_bias: Optional[torch.Tensor] = None,
+ weighted_sum: bool = False,
+ activation_fn: Optional[Any] = None,
+ ):
+ self.shared_up_proj_weight = up_proj_weight
+ self.shared_down_proj_weight = down_proj_weight
+ self.shared_up_proj_bias = up_proj_bias
+ self.shared_down_proj_bias = down_proj_bias
+ self.shared_expert_weighted_sum = weighted_sum
+ self.shared_activation_fn = activation_fn
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ gradient_scale = getattr(self.experts, "gradient_scale", None)
+ alpha = getattr(self.experts, "alpha", 1.0)
+ moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+ uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
+
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = get_device_mesh(self)
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
+
+ has_parallel = (
+ expert_parallel_group is not None
+ and dist.is_initialized()
+ and dist.get_world_size(expert_parallel_group) > 1
+ )
+ forward_fn = parallel_forward_once if has_parallel else forward_once
+
+ sort_end_bit = max(
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
+ )
+ mlp_impl = getattr(self, "mlp_impl", "grouped")
+
+ output, expert_weights_out, *_ = moe_forward_with_shared_expert(
+ x=x,
+ router_weight=self.router.weight,
+ router_bias=self.router.bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=self.training,
+ w1=self.experts.gate_up_proj,
+ w2=self.experts.down_proj,
+ w1_bias=self.experts.gate_up_proj_bias,
+ w2_bias=self.experts.down_proj_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=has_parallel,
+ forward_fn=forward_fn,
+ hidden_size=self.experts.hidden_size,
+ mlp_impl=mlp_impl,
+ # Shared expert parameters
+ shared_up_proj_weight=self.shared_up_proj_weight,
+ shared_down_proj_weight=self.shared_down_proj_weight,
+ shared_up_proj_bias=self.shared_up_proj_bias,
+ shared_down_proj_bias=self.shared_down_proj_bias,
+ shared_expert_weighted_sum=self.shared_expert_weighted_sum,
+ shared_activation_fn=self.shared_activation_fn,
+ )
+ return output, expert_weights_out
diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/__init__.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b944080df810d0b0cfc571f3009b0098a651f9b7
--- /dev/null
+++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/__init__.py
@@ -0,0 +1,35 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from .binned_gather import binned_gather
+from .binned_scatter import binned_scatter
+from .cumsum import exclusive_cumsum, inclusive_cumsum
+from .gather import gather
+from .histogram import histogram
+from .padded_gather import padded_gather
+from .padded_scatter import padded_scatter
+from .repeat import repeat
+from .replicate import replicate
+from .round_up import round_up
+from .scatter import scatter
+from .sort import sort
+from .sum import sum
+from .topology import topology
+
+__all__ = [
+ 'binned_gather',
+ 'binned_scatter',
+ 'exclusive_cumsum',
+ 'inclusive_cumsum',
+ 'gather',
+ 'histogram',
+ 'padded_gather',
+ 'padded_scatter',
+ 'repeat',
+ 'replicate',
+ 'round_up',
+ 'scatter',
+ 'sort',
+ 'sum',
+ 'topology',
+]
diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/all_to_all_benchmark.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/all_to_all_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c939818edca3345f6344bbc7cef07ffe3cd0181
--- /dev/null
+++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/all_to_all_benchmark.py
@@ -0,0 +1,63 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import torch.distributed as dist
+
+# from megablocks import benchmark_util
+# from megablocks.layers.all_to_all import all_to_all
+
+from .. import benchmark_util
+from .._layers.all_to_all import all_to_all
+
+_ALL_TO_ALL_BENCHMARK = (
+ (8, 1024),
+ (16, 1024),
+ (32, 1024),
+ (64, 1024),
+ (128, 1024),
+ (256, 1024),
+ (512, 1024),
+ (1024, 1024),
+ (2 * 1024, 1024),
+ (4 * 1024, 1024),
+ (8 * 1024, 1024),
+ (16 * 1024, 1024),
+ (32 * 1024, 1024),
+ (64 * 1024, 1024),
+ (128 * 1024, 1024),
+ (256 * 1024, 1024),
+ (512 * 1024, 1024),
+ (1024 * 1024, 1024),
+)
+
+
+def benchmark_all_to_all(group, sl, hs):
+ world_size = dist.get_world_size(group)
+ assert (sl % world_size) == 0
+ send_recv_sizes = [sl // world_size] * world_size
+
+ x = torch.randn((sl, hs)).cuda().half()
+
+ details = {
+ 'world_size': world_size,
+ 'message_size (B)': send_recv_sizes[0] * hs * 2, # 2B elements.
+ }
+
+ def benchmark():
+ return all_to_all(x, send_recv_sizes, send_recv_sizes, group)
+
+ time, std = benchmark_util.benchmark_function(benchmark)
+
+ if dist.get_rank(group) == 0:
+ benchmark_util.log_benchmark('All-To-All', details, time, std)
+
+
+if __name__ == '__main__':
+ assert dist.is_available()
+ group = dist.init_process_group(backend='nccl')
+ local_rank = dist.get_rank(group)
+ torch.cuda.set_device(local_rank)
+
+ for args in _ALL_TO_ALL_BENCHMARK:
+ benchmark_all_to_all(group, *args)
diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/binned_gather.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/binned_gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..189a7fa3518d660f29ea32e7a04827164af98d60
--- /dev/null
+++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/binned_gather.py
@@ -0,0 +1,37 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for binned_gather kernel.
+class BinnedGatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bins: torch.Tensor,
+ bin_size: int,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bins)
+ ctx.top_k = top_k
+ return kernels.binned_gather(x, indices, None, bins, bin_size, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ indices, bins = ctx.saved_tensors
+ out = kernels.binned_scatter(grad, indices, None, bins, ctx.top_k)
+ return out, None, None, None, None
+
+
+binned_gather = BinnedGatherOp.apply
diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/binned_scatter.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/binned_scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb937c0c106662ce8108c1cb926f8f063b163d3d
--- /dev/null
+++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/binned_scatter.py
@@ -0,0 +1,59 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for binned_scatter kernel.
+class BinnedScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ):
+ assert len(x.size()) == 3
+ ctx.bin_size = x.size(1)
+ ctx.top_k = top_k
+
+ # TODO(tgale): Don't save 'x' for backwards if we don't need to
+ # calculate the gradient w.r.t. 'weights'.
+ ctx.save_for_backward(x, indices, weights, bins)
+ return kernels.binned_scatter(x, indices, weights, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ x, indices, weights, bins = ctx.saved_tensors
+ out = kernels.binned_gather(
+ grad,
+ indices,
+ weights,
+ bins,
+ ctx.bin_size,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[2]:
+ wgrad = kernels.binned_scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bins,
+ ctx.top_k,
+ )
+ return out, None, wgrad, None, None
+
+
+binned_scatter = BinnedScatterOp.apply
diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/cumsum.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/cumsum.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2b7572391e20045d335cf7337246e8a9b9f57ef
--- /dev/null
+++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/cumsum.py
@@ -0,0 +1,52 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ # import megablocks_ops as ops # type: ignore
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrappers for cumsum kernels.
+# NOTE: Does not support gradients.
+class ExclusiveCumsumOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, dim: int):
+ if len(x.size()) == 1:
+ x = x.view([1, -1])
+ out = torch.empty_like(x)
+ ops.exclusive_cumsum(x, 1, out)
+ return out.squeeze()
+ out = torch.empty_like(x)
+ ops.exclusive_cumsum(x, dim, out)
+ return out
+
+
+exclusive_cumsum = ExclusiveCumsumOp.apply
+
+
+class InclusiveCumsumOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, dim: int) -> torch.Tensor:
+ if len(x.size()) == 1:
+ x = x.view([1, -1])
+ out = torch.empty_like(x)
+ ops.inclusive_cumsum(x, 1, out)
+ return out.squeeze()
+ out = torch.empty_like(x)
+ ops.inclusive_cumsum(x, dim, out)
+ return out
+
+
+inclusive_cumsum = InclusiveCumsumOp.apply
diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/gather.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1f87c1e7bed8d3589dd790805234976e0b05898
--- /dev/null
+++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/gather.py
@@ -0,0 +1,38 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for gather kernel.
+class GatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bin_ids, bins)
+ ctx.top_k = top_k
+ return kernels.gather(x, indices, bin_ids, None, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+
+ indices, bin_ids, bins = ctx.saved_tensors
+ out = kernels.scatter(grad, indices, bin_ids, None, bins, ctx.top_k)
+ return out, None, None, None, None, None
+
+
+gather = GatherOp.apply
diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/histogram.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/histogram.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b3f058ec373cbba7555704fb5e4212c3cc75d9d
--- /dev/null
+++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/histogram.py
@@ -0,0 +1,27 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for histogram kernel.
+# NOTE: Does not support gradients.
+class HistogramOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, max_val: float):
+ return ops.histogram(x, max_val)
+
+
+histogram = HistogramOp.apply
diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/histogram_benchmark.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/histogram_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..c57b7bf8228e01237236748147368b09ffdf8072
--- /dev/null
+++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/histogram_benchmark.py
@@ -0,0 +1,78 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import numpy as np
+import torch
+from absl.testing import parameterized
+
+from .. import ops
+
+_HISTOGRAM_TESTS = (
+ (16384, torch.int32, 2),
+ (16384, torch.int32, 4),
+ (16384, torch.int32, 8),
+ (16384, torch.int32, 16),
+ (16384, torch.int32, 32),
+ (16384, torch.int32, 64),
+ (16384, torch.int32, 128),
+ (16384, torch.int32, 256),
+)
+
+
+def benchmark_function(fn, iterations=10):
+ # Run once to get rid of startup overhead.
+ fn()
+ times = []
+ for _ in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+ start.record()
+ fn()
+ end.record()
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ times = np.array(times)
+ return times.mean(), times.std(), times.max(), times.min()
+
+
+def log_benchmark(arguments, mean_t, std_t):
+ print('=' * 60)
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t))
+ print('=' * 60)
+
+
+class HistogramBenchmark(parameterized.TestCase):
+
+ @parameterized.parameters(*_HISTOGRAM_TESTS)
+ def testHistogram(self, n, dtype, max_val):
+ x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
+
+ mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.histogram(x, max_val),)
+ arguments = {
+ 'n': n,
+ 'dtype': dtype,
+ 'max_val': max_val,
+ }
+ log_benchmark(arguments, mean_t, std_t)
+
+ @parameterized.parameters(*_HISTOGRAM_TESTS)
+ def testTorchHistogram(self, n, dtype, max_val):
+ x = torch.randint(0, 128, (n,)).cuda().to(dtype)
+
+ mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.histc(x, max_val, 0, max_val - 1),)
+ arguments = {
+ 'n': n,
+ 'dtype': dtype,
+ 'max_val': max_val,
+ }
+ log_benchmark(arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/matmul_benchmark.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/matmul_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ccc5dcec5e9a663794fad944c45285869c4d1c1
--- /dev/null
+++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/matmul_benchmark.py
@@ -0,0 +1,415 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+
+# import stk
+
+# try:
+# import stk
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/ops/matmul_benchmark.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+
+# Calling tensor.t() calls tensor.transpose(0, 1) which calls
+# torch.as_strided(...). Circumvent this chain to avoid an overhead
+# this adds.
+def transpose_view(x):
+ return torch.as_strided(
+ x,
+ (x.shape[1], x.shape[0]),
+ (x.stride()[1], x.stride()[0]),
+ )
+
+
+_MATMUL_TESTS = (
+ (64 * 1024, 512, 2048, 64),
+ (32 * 1024, 768, 3072, 64),
+ (8 * 1024, 1024, 4096, 64),
+ (4 * 2048, 4096, 4 * 4096, 4),
+)
+
+
+def log_benchmark(name, arguments, time, std, flops):
+ benchmark_util.log_benchmark(name, arguments, time, std)
+ print('flops = {:.2f}B'.format(flops / 1e9))
+ print('throughput = {:.2f}T'.format(flops / 1e9 / time))
+ print('=' * 60)
+
+
+class MatmulBenchmark(parameterized.TestCase):
+
+ def build_sparse_matrix(self, x, padded_bins, fhs, ne):
+ blocking = 128
+ padded_tokens, _ = x.size()
+ assert padded_tokens % blocking == 0
+ assert fhs % blocking == 0
+
+ # Offsets for the sparse matrix. All rows have the
+ # same number of nonzero blocks dictated by the
+ # dimensionality of a single expert.
+ block_rows = padded_tokens // blocking
+ blocks_per_row = fhs // blocking
+ offsets = torch.arange(
+ 0,
+ block_rows * blocks_per_row + 1,
+ blocks_per_row,
+ dtype=torch.int32,
+ device=x.device,
+ )
+
+ # Indices for the sparse matrix. The indices for
+ # the intermediate matrix are dynamic depending
+ # on the mapping of tokens to experts.
+ column_indices = ops.topology(
+ padded_bins,
+ blocking,
+ block_rows,
+ blocks_per_row,
+ )
+ data = torch.empty(
+ column_indices.numel(),
+ blocking,
+ blocking,
+ dtype=torch.float16,
+ device=x.device,
+ )
+ shape = (padded_tokens, fhs * ne)
+ row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
+ return stk.Matrix(shape, data, row_indices, column_indices, offsets)
+
+ def build_input_matrix(self, sl, hs, ne):
+ x = torch.randn((sl, hs)).cuda().half()
+
+ # Assign tokens to experts uniformly.
+ top_expert = torch.arange(0, sl).cuda().int() % ne
+
+ bin_ids, indices = ops.sort(top_expert)
+ tokens_per_expert = ops.histogram(top_expert, ne)
+ padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ out = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, 1)
+ return out, padded_bins
+
+ def build_weight_matrix(self, ne, hs, fhs):
+ return torch.randn((hs, ne * fhs)).cuda().half()
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear0_Fwd_SDD_NT(self, sl, hs, fhs, ne):
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
+ w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+ topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+ w = transpose_view(w)
+
+ def benchmark():
+ return stk.ops.sdd(x, w, topo)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '0::Fwd::SDD::NT',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * fhs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear0_GradX_DSD_NN(self, sl, hs, fhs, ne):
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
+ w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+ topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+
+ def benchmark():
+ return stk.ops.dsd(topo, w)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '0::GradX::DSD::NN',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * fhs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear0_GradW_DSD_TN(self, sl, hs, fhs, ne):
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
+ topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+ topo = topo.t()
+
+ def benchmark():
+ return stk.ops.dsd(topo, x)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '0::GradW::DSD::TN',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * fhs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear1_Fwd_DSD_NN(self, sl, hs, fhs, ne):
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
+ w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+ x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+
+ def benchmark():
+ return stk.ops.dsd(x, w)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '1::Fwd::DSD::NN',
+ arguments,
+ mean_t,
+ std_t,
+ x.nnz * hs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear1_GradX_SDD_NT(self, sl, hs, fhs, ne):
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
+ w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+ x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+ out = stk.ops.dsd(x, w)
+ w = transpose_view(w)
+
+ def benchmark():
+ return stk.ops.sdd(out, w, x)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '1::GradX::SDD::NT',
+ arguments,
+ mean_t,
+ std_t,
+ x.nnz * hs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear1_GradW_DSD_TN(self, sl, hs, fhs, ne):
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
+ w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+ x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+ out = stk.ops.dsd(x, w)
+ x = x.t()
+
+ def benchmark():
+ return stk.ops.dsd(x, out)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '1::GradW::DSD::TN',
+ arguments,
+ mean_t,
+ std_t,
+ x.nnz * hs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear0_Fwd_DDD_NT(self, sl, hs, fhs, ne):
+ assert (sl % ne) == 0
+ x = torch.randn((ne, sl // ne, hs)).cuda().half()
+ w = torch.randn((ne, hs, fhs)).cuda().half()
+
+ w = w.transpose(1, 2).contiguous()
+ w = w.transpose(1, 2)
+
+ def benchmark():
+ return torch.bmm(x, w)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '0::Fwd:DDD::NT',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * fhs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear0_GradX_DDD_NN(self, sl, hs, fhs, ne):
+ assert (sl % ne) == 0
+ x = torch.randn((ne, sl // ne, hs)).cuda().half()
+ w = torch.randn((ne, hs, fhs)).cuda().half()
+ out = torch.bmm(x, w)
+ w = w.transpose(1, 2).contiguous()
+
+ def benchmark():
+ return torch.bmm(out, w)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '0:GradX:DDD::NN',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * fhs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear0_GradW_DDD_TN(self, sl, hs, fhs, ne):
+ assert (sl % ne) == 0
+ x = torch.randn((ne, sl // ne, hs)).cuda().half()
+ w = torch.randn((ne, hs, fhs)).cuda().half()
+ out = torch.bmm(x, w)
+ out = out.transpose(1, 2)
+
+ def benchmark():
+ return torch.bmm(out, x)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '0:GradW:DDD::TN',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * fhs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear1_Fwd_DDD_NN(self, sl, hs, fhs, ne):
+ assert (sl % ne) == 0
+ x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+ w = torch.randn((ne, fhs, hs)).cuda().half()
+
+ def benchmark():
+ return torch.bmm(x, w)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '1::Fwd::DDD::NN',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * hs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear1_GradX_DDD_NT(self, sl, hs, fhs, ne):
+ assert (sl % ne) == 0
+ x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+ w = torch.randn((ne, fhs, hs)).cuda().half()
+ out = torch.bmm(x, w)
+ w = torch.transpose(w, 1, 2)
+
+ def benchmark():
+ return torch.bmm(out, w)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '1::GradX::DDD::NT',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * hs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear1_GradW_DDD_TN(self, sl, hs, fhs, ne):
+ assert (sl % ne) == 0
+ x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+ w = torch.randn((ne, fhs, hs)).cuda().half()
+ out = torch.bmm(x, w)
+ x = torch.transpose(x, 1, 2)
+
+ def benchmark():
+ return torch.bmm(x, out)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '1::GradW::DDD::TN',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * hs * 2,
+ )
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/padded_gather.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/padded_gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1cf4047c9494394d2a3884ba8830179013db7ff
--- /dev/null
+++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/padded_gather.py
@@ -0,0 +1,55 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for padded_gather kernel.
+class PaddedGatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bin_ids, bins, padded_bins)
+ ctx.top_k = top_k
+ return kernels.padded_gather(
+ x,
+ indices,
+ bin_ids,
+ None,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+
+ indices, bin_ids, bins, padded_bins = ctx.saved_tensors
+ out = kernels.padded_scatter(
+ grad,
+ indices,
+ bin_ids,
+ None,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+ return out, None, None, None, None, None
+
+
+padded_gather = PaddedGatherOp.apply
diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/padded_scatter.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/padded_scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..61e021b81497e472cda5d72bdac557a0ca92d262
--- /dev/null
+++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/padded_scatter.py
@@ -0,0 +1,98 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for padded_scatter kernel.
+class PaddedScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+ ):
+ maybe_x = [x] if ctx.needs_input_grad[3] else []
+ ctx.save_for_backward(
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ *maybe_x,
+ )
+ ctx.top_k = top_k
+ ctx.x_shape = x.shape
+ return kernels.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ saved_tensors = ctx.saved_tensors
+
+ indices, bin_ids, weights, bins, padded_bins = saved_tensors[:5]
+ dgrad = None
+ if ctx.needs_input_grad[0]:
+ dgrad = kernels.padded_gather(
+ grad,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[3]: # need wgrad
+ x = saved_tensors[-1]
+ wgrad = kernels.padded_scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+ return dgrad, None, None, wgrad, None, None, None, None
+
+
+def padded_scatter(
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+):
+ return PaddedScatterOp.apply(
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..c575cfe7487d346ba9ec18bbb7ef17f2eb77ec51
--- /dev/null
+++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py
@@ -0,0 +1,66 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import torch
+from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+_PADDED_SCATTER_BENCHMARK = (
+ # dMoE-Medium, 8-way EMP.
+ (1024 * 16, 1024, 8, 4),
+ # dMoE-Medium, post-all-to-all.
+ (1024 * 16 * 4, 1024, 8, 1),
+)
+
+
+class PaddedScatterTest(parameterized.TestCase):
+
+ @parameterized.parameters(*_PADDED_SCATTER_BENCHMARK)
+ def testPaddedScatter(self, sl, hs, ne, top_k):
+ # Create the data and indices.
+ x = torch.randn((sl, hs)).cuda().half()
+
+ # Randomly assign tokens to experts.
+ top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int()
+ bin_ids, indices = ops.sort(top_expert)
+ tokens_per_expert = ops.histogram(top_expert, ne)
+ padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+
+ # Sample weights for the scatter reduce.
+ weights = torch.rand((sl * top_k,)).cuda().half()
+
+ # Gather the data to prepare for backwards.
+ x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
+
+ def benchmark():
+ return ops.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ time, std = benchmark_util.benchmark_function(benchmark)
+ benchmark_util.log_benchmark(
+ 'Padded Scatter',
+ {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'num_experts': ne,
+ 'top_k': top_k,
+ },
+ time,
+ std,
+ )
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/permute_benchmark.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/permute_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..6536eeeae402659a087e5c51ef9840627af56501
--- /dev/null
+++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/permute_benchmark.py
@@ -0,0 +1,149 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import torch
+from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+_PERMUTE_TESTS = (
+ (16384, 768, 2),
+ (16384, 768, 4),
+ (16384, 768, 8),
+ (16384, 768, 16),
+ (16384, 768, 32),
+ (16384, 768, 64),
+ (16384, 768, 128),
+ (16384 * 8, 768, 2),
+ (16384 * 8, 768, 4),
+ (16384 * 8, 768, 8),
+ (16384 * 8, 768, 16),
+ (16384 * 8, 768, 32),
+ (16384 * 8, 768, 64),
+ (16384 * 8, 768, 128),
+)
+
+
+class PermuteBenchmark(parameterized.TestCase):
+
+ @parameterized.parameters(*_PERMUTE_TESTS)
+ def testBinnedGather(self, sl, hs, ne):
+ # NOTE: Capacity factor == 1.
+ ec = sl // ne
+
+ # Create the data and indices.
+ x = torch.randn((sl, hs)).cuda().half()
+ top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+ bin_ids, indices = ops.sort(top_expert)
+ tokens_per_expert = ops.histogram(indices, ne)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+
+ def benchmark():
+ return ops.binned_gather(x, indices, bins, ec)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'num_experts': ne,
+ }
+ benchmark_util.log_benchmark('BinnedGather', arguments, mean_t, std_t)
+
+ @parameterized.parameters(*_PERMUTE_TESTS)
+ def testBinnedScatter(self, sl, hs, ne):
+ # NOTE: Capacity factor == 1.
+ ec = sl // ne
+
+ # Create the data and indices.
+ x = torch.randn((sl, hs)).cuda().half()
+ top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+ bin_ids, indices = ops.sort(top_expert)
+ tokens_per_expert = ops.histogram(indices, ne)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ x = ops.binned_gather(x, indices, bins, ec)
+
+ def benchmark():
+ return ops.binned_scatter(x, indices, bins)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'num_experts': ne,
+ }
+ benchmark_util.log_benchmark('BinnedScatter', arguments, mean_t, std_t)
+
+ @parameterized.parameters(*_PERMUTE_TESTS)
+ def testPaddedGather(self, sl, hs, ne):
+ # Create the data and indices.
+ x = torch.randn((sl, hs)).cuda().half()
+
+ # Randomly assign tokens to experts.
+ top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+ bin_ids, indices = ops.sort(top_expert)
+ tokens_per_expert = ops.histogram(top_expert, ne)
+ padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+
+ def benchmark():
+ return ops.padded_gather(x, indices, bin_ids, bins, padded_bins)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'num_experts': ne,
+ }
+ benchmark_util.log_benchmark('PaddedGather', arguments, mean_t, std_t)
+
+ @parameterized.parameters(*_PERMUTE_TESTS)
+ def testPaddedScatter(self, sl, hs, ne):
+ # Create the data and indices.
+ x = torch.randn((sl, hs)).cuda().half()
+
+ # Randomly assign tokens to experts.
+ top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+ bin_ids, indices = ops.sort(top_expert)
+ tokens_per_expert = ops.histogram(top_expert, ne)
+ padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins)
+
+ def benchmark():
+ return ops.padded_scatter(x, indices, bin_ids, bins, padded_bins)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'num_experts': ne,
+ }
+ benchmark_util.log_benchmark('PaddedScatter', arguments, mean_t, std_t)
+
+ @parameterized.parameters(*_PERMUTE_TESTS)
+ def testCopy(self, sl, hs, ne):
+ # NOTE: Capacity factor == 1.
+ # ec = sl // ne
+
+ # Create the data and indices.
+ x = torch.randn((sl, hs)).cuda().half()
+ y = x.clone()
+
+ def benchmark():
+ return y.copy_(x)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'num_experts': ne,
+ }
+ benchmark_util.log_benchmark('Copy', arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/repeat.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/repeat.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e9e09de5f857d51cd758ab30b2f3a846d6f9275
--- /dev/null
+++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/repeat.py
@@ -0,0 +1,10 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+
+def repeat(x: torch.Tensor, tiling: torch.Size):
+ if all((t == 1 for t in tiling)):
+ return x
+ return x.repeat(*tiling)
diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/replicate.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/replicate.py
new file mode 100644
index 0000000000000000000000000000000000000000..26daf0eede330603a4b8ea7167faf1411d07ca93
--- /dev/null
+++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/replicate.py
@@ -0,0 +1,36 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for replicate kernel.
+class ReplicateOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, bins: torch.Tensor, num_outputs: int):
+ ctx.save_for_backward(bins)
+ out = torch.empty((x.shape[0], num_outputs), dtype=x.dtype, device=x.device)
+ ops.replicate_forward(x, bins, out)
+ return out
+
+ @staticmethod
+ def backward(ctx: Any, grad: torch.Tensor):
+ bins, = ctx.saved_tensors
+ out = torch.empty((grad.shape[0], bins.shape[0]), dtype=grad.dtype, device=grad.device)
+ ops.replicate_backward(grad, bins, out)
+ return out, None, None
+
+
+replicate = ReplicateOp.apply
diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/round_up.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/round_up.py
new file mode 100644
index 0000000000000000000000000000000000000000..6cf6bc873c9f448c5fa9126ebcfd66e8688002af
--- /dev/null
+++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/round_up.py
@@ -0,0 +1,14 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+
+def round_up(x: torch.Tensor, value: int):
+ assert isinstance(value, int)
+ assert x.dtype == torch.int32
+
+ # TODO(tgale): If this becomes and issue
+ # do this in a custom kernel. We only expect
+ # to use this on arrays of less than 1k elements.
+ return torch.div(x + (value - 1), value, rounding_mode='trunc') * value
diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/scatter.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4605d9b46f387761b070352365f223dbfe69d47
--- /dev/null
+++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/scatter.py
@@ -0,0 +1,72 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Optional
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for scatter kernel.
+class ScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ) -> torch.Tensor:
+ maybe_x = [x] if ctx.needs_input_grad[3] else []
+ ctx.save_for_backward(indices, bin_ids, weights, bins, *maybe_x)
+ ctx.top_k = top_k
+ ctx.x_shape = x.shape
+ return kernels.scatter(x, indices, bin_ids, weights, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ saved_tensors = ctx.saved_tensors
+
+ indices, bin_ids, weights, bins = saved_tensors[:4]
+ dgrad = None
+ if ctx.needs_input_grad[0]:
+ dgrad = kernels.gather(
+ grad,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[3]: # need wgrad
+ x = saved_tensors[-1]
+ wgrad = kernels.scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bin_ids,
+ bins,
+ ctx.top_k,
+ )
+ return dgrad, None, None, wgrad, None, None, None
+
+
+def scatter(
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+) -> Optional[torch.Tensor]:
+ return ScatterOp.apply(x, indices, bin_ids, weights, bins, top_k)
diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/sort.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/sort.py
new file mode 100644
index 0000000000000000000000000000000000000000..bda3bf64283e39533c2eae3627e76bb2d0262c9f
--- /dev/null
+++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/sort.py
@@ -0,0 +1,38 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Optional, Tuple
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+_BITS_FOR_DTYPE = {
+ torch.int16: 16,
+ torch.int32: 32,
+ torch.int64: 64,
+}
+
+
+# Autograd wrapper for sort kernel.
+# NOTE: Does not support gradients.
+class SortOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, end_bit: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
+ if end_bit is None:
+ end_bit = _BITS_FOR_DTYPE[x.dtype]
+ x_out = torch.empty_like(x)
+ iota_out = torch.empty_like(x)
+ ops.sort(x, end_bit, x_out, iota_out)
+ return (x_out, iota_out)
+
+
+sort = SortOp.apply
diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/sort_benchmark.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/sort_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..a92ff957d4c552c6e61d9279a7989795472af7b7
--- /dev/null
+++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/sort_benchmark.py
@@ -0,0 +1,85 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import numpy as np
+import torch
+from absl.testing import parameterized
+
+from .. import ops
+
+_SORT_TESTS = (
+ (16384, torch.int32, None),
+ (16384, torch.int32, 2),
+ (16384, torch.int32, 128),
+)
+
+_BASELINE_SORT_TESTS = ((16384,),)
+
+
+def numpy_dtype(dtype):
+ types = {
+ torch.int16: np.int16,
+ torch.int32: np.int32,
+ torch.int64: np.int64,
+ }
+ return types[dtype]
+
+
+def benchmark_function(fn, iterations=10):
+ # Run once to get rid of startup overhead.
+ fn()
+ times = []
+ for _ in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+ start.record()
+ fn()
+ end.record()
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ times = np.array(times)
+ return times.mean(), times.std(), times.max(), times.min()
+
+
+def log_benchmark(arguments, mean_t, std_t):
+ print('=' * 60)
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t))
+ print('=' * 60)
+
+
+class SortBenchmark(parameterized.TestCase):
+
+ @parameterized.parameters(*_SORT_TESTS)
+ def testSort(self, n, dtype, max_val):
+ if max_val is None:
+ max_val = np.iinfo(numpy_dtype(dtype)).max
+ end_bit = int(np.ceil(np.log2(max_val)))
+ x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
+
+ mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.sort(x, end_bit),)
+ arguments = {
+ 'n': n,
+ 'dtype': dtype,
+ 'max_val': max_val,
+ }
+ log_benchmark(arguments, mean_t, std_t)
+
+ @parameterized.parameters(*_BASELINE_SORT_TESTS)
+ def testTorchSort(self, n):
+ x = torch.randint(0, 128, (n,)).cuda().to(torch.int32)
+
+ mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.sort(x))
+ arguments = {
+ 'n': n,
+ }
+ log_benchmark(arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/stk_autocast.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/stk_autocast.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a3626e5e0eec51339c95a448bca84be14a2ca93
--- /dev/null
+++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/stk_autocast.py
@@ -0,0 +1,39 @@
+# vendored from
+# https://github.com/stanford-futuredata/stk/blob/736313768ef697ce13a0594a41b2512a0fbc9884/stk/backend/autocast.py
+import functools
+import torch
+
+
+def _is_eligible(x):
+ return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64)
+
+
+def _cast(x, dtype):
+ if isinstance(x, torch.Tensor) and _is_eligible(x):
+ return x.to(dtype)
+ elif isinstance(x, map):
+ return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()}
+ elif isinstance(x, list) or isinstance(x, tuple):
+ return type(x)(map(lambda y: _cast(y, dtype), x))
+ return x
+
+
+def custom_fwd(fwd):
+ """Wrap a custom autograd function that always uses autocast dtype."""
+
+ @functools.wraps(fwd)
+ def decorate_fwd(*args, **kwargs):
+ if torch.is_autocast_enabled():
+ with torch.autocast(device_type="cuda", enabled=False):
+ dtype = torch.get_autocast_gpu_dtype()
+ return fwd(*_cast(args, dtype), **_cast(kwargs, dtype))
+ return fwd(*args, **kwargs)
+ return decorate_fwd
+
+
+def custom_bwd(bwd):
+ @functools.wraps(bwd)
+ def decorate_bwd(*args, **kwargs):
+ with torch.autocast(device_type="cuda", enabled=False):
+ return bwd(*args, **kwargs)
+ return decorate_bwd
\ No newline at end of file
diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/sum.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/sum.py
new file mode 100644
index 0000000000000000000000000000000000000000..e00c1aa68e1193f5b72f75a2edc37de8d505facc
--- /dev/null
+++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/sum.py
@@ -0,0 +1,9 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+import torch
+
+
+def sum(x: torch.Tensor, dim: int = 0):
+ if x.shape[dim] == 1:
+ return x.squeeze(dim=dim)
+ return x.sum(dim=dim)
diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/topology.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/topology.py
new file mode 100644
index 0000000000000000000000000000000000000000..76a50d3164db20534b099dcb4d8487a7aef25d15
--- /dev/null
+++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/topology.py
@@ -0,0 +1,45 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for topology kernel.
+# NOTE: Does not support gradients.
+class TopologyOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(
+ ctx: Any,
+ padded_bins: torch.Tensor,
+ block_size: int,
+ output_block_rows: int,
+ output_block_columns: int,
+ ):
+ out = torch.empty(
+ output_block_rows * output_block_columns,
+ dtype=torch.int16,
+ device=padded_bins.device,
+ )
+ ops.indices(
+ padded_bins,
+ block_size,
+ output_block_rows,
+ output_block_columns,
+ out,
+ )
+ return out
+
+
+topology = TopologyOp.apply
diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/__init__.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..73c40c267b1c3f4949e9c957a5d2c9f682dfc1a6
--- /dev/null
+++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/__init__.py
@@ -0,0 +1,7 @@
+# import stk.random
+# import stk.ops
+# from stk.matrix import Matrix
+
+from . import random
+from . import ops
+from .matrix import Matrix
diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/backend/__init__.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/backend/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/backend/autocast.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/backend/autocast.py
new file mode 100644
index 0000000000000000000000000000000000000000..97f6e919a60f3fd579ed0215031008d14111dc96
--- /dev/null
+++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/backend/autocast.py
@@ -0,0 +1,37 @@
+import functools
+import torch
+
+
+def _is_eligible(x):
+ return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64)
+
+
+def _cast(x, dtype):
+ if isinstance(x, torch.Tensor) and _is_eligible(x):
+ return x.to(dtype)
+ elif isinstance(x, map):
+ return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()}
+ elif isinstance(x, list) or isinstance(x, tuple):
+ return type(x)(map(lambda y: _cast(y, dtype), x))
+ return x
+
+
+def custom_fwd(fwd):
+ """Wrap a custom autograd function that always uses autocast dtype."""
+
+ @functools.wraps(fwd)
+ def decorate_fwd(*args, **kwargs):
+ if torch.is_autocast_enabled():
+ with torch.autocast(device_type="cuda", enabled=False):
+ dtype = torch.get_autocast_gpu_dtype()
+ return fwd(*_cast(args, dtype), **_cast(kwargs, dtype))
+ return fwd(*args, **kwargs)
+ return decorate_fwd
+
+
+def custom_bwd(bwd):
+ @functools.wraps(bwd)
+ def decorate_bwd(*args, **kwargs):
+ with torch.autocast(device_type="cuda", enabled=False):
+ return bwd(*args, **kwargs)
+ return decorate_bwd
diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/backend/sputnik.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/backend/sputnik.py
new file mode 100644
index 0000000000000000000000000000000000000000..220c947bc1e932e8c77cc30f4069e9930f1aa962
--- /dev/null
+++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/backend/sputnik.py
@@ -0,0 +1,316 @@
+import torch
+
+from ..backend import triton_kernels as backend
+from ..backend.autocast import custom_bwd, custom_fwd
+
+
+def _standardize_shape(x, transpose):
+ if transpose:
+ return torch.Size((x[1], x[0]))
+ return x
+
+
+def _sparse_transpose(x):
+ return (torch.Size((x[0][1], x[0][0])), ) + x[1:]
+
+
+def _transpose_helper(x, transpose):
+ if isinstance(x, torch.Tensor):
+ return x.t() if transpose else x
+ if transpose:
+ x = _sparse_transpose(x)
+ return x + (transpose,)
+
+
+def _wrap(x):
+ if isinstance(x, torch.Tensor):
+ return (x,)
+ return x
+
+
+def _is_transposed(x):
+ return (not x.is_contiguous() and
+ x.stride()[0] == 1 and
+ x.stride()[1] == x.size()[0])
+
+
+def _call_helper(op, out, a, b, trans_a, trans_b):
+ args = (_wrap(_transpose_helper(a, trans_a)) +
+ _wrap(_transpose_helper(b, trans_b)))
+ if isinstance(out, tuple):
+ args = args + out
+ return op(*args)
+
+
+def _preprocess_inputs(lhs, rhs, dy):
+ if isinstance(lhs, torch.Tensor) and _is_transposed(lhs):
+ lhs = lhs.t()
+ if isinstance(rhs, torch.Tensor) and _is_transposed(rhs):
+ rhs = rhs.t()
+ if (isinstance(dy, torch.Tensor) and
+ not dy.is_contiguous() and
+ not _is_transposed(dy)):
+ dy = dy.contiguous()
+ if isinstance(dy, tuple) and not dy[1].is_contiguous():
+ dy = (dy[0], dy[1].contiguous()) + dy[2:]
+ return lhs, rhs, dy
+
+
+def _postprocess_outputs(x, transpose, grad):
+ if isinstance(x, torch.Tensor) and transpose:
+ return grad.t()
+ return grad
+
+
+def _lhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs):
+ lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy)
+
+ a, b = (rhs, dy) if trans_lhs else (dy, rhs)
+ trans_a = trans_lhs and trans_rhs
+ trans_b = trans_lhs or not trans_rhs
+ out = _call_helper(op, lhs, a, b, trans_a, trans_b)
+ return _postprocess_outputs(lhs, trans_lhs, out)
+
+
+def _rhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs):
+ lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy)
+
+ a, b = (dy, lhs) if trans_rhs else (lhs, dy)
+ trans_a = not trans_lhs or trans_rhs
+ trans_b = trans_lhs and trans_rhs
+ out = _call_helper(op, rhs, a, b, trans_a, trans_b)
+ return _postprocess_outputs(rhs, trans_rhs, out)
+
+
+class DSD(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs):
+ ctx.save_for_backward(data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ rhs)
+ ctx.shape = _standardize_shape(shape, transpose_a)
+ ctx.transpose_a = transpose_a
+
+ out = torch.empty(
+ (shape[0], rhs.size()[1]),
+ dtype=rhs.dtype,
+ device=rhs.device)
+
+ backend.dsd(shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs,
+ out)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs = (ctx.shape,) + saved_tensors[:-1]
+ rhs = saved_tensors[-1]
+ trans_a = ctx.transpose_a
+ trans_b = _is_transposed(rhs)
+
+ ddata = None
+ if ctx.needs_input_grad[1]:
+ ddata = _lhs_gradient(sdd,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ drhs = None
+ if ctx.needs_input_grad[-1]:
+ op = dds if trans_b else dsd
+ drhs = _rhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return None, ddata, None, None, None, None, None, None, None, drhs
+
+
+dsd = DSD.apply
+
+
+class DDS(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b):
+ ctx.save_for_backward(lhs,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t)
+ ctx.shape = _standardize_shape(shape, transpose_b)
+ ctx.transpose_b = transpose_b
+ out = torch.empty((lhs.size()[0], shape[1]),
+ dtype=lhs.dtype,
+ device=lhs.device)
+ backend.dds(lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b,
+ out)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs = saved_tensors[0]
+ rhs = (ctx.shape,) + saved_tensors[1:]
+ trans_a = _is_transposed(lhs)
+ trans_b = ctx.transpose_b
+
+ dlhs = None
+ if ctx.needs_input_grad[0]:
+ op = dsd if trans_a else dds
+ dlhs = _lhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ ddata = None
+ if ctx.needs_input_grad[2]:
+ ddata = _rhs_gradient(sdd,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return dlhs, None, ddata, None, None, None, None, None, None, None
+
+
+dds = DDS.apply
+
+
+class SDD(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ lhs,
+ rhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t):
+ ctx.save_for_backward(
+ lhs,
+ rhs,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t)
+ ctx.shape = shape
+ out = torch.empty(
+ data.shape,
+ dtype=lhs.dtype,
+ device=lhs.device)
+ backend.sdd(lhs,
+ rhs,
+ shape,
+ out,
+ offsets,
+ row_indices,
+ column_indices)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs, rhs = saved_tensors[:2]
+ dy = (ctx.shape, dy) + saved_tensors[2:]
+ trans_a = _is_transposed(lhs)
+ trans_b = _is_transposed(rhs)
+
+ dlhs = None
+ if ctx.needs_input_grad[0]:
+ op = dds if trans_a else dsd
+ dlhs = _lhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ drhs = None
+ if ctx.needs_input_grad[1]:
+ op = dsd if trans_b else dds
+ drhs = _rhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return dlhs, drhs, None, None, None, None, None, None, None, None
+
+
+sdd = SDD.apply
+
+class RowIndices(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, shape, data, offsets, column_indices):
+ out = torch.empty(
+ column_indices.shape,
+ dtype=column_indices.dtype,
+ device=column_indices.device)
+ backend.row_indices(shape, data, offsets, column_indices, out)
+ return out
+
+
+row_indices = RowIndices.apply
diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/backend/triton_kernels.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/backend/triton_kernels.py
new file mode 100644
index 0000000000000000000000000000000000000000..c535309f3321249f475367164a558f94a4f8eb86
--- /dev/null
+++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/backend/triton_kernels.py
@@ -0,0 +1,393 @@
+import torch
+import triton
+import triton.language as tl
+from dataclasses import dataclass
+
+@dataclass
+class TritonConfig:
+ BLOCK_M: int = 128
+ BLOCK_N: int = 128
+ BLOCK_K: int = 32
+ BLOCK_SIZE: int = 128
+ NUM_STAGES: int = 4
+ NUM_WARPS: int = 4
+
+def _validate_matmul_dims(M: int, K: int, N: int):
+ error_string = "incompatible dimensions: tensor has dim with length: {}, which must be divisible by {}"
+ assert M % TritonConfig.BLOCK_M == 0, error_string.format(M, TritonConfig.BLOCK_M)
+ assert K % TritonConfig.BLOCK_K == 0, error_string.format(K, TritonConfig.BLOCK_K)
+ assert N % TritonConfig.BLOCK_N == 0, error_string.format(N, TritonConfig.BLOCK_N)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _sdd_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+ # matrix multiplication
+ pid = tl.program_id(0)
+ pid_m = tl.load(row_indices + pid)
+ pid_n = tl.load(column_indices + pid)
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
+ rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
+ rk = tl.arange(0, BLOCK_K)
+ # pointers
+ A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
+ B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ for k in range(0, tl.cdiv(K, BLOCK_K)):
+ a = tl.load(A)
+ b = tl.load(B)
+ acc += tl.dot(a, b)
+ A += BLOCK_K * stride_ak
+ B += BLOCK_K * stride_bk
+ #Store to sparse matrix
+ acc = acc.to(C.dtype.element_ty)
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+ cm = tl.arange(0, BLOCK_M)
+ cn = tl.arange(0, BLOCK_N)
+ C = C + pid * BLOCK_ELEMENTS + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _dsd_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices, offsets,
+ block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+
+ # matrix multiplication
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+
+ num_pid_m = tl.num_programs(0)
+ num_pid_n = tl.num_programs(1)
+ pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M)
+
+ start_inx = tl.load(offsets + pid_m)
+ end_inx = tl.load(offsets + pid_m + 1)
+
+ # pointers to sparse matrix
+ rm = tl.arange(0, BLOCK_M)
+ rak = tl.arange(0, BLOCK_K)
+
+ A += (rm[:, None] * stride_am + rak[None, :] * stride_ak)
+
+ # pointers to dense matrix
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ rbk = tl.arange(0, BLOCK_K)
+ B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn)
+
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K)
+
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+ ak_sub_incr = BLOCK_K * stride_ak
+ bk_sub_incr = BLOCK_K * stride_bk
+ bk_block_incr = BLOCK_SIZE * stride_bk
+
+ for k in range(nsub_blocks * (end_inx - start_inx)):
+ sub_block_inx = k % nsub_blocks
+ block_inx = k // nsub_blocks
+
+ if trans_A:
+ ptr_A = A + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr
+ else:
+ ptr_A = A + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr
+
+ ptr_B = B + tl.load(column_indices + start_inx + block_inx) * bk_block_incr + sub_block_inx * bk_sub_incr
+
+ a = tl.load(ptr_A)
+ b = tl.load(ptr_B)
+ acc += tl.dot(a, b)
+
+ acc = acc.to(C.dtype.element_ty)
+
+ cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+
+ C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _dds_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices, offsets,
+ block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+
+ # matrix multiplication
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+
+ num_pid_m = tl.num_programs(0)
+ num_pid_n = tl.num_programs(1)
+ pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M)
+
+ start_inx = tl.load(offsets + pid_n)
+ end_inx = tl.load(offsets + pid_n + 1)
+
+ # pointers to dense matrix
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ rak = tl.arange(0, BLOCK_K)
+
+ A += (rm[:, None] * stride_am + rak[None, :] * stride_ak)
+
+ # pointers to sparse matrix
+ rn = tl.arange(0, BLOCK_N)
+ rbk = tl.arange(0, BLOCK_K)
+ B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn)
+
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K)
+
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+
+ ak_sub_incr = BLOCK_K * stride_ak
+ ak_block_incr = BLOCK_SIZE * stride_ak
+ bk_sub_incr = BLOCK_K * stride_bk
+
+ for k in range(nsub_blocks * (end_inx - start_inx)):
+ sub_block_inx = k % nsub_blocks
+ block_inx = k // nsub_blocks
+
+ if trans_B:
+ ptr_B = B + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr
+ else:
+ ptr_B = B + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr
+
+ ptr_A = A + tl.load(column_indices + start_inx + block_inx) * ak_block_incr + sub_block_inx * ak_sub_incr
+ a = tl.load(ptr_A)
+ b = tl.load(ptr_B)
+ acc += tl.dot(a, b)
+
+ acc = acc.to(C.dtype.element_ty)
+ cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+def dsd(shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs,
+ out
+ ):
+
+ device = rhs.device
+ trans_A = transpose_a
+ trans_B = False
+
+ if rhs.stride(0) > 1 and rhs.stride(1) > 1:
+ trans_B = True
+
+ # checks constraints
+ assert shape[1] == rhs.shape[0], "incompatible dimensions"
+ M, K = shape
+ _, N = rhs.shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if rhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ stride_am, stride_ak = data.stride(1), data.stride(2)
+ stride_bk, stride_bn = rhs.stride(0), rhs.stride(1)
+ a_column_indices = column_indices
+ a_offsets = offsets
+
+ # launch kernel
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N']))
+
+ if trans_A:
+ stride_am, stride_ak = data.stride(2), data.stride(1)
+ a_column_indices, a_offsets = column_indices_t, offsets_t
+
+ if trans_B:
+ stride_bk, stride_bn = rhs.stride(1), rhs.stride(0)
+
+ _dsd_kernel[grid](
+ data.data, rhs, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(0), out.stride(1),
+ row_indices, a_column_indices, a_offsets,
+ block_offsets_t, trans_A, trans_B,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+ # return out
+
+def dds(lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b,
+ out
+ ):
+
+ device = lhs.device
+ trans_B = transpose_b
+ trans_A = False
+
+ if lhs.stride(0) > 1 and lhs.stride(1) > 1:
+ trans_A = True
+
+ # checks constraints
+ assert lhs.shape[1] == shape[0], "incompatible dimensions"
+ M, K = lhs.shape
+ _, N = shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if lhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ stride_am, stride_ak = lhs.stride(0), lhs.stride(1)
+ stride_bk, stride_bn = data.stride(1), data.stride(2)
+ b_column_indices = column_indices_t
+ b_offsets = offsets_t
+
+ # launch kernel
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N']))
+
+ if trans_A:
+ stride_am, stride_ak = lhs.stride(1), lhs.stride(0)
+ if trans_B:
+ stride_bk, stride_bn = data.stride(2), data.stride(1)
+ b_column_indices, b_offsets = column_indices, offsets
+
+ _dds_kernel[grid](
+ lhs, data, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(0), out.stride(1),
+ row_indices, b_column_indices, b_offsets,
+ block_offsets_t, trans_A, trans_B,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+
+def sdd(lhs,
+ rhs,
+ shape,
+ out,
+ offsets,
+ row_indices,
+ column_indices
+ ):
+
+ device = out.device
+ trans_A = False
+ trans_B = False
+
+ if lhs.stride(0) > 1 and lhs.stride(1) > 1:
+ trans_A = True
+ if rhs.stride(0) > 1 and rhs.stride(1) > 1:
+ trans_B = True
+
+ # checks constraints
+ assert lhs.shape[1] == rhs.shape[0], "incompatible dimensions"
+ M, K = lhs.shape
+ _, N = rhs.shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if out.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ # launch kernel
+ nnz_blocks = len(row_indices)
+ grid = lambda META: (nnz_blocks,)
+
+ stride_am, stride_ak = lhs.stride(0), lhs.stride(1)
+ stride_bk, stride_bn = rhs.stride(0), rhs.stride(1)
+
+ if trans_A:
+ stride_am, stride_ak = lhs.stride(1), lhs.stride(0)
+ if trans_B:
+ stride_bk, stride_bn = rhs.stride(1), rhs.stride(0)
+
+ _sdd_kernel[grid](
+ lhs, rhs, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(1), out.stride(2),
+ row_indices, column_indices,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+
+@triton.jit
+def _row_indices_kernel(offsets, out):
+ pid = tl.program_id(0)
+ row_offset = tl.load(offsets + pid)
+ nnz_blocks = tl.load(offsets + pid + 1) - row_offset
+ for nnz_block in range(nnz_blocks):
+ tl.store(out + row_offset + nnz_block, pid)
+
+def row_indices(
+ shape, data, offsets, column_indices, out
+):
+ block_rows = len(offsets) - 1
+ _row_indices_kernel[(block_rows, )](offsets, out)
diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/matrix.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/matrix.py
new file mode 100644
index 0000000000000000000000000000000000000000..80f42263d6aada287adbfa52a61fe950162a9e28
--- /dev/null
+++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/matrix.py
@@ -0,0 +1,329 @@
+import numpy as np
+import torch
+
+# 1. Add heavyweight (data) validation helper.
+# 2. Add construction helpers
+# 3. Make indentation consistent
+# 4. Replace asserts with descriptive errors.
+
+##
+### Validation helpers.
+##
+
+
+def _validate_matrix(shape, data, row_indices, column_indices, offsets):
+ # Data should be [nnz, block_size, block_size]
+ if data.dim() == 1:
+ data = torch.reshape(data, [data.numel(), 1, 1])
+
+ # Blocks should be square.
+ if data.shape[-2] != data.shape[-1]:
+ raise ValueError(
+ "Expected square blocking in data. "
+ f"Got block shape {[data.shape[-2], data.shape[-1]]}")
+
+ # Flatten batch dimensions on data - original shape preserved
+ # in shape argument.
+ block_size = data.shape[-1]
+ data = data.view([-1, block_size, block_size])
+
+ if data.dim() != 3:
+ raise ValueError(
+ "Expected 3D shape for data (nnz, block, block). "
+ f"Got shape {data.dim()}D shape.")
+
+ block_size = data.shape[1]
+ if shape[-2] % block_size != 0 or shape[-1] % block_size != 0:
+ raise ValueError(
+ "Matrix shape must be dividible by blocking. "
+ f"Got shape {shape} with "
+ f"{[block_size, block_size]} blocking.")
+
+ if np.prod(shape) < data.numel():
+ raise ValueError(
+ "Invalid matrix. Number of nonzeros exceeds matrix capacity "
+ f"({data.numel()} v. {np.prod(shape)})")
+
+ if row_indices.dim() != 1:
+ raise ValueError(
+ f"Expected 1D row_indices. Got {row_indices.dim()}D row_indices.")
+
+ if column_indices.dim() != 1:
+ raise ValueError(
+ f"Expected 1D column_indices. Got {column_indices.dim()}D column_indices.")
+
+ if offsets.dim() != 1:
+ raise ValueError(
+ f"Expected 1D offsets. Got {offsets.dim()}D offsets.")
+
+ if row_indices.numel() != data.shape[0]:
+ raise ValueError(
+ "Expected 1 index per nonzero block. "
+ f"Got {row_indices.numel()} row_indices for {data.shape[0]} blocks")
+
+ if column_indices.numel() != data.shape[0]:
+ raise ValueError(
+ "Expected 1 index per nonzero block. "
+ f"Got {column_indices.numel()} column_indices for {data.shape[0]} blocks")
+
+ block_rows = np.prod(shape[:-1]) / block_size
+ if offsets.numel() != block_rows + 1:
+ raise ValueError(
+ "Expected one offset per block row plus one. "
+ f"Got {offsets.numel()} offsets with {block_rows} block rows.")
+
+ is_cuda = (data.is_cuda and
+ row_indices.is_cuda and
+ column_indices.is_cuda and
+ offsets.is_cuda)
+ is_cpu = (not data.is_cuda and
+ not row_indices.is_cuda and
+ not column_indices.is_cuda and
+ not offsets.is_cuda)
+ if not (is_cuda or is_cpu):
+ raise ValueError(
+ "Expected data & meta-data on common device. "
+ f"Got data on {data.device}, row_indices on {row_indices.device} "
+ f"column_indices on {column_indices.device} and "
+ f"offsets on {offsets.device}.")
+
+ if data.dtype != torch.float16:
+ raise ValueError(
+ f"Expected float16 data. Got {data.dtype} data.")
+ if row_indices.dtype != torch.int16:
+ raise ValueError(
+ f"Expected int16 row_indices. Got {row_indices.dtype} row_indices.")
+ if column_indices.dtype != torch.int16:
+ raise ValueError(
+ f"Expected int16 column_indices. Got {column_indices.dtype} column_indices.")
+ if offsets.dtype != torch.int32:
+ raise ValueError(
+ f"Expected int32 offsets. Got {offsets.dtype} offsets.")
+ return data
+
+
+def _transpose(size, data, row_indices, column_indices, offsets):
+ block_columns = size[1] // data.shape[1]
+
+ # Sort row indices by column indices to get the transposed matrix's
+ # column indices.
+ gather_indices = column_indices.argsort()
+ column_indices_t = row_indices.gather(0, gather_indices)
+ block_offsets_t = gather_indices.int()
+
+ # NOTE: Histogram is not implemented for any integer type on CPU. Do
+ # the histogram in 32-bit float, which can exactly represent 16-bit
+ # integers.
+ column_indices_float = column_indices.float()
+
+ zero = torch.zeros((1,), dtype=torch.int32, device=data.device)
+ nnz_per_column = column_indices_float.histc(block_columns, 0, block_columns)
+ nnz_per_column = nnz_per_column.int()
+ offsets_t = torch.cat([zero, nnz_per_column.cumsum(0, dtype=torch.int32)])
+ return column_indices_t, offsets_t, block_offsets_t
+
+
+class Matrix(torch.nn.Module):
+ """A matrix stored in sparse format.
+
+ Underlying format is block compressed sparse row (BCSR).
+
+ TODO(tgale): Make this mirror torch.Tensor API as much as possible.
+ """
+
+ def __init__(self,
+ size,
+ data,
+ row_indices,
+ column_indices,
+ offsets,
+ column_indices_t=None,
+ offsets_t=None,
+ block_offsets_t=None):
+ super().__init__()
+ self._size = size
+ self._data = data
+ self._row_indices = row_indices
+ self._column_indices = column_indices
+ self._offsets = offsets
+
+ # Produce the transpose meta-data if it is not passed in.
+ if ((column_indices_t is None) or (offsets_t is None) or
+ (block_offsets_t is None)):
+ column_indices_t, offsets_t, block_offsets_t = _transpose(
+ size, data, row_indices, column_indices, offsets)
+ self._column_indices_t = column_indices_t
+ self._offsets_t = offsets_t
+ self._block_offsets_t = block_offsets_t
+
+ self._transposed = False
+
+ # Validate that our metadata will not overflow.
+ max_dim = np.iinfo(np.int16).max * self.blocking
+ if column_indices.dtype == torch.int16:
+ if size[0] > max_dim or size[1] > max_dim:
+ raise ValueError(
+ "Sparse matrix with shape {size} exceeds representable "
+ "size with 16-bit indices.")
+
+ def validate(self):
+ _validate_matrix(self._size,
+ self._data,
+ self._row_indices,
+ self._column_indices,
+ self._offsets)
+
+ # TODO(tgale): Add heavyweight data validation.
+
+ def to(self, device):
+ # TODO(tgale): Handle type conversions here. We
+ # need to set the appropriate meta-data type for
+ # the given floating-point type.
+ self._data = self._data.to(device)
+ self._row_indices = self._row_indices.to(device)
+ self._column_indices = self._column_indices.to(device)
+ self._offsets = self._offsets.to(device)
+ self._column_indices_t = self._column_indices_t.to(device)
+ self._offsets_t = self._offsets_t.to(device)
+ self._block_offsets_t = self._block_offsets_t.to(device)
+ return self
+
+ def cuda(self):
+ return self.to(torch.cuda.current_device())
+
+ def clone(self):
+ return Matrix(
+ self.size(),
+ self.data.clone(),
+ self.row_indices.clone(),
+ self.column_indices.clone(),
+ self.offsets.clone(),
+ self.column_indices_t.clone(),
+ self.offsets_t.clone(),
+ self.block_offsets_t.clone())
+
+ def t(self):
+ if self.dim() != 2:
+ raise ValueError(
+ "t() expects a tensor with <= 2 dimensions, "
+ f"but self is {self.dim()}D.")
+ out = Matrix(self.size(),
+ self.data,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+ out._transposed = not self._transposed
+ out._size = torch.Size((self._size[1], self._size[0]))
+ return out
+
+ def contiguous(self):
+ raise ValueError("Not yet implemented.")
+
+ def is_contiguous(self):
+ return not self._transposed
+
+ @property
+ def is_cuda(self):
+ return self._data.is_cuda
+
+ @property
+ def device(self):
+ return self._data.device
+
+ def size(self):
+ return self._size
+
+ @property
+ def shape(self):
+ return self.size()
+
+ def dim(self):
+ return len(self._size)
+
+ @property
+ def data(self):
+ return self._data
+
+ @property
+ def row_indices(self):
+ return self._row_indices
+
+ @property
+ def column_indices(self):
+ return self._column_indices
+
+ @property
+ def offsets(self):
+ return self._offsets
+
+ @property
+ def offsets_t(self):
+ return self._offsets_t
+
+ @property
+ def column_indices_t(self):
+ return self._column_indices_t
+
+ @property
+ def block_offsets_t(self):
+ return self._block_offsets_t
+
+ @property
+ def dtype(self):
+ return self.data.dtype
+
+ @property
+ def nnz(self):
+ return self.data.numel()
+
+ @property
+ def blocking(self):
+ return self.data.shape[1]
+
+ @property
+ def requires_grad(self):
+ return self.data.requires_grad
+
+ def requires_grad_(self, x):
+ self.data.requires_grad_(x)
+ return self
+
+ def view(self, *shape):
+ assert self.is_contiguous()
+ if shape[-1] != self.size()[-1]:
+ raise ValueError(
+ "Can't change view on compressed dimension. "
+ f"{self.size()[-1]} v. {shape[-1]}.")
+ if np.prod(shape) != np.prod(self.size()):
+ raise ValueError(
+ "Mismatch in numel of Matrix and new shape. "
+ f"{np.prod(self.size())} v. {np.prod(shape)}")
+ return Matrix(shape,
+ self.data,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+
+ @property
+ def grad(self):
+ # TODO(tgale): Make sure this mirrors torch.Tensor
+ # behavior in the case where we ask for the gradient
+ # of a non-contiguous tensor.
+ size = self.size()
+ if not self.is_contiguous():
+ size = torch.Size((size[1], size[0]))
+ out = Matrix(size,
+ self.data.grad,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+ return out if self.is_contiguous() else out.t()
diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/ops/__init__.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc873b236f4cd4036964c016a4036e3ce5ebf1ac
--- /dev/null
+++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/ops/__init__.py
@@ -0,0 +1,3 @@
+from .linear_ops import dds, dsd, sdd
+from .matrix_ops import ones_like, row_indices, sum, to_dense, to_sparse
+from .eltwise_ops import mul
diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/ops/eltwise_ops.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/ops/eltwise_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba7d7332320250fd01fa60e60528f19de3e8ed03
--- /dev/null
+++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/ops/eltwise_ops.py
@@ -0,0 +1,28 @@
+from ..matrix import Matrix
+
+def mul(a, b):
+ """Performs element-wise multiplication of matrices a and b.
+
+ It is the user's responsibility to make sure that a and b
+ follow the same matrix topology. This function assumes it is safe
+ to use the topoplogy of a.
+
+ Args:
+ a: stk.Matrix.
+ b: stk.Matrix with a's matrix topology.
+
+ Returns:
+ stk.Matrix where the entries correspond to torch.mul(a, b).
+ """
+ assert isinstance(a, Matrix)
+ assert isinstance(b, Matrix)
+ assert a.size() == b.size()
+
+ return Matrix(a.size(),
+ a.data * b.data,
+ a.row_indices,
+ a.column_indices,
+ a.offsets,
+ a.column_indices_t,
+ a.offsets_t,
+ a.block_offsets_t)
diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/ops/eltwise_ops_test.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/ops/eltwise_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..66bfd4f6af77042d3c5bdb1fe18d00e457478d46
--- /dev/null
+++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/ops/eltwise_ops_test.py
@@ -0,0 +1,86 @@
+import unittest
+import itertools
+import torch
+from absl.testing import parameterized
+
+import stk
+from stk.ops.linear_ops_test import allclose, _dense_and_sparse
+
+_MATRIX_SIZES = (
+ (128, 128, 0.0),
+ (256, 256, 0.5),
+ (2048, 1024, 0.8),
+ (512, 128, 0.0),
+ (128, 512, 0.0),
+ (1024, 512, 0.0),
+ (1024, 512, 0.5),
+ (1024, 512, 0.75),
+ (512, 1024, 0.0),
+ (512, 1024, 0.5),
+ (512, 1024, 0.75),
+ (1024, 1024, 0.0),
+ (1024, 1024, 0.5),
+ (1024, 1024, 0.75),
+)
+
+_DTYPE = (
+ torch.float16, torch.bfloat16
+)
+
+def _generate_testcases():
+ testcases = itertools.product(_MATRIX_SIZES, _DTYPE)
+ testcases = [(*size, 128, dtype) for
+ (size, dtype) in testcases]
+ return testcases
+
+_ELTWISE_OP_TESTS = _generate_testcases()
+
+def _dense_and_sparse_like(x, std=0.1):
+ dense_data = torch.randn_like(x.data, device=x.device) * std
+ sparse = stk.Matrix(x.size(),
+ dense_data,
+ x.row_indices,
+ x.column_indices,
+ x.offsets)
+ dense = stk.ops.to_dense(sparse)
+
+ return (dense.requires_grad_(True),
+ sparse.requires_grad_(True))
+
+@parameterized.parameters(_ELTWISE_OP_TESTS)
+class EltwiseOpsTest(parameterized.TestCase):
+
+ def testEltwiseMul(self, m, n, sparsity, blocking, dtype):
+
+ a_dense, a = _dense_and_sparse(m, n, sparsity, blocking, dtype)
+ b_dense, b = _dense_and_sparse_like(a)
+
+ out = stk.ops.mul(a, b)
+ expected_out = torch.mul(a_dense, b_dense)
+
+ # Compute the gradients w.r.t. the inputs.
+ expected_out.sum().backward()
+ stk.ops.sum(out).backward()
+
+ # Validate the results.
+ out = stk.ops.to_dense(out)
+ self.assertEqual(out.dim(), 2)
+ self.assertEqual(expected_out.size(), out.size())
+ self.assertTrue(allclose(out, expected_out))
+
+ # LHS gradient.
+ grad = stk.ops.to_dense(a.grad)
+ expected_grad = a_dense.grad
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size(), grad.size())
+ self.assertTrue(allclose(grad, expected_grad))
+
+ # RHS gradient.
+ grad = stk.ops.to_dense(b.grad)
+ expected_grad = b_dense.grad
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size(), grad.size())
+ self.assertTrue(allclose(grad, expected_grad))
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/ops/linear_ops.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/ops/linear_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d277c8c07f9e30addc31900a12175c8a1f4d7ad
--- /dev/null
+++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/ops/linear_ops.py
@@ -0,0 +1,59 @@
+import torch
+
+from ..backend import sputnik
+from ..matrix import Matrix
+
+
+def dsd(a, b):
+ assert isinstance(a, Matrix)
+ assert isinstance(b, torch.Tensor)
+ return sputnik.dsd(
+ a.size(),
+ a.data, a.offsets,
+ a.row_indices,
+ a.column_indices,
+ a.offsets_t,
+ a.column_indices_t,
+ a.block_offsets_t,
+ not a.is_contiguous(),
+ b)
+
+
+def dds(a, b):
+ assert isinstance(a, torch.Tensor)
+ assert isinstance(b, Matrix)
+ return sputnik.dds(
+ a,
+ b.size(),
+ b.data, b.offsets,
+ b.row_indices,
+ b.column_indices,
+ b.offsets_t,
+ b.column_indices_t,
+ b.block_offsets_t,
+ not b.is_contiguous())
+
+
+def sdd(a, b, topo):
+ assert isinstance(a, torch.Tensor)
+ assert isinstance(b, torch.Tensor)
+ assert isinstance(topo, Matrix)
+ assert topo.is_contiguous()
+ out = sputnik.sdd(
+ a, b,
+ topo.size(),
+ topo.data,
+ topo.offsets,
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets_t,
+ topo.column_indices_t,
+ topo.block_offsets_t)
+ return Matrix(topo.size(),
+ out,
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets,
+ topo.column_indices_t,
+ topo.offsets_t,
+ topo.block_offsets_t)
diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/ops/linear_ops_test.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/ops/linear_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..ced1d782fbc9f9ca16b3449239f1588dc5ff5e00
--- /dev/null
+++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/ops/linear_ops_test.py
@@ -0,0 +1,216 @@
+import unittest
+import itertools
+import numpy as np
+import torch
+from absl.testing import parameterized
+
+import stk
+
+
+def allclose(x, y, pct=0.25):
+ mask = torch.isclose(x, y, rtol=5e-2)
+ pct_diff = (mask.numel() - mask.sum()) / mask.numel() * 100
+ if pct_diff > pct:
+ print("{:.2f}% of values not close.".format(pct_diff))
+ return False
+ return True
+
+
+# An assortment of problems designed to make sure
+# the bindings are operating correctly.
+_MATRIX_SIZES = (
+ (128, 128, 128, 0.0),
+ (256, 256, 256, 0.5),
+ (2048, 1024, 512, 0.8),
+ (512, 128, 128, 0.0),
+ (128, 128, 512, 0.0),
+ (1024, 512, 512, 0.0),
+ (1024, 512, 512, 0.5),
+ (1024, 512, 512, 0.75),
+ (512, 512, 1024, 0.0),
+ (512, 512, 1024, 0.5),
+ (512, 512, 1024, 0.75),
+ (1024, 1024, 1024, 0.0),
+ (1024, 1024, 1024, 0.5),
+ (1024, 1024, 1024, 0.75),
+)
+
+_TRANSPOSE = (
+ (False, False),
+ (False, True),
+ (True, False),
+ (True, True),
+)
+
+_DTYPE = (
+ torch.float16, torch.bfloat16
+)
+
+def _generate_testcases():
+ testcases = itertools.product(_MATRIX_SIZES, _TRANSPOSE, _DTYPE)
+ testcases = [(*size, *trans, 128, dtype) for
+ (size, trans, dtype) in testcases]
+ return testcases
+
+_LINEAR_OP_TESTS = _generate_testcases()
+
+def _dense_and_sparse(rows, cols, sparsity, blocking, dtype, std=0.1):
+ mask = stk.random.dense_mask(rows, cols, sparsity, blocking)
+ dense = (torch.randn(rows, cols) * std * mask).type(dtype)
+ sparse = stk.ops.to_sparse(dense, blocking)
+ cuda_device = torch.device("cuda")
+ return (dense.to(cuda_device).requires_grad_(True),
+ sparse.to(cuda_device).requires_grad_(True))
+
+
+def _dense(rows, cols, dtype, std=0.1):
+ cuda_device = torch.device("cuda")
+ out = (torch.randn(rows, cols) * std).type(dtype)
+ return out.to(cuda_device).requires_grad_(True)
+
+
+def _dense_2x(rows, cols, dtype):
+ a = _dense(rows, cols, dtype)
+ return a, a.detach().requires_grad_(True)
+
+
+def _with_transpose(op, a, b, trans_a, trans_b):
+ a = a.t() if trans_a else a
+ b = b.t() if trans_b else b
+ return op(a, b)
+
+
+def _mmm(a, b, topo):
+ mask = stk.ops.to_dense(stk.ops.ones_like(topo))
+ return torch.mm(a, b) * mask
+
+
+def _sparse_out_with_transpose(op, a, b, topo, trans_a, trans_b):
+ a = a.t() if trans_a else a
+ b = b.t() if trans_b else b
+ return op(a, b, topo)
+
+
+def _mask(x, mask):
+ mask = stk.ops.to_dense(stk.ops.ones_like(mask))
+ return x * mask
+
+
+@parameterized.parameters(*_LINEAR_OP_TESTS)
+class LinearOpsTest(parameterized.TestCase):
+
+ def testLinearOps_Dsd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+ # Construct the operands.
+ a_shape = (k, m) if trans_a else (m, k)
+ a_dense, a = _dense_and_sparse(*a_shape, sparsity, blocking, dtype)
+ b_shape = (n, k) if trans_b else (k, n)
+ b, bcp = _dense_2x(*b_shape, dtype)
+
+ # Execute the matmul.
+ out = _with_transpose(stk.ops.dsd, a, b, trans_a, trans_b)
+ expected_out = _with_transpose(torch.mm, a_dense, bcp, trans_a, trans_b)
+
+ # Compute the gradients w.r.t. the inputs.
+ expected_out.sum().backward()
+ out.sum().backward()
+
+ # Validate the results.
+ self.assertEqual(out.dim(), 2)
+ self.assertEqual(expected_out.size()[0], out.size()[0])
+ self.assertEqual(expected_out.size()[1], out.size()[1])
+ self.assertTrue(allclose(out, expected_out))
+
+ # LHS gradient.
+ grad = stk.ops.to_dense(a.grad)
+ expected_grad = _mask(a_dense.grad, a.grad)
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
+ self.assertTrue(allclose(grad, expected_grad))
+
+ # RHS gradient.
+ grad = b.grad
+ expected_grad = bcp.grad
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
+ self.assertTrue(allclose(grad, expected_grad))
+
+ def testLinearOps_Dds(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+ # Construct the operands.
+ a_shape = (k, m) if trans_a else (m, k)
+ a, acp = _dense_2x(*a_shape, dtype)
+ b_shape = (n, k) if trans_b else (k, n)
+ b_dense, b = _dense_and_sparse(*b_shape, sparsity, blocking, dtype)
+
+ # Execute the matmul.
+ out = _with_transpose(stk.ops.dds, a, b, trans_a, trans_b)
+ expected_out = _with_transpose(torch.mm, acp, b_dense, trans_a, trans_b)
+
+ # Compute the gradients w.r.t. the inputs.
+ expected_out.sum().backward()
+ out.sum().backward()
+
+ # Validate the results.
+ self.assertEqual(out.dim(), 2)
+ self.assertEqual(expected_out.size()[0], out.size()[0])
+ self.assertEqual(expected_out.size()[1], out.size()[1])
+ self.assertTrue(allclose(out, expected_out))
+
+ # LHS gradient.
+ grad = a.grad
+ expected_grad = acp.grad
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
+ self.assertTrue(allclose(grad, expected_grad))
+
+ # RHS gradient.
+ grad = stk.ops.to_dense(b.grad)
+ expected_grad = _mask(b_dense.grad, b.grad)
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
+ self.assertTrue(allclose(grad, expected_grad))
+
+ def testLinearOps_Sdd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+ # Construct the operands.
+ a_shape = (k, m) if trans_a else (m, k)
+ a, acp = _dense_2x(*a_shape, dtype)
+ b_shape = (n, k) if trans_b else (k, n)
+ b, bcp = _dense_2x(*b_shape, dtype)
+ _, topo = _dense_and_sparse(m, n, sparsity, blocking, dtype)
+
+ # Execute the matmul.
+ out = _sparse_out_with_transpose(stk.ops.sdd, a, b, topo, trans_a, trans_b)
+ expected_out = _sparse_out_with_transpose(_mmm, acp, bcp, topo, trans_a, trans_b)
+
+ # Compute the gradients w.r.t. the inputs.
+ expected_out.sum().backward()
+ stk.ops.sum(out).backward()
+
+ # Validate the results.
+ out = stk.ops.to_dense(out)
+ self.assertEqual(out.dim(), 2)
+ self.assertEqual(expected_out.size()[0], out.size()[0])
+ self.assertEqual(expected_out.size()[1], out.size()[1])
+ self.assertTrue(allclose(out, expected_out))
+
+ # LHS gradient.
+ grad = a.grad
+ expected_grad = acp.grad
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
+ self.assertTrue(allclose(grad, expected_grad))
+
+ # RHS gradient.
+ grad = b.grad
+ expected_grad = bcp.grad
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
+ self.assertTrue(allclose(grad, expected_grad))
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/ops/matrix_ops.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/ops/matrix_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..447c72dc73439d84f58c917676cc04e64f13e97d
--- /dev/null
+++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/ops/matrix_ops.py
@@ -0,0 +1,98 @@
+from ..backend import sputnik
+from ..matrix import Matrix
+import torch
+import numpy as np
+
+
+@torch.no_grad()
+def row_indices(shape, data, offsets, column_indices):
+ return sputnik.row_indices(shape, data, offsets, column_indices)
+
+
+# TODO(tgale): Replace this helper with a custom kernel. This operation
+# is much simpler to do than how it's currently implemented.
+@torch.no_grad()
+def _expand_for_blocking(idxs, blocking):
+ # Duplicate for block column dimension.
+ idxs = torch.reshape(idxs, [idxs.size()[0], 1, 2]).repeat(1, blocking, 1)
+
+ # Update the column indices.
+ idxs[:, :, 1] *= blocking
+ idxs[:, :, 1] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking])
+
+ # Duplicate for block row dimension.
+ idxs = torch.reshape(idxs, [idxs.size()[0], 1, blocking, 2])
+ idxs = idxs.repeat(1, blocking, 1, 1)
+
+ # Update the row indices.
+ idxs[:, :, :, 0] *= blocking
+ idxs[:, :, :, 0] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking, 1])
+ idxs = torch.reshape(idxs, [-1, 2])
+ return idxs
+
+
+# TODO(tgale): Add input type checking.
+@torch.no_grad()
+def to_dense(x):
+ assert isinstance(x, Matrix)
+
+ shape = (np.prod(x.shape[:-1]), x.shape[-1])
+ row_idxs = x.row_indices.type(torch.int32)
+ col_idxs = x.column_indices.type(torch.int32)
+ indices = _expand_for_blocking(torch.stack([row_idxs, col_idxs], dim=1), x.blocking)
+ indices = (indices[:, 0] * shape[1] + indices[:, 1]).type(torch.int64)
+
+ out = torch.zeros(shape[0] * shape[1], dtype=x.dtype, device=x.device)
+ out.scatter_(0, indices, x.data.flatten())
+ return out.reshape(x.size())
+
+
+@torch.no_grad()
+def _mask(x, blocking=1):
+ assert x.dim() == 2
+ assert x.size()[0] % blocking == 0
+ assert x.size()[1] % blocking == 0
+ block_rows = x.size()[0] // blocking
+ block_cols = x.size()[1] // blocking
+ x = torch.reshape(x, [block_rows, blocking, block_cols, blocking])
+ x = torch.sum(torch.abs(x), dim=(1, 3))
+ return x != 0
+
+
+# TODO(tgale): Add input type checking.
+@torch.no_grad()
+def to_sparse(x, blocking=1):
+ m = _mask(x, blocking)
+
+ # TODO(tgale): Set to appropriate type for input matrix.
+ row_nnzs = torch.sum(m, dim=1).type(torch.int32)
+ zeros = torch.zeros((1,), dtype=row_nnzs.dtype, device=row_nnzs.device)
+ offsets = torch.cat([zeros, torch.cumsum(row_nnzs, dim=0)])
+ offsets = offsets.type(torch.int32)
+
+ indices = torch.nonzero(m).type(torch.int16)
+ row_indices = indices[:, 0]
+ column_indices = indices[:, 1]
+
+ # Nonzero indices in the dense matrix.
+ nonzero_indices = torch.nonzero(m)
+ nonzero_indices = _expand_for_blocking(nonzero_indices, blocking)
+ nonzero_indices = nonzero_indices[:, 0] * x.size()[1] + nonzero_indices[:, 1]
+
+ # Gather the data and construct the sparse matrix.
+ data = torch.gather(x.flatten(), dim=0, index=nonzero_indices)
+ data = torch.reshape(data, [-1, blocking, blocking])
+ return Matrix(x.size(), data, row_indices, column_indices, offsets)
+
+
+@torch.no_grad()
+def ones_like(x):
+ return Matrix(x.size(),
+ torch.ones_like(x.data),
+ x.row_indices,
+ x.column_indices, x.offsets)
+
+
+def sum(x):
+ assert isinstance(x, Matrix)
+ return x.data.sum()
diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/ops/matrix_ops_test.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/ops/matrix_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..3af04c0760483e578f93303dc457415948a2a34c
--- /dev/null
+++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/ops/matrix_ops_test.py
@@ -0,0 +1,62 @@
+import unittest
+
+from absl.testing import parameterized
+import stk
+import torch
+
+
+@parameterized.parameters(
+ (8, 16, 0.0, 1),
+ (8, 16, 0.5, 1),
+ (8, 16, .95, 1),
+ (16, 8, 0.0, 1),
+ (16, 8, 0.5, 1),
+ (16, 8, .95, 1),
+ (8, 16, 0.0, 8),
+ (8, 16, 0.5, 8),
+ (8, 16, 1.0, 8),
+ (16, 8, 0.0, 8),
+ (16, 8, 0.5, 8),
+ (16, 8, 1.0, 8),
+ (128, 256, 0.5, 16),
+ (256, 128, 0.75, 32),
+ (512, 512, .875, 128))
+class MatrixOpsTest(parameterized.TestCase):
+
+ def testMatrixOps_FormatConversion(self, rows, cols, sparsity, blocking):
+ mask = stk.random.dense_mask(rows, cols, sparsity, blocking)
+ x = (torch.randn(rows, cols) * mask).type(torch.float16)
+
+ # Convert the matrix to sparse format.
+ sparse_x = stk.ops.to_sparse(x, blocking)
+
+ # Validate the matrix.
+ sparse_x.validate()
+
+ # Validate the shape.
+ self.assertEqual(sparse_x.dim(), 2)
+ self.assertEqual(sparse_x.size()[0], rows)
+ self.assertEqual(sparse_x.size()[1], cols)
+
+ # Validate the sparsity.
+ numblocks = rows // blocking * cols // blocking
+ nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+ self.assertEqual(sparse_x.nnz, nnz)
+
+ # Convert back to dense format.
+ dense_x = stk.ops.to_dense(sparse_x)
+
+ # Validate the shape.
+ self.assertEqual(dense_x.dim(), 2)
+ self.assertEqual(dense_x.size()[0], rows)
+ self.assertEqual(dense_x.size()[1], cols)
+
+ # Validate the sparsity
+ self.assertEqual(torch.count_nonzero(dense_x).item(), nnz)
+
+ # Validate the output.
+ self.assertTrue(torch.all(torch.eq(x, dense_x)))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/random/__init__.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/random/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2576d1ca27283f77569a9a620c7c99fa68aaf30e
--- /dev/null
+++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/random/__init__.py
@@ -0,0 +1,2 @@
+# from stk.random.random_ops import dense_mask, mask, randn
+from .random_ops import dense_mask, mask, randn
diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/random/random_ops.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/random/random_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1b36771e0eb8e7abf46bcb3b136b5fb1d29df93
--- /dev/null
+++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/random/random_ops.py
@@ -0,0 +1,36 @@
+import numpy as np
+import torch
+from ..ops import matrix_ops
+
+
+@torch.no_grad()
+def dense_mask(rows, cols, sparsity, blocking=1):
+ assert sparsity >= 0.0 and sparsity <= 1.0
+ assert rows % blocking == 0 and cols % blocking == 0
+
+ block_rows, block_cols = (rows // blocking, cols // blocking)
+ nnz = round(block_rows * block_cols * (1 - sparsity))
+
+ out = np.ones(block_rows * block_cols)
+ mask = np.random.choice(out.size, out.size - nnz, replace=False)
+ out[mask] = 0.0
+
+ out = np.tile(
+ np.reshape(out, [block_rows, 1, block_cols, 1]),
+ (1, blocking, 1, blocking))
+ out = np.reshape(out, [rows, cols])
+ return torch.from_numpy(out.astype(np.float32))
+
+
+@torch.no_grad()
+def mask(m, n, sparsity, blocking=1):
+ out = dense_mask(m, n, sparsity, blocking).type(torch.float16)
+ return matrix_ops.to_sparse(out, blocking=blocking)
+
+
+@torch.no_grad()
+def randn(shape, sparsity, blocking=1):
+ shape_2d = (np.prod(shape[:-1]), shape[-1])
+ out = mask(*shape_2d, sparsity, blocking)
+ out.data.copy_(torch.randn(*out.data.shape))
+ return out.view(*shape)
diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/random/random_ops_test.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/random/random_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..587b44ec890c861879c6296b8f9028f5d99ab82f
--- /dev/null
+++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/random/random_ops_test.py
@@ -0,0 +1,73 @@
+import unittest
+
+from absl.testing import parameterized
+from . import random
+import torch
+
+
+@parameterized.parameters(
+ (8, 16, 0.0, 1),
+ (8, 16, 0.5, 1),
+ (8, 16, .95, 1),
+ (16, 8, 0.0, 1),
+ (16, 8, 0.5, 1),
+ (16, 8, .95, 1),
+ (8, 16, 0.0, 8),
+ (8, 16, 0.5, 8),
+ (8, 16, 1.0, 8),
+ (16, 8, 0.0, 8),
+ (16, 8, 0.5, 8),
+ (16, 8, 1.0, 8),
+ (128, 256, 0.5, 16),
+ (256, 128, 0.75, 32),
+ (512, 512, .875, 128))
+class RandomOpsTest(parameterized.TestCase):
+
+ def testRandomOps_DenseMask(self, rows, cols, sparsity, blocking):
+ mask = random.dense_mask(
+ rows, cols, sparsity, blocking)
+
+ # Validate the shape.
+ self.assertEqual(mask.dim(), 2)
+ self.assertEqual(mask.size()[0], rows)
+ self.assertEqual(mask.size()[1], cols)
+
+ # Validate the sparsity
+ numblocks = rows // blocking * cols // blocking
+ nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+ self.assertEqual(
+ torch.count_nonzero(mask).item(),
+ nnz)
+
+ # Check values are zero or one.
+ self.assertTrue(
+ torch.all(torch.logical_or(
+ torch.eq(mask, 0),
+ torch.eq(mask, 1))))
+
+ def testRandomOps_SparseMask(self, rows, cols, sparsity, blocking):
+ mask = random.mask(
+ rows, cols, sparsity, blocking)
+
+ # Validate the matrix.
+ mask.validate()
+
+ # Validate the shape.
+ self.assertEqual(mask.dim(), 2)
+ self.assertEqual(mask.size()[0], rows)
+ self.assertEqual(mask.size()[1], cols)
+
+ # Validate the sparsity.
+ numblocks = rows // blocking * cols // blocking
+ nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+ self.assertEqual(mask.nnz, nnz)
+
+ # Check values are zero or one.
+ self.assertTrue(
+ torch.all(torch.logical_or(
+ torch.eq(mask.data, 0),
+ torch.eq(mask.data, 1))))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/__init__.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..38075732c6d8fa0e1e6ef493145e1aca3851ae6b
--- /dev/null
+++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/__init__.py
@@ -0,0 +1,202 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+from ._ops import ops
+
+from .grouped_gemm import backend as gg_backend
+from .grouped_gemm import ops as gg_ops
+
+
+from ._layers.arguments import Arguments
+from ._layers.dmoe import ParallelDroplessMLP, dMoE
+from ._layers.glu import SparseGLU
+from ._layers.mlp import MLP, SparseMLP
+from ._layers.moe import MoE, ParallelMLP, get_load_balancing_loss
+
+from . import layers
+
+# This section contains the direct kernel exports (not inlcuded in the original code)
+def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
+ """
+ Compute exclusive cumulative sum along the specified dimension.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ result = ops.exclusive_cumsum(x, dim)
+ out.copy_(result)
+ return out
+
+
+def inclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
+ """
+ Compute inclusive cumulative sum along the specified dimension.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ result = ops.inclusive_cumsum(x, dim)
+ out.copy_(result)
+ return out
+
+
+def histogram(x: torch.Tensor, num_bins: int) -> torch.Tensor:
+ """
+ Compute histogram of input tensor values.
+
+ Args:
+ x: Input tensor
+ num_bins: Number of histogram bins
+
+ Returns:
+ Histogram tensor with counts for each bin
+ """
+ return ops.histogram(x, num_bins)
+
+
+def indices(
+ padded_bins: torch.Tensor,
+ block_size: int,
+ output_block_rows: int,
+ output_block_columns: int,
+) -> torch.Tensor:
+ """
+ Construct indices from padded bins for sparse operations.
+
+ Args:
+ padded_bins: Tensor containing bin boundaries
+ block_size: Size of each block
+ output_block_rows: Number of rows in output blocks
+ output_block_columns: Number of columns in output blocks
+
+ Returns:
+ Tensor containing constructed indices
+ """
+ return ops.indices(padded_bins, block_size, output_block_rows, output_block_columns)
+
+
+def replicate_forward(
+ x: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Forward pass of replicate operation - replicate values according to bin sizes.
+
+ Args:
+ x: Input tensor with values to replicate
+ bins: Tensor containing bin sizes
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ return ops.replicate_forward(x, bins, out)
+
+
+def replicate_backward(
+ grad: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Backward pass of replicate operation - reduce gradients back to bins.
+
+ Args:
+ grad: Gradient tensor to reduce
+ bins: Tensor containing bin sizes
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ return ops.replicate_backward(grad, bins, out)
+
+
+def sort(
+ x: torch.Tensor, end_bit: int, x_out: torch.Tensor, iota_out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Radix sort with index tracking.
+
+ Args:
+ x: Input tensor to sort
+ end_bit: Number of bits to consider in sorting
+ x_out: Output tensor for sorted values
+ iota_out: Output tensor for sorted indices
+
+ Returns:
+ The sorted values tensor
+ """
+ return ops.sort(x, end_bit, x_out, iota_out)
+
+
+# Convenience functions for common use cases
+def cumsum(x: torch.Tensor, dim: int = -1, exclusive: bool = False) -> torch.Tensor:
+ """
+ Compute cumulative sum with automatic output allocation.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum (default: last dimension)
+ exclusive: Whether to compute exclusive (True) or inclusive (False) cumsum
+
+ Returns:
+ New tensor containing the cumulative sum
+ """
+ out = torch.empty_like(x)
+ if exclusive:
+ return exclusive_cumsum(x, dim, out)
+ else:
+ return inclusive_cumsum(x, dim, out)
+
+
+def argsort(x: torch.Tensor, end_bit: int = 32) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Sort tensor and return both sorted values and indices.
+
+ Args:
+ x: Input tensor to sort
+ end_bit: Number of bits to consider in sorting
+
+ Returns:
+ Tuple of (sorted_values, sorted_indices)
+ """
+ x_out = torch.empty_like(x)
+ iota_out = torch.empty_like(x)
+ sort(x, end_bit, x_out, iota_out)
+ return x_out, iota_out
+
+
+# Export public API
+__all__ = [
+ "MyReplacementLayer",
+ # Direct kernel exports
+ "exclusive_cumsum",
+ "inclusive_cumsum",
+ "histogram",
+ "indices",
+ "replicate_forward",
+ "replicate_backward",
+ "sort",
+ "cumsum",
+ "argsort",
+ # Original exports
+ "Arguments",
+ "ParallelDroplessMLP",
+ "dMoE",
+ "SparseGLU",
+ "MLP",
+ "SparseMLP",
+ "MoE",
+ "ParallelMLP",
+ "get_load_balancing_loss",
+]
diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/__init__.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a720e7a2cc4e44636f6e433a2750e945dc38e8b2
--- /dev/null
+++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/__init__.py
@@ -0,0 +1,10 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# from megablocks.layers.dmoe import dMoE
+from .moe import MoE
+
+__all__ = [
+ 'MoE',
+ # 'dMoE',
+]
diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/activation_fn.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/activation_fn.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e1d956704840aa4daf7d1d71d24e051567feab9
--- /dev/null
+++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/activation_fn.py
@@ -0,0 +1,33 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Callable, Union
+
+import torch
+from ..stk import Matrix
+
+
+def act_fn(
+ x: Matrix,
+ function: Callable,
+ return_grad_fn: bool = False,
+ **kwargs,
+) -> Union[tuple[Matrix, Any] | Matrix]:
+ assert isinstance(x, Matrix)
+ with torch.set_grad_enabled(torch.is_grad_enabled() or return_grad_fn):
+ if return_grad_fn:
+ x.data.requires_grad = True
+ out = function(x.data, **kwargs)
+ y = Matrix(
+ x.size(),
+ out,
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
+ if return_grad_fn:
+ return y, out.backward
+ return y
diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/all_to_all.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/all_to_all.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ac7067bcaa34db1d82b340c43550fe3577aa7a3
--- /dev/null
+++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/all_to_all.py
@@ -0,0 +1,54 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import torch.distributed as dist
+
+
+class AllToAllOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op):
+ out = torch.empty((sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype)
+
+ ctx.input_shape = x.shape
+ ctx.output_split_sizes = output_split_sizes
+ ctx.input_split_sizes = input_split_sizes
+ ctx.group = group
+ handle = dist.all_to_all_single(
+ out,
+ x,
+ output_split_sizes=output_split_sizes,
+ input_split_sizes=input_split_sizes,
+ group=group,
+ async_op=async_op,
+ )
+ return out, handle
+
+ @staticmethod
+ def backward(ctx, grad, _):
+ if ctx.needs_input_grad[0]:
+ out = torch.empty(
+ ctx.input_shape,
+ device=grad.device,
+ dtype=grad.dtype,
+ )
+ dist.all_to_all_single(
+ out,
+ grad,
+ output_split_sizes=ctx.input_split_sizes,
+ input_split_sizes=ctx.output_split_sizes,
+ group=ctx.group,
+ )
+ return out, None, None, None, None
+ return None, None, None, None, None
+
+
+def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False):
+ return AllToAllOp.apply(
+ x,
+ output_split_sizes,
+ input_split_sizes,
+ group,
+ async_op,
+ )
diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/arguments.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/arguments.py
new file mode 100644
index 0000000000000000000000000000000000000000..4db9b1bd38bc2e2f421625c124f86b85f45c5ae0
--- /dev/null
+++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/arguments.py
@@ -0,0 +1,101 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import dataclasses
+from functools import partial
+from typing import Any, Callable, Optional, Union
+
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+
+# import megablocks.grouped_gemm_util as grouped_gemm
+from .. import grouped_gemm_util as grouped_gemm
+
+# Type annotation for in-place Tensor initialization function.
+InitFn = Union[Callable[[torch.Tensor], None], partial[torch.Tensor]]
+
+_ALLOWED_BITWIDTHS = (-1, 4, 8)
+
+DEFAULT_ACTIVATION_FN = partial(F.gelu, approximate='tanh')
+
+
+@dataclasses.dataclass
+class Arguments:
+ # Model arguments.
+ hidden_size: int = 1024
+ ffn_hidden_size: int = 4096
+ num_layers: int = 1
+ bias: bool = True
+ return_bias: bool = True
+ activation_fn: Optional[Callable] = DEFAULT_ACTIVATION_FN
+
+ # MoE arguments.
+ moe_num_experts: int = 1
+ moe_top_k: int = 1
+ moe_capacity_factor: int = 1
+ moe_normalize_expert_weights: Optional[Union[int, float]] = None
+ moe_loss_weight: float = 0.1
+ moe_jitter_eps: Optional[float] = None
+ moe_lbl_in_fp32: bool = False
+
+ # Parallelism arguments.
+ moe_expert_model_parallelism: bool = False
+ expert_parallel_group: Optional[dist.ProcessGroup] = None
+ pipeline_model_parallel_size: int = 1
+ num_layers_per_virtual_pipeline_stage: Optional[int] = None
+
+ # Compute arguments.
+ memory_optimized_mlp: bool = False
+ mlp_type: str = 'mlp'
+ mlp_impl: str = 'sparse'
+
+ # Initialization arguments.
+ fp16: bool = True
+ bf16: bool = False
+ device: Union[int, torch.device] = dataclasses.field(default_factory=torch.cuda.current_device)
+ init_method: InitFn = partial(torch.nn.init.normal_, mean=0.0, std=0.02)
+ output_layer_init_method: InitFn = init_method
+
+ # Benchmarking arguments.
+ uniform_expert_assignment: bool = False
+
+ # shared expert arguments
+ shared_expert: bool = False # enable using shared expert
+ fc_cls: Any = torch.nn.Linear # class of the fully connected layer in shared expert (purpose: to allow using custom FC layer eg te.Linear (for FP8))
+ fc_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict,) # kwargs for custom fc layers
+ remat_act_fn: bool = True # enable act fn to be rematerialized instead of stored
+ shared_expert_hidden_size: Optional[
+ int] = None # hidden size of the shared expert IF we want to set it to something different from hidden_size
+ shared_expert_weighted_sum: bool = False # enable using weighted sum for shared expert output (wieghted by number of experts used)
+
+ # Router Z-loss arguments
+ moe_zloss_weight: float = 0 # 1e-3 is a reasonable value
+ moe_zloss_in_fp32: bool = False
+
+ def __post_init__(self):
+ # Sparse MLP is not supported with triton >=3.2.0
+ # TODO: Remove this once sparse is supported with triton >=3.2.0
+ if self.__getattribute__('mlp_impl') == 'sparse':
+ try:
+ import triton
+ if triton.__version__ >= '3.2.0':
+ raise ValueError(
+ 'Sparse MLP is not supported with triton >=3.2.0. Please use mlp_impl="grouped" instead.',
+ )
+ except ImportError:
+ raise ImportError('Triton is required for sparse MLP implementation')
+
+ if self.__getattribute__('mlp_impl') == 'grouped':
+ grouped_gemm.assert_grouped_gemm_is_available()
+
+ if self.shared_expert_hidden_size is None:
+ self.shared_expert_hidden_size = self.ffn_hidden_size
+
+
+def from_megatron(megatron_args: Any):
+ args = Arguments()
+ for field in dataclasses.fields(args):
+ if hasattr(megatron_args, field.name):
+ setattr(args, field.name, getattr(megatron_args, field.name))
+ return args
diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/common.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d07109702963ba48a3b94ab860807954dfd79c1
--- /dev/null
+++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/common.py
@@ -0,0 +1,26 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+from .arguments import Arguments
+
+
+def dtype(args: Arguments):
+ if args.fp16:
+ return torch.float16
+ elif args.bf16:
+ return torch.bfloat16
+ return None
+
+
+def cast_if_autocast_enabled(tensor):
+ if torch.is_autocast_enabled():
+ if tensor.device.type == 'cuda':
+ dtype = torch.get_autocast_gpu_dtype()
+ elif tensor.device.type == 'cpu':
+ dtype = torch.get_autocast_cpu_dtype()
+ else:
+ raise NotImplementedError()
+ return tensor.to(dtype=dtype)
+ return tensor
diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/dmlp_registry.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/dmlp_registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..de2ed047042e438c7190ebb139b6f7f30009734c
--- /dev/null
+++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/dmlp_registry.py
@@ -0,0 +1,42 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Union
+
+from . import glu, mlp
+from .arguments import Arguments
+
+MlpType = Union[mlp.SparseMLP, glu.SparseGLU]
+
+_REGISTRY = {
+ 'mlp': {
+ 'grouped': mlp.GroupedMLP,
+ 'sparse': mlp.SparseMLP,
+ },
+ 'glu': {
+ 'grouped': glu.GroupedGLU,
+ 'sparse': glu.SparseGLU,
+ },
+}
+
+
+def get(args: Arguments) -> MlpType:
+ """Returns an MLP for use in a dMoE instance.
+
+ Uses the provided arguments to instantiate the appropriate
+ MLP instance. This only contains MLPs for use in dMoEs
+ (ie. only for the dropless versions of MoEs).
+
+ Args:
+ args: propagated Arguments dataclass.
+
+ Returns:
+ An instantiated MLP constructed using the input args.
+ """
+ if args.mlp_type not in _REGISTRY:
+ raise ValueError(f'Unsupported mlp type: {args.mlp_type}')
+
+ if args.mlp_impl not in _REGISTRY[args.mlp_type]:
+ raise ValueError(f'{args.mlp_type} does not support {args.mlp_impl} backend.',)
+
+ return _REGISTRY[args.mlp_type][args.mlp_impl](args)
diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/dmoe.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/dmoe.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d0375a4df2f27134c4127e60be04f3b45693050
--- /dev/null
+++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/dmoe.py
@@ -0,0 +1,337 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import numpy as np
+import torch
+
+# try:
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/dmoe.py is needed.',
+# )
+
+# import megablocks.ops as ops
+# # from megablocks.ops import ops
+# from megablocks.layers import common, dmlp_registry, moe, mpu
+# from megablocks.layers.arguments import Arguments
+
+from .. import stk
+from .. import ops
+from . import common, dmlp_registry, moe, mpu
+from .arguments import Arguments
+
+def promote_scalar(x):
+ return x.view(1) if not len(x.size()) else x
+
+
+class ParallelDroplessMLP(moe.ParallelMLP):
+
+ def __init__(self, args: Arguments):
+ super(ParallelDroplessMLP, self).__init__(args)
+ self.hidden_size = args.hidden_size
+ self.ffn_hidden_size = mpu.features_per_rank(args)
+ self.blocking = 128
+ self.mlp = dmlp_registry.get(args)
+
+ # Calculate the number of bits needed to represent the column indices
+ # in the intermediate sparse matrix.
+ max_column_index = ((self.ffn_hidden_size * self.num_experts) // self.blocking)
+ self.transpose_sort_end_bit = max(
+ int(np.ceil(np.log2(max_column_index))),
+ 1,
+ )
+
+ def sparse_transpose(self, size, row_indices, column_indices, offsets):
+ block_columns = size[1] // self.blocking
+
+ # Sort row indices by column indices to get the transposed matrix's
+ # column indices.
+ #
+ # NOTE: Our sort operation uses the same width indices as the input values.
+ # To avoid overflow when we have large activation matrices we cast to
+ # 32-bit before sorting.
+ _, gather_indices = ops.sort(
+ column_indices.int(),
+ self.transpose_sort_end_bit,
+ )
+
+ # There are a constant number of blocks in every row of the sparse matrix.
+ # A blocks offset is:
+ #
+ # row_index * blocks_per_row + column_index % blocks_per_row
+ #
+ # Once we have the block offsets ordered for transposition we can divide
+ # by blocks_per_row to get the transposed column indices.
+ column_indices_t = row_indices.gather(0, gather_indices.long())
+ block_offsets_t = gather_indices.int()
+
+ zero = torch.zeros((1,), dtype=torch.int32, device=row_indices.device)
+ nnz_per_column = ops.histogram(column_indices, block_columns)
+ nnz_per_column = ops.inclusive_cumsum(nnz_per_column, 0)
+ if nnz_per_column.dim() == 0:
+ # This addresses an edge case when ffn_hidden_size is equal to self.blocking.
+ nnz_per_column = nnz_per_column.unsqueeze(0)
+ offsets_t = torch.cat([zero, nnz_per_column])
+ return column_indices_t, offsets_t, block_offsets_t
+
+ def topology(self, x, padded_bins):
+ padded_tokens, _ = x.size()
+ assert padded_tokens % self.blocking == 0
+ if self.ffn_hidden_size % self.blocking != 0:
+ raise ValueError(
+ f'The ffn_hidden_size {self.ffn_hidden_size} must be divisible by ' +
+ f'the block size {self.blocking}. Please update your configuration.',
+ )
+
+ # Offsets for the sparse matrix. All rows have the
+ # same number of nonzero blocks dictated by the
+ # dimensionality of a single expert.
+ block_rows = padded_tokens // self.blocking
+ blocks_per_row = self.ffn_hidden_size // self.blocking
+ offsets = torch.arange(
+ 0,
+ block_rows * blocks_per_row + 1,
+ blocks_per_row,
+ dtype=torch.int32,
+ device=x.device,
+ )
+
+ # Indices for the sparse matrix. The indices for
+ # the intermediate matrix are dynamic depending
+ # on the mapping of tokens to experts.
+ column_indices = ops.topology(
+ padded_bins,
+ self.blocking,
+ block_rows,
+ blocks_per_row,
+ )
+
+ # TODO(tgale): This is unused. Remove the need for this in stk.
+ # For now, use meta init to save the device memory.
+ data = torch.empty(
+ column_indices.numel(),
+ self.blocking,
+ self.blocking,
+ dtype=common.dtype(self.args),
+ device='meta',
+ )
+ shape = (
+ padded_tokens,
+ self.ffn_hidden_size * mpu.experts_per_rank(self.args),
+ )
+ row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
+ column_indices_t, offsets_t, block_offsets_t = self.sparse_transpose(
+ shape,
+ row_indices,
+ column_indices,
+ offsets,
+ )
+ return stk.Matrix(
+ shape,
+ data,
+ row_indices,
+ column_indices,
+ offsets,
+ column_indices_t,
+ offsets_t,
+ block_offsets_t,
+ )
+
+ def indices_and_padded_bins(self, top_experts):
+ # Sort the expert ids to produce the scatter/gather
+ # indices for the permutation.
+ top_experts = top_experts.int()
+ bin_ids, indices = ops.sort(top_experts, self.sort_end_bit)
+
+ # Histogram the expert ids to identify the number of
+ # tokens routed to each expert.
+ tokens_per_expert = ops.histogram(top_experts, self.num_experts)
+
+ # Round the token counts up to the block size used in
+ # the matrix muliplications. Caculate the starting
+ # position of each bin.
+ padded_tokens_per_expert = ops.round_up(
+ tokens_per_expert,
+ self.blocking,
+ )
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ padded_bins = promote_scalar(padded_bins)
+
+ # Calculate the bin bounds for the sorted tokens.
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ bins = promote_scalar(bins)
+ return indices, bin_ids, bins, padded_bins, tokens_per_expert
+
+ def sparse_forward_once(self, x, expert_weights, top_experts):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, padded_bins, tokens_per_expert = (self.indices_and_padded_bins(top_experts))
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.padded_gather(
+ x,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ self.top_k,
+ )
+
+ # Create the sparse matrix topology.
+ with torch.no_grad():
+ topo = self.topology(x, padded_bins)
+
+ # Perform the expert computation.
+ x = self.mlp(x, topo)
+
+ # Un-route the data for the MoE output.
+ x = ops.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ padded_bins,
+ self.top_k,
+ )
+ return x, tokens_per_expert
+
+ # For use in the base-class parallel_forward_once.
+ def sparse_permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy, # unused
+ top_k,
+ ):
+
+ # Round the token counts up to the block size used in the matrix
+ # multiplication. Calculate the starting position of each bin.
+ padded_tokens_per_expert = ops.round_up(
+ tokens_per_expert,
+ self.blocking,
+ )
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ padded_bins = promote_scalar(padded_bins)
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
+
+ # Create the sparse matrix topology.
+ with torch.no_grad():
+ topo = self.topology(x, padded_bins)
+
+ # Perform the expert computation.
+ x = self.mlp(x, topo)
+
+ # Un-route the data for the MoE output.
+ return ops.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ def grouped_forward_once(self, x, expert_weights, top_experts):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ out = self.grouped_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ -1, # unused
+ self.args.moe_top_k,
+ )
+ return out, tokens_per_expert
+
+ def grouped_permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy, # unused
+ top_k,
+ ):
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.gather(x, indices, bin_ids, bins, top_k)
+
+ # Perform the expert computation.
+ x = self.mlp(x, tokens_per_expert)
+
+ # Un-route the data for the MoE output.
+ return ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
+
+ def forward_once(self, x, expert_weights, top_experts):
+ if self.args.mlp_impl == 'sparse':
+ return self.sparse_forward_once(x, expert_weights, top_experts)
+ else:
+ return self.grouped_forward_once(x, expert_weights, top_experts)
+
+ def permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ ):
+ if self.args.mlp_impl == 'sparse':
+ return self.sparse_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ )
+ else:
+ return self.grouped_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ )
+
+
+class dMoE(moe.MoE):
+
+ def _init_experts_mlp(self, args: Arguments):
+ return ParallelDroplessMLP(args)
diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/gelu.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/gelu.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4c9e6532798615b5c12c96694241a4c18ee8f7b
--- /dev/null
+++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/gelu.py
@@ -0,0 +1,52 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# try:
+# import stk
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/gelu.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+import torch.nn.functional as F
+
+
+@torch.jit.script
+def _gelu_backward_inplace(g, x):
+ tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
+ ff = (0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out))
+ return g.mul_(ff)
+
+
+def gelu_backward_(grad: stk.Matrix, x: stk.Matrix):
+ # NOTE: The two sparse matrices must have the same topology.
+ if isinstance(grad, stk.Matrix) and isinstance(x, stk.Matrix):
+ return stk.Matrix(
+ x.size(),
+ _gelu_backward_inplace(grad.data, x.data),
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
+ return _gelu_backward_inplace(grad, x)
+
+
+def gelu(x: stk.Matrix):
+ assert isinstance(x, stk.Matrix)
+ return stk.Matrix(
+ x.size(),
+ F.gelu(x.data, approximate='tanh'),
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/glu.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/glu.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f297a41ff6a1a2a285f5b461951672364b898da
--- /dev/null
+++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/glu.py
@@ -0,0 +1,244 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# import stk.ops
+# try:
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/glu.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+
+# from megablocks import grouped_gemm_util as gg
+# from megablocks.layers import common, mpu
+# from megablocks.layers.activation_fn import act_fn
+# from megablocks.layers.arguments import Arguments
+# from megablocks.layers.mlp import (
+# SharedMLP,
+# SparseMLP,
+# create_dmoe_expert_weights,
+# resolve_dtensor,
+# )
+
+from .. import grouped_gemm_util as gg
+from . import common, mpu
+from .activation_fn import act_fn
+from .arguments import Arguments
+from .mlp import (
+ SharedMLP,
+ SparseMLP,
+ create_dmoe_expert_weights,
+ resolve_dtensor,
+)
+
+
+class SparseGLU(SparseMLP):
+
+ def __init__(self, args: Arguments):
+ super().__init__(args)
+ self.v1 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ with torch.no_grad():
+ self.v1.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ ),
+ )
+
+ mpu.set_expert_model_parallel_attributes(
+ self.v1,
+ self._should_set_parallelism_attribute,
+ )
+
+ def forward(self, x, topo):
+ if self.args.memory_optimized_mlp:
+ raise NotImplementedError(
+ 'Memory optimized implementation not yet supported with GLU with sparse kernels.',
+ )
+
+ w1, v1, w2 = self.scale_grad(self.w1), self.scale_grad(self.v1,), self.scale_grad(self.w2)
+ w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2)
+
+ # Compute the GLU.
+ x1 = stk.ops.sdd(x, w1.t(), topo)
+ x2 = stk.ops.sdd(x, v1.t(), topo)
+
+ activation_fn_out = act_fn(x1, self.args.activation_fn)
+ x1 = stk.ops.mul(activation_fn_out, x2)
+
+ return stk.ops.dsd(x1, w2)
+
+
+class MemoryOptimizedGroupedGLU(torch.autograd.Function):
+ """GroupedMLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, v1, w2, batch_sizes, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ v1 = v1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], v1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not v1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1', 'v1' and 'w2'.")
+
+ # Layer 0: x @ w1.t().
+ assert gg.backend is not None
+ sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True)
+ v1_out = gg.backend.gmm(x, v1, batch_sizes, trans_b=True)
+
+ # GeLU.
+ activation_fn_out = activation_fn(sdd_out) * v1_out
+
+ # Layer 1: x @ w2.
+ dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, v1, w2, batch_sizes, x, sdd_out, v1_out)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx, ddsd_out):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # Unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, v1, w2 = saved_tensors[:3]
+ batch_sizes = saved_tensors[3]
+ x = saved_tensors[4]
+ sdd_out, v1_out = saved_tensors[5:7]
+
+ # Rematerialize activation_fn output.
+ activation_fn = ctx.activation_fn
+ with torch.set_grad_enabled(True):
+ sdd_out.requires_grad = True
+ v1_out.requires_grad = True
+ activation_fn_out = activation_fn(sdd_out) * v1_out
+ activation_grad_fn = activation_fn_out.backward
+
+ # Compute dw2 with recomputed activation_fn output.
+ assert gg.backend is not None
+ dw2 = gg.backend.gmm(
+ activation_fn_out,
+ ddsd_out,
+ batch_sizes,
+ trans_a=True,
+ )
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ gg.backend.gmm(
+ ddsd_out,
+ w2,
+ batch_sizes,
+ trans_b=True,
+ c=dactivation_fn_out,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out)
+ dsdd_out = sdd_out.grad
+ dv1_out = v1_out.grad
+
+ # Compute dw1.
+ dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True)
+
+ # Compute dv1.
+ dv1 = gg.backend.gmm(dv1_out, x, batch_sizes, trans_a=True)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ dx = ddsd_out
+ gg.backend.gmm(dsdd_out, w1, batch_sizes, c=dx)
+ dx += gg.backend.gmm(dv1_out, v1, batch_sizes)
+ return dx, dw1, dv1, dw2, None, None
+
+
+memory_optimized_grouped_glu = MemoryOptimizedGroupedGLU.apply
+
+
+class GroupedGLU(SparseGLU):
+
+ def forward(self, x, tokens_per_expert):
+ batch_sizes = tokens_per_expert.cpu().to(torch.long)
+ w1, v1, w2 = (
+ self.scale_grad(self.w1),
+ self.scale_grad(self.v1),
+ self.scale_grad(self.w2),
+ )
+ w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2)
+
+ # Re-shape the weights for the grouped GEMMs.
+ ne = mpu.experts_per_rank(self.args)
+ w1 = w1.view(ne, -1, self.args.hidden_size)
+ v1 = v1.view(ne, -1, self.args.hidden_size)
+ w2 = w2.view(ne, -1, self.args.hidden_size)
+
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_grouped_glu(
+ x,
+ w1,
+ v1,
+ w2,
+ batch_sizes,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ assert gg.ops is not None
+ x1 = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
+ x2 = gg.ops.gmm(x, v1, batch_sizes, trans_b=True)
+ x1 = self.args.activation_fn(x1) * x2
+ return gg.ops.gmm(x1, w2, batch_sizes)
+
+
+class SharedGLU(SharedMLP):
+ """GPU for shared expert.
+
+ Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTGLU class
+ """
+
+ def __init__(self, args: Arguments):
+ super().__init__(args)
+ self.gate_proj = args.fc_cls(
+ args.hidden_size,
+ self.args.shared_expert_hidden_size,
+ **self.fc_kwargs,
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))
diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/memory_test.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/memory_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..74d1166931b712635131985b25a89f4ca23e576d
--- /dev/null
+++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/memory_test.py
@@ -0,0 +1,103 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import gc
+
+import torch
+import torch.distributed as dist
+
+# from megablocks.layers import arguments, dmoe
+from . import arguments, dmoe
+
+_TESTS = ((8, 2048, 4096, 4096, 32, 4),)
+
+
+def get_tensors():
+ ptrs = set()
+ out = []
+ for obj in gc.get_objects():
+ if torch.is_tensor(obj):
+ if not obj.is_contiguous() or obj.data_ptr() in ptrs:
+ continue
+ out.append(obj)
+ ptrs.add(obj.data_ptr())
+ return out
+
+
+def test_memory(
+ group,
+ batch_size,
+ sequence_length,
+ hidden_size,
+ ffn_hidden_size,
+ num_experts,
+ top_k,
+):
+ args = arguments.Arguments(
+ hidden_size=hidden_size,
+ ffn_hidden_size=ffn_hidden_size,
+ moe_num_experts=num_experts,
+ moe_top_k=top_k,
+ moe_expert_model_parallelism=True,
+ expert_parallel_group=group,
+ fp16=False,
+ bf16=True,
+ device=torch.cuda.current_device(),
+ )
+ layer = dmoe.dMoE(args).cuda()
+
+ x = torch.randn((batch_size, sequence_length, hidden_size),
+ device=torch.cuda.current_device(),
+ dtype=torch.bfloat16).requires_grad_(True)
+ torch.cuda.empty_cache()
+
+ # Run forward + backward.
+ # with torch.autograd.detect_anomaly():
+ out, _ = layer(x)
+ out.mean().backward()
+
+ # Report peak memory.
+ mem = torch.cuda.max_memory_allocated()
+ print('Max Memory Allocated = {:0.0f}MiB'.format(mem / 1e6))
+ print('Max Memory Reserved = {:0.0f}MiB'.format(torch.cuda.max_memory_reserved() / 1e6,),)
+
+ # Calculate weight and gradient memory usage.
+ weight_memory = 2 * (
+ layer.router.layer.weight.numel() + layer.experts.mlp.w1.numel() + layer.experts.mlp.w2.numel()
+ )
+
+ def grad_numel(x):
+ if x.grad is not None:
+ return x.grad.numel()
+ return 0
+
+ grad_memory = 2 * (
+ grad_numel(layer.router.layer.weight) + grad_numel(layer.experts.mlp.w1) + grad_numel(layer.experts.mlp.w2)
+ )
+ weight_memory += grad_memory
+
+ print('Weight Memory Allocated = {:0.0f}MiB'.format(weight_memory / 1e6))
+ print('Activation Memory Allocated = {:0.0f}MiB'.format((mem - weight_memory) / 1e6,),)
+
+ # Manually calculate GPU memory usage from the garbage
+ # collector.
+ gc.collect()
+ total = 0
+ tensors = get_tensors()
+ tensors = sorted(tensors, key=lambda x: -x.numel())
+ for i, t in enumerate(tensors):
+ total += t.numel()
+ print(f'{i}: {t.shape}, {t.numel() * 2}')
+ del tensors
+
+ print('Total Bytes Found = {:0.0f}MiB'.format(total * 2 / 1e6))
+
+
+if __name__ == '__main__':
+ assert dist.is_available()
+ group = dist.init_process_group(backend='nccl')
+ local_rank = dist.get_rank(group)
+ torch.cuda.set_device(local_rank)
+
+ for args in _TESTS:
+ test_memory(group, *args)
diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/mlp.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/mlp.py
new file mode 100644
index 0000000000000000000000000000000000000000..c99afb9904c24a8b6a83e79059cd1251dbbfd99e
--- /dev/null
+++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/mlp.py
@@ -0,0 +1,587 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# try:
+# import stk
+# import stk.backend.triton_kernels
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/mlp.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+from packaging import version
+
+# from megablocks import grouped_gemm_util as gg
+# from megablocks.layers import common, gelu, mpu
+# from megablocks.layers.activation_fn import act_fn
+# from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
+
+from .. import grouped_gemm_util as gg
+from . import common, gelu, mpu
+from .activation_fn import act_fn
+from .arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
+
+class ScaleGradient(torch.autograd.Function):
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx: Any, x: torch.Tensor, scale: float):
+ ctx.scale = scale
+ return x
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx: torch.Tensor, grad: torch.Tensor):
+ return grad * ctx.scale, None
+
+
+scale_gradient = ScaleGradient.apply
+
+
+def resolve_dtensor(weight: torch.Tensor):
+ if version.parse(torch.__version__) >= version.parse('2.0.0'):
+ from torch.distributed._tensor import DTensor
+ if isinstance(weight, DTensor):
+ return weight.to_local()
+ return weight
+
+
+def create_moe_expert_weights(
+ args: Arguments,
+ num_experts: int,
+ ffn_hidden_size: int,
+ hidden_size: int,
+ init_method: InitFn,
+):
+ # Create the entire weight matrix such that the sampled weights will
+ # not vary between data parallelism and expert model parallelism for
+ # the same random seed.
+ master_weights = torch.empty(
+ num_experts,
+ ffn_hidden_size,
+ hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ )
+ init_method(master_weights)
+
+ if not args.moe_expert_model_parallelism:
+ return master_weights
+
+ # Calculate the amount of sharding in each dimension.
+ expert_sharding_degree = mpu.expert_sharding_degree(args)
+ hidden_sharding_degree = mpu.hidden_sharding_degree(args)
+
+ # Calculate the experts per rank.
+ #
+ # NOTE: We assign ranks to be expert parallel before going
+ # tensor parallel.
+ rank = mpu.get_expert_parallel_rank(args)
+ expert_rank = rank % expert_sharding_degree
+ num_experts_per_rank = num_experts // expert_sharding_degree
+ start_expert = expert_rank * num_experts_per_rank
+ end_expert = (expert_rank + 1) * num_experts_per_rank
+
+ # Calculate the rows per rank.
+ row_rank = rank // expert_sharding_degree
+ num_rows_per_rank = ffn_hidden_size // hidden_sharding_degree
+ start_row = row_rank * num_rows_per_rank
+ end_row = (row_rank + 1) * num_rows_per_rank
+
+ # Slice the weight matrix to get the chunk for this rank.
+ with torch.no_grad():
+ weights = master_weights[start_expert:end_expert, start_row:end_row]
+ return weights
+
+
+class MLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ # expert_parallel_world_size = mpu.get_expert_parallel_world_size(args)
+ experts_per_rank = mpu.experts_per_rank(args)
+
+ self.w1 = torch.nn.Parameter(
+ torch.empty(
+ experts_per_rank,
+ args.hidden_size,
+ mpu.features_per_rank(args),
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ self.w2 = torch.nn.Parameter(
+ torch.empty(
+ experts_per_rank,
+ mpu.features_per_rank(args),
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w1,
+ args.moe_expert_model_parallelism,
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w2,
+ args.moe_expert_model_parallelism,
+ )
+
+ # Initialize the parameters for the MLP.
+ #
+ # NOTE: It is important that we create the weight tensors prior
+ # to creating the master weights and slicing our the piece for
+ # this rank. If the master weights are created first the PyTorch
+ # caching allocator appears to use the same memory block for these
+ # and the slice which causes large increases in our peak memory
+ # usage.
+ with torch.no_grad():
+ w1 = create_moe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ )
+ self.w1.copy_(w1.transpose(1, 2).contiguous())
+ self.w2.copy_(
+ create_moe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.output_layer_init_method,
+ ),
+ )
+
+ self.gradient_scale = None
+ if self.args.moe_expert_model_parallelism:
+ self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,)
+
+ def scale_grad(self, w):
+ if self.gradient_scale is None:
+ return w
+ return scale_gradient(w, self.gradient_scale)
+
+ def forward(self, x):
+ w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2)
+ w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2)
+ x = torch.bmm(x, w1)
+ x = self.args.activation_fn(x)
+ return torch.bmm(x, w2)
+
+
+def create_dmoe_expert_weights(
+ args: Arguments,
+ num_experts: int,
+ rows: int,
+ columns: int,
+ init_method: InitFn,
+):
+ weights = create_moe_expert_weights(
+ args,
+ num_experts,
+ rows,
+ columns,
+ init_method,
+ )
+ return weights.view([-1, columns])
+
+
+class MemoryOptimizedMLP(torch.autograd.Function):
+ """Sparse MLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, w2, topo, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.")
+
+ topo_tensors = (
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets,
+ topo.column_indices_t,
+ topo.offsets_t,
+ topo.block_offsets_t,
+ )
+
+ # Layer 0: x @ w1.t().
+ sdd_out = stk.ops.sdd(x, w1.t(), topo)
+
+ # GeLU.
+ activation_fn_out = act_fn(sdd_out, activation_fn)
+
+ # Layer 1: x @ w2.
+ dsd_out = stk.ops.dsd(activation_fn_out, w2)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.shape = topo.shape
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.data.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, w2, *topo_tensors, x, sdd_out.data)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx, ddsd_out):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, w2 = saved_tensors[:2]
+ topo_tensors = saved_tensors[2:8]
+ x = saved_tensors[8]
+ sdd_out_data = saved_tensors[9]
+
+ # rematerialize activation function output
+ activation_fn = ctx.activation_fn
+ sdd_out = stk.Matrix(ctx.shape, sdd_out_data, *topo_tensors)
+ activation_fn_out, activation_grad_fn = act_fn(
+ sdd_out,
+ activation_fn,
+ return_grad_fn=True,
+ )
+
+ # Compute dw2 with recomputed activation_fn output.
+ dw2 = stk.ops.dsd(activation_fn_out.t(), ddsd_out)
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ stk.backend.triton_kernels.sdd(
+ ddsd_out,
+ w2.t(),
+ dactivation_fn_out.shape,
+ dactivation_fn_out.data,
+ dactivation_fn_out.offsets,
+ dactivation_fn_out.row_indices,
+ dactivation_fn_out.column_indices,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ if activation_fn is DEFAULT_ACTIVATION_FN:
+ dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out)
+ else:
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out.data)
+ dsdd_out = stk.Matrix(ctx.shape, sdd_out.data.grad, *topo_tensors)
+
+ # Compute dw1.
+ dw1 = stk.ops.dsd(dsdd_out.t(), x)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ stk.backend.triton_kernels.dsd(
+ dsdd_out.shape,
+ dsdd_out.data,
+ dsdd_out.offsets,
+ dsdd_out.row_indices,
+ dsdd_out.column_indices,
+ dsdd_out.offsets_t,
+ dsdd_out.column_indices_t,
+ dsdd_out.block_offsets_t,
+ False,
+ w1,
+ ddsd_out,
+ )
+ dx = ddsd_out
+ return dx, dw1, dw2, None, None
+
+
+memory_optimized_mlp = MemoryOptimizedMLP.apply
+
+
+class SparseMLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ self._num_rows_per_rank = mpu.experts_per_rank(args) * mpu.features_per_rank(args)
+
+ self.w1 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ self.w2 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+
+ # Initialize the parameters for the MLP.
+ #
+ # NOTE: It is important that we create the weight tensors prior
+ # to creating the master weights and slicing our the piece for
+ # this rank. If the master weights are created first the PyTorch
+ # caching allocator appears to use the same memory block for these
+ # and the slice which causes large increases in our peak memory
+ # usage.
+ with torch.no_grad():
+ self.w1.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ ),
+ )
+ self.w2.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.output_layer_init_method,
+ ),
+ )
+
+ self._should_set_parallelism_attribute = args.moe_expert_model_parallelism
+ mpu.set_expert_model_parallel_attributes(
+ self.w1,
+ self._should_set_parallelism_attribute,
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w2,
+ self._should_set_parallelism_attribute,
+ )
+
+ self.gradient_scale = None
+ if self.args.moe_expert_model_parallelism:
+ self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,)
+
+ def scale_grad(self, w):
+ if self.gradient_scale is None:
+ return w
+ return scale_gradient(w, self.gradient_scale)
+
+ def forward(self, x, topo):
+ w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2)
+ w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2)
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_mlp(
+ x,
+ w1,
+ w2,
+ topo,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ x = stk.ops.sdd(x, w1.t(), topo)
+ activation_fn_out = act_fn(x, self.args.activation_fn)
+ return stk.ops.dsd(activation_fn_out, w2)
+
+
+class MemoryOptimizedGroupedMLP(torch.autograd.Function):
+ """GroupedMLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, w2, batch_sizes, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.")
+
+ # Layer 0: x @ w1.t().
+ assert gg.backend is not None
+ sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True)
+
+ # activation_fn
+ activation_fn_out = activation_fn(sdd_out)
+
+ # Layer 1: x @ w2.
+ dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, w2, batch_sizes, x, sdd_out)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx: Any, ddsd_out: torch.Tensor):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # Unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, w2 = saved_tensors[:2]
+ batch_sizes = saved_tensors[2]
+ x = saved_tensors[3]
+ sdd_out = saved_tensors[4]
+
+ # Rematerialize activation_fn output.
+ activation_fn = ctx.activation_fn
+ with torch.set_grad_enabled(True):
+ sdd_out.requires_grad = True
+ activation_fn_out = activation_fn(sdd_out)
+ activation_grad_fn = activation_fn_out.backward
+
+ # Compute dw2 with recomputed activation_fn output.
+ assert gg.backend is not None
+ dw2 = gg.backend.gmm(
+ activation_fn_out,
+ ddsd_out,
+ batch_sizes,
+ trans_a=True,
+ )
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ gg.backend.gmm(
+ ddsd_out,
+ w2,
+ batch_sizes,
+ trans_b=True,
+ c=dactivation_fn_out,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ if activation_fn is DEFAULT_ACTIVATION_FN:
+ dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out)
+ else:
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out)
+ dsdd_out = sdd_out.grad
+
+ # Compute dw1.
+ dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ gg.backend.gmm(dsdd_out, w1, batch_sizes, c=ddsd_out)
+ dx = ddsd_out
+ return dx, dw1, dw2, None, None
+
+
+memory_optimized_grouped_mlp = MemoryOptimizedGroupedMLP.apply
+
+
+class GroupedMLP(SparseMLP):
+
+ def forward(self, x, tokens_per_expert):
+ batch_sizes = tokens_per_expert.cpu().to(torch.long)
+ w1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.w2))
+
+ # Re-shape the weights for the grouped GEMMs.
+ ne = mpu.experts_per_rank(self.args)
+ w1 = resolve_dtensor(w1).view(ne, -1, self.args.hidden_size)
+ w2 = resolve_dtensor(w2).view(ne, -1, self.args.hidden_size)
+
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_grouped_mlp(
+ x,
+ w1,
+ w2,
+ batch_sizes,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ assert gg.ops is not None
+ x = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
+ x = self.args.activation_fn(x)
+ return gg.ops.gmm(x, w2, batch_sizes)
+
+
+class SharedMLP(torch.nn.Module):
+ """MLP for shared expert.
+
+ Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTMLP class
+ """
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ self.fc_kwargs: dict[str, Any] = {
+ 'bias': args.bias,
+ 'device': args.device,
+ }
+ self.fc_kwargs.update(args.fc_kwargs)
+
+ self.up_proj = args.fc_cls(
+ args.hidden_size,
+ args.shared_expert_hidden_size,
+ **self.fc_kwargs,
+ )
+ self.act = args.activation_fn
+ self.down_proj = args.fc_cls(
+ args.shared_expert_hidden_size,
+ args.hidden_size,
+ **self.fc_kwargs,
+ )
+ self.down_proj._is_residual = True # a flag for llm-foundry init
+
+ def add_experts_sharedexpert(
+ self,
+ shared_expert_out: torch.Tensor,
+ expert_out: torch.Tensor,
+ ) -> torch.Tensor:
+ # Helper function to add expert output to shared expert output
+ # with optional weighted sum.
+ if self.args.shared_expert_weighted_sum:
+ # enable using weighted sum for shared expert output
+ # wieghted by number of experts used
+ t_experts = self.args.moe_top_k + 1
+ sh_mlp_out = shared_expert_out / t_experts
+ return sh_mlp_out.add(
+ expert_out,
+ alpha=(self.args.moe_top_k / t_experts),
+ )
+
+ return shared_expert_out + expert_out
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.down_proj(self.act(self.up_proj(x)))
diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/moe.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0a4aeaacc9c86fc70944e730c53f7a55644e05e
--- /dev/null
+++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/moe.py
@@ -0,0 +1,507 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Optional, Tuple
+
+import numpy as np
+import torch
+import torch.distributed as dist
+
+# import megablocks.ops as ops
+# from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry
+# from megablocks.layers.all_to_all import all_to_all
+# from megablocks.layers.arguments import Arguments
+
+from ..ops import (
+ sort,
+ histogram,
+ inclusive_cumsum,
+ exclusive_cumsum,
+ binned_gather,
+ binned_scatter,
+ gather,
+ scatter,
+ repeat,
+ replicate,
+)
+
+from . import common, mlp, mpu, router, sharedexpert_registry
+from .arguments import Arguments
+from .all_to_all import all_to_all
+
+_LOAD_BALANCING_LOSS = []
+
+
+def save_load_balancing_loss(loss):
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.append(loss)
+
+
+def get_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ return _LOAD_BALANCING_LOSS
+
+
+def clear_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.clear()
+
+
+def batched_load_balancing_loss(args: Arguments):
+ if args.moe_loss_weight == 0:
+ return 0.0
+
+ # tokens_per_expert[i].shape = (num_experts)
+ # expert_scores[i].shape = (tokens, num_experts)
+ tokens_per_expert, expert_scores = zip(*get_load_balancing_loss())
+ num_layers_per_pipeline_stage = (args.num_layers // args.pipeline_model_parallel_size)
+ if args.num_layers_per_virtual_pipeline_stage is not None:
+ num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage
+
+ if len(tokens_per_expert) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f'Expected {num_layers_per_pipeline_stage} token_per_experts '
+ f'but found {len(tokens_per_expert)}.\nnum_layers = '
+ f'{args.num_layers}\npipeline_model_parallel_size = '
+ f'{args.pipeline_model_parallel_size}\n'
+ 'num_layers_per_virtual_pipeline_stage'
+ f' = {args.num_layers_per_virtual_pipeline_stage}',
+ )
+ if len(expert_scores) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f'Expected {num_layers_per_pipeline_stage} expert_scores '
+ f'but found {len(tokens_per_expert)}.\nnum_layers = '
+ f'{args.num_layers}\npipeline_model_parallel_size = '
+ f'{args.pipeline_model_parallel_size}\n'
+ 'num_layers_per_virtual_pipeline_stage'
+ f' = {args.num_layers_per_virtual_pipeline_stage}',
+ )
+
+ # Verify the shape of the tokens_per_expert and expert_scores tensors.
+ assert all((x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert))
+
+ tokens = expert_scores[0].shape[0]
+ assert all(((x.ndim == 2 and x.shape[1] == args.moe_num_experts and x.shape[0] == tokens) for x in expert_scores))
+
+ # Concatenate the contributions of each layer and convert to
+ # the correct types and formats for the dot product.
+ expert_scores = torch.cat(expert_scores, dim=1)
+ if args.moe_lbl_in_fp32:
+ expert_scores = expert_scores.float()
+ if tokens != 0:
+ expert_scores = expert_scores.mean(dim=0)
+ else:
+ expert_scores = expert_scores.sum(dim=0)
+ tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype)
+
+ expected_values = num_layers_per_pipeline_stage * args.moe_num_experts
+ assert tokens_per_expert.numel() == expected_values
+ assert expert_scores.numel() == expected_values
+
+ # Calculate the total scale across all factors.
+ #
+ # loss_weight * num_experts / (num_layers * tokens * top_k)
+ scale_numerator = (args.moe_num_experts * args.moe_loss_weight)
+ scale_denominator = (args.num_layers * tokens * args.moe_top_k)
+ scale = scale_numerator / scale_denominator
+ return scale * torch.dot(tokens_per_expert, expert_scores)
+
+
+# NOTE: This class defines MoE expert computation, including expert model parallel
+# communication. When using FSDP on top of MegaBlocks this is the module that should
+# be wrapped s.t. the weight all-gathers can be scheduled *before* the expert model
+# parallel all2all.
+class ParallelMLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super(ParallelMLP, self).__init__()
+ self.args = args
+
+ # Calculate the number of experts in total and the number of experts
+ # owned by this rank.
+ # world_size = mpu.get_expert_parallel_world_size(args)
+ self.num_experts = args.moe_num_experts
+ self.top_k = self.args.moe_top_k
+
+ # Calculate the number of bits needed to represent the expert indices
+ # so that we can pass it to radix sort.
+ self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1)
+
+ # Expert MLP.
+ self.mlp = mlp.MLP(args)
+
+ self.bias: Optional[torch.Tensor]
+ if self.args.bias:
+ # Note that the output bias is not parallelized with expert
+ # model parallelism.
+ self.bias = torch.nn.Parameter(
+ torch.empty(
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ torch.nn.init.zeros_(self.bias)
+ else:
+ self.register_parameter('bias', None)
+
+ # Select the forward function for the operating mode.
+ self.forward_fn = (self.parallel_forward_once if args.moe_expert_model_parallelism else self.forward_once)
+
+ def expert_capacity(self, tokens: int) -> int:
+ world_size = mpu.get_expert_parallel_world_size(self.args)
+ tokens_per_expert = (self.top_k * tokens * world_size / self.num_experts)
+ return int(self.args.moe_capacity_factor * tokens_per_expert)
+
+ def load_balancing_loss(self, tokens_per_expert: torch.Tensor, expert_scores: torch.Tensor):
+ """Calculate the load balancing loss contribution."""
+ assert len(expert_scores.size()) == 2
+ tokens, num_experts = expert_scores.size()
+ assert num_experts == self.num_experts
+ assert len(tokens_per_expert.size()) == 1
+ num_experts, = tokens_per_expert.size()
+ assert num_experts == self.num_experts
+ scale = self.num_experts / (tokens * self.top_k)
+ return scale * torch.dot(
+ tokens_per_expert.to(expert_scores.dtype),
+ expert_scores.mean(dim=0),
+ )
+
+ def indices_and_bins(self,
+ top_expert: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ # Sort the expert ids to produce the scatter/gather
+ # indices for the permutation.
+ #
+ # TODO(tgale): Is it worth doing this conversion to 32-bit
+ # prior? Could we place the `torch.max` operation to return
+ # 32-bit expert indices?
+ top_expert = top_expert.int()
+ # output = ops.sort(top_expert, self.sort_end_bit)
+ output = sort(top_expert, self.sort_end_bit)
+ assert output is not None
+ bin_ids, indices = output
+
+ # Histogram the expert ids to identify the number of
+ # tokens routed to each expert.
+ #
+ # TODO(tgale): Does the sorted data produce a more favorable
+ # data distribution for histogram? Or is the op parallelism
+ # worth more?
+ # tokens_per_expert = ops.histogram(top_expert, self.num_experts)
+ tokens_per_expert = histogram(top_expert, self.num_experts)
+
+ # Calculate the bin bounds for the sorted tokens.
+ # bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ bins = inclusive_cumsum(tokens_per_expert, 0)
+ assert bins is not None
+ bins = bins.view(1) if not len(bins.size()) else bins
+
+ assert isinstance(indices, torch.Tensor)
+ assert isinstance(bin_ids, torch.Tensor)
+ assert isinstance(bins, torch.Tensor)
+ assert isinstance(tokens_per_expert, torch.Tensor)
+
+ return indices, bin_ids, bins, tokens_per_expert
+
+ def permute_and_compute(
+ self,
+ x: torch.Tensor,
+ tokens_per_expert: int, # unused
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor, # unused
+ expert_weights: torch.Tensor,
+ bins: torch.Tensor,
+ expert_capacity: int,
+ top_k: int,
+ ):
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ # output = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
+ output = binned_gather(x, indices, bins, expert_capacity, top_k)
+ assert output is not None
+ x = output
+
+ # Perform the expert computation. Note that we don't
+ # use biases for these linear operations.
+ x = self.mlp(x)
+
+ # Un-route the data for the MoE output.
+ # return ops.binned_scatter(x, indices, expert_weights, bins, top_k)
+ return binned_scatter(x, indices, expert_weights, bins, top_k)
+
+
+ def forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ # If expert_capacity is set to zero, set the number of tokens
+ # per expert to the maximum we need to avoid dropping tokens.
+ sl, bs, _ = x.size()
+ expert_capacity = self.expert_capacity(sl * bs)
+ if expert_capacity == 0:
+ expert_capacity = torch.max(tokens_per_expert).item()
+
+ x = self.permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ self.top_k,
+ )
+ return x, tokens_per_expert
+
+ def parallel_forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ # NOTE: This function implements the same computation as forward_once
+ # but with expert model parallelism.
+ #
+ # 1. Permute the tokens locally so that they are grouped by their
+ # expert assignments. This allows us to transfer all of the tokens
+ # for a remote device in one communication primitive.
+ #
+ # 2. Permute the tokens across the expert parallel devices. After
+ # this is completed each device has all of the tokens assigned to
+ # its set of experts in its local HBM.
+ #
+ # 3. Permute the tokens locally so that they are grouped by their
+ # expert assignement. After the distributed permutation the tokens
+ # are grouped by which device they came from. We re-order them
+ # locally to allow for efficient computation.
+ #
+ # After this series of permutations we compute the linear layers
+ # and then repeat these three steps in reverse to produce the final
+ # output.
+ #
+ # Compute the mapping of local tokens to experts.
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ # If we're sharding the experts along the hidden dimension
+ # multiple devices own parts of the same sets of experts.
+ # Replicate the token counts so every device gets the counts.
+ # repeated_tokens_per_expert = ops.repeat(
+ repeated_tokens_per_expert = repeat(
+ tokens_per_expert,
+ (mpu.hidden_sharding_degree(self.args),),
+ )
+
+ # Pass token count information to the device on which the
+ # target expert resides.
+ parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert,)
+ tpe_handle = dist.all_to_all_single(
+ parallel_tokens_per_expert,
+ repeated_tokens_per_expert,
+ group=self.args.expert_parallel_group,
+ async_op=True,
+ )
+
+ # Permute locally and without any padding so that tokens for each
+ # parallel device are stored contiguously.
+ #
+ # This view updates the shape of the tensor from [sl, bs, hs] to
+ # [sl * bs, hs] prior to the permutation.
+ x = x.view(-1, x.shape[-1])
+ # output = ops.gather(x, indices, bin_ids, bins, self.top_k)
+ output = gather(x, indices, bin_ids, bins, self.top_k)
+ assert output is not None
+ x = output
+
+ # Compute the number of tokens that will be received from each
+ # device and permute the input data across the devices.
+ with torch.no_grad():
+ tpe_handle.wait()
+ experts_per_rank = mpu.experts_per_rank(self.args)
+
+ # Reshape to [world_size, num_experts_per_rank].
+ world_size = mpu.get_expert_parallel_world_size(self.args)
+ repeated_tokens_per_expert = (repeated_tokens_per_expert.view(world_size, experts_per_rank))
+ parallel_tokens_per_expert = (parallel_tokens_per_expert.view(world_size, experts_per_rank))
+
+ # TODO(tgale): It might be faster to do this on the GPU and
+ # then communicate the results back to the host.
+ send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1)
+ parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
+ recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1)
+
+ # Convert the send/recv counts to lists.
+ send_counts = send_counts.tolist()
+ recv_counts = recv_counts.tolist()
+ tokens_received = sum(recv_counts)
+
+ # If we're sharding the experts along the hidden dimension
+ # multiple devices own parts of the same sets of experts.
+ # Replicate the token counts so devices that share experts
+ # get all of the tokens assigned to them.
+ #
+ # TODO(tgale): Fuse this into the prior, local permutation.
+ # x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
+ x = repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
+
+ # Start the cross-device permutation asynchronously so we can
+ # overlap communication with computation.
+ parallel_x, parallel_x_handle = all_to_all(
+ x,
+ recv_counts,
+ send_counts,
+ self.args.expert_parallel_group,
+ async_op=True,
+ )
+
+ with torch.no_grad():
+ # After we do the cross-device permutation we have the tokens on the
+ # correct device but not yet grouped by expert because we received
+ # tokens from each device as contiguous chunks. To group the tokens
+ # for expert computation we'll do one more local permutation. The
+ # rest of this torch.no_grad() scope sets up the indices and bins
+ # for this permutation.
+ # replicate_bins = ops.inclusive_cumsum(
+ replicate_bins = inclusive_cumsum(
+ parallel_tokens_per_expert.flatten(),
+ 0,
+ )
+ replicate_bins = (replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins)
+
+ # Construct the expert indices for the permuted tokens.
+ parallel_top_expert = torch.remainder(
+ torch.arange(
+ self.num_experts * mpu.hidden_sharding_degree(self.args),
+ dtype=torch.int32,
+ device=indices.device,
+ ),
+ mpu.experts_per_rank(self.args),
+ )
+ # parallel_top_expert = ops.replicate(
+ parallel_top_expert = replicate(
+ parallel_top_expert.unsqueeze(dim=0),
+ replicate_bins,
+ tokens_received,
+ ).flatten()
+
+ # TODO(tgale): The sort_end_bit here can be reduced.
+ # parallel_bin_ids, parallel_indices = ops.sort(
+ parallel_bin_ids, parallel_indices = sort(
+ parallel_top_expert,
+ self.sort_end_bit,
+ )
+
+ # Calculate the bins boundaries from the token counts.
+ parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+ # parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = (parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins)
+
+ # If expert_capacity is set to zero, set the number of tokens
+ # per expert to the maximum we need to avoid dropping tokens.
+ tokens, _ = x.size()
+ expert_capacity = self.expert_capacity(tokens)
+ if expert_capacity == 0:
+ expert_capacity = torch.max(parallel_tokens_per_expert).item()
+
+ # Locally permute the tokens and perform the expert computation.
+ # Block to make sure that the cross-device permutation is complete.
+ if self.args.mlp_impl == 'grouped':
+ # GroupedMLP requires counts on CPU. We can use the tensor already
+ # moved to CPU for the prior all_to_all, which avoids an extra
+ # device synchronization.
+ parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+ parallel_x_handle.wait()
+ parallel_x = self.permute_and_compute(
+ parallel_x,
+ parallel_tokens_per_expert,
+ parallel_indices,
+ parallel_bin_ids,
+ None, # expert_weights
+ parallel_bins,
+ expert_capacity,
+ top_k=1,
+ )
+
+ # Un-permute the tokens across the devices.
+ x, _ = all_to_all(
+ parallel_x,
+ send_counts,
+ recv_counts,
+ self.args.expert_parallel_group,
+ )
+
+ # Reduce along the hidden sharding to get the final outputs.
+ #
+ # TODO(tgale): Fuse this into the following local permutation.
+ shape = (
+ mpu.hidden_sharding_degree(self.args),
+ -1,
+ self.args.hidden_size,
+ )
+ # x = ops.sum(x.view(shape), dim=0)
+ x = x.view(shape).sum(dim=0)
+
+ # Un-permute locally to setup for the next series of operations.
+ # x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
+ x = scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
+ return x, tokens_per_expert.flatten()
+
+ def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ in_shape = x.size()
+
+ # Compute the experts.
+ x, tokens_per_expert = self.forward_fn(x, expert_weights, top_experts)
+ if self.training and self.args.moe_loss_weight > 0:
+ save_load_balancing_loss((tokens_per_expert, scores))
+ x = x.view(in_shape)
+ if self.bias is not None:
+ if self.args.return_bias:
+ return x, self.bias
+ return x + self.bias
+ return x
+
+
+class MoE(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super(MoE, self).__init__()
+
+ # Token router.
+ self.router = router.LearnedRouter(args)
+
+ # Expert computation helper.
+ self.experts = self._init_experts_mlp(args)
+
+ self.shared_expert = None
+ if args.shared_expert:
+ # SharedExpert computation helper.
+ self.shared_expert = sharedexpert_registry.get(args)
+
+ def _init_experts_mlp(self, args: Arguments):
+ return ParallelMLP(args)
+
+ def forward(self, x: torch.Tensor):
+ # NOTE: If we're going to cast the activations to lower precision
+ # do it before we permute the tokens to save bandwidth.
+ x = common.cast_if_autocast_enabled(x)
+
+ # Compute the expert scores and assignments.
+ scores, expert_weights, top_experts = self.router(x)
+
+ # Compute the experts.
+ out = self.experts(x, scores, expert_weights, top_experts)
+ if self.shared_expert is not None:
+ shared_expert_out = self.shared_expert(x)
+ out = self.shared_expert.add_experts_sharedexpert(
+ shared_expert_out,
+ out,
+ )
+ return out
diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/mpu.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/mpu.py
new file mode 100644
index 0000000000000000000000000000000000000000..434e143ab42bf3f83406d69e9dd1f72777716e22
--- /dev/null
+++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/mpu.py
@@ -0,0 +1,94 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Optional
+
+import torch
+import torch.distributed as dist
+
+# from megablocks.layers.arguments import Arguments
+from .arguments import Arguments
+
+
+class MoeParam(torch.Tensor):
+
+ def __init__(self):
+ super().__init__(self)
+ self.expert_model_parallel: bool
+
+
+def is_moe_param(tensor: torch.Tensor) -> bool:
+ return hasattr(tensor, 'expert_model_parallel')
+
+
+def get_expert_parallel_world_size(args: Arguments) -> int:
+ return (dist.get_world_size(args.expert_parallel_group) if args.moe_expert_model_parallelism else 1)
+
+
+def get_expert_parallel_rank(args: Arguments) -> int:
+ return (dist.get_rank(args.expert_parallel_group) if args.moe_expert_model_parallelism else 0)
+
+
+def set_expert_model_parallel_attributes(
+ tensor: torch.Tensor,
+ is_parallel: bool,
+):
+ assert not hasattr(tensor, 'expert_model_parallel')
+ setattr(tensor, 'expert_model_parallel', is_parallel)
+
+
+def param_is_expert_model_parallel(param: MoeParam) -> bool:
+ return (hasattr(param, 'expert_model_parallel') and param.expert_model_parallel)
+
+
+def copy_expert_model_parallel_attributes(
+ destination_tensor: torch.Tensor,
+ source_tensor: torch.Tensor,
+):
+ if hasattr(source_tensor, 'expert_model_parallel'):
+ setattr(
+ destination_tensor,
+ 'expert_model_parallel',
+ getattr(source_tensor, 'expert_model_parallel'),
+ )
+
+
+def synchronized_print(group: Optional[dist.ProcessGroup], *x: torch.Tensor):
+ world_size = dist.get_world_size(group)
+ rank = dist.get_rank(group)
+ for i in range(world_size):
+ dist.barrier(group)
+ if i == rank:
+ print(f'rank = {rank}', *x)
+
+
+# Helpers for expert/tensor sharding.
+def expert_sharding_degree(args: Arguments) -> int:
+ world_size = get_expert_parallel_world_size(args)
+ esd = min(world_size, args.moe_num_experts)
+
+ if (args.moe_num_experts % esd) != 0:
+ raise ValueError(f'Cannot shard {args.moe_num_experts} experts {esd} ways.',)
+ return esd
+
+
+def hidden_sharding_degree(args: Arguments) -> int:
+ world_size = get_expert_parallel_world_size(args)
+ esd = expert_sharding_degree(args)
+ hsd = world_size // esd
+
+ if (args.ffn_hidden_size % hsd) != 0:
+ raise ValueError(f'Cannot shard {args.ffn_hidden_size} features {hsd} ways.',)
+ if (esd * hsd) != world_size:
+ raise ValueError(
+ f"Invalid sharding. 'expert_sharding_degree' ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size}).",
+ )
+ return hsd
+
+
+def experts_per_rank(args: Arguments) -> int:
+ return args.moe_num_experts // expert_sharding_degree(args)
+
+
+def features_per_rank(args: Arguments) -> int:
+ return args.ffn_hidden_size // hidden_sharding_degree(args)
diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/router.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/router.py
new file mode 100644
index 0000000000000000000000000000000000000000..37cb2782348d62583376f1a183c7ede83601216d
--- /dev/null
+++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/router.py
@@ -0,0 +1,116 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+
+# from megablocks.layers import common
+# from megablocks.layers.arguments import Arguments
+from . import common
+from .arguments import Arguments
+
+_ROUTER_LOGITS = []
+
+
+def _save_router_logits(logits: torch.Tensor, args: Arguments):
+ if args.moe_zloss_weight == 0:
+ return
+ global _ROUTER_LOGITS
+ _ROUTER_LOGITS.append(logits)
+
+
+def clear_router_zloss():
+ global _ROUTER_LOGITS
+ _ROUTER_LOGITS.clear()
+
+
+def batched_router_zloss(args: Arguments):
+ global _ROUTER_LOGITS
+
+ if args.moe_zloss_weight == 0:
+ import warnings
+ warnings.warn('Call to batched_router_zloss, but moe_zloss_weight=0')
+ return 0
+
+ logits_per_router = _ROUTER_LOGITS
+
+ if args.moe_zloss_in_fp32:
+ logits_per_router = [logits.float() for logits in logits_per_router]
+
+ unscaled_zloss_per_router = torch.stack([
+ torch.logsumexp(logits, dim=1).square().mean() for logits in logits_per_router
+ ])
+
+ return args.moe_zloss_weight * unscaled_zloss_per_router
+
+
+# NOTE: To enable end-to-end benchmarking without convergence we
+# support a flag to force the router to assign tokens uniformly
+# across the experts. We do this with a custom autograd operation
+# so that PyTorch still executes the full set of router operation.
+class _UniformExpertAssignment(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, num_experts: int):
+ out = torch.arange(x.numel(), dtype=x.dtype, device=x.device)
+ out = torch.remainder(out, num_experts)
+ return out.view(x.shape)
+
+
+_uniform_expert_assignment = _UniformExpertAssignment.apply
+
+
+class LearnedRouter(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+
+ # Learned router parameters.
+ #
+ # NOTE: This weight matrix is not parallelized with expert model
+ # parallelism. Each device needs the entire router weight matrix
+ # so that it can route its batch of data correctly.
+ self.layer = torch.nn.Linear(
+ args.hidden_size,
+ args.moe_num_experts,
+ bias=False,
+ dtype=common.dtype(args),
+ device=args.device,
+ )
+ args.init_method(self.layer.weight)
+
+ def jitter(self, x: torch.Tensor):
+ low: float = 1.0 - self.args.moe_jitter_eps
+ high: float = 1.0 + self.args.moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return low + noise * (high - low)
+
+ def _top_k(self, scores: torch.Tensor):
+ if self.args.moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, self.args.moe_top_k, dim=-1)
+
+ def forward(self, x: torch.Tensor):
+ if self.training and self.args.moe_jitter_eps is not None:
+ x = x * self.jitter(x)
+
+ logits = self.layer(x.view(-1, x.shape[-1]))
+ _save_router_logits(logits, self.args)
+ scores = logits.softmax(dim=-1)
+ expert_weights, expert_indices = self._top_k(scores)
+ if self.args.moe_normalize_expert_weights:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=self.args.moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+
+ expert_indices = (
+ _uniform_expert_assignment(
+ expert_indices,
+ self.args.moe_num_experts,
+ ) if self.args.uniform_expert_assignment else expert_indices
+ )
+ return scores, expert_weights, expert_indices
diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/sharedexpert_registry.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/sharedexpert_registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..5840862f88f370ace5fd49bd0612fc98d186cc49
--- /dev/null
+++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/sharedexpert_registry.py
@@ -0,0 +1,32 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Union
+
+# from megablocks.layers import glu, mlp
+# from megablocks.layers.arguments import Arguments
+from . import glu, mlp
+from .arguments import Arguments
+
+_REGISTRY = {
+ 'mlp': mlp.SharedMLP,
+ 'glu': glu.SharedGLU,
+}
+
+
+def get(args: Arguments) -> Union[mlp.SharedMLP, glu.SharedGLU]:
+ """Returns an SharedMLP for use in a dMoE instance.
+
+ Uses the provided arguments to instantiate the appropriate
+ SharedMLP instance.
+
+ Args:
+ args: propagated Arguments dataclass.
+
+ Returns:
+ An instantiated SharedMLP constructed using the input args.
+ """
+ if args.mlp_type not in _REGISTRY:
+ raise ValueError(f'Unsupported mlp type: {args.mlp_type}')
+
+ return _REGISTRY[args.mlp_type](args)
diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_megablocks_885c7a2.abi3.so b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_megablocks_885c7a2.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..78f4bd294d6978983cb2f4940fe24d66fc47c5f8
--- /dev/null
+++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_megablocks_885c7a2.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ac18b8df258ecd2e14581e6ec042b55d296f63ec1ab9d1704d5332a0d9cef05a
+size 11918752
diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_ops.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ec290dd41dd30ed4551035db04f6c85ee1a0fe0
--- /dev/null
+++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_ops.py
@@ -0,0 +1,9 @@
+import torch
+from . import _megablocks_885c7a2
+ops = torch.ops._megablocks_885c7a2
+
+def add_op_namespace_prefix(op_name: str):
+ """
+ Prefix op by namespace.
+ """
+ return f"_megablocks_885c7a2::{op_name}"
\ No newline at end of file
diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_version.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_version.py
new file mode 100644
index 0000000000000000000000000000000000000000..c55783177af19bc03654c730c4892df8f8532279
--- /dev/null
+++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_version.py
@@ -0,0 +1,6 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+"""The MegaBlocks Version."""
+
+__version__ = '0.11.0.dev0'
diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/backend/__init__.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/backend/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d4e43e9b7e6a9a1c3bde2df34914643ca5d8332
--- /dev/null
+++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/backend/__init__.py
@@ -0,0 +1,2 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/backend/kernels.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/backend/kernels.py
new file mode 100644
index 0000000000000000000000000000000000000000..d20139324883338ddb312e4b05a72056d74491ac
--- /dev/null
+++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/backend/kernels.py
@@ -0,0 +1,557 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import triton
+import triton.language as tl
+
+# Stub triton autotune when testing in a env that does not have CUDA
+# this approach preserves the original code but enables testing without a GPU
+if torch.cuda.is_available() is False:
+ import warnings
+
+ warnings.warn("CUDA is not available. Triton autotuning is disabled.")
+
+ def _no_autotune(*args, **kwargs):
+ def deco(fn):
+ return fn
+ return deco
+
+ triton.autotune = _no_autotune
+
+
+def assert_is_tensor(x, ndim):
+ if x.ndim != ndim:
+ raise ValueError(f'Expected {ndim}-tensor but got {x.ndim}-tensor')
+
+
+def assert_is_matrix(x):
+ assert_is_tensor(x, 2)
+
+
+def assert_is_vector(x):
+ if x.ndim != 1:
+ raise ValueError(f'Expected 1-tensor but got {x.ndim}-tensor')
+
+
+def assert_equal(a, b):
+ if a != b:
+ raise ValueError(f'Expected dimensions to be equal but got {a} and {b}.',)
+
+
+# a: (tokens, hidden_size), real.
+# indices: (tokens * top_k), integer.
+# bin_ids: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+# padded_bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _padded_copy(
+ a,
+ b,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+ A_TO_B: tl.constexpr,
+ SCALE: tl.constexpr,
+):
+ # Our index into array 'a'.
+ index_a = tl.load(indices + tl.program_id(0))
+
+ # One threadblock per row in 'a'. Array 'b' has greater or equal
+ # number of rows since they could be padded.
+ bin_idx = tl.load(bin_ids + tl.program_id(0))
+
+ # Now we know what bin we're assigned to, but we need to know how
+ # many threadblocks were assigned to earlier bins so we can offset
+ # in our bin properly.
+ offset_in_bin = tl.program_id(0)
+ if bin_idx > 0:
+ offset_in_bin -= tl.load(bins + bin_idx - 1)
+
+ # Load the starting index of our bin in array 'b'.
+ index_b = offset_in_bin
+ if bin_idx > 0:
+ index_b += tl.load(padded_bins + bin_idx - 1)
+
+ # Offset the input and output pointers.
+ #
+ # If we're going from A to B, divide the input index to copy
+ # the same input repeatedly. If we're going from B to A we
+ # need to reduce the result. Using atomics is slow, so we
+ # do the reduce step in a second kernel.
+ offset = index_a // TOP_K if A_TO_B else index_a
+ a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS)
+ b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ # Load the scale, if requested.
+ scale = tl.load(weights + index_a) if SCALE else 1
+
+ # Swap the pointers depending on the direction.
+ iptr = a if A_TO_B else b
+ optr = b if A_TO_B else a
+
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ x = tl.load(iptr + offsets, mask=mask)
+ x = x.to(tl.float32) * scale.to(tl.float32)
+
+ tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask)
+
+ offsets += BLOCK_X
+
+
+def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+ assert_equal(bin_ids.shape[0], x.shape[0] * top_k)
+ assert_equal(bins.size(), padded_bins.size())
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ # NOTE: Because of the padding, the output size is dynamic.
+ # We load the final padded bin bound to get the output rows.
+ output_rows = padded_bins[-1].cpu().item()
+ out = torch.zeros((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ x,
+ out,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def gather(x, indices, bin_ids, weights, bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+ assert_equal(bin_ids.shape[0], x.shape[0] * top_k)
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ # NOTE: There is no padding so the output rows equals the
+ # input rows multiplied by top_k.
+ output_rows = x.shape[0] * top_k
+ out = torch.empty((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ x,
+ out,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], bin_ids.shape[0])
+ assert_equal(bins.size(), padded_bins.size())
+
+ if weights is not None:
+ assert_equal(indices.shape[0], weights.shape[0])
+
+ tokens = indices.shape[0] // top_k
+ out = torch.empty((tokens, top_k, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ out,
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=False,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+
+ # Reduce along the top-k dimension, if needed.
+ return out.sum(dim=1) if top_k > 1 else out.view(tokens, x.shape[1])
+
+
+def scatter(x, indices, bin_ids, weights, bins, top_k):
+ return padded_scatter(x, indices, bin_ids, weights, bins, bins, top_k)
+
+
+# x: (tokens, top_k, hidden_size), real
+# grad: (tokens, hidden_size), real.
+# wgrad: (tokens, top_k), real.
+# indices: (tokens * top_k), integer.
+# bin_ids: (tokens * top_k), integer.
+# bins: (num_experts), integer.
+# padded_bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _padded_copy_wgrad(
+ x,
+ grad,
+ wgrad,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+):
+ # Our index into 'tokens * top_k'.
+ index_out = tl.load(indices + tl.program_id(0))
+
+ # One threadblock per row in 'a'. Array 'b' has greater or equal
+ # number of rows since they could be padded.
+ bin_idx = tl.load(bin_ids + tl.program_id(0))
+
+ # Now we know what bin we're assigned to, but we need to know how
+ # many threadblocks were assigned to earlier bins so we can offset
+ # in our bin properly.
+ offset_in_bin = tl.program_id(0)
+ if bin_idx > 0:
+ offset_in_bin -= tl.load(bins + bin_idx - 1)
+
+ # Load the starting index of our bin in array 'x'.
+ index_x = offset_in_bin
+ if bin_idx > 0:
+ index_x += tl.load(padded_bins + bin_idx - 1)
+
+ # Offset the input and output pointers.
+ wgrad += index_out
+ grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS)
+ x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ data = tl.load(x + offsets, mask=mask).to(tl.float32)
+ scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
+ acc += data * scale
+ offsets += BLOCK_X
+
+ # Reduce to get the final result and store.
+ out = tl.sum(acc).to(wgrad.dtype.element_ty)
+ tl.store(wgrad, out)
+
+
+def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_matrix(grad)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], bin_ids.shape[0])
+ assert_equal(bins.size(), padded_bins.size())
+
+ tokens = indices.shape[0] // top_k
+ out = torch.empty((tokens * top_k), dtype=x.dtype, device=x.device)
+ _padded_copy_wgrad[(indices.shape[0],)](
+ x,
+ grad,
+ out,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ TOP_K=top_k,
+ )
+ return out
+
+
+def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k):
+ return padded_scatter_wgrad(x, grad, indices, bin_ids, bins, bins, top_k)
+
+
+# a: (tokens, hidden_size), real.
+# b: (num_experts, expert_capacity, num_columns), real.
+# indices: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _binned_copy(
+ a,
+ b,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+ A_TO_B: tl.constexpr,
+ SCALE: tl.constexpr,
+):
+ # Load our indices into the output.
+ expert_idx = tl.program_id(0)
+ entry_idx = tl.program_id(1)
+
+ # Calculate our offset into the output.
+ index_b = expert_idx * expert_capacity + entry_idx
+
+ # Load the index bounds for our bin and calculate
+ # the number of tokens assigned to our expert.
+ start = 0
+ if expert_idx > 0:
+ start = tl.load(bins + expert_idx - 1)
+ end = tl.load(bins + expert_idx)
+ num_tokens = end - start
+
+ # Calculate our offset into the input. If we don't
+ # have an input exit early.
+ if entry_idx >= num_tokens:
+ return
+ index_a = tl.load(indices + start + entry_idx)
+
+ # Offset the input and output pointers.
+ #
+ # If we're going from A to B, divide the input index to copy
+ # the same input repeatedly. If we're going from B to A we
+ # need to reduce the result. Using atomics is slow, so we
+ # do the reduce step in a second kernel.
+ offset = index_a // TOP_K if A_TO_B else index_a
+ a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS)
+ b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ # Load the scale, if requested.
+ scale = tl.load(weights + index_a) if SCALE else 1
+
+ # Swap the pointers depending on the direction.
+ #
+ # NOTE: We need to zero the output in both directions.
+ iptr = a if A_TO_B else b
+ optr = b if A_TO_B else a
+
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ x = tl.load(iptr + offsets, mask=mask)
+ x = x.to(tl.float32) * scale.to(tl.float32)
+
+ tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask)
+
+ offsets += BLOCK_X
+
+
+def binned_gather(x, indices, weights, bins, expert_capacity, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ num_experts = bins.shape[0]
+ out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
+
+ _binned_copy[(num_experts, expert_capacity)](
+ x,
+ out,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def binned_scatter(x, indices, weights, bins, top_k):
+ # Validate the input shapes.
+ assert_is_tensor(x, 3)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(bins.shape[0], x.shape[0])
+
+ if weights is not None:
+ assert_equal(indices.shape[0], weights.shape[0])
+
+ num_experts, expert_capacity, hidden_size = x.shape
+ tokens = indices.shape[0] // top_k
+ out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
+ _binned_copy[(num_experts, expert_capacity)](
+ out,
+ x,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS=hidden_size,
+ A_TO_B=False,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+
+ # Reduce along the top-k dimension, if needed.
+ return out.sum(dim=1) if top_k > 1 else out.view(tokens, hidden_size)
+
+
+# a: (tokens, hidden_size), real.
+# b: (num_experts, expert_capacity, num_columns), real.
+# indices: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _binned_copy_wgrad(
+ x,
+ grad,
+ wgrad,
+ num_experts,
+ expert_capacity,
+ indices,
+ bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+):
+ # Load our indices into the output.
+ expert_idx = tl.program_id(0)
+ entry_idx = tl.program_id(1)
+
+ # Calculate our offset into the output.
+ index_x = expert_idx * expert_capacity + entry_idx
+
+ # Load the index bounds for our bin and calculate
+ # the number of tokens assigned to our expert.
+ start = 0
+ if expert_idx > 0:
+ start = tl.load(bins + expert_idx - 1)
+ end = tl.load(bins + expert_idx)
+ num_tokens = end - start
+
+ # Calculate our offset into the input. If we don't
+ # have an input exit early.
+ if entry_idx >= num_tokens:
+ return
+ index_out = tl.load(indices + start + entry_idx)
+
+ # Offset the input and output pointers.
+ wgrad += index_out
+ grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS)
+ x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ data = tl.load(x + offsets, mask=mask).to(tl.float32)
+ scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
+ acc += data * scale
+ offsets += BLOCK_X
+
+ # Reduce to get the final result and store.
+ out = tl.sum(acc).to(wgrad.dtype.element_ty)
+ tl.store(wgrad, out)
+
+
+def binned_scatter_wgrad(x, grad, indices, bins, top_k):
+ # Validate the input shapes.
+ assert_is_tensor(x, 3)
+ assert_is_matrix(grad)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(bins.shape[0], x.shape[0])
+
+ num_experts, expert_capacity, hidden_size = x.shape
+ tokens = indices.shape[0] // top_k
+ out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
+ _binned_copy_wgrad[(num_experts, expert_capacity)](
+ x,
+ grad,
+ out,
+ num_experts,
+ expert_capacity,
+ indices,
+ bins,
+ NUM_COLUMNS=hidden_size,
+ TOP_K=top_k,
+ )
+ return out
diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/benchmark_util.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/benchmark_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..02612d95e3ead1175a596e2878fa34b5bf85ad6f
--- /dev/null
+++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/benchmark_util.py
@@ -0,0 +1,35 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import numpy as np
+import torch
+
+
+def log_benchmark(name, arguments, time, std):
+ print('=' * 60)
+ print(f'{name} Benchmark')
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean time = {:.3f}ms, std time = {:.3f}ms'.format(time, std))
+ print('=' * 60)
+
+
+def benchmark_function(fn, iterations=100, warmup=10):
+ # Warmup iterations.
+ for _ in range(warmup):
+ fn()
+
+ times = []
+ for i in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+
+ start.record()
+ fn()
+ end.record()
+
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ return np.mean(times), np.std(times)
diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/grouped_gemm/__init__.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/grouped_gemm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b91c8308f0c24f4c4171b6e4f15b6f76dabf295a
--- /dev/null
+++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/grouped_gemm/__init__.py
@@ -0,0 +1,2 @@
+from . import ops
+from . import backend
diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/grouped_gemm/backend.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/grouped_gemm/backend.py
new file mode 100644
index 0000000000000000000000000000000000000000..76037d8039cbfc2f0577275c78e4bc0be762592a
--- /dev/null
+++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/grouped_gemm/backend.py
@@ -0,0 +1,33 @@
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# # TODO(tgale): Wrap this in a try-block with better
+# # error message and instructions for building the
+# # c++ operations.
+# import grouped_gemm_backend as backend
+
+# We import the backend operations from the megablocks package as
+# grouped_gemm is vendored in megablocks in this repository.
+# from ... import _ops as backend
+# from megablocks._ops import ops as backend # type: ignore
+from .._ops import ops as backend # type: ignore
+
+def _allocate_output(a, b, batch_sizes, trans_a, trans_b):
+ assert not (trans_a and trans_b)
+ assert batch_sizes.ndim == 1, "Expected 1d tensor for batch_sizes"
+ assert a.ndim == 2, "Expected 2d tensor for 'a'"
+ assert b.ndim == (2 if trans_a else 3)
+
+ shape = (
+ (batch_sizes.shape[0], a.shape[1], b.shape[1])
+ if trans_a else
+ (a.shape[0], (b.shape[1] if trans_b else b.shape[2]))
+ )
+ return torch.empty(*shape, device=a.device, dtype=a.dtype)
+
+def gmm(a, b, batch_sizes, trans_a=False, trans_b=False, c=None):
+ if c is None:
+ c = _allocate_output(a, b, batch_sizes, trans_a, trans_b)
+ backend.gmm(a, b, c, batch_sizes, trans_a, trans_b)
+ return c
diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/grouped_gemm/ops.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/grouped_gemm/ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b30dd14e23837ea3b12334f4e31337ed9ad2b69
--- /dev/null
+++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/grouped_gemm/ops.py
@@ -0,0 +1,33 @@
+from . import backend
+import torch
+
+
+class GroupedGemm(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, a, b, batch_sizes, trans_b):
+ ctx.save_for_backward(a, b, batch_sizes)
+ ctx.trans_b = trans_b
+ return backend.gmm(a, b, batch_sizes, trans_a=False, trans_b=trans_b)
+
+ @staticmethod
+ def backward(ctx, grad):
+ grad = grad.contiguous()
+ a, b, batch_sizes = ctx.saved_tensors
+ trans_b = ctx.trans_b
+
+ agrad = None
+ if ctx.needs_input_grad[0]:
+ agrad = backend.gmm(
+ grad, b, batch_sizes, trans_a=False, trans_b=not trans_b)
+
+ bgrad = None
+ if ctx.needs_input_grad[1]:
+ lhs, rhs = (grad, a) if trans_b else (a, grad)
+ bgrad = backend.gmm(
+ lhs, rhs, batch_sizes, trans_a=True, trans_b=False)
+ return agrad, bgrad, None, None
+
+
+def gmm(a, b, batch_sizes, trans_b=False):
+ return GroupedGemm.apply(a, b, batch_sizes, trans_b)
diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/grouped_gemm_util.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/grouped_gemm_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d6d49fc46b0a57ad46e4179df3cc1ac2a24f7ae
--- /dev/null
+++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/grouped_gemm_util.py
@@ -0,0 +1,31 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+import warnings
+
+_grouped_gemm_is_available: bool = False
+try:
+ # import grouped_gemm
+ pass
+ _grouped_gemm_is_available = True
+except ImportError as error:
+ warnings.warn('Grouped GEMM not available.')
+
+
+def grouped_gemm_is_available():
+ return _grouped_gemm_is_available
+
+
+def assert_grouped_gemm_is_available():
+ msg = (
+ 'Grouped GEMM not available. Please run '
+ '`pip install git+https://github.com/tgale96/grouped_gemm@main`.',
+ )
+ assert _grouped_gemm_is_available, msg
+
+
+# backend = grouped_gemm.backend if grouped_gemm_is_available() else None
+# ops = grouped_gemm.ops if grouped_gemm_is_available() else None
+
+
+from .grouped_gemm import backend as ops
+from .grouped_gemm import ops as backend
diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b835ac5f6929edb8b547f373212388f34be3868
--- /dev/null
+++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers.py
@@ -0,0 +1,1225 @@
+import torch
+import torch.distributed as dist
+
+from typing import Optional, Any, TYPE_CHECKING
+
+from . import _layers
+from . import ops
+
+# Conditional import for meta kernel registration
+if TYPE_CHECKING:
+
+ def register_fake(fn):
+ return lambda name: fn
+
+else:
+ try:
+ from torch.library import register_fake
+ except ImportError:
+ try:
+ from torch.library import impl_abstract as register_fake
+ except ImportError:
+ # Fallback for older PyTorch versions
+ def register_fake(op_name):
+ def decorator(fn):
+ return fn
+
+ return decorator
+
+
+# Meta kernel implementations for torch.compile compatibility
+def _install_meta_kernels():
+ """Install meta kernels for existing MegaBlocks operations"""
+
+ # Create wrapper functions that check for compilation and return meta tensors
+
+ # Patch ops.sort
+ if hasattr(ops, "sort"):
+ original_sort = ops.sort
+
+ def sort_with_meta(x, end_bit=None):
+ if torch.compiler.is_compiling():
+ print("Using meta kernel for sort")
+ # Meta implementation - return tensors with correct shape/dtype/device
+ return torch.empty_like(x), torch.empty_like(x)
+ # print("Using original sort kernel")
+ return original_sort(x, end_bit)
+
+ ops.sort = sort_with_meta
+
+ # Patch ops.histogram
+ if hasattr(ops, "histogram"):
+ original_histogram = ops.histogram
+
+ def histogram_with_meta(x, max_val):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty((max_val,), dtype=torch.int32, device=x.device)
+ return original_histogram(x, max_val)
+
+ ops.histogram = histogram_with_meta
+
+ # Patch ops.inclusive_cumsum
+ if hasattr(ops, "inclusive_cumsum"):
+ original_inclusive_cumsum = ops.inclusive_cumsum
+
+ def inclusive_cumsum_with_meta(x, dim):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty_like(x)
+ return original_inclusive_cumsum(x, dim)
+
+ ops.inclusive_cumsum = inclusive_cumsum_with_meta
+
+ # Patch ops.binned_gather
+ if hasattr(ops, "binned_gather"):
+ original_binned_gather = ops.binned_gather
+
+ def binned_gather_with_meta(x, indices, bins, bin_size, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - output shape based on bin_size
+ if x.dim() >= 2:
+ hidden_size = x.size(-1)
+ return torch.empty(
+ (bin_size, x.size(1), hidden_size),
+ dtype=x.dtype,
+ device=x.device,
+ )
+ else:
+ return torch.empty((bin_size,), dtype=x.dtype, device=x.device)
+ return original_binned_gather(x, indices, bins, bin_size, top_k)
+
+ ops.binned_gather = binned_gather_with_meta
+
+ # Patch ops.binned_scatter
+ if hasattr(ops, "binned_scatter"):
+ original_binned_scatter = ops.binned_scatter
+
+ def binned_scatter_with_meta(x, indices, weights, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - typically reduces to 2D
+ if x.dim() >= 3:
+ return torch.empty(
+ (x.size(1), x.size(2)), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty_like(x)
+ return original_binned_scatter(x, indices, weights, bins, top_k)
+
+ ops.binned_scatter = binned_scatter_with_meta
+
+ # Patch ops.gather
+ if hasattr(ops, "gather"):
+ original_gather = ops.gather
+
+ def gather_with_meta(x, indices, bin_ids, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ if x.dim() >= 2:
+ hidden_size = x.size(-1)
+ return torch.empty(
+ (indices.numel(), hidden_size), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty(indices.shape, dtype=x.dtype, device=x.device)
+ return original_gather(x, indices, bin_ids, bins, top_k)
+
+ ops.gather = gather_with_meta
+
+ # Patch ops.scatter
+ if hasattr(ops, "scatter"):
+ original_scatter = ops.scatter
+
+ def scatter_with_meta(x, indices, bin_ids, weights, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - restore sequence shape
+ seq_len = (
+ indices.size(0) // top_k
+ if indices.numel() > 0 and top_k > 0
+ else x.size(0)
+ )
+ if x.dim() >= 2:
+ return torch.empty(
+ (seq_len, x.size(-1)), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty((seq_len,), dtype=x.dtype, device=x.device)
+ return original_scatter(x, indices, bin_ids, weights, bins, top_k)
+
+ ops.scatter = scatter_with_meta
+
+ # Patch ops.replicate
+ if hasattr(ops, "replicate"):
+ original_replicate = ops.replicate
+
+ def replicate_with_meta(x, bins, num_outputs):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty(
+ (x.shape[0], num_outputs), dtype=x.dtype, device=x.device
+ )
+ return original_replicate(x, bins, num_outputs)
+
+ ops.replicate = replicate_with_meta
+
+ # Patch ops.repeat (if it's a regular function)
+ if hasattr(ops, "repeat"):
+ original_repeat = ops.repeat
+
+ def repeat_with_meta(x, repeats):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ if isinstance(repeats, (tuple, list)):
+ new_shape = list(x.shape)
+ for i, rep in enumerate(repeats):
+ if i < len(new_shape):
+ new_shape[i] *= rep
+ return torch.empty(new_shape, dtype=x.dtype, device=x.device)
+ else:
+ new_shape = [x.size(0) * repeats] + list(x.shape[1:])
+ return torch.empty(new_shape, dtype=x.dtype, device=x.device)
+ return original_repeat(x, repeats)
+
+ ops.repeat = repeat_with_meta
+
+
+# Install meta kernels on import
+try:
+ _install_meta_kernels()
+except Exception as e:
+ # If meta kernel installation fails, continue without them
+ # torch.compile may not work but the library will still function
+ import warnings
+
+ warnings.warn(
+ f"Failed to install meta kernels for torch.compile support: {e}", UserWarning
+ )
+
+
+# Set the expert model parallel attributes on a tensor
+def set_expert_model_parallel_attributes(
+ tensor: torch.Tensor,
+ is_parallel: bool,
+):
+ assert not hasattr(tensor, "expert_model_parallel")
+ setattr(tensor, "expert_model_parallel", is_parallel)
+
+
+# Get the expert model parallel attributes from a tensor
+def expert_sharding_degree(
+ world_size: int,
+ moe_num_experts: int,
+) -> int:
+ esd = min(world_size, moe_num_experts)
+ if (moe_num_experts % esd) != 0:
+ raise ValueError(f"Cannot shard {moe_num_experts} experts {esd} ways.")
+ return esd
+
+
+# Calculate the hidden sharding degree based on world size and expert sharding degree
+def hidden_sharding_degree(
+ world_size: int,
+ moe_num_experts: int,
+ ffn_hidden_size: int,
+) -> int:
+ esd = expert_sharding_degree(world_size, moe_num_experts)
+ hsd = world_size // esd
+ if (ffn_hidden_size % hsd) != 0:
+ raise ValueError(f"Cannot shard {ffn_hidden_size} features {hsd} ways.")
+ if (esd * hsd) != world_size:
+ raise ValueError(
+ f"Invalid sharding. expert_sharding_degree ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size})."
+ )
+ return hsd
+
+
+# Calculate the number of experts per rank based on world size and expert sharding degree
+def experts_per_rank(
+ moe_num_experts: int,
+ world_size: int,
+) -> int:
+ return moe_num_experts // expert_sharding_degree(world_size, moe_num_experts)
+
+
+# Calculate the number of features per rank based on ffn hidden size and hidden sharding degree
+def features_per_rank(
+ ffn_hidden_size: int, world_size: int, moe_num_experts: int
+) -> int:
+ return ffn_hidden_size // hidden_sharding_degree(
+ world_size, moe_num_experts, ffn_hidden_size
+ )
+
+
+# Apply jitter to the input tensor
+def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor:
+ low = 1.0 - moe_jitter_eps
+ high = 1.0 + moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return x * (low + noise * (high - low))
+
+
+# Compute the top-k scores from the logits
+def compute_top_k(scores: torch.Tensor, moe_top_k: int):
+ if moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, moe_top_k, dim=-1)
+
+
+# Route tokens to experts and compute expert weights and indices
+def route_tokens(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: torch.Tensor,
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ if training and moe_jitter_eps is not None:
+ x = apply_jitter(x, moe_jitter_eps)
+
+ x_flat = x.view(-1, x.shape[-1])
+ logits = torch.nn.functional.linear(x_flat, router_weight, router_bias)
+ expert_weights, expert_indices = compute_top_k(logits, moe_top_k)
+ expert_weights = expert_weights.softmax(dim=-1)
+ if moe_normalize_expert_weights is not None:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+ if uniform_expert_assignment:
+ expert_indices = _layers.router._uniform_expert_assignment(
+ expert_indices,
+ moe_num_experts,
+ )
+
+ return logits, expert_weights, expert_indices
+
+
+# Scale the gradient of the weights
+def scale_grad(
+ w: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+) -> torch.Tensor:
+ if gradient_scale is None:
+ return w
+ return _layers.mlp.scale_gradient(w, gradient_scale)
+
+
+# Forward pass for the MLP layer
+def mlp_forward(
+ x: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ limit: float = 7.0,
+):
+ # Scale weights
+ w1 = scale_grad(w1, gradient_scale)
+ w2 = scale_grad(w2, gradient_scale)
+ w1_bias = scale_grad(w1_bias, gradient_scale)
+ w2_bias = scale_grad(w2_bias, gradient_scale)
+
+ # Resolve dtensors
+ w1 = _layers.mlp.resolve_dtensor(w1)
+ w2 = _layers.mlp.resolve_dtensor(w2)
+ w1_bias = _layers.mlp.resolve_dtensor(w1_bias)
+ w2_bias = _layers.mlp.resolve_dtensor(w2_bias)
+
+ # Forward pass
+ gate_up = torch.bmm(x, w1) + w1_bias[..., None, :]
+ gate, up = gate_up[..., ::2], gate_up[..., 1::2]
+ gate = gate.clamp(min=None, max=limit)
+ up = up.clamp(min=-limit, max=limit)
+ glu = gate * torch.sigmoid(gate * alpha)
+ next_states = torch.bmm(((up + 1) * glu), w2)
+ next_states += w2_bias[..., None, :]
+ return next_states
+
+# Shared expert MLP forward pass
+def shared_mlp_forward(
+ x: torch.Tensor,
+ up_proj_weight: torch.Tensor,
+ down_proj_weight: torch.Tensor,
+ up_proj_bias: Optional[torch.Tensor] = None,
+ down_proj_bias: Optional[torch.Tensor] = None,
+ activation_fn: Optional[Any] = None,
+ gradient_scale: Optional[float] = None,
+) -> torch.Tensor:
+ # Default activation function
+ if activation_fn is None:
+ activation_fn = torch.nn.functional.gelu
+
+ # Scale weights
+ up_proj_weight = scale_grad(up_proj_weight, gradient_scale)
+ down_proj_weight = scale_grad(down_proj_weight, gradient_scale)
+ if up_proj_bias is not None:
+ up_proj_bias = scale_grad(up_proj_bias, gradient_scale)
+ if down_proj_bias is not None:
+ down_proj_bias = scale_grad(down_proj_bias, gradient_scale)
+
+ # Resolve dtensors
+ up_proj_weight = _layers.mlp.resolve_dtensor(up_proj_weight)
+ down_proj_weight = _layers.mlp.resolve_dtensor(down_proj_weight)
+ if up_proj_bias is not None:
+ up_proj_bias = _layers.mlp.resolve_dtensor(up_proj_bias)
+ if down_proj_bias is not None:
+ down_proj_bias = _layers.mlp.resolve_dtensor(down_proj_bias)
+
+ # Up projection
+ x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias)
+
+ # Activation
+ x = activation_fn(x)
+
+ # Down projection
+ x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias)
+
+ return x
+
+
+# Combine outputs from shared expert and regular experts
+def combine_expert_shared_outputs(
+ shared_expert_out: torch.Tensor,
+ expert_out: torch.Tensor,
+ shared_expert_weighted_sum: bool = False,
+ moe_top_k: int = 1,
+) -> torch.Tensor:
+ if shared_expert_weighted_sum:
+ # Weighted sum based on number of experts used
+ total_experts = moe_top_k + 1
+ shared_weight = 1.0 / total_experts
+ expert_weight = moe_top_k / total_experts
+ return shared_expert_out * shared_weight + expert_out * expert_weight
+ else:
+ # Simple addition
+ return shared_expert_out + expert_out
+
+
+# Global variable to store load balancing loss
+_LOAD_BALANCING_LOSS = []
+
+
+def save_load_balancing_loss(loss):
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.append(loss)
+
+
+def get_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ return _LOAD_BALANCING_LOSS
+
+
+def clear_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.clear()
+
+
+def batched_load_balancing_loss(args):
+ if args.moe_loss_weight == 0:
+ return 0.0
+
+ tokens_per_expert, expert_scores = zip(*get_load_balancing_loss())
+ num_layers_per_pipeline_stage = args.num_layers // args.pipeline_model_parallel_size
+ if args.num_layers_per_virtual_pipeline_stage is not None:
+ num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage
+
+ if len(tokens_per_expert) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f"Expected {num_layers_per_pipeline_stage} token_per_experts "
+ f"but found {len(tokens_per_expert)}.\nnum_layers = "
+ f"{args.num_layers}\npipeline_model_parallel_size = "
+ f"{args.pipeline_model_parallel_size}\n"
+ "num_layers_per_virtual_pipeline_stage"
+ f" = {args.num_layers_per_virtual_pipeline_stage}",
+ )
+ if len(expert_scores) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f"Expected {num_layers_per_pipeline_stage} expert_scores "
+ f"but found {len(tokens_per_expert)}.\nnum_layers = "
+ f"{args.num_layers}\npipeline_model_parallel_size = "
+ f"{args.pipeline_model_parallel_size}\n"
+ "num_layers_per_virtual_pipeline_stage"
+ f" = {args.num_layers_per_virtual_pipeline_stage}",
+ )
+
+ # Verify the shape of the tokens_per_expert and expert_scores tensors.
+ assert all(
+ (x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert)
+ )
+
+ tokens = expert_scores[0].shape[0]
+ assert all(
+ (
+ (
+ x.ndim == 2
+ and x.shape[1] == args.moe_num_experts
+ and x.shape[0] == tokens
+ )
+ for x in expert_scores
+ )
+ )
+
+ # Concatenate the contributions of each layer and convert to
+ # the correct types and formats for the dot product.
+ expert_scores = torch.cat(expert_scores, dim=1)
+ if args.moe_lbl_in_fp32:
+ expert_scores = expert_scores.float()
+ if tokens != 0:
+ expert_scores = expert_scores.mean(dim=0)
+ else:
+ expert_scores = expert_scores.sum(dim=0)
+ tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype)
+
+ expected_values = num_layers_per_pipeline_stage * args.moe_num_experts
+ assert tokens_per_expert.numel() == expected_values
+ assert expert_scores.numel() == expected_values
+
+ # Calculate the total scale across all factors.
+ #
+ # loss_weight * num_experts / (num_layers * tokens * top_k)
+ scale_numerator = args.moe_num_experts * args.moe_loss_weight
+ scale_denominator = args.num_layers * tokens * args.moe_top_k
+ scale = scale_numerator / scale_denominator
+ return scale * torch.dot(tokens_per_expert, expert_scores)
+
+
+# Calculate the expert capacity based on tokens, top_k, number of experts,
+# expert parallel group, capacity factor, and whether expert model parallelism is used.
+def expert_capacity(
+ tokens: int,
+ top_k: int,
+ num_experts: int,
+ expert_parallel_group: int,
+ moe_capacity_factor: float,
+ moe_expert_model_parallelism: bool,
+) -> int:
+ world_size = (
+ dist.get_world_size(expert_parallel_group)
+ if moe_expert_model_parallelism
+ else 1
+ )
+
+ tokens_per_expert = top_k * tokens * world_size / num_experts
+ return int(moe_capacity_factor * tokens_per_expert)
+
+
+def load_balancing_loss(
+ tokens_per_expert: torch.Tensor,
+ expert_scores: torch.Tensor,
+ top_k: int,
+ num_experts: int,
+):
+ assert len(expert_scores.size()) == 2
+ tokens, num_experts = expert_scores.size()
+ assert num_experts == num_experts
+ assert len(tokens_per_expert.size()) == 1
+ (num_experts,) = tokens_per_expert.size()
+ assert num_experts == num_experts
+ scale = num_experts / (tokens * top_k)
+ return scale * torch.dot(
+ tokens_per_expert.to(expert_scores.dtype),
+ expert_scores.mean(dim=0),
+ )
+
+
+def indices_and_bins(
+ top_expert: torch.Tensor,
+ sort_end_bit: int,
+ num_experts: int,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ top_expert = top_expert.int()
+
+ # Ensure contiguous memory layout
+ top_expert = top_expert.contiguous()
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(top_expert.device):
+ output = ops.sort(top_expert, sort_end_bit)
+ bin_ids, indices = output
+ tokens_per_expert = ops.histogram(top_expert, num_experts)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+
+ bins = bins.view(1) if not len(bins.size()) else bins
+ return indices, bin_ids, bins, tokens_per_expert
+
+
+def expert_capacity_fn(
+ tokens: int,
+ top_k: int,
+ num_experts: int,
+ expert_parallel_group: torch.distributed.ProcessGroup,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+) -> int:
+ world_size = (
+ dist.get_world_size(expert_parallel_group)
+ if moe_expert_model_parallelism
+ else 1
+ )
+ tokens_per_expert = top_k * tokens * world_size / num_experts
+ return int(moe_capacity_factor * tokens_per_expert)
+
+
+def permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ top_k,
+ w1,
+ w2,
+ w1_bias,
+ w2_bias,
+ gradient_scale,
+ alpha,
+):
+ # Route tokens to experts
+ x = x.view(-1, x.shape[-1])
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(x.device):
+ x = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
+
+ # Expert computation
+ x = mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale, alpha)
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(x.device):
+ # Route tokens back
+ out = ops.binned_scatter(x, indices, expert_weights, bins, top_k)
+ return out
+
+
+def forward_once(
+ x: torch.Tensor,
+ expert_weights: torch.Tensor,
+ top_experts: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ top_k: int = 4,
+ num_experts: int = 128,
+ expert_parallel_group: int = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ mlp_impl: Optional[str] = None,
+):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
+ top_experts, sort_end_bit, num_experts
+ )
+
+ # Calculate expert capacity
+ sl, bs, _ = x.size()
+
+ expert_capacity = expert_capacity_fn(
+ sl * bs,
+ top_k,
+ num_experts,
+ expert_parallel_group,
+ moe_capacity_factor,
+ moe_expert_model_parallelism,
+ )
+
+ if expert_capacity == 0:
+ expert_capacity = torch.max(tokens_per_expert).item()
+
+ x = permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ top_k,
+ w1,
+ w2,
+ w1_bias,
+ w2_bias,
+ gradient_scale,
+ alpha,
+ )
+ return x, tokens_per_expert
+
+
+def parallel_forward_once(
+ x: torch.Tensor,
+ expert_weights: torch.Tensor,
+ top_experts: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ top_k: int = 4,
+ num_experts: int = 128,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = True,
+ hidden_size: int = 1152,
+ mlp_impl: Optional[str] = "grouped",
+):
+ # Flatten inputs
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+
+ # TODO: remove debugging var
+ # my_rank = dist.get_rank(expert_parallel_group) if expert_parallel_group else 0
+
+ with torch.no_grad():
+ # Step 1: Local permutation setup
+ indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
+ top_experts, sort_end_bit, num_experts
+ )
+
+ # Calculate sharding parameters
+ world_size = dist.get_world_size(expert_parallel_group)
+ hidden_sharding_deg = hidden_sharding_degree(
+ world_size, num_experts, hidden_size
+ )
+ experts_per_rank_val = experts_per_rank(num_experts, world_size)
+
+ # Replicate token counts for hidden sharding
+ repeated_tokens_per_expert = ops.repeat(
+ tokens_per_expert, (hidden_sharding_deg,)
+ )
+
+ # Exchange token counts across devices
+ parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert)
+
+ # Ensure CUB knows which device to use
+ tpe_handle = dist.all_to_all_single(
+ parallel_tokens_per_expert,
+ repeated_tokens_per_expert,
+ group=expert_parallel_group,
+ async_op=True,
+ )
+
+ # Step 2: Local permutation - group tokens by target device
+ x = x.view(-1, x.shape[-1]) # [sl * bs, hs]
+ x = ops.gather(x, indices, bin_ids, bins, top_k)
+
+ # Step 3: Compute communication counts and exchange tokens
+ with torch.no_grad():
+ tpe_handle.wait()
+
+ # Reshape for per-device calculations
+ repeated_tokens_per_expert = repeated_tokens_per_expert.view(
+ world_size, experts_per_rank_val
+ )
+ parallel_tokens_per_expert = parallel_tokens_per_expert.view(
+ world_size, experts_per_rank_val
+ )
+
+ # Calculate send/recv counts
+ send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1).tolist()
+ # recv_counts = parallel_tokens_per_expert.cpu().sum(dim=-1).tolist()
+ parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
+ recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1).tolist()
+ tokens_received = sum(recv_counts)
+
+ # Replicate for hidden sharding
+ x = ops.repeat(x, (hidden_sharding_deg, 1))
+
+ # Cross-device token exchange
+ parallel_x, parallel_x_handle = _layers.all_to_all.all_to_all(
+ x, recv_counts, send_counts, expert_parallel_group, async_op=True
+ )
+
+ with torch.no_grad():
+ # Step 4: Setup for local expert computation
+ replicate_bins = ops.inclusive_cumsum(parallel_tokens_per_expert.flatten(), 0)
+ replicate_bins = (
+ replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins
+ )
+
+ # Create expert indices for received tokens
+ parallel_top_expert = torch.remainder(
+ torch.arange(
+ num_experts * hidden_sharding_deg,
+ dtype=torch.int32,
+ device=indices.device,
+ ),
+ experts_per_rank_val,
+ )
+ parallel_top_expert = ops.replicate(
+ parallel_top_expert.unsqueeze(dim=0),
+ replicate_bins,
+ tokens_received,
+ ).flatten()
+
+ # Sort tokens by expert assignment
+ parallel_bin_ids, parallel_indices = ops.sort(
+ parallel_top_expert,
+ sort_end_bit,
+ )
+
+ # Calculate bins for local experts
+ parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
+ dim=0, dtype=torch.int
+ )
+ parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = (
+ parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins
+ )
+
+ # Calculate expert capacity
+ expert_capacity = expert_capacity_fn(
+ tokens_received,
+ top_k,
+ experts_per_rank_val,
+ expert_parallel_group,
+ moe_capacity_factor,
+ moe_expert_model_parallelism,
+ )
+ if expert_capacity == 0:
+ expert_capacity = torch.max(parallel_tokens_per_expert).item()
+
+ # Locally permute the tokens and perform the expert computation.
+ # Block to make sure that the cross-device permutation is complete.
+ if mlp_impl == "grouped":
+ # GroupedMLP requires counts on CPU. We can use the tensor already
+ # moved to CPU for the prior all_to_all, which avoids an extra
+ # device synchronization.
+ parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+
+ # Step 5: Expert computation
+ parallel_x_handle.wait()
+
+ parallel_x = permute_and_compute(
+ parallel_x,
+ parallel_tokens_per_expert,
+ parallel_indices,
+ parallel_bin_ids,
+ None, # expert_weights
+ parallel_bins,
+ expert_capacity,
+ top_k=1,
+ w1=w1,
+ w2=w2,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ )
+
+ # Step 6: Reverse communication - send results back
+ x, _ = _layers.all_to_all.all_to_all(
+ parallel_x, send_counts, recv_counts, expert_parallel_group
+ )
+
+ # Step 7: Reduce across hidden sharding dimension
+ shape = (hidden_sharding_deg, -1, hidden_size)
+ x = x.view(shape).sum(dim=0)
+
+ # Step 8: Final local unpermutation
+ x = ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
+
+ return x, tokens_per_expert.flatten()
+
+
+def moe_forward(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: Optional[torch.Tensor],
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+ w1: torch.Tensor = None,
+ w2: torch.Tensor = None,
+ w1_bias: torch.Tensor = None,
+ w2_bias: torch.Tensor = None,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ forward_fn: Any = None,
+ hidden_size: int = None,
+ mlp_impl: str = "grouped",
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+
+ # Route tokens to experts
+ logits, expert_weights, expert_indices = route_tokens(
+ x,
+ router_weight,
+ router_bias,
+ moe_top_k,
+ moe_num_experts,
+ moe_jitter_eps,
+ moe_normalize_expert_weights,
+ uniform_expert_assignment,
+ training,
+ )
+
+ # Create router scores for output
+ router_scores = (
+ torch.zeros_like(logits)
+ .scatter_(1, expert_indices, expert_weights)
+ .transpose(0, 1)
+ )
+
+ in_shape = x.size()
+
+ # Prepare forward function arguments
+ forward_args = {
+ "x": x,
+ "expert_weights": expert_weights,
+ "top_experts": expert_indices,
+ "w1": w1,
+ "w2": w2,
+ "w1_bias": w1_bias,
+ "w2_bias": w2_bias,
+ "gradient_scale": gradient_scale,
+ "alpha": alpha,
+ "sort_end_bit": sort_end_bit,
+ "top_k": moe_top_k,
+ "num_experts": moe_num_experts,
+ "expert_parallel_group": expert_parallel_group,
+ "moe_capacity_factor": moe_capacity_factor,
+ "moe_expert_model_parallelism": moe_expert_model_parallelism,
+ "mlp_impl": mlp_impl,
+ }
+
+ # Add hidden_size for parallel forward
+ if moe_expert_model_parallelism and hidden_size is not None:
+ forward_args["hidden_size"] = hidden_size
+ elif moe_expert_model_parallelism and hidden_size is None:
+ # Infer hidden_size from input shape
+ forward_args["hidden_size"] = x.shape[-1]
+
+ # Compute expert outputs
+ x, tokens_per_expert = forward_fn(**forward_args)
+
+ # Save load balancing loss if needed
+ moe_loss_weight = 0.0 # Can be made configurable
+ if training and moe_loss_weight > 0:
+ save_load_balancing_loss((tokens_per_expert, logits))
+
+ # Restore original shape
+ x = x.view(in_shape)
+
+ return x, expert_weights, router_scores
+
+
+def moe_forward_with_shared_expert(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: Optional[torch.Tensor],
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+ w1: torch.Tensor = None,
+ w2: torch.Tensor = None,
+ w1_bias: torch.Tensor = None,
+ w2_bias: torch.Tensor = None,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ forward_fn: Any = None,
+ hidden_size: int = None,
+ mlp_impl: str = "grouped",
+ # Shared expert parameters
+ shared_up_proj_weight: Optional[torch.Tensor] = None,
+ shared_down_proj_weight: Optional[torch.Tensor] = None,
+ shared_up_proj_bias: Optional[torch.Tensor] = None,
+ shared_down_proj_bias: Optional[torch.Tensor] = None,
+ shared_expert_weighted_sum: bool = False,
+ shared_activation_fn: Optional[Any] = None,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+
+ # First, compute regular MoE forward pass
+ expert_out, expert_weights, router_scores = moe_forward(
+ x=x,
+ router_weight=router_weight,
+ router_bias=router_bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=training,
+ w1=w1,
+ w2=w2,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=moe_expert_model_parallelism,
+ forward_fn=forward_fn,
+ hidden_size=hidden_size,
+ mlp_impl=mlp_impl,
+ )
+
+ # If shared expert weights provided, compute shared expert output
+ if shared_up_proj_weight is not None and shared_down_proj_weight is not None:
+ shared_expert_out = shared_mlp_forward(
+ x=x,
+ up_proj_weight=shared_up_proj_weight,
+ down_proj_weight=shared_down_proj_weight,
+ up_proj_bias=shared_up_proj_bias,
+ down_proj_bias=shared_down_proj_bias,
+ activation_fn=shared_activation_fn,
+ gradient_scale=gradient_scale,
+ )
+
+ # Combine expert outputs
+ combined_out = combine_expert_shared_outputs(
+ shared_expert_out=shared_expert_out,
+ expert_out=expert_out,
+ shared_expert_weighted_sum=shared_expert_weighted_sum,
+ moe_top_k=moe_top_k,
+ )
+
+ return combined_out, expert_weights, router_scores
+
+ # Return regular MoE output if no shared expert
+ return expert_out, expert_weights, router_scores
+
+
+def create_shared_expert_weights(
+ hidden_size: int,
+ shared_expert_hidden_size: int,
+ device: torch.device,
+ dtype: torch.dtype,
+ init_method: Any,
+ output_layer_init_method: Any = None,
+) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
+
+ if output_layer_init_method is None:
+ output_layer_init_method = init_method
+
+ # Create weight tensors
+ up_proj_weight = torch.empty(
+ shared_expert_hidden_size,
+ hidden_size,
+ device=device,
+ dtype=dtype,
+ )
+ down_proj_weight = torch.empty(
+ hidden_size,
+ shared_expert_hidden_size,
+ device=device,
+ dtype=dtype,
+ )
+
+ # Initialize weights
+ init_method(up_proj_weight)
+ output_layer_init_method(down_proj_weight)
+
+ # No bias by default
+ return up_proj_weight, down_proj_weight, None, None
+
+
+# HACK: Extract device_mesh from pre-hook closure - required for transformers integration
+# This exists because device_mesh is trapped in hook closures with no model attribute
+# Fragile - breaks if hook structure changes or Python internals change
+# TODO: Replace with a more robust solution when available
+def get_device_mesh(model):
+ # Extract device_mesh from child's unused pre_hook closure
+ try:
+ # Find the pre-hook that contains 'device_mesh' in its closure
+ hook = next(
+ h
+ for h in model.experts._forward_pre_hooks.values()
+ if "device_mesh" in h.__code__.co_freevars
+ )
+ # Extract the device_mesh from the closure
+ return hook.__closure__[
+ hook.__code__.co_freevars.index("device_mesh")
+ ].cell_contents
+ except Exception:
+ return None
+
+
+class MegaBlocksMoeMLP(torch.nn.Module):
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ gradient_scale = getattr(self.experts, "gradient_scale", None)
+ alpha = getattr(self.experts, "alpha", 1.0)
+ moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+ uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
+
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = get_device_mesh(self)
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
+
+ has_parallel = (
+ expert_parallel_group is not None
+ and dist.is_initialized()
+ and dist.get_world_size(expert_parallel_group) > 1
+ )
+ forward_fn = parallel_forward_once if has_parallel else forward_once
+
+ sort_end_bit = max(
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
+ )
+ mlp_impl = getattr(self, "mlp_impl", "grouped")
+ output, expert_weights_out, *_ = moe_forward(
+ x=x,
+ router_weight=self.router.weight,
+ router_bias=self.router.bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=self.training,
+ w1=self.experts.gate_up_proj,
+ w2=self.experts.down_proj,
+ w1_bias=self.experts.gate_up_proj_bias,
+ w2_bias=self.experts.down_proj_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=has_parallel,
+ forward_fn=forward_fn,
+ hidden_size=self.experts.hidden_size,
+ mlp_impl=mlp_impl,
+ )
+ return output, expert_weights_out
+
+
+# Export main classes
+__all__ = ["MegaBlocksMoeMLP", "MegaBlocksMoeMLPWithSharedExpert"]
+
+
+class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
+
+ def __init__(self):
+ super().__init__()
+ # Shared expert weights will be set by the user
+ self.shared_up_proj_weight = None
+ self.shared_down_proj_weight = None
+ self.shared_up_proj_bias = None
+ self.shared_down_proj_bias = None
+ self.shared_expert_weighted_sum = False
+ self.shared_activation_fn = None
+
+ def set_shared_expert_weights(
+ self,
+ up_proj_weight: torch.Tensor,
+ down_proj_weight: torch.Tensor,
+ up_proj_bias: Optional[torch.Tensor] = None,
+ down_proj_bias: Optional[torch.Tensor] = None,
+ weighted_sum: bool = False,
+ activation_fn: Optional[Any] = None,
+ ):
+ self.shared_up_proj_weight = up_proj_weight
+ self.shared_down_proj_weight = down_proj_weight
+ self.shared_up_proj_bias = up_proj_bias
+ self.shared_down_proj_bias = down_proj_bias
+ self.shared_expert_weighted_sum = weighted_sum
+ self.shared_activation_fn = activation_fn
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ gradient_scale = getattr(self.experts, "gradient_scale", None)
+ alpha = getattr(self.experts, "alpha", 1.0)
+ moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+ uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
+
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = get_device_mesh(self)
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
+
+ has_parallel = (
+ expert_parallel_group is not None
+ and dist.is_initialized()
+ and dist.get_world_size(expert_parallel_group) > 1
+ )
+ forward_fn = parallel_forward_once if has_parallel else forward_once
+
+ sort_end_bit = max(
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
+ )
+ mlp_impl = getattr(self, "mlp_impl", "grouped")
+
+ output, expert_weights_out, *_ = moe_forward_with_shared_expert(
+ x=x,
+ router_weight=self.router.weight,
+ router_bias=self.router.bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=self.training,
+ w1=self.experts.gate_up_proj,
+ w2=self.experts.down_proj,
+ w1_bias=self.experts.gate_up_proj_bias,
+ w2_bias=self.experts.down_proj_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=has_parallel,
+ forward_fn=forward_fn,
+ hidden_size=self.experts.hidden_size,
+ mlp_impl=mlp_impl,
+ # Shared expert parameters
+ shared_up_proj_weight=self.shared_up_proj_weight,
+ shared_down_proj_weight=self.shared_down_proj_weight,
+ shared_up_proj_bias=self.shared_up_proj_bias,
+ shared_down_proj_bias=self.shared_down_proj_bias,
+ shared_expert_weighted_sum=self.shared_expert_weighted_sum,
+ shared_activation_fn=self.shared_activation_fn,
+ )
+ return output, expert_weights_out
diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/__init__.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b944080df810d0b0cfc571f3009b0098a651f9b7
--- /dev/null
+++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/__init__.py
@@ -0,0 +1,35 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from .binned_gather import binned_gather
+from .binned_scatter import binned_scatter
+from .cumsum import exclusive_cumsum, inclusive_cumsum
+from .gather import gather
+from .histogram import histogram
+from .padded_gather import padded_gather
+from .padded_scatter import padded_scatter
+from .repeat import repeat
+from .replicate import replicate
+from .round_up import round_up
+from .scatter import scatter
+from .sort import sort
+from .sum import sum
+from .topology import topology
+
+__all__ = [
+ 'binned_gather',
+ 'binned_scatter',
+ 'exclusive_cumsum',
+ 'inclusive_cumsum',
+ 'gather',
+ 'histogram',
+ 'padded_gather',
+ 'padded_scatter',
+ 'repeat',
+ 'replicate',
+ 'round_up',
+ 'scatter',
+ 'sort',
+ 'sum',
+ 'topology',
+]
diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/all_to_all_benchmark.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/all_to_all_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c939818edca3345f6344bbc7cef07ffe3cd0181
--- /dev/null
+++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/all_to_all_benchmark.py
@@ -0,0 +1,63 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import torch.distributed as dist
+
+# from megablocks import benchmark_util
+# from megablocks.layers.all_to_all import all_to_all
+
+from .. import benchmark_util
+from .._layers.all_to_all import all_to_all
+
+_ALL_TO_ALL_BENCHMARK = (
+ (8, 1024),
+ (16, 1024),
+ (32, 1024),
+ (64, 1024),
+ (128, 1024),
+ (256, 1024),
+ (512, 1024),
+ (1024, 1024),
+ (2 * 1024, 1024),
+ (4 * 1024, 1024),
+ (8 * 1024, 1024),
+ (16 * 1024, 1024),
+ (32 * 1024, 1024),
+ (64 * 1024, 1024),
+ (128 * 1024, 1024),
+ (256 * 1024, 1024),
+ (512 * 1024, 1024),
+ (1024 * 1024, 1024),
+)
+
+
+def benchmark_all_to_all(group, sl, hs):
+ world_size = dist.get_world_size(group)
+ assert (sl % world_size) == 0
+ send_recv_sizes = [sl // world_size] * world_size
+
+ x = torch.randn((sl, hs)).cuda().half()
+
+ details = {
+ 'world_size': world_size,
+ 'message_size (B)': send_recv_sizes[0] * hs * 2, # 2B elements.
+ }
+
+ def benchmark():
+ return all_to_all(x, send_recv_sizes, send_recv_sizes, group)
+
+ time, std = benchmark_util.benchmark_function(benchmark)
+
+ if dist.get_rank(group) == 0:
+ benchmark_util.log_benchmark('All-To-All', details, time, std)
+
+
+if __name__ == '__main__':
+ assert dist.is_available()
+ group = dist.init_process_group(backend='nccl')
+ local_rank = dist.get_rank(group)
+ torch.cuda.set_device(local_rank)
+
+ for args in _ALL_TO_ALL_BENCHMARK:
+ benchmark_all_to_all(group, *args)
diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/binned_gather.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/binned_gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..189a7fa3518d660f29ea32e7a04827164af98d60
--- /dev/null
+++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/binned_gather.py
@@ -0,0 +1,37 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for binned_gather kernel.
+class BinnedGatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bins: torch.Tensor,
+ bin_size: int,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bins)
+ ctx.top_k = top_k
+ return kernels.binned_gather(x, indices, None, bins, bin_size, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ indices, bins = ctx.saved_tensors
+ out = kernels.binned_scatter(grad, indices, None, bins, ctx.top_k)
+ return out, None, None, None, None
+
+
+binned_gather = BinnedGatherOp.apply
diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/binned_scatter.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/binned_scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb937c0c106662ce8108c1cb926f8f063b163d3d
--- /dev/null
+++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/binned_scatter.py
@@ -0,0 +1,59 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for binned_scatter kernel.
+class BinnedScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ):
+ assert len(x.size()) == 3
+ ctx.bin_size = x.size(1)
+ ctx.top_k = top_k
+
+ # TODO(tgale): Don't save 'x' for backwards if we don't need to
+ # calculate the gradient w.r.t. 'weights'.
+ ctx.save_for_backward(x, indices, weights, bins)
+ return kernels.binned_scatter(x, indices, weights, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ x, indices, weights, bins = ctx.saved_tensors
+ out = kernels.binned_gather(
+ grad,
+ indices,
+ weights,
+ bins,
+ ctx.bin_size,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[2]:
+ wgrad = kernels.binned_scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bins,
+ ctx.top_k,
+ )
+ return out, None, wgrad, None, None
+
+
+binned_scatter = BinnedScatterOp.apply
diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/cumsum.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/cumsum.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2b7572391e20045d335cf7337246e8a9b9f57ef
--- /dev/null
+++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/cumsum.py
@@ -0,0 +1,52 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ # import megablocks_ops as ops # type: ignore
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrappers for cumsum kernels.
+# NOTE: Does not support gradients.
+class ExclusiveCumsumOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, dim: int):
+ if len(x.size()) == 1:
+ x = x.view([1, -1])
+ out = torch.empty_like(x)
+ ops.exclusive_cumsum(x, 1, out)
+ return out.squeeze()
+ out = torch.empty_like(x)
+ ops.exclusive_cumsum(x, dim, out)
+ return out
+
+
+exclusive_cumsum = ExclusiveCumsumOp.apply
+
+
+class InclusiveCumsumOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, dim: int) -> torch.Tensor:
+ if len(x.size()) == 1:
+ x = x.view([1, -1])
+ out = torch.empty_like(x)
+ ops.inclusive_cumsum(x, 1, out)
+ return out.squeeze()
+ out = torch.empty_like(x)
+ ops.inclusive_cumsum(x, dim, out)
+ return out
+
+
+inclusive_cumsum = InclusiveCumsumOp.apply
diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/gather.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1f87c1e7bed8d3589dd790805234976e0b05898
--- /dev/null
+++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/gather.py
@@ -0,0 +1,38 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for gather kernel.
+class GatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bin_ids, bins)
+ ctx.top_k = top_k
+ return kernels.gather(x, indices, bin_ids, None, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+
+ indices, bin_ids, bins = ctx.saved_tensors
+ out = kernels.scatter(grad, indices, bin_ids, None, bins, ctx.top_k)
+ return out, None, None, None, None, None
+
+
+gather = GatherOp.apply
diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/histogram.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/histogram.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b3f058ec373cbba7555704fb5e4212c3cc75d9d
--- /dev/null
+++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/histogram.py
@@ -0,0 +1,27 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for histogram kernel.
+# NOTE: Does not support gradients.
+class HistogramOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, max_val: float):
+ return ops.histogram(x, max_val)
+
+
+histogram = HistogramOp.apply
diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/histogram_benchmark.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/histogram_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..c57b7bf8228e01237236748147368b09ffdf8072
--- /dev/null
+++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/histogram_benchmark.py
@@ -0,0 +1,78 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import numpy as np
+import torch
+from absl.testing import parameterized
+
+from .. import ops
+
+_HISTOGRAM_TESTS = (
+ (16384, torch.int32, 2),
+ (16384, torch.int32, 4),
+ (16384, torch.int32, 8),
+ (16384, torch.int32, 16),
+ (16384, torch.int32, 32),
+ (16384, torch.int32, 64),
+ (16384, torch.int32, 128),
+ (16384, torch.int32, 256),
+)
+
+
+def benchmark_function(fn, iterations=10):
+ # Run once to get rid of startup overhead.
+ fn()
+ times = []
+ for _ in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+ start.record()
+ fn()
+ end.record()
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ times = np.array(times)
+ return times.mean(), times.std(), times.max(), times.min()
+
+
+def log_benchmark(arguments, mean_t, std_t):
+ print('=' * 60)
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t))
+ print('=' * 60)
+
+
+class HistogramBenchmark(parameterized.TestCase):
+
+ @parameterized.parameters(*_HISTOGRAM_TESTS)
+ def testHistogram(self, n, dtype, max_val):
+ x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
+
+ mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.histogram(x, max_val),)
+ arguments = {
+ 'n': n,
+ 'dtype': dtype,
+ 'max_val': max_val,
+ }
+ log_benchmark(arguments, mean_t, std_t)
+
+ @parameterized.parameters(*_HISTOGRAM_TESTS)
+ def testTorchHistogram(self, n, dtype, max_val):
+ x = torch.randint(0, 128, (n,)).cuda().to(dtype)
+
+ mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.histc(x, max_val, 0, max_val - 1),)
+ arguments = {
+ 'n': n,
+ 'dtype': dtype,
+ 'max_val': max_val,
+ }
+ log_benchmark(arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/matmul_benchmark.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/matmul_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ccc5dcec5e9a663794fad944c45285869c4d1c1
--- /dev/null
+++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/matmul_benchmark.py
@@ -0,0 +1,415 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+
+# import stk
+
+# try:
+# import stk
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/ops/matmul_benchmark.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+
+# Calling tensor.t() calls tensor.transpose(0, 1) which calls
+# torch.as_strided(...). Circumvent this chain to avoid an overhead
+# this adds.
+def transpose_view(x):
+ return torch.as_strided(
+ x,
+ (x.shape[1], x.shape[0]),
+ (x.stride()[1], x.stride()[0]),
+ )
+
+
+_MATMUL_TESTS = (
+ (64 * 1024, 512, 2048, 64),
+ (32 * 1024, 768, 3072, 64),
+ (8 * 1024, 1024, 4096, 64),
+ (4 * 2048, 4096, 4 * 4096, 4),
+)
+
+
+def log_benchmark(name, arguments, time, std, flops):
+ benchmark_util.log_benchmark(name, arguments, time, std)
+ print('flops = {:.2f}B'.format(flops / 1e9))
+ print('throughput = {:.2f}T'.format(flops / 1e9 / time))
+ print('=' * 60)
+
+
+class MatmulBenchmark(parameterized.TestCase):
+
+ def build_sparse_matrix(self, x, padded_bins, fhs, ne):
+ blocking = 128
+ padded_tokens, _ = x.size()
+ assert padded_tokens % blocking == 0
+ assert fhs % blocking == 0
+
+ # Offsets for the sparse matrix. All rows have the
+ # same number of nonzero blocks dictated by the
+ # dimensionality of a single expert.
+ block_rows = padded_tokens // blocking
+ blocks_per_row = fhs // blocking
+ offsets = torch.arange(
+ 0,
+ block_rows * blocks_per_row + 1,
+ blocks_per_row,
+ dtype=torch.int32,
+ device=x.device,
+ )
+
+ # Indices for the sparse matrix. The indices for
+ # the intermediate matrix are dynamic depending
+ # on the mapping of tokens to experts.
+ column_indices = ops.topology(
+ padded_bins,
+ blocking,
+ block_rows,
+ blocks_per_row,
+ )
+ data = torch.empty(
+ column_indices.numel(),
+ blocking,
+ blocking,
+ dtype=torch.float16,
+ device=x.device,
+ )
+ shape = (padded_tokens, fhs * ne)
+ row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
+ return stk.Matrix(shape, data, row_indices, column_indices, offsets)
+
+ def build_input_matrix(self, sl, hs, ne):
+ x = torch.randn((sl, hs)).cuda().half()
+
+ # Assign tokens to experts uniformly.
+ top_expert = torch.arange(0, sl).cuda().int() % ne
+
+ bin_ids, indices = ops.sort(top_expert)
+ tokens_per_expert = ops.histogram(top_expert, ne)
+ padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ out = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, 1)
+ return out, padded_bins
+
+ def build_weight_matrix(self, ne, hs, fhs):
+ return torch.randn((hs, ne * fhs)).cuda().half()
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear0_Fwd_SDD_NT(self, sl, hs, fhs, ne):
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
+ w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+ topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+ w = transpose_view(w)
+
+ def benchmark():
+ return stk.ops.sdd(x, w, topo)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '0::Fwd::SDD::NT',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * fhs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear0_GradX_DSD_NN(self, sl, hs, fhs, ne):
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
+ w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+ topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+
+ def benchmark():
+ return stk.ops.dsd(topo, w)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '0::GradX::DSD::NN',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * fhs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear0_GradW_DSD_TN(self, sl, hs, fhs, ne):
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
+ topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+ topo = topo.t()
+
+ def benchmark():
+ return stk.ops.dsd(topo, x)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '0::GradW::DSD::TN',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * fhs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear1_Fwd_DSD_NN(self, sl, hs, fhs, ne):
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
+ w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+ x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+
+ def benchmark():
+ return stk.ops.dsd(x, w)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '1::Fwd::DSD::NN',
+ arguments,
+ mean_t,
+ std_t,
+ x.nnz * hs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear1_GradX_SDD_NT(self, sl, hs, fhs, ne):
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
+ w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+ x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+ out = stk.ops.dsd(x, w)
+ w = transpose_view(w)
+
+ def benchmark():
+ return stk.ops.sdd(out, w, x)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '1::GradX::SDD::NT',
+ arguments,
+ mean_t,
+ std_t,
+ x.nnz * hs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear1_GradW_DSD_TN(self, sl, hs, fhs, ne):
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
+ w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+ x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+ out = stk.ops.dsd(x, w)
+ x = x.t()
+
+ def benchmark():
+ return stk.ops.dsd(x, out)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '1::GradW::DSD::TN',
+ arguments,
+ mean_t,
+ std_t,
+ x.nnz * hs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear0_Fwd_DDD_NT(self, sl, hs, fhs, ne):
+ assert (sl % ne) == 0
+ x = torch.randn((ne, sl // ne, hs)).cuda().half()
+ w = torch.randn((ne, hs, fhs)).cuda().half()
+
+ w = w.transpose(1, 2).contiguous()
+ w = w.transpose(1, 2)
+
+ def benchmark():
+ return torch.bmm(x, w)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '0::Fwd:DDD::NT',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * fhs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear0_GradX_DDD_NN(self, sl, hs, fhs, ne):
+ assert (sl % ne) == 0
+ x = torch.randn((ne, sl // ne, hs)).cuda().half()
+ w = torch.randn((ne, hs, fhs)).cuda().half()
+ out = torch.bmm(x, w)
+ w = w.transpose(1, 2).contiguous()
+
+ def benchmark():
+ return torch.bmm(out, w)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '0:GradX:DDD::NN',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * fhs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear0_GradW_DDD_TN(self, sl, hs, fhs, ne):
+ assert (sl % ne) == 0
+ x = torch.randn((ne, sl // ne, hs)).cuda().half()
+ w = torch.randn((ne, hs, fhs)).cuda().half()
+ out = torch.bmm(x, w)
+ out = out.transpose(1, 2)
+
+ def benchmark():
+ return torch.bmm(out, x)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '0:GradW:DDD::TN',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * fhs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear1_Fwd_DDD_NN(self, sl, hs, fhs, ne):
+ assert (sl % ne) == 0
+ x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+ w = torch.randn((ne, fhs, hs)).cuda().half()
+
+ def benchmark():
+ return torch.bmm(x, w)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '1::Fwd::DDD::NN',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * hs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear1_GradX_DDD_NT(self, sl, hs, fhs, ne):
+ assert (sl % ne) == 0
+ x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+ w = torch.randn((ne, fhs, hs)).cuda().half()
+ out = torch.bmm(x, w)
+ w = torch.transpose(w, 1, 2)
+
+ def benchmark():
+ return torch.bmm(out, w)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '1::GradX::DDD::NT',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * hs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear1_GradW_DDD_TN(self, sl, hs, fhs, ne):
+ assert (sl % ne) == 0
+ x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+ w = torch.randn((ne, fhs, hs)).cuda().half()
+ out = torch.bmm(x, w)
+ x = torch.transpose(x, 1, 2)
+
+ def benchmark():
+ return torch.bmm(x, out)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '1::GradW::DDD::TN',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * hs * 2,
+ )
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/padded_gather.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/padded_gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1cf4047c9494394d2a3884ba8830179013db7ff
--- /dev/null
+++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/padded_gather.py
@@ -0,0 +1,55 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for padded_gather kernel.
+class PaddedGatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bin_ids, bins, padded_bins)
+ ctx.top_k = top_k
+ return kernels.padded_gather(
+ x,
+ indices,
+ bin_ids,
+ None,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+
+ indices, bin_ids, bins, padded_bins = ctx.saved_tensors
+ out = kernels.padded_scatter(
+ grad,
+ indices,
+ bin_ids,
+ None,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+ return out, None, None, None, None, None
+
+
+padded_gather = PaddedGatherOp.apply
diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/padded_scatter.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/padded_scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..61e021b81497e472cda5d72bdac557a0ca92d262
--- /dev/null
+++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/padded_scatter.py
@@ -0,0 +1,98 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for padded_scatter kernel.
+class PaddedScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+ ):
+ maybe_x = [x] if ctx.needs_input_grad[3] else []
+ ctx.save_for_backward(
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ *maybe_x,
+ )
+ ctx.top_k = top_k
+ ctx.x_shape = x.shape
+ return kernels.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ saved_tensors = ctx.saved_tensors
+
+ indices, bin_ids, weights, bins, padded_bins = saved_tensors[:5]
+ dgrad = None
+ if ctx.needs_input_grad[0]:
+ dgrad = kernels.padded_gather(
+ grad,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[3]: # need wgrad
+ x = saved_tensors[-1]
+ wgrad = kernels.padded_scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+ return dgrad, None, None, wgrad, None, None, None, None
+
+
+def padded_scatter(
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+):
+ return PaddedScatterOp.apply(
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..c575cfe7487d346ba9ec18bbb7ef17f2eb77ec51
--- /dev/null
+++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py
@@ -0,0 +1,66 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import torch
+from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+_PADDED_SCATTER_BENCHMARK = (
+ # dMoE-Medium, 8-way EMP.
+ (1024 * 16, 1024, 8, 4),
+ # dMoE-Medium, post-all-to-all.
+ (1024 * 16 * 4, 1024, 8, 1),
+)
+
+
+class PaddedScatterTest(parameterized.TestCase):
+
+ @parameterized.parameters(*_PADDED_SCATTER_BENCHMARK)
+ def testPaddedScatter(self, sl, hs, ne, top_k):
+ # Create the data and indices.
+ x = torch.randn((sl, hs)).cuda().half()
+
+ # Randomly assign tokens to experts.
+ top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int()
+ bin_ids, indices = ops.sort(top_expert)
+ tokens_per_expert = ops.histogram(top_expert, ne)
+ padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+
+ # Sample weights for the scatter reduce.
+ weights = torch.rand((sl * top_k,)).cuda().half()
+
+ # Gather the data to prepare for backwards.
+ x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
+
+ def benchmark():
+ return ops.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ time, std = benchmark_util.benchmark_function(benchmark)
+ benchmark_util.log_benchmark(
+ 'Padded Scatter',
+ {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'num_experts': ne,
+ 'top_k': top_k,
+ },
+ time,
+ std,
+ )
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/permute_benchmark.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/permute_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..6536eeeae402659a087e5c51ef9840627af56501
--- /dev/null
+++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/permute_benchmark.py
@@ -0,0 +1,149 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import torch
+from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+_PERMUTE_TESTS = (
+ (16384, 768, 2),
+ (16384, 768, 4),
+ (16384, 768, 8),
+ (16384, 768, 16),
+ (16384, 768, 32),
+ (16384, 768, 64),
+ (16384, 768, 128),
+ (16384 * 8, 768, 2),
+ (16384 * 8, 768, 4),
+ (16384 * 8, 768, 8),
+ (16384 * 8, 768, 16),
+ (16384 * 8, 768, 32),
+ (16384 * 8, 768, 64),
+ (16384 * 8, 768, 128),
+)
+
+
+class PermuteBenchmark(parameterized.TestCase):
+
+ @parameterized.parameters(*_PERMUTE_TESTS)
+ def testBinnedGather(self, sl, hs, ne):
+ # NOTE: Capacity factor == 1.
+ ec = sl // ne
+
+ # Create the data and indices.
+ x = torch.randn((sl, hs)).cuda().half()
+ top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+ bin_ids, indices = ops.sort(top_expert)
+ tokens_per_expert = ops.histogram(indices, ne)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+
+ def benchmark():
+ return ops.binned_gather(x, indices, bins, ec)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'num_experts': ne,
+ }
+ benchmark_util.log_benchmark('BinnedGather', arguments, mean_t, std_t)
+
+ @parameterized.parameters(*_PERMUTE_TESTS)
+ def testBinnedScatter(self, sl, hs, ne):
+ # NOTE: Capacity factor == 1.
+ ec = sl // ne
+
+ # Create the data and indices.
+ x = torch.randn((sl, hs)).cuda().half()
+ top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+ bin_ids, indices = ops.sort(top_expert)
+ tokens_per_expert = ops.histogram(indices, ne)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ x = ops.binned_gather(x, indices, bins, ec)
+
+ def benchmark():
+ return ops.binned_scatter(x, indices, bins)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'num_experts': ne,
+ }
+ benchmark_util.log_benchmark('BinnedScatter', arguments, mean_t, std_t)
+
+ @parameterized.parameters(*_PERMUTE_TESTS)
+ def testPaddedGather(self, sl, hs, ne):
+ # Create the data and indices.
+ x = torch.randn((sl, hs)).cuda().half()
+
+ # Randomly assign tokens to experts.
+ top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+ bin_ids, indices = ops.sort(top_expert)
+ tokens_per_expert = ops.histogram(top_expert, ne)
+ padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+
+ def benchmark():
+ return ops.padded_gather(x, indices, bin_ids, bins, padded_bins)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'num_experts': ne,
+ }
+ benchmark_util.log_benchmark('PaddedGather', arguments, mean_t, std_t)
+
+ @parameterized.parameters(*_PERMUTE_TESTS)
+ def testPaddedScatter(self, sl, hs, ne):
+ # Create the data and indices.
+ x = torch.randn((sl, hs)).cuda().half()
+
+ # Randomly assign tokens to experts.
+ top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+ bin_ids, indices = ops.sort(top_expert)
+ tokens_per_expert = ops.histogram(top_expert, ne)
+ padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins)
+
+ def benchmark():
+ return ops.padded_scatter(x, indices, bin_ids, bins, padded_bins)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'num_experts': ne,
+ }
+ benchmark_util.log_benchmark('PaddedScatter', arguments, mean_t, std_t)
+
+ @parameterized.parameters(*_PERMUTE_TESTS)
+ def testCopy(self, sl, hs, ne):
+ # NOTE: Capacity factor == 1.
+ # ec = sl // ne
+
+ # Create the data and indices.
+ x = torch.randn((sl, hs)).cuda().half()
+ y = x.clone()
+
+ def benchmark():
+ return y.copy_(x)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'num_experts': ne,
+ }
+ benchmark_util.log_benchmark('Copy', arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/repeat.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/repeat.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e9e09de5f857d51cd758ab30b2f3a846d6f9275
--- /dev/null
+++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/repeat.py
@@ -0,0 +1,10 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+
+def repeat(x: torch.Tensor, tiling: torch.Size):
+ if all((t == 1 for t in tiling)):
+ return x
+ return x.repeat(*tiling)
diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/replicate.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/replicate.py
new file mode 100644
index 0000000000000000000000000000000000000000..26daf0eede330603a4b8ea7167faf1411d07ca93
--- /dev/null
+++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/replicate.py
@@ -0,0 +1,36 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for replicate kernel.
+class ReplicateOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, bins: torch.Tensor, num_outputs: int):
+ ctx.save_for_backward(bins)
+ out = torch.empty((x.shape[0], num_outputs), dtype=x.dtype, device=x.device)
+ ops.replicate_forward(x, bins, out)
+ return out
+
+ @staticmethod
+ def backward(ctx: Any, grad: torch.Tensor):
+ bins, = ctx.saved_tensors
+ out = torch.empty((grad.shape[0], bins.shape[0]), dtype=grad.dtype, device=grad.device)
+ ops.replicate_backward(grad, bins, out)
+ return out, None, None
+
+
+replicate = ReplicateOp.apply
diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/round_up.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/round_up.py
new file mode 100644
index 0000000000000000000000000000000000000000..6cf6bc873c9f448c5fa9126ebcfd66e8688002af
--- /dev/null
+++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/round_up.py
@@ -0,0 +1,14 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+
+def round_up(x: torch.Tensor, value: int):
+ assert isinstance(value, int)
+ assert x.dtype == torch.int32
+
+ # TODO(tgale): If this becomes and issue
+ # do this in a custom kernel. We only expect
+ # to use this on arrays of less than 1k elements.
+ return torch.div(x + (value - 1), value, rounding_mode='trunc') * value
diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/scatter.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4605d9b46f387761b070352365f223dbfe69d47
--- /dev/null
+++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/scatter.py
@@ -0,0 +1,72 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Optional
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for scatter kernel.
+class ScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ) -> torch.Tensor:
+ maybe_x = [x] if ctx.needs_input_grad[3] else []
+ ctx.save_for_backward(indices, bin_ids, weights, bins, *maybe_x)
+ ctx.top_k = top_k
+ ctx.x_shape = x.shape
+ return kernels.scatter(x, indices, bin_ids, weights, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ saved_tensors = ctx.saved_tensors
+
+ indices, bin_ids, weights, bins = saved_tensors[:4]
+ dgrad = None
+ if ctx.needs_input_grad[0]:
+ dgrad = kernels.gather(
+ grad,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[3]: # need wgrad
+ x = saved_tensors[-1]
+ wgrad = kernels.scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bin_ids,
+ bins,
+ ctx.top_k,
+ )
+ return dgrad, None, None, wgrad, None, None, None
+
+
+def scatter(
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+) -> Optional[torch.Tensor]:
+ return ScatterOp.apply(x, indices, bin_ids, weights, bins, top_k)
diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/sort.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/sort.py
new file mode 100644
index 0000000000000000000000000000000000000000..bda3bf64283e39533c2eae3627e76bb2d0262c9f
--- /dev/null
+++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/sort.py
@@ -0,0 +1,38 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Optional, Tuple
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+_BITS_FOR_DTYPE = {
+ torch.int16: 16,
+ torch.int32: 32,
+ torch.int64: 64,
+}
+
+
+# Autograd wrapper for sort kernel.
+# NOTE: Does not support gradients.
+class SortOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, end_bit: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
+ if end_bit is None:
+ end_bit = _BITS_FOR_DTYPE[x.dtype]
+ x_out = torch.empty_like(x)
+ iota_out = torch.empty_like(x)
+ ops.sort(x, end_bit, x_out, iota_out)
+ return (x_out, iota_out)
+
+
+sort = SortOp.apply
diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/sort_benchmark.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/sort_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..a92ff957d4c552c6e61d9279a7989795472af7b7
--- /dev/null
+++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/sort_benchmark.py
@@ -0,0 +1,85 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import numpy as np
+import torch
+from absl.testing import parameterized
+
+from .. import ops
+
+_SORT_TESTS = (
+ (16384, torch.int32, None),
+ (16384, torch.int32, 2),
+ (16384, torch.int32, 128),
+)
+
+_BASELINE_SORT_TESTS = ((16384,),)
+
+
+def numpy_dtype(dtype):
+ types = {
+ torch.int16: np.int16,
+ torch.int32: np.int32,
+ torch.int64: np.int64,
+ }
+ return types[dtype]
+
+
+def benchmark_function(fn, iterations=10):
+ # Run once to get rid of startup overhead.
+ fn()
+ times = []
+ for _ in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+ start.record()
+ fn()
+ end.record()
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ times = np.array(times)
+ return times.mean(), times.std(), times.max(), times.min()
+
+
+def log_benchmark(arguments, mean_t, std_t):
+ print('=' * 60)
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t))
+ print('=' * 60)
+
+
+class SortBenchmark(parameterized.TestCase):
+
+ @parameterized.parameters(*_SORT_TESTS)
+ def testSort(self, n, dtype, max_val):
+ if max_val is None:
+ max_val = np.iinfo(numpy_dtype(dtype)).max
+ end_bit = int(np.ceil(np.log2(max_val)))
+ x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
+
+ mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.sort(x, end_bit),)
+ arguments = {
+ 'n': n,
+ 'dtype': dtype,
+ 'max_val': max_val,
+ }
+ log_benchmark(arguments, mean_t, std_t)
+
+ @parameterized.parameters(*_BASELINE_SORT_TESTS)
+ def testTorchSort(self, n):
+ x = torch.randint(0, 128, (n,)).cuda().to(torch.int32)
+
+ mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.sort(x))
+ arguments = {
+ 'n': n,
+ }
+ log_benchmark(arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/stk_autocast.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/stk_autocast.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a3626e5e0eec51339c95a448bca84be14a2ca93
--- /dev/null
+++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/stk_autocast.py
@@ -0,0 +1,39 @@
+# vendored from
+# https://github.com/stanford-futuredata/stk/blob/736313768ef697ce13a0594a41b2512a0fbc9884/stk/backend/autocast.py
+import functools
+import torch
+
+
+def _is_eligible(x):
+ return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64)
+
+
+def _cast(x, dtype):
+ if isinstance(x, torch.Tensor) and _is_eligible(x):
+ return x.to(dtype)
+ elif isinstance(x, map):
+ return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()}
+ elif isinstance(x, list) or isinstance(x, tuple):
+ return type(x)(map(lambda y: _cast(y, dtype), x))
+ return x
+
+
+def custom_fwd(fwd):
+ """Wrap a custom autograd function that always uses autocast dtype."""
+
+ @functools.wraps(fwd)
+ def decorate_fwd(*args, **kwargs):
+ if torch.is_autocast_enabled():
+ with torch.autocast(device_type="cuda", enabled=False):
+ dtype = torch.get_autocast_gpu_dtype()
+ return fwd(*_cast(args, dtype), **_cast(kwargs, dtype))
+ return fwd(*args, **kwargs)
+ return decorate_fwd
+
+
+def custom_bwd(bwd):
+ @functools.wraps(bwd)
+ def decorate_bwd(*args, **kwargs):
+ with torch.autocast(device_type="cuda", enabled=False):
+ return bwd(*args, **kwargs)
+ return decorate_bwd
\ No newline at end of file
diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/sum.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/sum.py
new file mode 100644
index 0000000000000000000000000000000000000000..e00c1aa68e1193f5b72f75a2edc37de8d505facc
--- /dev/null
+++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/sum.py
@@ -0,0 +1,9 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+import torch
+
+
+def sum(x: torch.Tensor, dim: int = 0):
+ if x.shape[dim] == 1:
+ return x.squeeze(dim=dim)
+ return x.sum(dim=dim)
diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/topology.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/topology.py
new file mode 100644
index 0000000000000000000000000000000000000000..76a50d3164db20534b099dcb4d8487a7aef25d15
--- /dev/null
+++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/topology.py
@@ -0,0 +1,45 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for topology kernel.
+# NOTE: Does not support gradients.
+class TopologyOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(
+ ctx: Any,
+ padded_bins: torch.Tensor,
+ block_size: int,
+ output_block_rows: int,
+ output_block_columns: int,
+ ):
+ out = torch.empty(
+ output_block_rows * output_block_columns,
+ dtype=torch.int16,
+ device=padded_bins.device,
+ )
+ ops.indices(
+ padded_bins,
+ block_size,
+ output_block_rows,
+ output_block_columns,
+ out,
+ )
+ return out
+
+
+topology = TopologyOp.apply
diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/__init__.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..73c40c267b1c3f4949e9c957a5d2c9f682dfc1a6
--- /dev/null
+++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/__init__.py
@@ -0,0 +1,7 @@
+# import stk.random
+# import stk.ops
+# from stk.matrix import Matrix
+
+from . import random
+from . import ops
+from .matrix import Matrix
diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/backend/__init__.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/backend/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/backend/autocast.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/backend/autocast.py
new file mode 100644
index 0000000000000000000000000000000000000000..97f6e919a60f3fd579ed0215031008d14111dc96
--- /dev/null
+++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/backend/autocast.py
@@ -0,0 +1,37 @@
+import functools
+import torch
+
+
+def _is_eligible(x):
+ return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64)
+
+
+def _cast(x, dtype):
+ if isinstance(x, torch.Tensor) and _is_eligible(x):
+ return x.to(dtype)
+ elif isinstance(x, map):
+ return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()}
+ elif isinstance(x, list) or isinstance(x, tuple):
+ return type(x)(map(lambda y: _cast(y, dtype), x))
+ return x
+
+
+def custom_fwd(fwd):
+ """Wrap a custom autograd function that always uses autocast dtype."""
+
+ @functools.wraps(fwd)
+ def decorate_fwd(*args, **kwargs):
+ if torch.is_autocast_enabled():
+ with torch.autocast(device_type="cuda", enabled=False):
+ dtype = torch.get_autocast_gpu_dtype()
+ return fwd(*_cast(args, dtype), **_cast(kwargs, dtype))
+ return fwd(*args, **kwargs)
+ return decorate_fwd
+
+
+def custom_bwd(bwd):
+ @functools.wraps(bwd)
+ def decorate_bwd(*args, **kwargs):
+ with torch.autocast(device_type="cuda", enabled=False):
+ return bwd(*args, **kwargs)
+ return decorate_bwd
diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/backend/sputnik.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/backend/sputnik.py
new file mode 100644
index 0000000000000000000000000000000000000000..220c947bc1e932e8c77cc30f4069e9930f1aa962
--- /dev/null
+++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/backend/sputnik.py
@@ -0,0 +1,316 @@
+import torch
+
+from ..backend import triton_kernels as backend
+from ..backend.autocast import custom_bwd, custom_fwd
+
+
+def _standardize_shape(x, transpose):
+ if transpose:
+ return torch.Size((x[1], x[0]))
+ return x
+
+
+def _sparse_transpose(x):
+ return (torch.Size((x[0][1], x[0][0])), ) + x[1:]
+
+
+def _transpose_helper(x, transpose):
+ if isinstance(x, torch.Tensor):
+ return x.t() if transpose else x
+ if transpose:
+ x = _sparse_transpose(x)
+ return x + (transpose,)
+
+
+def _wrap(x):
+ if isinstance(x, torch.Tensor):
+ return (x,)
+ return x
+
+
+def _is_transposed(x):
+ return (not x.is_contiguous() and
+ x.stride()[0] == 1 and
+ x.stride()[1] == x.size()[0])
+
+
+def _call_helper(op, out, a, b, trans_a, trans_b):
+ args = (_wrap(_transpose_helper(a, trans_a)) +
+ _wrap(_transpose_helper(b, trans_b)))
+ if isinstance(out, tuple):
+ args = args + out
+ return op(*args)
+
+
+def _preprocess_inputs(lhs, rhs, dy):
+ if isinstance(lhs, torch.Tensor) and _is_transposed(lhs):
+ lhs = lhs.t()
+ if isinstance(rhs, torch.Tensor) and _is_transposed(rhs):
+ rhs = rhs.t()
+ if (isinstance(dy, torch.Tensor) and
+ not dy.is_contiguous() and
+ not _is_transposed(dy)):
+ dy = dy.contiguous()
+ if isinstance(dy, tuple) and not dy[1].is_contiguous():
+ dy = (dy[0], dy[1].contiguous()) + dy[2:]
+ return lhs, rhs, dy
+
+
+def _postprocess_outputs(x, transpose, grad):
+ if isinstance(x, torch.Tensor) and transpose:
+ return grad.t()
+ return grad
+
+
+def _lhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs):
+ lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy)
+
+ a, b = (rhs, dy) if trans_lhs else (dy, rhs)
+ trans_a = trans_lhs and trans_rhs
+ trans_b = trans_lhs or not trans_rhs
+ out = _call_helper(op, lhs, a, b, trans_a, trans_b)
+ return _postprocess_outputs(lhs, trans_lhs, out)
+
+
+def _rhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs):
+ lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy)
+
+ a, b = (dy, lhs) if trans_rhs else (lhs, dy)
+ trans_a = not trans_lhs or trans_rhs
+ trans_b = trans_lhs and trans_rhs
+ out = _call_helper(op, rhs, a, b, trans_a, trans_b)
+ return _postprocess_outputs(rhs, trans_rhs, out)
+
+
+class DSD(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs):
+ ctx.save_for_backward(data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ rhs)
+ ctx.shape = _standardize_shape(shape, transpose_a)
+ ctx.transpose_a = transpose_a
+
+ out = torch.empty(
+ (shape[0], rhs.size()[1]),
+ dtype=rhs.dtype,
+ device=rhs.device)
+
+ backend.dsd(shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs,
+ out)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs = (ctx.shape,) + saved_tensors[:-1]
+ rhs = saved_tensors[-1]
+ trans_a = ctx.transpose_a
+ trans_b = _is_transposed(rhs)
+
+ ddata = None
+ if ctx.needs_input_grad[1]:
+ ddata = _lhs_gradient(sdd,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ drhs = None
+ if ctx.needs_input_grad[-1]:
+ op = dds if trans_b else dsd
+ drhs = _rhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return None, ddata, None, None, None, None, None, None, None, drhs
+
+
+dsd = DSD.apply
+
+
+class DDS(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b):
+ ctx.save_for_backward(lhs,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t)
+ ctx.shape = _standardize_shape(shape, transpose_b)
+ ctx.transpose_b = transpose_b
+ out = torch.empty((lhs.size()[0], shape[1]),
+ dtype=lhs.dtype,
+ device=lhs.device)
+ backend.dds(lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b,
+ out)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs = saved_tensors[0]
+ rhs = (ctx.shape,) + saved_tensors[1:]
+ trans_a = _is_transposed(lhs)
+ trans_b = ctx.transpose_b
+
+ dlhs = None
+ if ctx.needs_input_grad[0]:
+ op = dsd if trans_a else dds
+ dlhs = _lhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ ddata = None
+ if ctx.needs_input_grad[2]:
+ ddata = _rhs_gradient(sdd,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return dlhs, None, ddata, None, None, None, None, None, None, None
+
+
+dds = DDS.apply
+
+
+class SDD(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ lhs,
+ rhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t):
+ ctx.save_for_backward(
+ lhs,
+ rhs,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t)
+ ctx.shape = shape
+ out = torch.empty(
+ data.shape,
+ dtype=lhs.dtype,
+ device=lhs.device)
+ backend.sdd(lhs,
+ rhs,
+ shape,
+ out,
+ offsets,
+ row_indices,
+ column_indices)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs, rhs = saved_tensors[:2]
+ dy = (ctx.shape, dy) + saved_tensors[2:]
+ trans_a = _is_transposed(lhs)
+ trans_b = _is_transposed(rhs)
+
+ dlhs = None
+ if ctx.needs_input_grad[0]:
+ op = dds if trans_a else dsd
+ dlhs = _lhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ drhs = None
+ if ctx.needs_input_grad[1]:
+ op = dsd if trans_b else dds
+ drhs = _rhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return dlhs, drhs, None, None, None, None, None, None, None, None
+
+
+sdd = SDD.apply
+
+class RowIndices(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, shape, data, offsets, column_indices):
+ out = torch.empty(
+ column_indices.shape,
+ dtype=column_indices.dtype,
+ device=column_indices.device)
+ backend.row_indices(shape, data, offsets, column_indices, out)
+ return out
+
+
+row_indices = RowIndices.apply
diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/backend/triton_kernels.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/backend/triton_kernels.py
new file mode 100644
index 0000000000000000000000000000000000000000..c535309f3321249f475367164a558f94a4f8eb86
--- /dev/null
+++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/backend/triton_kernels.py
@@ -0,0 +1,393 @@
+import torch
+import triton
+import triton.language as tl
+from dataclasses import dataclass
+
+@dataclass
+class TritonConfig:
+ BLOCK_M: int = 128
+ BLOCK_N: int = 128
+ BLOCK_K: int = 32
+ BLOCK_SIZE: int = 128
+ NUM_STAGES: int = 4
+ NUM_WARPS: int = 4
+
+def _validate_matmul_dims(M: int, K: int, N: int):
+ error_string = "incompatible dimensions: tensor has dim with length: {}, which must be divisible by {}"
+ assert M % TritonConfig.BLOCK_M == 0, error_string.format(M, TritonConfig.BLOCK_M)
+ assert K % TritonConfig.BLOCK_K == 0, error_string.format(K, TritonConfig.BLOCK_K)
+ assert N % TritonConfig.BLOCK_N == 0, error_string.format(N, TritonConfig.BLOCK_N)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _sdd_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+ # matrix multiplication
+ pid = tl.program_id(0)
+ pid_m = tl.load(row_indices + pid)
+ pid_n = tl.load(column_indices + pid)
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
+ rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
+ rk = tl.arange(0, BLOCK_K)
+ # pointers
+ A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
+ B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ for k in range(0, tl.cdiv(K, BLOCK_K)):
+ a = tl.load(A)
+ b = tl.load(B)
+ acc += tl.dot(a, b)
+ A += BLOCK_K * stride_ak
+ B += BLOCK_K * stride_bk
+ #Store to sparse matrix
+ acc = acc.to(C.dtype.element_ty)
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+ cm = tl.arange(0, BLOCK_M)
+ cn = tl.arange(0, BLOCK_N)
+ C = C + pid * BLOCK_ELEMENTS + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _dsd_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices, offsets,
+ block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+
+ # matrix multiplication
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+
+ num_pid_m = tl.num_programs(0)
+ num_pid_n = tl.num_programs(1)
+ pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M)
+
+ start_inx = tl.load(offsets + pid_m)
+ end_inx = tl.load(offsets + pid_m + 1)
+
+ # pointers to sparse matrix
+ rm = tl.arange(0, BLOCK_M)
+ rak = tl.arange(0, BLOCK_K)
+
+ A += (rm[:, None] * stride_am + rak[None, :] * stride_ak)
+
+ # pointers to dense matrix
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ rbk = tl.arange(0, BLOCK_K)
+ B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn)
+
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K)
+
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+ ak_sub_incr = BLOCK_K * stride_ak
+ bk_sub_incr = BLOCK_K * stride_bk
+ bk_block_incr = BLOCK_SIZE * stride_bk
+
+ for k in range(nsub_blocks * (end_inx - start_inx)):
+ sub_block_inx = k % nsub_blocks
+ block_inx = k // nsub_blocks
+
+ if trans_A:
+ ptr_A = A + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr
+ else:
+ ptr_A = A + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr
+
+ ptr_B = B + tl.load(column_indices + start_inx + block_inx) * bk_block_incr + sub_block_inx * bk_sub_incr
+
+ a = tl.load(ptr_A)
+ b = tl.load(ptr_B)
+ acc += tl.dot(a, b)
+
+ acc = acc.to(C.dtype.element_ty)
+
+ cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+
+ C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _dds_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices, offsets,
+ block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+
+ # matrix multiplication
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+
+ num_pid_m = tl.num_programs(0)
+ num_pid_n = tl.num_programs(1)
+ pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M)
+
+ start_inx = tl.load(offsets + pid_n)
+ end_inx = tl.load(offsets + pid_n + 1)
+
+ # pointers to dense matrix
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ rak = tl.arange(0, BLOCK_K)
+
+ A += (rm[:, None] * stride_am + rak[None, :] * stride_ak)
+
+ # pointers to sparse matrix
+ rn = tl.arange(0, BLOCK_N)
+ rbk = tl.arange(0, BLOCK_K)
+ B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn)
+
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K)
+
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+
+ ak_sub_incr = BLOCK_K * stride_ak
+ ak_block_incr = BLOCK_SIZE * stride_ak
+ bk_sub_incr = BLOCK_K * stride_bk
+
+ for k in range(nsub_blocks * (end_inx - start_inx)):
+ sub_block_inx = k % nsub_blocks
+ block_inx = k // nsub_blocks
+
+ if trans_B:
+ ptr_B = B + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr
+ else:
+ ptr_B = B + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr
+
+ ptr_A = A + tl.load(column_indices + start_inx + block_inx) * ak_block_incr + sub_block_inx * ak_sub_incr
+ a = tl.load(ptr_A)
+ b = tl.load(ptr_B)
+ acc += tl.dot(a, b)
+
+ acc = acc.to(C.dtype.element_ty)
+ cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+def dsd(shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs,
+ out
+ ):
+
+ device = rhs.device
+ trans_A = transpose_a
+ trans_B = False
+
+ if rhs.stride(0) > 1 and rhs.stride(1) > 1:
+ trans_B = True
+
+ # checks constraints
+ assert shape[1] == rhs.shape[0], "incompatible dimensions"
+ M, K = shape
+ _, N = rhs.shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if rhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ stride_am, stride_ak = data.stride(1), data.stride(2)
+ stride_bk, stride_bn = rhs.stride(0), rhs.stride(1)
+ a_column_indices = column_indices
+ a_offsets = offsets
+
+ # launch kernel
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N']))
+
+ if trans_A:
+ stride_am, stride_ak = data.stride(2), data.stride(1)
+ a_column_indices, a_offsets = column_indices_t, offsets_t
+
+ if trans_B:
+ stride_bk, stride_bn = rhs.stride(1), rhs.stride(0)
+
+ _dsd_kernel[grid](
+ data.data, rhs, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(0), out.stride(1),
+ row_indices, a_column_indices, a_offsets,
+ block_offsets_t, trans_A, trans_B,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+ # return out
+
+def dds(lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b,
+ out
+ ):
+
+ device = lhs.device
+ trans_B = transpose_b
+ trans_A = False
+
+ if lhs.stride(0) > 1 and lhs.stride(1) > 1:
+ trans_A = True
+
+ # checks constraints
+ assert lhs.shape[1] == shape[0], "incompatible dimensions"
+ M, K = lhs.shape
+ _, N = shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if lhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ stride_am, stride_ak = lhs.stride(0), lhs.stride(1)
+ stride_bk, stride_bn = data.stride(1), data.stride(2)
+ b_column_indices = column_indices_t
+ b_offsets = offsets_t
+
+ # launch kernel
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N']))
+
+ if trans_A:
+ stride_am, stride_ak = lhs.stride(1), lhs.stride(0)
+ if trans_B:
+ stride_bk, stride_bn = data.stride(2), data.stride(1)
+ b_column_indices, b_offsets = column_indices, offsets
+
+ _dds_kernel[grid](
+ lhs, data, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(0), out.stride(1),
+ row_indices, b_column_indices, b_offsets,
+ block_offsets_t, trans_A, trans_B,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+
+def sdd(lhs,
+ rhs,
+ shape,
+ out,
+ offsets,
+ row_indices,
+ column_indices
+ ):
+
+ device = out.device
+ trans_A = False
+ trans_B = False
+
+ if lhs.stride(0) > 1 and lhs.stride(1) > 1:
+ trans_A = True
+ if rhs.stride(0) > 1 and rhs.stride(1) > 1:
+ trans_B = True
+
+ # checks constraints
+ assert lhs.shape[1] == rhs.shape[0], "incompatible dimensions"
+ M, K = lhs.shape
+ _, N = rhs.shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if out.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ # launch kernel
+ nnz_blocks = len(row_indices)
+ grid = lambda META: (nnz_blocks,)
+
+ stride_am, stride_ak = lhs.stride(0), lhs.stride(1)
+ stride_bk, stride_bn = rhs.stride(0), rhs.stride(1)
+
+ if trans_A:
+ stride_am, stride_ak = lhs.stride(1), lhs.stride(0)
+ if trans_B:
+ stride_bk, stride_bn = rhs.stride(1), rhs.stride(0)
+
+ _sdd_kernel[grid](
+ lhs, rhs, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(1), out.stride(2),
+ row_indices, column_indices,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+
+@triton.jit
+def _row_indices_kernel(offsets, out):
+ pid = tl.program_id(0)
+ row_offset = tl.load(offsets + pid)
+ nnz_blocks = tl.load(offsets + pid + 1) - row_offset
+ for nnz_block in range(nnz_blocks):
+ tl.store(out + row_offset + nnz_block, pid)
+
+def row_indices(
+ shape, data, offsets, column_indices, out
+):
+ block_rows = len(offsets) - 1
+ _row_indices_kernel[(block_rows, )](offsets, out)
diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/matrix.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/matrix.py
new file mode 100644
index 0000000000000000000000000000000000000000..80f42263d6aada287adbfa52a61fe950162a9e28
--- /dev/null
+++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/matrix.py
@@ -0,0 +1,329 @@
+import numpy as np
+import torch
+
+# 1. Add heavyweight (data) validation helper.
+# 2. Add construction helpers
+# 3. Make indentation consistent
+# 4. Replace asserts with descriptive errors.
+
+##
+### Validation helpers.
+##
+
+
+def _validate_matrix(shape, data, row_indices, column_indices, offsets):
+ # Data should be [nnz, block_size, block_size]
+ if data.dim() == 1:
+ data = torch.reshape(data, [data.numel(), 1, 1])
+
+ # Blocks should be square.
+ if data.shape[-2] != data.shape[-1]:
+ raise ValueError(
+ "Expected square blocking in data. "
+ f"Got block shape {[data.shape[-2], data.shape[-1]]}")
+
+ # Flatten batch dimensions on data - original shape preserved
+ # in shape argument.
+ block_size = data.shape[-1]
+ data = data.view([-1, block_size, block_size])
+
+ if data.dim() != 3:
+ raise ValueError(
+ "Expected 3D shape for data (nnz, block, block). "
+ f"Got shape {data.dim()}D shape.")
+
+ block_size = data.shape[1]
+ if shape[-2] % block_size != 0 or shape[-1] % block_size != 0:
+ raise ValueError(
+ "Matrix shape must be dividible by blocking. "
+ f"Got shape {shape} with "
+ f"{[block_size, block_size]} blocking.")
+
+ if np.prod(shape) < data.numel():
+ raise ValueError(
+ "Invalid matrix. Number of nonzeros exceeds matrix capacity "
+ f"({data.numel()} v. {np.prod(shape)})")
+
+ if row_indices.dim() != 1:
+ raise ValueError(
+ f"Expected 1D row_indices. Got {row_indices.dim()}D row_indices.")
+
+ if column_indices.dim() != 1:
+ raise ValueError(
+ f"Expected 1D column_indices. Got {column_indices.dim()}D column_indices.")
+
+ if offsets.dim() != 1:
+ raise ValueError(
+ f"Expected 1D offsets. Got {offsets.dim()}D offsets.")
+
+ if row_indices.numel() != data.shape[0]:
+ raise ValueError(
+ "Expected 1 index per nonzero block. "
+ f"Got {row_indices.numel()} row_indices for {data.shape[0]} blocks")
+
+ if column_indices.numel() != data.shape[0]:
+ raise ValueError(
+ "Expected 1 index per nonzero block. "
+ f"Got {column_indices.numel()} column_indices for {data.shape[0]} blocks")
+
+ block_rows = np.prod(shape[:-1]) / block_size
+ if offsets.numel() != block_rows + 1:
+ raise ValueError(
+ "Expected one offset per block row plus one. "
+ f"Got {offsets.numel()} offsets with {block_rows} block rows.")
+
+ is_cuda = (data.is_cuda and
+ row_indices.is_cuda and
+ column_indices.is_cuda and
+ offsets.is_cuda)
+ is_cpu = (not data.is_cuda and
+ not row_indices.is_cuda and
+ not column_indices.is_cuda and
+ not offsets.is_cuda)
+ if not (is_cuda or is_cpu):
+ raise ValueError(
+ "Expected data & meta-data on common device. "
+ f"Got data on {data.device}, row_indices on {row_indices.device} "
+ f"column_indices on {column_indices.device} and "
+ f"offsets on {offsets.device}.")
+
+ if data.dtype != torch.float16:
+ raise ValueError(
+ f"Expected float16 data. Got {data.dtype} data.")
+ if row_indices.dtype != torch.int16:
+ raise ValueError(
+ f"Expected int16 row_indices. Got {row_indices.dtype} row_indices.")
+ if column_indices.dtype != torch.int16:
+ raise ValueError(
+ f"Expected int16 column_indices. Got {column_indices.dtype} column_indices.")
+ if offsets.dtype != torch.int32:
+ raise ValueError(
+ f"Expected int32 offsets. Got {offsets.dtype} offsets.")
+ return data
+
+
+def _transpose(size, data, row_indices, column_indices, offsets):
+ block_columns = size[1] // data.shape[1]
+
+ # Sort row indices by column indices to get the transposed matrix's
+ # column indices.
+ gather_indices = column_indices.argsort()
+ column_indices_t = row_indices.gather(0, gather_indices)
+ block_offsets_t = gather_indices.int()
+
+ # NOTE: Histogram is not implemented for any integer type on CPU. Do
+ # the histogram in 32-bit float, which can exactly represent 16-bit
+ # integers.
+ column_indices_float = column_indices.float()
+
+ zero = torch.zeros((1,), dtype=torch.int32, device=data.device)
+ nnz_per_column = column_indices_float.histc(block_columns, 0, block_columns)
+ nnz_per_column = nnz_per_column.int()
+ offsets_t = torch.cat([zero, nnz_per_column.cumsum(0, dtype=torch.int32)])
+ return column_indices_t, offsets_t, block_offsets_t
+
+
+class Matrix(torch.nn.Module):
+ """A matrix stored in sparse format.
+
+ Underlying format is block compressed sparse row (BCSR).
+
+ TODO(tgale): Make this mirror torch.Tensor API as much as possible.
+ """
+
+ def __init__(self,
+ size,
+ data,
+ row_indices,
+ column_indices,
+ offsets,
+ column_indices_t=None,
+ offsets_t=None,
+ block_offsets_t=None):
+ super().__init__()
+ self._size = size
+ self._data = data
+ self._row_indices = row_indices
+ self._column_indices = column_indices
+ self._offsets = offsets
+
+ # Produce the transpose meta-data if it is not passed in.
+ if ((column_indices_t is None) or (offsets_t is None) or
+ (block_offsets_t is None)):
+ column_indices_t, offsets_t, block_offsets_t = _transpose(
+ size, data, row_indices, column_indices, offsets)
+ self._column_indices_t = column_indices_t
+ self._offsets_t = offsets_t
+ self._block_offsets_t = block_offsets_t
+
+ self._transposed = False
+
+ # Validate that our metadata will not overflow.
+ max_dim = np.iinfo(np.int16).max * self.blocking
+ if column_indices.dtype == torch.int16:
+ if size[0] > max_dim or size[1] > max_dim:
+ raise ValueError(
+ "Sparse matrix with shape {size} exceeds representable "
+ "size with 16-bit indices.")
+
+ def validate(self):
+ _validate_matrix(self._size,
+ self._data,
+ self._row_indices,
+ self._column_indices,
+ self._offsets)
+
+ # TODO(tgale): Add heavyweight data validation.
+
+ def to(self, device):
+ # TODO(tgale): Handle type conversions here. We
+ # need to set the appropriate meta-data type for
+ # the given floating-point type.
+ self._data = self._data.to(device)
+ self._row_indices = self._row_indices.to(device)
+ self._column_indices = self._column_indices.to(device)
+ self._offsets = self._offsets.to(device)
+ self._column_indices_t = self._column_indices_t.to(device)
+ self._offsets_t = self._offsets_t.to(device)
+ self._block_offsets_t = self._block_offsets_t.to(device)
+ return self
+
+ def cuda(self):
+ return self.to(torch.cuda.current_device())
+
+ def clone(self):
+ return Matrix(
+ self.size(),
+ self.data.clone(),
+ self.row_indices.clone(),
+ self.column_indices.clone(),
+ self.offsets.clone(),
+ self.column_indices_t.clone(),
+ self.offsets_t.clone(),
+ self.block_offsets_t.clone())
+
+ def t(self):
+ if self.dim() != 2:
+ raise ValueError(
+ "t() expects a tensor with <= 2 dimensions, "
+ f"but self is {self.dim()}D.")
+ out = Matrix(self.size(),
+ self.data,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+ out._transposed = not self._transposed
+ out._size = torch.Size((self._size[1], self._size[0]))
+ return out
+
+ def contiguous(self):
+ raise ValueError("Not yet implemented.")
+
+ def is_contiguous(self):
+ return not self._transposed
+
+ @property
+ def is_cuda(self):
+ return self._data.is_cuda
+
+ @property
+ def device(self):
+ return self._data.device
+
+ def size(self):
+ return self._size
+
+ @property
+ def shape(self):
+ return self.size()
+
+ def dim(self):
+ return len(self._size)
+
+ @property
+ def data(self):
+ return self._data
+
+ @property
+ def row_indices(self):
+ return self._row_indices
+
+ @property
+ def column_indices(self):
+ return self._column_indices
+
+ @property
+ def offsets(self):
+ return self._offsets
+
+ @property
+ def offsets_t(self):
+ return self._offsets_t
+
+ @property
+ def column_indices_t(self):
+ return self._column_indices_t
+
+ @property
+ def block_offsets_t(self):
+ return self._block_offsets_t
+
+ @property
+ def dtype(self):
+ return self.data.dtype
+
+ @property
+ def nnz(self):
+ return self.data.numel()
+
+ @property
+ def blocking(self):
+ return self.data.shape[1]
+
+ @property
+ def requires_grad(self):
+ return self.data.requires_grad
+
+ def requires_grad_(self, x):
+ self.data.requires_grad_(x)
+ return self
+
+ def view(self, *shape):
+ assert self.is_contiguous()
+ if shape[-1] != self.size()[-1]:
+ raise ValueError(
+ "Can't change view on compressed dimension. "
+ f"{self.size()[-1]} v. {shape[-1]}.")
+ if np.prod(shape) != np.prod(self.size()):
+ raise ValueError(
+ "Mismatch in numel of Matrix and new shape. "
+ f"{np.prod(self.size())} v. {np.prod(shape)}")
+ return Matrix(shape,
+ self.data,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+
+ @property
+ def grad(self):
+ # TODO(tgale): Make sure this mirrors torch.Tensor
+ # behavior in the case where we ask for the gradient
+ # of a non-contiguous tensor.
+ size = self.size()
+ if not self.is_contiguous():
+ size = torch.Size((size[1], size[0]))
+ out = Matrix(size,
+ self.data.grad,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+ return out if self.is_contiguous() else out.t()
diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/ops/__init__.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc873b236f4cd4036964c016a4036e3ce5ebf1ac
--- /dev/null
+++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/ops/__init__.py
@@ -0,0 +1,3 @@
+from .linear_ops import dds, dsd, sdd
+from .matrix_ops import ones_like, row_indices, sum, to_dense, to_sparse
+from .eltwise_ops import mul
diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/ops/eltwise_ops.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/ops/eltwise_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba7d7332320250fd01fa60e60528f19de3e8ed03
--- /dev/null
+++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/ops/eltwise_ops.py
@@ -0,0 +1,28 @@
+from ..matrix import Matrix
+
+def mul(a, b):
+ """Performs element-wise multiplication of matrices a and b.
+
+ It is the user's responsibility to make sure that a and b
+ follow the same matrix topology. This function assumes it is safe
+ to use the topoplogy of a.
+
+ Args:
+ a: stk.Matrix.
+ b: stk.Matrix with a's matrix topology.
+
+ Returns:
+ stk.Matrix where the entries correspond to torch.mul(a, b).
+ """
+ assert isinstance(a, Matrix)
+ assert isinstance(b, Matrix)
+ assert a.size() == b.size()
+
+ return Matrix(a.size(),
+ a.data * b.data,
+ a.row_indices,
+ a.column_indices,
+ a.offsets,
+ a.column_indices_t,
+ a.offsets_t,
+ a.block_offsets_t)
diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/ops/eltwise_ops_test.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/ops/eltwise_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..66bfd4f6af77042d3c5bdb1fe18d00e457478d46
--- /dev/null
+++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/ops/eltwise_ops_test.py
@@ -0,0 +1,86 @@
+import unittest
+import itertools
+import torch
+from absl.testing import parameterized
+
+import stk
+from stk.ops.linear_ops_test import allclose, _dense_and_sparse
+
+_MATRIX_SIZES = (
+ (128, 128, 0.0),
+ (256, 256, 0.5),
+ (2048, 1024, 0.8),
+ (512, 128, 0.0),
+ (128, 512, 0.0),
+ (1024, 512, 0.0),
+ (1024, 512, 0.5),
+ (1024, 512, 0.75),
+ (512, 1024, 0.0),
+ (512, 1024, 0.5),
+ (512, 1024, 0.75),
+ (1024, 1024, 0.0),
+ (1024, 1024, 0.5),
+ (1024, 1024, 0.75),
+)
+
+_DTYPE = (
+ torch.float16, torch.bfloat16
+)
+
+def _generate_testcases():
+ testcases = itertools.product(_MATRIX_SIZES, _DTYPE)
+ testcases = [(*size, 128, dtype) for
+ (size, dtype) in testcases]
+ return testcases
+
+_ELTWISE_OP_TESTS = _generate_testcases()
+
+def _dense_and_sparse_like(x, std=0.1):
+ dense_data = torch.randn_like(x.data, device=x.device) * std
+ sparse = stk.Matrix(x.size(),
+ dense_data,
+ x.row_indices,
+ x.column_indices,
+ x.offsets)
+ dense = stk.ops.to_dense(sparse)
+
+ return (dense.requires_grad_(True),
+ sparse.requires_grad_(True))
+
+@parameterized.parameters(_ELTWISE_OP_TESTS)
+class EltwiseOpsTest(parameterized.TestCase):
+
+ def testEltwiseMul(self, m, n, sparsity, blocking, dtype):
+
+ a_dense, a = _dense_and_sparse(m, n, sparsity, blocking, dtype)
+ b_dense, b = _dense_and_sparse_like(a)
+
+ out = stk.ops.mul(a, b)
+ expected_out = torch.mul(a_dense, b_dense)
+
+ # Compute the gradients w.r.t. the inputs.
+ expected_out.sum().backward()
+ stk.ops.sum(out).backward()
+
+ # Validate the results.
+ out = stk.ops.to_dense(out)
+ self.assertEqual(out.dim(), 2)
+ self.assertEqual(expected_out.size(), out.size())
+ self.assertTrue(allclose(out, expected_out))
+
+ # LHS gradient.
+ grad = stk.ops.to_dense(a.grad)
+ expected_grad = a_dense.grad
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size(), grad.size())
+ self.assertTrue(allclose(grad, expected_grad))
+
+ # RHS gradient.
+ grad = stk.ops.to_dense(b.grad)
+ expected_grad = b_dense.grad
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size(), grad.size())
+ self.assertTrue(allclose(grad, expected_grad))
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/ops/linear_ops.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/ops/linear_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d277c8c07f9e30addc31900a12175c8a1f4d7ad
--- /dev/null
+++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/ops/linear_ops.py
@@ -0,0 +1,59 @@
+import torch
+
+from ..backend import sputnik
+from ..matrix import Matrix
+
+
+def dsd(a, b):
+ assert isinstance(a, Matrix)
+ assert isinstance(b, torch.Tensor)
+ return sputnik.dsd(
+ a.size(),
+ a.data, a.offsets,
+ a.row_indices,
+ a.column_indices,
+ a.offsets_t,
+ a.column_indices_t,
+ a.block_offsets_t,
+ not a.is_contiguous(),
+ b)
+
+
+def dds(a, b):
+ assert isinstance(a, torch.Tensor)
+ assert isinstance(b, Matrix)
+ return sputnik.dds(
+ a,
+ b.size(),
+ b.data, b.offsets,
+ b.row_indices,
+ b.column_indices,
+ b.offsets_t,
+ b.column_indices_t,
+ b.block_offsets_t,
+ not b.is_contiguous())
+
+
+def sdd(a, b, topo):
+ assert isinstance(a, torch.Tensor)
+ assert isinstance(b, torch.Tensor)
+ assert isinstance(topo, Matrix)
+ assert topo.is_contiguous()
+ out = sputnik.sdd(
+ a, b,
+ topo.size(),
+ topo.data,
+ topo.offsets,
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets_t,
+ topo.column_indices_t,
+ topo.block_offsets_t)
+ return Matrix(topo.size(),
+ out,
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets,
+ topo.column_indices_t,
+ topo.offsets_t,
+ topo.block_offsets_t)
diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/ops/linear_ops_test.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/ops/linear_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..ced1d782fbc9f9ca16b3449239f1588dc5ff5e00
--- /dev/null
+++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/ops/linear_ops_test.py
@@ -0,0 +1,216 @@
+import unittest
+import itertools
+import numpy as np
+import torch
+from absl.testing import parameterized
+
+import stk
+
+
+def allclose(x, y, pct=0.25):
+ mask = torch.isclose(x, y, rtol=5e-2)
+ pct_diff = (mask.numel() - mask.sum()) / mask.numel() * 100
+ if pct_diff > pct:
+ print("{:.2f}% of values not close.".format(pct_diff))
+ return False
+ return True
+
+
+# An assortment of problems designed to make sure
+# the bindings are operating correctly.
+_MATRIX_SIZES = (
+ (128, 128, 128, 0.0),
+ (256, 256, 256, 0.5),
+ (2048, 1024, 512, 0.8),
+ (512, 128, 128, 0.0),
+ (128, 128, 512, 0.0),
+ (1024, 512, 512, 0.0),
+ (1024, 512, 512, 0.5),
+ (1024, 512, 512, 0.75),
+ (512, 512, 1024, 0.0),
+ (512, 512, 1024, 0.5),
+ (512, 512, 1024, 0.75),
+ (1024, 1024, 1024, 0.0),
+ (1024, 1024, 1024, 0.5),
+ (1024, 1024, 1024, 0.75),
+)
+
+_TRANSPOSE = (
+ (False, False),
+ (False, True),
+ (True, False),
+ (True, True),
+)
+
+_DTYPE = (
+ torch.float16, torch.bfloat16
+)
+
+def _generate_testcases():
+ testcases = itertools.product(_MATRIX_SIZES, _TRANSPOSE, _DTYPE)
+ testcases = [(*size, *trans, 128, dtype) for
+ (size, trans, dtype) in testcases]
+ return testcases
+
+_LINEAR_OP_TESTS = _generate_testcases()
+
+def _dense_and_sparse(rows, cols, sparsity, blocking, dtype, std=0.1):
+ mask = stk.random.dense_mask(rows, cols, sparsity, blocking)
+ dense = (torch.randn(rows, cols) * std * mask).type(dtype)
+ sparse = stk.ops.to_sparse(dense, blocking)
+ cuda_device = torch.device("cuda")
+ return (dense.to(cuda_device).requires_grad_(True),
+ sparse.to(cuda_device).requires_grad_(True))
+
+
+def _dense(rows, cols, dtype, std=0.1):
+ cuda_device = torch.device("cuda")
+ out = (torch.randn(rows, cols) * std).type(dtype)
+ return out.to(cuda_device).requires_grad_(True)
+
+
+def _dense_2x(rows, cols, dtype):
+ a = _dense(rows, cols, dtype)
+ return a, a.detach().requires_grad_(True)
+
+
+def _with_transpose(op, a, b, trans_a, trans_b):
+ a = a.t() if trans_a else a
+ b = b.t() if trans_b else b
+ return op(a, b)
+
+
+def _mmm(a, b, topo):
+ mask = stk.ops.to_dense(stk.ops.ones_like(topo))
+ return torch.mm(a, b) * mask
+
+
+def _sparse_out_with_transpose(op, a, b, topo, trans_a, trans_b):
+ a = a.t() if trans_a else a
+ b = b.t() if trans_b else b
+ return op(a, b, topo)
+
+
+def _mask(x, mask):
+ mask = stk.ops.to_dense(stk.ops.ones_like(mask))
+ return x * mask
+
+
+@parameterized.parameters(*_LINEAR_OP_TESTS)
+class LinearOpsTest(parameterized.TestCase):
+
+ def testLinearOps_Dsd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+ # Construct the operands.
+ a_shape = (k, m) if trans_a else (m, k)
+ a_dense, a = _dense_and_sparse(*a_shape, sparsity, blocking, dtype)
+ b_shape = (n, k) if trans_b else (k, n)
+ b, bcp = _dense_2x(*b_shape, dtype)
+
+ # Execute the matmul.
+ out = _with_transpose(stk.ops.dsd, a, b, trans_a, trans_b)
+ expected_out = _with_transpose(torch.mm, a_dense, bcp, trans_a, trans_b)
+
+ # Compute the gradients w.r.t. the inputs.
+ expected_out.sum().backward()
+ out.sum().backward()
+
+ # Validate the results.
+ self.assertEqual(out.dim(), 2)
+ self.assertEqual(expected_out.size()[0], out.size()[0])
+ self.assertEqual(expected_out.size()[1], out.size()[1])
+ self.assertTrue(allclose(out, expected_out))
+
+ # LHS gradient.
+ grad = stk.ops.to_dense(a.grad)
+ expected_grad = _mask(a_dense.grad, a.grad)
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
+ self.assertTrue(allclose(grad, expected_grad))
+
+ # RHS gradient.
+ grad = b.grad
+ expected_grad = bcp.grad
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
+ self.assertTrue(allclose(grad, expected_grad))
+
+ def testLinearOps_Dds(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+ # Construct the operands.
+ a_shape = (k, m) if trans_a else (m, k)
+ a, acp = _dense_2x(*a_shape, dtype)
+ b_shape = (n, k) if trans_b else (k, n)
+ b_dense, b = _dense_and_sparse(*b_shape, sparsity, blocking, dtype)
+
+ # Execute the matmul.
+ out = _with_transpose(stk.ops.dds, a, b, trans_a, trans_b)
+ expected_out = _with_transpose(torch.mm, acp, b_dense, trans_a, trans_b)
+
+ # Compute the gradients w.r.t. the inputs.
+ expected_out.sum().backward()
+ out.sum().backward()
+
+ # Validate the results.
+ self.assertEqual(out.dim(), 2)
+ self.assertEqual(expected_out.size()[0], out.size()[0])
+ self.assertEqual(expected_out.size()[1], out.size()[1])
+ self.assertTrue(allclose(out, expected_out))
+
+ # LHS gradient.
+ grad = a.grad
+ expected_grad = acp.grad
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
+ self.assertTrue(allclose(grad, expected_grad))
+
+ # RHS gradient.
+ grad = stk.ops.to_dense(b.grad)
+ expected_grad = _mask(b_dense.grad, b.grad)
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
+ self.assertTrue(allclose(grad, expected_grad))
+
+ def testLinearOps_Sdd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+ # Construct the operands.
+ a_shape = (k, m) if trans_a else (m, k)
+ a, acp = _dense_2x(*a_shape, dtype)
+ b_shape = (n, k) if trans_b else (k, n)
+ b, bcp = _dense_2x(*b_shape, dtype)
+ _, topo = _dense_and_sparse(m, n, sparsity, blocking, dtype)
+
+ # Execute the matmul.
+ out = _sparse_out_with_transpose(stk.ops.sdd, a, b, topo, trans_a, trans_b)
+ expected_out = _sparse_out_with_transpose(_mmm, acp, bcp, topo, trans_a, trans_b)
+
+ # Compute the gradients w.r.t. the inputs.
+ expected_out.sum().backward()
+ stk.ops.sum(out).backward()
+
+ # Validate the results.
+ out = stk.ops.to_dense(out)
+ self.assertEqual(out.dim(), 2)
+ self.assertEqual(expected_out.size()[0], out.size()[0])
+ self.assertEqual(expected_out.size()[1], out.size()[1])
+ self.assertTrue(allclose(out, expected_out))
+
+ # LHS gradient.
+ grad = a.grad
+ expected_grad = acp.grad
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
+ self.assertTrue(allclose(grad, expected_grad))
+
+ # RHS gradient.
+ grad = b.grad
+ expected_grad = bcp.grad
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
+ self.assertTrue(allclose(grad, expected_grad))
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/ops/matrix_ops.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/ops/matrix_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..447c72dc73439d84f58c917676cc04e64f13e97d
--- /dev/null
+++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/ops/matrix_ops.py
@@ -0,0 +1,98 @@
+from ..backend import sputnik
+from ..matrix import Matrix
+import torch
+import numpy as np
+
+
+@torch.no_grad()
+def row_indices(shape, data, offsets, column_indices):
+ return sputnik.row_indices(shape, data, offsets, column_indices)
+
+
+# TODO(tgale): Replace this helper with a custom kernel. This operation
+# is much simpler to do than how it's currently implemented.
+@torch.no_grad()
+def _expand_for_blocking(idxs, blocking):
+ # Duplicate for block column dimension.
+ idxs = torch.reshape(idxs, [idxs.size()[0], 1, 2]).repeat(1, blocking, 1)
+
+ # Update the column indices.
+ idxs[:, :, 1] *= blocking
+ idxs[:, :, 1] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking])
+
+ # Duplicate for block row dimension.
+ idxs = torch.reshape(idxs, [idxs.size()[0], 1, blocking, 2])
+ idxs = idxs.repeat(1, blocking, 1, 1)
+
+ # Update the row indices.
+ idxs[:, :, :, 0] *= blocking
+ idxs[:, :, :, 0] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking, 1])
+ idxs = torch.reshape(idxs, [-1, 2])
+ return idxs
+
+
+# TODO(tgale): Add input type checking.
+@torch.no_grad()
+def to_dense(x):
+ assert isinstance(x, Matrix)
+
+ shape = (np.prod(x.shape[:-1]), x.shape[-1])
+ row_idxs = x.row_indices.type(torch.int32)
+ col_idxs = x.column_indices.type(torch.int32)
+ indices = _expand_for_blocking(torch.stack([row_idxs, col_idxs], dim=1), x.blocking)
+ indices = (indices[:, 0] * shape[1] + indices[:, 1]).type(torch.int64)
+
+ out = torch.zeros(shape[0] * shape[1], dtype=x.dtype, device=x.device)
+ out.scatter_(0, indices, x.data.flatten())
+ return out.reshape(x.size())
+
+
+@torch.no_grad()
+def _mask(x, blocking=1):
+ assert x.dim() == 2
+ assert x.size()[0] % blocking == 0
+ assert x.size()[1] % blocking == 0
+ block_rows = x.size()[0] // blocking
+ block_cols = x.size()[1] // blocking
+ x = torch.reshape(x, [block_rows, blocking, block_cols, blocking])
+ x = torch.sum(torch.abs(x), dim=(1, 3))
+ return x != 0
+
+
+# TODO(tgale): Add input type checking.
+@torch.no_grad()
+def to_sparse(x, blocking=1):
+ m = _mask(x, blocking)
+
+ # TODO(tgale): Set to appropriate type for input matrix.
+ row_nnzs = torch.sum(m, dim=1).type(torch.int32)
+ zeros = torch.zeros((1,), dtype=row_nnzs.dtype, device=row_nnzs.device)
+ offsets = torch.cat([zeros, torch.cumsum(row_nnzs, dim=0)])
+ offsets = offsets.type(torch.int32)
+
+ indices = torch.nonzero(m).type(torch.int16)
+ row_indices = indices[:, 0]
+ column_indices = indices[:, 1]
+
+ # Nonzero indices in the dense matrix.
+ nonzero_indices = torch.nonzero(m)
+ nonzero_indices = _expand_for_blocking(nonzero_indices, blocking)
+ nonzero_indices = nonzero_indices[:, 0] * x.size()[1] + nonzero_indices[:, 1]
+
+ # Gather the data and construct the sparse matrix.
+ data = torch.gather(x.flatten(), dim=0, index=nonzero_indices)
+ data = torch.reshape(data, [-1, blocking, blocking])
+ return Matrix(x.size(), data, row_indices, column_indices, offsets)
+
+
+@torch.no_grad()
+def ones_like(x):
+ return Matrix(x.size(),
+ torch.ones_like(x.data),
+ x.row_indices,
+ x.column_indices, x.offsets)
+
+
+def sum(x):
+ assert isinstance(x, Matrix)
+ return x.data.sum()
diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/ops/matrix_ops_test.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/ops/matrix_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..3af04c0760483e578f93303dc457415948a2a34c
--- /dev/null
+++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/ops/matrix_ops_test.py
@@ -0,0 +1,62 @@
+import unittest
+
+from absl.testing import parameterized
+import stk
+import torch
+
+
+@parameterized.parameters(
+ (8, 16, 0.0, 1),
+ (8, 16, 0.5, 1),
+ (8, 16, .95, 1),
+ (16, 8, 0.0, 1),
+ (16, 8, 0.5, 1),
+ (16, 8, .95, 1),
+ (8, 16, 0.0, 8),
+ (8, 16, 0.5, 8),
+ (8, 16, 1.0, 8),
+ (16, 8, 0.0, 8),
+ (16, 8, 0.5, 8),
+ (16, 8, 1.0, 8),
+ (128, 256, 0.5, 16),
+ (256, 128, 0.75, 32),
+ (512, 512, .875, 128))
+class MatrixOpsTest(parameterized.TestCase):
+
+ def testMatrixOps_FormatConversion(self, rows, cols, sparsity, blocking):
+ mask = stk.random.dense_mask(rows, cols, sparsity, blocking)
+ x = (torch.randn(rows, cols) * mask).type(torch.float16)
+
+ # Convert the matrix to sparse format.
+ sparse_x = stk.ops.to_sparse(x, blocking)
+
+ # Validate the matrix.
+ sparse_x.validate()
+
+ # Validate the shape.
+ self.assertEqual(sparse_x.dim(), 2)
+ self.assertEqual(sparse_x.size()[0], rows)
+ self.assertEqual(sparse_x.size()[1], cols)
+
+ # Validate the sparsity.
+ numblocks = rows // blocking * cols // blocking
+ nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+ self.assertEqual(sparse_x.nnz, nnz)
+
+ # Convert back to dense format.
+ dense_x = stk.ops.to_dense(sparse_x)
+
+ # Validate the shape.
+ self.assertEqual(dense_x.dim(), 2)
+ self.assertEqual(dense_x.size()[0], rows)
+ self.assertEqual(dense_x.size()[1], cols)
+
+ # Validate the sparsity
+ self.assertEqual(torch.count_nonzero(dense_x).item(), nnz)
+
+ # Validate the output.
+ self.assertTrue(torch.all(torch.eq(x, dense_x)))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/random/__init__.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/random/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2576d1ca27283f77569a9a620c7c99fa68aaf30e
--- /dev/null
+++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/random/__init__.py
@@ -0,0 +1,2 @@
+# from stk.random.random_ops import dense_mask, mask, randn
+from .random_ops import dense_mask, mask, randn
diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/random/random_ops.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/random/random_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1b36771e0eb8e7abf46bcb3b136b5fb1d29df93
--- /dev/null
+++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/random/random_ops.py
@@ -0,0 +1,36 @@
+import numpy as np
+import torch
+from ..ops import matrix_ops
+
+
+@torch.no_grad()
+def dense_mask(rows, cols, sparsity, blocking=1):
+ assert sparsity >= 0.0 and sparsity <= 1.0
+ assert rows % blocking == 0 and cols % blocking == 0
+
+ block_rows, block_cols = (rows // blocking, cols // blocking)
+ nnz = round(block_rows * block_cols * (1 - sparsity))
+
+ out = np.ones(block_rows * block_cols)
+ mask = np.random.choice(out.size, out.size - nnz, replace=False)
+ out[mask] = 0.0
+
+ out = np.tile(
+ np.reshape(out, [block_rows, 1, block_cols, 1]),
+ (1, blocking, 1, blocking))
+ out = np.reshape(out, [rows, cols])
+ return torch.from_numpy(out.astype(np.float32))
+
+
+@torch.no_grad()
+def mask(m, n, sparsity, blocking=1):
+ out = dense_mask(m, n, sparsity, blocking).type(torch.float16)
+ return matrix_ops.to_sparse(out, blocking=blocking)
+
+
+@torch.no_grad()
+def randn(shape, sparsity, blocking=1):
+ shape_2d = (np.prod(shape[:-1]), shape[-1])
+ out = mask(*shape_2d, sparsity, blocking)
+ out.data.copy_(torch.randn(*out.data.shape))
+ return out.view(*shape)
diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/random/random_ops_test.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/random/random_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..587b44ec890c861879c6296b8f9028f5d99ab82f
--- /dev/null
+++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/random/random_ops_test.py
@@ -0,0 +1,73 @@
+import unittest
+
+from absl.testing import parameterized
+from . import random
+import torch
+
+
+@parameterized.parameters(
+ (8, 16, 0.0, 1),
+ (8, 16, 0.5, 1),
+ (8, 16, .95, 1),
+ (16, 8, 0.0, 1),
+ (16, 8, 0.5, 1),
+ (16, 8, .95, 1),
+ (8, 16, 0.0, 8),
+ (8, 16, 0.5, 8),
+ (8, 16, 1.0, 8),
+ (16, 8, 0.0, 8),
+ (16, 8, 0.5, 8),
+ (16, 8, 1.0, 8),
+ (128, 256, 0.5, 16),
+ (256, 128, 0.75, 32),
+ (512, 512, .875, 128))
+class RandomOpsTest(parameterized.TestCase):
+
+ def testRandomOps_DenseMask(self, rows, cols, sparsity, blocking):
+ mask = random.dense_mask(
+ rows, cols, sparsity, blocking)
+
+ # Validate the shape.
+ self.assertEqual(mask.dim(), 2)
+ self.assertEqual(mask.size()[0], rows)
+ self.assertEqual(mask.size()[1], cols)
+
+ # Validate the sparsity
+ numblocks = rows // blocking * cols // blocking
+ nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+ self.assertEqual(
+ torch.count_nonzero(mask).item(),
+ nnz)
+
+ # Check values are zero or one.
+ self.assertTrue(
+ torch.all(torch.logical_or(
+ torch.eq(mask, 0),
+ torch.eq(mask, 1))))
+
+ def testRandomOps_SparseMask(self, rows, cols, sparsity, blocking):
+ mask = random.mask(
+ rows, cols, sparsity, blocking)
+
+ # Validate the matrix.
+ mask.validate()
+
+ # Validate the shape.
+ self.assertEqual(mask.dim(), 2)
+ self.assertEqual(mask.size()[0], rows)
+ self.assertEqual(mask.size()[1], cols)
+
+ # Validate the sparsity.
+ numblocks = rows // blocking * cols // blocking
+ nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+ self.assertEqual(mask.nnz, nnz)
+
+ # Check values are zero or one.
+ self.assertTrue(
+ torch.all(torch.logical_or(
+ torch.eq(mask.data, 0),
+ torch.eq(mask.data, 1))))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/__init__.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..38075732c6d8fa0e1e6ef493145e1aca3851ae6b
--- /dev/null
+++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/__init__.py
@@ -0,0 +1,202 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+from ._ops import ops
+
+from .grouped_gemm import backend as gg_backend
+from .grouped_gemm import ops as gg_ops
+
+
+from ._layers.arguments import Arguments
+from ._layers.dmoe import ParallelDroplessMLP, dMoE
+from ._layers.glu import SparseGLU
+from ._layers.mlp import MLP, SparseMLP
+from ._layers.moe import MoE, ParallelMLP, get_load_balancing_loss
+
+from . import layers
+
+# This section contains the direct kernel exports (not inlcuded in the original code)
+def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
+ """
+ Compute exclusive cumulative sum along the specified dimension.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ result = ops.exclusive_cumsum(x, dim)
+ out.copy_(result)
+ return out
+
+
+def inclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
+ """
+ Compute inclusive cumulative sum along the specified dimension.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ result = ops.inclusive_cumsum(x, dim)
+ out.copy_(result)
+ return out
+
+
+def histogram(x: torch.Tensor, num_bins: int) -> torch.Tensor:
+ """
+ Compute histogram of input tensor values.
+
+ Args:
+ x: Input tensor
+ num_bins: Number of histogram bins
+
+ Returns:
+ Histogram tensor with counts for each bin
+ """
+ return ops.histogram(x, num_bins)
+
+
+def indices(
+ padded_bins: torch.Tensor,
+ block_size: int,
+ output_block_rows: int,
+ output_block_columns: int,
+) -> torch.Tensor:
+ """
+ Construct indices from padded bins for sparse operations.
+
+ Args:
+ padded_bins: Tensor containing bin boundaries
+ block_size: Size of each block
+ output_block_rows: Number of rows in output blocks
+ output_block_columns: Number of columns in output blocks
+
+ Returns:
+ Tensor containing constructed indices
+ """
+ return ops.indices(padded_bins, block_size, output_block_rows, output_block_columns)
+
+
+def replicate_forward(
+ x: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Forward pass of replicate operation - replicate values according to bin sizes.
+
+ Args:
+ x: Input tensor with values to replicate
+ bins: Tensor containing bin sizes
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ return ops.replicate_forward(x, bins, out)
+
+
+def replicate_backward(
+ grad: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Backward pass of replicate operation - reduce gradients back to bins.
+
+ Args:
+ grad: Gradient tensor to reduce
+ bins: Tensor containing bin sizes
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ return ops.replicate_backward(grad, bins, out)
+
+
+def sort(
+ x: torch.Tensor, end_bit: int, x_out: torch.Tensor, iota_out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Radix sort with index tracking.
+
+ Args:
+ x: Input tensor to sort
+ end_bit: Number of bits to consider in sorting
+ x_out: Output tensor for sorted values
+ iota_out: Output tensor for sorted indices
+
+ Returns:
+ The sorted values tensor
+ """
+ return ops.sort(x, end_bit, x_out, iota_out)
+
+
+# Convenience functions for common use cases
+def cumsum(x: torch.Tensor, dim: int = -1, exclusive: bool = False) -> torch.Tensor:
+ """
+ Compute cumulative sum with automatic output allocation.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum (default: last dimension)
+ exclusive: Whether to compute exclusive (True) or inclusive (False) cumsum
+
+ Returns:
+ New tensor containing the cumulative sum
+ """
+ out = torch.empty_like(x)
+ if exclusive:
+ return exclusive_cumsum(x, dim, out)
+ else:
+ return inclusive_cumsum(x, dim, out)
+
+
+def argsort(x: torch.Tensor, end_bit: int = 32) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Sort tensor and return both sorted values and indices.
+
+ Args:
+ x: Input tensor to sort
+ end_bit: Number of bits to consider in sorting
+
+ Returns:
+ Tuple of (sorted_values, sorted_indices)
+ """
+ x_out = torch.empty_like(x)
+ iota_out = torch.empty_like(x)
+ sort(x, end_bit, x_out, iota_out)
+ return x_out, iota_out
+
+
+# Export public API
+__all__ = [
+ "MyReplacementLayer",
+ # Direct kernel exports
+ "exclusive_cumsum",
+ "inclusive_cumsum",
+ "histogram",
+ "indices",
+ "replicate_forward",
+ "replicate_backward",
+ "sort",
+ "cumsum",
+ "argsort",
+ # Original exports
+ "Arguments",
+ "ParallelDroplessMLP",
+ "dMoE",
+ "SparseGLU",
+ "MLP",
+ "SparseMLP",
+ "MoE",
+ "ParallelMLP",
+ "get_load_balancing_loss",
+]
diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/__init__.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a720e7a2cc4e44636f6e433a2750e945dc38e8b2
--- /dev/null
+++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/__init__.py
@@ -0,0 +1,10 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# from megablocks.layers.dmoe import dMoE
+from .moe import MoE
+
+__all__ = [
+ 'MoE',
+ # 'dMoE',
+]
diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/activation_fn.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/activation_fn.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e1d956704840aa4daf7d1d71d24e051567feab9
--- /dev/null
+++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/activation_fn.py
@@ -0,0 +1,33 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Callable, Union
+
+import torch
+from ..stk import Matrix
+
+
+def act_fn(
+ x: Matrix,
+ function: Callable,
+ return_grad_fn: bool = False,
+ **kwargs,
+) -> Union[tuple[Matrix, Any] | Matrix]:
+ assert isinstance(x, Matrix)
+ with torch.set_grad_enabled(torch.is_grad_enabled() or return_grad_fn):
+ if return_grad_fn:
+ x.data.requires_grad = True
+ out = function(x.data, **kwargs)
+ y = Matrix(
+ x.size(),
+ out,
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
+ if return_grad_fn:
+ return y, out.backward
+ return y
diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/all_to_all.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/all_to_all.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ac7067bcaa34db1d82b340c43550fe3577aa7a3
--- /dev/null
+++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/all_to_all.py
@@ -0,0 +1,54 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import torch.distributed as dist
+
+
+class AllToAllOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op):
+ out = torch.empty((sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype)
+
+ ctx.input_shape = x.shape
+ ctx.output_split_sizes = output_split_sizes
+ ctx.input_split_sizes = input_split_sizes
+ ctx.group = group
+ handle = dist.all_to_all_single(
+ out,
+ x,
+ output_split_sizes=output_split_sizes,
+ input_split_sizes=input_split_sizes,
+ group=group,
+ async_op=async_op,
+ )
+ return out, handle
+
+ @staticmethod
+ def backward(ctx, grad, _):
+ if ctx.needs_input_grad[0]:
+ out = torch.empty(
+ ctx.input_shape,
+ device=grad.device,
+ dtype=grad.dtype,
+ )
+ dist.all_to_all_single(
+ out,
+ grad,
+ output_split_sizes=ctx.input_split_sizes,
+ input_split_sizes=ctx.output_split_sizes,
+ group=ctx.group,
+ )
+ return out, None, None, None, None
+ return None, None, None, None, None
+
+
+def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False):
+ return AllToAllOp.apply(
+ x,
+ output_split_sizes,
+ input_split_sizes,
+ group,
+ async_op,
+ )
diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/arguments.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/arguments.py
new file mode 100644
index 0000000000000000000000000000000000000000..4db9b1bd38bc2e2f421625c124f86b85f45c5ae0
--- /dev/null
+++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/arguments.py
@@ -0,0 +1,101 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import dataclasses
+from functools import partial
+from typing import Any, Callable, Optional, Union
+
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+
+# import megablocks.grouped_gemm_util as grouped_gemm
+from .. import grouped_gemm_util as grouped_gemm
+
+# Type annotation for in-place Tensor initialization function.
+InitFn = Union[Callable[[torch.Tensor], None], partial[torch.Tensor]]
+
+_ALLOWED_BITWIDTHS = (-1, 4, 8)
+
+DEFAULT_ACTIVATION_FN = partial(F.gelu, approximate='tanh')
+
+
+@dataclasses.dataclass
+class Arguments:
+ # Model arguments.
+ hidden_size: int = 1024
+ ffn_hidden_size: int = 4096
+ num_layers: int = 1
+ bias: bool = True
+ return_bias: bool = True
+ activation_fn: Optional[Callable] = DEFAULT_ACTIVATION_FN
+
+ # MoE arguments.
+ moe_num_experts: int = 1
+ moe_top_k: int = 1
+ moe_capacity_factor: int = 1
+ moe_normalize_expert_weights: Optional[Union[int, float]] = None
+ moe_loss_weight: float = 0.1
+ moe_jitter_eps: Optional[float] = None
+ moe_lbl_in_fp32: bool = False
+
+ # Parallelism arguments.
+ moe_expert_model_parallelism: bool = False
+ expert_parallel_group: Optional[dist.ProcessGroup] = None
+ pipeline_model_parallel_size: int = 1
+ num_layers_per_virtual_pipeline_stage: Optional[int] = None
+
+ # Compute arguments.
+ memory_optimized_mlp: bool = False
+ mlp_type: str = 'mlp'
+ mlp_impl: str = 'sparse'
+
+ # Initialization arguments.
+ fp16: bool = True
+ bf16: bool = False
+ device: Union[int, torch.device] = dataclasses.field(default_factory=torch.cuda.current_device)
+ init_method: InitFn = partial(torch.nn.init.normal_, mean=0.0, std=0.02)
+ output_layer_init_method: InitFn = init_method
+
+ # Benchmarking arguments.
+ uniform_expert_assignment: bool = False
+
+ # shared expert arguments
+ shared_expert: bool = False # enable using shared expert
+ fc_cls: Any = torch.nn.Linear # class of the fully connected layer in shared expert (purpose: to allow using custom FC layer eg te.Linear (for FP8))
+ fc_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict,) # kwargs for custom fc layers
+ remat_act_fn: bool = True # enable act fn to be rematerialized instead of stored
+ shared_expert_hidden_size: Optional[
+ int] = None # hidden size of the shared expert IF we want to set it to something different from hidden_size
+ shared_expert_weighted_sum: bool = False # enable using weighted sum for shared expert output (wieghted by number of experts used)
+
+ # Router Z-loss arguments
+ moe_zloss_weight: float = 0 # 1e-3 is a reasonable value
+ moe_zloss_in_fp32: bool = False
+
+ def __post_init__(self):
+ # Sparse MLP is not supported with triton >=3.2.0
+ # TODO: Remove this once sparse is supported with triton >=3.2.0
+ if self.__getattribute__('mlp_impl') == 'sparse':
+ try:
+ import triton
+ if triton.__version__ >= '3.2.0':
+ raise ValueError(
+ 'Sparse MLP is not supported with triton >=3.2.0. Please use mlp_impl="grouped" instead.',
+ )
+ except ImportError:
+ raise ImportError('Triton is required for sparse MLP implementation')
+
+ if self.__getattribute__('mlp_impl') == 'grouped':
+ grouped_gemm.assert_grouped_gemm_is_available()
+
+ if self.shared_expert_hidden_size is None:
+ self.shared_expert_hidden_size = self.ffn_hidden_size
+
+
+def from_megatron(megatron_args: Any):
+ args = Arguments()
+ for field in dataclasses.fields(args):
+ if hasattr(megatron_args, field.name):
+ setattr(args, field.name, getattr(megatron_args, field.name))
+ return args
diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/common.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d07109702963ba48a3b94ab860807954dfd79c1
--- /dev/null
+++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/common.py
@@ -0,0 +1,26 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+from .arguments import Arguments
+
+
+def dtype(args: Arguments):
+ if args.fp16:
+ return torch.float16
+ elif args.bf16:
+ return torch.bfloat16
+ return None
+
+
+def cast_if_autocast_enabled(tensor):
+ if torch.is_autocast_enabled():
+ if tensor.device.type == 'cuda':
+ dtype = torch.get_autocast_gpu_dtype()
+ elif tensor.device.type == 'cpu':
+ dtype = torch.get_autocast_cpu_dtype()
+ else:
+ raise NotImplementedError()
+ return tensor.to(dtype=dtype)
+ return tensor
diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/dmlp_registry.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/dmlp_registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..de2ed047042e438c7190ebb139b6f7f30009734c
--- /dev/null
+++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/dmlp_registry.py
@@ -0,0 +1,42 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Union
+
+from . import glu, mlp
+from .arguments import Arguments
+
+MlpType = Union[mlp.SparseMLP, glu.SparseGLU]
+
+_REGISTRY = {
+ 'mlp': {
+ 'grouped': mlp.GroupedMLP,
+ 'sparse': mlp.SparseMLP,
+ },
+ 'glu': {
+ 'grouped': glu.GroupedGLU,
+ 'sparse': glu.SparseGLU,
+ },
+}
+
+
+def get(args: Arguments) -> MlpType:
+ """Returns an MLP for use in a dMoE instance.
+
+ Uses the provided arguments to instantiate the appropriate
+ MLP instance. This only contains MLPs for use in dMoEs
+ (ie. only for the dropless versions of MoEs).
+
+ Args:
+ args: propagated Arguments dataclass.
+
+ Returns:
+ An instantiated MLP constructed using the input args.
+ """
+ if args.mlp_type not in _REGISTRY:
+ raise ValueError(f'Unsupported mlp type: {args.mlp_type}')
+
+ if args.mlp_impl not in _REGISTRY[args.mlp_type]:
+ raise ValueError(f'{args.mlp_type} does not support {args.mlp_impl} backend.',)
+
+ return _REGISTRY[args.mlp_type][args.mlp_impl](args)
diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/dmoe.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/dmoe.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d0375a4df2f27134c4127e60be04f3b45693050
--- /dev/null
+++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/dmoe.py
@@ -0,0 +1,337 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import numpy as np
+import torch
+
+# try:
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/dmoe.py is needed.',
+# )
+
+# import megablocks.ops as ops
+# # from megablocks.ops import ops
+# from megablocks.layers import common, dmlp_registry, moe, mpu
+# from megablocks.layers.arguments import Arguments
+
+from .. import stk
+from .. import ops
+from . import common, dmlp_registry, moe, mpu
+from .arguments import Arguments
+
+def promote_scalar(x):
+ return x.view(1) if not len(x.size()) else x
+
+
+class ParallelDroplessMLP(moe.ParallelMLP):
+
+ def __init__(self, args: Arguments):
+ super(ParallelDroplessMLP, self).__init__(args)
+ self.hidden_size = args.hidden_size
+ self.ffn_hidden_size = mpu.features_per_rank(args)
+ self.blocking = 128
+ self.mlp = dmlp_registry.get(args)
+
+ # Calculate the number of bits needed to represent the column indices
+ # in the intermediate sparse matrix.
+ max_column_index = ((self.ffn_hidden_size * self.num_experts) // self.blocking)
+ self.transpose_sort_end_bit = max(
+ int(np.ceil(np.log2(max_column_index))),
+ 1,
+ )
+
+ def sparse_transpose(self, size, row_indices, column_indices, offsets):
+ block_columns = size[1] // self.blocking
+
+ # Sort row indices by column indices to get the transposed matrix's
+ # column indices.
+ #
+ # NOTE: Our sort operation uses the same width indices as the input values.
+ # To avoid overflow when we have large activation matrices we cast to
+ # 32-bit before sorting.
+ _, gather_indices = ops.sort(
+ column_indices.int(),
+ self.transpose_sort_end_bit,
+ )
+
+ # There are a constant number of blocks in every row of the sparse matrix.
+ # A blocks offset is:
+ #
+ # row_index * blocks_per_row + column_index % blocks_per_row
+ #
+ # Once we have the block offsets ordered for transposition we can divide
+ # by blocks_per_row to get the transposed column indices.
+ column_indices_t = row_indices.gather(0, gather_indices.long())
+ block_offsets_t = gather_indices.int()
+
+ zero = torch.zeros((1,), dtype=torch.int32, device=row_indices.device)
+ nnz_per_column = ops.histogram(column_indices, block_columns)
+ nnz_per_column = ops.inclusive_cumsum(nnz_per_column, 0)
+ if nnz_per_column.dim() == 0:
+ # This addresses an edge case when ffn_hidden_size is equal to self.blocking.
+ nnz_per_column = nnz_per_column.unsqueeze(0)
+ offsets_t = torch.cat([zero, nnz_per_column])
+ return column_indices_t, offsets_t, block_offsets_t
+
+ def topology(self, x, padded_bins):
+ padded_tokens, _ = x.size()
+ assert padded_tokens % self.blocking == 0
+ if self.ffn_hidden_size % self.blocking != 0:
+ raise ValueError(
+ f'The ffn_hidden_size {self.ffn_hidden_size} must be divisible by ' +
+ f'the block size {self.blocking}. Please update your configuration.',
+ )
+
+ # Offsets for the sparse matrix. All rows have the
+ # same number of nonzero blocks dictated by the
+ # dimensionality of a single expert.
+ block_rows = padded_tokens // self.blocking
+ blocks_per_row = self.ffn_hidden_size // self.blocking
+ offsets = torch.arange(
+ 0,
+ block_rows * blocks_per_row + 1,
+ blocks_per_row,
+ dtype=torch.int32,
+ device=x.device,
+ )
+
+ # Indices for the sparse matrix. The indices for
+ # the intermediate matrix are dynamic depending
+ # on the mapping of tokens to experts.
+ column_indices = ops.topology(
+ padded_bins,
+ self.blocking,
+ block_rows,
+ blocks_per_row,
+ )
+
+ # TODO(tgale): This is unused. Remove the need for this in stk.
+ # For now, use meta init to save the device memory.
+ data = torch.empty(
+ column_indices.numel(),
+ self.blocking,
+ self.blocking,
+ dtype=common.dtype(self.args),
+ device='meta',
+ )
+ shape = (
+ padded_tokens,
+ self.ffn_hidden_size * mpu.experts_per_rank(self.args),
+ )
+ row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
+ column_indices_t, offsets_t, block_offsets_t = self.sparse_transpose(
+ shape,
+ row_indices,
+ column_indices,
+ offsets,
+ )
+ return stk.Matrix(
+ shape,
+ data,
+ row_indices,
+ column_indices,
+ offsets,
+ column_indices_t,
+ offsets_t,
+ block_offsets_t,
+ )
+
+ def indices_and_padded_bins(self, top_experts):
+ # Sort the expert ids to produce the scatter/gather
+ # indices for the permutation.
+ top_experts = top_experts.int()
+ bin_ids, indices = ops.sort(top_experts, self.sort_end_bit)
+
+ # Histogram the expert ids to identify the number of
+ # tokens routed to each expert.
+ tokens_per_expert = ops.histogram(top_experts, self.num_experts)
+
+ # Round the token counts up to the block size used in
+ # the matrix muliplications. Caculate the starting
+ # position of each bin.
+ padded_tokens_per_expert = ops.round_up(
+ tokens_per_expert,
+ self.blocking,
+ )
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ padded_bins = promote_scalar(padded_bins)
+
+ # Calculate the bin bounds for the sorted tokens.
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ bins = promote_scalar(bins)
+ return indices, bin_ids, bins, padded_bins, tokens_per_expert
+
+ def sparse_forward_once(self, x, expert_weights, top_experts):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, padded_bins, tokens_per_expert = (self.indices_and_padded_bins(top_experts))
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.padded_gather(
+ x,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ self.top_k,
+ )
+
+ # Create the sparse matrix topology.
+ with torch.no_grad():
+ topo = self.topology(x, padded_bins)
+
+ # Perform the expert computation.
+ x = self.mlp(x, topo)
+
+ # Un-route the data for the MoE output.
+ x = ops.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ padded_bins,
+ self.top_k,
+ )
+ return x, tokens_per_expert
+
+ # For use in the base-class parallel_forward_once.
+ def sparse_permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy, # unused
+ top_k,
+ ):
+
+ # Round the token counts up to the block size used in the matrix
+ # multiplication. Calculate the starting position of each bin.
+ padded_tokens_per_expert = ops.round_up(
+ tokens_per_expert,
+ self.blocking,
+ )
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ padded_bins = promote_scalar(padded_bins)
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
+
+ # Create the sparse matrix topology.
+ with torch.no_grad():
+ topo = self.topology(x, padded_bins)
+
+ # Perform the expert computation.
+ x = self.mlp(x, topo)
+
+ # Un-route the data for the MoE output.
+ return ops.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ def grouped_forward_once(self, x, expert_weights, top_experts):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ out = self.grouped_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ -1, # unused
+ self.args.moe_top_k,
+ )
+ return out, tokens_per_expert
+
+ def grouped_permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy, # unused
+ top_k,
+ ):
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.gather(x, indices, bin_ids, bins, top_k)
+
+ # Perform the expert computation.
+ x = self.mlp(x, tokens_per_expert)
+
+ # Un-route the data for the MoE output.
+ return ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
+
+ def forward_once(self, x, expert_weights, top_experts):
+ if self.args.mlp_impl == 'sparse':
+ return self.sparse_forward_once(x, expert_weights, top_experts)
+ else:
+ return self.grouped_forward_once(x, expert_weights, top_experts)
+
+ def permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ ):
+ if self.args.mlp_impl == 'sparse':
+ return self.sparse_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ )
+ else:
+ return self.grouped_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ )
+
+
+class dMoE(moe.MoE):
+
+ def _init_experts_mlp(self, args: Arguments):
+ return ParallelDroplessMLP(args)
diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/gelu.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/gelu.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4c9e6532798615b5c12c96694241a4c18ee8f7b
--- /dev/null
+++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/gelu.py
@@ -0,0 +1,52 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# try:
+# import stk
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/gelu.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+import torch.nn.functional as F
+
+
+@torch.jit.script
+def _gelu_backward_inplace(g, x):
+ tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
+ ff = (0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out))
+ return g.mul_(ff)
+
+
+def gelu_backward_(grad: stk.Matrix, x: stk.Matrix):
+ # NOTE: The two sparse matrices must have the same topology.
+ if isinstance(grad, stk.Matrix) and isinstance(x, stk.Matrix):
+ return stk.Matrix(
+ x.size(),
+ _gelu_backward_inplace(grad.data, x.data),
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
+ return _gelu_backward_inplace(grad, x)
+
+
+def gelu(x: stk.Matrix):
+ assert isinstance(x, stk.Matrix)
+ return stk.Matrix(
+ x.size(),
+ F.gelu(x.data, approximate='tanh'),
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/glu.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/glu.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f297a41ff6a1a2a285f5b461951672364b898da
--- /dev/null
+++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/glu.py
@@ -0,0 +1,244 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# import stk.ops
+# try:
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/glu.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+
+# from megablocks import grouped_gemm_util as gg
+# from megablocks.layers import common, mpu
+# from megablocks.layers.activation_fn import act_fn
+# from megablocks.layers.arguments import Arguments
+# from megablocks.layers.mlp import (
+# SharedMLP,
+# SparseMLP,
+# create_dmoe_expert_weights,
+# resolve_dtensor,
+# )
+
+from .. import grouped_gemm_util as gg
+from . import common, mpu
+from .activation_fn import act_fn
+from .arguments import Arguments
+from .mlp import (
+ SharedMLP,
+ SparseMLP,
+ create_dmoe_expert_weights,
+ resolve_dtensor,
+)
+
+
+class SparseGLU(SparseMLP):
+
+ def __init__(self, args: Arguments):
+ super().__init__(args)
+ self.v1 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ with torch.no_grad():
+ self.v1.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ ),
+ )
+
+ mpu.set_expert_model_parallel_attributes(
+ self.v1,
+ self._should_set_parallelism_attribute,
+ )
+
+ def forward(self, x, topo):
+ if self.args.memory_optimized_mlp:
+ raise NotImplementedError(
+ 'Memory optimized implementation not yet supported with GLU with sparse kernels.',
+ )
+
+ w1, v1, w2 = self.scale_grad(self.w1), self.scale_grad(self.v1,), self.scale_grad(self.w2)
+ w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2)
+
+ # Compute the GLU.
+ x1 = stk.ops.sdd(x, w1.t(), topo)
+ x2 = stk.ops.sdd(x, v1.t(), topo)
+
+ activation_fn_out = act_fn(x1, self.args.activation_fn)
+ x1 = stk.ops.mul(activation_fn_out, x2)
+
+ return stk.ops.dsd(x1, w2)
+
+
+class MemoryOptimizedGroupedGLU(torch.autograd.Function):
+ """GroupedMLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, v1, w2, batch_sizes, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ v1 = v1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], v1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not v1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1', 'v1' and 'w2'.")
+
+ # Layer 0: x @ w1.t().
+ assert gg.backend is not None
+ sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True)
+ v1_out = gg.backend.gmm(x, v1, batch_sizes, trans_b=True)
+
+ # GeLU.
+ activation_fn_out = activation_fn(sdd_out) * v1_out
+
+ # Layer 1: x @ w2.
+ dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, v1, w2, batch_sizes, x, sdd_out, v1_out)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx, ddsd_out):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # Unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, v1, w2 = saved_tensors[:3]
+ batch_sizes = saved_tensors[3]
+ x = saved_tensors[4]
+ sdd_out, v1_out = saved_tensors[5:7]
+
+ # Rematerialize activation_fn output.
+ activation_fn = ctx.activation_fn
+ with torch.set_grad_enabled(True):
+ sdd_out.requires_grad = True
+ v1_out.requires_grad = True
+ activation_fn_out = activation_fn(sdd_out) * v1_out
+ activation_grad_fn = activation_fn_out.backward
+
+ # Compute dw2 with recomputed activation_fn output.
+ assert gg.backend is not None
+ dw2 = gg.backend.gmm(
+ activation_fn_out,
+ ddsd_out,
+ batch_sizes,
+ trans_a=True,
+ )
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ gg.backend.gmm(
+ ddsd_out,
+ w2,
+ batch_sizes,
+ trans_b=True,
+ c=dactivation_fn_out,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out)
+ dsdd_out = sdd_out.grad
+ dv1_out = v1_out.grad
+
+ # Compute dw1.
+ dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True)
+
+ # Compute dv1.
+ dv1 = gg.backend.gmm(dv1_out, x, batch_sizes, trans_a=True)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ dx = ddsd_out
+ gg.backend.gmm(dsdd_out, w1, batch_sizes, c=dx)
+ dx += gg.backend.gmm(dv1_out, v1, batch_sizes)
+ return dx, dw1, dv1, dw2, None, None
+
+
+memory_optimized_grouped_glu = MemoryOptimizedGroupedGLU.apply
+
+
+class GroupedGLU(SparseGLU):
+
+ def forward(self, x, tokens_per_expert):
+ batch_sizes = tokens_per_expert.cpu().to(torch.long)
+ w1, v1, w2 = (
+ self.scale_grad(self.w1),
+ self.scale_grad(self.v1),
+ self.scale_grad(self.w2),
+ )
+ w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2)
+
+ # Re-shape the weights for the grouped GEMMs.
+ ne = mpu.experts_per_rank(self.args)
+ w1 = w1.view(ne, -1, self.args.hidden_size)
+ v1 = v1.view(ne, -1, self.args.hidden_size)
+ w2 = w2.view(ne, -1, self.args.hidden_size)
+
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_grouped_glu(
+ x,
+ w1,
+ v1,
+ w2,
+ batch_sizes,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ assert gg.ops is not None
+ x1 = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
+ x2 = gg.ops.gmm(x, v1, batch_sizes, trans_b=True)
+ x1 = self.args.activation_fn(x1) * x2
+ return gg.ops.gmm(x1, w2, batch_sizes)
+
+
+class SharedGLU(SharedMLP):
+ """GPU for shared expert.
+
+ Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTGLU class
+ """
+
+ def __init__(self, args: Arguments):
+ super().__init__(args)
+ self.gate_proj = args.fc_cls(
+ args.hidden_size,
+ self.args.shared_expert_hidden_size,
+ **self.fc_kwargs,
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))
diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/memory_test.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/memory_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..74d1166931b712635131985b25a89f4ca23e576d
--- /dev/null
+++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/memory_test.py
@@ -0,0 +1,103 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import gc
+
+import torch
+import torch.distributed as dist
+
+# from megablocks.layers import arguments, dmoe
+from . import arguments, dmoe
+
+_TESTS = ((8, 2048, 4096, 4096, 32, 4),)
+
+
+def get_tensors():
+ ptrs = set()
+ out = []
+ for obj in gc.get_objects():
+ if torch.is_tensor(obj):
+ if not obj.is_contiguous() or obj.data_ptr() in ptrs:
+ continue
+ out.append(obj)
+ ptrs.add(obj.data_ptr())
+ return out
+
+
+def test_memory(
+ group,
+ batch_size,
+ sequence_length,
+ hidden_size,
+ ffn_hidden_size,
+ num_experts,
+ top_k,
+):
+ args = arguments.Arguments(
+ hidden_size=hidden_size,
+ ffn_hidden_size=ffn_hidden_size,
+ moe_num_experts=num_experts,
+ moe_top_k=top_k,
+ moe_expert_model_parallelism=True,
+ expert_parallel_group=group,
+ fp16=False,
+ bf16=True,
+ device=torch.cuda.current_device(),
+ )
+ layer = dmoe.dMoE(args).cuda()
+
+ x = torch.randn((batch_size, sequence_length, hidden_size),
+ device=torch.cuda.current_device(),
+ dtype=torch.bfloat16).requires_grad_(True)
+ torch.cuda.empty_cache()
+
+ # Run forward + backward.
+ # with torch.autograd.detect_anomaly():
+ out, _ = layer(x)
+ out.mean().backward()
+
+ # Report peak memory.
+ mem = torch.cuda.max_memory_allocated()
+ print('Max Memory Allocated = {:0.0f}MiB'.format(mem / 1e6))
+ print('Max Memory Reserved = {:0.0f}MiB'.format(torch.cuda.max_memory_reserved() / 1e6,),)
+
+ # Calculate weight and gradient memory usage.
+ weight_memory = 2 * (
+ layer.router.layer.weight.numel() + layer.experts.mlp.w1.numel() + layer.experts.mlp.w2.numel()
+ )
+
+ def grad_numel(x):
+ if x.grad is not None:
+ return x.grad.numel()
+ return 0
+
+ grad_memory = 2 * (
+ grad_numel(layer.router.layer.weight) + grad_numel(layer.experts.mlp.w1) + grad_numel(layer.experts.mlp.w2)
+ )
+ weight_memory += grad_memory
+
+ print('Weight Memory Allocated = {:0.0f}MiB'.format(weight_memory / 1e6))
+ print('Activation Memory Allocated = {:0.0f}MiB'.format((mem - weight_memory) / 1e6,),)
+
+ # Manually calculate GPU memory usage from the garbage
+ # collector.
+ gc.collect()
+ total = 0
+ tensors = get_tensors()
+ tensors = sorted(tensors, key=lambda x: -x.numel())
+ for i, t in enumerate(tensors):
+ total += t.numel()
+ print(f'{i}: {t.shape}, {t.numel() * 2}')
+ del tensors
+
+ print('Total Bytes Found = {:0.0f}MiB'.format(total * 2 / 1e6))
+
+
+if __name__ == '__main__':
+ assert dist.is_available()
+ group = dist.init_process_group(backend='nccl')
+ local_rank = dist.get_rank(group)
+ torch.cuda.set_device(local_rank)
+
+ for args in _TESTS:
+ test_memory(group, *args)
diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/mlp.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/mlp.py
new file mode 100644
index 0000000000000000000000000000000000000000..c99afb9904c24a8b6a83e79059cd1251dbbfd99e
--- /dev/null
+++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/mlp.py
@@ -0,0 +1,587 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# try:
+# import stk
+# import stk.backend.triton_kernels
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/mlp.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+from packaging import version
+
+# from megablocks import grouped_gemm_util as gg
+# from megablocks.layers import common, gelu, mpu
+# from megablocks.layers.activation_fn import act_fn
+# from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
+
+from .. import grouped_gemm_util as gg
+from . import common, gelu, mpu
+from .activation_fn import act_fn
+from .arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
+
+class ScaleGradient(torch.autograd.Function):
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx: Any, x: torch.Tensor, scale: float):
+ ctx.scale = scale
+ return x
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx: torch.Tensor, grad: torch.Tensor):
+ return grad * ctx.scale, None
+
+
+scale_gradient = ScaleGradient.apply
+
+
+def resolve_dtensor(weight: torch.Tensor):
+ if version.parse(torch.__version__) >= version.parse('2.0.0'):
+ from torch.distributed._tensor import DTensor
+ if isinstance(weight, DTensor):
+ return weight.to_local()
+ return weight
+
+
+def create_moe_expert_weights(
+ args: Arguments,
+ num_experts: int,
+ ffn_hidden_size: int,
+ hidden_size: int,
+ init_method: InitFn,
+):
+ # Create the entire weight matrix such that the sampled weights will
+ # not vary between data parallelism and expert model parallelism for
+ # the same random seed.
+ master_weights = torch.empty(
+ num_experts,
+ ffn_hidden_size,
+ hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ )
+ init_method(master_weights)
+
+ if not args.moe_expert_model_parallelism:
+ return master_weights
+
+ # Calculate the amount of sharding in each dimension.
+ expert_sharding_degree = mpu.expert_sharding_degree(args)
+ hidden_sharding_degree = mpu.hidden_sharding_degree(args)
+
+ # Calculate the experts per rank.
+ #
+ # NOTE: We assign ranks to be expert parallel before going
+ # tensor parallel.
+ rank = mpu.get_expert_parallel_rank(args)
+ expert_rank = rank % expert_sharding_degree
+ num_experts_per_rank = num_experts // expert_sharding_degree
+ start_expert = expert_rank * num_experts_per_rank
+ end_expert = (expert_rank + 1) * num_experts_per_rank
+
+ # Calculate the rows per rank.
+ row_rank = rank // expert_sharding_degree
+ num_rows_per_rank = ffn_hidden_size // hidden_sharding_degree
+ start_row = row_rank * num_rows_per_rank
+ end_row = (row_rank + 1) * num_rows_per_rank
+
+ # Slice the weight matrix to get the chunk for this rank.
+ with torch.no_grad():
+ weights = master_weights[start_expert:end_expert, start_row:end_row]
+ return weights
+
+
+class MLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ # expert_parallel_world_size = mpu.get_expert_parallel_world_size(args)
+ experts_per_rank = mpu.experts_per_rank(args)
+
+ self.w1 = torch.nn.Parameter(
+ torch.empty(
+ experts_per_rank,
+ args.hidden_size,
+ mpu.features_per_rank(args),
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ self.w2 = torch.nn.Parameter(
+ torch.empty(
+ experts_per_rank,
+ mpu.features_per_rank(args),
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w1,
+ args.moe_expert_model_parallelism,
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w2,
+ args.moe_expert_model_parallelism,
+ )
+
+ # Initialize the parameters for the MLP.
+ #
+ # NOTE: It is important that we create the weight tensors prior
+ # to creating the master weights and slicing our the piece for
+ # this rank. If the master weights are created first the PyTorch
+ # caching allocator appears to use the same memory block for these
+ # and the slice which causes large increases in our peak memory
+ # usage.
+ with torch.no_grad():
+ w1 = create_moe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ )
+ self.w1.copy_(w1.transpose(1, 2).contiguous())
+ self.w2.copy_(
+ create_moe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.output_layer_init_method,
+ ),
+ )
+
+ self.gradient_scale = None
+ if self.args.moe_expert_model_parallelism:
+ self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,)
+
+ def scale_grad(self, w):
+ if self.gradient_scale is None:
+ return w
+ return scale_gradient(w, self.gradient_scale)
+
+ def forward(self, x):
+ w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2)
+ w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2)
+ x = torch.bmm(x, w1)
+ x = self.args.activation_fn(x)
+ return torch.bmm(x, w2)
+
+
+def create_dmoe_expert_weights(
+ args: Arguments,
+ num_experts: int,
+ rows: int,
+ columns: int,
+ init_method: InitFn,
+):
+ weights = create_moe_expert_weights(
+ args,
+ num_experts,
+ rows,
+ columns,
+ init_method,
+ )
+ return weights.view([-1, columns])
+
+
+class MemoryOptimizedMLP(torch.autograd.Function):
+ """Sparse MLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, w2, topo, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.")
+
+ topo_tensors = (
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets,
+ topo.column_indices_t,
+ topo.offsets_t,
+ topo.block_offsets_t,
+ )
+
+ # Layer 0: x @ w1.t().
+ sdd_out = stk.ops.sdd(x, w1.t(), topo)
+
+ # GeLU.
+ activation_fn_out = act_fn(sdd_out, activation_fn)
+
+ # Layer 1: x @ w2.
+ dsd_out = stk.ops.dsd(activation_fn_out, w2)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.shape = topo.shape
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.data.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, w2, *topo_tensors, x, sdd_out.data)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx, ddsd_out):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, w2 = saved_tensors[:2]
+ topo_tensors = saved_tensors[2:8]
+ x = saved_tensors[8]
+ sdd_out_data = saved_tensors[9]
+
+ # rematerialize activation function output
+ activation_fn = ctx.activation_fn
+ sdd_out = stk.Matrix(ctx.shape, sdd_out_data, *topo_tensors)
+ activation_fn_out, activation_grad_fn = act_fn(
+ sdd_out,
+ activation_fn,
+ return_grad_fn=True,
+ )
+
+ # Compute dw2 with recomputed activation_fn output.
+ dw2 = stk.ops.dsd(activation_fn_out.t(), ddsd_out)
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ stk.backend.triton_kernels.sdd(
+ ddsd_out,
+ w2.t(),
+ dactivation_fn_out.shape,
+ dactivation_fn_out.data,
+ dactivation_fn_out.offsets,
+ dactivation_fn_out.row_indices,
+ dactivation_fn_out.column_indices,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ if activation_fn is DEFAULT_ACTIVATION_FN:
+ dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out)
+ else:
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out.data)
+ dsdd_out = stk.Matrix(ctx.shape, sdd_out.data.grad, *topo_tensors)
+
+ # Compute dw1.
+ dw1 = stk.ops.dsd(dsdd_out.t(), x)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ stk.backend.triton_kernels.dsd(
+ dsdd_out.shape,
+ dsdd_out.data,
+ dsdd_out.offsets,
+ dsdd_out.row_indices,
+ dsdd_out.column_indices,
+ dsdd_out.offsets_t,
+ dsdd_out.column_indices_t,
+ dsdd_out.block_offsets_t,
+ False,
+ w1,
+ ddsd_out,
+ )
+ dx = ddsd_out
+ return dx, dw1, dw2, None, None
+
+
+memory_optimized_mlp = MemoryOptimizedMLP.apply
+
+
+class SparseMLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ self._num_rows_per_rank = mpu.experts_per_rank(args) * mpu.features_per_rank(args)
+
+ self.w1 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ self.w2 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+
+ # Initialize the parameters for the MLP.
+ #
+ # NOTE: It is important that we create the weight tensors prior
+ # to creating the master weights and slicing our the piece for
+ # this rank. If the master weights are created first the PyTorch
+ # caching allocator appears to use the same memory block for these
+ # and the slice which causes large increases in our peak memory
+ # usage.
+ with torch.no_grad():
+ self.w1.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ ),
+ )
+ self.w2.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.output_layer_init_method,
+ ),
+ )
+
+ self._should_set_parallelism_attribute = args.moe_expert_model_parallelism
+ mpu.set_expert_model_parallel_attributes(
+ self.w1,
+ self._should_set_parallelism_attribute,
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w2,
+ self._should_set_parallelism_attribute,
+ )
+
+ self.gradient_scale = None
+ if self.args.moe_expert_model_parallelism:
+ self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,)
+
+ def scale_grad(self, w):
+ if self.gradient_scale is None:
+ return w
+ return scale_gradient(w, self.gradient_scale)
+
+ def forward(self, x, topo):
+ w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2)
+ w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2)
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_mlp(
+ x,
+ w1,
+ w2,
+ topo,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ x = stk.ops.sdd(x, w1.t(), topo)
+ activation_fn_out = act_fn(x, self.args.activation_fn)
+ return stk.ops.dsd(activation_fn_out, w2)
+
+
+class MemoryOptimizedGroupedMLP(torch.autograd.Function):
+ """GroupedMLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, w2, batch_sizes, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.")
+
+ # Layer 0: x @ w1.t().
+ assert gg.backend is not None
+ sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True)
+
+ # activation_fn
+ activation_fn_out = activation_fn(sdd_out)
+
+ # Layer 1: x @ w2.
+ dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, w2, batch_sizes, x, sdd_out)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx: Any, ddsd_out: torch.Tensor):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # Unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, w2 = saved_tensors[:2]
+ batch_sizes = saved_tensors[2]
+ x = saved_tensors[3]
+ sdd_out = saved_tensors[4]
+
+ # Rematerialize activation_fn output.
+ activation_fn = ctx.activation_fn
+ with torch.set_grad_enabled(True):
+ sdd_out.requires_grad = True
+ activation_fn_out = activation_fn(sdd_out)
+ activation_grad_fn = activation_fn_out.backward
+
+ # Compute dw2 with recomputed activation_fn output.
+ assert gg.backend is not None
+ dw2 = gg.backend.gmm(
+ activation_fn_out,
+ ddsd_out,
+ batch_sizes,
+ trans_a=True,
+ )
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ gg.backend.gmm(
+ ddsd_out,
+ w2,
+ batch_sizes,
+ trans_b=True,
+ c=dactivation_fn_out,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ if activation_fn is DEFAULT_ACTIVATION_FN:
+ dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out)
+ else:
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out)
+ dsdd_out = sdd_out.grad
+
+ # Compute dw1.
+ dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ gg.backend.gmm(dsdd_out, w1, batch_sizes, c=ddsd_out)
+ dx = ddsd_out
+ return dx, dw1, dw2, None, None
+
+
+memory_optimized_grouped_mlp = MemoryOptimizedGroupedMLP.apply
+
+
+class GroupedMLP(SparseMLP):
+
+ def forward(self, x, tokens_per_expert):
+ batch_sizes = tokens_per_expert.cpu().to(torch.long)
+ w1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.w2))
+
+ # Re-shape the weights for the grouped GEMMs.
+ ne = mpu.experts_per_rank(self.args)
+ w1 = resolve_dtensor(w1).view(ne, -1, self.args.hidden_size)
+ w2 = resolve_dtensor(w2).view(ne, -1, self.args.hidden_size)
+
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_grouped_mlp(
+ x,
+ w1,
+ w2,
+ batch_sizes,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ assert gg.ops is not None
+ x = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
+ x = self.args.activation_fn(x)
+ return gg.ops.gmm(x, w2, batch_sizes)
+
+
+class SharedMLP(torch.nn.Module):
+ """MLP for shared expert.
+
+ Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTMLP class
+ """
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ self.fc_kwargs: dict[str, Any] = {
+ 'bias': args.bias,
+ 'device': args.device,
+ }
+ self.fc_kwargs.update(args.fc_kwargs)
+
+ self.up_proj = args.fc_cls(
+ args.hidden_size,
+ args.shared_expert_hidden_size,
+ **self.fc_kwargs,
+ )
+ self.act = args.activation_fn
+ self.down_proj = args.fc_cls(
+ args.shared_expert_hidden_size,
+ args.hidden_size,
+ **self.fc_kwargs,
+ )
+ self.down_proj._is_residual = True # a flag for llm-foundry init
+
+ def add_experts_sharedexpert(
+ self,
+ shared_expert_out: torch.Tensor,
+ expert_out: torch.Tensor,
+ ) -> torch.Tensor:
+ # Helper function to add expert output to shared expert output
+ # with optional weighted sum.
+ if self.args.shared_expert_weighted_sum:
+ # enable using weighted sum for shared expert output
+ # wieghted by number of experts used
+ t_experts = self.args.moe_top_k + 1
+ sh_mlp_out = shared_expert_out / t_experts
+ return sh_mlp_out.add(
+ expert_out,
+ alpha=(self.args.moe_top_k / t_experts),
+ )
+
+ return shared_expert_out + expert_out
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.down_proj(self.act(self.up_proj(x)))
diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/moe.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0a4aeaacc9c86fc70944e730c53f7a55644e05e
--- /dev/null
+++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/moe.py
@@ -0,0 +1,507 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Optional, Tuple
+
+import numpy as np
+import torch
+import torch.distributed as dist
+
+# import megablocks.ops as ops
+# from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry
+# from megablocks.layers.all_to_all import all_to_all
+# from megablocks.layers.arguments import Arguments
+
+from ..ops import (
+ sort,
+ histogram,
+ inclusive_cumsum,
+ exclusive_cumsum,
+ binned_gather,
+ binned_scatter,
+ gather,
+ scatter,
+ repeat,
+ replicate,
+)
+
+from . import common, mlp, mpu, router, sharedexpert_registry
+from .arguments import Arguments
+from .all_to_all import all_to_all
+
+_LOAD_BALANCING_LOSS = []
+
+
+def save_load_balancing_loss(loss):
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.append(loss)
+
+
+def get_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ return _LOAD_BALANCING_LOSS
+
+
+def clear_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.clear()
+
+
+def batched_load_balancing_loss(args: Arguments):
+ if args.moe_loss_weight == 0:
+ return 0.0
+
+ # tokens_per_expert[i].shape = (num_experts)
+ # expert_scores[i].shape = (tokens, num_experts)
+ tokens_per_expert, expert_scores = zip(*get_load_balancing_loss())
+ num_layers_per_pipeline_stage = (args.num_layers // args.pipeline_model_parallel_size)
+ if args.num_layers_per_virtual_pipeline_stage is not None:
+ num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage
+
+ if len(tokens_per_expert) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f'Expected {num_layers_per_pipeline_stage} token_per_experts '
+ f'but found {len(tokens_per_expert)}.\nnum_layers = '
+ f'{args.num_layers}\npipeline_model_parallel_size = '
+ f'{args.pipeline_model_parallel_size}\n'
+ 'num_layers_per_virtual_pipeline_stage'
+ f' = {args.num_layers_per_virtual_pipeline_stage}',
+ )
+ if len(expert_scores) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f'Expected {num_layers_per_pipeline_stage} expert_scores '
+ f'but found {len(tokens_per_expert)}.\nnum_layers = '
+ f'{args.num_layers}\npipeline_model_parallel_size = '
+ f'{args.pipeline_model_parallel_size}\n'
+ 'num_layers_per_virtual_pipeline_stage'
+ f' = {args.num_layers_per_virtual_pipeline_stage}',
+ )
+
+ # Verify the shape of the tokens_per_expert and expert_scores tensors.
+ assert all((x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert))
+
+ tokens = expert_scores[0].shape[0]
+ assert all(((x.ndim == 2 and x.shape[1] == args.moe_num_experts and x.shape[0] == tokens) for x in expert_scores))
+
+ # Concatenate the contributions of each layer and convert to
+ # the correct types and formats for the dot product.
+ expert_scores = torch.cat(expert_scores, dim=1)
+ if args.moe_lbl_in_fp32:
+ expert_scores = expert_scores.float()
+ if tokens != 0:
+ expert_scores = expert_scores.mean(dim=0)
+ else:
+ expert_scores = expert_scores.sum(dim=0)
+ tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype)
+
+ expected_values = num_layers_per_pipeline_stage * args.moe_num_experts
+ assert tokens_per_expert.numel() == expected_values
+ assert expert_scores.numel() == expected_values
+
+ # Calculate the total scale across all factors.
+ #
+ # loss_weight * num_experts / (num_layers * tokens * top_k)
+ scale_numerator = (args.moe_num_experts * args.moe_loss_weight)
+ scale_denominator = (args.num_layers * tokens * args.moe_top_k)
+ scale = scale_numerator / scale_denominator
+ return scale * torch.dot(tokens_per_expert, expert_scores)
+
+
+# NOTE: This class defines MoE expert computation, including expert model parallel
+# communication. When using FSDP on top of MegaBlocks this is the module that should
+# be wrapped s.t. the weight all-gathers can be scheduled *before* the expert model
+# parallel all2all.
+class ParallelMLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super(ParallelMLP, self).__init__()
+ self.args = args
+
+ # Calculate the number of experts in total and the number of experts
+ # owned by this rank.
+ # world_size = mpu.get_expert_parallel_world_size(args)
+ self.num_experts = args.moe_num_experts
+ self.top_k = self.args.moe_top_k
+
+ # Calculate the number of bits needed to represent the expert indices
+ # so that we can pass it to radix sort.
+ self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1)
+
+ # Expert MLP.
+ self.mlp = mlp.MLP(args)
+
+ self.bias: Optional[torch.Tensor]
+ if self.args.bias:
+ # Note that the output bias is not parallelized with expert
+ # model parallelism.
+ self.bias = torch.nn.Parameter(
+ torch.empty(
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ torch.nn.init.zeros_(self.bias)
+ else:
+ self.register_parameter('bias', None)
+
+ # Select the forward function for the operating mode.
+ self.forward_fn = (self.parallel_forward_once if args.moe_expert_model_parallelism else self.forward_once)
+
+ def expert_capacity(self, tokens: int) -> int:
+ world_size = mpu.get_expert_parallel_world_size(self.args)
+ tokens_per_expert = (self.top_k * tokens * world_size / self.num_experts)
+ return int(self.args.moe_capacity_factor * tokens_per_expert)
+
+ def load_balancing_loss(self, tokens_per_expert: torch.Tensor, expert_scores: torch.Tensor):
+ """Calculate the load balancing loss contribution."""
+ assert len(expert_scores.size()) == 2
+ tokens, num_experts = expert_scores.size()
+ assert num_experts == self.num_experts
+ assert len(tokens_per_expert.size()) == 1
+ num_experts, = tokens_per_expert.size()
+ assert num_experts == self.num_experts
+ scale = self.num_experts / (tokens * self.top_k)
+ return scale * torch.dot(
+ tokens_per_expert.to(expert_scores.dtype),
+ expert_scores.mean(dim=0),
+ )
+
+ def indices_and_bins(self,
+ top_expert: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ # Sort the expert ids to produce the scatter/gather
+ # indices for the permutation.
+ #
+ # TODO(tgale): Is it worth doing this conversion to 32-bit
+ # prior? Could we place the `torch.max` operation to return
+ # 32-bit expert indices?
+ top_expert = top_expert.int()
+ # output = ops.sort(top_expert, self.sort_end_bit)
+ output = sort(top_expert, self.sort_end_bit)
+ assert output is not None
+ bin_ids, indices = output
+
+ # Histogram the expert ids to identify the number of
+ # tokens routed to each expert.
+ #
+ # TODO(tgale): Does the sorted data produce a more favorable
+ # data distribution for histogram? Or is the op parallelism
+ # worth more?
+ # tokens_per_expert = ops.histogram(top_expert, self.num_experts)
+ tokens_per_expert = histogram(top_expert, self.num_experts)
+
+ # Calculate the bin bounds for the sorted tokens.
+ # bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ bins = inclusive_cumsum(tokens_per_expert, 0)
+ assert bins is not None
+ bins = bins.view(1) if not len(bins.size()) else bins
+
+ assert isinstance(indices, torch.Tensor)
+ assert isinstance(bin_ids, torch.Tensor)
+ assert isinstance(bins, torch.Tensor)
+ assert isinstance(tokens_per_expert, torch.Tensor)
+
+ return indices, bin_ids, bins, tokens_per_expert
+
+ def permute_and_compute(
+ self,
+ x: torch.Tensor,
+ tokens_per_expert: int, # unused
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor, # unused
+ expert_weights: torch.Tensor,
+ bins: torch.Tensor,
+ expert_capacity: int,
+ top_k: int,
+ ):
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ # output = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
+ output = binned_gather(x, indices, bins, expert_capacity, top_k)
+ assert output is not None
+ x = output
+
+ # Perform the expert computation. Note that we don't
+ # use biases for these linear operations.
+ x = self.mlp(x)
+
+ # Un-route the data for the MoE output.
+ # return ops.binned_scatter(x, indices, expert_weights, bins, top_k)
+ return binned_scatter(x, indices, expert_weights, bins, top_k)
+
+
+ def forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ # If expert_capacity is set to zero, set the number of tokens
+ # per expert to the maximum we need to avoid dropping tokens.
+ sl, bs, _ = x.size()
+ expert_capacity = self.expert_capacity(sl * bs)
+ if expert_capacity == 0:
+ expert_capacity = torch.max(tokens_per_expert).item()
+
+ x = self.permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ self.top_k,
+ )
+ return x, tokens_per_expert
+
+ def parallel_forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ # NOTE: This function implements the same computation as forward_once
+ # but with expert model parallelism.
+ #
+ # 1. Permute the tokens locally so that they are grouped by their
+ # expert assignments. This allows us to transfer all of the tokens
+ # for a remote device in one communication primitive.
+ #
+ # 2. Permute the tokens across the expert parallel devices. After
+ # this is completed each device has all of the tokens assigned to
+ # its set of experts in its local HBM.
+ #
+ # 3. Permute the tokens locally so that they are grouped by their
+ # expert assignement. After the distributed permutation the tokens
+ # are grouped by which device they came from. We re-order them
+ # locally to allow for efficient computation.
+ #
+ # After this series of permutations we compute the linear layers
+ # and then repeat these three steps in reverse to produce the final
+ # output.
+ #
+ # Compute the mapping of local tokens to experts.
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ # If we're sharding the experts along the hidden dimension
+ # multiple devices own parts of the same sets of experts.
+ # Replicate the token counts so every device gets the counts.
+ # repeated_tokens_per_expert = ops.repeat(
+ repeated_tokens_per_expert = repeat(
+ tokens_per_expert,
+ (mpu.hidden_sharding_degree(self.args),),
+ )
+
+ # Pass token count information to the device on which the
+ # target expert resides.
+ parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert,)
+ tpe_handle = dist.all_to_all_single(
+ parallel_tokens_per_expert,
+ repeated_tokens_per_expert,
+ group=self.args.expert_parallel_group,
+ async_op=True,
+ )
+
+ # Permute locally and without any padding so that tokens for each
+ # parallel device are stored contiguously.
+ #
+ # This view updates the shape of the tensor from [sl, bs, hs] to
+ # [sl * bs, hs] prior to the permutation.
+ x = x.view(-1, x.shape[-1])
+ # output = ops.gather(x, indices, bin_ids, bins, self.top_k)
+ output = gather(x, indices, bin_ids, bins, self.top_k)
+ assert output is not None
+ x = output
+
+ # Compute the number of tokens that will be received from each
+ # device and permute the input data across the devices.
+ with torch.no_grad():
+ tpe_handle.wait()
+ experts_per_rank = mpu.experts_per_rank(self.args)
+
+ # Reshape to [world_size, num_experts_per_rank].
+ world_size = mpu.get_expert_parallel_world_size(self.args)
+ repeated_tokens_per_expert = (repeated_tokens_per_expert.view(world_size, experts_per_rank))
+ parallel_tokens_per_expert = (parallel_tokens_per_expert.view(world_size, experts_per_rank))
+
+ # TODO(tgale): It might be faster to do this on the GPU and
+ # then communicate the results back to the host.
+ send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1)
+ parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
+ recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1)
+
+ # Convert the send/recv counts to lists.
+ send_counts = send_counts.tolist()
+ recv_counts = recv_counts.tolist()
+ tokens_received = sum(recv_counts)
+
+ # If we're sharding the experts along the hidden dimension
+ # multiple devices own parts of the same sets of experts.
+ # Replicate the token counts so devices that share experts
+ # get all of the tokens assigned to them.
+ #
+ # TODO(tgale): Fuse this into the prior, local permutation.
+ # x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
+ x = repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
+
+ # Start the cross-device permutation asynchronously so we can
+ # overlap communication with computation.
+ parallel_x, parallel_x_handle = all_to_all(
+ x,
+ recv_counts,
+ send_counts,
+ self.args.expert_parallel_group,
+ async_op=True,
+ )
+
+ with torch.no_grad():
+ # After we do the cross-device permutation we have the tokens on the
+ # correct device but not yet grouped by expert because we received
+ # tokens from each device as contiguous chunks. To group the tokens
+ # for expert computation we'll do one more local permutation. The
+ # rest of this torch.no_grad() scope sets up the indices and bins
+ # for this permutation.
+ # replicate_bins = ops.inclusive_cumsum(
+ replicate_bins = inclusive_cumsum(
+ parallel_tokens_per_expert.flatten(),
+ 0,
+ )
+ replicate_bins = (replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins)
+
+ # Construct the expert indices for the permuted tokens.
+ parallel_top_expert = torch.remainder(
+ torch.arange(
+ self.num_experts * mpu.hidden_sharding_degree(self.args),
+ dtype=torch.int32,
+ device=indices.device,
+ ),
+ mpu.experts_per_rank(self.args),
+ )
+ # parallel_top_expert = ops.replicate(
+ parallel_top_expert = replicate(
+ parallel_top_expert.unsqueeze(dim=0),
+ replicate_bins,
+ tokens_received,
+ ).flatten()
+
+ # TODO(tgale): The sort_end_bit here can be reduced.
+ # parallel_bin_ids, parallel_indices = ops.sort(
+ parallel_bin_ids, parallel_indices = sort(
+ parallel_top_expert,
+ self.sort_end_bit,
+ )
+
+ # Calculate the bins boundaries from the token counts.
+ parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+ # parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = (parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins)
+
+ # If expert_capacity is set to zero, set the number of tokens
+ # per expert to the maximum we need to avoid dropping tokens.
+ tokens, _ = x.size()
+ expert_capacity = self.expert_capacity(tokens)
+ if expert_capacity == 0:
+ expert_capacity = torch.max(parallel_tokens_per_expert).item()
+
+ # Locally permute the tokens and perform the expert computation.
+ # Block to make sure that the cross-device permutation is complete.
+ if self.args.mlp_impl == 'grouped':
+ # GroupedMLP requires counts on CPU. We can use the tensor already
+ # moved to CPU for the prior all_to_all, which avoids an extra
+ # device synchronization.
+ parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+ parallel_x_handle.wait()
+ parallel_x = self.permute_and_compute(
+ parallel_x,
+ parallel_tokens_per_expert,
+ parallel_indices,
+ parallel_bin_ids,
+ None, # expert_weights
+ parallel_bins,
+ expert_capacity,
+ top_k=1,
+ )
+
+ # Un-permute the tokens across the devices.
+ x, _ = all_to_all(
+ parallel_x,
+ send_counts,
+ recv_counts,
+ self.args.expert_parallel_group,
+ )
+
+ # Reduce along the hidden sharding to get the final outputs.
+ #
+ # TODO(tgale): Fuse this into the following local permutation.
+ shape = (
+ mpu.hidden_sharding_degree(self.args),
+ -1,
+ self.args.hidden_size,
+ )
+ # x = ops.sum(x.view(shape), dim=0)
+ x = x.view(shape).sum(dim=0)
+
+ # Un-permute locally to setup for the next series of operations.
+ # x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
+ x = scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
+ return x, tokens_per_expert.flatten()
+
+ def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ in_shape = x.size()
+
+ # Compute the experts.
+ x, tokens_per_expert = self.forward_fn(x, expert_weights, top_experts)
+ if self.training and self.args.moe_loss_weight > 0:
+ save_load_balancing_loss((tokens_per_expert, scores))
+ x = x.view(in_shape)
+ if self.bias is not None:
+ if self.args.return_bias:
+ return x, self.bias
+ return x + self.bias
+ return x
+
+
+class MoE(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super(MoE, self).__init__()
+
+ # Token router.
+ self.router = router.LearnedRouter(args)
+
+ # Expert computation helper.
+ self.experts = self._init_experts_mlp(args)
+
+ self.shared_expert = None
+ if args.shared_expert:
+ # SharedExpert computation helper.
+ self.shared_expert = sharedexpert_registry.get(args)
+
+ def _init_experts_mlp(self, args: Arguments):
+ return ParallelMLP(args)
+
+ def forward(self, x: torch.Tensor):
+ # NOTE: If we're going to cast the activations to lower precision
+ # do it before we permute the tokens to save bandwidth.
+ x = common.cast_if_autocast_enabled(x)
+
+ # Compute the expert scores and assignments.
+ scores, expert_weights, top_experts = self.router(x)
+
+ # Compute the experts.
+ out = self.experts(x, scores, expert_weights, top_experts)
+ if self.shared_expert is not None:
+ shared_expert_out = self.shared_expert(x)
+ out = self.shared_expert.add_experts_sharedexpert(
+ shared_expert_out,
+ out,
+ )
+ return out
diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/mpu.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/mpu.py
new file mode 100644
index 0000000000000000000000000000000000000000..434e143ab42bf3f83406d69e9dd1f72777716e22
--- /dev/null
+++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/mpu.py
@@ -0,0 +1,94 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Optional
+
+import torch
+import torch.distributed as dist
+
+# from megablocks.layers.arguments import Arguments
+from .arguments import Arguments
+
+
+class MoeParam(torch.Tensor):
+
+ def __init__(self):
+ super().__init__(self)
+ self.expert_model_parallel: bool
+
+
+def is_moe_param(tensor: torch.Tensor) -> bool:
+ return hasattr(tensor, 'expert_model_parallel')
+
+
+def get_expert_parallel_world_size(args: Arguments) -> int:
+ return (dist.get_world_size(args.expert_parallel_group) if args.moe_expert_model_parallelism else 1)
+
+
+def get_expert_parallel_rank(args: Arguments) -> int:
+ return (dist.get_rank(args.expert_parallel_group) if args.moe_expert_model_parallelism else 0)
+
+
+def set_expert_model_parallel_attributes(
+ tensor: torch.Tensor,
+ is_parallel: bool,
+):
+ assert not hasattr(tensor, 'expert_model_parallel')
+ setattr(tensor, 'expert_model_parallel', is_parallel)
+
+
+def param_is_expert_model_parallel(param: MoeParam) -> bool:
+ return (hasattr(param, 'expert_model_parallel') and param.expert_model_parallel)
+
+
+def copy_expert_model_parallel_attributes(
+ destination_tensor: torch.Tensor,
+ source_tensor: torch.Tensor,
+):
+ if hasattr(source_tensor, 'expert_model_parallel'):
+ setattr(
+ destination_tensor,
+ 'expert_model_parallel',
+ getattr(source_tensor, 'expert_model_parallel'),
+ )
+
+
+def synchronized_print(group: Optional[dist.ProcessGroup], *x: torch.Tensor):
+ world_size = dist.get_world_size(group)
+ rank = dist.get_rank(group)
+ for i in range(world_size):
+ dist.barrier(group)
+ if i == rank:
+ print(f'rank = {rank}', *x)
+
+
+# Helpers for expert/tensor sharding.
+def expert_sharding_degree(args: Arguments) -> int:
+ world_size = get_expert_parallel_world_size(args)
+ esd = min(world_size, args.moe_num_experts)
+
+ if (args.moe_num_experts % esd) != 0:
+ raise ValueError(f'Cannot shard {args.moe_num_experts} experts {esd} ways.',)
+ return esd
+
+
+def hidden_sharding_degree(args: Arguments) -> int:
+ world_size = get_expert_parallel_world_size(args)
+ esd = expert_sharding_degree(args)
+ hsd = world_size // esd
+
+ if (args.ffn_hidden_size % hsd) != 0:
+ raise ValueError(f'Cannot shard {args.ffn_hidden_size} features {hsd} ways.',)
+ if (esd * hsd) != world_size:
+ raise ValueError(
+ f"Invalid sharding. 'expert_sharding_degree' ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size}).",
+ )
+ return hsd
+
+
+def experts_per_rank(args: Arguments) -> int:
+ return args.moe_num_experts // expert_sharding_degree(args)
+
+
+def features_per_rank(args: Arguments) -> int:
+ return args.ffn_hidden_size // hidden_sharding_degree(args)
diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/router.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/router.py
new file mode 100644
index 0000000000000000000000000000000000000000..37cb2782348d62583376f1a183c7ede83601216d
--- /dev/null
+++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/router.py
@@ -0,0 +1,116 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+
+# from megablocks.layers import common
+# from megablocks.layers.arguments import Arguments
+from . import common
+from .arguments import Arguments
+
+_ROUTER_LOGITS = []
+
+
+def _save_router_logits(logits: torch.Tensor, args: Arguments):
+ if args.moe_zloss_weight == 0:
+ return
+ global _ROUTER_LOGITS
+ _ROUTER_LOGITS.append(logits)
+
+
+def clear_router_zloss():
+ global _ROUTER_LOGITS
+ _ROUTER_LOGITS.clear()
+
+
+def batched_router_zloss(args: Arguments):
+ global _ROUTER_LOGITS
+
+ if args.moe_zloss_weight == 0:
+ import warnings
+ warnings.warn('Call to batched_router_zloss, but moe_zloss_weight=0')
+ return 0
+
+ logits_per_router = _ROUTER_LOGITS
+
+ if args.moe_zloss_in_fp32:
+ logits_per_router = [logits.float() for logits in logits_per_router]
+
+ unscaled_zloss_per_router = torch.stack([
+ torch.logsumexp(logits, dim=1).square().mean() for logits in logits_per_router
+ ])
+
+ return args.moe_zloss_weight * unscaled_zloss_per_router
+
+
+# NOTE: To enable end-to-end benchmarking without convergence we
+# support a flag to force the router to assign tokens uniformly
+# across the experts. We do this with a custom autograd operation
+# so that PyTorch still executes the full set of router operation.
+class _UniformExpertAssignment(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, num_experts: int):
+ out = torch.arange(x.numel(), dtype=x.dtype, device=x.device)
+ out = torch.remainder(out, num_experts)
+ return out.view(x.shape)
+
+
+_uniform_expert_assignment = _UniformExpertAssignment.apply
+
+
+class LearnedRouter(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+
+ # Learned router parameters.
+ #
+ # NOTE: This weight matrix is not parallelized with expert model
+ # parallelism. Each device needs the entire router weight matrix
+ # so that it can route its batch of data correctly.
+ self.layer = torch.nn.Linear(
+ args.hidden_size,
+ args.moe_num_experts,
+ bias=False,
+ dtype=common.dtype(args),
+ device=args.device,
+ )
+ args.init_method(self.layer.weight)
+
+ def jitter(self, x: torch.Tensor):
+ low: float = 1.0 - self.args.moe_jitter_eps
+ high: float = 1.0 + self.args.moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return low + noise * (high - low)
+
+ def _top_k(self, scores: torch.Tensor):
+ if self.args.moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, self.args.moe_top_k, dim=-1)
+
+ def forward(self, x: torch.Tensor):
+ if self.training and self.args.moe_jitter_eps is not None:
+ x = x * self.jitter(x)
+
+ logits = self.layer(x.view(-1, x.shape[-1]))
+ _save_router_logits(logits, self.args)
+ scores = logits.softmax(dim=-1)
+ expert_weights, expert_indices = self._top_k(scores)
+ if self.args.moe_normalize_expert_weights:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=self.args.moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+
+ expert_indices = (
+ _uniform_expert_assignment(
+ expert_indices,
+ self.args.moe_num_experts,
+ ) if self.args.uniform_expert_assignment else expert_indices
+ )
+ return scores, expert_weights, expert_indices
diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/sharedexpert_registry.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/sharedexpert_registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..5840862f88f370ace5fd49bd0612fc98d186cc49
--- /dev/null
+++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/sharedexpert_registry.py
@@ -0,0 +1,32 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Union
+
+# from megablocks.layers import glu, mlp
+# from megablocks.layers.arguments import Arguments
+from . import glu, mlp
+from .arguments import Arguments
+
+_REGISTRY = {
+ 'mlp': mlp.SharedMLP,
+ 'glu': glu.SharedGLU,
+}
+
+
+def get(args: Arguments) -> Union[mlp.SharedMLP, glu.SharedGLU]:
+ """Returns an SharedMLP for use in a dMoE instance.
+
+ Uses the provided arguments to instantiate the appropriate
+ SharedMLP instance.
+
+ Args:
+ args: propagated Arguments dataclass.
+
+ Returns:
+ An instantiated SharedMLP constructed using the input args.
+ """
+ if args.mlp_type not in _REGISTRY:
+ raise ValueError(f'Unsupported mlp type: {args.mlp_type}')
+
+ return _REGISTRY[args.mlp_type](args)
diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_megablocks_885c7a2.abi3.so b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_megablocks_885c7a2.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..d6127bc51af7aca7db0bc5948113eebd538fc945
--- /dev/null
+++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_megablocks_885c7a2.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8ad7baca6130dfa611dfe75ee1691d31d651cf8a49725c6a1caa5bb0ed22ba48
+size 17876184
diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_ops.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ec290dd41dd30ed4551035db04f6c85ee1a0fe0
--- /dev/null
+++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_ops.py
@@ -0,0 +1,9 @@
+import torch
+from . import _megablocks_885c7a2
+ops = torch.ops._megablocks_885c7a2
+
+def add_op_namespace_prefix(op_name: str):
+ """
+ Prefix op by namespace.
+ """
+ return f"_megablocks_885c7a2::{op_name}"
\ No newline at end of file
diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_version.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_version.py
new file mode 100644
index 0000000000000000000000000000000000000000..c55783177af19bc03654c730c4892df8f8532279
--- /dev/null
+++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_version.py
@@ -0,0 +1,6 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+"""The MegaBlocks Version."""
+
+__version__ = '0.11.0.dev0'
diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/backend/__init__.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/backend/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d4e43e9b7e6a9a1c3bde2df34914643ca5d8332
--- /dev/null
+++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/backend/__init__.py
@@ -0,0 +1,2 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/backend/kernels.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/backend/kernels.py
new file mode 100644
index 0000000000000000000000000000000000000000..d20139324883338ddb312e4b05a72056d74491ac
--- /dev/null
+++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/backend/kernels.py
@@ -0,0 +1,557 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import triton
+import triton.language as tl
+
+# Stub triton autotune when testing in a env that does not have CUDA
+# this approach preserves the original code but enables testing without a GPU
+if torch.cuda.is_available() is False:
+ import warnings
+
+ warnings.warn("CUDA is not available. Triton autotuning is disabled.")
+
+ def _no_autotune(*args, **kwargs):
+ def deco(fn):
+ return fn
+ return deco
+
+ triton.autotune = _no_autotune
+
+
+def assert_is_tensor(x, ndim):
+ if x.ndim != ndim:
+ raise ValueError(f'Expected {ndim}-tensor but got {x.ndim}-tensor')
+
+
+def assert_is_matrix(x):
+ assert_is_tensor(x, 2)
+
+
+def assert_is_vector(x):
+ if x.ndim != 1:
+ raise ValueError(f'Expected 1-tensor but got {x.ndim}-tensor')
+
+
+def assert_equal(a, b):
+ if a != b:
+ raise ValueError(f'Expected dimensions to be equal but got {a} and {b}.',)
+
+
+# a: (tokens, hidden_size), real.
+# indices: (tokens * top_k), integer.
+# bin_ids: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+# padded_bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _padded_copy(
+ a,
+ b,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+ A_TO_B: tl.constexpr,
+ SCALE: tl.constexpr,
+):
+ # Our index into array 'a'.
+ index_a = tl.load(indices + tl.program_id(0))
+
+ # One threadblock per row in 'a'. Array 'b' has greater or equal
+ # number of rows since they could be padded.
+ bin_idx = tl.load(bin_ids + tl.program_id(0))
+
+ # Now we know what bin we're assigned to, but we need to know how
+ # many threadblocks were assigned to earlier bins so we can offset
+ # in our bin properly.
+ offset_in_bin = tl.program_id(0)
+ if bin_idx > 0:
+ offset_in_bin -= tl.load(bins + bin_idx - 1)
+
+ # Load the starting index of our bin in array 'b'.
+ index_b = offset_in_bin
+ if bin_idx > 0:
+ index_b += tl.load(padded_bins + bin_idx - 1)
+
+ # Offset the input and output pointers.
+ #
+ # If we're going from A to B, divide the input index to copy
+ # the same input repeatedly. If we're going from B to A we
+ # need to reduce the result. Using atomics is slow, so we
+ # do the reduce step in a second kernel.
+ offset = index_a // TOP_K if A_TO_B else index_a
+ a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS)
+ b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ # Load the scale, if requested.
+ scale = tl.load(weights + index_a) if SCALE else 1
+
+ # Swap the pointers depending on the direction.
+ iptr = a if A_TO_B else b
+ optr = b if A_TO_B else a
+
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ x = tl.load(iptr + offsets, mask=mask)
+ x = x.to(tl.float32) * scale.to(tl.float32)
+
+ tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask)
+
+ offsets += BLOCK_X
+
+
+def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+ assert_equal(bin_ids.shape[0], x.shape[0] * top_k)
+ assert_equal(bins.size(), padded_bins.size())
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ # NOTE: Because of the padding, the output size is dynamic.
+ # We load the final padded bin bound to get the output rows.
+ output_rows = padded_bins[-1].cpu().item()
+ out = torch.zeros((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ x,
+ out,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def gather(x, indices, bin_ids, weights, bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+ assert_equal(bin_ids.shape[0], x.shape[0] * top_k)
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ # NOTE: There is no padding so the output rows equals the
+ # input rows multiplied by top_k.
+ output_rows = x.shape[0] * top_k
+ out = torch.empty((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ x,
+ out,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], bin_ids.shape[0])
+ assert_equal(bins.size(), padded_bins.size())
+
+ if weights is not None:
+ assert_equal(indices.shape[0], weights.shape[0])
+
+ tokens = indices.shape[0] // top_k
+ out = torch.empty((tokens, top_k, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ out,
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=False,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+
+ # Reduce along the top-k dimension, if needed.
+ return out.sum(dim=1) if top_k > 1 else out.view(tokens, x.shape[1])
+
+
+def scatter(x, indices, bin_ids, weights, bins, top_k):
+ return padded_scatter(x, indices, bin_ids, weights, bins, bins, top_k)
+
+
+# x: (tokens, top_k, hidden_size), real
+# grad: (tokens, hidden_size), real.
+# wgrad: (tokens, top_k), real.
+# indices: (tokens * top_k), integer.
+# bin_ids: (tokens * top_k), integer.
+# bins: (num_experts), integer.
+# padded_bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _padded_copy_wgrad(
+ x,
+ grad,
+ wgrad,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+):
+ # Our index into 'tokens * top_k'.
+ index_out = tl.load(indices + tl.program_id(0))
+
+ # One threadblock per row in 'a'. Array 'b' has greater or equal
+ # number of rows since they could be padded.
+ bin_idx = tl.load(bin_ids + tl.program_id(0))
+
+ # Now we know what bin we're assigned to, but we need to know how
+ # many threadblocks were assigned to earlier bins so we can offset
+ # in our bin properly.
+ offset_in_bin = tl.program_id(0)
+ if bin_idx > 0:
+ offset_in_bin -= tl.load(bins + bin_idx - 1)
+
+ # Load the starting index of our bin in array 'x'.
+ index_x = offset_in_bin
+ if bin_idx > 0:
+ index_x += tl.load(padded_bins + bin_idx - 1)
+
+ # Offset the input and output pointers.
+ wgrad += index_out
+ grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS)
+ x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ data = tl.load(x + offsets, mask=mask).to(tl.float32)
+ scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
+ acc += data * scale
+ offsets += BLOCK_X
+
+ # Reduce to get the final result and store.
+ out = tl.sum(acc).to(wgrad.dtype.element_ty)
+ tl.store(wgrad, out)
+
+
+def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_matrix(grad)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], bin_ids.shape[0])
+ assert_equal(bins.size(), padded_bins.size())
+
+ tokens = indices.shape[0] // top_k
+ out = torch.empty((tokens * top_k), dtype=x.dtype, device=x.device)
+ _padded_copy_wgrad[(indices.shape[0],)](
+ x,
+ grad,
+ out,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ TOP_K=top_k,
+ )
+ return out
+
+
+def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k):
+ return padded_scatter_wgrad(x, grad, indices, bin_ids, bins, bins, top_k)
+
+
+# a: (tokens, hidden_size), real.
+# b: (num_experts, expert_capacity, num_columns), real.
+# indices: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _binned_copy(
+ a,
+ b,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+ A_TO_B: tl.constexpr,
+ SCALE: tl.constexpr,
+):
+ # Load our indices into the output.
+ expert_idx = tl.program_id(0)
+ entry_idx = tl.program_id(1)
+
+ # Calculate our offset into the output.
+ index_b = expert_idx * expert_capacity + entry_idx
+
+ # Load the index bounds for our bin and calculate
+ # the number of tokens assigned to our expert.
+ start = 0
+ if expert_idx > 0:
+ start = tl.load(bins + expert_idx - 1)
+ end = tl.load(bins + expert_idx)
+ num_tokens = end - start
+
+ # Calculate our offset into the input. If we don't
+ # have an input exit early.
+ if entry_idx >= num_tokens:
+ return
+ index_a = tl.load(indices + start + entry_idx)
+
+ # Offset the input and output pointers.
+ #
+ # If we're going from A to B, divide the input index to copy
+ # the same input repeatedly. If we're going from B to A we
+ # need to reduce the result. Using atomics is slow, so we
+ # do the reduce step in a second kernel.
+ offset = index_a // TOP_K if A_TO_B else index_a
+ a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS)
+ b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ # Load the scale, if requested.
+ scale = tl.load(weights + index_a) if SCALE else 1
+
+ # Swap the pointers depending on the direction.
+ #
+ # NOTE: We need to zero the output in both directions.
+ iptr = a if A_TO_B else b
+ optr = b if A_TO_B else a
+
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ x = tl.load(iptr + offsets, mask=mask)
+ x = x.to(tl.float32) * scale.to(tl.float32)
+
+ tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask)
+
+ offsets += BLOCK_X
+
+
+def binned_gather(x, indices, weights, bins, expert_capacity, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ num_experts = bins.shape[0]
+ out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
+
+ _binned_copy[(num_experts, expert_capacity)](
+ x,
+ out,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def binned_scatter(x, indices, weights, bins, top_k):
+ # Validate the input shapes.
+ assert_is_tensor(x, 3)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(bins.shape[0], x.shape[0])
+
+ if weights is not None:
+ assert_equal(indices.shape[0], weights.shape[0])
+
+ num_experts, expert_capacity, hidden_size = x.shape
+ tokens = indices.shape[0] // top_k
+ out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
+ _binned_copy[(num_experts, expert_capacity)](
+ out,
+ x,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS=hidden_size,
+ A_TO_B=False,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+
+ # Reduce along the top-k dimension, if needed.
+ return out.sum(dim=1) if top_k > 1 else out.view(tokens, hidden_size)
+
+
+# a: (tokens, hidden_size), real.
+# b: (num_experts, expert_capacity, num_columns), real.
+# indices: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _binned_copy_wgrad(
+ x,
+ grad,
+ wgrad,
+ num_experts,
+ expert_capacity,
+ indices,
+ bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+):
+ # Load our indices into the output.
+ expert_idx = tl.program_id(0)
+ entry_idx = tl.program_id(1)
+
+ # Calculate our offset into the output.
+ index_x = expert_idx * expert_capacity + entry_idx
+
+ # Load the index bounds for our bin and calculate
+ # the number of tokens assigned to our expert.
+ start = 0
+ if expert_idx > 0:
+ start = tl.load(bins + expert_idx - 1)
+ end = tl.load(bins + expert_idx)
+ num_tokens = end - start
+
+ # Calculate our offset into the input. If we don't
+ # have an input exit early.
+ if entry_idx >= num_tokens:
+ return
+ index_out = tl.load(indices + start + entry_idx)
+
+ # Offset the input and output pointers.
+ wgrad += index_out
+ grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS)
+ x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ data = tl.load(x + offsets, mask=mask).to(tl.float32)
+ scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
+ acc += data * scale
+ offsets += BLOCK_X
+
+ # Reduce to get the final result and store.
+ out = tl.sum(acc).to(wgrad.dtype.element_ty)
+ tl.store(wgrad, out)
+
+
+def binned_scatter_wgrad(x, grad, indices, bins, top_k):
+ # Validate the input shapes.
+ assert_is_tensor(x, 3)
+ assert_is_matrix(grad)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(bins.shape[0], x.shape[0])
+
+ num_experts, expert_capacity, hidden_size = x.shape
+ tokens = indices.shape[0] // top_k
+ out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
+ _binned_copy_wgrad[(num_experts, expert_capacity)](
+ x,
+ grad,
+ out,
+ num_experts,
+ expert_capacity,
+ indices,
+ bins,
+ NUM_COLUMNS=hidden_size,
+ TOP_K=top_k,
+ )
+ return out
diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/benchmark_util.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/benchmark_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..02612d95e3ead1175a596e2878fa34b5bf85ad6f
--- /dev/null
+++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/benchmark_util.py
@@ -0,0 +1,35 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import numpy as np
+import torch
+
+
+def log_benchmark(name, arguments, time, std):
+ print('=' * 60)
+ print(f'{name} Benchmark')
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean time = {:.3f}ms, std time = {:.3f}ms'.format(time, std))
+ print('=' * 60)
+
+
+def benchmark_function(fn, iterations=100, warmup=10):
+ # Warmup iterations.
+ for _ in range(warmup):
+ fn()
+
+ times = []
+ for i in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+
+ start.record()
+ fn()
+ end.record()
+
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ return np.mean(times), np.std(times)
diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/grouped_gemm/__init__.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/grouped_gemm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b91c8308f0c24f4c4171b6e4f15b6f76dabf295a
--- /dev/null
+++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/grouped_gemm/__init__.py
@@ -0,0 +1,2 @@
+from . import ops
+from . import backend
diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/grouped_gemm/backend.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/grouped_gemm/backend.py
new file mode 100644
index 0000000000000000000000000000000000000000..76037d8039cbfc2f0577275c78e4bc0be762592a
--- /dev/null
+++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/grouped_gemm/backend.py
@@ -0,0 +1,33 @@
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# # TODO(tgale): Wrap this in a try-block with better
+# # error message and instructions for building the
+# # c++ operations.
+# import grouped_gemm_backend as backend
+
+# We import the backend operations from the megablocks package as
+# grouped_gemm is vendored in megablocks in this repository.
+# from ... import _ops as backend
+# from megablocks._ops import ops as backend # type: ignore
+from .._ops import ops as backend # type: ignore
+
+def _allocate_output(a, b, batch_sizes, trans_a, trans_b):
+ assert not (trans_a and trans_b)
+ assert batch_sizes.ndim == 1, "Expected 1d tensor for batch_sizes"
+ assert a.ndim == 2, "Expected 2d tensor for 'a'"
+ assert b.ndim == (2 if trans_a else 3)
+
+ shape = (
+ (batch_sizes.shape[0], a.shape[1], b.shape[1])
+ if trans_a else
+ (a.shape[0], (b.shape[1] if trans_b else b.shape[2]))
+ )
+ return torch.empty(*shape, device=a.device, dtype=a.dtype)
+
+def gmm(a, b, batch_sizes, trans_a=False, trans_b=False, c=None):
+ if c is None:
+ c = _allocate_output(a, b, batch_sizes, trans_a, trans_b)
+ backend.gmm(a, b, c, batch_sizes, trans_a, trans_b)
+ return c
diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/grouped_gemm/ops.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/grouped_gemm/ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b30dd14e23837ea3b12334f4e31337ed9ad2b69
--- /dev/null
+++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/grouped_gemm/ops.py
@@ -0,0 +1,33 @@
+from . import backend
+import torch
+
+
+class GroupedGemm(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, a, b, batch_sizes, trans_b):
+ ctx.save_for_backward(a, b, batch_sizes)
+ ctx.trans_b = trans_b
+ return backend.gmm(a, b, batch_sizes, trans_a=False, trans_b=trans_b)
+
+ @staticmethod
+ def backward(ctx, grad):
+ grad = grad.contiguous()
+ a, b, batch_sizes = ctx.saved_tensors
+ trans_b = ctx.trans_b
+
+ agrad = None
+ if ctx.needs_input_grad[0]:
+ agrad = backend.gmm(
+ grad, b, batch_sizes, trans_a=False, trans_b=not trans_b)
+
+ bgrad = None
+ if ctx.needs_input_grad[1]:
+ lhs, rhs = (grad, a) if trans_b else (a, grad)
+ bgrad = backend.gmm(
+ lhs, rhs, batch_sizes, trans_a=True, trans_b=False)
+ return agrad, bgrad, None, None
+
+
+def gmm(a, b, batch_sizes, trans_b=False):
+ return GroupedGemm.apply(a, b, batch_sizes, trans_b)
diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/grouped_gemm_util.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/grouped_gemm_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d6d49fc46b0a57ad46e4179df3cc1ac2a24f7ae
--- /dev/null
+++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/grouped_gemm_util.py
@@ -0,0 +1,31 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+import warnings
+
+_grouped_gemm_is_available: bool = False
+try:
+ # import grouped_gemm
+ pass
+ _grouped_gemm_is_available = True
+except ImportError as error:
+ warnings.warn('Grouped GEMM not available.')
+
+
+def grouped_gemm_is_available():
+ return _grouped_gemm_is_available
+
+
+def assert_grouped_gemm_is_available():
+ msg = (
+ 'Grouped GEMM not available. Please run '
+ '`pip install git+https://github.com/tgale96/grouped_gemm@main`.',
+ )
+ assert _grouped_gemm_is_available, msg
+
+
+# backend = grouped_gemm.backend if grouped_gemm_is_available() else None
+# ops = grouped_gemm.ops if grouped_gemm_is_available() else None
+
+
+from .grouped_gemm import backend as ops
+from .grouped_gemm import ops as backend
diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b835ac5f6929edb8b547f373212388f34be3868
--- /dev/null
+++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers.py
@@ -0,0 +1,1225 @@
+import torch
+import torch.distributed as dist
+
+from typing import Optional, Any, TYPE_CHECKING
+
+from . import _layers
+from . import ops
+
+# Conditional import for meta kernel registration
+if TYPE_CHECKING:
+
+ def register_fake(fn):
+ return lambda name: fn
+
+else:
+ try:
+ from torch.library import register_fake
+ except ImportError:
+ try:
+ from torch.library import impl_abstract as register_fake
+ except ImportError:
+ # Fallback for older PyTorch versions
+ def register_fake(op_name):
+ def decorator(fn):
+ return fn
+
+ return decorator
+
+
+# Meta kernel implementations for torch.compile compatibility
+def _install_meta_kernels():
+ """Install meta kernels for existing MegaBlocks operations"""
+
+ # Create wrapper functions that check for compilation and return meta tensors
+
+ # Patch ops.sort
+ if hasattr(ops, "sort"):
+ original_sort = ops.sort
+
+ def sort_with_meta(x, end_bit=None):
+ if torch.compiler.is_compiling():
+ print("Using meta kernel for sort")
+ # Meta implementation - return tensors with correct shape/dtype/device
+ return torch.empty_like(x), torch.empty_like(x)
+ # print("Using original sort kernel")
+ return original_sort(x, end_bit)
+
+ ops.sort = sort_with_meta
+
+ # Patch ops.histogram
+ if hasattr(ops, "histogram"):
+ original_histogram = ops.histogram
+
+ def histogram_with_meta(x, max_val):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty((max_val,), dtype=torch.int32, device=x.device)
+ return original_histogram(x, max_val)
+
+ ops.histogram = histogram_with_meta
+
+ # Patch ops.inclusive_cumsum
+ if hasattr(ops, "inclusive_cumsum"):
+ original_inclusive_cumsum = ops.inclusive_cumsum
+
+ def inclusive_cumsum_with_meta(x, dim):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty_like(x)
+ return original_inclusive_cumsum(x, dim)
+
+ ops.inclusive_cumsum = inclusive_cumsum_with_meta
+
+ # Patch ops.binned_gather
+ if hasattr(ops, "binned_gather"):
+ original_binned_gather = ops.binned_gather
+
+ def binned_gather_with_meta(x, indices, bins, bin_size, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - output shape based on bin_size
+ if x.dim() >= 2:
+ hidden_size = x.size(-1)
+ return torch.empty(
+ (bin_size, x.size(1), hidden_size),
+ dtype=x.dtype,
+ device=x.device,
+ )
+ else:
+ return torch.empty((bin_size,), dtype=x.dtype, device=x.device)
+ return original_binned_gather(x, indices, bins, bin_size, top_k)
+
+ ops.binned_gather = binned_gather_with_meta
+
+ # Patch ops.binned_scatter
+ if hasattr(ops, "binned_scatter"):
+ original_binned_scatter = ops.binned_scatter
+
+ def binned_scatter_with_meta(x, indices, weights, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - typically reduces to 2D
+ if x.dim() >= 3:
+ return torch.empty(
+ (x.size(1), x.size(2)), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty_like(x)
+ return original_binned_scatter(x, indices, weights, bins, top_k)
+
+ ops.binned_scatter = binned_scatter_with_meta
+
+ # Patch ops.gather
+ if hasattr(ops, "gather"):
+ original_gather = ops.gather
+
+ def gather_with_meta(x, indices, bin_ids, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ if x.dim() >= 2:
+ hidden_size = x.size(-1)
+ return torch.empty(
+ (indices.numel(), hidden_size), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty(indices.shape, dtype=x.dtype, device=x.device)
+ return original_gather(x, indices, bin_ids, bins, top_k)
+
+ ops.gather = gather_with_meta
+
+ # Patch ops.scatter
+ if hasattr(ops, "scatter"):
+ original_scatter = ops.scatter
+
+ def scatter_with_meta(x, indices, bin_ids, weights, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - restore sequence shape
+ seq_len = (
+ indices.size(0) // top_k
+ if indices.numel() > 0 and top_k > 0
+ else x.size(0)
+ )
+ if x.dim() >= 2:
+ return torch.empty(
+ (seq_len, x.size(-1)), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty((seq_len,), dtype=x.dtype, device=x.device)
+ return original_scatter(x, indices, bin_ids, weights, bins, top_k)
+
+ ops.scatter = scatter_with_meta
+
+ # Patch ops.replicate
+ if hasattr(ops, "replicate"):
+ original_replicate = ops.replicate
+
+ def replicate_with_meta(x, bins, num_outputs):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty(
+ (x.shape[0], num_outputs), dtype=x.dtype, device=x.device
+ )
+ return original_replicate(x, bins, num_outputs)
+
+ ops.replicate = replicate_with_meta
+
+ # Patch ops.repeat (if it's a regular function)
+ if hasattr(ops, "repeat"):
+ original_repeat = ops.repeat
+
+ def repeat_with_meta(x, repeats):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ if isinstance(repeats, (tuple, list)):
+ new_shape = list(x.shape)
+ for i, rep in enumerate(repeats):
+ if i < len(new_shape):
+ new_shape[i] *= rep
+ return torch.empty(new_shape, dtype=x.dtype, device=x.device)
+ else:
+ new_shape = [x.size(0) * repeats] + list(x.shape[1:])
+ return torch.empty(new_shape, dtype=x.dtype, device=x.device)
+ return original_repeat(x, repeats)
+
+ ops.repeat = repeat_with_meta
+
+
+# Install meta kernels on import
+try:
+ _install_meta_kernels()
+except Exception as e:
+ # If meta kernel installation fails, continue without them
+ # torch.compile may not work but the library will still function
+ import warnings
+
+ warnings.warn(
+ f"Failed to install meta kernels for torch.compile support: {e}", UserWarning
+ )
+
+
+# Set the expert model parallel attributes on a tensor
+def set_expert_model_parallel_attributes(
+ tensor: torch.Tensor,
+ is_parallel: bool,
+):
+ assert not hasattr(tensor, "expert_model_parallel")
+ setattr(tensor, "expert_model_parallel", is_parallel)
+
+
+# Get the expert model parallel attributes from a tensor
+def expert_sharding_degree(
+ world_size: int,
+ moe_num_experts: int,
+) -> int:
+ esd = min(world_size, moe_num_experts)
+ if (moe_num_experts % esd) != 0:
+ raise ValueError(f"Cannot shard {moe_num_experts} experts {esd} ways.")
+ return esd
+
+
+# Calculate the hidden sharding degree based on world size and expert sharding degree
+def hidden_sharding_degree(
+ world_size: int,
+ moe_num_experts: int,
+ ffn_hidden_size: int,
+) -> int:
+ esd = expert_sharding_degree(world_size, moe_num_experts)
+ hsd = world_size // esd
+ if (ffn_hidden_size % hsd) != 0:
+ raise ValueError(f"Cannot shard {ffn_hidden_size} features {hsd} ways.")
+ if (esd * hsd) != world_size:
+ raise ValueError(
+ f"Invalid sharding. expert_sharding_degree ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size})."
+ )
+ return hsd
+
+
+# Calculate the number of experts per rank based on world size and expert sharding degree
+def experts_per_rank(
+ moe_num_experts: int,
+ world_size: int,
+) -> int:
+ return moe_num_experts // expert_sharding_degree(world_size, moe_num_experts)
+
+
+# Calculate the number of features per rank based on ffn hidden size and hidden sharding degree
+def features_per_rank(
+ ffn_hidden_size: int, world_size: int, moe_num_experts: int
+) -> int:
+ return ffn_hidden_size // hidden_sharding_degree(
+ world_size, moe_num_experts, ffn_hidden_size
+ )
+
+
+# Apply jitter to the input tensor
+def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor:
+ low = 1.0 - moe_jitter_eps
+ high = 1.0 + moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return x * (low + noise * (high - low))
+
+
+# Compute the top-k scores from the logits
+def compute_top_k(scores: torch.Tensor, moe_top_k: int):
+ if moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, moe_top_k, dim=-1)
+
+
+# Route tokens to experts and compute expert weights and indices
+def route_tokens(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: torch.Tensor,
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ if training and moe_jitter_eps is not None:
+ x = apply_jitter(x, moe_jitter_eps)
+
+ x_flat = x.view(-1, x.shape[-1])
+ logits = torch.nn.functional.linear(x_flat, router_weight, router_bias)
+ expert_weights, expert_indices = compute_top_k(logits, moe_top_k)
+ expert_weights = expert_weights.softmax(dim=-1)
+ if moe_normalize_expert_weights is not None:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+ if uniform_expert_assignment:
+ expert_indices = _layers.router._uniform_expert_assignment(
+ expert_indices,
+ moe_num_experts,
+ )
+
+ return logits, expert_weights, expert_indices
+
+
+# Scale the gradient of the weights
+def scale_grad(
+ w: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+) -> torch.Tensor:
+ if gradient_scale is None:
+ return w
+ return _layers.mlp.scale_gradient(w, gradient_scale)
+
+
+# Forward pass for the MLP layer
+def mlp_forward(
+ x: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ limit: float = 7.0,
+):
+ # Scale weights
+ w1 = scale_grad(w1, gradient_scale)
+ w2 = scale_grad(w2, gradient_scale)
+ w1_bias = scale_grad(w1_bias, gradient_scale)
+ w2_bias = scale_grad(w2_bias, gradient_scale)
+
+ # Resolve dtensors
+ w1 = _layers.mlp.resolve_dtensor(w1)
+ w2 = _layers.mlp.resolve_dtensor(w2)
+ w1_bias = _layers.mlp.resolve_dtensor(w1_bias)
+ w2_bias = _layers.mlp.resolve_dtensor(w2_bias)
+
+ # Forward pass
+ gate_up = torch.bmm(x, w1) + w1_bias[..., None, :]
+ gate, up = gate_up[..., ::2], gate_up[..., 1::2]
+ gate = gate.clamp(min=None, max=limit)
+ up = up.clamp(min=-limit, max=limit)
+ glu = gate * torch.sigmoid(gate * alpha)
+ next_states = torch.bmm(((up + 1) * glu), w2)
+ next_states += w2_bias[..., None, :]
+ return next_states
+
+# Shared expert MLP forward pass
+def shared_mlp_forward(
+ x: torch.Tensor,
+ up_proj_weight: torch.Tensor,
+ down_proj_weight: torch.Tensor,
+ up_proj_bias: Optional[torch.Tensor] = None,
+ down_proj_bias: Optional[torch.Tensor] = None,
+ activation_fn: Optional[Any] = None,
+ gradient_scale: Optional[float] = None,
+) -> torch.Tensor:
+ # Default activation function
+ if activation_fn is None:
+ activation_fn = torch.nn.functional.gelu
+
+ # Scale weights
+ up_proj_weight = scale_grad(up_proj_weight, gradient_scale)
+ down_proj_weight = scale_grad(down_proj_weight, gradient_scale)
+ if up_proj_bias is not None:
+ up_proj_bias = scale_grad(up_proj_bias, gradient_scale)
+ if down_proj_bias is not None:
+ down_proj_bias = scale_grad(down_proj_bias, gradient_scale)
+
+ # Resolve dtensors
+ up_proj_weight = _layers.mlp.resolve_dtensor(up_proj_weight)
+ down_proj_weight = _layers.mlp.resolve_dtensor(down_proj_weight)
+ if up_proj_bias is not None:
+ up_proj_bias = _layers.mlp.resolve_dtensor(up_proj_bias)
+ if down_proj_bias is not None:
+ down_proj_bias = _layers.mlp.resolve_dtensor(down_proj_bias)
+
+ # Up projection
+ x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias)
+
+ # Activation
+ x = activation_fn(x)
+
+ # Down projection
+ x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias)
+
+ return x
+
+
+# Combine outputs from shared expert and regular experts
+def combine_expert_shared_outputs(
+ shared_expert_out: torch.Tensor,
+ expert_out: torch.Tensor,
+ shared_expert_weighted_sum: bool = False,
+ moe_top_k: int = 1,
+) -> torch.Tensor:
+ if shared_expert_weighted_sum:
+ # Weighted sum based on number of experts used
+ total_experts = moe_top_k + 1
+ shared_weight = 1.0 / total_experts
+ expert_weight = moe_top_k / total_experts
+ return shared_expert_out * shared_weight + expert_out * expert_weight
+ else:
+ # Simple addition
+ return shared_expert_out + expert_out
+
+
+# Global variable to store load balancing loss
+_LOAD_BALANCING_LOSS = []
+
+
+def save_load_balancing_loss(loss):
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.append(loss)
+
+
+def get_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ return _LOAD_BALANCING_LOSS
+
+
+def clear_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.clear()
+
+
+def batched_load_balancing_loss(args):
+ if args.moe_loss_weight == 0:
+ return 0.0
+
+ tokens_per_expert, expert_scores = zip(*get_load_balancing_loss())
+ num_layers_per_pipeline_stage = args.num_layers // args.pipeline_model_parallel_size
+ if args.num_layers_per_virtual_pipeline_stage is not None:
+ num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage
+
+ if len(tokens_per_expert) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f"Expected {num_layers_per_pipeline_stage} token_per_experts "
+ f"but found {len(tokens_per_expert)}.\nnum_layers = "
+ f"{args.num_layers}\npipeline_model_parallel_size = "
+ f"{args.pipeline_model_parallel_size}\n"
+ "num_layers_per_virtual_pipeline_stage"
+ f" = {args.num_layers_per_virtual_pipeline_stage}",
+ )
+ if len(expert_scores) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f"Expected {num_layers_per_pipeline_stage} expert_scores "
+ f"but found {len(tokens_per_expert)}.\nnum_layers = "
+ f"{args.num_layers}\npipeline_model_parallel_size = "
+ f"{args.pipeline_model_parallel_size}\n"
+ "num_layers_per_virtual_pipeline_stage"
+ f" = {args.num_layers_per_virtual_pipeline_stage}",
+ )
+
+ # Verify the shape of the tokens_per_expert and expert_scores tensors.
+ assert all(
+ (x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert)
+ )
+
+ tokens = expert_scores[0].shape[0]
+ assert all(
+ (
+ (
+ x.ndim == 2
+ and x.shape[1] == args.moe_num_experts
+ and x.shape[0] == tokens
+ )
+ for x in expert_scores
+ )
+ )
+
+ # Concatenate the contributions of each layer and convert to
+ # the correct types and formats for the dot product.
+ expert_scores = torch.cat(expert_scores, dim=1)
+ if args.moe_lbl_in_fp32:
+ expert_scores = expert_scores.float()
+ if tokens != 0:
+ expert_scores = expert_scores.mean(dim=0)
+ else:
+ expert_scores = expert_scores.sum(dim=0)
+ tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype)
+
+ expected_values = num_layers_per_pipeline_stage * args.moe_num_experts
+ assert tokens_per_expert.numel() == expected_values
+ assert expert_scores.numel() == expected_values
+
+ # Calculate the total scale across all factors.
+ #
+ # loss_weight * num_experts / (num_layers * tokens * top_k)
+ scale_numerator = args.moe_num_experts * args.moe_loss_weight
+ scale_denominator = args.num_layers * tokens * args.moe_top_k
+ scale = scale_numerator / scale_denominator
+ return scale * torch.dot(tokens_per_expert, expert_scores)
+
+
+# Calculate the expert capacity based on tokens, top_k, number of experts,
+# expert parallel group, capacity factor, and whether expert model parallelism is used.
+def expert_capacity(
+ tokens: int,
+ top_k: int,
+ num_experts: int,
+ expert_parallel_group: int,
+ moe_capacity_factor: float,
+ moe_expert_model_parallelism: bool,
+) -> int:
+ world_size = (
+ dist.get_world_size(expert_parallel_group)
+ if moe_expert_model_parallelism
+ else 1
+ )
+
+ tokens_per_expert = top_k * tokens * world_size / num_experts
+ return int(moe_capacity_factor * tokens_per_expert)
+
+
+def load_balancing_loss(
+ tokens_per_expert: torch.Tensor,
+ expert_scores: torch.Tensor,
+ top_k: int,
+ num_experts: int,
+):
+ assert len(expert_scores.size()) == 2
+ tokens, num_experts = expert_scores.size()
+ assert num_experts == num_experts
+ assert len(tokens_per_expert.size()) == 1
+ (num_experts,) = tokens_per_expert.size()
+ assert num_experts == num_experts
+ scale = num_experts / (tokens * top_k)
+ return scale * torch.dot(
+ tokens_per_expert.to(expert_scores.dtype),
+ expert_scores.mean(dim=0),
+ )
+
+
+def indices_and_bins(
+ top_expert: torch.Tensor,
+ sort_end_bit: int,
+ num_experts: int,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ top_expert = top_expert.int()
+
+ # Ensure contiguous memory layout
+ top_expert = top_expert.contiguous()
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(top_expert.device):
+ output = ops.sort(top_expert, sort_end_bit)
+ bin_ids, indices = output
+ tokens_per_expert = ops.histogram(top_expert, num_experts)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+
+ bins = bins.view(1) if not len(bins.size()) else bins
+ return indices, bin_ids, bins, tokens_per_expert
+
+
+def expert_capacity_fn(
+ tokens: int,
+ top_k: int,
+ num_experts: int,
+ expert_parallel_group: torch.distributed.ProcessGroup,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+) -> int:
+ world_size = (
+ dist.get_world_size(expert_parallel_group)
+ if moe_expert_model_parallelism
+ else 1
+ )
+ tokens_per_expert = top_k * tokens * world_size / num_experts
+ return int(moe_capacity_factor * tokens_per_expert)
+
+
+def permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ top_k,
+ w1,
+ w2,
+ w1_bias,
+ w2_bias,
+ gradient_scale,
+ alpha,
+):
+ # Route tokens to experts
+ x = x.view(-1, x.shape[-1])
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(x.device):
+ x = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
+
+ # Expert computation
+ x = mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale, alpha)
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(x.device):
+ # Route tokens back
+ out = ops.binned_scatter(x, indices, expert_weights, bins, top_k)
+ return out
+
+
+def forward_once(
+ x: torch.Tensor,
+ expert_weights: torch.Tensor,
+ top_experts: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ top_k: int = 4,
+ num_experts: int = 128,
+ expert_parallel_group: int = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ mlp_impl: Optional[str] = None,
+):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
+ top_experts, sort_end_bit, num_experts
+ )
+
+ # Calculate expert capacity
+ sl, bs, _ = x.size()
+
+ expert_capacity = expert_capacity_fn(
+ sl * bs,
+ top_k,
+ num_experts,
+ expert_parallel_group,
+ moe_capacity_factor,
+ moe_expert_model_parallelism,
+ )
+
+ if expert_capacity == 0:
+ expert_capacity = torch.max(tokens_per_expert).item()
+
+ x = permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ top_k,
+ w1,
+ w2,
+ w1_bias,
+ w2_bias,
+ gradient_scale,
+ alpha,
+ )
+ return x, tokens_per_expert
+
+
+def parallel_forward_once(
+ x: torch.Tensor,
+ expert_weights: torch.Tensor,
+ top_experts: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ top_k: int = 4,
+ num_experts: int = 128,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = True,
+ hidden_size: int = 1152,
+ mlp_impl: Optional[str] = "grouped",
+):
+ # Flatten inputs
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+
+ # TODO: remove debugging var
+ # my_rank = dist.get_rank(expert_parallel_group) if expert_parallel_group else 0
+
+ with torch.no_grad():
+ # Step 1: Local permutation setup
+ indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
+ top_experts, sort_end_bit, num_experts
+ )
+
+ # Calculate sharding parameters
+ world_size = dist.get_world_size(expert_parallel_group)
+ hidden_sharding_deg = hidden_sharding_degree(
+ world_size, num_experts, hidden_size
+ )
+ experts_per_rank_val = experts_per_rank(num_experts, world_size)
+
+ # Replicate token counts for hidden sharding
+ repeated_tokens_per_expert = ops.repeat(
+ tokens_per_expert, (hidden_sharding_deg,)
+ )
+
+ # Exchange token counts across devices
+ parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert)
+
+ # Ensure CUB knows which device to use
+ tpe_handle = dist.all_to_all_single(
+ parallel_tokens_per_expert,
+ repeated_tokens_per_expert,
+ group=expert_parallel_group,
+ async_op=True,
+ )
+
+ # Step 2: Local permutation - group tokens by target device
+ x = x.view(-1, x.shape[-1]) # [sl * bs, hs]
+ x = ops.gather(x, indices, bin_ids, bins, top_k)
+
+ # Step 3: Compute communication counts and exchange tokens
+ with torch.no_grad():
+ tpe_handle.wait()
+
+ # Reshape for per-device calculations
+ repeated_tokens_per_expert = repeated_tokens_per_expert.view(
+ world_size, experts_per_rank_val
+ )
+ parallel_tokens_per_expert = parallel_tokens_per_expert.view(
+ world_size, experts_per_rank_val
+ )
+
+ # Calculate send/recv counts
+ send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1).tolist()
+ # recv_counts = parallel_tokens_per_expert.cpu().sum(dim=-1).tolist()
+ parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
+ recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1).tolist()
+ tokens_received = sum(recv_counts)
+
+ # Replicate for hidden sharding
+ x = ops.repeat(x, (hidden_sharding_deg, 1))
+
+ # Cross-device token exchange
+ parallel_x, parallel_x_handle = _layers.all_to_all.all_to_all(
+ x, recv_counts, send_counts, expert_parallel_group, async_op=True
+ )
+
+ with torch.no_grad():
+ # Step 4: Setup for local expert computation
+ replicate_bins = ops.inclusive_cumsum(parallel_tokens_per_expert.flatten(), 0)
+ replicate_bins = (
+ replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins
+ )
+
+ # Create expert indices for received tokens
+ parallel_top_expert = torch.remainder(
+ torch.arange(
+ num_experts * hidden_sharding_deg,
+ dtype=torch.int32,
+ device=indices.device,
+ ),
+ experts_per_rank_val,
+ )
+ parallel_top_expert = ops.replicate(
+ parallel_top_expert.unsqueeze(dim=0),
+ replicate_bins,
+ tokens_received,
+ ).flatten()
+
+ # Sort tokens by expert assignment
+ parallel_bin_ids, parallel_indices = ops.sort(
+ parallel_top_expert,
+ sort_end_bit,
+ )
+
+ # Calculate bins for local experts
+ parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
+ dim=0, dtype=torch.int
+ )
+ parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = (
+ parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins
+ )
+
+ # Calculate expert capacity
+ expert_capacity = expert_capacity_fn(
+ tokens_received,
+ top_k,
+ experts_per_rank_val,
+ expert_parallel_group,
+ moe_capacity_factor,
+ moe_expert_model_parallelism,
+ )
+ if expert_capacity == 0:
+ expert_capacity = torch.max(parallel_tokens_per_expert).item()
+
+ # Locally permute the tokens and perform the expert computation.
+ # Block to make sure that the cross-device permutation is complete.
+ if mlp_impl == "grouped":
+ # GroupedMLP requires counts on CPU. We can use the tensor already
+ # moved to CPU for the prior all_to_all, which avoids an extra
+ # device synchronization.
+ parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+
+ # Step 5: Expert computation
+ parallel_x_handle.wait()
+
+ parallel_x = permute_and_compute(
+ parallel_x,
+ parallel_tokens_per_expert,
+ parallel_indices,
+ parallel_bin_ids,
+ None, # expert_weights
+ parallel_bins,
+ expert_capacity,
+ top_k=1,
+ w1=w1,
+ w2=w2,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ )
+
+ # Step 6: Reverse communication - send results back
+ x, _ = _layers.all_to_all.all_to_all(
+ parallel_x, send_counts, recv_counts, expert_parallel_group
+ )
+
+ # Step 7: Reduce across hidden sharding dimension
+ shape = (hidden_sharding_deg, -1, hidden_size)
+ x = x.view(shape).sum(dim=0)
+
+ # Step 8: Final local unpermutation
+ x = ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
+
+ return x, tokens_per_expert.flatten()
+
+
+def moe_forward(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: Optional[torch.Tensor],
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+ w1: torch.Tensor = None,
+ w2: torch.Tensor = None,
+ w1_bias: torch.Tensor = None,
+ w2_bias: torch.Tensor = None,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ forward_fn: Any = None,
+ hidden_size: int = None,
+ mlp_impl: str = "grouped",
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+
+ # Route tokens to experts
+ logits, expert_weights, expert_indices = route_tokens(
+ x,
+ router_weight,
+ router_bias,
+ moe_top_k,
+ moe_num_experts,
+ moe_jitter_eps,
+ moe_normalize_expert_weights,
+ uniform_expert_assignment,
+ training,
+ )
+
+ # Create router scores for output
+ router_scores = (
+ torch.zeros_like(logits)
+ .scatter_(1, expert_indices, expert_weights)
+ .transpose(0, 1)
+ )
+
+ in_shape = x.size()
+
+ # Prepare forward function arguments
+ forward_args = {
+ "x": x,
+ "expert_weights": expert_weights,
+ "top_experts": expert_indices,
+ "w1": w1,
+ "w2": w2,
+ "w1_bias": w1_bias,
+ "w2_bias": w2_bias,
+ "gradient_scale": gradient_scale,
+ "alpha": alpha,
+ "sort_end_bit": sort_end_bit,
+ "top_k": moe_top_k,
+ "num_experts": moe_num_experts,
+ "expert_parallel_group": expert_parallel_group,
+ "moe_capacity_factor": moe_capacity_factor,
+ "moe_expert_model_parallelism": moe_expert_model_parallelism,
+ "mlp_impl": mlp_impl,
+ }
+
+ # Add hidden_size for parallel forward
+ if moe_expert_model_parallelism and hidden_size is not None:
+ forward_args["hidden_size"] = hidden_size
+ elif moe_expert_model_parallelism and hidden_size is None:
+ # Infer hidden_size from input shape
+ forward_args["hidden_size"] = x.shape[-1]
+
+ # Compute expert outputs
+ x, tokens_per_expert = forward_fn(**forward_args)
+
+ # Save load balancing loss if needed
+ moe_loss_weight = 0.0 # Can be made configurable
+ if training and moe_loss_weight > 0:
+ save_load_balancing_loss((tokens_per_expert, logits))
+
+ # Restore original shape
+ x = x.view(in_shape)
+
+ return x, expert_weights, router_scores
+
+
+def moe_forward_with_shared_expert(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: Optional[torch.Tensor],
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+ w1: torch.Tensor = None,
+ w2: torch.Tensor = None,
+ w1_bias: torch.Tensor = None,
+ w2_bias: torch.Tensor = None,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ forward_fn: Any = None,
+ hidden_size: int = None,
+ mlp_impl: str = "grouped",
+ # Shared expert parameters
+ shared_up_proj_weight: Optional[torch.Tensor] = None,
+ shared_down_proj_weight: Optional[torch.Tensor] = None,
+ shared_up_proj_bias: Optional[torch.Tensor] = None,
+ shared_down_proj_bias: Optional[torch.Tensor] = None,
+ shared_expert_weighted_sum: bool = False,
+ shared_activation_fn: Optional[Any] = None,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+
+ # First, compute regular MoE forward pass
+ expert_out, expert_weights, router_scores = moe_forward(
+ x=x,
+ router_weight=router_weight,
+ router_bias=router_bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=training,
+ w1=w1,
+ w2=w2,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=moe_expert_model_parallelism,
+ forward_fn=forward_fn,
+ hidden_size=hidden_size,
+ mlp_impl=mlp_impl,
+ )
+
+ # If shared expert weights provided, compute shared expert output
+ if shared_up_proj_weight is not None and shared_down_proj_weight is not None:
+ shared_expert_out = shared_mlp_forward(
+ x=x,
+ up_proj_weight=shared_up_proj_weight,
+ down_proj_weight=shared_down_proj_weight,
+ up_proj_bias=shared_up_proj_bias,
+ down_proj_bias=shared_down_proj_bias,
+ activation_fn=shared_activation_fn,
+ gradient_scale=gradient_scale,
+ )
+
+ # Combine expert outputs
+ combined_out = combine_expert_shared_outputs(
+ shared_expert_out=shared_expert_out,
+ expert_out=expert_out,
+ shared_expert_weighted_sum=shared_expert_weighted_sum,
+ moe_top_k=moe_top_k,
+ )
+
+ return combined_out, expert_weights, router_scores
+
+ # Return regular MoE output if no shared expert
+ return expert_out, expert_weights, router_scores
+
+
+def create_shared_expert_weights(
+ hidden_size: int,
+ shared_expert_hidden_size: int,
+ device: torch.device,
+ dtype: torch.dtype,
+ init_method: Any,
+ output_layer_init_method: Any = None,
+) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
+
+ if output_layer_init_method is None:
+ output_layer_init_method = init_method
+
+ # Create weight tensors
+ up_proj_weight = torch.empty(
+ shared_expert_hidden_size,
+ hidden_size,
+ device=device,
+ dtype=dtype,
+ )
+ down_proj_weight = torch.empty(
+ hidden_size,
+ shared_expert_hidden_size,
+ device=device,
+ dtype=dtype,
+ )
+
+ # Initialize weights
+ init_method(up_proj_weight)
+ output_layer_init_method(down_proj_weight)
+
+ # No bias by default
+ return up_proj_weight, down_proj_weight, None, None
+
+
+# HACK: Extract device_mesh from pre-hook closure - required for transformers integration
+# This exists because device_mesh is trapped in hook closures with no model attribute
+# Fragile - breaks if hook structure changes or Python internals change
+# TODO: Replace with a more robust solution when available
+def get_device_mesh(model):
+ # Extract device_mesh from child's unused pre_hook closure
+ try:
+ # Find the pre-hook that contains 'device_mesh' in its closure
+ hook = next(
+ h
+ for h in model.experts._forward_pre_hooks.values()
+ if "device_mesh" in h.__code__.co_freevars
+ )
+ # Extract the device_mesh from the closure
+ return hook.__closure__[
+ hook.__code__.co_freevars.index("device_mesh")
+ ].cell_contents
+ except Exception:
+ return None
+
+
+class MegaBlocksMoeMLP(torch.nn.Module):
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ gradient_scale = getattr(self.experts, "gradient_scale", None)
+ alpha = getattr(self.experts, "alpha", 1.0)
+ moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+ uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
+
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = get_device_mesh(self)
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
+
+ has_parallel = (
+ expert_parallel_group is not None
+ and dist.is_initialized()
+ and dist.get_world_size(expert_parallel_group) > 1
+ )
+ forward_fn = parallel_forward_once if has_parallel else forward_once
+
+ sort_end_bit = max(
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
+ )
+ mlp_impl = getattr(self, "mlp_impl", "grouped")
+ output, expert_weights_out, *_ = moe_forward(
+ x=x,
+ router_weight=self.router.weight,
+ router_bias=self.router.bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=self.training,
+ w1=self.experts.gate_up_proj,
+ w2=self.experts.down_proj,
+ w1_bias=self.experts.gate_up_proj_bias,
+ w2_bias=self.experts.down_proj_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=has_parallel,
+ forward_fn=forward_fn,
+ hidden_size=self.experts.hidden_size,
+ mlp_impl=mlp_impl,
+ )
+ return output, expert_weights_out
+
+
+# Export main classes
+__all__ = ["MegaBlocksMoeMLP", "MegaBlocksMoeMLPWithSharedExpert"]
+
+
+class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
+
+ def __init__(self):
+ super().__init__()
+ # Shared expert weights will be set by the user
+ self.shared_up_proj_weight = None
+ self.shared_down_proj_weight = None
+ self.shared_up_proj_bias = None
+ self.shared_down_proj_bias = None
+ self.shared_expert_weighted_sum = False
+ self.shared_activation_fn = None
+
+ def set_shared_expert_weights(
+ self,
+ up_proj_weight: torch.Tensor,
+ down_proj_weight: torch.Tensor,
+ up_proj_bias: Optional[torch.Tensor] = None,
+ down_proj_bias: Optional[torch.Tensor] = None,
+ weighted_sum: bool = False,
+ activation_fn: Optional[Any] = None,
+ ):
+ self.shared_up_proj_weight = up_proj_weight
+ self.shared_down_proj_weight = down_proj_weight
+ self.shared_up_proj_bias = up_proj_bias
+ self.shared_down_proj_bias = down_proj_bias
+ self.shared_expert_weighted_sum = weighted_sum
+ self.shared_activation_fn = activation_fn
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ gradient_scale = getattr(self.experts, "gradient_scale", None)
+ alpha = getattr(self.experts, "alpha", 1.0)
+ moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+ uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
+
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = get_device_mesh(self)
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
+
+ has_parallel = (
+ expert_parallel_group is not None
+ and dist.is_initialized()
+ and dist.get_world_size(expert_parallel_group) > 1
+ )
+ forward_fn = parallel_forward_once if has_parallel else forward_once
+
+ sort_end_bit = max(
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
+ )
+ mlp_impl = getattr(self, "mlp_impl", "grouped")
+
+ output, expert_weights_out, *_ = moe_forward_with_shared_expert(
+ x=x,
+ router_weight=self.router.weight,
+ router_bias=self.router.bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=self.training,
+ w1=self.experts.gate_up_proj,
+ w2=self.experts.down_proj,
+ w1_bias=self.experts.gate_up_proj_bias,
+ w2_bias=self.experts.down_proj_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=has_parallel,
+ forward_fn=forward_fn,
+ hidden_size=self.experts.hidden_size,
+ mlp_impl=mlp_impl,
+ # Shared expert parameters
+ shared_up_proj_weight=self.shared_up_proj_weight,
+ shared_down_proj_weight=self.shared_down_proj_weight,
+ shared_up_proj_bias=self.shared_up_proj_bias,
+ shared_down_proj_bias=self.shared_down_proj_bias,
+ shared_expert_weighted_sum=self.shared_expert_weighted_sum,
+ shared_activation_fn=self.shared_activation_fn,
+ )
+ return output, expert_weights_out
diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/__init__.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b944080df810d0b0cfc571f3009b0098a651f9b7
--- /dev/null
+++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/__init__.py
@@ -0,0 +1,35 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from .binned_gather import binned_gather
+from .binned_scatter import binned_scatter
+from .cumsum import exclusive_cumsum, inclusive_cumsum
+from .gather import gather
+from .histogram import histogram
+from .padded_gather import padded_gather
+from .padded_scatter import padded_scatter
+from .repeat import repeat
+from .replicate import replicate
+from .round_up import round_up
+from .scatter import scatter
+from .sort import sort
+from .sum import sum
+from .topology import topology
+
+__all__ = [
+ 'binned_gather',
+ 'binned_scatter',
+ 'exclusive_cumsum',
+ 'inclusive_cumsum',
+ 'gather',
+ 'histogram',
+ 'padded_gather',
+ 'padded_scatter',
+ 'repeat',
+ 'replicate',
+ 'round_up',
+ 'scatter',
+ 'sort',
+ 'sum',
+ 'topology',
+]
diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/all_to_all_benchmark.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/all_to_all_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c939818edca3345f6344bbc7cef07ffe3cd0181
--- /dev/null
+++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/all_to_all_benchmark.py
@@ -0,0 +1,63 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import torch.distributed as dist
+
+# from megablocks import benchmark_util
+# from megablocks.layers.all_to_all import all_to_all
+
+from .. import benchmark_util
+from .._layers.all_to_all import all_to_all
+
+_ALL_TO_ALL_BENCHMARK = (
+ (8, 1024),
+ (16, 1024),
+ (32, 1024),
+ (64, 1024),
+ (128, 1024),
+ (256, 1024),
+ (512, 1024),
+ (1024, 1024),
+ (2 * 1024, 1024),
+ (4 * 1024, 1024),
+ (8 * 1024, 1024),
+ (16 * 1024, 1024),
+ (32 * 1024, 1024),
+ (64 * 1024, 1024),
+ (128 * 1024, 1024),
+ (256 * 1024, 1024),
+ (512 * 1024, 1024),
+ (1024 * 1024, 1024),
+)
+
+
+def benchmark_all_to_all(group, sl, hs):
+ world_size = dist.get_world_size(group)
+ assert (sl % world_size) == 0
+ send_recv_sizes = [sl // world_size] * world_size
+
+ x = torch.randn((sl, hs)).cuda().half()
+
+ details = {
+ 'world_size': world_size,
+ 'message_size (B)': send_recv_sizes[0] * hs * 2, # 2B elements.
+ }
+
+ def benchmark():
+ return all_to_all(x, send_recv_sizes, send_recv_sizes, group)
+
+ time, std = benchmark_util.benchmark_function(benchmark)
+
+ if dist.get_rank(group) == 0:
+ benchmark_util.log_benchmark('All-To-All', details, time, std)
+
+
+if __name__ == '__main__':
+ assert dist.is_available()
+ group = dist.init_process_group(backend='nccl')
+ local_rank = dist.get_rank(group)
+ torch.cuda.set_device(local_rank)
+
+ for args in _ALL_TO_ALL_BENCHMARK:
+ benchmark_all_to_all(group, *args)
diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/binned_gather.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/binned_gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..189a7fa3518d660f29ea32e7a04827164af98d60
--- /dev/null
+++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/binned_gather.py
@@ -0,0 +1,37 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for binned_gather kernel.
+class BinnedGatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bins: torch.Tensor,
+ bin_size: int,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bins)
+ ctx.top_k = top_k
+ return kernels.binned_gather(x, indices, None, bins, bin_size, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ indices, bins = ctx.saved_tensors
+ out = kernels.binned_scatter(grad, indices, None, bins, ctx.top_k)
+ return out, None, None, None, None
+
+
+binned_gather = BinnedGatherOp.apply
diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/binned_scatter.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/binned_scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb937c0c106662ce8108c1cb926f8f063b163d3d
--- /dev/null
+++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/binned_scatter.py
@@ -0,0 +1,59 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for binned_scatter kernel.
+class BinnedScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ):
+ assert len(x.size()) == 3
+ ctx.bin_size = x.size(1)
+ ctx.top_k = top_k
+
+ # TODO(tgale): Don't save 'x' for backwards if we don't need to
+ # calculate the gradient w.r.t. 'weights'.
+ ctx.save_for_backward(x, indices, weights, bins)
+ return kernels.binned_scatter(x, indices, weights, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ x, indices, weights, bins = ctx.saved_tensors
+ out = kernels.binned_gather(
+ grad,
+ indices,
+ weights,
+ bins,
+ ctx.bin_size,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[2]:
+ wgrad = kernels.binned_scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bins,
+ ctx.top_k,
+ )
+ return out, None, wgrad, None, None
+
+
+binned_scatter = BinnedScatterOp.apply
diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/cumsum.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/cumsum.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2b7572391e20045d335cf7337246e8a9b9f57ef
--- /dev/null
+++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/cumsum.py
@@ -0,0 +1,52 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ # import megablocks_ops as ops # type: ignore
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrappers for cumsum kernels.
+# NOTE: Does not support gradients.
+class ExclusiveCumsumOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, dim: int):
+ if len(x.size()) == 1:
+ x = x.view([1, -1])
+ out = torch.empty_like(x)
+ ops.exclusive_cumsum(x, 1, out)
+ return out.squeeze()
+ out = torch.empty_like(x)
+ ops.exclusive_cumsum(x, dim, out)
+ return out
+
+
+exclusive_cumsum = ExclusiveCumsumOp.apply
+
+
+class InclusiveCumsumOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, dim: int) -> torch.Tensor:
+ if len(x.size()) == 1:
+ x = x.view([1, -1])
+ out = torch.empty_like(x)
+ ops.inclusive_cumsum(x, 1, out)
+ return out.squeeze()
+ out = torch.empty_like(x)
+ ops.inclusive_cumsum(x, dim, out)
+ return out
+
+
+inclusive_cumsum = InclusiveCumsumOp.apply
diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/gather.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1f87c1e7bed8d3589dd790805234976e0b05898
--- /dev/null
+++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/gather.py
@@ -0,0 +1,38 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for gather kernel.
+class GatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bin_ids, bins)
+ ctx.top_k = top_k
+ return kernels.gather(x, indices, bin_ids, None, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+
+ indices, bin_ids, bins = ctx.saved_tensors
+ out = kernels.scatter(grad, indices, bin_ids, None, bins, ctx.top_k)
+ return out, None, None, None, None, None
+
+
+gather = GatherOp.apply
diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/histogram.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/histogram.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b3f058ec373cbba7555704fb5e4212c3cc75d9d
--- /dev/null
+++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/histogram.py
@@ -0,0 +1,27 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for histogram kernel.
+# NOTE: Does not support gradients.
+class HistogramOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, max_val: float):
+ return ops.histogram(x, max_val)
+
+
+histogram = HistogramOp.apply
diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/histogram_benchmark.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/histogram_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..c57b7bf8228e01237236748147368b09ffdf8072
--- /dev/null
+++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/histogram_benchmark.py
@@ -0,0 +1,78 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import numpy as np
+import torch
+from absl.testing import parameterized
+
+from .. import ops
+
+_HISTOGRAM_TESTS = (
+ (16384, torch.int32, 2),
+ (16384, torch.int32, 4),
+ (16384, torch.int32, 8),
+ (16384, torch.int32, 16),
+ (16384, torch.int32, 32),
+ (16384, torch.int32, 64),
+ (16384, torch.int32, 128),
+ (16384, torch.int32, 256),
+)
+
+
+def benchmark_function(fn, iterations=10):
+ # Run once to get rid of startup overhead.
+ fn()
+ times = []
+ for _ in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+ start.record()
+ fn()
+ end.record()
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ times = np.array(times)
+ return times.mean(), times.std(), times.max(), times.min()
+
+
+def log_benchmark(arguments, mean_t, std_t):
+ print('=' * 60)
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t))
+ print('=' * 60)
+
+
+class HistogramBenchmark(parameterized.TestCase):
+
+ @parameterized.parameters(*_HISTOGRAM_TESTS)
+ def testHistogram(self, n, dtype, max_val):
+ x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
+
+ mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.histogram(x, max_val),)
+ arguments = {
+ 'n': n,
+ 'dtype': dtype,
+ 'max_val': max_val,
+ }
+ log_benchmark(arguments, mean_t, std_t)
+
+ @parameterized.parameters(*_HISTOGRAM_TESTS)
+ def testTorchHistogram(self, n, dtype, max_val):
+ x = torch.randint(0, 128, (n,)).cuda().to(dtype)
+
+ mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.histc(x, max_val, 0, max_val - 1),)
+ arguments = {
+ 'n': n,
+ 'dtype': dtype,
+ 'max_val': max_val,
+ }
+ log_benchmark(arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/matmul_benchmark.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/matmul_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ccc5dcec5e9a663794fad944c45285869c4d1c1
--- /dev/null
+++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/matmul_benchmark.py
@@ -0,0 +1,415 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+
+# import stk
+
+# try:
+# import stk
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/ops/matmul_benchmark.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+
+# Calling tensor.t() calls tensor.transpose(0, 1) which calls
+# torch.as_strided(...). Circumvent this chain to avoid an overhead
+# this adds.
+def transpose_view(x):
+ return torch.as_strided(
+ x,
+ (x.shape[1], x.shape[0]),
+ (x.stride()[1], x.stride()[0]),
+ )
+
+
+_MATMUL_TESTS = (
+ (64 * 1024, 512, 2048, 64),
+ (32 * 1024, 768, 3072, 64),
+ (8 * 1024, 1024, 4096, 64),
+ (4 * 2048, 4096, 4 * 4096, 4),
+)
+
+
+def log_benchmark(name, arguments, time, std, flops):
+ benchmark_util.log_benchmark(name, arguments, time, std)
+ print('flops = {:.2f}B'.format(flops / 1e9))
+ print('throughput = {:.2f}T'.format(flops / 1e9 / time))
+ print('=' * 60)
+
+
+class MatmulBenchmark(parameterized.TestCase):
+
+ def build_sparse_matrix(self, x, padded_bins, fhs, ne):
+ blocking = 128
+ padded_tokens, _ = x.size()
+ assert padded_tokens % blocking == 0
+ assert fhs % blocking == 0
+
+ # Offsets for the sparse matrix. All rows have the
+ # same number of nonzero blocks dictated by the
+ # dimensionality of a single expert.
+ block_rows = padded_tokens // blocking
+ blocks_per_row = fhs // blocking
+ offsets = torch.arange(
+ 0,
+ block_rows * blocks_per_row + 1,
+ blocks_per_row,
+ dtype=torch.int32,
+ device=x.device,
+ )
+
+ # Indices for the sparse matrix. The indices for
+ # the intermediate matrix are dynamic depending
+ # on the mapping of tokens to experts.
+ column_indices = ops.topology(
+ padded_bins,
+ blocking,
+ block_rows,
+ blocks_per_row,
+ )
+ data = torch.empty(
+ column_indices.numel(),
+ blocking,
+ blocking,
+ dtype=torch.float16,
+ device=x.device,
+ )
+ shape = (padded_tokens, fhs * ne)
+ row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
+ return stk.Matrix(shape, data, row_indices, column_indices, offsets)
+
+ def build_input_matrix(self, sl, hs, ne):
+ x = torch.randn((sl, hs)).cuda().half()
+
+ # Assign tokens to experts uniformly.
+ top_expert = torch.arange(0, sl).cuda().int() % ne
+
+ bin_ids, indices = ops.sort(top_expert)
+ tokens_per_expert = ops.histogram(top_expert, ne)
+ padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ out = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, 1)
+ return out, padded_bins
+
+ def build_weight_matrix(self, ne, hs, fhs):
+ return torch.randn((hs, ne * fhs)).cuda().half()
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear0_Fwd_SDD_NT(self, sl, hs, fhs, ne):
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
+ w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+ topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+ w = transpose_view(w)
+
+ def benchmark():
+ return stk.ops.sdd(x, w, topo)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '0::Fwd::SDD::NT',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * fhs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear0_GradX_DSD_NN(self, sl, hs, fhs, ne):
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
+ w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+ topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+
+ def benchmark():
+ return stk.ops.dsd(topo, w)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '0::GradX::DSD::NN',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * fhs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear0_GradW_DSD_TN(self, sl, hs, fhs, ne):
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
+ topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+ topo = topo.t()
+
+ def benchmark():
+ return stk.ops.dsd(topo, x)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '0::GradW::DSD::TN',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * fhs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear1_Fwd_DSD_NN(self, sl, hs, fhs, ne):
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
+ w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+ x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+
+ def benchmark():
+ return stk.ops.dsd(x, w)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '1::Fwd::DSD::NN',
+ arguments,
+ mean_t,
+ std_t,
+ x.nnz * hs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear1_GradX_SDD_NT(self, sl, hs, fhs, ne):
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
+ w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+ x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+ out = stk.ops.dsd(x, w)
+ w = transpose_view(w)
+
+ def benchmark():
+ return stk.ops.sdd(out, w, x)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '1::GradX::SDD::NT',
+ arguments,
+ mean_t,
+ std_t,
+ x.nnz * hs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear1_GradW_DSD_TN(self, sl, hs, fhs, ne):
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
+ w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+ x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+ out = stk.ops.dsd(x, w)
+ x = x.t()
+
+ def benchmark():
+ return stk.ops.dsd(x, out)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '1::GradW::DSD::TN',
+ arguments,
+ mean_t,
+ std_t,
+ x.nnz * hs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear0_Fwd_DDD_NT(self, sl, hs, fhs, ne):
+ assert (sl % ne) == 0
+ x = torch.randn((ne, sl // ne, hs)).cuda().half()
+ w = torch.randn((ne, hs, fhs)).cuda().half()
+
+ w = w.transpose(1, 2).contiguous()
+ w = w.transpose(1, 2)
+
+ def benchmark():
+ return torch.bmm(x, w)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '0::Fwd:DDD::NT',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * fhs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear0_GradX_DDD_NN(self, sl, hs, fhs, ne):
+ assert (sl % ne) == 0
+ x = torch.randn((ne, sl // ne, hs)).cuda().half()
+ w = torch.randn((ne, hs, fhs)).cuda().half()
+ out = torch.bmm(x, w)
+ w = w.transpose(1, 2).contiguous()
+
+ def benchmark():
+ return torch.bmm(out, w)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '0:GradX:DDD::NN',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * fhs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear0_GradW_DDD_TN(self, sl, hs, fhs, ne):
+ assert (sl % ne) == 0
+ x = torch.randn((ne, sl // ne, hs)).cuda().half()
+ w = torch.randn((ne, hs, fhs)).cuda().half()
+ out = torch.bmm(x, w)
+ out = out.transpose(1, 2)
+
+ def benchmark():
+ return torch.bmm(out, x)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '0:GradW:DDD::TN',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * fhs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear1_Fwd_DDD_NN(self, sl, hs, fhs, ne):
+ assert (sl % ne) == 0
+ x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+ w = torch.randn((ne, fhs, hs)).cuda().half()
+
+ def benchmark():
+ return torch.bmm(x, w)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '1::Fwd::DDD::NN',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * hs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear1_GradX_DDD_NT(self, sl, hs, fhs, ne):
+ assert (sl % ne) == 0
+ x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+ w = torch.randn((ne, fhs, hs)).cuda().half()
+ out = torch.bmm(x, w)
+ w = torch.transpose(w, 1, 2)
+
+ def benchmark():
+ return torch.bmm(out, w)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '1::GradX::DDD::NT',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * hs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear1_GradW_DDD_TN(self, sl, hs, fhs, ne):
+ assert (sl % ne) == 0
+ x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+ w = torch.randn((ne, fhs, hs)).cuda().half()
+ out = torch.bmm(x, w)
+ x = torch.transpose(x, 1, 2)
+
+ def benchmark():
+ return torch.bmm(x, out)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '1::GradW::DDD::TN',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * hs * 2,
+ )
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/padded_gather.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/padded_gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1cf4047c9494394d2a3884ba8830179013db7ff
--- /dev/null
+++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/padded_gather.py
@@ -0,0 +1,55 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for padded_gather kernel.
+class PaddedGatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bin_ids, bins, padded_bins)
+ ctx.top_k = top_k
+ return kernels.padded_gather(
+ x,
+ indices,
+ bin_ids,
+ None,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+
+ indices, bin_ids, bins, padded_bins = ctx.saved_tensors
+ out = kernels.padded_scatter(
+ grad,
+ indices,
+ bin_ids,
+ None,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+ return out, None, None, None, None, None
+
+
+padded_gather = PaddedGatherOp.apply
diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/padded_scatter.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/padded_scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..61e021b81497e472cda5d72bdac557a0ca92d262
--- /dev/null
+++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/padded_scatter.py
@@ -0,0 +1,98 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for padded_scatter kernel.
+class PaddedScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+ ):
+ maybe_x = [x] if ctx.needs_input_grad[3] else []
+ ctx.save_for_backward(
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ *maybe_x,
+ )
+ ctx.top_k = top_k
+ ctx.x_shape = x.shape
+ return kernels.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ saved_tensors = ctx.saved_tensors
+
+ indices, bin_ids, weights, bins, padded_bins = saved_tensors[:5]
+ dgrad = None
+ if ctx.needs_input_grad[0]:
+ dgrad = kernels.padded_gather(
+ grad,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[3]: # need wgrad
+ x = saved_tensors[-1]
+ wgrad = kernels.padded_scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+ return dgrad, None, None, wgrad, None, None, None, None
+
+
+def padded_scatter(
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+):
+ return PaddedScatterOp.apply(
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..c575cfe7487d346ba9ec18bbb7ef17f2eb77ec51
--- /dev/null
+++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py
@@ -0,0 +1,66 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import torch
+from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+_PADDED_SCATTER_BENCHMARK = (
+ # dMoE-Medium, 8-way EMP.
+ (1024 * 16, 1024, 8, 4),
+ # dMoE-Medium, post-all-to-all.
+ (1024 * 16 * 4, 1024, 8, 1),
+)
+
+
+class PaddedScatterTest(parameterized.TestCase):
+
+ @parameterized.parameters(*_PADDED_SCATTER_BENCHMARK)
+ def testPaddedScatter(self, sl, hs, ne, top_k):
+ # Create the data and indices.
+ x = torch.randn((sl, hs)).cuda().half()
+
+ # Randomly assign tokens to experts.
+ top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int()
+ bin_ids, indices = ops.sort(top_expert)
+ tokens_per_expert = ops.histogram(top_expert, ne)
+ padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+
+ # Sample weights for the scatter reduce.
+ weights = torch.rand((sl * top_k,)).cuda().half()
+
+ # Gather the data to prepare for backwards.
+ x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
+
+ def benchmark():
+ return ops.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ time, std = benchmark_util.benchmark_function(benchmark)
+ benchmark_util.log_benchmark(
+ 'Padded Scatter',
+ {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'num_experts': ne,
+ 'top_k': top_k,
+ },
+ time,
+ std,
+ )
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/permute_benchmark.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/permute_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..6536eeeae402659a087e5c51ef9840627af56501
--- /dev/null
+++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/permute_benchmark.py
@@ -0,0 +1,149 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import torch
+from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+_PERMUTE_TESTS = (
+ (16384, 768, 2),
+ (16384, 768, 4),
+ (16384, 768, 8),
+ (16384, 768, 16),
+ (16384, 768, 32),
+ (16384, 768, 64),
+ (16384, 768, 128),
+ (16384 * 8, 768, 2),
+ (16384 * 8, 768, 4),
+ (16384 * 8, 768, 8),
+ (16384 * 8, 768, 16),
+ (16384 * 8, 768, 32),
+ (16384 * 8, 768, 64),
+ (16384 * 8, 768, 128),
+)
+
+
+class PermuteBenchmark(parameterized.TestCase):
+
+ @parameterized.parameters(*_PERMUTE_TESTS)
+ def testBinnedGather(self, sl, hs, ne):
+ # NOTE: Capacity factor == 1.
+ ec = sl // ne
+
+ # Create the data and indices.
+ x = torch.randn((sl, hs)).cuda().half()
+ top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+ bin_ids, indices = ops.sort(top_expert)
+ tokens_per_expert = ops.histogram(indices, ne)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+
+ def benchmark():
+ return ops.binned_gather(x, indices, bins, ec)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'num_experts': ne,
+ }
+ benchmark_util.log_benchmark('BinnedGather', arguments, mean_t, std_t)
+
+ @parameterized.parameters(*_PERMUTE_TESTS)
+ def testBinnedScatter(self, sl, hs, ne):
+ # NOTE: Capacity factor == 1.
+ ec = sl // ne
+
+ # Create the data and indices.
+ x = torch.randn((sl, hs)).cuda().half()
+ top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+ bin_ids, indices = ops.sort(top_expert)
+ tokens_per_expert = ops.histogram(indices, ne)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ x = ops.binned_gather(x, indices, bins, ec)
+
+ def benchmark():
+ return ops.binned_scatter(x, indices, bins)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'num_experts': ne,
+ }
+ benchmark_util.log_benchmark('BinnedScatter', arguments, mean_t, std_t)
+
+ @parameterized.parameters(*_PERMUTE_TESTS)
+ def testPaddedGather(self, sl, hs, ne):
+ # Create the data and indices.
+ x = torch.randn((sl, hs)).cuda().half()
+
+ # Randomly assign tokens to experts.
+ top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+ bin_ids, indices = ops.sort(top_expert)
+ tokens_per_expert = ops.histogram(top_expert, ne)
+ padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+
+ def benchmark():
+ return ops.padded_gather(x, indices, bin_ids, bins, padded_bins)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'num_experts': ne,
+ }
+ benchmark_util.log_benchmark('PaddedGather', arguments, mean_t, std_t)
+
+ @parameterized.parameters(*_PERMUTE_TESTS)
+ def testPaddedScatter(self, sl, hs, ne):
+ # Create the data and indices.
+ x = torch.randn((sl, hs)).cuda().half()
+
+ # Randomly assign tokens to experts.
+ top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+ bin_ids, indices = ops.sort(top_expert)
+ tokens_per_expert = ops.histogram(top_expert, ne)
+ padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins)
+
+ def benchmark():
+ return ops.padded_scatter(x, indices, bin_ids, bins, padded_bins)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'num_experts': ne,
+ }
+ benchmark_util.log_benchmark('PaddedScatter', arguments, mean_t, std_t)
+
+ @parameterized.parameters(*_PERMUTE_TESTS)
+ def testCopy(self, sl, hs, ne):
+ # NOTE: Capacity factor == 1.
+ # ec = sl // ne
+
+ # Create the data and indices.
+ x = torch.randn((sl, hs)).cuda().half()
+ y = x.clone()
+
+ def benchmark():
+ return y.copy_(x)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'num_experts': ne,
+ }
+ benchmark_util.log_benchmark('Copy', arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/repeat.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/repeat.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e9e09de5f857d51cd758ab30b2f3a846d6f9275
--- /dev/null
+++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/repeat.py
@@ -0,0 +1,10 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+
+def repeat(x: torch.Tensor, tiling: torch.Size):
+ if all((t == 1 for t in tiling)):
+ return x
+ return x.repeat(*tiling)
diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/replicate.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/replicate.py
new file mode 100644
index 0000000000000000000000000000000000000000..26daf0eede330603a4b8ea7167faf1411d07ca93
--- /dev/null
+++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/replicate.py
@@ -0,0 +1,36 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for replicate kernel.
+class ReplicateOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, bins: torch.Tensor, num_outputs: int):
+ ctx.save_for_backward(bins)
+ out = torch.empty((x.shape[0], num_outputs), dtype=x.dtype, device=x.device)
+ ops.replicate_forward(x, bins, out)
+ return out
+
+ @staticmethod
+ def backward(ctx: Any, grad: torch.Tensor):
+ bins, = ctx.saved_tensors
+ out = torch.empty((grad.shape[0], bins.shape[0]), dtype=grad.dtype, device=grad.device)
+ ops.replicate_backward(grad, bins, out)
+ return out, None, None
+
+
+replicate = ReplicateOp.apply
diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/round_up.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/round_up.py
new file mode 100644
index 0000000000000000000000000000000000000000..6cf6bc873c9f448c5fa9126ebcfd66e8688002af
--- /dev/null
+++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/round_up.py
@@ -0,0 +1,14 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+
+def round_up(x: torch.Tensor, value: int):
+ assert isinstance(value, int)
+ assert x.dtype == torch.int32
+
+ # TODO(tgale): If this becomes and issue
+ # do this in a custom kernel. We only expect
+ # to use this on arrays of less than 1k elements.
+ return torch.div(x + (value - 1), value, rounding_mode='trunc') * value
diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/scatter.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4605d9b46f387761b070352365f223dbfe69d47
--- /dev/null
+++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/scatter.py
@@ -0,0 +1,72 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Optional
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for scatter kernel.
+class ScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ) -> torch.Tensor:
+ maybe_x = [x] if ctx.needs_input_grad[3] else []
+ ctx.save_for_backward(indices, bin_ids, weights, bins, *maybe_x)
+ ctx.top_k = top_k
+ ctx.x_shape = x.shape
+ return kernels.scatter(x, indices, bin_ids, weights, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ saved_tensors = ctx.saved_tensors
+
+ indices, bin_ids, weights, bins = saved_tensors[:4]
+ dgrad = None
+ if ctx.needs_input_grad[0]:
+ dgrad = kernels.gather(
+ grad,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[3]: # need wgrad
+ x = saved_tensors[-1]
+ wgrad = kernels.scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bin_ids,
+ bins,
+ ctx.top_k,
+ )
+ return dgrad, None, None, wgrad, None, None, None
+
+
+def scatter(
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+) -> Optional[torch.Tensor]:
+ return ScatterOp.apply(x, indices, bin_ids, weights, bins, top_k)
diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/sort.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/sort.py
new file mode 100644
index 0000000000000000000000000000000000000000..bda3bf64283e39533c2eae3627e76bb2d0262c9f
--- /dev/null
+++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/sort.py
@@ -0,0 +1,38 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Optional, Tuple
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+_BITS_FOR_DTYPE = {
+ torch.int16: 16,
+ torch.int32: 32,
+ torch.int64: 64,
+}
+
+
+# Autograd wrapper for sort kernel.
+# NOTE: Does not support gradients.
+class SortOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, end_bit: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
+ if end_bit is None:
+ end_bit = _BITS_FOR_DTYPE[x.dtype]
+ x_out = torch.empty_like(x)
+ iota_out = torch.empty_like(x)
+ ops.sort(x, end_bit, x_out, iota_out)
+ return (x_out, iota_out)
+
+
+sort = SortOp.apply
diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/sort_benchmark.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/sort_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..a92ff957d4c552c6e61d9279a7989795472af7b7
--- /dev/null
+++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/sort_benchmark.py
@@ -0,0 +1,85 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import numpy as np
+import torch
+from absl.testing import parameterized
+
+from .. import ops
+
+_SORT_TESTS = (
+ (16384, torch.int32, None),
+ (16384, torch.int32, 2),
+ (16384, torch.int32, 128),
+)
+
+_BASELINE_SORT_TESTS = ((16384,),)
+
+
+def numpy_dtype(dtype):
+ types = {
+ torch.int16: np.int16,
+ torch.int32: np.int32,
+ torch.int64: np.int64,
+ }
+ return types[dtype]
+
+
+def benchmark_function(fn, iterations=10):
+ # Run once to get rid of startup overhead.
+ fn()
+ times = []
+ for _ in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+ start.record()
+ fn()
+ end.record()
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ times = np.array(times)
+ return times.mean(), times.std(), times.max(), times.min()
+
+
+def log_benchmark(arguments, mean_t, std_t):
+ print('=' * 60)
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t))
+ print('=' * 60)
+
+
+class SortBenchmark(parameterized.TestCase):
+
+ @parameterized.parameters(*_SORT_TESTS)
+ def testSort(self, n, dtype, max_val):
+ if max_val is None:
+ max_val = np.iinfo(numpy_dtype(dtype)).max
+ end_bit = int(np.ceil(np.log2(max_val)))
+ x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
+
+ mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.sort(x, end_bit),)
+ arguments = {
+ 'n': n,
+ 'dtype': dtype,
+ 'max_val': max_val,
+ }
+ log_benchmark(arguments, mean_t, std_t)
+
+ @parameterized.parameters(*_BASELINE_SORT_TESTS)
+ def testTorchSort(self, n):
+ x = torch.randint(0, 128, (n,)).cuda().to(torch.int32)
+
+ mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.sort(x))
+ arguments = {
+ 'n': n,
+ }
+ log_benchmark(arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/stk_autocast.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/stk_autocast.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a3626e5e0eec51339c95a448bca84be14a2ca93
--- /dev/null
+++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/stk_autocast.py
@@ -0,0 +1,39 @@
+# vendored from
+# https://github.com/stanford-futuredata/stk/blob/736313768ef697ce13a0594a41b2512a0fbc9884/stk/backend/autocast.py
+import functools
+import torch
+
+
+def _is_eligible(x):
+ return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64)
+
+
+def _cast(x, dtype):
+ if isinstance(x, torch.Tensor) and _is_eligible(x):
+ return x.to(dtype)
+ elif isinstance(x, map):
+ return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()}
+ elif isinstance(x, list) or isinstance(x, tuple):
+ return type(x)(map(lambda y: _cast(y, dtype), x))
+ return x
+
+
+def custom_fwd(fwd):
+ """Wrap a custom autograd function that always uses autocast dtype."""
+
+ @functools.wraps(fwd)
+ def decorate_fwd(*args, **kwargs):
+ if torch.is_autocast_enabled():
+ with torch.autocast(device_type="cuda", enabled=False):
+ dtype = torch.get_autocast_gpu_dtype()
+ return fwd(*_cast(args, dtype), **_cast(kwargs, dtype))
+ return fwd(*args, **kwargs)
+ return decorate_fwd
+
+
+def custom_bwd(bwd):
+ @functools.wraps(bwd)
+ def decorate_bwd(*args, **kwargs):
+ with torch.autocast(device_type="cuda", enabled=False):
+ return bwd(*args, **kwargs)
+ return decorate_bwd
\ No newline at end of file
diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/sum.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/sum.py
new file mode 100644
index 0000000000000000000000000000000000000000..e00c1aa68e1193f5b72f75a2edc37de8d505facc
--- /dev/null
+++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/sum.py
@@ -0,0 +1,9 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+import torch
+
+
+def sum(x: torch.Tensor, dim: int = 0):
+ if x.shape[dim] == 1:
+ return x.squeeze(dim=dim)
+ return x.sum(dim=dim)
diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/topology.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/topology.py
new file mode 100644
index 0000000000000000000000000000000000000000..76a50d3164db20534b099dcb4d8487a7aef25d15
--- /dev/null
+++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/topology.py
@@ -0,0 +1,45 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for topology kernel.
+# NOTE: Does not support gradients.
+class TopologyOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(
+ ctx: Any,
+ padded_bins: torch.Tensor,
+ block_size: int,
+ output_block_rows: int,
+ output_block_columns: int,
+ ):
+ out = torch.empty(
+ output_block_rows * output_block_columns,
+ dtype=torch.int16,
+ device=padded_bins.device,
+ )
+ ops.indices(
+ padded_bins,
+ block_size,
+ output_block_rows,
+ output_block_columns,
+ out,
+ )
+ return out
+
+
+topology = TopologyOp.apply
diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/__init__.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..73c40c267b1c3f4949e9c957a5d2c9f682dfc1a6
--- /dev/null
+++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/__init__.py
@@ -0,0 +1,7 @@
+# import stk.random
+# import stk.ops
+# from stk.matrix import Matrix
+
+from . import random
+from . import ops
+from .matrix import Matrix
diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/backend/__init__.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/backend/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/backend/autocast.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/backend/autocast.py
new file mode 100644
index 0000000000000000000000000000000000000000..97f6e919a60f3fd579ed0215031008d14111dc96
--- /dev/null
+++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/backend/autocast.py
@@ -0,0 +1,37 @@
+import functools
+import torch
+
+
+def _is_eligible(x):
+ return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64)
+
+
+def _cast(x, dtype):
+ if isinstance(x, torch.Tensor) and _is_eligible(x):
+ return x.to(dtype)
+ elif isinstance(x, map):
+ return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()}
+ elif isinstance(x, list) or isinstance(x, tuple):
+ return type(x)(map(lambda y: _cast(y, dtype), x))
+ return x
+
+
+def custom_fwd(fwd):
+ """Wrap a custom autograd function that always uses autocast dtype."""
+
+ @functools.wraps(fwd)
+ def decorate_fwd(*args, **kwargs):
+ if torch.is_autocast_enabled():
+ with torch.autocast(device_type="cuda", enabled=False):
+ dtype = torch.get_autocast_gpu_dtype()
+ return fwd(*_cast(args, dtype), **_cast(kwargs, dtype))
+ return fwd(*args, **kwargs)
+ return decorate_fwd
+
+
+def custom_bwd(bwd):
+ @functools.wraps(bwd)
+ def decorate_bwd(*args, **kwargs):
+ with torch.autocast(device_type="cuda", enabled=False):
+ return bwd(*args, **kwargs)
+ return decorate_bwd
diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/backend/sputnik.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/backend/sputnik.py
new file mode 100644
index 0000000000000000000000000000000000000000..220c947bc1e932e8c77cc30f4069e9930f1aa962
--- /dev/null
+++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/backend/sputnik.py
@@ -0,0 +1,316 @@
+import torch
+
+from ..backend import triton_kernels as backend
+from ..backend.autocast import custom_bwd, custom_fwd
+
+
+def _standardize_shape(x, transpose):
+ if transpose:
+ return torch.Size((x[1], x[0]))
+ return x
+
+
+def _sparse_transpose(x):
+ return (torch.Size((x[0][1], x[0][0])), ) + x[1:]
+
+
+def _transpose_helper(x, transpose):
+ if isinstance(x, torch.Tensor):
+ return x.t() if transpose else x
+ if transpose:
+ x = _sparse_transpose(x)
+ return x + (transpose,)
+
+
+def _wrap(x):
+ if isinstance(x, torch.Tensor):
+ return (x,)
+ return x
+
+
+def _is_transposed(x):
+ return (not x.is_contiguous() and
+ x.stride()[0] == 1 and
+ x.stride()[1] == x.size()[0])
+
+
+def _call_helper(op, out, a, b, trans_a, trans_b):
+ args = (_wrap(_transpose_helper(a, trans_a)) +
+ _wrap(_transpose_helper(b, trans_b)))
+ if isinstance(out, tuple):
+ args = args + out
+ return op(*args)
+
+
+def _preprocess_inputs(lhs, rhs, dy):
+ if isinstance(lhs, torch.Tensor) and _is_transposed(lhs):
+ lhs = lhs.t()
+ if isinstance(rhs, torch.Tensor) and _is_transposed(rhs):
+ rhs = rhs.t()
+ if (isinstance(dy, torch.Tensor) and
+ not dy.is_contiguous() and
+ not _is_transposed(dy)):
+ dy = dy.contiguous()
+ if isinstance(dy, tuple) and not dy[1].is_contiguous():
+ dy = (dy[0], dy[1].contiguous()) + dy[2:]
+ return lhs, rhs, dy
+
+
+def _postprocess_outputs(x, transpose, grad):
+ if isinstance(x, torch.Tensor) and transpose:
+ return grad.t()
+ return grad
+
+
+def _lhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs):
+ lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy)
+
+ a, b = (rhs, dy) if trans_lhs else (dy, rhs)
+ trans_a = trans_lhs and trans_rhs
+ trans_b = trans_lhs or not trans_rhs
+ out = _call_helper(op, lhs, a, b, trans_a, trans_b)
+ return _postprocess_outputs(lhs, trans_lhs, out)
+
+
+def _rhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs):
+ lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy)
+
+ a, b = (dy, lhs) if trans_rhs else (lhs, dy)
+ trans_a = not trans_lhs or trans_rhs
+ trans_b = trans_lhs and trans_rhs
+ out = _call_helper(op, rhs, a, b, trans_a, trans_b)
+ return _postprocess_outputs(rhs, trans_rhs, out)
+
+
+class DSD(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs):
+ ctx.save_for_backward(data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ rhs)
+ ctx.shape = _standardize_shape(shape, transpose_a)
+ ctx.transpose_a = transpose_a
+
+ out = torch.empty(
+ (shape[0], rhs.size()[1]),
+ dtype=rhs.dtype,
+ device=rhs.device)
+
+ backend.dsd(shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs,
+ out)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs = (ctx.shape,) + saved_tensors[:-1]
+ rhs = saved_tensors[-1]
+ trans_a = ctx.transpose_a
+ trans_b = _is_transposed(rhs)
+
+ ddata = None
+ if ctx.needs_input_grad[1]:
+ ddata = _lhs_gradient(sdd,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ drhs = None
+ if ctx.needs_input_grad[-1]:
+ op = dds if trans_b else dsd
+ drhs = _rhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return None, ddata, None, None, None, None, None, None, None, drhs
+
+
+dsd = DSD.apply
+
+
+class DDS(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b):
+ ctx.save_for_backward(lhs,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t)
+ ctx.shape = _standardize_shape(shape, transpose_b)
+ ctx.transpose_b = transpose_b
+ out = torch.empty((lhs.size()[0], shape[1]),
+ dtype=lhs.dtype,
+ device=lhs.device)
+ backend.dds(lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b,
+ out)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs = saved_tensors[0]
+ rhs = (ctx.shape,) + saved_tensors[1:]
+ trans_a = _is_transposed(lhs)
+ trans_b = ctx.transpose_b
+
+ dlhs = None
+ if ctx.needs_input_grad[0]:
+ op = dsd if trans_a else dds
+ dlhs = _lhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ ddata = None
+ if ctx.needs_input_grad[2]:
+ ddata = _rhs_gradient(sdd,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return dlhs, None, ddata, None, None, None, None, None, None, None
+
+
+dds = DDS.apply
+
+
+class SDD(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ lhs,
+ rhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t):
+ ctx.save_for_backward(
+ lhs,
+ rhs,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t)
+ ctx.shape = shape
+ out = torch.empty(
+ data.shape,
+ dtype=lhs.dtype,
+ device=lhs.device)
+ backend.sdd(lhs,
+ rhs,
+ shape,
+ out,
+ offsets,
+ row_indices,
+ column_indices)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs, rhs = saved_tensors[:2]
+ dy = (ctx.shape, dy) + saved_tensors[2:]
+ trans_a = _is_transposed(lhs)
+ trans_b = _is_transposed(rhs)
+
+ dlhs = None
+ if ctx.needs_input_grad[0]:
+ op = dds if trans_a else dsd
+ dlhs = _lhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ drhs = None
+ if ctx.needs_input_grad[1]:
+ op = dsd if trans_b else dds
+ drhs = _rhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return dlhs, drhs, None, None, None, None, None, None, None, None
+
+
+sdd = SDD.apply
+
+class RowIndices(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, shape, data, offsets, column_indices):
+ out = torch.empty(
+ column_indices.shape,
+ dtype=column_indices.dtype,
+ device=column_indices.device)
+ backend.row_indices(shape, data, offsets, column_indices, out)
+ return out
+
+
+row_indices = RowIndices.apply
diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/backend/triton_kernels.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/backend/triton_kernels.py
new file mode 100644
index 0000000000000000000000000000000000000000..c535309f3321249f475367164a558f94a4f8eb86
--- /dev/null
+++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/backend/triton_kernels.py
@@ -0,0 +1,393 @@
+import torch
+import triton
+import triton.language as tl
+from dataclasses import dataclass
+
+@dataclass
+class TritonConfig:
+ BLOCK_M: int = 128
+ BLOCK_N: int = 128
+ BLOCK_K: int = 32
+ BLOCK_SIZE: int = 128
+ NUM_STAGES: int = 4
+ NUM_WARPS: int = 4
+
+def _validate_matmul_dims(M: int, K: int, N: int):
+ error_string = "incompatible dimensions: tensor has dim with length: {}, which must be divisible by {}"
+ assert M % TritonConfig.BLOCK_M == 0, error_string.format(M, TritonConfig.BLOCK_M)
+ assert K % TritonConfig.BLOCK_K == 0, error_string.format(K, TritonConfig.BLOCK_K)
+ assert N % TritonConfig.BLOCK_N == 0, error_string.format(N, TritonConfig.BLOCK_N)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _sdd_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+ # matrix multiplication
+ pid = tl.program_id(0)
+ pid_m = tl.load(row_indices + pid)
+ pid_n = tl.load(column_indices + pid)
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
+ rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
+ rk = tl.arange(0, BLOCK_K)
+ # pointers
+ A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
+ B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ for k in range(0, tl.cdiv(K, BLOCK_K)):
+ a = tl.load(A)
+ b = tl.load(B)
+ acc += tl.dot(a, b)
+ A += BLOCK_K * stride_ak
+ B += BLOCK_K * stride_bk
+ #Store to sparse matrix
+ acc = acc.to(C.dtype.element_ty)
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+ cm = tl.arange(0, BLOCK_M)
+ cn = tl.arange(0, BLOCK_N)
+ C = C + pid * BLOCK_ELEMENTS + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _dsd_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices, offsets,
+ block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+
+ # matrix multiplication
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+
+ num_pid_m = tl.num_programs(0)
+ num_pid_n = tl.num_programs(1)
+ pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M)
+
+ start_inx = tl.load(offsets + pid_m)
+ end_inx = tl.load(offsets + pid_m + 1)
+
+ # pointers to sparse matrix
+ rm = tl.arange(0, BLOCK_M)
+ rak = tl.arange(0, BLOCK_K)
+
+ A += (rm[:, None] * stride_am + rak[None, :] * stride_ak)
+
+ # pointers to dense matrix
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ rbk = tl.arange(0, BLOCK_K)
+ B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn)
+
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K)
+
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+ ak_sub_incr = BLOCK_K * stride_ak
+ bk_sub_incr = BLOCK_K * stride_bk
+ bk_block_incr = BLOCK_SIZE * stride_bk
+
+ for k in range(nsub_blocks * (end_inx - start_inx)):
+ sub_block_inx = k % nsub_blocks
+ block_inx = k // nsub_blocks
+
+ if trans_A:
+ ptr_A = A + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr
+ else:
+ ptr_A = A + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr
+
+ ptr_B = B + tl.load(column_indices + start_inx + block_inx) * bk_block_incr + sub_block_inx * bk_sub_incr
+
+ a = tl.load(ptr_A)
+ b = tl.load(ptr_B)
+ acc += tl.dot(a, b)
+
+ acc = acc.to(C.dtype.element_ty)
+
+ cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+
+ C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _dds_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices, offsets,
+ block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+
+ # matrix multiplication
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+
+ num_pid_m = tl.num_programs(0)
+ num_pid_n = tl.num_programs(1)
+ pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M)
+
+ start_inx = tl.load(offsets + pid_n)
+ end_inx = tl.load(offsets + pid_n + 1)
+
+ # pointers to dense matrix
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ rak = tl.arange(0, BLOCK_K)
+
+ A += (rm[:, None] * stride_am + rak[None, :] * stride_ak)
+
+ # pointers to sparse matrix
+ rn = tl.arange(0, BLOCK_N)
+ rbk = tl.arange(0, BLOCK_K)
+ B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn)
+
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K)
+
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+
+ ak_sub_incr = BLOCK_K * stride_ak
+ ak_block_incr = BLOCK_SIZE * stride_ak
+ bk_sub_incr = BLOCK_K * stride_bk
+
+ for k in range(nsub_blocks * (end_inx - start_inx)):
+ sub_block_inx = k % nsub_blocks
+ block_inx = k // nsub_blocks
+
+ if trans_B:
+ ptr_B = B + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr
+ else:
+ ptr_B = B + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr
+
+ ptr_A = A + tl.load(column_indices + start_inx + block_inx) * ak_block_incr + sub_block_inx * ak_sub_incr
+ a = tl.load(ptr_A)
+ b = tl.load(ptr_B)
+ acc += tl.dot(a, b)
+
+ acc = acc.to(C.dtype.element_ty)
+ cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+def dsd(shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs,
+ out
+ ):
+
+ device = rhs.device
+ trans_A = transpose_a
+ trans_B = False
+
+ if rhs.stride(0) > 1 and rhs.stride(1) > 1:
+ trans_B = True
+
+ # checks constraints
+ assert shape[1] == rhs.shape[0], "incompatible dimensions"
+ M, K = shape
+ _, N = rhs.shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if rhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ stride_am, stride_ak = data.stride(1), data.stride(2)
+ stride_bk, stride_bn = rhs.stride(0), rhs.stride(1)
+ a_column_indices = column_indices
+ a_offsets = offsets
+
+ # launch kernel
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N']))
+
+ if trans_A:
+ stride_am, stride_ak = data.stride(2), data.stride(1)
+ a_column_indices, a_offsets = column_indices_t, offsets_t
+
+ if trans_B:
+ stride_bk, stride_bn = rhs.stride(1), rhs.stride(0)
+
+ _dsd_kernel[grid](
+ data.data, rhs, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(0), out.stride(1),
+ row_indices, a_column_indices, a_offsets,
+ block_offsets_t, trans_A, trans_B,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+ # return out
+
+def dds(lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b,
+ out
+ ):
+
+ device = lhs.device
+ trans_B = transpose_b
+ trans_A = False
+
+ if lhs.stride(0) > 1 and lhs.stride(1) > 1:
+ trans_A = True
+
+ # checks constraints
+ assert lhs.shape[1] == shape[0], "incompatible dimensions"
+ M, K = lhs.shape
+ _, N = shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if lhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ stride_am, stride_ak = lhs.stride(0), lhs.stride(1)
+ stride_bk, stride_bn = data.stride(1), data.stride(2)
+ b_column_indices = column_indices_t
+ b_offsets = offsets_t
+
+ # launch kernel
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N']))
+
+ if trans_A:
+ stride_am, stride_ak = lhs.stride(1), lhs.stride(0)
+ if trans_B:
+ stride_bk, stride_bn = data.stride(2), data.stride(1)
+ b_column_indices, b_offsets = column_indices, offsets
+
+ _dds_kernel[grid](
+ lhs, data, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(0), out.stride(1),
+ row_indices, b_column_indices, b_offsets,
+ block_offsets_t, trans_A, trans_B,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+
+def sdd(lhs,
+ rhs,
+ shape,
+ out,
+ offsets,
+ row_indices,
+ column_indices
+ ):
+
+ device = out.device
+ trans_A = False
+ trans_B = False
+
+ if lhs.stride(0) > 1 and lhs.stride(1) > 1:
+ trans_A = True
+ if rhs.stride(0) > 1 and rhs.stride(1) > 1:
+ trans_B = True
+
+ # checks constraints
+ assert lhs.shape[1] == rhs.shape[0], "incompatible dimensions"
+ M, K = lhs.shape
+ _, N = rhs.shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if out.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ # launch kernel
+ nnz_blocks = len(row_indices)
+ grid = lambda META: (nnz_blocks,)
+
+ stride_am, stride_ak = lhs.stride(0), lhs.stride(1)
+ stride_bk, stride_bn = rhs.stride(0), rhs.stride(1)
+
+ if trans_A:
+ stride_am, stride_ak = lhs.stride(1), lhs.stride(0)
+ if trans_B:
+ stride_bk, stride_bn = rhs.stride(1), rhs.stride(0)
+
+ _sdd_kernel[grid](
+ lhs, rhs, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(1), out.stride(2),
+ row_indices, column_indices,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+
+@triton.jit
+def _row_indices_kernel(offsets, out):
+ pid = tl.program_id(0)
+ row_offset = tl.load(offsets + pid)
+ nnz_blocks = tl.load(offsets + pid + 1) - row_offset
+ for nnz_block in range(nnz_blocks):
+ tl.store(out + row_offset + nnz_block, pid)
+
+def row_indices(
+ shape, data, offsets, column_indices, out
+):
+ block_rows = len(offsets) - 1
+ _row_indices_kernel[(block_rows, )](offsets, out)
diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/matrix.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/matrix.py
new file mode 100644
index 0000000000000000000000000000000000000000..80f42263d6aada287adbfa52a61fe950162a9e28
--- /dev/null
+++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/matrix.py
@@ -0,0 +1,329 @@
+import numpy as np
+import torch
+
+# 1. Add heavyweight (data) validation helper.
+# 2. Add construction helpers
+# 3. Make indentation consistent
+# 4. Replace asserts with descriptive errors.
+
+##
+### Validation helpers.
+##
+
+
+def _validate_matrix(shape, data, row_indices, column_indices, offsets):
+ # Data should be [nnz, block_size, block_size]
+ if data.dim() == 1:
+ data = torch.reshape(data, [data.numel(), 1, 1])
+
+ # Blocks should be square.
+ if data.shape[-2] != data.shape[-1]:
+ raise ValueError(
+ "Expected square blocking in data. "
+ f"Got block shape {[data.shape[-2], data.shape[-1]]}")
+
+ # Flatten batch dimensions on data - original shape preserved
+ # in shape argument.
+ block_size = data.shape[-1]
+ data = data.view([-1, block_size, block_size])
+
+ if data.dim() != 3:
+ raise ValueError(
+ "Expected 3D shape for data (nnz, block, block). "
+ f"Got shape {data.dim()}D shape.")
+
+ block_size = data.shape[1]
+ if shape[-2] % block_size != 0 or shape[-1] % block_size != 0:
+ raise ValueError(
+ "Matrix shape must be dividible by blocking. "
+ f"Got shape {shape} with "
+ f"{[block_size, block_size]} blocking.")
+
+ if np.prod(shape) < data.numel():
+ raise ValueError(
+ "Invalid matrix. Number of nonzeros exceeds matrix capacity "
+ f"({data.numel()} v. {np.prod(shape)})")
+
+ if row_indices.dim() != 1:
+ raise ValueError(
+ f"Expected 1D row_indices. Got {row_indices.dim()}D row_indices.")
+
+ if column_indices.dim() != 1:
+ raise ValueError(
+ f"Expected 1D column_indices. Got {column_indices.dim()}D column_indices.")
+
+ if offsets.dim() != 1:
+ raise ValueError(
+ f"Expected 1D offsets. Got {offsets.dim()}D offsets.")
+
+ if row_indices.numel() != data.shape[0]:
+ raise ValueError(
+ "Expected 1 index per nonzero block. "
+ f"Got {row_indices.numel()} row_indices for {data.shape[0]} blocks")
+
+ if column_indices.numel() != data.shape[0]:
+ raise ValueError(
+ "Expected 1 index per nonzero block. "
+ f"Got {column_indices.numel()} column_indices for {data.shape[0]} blocks")
+
+ block_rows = np.prod(shape[:-1]) / block_size
+ if offsets.numel() != block_rows + 1:
+ raise ValueError(
+ "Expected one offset per block row plus one. "
+ f"Got {offsets.numel()} offsets with {block_rows} block rows.")
+
+ is_cuda = (data.is_cuda and
+ row_indices.is_cuda and
+ column_indices.is_cuda and
+ offsets.is_cuda)
+ is_cpu = (not data.is_cuda and
+ not row_indices.is_cuda and
+ not column_indices.is_cuda and
+ not offsets.is_cuda)
+ if not (is_cuda or is_cpu):
+ raise ValueError(
+ "Expected data & meta-data on common device. "
+ f"Got data on {data.device}, row_indices on {row_indices.device} "
+ f"column_indices on {column_indices.device} and "
+ f"offsets on {offsets.device}.")
+
+ if data.dtype != torch.float16:
+ raise ValueError(
+ f"Expected float16 data. Got {data.dtype} data.")
+ if row_indices.dtype != torch.int16:
+ raise ValueError(
+ f"Expected int16 row_indices. Got {row_indices.dtype} row_indices.")
+ if column_indices.dtype != torch.int16:
+ raise ValueError(
+ f"Expected int16 column_indices. Got {column_indices.dtype} column_indices.")
+ if offsets.dtype != torch.int32:
+ raise ValueError(
+ f"Expected int32 offsets. Got {offsets.dtype} offsets.")
+ return data
+
+
+def _transpose(size, data, row_indices, column_indices, offsets):
+ block_columns = size[1] // data.shape[1]
+
+ # Sort row indices by column indices to get the transposed matrix's
+ # column indices.
+ gather_indices = column_indices.argsort()
+ column_indices_t = row_indices.gather(0, gather_indices)
+ block_offsets_t = gather_indices.int()
+
+ # NOTE: Histogram is not implemented for any integer type on CPU. Do
+ # the histogram in 32-bit float, which can exactly represent 16-bit
+ # integers.
+ column_indices_float = column_indices.float()
+
+ zero = torch.zeros((1,), dtype=torch.int32, device=data.device)
+ nnz_per_column = column_indices_float.histc(block_columns, 0, block_columns)
+ nnz_per_column = nnz_per_column.int()
+ offsets_t = torch.cat([zero, nnz_per_column.cumsum(0, dtype=torch.int32)])
+ return column_indices_t, offsets_t, block_offsets_t
+
+
+class Matrix(torch.nn.Module):
+ """A matrix stored in sparse format.
+
+ Underlying format is block compressed sparse row (BCSR).
+
+ TODO(tgale): Make this mirror torch.Tensor API as much as possible.
+ """
+
+ def __init__(self,
+ size,
+ data,
+ row_indices,
+ column_indices,
+ offsets,
+ column_indices_t=None,
+ offsets_t=None,
+ block_offsets_t=None):
+ super().__init__()
+ self._size = size
+ self._data = data
+ self._row_indices = row_indices
+ self._column_indices = column_indices
+ self._offsets = offsets
+
+ # Produce the transpose meta-data if it is not passed in.
+ if ((column_indices_t is None) or (offsets_t is None) or
+ (block_offsets_t is None)):
+ column_indices_t, offsets_t, block_offsets_t = _transpose(
+ size, data, row_indices, column_indices, offsets)
+ self._column_indices_t = column_indices_t
+ self._offsets_t = offsets_t
+ self._block_offsets_t = block_offsets_t
+
+ self._transposed = False
+
+ # Validate that our metadata will not overflow.
+ max_dim = np.iinfo(np.int16).max * self.blocking
+ if column_indices.dtype == torch.int16:
+ if size[0] > max_dim or size[1] > max_dim:
+ raise ValueError(
+ "Sparse matrix with shape {size} exceeds representable "
+ "size with 16-bit indices.")
+
+ def validate(self):
+ _validate_matrix(self._size,
+ self._data,
+ self._row_indices,
+ self._column_indices,
+ self._offsets)
+
+ # TODO(tgale): Add heavyweight data validation.
+
+ def to(self, device):
+ # TODO(tgale): Handle type conversions here. We
+ # need to set the appropriate meta-data type for
+ # the given floating-point type.
+ self._data = self._data.to(device)
+ self._row_indices = self._row_indices.to(device)
+ self._column_indices = self._column_indices.to(device)
+ self._offsets = self._offsets.to(device)
+ self._column_indices_t = self._column_indices_t.to(device)
+ self._offsets_t = self._offsets_t.to(device)
+ self._block_offsets_t = self._block_offsets_t.to(device)
+ return self
+
+ def cuda(self):
+ return self.to(torch.cuda.current_device())
+
+ def clone(self):
+ return Matrix(
+ self.size(),
+ self.data.clone(),
+ self.row_indices.clone(),
+ self.column_indices.clone(),
+ self.offsets.clone(),
+ self.column_indices_t.clone(),
+ self.offsets_t.clone(),
+ self.block_offsets_t.clone())
+
+ def t(self):
+ if self.dim() != 2:
+ raise ValueError(
+ "t() expects a tensor with <= 2 dimensions, "
+ f"but self is {self.dim()}D.")
+ out = Matrix(self.size(),
+ self.data,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+ out._transposed = not self._transposed
+ out._size = torch.Size((self._size[1], self._size[0]))
+ return out
+
+ def contiguous(self):
+ raise ValueError("Not yet implemented.")
+
+ def is_contiguous(self):
+ return not self._transposed
+
+ @property
+ def is_cuda(self):
+ return self._data.is_cuda
+
+ @property
+ def device(self):
+ return self._data.device
+
+ def size(self):
+ return self._size
+
+ @property
+ def shape(self):
+ return self.size()
+
+ def dim(self):
+ return len(self._size)
+
+ @property
+ def data(self):
+ return self._data
+
+ @property
+ def row_indices(self):
+ return self._row_indices
+
+ @property
+ def column_indices(self):
+ return self._column_indices
+
+ @property
+ def offsets(self):
+ return self._offsets
+
+ @property
+ def offsets_t(self):
+ return self._offsets_t
+
+ @property
+ def column_indices_t(self):
+ return self._column_indices_t
+
+ @property
+ def block_offsets_t(self):
+ return self._block_offsets_t
+
+ @property
+ def dtype(self):
+ return self.data.dtype
+
+ @property
+ def nnz(self):
+ return self.data.numel()
+
+ @property
+ def blocking(self):
+ return self.data.shape[1]
+
+ @property
+ def requires_grad(self):
+ return self.data.requires_grad
+
+ def requires_grad_(self, x):
+ self.data.requires_grad_(x)
+ return self
+
+ def view(self, *shape):
+ assert self.is_contiguous()
+ if shape[-1] != self.size()[-1]:
+ raise ValueError(
+ "Can't change view on compressed dimension. "
+ f"{self.size()[-1]} v. {shape[-1]}.")
+ if np.prod(shape) != np.prod(self.size()):
+ raise ValueError(
+ "Mismatch in numel of Matrix and new shape. "
+ f"{np.prod(self.size())} v. {np.prod(shape)}")
+ return Matrix(shape,
+ self.data,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+
+ @property
+ def grad(self):
+ # TODO(tgale): Make sure this mirrors torch.Tensor
+ # behavior in the case where we ask for the gradient
+ # of a non-contiguous tensor.
+ size = self.size()
+ if not self.is_contiguous():
+ size = torch.Size((size[1], size[0]))
+ out = Matrix(size,
+ self.data.grad,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+ return out if self.is_contiguous() else out.t()
diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/ops/__init__.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc873b236f4cd4036964c016a4036e3ce5ebf1ac
--- /dev/null
+++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/ops/__init__.py
@@ -0,0 +1,3 @@
+from .linear_ops import dds, dsd, sdd
+from .matrix_ops import ones_like, row_indices, sum, to_dense, to_sparse
+from .eltwise_ops import mul
diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/ops/eltwise_ops.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/ops/eltwise_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba7d7332320250fd01fa60e60528f19de3e8ed03
--- /dev/null
+++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/ops/eltwise_ops.py
@@ -0,0 +1,28 @@
+from ..matrix import Matrix
+
+def mul(a, b):
+ """Performs element-wise multiplication of matrices a and b.
+
+ It is the user's responsibility to make sure that a and b
+ follow the same matrix topology. This function assumes it is safe
+ to use the topoplogy of a.
+
+ Args:
+ a: stk.Matrix.
+ b: stk.Matrix with a's matrix topology.
+
+ Returns:
+ stk.Matrix where the entries correspond to torch.mul(a, b).
+ """
+ assert isinstance(a, Matrix)
+ assert isinstance(b, Matrix)
+ assert a.size() == b.size()
+
+ return Matrix(a.size(),
+ a.data * b.data,
+ a.row_indices,
+ a.column_indices,
+ a.offsets,
+ a.column_indices_t,
+ a.offsets_t,
+ a.block_offsets_t)
diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/ops/eltwise_ops_test.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/ops/eltwise_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..66bfd4f6af77042d3c5bdb1fe18d00e457478d46
--- /dev/null
+++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/ops/eltwise_ops_test.py
@@ -0,0 +1,86 @@
+import unittest
+import itertools
+import torch
+from absl.testing import parameterized
+
+import stk
+from stk.ops.linear_ops_test import allclose, _dense_and_sparse
+
+_MATRIX_SIZES = (
+ (128, 128, 0.0),
+ (256, 256, 0.5),
+ (2048, 1024, 0.8),
+ (512, 128, 0.0),
+ (128, 512, 0.0),
+ (1024, 512, 0.0),
+ (1024, 512, 0.5),
+ (1024, 512, 0.75),
+ (512, 1024, 0.0),
+ (512, 1024, 0.5),
+ (512, 1024, 0.75),
+ (1024, 1024, 0.0),
+ (1024, 1024, 0.5),
+ (1024, 1024, 0.75),
+)
+
+_DTYPE = (
+ torch.float16, torch.bfloat16
+)
+
+def _generate_testcases():
+ testcases = itertools.product(_MATRIX_SIZES, _DTYPE)
+ testcases = [(*size, 128, dtype) for
+ (size, dtype) in testcases]
+ return testcases
+
+_ELTWISE_OP_TESTS = _generate_testcases()
+
+def _dense_and_sparse_like(x, std=0.1):
+ dense_data = torch.randn_like(x.data, device=x.device) * std
+ sparse = stk.Matrix(x.size(),
+ dense_data,
+ x.row_indices,
+ x.column_indices,
+ x.offsets)
+ dense = stk.ops.to_dense(sparse)
+
+ return (dense.requires_grad_(True),
+ sparse.requires_grad_(True))
+
+@parameterized.parameters(_ELTWISE_OP_TESTS)
+class EltwiseOpsTest(parameterized.TestCase):
+
+ def testEltwiseMul(self, m, n, sparsity, blocking, dtype):
+
+ a_dense, a = _dense_and_sparse(m, n, sparsity, blocking, dtype)
+ b_dense, b = _dense_and_sparse_like(a)
+
+ out = stk.ops.mul(a, b)
+ expected_out = torch.mul(a_dense, b_dense)
+
+ # Compute the gradients w.r.t. the inputs.
+ expected_out.sum().backward()
+ stk.ops.sum(out).backward()
+
+ # Validate the results.
+ out = stk.ops.to_dense(out)
+ self.assertEqual(out.dim(), 2)
+ self.assertEqual(expected_out.size(), out.size())
+ self.assertTrue(allclose(out, expected_out))
+
+ # LHS gradient.
+ grad = stk.ops.to_dense(a.grad)
+ expected_grad = a_dense.grad
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size(), grad.size())
+ self.assertTrue(allclose(grad, expected_grad))
+
+ # RHS gradient.
+ grad = stk.ops.to_dense(b.grad)
+ expected_grad = b_dense.grad
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size(), grad.size())
+ self.assertTrue(allclose(grad, expected_grad))
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/ops/linear_ops.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/ops/linear_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d277c8c07f9e30addc31900a12175c8a1f4d7ad
--- /dev/null
+++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/ops/linear_ops.py
@@ -0,0 +1,59 @@
+import torch
+
+from ..backend import sputnik
+from ..matrix import Matrix
+
+
+def dsd(a, b):
+ assert isinstance(a, Matrix)
+ assert isinstance(b, torch.Tensor)
+ return sputnik.dsd(
+ a.size(),
+ a.data, a.offsets,
+ a.row_indices,
+ a.column_indices,
+ a.offsets_t,
+ a.column_indices_t,
+ a.block_offsets_t,
+ not a.is_contiguous(),
+ b)
+
+
+def dds(a, b):
+ assert isinstance(a, torch.Tensor)
+ assert isinstance(b, Matrix)
+ return sputnik.dds(
+ a,
+ b.size(),
+ b.data, b.offsets,
+ b.row_indices,
+ b.column_indices,
+ b.offsets_t,
+ b.column_indices_t,
+ b.block_offsets_t,
+ not b.is_contiguous())
+
+
+def sdd(a, b, topo):
+ assert isinstance(a, torch.Tensor)
+ assert isinstance(b, torch.Tensor)
+ assert isinstance(topo, Matrix)
+ assert topo.is_contiguous()
+ out = sputnik.sdd(
+ a, b,
+ topo.size(),
+ topo.data,
+ topo.offsets,
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets_t,
+ topo.column_indices_t,
+ topo.block_offsets_t)
+ return Matrix(topo.size(),
+ out,
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets,
+ topo.column_indices_t,
+ topo.offsets_t,
+ topo.block_offsets_t)
diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/ops/linear_ops_test.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/ops/linear_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..ced1d782fbc9f9ca16b3449239f1588dc5ff5e00
--- /dev/null
+++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/ops/linear_ops_test.py
@@ -0,0 +1,216 @@
+import unittest
+import itertools
+import numpy as np
+import torch
+from absl.testing import parameterized
+
+import stk
+
+
+def allclose(x, y, pct=0.25):
+ mask = torch.isclose(x, y, rtol=5e-2)
+ pct_diff = (mask.numel() - mask.sum()) / mask.numel() * 100
+ if pct_diff > pct:
+ print("{:.2f}% of values not close.".format(pct_diff))
+ return False
+ return True
+
+
+# An assortment of problems designed to make sure
+# the bindings are operating correctly.
+_MATRIX_SIZES = (
+ (128, 128, 128, 0.0),
+ (256, 256, 256, 0.5),
+ (2048, 1024, 512, 0.8),
+ (512, 128, 128, 0.0),
+ (128, 128, 512, 0.0),
+ (1024, 512, 512, 0.0),
+ (1024, 512, 512, 0.5),
+ (1024, 512, 512, 0.75),
+ (512, 512, 1024, 0.0),
+ (512, 512, 1024, 0.5),
+ (512, 512, 1024, 0.75),
+ (1024, 1024, 1024, 0.0),
+ (1024, 1024, 1024, 0.5),
+ (1024, 1024, 1024, 0.75),
+)
+
+_TRANSPOSE = (
+ (False, False),
+ (False, True),
+ (True, False),
+ (True, True),
+)
+
+_DTYPE = (
+ torch.float16, torch.bfloat16
+)
+
+def _generate_testcases():
+ testcases = itertools.product(_MATRIX_SIZES, _TRANSPOSE, _DTYPE)
+ testcases = [(*size, *trans, 128, dtype) for
+ (size, trans, dtype) in testcases]
+ return testcases
+
+_LINEAR_OP_TESTS = _generate_testcases()
+
+def _dense_and_sparse(rows, cols, sparsity, blocking, dtype, std=0.1):
+ mask = stk.random.dense_mask(rows, cols, sparsity, blocking)
+ dense = (torch.randn(rows, cols) * std * mask).type(dtype)
+ sparse = stk.ops.to_sparse(dense, blocking)
+ cuda_device = torch.device("cuda")
+ return (dense.to(cuda_device).requires_grad_(True),
+ sparse.to(cuda_device).requires_grad_(True))
+
+
+def _dense(rows, cols, dtype, std=0.1):
+ cuda_device = torch.device("cuda")
+ out = (torch.randn(rows, cols) * std).type(dtype)
+ return out.to(cuda_device).requires_grad_(True)
+
+
+def _dense_2x(rows, cols, dtype):
+ a = _dense(rows, cols, dtype)
+ return a, a.detach().requires_grad_(True)
+
+
+def _with_transpose(op, a, b, trans_a, trans_b):
+ a = a.t() if trans_a else a
+ b = b.t() if trans_b else b
+ return op(a, b)
+
+
+def _mmm(a, b, topo):
+ mask = stk.ops.to_dense(stk.ops.ones_like(topo))
+ return torch.mm(a, b) * mask
+
+
+def _sparse_out_with_transpose(op, a, b, topo, trans_a, trans_b):
+ a = a.t() if trans_a else a
+ b = b.t() if trans_b else b
+ return op(a, b, topo)
+
+
+def _mask(x, mask):
+ mask = stk.ops.to_dense(stk.ops.ones_like(mask))
+ return x * mask
+
+
+@parameterized.parameters(*_LINEAR_OP_TESTS)
+class LinearOpsTest(parameterized.TestCase):
+
+ def testLinearOps_Dsd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+ # Construct the operands.
+ a_shape = (k, m) if trans_a else (m, k)
+ a_dense, a = _dense_and_sparse(*a_shape, sparsity, blocking, dtype)
+ b_shape = (n, k) if trans_b else (k, n)
+ b, bcp = _dense_2x(*b_shape, dtype)
+
+ # Execute the matmul.
+ out = _with_transpose(stk.ops.dsd, a, b, trans_a, trans_b)
+ expected_out = _with_transpose(torch.mm, a_dense, bcp, trans_a, trans_b)
+
+ # Compute the gradients w.r.t. the inputs.
+ expected_out.sum().backward()
+ out.sum().backward()
+
+ # Validate the results.
+ self.assertEqual(out.dim(), 2)
+ self.assertEqual(expected_out.size()[0], out.size()[0])
+ self.assertEqual(expected_out.size()[1], out.size()[1])
+ self.assertTrue(allclose(out, expected_out))
+
+ # LHS gradient.
+ grad = stk.ops.to_dense(a.grad)
+ expected_grad = _mask(a_dense.grad, a.grad)
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
+ self.assertTrue(allclose(grad, expected_grad))
+
+ # RHS gradient.
+ grad = b.grad
+ expected_grad = bcp.grad
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
+ self.assertTrue(allclose(grad, expected_grad))
+
+ def testLinearOps_Dds(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+ # Construct the operands.
+ a_shape = (k, m) if trans_a else (m, k)
+ a, acp = _dense_2x(*a_shape, dtype)
+ b_shape = (n, k) if trans_b else (k, n)
+ b_dense, b = _dense_and_sparse(*b_shape, sparsity, blocking, dtype)
+
+ # Execute the matmul.
+ out = _with_transpose(stk.ops.dds, a, b, trans_a, trans_b)
+ expected_out = _with_transpose(torch.mm, acp, b_dense, trans_a, trans_b)
+
+ # Compute the gradients w.r.t. the inputs.
+ expected_out.sum().backward()
+ out.sum().backward()
+
+ # Validate the results.
+ self.assertEqual(out.dim(), 2)
+ self.assertEqual(expected_out.size()[0], out.size()[0])
+ self.assertEqual(expected_out.size()[1], out.size()[1])
+ self.assertTrue(allclose(out, expected_out))
+
+ # LHS gradient.
+ grad = a.grad
+ expected_grad = acp.grad
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
+ self.assertTrue(allclose(grad, expected_grad))
+
+ # RHS gradient.
+ grad = stk.ops.to_dense(b.grad)
+ expected_grad = _mask(b_dense.grad, b.grad)
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
+ self.assertTrue(allclose(grad, expected_grad))
+
+ def testLinearOps_Sdd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+ # Construct the operands.
+ a_shape = (k, m) if trans_a else (m, k)
+ a, acp = _dense_2x(*a_shape, dtype)
+ b_shape = (n, k) if trans_b else (k, n)
+ b, bcp = _dense_2x(*b_shape, dtype)
+ _, topo = _dense_and_sparse(m, n, sparsity, blocking, dtype)
+
+ # Execute the matmul.
+ out = _sparse_out_with_transpose(stk.ops.sdd, a, b, topo, trans_a, trans_b)
+ expected_out = _sparse_out_with_transpose(_mmm, acp, bcp, topo, trans_a, trans_b)
+
+ # Compute the gradients w.r.t. the inputs.
+ expected_out.sum().backward()
+ stk.ops.sum(out).backward()
+
+ # Validate the results.
+ out = stk.ops.to_dense(out)
+ self.assertEqual(out.dim(), 2)
+ self.assertEqual(expected_out.size()[0], out.size()[0])
+ self.assertEqual(expected_out.size()[1], out.size()[1])
+ self.assertTrue(allclose(out, expected_out))
+
+ # LHS gradient.
+ grad = a.grad
+ expected_grad = acp.grad
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
+ self.assertTrue(allclose(grad, expected_grad))
+
+ # RHS gradient.
+ grad = b.grad
+ expected_grad = bcp.grad
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
+ self.assertTrue(allclose(grad, expected_grad))
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/ops/matrix_ops.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/ops/matrix_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..447c72dc73439d84f58c917676cc04e64f13e97d
--- /dev/null
+++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/ops/matrix_ops.py
@@ -0,0 +1,98 @@
+from ..backend import sputnik
+from ..matrix import Matrix
+import torch
+import numpy as np
+
+
+@torch.no_grad()
+def row_indices(shape, data, offsets, column_indices):
+ return sputnik.row_indices(shape, data, offsets, column_indices)
+
+
+# TODO(tgale): Replace this helper with a custom kernel. This operation
+# is much simpler to do than how it's currently implemented.
+@torch.no_grad()
+def _expand_for_blocking(idxs, blocking):
+ # Duplicate for block column dimension.
+ idxs = torch.reshape(idxs, [idxs.size()[0], 1, 2]).repeat(1, blocking, 1)
+
+ # Update the column indices.
+ idxs[:, :, 1] *= blocking
+ idxs[:, :, 1] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking])
+
+ # Duplicate for block row dimension.
+ idxs = torch.reshape(idxs, [idxs.size()[0], 1, blocking, 2])
+ idxs = idxs.repeat(1, blocking, 1, 1)
+
+ # Update the row indices.
+ idxs[:, :, :, 0] *= blocking
+ idxs[:, :, :, 0] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking, 1])
+ idxs = torch.reshape(idxs, [-1, 2])
+ return idxs
+
+
+# TODO(tgale): Add input type checking.
+@torch.no_grad()
+def to_dense(x):
+ assert isinstance(x, Matrix)
+
+ shape = (np.prod(x.shape[:-1]), x.shape[-1])
+ row_idxs = x.row_indices.type(torch.int32)
+ col_idxs = x.column_indices.type(torch.int32)
+ indices = _expand_for_blocking(torch.stack([row_idxs, col_idxs], dim=1), x.blocking)
+ indices = (indices[:, 0] * shape[1] + indices[:, 1]).type(torch.int64)
+
+ out = torch.zeros(shape[0] * shape[1], dtype=x.dtype, device=x.device)
+ out.scatter_(0, indices, x.data.flatten())
+ return out.reshape(x.size())
+
+
+@torch.no_grad()
+def _mask(x, blocking=1):
+ assert x.dim() == 2
+ assert x.size()[0] % blocking == 0
+ assert x.size()[1] % blocking == 0
+ block_rows = x.size()[0] // blocking
+ block_cols = x.size()[1] // blocking
+ x = torch.reshape(x, [block_rows, blocking, block_cols, blocking])
+ x = torch.sum(torch.abs(x), dim=(1, 3))
+ return x != 0
+
+
+# TODO(tgale): Add input type checking.
+@torch.no_grad()
+def to_sparse(x, blocking=1):
+ m = _mask(x, blocking)
+
+ # TODO(tgale): Set to appropriate type for input matrix.
+ row_nnzs = torch.sum(m, dim=1).type(torch.int32)
+ zeros = torch.zeros((1,), dtype=row_nnzs.dtype, device=row_nnzs.device)
+ offsets = torch.cat([zeros, torch.cumsum(row_nnzs, dim=0)])
+ offsets = offsets.type(torch.int32)
+
+ indices = torch.nonzero(m).type(torch.int16)
+ row_indices = indices[:, 0]
+ column_indices = indices[:, 1]
+
+ # Nonzero indices in the dense matrix.
+ nonzero_indices = torch.nonzero(m)
+ nonzero_indices = _expand_for_blocking(nonzero_indices, blocking)
+ nonzero_indices = nonzero_indices[:, 0] * x.size()[1] + nonzero_indices[:, 1]
+
+ # Gather the data and construct the sparse matrix.
+ data = torch.gather(x.flatten(), dim=0, index=nonzero_indices)
+ data = torch.reshape(data, [-1, blocking, blocking])
+ return Matrix(x.size(), data, row_indices, column_indices, offsets)
+
+
+@torch.no_grad()
+def ones_like(x):
+ return Matrix(x.size(),
+ torch.ones_like(x.data),
+ x.row_indices,
+ x.column_indices, x.offsets)
+
+
+def sum(x):
+ assert isinstance(x, Matrix)
+ return x.data.sum()
diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/ops/matrix_ops_test.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/ops/matrix_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..3af04c0760483e578f93303dc457415948a2a34c
--- /dev/null
+++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/ops/matrix_ops_test.py
@@ -0,0 +1,62 @@
+import unittest
+
+from absl.testing import parameterized
+import stk
+import torch
+
+
+@parameterized.parameters(
+ (8, 16, 0.0, 1),
+ (8, 16, 0.5, 1),
+ (8, 16, .95, 1),
+ (16, 8, 0.0, 1),
+ (16, 8, 0.5, 1),
+ (16, 8, .95, 1),
+ (8, 16, 0.0, 8),
+ (8, 16, 0.5, 8),
+ (8, 16, 1.0, 8),
+ (16, 8, 0.0, 8),
+ (16, 8, 0.5, 8),
+ (16, 8, 1.0, 8),
+ (128, 256, 0.5, 16),
+ (256, 128, 0.75, 32),
+ (512, 512, .875, 128))
+class MatrixOpsTest(parameterized.TestCase):
+
+ def testMatrixOps_FormatConversion(self, rows, cols, sparsity, blocking):
+ mask = stk.random.dense_mask(rows, cols, sparsity, blocking)
+ x = (torch.randn(rows, cols) * mask).type(torch.float16)
+
+ # Convert the matrix to sparse format.
+ sparse_x = stk.ops.to_sparse(x, blocking)
+
+ # Validate the matrix.
+ sparse_x.validate()
+
+ # Validate the shape.
+ self.assertEqual(sparse_x.dim(), 2)
+ self.assertEqual(sparse_x.size()[0], rows)
+ self.assertEqual(sparse_x.size()[1], cols)
+
+ # Validate the sparsity.
+ numblocks = rows // blocking * cols // blocking
+ nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+ self.assertEqual(sparse_x.nnz, nnz)
+
+ # Convert back to dense format.
+ dense_x = stk.ops.to_dense(sparse_x)
+
+ # Validate the shape.
+ self.assertEqual(dense_x.dim(), 2)
+ self.assertEqual(dense_x.size()[0], rows)
+ self.assertEqual(dense_x.size()[1], cols)
+
+ # Validate the sparsity
+ self.assertEqual(torch.count_nonzero(dense_x).item(), nnz)
+
+ # Validate the output.
+ self.assertTrue(torch.all(torch.eq(x, dense_x)))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/random/__init__.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/random/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2576d1ca27283f77569a9a620c7c99fa68aaf30e
--- /dev/null
+++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/random/__init__.py
@@ -0,0 +1,2 @@
+# from stk.random.random_ops import dense_mask, mask, randn
+from .random_ops import dense_mask, mask, randn
diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/random/random_ops.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/random/random_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1b36771e0eb8e7abf46bcb3b136b5fb1d29df93
--- /dev/null
+++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/random/random_ops.py
@@ -0,0 +1,36 @@
+import numpy as np
+import torch
+from ..ops import matrix_ops
+
+
+@torch.no_grad()
+def dense_mask(rows, cols, sparsity, blocking=1):
+ assert sparsity >= 0.0 and sparsity <= 1.0
+ assert rows % blocking == 0 and cols % blocking == 0
+
+ block_rows, block_cols = (rows // blocking, cols // blocking)
+ nnz = round(block_rows * block_cols * (1 - sparsity))
+
+ out = np.ones(block_rows * block_cols)
+ mask = np.random.choice(out.size, out.size - nnz, replace=False)
+ out[mask] = 0.0
+
+ out = np.tile(
+ np.reshape(out, [block_rows, 1, block_cols, 1]),
+ (1, blocking, 1, blocking))
+ out = np.reshape(out, [rows, cols])
+ return torch.from_numpy(out.astype(np.float32))
+
+
+@torch.no_grad()
+def mask(m, n, sparsity, blocking=1):
+ out = dense_mask(m, n, sparsity, blocking).type(torch.float16)
+ return matrix_ops.to_sparse(out, blocking=blocking)
+
+
+@torch.no_grad()
+def randn(shape, sparsity, blocking=1):
+ shape_2d = (np.prod(shape[:-1]), shape[-1])
+ out = mask(*shape_2d, sparsity, blocking)
+ out.data.copy_(torch.randn(*out.data.shape))
+ return out.view(*shape)
diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/random/random_ops_test.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/random/random_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..587b44ec890c861879c6296b8f9028f5d99ab82f
--- /dev/null
+++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/random/random_ops_test.py
@@ -0,0 +1,73 @@
+import unittest
+
+from absl.testing import parameterized
+from . import random
+import torch
+
+
+@parameterized.parameters(
+ (8, 16, 0.0, 1),
+ (8, 16, 0.5, 1),
+ (8, 16, .95, 1),
+ (16, 8, 0.0, 1),
+ (16, 8, 0.5, 1),
+ (16, 8, .95, 1),
+ (8, 16, 0.0, 8),
+ (8, 16, 0.5, 8),
+ (8, 16, 1.0, 8),
+ (16, 8, 0.0, 8),
+ (16, 8, 0.5, 8),
+ (16, 8, 1.0, 8),
+ (128, 256, 0.5, 16),
+ (256, 128, 0.75, 32),
+ (512, 512, .875, 128))
+class RandomOpsTest(parameterized.TestCase):
+
+ def testRandomOps_DenseMask(self, rows, cols, sparsity, blocking):
+ mask = random.dense_mask(
+ rows, cols, sparsity, blocking)
+
+ # Validate the shape.
+ self.assertEqual(mask.dim(), 2)
+ self.assertEqual(mask.size()[0], rows)
+ self.assertEqual(mask.size()[1], cols)
+
+ # Validate the sparsity
+ numblocks = rows // blocking * cols // blocking
+ nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+ self.assertEqual(
+ torch.count_nonzero(mask).item(),
+ nnz)
+
+ # Check values are zero or one.
+ self.assertTrue(
+ torch.all(torch.logical_or(
+ torch.eq(mask, 0),
+ torch.eq(mask, 1))))
+
+ def testRandomOps_SparseMask(self, rows, cols, sparsity, blocking):
+ mask = random.mask(
+ rows, cols, sparsity, blocking)
+
+ # Validate the matrix.
+ mask.validate()
+
+ # Validate the shape.
+ self.assertEqual(mask.dim(), 2)
+ self.assertEqual(mask.size()[0], rows)
+ self.assertEqual(mask.size()[1], cols)
+
+ # Validate the sparsity.
+ numblocks = rows // blocking * cols // blocking
+ nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+ self.assertEqual(mask.nnz, nnz)
+
+ # Check values are zero or one.
+ self.assertTrue(
+ torch.all(torch.logical_or(
+ torch.eq(mask.data, 0),
+ torch.eq(mask.data, 1))))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/__init__.py b/build/torch28-cxx11-cu126-x86_64-linux/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..38075732c6d8fa0e1e6ef493145e1aca3851ae6b
--- /dev/null
+++ b/build/torch28-cxx11-cu126-x86_64-linux/__init__.py
@@ -0,0 +1,202 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+from ._ops import ops
+
+from .grouped_gemm import backend as gg_backend
+from .grouped_gemm import ops as gg_ops
+
+
+from ._layers.arguments import Arguments
+from ._layers.dmoe import ParallelDroplessMLP, dMoE
+from ._layers.glu import SparseGLU
+from ._layers.mlp import MLP, SparseMLP
+from ._layers.moe import MoE, ParallelMLP, get_load_balancing_loss
+
+from . import layers
+
+# This section contains the direct kernel exports (not inlcuded in the original code)
+def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
+ """
+ Compute exclusive cumulative sum along the specified dimension.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ result = ops.exclusive_cumsum(x, dim)
+ out.copy_(result)
+ return out
+
+
+def inclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
+ """
+ Compute inclusive cumulative sum along the specified dimension.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ result = ops.inclusive_cumsum(x, dim)
+ out.copy_(result)
+ return out
+
+
+def histogram(x: torch.Tensor, num_bins: int) -> torch.Tensor:
+ """
+ Compute histogram of input tensor values.
+
+ Args:
+ x: Input tensor
+ num_bins: Number of histogram bins
+
+ Returns:
+ Histogram tensor with counts for each bin
+ """
+ return ops.histogram(x, num_bins)
+
+
+def indices(
+ padded_bins: torch.Tensor,
+ block_size: int,
+ output_block_rows: int,
+ output_block_columns: int,
+) -> torch.Tensor:
+ """
+ Construct indices from padded bins for sparse operations.
+
+ Args:
+ padded_bins: Tensor containing bin boundaries
+ block_size: Size of each block
+ output_block_rows: Number of rows in output blocks
+ output_block_columns: Number of columns in output blocks
+
+ Returns:
+ Tensor containing constructed indices
+ """
+ return ops.indices(padded_bins, block_size, output_block_rows, output_block_columns)
+
+
+def replicate_forward(
+ x: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Forward pass of replicate operation - replicate values according to bin sizes.
+
+ Args:
+ x: Input tensor with values to replicate
+ bins: Tensor containing bin sizes
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ return ops.replicate_forward(x, bins, out)
+
+
+def replicate_backward(
+ grad: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Backward pass of replicate operation - reduce gradients back to bins.
+
+ Args:
+ grad: Gradient tensor to reduce
+ bins: Tensor containing bin sizes
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ return ops.replicate_backward(grad, bins, out)
+
+
+def sort(
+ x: torch.Tensor, end_bit: int, x_out: torch.Tensor, iota_out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Radix sort with index tracking.
+
+ Args:
+ x: Input tensor to sort
+ end_bit: Number of bits to consider in sorting
+ x_out: Output tensor for sorted values
+ iota_out: Output tensor for sorted indices
+
+ Returns:
+ The sorted values tensor
+ """
+ return ops.sort(x, end_bit, x_out, iota_out)
+
+
+# Convenience functions for common use cases
+def cumsum(x: torch.Tensor, dim: int = -1, exclusive: bool = False) -> torch.Tensor:
+ """
+ Compute cumulative sum with automatic output allocation.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum (default: last dimension)
+ exclusive: Whether to compute exclusive (True) or inclusive (False) cumsum
+
+ Returns:
+ New tensor containing the cumulative sum
+ """
+ out = torch.empty_like(x)
+ if exclusive:
+ return exclusive_cumsum(x, dim, out)
+ else:
+ return inclusive_cumsum(x, dim, out)
+
+
+def argsort(x: torch.Tensor, end_bit: int = 32) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Sort tensor and return both sorted values and indices.
+
+ Args:
+ x: Input tensor to sort
+ end_bit: Number of bits to consider in sorting
+
+ Returns:
+ Tuple of (sorted_values, sorted_indices)
+ """
+ x_out = torch.empty_like(x)
+ iota_out = torch.empty_like(x)
+ sort(x, end_bit, x_out, iota_out)
+ return x_out, iota_out
+
+
+# Export public API
+__all__ = [
+ "MyReplacementLayer",
+ # Direct kernel exports
+ "exclusive_cumsum",
+ "inclusive_cumsum",
+ "histogram",
+ "indices",
+ "replicate_forward",
+ "replicate_backward",
+ "sort",
+ "cumsum",
+ "argsort",
+ # Original exports
+ "Arguments",
+ "ParallelDroplessMLP",
+ "dMoE",
+ "SparseGLU",
+ "MLP",
+ "SparseMLP",
+ "MoE",
+ "ParallelMLP",
+ "get_load_balancing_loss",
+]
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/_layers/__init__.py b/build/torch28-cxx11-cu126-x86_64-linux/_layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a720e7a2cc4e44636f6e433a2750e945dc38e8b2
--- /dev/null
+++ b/build/torch28-cxx11-cu126-x86_64-linux/_layers/__init__.py
@@ -0,0 +1,10 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# from megablocks.layers.dmoe import dMoE
+from .moe import MoE
+
+__all__ = [
+ 'MoE',
+ # 'dMoE',
+]
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/_layers/activation_fn.py b/build/torch28-cxx11-cu126-x86_64-linux/_layers/activation_fn.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e1d956704840aa4daf7d1d71d24e051567feab9
--- /dev/null
+++ b/build/torch28-cxx11-cu126-x86_64-linux/_layers/activation_fn.py
@@ -0,0 +1,33 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Callable, Union
+
+import torch
+from ..stk import Matrix
+
+
+def act_fn(
+ x: Matrix,
+ function: Callable,
+ return_grad_fn: bool = False,
+ **kwargs,
+) -> Union[tuple[Matrix, Any] | Matrix]:
+ assert isinstance(x, Matrix)
+ with torch.set_grad_enabled(torch.is_grad_enabled() or return_grad_fn):
+ if return_grad_fn:
+ x.data.requires_grad = True
+ out = function(x.data, **kwargs)
+ y = Matrix(
+ x.size(),
+ out,
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
+ if return_grad_fn:
+ return y, out.backward
+ return y
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/_layers/all_to_all.py b/build/torch28-cxx11-cu126-x86_64-linux/_layers/all_to_all.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ac7067bcaa34db1d82b340c43550fe3577aa7a3
--- /dev/null
+++ b/build/torch28-cxx11-cu126-x86_64-linux/_layers/all_to_all.py
@@ -0,0 +1,54 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import torch.distributed as dist
+
+
+class AllToAllOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op):
+ out = torch.empty((sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype)
+
+ ctx.input_shape = x.shape
+ ctx.output_split_sizes = output_split_sizes
+ ctx.input_split_sizes = input_split_sizes
+ ctx.group = group
+ handle = dist.all_to_all_single(
+ out,
+ x,
+ output_split_sizes=output_split_sizes,
+ input_split_sizes=input_split_sizes,
+ group=group,
+ async_op=async_op,
+ )
+ return out, handle
+
+ @staticmethod
+ def backward(ctx, grad, _):
+ if ctx.needs_input_grad[0]:
+ out = torch.empty(
+ ctx.input_shape,
+ device=grad.device,
+ dtype=grad.dtype,
+ )
+ dist.all_to_all_single(
+ out,
+ grad,
+ output_split_sizes=ctx.input_split_sizes,
+ input_split_sizes=ctx.output_split_sizes,
+ group=ctx.group,
+ )
+ return out, None, None, None, None
+ return None, None, None, None, None
+
+
+def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False):
+ return AllToAllOp.apply(
+ x,
+ output_split_sizes,
+ input_split_sizes,
+ group,
+ async_op,
+ )
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/_layers/arguments.py b/build/torch28-cxx11-cu126-x86_64-linux/_layers/arguments.py
new file mode 100644
index 0000000000000000000000000000000000000000..4db9b1bd38bc2e2f421625c124f86b85f45c5ae0
--- /dev/null
+++ b/build/torch28-cxx11-cu126-x86_64-linux/_layers/arguments.py
@@ -0,0 +1,101 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import dataclasses
+from functools import partial
+from typing import Any, Callable, Optional, Union
+
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+
+# import megablocks.grouped_gemm_util as grouped_gemm
+from .. import grouped_gemm_util as grouped_gemm
+
+# Type annotation for in-place Tensor initialization function.
+InitFn = Union[Callable[[torch.Tensor], None], partial[torch.Tensor]]
+
+_ALLOWED_BITWIDTHS = (-1, 4, 8)
+
+DEFAULT_ACTIVATION_FN = partial(F.gelu, approximate='tanh')
+
+
+@dataclasses.dataclass
+class Arguments:
+ # Model arguments.
+ hidden_size: int = 1024
+ ffn_hidden_size: int = 4096
+ num_layers: int = 1
+ bias: bool = True
+ return_bias: bool = True
+ activation_fn: Optional[Callable] = DEFAULT_ACTIVATION_FN
+
+ # MoE arguments.
+ moe_num_experts: int = 1
+ moe_top_k: int = 1
+ moe_capacity_factor: int = 1
+ moe_normalize_expert_weights: Optional[Union[int, float]] = None
+ moe_loss_weight: float = 0.1
+ moe_jitter_eps: Optional[float] = None
+ moe_lbl_in_fp32: bool = False
+
+ # Parallelism arguments.
+ moe_expert_model_parallelism: bool = False
+ expert_parallel_group: Optional[dist.ProcessGroup] = None
+ pipeline_model_parallel_size: int = 1
+ num_layers_per_virtual_pipeline_stage: Optional[int] = None
+
+ # Compute arguments.
+ memory_optimized_mlp: bool = False
+ mlp_type: str = 'mlp'
+ mlp_impl: str = 'sparse'
+
+ # Initialization arguments.
+ fp16: bool = True
+ bf16: bool = False
+ device: Union[int, torch.device] = dataclasses.field(default_factory=torch.cuda.current_device)
+ init_method: InitFn = partial(torch.nn.init.normal_, mean=0.0, std=0.02)
+ output_layer_init_method: InitFn = init_method
+
+ # Benchmarking arguments.
+ uniform_expert_assignment: bool = False
+
+ # shared expert arguments
+ shared_expert: bool = False # enable using shared expert
+ fc_cls: Any = torch.nn.Linear # class of the fully connected layer in shared expert (purpose: to allow using custom FC layer eg te.Linear (for FP8))
+ fc_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict,) # kwargs for custom fc layers
+ remat_act_fn: bool = True # enable act fn to be rematerialized instead of stored
+ shared_expert_hidden_size: Optional[
+ int] = None # hidden size of the shared expert IF we want to set it to something different from hidden_size
+ shared_expert_weighted_sum: bool = False # enable using weighted sum for shared expert output (wieghted by number of experts used)
+
+ # Router Z-loss arguments
+ moe_zloss_weight: float = 0 # 1e-3 is a reasonable value
+ moe_zloss_in_fp32: bool = False
+
+ def __post_init__(self):
+ # Sparse MLP is not supported with triton >=3.2.0
+ # TODO: Remove this once sparse is supported with triton >=3.2.0
+ if self.__getattribute__('mlp_impl') == 'sparse':
+ try:
+ import triton
+ if triton.__version__ >= '3.2.0':
+ raise ValueError(
+ 'Sparse MLP is not supported with triton >=3.2.0. Please use mlp_impl="grouped" instead.',
+ )
+ except ImportError:
+ raise ImportError('Triton is required for sparse MLP implementation')
+
+ if self.__getattribute__('mlp_impl') == 'grouped':
+ grouped_gemm.assert_grouped_gemm_is_available()
+
+ if self.shared_expert_hidden_size is None:
+ self.shared_expert_hidden_size = self.ffn_hidden_size
+
+
+def from_megatron(megatron_args: Any):
+ args = Arguments()
+ for field in dataclasses.fields(args):
+ if hasattr(megatron_args, field.name):
+ setattr(args, field.name, getattr(megatron_args, field.name))
+ return args
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/_layers/common.py b/build/torch28-cxx11-cu126-x86_64-linux/_layers/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d07109702963ba48a3b94ab860807954dfd79c1
--- /dev/null
+++ b/build/torch28-cxx11-cu126-x86_64-linux/_layers/common.py
@@ -0,0 +1,26 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+from .arguments import Arguments
+
+
+def dtype(args: Arguments):
+ if args.fp16:
+ return torch.float16
+ elif args.bf16:
+ return torch.bfloat16
+ return None
+
+
+def cast_if_autocast_enabled(tensor):
+ if torch.is_autocast_enabled():
+ if tensor.device.type == 'cuda':
+ dtype = torch.get_autocast_gpu_dtype()
+ elif tensor.device.type == 'cpu':
+ dtype = torch.get_autocast_cpu_dtype()
+ else:
+ raise NotImplementedError()
+ return tensor.to(dtype=dtype)
+ return tensor
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/_layers/dmlp_registry.py b/build/torch28-cxx11-cu126-x86_64-linux/_layers/dmlp_registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..de2ed047042e438c7190ebb139b6f7f30009734c
--- /dev/null
+++ b/build/torch28-cxx11-cu126-x86_64-linux/_layers/dmlp_registry.py
@@ -0,0 +1,42 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Union
+
+from . import glu, mlp
+from .arguments import Arguments
+
+MlpType = Union[mlp.SparseMLP, glu.SparseGLU]
+
+_REGISTRY = {
+ 'mlp': {
+ 'grouped': mlp.GroupedMLP,
+ 'sparse': mlp.SparseMLP,
+ },
+ 'glu': {
+ 'grouped': glu.GroupedGLU,
+ 'sparse': glu.SparseGLU,
+ },
+}
+
+
+def get(args: Arguments) -> MlpType:
+ """Returns an MLP for use in a dMoE instance.
+
+ Uses the provided arguments to instantiate the appropriate
+ MLP instance. This only contains MLPs for use in dMoEs
+ (ie. only for the dropless versions of MoEs).
+
+ Args:
+ args: propagated Arguments dataclass.
+
+ Returns:
+ An instantiated MLP constructed using the input args.
+ """
+ if args.mlp_type not in _REGISTRY:
+ raise ValueError(f'Unsupported mlp type: {args.mlp_type}')
+
+ if args.mlp_impl not in _REGISTRY[args.mlp_type]:
+ raise ValueError(f'{args.mlp_type} does not support {args.mlp_impl} backend.',)
+
+ return _REGISTRY[args.mlp_type][args.mlp_impl](args)
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/_layers/dmoe.py b/build/torch28-cxx11-cu126-x86_64-linux/_layers/dmoe.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d0375a4df2f27134c4127e60be04f3b45693050
--- /dev/null
+++ b/build/torch28-cxx11-cu126-x86_64-linux/_layers/dmoe.py
@@ -0,0 +1,337 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import numpy as np
+import torch
+
+# try:
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/dmoe.py is needed.',
+# )
+
+# import megablocks.ops as ops
+# # from megablocks.ops import ops
+# from megablocks.layers import common, dmlp_registry, moe, mpu
+# from megablocks.layers.arguments import Arguments
+
+from .. import stk
+from .. import ops
+from . import common, dmlp_registry, moe, mpu
+from .arguments import Arguments
+
+def promote_scalar(x):
+ return x.view(1) if not len(x.size()) else x
+
+
+class ParallelDroplessMLP(moe.ParallelMLP):
+
+ def __init__(self, args: Arguments):
+ super(ParallelDroplessMLP, self).__init__(args)
+ self.hidden_size = args.hidden_size
+ self.ffn_hidden_size = mpu.features_per_rank(args)
+ self.blocking = 128
+ self.mlp = dmlp_registry.get(args)
+
+ # Calculate the number of bits needed to represent the column indices
+ # in the intermediate sparse matrix.
+ max_column_index = ((self.ffn_hidden_size * self.num_experts) // self.blocking)
+ self.transpose_sort_end_bit = max(
+ int(np.ceil(np.log2(max_column_index))),
+ 1,
+ )
+
+ def sparse_transpose(self, size, row_indices, column_indices, offsets):
+ block_columns = size[1] // self.blocking
+
+ # Sort row indices by column indices to get the transposed matrix's
+ # column indices.
+ #
+ # NOTE: Our sort operation uses the same width indices as the input values.
+ # To avoid overflow when we have large activation matrices we cast to
+ # 32-bit before sorting.
+ _, gather_indices = ops.sort(
+ column_indices.int(),
+ self.transpose_sort_end_bit,
+ )
+
+ # There are a constant number of blocks in every row of the sparse matrix.
+ # A blocks offset is:
+ #
+ # row_index * blocks_per_row + column_index % blocks_per_row
+ #
+ # Once we have the block offsets ordered for transposition we can divide
+ # by blocks_per_row to get the transposed column indices.
+ column_indices_t = row_indices.gather(0, gather_indices.long())
+ block_offsets_t = gather_indices.int()
+
+ zero = torch.zeros((1,), dtype=torch.int32, device=row_indices.device)
+ nnz_per_column = ops.histogram(column_indices, block_columns)
+ nnz_per_column = ops.inclusive_cumsum(nnz_per_column, 0)
+ if nnz_per_column.dim() == 0:
+ # This addresses an edge case when ffn_hidden_size is equal to self.blocking.
+ nnz_per_column = nnz_per_column.unsqueeze(0)
+ offsets_t = torch.cat([zero, nnz_per_column])
+ return column_indices_t, offsets_t, block_offsets_t
+
+ def topology(self, x, padded_bins):
+ padded_tokens, _ = x.size()
+ assert padded_tokens % self.blocking == 0
+ if self.ffn_hidden_size % self.blocking != 0:
+ raise ValueError(
+ f'The ffn_hidden_size {self.ffn_hidden_size} must be divisible by ' +
+ f'the block size {self.blocking}. Please update your configuration.',
+ )
+
+ # Offsets for the sparse matrix. All rows have the
+ # same number of nonzero blocks dictated by the
+ # dimensionality of a single expert.
+ block_rows = padded_tokens // self.blocking
+ blocks_per_row = self.ffn_hidden_size // self.blocking
+ offsets = torch.arange(
+ 0,
+ block_rows * blocks_per_row + 1,
+ blocks_per_row,
+ dtype=torch.int32,
+ device=x.device,
+ )
+
+ # Indices for the sparse matrix. The indices for
+ # the intermediate matrix are dynamic depending
+ # on the mapping of tokens to experts.
+ column_indices = ops.topology(
+ padded_bins,
+ self.blocking,
+ block_rows,
+ blocks_per_row,
+ )
+
+ # TODO(tgale): This is unused. Remove the need for this in stk.
+ # For now, use meta init to save the device memory.
+ data = torch.empty(
+ column_indices.numel(),
+ self.blocking,
+ self.blocking,
+ dtype=common.dtype(self.args),
+ device='meta',
+ )
+ shape = (
+ padded_tokens,
+ self.ffn_hidden_size * mpu.experts_per_rank(self.args),
+ )
+ row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
+ column_indices_t, offsets_t, block_offsets_t = self.sparse_transpose(
+ shape,
+ row_indices,
+ column_indices,
+ offsets,
+ )
+ return stk.Matrix(
+ shape,
+ data,
+ row_indices,
+ column_indices,
+ offsets,
+ column_indices_t,
+ offsets_t,
+ block_offsets_t,
+ )
+
+ def indices_and_padded_bins(self, top_experts):
+ # Sort the expert ids to produce the scatter/gather
+ # indices for the permutation.
+ top_experts = top_experts.int()
+ bin_ids, indices = ops.sort(top_experts, self.sort_end_bit)
+
+ # Histogram the expert ids to identify the number of
+ # tokens routed to each expert.
+ tokens_per_expert = ops.histogram(top_experts, self.num_experts)
+
+ # Round the token counts up to the block size used in
+ # the matrix muliplications. Caculate the starting
+ # position of each bin.
+ padded_tokens_per_expert = ops.round_up(
+ tokens_per_expert,
+ self.blocking,
+ )
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ padded_bins = promote_scalar(padded_bins)
+
+ # Calculate the bin bounds for the sorted tokens.
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ bins = promote_scalar(bins)
+ return indices, bin_ids, bins, padded_bins, tokens_per_expert
+
+ def sparse_forward_once(self, x, expert_weights, top_experts):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, padded_bins, tokens_per_expert = (self.indices_and_padded_bins(top_experts))
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.padded_gather(
+ x,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ self.top_k,
+ )
+
+ # Create the sparse matrix topology.
+ with torch.no_grad():
+ topo = self.topology(x, padded_bins)
+
+ # Perform the expert computation.
+ x = self.mlp(x, topo)
+
+ # Un-route the data for the MoE output.
+ x = ops.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ padded_bins,
+ self.top_k,
+ )
+ return x, tokens_per_expert
+
+ # For use in the base-class parallel_forward_once.
+ def sparse_permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy, # unused
+ top_k,
+ ):
+
+ # Round the token counts up to the block size used in the matrix
+ # multiplication. Calculate the starting position of each bin.
+ padded_tokens_per_expert = ops.round_up(
+ tokens_per_expert,
+ self.blocking,
+ )
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ padded_bins = promote_scalar(padded_bins)
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
+
+ # Create the sparse matrix topology.
+ with torch.no_grad():
+ topo = self.topology(x, padded_bins)
+
+ # Perform the expert computation.
+ x = self.mlp(x, topo)
+
+ # Un-route the data for the MoE output.
+ return ops.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ def grouped_forward_once(self, x, expert_weights, top_experts):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ out = self.grouped_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ -1, # unused
+ self.args.moe_top_k,
+ )
+ return out, tokens_per_expert
+
+ def grouped_permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy, # unused
+ top_k,
+ ):
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.gather(x, indices, bin_ids, bins, top_k)
+
+ # Perform the expert computation.
+ x = self.mlp(x, tokens_per_expert)
+
+ # Un-route the data for the MoE output.
+ return ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
+
+ def forward_once(self, x, expert_weights, top_experts):
+ if self.args.mlp_impl == 'sparse':
+ return self.sparse_forward_once(x, expert_weights, top_experts)
+ else:
+ return self.grouped_forward_once(x, expert_weights, top_experts)
+
+ def permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ ):
+ if self.args.mlp_impl == 'sparse':
+ return self.sparse_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ )
+ else:
+ return self.grouped_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ )
+
+
+class dMoE(moe.MoE):
+
+ def _init_experts_mlp(self, args: Arguments):
+ return ParallelDroplessMLP(args)
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/_layers/gelu.py b/build/torch28-cxx11-cu126-x86_64-linux/_layers/gelu.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4c9e6532798615b5c12c96694241a4c18ee8f7b
--- /dev/null
+++ b/build/torch28-cxx11-cu126-x86_64-linux/_layers/gelu.py
@@ -0,0 +1,52 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# try:
+# import stk
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/gelu.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+import torch.nn.functional as F
+
+
+@torch.jit.script
+def _gelu_backward_inplace(g, x):
+ tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
+ ff = (0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out))
+ return g.mul_(ff)
+
+
+def gelu_backward_(grad: stk.Matrix, x: stk.Matrix):
+ # NOTE: The two sparse matrices must have the same topology.
+ if isinstance(grad, stk.Matrix) and isinstance(x, stk.Matrix):
+ return stk.Matrix(
+ x.size(),
+ _gelu_backward_inplace(grad.data, x.data),
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
+ return _gelu_backward_inplace(grad, x)
+
+
+def gelu(x: stk.Matrix):
+ assert isinstance(x, stk.Matrix)
+ return stk.Matrix(
+ x.size(),
+ F.gelu(x.data, approximate='tanh'),
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/_layers/glu.py b/build/torch28-cxx11-cu126-x86_64-linux/_layers/glu.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f297a41ff6a1a2a285f5b461951672364b898da
--- /dev/null
+++ b/build/torch28-cxx11-cu126-x86_64-linux/_layers/glu.py
@@ -0,0 +1,244 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# import stk.ops
+# try:
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/glu.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+
+# from megablocks import grouped_gemm_util as gg
+# from megablocks.layers import common, mpu
+# from megablocks.layers.activation_fn import act_fn
+# from megablocks.layers.arguments import Arguments
+# from megablocks.layers.mlp import (
+# SharedMLP,
+# SparseMLP,
+# create_dmoe_expert_weights,
+# resolve_dtensor,
+# )
+
+from .. import grouped_gemm_util as gg
+from . import common, mpu
+from .activation_fn import act_fn
+from .arguments import Arguments
+from .mlp import (
+ SharedMLP,
+ SparseMLP,
+ create_dmoe_expert_weights,
+ resolve_dtensor,
+)
+
+
+class SparseGLU(SparseMLP):
+
+ def __init__(self, args: Arguments):
+ super().__init__(args)
+ self.v1 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ with torch.no_grad():
+ self.v1.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ ),
+ )
+
+ mpu.set_expert_model_parallel_attributes(
+ self.v1,
+ self._should_set_parallelism_attribute,
+ )
+
+ def forward(self, x, topo):
+ if self.args.memory_optimized_mlp:
+ raise NotImplementedError(
+ 'Memory optimized implementation not yet supported with GLU with sparse kernels.',
+ )
+
+ w1, v1, w2 = self.scale_grad(self.w1), self.scale_grad(self.v1,), self.scale_grad(self.w2)
+ w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2)
+
+ # Compute the GLU.
+ x1 = stk.ops.sdd(x, w1.t(), topo)
+ x2 = stk.ops.sdd(x, v1.t(), topo)
+
+ activation_fn_out = act_fn(x1, self.args.activation_fn)
+ x1 = stk.ops.mul(activation_fn_out, x2)
+
+ return stk.ops.dsd(x1, w2)
+
+
+class MemoryOptimizedGroupedGLU(torch.autograd.Function):
+ """GroupedMLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, v1, w2, batch_sizes, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ v1 = v1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], v1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not v1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1', 'v1' and 'w2'.")
+
+ # Layer 0: x @ w1.t().
+ assert gg.backend is not None
+ sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True)
+ v1_out = gg.backend.gmm(x, v1, batch_sizes, trans_b=True)
+
+ # GeLU.
+ activation_fn_out = activation_fn(sdd_out) * v1_out
+
+ # Layer 1: x @ w2.
+ dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, v1, w2, batch_sizes, x, sdd_out, v1_out)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx, ddsd_out):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # Unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, v1, w2 = saved_tensors[:3]
+ batch_sizes = saved_tensors[3]
+ x = saved_tensors[4]
+ sdd_out, v1_out = saved_tensors[5:7]
+
+ # Rematerialize activation_fn output.
+ activation_fn = ctx.activation_fn
+ with torch.set_grad_enabled(True):
+ sdd_out.requires_grad = True
+ v1_out.requires_grad = True
+ activation_fn_out = activation_fn(sdd_out) * v1_out
+ activation_grad_fn = activation_fn_out.backward
+
+ # Compute dw2 with recomputed activation_fn output.
+ assert gg.backend is not None
+ dw2 = gg.backend.gmm(
+ activation_fn_out,
+ ddsd_out,
+ batch_sizes,
+ trans_a=True,
+ )
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ gg.backend.gmm(
+ ddsd_out,
+ w2,
+ batch_sizes,
+ trans_b=True,
+ c=dactivation_fn_out,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out)
+ dsdd_out = sdd_out.grad
+ dv1_out = v1_out.grad
+
+ # Compute dw1.
+ dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True)
+
+ # Compute dv1.
+ dv1 = gg.backend.gmm(dv1_out, x, batch_sizes, trans_a=True)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ dx = ddsd_out
+ gg.backend.gmm(dsdd_out, w1, batch_sizes, c=dx)
+ dx += gg.backend.gmm(dv1_out, v1, batch_sizes)
+ return dx, dw1, dv1, dw2, None, None
+
+
+memory_optimized_grouped_glu = MemoryOptimizedGroupedGLU.apply
+
+
+class GroupedGLU(SparseGLU):
+
+ def forward(self, x, tokens_per_expert):
+ batch_sizes = tokens_per_expert.cpu().to(torch.long)
+ w1, v1, w2 = (
+ self.scale_grad(self.w1),
+ self.scale_grad(self.v1),
+ self.scale_grad(self.w2),
+ )
+ w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2)
+
+ # Re-shape the weights for the grouped GEMMs.
+ ne = mpu.experts_per_rank(self.args)
+ w1 = w1.view(ne, -1, self.args.hidden_size)
+ v1 = v1.view(ne, -1, self.args.hidden_size)
+ w2 = w2.view(ne, -1, self.args.hidden_size)
+
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_grouped_glu(
+ x,
+ w1,
+ v1,
+ w2,
+ batch_sizes,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ assert gg.ops is not None
+ x1 = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
+ x2 = gg.ops.gmm(x, v1, batch_sizes, trans_b=True)
+ x1 = self.args.activation_fn(x1) * x2
+ return gg.ops.gmm(x1, w2, batch_sizes)
+
+
+class SharedGLU(SharedMLP):
+ """GPU for shared expert.
+
+ Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTGLU class
+ """
+
+ def __init__(self, args: Arguments):
+ super().__init__(args)
+ self.gate_proj = args.fc_cls(
+ args.hidden_size,
+ self.args.shared_expert_hidden_size,
+ **self.fc_kwargs,
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/_layers/memory_test.py b/build/torch28-cxx11-cu126-x86_64-linux/_layers/memory_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..74d1166931b712635131985b25a89f4ca23e576d
--- /dev/null
+++ b/build/torch28-cxx11-cu126-x86_64-linux/_layers/memory_test.py
@@ -0,0 +1,103 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import gc
+
+import torch
+import torch.distributed as dist
+
+# from megablocks.layers import arguments, dmoe
+from . import arguments, dmoe
+
+_TESTS = ((8, 2048, 4096, 4096, 32, 4),)
+
+
+def get_tensors():
+ ptrs = set()
+ out = []
+ for obj in gc.get_objects():
+ if torch.is_tensor(obj):
+ if not obj.is_contiguous() or obj.data_ptr() in ptrs:
+ continue
+ out.append(obj)
+ ptrs.add(obj.data_ptr())
+ return out
+
+
+def test_memory(
+ group,
+ batch_size,
+ sequence_length,
+ hidden_size,
+ ffn_hidden_size,
+ num_experts,
+ top_k,
+):
+ args = arguments.Arguments(
+ hidden_size=hidden_size,
+ ffn_hidden_size=ffn_hidden_size,
+ moe_num_experts=num_experts,
+ moe_top_k=top_k,
+ moe_expert_model_parallelism=True,
+ expert_parallel_group=group,
+ fp16=False,
+ bf16=True,
+ device=torch.cuda.current_device(),
+ )
+ layer = dmoe.dMoE(args).cuda()
+
+ x = torch.randn((batch_size, sequence_length, hidden_size),
+ device=torch.cuda.current_device(),
+ dtype=torch.bfloat16).requires_grad_(True)
+ torch.cuda.empty_cache()
+
+ # Run forward + backward.
+ # with torch.autograd.detect_anomaly():
+ out, _ = layer(x)
+ out.mean().backward()
+
+ # Report peak memory.
+ mem = torch.cuda.max_memory_allocated()
+ print('Max Memory Allocated = {:0.0f}MiB'.format(mem / 1e6))
+ print('Max Memory Reserved = {:0.0f}MiB'.format(torch.cuda.max_memory_reserved() / 1e6,),)
+
+ # Calculate weight and gradient memory usage.
+ weight_memory = 2 * (
+ layer.router.layer.weight.numel() + layer.experts.mlp.w1.numel() + layer.experts.mlp.w2.numel()
+ )
+
+ def grad_numel(x):
+ if x.grad is not None:
+ return x.grad.numel()
+ return 0
+
+ grad_memory = 2 * (
+ grad_numel(layer.router.layer.weight) + grad_numel(layer.experts.mlp.w1) + grad_numel(layer.experts.mlp.w2)
+ )
+ weight_memory += grad_memory
+
+ print('Weight Memory Allocated = {:0.0f}MiB'.format(weight_memory / 1e6))
+ print('Activation Memory Allocated = {:0.0f}MiB'.format((mem - weight_memory) / 1e6,),)
+
+ # Manually calculate GPU memory usage from the garbage
+ # collector.
+ gc.collect()
+ total = 0
+ tensors = get_tensors()
+ tensors = sorted(tensors, key=lambda x: -x.numel())
+ for i, t in enumerate(tensors):
+ total += t.numel()
+ print(f'{i}: {t.shape}, {t.numel() * 2}')
+ del tensors
+
+ print('Total Bytes Found = {:0.0f}MiB'.format(total * 2 / 1e6))
+
+
+if __name__ == '__main__':
+ assert dist.is_available()
+ group = dist.init_process_group(backend='nccl')
+ local_rank = dist.get_rank(group)
+ torch.cuda.set_device(local_rank)
+
+ for args in _TESTS:
+ test_memory(group, *args)
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/_layers/mlp.py b/build/torch28-cxx11-cu126-x86_64-linux/_layers/mlp.py
new file mode 100644
index 0000000000000000000000000000000000000000..c99afb9904c24a8b6a83e79059cd1251dbbfd99e
--- /dev/null
+++ b/build/torch28-cxx11-cu126-x86_64-linux/_layers/mlp.py
@@ -0,0 +1,587 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# try:
+# import stk
+# import stk.backend.triton_kernels
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/mlp.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+from packaging import version
+
+# from megablocks import grouped_gemm_util as gg
+# from megablocks.layers import common, gelu, mpu
+# from megablocks.layers.activation_fn import act_fn
+# from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
+
+from .. import grouped_gemm_util as gg
+from . import common, gelu, mpu
+from .activation_fn import act_fn
+from .arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
+
+class ScaleGradient(torch.autograd.Function):
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx: Any, x: torch.Tensor, scale: float):
+ ctx.scale = scale
+ return x
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx: torch.Tensor, grad: torch.Tensor):
+ return grad * ctx.scale, None
+
+
+scale_gradient = ScaleGradient.apply
+
+
+def resolve_dtensor(weight: torch.Tensor):
+ if version.parse(torch.__version__) >= version.parse('2.0.0'):
+ from torch.distributed._tensor import DTensor
+ if isinstance(weight, DTensor):
+ return weight.to_local()
+ return weight
+
+
+def create_moe_expert_weights(
+ args: Arguments,
+ num_experts: int,
+ ffn_hidden_size: int,
+ hidden_size: int,
+ init_method: InitFn,
+):
+ # Create the entire weight matrix such that the sampled weights will
+ # not vary between data parallelism and expert model parallelism for
+ # the same random seed.
+ master_weights = torch.empty(
+ num_experts,
+ ffn_hidden_size,
+ hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ )
+ init_method(master_weights)
+
+ if not args.moe_expert_model_parallelism:
+ return master_weights
+
+ # Calculate the amount of sharding in each dimension.
+ expert_sharding_degree = mpu.expert_sharding_degree(args)
+ hidden_sharding_degree = mpu.hidden_sharding_degree(args)
+
+ # Calculate the experts per rank.
+ #
+ # NOTE: We assign ranks to be expert parallel before going
+ # tensor parallel.
+ rank = mpu.get_expert_parallel_rank(args)
+ expert_rank = rank % expert_sharding_degree
+ num_experts_per_rank = num_experts // expert_sharding_degree
+ start_expert = expert_rank * num_experts_per_rank
+ end_expert = (expert_rank + 1) * num_experts_per_rank
+
+ # Calculate the rows per rank.
+ row_rank = rank // expert_sharding_degree
+ num_rows_per_rank = ffn_hidden_size // hidden_sharding_degree
+ start_row = row_rank * num_rows_per_rank
+ end_row = (row_rank + 1) * num_rows_per_rank
+
+ # Slice the weight matrix to get the chunk for this rank.
+ with torch.no_grad():
+ weights = master_weights[start_expert:end_expert, start_row:end_row]
+ return weights
+
+
+class MLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ # expert_parallel_world_size = mpu.get_expert_parallel_world_size(args)
+ experts_per_rank = mpu.experts_per_rank(args)
+
+ self.w1 = torch.nn.Parameter(
+ torch.empty(
+ experts_per_rank,
+ args.hidden_size,
+ mpu.features_per_rank(args),
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ self.w2 = torch.nn.Parameter(
+ torch.empty(
+ experts_per_rank,
+ mpu.features_per_rank(args),
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w1,
+ args.moe_expert_model_parallelism,
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w2,
+ args.moe_expert_model_parallelism,
+ )
+
+ # Initialize the parameters for the MLP.
+ #
+ # NOTE: It is important that we create the weight tensors prior
+ # to creating the master weights and slicing our the piece for
+ # this rank. If the master weights are created first the PyTorch
+ # caching allocator appears to use the same memory block for these
+ # and the slice which causes large increases in our peak memory
+ # usage.
+ with torch.no_grad():
+ w1 = create_moe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ )
+ self.w1.copy_(w1.transpose(1, 2).contiguous())
+ self.w2.copy_(
+ create_moe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.output_layer_init_method,
+ ),
+ )
+
+ self.gradient_scale = None
+ if self.args.moe_expert_model_parallelism:
+ self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,)
+
+ def scale_grad(self, w):
+ if self.gradient_scale is None:
+ return w
+ return scale_gradient(w, self.gradient_scale)
+
+ def forward(self, x):
+ w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2)
+ w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2)
+ x = torch.bmm(x, w1)
+ x = self.args.activation_fn(x)
+ return torch.bmm(x, w2)
+
+
+def create_dmoe_expert_weights(
+ args: Arguments,
+ num_experts: int,
+ rows: int,
+ columns: int,
+ init_method: InitFn,
+):
+ weights = create_moe_expert_weights(
+ args,
+ num_experts,
+ rows,
+ columns,
+ init_method,
+ )
+ return weights.view([-1, columns])
+
+
+class MemoryOptimizedMLP(torch.autograd.Function):
+ """Sparse MLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, w2, topo, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.")
+
+ topo_tensors = (
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets,
+ topo.column_indices_t,
+ topo.offsets_t,
+ topo.block_offsets_t,
+ )
+
+ # Layer 0: x @ w1.t().
+ sdd_out = stk.ops.sdd(x, w1.t(), topo)
+
+ # GeLU.
+ activation_fn_out = act_fn(sdd_out, activation_fn)
+
+ # Layer 1: x @ w2.
+ dsd_out = stk.ops.dsd(activation_fn_out, w2)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.shape = topo.shape
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.data.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, w2, *topo_tensors, x, sdd_out.data)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx, ddsd_out):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, w2 = saved_tensors[:2]
+ topo_tensors = saved_tensors[2:8]
+ x = saved_tensors[8]
+ sdd_out_data = saved_tensors[9]
+
+ # rematerialize activation function output
+ activation_fn = ctx.activation_fn
+ sdd_out = stk.Matrix(ctx.shape, sdd_out_data, *topo_tensors)
+ activation_fn_out, activation_grad_fn = act_fn(
+ sdd_out,
+ activation_fn,
+ return_grad_fn=True,
+ )
+
+ # Compute dw2 with recomputed activation_fn output.
+ dw2 = stk.ops.dsd(activation_fn_out.t(), ddsd_out)
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ stk.backend.triton_kernels.sdd(
+ ddsd_out,
+ w2.t(),
+ dactivation_fn_out.shape,
+ dactivation_fn_out.data,
+ dactivation_fn_out.offsets,
+ dactivation_fn_out.row_indices,
+ dactivation_fn_out.column_indices,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ if activation_fn is DEFAULT_ACTIVATION_FN:
+ dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out)
+ else:
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out.data)
+ dsdd_out = stk.Matrix(ctx.shape, sdd_out.data.grad, *topo_tensors)
+
+ # Compute dw1.
+ dw1 = stk.ops.dsd(dsdd_out.t(), x)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ stk.backend.triton_kernels.dsd(
+ dsdd_out.shape,
+ dsdd_out.data,
+ dsdd_out.offsets,
+ dsdd_out.row_indices,
+ dsdd_out.column_indices,
+ dsdd_out.offsets_t,
+ dsdd_out.column_indices_t,
+ dsdd_out.block_offsets_t,
+ False,
+ w1,
+ ddsd_out,
+ )
+ dx = ddsd_out
+ return dx, dw1, dw2, None, None
+
+
+memory_optimized_mlp = MemoryOptimizedMLP.apply
+
+
+class SparseMLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ self._num_rows_per_rank = mpu.experts_per_rank(args) * mpu.features_per_rank(args)
+
+ self.w1 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ self.w2 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+
+ # Initialize the parameters for the MLP.
+ #
+ # NOTE: It is important that we create the weight tensors prior
+ # to creating the master weights and slicing our the piece for
+ # this rank. If the master weights are created first the PyTorch
+ # caching allocator appears to use the same memory block for these
+ # and the slice which causes large increases in our peak memory
+ # usage.
+ with torch.no_grad():
+ self.w1.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ ),
+ )
+ self.w2.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.output_layer_init_method,
+ ),
+ )
+
+ self._should_set_parallelism_attribute = args.moe_expert_model_parallelism
+ mpu.set_expert_model_parallel_attributes(
+ self.w1,
+ self._should_set_parallelism_attribute,
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w2,
+ self._should_set_parallelism_attribute,
+ )
+
+ self.gradient_scale = None
+ if self.args.moe_expert_model_parallelism:
+ self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,)
+
+ def scale_grad(self, w):
+ if self.gradient_scale is None:
+ return w
+ return scale_gradient(w, self.gradient_scale)
+
+ def forward(self, x, topo):
+ w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2)
+ w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2)
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_mlp(
+ x,
+ w1,
+ w2,
+ topo,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ x = stk.ops.sdd(x, w1.t(), topo)
+ activation_fn_out = act_fn(x, self.args.activation_fn)
+ return stk.ops.dsd(activation_fn_out, w2)
+
+
+class MemoryOptimizedGroupedMLP(torch.autograd.Function):
+ """GroupedMLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, w2, batch_sizes, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.")
+
+ # Layer 0: x @ w1.t().
+ assert gg.backend is not None
+ sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True)
+
+ # activation_fn
+ activation_fn_out = activation_fn(sdd_out)
+
+ # Layer 1: x @ w2.
+ dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, w2, batch_sizes, x, sdd_out)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx: Any, ddsd_out: torch.Tensor):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # Unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, w2 = saved_tensors[:2]
+ batch_sizes = saved_tensors[2]
+ x = saved_tensors[3]
+ sdd_out = saved_tensors[4]
+
+ # Rematerialize activation_fn output.
+ activation_fn = ctx.activation_fn
+ with torch.set_grad_enabled(True):
+ sdd_out.requires_grad = True
+ activation_fn_out = activation_fn(sdd_out)
+ activation_grad_fn = activation_fn_out.backward
+
+ # Compute dw2 with recomputed activation_fn output.
+ assert gg.backend is not None
+ dw2 = gg.backend.gmm(
+ activation_fn_out,
+ ddsd_out,
+ batch_sizes,
+ trans_a=True,
+ )
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ gg.backend.gmm(
+ ddsd_out,
+ w2,
+ batch_sizes,
+ trans_b=True,
+ c=dactivation_fn_out,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ if activation_fn is DEFAULT_ACTIVATION_FN:
+ dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out)
+ else:
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out)
+ dsdd_out = sdd_out.grad
+
+ # Compute dw1.
+ dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ gg.backend.gmm(dsdd_out, w1, batch_sizes, c=ddsd_out)
+ dx = ddsd_out
+ return dx, dw1, dw2, None, None
+
+
+memory_optimized_grouped_mlp = MemoryOptimizedGroupedMLP.apply
+
+
+class GroupedMLP(SparseMLP):
+
+ def forward(self, x, tokens_per_expert):
+ batch_sizes = tokens_per_expert.cpu().to(torch.long)
+ w1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.w2))
+
+ # Re-shape the weights for the grouped GEMMs.
+ ne = mpu.experts_per_rank(self.args)
+ w1 = resolve_dtensor(w1).view(ne, -1, self.args.hidden_size)
+ w2 = resolve_dtensor(w2).view(ne, -1, self.args.hidden_size)
+
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_grouped_mlp(
+ x,
+ w1,
+ w2,
+ batch_sizes,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ assert gg.ops is not None
+ x = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
+ x = self.args.activation_fn(x)
+ return gg.ops.gmm(x, w2, batch_sizes)
+
+
+class SharedMLP(torch.nn.Module):
+ """MLP for shared expert.
+
+ Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTMLP class
+ """
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ self.fc_kwargs: dict[str, Any] = {
+ 'bias': args.bias,
+ 'device': args.device,
+ }
+ self.fc_kwargs.update(args.fc_kwargs)
+
+ self.up_proj = args.fc_cls(
+ args.hidden_size,
+ args.shared_expert_hidden_size,
+ **self.fc_kwargs,
+ )
+ self.act = args.activation_fn
+ self.down_proj = args.fc_cls(
+ args.shared_expert_hidden_size,
+ args.hidden_size,
+ **self.fc_kwargs,
+ )
+ self.down_proj._is_residual = True # a flag for llm-foundry init
+
+ def add_experts_sharedexpert(
+ self,
+ shared_expert_out: torch.Tensor,
+ expert_out: torch.Tensor,
+ ) -> torch.Tensor:
+ # Helper function to add expert output to shared expert output
+ # with optional weighted sum.
+ if self.args.shared_expert_weighted_sum:
+ # enable using weighted sum for shared expert output
+ # wieghted by number of experts used
+ t_experts = self.args.moe_top_k + 1
+ sh_mlp_out = shared_expert_out / t_experts
+ return sh_mlp_out.add(
+ expert_out,
+ alpha=(self.args.moe_top_k / t_experts),
+ )
+
+ return shared_expert_out + expert_out
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.down_proj(self.act(self.up_proj(x)))
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/_layers/moe.py b/build/torch28-cxx11-cu126-x86_64-linux/_layers/moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0a4aeaacc9c86fc70944e730c53f7a55644e05e
--- /dev/null
+++ b/build/torch28-cxx11-cu126-x86_64-linux/_layers/moe.py
@@ -0,0 +1,507 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Optional, Tuple
+
+import numpy as np
+import torch
+import torch.distributed as dist
+
+# import megablocks.ops as ops
+# from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry
+# from megablocks.layers.all_to_all import all_to_all
+# from megablocks.layers.arguments import Arguments
+
+from ..ops import (
+ sort,
+ histogram,
+ inclusive_cumsum,
+ exclusive_cumsum,
+ binned_gather,
+ binned_scatter,
+ gather,
+ scatter,
+ repeat,
+ replicate,
+)
+
+from . import common, mlp, mpu, router, sharedexpert_registry
+from .arguments import Arguments
+from .all_to_all import all_to_all
+
+_LOAD_BALANCING_LOSS = []
+
+
+def save_load_balancing_loss(loss):
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.append(loss)
+
+
+def get_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ return _LOAD_BALANCING_LOSS
+
+
+def clear_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.clear()
+
+
+def batched_load_balancing_loss(args: Arguments):
+ if args.moe_loss_weight == 0:
+ return 0.0
+
+ # tokens_per_expert[i].shape = (num_experts)
+ # expert_scores[i].shape = (tokens, num_experts)
+ tokens_per_expert, expert_scores = zip(*get_load_balancing_loss())
+ num_layers_per_pipeline_stage = (args.num_layers // args.pipeline_model_parallel_size)
+ if args.num_layers_per_virtual_pipeline_stage is not None:
+ num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage
+
+ if len(tokens_per_expert) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f'Expected {num_layers_per_pipeline_stage} token_per_experts '
+ f'but found {len(tokens_per_expert)}.\nnum_layers = '
+ f'{args.num_layers}\npipeline_model_parallel_size = '
+ f'{args.pipeline_model_parallel_size}\n'
+ 'num_layers_per_virtual_pipeline_stage'
+ f' = {args.num_layers_per_virtual_pipeline_stage}',
+ )
+ if len(expert_scores) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f'Expected {num_layers_per_pipeline_stage} expert_scores '
+ f'but found {len(tokens_per_expert)}.\nnum_layers = '
+ f'{args.num_layers}\npipeline_model_parallel_size = '
+ f'{args.pipeline_model_parallel_size}\n'
+ 'num_layers_per_virtual_pipeline_stage'
+ f' = {args.num_layers_per_virtual_pipeline_stage}',
+ )
+
+ # Verify the shape of the tokens_per_expert and expert_scores tensors.
+ assert all((x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert))
+
+ tokens = expert_scores[0].shape[0]
+ assert all(((x.ndim == 2 and x.shape[1] == args.moe_num_experts and x.shape[0] == tokens) for x in expert_scores))
+
+ # Concatenate the contributions of each layer and convert to
+ # the correct types and formats for the dot product.
+ expert_scores = torch.cat(expert_scores, dim=1)
+ if args.moe_lbl_in_fp32:
+ expert_scores = expert_scores.float()
+ if tokens != 0:
+ expert_scores = expert_scores.mean(dim=0)
+ else:
+ expert_scores = expert_scores.sum(dim=0)
+ tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype)
+
+ expected_values = num_layers_per_pipeline_stage * args.moe_num_experts
+ assert tokens_per_expert.numel() == expected_values
+ assert expert_scores.numel() == expected_values
+
+ # Calculate the total scale across all factors.
+ #
+ # loss_weight * num_experts / (num_layers * tokens * top_k)
+ scale_numerator = (args.moe_num_experts * args.moe_loss_weight)
+ scale_denominator = (args.num_layers * tokens * args.moe_top_k)
+ scale = scale_numerator / scale_denominator
+ return scale * torch.dot(tokens_per_expert, expert_scores)
+
+
+# NOTE: This class defines MoE expert computation, including expert model parallel
+# communication. When using FSDP on top of MegaBlocks this is the module that should
+# be wrapped s.t. the weight all-gathers can be scheduled *before* the expert model
+# parallel all2all.
+class ParallelMLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super(ParallelMLP, self).__init__()
+ self.args = args
+
+ # Calculate the number of experts in total and the number of experts
+ # owned by this rank.
+ # world_size = mpu.get_expert_parallel_world_size(args)
+ self.num_experts = args.moe_num_experts
+ self.top_k = self.args.moe_top_k
+
+ # Calculate the number of bits needed to represent the expert indices
+ # so that we can pass it to radix sort.
+ self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1)
+
+ # Expert MLP.
+ self.mlp = mlp.MLP(args)
+
+ self.bias: Optional[torch.Tensor]
+ if self.args.bias:
+ # Note that the output bias is not parallelized with expert
+ # model parallelism.
+ self.bias = torch.nn.Parameter(
+ torch.empty(
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ torch.nn.init.zeros_(self.bias)
+ else:
+ self.register_parameter('bias', None)
+
+ # Select the forward function for the operating mode.
+ self.forward_fn = (self.parallel_forward_once if args.moe_expert_model_parallelism else self.forward_once)
+
+ def expert_capacity(self, tokens: int) -> int:
+ world_size = mpu.get_expert_parallel_world_size(self.args)
+ tokens_per_expert = (self.top_k * tokens * world_size / self.num_experts)
+ return int(self.args.moe_capacity_factor * tokens_per_expert)
+
+ def load_balancing_loss(self, tokens_per_expert: torch.Tensor, expert_scores: torch.Tensor):
+ """Calculate the load balancing loss contribution."""
+ assert len(expert_scores.size()) == 2
+ tokens, num_experts = expert_scores.size()
+ assert num_experts == self.num_experts
+ assert len(tokens_per_expert.size()) == 1
+ num_experts, = tokens_per_expert.size()
+ assert num_experts == self.num_experts
+ scale = self.num_experts / (tokens * self.top_k)
+ return scale * torch.dot(
+ tokens_per_expert.to(expert_scores.dtype),
+ expert_scores.mean(dim=0),
+ )
+
+ def indices_and_bins(self,
+ top_expert: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ # Sort the expert ids to produce the scatter/gather
+ # indices for the permutation.
+ #
+ # TODO(tgale): Is it worth doing this conversion to 32-bit
+ # prior? Could we place the `torch.max` operation to return
+ # 32-bit expert indices?
+ top_expert = top_expert.int()
+ # output = ops.sort(top_expert, self.sort_end_bit)
+ output = sort(top_expert, self.sort_end_bit)
+ assert output is not None
+ bin_ids, indices = output
+
+ # Histogram the expert ids to identify the number of
+ # tokens routed to each expert.
+ #
+ # TODO(tgale): Does the sorted data produce a more favorable
+ # data distribution for histogram? Or is the op parallelism
+ # worth more?
+ # tokens_per_expert = ops.histogram(top_expert, self.num_experts)
+ tokens_per_expert = histogram(top_expert, self.num_experts)
+
+ # Calculate the bin bounds for the sorted tokens.
+ # bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ bins = inclusive_cumsum(tokens_per_expert, 0)
+ assert bins is not None
+ bins = bins.view(1) if not len(bins.size()) else bins
+
+ assert isinstance(indices, torch.Tensor)
+ assert isinstance(bin_ids, torch.Tensor)
+ assert isinstance(bins, torch.Tensor)
+ assert isinstance(tokens_per_expert, torch.Tensor)
+
+ return indices, bin_ids, bins, tokens_per_expert
+
+ def permute_and_compute(
+ self,
+ x: torch.Tensor,
+ tokens_per_expert: int, # unused
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor, # unused
+ expert_weights: torch.Tensor,
+ bins: torch.Tensor,
+ expert_capacity: int,
+ top_k: int,
+ ):
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ # output = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
+ output = binned_gather(x, indices, bins, expert_capacity, top_k)
+ assert output is not None
+ x = output
+
+ # Perform the expert computation. Note that we don't
+ # use biases for these linear operations.
+ x = self.mlp(x)
+
+ # Un-route the data for the MoE output.
+ # return ops.binned_scatter(x, indices, expert_weights, bins, top_k)
+ return binned_scatter(x, indices, expert_weights, bins, top_k)
+
+
+ def forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ # If expert_capacity is set to zero, set the number of tokens
+ # per expert to the maximum we need to avoid dropping tokens.
+ sl, bs, _ = x.size()
+ expert_capacity = self.expert_capacity(sl * bs)
+ if expert_capacity == 0:
+ expert_capacity = torch.max(tokens_per_expert).item()
+
+ x = self.permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ self.top_k,
+ )
+ return x, tokens_per_expert
+
+ def parallel_forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ # NOTE: This function implements the same computation as forward_once
+ # but with expert model parallelism.
+ #
+ # 1. Permute the tokens locally so that they are grouped by their
+ # expert assignments. This allows us to transfer all of the tokens
+ # for a remote device in one communication primitive.
+ #
+ # 2. Permute the tokens across the expert parallel devices. After
+ # this is completed each device has all of the tokens assigned to
+ # its set of experts in its local HBM.
+ #
+ # 3. Permute the tokens locally so that they are grouped by their
+ # expert assignement. After the distributed permutation the tokens
+ # are grouped by which device they came from. We re-order them
+ # locally to allow for efficient computation.
+ #
+ # After this series of permutations we compute the linear layers
+ # and then repeat these three steps in reverse to produce the final
+ # output.
+ #
+ # Compute the mapping of local tokens to experts.
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ # If we're sharding the experts along the hidden dimension
+ # multiple devices own parts of the same sets of experts.
+ # Replicate the token counts so every device gets the counts.
+ # repeated_tokens_per_expert = ops.repeat(
+ repeated_tokens_per_expert = repeat(
+ tokens_per_expert,
+ (mpu.hidden_sharding_degree(self.args),),
+ )
+
+ # Pass token count information to the device on which the
+ # target expert resides.
+ parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert,)
+ tpe_handle = dist.all_to_all_single(
+ parallel_tokens_per_expert,
+ repeated_tokens_per_expert,
+ group=self.args.expert_parallel_group,
+ async_op=True,
+ )
+
+ # Permute locally and without any padding so that tokens for each
+ # parallel device are stored contiguously.
+ #
+ # This view updates the shape of the tensor from [sl, bs, hs] to
+ # [sl * bs, hs] prior to the permutation.
+ x = x.view(-1, x.shape[-1])
+ # output = ops.gather(x, indices, bin_ids, bins, self.top_k)
+ output = gather(x, indices, bin_ids, bins, self.top_k)
+ assert output is not None
+ x = output
+
+ # Compute the number of tokens that will be received from each
+ # device and permute the input data across the devices.
+ with torch.no_grad():
+ tpe_handle.wait()
+ experts_per_rank = mpu.experts_per_rank(self.args)
+
+ # Reshape to [world_size, num_experts_per_rank].
+ world_size = mpu.get_expert_parallel_world_size(self.args)
+ repeated_tokens_per_expert = (repeated_tokens_per_expert.view(world_size, experts_per_rank))
+ parallel_tokens_per_expert = (parallel_tokens_per_expert.view(world_size, experts_per_rank))
+
+ # TODO(tgale): It might be faster to do this on the GPU and
+ # then communicate the results back to the host.
+ send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1)
+ parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
+ recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1)
+
+ # Convert the send/recv counts to lists.
+ send_counts = send_counts.tolist()
+ recv_counts = recv_counts.tolist()
+ tokens_received = sum(recv_counts)
+
+ # If we're sharding the experts along the hidden dimension
+ # multiple devices own parts of the same sets of experts.
+ # Replicate the token counts so devices that share experts
+ # get all of the tokens assigned to them.
+ #
+ # TODO(tgale): Fuse this into the prior, local permutation.
+ # x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
+ x = repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
+
+ # Start the cross-device permutation asynchronously so we can
+ # overlap communication with computation.
+ parallel_x, parallel_x_handle = all_to_all(
+ x,
+ recv_counts,
+ send_counts,
+ self.args.expert_parallel_group,
+ async_op=True,
+ )
+
+ with torch.no_grad():
+ # After we do the cross-device permutation we have the tokens on the
+ # correct device but not yet grouped by expert because we received
+ # tokens from each device as contiguous chunks. To group the tokens
+ # for expert computation we'll do one more local permutation. The
+ # rest of this torch.no_grad() scope sets up the indices and bins
+ # for this permutation.
+ # replicate_bins = ops.inclusive_cumsum(
+ replicate_bins = inclusive_cumsum(
+ parallel_tokens_per_expert.flatten(),
+ 0,
+ )
+ replicate_bins = (replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins)
+
+ # Construct the expert indices for the permuted tokens.
+ parallel_top_expert = torch.remainder(
+ torch.arange(
+ self.num_experts * mpu.hidden_sharding_degree(self.args),
+ dtype=torch.int32,
+ device=indices.device,
+ ),
+ mpu.experts_per_rank(self.args),
+ )
+ # parallel_top_expert = ops.replicate(
+ parallel_top_expert = replicate(
+ parallel_top_expert.unsqueeze(dim=0),
+ replicate_bins,
+ tokens_received,
+ ).flatten()
+
+ # TODO(tgale): The sort_end_bit here can be reduced.
+ # parallel_bin_ids, parallel_indices = ops.sort(
+ parallel_bin_ids, parallel_indices = sort(
+ parallel_top_expert,
+ self.sort_end_bit,
+ )
+
+ # Calculate the bins boundaries from the token counts.
+ parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+ # parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = (parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins)
+
+ # If expert_capacity is set to zero, set the number of tokens
+ # per expert to the maximum we need to avoid dropping tokens.
+ tokens, _ = x.size()
+ expert_capacity = self.expert_capacity(tokens)
+ if expert_capacity == 0:
+ expert_capacity = torch.max(parallel_tokens_per_expert).item()
+
+ # Locally permute the tokens and perform the expert computation.
+ # Block to make sure that the cross-device permutation is complete.
+ if self.args.mlp_impl == 'grouped':
+ # GroupedMLP requires counts on CPU. We can use the tensor already
+ # moved to CPU for the prior all_to_all, which avoids an extra
+ # device synchronization.
+ parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+ parallel_x_handle.wait()
+ parallel_x = self.permute_and_compute(
+ parallel_x,
+ parallel_tokens_per_expert,
+ parallel_indices,
+ parallel_bin_ids,
+ None, # expert_weights
+ parallel_bins,
+ expert_capacity,
+ top_k=1,
+ )
+
+ # Un-permute the tokens across the devices.
+ x, _ = all_to_all(
+ parallel_x,
+ send_counts,
+ recv_counts,
+ self.args.expert_parallel_group,
+ )
+
+ # Reduce along the hidden sharding to get the final outputs.
+ #
+ # TODO(tgale): Fuse this into the following local permutation.
+ shape = (
+ mpu.hidden_sharding_degree(self.args),
+ -1,
+ self.args.hidden_size,
+ )
+ # x = ops.sum(x.view(shape), dim=0)
+ x = x.view(shape).sum(dim=0)
+
+ # Un-permute locally to setup for the next series of operations.
+ # x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
+ x = scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
+ return x, tokens_per_expert.flatten()
+
+ def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ in_shape = x.size()
+
+ # Compute the experts.
+ x, tokens_per_expert = self.forward_fn(x, expert_weights, top_experts)
+ if self.training and self.args.moe_loss_weight > 0:
+ save_load_balancing_loss((tokens_per_expert, scores))
+ x = x.view(in_shape)
+ if self.bias is not None:
+ if self.args.return_bias:
+ return x, self.bias
+ return x + self.bias
+ return x
+
+
+class MoE(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super(MoE, self).__init__()
+
+ # Token router.
+ self.router = router.LearnedRouter(args)
+
+ # Expert computation helper.
+ self.experts = self._init_experts_mlp(args)
+
+ self.shared_expert = None
+ if args.shared_expert:
+ # SharedExpert computation helper.
+ self.shared_expert = sharedexpert_registry.get(args)
+
+ def _init_experts_mlp(self, args: Arguments):
+ return ParallelMLP(args)
+
+ def forward(self, x: torch.Tensor):
+ # NOTE: If we're going to cast the activations to lower precision
+ # do it before we permute the tokens to save bandwidth.
+ x = common.cast_if_autocast_enabled(x)
+
+ # Compute the expert scores and assignments.
+ scores, expert_weights, top_experts = self.router(x)
+
+ # Compute the experts.
+ out = self.experts(x, scores, expert_weights, top_experts)
+ if self.shared_expert is not None:
+ shared_expert_out = self.shared_expert(x)
+ out = self.shared_expert.add_experts_sharedexpert(
+ shared_expert_out,
+ out,
+ )
+ return out
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/_layers/mpu.py b/build/torch28-cxx11-cu126-x86_64-linux/_layers/mpu.py
new file mode 100644
index 0000000000000000000000000000000000000000..434e143ab42bf3f83406d69e9dd1f72777716e22
--- /dev/null
+++ b/build/torch28-cxx11-cu126-x86_64-linux/_layers/mpu.py
@@ -0,0 +1,94 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Optional
+
+import torch
+import torch.distributed as dist
+
+# from megablocks.layers.arguments import Arguments
+from .arguments import Arguments
+
+
+class MoeParam(torch.Tensor):
+
+ def __init__(self):
+ super().__init__(self)
+ self.expert_model_parallel: bool
+
+
+def is_moe_param(tensor: torch.Tensor) -> bool:
+ return hasattr(tensor, 'expert_model_parallel')
+
+
+def get_expert_parallel_world_size(args: Arguments) -> int:
+ return (dist.get_world_size(args.expert_parallel_group) if args.moe_expert_model_parallelism else 1)
+
+
+def get_expert_parallel_rank(args: Arguments) -> int:
+ return (dist.get_rank(args.expert_parallel_group) if args.moe_expert_model_parallelism else 0)
+
+
+def set_expert_model_parallel_attributes(
+ tensor: torch.Tensor,
+ is_parallel: bool,
+):
+ assert not hasattr(tensor, 'expert_model_parallel')
+ setattr(tensor, 'expert_model_parallel', is_parallel)
+
+
+def param_is_expert_model_parallel(param: MoeParam) -> bool:
+ return (hasattr(param, 'expert_model_parallel') and param.expert_model_parallel)
+
+
+def copy_expert_model_parallel_attributes(
+ destination_tensor: torch.Tensor,
+ source_tensor: torch.Tensor,
+):
+ if hasattr(source_tensor, 'expert_model_parallel'):
+ setattr(
+ destination_tensor,
+ 'expert_model_parallel',
+ getattr(source_tensor, 'expert_model_parallel'),
+ )
+
+
+def synchronized_print(group: Optional[dist.ProcessGroup], *x: torch.Tensor):
+ world_size = dist.get_world_size(group)
+ rank = dist.get_rank(group)
+ for i in range(world_size):
+ dist.barrier(group)
+ if i == rank:
+ print(f'rank = {rank}', *x)
+
+
+# Helpers for expert/tensor sharding.
+def expert_sharding_degree(args: Arguments) -> int:
+ world_size = get_expert_parallel_world_size(args)
+ esd = min(world_size, args.moe_num_experts)
+
+ if (args.moe_num_experts % esd) != 0:
+ raise ValueError(f'Cannot shard {args.moe_num_experts} experts {esd} ways.',)
+ return esd
+
+
+def hidden_sharding_degree(args: Arguments) -> int:
+ world_size = get_expert_parallel_world_size(args)
+ esd = expert_sharding_degree(args)
+ hsd = world_size // esd
+
+ if (args.ffn_hidden_size % hsd) != 0:
+ raise ValueError(f'Cannot shard {args.ffn_hidden_size} features {hsd} ways.',)
+ if (esd * hsd) != world_size:
+ raise ValueError(
+ f"Invalid sharding. 'expert_sharding_degree' ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size}).",
+ )
+ return hsd
+
+
+def experts_per_rank(args: Arguments) -> int:
+ return args.moe_num_experts // expert_sharding_degree(args)
+
+
+def features_per_rank(args: Arguments) -> int:
+ return args.ffn_hidden_size // hidden_sharding_degree(args)
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/_layers/router.py b/build/torch28-cxx11-cu126-x86_64-linux/_layers/router.py
new file mode 100644
index 0000000000000000000000000000000000000000..37cb2782348d62583376f1a183c7ede83601216d
--- /dev/null
+++ b/build/torch28-cxx11-cu126-x86_64-linux/_layers/router.py
@@ -0,0 +1,116 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+
+# from megablocks.layers import common
+# from megablocks.layers.arguments import Arguments
+from . import common
+from .arguments import Arguments
+
+_ROUTER_LOGITS = []
+
+
+def _save_router_logits(logits: torch.Tensor, args: Arguments):
+ if args.moe_zloss_weight == 0:
+ return
+ global _ROUTER_LOGITS
+ _ROUTER_LOGITS.append(logits)
+
+
+def clear_router_zloss():
+ global _ROUTER_LOGITS
+ _ROUTER_LOGITS.clear()
+
+
+def batched_router_zloss(args: Arguments):
+ global _ROUTER_LOGITS
+
+ if args.moe_zloss_weight == 0:
+ import warnings
+ warnings.warn('Call to batched_router_zloss, but moe_zloss_weight=0')
+ return 0
+
+ logits_per_router = _ROUTER_LOGITS
+
+ if args.moe_zloss_in_fp32:
+ logits_per_router = [logits.float() for logits in logits_per_router]
+
+ unscaled_zloss_per_router = torch.stack([
+ torch.logsumexp(logits, dim=1).square().mean() for logits in logits_per_router
+ ])
+
+ return args.moe_zloss_weight * unscaled_zloss_per_router
+
+
+# NOTE: To enable end-to-end benchmarking without convergence we
+# support a flag to force the router to assign tokens uniformly
+# across the experts. We do this with a custom autograd operation
+# so that PyTorch still executes the full set of router operation.
+class _UniformExpertAssignment(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, num_experts: int):
+ out = torch.arange(x.numel(), dtype=x.dtype, device=x.device)
+ out = torch.remainder(out, num_experts)
+ return out.view(x.shape)
+
+
+_uniform_expert_assignment = _UniformExpertAssignment.apply
+
+
+class LearnedRouter(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+
+ # Learned router parameters.
+ #
+ # NOTE: This weight matrix is not parallelized with expert model
+ # parallelism. Each device needs the entire router weight matrix
+ # so that it can route its batch of data correctly.
+ self.layer = torch.nn.Linear(
+ args.hidden_size,
+ args.moe_num_experts,
+ bias=False,
+ dtype=common.dtype(args),
+ device=args.device,
+ )
+ args.init_method(self.layer.weight)
+
+ def jitter(self, x: torch.Tensor):
+ low: float = 1.0 - self.args.moe_jitter_eps
+ high: float = 1.0 + self.args.moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return low + noise * (high - low)
+
+ def _top_k(self, scores: torch.Tensor):
+ if self.args.moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, self.args.moe_top_k, dim=-1)
+
+ def forward(self, x: torch.Tensor):
+ if self.training and self.args.moe_jitter_eps is not None:
+ x = x * self.jitter(x)
+
+ logits = self.layer(x.view(-1, x.shape[-1]))
+ _save_router_logits(logits, self.args)
+ scores = logits.softmax(dim=-1)
+ expert_weights, expert_indices = self._top_k(scores)
+ if self.args.moe_normalize_expert_weights:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=self.args.moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+
+ expert_indices = (
+ _uniform_expert_assignment(
+ expert_indices,
+ self.args.moe_num_experts,
+ ) if self.args.uniform_expert_assignment else expert_indices
+ )
+ return scores, expert_weights, expert_indices
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/_layers/sharedexpert_registry.py b/build/torch28-cxx11-cu126-x86_64-linux/_layers/sharedexpert_registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..5840862f88f370ace5fd49bd0612fc98d186cc49
--- /dev/null
+++ b/build/torch28-cxx11-cu126-x86_64-linux/_layers/sharedexpert_registry.py
@@ -0,0 +1,32 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Union
+
+# from megablocks.layers import glu, mlp
+# from megablocks.layers.arguments import Arguments
+from . import glu, mlp
+from .arguments import Arguments
+
+_REGISTRY = {
+ 'mlp': mlp.SharedMLP,
+ 'glu': glu.SharedGLU,
+}
+
+
+def get(args: Arguments) -> Union[mlp.SharedMLP, glu.SharedGLU]:
+ """Returns an SharedMLP for use in a dMoE instance.
+
+ Uses the provided arguments to instantiate the appropriate
+ SharedMLP instance.
+
+ Args:
+ args: propagated Arguments dataclass.
+
+ Returns:
+ An instantiated SharedMLP constructed using the input args.
+ """
+ if args.mlp_type not in _REGISTRY:
+ raise ValueError(f'Unsupported mlp type: {args.mlp_type}')
+
+ return _REGISTRY[args.mlp_type](args)
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/_megablocks_4f35d2a.abi3.so b/build/torch28-cxx11-cu126-x86_64-linux/_megablocks_4f35d2a.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..53500b2527dbc4099ec960d6afcc290797c61fec
--- /dev/null
+++ b/build/torch28-cxx11-cu126-x86_64-linux/_megablocks_4f35d2a.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c25df724a678c8783c3252180f7bcc124e52baf2f6f4f8bd6d69ec2ff2c58e7d
+size 15046832
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/_ops.py b/build/torch28-cxx11-cu126-x86_64-linux/_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..69479df3044627d4a8ac3fb70d0b1f0e9b22deed
--- /dev/null
+++ b/build/torch28-cxx11-cu126-x86_64-linux/_ops.py
@@ -0,0 +1,9 @@
+import torch
+from . import _megablocks_4f35d2a
+ops = torch.ops._megablocks_4f35d2a
+
+def add_op_namespace_prefix(op_name: str):
+ """
+ Prefix op by namespace.
+ """
+ return f"_megablocks_4f35d2a::{op_name}"
\ No newline at end of file
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/_version.py b/build/torch28-cxx11-cu126-x86_64-linux/_version.py
new file mode 100644
index 0000000000000000000000000000000000000000..c55783177af19bc03654c730c4892df8f8532279
--- /dev/null
+++ b/build/torch28-cxx11-cu126-x86_64-linux/_version.py
@@ -0,0 +1,6 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+"""The MegaBlocks Version."""
+
+__version__ = '0.11.0.dev0'
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/backend/__init__.py b/build/torch28-cxx11-cu126-x86_64-linux/backend/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d4e43e9b7e6a9a1c3bde2df34914643ca5d8332
--- /dev/null
+++ b/build/torch28-cxx11-cu126-x86_64-linux/backend/__init__.py
@@ -0,0 +1,2 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/backend/kernels.py b/build/torch28-cxx11-cu126-x86_64-linux/backend/kernels.py
new file mode 100644
index 0000000000000000000000000000000000000000..d20139324883338ddb312e4b05a72056d74491ac
--- /dev/null
+++ b/build/torch28-cxx11-cu126-x86_64-linux/backend/kernels.py
@@ -0,0 +1,557 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import triton
+import triton.language as tl
+
+# Stub triton autotune when testing in a env that does not have CUDA
+# this approach preserves the original code but enables testing without a GPU
+if torch.cuda.is_available() is False:
+ import warnings
+
+ warnings.warn("CUDA is not available. Triton autotuning is disabled.")
+
+ def _no_autotune(*args, **kwargs):
+ def deco(fn):
+ return fn
+ return deco
+
+ triton.autotune = _no_autotune
+
+
+def assert_is_tensor(x, ndim):
+ if x.ndim != ndim:
+ raise ValueError(f'Expected {ndim}-tensor but got {x.ndim}-tensor')
+
+
+def assert_is_matrix(x):
+ assert_is_tensor(x, 2)
+
+
+def assert_is_vector(x):
+ if x.ndim != 1:
+ raise ValueError(f'Expected 1-tensor but got {x.ndim}-tensor')
+
+
+def assert_equal(a, b):
+ if a != b:
+ raise ValueError(f'Expected dimensions to be equal but got {a} and {b}.',)
+
+
+# a: (tokens, hidden_size), real.
+# indices: (tokens * top_k), integer.
+# bin_ids: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+# padded_bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _padded_copy(
+ a,
+ b,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+ A_TO_B: tl.constexpr,
+ SCALE: tl.constexpr,
+):
+ # Our index into array 'a'.
+ index_a = tl.load(indices + tl.program_id(0))
+
+ # One threadblock per row in 'a'. Array 'b' has greater or equal
+ # number of rows since they could be padded.
+ bin_idx = tl.load(bin_ids + tl.program_id(0))
+
+ # Now we know what bin we're assigned to, but we need to know how
+ # many threadblocks were assigned to earlier bins so we can offset
+ # in our bin properly.
+ offset_in_bin = tl.program_id(0)
+ if bin_idx > 0:
+ offset_in_bin -= tl.load(bins + bin_idx - 1)
+
+ # Load the starting index of our bin in array 'b'.
+ index_b = offset_in_bin
+ if bin_idx > 0:
+ index_b += tl.load(padded_bins + bin_idx - 1)
+
+ # Offset the input and output pointers.
+ #
+ # If we're going from A to B, divide the input index to copy
+ # the same input repeatedly. If we're going from B to A we
+ # need to reduce the result. Using atomics is slow, so we
+ # do the reduce step in a second kernel.
+ offset = index_a // TOP_K if A_TO_B else index_a
+ a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS)
+ b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ # Load the scale, if requested.
+ scale = tl.load(weights + index_a) if SCALE else 1
+
+ # Swap the pointers depending on the direction.
+ iptr = a if A_TO_B else b
+ optr = b if A_TO_B else a
+
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ x = tl.load(iptr + offsets, mask=mask)
+ x = x.to(tl.float32) * scale.to(tl.float32)
+
+ tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask)
+
+ offsets += BLOCK_X
+
+
+def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+ assert_equal(bin_ids.shape[0], x.shape[0] * top_k)
+ assert_equal(bins.size(), padded_bins.size())
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ # NOTE: Because of the padding, the output size is dynamic.
+ # We load the final padded bin bound to get the output rows.
+ output_rows = padded_bins[-1].cpu().item()
+ out = torch.zeros((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ x,
+ out,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def gather(x, indices, bin_ids, weights, bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+ assert_equal(bin_ids.shape[0], x.shape[0] * top_k)
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ # NOTE: There is no padding so the output rows equals the
+ # input rows multiplied by top_k.
+ output_rows = x.shape[0] * top_k
+ out = torch.empty((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ x,
+ out,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], bin_ids.shape[0])
+ assert_equal(bins.size(), padded_bins.size())
+
+ if weights is not None:
+ assert_equal(indices.shape[0], weights.shape[0])
+
+ tokens = indices.shape[0] // top_k
+ out = torch.empty((tokens, top_k, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ out,
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=False,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+
+ # Reduce along the top-k dimension, if needed.
+ return out.sum(dim=1) if top_k > 1 else out.view(tokens, x.shape[1])
+
+
+def scatter(x, indices, bin_ids, weights, bins, top_k):
+ return padded_scatter(x, indices, bin_ids, weights, bins, bins, top_k)
+
+
+# x: (tokens, top_k, hidden_size), real
+# grad: (tokens, hidden_size), real.
+# wgrad: (tokens, top_k), real.
+# indices: (tokens * top_k), integer.
+# bin_ids: (tokens * top_k), integer.
+# bins: (num_experts), integer.
+# padded_bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _padded_copy_wgrad(
+ x,
+ grad,
+ wgrad,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+):
+ # Our index into 'tokens * top_k'.
+ index_out = tl.load(indices + tl.program_id(0))
+
+ # One threadblock per row in 'a'. Array 'b' has greater or equal
+ # number of rows since they could be padded.
+ bin_idx = tl.load(bin_ids + tl.program_id(0))
+
+ # Now we know what bin we're assigned to, but we need to know how
+ # many threadblocks were assigned to earlier bins so we can offset
+ # in our bin properly.
+ offset_in_bin = tl.program_id(0)
+ if bin_idx > 0:
+ offset_in_bin -= tl.load(bins + bin_idx - 1)
+
+ # Load the starting index of our bin in array 'x'.
+ index_x = offset_in_bin
+ if bin_idx > 0:
+ index_x += tl.load(padded_bins + bin_idx - 1)
+
+ # Offset the input and output pointers.
+ wgrad += index_out
+ grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS)
+ x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ data = tl.load(x + offsets, mask=mask).to(tl.float32)
+ scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
+ acc += data * scale
+ offsets += BLOCK_X
+
+ # Reduce to get the final result and store.
+ out = tl.sum(acc).to(wgrad.dtype.element_ty)
+ tl.store(wgrad, out)
+
+
+def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_matrix(grad)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], bin_ids.shape[0])
+ assert_equal(bins.size(), padded_bins.size())
+
+ tokens = indices.shape[0] // top_k
+ out = torch.empty((tokens * top_k), dtype=x.dtype, device=x.device)
+ _padded_copy_wgrad[(indices.shape[0],)](
+ x,
+ grad,
+ out,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ TOP_K=top_k,
+ )
+ return out
+
+
+def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k):
+ return padded_scatter_wgrad(x, grad, indices, bin_ids, bins, bins, top_k)
+
+
+# a: (tokens, hidden_size), real.
+# b: (num_experts, expert_capacity, num_columns), real.
+# indices: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _binned_copy(
+ a,
+ b,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+ A_TO_B: tl.constexpr,
+ SCALE: tl.constexpr,
+):
+ # Load our indices into the output.
+ expert_idx = tl.program_id(0)
+ entry_idx = tl.program_id(1)
+
+ # Calculate our offset into the output.
+ index_b = expert_idx * expert_capacity + entry_idx
+
+ # Load the index bounds for our bin and calculate
+ # the number of tokens assigned to our expert.
+ start = 0
+ if expert_idx > 0:
+ start = tl.load(bins + expert_idx - 1)
+ end = tl.load(bins + expert_idx)
+ num_tokens = end - start
+
+ # Calculate our offset into the input. If we don't
+ # have an input exit early.
+ if entry_idx >= num_tokens:
+ return
+ index_a = tl.load(indices + start + entry_idx)
+
+ # Offset the input and output pointers.
+ #
+ # If we're going from A to B, divide the input index to copy
+ # the same input repeatedly. If we're going from B to A we
+ # need to reduce the result. Using atomics is slow, so we
+ # do the reduce step in a second kernel.
+ offset = index_a // TOP_K if A_TO_B else index_a
+ a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS)
+ b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ # Load the scale, if requested.
+ scale = tl.load(weights + index_a) if SCALE else 1
+
+ # Swap the pointers depending on the direction.
+ #
+ # NOTE: We need to zero the output in both directions.
+ iptr = a if A_TO_B else b
+ optr = b if A_TO_B else a
+
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ x = tl.load(iptr + offsets, mask=mask)
+ x = x.to(tl.float32) * scale.to(tl.float32)
+
+ tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask)
+
+ offsets += BLOCK_X
+
+
+def binned_gather(x, indices, weights, bins, expert_capacity, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ num_experts = bins.shape[0]
+ out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
+
+ _binned_copy[(num_experts, expert_capacity)](
+ x,
+ out,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def binned_scatter(x, indices, weights, bins, top_k):
+ # Validate the input shapes.
+ assert_is_tensor(x, 3)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(bins.shape[0], x.shape[0])
+
+ if weights is not None:
+ assert_equal(indices.shape[0], weights.shape[0])
+
+ num_experts, expert_capacity, hidden_size = x.shape
+ tokens = indices.shape[0] // top_k
+ out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
+ _binned_copy[(num_experts, expert_capacity)](
+ out,
+ x,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS=hidden_size,
+ A_TO_B=False,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+
+ # Reduce along the top-k dimension, if needed.
+ return out.sum(dim=1) if top_k > 1 else out.view(tokens, hidden_size)
+
+
+# a: (tokens, hidden_size), real.
+# b: (num_experts, expert_capacity, num_columns), real.
+# indices: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _binned_copy_wgrad(
+ x,
+ grad,
+ wgrad,
+ num_experts,
+ expert_capacity,
+ indices,
+ bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+):
+ # Load our indices into the output.
+ expert_idx = tl.program_id(0)
+ entry_idx = tl.program_id(1)
+
+ # Calculate our offset into the output.
+ index_x = expert_idx * expert_capacity + entry_idx
+
+ # Load the index bounds for our bin and calculate
+ # the number of tokens assigned to our expert.
+ start = 0
+ if expert_idx > 0:
+ start = tl.load(bins + expert_idx - 1)
+ end = tl.load(bins + expert_idx)
+ num_tokens = end - start
+
+ # Calculate our offset into the input. If we don't
+ # have an input exit early.
+ if entry_idx >= num_tokens:
+ return
+ index_out = tl.load(indices + start + entry_idx)
+
+ # Offset the input and output pointers.
+ wgrad += index_out
+ grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS)
+ x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ data = tl.load(x + offsets, mask=mask).to(tl.float32)
+ scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
+ acc += data * scale
+ offsets += BLOCK_X
+
+ # Reduce to get the final result and store.
+ out = tl.sum(acc).to(wgrad.dtype.element_ty)
+ tl.store(wgrad, out)
+
+
+def binned_scatter_wgrad(x, grad, indices, bins, top_k):
+ # Validate the input shapes.
+ assert_is_tensor(x, 3)
+ assert_is_matrix(grad)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(bins.shape[0], x.shape[0])
+
+ num_experts, expert_capacity, hidden_size = x.shape
+ tokens = indices.shape[0] // top_k
+ out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
+ _binned_copy_wgrad[(num_experts, expert_capacity)](
+ x,
+ grad,
+ out,
+ num_experts,
+ expert_capacity,
+ indices,
+ bins,
+ NUM_COLUMNS=hidden_size,
+ TOP_K=top_k,
+ )
+ return out
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/benchmark_util.py b/build/torch28-cxx11-cu126-x86_64-linux/benchmark_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..02612d95e3ead1175a596e2878fa34b5bf85ad6f
--- /dev/null
+++ b/build/torch28-cxx11-cu126-x86_64-linux/benchmark_util.py
@@ -0,0 +1,35 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import numpy as np
+import torch
+
+
+def log_benchmark(name, arguments, time, std):
+ print('=' * 60)
+ print(f'{name} Benchmark')
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean time = {:.3f}ms, std time = {:.3f}ms'.format(time, std))
+ print('=' * 60)
+
+
+def benchmark_function(fn, iterations=100, warmup=10):
+ # Warmup iterations.
+ for _ in range(warmup):
+ fn()
+
+ times = []
+ for i in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+
+ start.record()
+ fn()
+ end.record()
+
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ return np.mean(times), np.std(times)
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/grouped_gemm/__init__.py b/build/torch28-cxx11-cu126-x86_64-linux/grouped_gemm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b91c8308f0c24f4c4171b6e4f15b6f76dabf295a
--- /dev/null
+++ b/build/torch28-cxx11-cu126-x86_64-linux/grouped_gemm/__init__.py
@@ -0,0 +1,2 @@
+from . import ops
+from . import backend
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/grouped_gemm/backend.py b/build/torch28-cxx11-cu126-x86_64-linux/grouped_gemm/backend.py
new file mode 100644
index 0000000000000000000000000000000000000000..76037d8039cbfc2f0577275c78e4bc0be762592a
--- /dev/null
+++ b/build/torch28-cxx11-cu126-x86_64-linux/grouped_gemm/backend.py
@@ -0,0 +1,33 @@
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# # TODO(tgale): Wrap this in a try-block with better
+# # error message and instructions for building the
+# # c++ operations.
+# import grouped_gemm_backend as backend
+
+# We import the backend operations from the megablocks package as
+# grouped_gemm is vendored in megablocks in this repository.
+# from ... import _ops as backend
+# from megablocks._ops import ops as backend # type: ignore
+from .._ops import ops as backend # type: ignore
+
+def _allocate_output(a, b, batch_sizes, trans_a, trans_b):
+ assert not (trans_a and trans_b)
+ assert batch_sizes.ndim == 1, "Expected 1d tensor for batch_sizes"
+ assert a.ndim == 2, "Expected 2d tensor for 'a'"
+ assert b.ndim == (2 if trans_a else 3)
+
+ shape = (
+ (batch_sizes.shape[0], a.shape[1], b.shape[1])
+ if trans_a else
+ (a.shape[0], (b.shape[1] if trans_b else b.shape[2]))
+ )
+ return torch.empty(*shape, device=a.device, dtype=a.dtype)
+
+def gmm(a, b, batch_sizes, trans_a=False, trans_b=False, c=None):
+ if c is None:
+ c = _allocate_output(a, b, batch_sizes, trans_a, trans_b)
+ backend.gmm(a, b, c, batch_sizes, trans_a, trans_b)
+ return c
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/grouped_gemm/ops.py b/build/torch28-cxx11-cu126-x86_64-linux/grouped_gemm/ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b30dd14e23837ea3b12334f4e31337ed9ad2b69
--- /dev/null
+++ b/build/torch28-cxx11-cu126-x86_64-linux/grouped_gemm/ops.py
@@ -0,0 +1,33 @@
+from . import backend
+import torch
+
+
+class GroupedGemm(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, a, b, batch_sizes, trans_b):
+ ctx.save_for_backward(a, b, batch_sizes)
+ ctx.trans_b = trans_b
+ return backend.gmm(a, b, batch_sizes, trans_a=False, trans_b=trans_b)
+
+ @staticmethod
+ def backward(ctx, grad):
+ grad = grad.contiguous()
+ a, b, batch_sizes = ctx.saved_tensors
+ trans_b = ctx.trans_b
+
+ agrad = None
+ if ctx.needs_input_grad[0]:
+ agrad = backend.gmm(
+ grad, b, batch_sizes, trans_a=False, trans_b=not trans_b)
+
+ bgrad = None
+ if ctx.needs_input_grad[1]:
+ lhs, rhs = (grad, a) if trans_b else (a, grad)
+ bgrad = backend.gmm(
+ lhs, rhs, batch_sizes, trans_a=True, trans_b=False)
+ return agrad, bgrad, None, None
+
+
+def gmm(a, b, batch_sizes, trans_b=False):
+ return GroupedGemm.apply(a, b, batch_sizes, trans_b)
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/grouped_gemm_util.py b/build/torch28-cxx11-cu126-x86_64-linux/grouped_gemm_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d6d49fc46b0a57ad46e4179df3cc1ac2a24f7ae
--- /dev/null
+++ b/build/torch28-cxx11-cu126-x86_64-linux/grouped_gemm_util.py
@@ -0,0 +1,31 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+import warnings
+
+_grouped_gemm_is_available: bool = False
+try:
+ # import grouped_gemm
+ pass
+ _grouped_gemm_is_available = True
+except ImportError as error:
+ warnings.warn('Grouped GEMM not available.')
+
+
+def grouped_gemm_is_available():
+ return _grouped_gemm_is_available
+
+
+def assert_grouped_gemm_is_available():
+ msg = (
+ 'Grouped GEMM not available. Please run '
+ '`pip install git+https://github.com/tgale96/grouped_gemm@main`.',
+ )
+ assert _grouped_gemm_is_available, msg
+
+
+# backend = grouped_gemm.backend if grouped_gemm_is_available() else None
+# ops = grouped_gemm.ops if grouped_gemm_is_available() else None
+
+
+from .grouped_gemm import backend as ops
+from .grouped_gemm import ops as backend
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/layers.py b/build/torch28-cxx11-cu126-x86_64-linux/layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..d3a66957b08d748fd5b4fca8ad5f2c68c81cf429
--- /dev/null
+++ b/build/torch28-cxx11-cu126-x86_64-linux/layers.py
@@ -0,0 +1,1230 @@
+import torch
+import torch.distributed as dist
+
+from typing import Optional, Any, TYPE_CHECKING
+
+from . import _layers
+from . import ops
+
+# Conditional import for meta kernel registration
+if TYPE_CHECKING:
+
+ def register_fake(fn):
+ return lambda name: fn
+
+else:
+ try:
+ from torch.library import register_fake
+ except ImportError:
+ try:
+ from torch.library import impl_abstract as register_fake
+ except ImportError:
+ # Fallback for older PyTorch versions
+ def register_fake(op_name):
+ def decorator(fn):
+ return fn
+
+ return decorator
+
+
+# Meta kernel implementations for torch.compile compatibility
+def _install_meta_kernels():
+ """Install meta kernels for existing MegaBlocks operations"""
+
+ # Create wrapper functions that check for compilation and return meta tensors
+
+ # Patch ops.sort
+ if hasattr(ops, "sort"):
+ original_sort = ops.sort
+
+ def sort_with_meta(x, end_bit=None):
+ if torch.compiler.is_compiling():
+ print("Using meta kernel for sort")
+ # Meta implementation - return tensors with correct shape/dtype/device
+ return torch.empty_like(x), torch.empty_like(x)
+ # print("Using original sort kernel")
+ return original_sort(x, end_bit)
+
+ ops.sort = sort_with_meta
+
+ # Patch ops.histogram
+ if hasattr(ops, "histogram"):
+ original_histogram = ops.histogram
+
+ def histogram_with_meta(x, max_val):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty((max_val,), dtype=torch.int32, device=x.device)
+ return original_histogram(x, max_val)
+
+ ops.histogram = histogram_with_meta
+
+ # Patch ops.inclusive_cumsum
+ if hasattr(ops, "inclusive_cumsum"):
+ original_inclusive_cumsum = ops.inclusive_cumsum
+
+ def inclusive_cumsum_with_meta(x, dim):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty_like(x)
+ return original_inclusive_cumsum(x, dim)
+
+ ops.inclusive_cumsum = inclusive_cumsum_with_meta
+
+ # Patch ops.binned_gather
+ if hasattr(ops, "binned_gather"):
+ original_binned_gather = ops.binned_gather
+
+ def binned_gather_with_meta(x, indices, bins, bin_size, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - output shape based on bin_size
+ if x.dim() >= 2:
+ hidden_size = x.size(-1)
+ return torch.empty(
+ (bin_size, x.size(1), hidden_size),
+ dtype=x.dtype,
+ device=x.device,
+ )
+ else:
+ return torch.empty((bin_size,), dtype=x.dtype, device=x.device)
+ return original_binned_gather(x, indices, bins, bin_size, top_k)
+
+ ops.binned_gather = binned_gather_with_meta
+
+ # Patch ops.binned_scatter
+ if hasattr(ops, "binned_scatter"):
+ original_binned_scatter = ops.binned_scatter
+
+ def binned_scatter_with_meta(x, indices, weights, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - typically reduces to 2D
+ if x.dim() >= 3:
+ return torch.empty(
+ (x.size(1), x.size(2)), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty_like(x)
+ return original_binned_scatter(x, indices, weights, bins, top_k)
+
+ ops.binned_scatter = binned_scatter_with_meta
+
+ # Patch ops.gather
+ if hasattr(ops, "gather"):
+ original_gather = ops.gather
+
+ def gather_with_meta(x, indices, bin_ids, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ if x.dim() >= 2:
+ hidden_size = x.size(-1)
+ return torch.empty(
+ (indices.numel(), hidden_size), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty(indices.shape, dtype=x.dtype, device=x.device)
+ return original_gather(x, indices, bin_ids, bins, top_k)
+
+ ops.gather = gather_with_meta
+
+ # Patch ops.scatter
+ if hasattr(ops, "scatter"):
+ original_scatter = ops.scatter
+
+ def scatter_with_meta(x, indices, bin_ids, weights, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - restore sequence shape
+ seq_len = (
+ indices.size(0) // top_k
+ if indices.numel() > 0 and top_k > 0
+ else x.size(0)
+ )
+ if x.dim() >= 2:
+ return torch.empty(
+ (seq_len, x.size(-1)), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty((seq_len,), dtype=x.dtype, device=x.device)
+ return original_scatter(x, indices, bin_ids, weights, bins, top_k)
+
+ ops.scatter = scatter_with_meta
+
+ # Patch ops.replicate
+ if hasattr(ops, "replicate"):
+ original_replicate = ops.replicate
+
+ def replicate_with_meta(x, bins, num_outputs):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty(
+ (x.shape[0], num_outputs), dtype=x.dtype, device=x.device
+ )
+ return original_replicate(x, bins, num_outputs)
+
+ ops.replicate = replicate_with_meta
+
+ # Patch ops.repeat (if it's a regular function)
+ if hasattr(ops, "repeat"):
+ original_repeat = ops.repeat
+
+ def repeat_with_meta(x, repeats):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ if isinstance(repeats, (tuple, list)):
+ new_shape = list(x.shape)
+ for i, rep in enumerate(repeats):
+ if i < len(new_shape):
+ new_shape[i] *= rep
+ return torch.empty(new_shape, dtype=x.dtype, device=x.device)
+ else:
+ new_shape = [x.size(0) * repeats] + list(x.shape[1:])
+ return torch.empty(new_shape, dtype=x.dtype, device=x.device)
+ return original_repeat(x, repeats)
+
+ ops.repeat = repeat_with_meta
+
+
+# Install meta kernels on import
+try:
+ _install_meta_kernels()
+except Exception as e:
+ # If meta kernel installation fails, continue without them
+ # torch.compile may not work but the library will still function
+ import warnings
+
+ warnings.warn(
+ f"Failed to install meta kernels for torch.compile support: {e}", UserWarning
+ )
+
+
+# Set the expert model parallel attributes on a tensor
+def set_expert_model_parallel_attributes(
+ tensor: torch.Tensor,
+ is_parallel: bool,
+):
+ assert not hasattr(tensor, "expert_model_parallel")
+ setattr(tensor, "expert_model_parallel", is_parallel)
+
+
+# Get the expert model parallel attributes from a tensor
+def expert_sharding_degree(
+ world_size: int,
+ moe_num_experts: int,
+) -> int:
+ esd = min(world_size, moe_num_experts)
+ if (moe_num_experts % esd) != 0:
+ raise ValueError(f"Cannot shard {moe_num_experts} experts {esd} ways.")
+ return esd
+
+
+# Calculate the hidden sharding degree based on world size and expert sharding degree
+def hidden_sharding_degree(
+ world_size: int,
+ moe_num_experts: int,
+ ffn_hidden_size: int,
+) -> int:
+ esd = expert_sharding_degree(world_size, moe_num_experts)
+ hsd = world_size // esd
+ if (ffn_hidden_size % hsd) != 0:
+ raise ValueError(f"Cannot shard {ffn_hidden_size} features {hsd} ways.")
+ if (esd * hsd) != world_size:
+ raise ValueError(
+ f"Invalid sharding. expert_sharding_degree ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size})."
+ )
+ return hsd
+
+
+# Calculate the number of experts per rank based on world size and expert sharding degree
+def experts_per_rank(
+ moe_num_experts: int,
+ world_size: int,
+) -> int:
+ return moe_num_experts // expert_sharding_degree(world_size, moe_num_experts)
+
+
+# Calculate the number of features per rank based on ffn hidden size and hidden sharding degree
+def features_per_rank(
+ ffn_hidden_size: int, world_size: int, moe_num_experts: int
+) -> int:
+ return ffn_hidden_size // hidden_sharding_degree(
+ world_size, moe_num_experts, ffn_hidden_size
+ )
+
+
+# Apply jitter to the input tensor
+def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor:
+ low = 1.0 - moe_jitter_eps
+ high = 1.0 + moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return x * (low + noise * (high - low))
+
+
+# Compute the top-k scores from the logits
+def compute_top_k(scores: torch.Tensor, moe_top_k: int):
+ if moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, moe_top_k, dim=-1)
+
+
+# Route tokens to experts and compute expert weights and indices
+def route_tokens(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: torch.Tensor,
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ if training and moe_jitter_eps is not None:
+ x = apply_jitter(x, moe_jitter_eps)
+
+ x_flat = x.view(-1, x.shape[-1])
+ logits = torch.nn.functional.linear(x_flat, router_weight, router_bias)
+ expert_weights, expert_indices = compute_top_k(logits, moe_top_k)
+ expert_weights = expert_weights.softmax(dim=-1)
+ if moe_normalize_expert_weights is not None:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+ if uniform_expert_assignment:
+ expert_indices = _layers.router._uniform_expert_assignment(
+ expert_indices,
+ moe_num_experts,
+ )
+
+ return logits, expert_weights, expert_indices
+
+
+# Scale the gradient of the weights
+def scale_grad(
+ w: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+) -> torch.Tensor:
+ if gradient_scale is None:
+ return w
+ return _layers.mlp.scale_gradient(w, gradient_scale)
+
+
+# Forward pass for the MLP layer
+def mlp_forward(
+ x: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ limit: float = 7.0,
+):
+ # Scale weights
+ w1 = scale_grad(w1, gradient_scale)
+ w2 = scale_grad(w2, gradient_scale)
+ w1_bias = scale_grad(w1_bias, gradient_scale)
+ w2_bias = scale_grad(w2_bias, gradient_scale)
+
+ # Resolve dtensors
+ w1 = _layers.mlp.resolve_dtensor(w1)
+ w2 = _layers.mlp.resolve_dtensor(w2)
+ w1_bias = _layers.mlp.resolve_dtensor(w1_bias)
+ w2_bias = _layers.mlp.resolve_dtensor(w2_bias)
+
+ # Forward pass
+ gate_up = torch.bmm(x, w1) + w1_bias[..., None, :]
+ gate, up = gate_up[..., ::2], gate_up[..., 1::2]
+ gate = gate.clamp(min=None, max=limit)
+ up = up.clamp(min=-limit, max=limit)
+ glu = gate * torch.sigmoid(gate * alpha)
+ next_states = torch.bmm(((up + 1) * glu), w2)
+ next_states += w2_bias[..., None, :]
+ return next_states
+
+# Shared expert MLP forward pass
+def shared_mlp_forward(
+ x: torch.Tensor,
+ up_proj_weight: torch.Tensor,
+ down_proj_weight: torch.Tensor,
+ up_proj_bias: Optional[torch.Tensor] = None,
+ down_proj_bias: Optional[torch.Tensor] = None,
+ activation_fn: Optional[Any] = None,
+ gradient_scale: Optional[float] = None,
+) -> torch.Tensor:
+ # Default activation function
+ if activation_fn is None:
+ activation_fn = torch.nn.functional.gelu
+
+ # Scale weights
+ up_proj_weight = scale_grad(up_proj_weight, gradient_scale)
+ down_proj_weight = scale_grad(down_proj_weight, gradient_scale)
+ if up_proj_bias is not None:
+ up_proj_bias = scale_grad(up_proj_bias, gradient_scale)
+ if down_proj_bias is not None:
+ down_proj_bias = scale_grad(down_proj_bias, gradient_scale)
+
+ # Resolve dtensors
+ up_proj_weight = _layers.mlp.resolve_dtensor(up_proj_weight)
+ down_proj_weight = _layers.mlp.resolve_dtensor(down_proj_weight)
+ if up_proj_bias is not None:
+ up_proj_bias = _layers.mlp.resolve_dtensor(up_proj_bias)
+ if down_proj_bias is not None:
+ down_proj_bias = _layers.mlp.resolve_dtensor(down_proj_bias)
+
+ # Up projection
+ x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias)
+
+ # Activation
+ x = activation_fn(x)
+
+ # Down projection
+ x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias)
+
+ return x
+
+
+# Combine outputs from shared expert and regular experts
+def combine_expert_shared_outputs(
+ shared_expert_out: torch.Tensor,
+ expert_out: torch.Tensor,
+ shared_expert_weighted_sum: bool = False,
+ moe_top_k: int = 1,
+) -> torch.Tensor:
+ if shared_expert_weighted_sum:
+ # Weighted sum based on number of experts used
+ total_experts = moe_top_k + 1
+ shared_weight = 1.0 / total_experts
+ expert_weight = moe_top_k / total_experts
+ return shared_expert_out * shared_weight + expert_out * expert_weight
+ else:
+ # Simple addition
+ return shared_expert_out + expert_out
+
+
+# Global variable to store load balancing loss
+_LOAD_BALANCING_LOSS = []
+
+
+def save_load_balancing_loss(loss):
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.append(loss)
+
+
+def get_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ return _LOAD_BALANCING_LOSS
+
+
+def clear_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.clear()
+
+
+def batched_load_balancing_loss(args):
+ if args.moe_loss_weight == 0:
+ return 0.0
+
+ tokens_per_expert, expert_scores = zip(*get_load_balancing_loss())
+ num_layers_per_pipeline_stage = args.num_layers // args.pipeline_model_parallel_size
+ if args.num_layers_per_virtual_pipeline_stage is not None:
+ num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage
+
+ if len(tokens_per_expert) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f"Expected {num_layers_per_pipeline_stage} token_per_experts "
+ f"but found {len(tokens_per_expert)}.\nnum_layers = "
+ f"{args.num_layers}\npipeline_model_parallel_size = "
+ f"{args.pipeline_model_parallel_size}\n"
+ "num_layers_per_virtual_pipeline_stage"
+ f" = {args.num_layers_per_virtual_pipeline_stage}",
+ )
+ if len(expert_scores) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f"Expected {num_layers_per_pipeline_stage} expert_scores "
+ f"but found {len(tokens_per_expert)}.\nnum_layers = "
+ f"{args.num_layers}\npipeline_model_parallel_size = "
+ f"{args.pipeline_model_parallel_size}\n"
+ "num_layers_per_virtual_pipeline_stage"
+ f" = {args.num_layers_per_virtual_pipeline_stage}",
+ )
+
+ # Verify the shape of the tokens_per_expert and expert_scores tensors.
+ assert all(
+ (x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert)
+ )
+
+ tokens = expert_scores[0].shape[0]
+ assert all(
+ (
+ (
+ x.ndim == 2
+ and x.shape[1] == args.moe_num_experts
+ and x.shape[0] == tokens
+ )
+ for x in expert_scores
+ )
+ )
+
+ # Concatenate the contributions of each layer and convert to
+ # the correct types and formats for the dot product.
+ expert_scores = torch.cat(expert_scores, dim=1)
+ if args.moe_lbl_in_fp32:
+ expert_scores = expert_scores.float()
+ if tokens != 0:
+ expert_scores = expert_scores.mean(dim=0)
+ else:
+ expert_scores = expert_scores.sum(dim=0)
+ tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype)
+
+ expected_values = num_layers_per_pipeline_stage * args.moe_num_experts
+ assert tokens_per_expert.numel() == expected_values
+ assert expert_scores.numel() == expected_values
+
+ # Calculate the total scale across all factors.
+ #
+ # loss_weight * num_experts / (num_layers * tokens * top_k)
+ scale_numerator = args.moe_num_experts * args.moe_loss_weight
+ scale_denominator = args.num_layers * tokens * args.moe_top_k
+ scale = scale_numerator / scale_denominator
+ return scale * torch.dot(tokens_per_expert, expert_scores)
+
+
+# Calculate the expert capacity based on tokens, top_k, number of experts,
+# expert parallel group, capacity factor, and whether expert model parallelism is used.
+def expert_capacity(
+ tokens: int,
+ top_k: int,
+ num_experts: int,
+ expert_parallel_group: int,
+ moe_capacity_factor: float,
+ moe_expert_model_parallelism: bool,
+) -> int:
+ world_size = (
+ dist.get_world_size(expert_parallel_group)
+ if moe_expert_model_parallelism
+ else 1
+ )
+
+ tokens_per_expert = top_k * tokens * world_size / num_experts
+ return int(moe_capacity_factor * tokens_per_expert)
+
+
+def load_balancing_loss(
+ tokens_per_expert: torch.Tensor,
+ expert_scores: torch.Tensor,
+ top_k: int,
+ num_experts: int,
+):
+ assert len(expert_scores.size()) == 2
+ tokens, num_experts = expert_scores.size()
+ assert num_experts == num_experts
+ assert len(tokens_per_expert.size()) == 1
+ (num_experts,) = tokens_per_expert.size()
+ assert num_experts == num_experts
+ scale = num_experts / (tokens * top_k)
+ return scale * torch.dot(
+ tokens_per_expert.to(expert_scores.dtype),
+ expert_scores.mean(dim=0),
+ )
+
+
+def indices_and_bins(
+ top_expert: torch.Tensor,
+ sort_end_bit: int,
+ num_experts: int,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ top_expert = top_expert.int()
+
+ # Ensure contiguous memory layout
+ top_expert = top_expert.contiguous()
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(top_expert.device):
+ output = ops.sort(top_expert, sort_end_bit)
+ bin_ids, indices = output
+ tokens_per_expert = ops.histogram(top_expert, num_experts)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+
+ bins = bins.view(1) if not len(bins.size()) else bins
+ return indices, bin_ids, bins, tokens_per_expert
+
+
+def expert_capacity_fn(
+ tokens: int,
+ top_k: int,
+ num_experts: int,
+ expert_parallel_group: torch.distributed.ProcessGroup,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+) -> int:
+ world_size = (
+ dist.get_world_size(expert_parallel_group)
+ if moe_expert_model_parallelism
+ else 1
+ )
+ tokens_per_expert = top_k * tokens * world_size / num_experts
+ return int(moe_capacity_factor * tokens_per_expert)
+
+
+def permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ top_k,
+ w1,
+ w2,
+ w1_bias,
+ w2_bias,
+ gradient_scale,
+ alpha,
+):
+ # Route tokens to experts
+ x = x.view(-1, x.shape[-1])
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(x.device):
+ x = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
+
+ # Expert computation
+ x = mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale, alpha)
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(x.device):
+ # Route tokens back
+ out = ops.binned_scatter(x, indices, expert_weights, bins, top_k)
+ return out
+
+
+def forward_once(
+ x: torch.Tensor,
+ expert_weights: torch.Tensor,
+ top_experts: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ top_k: int = 4,
+ num_experts: int = 128,
+ expert_parallel_group: int = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ mlp_impl: Optional[str] = None,
+):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
+ top_experts, sort_end_bit, num_experts
+ )
+
+ # Calculate expert capacity
+ sl, bs, _ = x.size()
+
+ expert_capacity = expert_capacity_fn(
+ sl * bs,
+ top_k,
+ num_experts,
+ expert_parallel_group,
+ moe_capacity_factor,
+ moe_expert_model_parallelism,
+ )
+
+ if expert_capacity == 0:
+ expert_capacity = torch.max(tokens_per_expert).item()
+
+ x = permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ top_k,
+ w1,
+ w2,
+ w1_bias,
+ w2_bias,
+ gradient_scale,
+ alpha,
+ )
+ return x, tokens_per_expert
+
+
+def parallel_forward_once(
+ x: torch.Tensor,
+ expert_weights: torch.Tensor,
+ top_experts: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ top_k: int = 4,
+ num_experts: int = 128,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = True,
+ hidden_size: int = 1152,
+ mlp_impl: Optional[str] = "grouped",
+):
+ # Flatten inputs
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+
+ # TODO: remove debugging var
+ # my_rank = dist.get_rank(expert_parallel_group) if expert_parallel_group else 0
+
+ with torch.no_grad():
+ # Step 1: Local permutation setup
+ indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
+ top_experts, sort_end_bit, num_experts
+ )
+
+ # Calculate sharding parameters
+ world_size = dist.get_world_size(expert_parallel_group)
+ hidden_sharding_deg = hidden_sharding_degree(
+ world_size, num_experts, hidden_size
+ )
+ experts_per_rank_val = experts_per_rank(num_experts, world_size)
+
+ # Replicate token counts for hidden sharding
+ repeated_tokens_per_expert = ops.repeat(
+ tokens_per_expert, (hidden_sharding_deg,)
+ )
+
+ # Exchange token counts across devices
+ parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert)
+
+ # Ensure CUB knows which device to use
+ tpe_handle = dist.all_to_all_single(
+ parallel_tokens_per_expert,
+ repeated_tokens_per_expert,
+ group=expert_parallel_group,
+ async_op=True,
+ )
+
+ # Step 2: Local permutation - group tokens by target device
+ x = x.view(-1, x.shape[-1]) # [sl * bs, hs]
+ x = ops.gather(x, indices, bin_ids, bins, top_k)
+
+ # Step 3: Compute communication counts and exchange tokens
+ with torch.no_grad():
+ tpe_handle.wait()
+
+ # Reshape for per-device calculations
+ repeated_tokens_per_expert = repeated_tokens_per_expert.view(
+ world_size, experts_per_rank_val
+ )
+ parallel_tokens_per_expert = parallel_tokens_per_expert.view(
+ world_size, experts_per_rank_val
+ )
+
+ # Calculate send/recv counts
+ send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1).tolist()
+ # recv_counts = parallel_tokens_per_expert.cpu().sum(dim=-1).tolist()
+ parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
+ recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1).tolist()
+ tokens_received = sum(recv_counts)
+
+ # Replicate for hidden sharding
+ x = ops.repeat(x, (hidden_sharding_deg, 1))
+
+ # Cross-device token exchange
+ parallel_x, parallel_x_handle = _layers.all_to_all.all_to_all(
+ x, recv_counts, send_counts, expert_parallel_group, async_op=True
+ )
+
+ with torch.no_grad():
+ # Step 4: Setup for local expert computation
+ replicate_bins = ops.inclusive_cumsum(parallel_tokens_per_expert.flatten(), 0)
+ replicate_bins = (
+ replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins
+ )
+
+ # Create expert indices for received tokens
+ parallel_top_expert = torch.remainder(
+ torch.arange(
+ num_experts * hidden_sharding_deg,
+ dtype=torch.int32,
+ device=indices.device,
+ ),
+ experts_per_rank_val,
+ )
+ parallel_top_expert = ops.replicate(
+ parallel_top_expert.unsqueeze(dim=0),
+ replicate_bins,
+ tokens_received,
+ ).flatten()
+
+ # Sort tokens by expert assignment
+ parallel_bin_ids, parallel_indices = ops.sort(
+ parallel_top_expert,
+ sort_end_bit,
+ )
+
+ # Calculate bins for local experts
+ parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
+ dim=0, dtype=torch.int
+ )
+ parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = (
+ parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins
+ )
+
+ # Calculate expert capacity
+ expert_capacity = expert_capacity_fn(
+ tokens_received,
+ top_k,
+ experts_per_rank_val,
+ expert_parallel_group,
+ moe_capacity_factor,
+ moe_expert_model_parallelism,
+ )
+ if expert_capacity == 0:
+ expert_capacity = torch.max(parallel_tokens_per_expert).item()
+
+ # Locally permute the tokens and perform the expert computation.
+ # Block to make sure that the cross-device permutation is complete.
+ if mlp_impl == "grouped":
+ # GroupedMLP requires counts on CPU. We can use the tensor already
+ # moved to CPU for the prior all_to_all, which avoids an extra
+ # device synchronization.
+ parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+
+ # Step 5: Expert computation
+ parallel_x_handle.wait()
+
+ parallel_x = permute_and_compute(
+ parallel_x,
+ parallel_tokens_per_expert,
+ parallel_indices,
+ parallel_bin_ids,
+ None, # expert_weights
+ parallel_bins,
+ expert_capacity,
+ top_k=1,
+ w1=w1,
+ w2=w2,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ )
+
+ # Step 6: Reverse communication - send results back
+ x, _ = _layers.all_to_all.all_to_all(
+ parallel_x, send_counts, recv_counts, expert_parallel_group
+ )
+
+ # Step 7: Reduce across hidden sharding dimension
+ shape = (hidden_sharding_deg, -1, hidden_size)
+ x = x.view(shape).sum(dim=0)
+
+ # Step 8: Final local unpermutation
+ x = ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
+
+ return x, tokens_per_expert.flatten()
+
+
+def moe_forward(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: Optional[torch.Tensor],
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+ w1: torch.Tensor = None,
+ w2: torch.Tensor = None,
+ w1_bias: torch.Tensor = None,
+ w2_bias: torch.Tensor = None,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ forward_fn: Any = None,
+ hidden_size: int = None,
+ mlp_impl: str = "grouped",
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+
+ # Route tokens to experts
+ logits, expert_weights, expert_indices = route_tokens(
+ x,
+ router_weight,
+ router_bias,
+ moe_top_k,
+ moe_num_experts,
+ moe_jitter_eps,
+ moe_normalize_expert_weights,
+ uniform_expert_assignment,
+ training,
+ )
+
+ # Create router scores for output
+ router_scores = (
+ torch.zeros_like(logits)
+ .scatter_(1, expert_indices, expert_weights)
+ .transpose(0, 1)
+ )
+
+ in_shape = x.size()
+
+ # Prepare forward function arguments
+ forward_args = {
+ "x": x,
+ "expert_weights": expert_weights,
+ "top_experts": expert_indices,
+ "w1": w1,
+ "w2": w2,
+ "w1_bias": w1_bias,
+ "w2_bias": w2_bias,
+ "gradient_scale": gradient_scale,
+ "alpha": alpha,
+ "sort_end_bit": sort_end_bit,
+ "top_k": moe_top_k,
+ "num_experts": moe_num_experts,
+ "expert_parallel_group": expert_parallel_group,
+ "moe_capacity_factor": moe_capacity_factor,
+ "moe_expert_model_parallelism": moe_expert_model_parallelism,
+ "mlp_impl": mlp_impl,
+ }
+
+ # Add hidden_size for parallel forward
+ if moe_expert_model_parallelism and hidden_size is not None:
+ forward_args["hidden_size"] = hidden_size
+ elif moe_expert_model_parallelism and hidden_size is None:
+ # Infer hidden_size from input shape
+ forward_args["hidden_size"] = x.shape[-1]
+
+ # Compute expert outputs
+ x, tokens_per_expert = forward_fn(**forward_args)
+
+ # Save load balancing loss if needed
+ moe_loss_weight = 0.0 # Can be made configurable
+ if training and moe_loss_weight > 0:
+ save_load_balancing_loss((tokens_per_expert, logits))
+
+ # Restore original shape
+ x = x.view(in_shape)
+
+ return x, expert_weights, router_scores
+
+
+def moe_forward_with_shared_expert(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: Optional[torch.Tensor],
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+ w1: torch.Tensor = None,
+ w2: torch.Tensor = None,
+ w1_bias: torch.Tensor = None,
+ w2_bias: torch.Tensor = None,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ forward_fn: Any = None,
+ hidden_size: int = None,
+ mlp_impl: str = "grouped",
+ # Shared expert parameters
+ shared_up_proj_weight: Optional[torch.Tensor] = None,
+ shared_down_proj_weight: Optional[torch.Tensor] = None,
+ shared_up_proj_bias: Optional[torch.Tensor] = None,
+ shared_down_proj_bias: Optional[torch.Tensor] = None,
+ shared_expert_weighted_sum: bool = False,
+ shared_activation_fn: Optional[Any] = None,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+
+ # First, compute regular MoE forward pass
+ expert_out, expert_weights, router_scores = moe_forward(
+ x=x,
+ router_weight=router_weight,
+ router_bias=router_bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=training,
+ w1=w1,
+ w2=w2,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=moe_expert_model_parallelism,
+ forward_fn=forward_fn,
+ hidden_size=hidden_size,
+ mlp_impl=mlp_impl,
+ )
+
+ # If shared expert weights provided, compute shared expert output
+ if shared_up_proj_weight is not None and shared_down_proj_weight is not None:
+ shared_expert_out = shared_mlp_forward(
+ x=x,
+ up_proj_weight=shared_up_proj_weight,
+ down_proj_weight=shared_down_proj_weight,
+ up_proj_bias=shared_up_proj_bias,
+ down_proj_bias=shared_down_proj_bias,
+ activation_fn=shared_activation_fn,
+ gradient_scale=gradient_scale,
+ )
+
+ # Combine expert outputs
+ combined_out = combine_expert_shared_outputs(
+ shared_expert_out=shared_expert_out,
+ expert_out=expert_out,
+ shared_expert_weighted_sum=shared_expert_weighted_sum,
+ moe_top_k=moe_top_k,
+ )
+
+ return combined_out, expert_weights, router_scores
+
+ # Return regular MoE output if no shared expert
+ return expert_out, expert_weights, router_scores
+
+
+def create_shared_expert_weights(
+ hidden_size: int,
+ shared_expert_hidden_size: int,
+ device: torch.device,
+ dtype: torch.dtype,
+ init_method: Any,
+ output_layer_init_method: Any = None,
+) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
+
+ if output_layer_init_method is None:
+ output_layer_init_method = init_method
+
+ # Create weight tensors
+ up_proj_weight = torch.empty(
+ shared_expert_hidden_size,
+ hidden_size,
+ device=device,
+ dtype=dtype,
+ )
+ down_proj_weight = torch.empty(
+ hidden_size,
+ shared_expert_hidden_size,
+ device=device,
+ dtype=dtype,
+ )
+
+ # Initialize weights
+ init_method(up_proj_weight)
+ output_layer_init_method(down_proj_weight)
+
+ # No bias by default
+ return up_proj_weight, down_proj_weight, None, None
+
+
+# HACK: Extract device_mesh from pre-hook closure - required for transformers integration
+# This exists because device_mesh is trapped in hook closures with no model attribute
+# Fragile - breaks if hook structure changes or Python internals change
+# TODO: Replace with a more robust solution when available
+def get_device_mesh(model):
+ # Extract device_mesh from child's unused pre_hook closure
+ try:
+ # Find the pre-hook that contains 'device_mesh' in its closure
+ hook = next(
+ h
+ for h in model.experts._forward_pre_hooks.values()
+ if "device_mesh" in h.__code__.co_freevars
+ )
+ # Extract the device_mesh from the closure
+ return hook.__closure__[
+ hook.__code__.co_freevars.index("device_mesh")
+ ].cell_contents
+ except Exception:
+ return None
+
+
+class MegaBlocksMoeMLP(torch.nn.Module):
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ gradient_scale = getattr(self.experts, "gradient_scale", None)
+ alpha = getattr(self.experts, "alpha", 1.0)
+ moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+ uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
+
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = get_device_mesh(self)
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
+
+ has_parallel = (
+ expert_parallel_group is not None
+ and dist.is_initialized()
+ and dist.get_world_size(expert_parallel_group) > 1
+ )
+ forward_fn = parallel_forward_once if has_parallel else forward_once
+
+ sort_end_bit = max(
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
+ )
+ mlp_impl = getattr(self, "mlp_impl", "grouped")
+ output, expert_weights_out, *_ = moe_forward(
+ x=x,
+ router_weight=self.router.weight,
+ router_bias=self.router.bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=self.training,
+ w1=self.experts.gate_up_proj,
+ w2=self.experts.down_proj,
+ w1_bias=self.experts.gate_up_proj_bias,
+ w2_bias=self.experts.down_proj_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=has_parallel,
+ forward_fn=forward_fn,
+ hidden_size=self.experts.hidden_size,
+ mlp_impl=mlp_impl,
+ )
+ return output, expert_weights_out
+
+
+# Export main classes
+__all__ = ["MegaBlocksMoeMLP", "MegaBlocksMoeMLPWithSharedExpert"]
+
+
+class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
+
+ def __init__(self):
+ super().__init__()
+ # Shared expert weights will be set by the user
+ self.shared_up_proj_weight = None
+ self.shared_down_proj_weight = None
+ self.shared_up_proj_bias = None
+ self.shared_down_proj_bias = None
+ self.shared_expert_weighted_sum = False
+ self.shared_activation_fn = None
+
+ def set_shared_expert_weights(
+ self,
+ up_proj_weight: torch.Tensor,
+ down_proj_weight: torch.Tensor,
+ up_proj_bias: Optional[torch.Tensor] = None,
+ down_proj_bias: Optional[torch.Tensor] = None,
+ weighted_sum: bool = False,
+ activation_fn: Optional[Any] = None,
+ ):
+ self.shared_up_proj_weight = up_proj_weight
+ self.shared_down_proj_weight = down_proj_weight
+ self.shared_up_proj_bias = up_proj_bias
+ self.shared_down_proj_bias = down_proj_bias
+ self.shared_expert_weighted_sum = weighted_sum
+ self.shared_activation_fn = activation_fn
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ gradient_scale = getattr(self.experts, "gradient_scale", None)
+ alpha = getattr(self.experts, "alpha", 1.0)
+ moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+ uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
+
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = get_device_mesh(self)
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
+
+ has_parallel = (
+ expert_parallel_group is not None
+ and dist.is_initialized()
+ and dist.get_world_size(expert_parallel_group) > 1
+ )
+ forward_fn = parallel_forward_once if has_parallel else forward_once
+
+ sort_end_bit = max(
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
+ )
+ mlp_impl = getattr(self, "mlp_impl", "grouped")
+
+ output, expert_weights_out, *_ = moe_forward_with_shared_expert(
+ x=x,
+ router_weight=self.router.weight,
+ router_bias=self.router.bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=self.training,
+ w1=self.experts.gate_up_proj,
+ w2=self.experts.down_proj,
+ w1_bias=self.experts.gate_up_proj_bias,
+ w2_bias=self.experts.down_proj_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=has_parallel,
+ forward_fn=forward_fn,
+ hidden_size=self.experts.hidden_size,
+ mlp_impl=mlp_impl,
+ # Shared expert parameters
+ shared_up_proj_weight=self.shared_up_proj_weight,
+ shared_down_proj_weight=self.shared_down_proj_weight,
+ shared_up_proj_bias=self.shared_up_proj_bias,
+ shared_down_proj_bias=self.shared_down_proj_bias,
+ shared_expert_weighted_sum=self.shared_expert_weighted_sum,
+ shared_activation_fn=self.shared_activation_fn,
+ )
+ return output, expert_weights_out
+
+
+# Patch for XPU support
+if hasattr(torch, "xpu") and torch.xpu.is_available():
+ from .xpu_fused_moe import MegaBlocksMoeMLP
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/megablocks/__init__.py b/build/torch28-cxx11-cu126-x86_64-linux/megablocks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..03dbc1afe1cf156661a2b1b22003cd5f599a0309
--- /dev/null
+++ b/build/torch28-cxx11-cu126-x86_64-linux/megablocks/__init__.py
@@ -0,0 +1,26 @@
+import ctypes
+import sys
+
+import importlib
+from pathlib import Path
+from types import ModuleType
+
+def _import_from_path(file_path: Path) -> ModuleType:
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
+ # it would also be used for other imports. So, we make a module name that
+ # depends on the path for it to be unique using the hex-encoded hash of
+ # the path.
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
+ module_name = path_hash
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ if spec is None:
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
+ module = importlib.util.module_from_spec(spec)
+ if module is None:
+ raise ImportError(f"Cannot load module {module_name} from spec")
+ sys.modules[module_name] = module
+ spec.loader.exec_module(module) # type: ignore
+ return module
+
+
+globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/metadata.json b/build/torch28-cxx11-cu126-x86_64-linux/metadata.json
new file mode 100644
index 0000000000000000000000000000000000000000..9cf5deed9898dce769f4cc73913d3530b92a0bd8
--- /dev/null
+++ b/build/torch28-cxx11-cu126-x86_64-linux/metadata.json
@@ -0,0 +1,4 @@
+{
+ "version": 1,
+ "python-depends": []
+}
\ No newline at end of file
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/ops/__init__.py b/build/torch28-cxx11-cu126-x86_64-linux/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b944080df810d0b0cfc571f3009b0098a651f9b7
--- /dev/null
+++ b/build/torch28-cxx11-cu126-x86_64-linux/ops/__init__.py
@@ -0,0 +1,35 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from .binned_gather import binned_gather
+from .binned_scatter import binned_scatter
+from .cumsum import exclusive_cumsum, inclusive_cumsum
+from .gather import gather
+from .histogram import histogram
+from .padded_gather import padded_gather
+from .padded_scatter import padded_scatter
+from .repeat import repeat
+from .replicate import replicate
+from .round_up import round_up
+from .scatter import scatter
+from .sort import sort
+from .sum import sum
+from .topology import topology
+
+__all__ = [
+ 'binned_gather',
+ 'binned_scatter',
+ 'exclusive_cumsum',
+ 'inclusive_cumsum',
+ 'gather',
+ 'histogram',
+ 'padded_gather',
+ 'padded_scatter',
+ 'repeat',
+ 'replicate',
+ 'round_up',
+ 'scatter',
+ 'sort',
+ 'sum',
+ 'topology',
+]
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/ops/all_to_all_benchmark.py b/build/torch28-cxx11-cu126-x86_64-linux/ops/all_to_all_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c939818edca3345f6344bbc7cef07ffe3cd0181
--- /dev/null
+++ b/build/torch28-cxx11-cu126-x86_64-linux/ops/all_to_all_benchmark.py
@@ -0,0 +1,63 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import torch.distributed as dist
+
+# from megablocks import benchmark_util
+# from megablocks.layers.all_to_all import all_to_all
+
+from .. import benchmark_util
+from .._layers.all_to_all import all_to_all
+
+_ALL_TO_ALL_BENCHMARK = (
+ (8, 1024),
+ (16, 1024),
+ (32, 1024),
+ (64, 1024),
+ (128, 1024),
+ (256, 1024),
+ (512, 1024),
+ (1024, 1024),
+ (2 * 1024, 1024),
+ (4 * 1024, 1024),
+ (8 * 1024, 1024),
+ (16 * 1024, 1024),
+ (32 * 1024, 1024),
+ (64 * 1024, 1024),
+ (128 * 1024, 1024),
+ (256 * 1024, 1024),
+ (512 * 1024, 1024),
+ (1024 * 1024, 1024),
+)
+
+
+def benchmark_all_to_all(group, sl, hs):
+ world_size = dist.get_world_size(group)
+ assert (sl % world_size) == 0
+ send_recv_sizes = [sl // world_size] * world_size
+
+ x = torch.randn((sl, hs)).cuda().half()
+
+ details = {
+ 'world_size': world_size,
+ 'message_size (B)': send_recv_sizes[0] * hs * 2, # 2B elements.
+ }
+
+ def benchmark():
+ return all_to_all(x, send_recv_sizes, send_recv_sizes, group)
+
+ time, std = benchmark_util.benchmark_function(benchmark)
+
+ if dist.get_rank(group) == 0:
+ benchmark_util.log_benchmark('All-To-All', details, time, std)
+
+
+if __name__ == '__main__':
+ assert dist.is_available()
+ group = dist.init_process_group(backend='nccl')
+ local_rank = dist.get_rank(group)
+ torch.cuda.set_device(local_rank)
+
+ for args in _ALL_TO_ALL_BENCHMARK:
+ benchmark_all_to_all(group, *args)
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/ops/binned_gather.py b/build/torch28-cxx11-cu126-x86_64-linux/ops/binned_gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..189a7fa3518d660f29ea32e7a04827164af98d60
--- /dev/null
+++ b/build/torch28-cxx11-cu126-x86_64-linux/ops/binned_gather.py
@@ -0,0 +1,37 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for binned_gather kernel.
+class BinnedGatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bins: torch.Tensor,
+ bin_size: int,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bins)
+ ctx.top_k = top_k
+ return kernels.binned_gather(x, indices, None, bins, bin_size, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ indices, bins = ctx.saved_tensors
+ out = kernels.binned_scatter(grad, indices, None, bins, ctx.top_k)
+ return out, None, None, None, None
+
+
+binned_gather = BinnedGatherOp.apply
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/ops/binned_scatter.py b/build/torch28-cxx11-cu126-x86_64-linux/ops/binned_scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb937c0c106662ce8108c1cb926f8f063b163d3d
--- /dev/null
+++ b/build/torch28-cxx11-cu126-x86_64-linux/ops/binned_scatter.py
@@ -0,0 +1,59 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for binned_scatter kernel.
+class BinnedScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ):
+ assert len(x.size()) == 3
+ ctx.bin_size = x.size(1)
+ ctx.top_k = top_k
+
+ # TODO(tgale): Don't save 'x' for backwards if we don't need to
+ # calculate the gradient w.r.t. 'weights'.
+ ctx.save_for_backward(x, indices, weights, bins)
+ return kernels.binned_scatter(x, indices, weights, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ x, indices, weights, bins = ctx.saved_tensors
+ out = kernels.binned_gather(
+ grad,
+ indices,
+ weights,
+ bins,
+ ctx.bin_size,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[2]:
+ wgrad = kernels.binned_scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bins,
+ ctx.top_k,
+ )
+ return out, None, wgrad, None, None
+
+
+binned_scatter = BinnedScatterOp.apply
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/ops/cumsum.py b/build/torch28-cxx11-cu126-x86_64-linux/ops/cumsum.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2b7572391e20045d335cf7337246e8a9b9f57ef
--- /dev/null
+++ b/build/torch28-cxx11-cu126-x86_64-linux/ops/cumsum.py
@@ -0,0 +1,52 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ # import megablocks_ops as ops # type: ignore
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrappers for cumsum kernels.
+# NOTE: Does not support gradients.
+class ExclusiveCumsumOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, dim: int):
+ if len(x.size()) == 1:
+ x = x.view([1, -1])
+ out = torch.empty_like(x)
+ ops.exclusive_cumsum(x, 1, out)
+ return out.squeeze()
+ out = torch.empty_like(x)
+ ops.exclusive_cumsum(x, dim, out)
+ return out
+
+
+exclusive_cumsum = ExclusiveCumsumOp.apply
+
+
+class InclusiveCumsumOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, dim: int) -> torch.Tensor:
+ if len(x.size()) == 1:
+ x = x.view([1, -1])
+ out = torch.empty_like(x)
+ ops.inclusive_cumsum(x, 1, out)
+ return out.squeeze()
+ out = torch.empty_like(x)
+ ops.inclusive_cumsum(x, dim, out)
+ return out
+
+
+inclusive_cumsum = InclusiveCumsumOp.apply
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/ops/gather.py b/build/torch28-cxx11-cu126-x86_64-linux/ops/gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1f87c1e7bed8d3589dd790805234976e0b05898
--- /dev/null
+++ b/build/torch28-cxx11-cu126-x86_64-linux/ops/gather.py
@@ -0,0 +1,38 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for gather kernel.
+class GatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bin_ids, bins)
+ ctx.top_k = top_k
+ return kernels.gather(x, indices, bin_ids, None, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+
+ indices, bin_ids, bins = ctx.saved_tensors
+ out = kernels.scatter(grad, indices, bin_ids, None, bins, ctx.top_k)
+ return out, None, None, None, None, None
+
+
+gather = GatherOp.apply
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/ops/histogram.py b/build/torch28-cxx11-cu126-x86_64-linux/ops/histogram.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b3f058ec373cbba7555704fb5e4212c3cc75d9d
--- /dev/null
+++ b/build/torch28-cxx11-cu126-x86_64-linux/ops/histogram.py
@@ -0,0 +1,27 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for histogram kernel.
+# NOTE: Does not support gradients.
+class HistogramOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, max_val: float):
+ return ops.histogram(x, max_val)
+
+
+histogram = HistogramOp.apply
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/ops/histogram_benchmark.py b/build/torch28-cxx11-cu126-x86_64-linux/ops/histogram_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..c57b7bf8228e01237236748147368b09ffdf8072
--- /dev/null
+++ b/build/torch28-cxx11-cu126-x86_64-linux/ops/histogram_benchmark.py
@@ -0,0 +1,78 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import numpy as np
+import torch
+from absl.testing import parameterized
+
+from .. import ops
+
+_HISTOGRAM_TESTS = (
+ (16384, torch.int32, 2),
+ (16384, torch.int32, 4),
+ (16384, torch.int32, 8),
+ (16384, torch.int32, 16),
+ (16384, torch.int32, 32),
+ (16384, torch.int32, 64),
+ (16384, torch.int32, 128),
+ (16384, torch.int32, 256),
+)
+
+
+def benchmark_function(fn, iterations=10):
+ # Run once to get rid of startup overhead.
+ fn()
+ times = []
+ for _ in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+ start.record()
+ fn()
+ end.record()
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ times = np.array(times)
+ return times.mean(), times.std(), times.max(), times.min()
+
+
+def log_benchmark(arguments, mean_t, std_t):
+ print('=' * 60)
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t))
+ print('=' * 60)
+
+
+class HistogramBenchmark(parameterized.TestCase):
+
+ @parameterized.parameters(*_HISTOGRAM_TESTS)
+ def testHistogram(self, n, dtype, max_val):
+ x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
+
+ mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.histogram(x, max_val),)
+ arguments = {
+ 'n': n,
+ 'dtype': dtype,
+ 'max_val': max_val,
+ }
+ log_benchmark(arguments, mean_t, std_t)
+
+ @parameterized.parameters(*_HISTOGRAM_TESTS)
+ def testTorchHistogram(self, n, dtype, max_val):
+ x = torch.randint(0, 128, (n,)).cuda().to(dtype)
+
+ mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.histc(x, max_val, 0, max_val - 1),)
+ arguments = {
+ 'n': n,
+ 'dtype': dtype,
+ 'max_val': max_val,
+ }
+ log_benchmark(arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/ops/matmul_benchmark.py b/build/torch28-cxx11-cu126-x86_64-linux/ops/matmul_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ccc5dcec5e9a663794fad944c45285869c4d1c1
--- /dev/null
+++ b/build/torch28-cxx11-cu126-x86_64-linux/ops/matmul_benchmark.py
@@ -0,0 +1,415 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+
+# import stk
+
+# try:
+# import stk
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/ops/matmul_benchmark.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+
+# Calling tensor.t() calls tensor.transpose(0, 1) which calls
+# torch.as_strided(...). Circumvent this chain to avoid an overhead
+# this adds.
+def transpose_view(x):
+ return torch.as_strided(
+ x,
+ (x.shape[1], x.shape[0]),
+ (x.stride()[1], x.stride()[0]),
+ )
+
+
+_MATMUL_TESTS = (
+ (64 * 1024, 512, 2048, 64),
+ (32 * 1024, 768, 3072, 64),
+ (8 * 1024, 1024, 4096, 64),
+ (4 * 2048, 4096, 4 * 4096, 4),
+)
+
+
+def log_benchmark(name, arguments, time, std, flops):
+ benchmark_util.log_benchmark(name, arguments, time, std)
+ print('flops = {:.2f}B'.format(flops / 1e9))
+ print('throughput = {:.2f}T'.format(flops / 1e9 / time))
+ print('=' * 60)
+
+
+class MatmulBenchmark(parameterized.TestCase):
+
+ def build_sparse_matrix(self, x, padded_bins, fhs, ne):
+ blocking = 128
+ padded_tokens, _ = x.size()
+ assert padded_tokens % blocking == 0
+ assert fhs % blocking == 0
+
+ # Offsets for the sparse matrix. All rows have the
+ # same number of nonzero blocks dictated by the
+ # dimensionality of a single expert.
+ block_rows = padded_tokens // blocking
+ blocks_per_row = fhs // blocking
+ offsets = torch.arange(
+ 0,
+ block_rows * blocks_per_row + 1,
+ blocks_per_row,
+ dtype=torch.int32,
+ device=x.device,
+ )
+
+ # Indices for the sparse matrix. The indices for
+ # the intermediate matrix are dynamic depending
+ # on the mapping of tokens to experts.
+ column_indices = ops.topology(
+ padded_bins,
+ blocking,
+ block_rows,
+ blocks_per_row,
+ )
+ data = torch.empty(
+ column_indices.numel(),
+ blocking,
+ blocking,
+ dtype=torch.float16,
+ device=x.device,
+ )
+ shape = (padded_tokens, fhs * ne)
+ row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
+ return stk.Matrix(shape, data, row_indices, column_indices, offsets)
+
+ def build_input_matrix(self, sl, hs, ne):
+ x = torch.randn((sl, hs)).cuda().half()
+
+ # Assign tokens to experts uniformly.
+ top_expert = torch.arange(0, sl).cuda().int() % ne
+
+ bin_ids, indices = ops.sort(top_expert)
+ tokens_per_expert = ops.histogram(top_expert, ne)
+ padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ out = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, 1)
+ return out, padded_bins
+
+ def build_weight_matrix(self, ne, hs, fhs):
+ return torch.randn((hs, ne * fhs)).cuda().half()
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear0_Fwd_SDD_NT(self, sl, hs, fhs, ne):
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
+ w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+ topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+ w = transpose_view(w)
+
+ def benchmark():
+ return stk.ops.sdd(x, w, topo)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '0::Fwd::SDD::NT',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * fhs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear0_GradX_DSD_NN(self, sl, hs, fhs, ne):
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
+ w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+ topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+
+ def benchmark():
+ return stk.ops.dsd(topo, w)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '0::GradX::DSD::NN',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * fhs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear0_GradW_DSD_TN(self, sl, hs, fhs, ne):
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
+ topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+ topo = topo.t()
+
+ def benchmark():
+ return stk.ops.dsd(topo, x)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '0::GradW::DSD::TN',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * fhs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear1_Fwd_DSD_NN(self, sl, hs, fhs, ne):
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
+ w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+ x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+
+ def benchmark():
+ return stk.ops.dsd(x, w)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '1::Fwd::DSD::NN',
+ arguments,
+ mean_t,
+ std_t,
+ x.nnz * hs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear1_GradX_SDD_NT(self, sl, hs, fhs, ne):
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
+ w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+ x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+ out = stk.ops.dsd(x, w)
+ w = transpose_view(w)
+
+ def benchmark():
+ return stk.ops.sdd(out, w, x)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '1::GradX::SDD::NT',
+ arguments,
+ mean_t,
+ std_t,
+ x.nnz * hs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear1_GradW_DSD_TN(self, sl, hs, fhs, ne):
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
+ w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+ x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+ out = stk.ops.dsd(x, w)
+ x = x.t()
+
+ def benchmark():
+ return stk.ops.dsd(x, out)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '1::GradW::DSD::TN',
+ arguments,
+ mean_t,
+ std_t,
+ x.nnz * hs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear0_Fwd_DDD_NT(self, sl, hs, fhs, ne):
+ assert (sl % ne) == 0
+ x = torch.randn((ne, sl // ne, hs)).cuda().half()
+ w = torch.randn((ne, hs, fhs)).cuda().half()
+
+ w = w.transpose(1, 2).contiguous()
+ w = w.transpose(1, 2)
+
+ def benchmark():
+ return torch.bmm(x, w)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '0::Fwd:DDD::NT',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * fhs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear0_GradX_DDD_NN(self, sl, hs, fhs, ne):
+ assert (sl % ne) == 0
+ x = torch.randn((ne, sl // ne, hs)).cuda().half()
+ w = torch.randn((ne, hs, fhs)).cuda().half()
+ out = torch.bmm(x, w)
+ w = w.transpose(1, 2).contiguous()
+
+ def benchmark():
+ return torch.bmm(out, w)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '0:GradX:DDD::NN',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * fhs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear0_GradW_DDD_TN(self, sl, hs, fhs, ne):
+ assert (sl % ne) == 0
+ x = torch.randn((ne, sl // ne, hs)).cuda().half()
+ w = torch.randn((ne, hs, fhs)).cuda().half()
+ out = torch.bmm(x, w)
+ out = out.transpose(1, 2)
+
+ def benchmark():
+ return torch.bmm(out, x)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '0:GradW:DDD::TN',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * fhs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear1_Fwd_DDD_NN(self, sl, hs, fhs, ne):
+ assert (sl % ne) == 0
+ x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+ w = torch.randn((ne, fhs, hs)).cuda().half()
+
+ def benchmark():
+ return torch.bmm(x, w)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '1::Fwd::DDD::NN',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * hs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear1_GradX_DDD_NT(self, sl, hs, fhs, ne):
+ assert (sl % ne) == 0
+ x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+ w = torch.randn((ne, fhs, hs)).cuda().half()
+ out = torch.bmm(x, w)
+ w = torch.transpose(w, 1, 2)
+
+ def benchmark():
+ return torch.bmm(out, w)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '1::GradX::DDD::NT',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * hs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear1_GradW_DDD_TN(self, sl, hs, fhs, ne):
+ assert (sl % ne) == 0
+ x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+ w = torch.randn((ne, fhs, hs)).cuda().half()
+ out = torch.bmm(x, w)
+ x = torch.transpose(x, 1, 2)
+
+ def benchmark():
+ return torch.bmm(x, out)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '1::GradW::DDD::TN',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * hs * 2,
+ )
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/ops/padded_gather.py b/build/torch28-cxx11-cu126-x86_64-linux/ops/padded_gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1cf4047c9494394d2a3884ba8830179013db7ff
--- /dev/null
+++ b/build/torch28-cxx11-cu126-x86_64-linux/ops/padded_gather.py
@@ -0,0 +1,55 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for padded_gather kernel.
+class PaddedGatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bin_ids, bins, padded_bins)
+ ctx.top_k = top_k
+ return kernels.padded_gather(
+ x,
+ indices,
+ bin_ids,
+ None,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+
+ indices, bin_ids, bins, padded_bins = ctx.saved_tensors
+ out = kernels.padded_scatter(
+ grad,
+ indices,
+ bin_ids,
+ None,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+ return out, None, None, None, None, None
+
+
+padded_gather = PaddedGatherOp.apply
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/ops/padded_scatter.py b/build/torch28-cxx11-cu126-x86_64-linux/ops/padded_scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..61e021b81497e472cda5d72bdac557a0ca92d262
--- /dev/null
+++ b/build/torch28-cxx11-cu126-x86_64-linux/ops/padded_scatter.py
@@ -0,0 +1,98 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for padded_scatter kernel.
+class PaddedScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+ ):
+ maybe_x = [x] if ctx.needs_input_grad[3] else []
+ ctx.save_for_backward(
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ *maybe_x,
+ )
+ ctx.top_k = top_k
+ ctx.x_shape = x.shape
+ return kernels.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ saved_tensors = ctx.saved_tensors
+
+ indices, bin_ids, weights, bins, padded_bins = saved_tensors[:5]
+ dgrad = None
+ if ctx.needs_input_grad[0]:
+ dgrad = kernels.padded_gather(
+ grad,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[3]: # need wgrad
+ x = saved_tensors[-1]
+ wgrad = kernels.padded_scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+ return dgrad, None, None, wgrad, None, None, None, None
+
+
+def padded_scatter(
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+):
+ return PaddedScatterOp.apply(
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/ops/padded_scatter_benchmark.py b/build/torch28-cxx11-cu126-x86_64-linux/ops/padded_scatter_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..c575cfe7487d346ba9ec18bbb7ef17f2eb77ec51
--- /dev/null
+++ b/build/torch28-cxx11-cu126-x86_64-linux/ops/padded_scatter_benchmark.py
@@ -0,0 +1,66 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import torch
+from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+_PADDED_SCATTER_BENCHMARK = (
+ # dMoE-Medium, 8-way EMP.
+ (1024 * 16, 1024, 8, 4),
+ # dMoE-Medium, post-all-to-all.
+ (1024 * 16 * 4, 1024, 8, 1),
+)
+
+
+class PaddedScatterTest(parameterized.TestCase):
+
+ @parameterized.parameters(*_PADDED_SCATTER_BENCHMARK)
+ def testPaddedScatter(self, sl, hs, ne, top_k):
+ # Create the data and indices.
+ x = torch.randn((sl, hs)).cuda().half()
+
+ # Randomly assign tokens to experts.
+ top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int()
+ bin_ids, indices = ops.sort(top_expert)
+ tokens_per_expert = ops.histogram(top_expert, ne)
+ padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+
+ # Sample weights for the scatter reduce.
+ weights = torch.rand((sl * top_k,)).cuda().half()
+
+ # Gather the data to prepare for backwards.
+ x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
+
+ def benchmark():
+ return ops.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ time, std = benchmark_util.benchmark_function(benchmark)
+ benchmark_util.log_benchmark(
+ 'Padded Scatter',
+ {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'num_experts': ne,
+ 'top_k': top_k,
+ },
+ time,
+ std,
+ )
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/ops/permute_benchmark.py b/build/torch28-cxx11-cu126-x86_64-linux/ops/permute_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..6536eeeae402659a087e5c51ef9840627af56501
--- /dev/null
+++ b/build/torch28-cxx11-cu126-x86_64-linux/ops/permute_benchmark.py
@@ -0,0 +1,149 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import torch
+from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+_PERMUTE_TESTS = (
+ (16384, 768, 2),
+ (16384, 768, 4),
+ (16384, 768, 8),
+ (16384, 768, 16),
+ (16384, 768, 32),
+ (16384, 768, 64),
+ (16384, 768, 128),
+ (16384 * 8, 768, 2),
+ (16384 * 8, 768, 4),
+ (16384 * 8, 768, 8),
+ (16384 * 8, 768, 16),
+ (16384 * 8, 768, 32),
+ (16384 * 8, 768, 64),
+ (16384 * 8, 768, 128),
+)
+
+
+class PermuteBenchmark(parameterized.TestCase):
+
+ @parameterized.parameters(*_PERMUTE_TESTS)
+ def testBinnedGather(self, sl, hs, ne):
+ # NOTE: Capacity factor == 1.
+ ec = sl // ne
+
+ # Create the data and indices.
+ x = torch.randn((sl, hs)).cuda().half()
+ top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+ bin_ids, indices = ops.sort(top_expert)
+ tokens_per_expert = ops.histogram(indices, ne)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+
+ def benchmark():
+ return ops.binned_gather(x, indices, bins, ec)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'num_experts': ne,
+ }
+ benchmark_util.log_benchmark('BinnedGather', arguments, mean_t, std_t)
+
+ @parameterized.parameters(*_PERMUTE_TESTS)
+ def testBinnedScatter(self, sl, hs, ne):
+ # NOTE: Capacity factor == 1.
+ ec = sl // ne
+
+ # Create the data and indices.
+ x = torch.randn((sl, hs)).cuda().half()
+ top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+ bin_ids, indices = ops.sort(top_expert)
+ tokens_per_expert = ops.histogram(indices, ne)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ x = ops.binned_gather(x, indices, bins, ec)
+
+ def benchmark():
+ return ops.binned_scatter(x, indices, bins)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'num_experts': ne,
+ }
+ benchmark_util.log_benchmark('BinnedScatter', arguments, mean_t, std_t)
+
+ @parameterized.parameters(*_PERMUTE_TESTS)
+ def testPaddedGather(self, sl, hs, ne):
+ # Create the data and indices.
+ x = torch.randn((sl, hs)).cuda().half()
+
+ # Randomly assign tokens to experts.
+ top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+ bin_ids, indices = ops.sort(top_expert)
+ tokens_per_expert = ops.histogram(top_expert, ne)
+ padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+
+ def benchmark():
+ return ops.padded_gather(x, indices, bin_ids, bins, padded_bins)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'num_experts': ne,
+ }
+ benchmark_util.log_benchmark('PaddedGather', arguments, mean_t, std_t)
+
+ @parameterized.parameters(*_PERMUTE_TESTS)
+ def testPaddedScatter(self, sl, hs, ne):
+ # Create the data and indices.
+ x = torch.randn((sl, hs)).cuda().half()
+
+ # Randomly assign tokens to experts.
+ top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+ bin_ids, indices = ops.sort(top_expert)
+ tokens_per_expert = ops.histogram(top_expert, ne)
+ padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins)
+
+ def benchmark():
+ return ops.padded_scatter(x, indices, bin_ids, bins, padded_bins)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'num_experts': ne,
+ }
+ benchmark_util.log_benchmark('PaddedScatter', arguments, mean_t, std_t)
+
+ @parameterized.parameters(*_PERMUTE_TESTS)
+ def testCopy(self, sl, hs, ne):
+ # NOTE: Capacity factor == 1.
+ # ec = sl // ne
+
+ # Create the data and indices.
+ x = torch.randn((sl, hs)).cuda().half()
+ y = x.clone()
+
+ def benchmark():
+ return y.copy_(x)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'num_experts': ne,
+ }
+ benchmark_util.log_benchmark('Copy', arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/ops/repeat.py b/build/torch28-cxx11-cu126-x86_64-linux/ops/repeat.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e9e09de5f857d51cd758ab30b2f3a846d6f9275
--- /dev/null
+++ b/build/torch28-cxx11-cu126-x86_64-linux/ops/repeat.py
@@ -0,0 +1,10 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+
+def repeat(x: torch.Tensor, tiling: torch.Size):
+ if all((t == 1 for t in tiling)):
+ return x
+ return x.repeat(*tiling)
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/ops/replicate.py b/build/torch28-cxx11-cu126-x86_64-linux/ops/replicate.py
new file mode 100644
index 0000000000000000000000000000000000000000..26daf0eede330603a4b8ea7167faf1411d07ca93
--- /dev/null
+++ b/build/torch28-cxx11-cu126-x86_64-linux/ops/replicate.py
@@ -0,0 +1,36 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for replicate kernel.
+class ReplicateOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, bins: torch.Tensor, num_outputs: int):
+ ctx.save_for_backward(bins)
+ out = torch.empty((x.shape[0], num_outputs), dtype=x.dtype, device=x.device)
+ ops.replicate_forward(x, bins, out)
+ return out
+
+ @staticmethod
+ def backward(ctx: Any, grad: torch.Tensor):
+ bins, = ctx.saved_tensors
+ out = torch.empty((grad.shape[0], bins.shape[0]), dtype=grad.dtype, device=grad.device)
+ ops.replicate_backward(grad, bins, out)
+ return out, None, None
+
+
+replicate = ReplicateOp.apply
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/ops/round_up.py b/build/torch28-cxx11-cu126-x86_64-linux/ops/round_up.py
new file mode 100644
index 0000000000000000000000000000000000000000..6cf6bc873c9f448c5fa9126ebcfd66e8688002af
--- /dev/null
+++ b/build/torch28-cxx11-cu126-x86_64-linux/ops/round_up.py
@@ -0,0 +1,14 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+
+def round_up(x: torch.Tensor, value: int):
+ assert isinstance(value, int)
+ assert x.dtype == torch.int32
+
+ # TODO(tgale): If this becomes and issue
+ # do this in a custom kernel. We only expect
+ # to use this on arrays of less than 1k elements.
+ return torch.div(x + (value - 1), value, rounding_mode='trunc') * value
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/ops/scatter.py b/build/torch28-cxx11-cu126-x86_64-linux/ops/scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4605d9b46f387761b070352365f223dbfe69d47
--- /dev/null
+++ b/build/torch28-cxx11-cu126-x86_64-linux/ops/scatter.py
@@ -0,0 +1,72 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Optional
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for scatter kernel.
+class ScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ) -> torch.Tensor:
+ maybe_x = [x] if ctx.needs_input_grad[3] else []
+ ctx.save_for_backward(indices, bin_ids, weights, bins, *maybe_x)
+ ctx.top_k = top_k
+ ctx.x_shape = x.shape
+ return kernels.scatter(x, indices, bin_ids, weights, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ saved_tensors = ctx.saved_tensors
+
+ indices, bin_ids, weights, bins = saved_tensors[:4]
+ dgrad = None
+ if ctx.needs_input_grad[0]:
+ dgrad = kernels.gather(
+ grad,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[3]: # need wgrad
+ x = saved_tensors[-1]
+ wgrad = kernels.scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bin_ids,
+ bins,
+ ctx.top_k,
+ )
+ return dgrad, None, None, wgrad, None, None, None
+
+
+def scatter(
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+) -> Optional[torch.Tensor]:
+ return ScatterOp.apply(x, indices, bin_ids, weights, bins, top_k)
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/ops/sort.py b/build/torch28-cxx11-cu126-x86_64-linux/ops/sort.py
new file mode 100644
index 0000000000000000000000000000000000000000..bda3bf64283e39533c2eae3627e76bb2d0262c9f
--- /dev/null
+++ b/build/torch28-cxx11-cu126-x86_64-linux/ops/sort.py
@@ -0,0 +1,38 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Optional, Tuple
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+_BITS_FOR_DTYPE = {
+ torch.int16: 16,
+ torch.int32: 32,
+ torch.int64: 64,
+}
+
+
+# Autograd wrapper for sort kernel.
+# NOTE: Does not support gradients.
+class SortOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, end_bit: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
+ if end_bit is None:
+ end_bit = _BITS_FOR_DTYPE[x.dtype]
+ x_out = torch.empty_like(x)
+ iota_out = torch.empty_like(x)
+ ops.sort(x, end_bit, x_out, iota_out)
+ return (x_out, iota_out)
+
+
+sort = SortOp.apply
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/ops/sort_benchmark.py b/build/torch28-cxx11-cu126-x86_64-linux/ops/sort_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..a92ff957d4c552c6e61d9279a7989795472af7b7
--- /dev/null
+++ b/build/torch28-cxx11-cu126-x86_64-linux/ops/sort_benchmark.py
@@ -0,0 +1,85 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import numpy as np
+import torch
+from absl.testing import parameterized
+
+from .. import ops
+
+_SORT_TESTS = (
+ (16384, torch.int32, None),
+ (16384, torch.int32, 2),
+ (16384, torch.int32, 128),
+)
+
+_BASELINE_SORT_TESTS = ((16384,),)
+
+
+def numpy_dtype(dtype):
+ types = {
+ torch.int16: np.int16,
+ torch.int32: np.int32,
+ torch.int64: np.int64,
+ }
+ return types[dtype]
+
+
+def benchmark_function(fn, iterations=10):
+ # Run once to get rid of startup overhead.
+ fn()
+ times = []
+ for _ in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+ start.record()
+ fn()
+ end.record()
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ times = np.array(times)
+ return times.mean(), times.std(), times.max(), times.min()
+
+
+def log_benchmark(arguments, mean_t, std_t):
+ print('=' * 60)
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t))
+ print('=' * 60)
+
+
+class SortBenchmark(parameterized.TestCase):
+
+ @parameterized.parameters(*_SORT_TESTS)
+ def testSort(self, n, dtype, max_val):
+ if max_val is None:
+ max_val = np.iinfo(numpy_dtype(dtype)).max
+ end_bit = int(np.ceil(np.log2(max_val)))
+ x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
+
+ mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.sort(x, end_bit),)
+ arguments = {
+ 'n': n,
+ 'dtype': dtype,
+ 'max_val': max_val,
+ }
+ log_benchmark(arguments, mean_t, std_t)
+
+ @parameterized.parameters(*_BASELINE_SORT_TESTS)
+ def testTorchSort(self, n):
+ x = torch.randint(0, 128, (n,)).cuda().to(torch.int32)
+
+ mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.sort(x))
+ arguments = {
+ 'n': n,
+ }
+ log_benchmark(arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/ops/stk_autocast.py b/build/torch28-cxx11-cu126-x86_64-linux/ops/stk_autocast.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a3626e5e0eec51339c95a448bca84be14a2ca93
--- /dev/null
+++ b/build/torch28-cxx11-cu126-x86_64-linux/ops/stk_autocast.py
@@ -0,0 +1,39 @@
+# vendored from
+# https://github.com/stanford-futuredata/stk/blob/736313768ef697ce13a0594a41b2512a0fbc9884/stk/backend/autocast.py
+import functools
+import torch
+
+
+def _is_eligible(x):
+ return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64)
+
+
+def _cast(x, dtype):
+ if isinstance(x, torch.Tensor) and _is_eligible(x):
+ return x.to(dtype)
+ elif isinstance(x, map):
+ return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()}
+ elif isinstance(x, list) or isinstance(x, tuple):
+ return type(x)(map(lambda y: _cast(y, dtype), x))
+ return x
+
+
+def custom_fwd(fwd):
+ """Wrap a custom autograd function that always uses autocast dtype."""
+
+ @functools.wraps(fwd)
+ def decorate_fwd(*args, **kwargs):
+ if torch.is_autocast_enabled():
+ with torch.autocast(device_type="cuda", enabled=False):
+ dtype = torch.get_autocast_gpu_dtype()
+ return fwd(*_cast(args, dtype), **_cast(kwargs, dtype))
+ return fwd(*args, **kwargs)
+ return decorate_fwd
+
+
+def custom_bwd(bwd):
+ @functools.wraps(bwd)
+ def decorate_bwd(*args, **kwargs):
+ with torch.autocast(device_type="cuda", enabled=False):
+ return bwd(*args, **kwargs)
+ return decorate_bwd
\ No newline at end of file
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/ops/sum.py b/build/torch28-cxx11-cu126-x86_64-linux/ops/sum.py
new file mode 100644
index 0000000000000000000000000000000000000000..e00c1aa68e1193f5b72f75a2edc37de8d505facc
--- /dev/null
+++ b/build/torch28-cxx11-cu126-x86_64-linux/ops/sum.py
@@ -0,0 +1,9 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+import torch
+
+
+def sum(x: torch.Tensor, dim: int = 0):
+ if x.shape[dim] == 1:
+ return x.squeeze(dim=dim)
+ return x.sum(dim=dim)
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/ops/topology.py b/build/torch28-cxx11-cu126-x86_64-linux/ops/topology.py
new file mode 100644
index 0000000000000000000000000000000000000000..76a50d3164db20534b099dcb4d8487a7aef25d15
--- /dev/null
+++ b/build/torch28-cxx11-cu126-x86_64-linux/ops/topology.py
@@ -0,0 +1,45 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for topology kernel.
+# NOTE: Does not support gradients.
+class TopologyOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(
+ ctx: Any,
+ padded_bins: torch.Tensor,
+ block_size: int,
+ output_block_rows: int,
+ output_block_columns: int,
+ ):
+ out = torch.empty(
+ output_block_rows * output_block_columns,
+ dtype=torch.int16,
+ device=padded_bins.device,
+ )
+ ops.indices(
+ padded_bins,
+ block_size,
+ output_block_rows,
+ output_block_columns,
+ out,
+ )
+ return out
+
+
+topology = TopologyOp.apply
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/stk/__init__.py b/build/torch28-cxx11-cu126-x86_64-linux/stk/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..73c40c267b1c3f4949e9c957a5d2c9f682dfc1a6
--- /dev/null
+++ b/build/torch28-cxx11-cu126-x86_64-linux/stk/__init__.py
@@ -0,0 +1,7 @@
+# import stk.random
+# import stk.ops
+# from stk.matrix import Matrix
+
+from . import random
+from . import ops
+from .matrix import Matrix
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/stk/backend/__init__.py b/build/torch28-cxx11-cu126-x86_64-linux/stk/backend/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/stk/backend/autocast.py b/build/torch28-cxx11-cu126-x86_64-linux/stk/backend/autocast.py
new file mode 100644
index 0000000000000000000000000000000000000000..97f6e919a60f3fd579ed0215031008d14111dc96
--- /dev/null
+++ b/build/torch28-cxx11-cu126-x86_64-linux/stk/backend/autocast.py
@@ -0,0 +1,37 @@
+import functools
+import torch
+
+
+def _is_eligible(x):
+ return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64)
+
+
+def _cast(x, dtype):
+ if isinstance(x, torch.Tensor) and _is_eligible(x):
+ return x.to(dtype)
+ elif isinstance(x, map):
+ return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()}
+ elif isinstance(x, list) or isinstance(x, tuple):
+ return type(x)(map(lambda y: _cast(y, dtype), x))
+ return x
+
+
+def custom_fwd(fwd):
+ """Wrap a custom autograd function that always uses autocast dtype."""
+
+ @functools.wraps(fwd)
+ def decorate_fwd(*args, **kwargs):
+ if torch.is_autocast_enabled():
+ with torch.autocast(device_type="cuda", enabled=False):
+ dtype = torch.get_autocast_gpu_dtype()
+ return fwd(*_cast(args, dtype), **_cast(kwargs, dtype))
+ return fwd(*args, **kwargs)
+ return decorate_fwd
+
+
+def custom_bwd(bwd):
+ @functools.wraps(bwd)
+ def decorate_bwd(*args, **kwargs):
+ with torch.autocast(device_type="cuda", enabled=False):
+ return bwd(*args, **kwargs)
+ return decorate_bwd
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/stk/backend/sputnik.py b/build/torch28-cxx11-cu126-x86_64-linux/stk/backend/sputnik.py
new file mode 100644
index 0000000000000000000000000000000000000000..220c947bc1e932e8c77cc30f4069e9930f1aa962
--- /dev/null
+++ b/build/torch28-cxx11-cu126-x86_64-linux/stk/backend/sputnik.py
@@ -0,0 +1,316 @@
+import torch
+
+from ..backend import triton_kernels as backend
+from ..backend.autocast import custom_bwd, custom_fwd
+
+
+def _standardize_shape(x, transpose):
+ if transpose:
+ return torch.Size((x[1], x[0]))
+ return x
+
+
+def _sparse_transpose(x):
+ return (torch.Size((x[0][1], x[0][0])), ) + x[1:]
+
+
+def _transpose_helper(x, transpose):
+ if isinstance(x, torch.Tensor):
+ return x.t() if transpose else x
+ if transpose:
+ x = _sparse_transpose(x)
+ return x + (transpose,)
+
+
+def _wrap(x):
+ if isinstance(x, torch.Tensor):
+ return (x,)
+ return x
+
+
+def _is_transposed(x):
+ return (not x.is_contiguous() and
+ x.stride()[0] == 1 and
+ x.stride()[1] == x.size()[0])
+
+
+def _call_helper(op, out, a, b, trans_a, trans_b):
+ args = (_wrap(_transpose_helper(a, trans_a)) +
+ _wrap(_transpose_helper(b, trans_b)))
+ if isinstance(out, tuple):
+ args = args + out
+ return op(*args)
+
+
+def _preprocess_inputs(lhs, rhs, dy):
+ if isinstance(lhs, torch.Tensor) and _is_transposed(lhs):
+ lhs = lhs.t()
+ if isinstance(rhs, torch.Tensor) and _is_transposed(rhs):
+ rhs = rhs.t()
+ if (isinstance(dy, torch.Tensor) and
+ not dy.is_contiguous() and
+ not _is_transposed(dy)):
+ dy = dy.contiguous()
+ if isinstance(dy, tuple) and not dy[1].is_contiguous():
+ dy = (dy[0], dy[1].contiguous()) + dy[2:]
+ return lhs, rhs, dy
+
+
+def _postprocess_outputs(x, transpose, grad):
+ if isinstance(x, torch.Tensor) and transpose:
+ return grad.t()
+ return grad
+
+
+def _lhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs):
+ lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy)
+
+ a, b = (rhs, dy) if trans_lhs else (dy, rhs)
+ trans_a = trans_lhs and trans_rhs
+ trans_b = trans_lhs or not trans_rhs
+ out = _call_helper(op, lhs, a, b, trans_a, trans_b)
+ return _postprocess_outputs(lhs, trans_lhs, out)
+
+
+def _rhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs):
+ lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy)
+
+ a, b = (dy, lhs) if trans_rhs else (lhs, dy)
+ trans_a = not trans_lhs or trans_rhs
+ trans_b = trans_lhs and trans_rhs
+ out = _call_helper(op, rhs, a, b, trans_a, trans_b)
+ return _postprocess_outputs(rhs, trans_rhs, out)
+
+
+class DSD(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs):
+ ctx.save_for_backward(data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ rhs)
+ ctx.shape = _standardize_shape(shape, transpose_a)
+ ctx.transpose_a = transpose_a
+
+ out = torch.empty(
+ (shape[0], rhs.size()[1]),
+ dtype=rhs.dtype,
+ device=rhs.device)
+
+ backend.dsd(shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs,
+ out)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs = (ctx.shape,) + saved_tensors[:-1]
+ rhs = saved_tensors[-1]
+ trans_a = ctx.transpose_a
+ trans_b = _is_transposed(rhs)
+
+ ddata = None
+ if ctx.needs_input_grad[1]:
+ ddata = _lhs_gradient(sdd,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ drhs = None
+ if ctx.needs_input_grad[-1]:
+ op = dds if trans_b else dsd
+ drhs = _rhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return None, ddata, None, None, None, None, None, None, None, drhs
+
+
+dsd = DSD.apply
+
+
+class DDS(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b):
+ ctx.save_for_backward(lhs,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t)
+ ctx.shape = _standardize_shape(shape, transpose_b)
+ ctx.transpose_b = transpose_b
+ out = torch.empty((lhs.size()[0], shape[1]),
+ dtype=lhs.dtype,
+ device=lhs.device)
+ backend.dds(lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b,
+ out)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs = saved_tensors[0]
+ rhs = (ctx.shape,) + saved_tensors[1:]
+ trans_a = _is_transposed(lhs)
+ trans_b = ctx.transpose_b
+
+ dlhs = None
+ if ctx.needs_input_grad[0]:
+ op = dsd if trans_a else dds
+ dlhs = _lhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ ddata = None
+ if ctx.needs_input_grad[2]:
+ ddata = _rhs_gradient(sdd,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return dlhs, None, ddata, None, None, None, None, None, None, None
+
+
+dds = DDS.apply
+
+
+class SDD(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ lhs,
+ rhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t):
+ ctx.save_for_backward(
+ lhs,
+ rhs,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t)
+ ctx.shape = shape
+ out = torch.empty(
+ data.shape,
+ dtype=lhs.dtype,
+ device=lhs.device)
+ backend.sdd(lhs,
+ rhs,
+ shape,
+ out,
+ offsets,
+ row_indices,
+ column_indices)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs, rhs = saved_tensors[:2]
+ dy = (ctx.shape, dy) + saved_tensors[2:]
+ trans_a = _is_transposed(lhs)
+ trans_b = _is_transposed(rhs)
+
+ dlhs = None
+ if ctx.needs_input_grad[0]:
+ op = dds if trans_a else dsd
+ dlhs = _lhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ drhs = None
+ if ctx.needs_input_grad[1]:
+ op = dsd if trans_b else dds
+ drhs = _rhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return dlhs, drhs, None, None, None, None, None, None, None, None
+
+
+sdd = SDD.apply
+
+class RowIndices(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, shape, data, offsets, column_indices):
+ out = torch.empty(
+ column_indices.shape,
+ dtype=column_indices.dtype,
+ device=column_indices.device)
+ backend.row_indices(shape, data, offsets, column_indices, out)
+ return out
+
+
+row_indices = RowIndices.apply
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/stk/backend/triton_kernels.py b/build/torch28-cxx11-cu126-x86_64-linux/stk/backend/triton_kernels.py
new file mode 100644
index 0000000000000000000000000000000000000000..c535309f3321249f475367164a558f94a4f8eb86
--- /dev/null
+++ b/build/torch28-cxx11-cu126-x86_64-linux/stk/backend/triton_kernels.py
@@ -0,0 +1,393 @@
+import torch
+import triton
+import triton.language as tl
+from dataclasses import dataclass
+
+@dataclass
+class TritonConfig:
+ BLOCK_M: int = 128
+ BLOCK_N: int = 128
+ BLOCK_K: int = 32
+ BLOCK_SIZE: int = 128
+ NUM_STAGES: int = 4
+ NUM_WARPS: int = 4
+
+def _validate_matmul_dims(M: int, K: int, N: int):
+ error_string = "incompatible dimensions: tensor has dim with length: {}, which must be divisible by {}"
+ assert M % TritonConfig.BLOCK_M == 0, error_string.format(M, TritonConfig.BLOCK_M)
+ assert K % TritonConfig.BLOCK_K == 0, error_string.format(K, TritonConfig.BLOCK_K)
+ assert N % TritonConfig.BLOCK_N == 0, error_string.format(N, TritonConfig.BLOCK_N)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _sdd_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+ # matrix multiplication
+ pid = tl.program_id(0)
+ pid_m = tl.load(row_indices + pid)
+ pid_n = tl.load(column_indices + pid)
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
+ rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
+ rk = tl.arange(0, BLOCK_K)
+ # pointers
+ A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
+ B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ for k in range(0, tl.cdiv(K, BLOCK_K)):
+ a = tl.load(A)
+ b = tl.load(B)
+ acc += tl.dot(a, b)
+ A += BLOCK_K * stride_ak
+ B += BLOCK_K * stride_bk
+ #Store to sparse matrix
+ acc = acc.to(C.dtype.element_ty)
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+ cm = tl.arange(0, BLOCK_M)
+ cn = tl.arange(0, BLOCK_N)
+ C = C + pid * BLOCK_ELEMENTS + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _dsd_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices, offsets,
+ block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+
+ # matrix multiplication
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+
+ num_pid_m = tl.num_programs(0)
+ num_pid_n = tl.num_programs(1)
+ pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M)
+
+ start_inx = tl.load(offsets + pid_m)
+ end_inx = tl.load(offsets + pid_m + 1)
+
+ # pointers to sparse matrix
+ rm = tl.arange(0, BLOCK_M)
+ rak = tl.arange(0, BLOCK_K)
+
+ A += (rm[:, None] * stride_am + rak[None, :] * stride_ak)
+
+ # pointers to dense matrix
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ rbk = tl.arange(0, BLOCK_K)
+ B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn)
+
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K)
+
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+ ak_sub_incr = BLOCK_K * stride_ak
+ bk_sub_incr = BLOCK_K * stride_bk
+ bk_block_incr = BLOCK_SIZE * stride_bk
+
+ for k in range(nsub_blocks * (end_inx - start_inx)):
+ sub_block_inx = k % nsub_blocks
+ block_inx = k // nsub_blocks
+
+ if trans_A:
+ ptr_A = A + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr
+ else:
+ ptr_A = A + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr
+
+ ptr_B = B + tl.load(column_indices + start_inx + block_inx) * bk_block_incr + sub_block_inx * bk_sub_incr
+
+ a = tl.load(ptr_A)
+ b = tl.load(ptr_B)
+ acc += tl.dot(a, b)
+
+ acc = acc.to(C.dtype.element_ty)
+
+ cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+
+ C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _dds_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices, offsets,
+ block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+
+ # matrix multiplication
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+
+ num_pid_m = tl.num_programs(0)
+ num_pid_n = tl.num_programs(1)
+ pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M)
+
+ start_inx = tl.load(offsets + pid_n)
+ end_inx = tl.load(offsets + pid_n + 1)
+
+ # pointers to dense matrix
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ rak = tl.arange(0, BLOCK_K)
+
+ A += (rm[:, None] * stride_am + rak[None, :] * stride_ak)
+
+ # pointers to sparse matrix
+ rn = tl.arange(0, BLOCK_N)
+ rbk = tl.arange(0, BLOCK_K)
+ B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn)
+
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K)
+
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+
+ ak_sub_incr = BLOCK_K * stride_ak
+ ak_block_incr = BLOCK_SIZE * stride_ak
+ bk_sub_incr = BLOCK_K * stride_bk
+
+ for k in range(nsub_blocks * (end_inx - start_inx)):
+ sub_block_inx = k % nsub_blocks
+ block_inx = k // nsub_blocks
+
+ if trans_B:
+ ptr_B = B + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr
+ else:
+ ptr_B = B + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr
+
+ ptr_A = A + tl.load(column_indices + start_inx + block_inx) * ak_block_incr + sub_block_inx * ak_sub_incr
+ a = tl.load(ptr_A)
+ b = tl.load(ptr_B)
+ acc += tl.dot(a, b)
+
+ acc = acc.to(C.dtype.element_ty)
+ cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+def dsd(shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs,
+ out
+ ):
+
+ device = rhs.device
+ trans_A = transpose_a
+ trans_B = False
+
+ if rhs.stride(0) > 1 and rhs.stride(1) > 1:
+ trans_B = True
+
+ # checks constraints
+ assert shape[1] == rhs.shape[0], "incompatible dimensions"
+ M, K = shape
+ _, N = rhs.shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if rhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ stride_am, stride_ak = data.stride(1), data.stride(2)
+ stride_bk, stride_bn = rhs.stride(0), rhs.stride(1)
+ a_column_indices = column_indices
+ a_offsets = offsets
+
+ # launch kernel
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N']))
+
+ if trans_A:
+ stride_am, stride_ak = data.stride(2), data.stride(1)
+ a_column_indices, a_offsets = column_indices_t, offsets_t
+
+ if trans_B:
+ stride_bk, stride_bn = rhs.stride(1), rhs.stride(0)
+
+ _dsd_kernel[grid](
+ data.data, rhs, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(0), out.stride(1),
+ row_indices, a_column_indices, a_offsets,
+ block_offsets_t, trans_A, trans_B,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+ # return out
+
+def dds(lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b,
+ out
+ ):
+
+ device = lhs.device
+ trans_B = transpose_b
+ trans_A = False
+
+ if lhs.stride(0) > 1 and lhs.stride(1) > 1:
+ trans_A = True
+
+ # checks constraints
+ assert lhs.shape[1] == shape[0], "incompatible dimensions"
+ M, K = lhs.shape
+ _, N = shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if lhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ stride_am, stride_ak = lhs.stride(0), lhs.stride(1)
+ stride_bk, stride_bn = data.stride(1), data.stride(2)
+ b_column_indices = column_indices_t
+ b_offsets = offsets_t
+
+ # launch kernel
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N']))
+
+ if trans_A:
+ stride_am, stride_ak = lhs.stride(1), lhs.stride(0)
+ if trans_B:
+ stride_bk, stride_bn = data.stride(2), data.stride(1)
+ b_column_indices, b_offsets = column_indices, offsets
+
+ _dds_kernel[grid](
+ lhs, data, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(0), out.stride(1),
+ row_indices, b_column_indices, b_offsets,
+ block_offsets_t, trans_A, trans_B,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+
+def sdd(lhs,
+ rhs,
+ shape,
+ out,
+ offsets,
+ row_indices,
+ column_indices
+ ):
+
+ device = out.device
+ trans_A = False
+ trans_B = False
+
+ if lhs.stride(0) > 1 and lhs.stride(1) > 1:
+ trans_A = True
+ if rhs.stride(0) > 1 and rhs.stride(1) > 1:
+ trans_B = True
+
+ # checks constraints
+ assert lhs.shape[1] == rhs.shape[0], "incompatible dimensions"
+ M, K = lhs.shape
+ _, N = rhs.shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if out.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ # launch kernel
+ nnz_blocks = len(row_indices)
+ grid = lambda META: (nnz_blocks,)
+
+ stride_am, stride_ak = lhs.stride(0), lhs.stride(1)
+ stride_bk, stride_bn = rhs.stride(0), rhs.stride(1)
+
+ if trans_A:
+ stride_am, stride_ak = lhs.stride(1), lhs.stride(0)
+ if trans_B:
+ stride_bk, stride_bn = rhs.stride(1), rhs.stride(0)
+
+ _sdd_kernel[grid](
+ lhs, rhs, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(1), out.stride(2),
+ row_indices, column_indices,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+
+@triton.jit
+def _row_indices_kernel(offsets, out):
+ pid = tl.program_id(0)
+ row_offset = tl.load(offsets + pid)
+ nnz_blocks = tl.load(offsets + pid + 1) - row_offset
+ for nnz_block in range(nnz_blocks):
+ tl.store(out + row_offset + nnz_block, pid)
+
+def row_indices(
+ shape, data, offsets, column_indices, out
+):
+ block_rows = len(offsets) - 1
+ _row_indices_kernel[(block_rows, )](offsets, out)
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/stk/matrix.py b/build/torch28-cxx11-cu126-x86_64-linux/stk/matrix.py
new file mode 100644
index 0000000000000000000000000000000000000000..80f42263d6aada287adbfa52a61fe950162a9e28
--- /dev/null
+++ b/build/torch28-cxx11-cu126-x86_64-linux/stk/matrix.py
@@ -0,0 +1,329 @@
+import numpy as np
+import torch
+
+# 1. Add heavyweight (data) validation helper.
+# 2. Add construction helpers
+# 3. Make indentation consistent
+# 4. Replace asserts with descriptive errors.
+
+##
+### Validation helpers.
+##
+
+
+def _validate_matrix(shape, data, row_indices, column_indices, offsets):
+ # Data should be [nnz, block_size, block_size]
+ if data.dim() == 1:
+ data = torch.reshape(data, [data.numel(), 1, 1])
+
+ # Blocks should be square.
+ if data.shape[-2] != data.shape[-1]:
+ raise ValueError(
+ "Expected square blocking in data. "
+ f"Got block shape {[data.shape[-2], data.shape[-1]]}")
+
+ # Flatten batch dimensions on data - original shape preserved
+ # in shape argument.
+ block_size = data.shape[-1]
+ data = data.view([-1, block_size, block_size])
+
+ if data.dim() != 3:
+ raise ValueError(
+ "Expected 3D shape for data (nnz, block, block). "
+ f"Got shape {data.dim()}D shape.")
+
+ block_size = data.shape[1]
+ if shape[-2] % block_size != 0 or shape[-1] % block_size != 0:
+ raise ValueError(
+ "Matrix shape must be dividible by blocking. "
+ f"Got shape {shape} with "
+ f"{[block_size, block_size]} blocking.")
+
+ if np.prod(shape) < data.numel():
+ raise ValueError(
+ "Invalid matrix. Number of nonzeros exceeds matrix capacity "
+ f"({data.numel()} v. {np.prod(shape)})")
+
+ if row_indices.dim() != 1:
+ raise ValueError(
+ f"Expected 1D row_indices. Got {row_indices.dim()}D row_indices.")
+
+ if column_indices.dim() != 1:
+ raise ValueError(
+ f"Expected 1D column_indices. Got {column_indices.dim()}D column_indices.")
+
+ if offsets.dim() != 1:
+ raise ValueError(
+ f"Expected 1D offsets. Got {offsets.dim()}D offsets.")
+
+ if row_indices.numel() != data.shape[0]:
+ raise ValueError(
+ "Expected 1 index per nonzero block. "
+ f"Got {row_indices.numel()} row_indices for {data.shape[0]} blocks")
+
+ if column_indices.numel() != data.shape[0]:
+ raise ValueError(
+ "Expected 1 index per nonzero block. "
+ f"Got {column_indices.numel()} column_indices for {data.shape[0]} blocks")
+
+ block_rows = np.prod(shape[:-1]) / block_size
+ if offsets.numel() != block_rows + 1:
+ raise ValueError(
+ "Expected one offset per block row plus one. "
+ f"Got {offsets.numel()} offsets with {block_rows} block rows.")
+
+ is_cuda = (data.is_cuda and
+ row_indices.is_cuda and
+ column_indices.is_cuda and
+ offsets.is_cuda)
+ is_cpu = (not data.is_cuda and
+ not row_indices.is_cuda and
+ not column_indices.is_cuda and
+ not offsets.is_cuda)
+ if not (is_cuda or is_cpu):
+ raise ValueError(
+ "Expected data & meta-data on common device. "
+ f"Got data on {data.device}, row_indices on {row_indices.device} "
+ f"column_indices on {column_indices.device} and "
+ f"offsets on {offsets.device}.")
+
+ if data.dtype != torch.float16:
+ raise ValueError(
+ f"Expected float16 data. Got {data.dtype} data.")
+ if row_indices.dtype != torch.int16:
+ raise ValueError(
+ f"Expected int16 row_indices. Got {row_indices.dtype} row_indices.")
+ if column_indices.dtype != torch.int16:
+ raise ValueError(
+ f"Expected int16 column_indices. Got {column_indices.dtype} column_indices.")
+ if offsets.dtype != torch.int32:
+ raise ValueError(
+ f"Expected int32 offsets. Got {offsets.dtype} offsets.")
+ return data
+
+
+def _transpose(size, data, row_indices, column_indices, offsets):
+ block_columns = size[1] // data.shape[1]
+
+ # Sort row indices by column indices to get the transposed matrix's
+ # column indices.
+ gather_indices = column_indices.argsort()
+ column_indices_t = row_indices.gather(0, gather_indices)
+ block_offsets_t = gather_indices.int()
+
+ # NOTE: Histogram is not implemented for any integer type on CPU. Do
+ # the histogram in 32-bit float, which can exactly represent 16-bit
+ # integers.
+ column_indices_float = column_indices.float()
+
+ zero = torch.zeros((1,), dtype=torch.int32, device=data.device)
+ nnz_per_column = column_indices_float.histc(block_columns, 0, block_columns)
+ nnz_per_column = nnz_per_column.int()
+ offsets_t = torch.cat([zero, nnz_per_column.cumsum(0, dtype=torch.int32)])
+ return column_indices_t, offsets_t, block_offsets_t
+
+
+class Matrix(torch.nn.Module):
+ """A matrix stored in sparse format.
+
+ Underlying format is block compressed sparse row (BCSR).
+
+ TODO(tgale): Make this mirror torch.Tensor API as much as possible.
+ """
+
+ def __init__(self,
+ size,
+ data,
+ row_indices,
+ column_indices,
+ offsets,
+ column_indices_t=None,
+ offsets_t=None,
+ block_offsets_t=None):
+ super().__init__()
+ self._size = size
+ self._data = data
+ self._row_indices = row_indices
+ self._column_indices = column_indices
+ self._offsets = offsets
+
+ # Produce the transpose meta-data if it is not passed in.
+ if ((column_indices_t is None) or (offsets_t is None) or
+ (block_offsets_t is None)):
+ column_indices_t, offsets_t, block_offsets_t = _transpose(
+ size, data, row_indices, column_indices, offsets)
+ self._column_indices_t = column_indices_t
+ self._offsets_t = offsets_t
+ self._block_offsets_t = block_offsets_t
+
+ self._transposed = False
+
+ # Validate that our metadata will not overflow.
+ max_dim = np.iinfo(np.int16).max * self.blocking
+ if column_indices.dtype == torch.int16:
+ if size[0] > max_dim or size[1] > max_dim:
+ raise ValueError(
+ "Sparse matrix with shape {size} exceeds representable "
+ "size with 16-bit indices.")
+
+ def validate(self):
+ _validate_matrix(self._size,
+ self._data,
+ self._row_indices,
+ self._column_indices,
+ self._offsets)
+
+ # TODO(tgale): Add heavyweight data validation.
+
+ def to(self, device):
+ # TODO(tgale): Handle type conversions here. We
+ # need to set the appropriate meta-data type for
+ # the given floating-point type.
+ self._data = self._data.to(device)
+ self._row_indices = self._row_indices.to(device)
+ self._column_indices = self._column_indices.to(device)
+ self._offsets = self._offsets.to(device)
+ self._column_indices_t = self._column_indices_t.to(device)
+ self._offsets_t = self._offsets_t.to(device)
+ self._block_offsets_t = self._block_offsets_t.to(device)
+ return self
+
+ def cuda(self):
+ return self.to(torch.cuda.current_device())
+
+ def clone(self):
+ return Matrix(
+ self.size(),
+ self.data.clone(),
+ self.row_indices.clone(),
+ self.column_indices.clone(),
+ self.offsets.clone(),
+ self.column_indices_t.clone(),
+ self.offsets_t.clone(),
+ self.block_offsets_t.clone())
+
+ def t(self):
+ if self.dim() != 2:
+ raise ValueError(
+ "t() expects a tensor with <= 2 dimensions, "
+ f"but self is {self.dim()}D.")
+ out = Matrix(self.size(),
+ self.data,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+ out._transposed = not self._transposed
+ out._size = torch.Size((self._size[1], self._size[0]))
+ return out
+
+ def contiguous(self):
+ raise ValueError("Not yet implemented.")
+
+ def is_contiguous(self):
+ return not self._transposed
+
+ @property
+ def is_cuda(self):
+ return self._data.is_cuda
+
+ @property
+ def device(self):
+ return self._data.device
+
+ def size(self):
+ return self._size
+
+ @property
+ def shape(self):
+ return self.size()
+
+ def dim(self):
+ return len(self._size)
+
+ @property
+ def data(self):
+ return self._data
+
+ @property
+ def row_indices(self):
+ return self._row_indices
+
+ @property
+ def column_indices(self):
+ return self._column_indices
+
+ @property
+ def offsets(self):
+ return self._offsets
+
+ @property
+ def offsets_t(self):
+ return self._offsets_t
+
+ @property
+ def column_indices_t(self):
+ return self._column_indices_t
+
+ @property
+ def block_offsets_t(self):
+ return self._block_offsets_t
+
+ @property
+ def dtype(self):
+ return self.data.dtype
+
+ @property
+ def nnz(self):
+ return self.data.numel()
+
+ @property
+ def blocking(self):
+ return self.data.shape[1]
+
+ @property
+ def requires_grad(self):
+ return self.data.requires_grad
+
+ def requires_grad_(self, x):
+ self.data.requires_grad_(x)
+ return self
+
+ def view(self, *shape):
+ assert self.is_contiguous()
+ if shape[-1] != self.size()[-1]:
+ raise ValueError(
+ "Can't change view on compressed dimension. "
+ f"{self.size()[-1]} v. {shape[-1]}.")
+ if np.prod(shape) != np.prod(self.size()):
+ raise ValueError(
+ "Mismatch in numel of Matrix and new shape. "
+ f"{np.prod(self.size())} v. {np.prod(shape)}")
+ return Matrix(shape,
+ self.data,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+
+ @property
+ def grad(self):
+ # TODO(tgale): Make sure this mirrors torch.Tensor
+ # behavior in the case where we ask for the gradient
+ # of a non-contiguous tensor.
+ size = self.size()
+ if not self.is_contiguous():
+ size = torch.Size((size[1], size[0]))
+ out = Matrix(size,
+ self.data.grad,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+ return out if self.is_contiguous() else out.t()
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/stk/ops/__init__.py b/build/torch28-cxx11-cu126-x86_64-linux/stk/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc873b236f4cd4036964c016a4036e3ce5ebf1ac
--- /dev/null
+++ b/build/torch28-cxx11-cu126-x86_64-linux/stk/ops/__init__.py
@@ -0,0 +1,3 @@
+from .linear_ops import dds, dsd, sdd
+from .matrix_ops import ones_like, row_indices, sum, to_dense, to_sparse
+from .eltwise_ops import mul
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/stk/ops/eltwise_ops.py b/build/torch28-cxx11-cu126-x86_64-linux/stk/ops/eltwise_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba7d7332320250fd01fa60e60528f19de3e8ed03
--- /dev/null
+++ b/build/torch28-cxx11-cu126-x86_64-linux/stk/ops/eltwise_ops.py
@@ -0,0 +1,28 @@
+from ..matrix import Matrix
+
+def mul(a, b):
+ """Performs element-wise multiplication of matrices a and b.
+
+ It is the user's responsibility to make sure that a and b
+ follow the same matrix topology. This function assumes it is safe
+ to use the topoplogy of a.
+
+ Args:
+ a: stk.Matrix.
+ b: stk.Matrix with a's matrix topology.
+
+ Returns:
+ stk.Matrix where the entries correspond to torch.mul(a, b).
+ """
+ assert isinstance(a, Matrix)
+ assert isinstance(b, Matrix)
+ assert a.size() == b.size()
+
+ return Matrix(a.size(),
+ a.data * b.data,
+ a.row_indices,
+ a.column_indices,
+ a.offsets,
+ a.column_indices_t,
+ a.offsets_t,
+ a.block_offsets_t)
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/stk/ops/eltwise_ops_test.py b/build/torch28-cxx11-cu126-x86_64-linux/stk/ops/eltwise_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..66bfd4f6af77042d3c5bdb1fe18d00e457478d46
--- /dev/null
+++ b/build/torch28-cxx11-cu126-x86_64-linux/stk/ops/eltwise_ops_test.py
@@ -0,0 +1,86 @@
+import unittest
+import itertools
+import torch
+from absl.testing import parameterized
+
+import stk
+from stk.ops.linear_ops_test import allclose, _dense_and_sparse
+
+_MATRIX_SIZES = (
+ (128, 128, 0.0),
+ (256, 256, 0.5),
+ (2048, 1024, 0.8),
+ (512, 128, 0.0),
+ (128, 512, 0.0),
+ (1024, 512, 0.0),
+ (1024, 512, 0.5),
+ (1024, 512, 0.75),
+ (512, 1024, 0.0),
+ (512, 1024, 0.5),
+ (512, 1024, 0.75),
+ (1024, 1024, 0.0),
+ (1024, 1024, 0.5),
+ (1024, 1024, 0.75),
+)
+
+_DTYPE = (
+ torch.float16, torch.bfloat16
+)
+
+def _generate_testcases():
+ testcases = itertools.product(_MATRIX_SIZES, _DTYPE)
+ testcases = [(*size, 128, dtype) for
+ (size, dtype) in testcases]
+ return testcases
+
+_ELTWISE_OP_TESTS = _generate_testcases()
+
+def _dense_and_sparse_like(x, std=0.1):
+ dense_data = torch.randn_like(x.data, device=x.device) * std
+ sparse = stk.Matrix(x.size(),
+ dense_data,
+ x.row_indices,
+ x.column_indices,
+ x.offsets)
+ dense = stk.ops.to_dense(sparse)
+
+ return (dense.requires_grad_(True),
+ sparse.requires_grad_(True))
+
+@parameterized.parameters(_ELTWISE_OP_TESTS)
+class EltwiseOpsTest(parameterized.TestCase):
+
+ def testEltwiseMul(self, m, n, sparsity, blocking, dtype):
+
+ a_dense, a = _dense_and_sparse(m, n, sparsity, blocking, dtype)
+ b_dense, b = _dense_and_sparse_like(a)
+
+ out = stk.ops.mul(a, b)
+ expected_out = torch.mul(a_dense, b_dense)
+
+ # Compute the gradients w.r.t. the inputs.
+ expected_out.sum().backward()
+ stk.ops.sum(out).backward()
+
+ # Validate the results.
+ out = stk.ops.to_dense(out)
+ self.assertEqual(out.dim(), 2)
+ self.assertEqual(expected_out.size(), out.size())
+ self.assertTrue(allclose(out, expected_out))
+
+ # LHS gradient.
+ grad = stk.ops.to_dense(a.grad)
+ expected_grad = a_dense.grad
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size(), grad.size())
+ self.assertTrue(allclose(grad, expected_grad))
+
+ # RHS gradient.
+ grad = stk.ops.to_dense(b.grad)
+ expected_grad = b_dense.grad
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size(), grad.size())
+ self.assertTrue(allclose(grad, expected_grad))
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/stk/ops/linear_ops.py b/build/torch28-cxx11-cu126-x86_64-linux/stk/ops/linear_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d277c8c07f9e30addc31900a12175c8a1f4d7ad
--- /dev/null
+++ b/build/torch28-cxx11-cu126-x86_64-linux/stk/ops/linear_ops.py
@@ -0,0 +1,59 @@
+import torch
+
+from ..backend import sputnik
+from ..matrix import Matrix
+
+
+def dsd(a, b):
+ assert isinstance(a, Matrix)
+ assert isinstance(b, torch.Tensor)
+ return sputnik.dsd(
+ a.size(),
+ a.data, a.offsets,
+ a.row_indices,
+ a.column_indices,
+ a.offsets_t,
+ a.column_indices_t,
+ a.block_offsets_t,
+ not a.is_contiguous(),
+ b)
+
+
+def dds(a, b):
+ assert isinstance(a, torch.Tensor)
+ assert isinstance(b, Matrix)
+ return sputnik.dds(
+ a,
+ b.size(),
+ b.data, b.offsets,
+ b.row_indices,
+ b.column_indices,
+ b.offsets_t,
+ b.column_indices_t,
+ b.block_offsets_t,
+ not b.is_contiguous())
+
+
+def sdd(a, b, topo):
+ assert isinstance(a, torch.Tensor)
+ assert isinstance(b, torch.Tensor)
+ assert isinstance(topo, Matrix)
+ assert topo.is_contiguous()
+ out = sputnik.sdd(
+ a, b,
+ topo.size(),
+ topo.data,
+ topo.offsets,
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets_t,
+ topo.column_indices_t,
+ topo.block_offsets_t)
+ return Matrix(topo.size(),
+ out,
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets,
+ topo.column_indices_t,
+ topo.offsets_t,
+ topo.block_offsets_t)
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/stk/ops/linear_ops_test.py b/build/torch28-cxx11-cu126-x86_64-linux/stk/ops/linear_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..ced1d782fbc9f9ca16b3449239f1588dc5ff5e00
--- /dev/null
+++ b/build/torch28-cxx11-cu126-x86_64-linux/stk/ops/linear_ops_test.py
@@ -0,0 +1,216 @@
+import unittest
+import itertools
+import numpy as np
+import torch
+from absl.testing import parameterized
+
+import stk
+
+
+def allclose(x, y, pct=0.25):
+ mask = torch.isclose(x, y, rtol=5e-2)
+ pct_diff = (mask.numel() - mask.sum()) / mask.numel() * 100
+ if pct_diff > pct:
+ print("{:.2f}% of values not close.".format(pct_diff))
+ return False
+ return True
+
+
+# An assortment of problems designed to make sure
+# the bindings are operating correctly.
+_MATRIX_SIZES = (
+ (128, 128, 128, 0.0),
+ (256, 256, 256, 0.5),
+ (2048, 1024, 512, 0.8),
+ (512, 128, 128, 0.0),
+ (128, 128, 512, 0.0),
+ (1024, 512, 512, 0.0),
+ (1024, 512, 512, 0.5),
+ (1024, 512, 512, 0.75),
+ (512, 512, 1024, 0.0),
+ (512, 512, 1024, 0.5),
+ (512, 512, 1024, 0.75),
+ (1024, 1024, 1024, 0.0),
+ (1024, 1024, 1024, 0.5),
+ (1024, 1024, 1024, 0.75),
+)
+
+_TRANSPOSE = (
+ (False, False),
+ (False, True),
+ (True, False),
+ (True, True),
+)
+
+_DTYPE = (
+ torch.float16, torch.bfloat16
+)
+
+def _generate_testcases():
+ testcases = itertools.product(_MATRIX_SIZES, _TRANSPOSE, _DTYPE)
+ testcases = [(*size, *trans, 128, dtype) for
+ (size, trans, dtype) in testcases]
+ return testcases
+
+_LINEAR_OP_TESTS = _generate_testcases()
+
+def _dense_and_sparse(rows, cols, sparsity, blocking, dtype, std=0.1):
+ mask = stk.random.dense_mask(rows, cols, sparsity, blocking)
+ dense = (torch.randn(rows, cols) * std * mask).type(dtype)
+ sparse = stk.ops.to_sparse(dense, blocking)
+ cuda_device = torch.device("cuda")
+ return (dense.to(cuda_device).requires_grad_(True),
+ sparse.to(cuda_device).requires_grad_(True))
+
+
+def _dense(rows, cols, dtype, std=0.1):
+ cuda_device = torch.device("cuda")
+ out = (torch.randn(rows, cols) * std).type(dtype)
+ return out.to(cuda_device).requires_grad_(True)
+
+
+def _dense_2x(rows, cols, dtype):
+ a = _dense(rows, cols, dtype)
+ return a, a.detach().requires_grad_(True)
+
+
+def _with_transpose(op, a, b, trans_a, trans_b):
+ a = a.t() if trans_a else a
+ b = b.t() if trans_b else b
+ return op(a, b)
+
+
+def _mmm(a, b, topo):
+ mask = stk.ops.to_dense(stk.ops.ones_like(topo))
+ return torch.mm(a, b) * mask
+
+
+def _sparse_out_with_transpose(op, a, b, topo, trans_a, trans_b):
+ a = a.t() if trans_a else a
+ b = b.t() if trans_b else b
+ return op(a, b, topo)
+
+
+def _mask(x, mask):
+ mask = stk.ops.to_dense(stk.ops.ones_like(mask))
+ return x * mask
+
+
+@parameterized.parameters(*_LINEAR_OP_TESTS)
+class LinearOpsTest(parameterized.TestCase):
+
+ def testLinearOps_Dsd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+ # Construct the operands.
+ a_shape = (k, m) if trans_a else (m, k)
+ a_dense, a = _dense_and_sparse(*a_shape, sparsity, blocking, dtype)
+ b_shape = (n, k) if trans_b else (k, n)
+ b, bcp = _dense_2x(*b_shape, dtype)
+
+ # Execute the matmul.
+ out = _with_transpose(stk.ops.dsd, a, b, trans_a, trans_b)
+ expected_out = _with_transpose(torch.mm, a_dense, bcp, trans_a, trans_b)
+
+ # Compute the gradients w.r.t. the inputs.
+ expected_out.sum().backward()
+ out.sum().backward()
+
+ # Validate the results.
+ self.assertEqual(out.dim(), 2)
+ self.assertEqual(expected_out.size()[0], out.size()[0])
+ self.assertEqual(expected_out.size()[1], out.size()[1])
+ self.assertTrue(allclose(out, expected_out))
+
+ # LHS gradient.
+ grad = stk.ops.to_dense(a.grad)
+ expected_grad = _mask(a_dense.grad, a.grad)
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
+ self.assertTrue(allclose(grad, expected_grad))
+
+ # RHS gradient.
+ grad = b.grad
+ expected_grad = bcp.grad
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
+ self.assertTrue(allclose(grad, expected_grad))
+
+ def testLinearOps_Dds(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+ # Construct the operands.
+ a_shape = (k, m) if trans_a else (m, k)
+ a, acp = _dense_2x(*a_shape, dtype)
+ b_shape = (n, k) if trans_b else (k, n)
+ b_dense, b = _dense_and_sparse(*b_shape, sparsity, blocking, dtype)
+
+ # Execute the matmul.
+ out = _with_transpose(stk.ops.dds, a, b, trans_a, trans_b)
+ expected_out = _with_transpose(torch.mm, acp, b_dense, trans_a, trans_b)
+
+ # Compute the gradients w.r.t. the inputs.
+ expected_out.sum().backward()
+ out.sum().backward()
+
+ # Validate the results.
+ self.assertEqual(out.dim(), 2)
+ self.assertEqual(expected_out.size()[0], out.size()[0])
+ self.assertEqual(expected_out.size()[1], out.size()[1])
+ self.assertTrue(allclose(out, expected_out))
+
+ # LHS gradient.
+ grad = a.grad
+ expected_grad = acp.grad
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
+ self.assertTrue(allclose(grad, expected_grad))
+
+ # RHS gradient.
+ grad = stk.ops.to_dense(b.grad)
+ expected_grad = _mask(b_dense.grad, b.grad)
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
+ self.assertTrue(allclose(grad, expected_grad))
+
+ def testLinearOps_Sdd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+ # Construct the operands.
+ a_shape = (k, m) if trans_a else (m, k)
+ a, acp = _dense_2x(*a_shape, dtype)
+ b_shape = (n, k) if trans_b else (k, n)
+ b, bcp = _dense_2x(*b_shape, dtype)
+ _, topo = _dense_and_sparse(m, n, sparsity, blocking, dtype)
+
+ # Execute the matmul.
+ out = _sparse_out_with_transpose(stk.ops.sdd, a, b, topo, trans_a, trans_b)
+ expected_out = _sparse_out_with_transpose(_mmm, acp, bcp, topo, trans_a, trans_b)
+
+ # Compute the gradients w.r.t. the inputs.
+ expected_out.sum().backward()
+ stk.ops.sum(out).backward()
+
+ # Validate the results.
+ out = stk.ops.to_dense(out)
+ self.assertEqual(out.dim(), 2)
+ self.assertEqual(expected_out.size()[0], out.size()[0])
+ self.assertEqual(expected_out.size()[1], out.size()[1])
+ self.assertTrue(allclose(out, expected_out))
+
+ # LHS gradient.
+ grad = a.grad
+ expected_grad = acp.grad
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
+ self.assertTrue(allclose(grad, expected_grad))
+
+ # RHS gradient.
+ grad = b.grad
+ expected_grad = bcp.grad
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
+ self.assertTrue(allclose(grad, expected_grad))
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/stk/ops/matrix_ops.py b/build/torch28-cxx11-cu126-x86_64-linux/stk/ops/matrix_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..447c72dc73439d84f58c917676cc04e64f13e97d
--- /dev/null
+++ b/build/torch28-cxx11-cu126-x86_64-linux/stk/ops/matrix_ops.py
@@ -0,0 +1,98 @@
+from ..backend import sputnik
+from ..matrix import Matrix
+import torch
+import numpy as np
+
+
+@torch.no_grad()
+def row_indices(shape, data, offsets, column_indices):
+ return sputnik.row_indices(shape, data, offsets, column_indices)
+
+
+# TODO(tgale): Replace this helper with a custom kernel. This operation
+# is much simpler to do than how it's currently implemented.
+@torch.no_grad()
+def _expand_for_blocking(idxs, blocking):
+ # Duplicate for block column dimension.
+ idxs = torch.reshape(idxs, [idxs.size()[0], 1, 2]).repeat(1, blocking, 1)
+
+ # Update the column indices.
+ idxs[:, :, 1] *= blocking
+ idxs[:, :, 1] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking])
+
+ # Duplicate for block row dimension.
+ idxs = torch.reshape(idxs, [idxs.size()[0], 1, blocking, 2])
+ idxs = idxs.repeat(1, blocking, 1, 1)
+
+ # Update the row indices.
+ idxs[:, :, :, 0] *= blocking
+ idxs[:, :, :, 0] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking, 1])
+ idxs = torch.reshape(idxs, [-1, 2])
+ return idxs
+
+
+# TODO(tgale): Add input type checking.
+@torch.no_grad()
+def to_dense(x):
+ assert isinstance(x, Matrix)
+
+ shape = (np.prod(x.shape[:-1]), x.shape[-1])
+ row_idxs = x.row_indices.type(torch.int32)
+ col_idxs = x.column_indices.type(torch.int32)
+ indices = _expand_for_blocking(torch.stack([row_idxs, col_idxs], dim=1), x.blocking)
+ indices = (indices[:, 0] * shape[1] + indices[:, 1]).type(torch.int64)
+
+ out = torch.zeros(shape[0] * shape[1], dtype=x.dtype, device=x.device)
+ out.scatter_(0, indices, x.data.flatten())
+ return out.reshape(x.size())
+
+
+@torch.no_grad()
+def _mask(x, blocking=1):
+ assert x.dim() == 2
+ assert x.size()[0] % blocking == 0
+ assert x.size()[1] % blocking == 0
+ block_rows = x.size()[0] // blocking
+ block_cols = x.size()[1] // blocking
+ x = torch.reshape(x, [block_rows, blocking, block_cols, blocking])
+ x = torch.sum(torch.abs(x), dim=(1, 3))
+ return x != 0
+
+
+# TODO(tgale): Add input type checking.
+@torch.no_grad()
+def to_sparse(x, blocking=1):
+ m = _mask(x, blocking)
+
+ # TODO(tgale): Set to appropriate type for input matrix.
+ row_nnzs = torch.sum(m, dim=1).type(torch.int32)
+ zeros = torch.zeros((1,), dtype=row_nnzs.dtype, device=row_nnzs.device)
+ offsets = torch.cat([zeros, torch.cumsum(row_nnzs, dim=0)])
+ offsets = offsets.type(torch.int32)
+
+ indices = torch.nonzero(m).type(torch.int16)
+ row_indices = indices[:, 0]
+ column_indices = indices[:, 1]
+
+ # Nonzero indices in the dense matrix.
+ nonzero_indices = torch.nonzero(m)
+ nonzero_indices = _expand_for_blocking(nonzero_indices, blocking)
+ nonzero_indices = nonzero_indices[:, 0] * x.size()[1] + nonzero_indices[:, 1]
+
+ # Gather the data and construct the sparse matrix.
+ data = torch.gather(x.flatten(), dim=0, index=nonzero_indices)
+ data = torch.reshape(data, [-1, blocking, blocking])
+ return Matrix(x.size(), data, row_indices, column_indices, offsets)
+
+
+@torch.no_grad()
+def ones_like(x):
+ return Matrix(x.size(),
+ torch.ones_like(x.data),
+ x.row_indices,
+ x.column_indices, x.offsets)
+
+
+def sum(x):
+ assert isinstance(x, Matrix)
+ return x.data.sum()
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/stk/ops/matrix_ops_test.py b/build/torch28-cxx11-cu126-x86_64-linux/stk/ops/matrix_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..3af04c0760483e578f93303dc457415948a2a34c
--- /dev/null
+++ b/build/torch28-cxx11-cu126-x86_64-linux/stk/ops/matrix_ops_test.py
@@ -0,0 +1,62 @@
+import unittest
+
+from absl.testing import parameterized
+import stk
+import torch
+
+
+@parameterized.parameters(
+ (8, 16, 0.0, 1),
+ (8, 16, 0.5, 1),
+ (8, 16, .95, 1),
+ (16, 8, 0.0, 1),
+ (16, 8, 0.5, 1),
+ (16, 8, .95, 1),
+ (8, 16, 0.0, 8),
+ (8, 16, 0.5, 8),
+ (8, 16, 1.0, 8),
+ (16, 8, 0.0, 8),
+ (16, 8, 0.5, 8),
+ (16, 8, 1.0, 8),
+ (128, 256, 0.5, 16),
+ (256, 128, 0.75, 32),
+ (512, 512, .875, 128))
+class MatrixOpsTest(parameterized.TestCase):
+
+ def testMatrixOps_FormatConversion(self, rows, cols, sparsity, blocking):
+ mask = stk.random.dense_mask(rows, cols, sparsity, blocking)
+ x = (torch.randn(rows, cols) * mask).type(torch.float16)
+
+ # Convert the matrix to sparse format.
+ sparse_x = stk.ops.to_sparse(x, blocking)
+
+ # Validate the matrix.
+ sparse_x.validate()
+
+ # Validate the shape.
+ self.assertEqual(sparse_x.dim(), 2)
+ self.assertEqual(sparse_x.size()[0], rows)
+ self.assertEqual(sparse_x.size()[1], cols)
+
+ # Validate the sparsity.
+ numblocks = rows // blocking * cols // blocking
+ nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+ self.assertEqual(sparse_x.nnz, nnz)
+
+ # Convert back to dense format.
+ dense_x = stk.ops.to_dense(sparse_x)
+
+ # Validate the shape.
+ self.assertEqual(dense_x.dim(), 2)
+ self.assertEqual(dense_x.size()[0], rows)
+ self.assertEqual(dense_x.size()[1], cols)
+
+ # Validate the sparsity
+ self.assertEqual(torch.count_nonzero(dense_x).item(), nnz)
+
+ # Validate the output.
+ self.assertTrue(torch.all(torch.eq(x, dense_x)))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/stk/random/__init__.py b/build/torch28-cxx11-cu126-x86_64-linux/stk/random/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2576d1ca27283f77569a9a620c7c99fa68aaf30e
--- /dev/null
+++ b/build/torch28-cxx11-cu126-x86_64-linux/stk/random/__init__.py
@@ -0,0 +1,2 @@
+# from stk.random.random_ops import dense_mask, mask, randn
+from .random_ops import dense_mask, mask, randn
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/stk/random/random_ops.py b/build/torch28-cxx11-cu126-x86_64-linux/stk/random/random_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1b36771e0eb8e7abf46bcb3b136b5fb1d29df93
--- /dev/null
+++ b/build/torch28-cxx11-cu126-x86_64-linux/stk/random/random_ops.py
@@ -0,0 +1,36 @@
+import numpy as np
+import torch
+from ..ops import matrix_ops
+
+
+@torch.no_grad()
+def dense_mask(rows, cols, sparsity, blocking=1):
+ assert sparsity >= 0.0 and sparsity <= 1.0
+ assert rows % blocking == 0 and cols % blocking == 0
+
+ block_rows, block_cols = (rows // blocking, cols // blocking)
+ nnz = round(block_rows * block_cols * (1 - sparsity))
+
+ out = np.ones(block_rows * block_cols)
+ mask = np.random.choice(out.size, out.size - nnz, replace=False)
+ out[mask] = 0.0
+
+ out = np.tile(
+ np.reshape(out, [block_rows, 1, block_cols, 1]),
+ (1, blocking, 1, blocking))
+ out = np.reshape(out, [rows, cols])
+ return torch.from_numpy(out.astype(np.float32))
+
+
+@torch.no_grad()
+def mask(m, n, sparsity, blocking=1):
+ out = dense_mask(m, n, sparsity, blocking).type(torch.float16)
+ return matrix_ops.to_sparse(out, blocking=blocking)
+
+
+@torch.no_grad()
+def randn(shape, sparsity, blocking=1):
+ shape_2d = (np.prod(shape[:-1]), shape[-1])
+ out = mask(*shape_2d, sparsity, blocking)
+ out.data.copy_(torch.randn(*out.data.shape))
+ return out.view(*shape)
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/stk/random/random_ops_test.py b/build/torch28-cxx11-cu126-x86_64-linux/stk/random/random_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..587b44ec890c861879c6296b8f9028f5d99ab82f
--- /dev/null
+++ b/build/torch28-cxx11-cu126-x86_64-linux/stk/random/random_ops_test.py
@@ -0,0 +1,73 @@
+import unittest
+
+from absl.testing import parameterized
+from . import random
+import torch
+
+
+@parameterized.parameters(
+ (8, 16, 0.0, 1),
+ (8, 16, 0.5, 1),
+ (8, 16, .95, 1),
+ (16, 8, 0.0, 1),
+ (16, 8, 0.5, 1),
+ (16, 8, .95, 1),
+ (8, 16, 0.0, 8),
+ (8, 16, 0.5, 8),
+ (8, 16, 1.0, 8),
+ (16, 8, 0.0, 8),
+ (16, 8, 0.5, 8),
+ (16, 8, 1.0, 8),
+ (128, 256, 0.5, 16),
+ (256, 128, 0.75, 32),
+ (512, 512, .875, 128))
+class RandomOpsTest(parameterized.TestCase):
+
+ def testRandomOps_DenseMask(self, rows, cols, sparsity, blocking):
+ mask = random.dense_mask(
+ rows, cols, sparsity, blocking)
+
+ # Validate the shape.
+ self.assertEqual(mask.dim(), 2)
+ self.assertEqual(mask.size()[0], rows)
+ self.assertEqual(mask.size()[1], cols)
+
+ # Validate the sparsity
+ numblocks = rows // blocking * cols // blocking
+ nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+ self.assertEqual(
+ torch.count_nonzero(mask).item(),
+ nnz)
+
+ # Check values are zero or one.
+ self.assertTrue(
+ torch.all(torch.logical_or(
+ torch.eq(mask, 0),
+ torch.eq(mask, 1))))
+
+ def testRandomOps_SparseMask(self, rows, cols, sparsity, blocking):
+ mask = random.mask(
+ rows, cols, sparsity, blocking)
+
+ # Validate the matrix.
+ mask.validate()
+
+ # Validate the shape.
+ self.assertEqual(mask.dim(), 2)
+ self.assertEqual(mask.size()[0], rows)
+ self.assertEqual(mask.size()[1], cols)
+
+ # Validate the sparsity.
+ numblocks = rows // blocking * cols // blocking
+ nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+ self.assertEqual(mask.nnz, nnz)
+
+ # Check values are zero or one.
+ self.assertTrue(
+ torch.all(torch.logical_or(
+ torch.eq(mask.data, 0),
+ torch.eq(mask.data, 1))))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/xpu_fused_moe.py b/build/torch28-cxx11-cu126-x86_64-linux/xpu_fused_moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..d2e7c6692f101f9141e9d716c8af6ac92be95351
--- /dev/null
+++ b/build/torch28-cxx11-cu126-x86_64-linux/xpu_fused_moe.py
@@ -0,0 +1,577 @@
+# SPDX-License-Identifier: Apache-2.0
+# MegaBlocks XPU Fused MoE Implementation
+import os
+import torch
+
+from ._ops import ops
+
+
+# Install meta kernels for torch.compile compatibility
+def _install_xpu_meta_kernels():
+ """Install meta kernels for XPU MoE operations to support torch.compile"""
+
+ # Patch cutlass_grouped_gemm_interface
+ if hasattr(ops, "cutlass_grouped_gemm_interface"):
+ original_gemm = ops.cutlass_grouped_gemm_interface
+
+ def gemm_with_meta(ptr_A, ptr_B, ptr_scales, ptr_bias, ptr_D,
+ expert_first_token_offset, N, K, num_experts,
+ is_B_int4, is_B_mxfp4):
+ if torch.compiler.is_compiling():
+ # Meta implementation - ptr_D is the output, return it
+ return ptr_D
+ return original_gemm(ptr_A, ptr_B, ptr_scales, ptr_bias, ptr_D,
+ expert_first_token_offset, N, K, num_experts,
+ is_B_int4, is_B_mxfp4)
+
+ ops.cutlass_grouped_gemm_interface = gemm_with_meta
+
+ # Patch fused_moe_prologue
+ if hasattr(ops, "fused_moe_prologue"):
+ original_prologue = ops.fused_moe_prologue
+
+ def prologue_with_meta(input, token_selected_experts, token_final_scales,
+ workspace, hidden_size, inter_size, num_experts_on_rank):
+ if torch.compiler.is_compiling():
+ # Meta implementation - this op modifies workspace in-place
+ return None
+ return original_prologue(input, token_selected_experts, token_final_scales,
+ workspace, hidden_size, inter_size, num_experts_on_rank)
+
+ ops.fused_moe_prologue = prologue_with_meta
+
+ # Patch moe_gather
+ if hasattr(ops, "moe_gather"):
+ original_gather = ops.moe_gather
+
+ def gather_with_meta(output, moe_output, topk_weights,
+ unpermuted_row_to_permuted_row, num_experts):
+ if torch.compiler.is_compiling():
+ # Meta implementation - output is modified in-place
+ return None
+ return original_gather(output, moe_output, topk_weights,
+ unpermuted_row_to_permuted_row, num_experts)
+
+ ops.moe_gather = gather_with_meta
+
+ # Patch activation ops
+ for act_name in ["silu_and_mul", "gelu_and_mul", "gelu_tanh_and_mul",
+ "gelu_fast", "gelu_new", "gelu_quick", "mul_and_silu",
+ "swigluoai_and_mul"]:
+ if hasattr(ops, act_name):
+ original_act = getattr(ops, act_name)
+
+ def make_act_wrapper(orig_fn):
+ def act_with_meta(*args, **kwargs):
+ if torch.compiler.is_compiling():
+ # Meta implementation - in-place ops, return None
+ return None
+ return orig_fn(*args, **kwargs)
+ return act_with_meta
+
+ setattr(ops, act_name, make_act_wrapper(original_act))
+
+
+# Install meta kernels on module load
+_install_xpu_meta_kernels()
+
+
+# default
+def cutlass_grouped_gemm(input_A, input_B, bias, output, expert_token_count, n,
+ k, num_experts):
+ # expert_token_count_ = torch.tensor(expert_token_count,
+ # dtype=torch.int64,
+ # device=input_A.device)
+ # if bias is not None:
+ # bias = bias.repeat_interleave(expert_token_count_, dim=0).float()
+
+ def exclusive_prefix_sum(arr):
+ prefix = [0]
+ for i, x in enumerate(arr):
+ prefix.append(prefix[-1] + x)
+ return prefix
+
+ expert_offset = torch.tensor(exclusive_prefix_sum(expert_token_count),
+ dtype=torch.int64,
+ device="xpu")
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=None,
+ ptr_bias=bias,
+ ptr_D=output,
+ expert_first_token_offset=expert_offset,
+ N=n,
+ K=k,
+ num_experts=num_experts,
+ is_B_int4=False,
+ is_B_mxfp4=False)
+
+
+def cutlass_grouped_gemm_xe2(input_A, input_B, scales, bias, output,
+ num_rows_per_expert, n, k, num_experts, is_B_int4,
+ is_B_mxfp4):
+ expert_first_token_offset = torch.cat([
+ torch.tensor([0],
+ dtype=num_rows_per_expert.dtype,
+ device=num_rows_per_expert.device),
+ torch.cumsum(num_rows_per_expert, dim=0)
+ ]).to(torch.int64)
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=scales,
+ ptr_bias=bias,
+ ptr_D=output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=n,
+ K=k,
+ num_experts=num_experts,
+ is_B_int4=is_B_int4,
+ is_B_mxfp4=is_B_mxfp4)
+
+
+def ceilDiv(a, b):
+ return (a + b - 1) // b
+
+
+def compute_num_tokens_per_block(num_tokens, num_experts_per_node):
+ for num_tokens_per_block in [32, 64, 128, 256, 512, 1024]:
+ num_blocks_per_seq = ceilDiv(num_tokens, num_tokens_per_block)
+ if num_blocks_per_seq * num_experts_per_node <= num_tokens_per_block:
+ return num_tokens_per_block
+ return 1024
+
+
+def implement_zp(qweight):
+ # change u4 to s4 to avoid zero point in gemm kernel
+ # only support default zero point now
+ assert qweight.dtype == torch.uint8, "Input tensor must be uint8"
+
+ high_u4 = (qweight >> 4) & 0x0F
+ low_u4 = qweight & 0x0F
+
+ high_s8 = high_u4.to(torch.int8)
+ low_s8 = low_u4.to(torch.int8)
+
+ high_s8 = high_s8 - 8
+ low_s8 = low_s8 - 8
+
+ def pack_compact(a, b):
+
+ def process_number(x):
+ sign = (x < 0).to(torch.uint8)
+ abs_low3 = (x.view(torch.uint8) & 0x7).to(torch.uint8)
+ return (sign << 3) | abs_low3
+
+ packed_a = process_number(a)
+ packed_b = process_number(b)
+
+ return (packed_a << 4) | packed_b
+
+ result = pack_compact(high_s8, low_s8)
+
+ return result
+
+
+def xpu_fused_moe(hidden_states,
+ w13,
+ w13_scales,
+ w13_bias,
+ w2,
+ w2_scales,
+ w2_bias,
+ topk_weights,
+ topk_ids,
+ n_experts_per_token,
+ activation,
+ num_experts,
+ is_fp8=False,
+ is_int4=False,
+ is_mxfp4=False):
+ '''
+ hidden_states: [num_rows, hidden_size]
+ w13: [num_experts, 2*inter_size, hidden_size]
+ w13_scales:
+ None for bf16/fp16
+ or [num_experts] for fp8
+ or [num_experts, 2*inter_size, hidden_size // group_size] for 4bits
+ w13_bias: [num_experts, 2*inter_size] or None
+ w2: [num_experts, hidden_size, inter_size]
+ w2_scales:
+ None for bf16/fp16
+ or [num_experts] for fp8
+ or [num_experts, hidden_size, inter_size // group_size] for 4bits
+ w2_bias: [num_experts, hidden_size] or None
+ topk_weights: [num_rows, topk]
+ topk_ids: [num_rows, topk]
+ n_experts_per_token: int
+ activation: str
+ num_experts: int
+ is_int4: bool
+ is_mxfp4: bool
+ '''
+
+ output = torch.empty_like(hidden_states)
+ num_rows, hidden_size = list(hidden_states.shape)
+
+ dim_last = w13.shape[-1]
+ dim_second_last = w13.shape[-2]
+
+ # w13 is combined gate+up weights, so one dimension is 2*inter_size
+ # Determine which dimension is hidden_size and which is 2*inter_size
+ if dim_second_last == hidden_size:
+ # w13 is [E, hidden_size, 2*inter_size] - standard layout
+ inter_size = dim_last // 2
+ needs_transpose = False
+ else:
+ # w13 is [E, 2*inter_size, hidden_size] - needs transpose
+ inter_size = dim_second_last // 2
+ needs_transpose = True
+
+ assert w13.is_contiguous() and w2.is_contiguous()
+
+ # 4bits support [E, N, K]
+ # other types [E, K, N]
+ if not is_int4 and not is_mxfp4:
+ if not hasattr(w13, 'xpu_fused_moe'):
+ if needs_transpose:
+ w13.data = w13.transpose(-1, -2).contiguous()
+ w2.data = w2.transpose(-1, -2).contiguous()
+ w13.xpu_fused_moe = True
+ w13.inter_size = inter_size
+ else:
+ inter_size = w13.inter_size
+
+ if is_int4 and not hasattr(w13, 'xpu_fused_moe'):
+ w13_tmp = torch.empty_like(w13)
+ w2_tmp = torch.empty_like(w2)
+ for i in range(num_experts):
+ w13_tmp[i] = implement_zp(w13[i])
+ w2_tmp[i] = implement_zp(w2[i])
+ w13_tmp = w13_tmp.contiguous()
+ w2_tmp = w2_tmp.contiguous()
+ w13.data = w13_tmp
+ w2.data = w2_tmp
+ w13.xpu_fused_moe = True
+
+ # TODO: will all integrated in Cpp func. Temporary expose before gemm fusion
+ num_experts_per_node = num_experts
+ experts_per_token = n_experts_per_token
+ num_moe_inputs = n_experts_per_token * num_rows
+ permuted_elems = num_moe_inputs * hidden_size
+ # interbuf_elems = num_moe_inputs * inter_size
+ permuted_row_to_unpermuted_row_size = num_moe_inputs * 4
+ permuted_token_selected_experts_size = num_moe_inputs * 4
+ src_to_dest_map_size = experts_per_token * num_rows * 4
+ expert_first_token_offset_size = (num_experts_per_node + 1) * 8
+ num_tokens_per_block = compute_num_tokens_per_block(
+ num_rows, num_experts_per_node)
+ num_blocks_per_seq = ceilDiv(num_rows, num_tokens_per_block)
+ blocked_expert_counts_size = num_experts_per_node * num_blocks_per_seq * 4
+ blocked_expert_counts_cumsum_size = blocked_expert_counts_size
+ blocked_row_to_unpermuted_row_size = num_experts_per_node * num_rows * 4
+ permuted_data_size = permuted_elems * hidden_states.element_size()
+ permuted_token_final_scales_size = num_moe_inputs * 4
+
+ ws_map = {}
+ map_offset = 0
+
+ def config_ws(name, size):
+ nonlocal map_offset
+ if size % 256 != 0:
+ size += 256 - size % 256
+ ws_map[name] = (size, map_offset)
+ map_offset += size
+
+ config_ws("permuted_row_to_unpermuted_row",
+ permuted_row_to_unpermuted_row_size)
+ config_ws("permuted_token_selected_experts",
+ permuted_token_selected_experts_size)
+ config_ws("unpermuted_row_to_permuted_row", src_to_dest_map_size)
+ config_ws("blocked_expert_counts", blocked_expert_counts_size)
+ config_ws("blocked_expert_counts_cumsum",
+ blocked_expert_counts_cumsum_size)
+ config_ws("blocked_row_to_unpermuted_row",
+ blocked_row_to_unpermuted_row_size)
+ config_ws("expert_first_token_offset", expert_first_token_offset_size)
+ config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
+ config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
+
+ workspace = torch.zeros(map_offset,
+ dtype=torch.uint8,
+ device=hidden_states.device)
+ if topk_ids.dtype == torch.int32:
+ topk_ids = topk_ids.to(torch.int64)
+ ops.fused_moe_prologue(
+ input=hidden_states,
+ token_selected_experts=topk_ids,
+ token_final_scales=topk_weights,
+ workspace=workspace,
+ hidden_size=hidden_size,
+ inter_size=inter_size,
+ num_experts_on_rank=num_experts_per_node)
+
+ expert_first_token_offset = workspace[
+ ws_map["expert_first_token_offset"][1]:
+ ws_map["expert_first_token_offset"][1] +
+ expert_first_token_offset_size].view(torch.int64)
+ unpermuted_row_to_permuted_row = workspace[
+ ws_map["unpermuted_row_to_permuted_row"][1]:
+ ws_map["unpermuted_row_to_permuted_row"][1] +
+ src_to_dest_map_size].view(torch.int32)
+ gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
+ ws_map["overlapped_gemm1_gemm2_inputs"][1] +
+ permuted_data_size].view(hidden_states.dtype).view(
+ num_moe_inputs, hidden_size)
+ # permuted_token_final_scales = workspace[
+ # ws_map["permuted_token_final_scales"][1]:
+ # ws_map["permuted_token_final_scales"][1] +
+ # permuted_token_final_scales_size].view(torch.float)
+ gemm1_output = torch.empty((num_moe_inputs, 2 * inter_size),
+ dtype=hidden_states.dtype,
+ device=hidden_states.device)
+
+ ########### gemm1 ##################
+ input_B = w13
+
+ if not is_fp8 and not is_int4 and not is_mxfp4:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=gemm1_input,
+ ptr_B=input_B,
+ ptr_scales=None,
+ ptr_bias=w13_bias,
+ ptr_D=gemm1_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=2 * inter_size,
+ K=hidden_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+ else:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=gemm1_input,
+ ptr_B=input_B,
+ ptr_scales=w13_scales,
+ ptr_bias=w13_bias,
+ ptr_D=gemm1_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=2 * inter_size,
+ K=hidden_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+
+ # act
+ act_output = torch.empty((num_moe_inputs, inter_size),
+ dtype=gemm1_output.dtype,
+ device=gemm1_output.device)
+ if activation == "silu":
+ ops.silu_and_mul(act_output, gemm1_output)
+ elif activation == "gelu":
+ ops.gelu_and_mul(act_output, gemm1_output)
+ elif activation == "swigluoai":
+ ops.swigluoai_and_mul(act_output, gemm1_output, 1.702, 7.0)
+ else:
+ raise ValueError(f"Unsupported FusedMoe activation: {activation}.")
+
+ ########### gemm2 ##################
+ input_A = act_output.contiguous()
+ input_B = w2
+ gemm2_output = torch.empty((num_moe_inputs, hidden_size),
+ dtype=hidden_states.dtype,
+ device=hidden_states.device)
+ if not is_fp8 and not is_int4 and not is_mxfp4:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=None,
+ ptr_bias=w2_bias,
+ ptr_D=gemm2_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=hidden_size,
+ K=inter_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+ else:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=w2_scales,
+ ptr_bias=w2_bias,
+ ptr_D=gemm2_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=hidden_size,
+ K=inter_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+
+ ops.moe_gather(output, gemm2_output, topk_weights,
+ unpermuted_row_to_permuted_row,
+ num_experts_per_node)
+ return output
+
+
+def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor:
+ """Apply jitter to the input tensor for regularization."""
+ low = 1.0 - moe_jitter_eps
+ high = 1.0 + moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return x * (low + noise * (high - low))
+
+
+def compute_top_k(scores: torch.Tensor, moe_top_k: int):
+ """Compute the top-k scores from the logits."""
+ if moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, moe_top_k, dim=-1)
+
+
+def route_tokens_xpu(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: torch.Tensor,
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ training: bool = False,
+) -> tuple:
+ """Route tokens to experts and compute expert weights and indices (XPU version)."""
+ if training and moe_jitter_eps is not None:
+ x = apply_jitter(x, moe_jitter_eps)
+
+ x_flat = x.view(-1, x.shape[-1])
+ logits = torch.nn.functional.linear(x_flat, router_weight, router_bias)
+ expert_weights, expert_indices = compute_top_k(logits, moe_top_k)
+ expert_weights = expert_weights.softmax(dim=-1)
+ if moe_normalize_expert_weights is not None:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+
+ return logits, expert_weights, expert_indices
+
+
+class MegaBlocksMoeMLP(torch.nn.Module):
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> tuple:
+ """
+ Forward pass through the MoE layer.
+
+ Args:
+ x: Input tensor of shape [batch_size, seq_len, hidden_size] or [tokens, hidden_size]
+
+ Returns:
+ Tuple of (output, expert_weights) where:
+ - output: Tensor of same shape as input
+ - expert_weights: Expert weights for each token [tokens, top_k]
+ """
+
+ # Get MoE parameters from the wrapped modules
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+
+ # Detect activation type - check for GptOss-style swigluoai activation
+ # GptOssExperts has alpha and limit attributes for swigluoai
+ if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
+ activation = "swigluoai"
+ else:
+ activation = getattr(self.experts, "activation", "silu")
+
+ # Get weight tensors - support different naming conventions
+ if hasattr(self.experts, "gate_up_proj"):
+ w13 = self.experts.gate_up_proj
+ # NOTE: swigluoai_and_mul kernel expects interleaved layout [g0,u0,g1,u1,...]
+ # which matches GptOss's gate_up_proj layout, so no conversion needed.
+
+ elif hasattr(self.experts, "w1"):
+ # Combine w1 and w3 if stored separately
+ w1 = self.experts.w1
+ w3 = getattr(self.experts, "w3", None)
+ if w3 is not None:
+ w13 = torch.cat([w1, w3], dim=-2)
+ else:
+ w13 = w1
+ else:
+ raise AttributeError("experts module must have 'gate_up_proj' or 'w1' attribute")
+
+ if hasattr(self.experts, "down_proj"):
+ w2 = self.experts.down_proj
+ elif hasattr(self.experts, "w2"):
+ w2 = self.experts.w2
+ else:
+ raise AttributeError("experts module must have 'down_proj' or 'w2' attribute")
+
+ # Get optional bias tensors
+ w13_bias = getattr(self.experts, "gate_up_proj_bias", None)
+ w2_bias = getattr(self.experts, "down_proj_bias", None)
+
+ # Get quantization info
+ is_fp8 = getattr(self.experts, "is_fp8", False)
+ is_int4 = getattr(self.experts, "is_int4", False)
+ is_mxfp4 = getattr(self.experts, "is_mxfp4", False)
+
+ w13_scales = getattr(self.experts, "gate_up_proj_scales", None)
+ w2_scales = getattr(self.experts, "down_proj_scales", None)
+
+ # Store original shape
+ in_shape = x.size()
+
+ # Route tokens to experts
+ logits, expert_weights, expert_indices = route_tokens_xpu(
+ x,
+ self.router.weight,
+ getattr(self.router, "bias", None),
+ moe_top_k,
+ moe_num_experts,
+ moe_jitter_eps,
+ moe_normalize_expert_weights,
+ self.training,
+ )
+
+ # Reshape input for fused MoE
+ x_flat = x.view(-1, x.shape[-1])
+
+ # Call XPU fused MoE kernel
+ output = xpu_fused_moe(
+ hidden_states=x_flat,
+ w13=w13,
+ w13_scales=w13_scales,
+ w13_bias=w13_bias,
+ w2=w2,
+ w2_scales=w2_scales,
+ w2_bias=w2_bias,
+ topk_weights=expert_weights.float(),
+ topk_ids=expert_indices,
+ n_experts_per_token=moe_top_k,
+ activation=activation,
+ num_experts=moe_num_experts,
+ is_fp8=is_fp8,
+ is_int4=is_int4,
+ is_mxfp4=is_mxfp4,
+ )
+
+ # Restore original shape
+ output = output.view(in_shape)
+
+ return output, expert_weights
+
+
+# Export classes and functions
+__all__ = [
+ "MegaBlocksMoeMLP",
+ "xpu_fused_moe",
+ "cutlass_grouped_gemm",
+ "cutlass_grouped_gemm_xe2",
+]
\ No newline at end of file
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/__init__.py b/build/torch28-cxx11-cu128-x86_64-linux/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..38075732c6d8fa0e1e6ef493145e1aca3851ae6b
--- /dev/null
+++ b/build/torch28-cxx11-cu128-x86_64-linux/__init__.py
@@ -0,0 +1,202 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+from ._ops import ops
+
+from .grouped_gemm import backend as gg_backend
+from .grouped_gemm import ops as gg_ops
+
+
+from ._layers.arguments import Arguments
+from ._layers.dmoe import ParallelDroplessMLP, dMoE
+from ._layers.glu import SparseGLU
+from ._layers.mlp import MLP, SparseMLP
+from ._layers.moe import MoE, ParallelMLP, get_load_balancing_loss
+
+from . import layers
+
+# This section contains the direct kernel exports (not inlcuded in the original code)
+def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
+ """
+ Compute exclusive cumulative sum along the specified dimension.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ result = ops.exclusive_cumsum(x, dim)
+ out.copy_(result)
+ return out
+
+
+def inclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
+ """
+ Compute inclusive cumulative sum along the specified dimension.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ result = ops.inclusive_cumsum(x, dim)
+ out.copy_(result)
+ return out
+
+
+def histogram(x: torch.Tensor, num_bins: int) -> torch.Tensor:
+ """
+ Compute histogram of input tensor values.
+
+ Args:
+ x: Input tensor
+ num_bins: Number of histogram bins
+
+ Returns:
+ Histogram tensor with counts for each bin
+ """
+ return ops.histogram(x, num_bins)
+
+
+def indices(
+ padded_bins: torch.Tensor,
+ block_size: int,
+ output_block_rows: int,
+ output_block_columns: int,
+) -> torch.Tensor:
+ """
+ Construct indices from padded bins for sparse operations.
+
+ Args:
+ padded_bins: Tensor containing bin boundaries
+ block_size: Size of each block
+ output_block_rows: Number of rows in output blocks
+ output_block_columns: Number of columns in output blocks
+
+ Returns:
+ Tensor containing constructed indices
+ """
+ return ops.indices(padded_bins, block_size, output_block_rows, output_block_columns)
+
+
+def replicate_forward(
+ x: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Forward pass of replicate operation - replicate values according to bin sizes.
+
+ Args:
+ x: Input tensor with values to replicate
+ bins: Tensor containing bin sizes
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ return ops.replicate_forward(x, bins, out)
+
+
+def replicate_backward(
+ grad: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Backward pass of replicate operation - reduce gradients back to bins.
+
+ Args:
+ grad: Gradient tensor to reduce
+ bins: Tensor containing bin sizes
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ return ops.replicate_backward(grad, bins, out)
+
+
+def sort(
+ x: torch.Tensor, end_bit: int, x_out: torch.Tensor, iota_out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Radix sort with index tracking.
+
+ Args:
+ x: Input tensor to sort
+ end_bit: Number of bits to consider in sorting
+ x_out: Output tensor for sorted values
+ iota_out: Output tensor for sorted indices
+
+ Returns:
+ The sorted values tensor
+ """
+ return ops.sort(x, end_bit, x_out, iota_out)
+
+
+# Convenience functions for common use cases
+def cumsum(x: torch.Tensor, dim: int = -1, exclusive: bool = False) -> torch.Tensor:
+ """
+ Compute cumulative sum with automatic output allocation.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum (default: last dimension)
+ exclusive: Whether to compute exclusive (True) or inclusive (False) cumsum
+
+ Returns:
+ New tensor containing the cumulative sum
+ """
+ out = torch.empty_like(x)
+ if exclusive:
+ return exclusive_cumsum(x, dim, out)
+ else:
+ return inclusive_cumsum(x, dim, out)
+
+
+def argsort(x: torch.Tensor, end_bit: int = 32) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Sort tensor and return both sorted values and indices.
+
+ Args:
+ x: Input tensor to sort
+ end_bit: Number of bits to consider in sorting
+
+ Returns:
+ Tuple of (sorted_values, sorted_indices)
+ """
+ x_out = torch.empty_like(x)
+ iota_out = torch.empty_like(x)
+ sort(x, end_bit, x_out, iota_out)
+ return x_out, iota_out
+
+
+# Export public API
+__all__ = [
+ "MyReplacementLayer",
+ # Direct kernel exports
+ "exclusive_cumsum",
+ "inclusive_cumsum",
+ "histogram",
+ "indices",
+ "replicate_forward",
+ "replicate_backward",
+ "sort",
+ "cumsum",
+ "argsort",
+ # Original exports
+ "Arguments",
+ "ParallelDroplessMLP",
+ "dMoE",
+ "SparseGLU",
+ "MLP",
+ "SparseMLP",
+ "MoE",
+ "ParallelMLP",
+ "get_load_balancing_loss",
+]
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/_layers/__init__.py b/build/torch28-cxx11-cu128-x86_64-linux/_layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a720e7a2cc4e44636f6e433a2750e945dc38e8b2
--- /dev/null
+++ b/build/torch28-cxx11-cu128-x86_64-linux/_layers/__init__.py
@@ -0,0 +1,10 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# from megablocks.layers.dmoe import dMoE
+from .moe import MoE
+
+__all__ = [
+ 'MoE',
+ # 'dMoE',
+]
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/_layers/activation_fn.py b/build/torch28-cxx11-cu128-x86_64-linux/_layers/activation_fn.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e1d956704840aa4daf7d1d71d24e051567feab9
--- /dev/null
+++ b/build/torch28-cxx11-cu128-x86_64-linux/_layers/activation_fn.py
@@ -0,0 +1,33 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Callable, Union
+
+import torch
+from ..stk import Matrix
+
+
+def act_fn(
+ x: Matrix,
+ function: Callable,
+ return_grad_fn: bool = False,
+ **kwargs,
+) -> Union[tuple[Matrix, Any] | Matrix]:
+ assert isinstance(x, Matrix)
+ with torch.set_grad_enabled(torch.is_grad_enabled() or return_grad_fn):
+ if return_grad_fn:
+ x.data.requires_grad = True
+ out = function(x.data, **kwargs)
+ y = Matrix(
+ x.size(),
+ out,
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
+ if return_grad_fn:
+ return y, out.backward
+ return y
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/_layers/all_to_all.py b/build/torch28-cxx11-cu128-x86_64-linux/_layers/all_to_all.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ac7067bcaa34db1d82b340c43550fe3577aa7a3
--- /dev/null
+++ b/build/torch28-cxx11-cu128-x86_64-linux/_layers/all_to_all.py
@@ -0,0 +1,54 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import torch.distributed as dist
+
+
+class AllToAllOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op):
+ out = torch.empty((sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype)
+
+ ctx.input_shape = x.shape
+ ctx.output_split_sizes = output_split_sizes
+ ctx.input_split_sizes = input_split_sizes
+ ctx.group = group
+ handle = dist.all_to_all_single(
+ out,
+ x,
+ output_split_sizes=output_split_sizes,
+ input_split_sizes=input_split_sizes,
+ group=group,
+ async_op=async_op,
+ )
+ return out, handle
+
+ @staticmethod
+ def backward(ctx, grad, _):
+ if ctx.needs_input_grad[0]:
+ out = torch.empty(
+ ctx.input_shape,
+ device=grad.device,
+ dtype=grad.dtype,
+ )
+ dist.all_to_all_single(
+ out,
+ grad,
+ output_split_sizes=ctx.input_split_sizes,
+ input_split_sizes=ctx.output_split_sizes,
+ group=ctx.group,
+ )
+ return out, None, None, None, None
+ return None, None, None, None, None
+
+
+def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False):
+ return AllToAllOp.apply(
+ x,
+ output_split_sizes,
+ input_split_sizes,
+ group,
+ async_op,
+ )
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/_layers/arguments.py b/build/torch28-cxx11-cu128-x86_64-linux/_layers/arguments.py
new file mode 100644
index 0000000000000000000000000000000000000000..4db9b1bd38bc2e2f421625c124f86b85f45c5ae0
--- /dev/null
+++ b/build/torch28-cxx11-cu128-x86_64-linux/_layers/arguments.py
@@ -0,0 +1,101 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import dataclasses
+from functools import partial
+from typing import Any, Callable, Optional, Union
+
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+
+# import megablocks.grouped_gemm_util as grouped_gemm
+from .. import grouped_gemm_util as grouped_gemm
+
+# Type annotation for in-place Tensor initialization function.
+InitFn = Union[Callable[[torch.Tensor], None], partial[torch.Tensor]]
+
+_ALLOWED_BITWIDTHS = (-1, 4, 8)
+
+DEFAULT_ACTIVATION_FN = partial(F.gelu, approximate='tanh')
+
+
+@dataclasses.dataclass
+class Arguments:
+ # Model arguments.
+ hidden_size: int = 1024
+ ffn_hidden_size: int = 4096
+ num_layers: int = 1
+ bias: bool = True
+ return_bias: bool = True
+ activation_fn: Optional[Callable] = DEFAULT_ACTIVATION_FN
+
+ # MoE arguments.
+ moe_num_experts: int = 1
+ moe_top_k: int = 1
+ moe_capacity_factor: int = 1
+ moe_normalize_expert_weights: Optional[Union[int, float]] = None
+ moe_loss_weight: float = 0.1
+ moe_jitter_eps: Optional[float] = None
+ moe_lbl_in_fp32: bool = False
+
+ # Parallelism arguments.
+ moe_expert_model_parallelism: bool = False
+ expert_parallel_group: Optional[dist.ProcessGroup] = None
+ pipeline_model_parallel_size: int = 1
+ num_layers_per_virtual_pipeline_stage: Optional[int] = None
+
+ # Compute arguments.
+ memory_optimized_mlp: bool = False
+ mlp_type: str = 'mlp'
+ mlp_impl: str = 'sparse'
+
+ # Initialization arguments.
+ fp16: bool = True
+ bf16: bool = False
+ device: Union[int, torch.device] = dataclasses.field(default_factory=torch.cuda.current_device)
+ init_method: InitFn = partial(torch.nn.init.normal_, mean=0.0, std=0.02)
+ output_layer_init_method: InitFn = init_method
+
+ # Benchmarking arguments.
+ uniform_expert_assignment: bool = False
+
+ # shared expert arguments
+ shared_expert: bool = False # enable using shared expert
+ fc_cls: Any = torch.nn.Linear # class of the fully connected layer in shared expert (purpose: to allow using custom FC layer eg te.Linear (for FP8))
+ fc_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict,) # kwargs for custom fc layers
+ remat_act_fn: bool = True # enable act fn to be rematerialized instead of stored
+ shared_expert_hidden_size: Optional[
+ int] = None # hidden size of the shared expert IF we want to set it to something different from hidden_size
+ shared_expert_weighted_sum: bool = False # enable using weighted sum for shared expert output (wieghted by number of experts used)
+
+ # Router Z-loss arguments
+ moe_zloss_weight: float = 0 # 1e-3 is a reasonable value
+ moe_zloss_in_fp32: bool = False
+
+ def __post_init__(self):
+ # Sparse MLP is not supported with triton >=3.2.0
+ # TODO: Remove this once sparse is supported with triton >=3.2.0
+ if self.__getattribute__('mlp_impl') == 'sparse':
+ try:
+ import triton
+ if triton.__version__ >= '3.2.0':
+ raise ValueError(
+ 'Sparse MLP is not supported with triton >=3.2.0. Please use mlp_impl="grouped" instead.',
+ )
+ except ImportError:
+ raise ImportError('Triton is required for sparse MLP implementation')
+
+ if self.__getattribute__('mlp_impl') == 'grouped':
+ grouped_gemm.assert_grouped_gemm_is_available()
+
+ if self.shared_expert_hidden_size is None:
+ self.shared_expert_hidden_size = self.ffn_hidden_size
+
+
+def from_megatron(megatron_args: Any):
+ args = Arguments()
+ for field in dataclasses.fields(args):
+ if hasattr(megatron_args, field.name):
+ setattr(args, field.name, getattr(megatron_args, field.name))
+ return args
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/_layers/common.py b/build/torch28-cxx11-cu128-x86_64-linux/_layers/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d07109702963ba48a3b94ab860807954dfd79c1
--- /dev/null
+++ b/build/torch28-cxx11-cu128-x86_64-linux/_layers/common.py
@@ -0,0 +1,26 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+from .arguments import Arguments
+
+
+def dtype(args: Arguments):
+ if args.fp16:
+ return torch.float16
+ elif args.bf16:
+ return torch.bfloat16
+ return None
+
+
+def cast_if_autocast_enabled(tensor):
+ if torch.is_autocast_enabled():
+ if tensor.device.type == 'cuda':
+ dtype = torch.get_autocast_gpu_dtype()
+ elif tensor.device.type == 'cpu':
+ dtype = torch.get_autocast_cpu_dtype()
+ else:
+ raise NotImplementedError()
+ return tensor.to(dtype=dtype)
+ return tensor
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/_layers/dmlp_registry.py b/build/torch28-cxx11-cu128-x86_64-linux/_layers/dmlp_registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..de2ed047042e438c7190ebb139b6f7f30009734c
--- /dev/null
+++ b/build/torch28-cxx11-cu128-x86_64-linux/_layers/dmlp_registry.py
@@ -0,0 +1,42 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Union
+
+from . import glu, mlp
+from .arguments import Arguments
+
+MlpType = Union[mlp.SparseMLP, glu.SparseGLU]
+
+_REGISTRY = {
+ 'mlp': {
+ 'grouped': mlp.GroupedMLP,
+ 'sparse': mlp.SparseMLP,
+ },
+ 'glu': {
+ 'grouped': glu.GroupedGLU,
+ 'sparse': glu.SparseGLU,
+ },
+}
+
+
+def get(args: Arguments) -> MlpType:
+ """Returns an MLP for use in a dMoE instance.
+
+ Uses the provided arguments to instantiate the appropriate
+ MLP instance. This only contains MLPs for use in dMoEs
+ (ie. only for the dropless versions of MoEs).
+
+ Args:
+ args: propagated Arguments dataclass.
+
+ Returns:
+ An instantiated MLP constructed using the input args.
+ """
+ if args.mlp_type not in _REGISTRY:
+ raise ValueError(f'Unsupported mlp type: {args.mlp_type}')
+
+ if args.mlp_impl not in _REGISTRY[args.mlp_type]:
+ raise ValueError(f'{args.mlp_type} does not support {args.mlp_impl} backend.',)
+
+ return _REGISTRY[args.mlp_type][args.mlp_impl](args)
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/_layers/dmoe.py b/build/torch28-cxx11-cu128-x86_64-linux/_layers/dmoe.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d0375a4df2f27134c4127e60be04f3b45693050
--- /dev/null
+++ b/build/torch28-cxx11-cu128-x86_64-linux/_layers/dmoe.py
@@ -0,0 +1,337 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import numpy as np
+import torch
+
+# try:
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/dmoe.py is needed.',
+# )
+
+# import megablocks.ops as ops
+# # from megablocks.ops import ops
+# from megablocks.layers import common, dmlp_registry, moe, mpu
+# from megablocks.layers.arguments import Arguments
+
+from .. import stk
+from .. import ops
+from . import common, dmlp_registry, moe, mpu
+from .arguments import Arguments
+
+def promote_scalar(x):
+ return x.view(1) if not len(x.size()) else x
+
+
+class ParallelDroplessMLP(moe.ParallelMLP):
+
+ def __init__(self, args: Arguments):
+ super(ParallelDroplessMLP, self).__init__(args)
+ self.hidden_size = args.hidden_size
+ self.ffn_hidden_size = mpu.features_per_rank(args)
+ self.blocking = 128
+ self.mlp = dmlp_registry.get(args)
+
+ # Calculate the number of bits needed to represent the column indices
+ # in the intermediate sparse matrix.
+ max_column_index = ((self.ffn_hidden_size * self.num_experts) // self.blocking)
+ self.transpose_sort_end_bit = max(
+ int(np.ceil(np.log2(max_column_index))),
+ 1,
+ )
+
+ def sparse_transpose(self, size, row_indices, column_indices, offsets):
+ block_columns = size[1] // self.blocking
+
+ # Sort row indices by column indices to get the transposed matrix's
+ # column indices.
+ #
+ # NOTE: Our sort operation uses the same width indices as the input values.
+ # To avoid overflow when we have large activation matrices we cast to
+ # 32-bit before sorting.
+ _, gather_indices = ops.sort(
+ column_indices.int(),
+ self.transpose_sort_end_bit,
+ )
+
+ # There are a constant number of blocks in every row of the sparse matrix.
+ # A blocks offset is:
+ #
+ # row_index * blocks_per_row + column_index % blocks_per_row
+ #
+ # Once we have the block offsets ordered for transposition we can divide
+ # by blocks_per_row to get the transposed column indices.
+ column_indices_t = row_indices.gather(0, gather_indices.long())
+ block_offsets_t = gather_indices.int()
+
+ zero = torch.zeros((1,), dtype=torch.int32, device=row_indices.device)
+ nnz_per_column = ops.histogram(column_indices, block_columns)
+ nnz_per_column = ops.inclusive_cumsum(nnz_per_column, 0)
+ if nnz_per_column.dim() == 0:
+ # This addresses an edge case when ffn_hidden_size is equal to self.blocking.
+ nnz_per_column = nnz_per_column.unsqueeze(0)
+ offsets_t = torch.cat([zero, nnz_per_column])
+ return column_indices_t, offsets_t, block_offsets_t
+
+ def topology(self, x, padded_bins):
+ padded_tokens, _ = x.size()
+ assert padded_tokens % self.blocking == 0
+ if self.ffn_hidden_size % self.blocking != 0:
+ raise ValueError(
+ f'The ffn_hidden_size {self.ffn_hidden_size} must be divisible by ' +
+ f'the block size {self.blocking}. Please update your configuration.',
+ )
+
+ # Offsets for the sparse matrix. All rows have the
+ # same number of nonzero blocks dictated by the
+ # dimensionality of a single expert.
+ block_rows = padded_tokens // self.blocking
+ blocks_per_row = self.ffn_hidden_size // self.blocking
+ offsets = torch.arange(
+ 0,
+ block_rows * blocks_per_row + 1,
+ blocks_per_row,
+ dtype=torch.int32,
+ device=x.device,
+ )
+
+ # Indices for the sparse matrix. The indices for
+ # the intermediate matrix are dynamic depending
+ # on the mapping of tokens to experts.
+ column_indices = ops.topology(
+ padded_bins,
+ self.blocking,
+ block_rows,
+ blocks_per_row,
+ )
+
+ # TODO(tgale): This is unused. Remove the need for this in stk.
+ # For now, use meta init to save the device memory.
+ data = torch.empty(
+ column_indices.numel(),
+ self.blocking,
+ self.blocking,
+ dtype=common.dtype(self.args),
+ device='meta',
+ )
+ shape = (
+ padded_tokens,
+ self.ffn_hidden_size * mpu.experts_per_rank(self.args),
+ )
+ row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
+ column_indices_t, offsets_t, block_offsets_t = self.sparse_transpose(
+ shape,
+ row_indices,
+ column_indices,
+ offsets,
+ )
+ return stk.Matrix(
+ shape,
+ data,
+ row_indices,
+ column_indices,
+ offsets,
+ column_indices_t,
+ offsets_t,
+ block_offsets_t,
+ )
+
+ def indices_and_padded_bins(self, top_experts):
+ # Sort the expert ids to produce the scatter/gather
+ # indices for the permutation.
+ top_experts = top_experts.int()
+ bin_ids, indices = ops.sort(top_experts, self.sort_end_bit)
+
+ # Histogram the expert ids to identify the number of
+ # tokens routed to each expert.
+ tokens_per_expert = ops.histogram(top_experts, self.num_experts)
+
+ # Round the token counts up to the block size used in
+ # the matrix muliplications. Caculate the starting
+ # position of each bin.
+ padded_tokens_per_expert = ops.round_up(
+ tokens_per_expert,
+ self.blocking,
+ )
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ padded_bins = promote_scalar(padded_bins)
+
+ # Calculate the bin bounds for the sorted tokens.
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ bins = promote_scalar(bins)
+ return indices, bin_ids, bins, padded_bins, tokens_per_expert
+
+ def sparse_forward_once(self, x, expert_weights, top_experts):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, padded_bins, tokens_per_expert = (self.indices_and_padded_bins(top_experts))
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.padded_gather(
+ x,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ self.top_k,
+ )
+
+ # Create the sparse matrix topology.
+ with torch.no_grad():
+ topo = self.topology(x, padded_bins)
+
+ # Perform the expert computation.
+ x = self.mlp(x, topo)
+
+ # Un-route the data for the MoE output.
+ x = ops.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ padded_bins,
+ self.top_k,
+ )
+ return x, tokens_per_expert
+
+ # For use in the base-class parallel_forward_once.
+ def sparse_permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy, # unused
+ top_k,
+ ):
+
+ # Round the token counts up to the block size used in the matrix
+ # multiplication. Calculate the starting position of each bin.
+ padded_tokens_per_expert = ops.round_up(
+ tokens_per_expert,
+ self.blocking,
+ )
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ padded_bins = promote_scalar(padded_bins)
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
+
+ # Create the sparse matrix topology.
+ with torch.no_grad():
+ topo = self.topology(x, padded_bins)
+
+ # Perform the expert computation.
+ x = self.mlp(x, topo)
+
+ # Un-route the data for the MoE output.
+ return ops.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ def grouped_forward_once(self, x, expert_weights, top_experts):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ out = self.grouped_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ -1, # unused
+ self.args.moe_top_k,
+ )
+ return out, tokens_per_expert
+
+ def grouped_permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy, # unused
+ top_k,
+ ):
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.gather(x, indices, bin_ids, bins, top_k)
+
+ # Perform the expert computation.
+ x = self.mlp(x, tokens_per_expert)
+
+ # Un-route the data for the MoE output.
+ return ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
+
+ def forward_once(self, x, expert_weights, top_experts):
+ if self.args.mlp_impl == 'sparse':
+ return self.sparse_forward_once(x, expert_weights, top_experts)
+ else:
+ return self.grouped_forward_once(x, expert_weights, top_experts)
+
+ def permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ ):
+ if self.args.mlp_impl == 'sparse':
+ return self.sparse_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ )
+ else:
+ return self.grouped_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ )
+
+
+class dMoE(moe.MoE):
+
+ def _init_experts_mlp(self, args: Arguments):
+ return ParallelDroplessMLP(args)
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/_layers/gelu.py b/build/torch28-cxx11-cu128-x86_64-linux/_layers/gelu.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4c9e6532798615b5c12c96694241a4c18ee8f7b
--- /dev/null
+++ b/build/torch28-cxx11-cu128-x86_64-linux/_layers/gelu.py
@@ -0,0 +1,52 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# try:
+# import stk
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/gelu.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+import torch.nn.functional as F
+
+
+@torch.jit.script
+def _gelu_backward_inplace(g, x):
+ tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
+ ff = (0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out))
+ return g.mul_(ff)
+
+
+def gelu_backward_(grad: stk.Matrix, x: stk.Matrix):
+ # NOTE: The two sparse matrices must have the same topology.
+ if isinstance(grad, stk.Matrix) and isinstance(x, stk.Matrix):
+ return stk.Matrix(
+ x.size(),
+ _gelu_backward_inplace(grad.data, x.data),
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
+ return _gelu_backward_inplace(grad, x)
+
+
+def gelu(x: stk.Matrix):
+ assert isinstance(x, stk.Matrix)
+ return stk.Matrix(
+ x.size(),
+ F.gelu(x.data, approximate='tanh'),
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/_layers/glu.py b/build/torch28-cxx11-cu128-x86_64-linux/_layers/glu.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f297a41ff6a1a2a285f5b461951672364b898da
--- /dev/null
+++ b/build/torch28-cxx11-cu128-x86_64-linux/_layers/glu.py
@@ -0,0 +1,244 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# import stk.ops
+# try:
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/glu.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+
+# from megablocks import grouped_gemm_util as gg
+# from megablocks.layers import common, mpu
+# from megablocks.layers.activation_fn import act_fn
+# from megablocks.layers.arguments import Arguments
+# from megablocks.layers.mlp import (
+# SharedMLP,
+# SparseMLP,
+# create_dmoe_expert_weights,
+# resolve_dtensor,
+# )
+
+from .. import grouped_gemm_util as gg
+from . import common, mpu
+from .activation_fn import act_fn
+from .arguments import Arguments
+from .mlp import (
+ SharedMLP,
+ SparseMLP,
+ create_dmoe_expert_weights,
+ resolve_dtensor,
+)
+
+
+class SparseGLU(SparseMLP):
+
+ def __init__(self, args: Arguments):
+ super().__init__(args)
+ self.v1 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ with torch.no_grad():
+ self.v1.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ ),
+ )
+
+ mpu.set_expert_model_parallel_attributes(
+ self.v1,
+ self._should_set_parallelism_attribute,
+ )
+
+ def forward(self, x, topo):
+ if self.args.memory_optimized_mlp:
+ raise NotImplementedError(
+ 'Memory optimized implementation not yet supported with GLU with sparse kernels.',
+ )
+
+ w1, v1, w2 = self.scale_grad(self.w1), self.scale_grad(self.v1,), self.scale_grad(self.w2)
+ w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2)
+
+ # Compute the GLU.
+ x1 = stk.ops.sdd(x, w1.t(), topo)
+ x2 = stk.ops.sdd(x, v1.t(), topo)
+
+ activation_fn_out = act_fn(x1, self.args.activation_fn)
+ x1 = stk.ops.mul(activation_fn_out, x2)
+
+ return stk.ops.dsd(x1, w2)
+
+
+class MemoryOptimizedGroupedGLU(torch.autograd.Function):
+ """GroupedMLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, v1, w2, batch_sizes, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ v1 = v1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], v1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not v1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1', 'v1' and 'w2'.")
+
+ # Layer 0: x @ w1.t().
+ assert gg.backend is not None
+ sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True)
+ v1_out = gg.backend.gmm(x, v1, batch_sizes, trans_b=True)
+
+ # GeLU.
+ activation_fn_out = activation_fn(sdd_out) * v1_out
+
+ # Layer 1: x @ w2.
+ dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, v1, w2, batch_sizes, x, sdd_out, v1_out)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx, ddsd_out):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # Unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, v1, w2 = saved_tensors[:3]
+ batch_sizes = saved_tensors[3]
+ x = saved_tensors[4]
+ sdd_out, v1_out = saved_tensors[5:7]
+
+ # Rematerialize activation_fn output.
+ activation_fn = ctx.activation_fn
+ with torch.set_grad_enabled(True):
+ sdd_out.requires_grad = True
+ v1_out.requires_grad = True
+ activation_fn_out = activation_fn(sdd_out) * v1_out
+ activation_grad_fn = activation_fn_out.backward
+
+ # Compute dw2 with recomputed activation_fn output.
+ assert gg.backend is not None
+ dw2 = gg.backend.gmm(
+ activation_fn_out,
+ ddsd_out,
+ batch_sizes,
+ trans_a=True,
+ )
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ gg.backend.gmm(
+ ddsd_out,
+ w2,
+ batch_sizes,
+ trans_b=True,
+ c=dactivation_fn_out,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out)
+ dsdd_out = sdd_out.grad
+ dv1_out = v1_out.grad
+
+ # Compute dw1.
+ dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True)
+
+ # Compute dv1.
+ dv1 = gg.backend.gmm(dv1_out, x, batch_sizes, trans_a=True)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ dx = ddsd_out
+ gg.backend.gmm(dsdd_out, w1, batch_sizes, c=dx)
+ dx += gg.backend.gmm(dv1_out, v1, batch_sizes)
+ return dx, dw1, dv1, dw2, None, None
+
+
+memory_optimized_grouped_glu = MemoryOptimizedGroupedGLU.apply
+
+
+class GroupedGLU(SparseGLU):
+
+ def forward(self, x, tokens_per_expert):
+ batch_sizes = tokens_per_expert.cpu().to(torch.long)
+ w1, v1, w2 = (
+ self.scale_grad(self.w1),
+ self.scale_grad(self.v1),
+ self.scale_grad(self.w2),
+ )
+ w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2)
+
+ # Re-shape the weights for the grouped GEMMs.
+ ne = mpu.experts_per_rank(self.args)
+ w1 = w1.view(ne, -1, self.args.hidden_size)
+ v1 = v1.view(ne, -1, self.args.hidden_size)
+ w2 = w2.view(ne, -1, self.args.hidden_size)
+
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_grouped_glu(
+ x,
+ w1,
+ v1,
+ w2,
+ batch_sizes,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ assert gg.ops is not None
+ x1 = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
+ x2 = gg.ops.gmm(x, v1, batch_sizes, trans_b=True)
+ x1 = self.args.activation_fn(x1) * x2
+ return gg.ops.gmm(x1, w2, batch_sizes)
+
+
+class SharedGLU(SharedMLP):
+ """GPU for shared expert.
+
+ Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTGLU class
+ """
+
+ def __init__(self, args: Arguments):
+ super().__init__(args)
+ self.gate_proj = args.fc_cls(
+ args.hidden_size,
+ self.args.shared_expert_hidden_size,
+ **self.fc_kwargs,
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/_layers/memory_test.py b/build/torch28-cxx11-cu128-x86_64-linux/_layers/memory_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..74d1166931b712635131985b25a89f4ca23e576d
--- /dev/null
+++ b/build/torch28-cxx11-cu128-x86_64-linux/_layers/memory_test.py
@@ -0,0 +1,103 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import gc
+
+import torch
+import torch.distributed as dist
+
+# from megablocks.layers import arguments, dmoe
+from . import arguments, dmoe
+
+_TESTS = ((8, 2048, 4096, 4096, 32, 4),)
+
+
+def get_tensors():
+ ptrs = set()
+ out = []
+ for obj in gc.get_objects():
+ if torch.is_tensor(obj):
+ if not obj.is_contiguous() or obj.data_ptr() in ptrs:
+ continue
+ out.append(obj)
+ ptrs.add(obj.data_ptr())
+ return out
+
+
+def test_memory(
+ group,
+ batch_size,
+ sequence_length,
+ hidden_size,
+ ffn_hidden_size,
+ num_experts,
+ top_k,
+):
+ args = arguments.Arguments(
+ hidden_size=hidden_size,
+ ffn_hidden_size=ffn_hidden_size,
+ moe_num_experts=num_experts,
+ moe_top_k=top_k,
+ moe_expert_model_parallelism=True,
+ expert_parallel_group=group,
+ fp16=False,
+ bf16=True,
+ device=torch.cuda.current_device(),
+ )
+ layer = dmoe.dMoE(args).cuda()
+
+ x = torch.randn((batch_size, sequence_length, hidden_size),
+ device=torch.cuda.current_device(),
+ dtype=torch.bfloat16).requires_grad_(True)
+ torch.cuda.empty_cache()
+
+ # Run forward + backward.
+ # with torch.autograd.detect_anomaly():
+ out, _ = layer(x)
+ out.mean().backward()
+
+ # Report peak memory.
+ mem = torch.cuda.max_memory_allocated()
+ print('Max Memory Allocated = {:0.0f}MiB'.format(mem / 1e6))
+ print('Max Memory Reserved = {:0.0f}MiB'.format(torch.cuda.max_memory_reserved() / 1e6,),)
+
+ # Calculate weight and gradient memory usage.
+ weight_memory = 2 * (
+ layer.router.layer.weight.numel() + layer.experts.mlp.w1.numel() + layer.experts.mlp.w2.numel()
+ )
+
+ def grad_numel(x):
+ if x.grad is not None:
+ return x.grad.numel()
+ return 0
+
+ grad_memory = 2 * (
+ grad_numel(layer.router.layer.weight) + grad_numel(layer.experts.mlp.w1) + grad_numel(layer.experts.mlp.w2)
+ )
+ weight_memory += grad_memory
+
+ print('Weight Memory Allocated = {:0.0f}MiB'.format(weight_memory / 1e6))
+ print('Activation Memory Allocated = {:0.0f}MiB'.format((mem - weight_memory) / 1e6,),)
+
+ # Manually calculate GPU memory usage from the garbage
+ # collector.
+ gc.collect()
+ total = 0
+ tensors = get_tensors()
+ tensors = sorted(tensors, key=lambda x: -x.numel())
+ for i, t in enumerate(tensors):
+ total += t.numel()
+ print(f'{i}: {t.shape}, {t.numel() * 2}')
+ del tensors
+
+ print('Total Bytes Found = {:0.0f}MiB'.format(total * 2 / 1e6))
+
+
+if __name__ == '__main__':
+ assert dist.is_available()
+ group = dist.init_process_group(backend='nccl')
+ local_rank = dist.get_rank(group)
+ torch.cuda.set_device(local_rank)
+
+ for args in _TESTS:
+ test_memory(group, *args)
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/_layers/mlp.py b/build/torch28-cxx11-cu128-x86_64-linux/_layers/mlp.py
new file mode 100644
index 0000000000000000000000000000000000000000..c99afb9904c24a8b6a83e79059cd1251dbbfd99e
--- /dev/null
+++ b/build/torch28-cxx11-cu128-x86_64-linux/_layers/mlp.py
@@ -0,0 +1,587 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# try:
+# import stk
+# import stk.backend.triton_kernels
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/mlp.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+from packaging import version
+
+# from megablocks import grouped_gemm_util as gg
+# from megablocks.layers import common, gelu, mpu
+# from megablocks.layers.activation_fn import act_fn
+# from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
+
+from .. import grouped_gemm_util as gg
+from . import common, gelu, mpu
+from .activation_fn import act_fn
+from .arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
+
+class ScaleGradient(torch.autograd.Function):
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx: Any, x: torch.Tensor, scale: float):
+ ctx.scale = scale
+ return x
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx: torch.Tensor, grad: torch.Tensor):
+ return grad * ctx.scale, None
+
+
+scale_gradient = ScaleGradient.apply
+
+
+def resolve_dtensor(weight: torch.Tensor):
+ if version.parse(torch.__version__) >= version.parse('2.0.0'):
+ from torch.distributed._tensor import DTensor
+ if isinstance(weight, DTensor):
+ return weight.to_local()
+ return weight
+
+
+def create_moe_expert_weights(
+ args: Arguments,
+ num_experts: int,
+ ffn_hidden_size: int,
+ hidden_size: int,
+ init_method: InitFn,
+):
+ # Create the entire weight matrix such that the sampled weights will
+ # not vary between data parallelism and expert model parallelism for
+ # the same random seed.
+ master_weights = torch.empty(
+ num_experts,
+ ffn_hidden_size,
+ hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ )
+ init_method(master_weights)
+
+ if not args.moe_expert_model_parallelism:
+ return master_weights
+
+ # Calculate the amount of sharding in each dimension.
+ expert_sharding_degree = mpu.expert_sharding_degree(args)
+ hidden_sharding_degree = mpu.hidden_sharding_degree(args)
+
+ # Calculate the experts per rank.
+ #
+ # NOTE: We assign ranks to be expert parallel before going
+ # tensor parallel.
+ rank = mpu.get_expert_parallel_rank(args)
+ expert_rank = rank % expert_sharding_degree
+ num_experts_per_rank = num_experts // expert_sharding_degree
+ start_expert = expert_rank * num_experts_per_rank
+ end_expert = (expert_rank + 1) * num_experts_per_rank
+
+ # Calculate the rows per rank.
+ row_rank = rank // expert_sharding_degree
+ num_rows_per_rank = ffn_hidden_size // hidden_sharding_degree
+ start_row = row_rank * num_rows_per_rank
+ end_row = (row_rank + 1) * num_rows_per_rank
+
+ # Slice the weight matrix to get the chunk for this rank.
+ with torch.no_grad():
+ weights = master_weights[start_expert:end_expert, start_row:end_row]
+ return weights
+
+
+class MLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ # expert_parallel_world_size = mpu.get_expert_parallel_world_size(args)
+ experts_per_rank = mpu.experts_per_rank(args)
+
+ self.w1 = torch.nn.Parameter(
+ torch.empty(
+ experts_per_rank,
+ args.hidden_size,
+ mpu.features_per_rank(args),
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ self.w2 = torch.nn.Parameter(
+ torch.empty(
+ experts_per_rank,
+ mpu.features_per_rank(args),
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w1,
+ args.moe_expert_model_parallelism,
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w2,
+ args.moe_expert_model_parallelism,
+ )
+
+ # Initialize the parameters for the MLP.
+ #
+ # NOTE: It is important that we create the weight tensors prior
+ # to creating the master weights and slicing our the piece for
+ # this rank. If the master weights are created first the PyTorch
+ # caching allocator appears to use the same memory block for these
+ # and the slice which causes large increases in our peak memory
+ # usage.
+ with torch.no_grad():
+ w1 = create_moe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ )
+ self.w1.copy_(w1.transpose(1, 2).contiguous())
+ self.w2.copy_(
+ create_moe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.output_layer_init_method,
+ ),
+ )
+
+ self.gradient_scale = None
+ if self.args.moe_expert_model_parallelism:
+ self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,)
+
+ def scale_grad(self, w):
+ if self.gradient_scale is None:
+ return w
+ return scale_gradient(w, self.gradient_scale)
+
+ def forward(self, x):
+ w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2)
+ w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2)
+ x = torch.bmm(x, w1)
+ x = self.args.activation_fn(x)
+ return torch.bmm(x, w2)
+
+
+def create_dmoe_expert_weights(
+ args: Arguments,
+ num_experts: int,
+ rows: int,
+ columns: int,
+ init_method: InitFn,
+):
+ weights = create_moe_expert_weights(
+ args,
+ num_experts,
+ rows,
+ columns,
+ init_method,
+ )
+ return weights.view([-1, columns])
+
+
+class MemoryOptimizedMLP(torch.autograd.Function):
+ """Sparse MLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, w2, topo, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.")
+
+ topo_tensors = (
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets,
+ topo.column_indices_t,
+ topo.offsets_t,
+ topo.block_offsets_t,
+ )
+
+ # Layer 0: x @ w1.t().
+ sdd_out = stk.ops.sdd(x, w1.t(), topo)
+
+ # GeLU.
+ activation_fn_out = act_fn(sdd_out, activation_fn)
+
+ # Layer 1: x @ w2.
+ dsd_out = stk.ops.dsd(activation_fn_out, w2)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.shape = topo.shape
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.data.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, w2, *topo_tensors, x, sdd_out.data)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx, ddsd_out):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, w2 = saved_tensors[:2]
+ topo_tensors = saved_tensors[2:8]
+ x = saved_tensors[8]
+ sdd_out_data = saved_tensors[9]
+
+ # rematerialize activation function output
+ activation_fn = ctx.activation_fn
+ sdd_out = stk.Matrix(ctx.shape, sdd_out_data, *topo_tensors)
+ activation_fn_out, activation_grad_fn = act_fn(
+ sdd_out,
+ activation_fn,
+ return_grad_fn=True,
+ )
+
+ # Compute dw2 with recomputed activation_fn output.
+ dw2 = stk.ops.dsd(activation_fn_out.t(), ddsd_out)
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ stk.backend.triton_kernels.sdd(
+ ddsd_out,
+ w2.t(),
+ dactivation_fn_out.shape,
+ dactivation_fn_out.data,
+ dactivation_fn_out.offsets,
+ dactivation_fn_out.row_indices,
+ dactivation_fn_out.column_indices,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ if activation_fn is DEFAULT_ACTIVATION_FN:
+ dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out)
+ else:
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out.data)
+ dsdd_out = stk.Matrix(ctx.shape, sdd_out.data.grad, *topo_tensors)
+
+ # Compute dw1.
+ dw1 = stk.ops.dsd(dsdd_out.t(), x)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ stk.backend.triton_kernels.dsd(
+ dsdd_out.shape,
+ dsdd_out.data,
+ dsdd_out.offsets,
+ dsdd_out.row_indices,
+ dsdd_out.column_indices,
+ dsdd_out.offsets_t,
+ dsdd_out.column_indices_t,
+ dsdd_out.block_offsets_t,
+ False,
+ w1,
+ ddsd_out,
+ )
+ dx = ddsd_out
+ return dx, dw1, dw2, None, None
+
+
+memory_optimized_mlp = MemoryOptimizedMLP.apply
+
+
+class SparseMLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ self._num_rows_per_rank = mpu.experts_per_rank(args) * mpu.features_per_rank(args)
+
+ self.w1 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ self.w2 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+
+ # Initialize the parameters for the MLP.
+ #
+ # NOTE: It is important that we create the weight tensors prior
+ # to creating the master weights and slicing our the piece for
+ # this rank. If the master weights are created first the PyTorch
+ # caching allocator appears to use the same memory block for these
+ # and the slice which causes large increases in our peak memory
+ # usage.
+ with torch.no_grad():
+ self.w1.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ ),
+ )
+ self.w2.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.output_layer_init_method,
+ ),
+ )
+
+ self._should_set_parallelism_attribute = args.moe_expert_model_parallelism
+ mpu.set_expert_model_parallel_attributes(
+ self.w1,
+ self._should_set_parallelism_attribute,
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w2,
+ self._should_set_parallelism_attribute,
+ )
+
+ self.gradient_scale = None
+ if self.args.moe_expert_model_parallelism:
+ self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,)
+
+ def scale_grad(self, w):
+ if self.gradient_scale is None:
+ return w
+ return scale_gradient(w, self.gradient_scale)
+
+ def forward(self, x, topo):
+ w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2)
+ w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2)
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_mlp(
+ x,
+ w1,
+ w2,
+ topo,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ x = stk.ops.sdd(x, w1.t(), topo)
+ activation_fn_out = act_fn(x, self.args.activation_fn)
+ return stk.ops.dsd(activation_fn_out, w2)
+
+
+class MemoryOptimizedGroupedMLP(torch.autograd.Function):
+ """GroupedMLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, w2, batch_sizes, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.")
+
+ # Layer 0: x @ w1.t().
+ assert gg.backend is not None
+ sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True)
+
+ # activation_fn
+ activation_fn_out = activation_fn(sdd_out)
+
+ # Layer 1: x @ w2.
+ dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, w2, batch_sizes, x, sdd_out)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx: Any, ddsd_out: torch.Tensor):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # Unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, w2 = saved_tensors[:2]
+ batch_sizes = saved_tensors[2]
+ x = saved_tensors[3]
+ sdd_out = saved_tensors[4]
+
+ # Rematerialize activation_fn output.
+ activation_fn = ctx.activation_fn
+ with torch.set_grad_enabled(True):
+ sdd_out.requires_grad = True
+ activation_fn_out = activation_fn(sdd_out)
+ activation_grad_fn = activation_fn_out.backward
+
+ # Compute dw2 with recomputed activation_fn output.
+ assert gg.backend is not None
+ dw2 = gg.backend.gmm(
+ activation_fn_out,
+ ddsd_out,
+ batch_sizes,
+ trans_a=True,
+ )
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ gg.backend.gmm(
+ ddsd_out,
+ w2,
+ batch_sizes,
+ trans_b=True,
+ c=dactivation_fn_out,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ if activation_fn is DEFAULT_ACTIVATION_FN:
+ dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out)
+ else:
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out)
+ dsdd_out = sdd_out.grad
+
+ # Compute dw1.
+ dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ gg.backend.gmm(dsdd_out, w1, batch_sizes, c=ddsd_out)
+ dx = ddsd_out
+ return dx, dw1, dw2, None, None
+
+
+memory_optimized_grouped_mlp = MemoryOptimizedGroupedMLP.apply
+
+
+class GroupedMLP(SparseMLP):
+
+ def forward(self, x, tokens_per_expert):
+ batch_sizes = tokens_per_expert.cpu().to(torch.long)
+ w1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.w2))
+
+ # Re-shape the weights for the grouped GEMMs.
+ ne = mpu.experts_per_rank(self.args)
+ w1 = resolve_dtensor(w1).view(ne, -1, self.args.hidden_size)
+ w2 = resolve_dtensor(w2).view(ne, -1, self.args.hidden_size)
+
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_grouped_mlp(
+ x,
+ w1,
+ w2,
+ batch_sizes,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ assert gg.ops is not None
+ x = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
+ x = self.args.activation_fn(x)
+ return gg.ops.gmm(x, w2, batch_sizes)
+
+
+class SharedMLP(torch.nn.Module):
+ """MLP for shared expert.
+
+ Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTMLP class
+ """
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ self.fc_kwargs: dict[str, Any] = {
+ 'bias': args.bias,
+ 'device': args.device,
+ }
+ self.fc_kwargs.update(args.fc_kwargs)
+
+ self.up_proj = args.fc_cls(
+ args.hidden_size,
+ args.shared_expert_hidden_size,
+ **self.fc_kwargs,
+ )
+ self.act = args.activation_fn
+ self.down_proj = args.fc_cls(
+ args.shared_expert_hidden_size,
+ args.hidden_size,
+ **self.fc_kwargs,
+ )
+ self.down_proj._is_residual = True # a flag for llm-foundry init
+
+ def add_experts_sharedexpert(
+ self,
+ shared_expert_out: torch.Tensor,
+ expert_out: torch.Tensor,
+ ) -> torch.Tensor:
+ # Helper function to add expert output to shared expert output
+ # with optional weighted sum.
+ if self.args.shared_expert_weighted_sum:
+ # enable using weighted sum for shared expert output
+ # wieghted by number of experts used
+ t_experts = self.args.moe_top_k + 1
+ sh_mlp_out = shared_expert_out / t_experts
+ return sh_mlp_out.add(
+ expert_out,
+ alpha=(self.args.moe_top_k / t_experts),
+ )
+
+ return shared_expert_out + expert_out
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.down_proj(self.act(self.up_proj(x)))
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/_layers/moe.py b/build/torch28-cxx11-cu128-x86_64-linux/_layers/moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0a4aeaacc9c86fc70944e730c53f7a55644e05e
--- /dev/null
+++ b/build/torch28-cxx11-cu128-x86_64-linux/_layers/moe.py
@@ -0,0 +1,507 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Optional, Tuple
+
+import numpy as np
+import torch
+import torch.distributed as dist
+
+# import megablocks.ops as ops
+# from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry
+# from megablocks.layers.all_to_all import all_to_all
+# from megablocks.layers.arguments import Arguments
+
+from ..ops import (
+ sort,
+ histogram,
+ inclusive_cumsum,
+ exclusive_cumsum,
+ binned_gather,
+ binned_scatter,
+ gather,
+ scatter,
+ repeat,
+ replicate,
+)
+
+from . import common, mlp, mpu, router, sharedexpert_registry
+from .arguments import Arguments
+from .all_to_all import all_to_all
+
+_LOAD_BALANCING_LOSS = []
+
+
+def save_load_balancing_loss(loss):
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.append(loss)
+
+
+def get_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ return _LOAD_BALANCING_LOSS
+
+
+def clear_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.clear()
+
+
+def batched_load_balancing_loss(args: Arguments):
+ if args.moe_loss_weight == 0:
+ return 0.0
+
+ # tokens_per_expert[i].shape = (num_experts)
+ # expert_scores[i].shape = (tokens, num_experts)
+ tokens_per_expert, expert_scores = zip(*get_load_balancing_loss())
+ num_layers_per_pipeline_stage = (args.num_layers // args.pipeline_model_parallel_size)
+ if args.num_layers_per_virtual_pipeline_stage is not None:
+ num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage
+
+ if len(tokens_per_expert) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f'Expected {num_layers_per_pipeline_stage} token_per_experts '
+ f'but found {len(tokens_per_expert)}.\nnum_layers = '
+ f'{args.num_layers}\npipeline_model_parallel_size = '
+ f'{args.pipeline_model_parallel_size}\n'
+ 'num_layers_per_virtual_pipeline_stage'
+ f' = {args.num_layers_per_virtual_pipeline_stage}',
+ )
+ if len(expert_scores) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f'Expected {num_layers_per_pipeline_stage} expert_scores '
+ f'but found {len(tokens_per_expert)}.\nnum_layers = '
+ f'{args.num_layers}\npipeline_model_parallel_size = '
+ f'{args.pipeline_model_parallel_size}\n'
+ 'num_layers_per_virtual_pipeline_stage'
+ f' = {args.num_layers_per_virtual_pipeline_stage}',
+ )
+
+ # Verify the shape of the tokens_per_expert and expert_scores tensors.
+ assert all((x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert))
+
+ tokens = expert_scores[0].shape[0]
+ assert all(((x.ndim == 2 and x.shape[1] == args.moe_num_experts and x.shape[0] == tokens) for x in expert_scores))
+
+ # Concatenate the contributions of each layer and convert to
+ # the correct types and formats for the dot product.
+ expert_scores = torch.cat(expert_scores, dim=1)
+ if args.moe_lbl_in_fp32:
+ expert_scores = expert_scores.float()
+ if tokens != 0:
+ expert_scores = expert_scores.mean(dim=0)
+ else:
+ expert_scores = expert_scores.sum(dim=0)
+ tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype)
+
+ expected_values = num_layers_per_pipeline_stage * args.moe_num_experts
+ assert tokens_per_expert.numel() == expected_values
+ assert expert_scores.numel() == expected_values
+
+ # Calculate the total scale across all factors.
+ #
+ # loss_weight * num_experts / (num_layers * tokens * top_k)
+ scale_numerator = (args.moe_num_experts * args.moe_loss_weight)
+ scale_denominator = (args.num_layers * tokens * args.moe_top_k)
+ scale = scale_numerator / scale_denominator
+ return scale * torch.dot(tokens_per_expert, expert_scores)
+
+
+# NOTE: This class defines MoE expert computation, including expert model parallel
+# communication. When using FSDP on top of MegaBlocks this is the module that should
+# be wrapped s.t. the weight all-gathers can be scheduled *before* the expert model
+# parallel all2all.
+class ParallelMLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super(ParallelMLP, self).__init__()
+ self.args = args
+
+ # Calculate the number of experts in total and the number of experts
+ # owned by this rank.
+ # world_size = mpu.get_expert_parallel_world_size(args)
+ self.num_experts = args.moe_num_experts
+ self.top_k = self.args.moe_top_k
+
+ # Calculate the number of bits needed to represent the expert indices
+ # so that we can pass it to radix sort.
+ self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1)
+
+ # Expert MLP.
+ self.mlp = mlp.MLP(args)
+
+ self.bias: Optional[torch.Tensor]
+ if self.args.bias:
+ # Note that the output bias is not parallelized with expert
+ # model parallelism.
+ self.bias = torch.nn.Parameter(
+ torch.empty(
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ torch.nn.init.zeros_(self.bias)
+ else:
+ self.register_parameter('bias', None)
+
+ # Select the forward function for the operating mode.
+ self.forward_fn = (self.parallel_forward_once if args.moe_expert_model_parallelism else self.forward_once)
+
+ def expert_capacity(self, tokens: int) -> int:
+ world_size = mpu.get_expert_parallel_world_size(self.args)
+ tokens_per_expert = (self.top_k * tokens * world_size / self.num_experts)
+ return int(self.args.moe_capacity_factor * tokens_per_expert)
+
+ def load_balancing_loss(self, tokens_per_expert: torch.Tensor, expert_scores: torch.Tensor):
+ """Calculate the load balancing loss contribution."""
+ assert len(expert_scores.size()) == 2
+ tokens, num_experts = expert_scores.size()
+ assert num_experts == self.num_experts
+ assert len(tokens_per_expert.size()) == 1
+ num_experts, = tokens_per_expert.size()
+ assert num_experts == self.num_experts
+ scale = self.num_experts / (tokens * self.top_k)
+ return scale * torch.dot(
+ tokens_per_expert.to(expert_scores.dtype),
+ expert_scores.mean(dim=0),
+ )
+
+ def indices_and_bins(self,
+ top_expert: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ # Sort the expert ids to produce the scatter/gather
+ # indices for the permutation.
+ #
+ # TODO(tgale): Is it worth doing this conversion to 32-bit
+ # prior? Could we place the `torch.max` operation to return
+ # 32-bit expert indices?
+ top_expert = top_expert.int()
+ # output = ops.sort(top_expert, self.sort_end_bit)
+ output = sort(top_expert, self.sort_end_bit)
+ assert output is not None
+ bin_ids, indices = output
+
+ # Histogram the expert ids to identify the number of
+ # tokens routed to each expert.
+ #
+ # TODO(tgale): Does the sorted data produce a more favorable
+ # data distribution for histogram? Or is the op parallelism
+ # worth more?
+ # tokens_per_expert = ops.histogram(top_expert, self.num_experts)
+ tokens_per_expert = histogram(top_expert, self.num_experts)
+
+ # Calculate the bin bounds for the sorted tokens.
+ # bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ bins = inclusive_cumsum(tokens_per_expert, 0)
+ assert bins is not None
+ bins = bins.view(1) if not len(bins.size()) else bins
+
+ assert isinstance(indices, torch.Tensor)
+ assert isinstance(bin_ids, torch.Tensor)
+ assert isinstance(bins, torch.Tensor)
+ assert isinstance(tokens_per_expert, torch.Tensor)
+
+ return indices, bin_ids, bins, tokens_per_expert
+
+ def permute_and_compute(
+ self,
+ x: torch.Tensor,
+ tokens_per_expert: int, # unused
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor, # unused
+ expert_weights: torch.Tensor,
+ bins: torch.Tensor,
+ expert_capacity: int,
+ top_k: int,
+ ):
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ # output = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
+ output = binned_gather(x, indices, bins, expert_capacity, top_k)
+ assert output is not None
+ x = output
+
+ # Perform the expert computation. Note that we don't
+ # use biases for these linear operations.
+ x = self.mlp(x)
+
+ # Un-route the data for the MoE output.
+ # return ops.binned_scatter(x, indices, expert_weights, bins, top_k)
+ return binned_scatter(x, indices, expert_weights, bins, top_k)
+
+
+ def forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ # If expert_capacity is set to zero, set the number of tokens
+ # per expert to the maximum we need to avoid dropping tokens.
+ sl, bs, _ = x.size()
+ expert_capacity = self.expert_capacity(sl * bs)
+ if expert_capacity == 0:
+ expert_capacity = torch.max(tokens_per_expert).item()
+
+ x = self.permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ self.top_k,
+ )
+ return x, tokens_per_expert
+
+ def parallel_forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ # NOTE: This function implements the same computation as forward_once
+ # but with expert model parallelism.
+ #
+ # 1. Permute the tokens locally so that they are grouped by their
+ # expert assignments. This allows us to transfer all of the tokens
+ # for a remote device in one communication primitive.
+ #
+ # 2. Permute the tokens across the expert parallel devices. After
+ # this is completed each device has all of the tokens assigned to
+ # its set of experts in its local HBM.
+ #
+ # 3. Permute the tokens locally so that they are grouped by their
+ # expert assignement. After the distributed permutation the tokens
+ # are grouped by which device they came from. We re-order them
+ # locally to allow for efficient computation.
+ #
+ # After this series of permutations we compute the linear layers
+ # and then repeat these three steps in reverse to produce the final
+ # output.
+ #
+ # Compute the mapping of local tokens to experts.
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ # If we're sharding the experts along the hidden dimension
+ # multiple devices own parts of the same sets of experts.
+ # Replicate the token counts so every device gets the counts.
+ # repeated_tokens_per_expert = ops.repeat(
+ repeated_tokens_per_expert = repeat(
+ tokens_per_expert,
+ (mpu.hidden_sharding_degree(self.args),),
+ )
+
+ # Pass token count information to the device on which the
+ # target expert resides.
+ parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert,)
+ tpe_handle = dist.all_to_all_single(
+ parallel_tokens_per_expert,
+ repeated_tokens_per_expert,
+ group=self.args.expert_parallel_group,
+ async_op=True,
+ )
+
+ # Permute locally and without any padding so that tokens for each
+ # parallel device are stored contiguously.
+ #
+ # This view updates the shape of the tensor from [sl, bs, hs] to
+ # [sl * bs, hs] prior to the permutation.
+ x = x.view(-1, x.shape[-1])
+ # output = ops.gather(x, indices, bin_ids, bins, self.top_k)
+ output = gather(x, indices, bin_ids, bins, self.top_k)
+ assert output is not None
+ x = output
+
+ # Compute the number of tokens that will be received from each
+ # device and permute the input data across the devices.
+ with torch.no_grad():
+ tpe_handle.wait()
+ experts_per_rank = mpu.experts_per_rank(self.args)
+
+ # Reshape to [world_size, num_experts_per_rank].
+ world_size = mpu.get_expert_parallel_world_size(self.args)
+ repeated_tokens_per_expert = (repeated_tokens_per_expert.view(world_size, experts_per_rank))
+ parallel_tokens_per_expert = (parallel_tokens_per_expert.view(world_size, experts_per_rank))
+
+ # TODO(tgale): It might be faster to do this on the GPU and
+ # then communicate the results back to the host.
+ send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1)
+ parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
+ recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1)
+
+ # Convert the send/recv counts to lists.
+ send_counts = send_counts.tolist()
+ recv_counts = recv_counts.tolist()
+ tokens_received = sum(recv_counts)
+
+ # If we're sharding the experts along the hidden dimension
+ # multiple devices own parts of the same sets of experts.
+ # Replicate the token counts so devices that share experts
+ # get all of the tokens assigned to them.
+ #
+ # TODO(tgale): Fuse this into the prior, local permutation.
+ # x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
+ x = repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
+
+ # Start the cross-device permutation asynchronously so we can
+ # overlap communication with computation.
+ parallel_x, parallel_x_handle = all_to_all(
+ x,
+ recv_counts,
+ send_counts,
+ self.args.expert_parallel_group,
+ async_op=True,
+ )
+
+ with torch.no_grad():
+ # After we do the cross-device permutation we have the tokens on the
+ # correct device but not yet grouped by expert because we received
+ # tokens from each device as contiguous chunks. To group the tokens
+ # for expert computation we'll do one more local permutation. The
+ # rest of this torch.no_grad() scope sets up the indices and bins
+ # for this permutation.
+ # replicate_bins = ops.inclusive_cumsum(
+ replicate_bins = inclusive_cumsum(
+ parallel_tokens_per_expert.flatten(),
+ 0,
+ )
+ replicate_bins = (replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins)
+
+ # Construct the expert indices for the permuted tokens.
+ parallel_top_expert = torch.remainder(
+ torch.arange(
+ self.num_experts * mpu.hidden_sharding_degree(self.args),
+ dtype=torch.int32,
+ device=indices.device,
+ ),
+ mpu.experts_per_rank(self.args),
+ )
+ # parallel_top_expert = ops.replicate(
+ parallel_top_expert = replicate(
+ parallel_top_expert.unsqueeze(dim=0),
+ replicate_bins,
+ tokens_received,
+ ).flatten()
+
+ # TODO(tgale): The sort_end_bit here can be reduced.
+ # parallel_bin_ids, parallel_indices = ops.sort(
+ parallel_bin_ids, parallel_indices = sort(
+ parallel_top_expert,
+ self.sort_end_bit,
+ )
+
+ # Calculate the bins boundaries from the token counts.
+ parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+ # parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = (parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins)
+
+ # If expert_capacity is set to zero, set the number of tokens
+ # per expert to the maximum we need to avoid dropping tokens.
+ tokens, _ = x.size()
+ expert_capacity = self.expert_capacity(tokens)
+ if expert_capacity == 0:
+ expert_capacity = torch.max(parallel_tokens_per_expert).item()
+
+ # Locally permute the tokens and perform the expert computation.
+ # Block to make sure that the cross-device permutation is complete.
+ if self.args.mlp_impl == 'grouped':
+ # GroupedMLP requires counts on CPU. We can use the tensor already
+ # moved to CPU for the prior all_to_all, which avoids an extra
+ # device synchronization.
+ parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+ parallel_x_handle.wait()
+ parallel_x = self.permute_and_compute(
+ parallel_x,
+ parallel_tokens_per_expert,
+ parallel_indices,
+ parallel_bin_ids,
+ None, # expert_weights
+ parallel_bins,
+ expert_capacity,
+ top_k=1,
+ )
+
+ # Un-permute the tokens across the devices.
+ x, _ = all_to_all(
+ parallel_x,
+ send_counts,
+ recv_counts,
+ self.args.expert_parallel_group,
+ )
+
+ # Reduce along the hidden sharding to get the final outputs.
+ #
+ # TODO(tgale): Fuse this into the following local permutation.
+ shape = (
+ mpu.hidden_sharding_degree(self.args),
+ -1,
+ self.args.hidden_size,
+ )
+ # x = ops.sum(x.view(shape), dim=0)
+ x = x.view(shape).sum(dim=0)
+
+ # Un-permute locally to setup for the next series of operations.
+ # x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
+ x = scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
+ return x, tokens_per_expert.flatten()
+
+ def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ in_shape = x.size()
+
+ # Compute the experts.
+ x, tokens_per_expert = self.forward_fn(x, expert_weights, top_experts)
+ if self.training and self.args.moe_loss_weight > 0:
+ save_load_balancing_loss((tokens_per_expert, scores))
+ x = x.view(in_shape)
+ if self.bias is not None:
+ if self.args.return_bias:
+ return x, self.bias
+ return x + self.bias
+ return x
+
+
+class MoE(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super(MoE, self).__init__()
+
+ # Token router.
+ self.router = router.LearnedRouter(args)
+
+ # Expert computation helper.
+ self.experts = self._init_experts_mlp(args)
+
+ self.shared_expert = None
+ if args.shared_expert:
+ # SharedExpert computation helper.
+ self.shared_expert = sharedexpert_registry.get(args)
+
+ def _init_experts_mlp(self, args: Arguments):
+ return ParallelMLP(args)
+
+ def forward(self, x: torch.Tensor):
+ # NOTE: If we're going to cast the activations to lower precision
+ # do it before we permute the tokens to save bandwidth.
+ x = common.cast_if_autocast_enabled(x)
+
+ # Compute the expert scores and assignments.
+ scores, expert_weights, top_experts = self.router(x)
+
+ # Compute the experts.
+ out = self.experts(x, scores, expert_weights, top_experts)
+ if self.shared_expert is not None:
+ shared_expert_out = self.shared_expert(x)
+ out = self.shared_expert.add_experts_sharedexpert(
+ shared_expert_out,
+ out,
+ )
+ return out
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/_layers/mpu.py b/build/torch28-cxx11-cu128-x86_64-linux/_layers/mpu.py
new file mode 100644
index 0000000000000000000000000000000000000000..434e143ab42bf3f83406d69e9dd1f72777716e22
--- /dev/null
+++ b/build/torch28-cxx11-cu128-x86_64-linux/_layers/mpu.py
@@ -0,0 +1,94 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Optional
+
+import torch
+import torch.distributed as dist
+
+# from megablocks.layers.arguments import Arguments
+from .arguments import Arguments
+
+
+class MoeParam(torch.Tensor):
+
+ def __init__(self):
+ super().__init__(self)
+ self.expert_model_parallel: bool
+
+
+def is_moe_param(tensor: torch.Tensor) -> bool:
+ return hasattr(tensor, 'expert_model_parallel')
+
+
+def get_expert_parallel_world_size(args: Arguments) -> int:
+ return (dist.get_world_size(args.expert_parallel_group) if args.moe_expert_model_parallelism else 1)
+
+
+def get_expert_parallel_rank(args: Arguments) -> int:
+ return (dist.get_rank(args.expert_parallel_group) if args.moe_expert_model_parallelism else 0)
+
+
+def set_expert_model_parallel_attributes(
+ tensor: torch.Tensor,
+ is_parallel: bool,
+):
+ assert not hasattr(tensor, 'expert_model_parallel')
+ setattr(tensor, 'expert_model_parallel', is_parallel)
+
+
+def param_is_expert_model_parallel(param: MoeParam) -> bool:
+ return (hasattr(param, 'expert_model_parallel') and param.expert_model_parallel)
+
+
+def copy_expert_model_parallel_attributes(
+ destination_tensor: torch.Tensor,
+ source_tensor: torch.Tensor,
+):
+ if hasattr(source_tensor, 'expert_model_parallel'):
+ setattr(
+ destination_tensor,
+ 'expert_model_parallel',
+ getattr(source_tensor, 'expert_model_parallel'),
+ )
+
+
+def synchronized_print(group: Optional[dist.ProcessGroup], *x: torch.Tensor):
+ world_size = dist.get_world_size(group)
+ rank = dist.get_rank(group)
+ for i in range(world_size):
+ dist.barrier(group)
+ if i == rank:
+ print(f'rank = {rank}', *x)
+
+
+# Helpers for expert/tensor sharding.
+def expert_sharding_degree(args: Arguments) -> int:
+ world_size = get_expert_parallel_world_size(args)
+ esd = min(world_size, args.moe_num_experts)
+
+ if (args.moe_num_experts % esd) != 0:
+ raise ValueError(f'Cannot shard {args.moe_num_experts} experts {esd} ways.',)
+ return esd
+
+
+def hidden_sharding_degree(args: Arguments) -> int:
+ world_size = get_expert_parallel_world_size(args)
+ esd = expert_sharding_degree(args)
+ hsd = world_size // esd
+
+ if (args.ffn_hidden_size % hsd) != 0:
+ raise ValueError(f'Cannot shard {args.ffn_hidden_size} features {hsd} ways.',)
+ if (esd * hsd) != world_size:
+ raise ValueError(
+ f"Invalid sharding. 'expert_sharding_degree' ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size}).",
+ )
+ return hsd
+
+
+def experts_per_rank(args: Arguments) -> int:
+ return args.moe_num_experts // expert_sharding_degree(args)
+
+
+def features_per_rank(args: Arguments) -> int:
+ return args.ffn_hidden_size // hidden_sharding_degree(args)
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/_layers/router.py b/build/torch28-cxx11-cu128-x86_64-linux/_layers/router.py
new file mode 100644
index 0000000000000000000000000000000000000000..37cb2782348d62583376f1a183c7ede83601216d
--- /dev/null
+++ b/build/torch28-cxx11-cu128-x86_64-linux/_layers/router.py
@@ -0,0 +1,116 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+
+# from megablocks.layers import common
+# from megablocks.layers.arguments import Arguments
+from . import common
+from .arguments import Arguments
+
+_ROUTER_LOGITS = []
+
+
+def _save_router_logits(logits: torch.Tensor, args: Arguments):
+ if args.moe_zloss_weight == 0:
+ return
+ global _ROUTER_LOGITS
+ _ROUTER_LOGITS.append(logits)
+
+
+def clear_router_zloss():
+ global _ROUTER_LOGITS
+ _ROUTER_LOGITS.clear()
+
+
+def batched_router_zloss(args: Arguments):
+ global _ROUTER_LOGITS
+
+ if args.moe_zloss_weight == 0:
+ import warnings
+ warnings.warn('Call to batched_router_zloss, but moe_zloss_weight=0')
+ return 0
+
+ logits_per_router = _ROUTER_LOGITS
+
+ if args.moe_zloss_in_fp32:
+ logits_per_router = [logits.float() for logits in logits_per_router]
+
+ unscaled_zloss_per_router = torch.stack([
+ torch.logsumexp(logits, dim=1).square().mean() for logits in logits_per_router
+ ])
+
+ return args.moe_zloss_weight * unscaled_zloss_per_router
+
+
+# NOTE: To enable end-to-end benchmarking without convergence we
+# support a flag to force the router to assign tokens uniformly
+# across the experts. We do this with a custom autograd operation
+# so that PyTorch still executes the full set of router operation.
+class _UniformExpertAssignment(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, num_experts: int):
+ out = torch.arange(x.numel(), dtype=x.dtype, device=x.device)
+ out = torch.remainder(out, num_experts)
+ return out.view(x.shape)
+
+
+_uniform_expert_assignment = _UniformExpertAssignment.apply
+
+
+class LearnedRouter(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+
+ # Learned router parameters.
+ #
+ # NOTE: This weight matrix is not parallelized with expert model
+ # parallelism. Each device needs the entire router weight matrix
+ # so that it can route its batch of data correctly.
+ self.layer = torch.nn.Linear(
+ args.hidden_size,
+ args.moe_num_experts,
+ bias=False,
+ dtype=common.dtype(args),
+ device=args.device,
+ )
+ args.init_method(self.layer.weight)
+
+ def jitter(self, x: torch.Tensor):
+ low: float = 1.0 - self.args.moe_jitter_eps
+ high: float = 1.0 + self.args.moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return low + noise * (high - low)
+
+ def _top_k(self, scores: torch.Tensor):
+ if self.args.moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, self.args.moe_top_k, dim=-1)
+
+ def forward(self, x: torch.Tensor):
+ if self.training and self.args.moe_jitter_eps is not None:
+ x = x * self.jitter(x)
+
+ logits = self.layer(x.view(-1, x.shape[-1]))
+ _save_router_logits(logits, self.args)
+ scores = logits.softmax(dim=-1)
+ expert_weights, expert_indices = self._top_k(scores)
+ if self.args.moe_normalize_expert_weights:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=self.args.moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+
+ expert_indices = (
+ _uniform_expert_assignment(
+ expert_indices,
+ self.args.moe_num_experts,
+ ) if self.args.uniform_expert_assignment else expert_indices
+ )
+ return scores, expert_weights, expert_indices
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/_layers/sharedexpert_registry.py b/build/torch28-cxx11-cu128-x86_64-linux/_layers/sharedexpert_registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..5840862f88f370ace5fd49bd0612fc98d186cc49
--- /dev/null
+++ b/build/torch28-cxx11-cu128-x86_64-linux/_layers/sharedexpert_registry.py
@@ -0,0 +1,32 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Union
+
+# from megablocks.layers import glu, mlp
+# from megablocks.layers.arguments import Arguments
+from . import glu, mlp
+from .arguments import Arguments
+
+_REGISTRY = {
+ 'mlp': mlp.SharedMLP,
+ 'glu': glu.SharedGLU,
+}
+
+
+def get(args: Arguments) -> Union[mlp.SharedMLP, glu.SharedGLU]:
+ """Returns an SharedMLP for use in a dMoE instance.
+
+ Uses the provided arguments to instantiate the appropriate
+ SharedMLP instance.
+
+ Args:
+ args: propagated Arguments dataclass.
+
+ Returns:
+ An instantiated SharedMLP constructed using the input args.
+ """
+ if args.mlp_type not in _REGISTRY:
+ raise ValueError(f'Unsupported mlp type: {args.mlp_type}')
+
+ return _REGISTRY[args.mlp_type](args)
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/_megablocks_4f35d2a.abi3.so b/build/torch28-cxx11-cu128-x86_64-linux/_megablocks_4f35d2a.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..13621e1ff5f2f2b5852f2977186205013d3b2b62
--- /dev/null
+++ b/build/torch28-cxx11-cu128-x86_64-linux/_megablocks_4f35d2a.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:30359e1959d207a4ba71879037aafd22b8dada4c664c191dc4310f8b108131f8
+size 20995704
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/_ops.py b/build/torch28-cxx11-cu128-x86_64-linux/_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..69479df3044627d4a8ac3fb70d0b1f0e9b22deed
--- /dev/null
+++ b/build/torch28-cxx11-cu128-x86_64-linux/_ops.py
@@ -0,0 +1,9 @@
+import torch
+from . import _megablocks_4f35d2a
+ops = torch.ops._megablocks_4f35d2a
+
+def add_op_namespace_prefix(op_name: str):
+ """
+ Prefix op by namespace.
+ """
+ return f"_megablocks_4f35d2a::{op_name}"
\ No newline at end of file
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/_version.py b/build/torch28-cxx11-cu128-x86_64-linux/_version.py
new file mode 100644
index 0000000000000000000000000000000000000000..c55783177af19bc03654c730c4892df8f8532279
--- /dev/null
+++ b/build/torch28-cxx11-cu128-x86_64-linux/_version.py
@@ -0,0 +1,6 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+"""The MegaBlocks Version."""
+
+__version__ = '0.11.0.dev0'
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/backend/__init__.py b/build/torch28-cxx11-cu128-x86_64-linux/backend/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d4e43e9b7e6a9a1c3bde2df34914643ca5d8332
--- /dev/null
+++ b/build/torch28-cxx11-cu128-x86_64-linux/backend/__init__.py
@@ -0,0 +1,2 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/backend/kernels.py b/build/torch28-cxx11-cu128-x86_64-linux/backend/kernels.py
new file mode 100644
index 0000000000000000000000000000000000000000..d20139324883338ddb312e4b05a72056d74491ac
--- /dev/null
+++ b/build/torch28-cxx11-cu128-x86_64-linux/backend/kernels.py
@@ -0,0 +1,557 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import triton
+import triton.language as tl
+
+# Stub triton autotune when testing in a env that does not have CUDA
+# this approach preserves the original code but enables testing without a GPU
+if torch.cuda.is_available() is False:
+ import warnings
+
+ warnings.warn("CUDA is not available. Triton autotuning is disabled.")
+
+ def _no_autotune(*args, **kwargs):
+ def deco(fn):
+ return fn
+ return deco
+
+ triton.autotune = _no_autotune
+
+
+def assert_is_tensor(x, ndim):
+ if x.ndim != ndim:
+ raise ValueError(f'Expected {ndim}-tensor but got {x.ndim}-tensor')
+
+
+def assert_is_matrix(x):
+ assert_is_tensor(x, 2)
+
+
+def assert_is_vector(x):
+ if x.ndim != 1:
+ raise ValueError(f'Expected 1-tensor but got {x.ndim}-tensor')
+
+
+def assert_equal(a, b):
+ if a != b:
+ raise ValueError(f'Expected dimensions to be equal but got {a} and {b}.',)
+
+
+# a: (tokens, hidden_size), real.
+# indices: (tokens * top_k), integer.
+# bin_ids: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+# padded_bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _padded_copy(
+ a,
+ b,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+ A_TO_B: tl.constexpr,
+ SCALE: tl.constexpr,
+):
+ # Our index into array 'a'.
+ index_a = tl.load(indices + tl.program_id(0))
+
+ # One threadblock per row in 'a'. Array 'b' has greater or equal
+ # number of rows since they could be padded.
+ bin_idx = tl.load(bin_ids + tl.program_id(0))
+
+ # Now we know what bin we're assigned to, but we need to know how
+ # many threadblocks were assigned to earlier bins so we can offset
+ # in our bin properly.
+ offset_in_bin = tl.program_id(0)
+ if bin_idx > 0:
+ offset_in_bin -= tl.load(bins + bin_idx - 1)
+
+ # Load the starting index of our bin in array 'b'.
+ index_b = offset_in_bin
+ if bin_idx > 0:
+ index_b += tl.load(padded_bins + bin_idx - 1)
+
+ # Offset the input and output pointers.
+ #
+ # If we're going from A to B, divide the input index to copy
+ # the same input repeatedly. If we're going from B to A we
+ # need to reduce the result. Using atomics is slow, so we
+ # do the reduce step in a second kernel.
+ offset = index_a // TOP_K if A_TO_B else index_a
+ a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS)
+ b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ # Load the scale, if requested.
+ scale = tl.load(weights + index_a) if SCALE else 1
+
+ # Swap the pointers depending on the direction.
+ iptr = a if A_TO_B else b
+ optr = b if A_TO_B else a
+
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ x = tl.load(iptr + offsets, mask=mask)
+ x = x.to(tl.float32) * scale.to(tl.float32)
+
+ tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask)
+
+ offsets += BLOCK_X
+
+
+def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+ assert_equal(bin_ids.shape[0], x.shape[0] * top_k)
+ assert_equal(bins.size(), padded_bins.size())
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ # NOTE: Because of the padding, the output size is dynamic.
+ # We load the final padded bin bound to get the output rows.
+ output_rows = padded_bins[-1].cpu().item()
+ out = torch.zeros((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ x,
+ out,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def gather(x, indices, bin_ids, weights, bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+ assert_equal(bin_ids.shape[0], x.shape[0] * top_k)
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ # NOTE: There is no padding so the output rows equals the
+ # input rows multiplied by top_k.
+ output_rows = x.shape[0] * top_k
+ out = torch.empty((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ x,
+ out,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], bin_ids.shape[0])
+ assert_equal(bins.size(), padded_bins.size())
+
+ if weights is not None:
+ assert_equal(indices.shape[0], weights.shape[0])
+
+ tokens = indices.shape[0] // top_k
+ out = torch.empty((tokens, top_k, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ out,
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=False,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+
+ # Reduce along the top-k dimension, if needed.
+ return out.sum(dim=1) if top_k > 1 else out.view(tokens, x.shape[1])
+
+
+def scatter(x, indices, bin_ids, weights, bins, top_k):
+ return padded_scatter(x, indices, bin_ids, weights, bins, bins, top_k)
+
+
+# x: (tokens, top_k, hidden_size), real
+# grad: (tokens, hidden_size), real.
+# wgrad: (tokens, top_k), real.
+# indices: (tokens * top_k), integer.
+# bin_ids: (tokens * top_k), integer.
+# bins: (num_experts), integer.
+# padded_bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _padded_copy_wgrad(
+ x,
+ grad,
+ wgrad,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+):
+ # Our index into 'tokens * top_k'.
+ index_out = tl.load(indices + tl.program_id(0))
+
+ # One threadblock per row in 'a'. Array 'b' has greater or equal
+ # number of rows since they could be padded.
+ bin_idx = tl.load(bin_ids + tl.program_id(0))
+
+ # Now we know what bin we're assigned to, but we need to know how
+ # many threadblocks were assigned to earlier bins so we can offset
+ # in our bin properly.
+ offset_in_bin = tl.program_id(0)
+ if bin_idx > 0:
+ offset_in_bin -= tl.load(bins + bin_idx - 1)
+
+ # Load the starting index of our bin in array 'x'.
+ index_x = offset_in_bin
+ if bin_idx > 0:
+ index_x += tl.load(padded_bins + bin_idx - 1)
+
+ # Offset the input and output pointers.
+ wgrad += index_out
+ grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS)
+ x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ data = tl.load(x + offsets, mask=mask).to(tl.float32)
+ scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
+ acc += data * scale
+ offsets += BLOCK_X
+
+ # Reduce to get the final result and store.
+ out = tl.sum(acc).to(wgrad.dtype.element_ty)
+ tl.store(wgrad, out)
+
+
+def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_matrix(grad)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], bin_ids.shape[0])
+ assert_equal(bins.size(), padded_bins.size())
+
+ tokens = indices.shape[0] // top_k
+ out = torch.empty((tokens * top_k), dtype=x.dtype, device=x.device)
+ _padded_copy_wgrad[(indices.shape[0],)](
+ x,
+ grad,
+ out,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ TOP_K=top_k,
+ )
+ return out
+
+
+def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k):
+ return padded_scatter_wgrad(x, grad, indices, bin_ids, bins, bins, top_k)
+
+
+# a: (tokens, hidden_size), real.
+# b: (num_experts, expert_capacity, num_columns), real.
+# indices: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _binned_copy(
+ a,
+ b,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+ A_TO_B: tl.constexpr,
+ SCALE: tl.constexpr,
+):
+ # Load our indices into the output.
+ expert_idx = tl.program_id(0)
+ entry_idx = tl.program_id(1)
+
+ # Calculate our offset into the output.
+ index_b = expert_idx * expert_capacity + entry_idx
+
+ # Load the index bounds for our bin and calculate
+ # the number of tokens assigned to our expert.
+ start = 0
+ if expert_idx > 0:
+ start = tl.load(bins + expert_idx - 1)
+ end = tl.load(bins + expert_idx)
+ num_tokens = end - start
+
+ # Calculate our offset into the input. If we don't
+ # have an input exit early.
+ if entry_idx >= num_tokens:
+ return
+ index_a = tl.load(indices + start + entry_idx)
+
+ # Offset the input and output pointers.
+ #
+ # If we're going from A to B, divide the input index to copy
+ # the same input repeatedly. If we're going from B to A we
+ # need to reduce the result. Using atomics is slow, so we
+ # do the reduce step in a second kernel.
+ offset = index_a // TOP_K if A_TO_B else index_a
+ a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS)
+ b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ # Load the scale, if requested.
+ scale = tl.load(weights + index_a) if SCALE else 1
+
+ # Swap the pointers depending on the direction.
+ #
+ # NOTE: We need to zero the output in both directions.
+ iptr = a if A_TO_B else b
+ optr = b if A_TO_B else a
+
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ x = tl.load(iptr + offsets, mask=mask)
+ x = x.to(tl.float32) * scale.to(tl.float32)
+
+ tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask)
+
+ offsets += BLOCK_X
+
+
+def binned_gather(x, indices, weights, bins, expert_capacity, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ num_experts = bins.shape[0]
+ out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
+
+ _binned_copy[(num_experts, expert_capacity)](
+ x,
+ out,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def binned_scatter(x, indices, weights, bins, top_k):
+ # Validate the input shapes.
+ assert_is_tensor(x, 3)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(bins.shape[0], x.shape[0])
+
+ if weights is not None:
+ assert_equal(indices.shape[0], weights.shape[0])
+
+ num_experts, expert_capacity, hidden_size = x.shape
+ tokens = indices.shape[0] // top_k
+ out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
+ _binned_copy[(num_experts, expert_capacity)](
+ out,
+ x,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS=hidden_size,
+ A_TO_B=False,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+
+ # Reduce along the top-k dimension, if needed.
+ return out.sum(dim=1) if top_k > 1 else out.view(tokens, hidden_size)
+
+
+# a: (tokens, hidden_size), real.
+# b: (num_experts, expert_capacity, num_columns), real.
+# indices: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _binned_copy_wgrad(
+ x,
+ grad,
+ wgrad,
+ num_experts,
+ expert_capacity,
+ indices,
+ bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+):
+ # Load our indices into the output.
+ expert_idx = tl.program_id(0)
+ entry_idx = tl.program_id(1)
+
+ # Calculate our offset into the output.
+ index_x = expert_idx * expert_capacity + entry_idx
+
+ # Load the index bounds for our bin and calculate
+ # the number of tokens assigned to our expert.
+ start = 0
+ if expert_idx > 0:
+ start = tl.load(bins + expert_idx - 1)
+ end = tl.load(bins + expert_idx)
+ num_tokens = end - start
+
+ # Calculate our offset into the input. If we don't
+ # have an input exit early.
+ if entry_idx >= num_tokens:
+ return
+ index_out = tl.load(indices + start + entry_idx)
+
+ # Offset the input and output pointers.
+ wgrad += index_out
+ grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS)
+ x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ data = tl.load(x + offsets, mask=mask).to(tl.float32)
+ scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
+ acc += data * scale
+ offsets += BLOCK_X
+
+ # Reduce to get the final result and store.
+ out = tl.sum(acc).to(wgrad.dtype.element_ty)
+ tl.store(wgrad, out)
+
+
+def binned_scatter_wgrad(x, grad, indices, bins, top_k):
+ # Validate the input shapes.
+ assert_is_tensor(x, 3)
+ assert_is_matrix(grad)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(bins.shape[0], x.shape[0])
+
+ num_experts, expert_capacity, hidden_size = x.shape
+ tokens = indices.shape[0] // top_k
+ out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
+ _binned_copy_wgrad[(num_experts, expert_capacity)](
+ x,
+ grad,
+ out,
+ num_experts,
+ expert_capacity,
+ indices,
+ bins,
+ NUM_COLUMNS=hidden_size,
+ TOP_K=top_k,
+ )
+ return out
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/benchmark_util.py b/build/torch28-cxx11-cu128-x86_64-linux/benchmark_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..02612d95e3ead1175a596e2878fa34b5bf85ad6f
--- /dev/null
+++ b/build/torch28-cxx11-cu128-x86_64-linux/benchmark_util.py
@@ -0,0 +1,35 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import numpy as np
+import torch
+
+
+def log_benchmark(name, arguments, time, std):
+ print('=' * 60)
+ print(f'{name} Benchmark')
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean time = {:.3f}ms, std time = {:.3f}ms'.format(time, std))
+ print('=' * 60)
+
+
+def benchmark_function(fn, iterations=100, warmup=10):
+ # Warmup iterations.
+ for _ in range(warmup):
+ fn()
+
+ times = []
+ for i in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+
+ start.record()
+ fn()
+ end.record()
+
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ return np.mean(times), np.std(times)
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/grouped_gemm/__init__.py b/build/torch28-cxx11-cu128-x86_64-linux/grouped_gemm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b91c8308f0c24f4c4171b6e4f15b6f76dabf295a
--- /dev/null
+++ b/build/torch28-cxx11-cu128-x86_64-linux/grouped_gemm/__init__.py
@@ -0,0 +1,2 @@
+from . import ops
+from . import backend
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/grouped_gemm/backend.py b/build/torch28-cxx11-cu128-x86_64-linux/grouped_gemm/backend.py
new file mode 100644
index 0000000000000000000000000000000000000000..76037d8039cbfc2f0577275c78e4bc0be762592a
--- /dev/null
+++ b/build/torch28-cxx11-cu128-x86_64-linux/grouped_gemm/backend.py
@@ -0,0 +1,33 @@
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# # TODO(tgale): Wrap this in a try-block with better
+# # error message and instructions for building the
+# # c++ operations.
+# import grouped_gemm_backend as backend
+
+# We import the backend operations from the megablocks package as
+# grouped_gemm is vendored in megablocks in this repository.
+# from ... import _ops as backend
+# from megablocks._ops import ops as backend # type: ignore
+from .._ops import ops as backend # type: ignore
+
+def _allocate_output(a, b, batch_sizes, trans_a, trans_b):
+ assert not (trans_a and trans_b)
+ assert batch_sizes.ndim == 1, "Expected 1d tensor for batch_sizes"
+ assert a.ndim == 2, "Expected 2d tensor for 'a'"
+ assert b.ndim == (2 if trans_a else 3)
+
+ shape = (
+ (batch_sizes.shape[0], a.shape[1], b.shape[1])
+ if trans_a else
+ (a.shape[0], (b.shape[1] if trans_b else b.shape[2]))
+ )
+ return torch.empty(*shape, device=a.device, dtype=a.dtype)
+
+def gmm(a, b, batch_sizes, trans_a=False, trans_b=False, c=None):
+ if c is None:
+ c = _allocate_output(a, b, batch_sizes, trans_a, trans_b)
+ backend.gmm(a, b, c, batch_sizes, trans_a, trans_b)
+ return c
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/grouped_gemm/ops.py b/build/torch28-cxx11-cu128-x86_64-linux/grouped_gemm/ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b30dd14e23837ea3b12334f4e31337ed9ad2b69
--- /dev/null
+++ b/build/torch28-cxx11-cu128-x86_64-linux/grouped_gemm/ops.py
@@ -0,0 +1,33 @@
+from . import backend
+import torch
+
+
+class GroupedGemm(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, a, b, batch_sizes, trans_b):
+ ctx.save_for_backward(a, b, batch_sizes)
+ ctx.trans_b = trans_b
+ return backend.gmm(a, b, batch_sizes, trans_a=False, trans_b=trans_b)
+
+ @staticmethod
+ def backward(ctx, grad):
+ grad = grad.contiguous()
+ a, b, batch_sizes = ctx.saved_tensors
+ trans_b = ctx.trans_b
+
+ agrad = None
+ if ctx.needs_input_grad[0]:
+ agrad = backend.gmm(
+ grad, b, batch_sizes, trans_a=False, trans_b=not trans_b)
+
+ bgrad = None
+ if ctx.needs_input_grad[1]:
+ lhs, rhs = (grad, a) if trans_b else (a, grad)
+ bgrad = backend.gmm(
+ lhs, rhs, batch_sizes, trans_a=True, trans_b=False)
+ return agrad, bgrad, None, None
+
+
+def gmm(a, b, batch_sizes, trans_b=False):
+ return GroupedGemm.apply(a, b, batch_sizes, trans_b)
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/grouped_gemm_util.py b/build/torch28-cxx11-cu128-x86_64-linux/grouped_gemm_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d6d49fc46b0a57ad46e4179df3cc1ac2a24f7ae
--- /dev/null
+++ b/build/torch28-cxx11-cu128-x86_64-linux/grouped_gemm_util.py
@@ -0,0 +1,31 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+import warnings
+
+_grouped_gemm_is_available: bool = False
+try:
+ # import grouped_gemm
+ pass
+ _grouped_gemm_is_available = True
+except ImportError as error:
+ warnings.warn('Grouped GEMM not available.')
+
+
+def grouped_gemm_is_available():
+ return _grouped_gemm_is_available
+
+
+def assert_grouped_gemm_is_available():
+ msg = (
+ 'Grouped GEMM not available. Please run '
+ '`pip install git+https://github.com/tgale96/grouped_gemm@main`.',
+ )
+ assert _grouped_gemm_is_available, msg
+
+
+# backend = grouped_gemm.backend if grouped_gemm_is_available() else None
+# ops = grouped_gemm.ops if grouped_gemm_is_available() else None
+
+
+from .grouped_gemm import backend as ops
+from .grouped_gemm import ops as backend
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/layers.py b/build/torch28-cxx11-cu128-x86_64-linux/layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..d3a66957b08d748fd5b4fca8ad5f2c68c81cf429
--- /dev/null
+++ b/build/torch28-cxx11-cu128-x86_64-linux/layers.py
@@ -0,0 +1,1230 @@
+import torch
+import torch.distributed as dist
+
+from typing import Optional, Any, TYPE_CHECKING
+
+from . import _layers
+from . import ops
+
+# Conditional import for meta kernel registration
+if TYPE_CHECKING:
+
+ def register_fake(fn):
+ return lambda name: fn
+
+else:
+ try:
+ from torch.library import register_fake
+ except ImportError:
+ try:
+ from torch.library import impl_abstract as register_fake
+ except ImportError:
+ # Fallback for older PyTorch versions
+ def register_fake(op_name):
+ def decorator(fn):
+ return fn
+
+ return decorator
+
+
+# Meta kernel implementations for torch.compile compatibility
+def _install_meta_kernels():
+ """Install meta kernels for existing MegaBlocks operations"""
+
+ # Create wrapper functions that check for compilation and return meta tensors
+
+ # Patch ops.sort
+ if hasattr(ops, "sort"):
+ original_sort = ops.sort
+
+ def sort_with_meta(x, end_bit=None):
+ if torch.compiler.is_compiling():
+ print("Using meta kernel for sort")
+ # Meta implementation - return tensors with correct shape/dtype/device
+ return torch.empty_like(x), torch.empty_like(x)
+ # print("Using original sort kernel")
+ return original_sort(x, end_bit)
+
+ ops.sort = sort_with_meta
+
+ # Patch ops.histogram
+ if hasattr(ops, "histogram"):
+ original_histogram = ops.histogram
+
+ def histogram_with_meta(x, max_val):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty((max_val,), dtype=torch.int32, device=x.device)
+ return original_histogram(x, max_val)
+
+ ops.histogram = histogram_with_meta
+
+ # Patch ops.inclusive_cumsum
+ if hasattr(ops, "inclusive_cumsum"):
+ original_inclusive_cumsum = ops.inclusive_cumsum
+
+ def inclusive_cumsum_with_meta(x, dim):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty_like(x)
+ return original_inclusive_cumsum(x, dim)
+
+ ops.inclusive_cumsum = inclusive_cumsum_with_meta
+
+ # Patch ops.binned_gather
+ if hasattr(ops, "binned_gather"):
+ original_binned_gather = ops.binned_gather
+
+ def binned_gather_with_meta(x, indices, bins, bin_size, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - output shape based on bin_size
+ if x.dim() >= 2:
+ hidden_size = x.size(-1)
+ return torch.empty(
+ (bin_size, x.size(1), hidden_size),
+ dtype=x.dtype,
+ device=x.device,
+ )
+ else:
+ return torch.empty((bin_size,), dtype=x.dtype, device=x.device)
+ return original_binned_gather(x, indices, bins, bin_size, top_k)
+
+ ops.binned_gather = binned_gather_with_meta
+
+ # Patch ops.binned_scatter
+ if hasattr(ops, "binned_scatter"):
+ original_binned_scatter = ops.binned_scatter
+
+ def binned_scatter_with_meta(x, indices, weights, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - typically reduces to 2D
+ if x.dim() >= 3:
+ return torch.empty(
+ (x.size(1), x.size(2)), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty_like(x)
+ return original_binned_scatter(x, indices, weights, bins, top_k)
+
+ ops.binned_scatter = binned_scatter_with_meta
+
+ # Patch ops.gather
+ if hasattr(ops, "gather"):
+ original_gather = ops.gather
+
+ def gather_with_meta(x, indices, bin_ids, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ if x.dim() >= 2:
+ hidden_size = x.size(-1)
+ return torch.empty(
+ (indices.numel(), hidden_size), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty(indices.shape, dtype=x.dtype, device=x.device)
+ return original_gather(x, indices, bin_ids, bins, top_k)
+
+ ops.gather = gather_with_meta
+
+ # Patch ops.scatter
+ if hasattr(ops, "scatter"):
+ original_scatter = ops.scatter
+
+ def scatter_with_meta(x, indices, bin_ids, weights, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - restore sequence shape
+ seq_len = (
+ indices.size(0) // top_k
+ if indices.numel() > 0 and top_k > 0
+ else x.size(0)
+ )
+ if x.dim() >= 2:
+ return torch.empty(
+ (seq_len, x.size(-1)), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty((seq_len,), dtype=x.dtype, device=x.device)
+ return original_scatter(x, indices, bin_ids, weights, bins, top_k)
+
+ ops.scatter = scatter_with_meta
+
+ # Patch ops.replicate
+ if hasattr(ops, "replicate"):
+ original_replicate = ops.replicate
+
+ def replicate_with_meta(x, bins, num_outputs):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty(
+ (x.shape[0], num_outputs), dtype=x.dtype, device=x.device
+ )
+ return original_replicate(x, bins, num_outputs)
+
+ ops.replicate = replicate_with_meta
+
+ # Patch ops.repeat (if it's a regular function)
+ if hasattr(ops, "repeat"):
+ original_repeat = ops.repeat
+
+ def repeat_with_meta(x, repeats):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ if isinstance(repeats, (tuple, list)):
+ new_shape = list(x.shape)
+ for i, rep in enumerate(repeats):
+ if i < len(new_shape):
+ new_shape[i] *= rep
+ return torch.empty(new_shape, dtype=x.dtype, device=x.device)
+ else:
+ new_shape = [x.size(0) * repeats] + list(x.shape[1:])
+ return torch.empty(new_shape, dtype=x.dtype, device=x.device)
+ return original_repeat(x, repeats)
+
+ ops.repeat = repeat_with_meta
+
+
+# Install meta kernels on import
+try:
+ _install_meta_kernels()
+except Exception as e:
+ # If meta kernel installation fails, continue without them
+ # torch.compile may not work but the library will still function
+ import warnings
+
+ warnings.warn(
+ f"Failed to install meta kernels for torch.compile support: {e}", UserWarning
+ )
+
+
+# Set the expert model parallel attributes on a tensor
+def set_expert_model_parallel_attributes(
+ tensor: torch.Tensor,
+ is_parallel: bool,
+):
+ assert not hasattr(tensor, "expert_model_parallel")
+ setattr(tensor, "expert_model_parallel", is_parallel)
+
+
+# Get the expert model parallel attributes from a tensor
+def expert_sharding_degree(
+ world_size: int,
+ moe_num_experts: int,
+) -> int:
+ esd = min(world_size, moe_num_experts)
+ if (moe_num_experts % esd) != 0:
+ raise ValueError(f"Cannot shard {moe_num_experts} experts {esd} ways.")
+ return esd
+
+
+# Calculate the hidden sharding degree based on world size and expert sharding degree
+def hidden_sharding_degree(
+ world_size: int,
+ moe_num_experts: int,
+ ffn_hidden_size: int,
+) -> int:
+ esd = expert_sharding_degree(world_size, moe_num_experts)
+ hsd = world_size // esd
+ if (ffn_hidden_size % hsd) != 0:
+ raise ValueError(f"Cannot shard {ffn_hidden_size} features {hsd} ways.")
+ if (esd * hsd) != world_size:
+ raise ValueError(
+ f"Invalid sharding. expert_sharding_degree ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size})."
+ )
+ return hsd
+
+
+# Calculate the number of experts per rank based on world size and expert sharding degree
+def experts_per_rank(
+ moe_num_experts: int,
+ world_size: int,
+) -> int:
+ return moe_num_experts // expert_sharding_degree(world_size, moe_num_experts)
+
+
+# Calculate the number of features per rank based on ffn hidden size and hidden sharding degree
+def features_per_rank(
+ ffn_hidden_size: int, world_size: int, moe_num_experts: int
+) -> int:
+ return ffn_hidden_size // hidden_sharding_degree(
+ world_size, moe_num_experts, ffn_hidden_size
+ )
+
+
+# Apply jitter to the input tensor
+def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor:
+ low = 1.0 - moe_jitter_eps
+ high = 1.0 + moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return x * (low + noise * (high - low))
+
+
+# Compute the top-k scores from the logits
+def compute_top_k(scores: torch.Tensor, moe_top_k: int):
+ if moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, moe_top_k, dim=-1)
+
+
+# Route tokens to experts and compute expert weights and indices
+def route_tokens(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: torch.Tensor,
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ if training and moe_jitter_eps is not None:
+ x = apply_jitter(x, moe_jitter_eps)
+
+ x_flat = x.view(-1, x.shape[-1])
+ logits = torch.nn.functional.linear(x_flat, router_weight, router_bias)
+ expert_weights, expert_indices = compute_top_k(logits, moe_top_k)
+ expert_weights = expert_weights.softmax(dim=-1)
+ if moe_normalize_expert_weights is not None:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+ if uniform_expert_assignment:
+ expert_indices = _layers.router._uniform_expert_assignment(
+ expert_indices,
+ moe_num_experts,
+ )
+
+ return logits, expert_weights, expert_indices
+
+
+# Scale the gradient of the weights
+def scale_grad(
+ w: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+) -> torch.Tensor:
+ if gradient_scale is None:
+ return w
+ return _layers.mlp.scale_gradient(w, gradient_scale)
+
+
+# Forward pass for the MLP layer
+def mlp_forward(
+ x: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ limit: float = 7.0,
+):
+ # Scale weights
+ w1 = scale_grad(w1, gradient_scale)
+ w2 = scale_grad(w2, gradient_scale)
+ w1_bias = scale_grad(w1_bias, gradient_scale)
+ w2_bias = scale_grad(w2_bias, gradient_scale)
+
+ # Resolve dtensors
+ w1 = _layers.mlp.resolve_dtensor(w1)
+ w2 = _layers.mlp.resolve_dtensor(w2)
+ w1_bias = _layers.mlp.resolve_dtensor(w1_bias)
+ w2_bias = _layers.mlp.resolve_dtensor(w2_bias)
+
+ # Forward pass
+ gate_up = torch.bmm(x, w1) + w1_bias[..., None, :]
+ gate, up = gate_up[..., ::2], gate_up[..., 1::2]
+ gate = gate.clamp(min=None, max=limit)
+ up = up.clamp(min=-limit, max=limit)
+ glu = gate * torch.sigmoid(gate * alpha)
+ next_states = torch.bmm(((up + 1) * glu), w2)
+ next_states += w2_bias[..., None, :]
+ return next_states
+
+# Shared expert MLP forward pass
+def shared_mlp_forward(
+ x: torch.Tensor,
+ up_proj_weight: torch.Tensor,
+ down_proj_weight: torch.Tensor,
+ up_proj_bias: Optional[torch.Tensor] = None,
+ down_proj_bias: Optional[torch.Tensor] = None,
+ activation_fn: Optional[Any] = None,
+ gradient_scale: Optional[float] = None,
+) -> torch.Tensor:
+ # Default activation function
+ if activation_fn is None:
+ activation_fn = torch.nn.functional.gelu
+
+ # Scale weights
+ up_proj_weight = scale_grad(up_proj_weight, gradient_scale)
+ down_proj_weight = scale_grad(down_proj_weight, gradient_scale)
+ if up_proj_bias is not None:
+ up_proj_bias = scale_grad(up_proj_bias, gradient_scale)
+ if down_proj_bias is not None:
+ down_proj_bias = scale_grad(down_proj_bias, gradient_scale)
+
+ # Resolve dtensors
+ up_proj_weight = _layers.mlp.resolve_dtensor(up_proj_weight)
+ down_proj_weight = _layers.mlp.resolve_dtensor(down_proj_weight)
+ if up_proj_bias is not None:
+ up_proj_bias = _layers.mlp.resolve_dtensor(up_proj_bias)
+ if down_proj_bias is not None:
+ down_proj_bias = _layers.mlp.resolve_dtensor(down_proj_bias)
+
+ # Up projection
+ x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias)
+
+ # Activation
+ x = activation_fn(x)
+
+ # Down projection
+ x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias)
+
+ return x
+
+
+# Combine outputs from shared expert and regular experts
+def combine_expert_shared_outputs(
+ shared_expert_out: torch.Tensor,
+ expert_out: torch.Tensor,
+ shared_expert_weighted_sum: bool = False,
+ moe_top_k: int = 1,
+) -> torch.Tensor:
+ if shared_expert_weighted_sum:
+ # Weighted sum based on number of experts used
+ total_experts = moe_top_k + 1
+ shared_weight = 1.0 / total_experts
+ expert_weight = moe_top_k / total_experts
+ return shared_expert_out * shared_weight + expert_out * expert_weight
+ else:
+ # Simple addition
+ return shared_expert_out + expert_out
+
+
+# Global variable to store load balancing loss
+_LOAD_BALANCING_LOSS = []
+
+
+def save_load_balancing_loss(loss):
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.append(loss)
+
+
+def get_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ return _LOAD_BALANCING_LOSS
+
+
+def clear_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.clear()
+
+
+def batched_load_balancing_loss(args):
+ if args.moe_loss_weight == 0:
+ return 0.0
+
+ tokens_per_expert, expert_scores = zip(*get_load_balancing_loss())
+ num_layers_per_pipeline_stage = args.num_layers // args.pipeline_model_parallel_size
+ if args.num_layers_per_virtual_pipeline_stage is not None:
+ num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage
+
+ if len(tokens_per_expert) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f"Expected {num_layers_per_pipeline_stage} token_per_experts "
+ f"but found {len(tokens_per_expert)}.\nnum_layers = "
+ f"{args.num_layers}\npipeline_model_parallel_size = "
+ f"{args.pipeline_model_parallel_size}\n"
+ "num_layers_per_virtual_pipeline_stage"
+ f" = {args.num_layers_per_virtual_pipeline_stage}",
+ )
+ if len(expert_scores) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f"Expected {num_layers_per_pipeline_stage} expert_scores "
+ f"but found {len(tokens_per_expert)}.\nnum_layers = "
+ f"{args.num_layers}\npipeline_model_parallel_size = "
+ f"{args.pipeline_model_parallel_size}\n"
+ "num_layers_per_virtual_pipeline_stage"
+ f" = {args.num_layers_per_virtual_pipeline_stage}",
+ )
+
+ # Verify the shape of the tokens_per_expert and expert_scores tensors.
+ assert all(
+ (x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert)
+ )
+
+ tokens = expert_scores[0].shape[0]
+ assert all(
+ (
+ (
+ x.ndim == 2
+ and x.shape[1] == args.moe_num_experts
+ and x.shape[0] == tokens
+ )
+ for x in expert_scores
+ )
+ )
+
+ # Concatenate the contributions of each layer and convert to
+ # the correct types and formats for the dot product.
+ expert_scores = torch.cat(expert_scores, dim=1)
+ if args.moe_lbl_in_fp32:
+ expert_scores = expert_scores.float()
+ if tokens != 0:
+ expert_scores = expert_scores.mean(dim=0)
+ else:
+ expert_scores = expert_scores.sum(dim=0)
+ tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype)
+
+ expected_values = num_layers_per_pipeline_stage * args.moe_num_experts
+ assert tokens_per_expert.numel() == expected_values
+ assert expert_scores.numel() == expected_values
+
+ # Calculate the total scale across all factors.
+ #
+ # loss_weight * num_experts / (num_layers * tokens * top_k)
+ scale_numerator = args.moe_num_experts * args.moe_loss_weight
+ scale_denominator = args.num_layers * tokens * args.moe_top_k
+ scale = scale_numerator / scale_denominator
+ return scale * torch.dot(tokens_per_expert, expert_scores)
+
+
+# Calculate the expert capacity based on tokens, top_k, number of experts,
+# expert parallel group, capacity factor, and whether expert model parallelism is used.
+def expert_capacity(
+ tokens: int,
+ top_k: int,
+ num_experts: int,
+ expert_parallel_group: int,
+ moe_capacity_factor: float,
+ moe_expert_model_parallelism: bool,
+) -> int:
+ world_size = (
+ dist.get_world_size(expert_parallel_group)
+ if moe_expert_model_parallelism
+ else 1
+ )
+
+ tokens_per_expert = top_k * tokens * world_size / num_experts
+ return int(moe_capacity_factor * tokens_per_expert)
+
+
+def load_balancing_loss(
+ tokens_per_expert: torch.Tensor,
+ expert_scores: torch.Tensor,
+ top_k: int,
+ num_experts: int,
+):
+ assert len(expert_scores.size()) == 2
+ tokens, num_experts = expert_scores.size()
+ assert num_experts == num_experts
+ assert len(tokens_per_expert.size()) == 1
+ (num_experts,) = tokens_per_expert.size()
+ assert num_experts == num_experts
+ scale = num_experts / (tokens * top_k)
+ return scale * torch.dot(
+ tokens_per_expert.to(expert_scores.dtype),
+ expert_scores.mean(dim=0),
+ )
+
+
+def indices_and_bins(
+ top_expert: torch.Tensor,
+ sort_end_bit: int,
+ num_experts: int,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ top_expert = top_expert.int()
+
+ # Ensure contiguous memory layout
+ top_expert = top_expert.contiguous()
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(top_expert.device):
+ output = ops.sort(top_expert, sort_end_bit)
+ bin_ids, indices = output
+ tokens_per_expert = ops.histogram(top_expert, num_experts)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+
+ bins = bins.view(1) if not len(bins.size()) else bins
+ return indices, bin_ids, bins, tokens_per_expert
+
+
+def expert_capacity_fn(
+ tokens: int,
+ top_k: int,
+ num_experts: int,
+ expert_parallel_group: torch.distributed.ProcessGroup,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+) -> int:
+ world_size = (
+ dist.get_world_size(expert_parallel_group)
+ if moe_expert_model_parallelism
+ else 1
+ )
+ tokens_per_expert = top_k * tokens * world_size / num_experts
+ return int(moe_capacity_factor * tokens_per_expert)
+
+
+def permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ top_k,
+ w1,
+ w2,
+ w1_bias,
+ w2_bias,
+ gradient_scale,
+ alpha,
+):
+ # Route tokens to experts
+ x = x.view(-1, x.shape[-1])
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(x.device):
+ x = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
+
+ # Expert computation
+ x = mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale, alpha)
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(x.device):
+ # Route tokens back
+ out = ops.binned_scatter(x, indices, expert_weights, bins, top_k)
+ return out
+
+
+def forward_once(
+ x: torch.Tensor,
+ expert_weights: torch.Tensor,
+ top_experts: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ top_k: int = 4,
+ num_experts: int = 128,
+ expert_parallel_group: int = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ mlp_impl: Optional[str] = None,
+):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
+ top_experts, sort_end_bit, num_experts
+ )
+
+ # Calculate expert capacity
+ sl, bs, _ = x.size()
+
+ expert_capacity = expert_capacity_fn(
+ sl * bs,
+ top_k,
+ num_experts,
+ expert_parallel_group,
+ moe_capacity_factor,
+ moe_expert_model_parallelism,
+ )
+
+ if expert_capacity == 0:
+ expert_capacity = torch.max(tokens_per_expert).item()
+
+ x = permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ top_k,
+ w1,
+ w2,
+ w1_bias,
+ w2_bias,
+ gradient_scale,
+ alpha,
+ )
+ return x, tokens_per_expert
+
+
+def parallel_forward_once(
+ x: torch.Tensor,
+ expert_weights: torch.Tensor,
+ top_experts: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ top_k: int = 4,
+ num_experts: int = 128,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = True,
+ hidden_size: int = 1152,
+ mlp_impl: Optional[str] = "grouped",
+):
+ # Flatten inputs
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+
+ # TODO: remove debugging var
+ # my_rank = dist.get_rank(expert_parallel_group) if expert_parallel_group else 0
+
+ with torch.no_grad():
+ # Step 1: Local permutation setup
+ indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
+ top_experts, sort_end_bit, num_experts
+ )
+
+ # Calculate sharding parameters
+ world_size = dist.get_world_size(expert_parallel_group)
+ hidden_sharding_deg = hidden_sharding_degree(
+ world_size, num_experts, hidden_size
+ )
+ experts_per_rank_val = experts_per_rank(num_experts, world_size)
+
+ # Replicate token counts for hidden sharding
+ repeated_tokens_per_expert = ops.repeat(
+ tokens_per_expert, (hidden_sharding_deg,)
+ )
+
+ # Exchange token counts across devices
+ parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert)
+
+ # Ensure CUB knows which device to use
+ tpe_handle = dist.all_to_all_single(
+ parallel_tokens_per_expert,
+ repeated_tokens_per_expert,
+ group=expert_parallel_group,
+ async_op=True,
+ )
+
+ # Step 2: Local permutation - group tokens by target device
+ x = x.view(-1, x.shape[-1]) # [sl * bs, hs]
+ x = ops.gather(x, indices, bin_ids, bins, top_k)
+
+ # Step 3: Compute communication counts and exchange tokens
+ with torch.no_grad():
+ tpe_handle.wait()
+
+ # Reshape for per-device calculations
+ repeated_tokens_per_expert = repeated_tokens_per_expert.view(
+ world_size, experts_per_rank_val
+ )
+ parallel_tokens_per_expert = parallel_tokens_per_expert.view(
+ world_size, experts_per_rank_val
+ )
+
+ # Calculate send/recv counts
+ send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1).tolist()
+ # recv_counts = parallel_tokens_per_expert.cpu().sum(dim=-1).tolist()
+ parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
+ recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1).tolist()
+ tokens_received = sum(recv_counts)
+
+ # Replicate for hidden sharding
+ x = ops.repeat(x, (hidden_sharding_deg, 1))
+
+ # Cross-device token exchange
+ parallel_x, parallel_x_handle = _layers.all_to_all.all_to_all(
+ x, recv_counts, send_counts, expert_parallel_group, async_op=True
+ )
+
+ with torch.no_grad():
+ # Step 4: Setup for local expert computation
+ replicate_bins = ops.inclusive_cumsum(parallel_tokens_per_expert.flatten(), 0)
+ replicate_bins = (
+ replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins
+ )
+
+ # Create expert indices for received tokens
+ parallel_top_expert = torch.remainder(
+ torch.arange(
+ num_experts * hidden_sharding_deg,
+ dtype=torch.int32,
+ device=indices.device,
+ ),
+ experts_per_rank_val,
+ )
+ parallel_top_expert = ops.replicate(
+ parallel_top_expert.unsqueeze(dim=0),
+ replicate_bins,
+ tokens_received,
+ ).flatten()
+
+ # Sort tokens by expert assignment
+ parallel_bin_ids, parallel_indices = ops.sort(
+ parallel_top_expert,
+ sort_end_bit,
+ )
+
+ # Calculate bins for local experts
+ parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
+ dim=0, dtype=torch.int
+ )
+ parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = (
+ parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins
+ )
+
+ # Calculate expert capacity
+ expert_capacity = expert_capacity_fn(
+ tokens_received,
+ top_k,
+ experts_per_rank_val,
+ expert_parallel_group,
+ moe_capacity_factor,
+ moe_expert_model_parallelism,
+ )
+ if expert_capacity == 0:
+ expert_capacity = torch.max(parallel_tokens_per_expert).item()
+
+ # Locally permute the tokens and perform the expert computation.
+ # Block to make sure that the cross-device permutation is complete.
+ if mlp_impl == "grouped":
+ # GroupedMLP requires counts on CPU. We can use the tensor already
+ # moved to CPU for the prior all_to_all, which avoids an extra
+ # device synchronization.
+ parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+
+ # Step 5: Expert computation
+ parallel_x_handle.wait()
+
+ parallel_x = permute_and_compute(
+ parallel_x,
+ parallel_tokens_per_expert,
+ parallel_indices,
+ parallel_bin_ids,
+ None, # expert_weights
+ parallel_bins,
+ expert_capacity,
+ top_k=1,
+ w1=w1,
+ w2=w2,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ )
+
+ # Step 6: Reverse communication - send results back
+ x, _ = _layers.all_to_all.all_to_all(
+ parallel_x, send_counts, recv_counts, expert_parallel_group
+ )
+
+ # Step 7: Reduce across hidden sharding dimension
+ shape = (hidden_sharding_deg, -1, hidden_size)
+ x = x.view(shape).sum(dim=0)
+
+ # Step 8: Final local unpermutation
+ x = ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
+
+ return x, tokens_per_expert.flatten()
+
+
+def moe_forward(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: Optional[torch.Tensor],
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+ w1: torch.Tensor = None,
+ w2: torch.Tensor = None,
+ w1_bias: torch.Tensor = None,
+ w2_bias: torch.Tensor = None,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ forward_fn: Any = None,
+ hidden_size: int = None,
+ mlp_impl: str = "grouped",
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+
+ # Route tokens to experts
+ logits, expert_weights, expert_indices = route_tokens(
+ x,
+ router_weight,
+ router_bias,
+ moe_top_k,
+ moe_num_experts,
+ moe_jitter_eps,
+ moe_normalize_expert_weights,
+ uniform_expert_assignment,
+ training,
+ )
+
+ # Create router scores for output
+ router_scores = (
+ torch.zeros_like(logits)
+ .scatter_(1, expert_indices, expert_weights)
+ .transpose(0, 1)
+ )
+
+ in_shape = x.size()
+
+ # Prepare forward function arguments
+ forward_args = {
+ "x": x,
+ "expert_weights": expert_weights,
+ "top_experts": expert_indices,
+ "w1": w1,
+ "w2": w2,
+ "w1_bias": w1_bias,
+ "w2_bias": w2_bias,
+ "gradient_scale": gradient_scale,
+ "alpha": alpha,
+ "sort_end_bit": sort_end_bit,
+ "top_k": moe_top_k,
+ "num_experts": moe_num_experts,
+ "expert_parallel_group": expert_parallel_group,
+ "moe_capacity_factor": moe_capacity_factor,
+ "moe_expert_model_parallelism": moe_expert_model_parallelism,
+ "mlp_impl": mlp_impl,
+ }
+
+ # Add hidden_size for parallel forward
+ if moe_expert_model_parallelism and hidden_size is not None:
+ forward_args["hidden_size"] = hidden_size
+ elif moe_expert_model_parallelism and hidden_size is None:
+ # Infer hidden_size from input shape
+ forward_args["hidden_size"] = x.shape[-1]
+
+ # Compute expert outputs
+ x, tokens_per_expert = forward_fn(**forward_args)
+
+ # Save load balancing loss if needed
+ moe_loss_weight = 0.0 # Can be made configurable
+ if training and moe_loss_weight > 0:
+ save_load_balancing_loss((tokens_per_expert, logits))
+
+ # Restore original shape
+ x = x.view(in_shape)
+
+ return x, expert_weights, router_scores
+
+
+def moe_forward_with_shared_expert(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: Optional[torch.Tensor],
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+ w1: torch.Tensor = None,
+ w2: torch.Tensor = None,
+ w1_bias: torch.Tensor = None,
+ w2_bias: torch.Tensor = None,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ forward_fn: Any = None,
+ hidden_size: int = None,
+ mlp_impl: str = "grouped",
+ # Shared expert parameters
+ shared_up_proj_weight: Optional[torch.Tensor] = None,
+ shared_down_proj_weight: Optional[torch.Tensor] = None,
+ shared_up_proj_bias: Optional[torch.Tensor] = None,
+ shared_down_proj_bias: Optional[torch.Tensor] = None,
+ shared_expert_weighted_sum: bool = False,
+ shared_activation_fn: Optional[Any] = None,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+
+ # First, compute regular MoE forward pass
+ expert_out, expert_weights, router_scores = moe_forward(
+ x=x,
+ router_weight=router_weight,
+ router_bias=router_bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=training,
+ w1=w1,
+ w2=w2,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=moe_expert_model_parallelism,
+ forward_fn=forward_fn,
+ hidden_size=hidden_size,
+ mlp_impl=mlp_impl,
+ )
+
+ # If shared expert weights provided, compute shared expert output
+ if shared_up_proj_weight is not None and shared_down_proj_weight is not None:
+ shared_expert_out = shared_mlp_forward(
+ x=x,
+ up_proj_weight=shared_up_proj_weight,
+ down_proj_weight=shared_down_proj_weight,
+ up_proj_bias=shared_up_proj_bias,
+ down_proj_bias=shared_down_proj_bias,
+ activation_fn=shared_activation_fn,
+ gradient_scale=gradient_scale,
+ )
+
+ # Combine expert outputs
+ combined_out = combine_expert_shared_outputs(
+ shared_expert_out=shared_expert_out,
+ expert_out=expert_out,
+ shared_expert_weighted_sum=shared_expert_weighted_sum,
+ moe_top_k=moe_top_k,
+ )
+
+ return combined_out, expert_weights, router_scores
+
+ # Return regular MoE output if no shared expert
+ return expert_out, expert_weights, router_scores
+
+
+def create_shared_expert_weights(
+ hidden_size: int,
+ shared_expert_hidden_size: int,
+ device: torch.device,
+ dtype: torch.dtype,
+ init_method: Any,
+ output_layer_init_method: Any = None,
+) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
+
+ if output_layer_init_method is None:
+ output_layer_init_method = init_method
+
+ # Create weight tensors
+ up_proj_weight = torch.empty(
+ shared_expert_hidden_size,
+ hidden_size,
+ device=device,
+ dtype=dtype,
+ )
+ down_proj_weight = torch.empty(
+ hidden_size,
+ shared_expert_hidden_size,
+ device=device,
+ dtype=dtype,
+ )
+
+ # Initialize weights
+ init_method(up_proj_weight)
+ output_layer_init_method(down_proj_weight)
+
+ # No bias by default
+ return up_proj_weight, down_proj_weight, None, None
+
+
+# HACK: Extract device_mesh from pre-hook closure - required for transformers integration
+# This exists because device_mesh is trapped in hook closures with no model attribute
+# Fragile - breaks if hook structure changes or Python internals change
+# TODO: Replace with a more robust solution when available
+def get_device_mesh(model):
+ # Extract device_mesh from child's unused pre_hook closure
+ try:
+ # Find the pre-hook that contains 'device_mesh' in its closure
+ hook = next(
+ h
+ for h in model.experts._forward_pre_hooks.values()
+ if "device_mesh" in h.__code__.co_freevars
+ )
+ # Extract the device_mesh from the closure
+ return hook.__closure__[
+ hook.__code__.co_freevars.index("device_mesh")
+ ].cell_contents
+ except Exception:
+ return None
+
+
+class MegaBlocksMoeMLP(torch.nn.Module):
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ gradient_scale = getattr(self.experts, "gradient_scale", None)
+ alpha = getattr(self.experts, "alpha", 1.0)
+ moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+ uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
+
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = get_device_mesh(self)
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
+
+ has_parallel = (
+ expert_parallel_group is not None
+ and dist.is_initialized()
+ and dist.get_world_size(expert_parallel_group) > 1
+ )
+ forward_fn = parallel_forward_once if has_parallel else forward_once
+
+ sort_end_bit = max(
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
+ )
+ mlp_impl = getattr(self, "mlp_impl", "grouped")
+ output, expert_weights_out, *_ = moe_forward(
+ x=x,
+ router_weight=self.router.weight,
+ router_bias=self.router.bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=self.training,
+ w1=self.experts.gate_up_proj,
+ w2=self.experts.down_proj,
+ w1_bias=self.experts.gate_up_proj_bias,
+ w2_bias=self.experts.down_proj_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=has_parallel,
+ forward_fn=forward_fn,
+ hidden_size=self.experts.hidden_size,
+ mlp_impl=mlp_impl,
+ )
+ return output, expert_weights_out
+
+
+# Export main classes
+__all__ = ["MegaBlocksMoeMLP", "MegaBlocksMoeMLPWithSharedExpert"]
+
+
+class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
+
+ def __init__(self):
+ super().__init__()
+ # Shared expert weights will be set by the user
+ self.shared_up_proj_weight = None
+ self.shared_down_proj_weight = None
+ self.shared_up_proj_bias = None
+ self.shared_down_proj_bias = None
+ self.shared_expert_weighted_sum = False
+ self.shared_activation_fn = None
+
+ def set_shared_expert_weights(
+ self,
+ up_proj_weight: torch.Tensor,
+ down_proj_weight: torch.Tensor,
+ up_proj_bias: Optional[torch.Tensor] = None,
+ down_proj_bias: Optional[torch.Tensor] = None,
+ weighted_sum: bool = False,
+ activation_fn: Optional[Any] = None,
+ ):
+ self.shared_up_proj_weight = up_proj_weight
+ self.shared_down_proj_weight = down_proj_weight
+ self.shared_up_proj_bias = up_proj_bias
+ self.shared_down_proj_bias = down_proj_bias
+ self.shared_expert_weighted_sum = weighted_sum
+ self.shared_activation_fn = activation_fn
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ gradient_scale = getattr(self.experts, "gradient_scale", None)
+ alpha = getattr(self.experts, "alpha", 1.0)
+ moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+ uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
+
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = get_device_mesh(self)
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
+
+ has_parallel = (
+ expert_parallel_group is not None
+ and dist.is_initialized()
+ and dist.get_world_size(expert_parallel_group) > 1
+ )
+ forward_fn = parallel_forward_once if has_parallel else forward_once
+
+ sort_end_bit = max(
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
+ )
+ mlp_impl = getattr(self, "mlp_impl", "grouped")
+
+ output, expert_weights_out, *_ = moe_forward_with_shared_expert(
+ x=x,
+ router_weight=self.router.weight,
+ router_bias=self.router.bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=self.training,
+ w1=self.experts.gate_up_proj,
+ w2=self.experts.down_proj,
+ w1_bias=self.experts.gate_up_proj_bias,
+ w2_bias=self.experts.down_proj_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=has_parallel,
+ forward_fn=forward_fn,
+ hidden_size=self.experts.hidden_size,
+ mlp_impl=mlp_impl,
+ # Shared expert parameters
+ shared_up_proj_weight=self.shared_up_proj_weight,
+ shared_down_proj_weight=self.shared_down_proj_weight,
+ shared_up_proj_bias=self.shared_up_proj_bias,
+ shared_down_proj_bias=self.shared_down_proj_bias,
+ shared_expert_weighted_sum=self.shared_expert_weighted_sum,
+ shared_activation_fn=self.shared_activation_fn,
+ )
+ return output, expert_weights_out
+
+
+# Patch for XPU support
+if hasattr(torch, "xpu") and torch.xpu.is_available():
+ from .xpu_fused_moe import MegaBlocksMoeMLP
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/megablocks/__init__.py b/build/torch28-cxx11-cu128-x86_64-linux/megablocks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..03dbc1afe1cf156661a2b1b22003cd5f599a0309
--- /dev/null
+++ b/build/torch28-cxx11-cu128-x86_64-linux/megablocks/__init__.py
@@ -0,0 +1,26 @@
+import ctypes
+import sys
+
+import importlib
+from pathlib import Path
+from types import ModuleType
+
+def _import_from_path(file_path: Path) -> ModuleType:
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
+ # it would also be used for other imports. So, we make a module name that
+ # depends on the path for it to be unique using the hex-encoded hash of
+ # the path.
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
+ module_name = path_hash
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ if spec is None:
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
+ module = importlib.util.module_from_spec(spec)
+ if module is None:
+ raise ImportError(f"Cannot load module {module_name} from spec")
+ sys.modules[module_name] = module
+ spec.loader.exec_module(module) # type: ignore
+ return module
+
+
+globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/metadata.json b/build/torch28-cxx11-cu128-x86_64-linux/metadata.json
new file mode 100644
index 0000000000000000000000000000000000000000..9cf5deed9898dce769f4cc73913d3530b92a0bd8
--- /dev/null
+++ b/build/torch28-cxx11-cu128-x86_64-linux/metadata.json
@@ -0,0 +1,4 @@
+{
+ "version": 1,
+ "python-depends": []
+}
\ No newline at end of file
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/ops/__init__.py b/build/torch28-cxx11-cu128-x86_64-linux/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b944080df810d0b0cfc571f3009b0098a651f9b7
--- /dev/null
+++ b/build/torch28-cxx11-cu128-x86_64-linux/ops/__init__.py
@@ -0,0 +1,35 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from .binned_gather import binned_gather
+from .binned_scatter import binned_scatter
+from .cumsum import exclusive_cumsum, inclusive_cumsum
+from .gather import gather
+from .histogram import histogram
+from .padded_gather import padded_gather
+from .padded_scatter import padded_scatter
+from .repeat import repeat
+from .replicate import replicate
+from .round_up import round_up
+from .scatter import scatter
+from .sort import sort
+from .sum import sum
+from .topology import topology
+
+__all__ = [
+ 'binned_gather',
+ 'binned_scatter',
+ 'exclusive_cumsum',
+ 'inclusive_cumsum',
+ 'gather',
+ 'histogram',
+ 'padded_gather',
+ 'padded_scatter',
+ 'repeat',
+ 'replicate',
+ 'round_up',
+ 'scatter',
+ 'sort',
+ 'sum',
+ 'topology',
+]
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/ops/all_to_all_benchmark.py b/build/torch28-cxx11-cu128-x86_64-linux/ops/all_to_all_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c939818edca3345f6344bbc7cef07ffe3cd0181
--- /dev/null
+++ b/build/torch28-cxx11-cu128-x86_64-linux/ops/all_to_all_benchmark.py
@@ -0,0 +1,63 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import torch.distributed as dist
+
+# from megablocks import benchmark_util
+# from megablocks.layers.all_to_all import all_to_all
+
+from .. import benchmark_util
+from .._layers.all_to_all import all_to_all
+
+_ALL_TO_ALL_BENCHMARK = (
+ (8, 1024),
+ (16, 1024),
+ (32, 1024),
+ (64, 1024),
+ (128, 1024),
+ (256, 1024),
+ (512, 1024),
+ (1024, 1024),
+ (2 * 1024, 1024),
+ (4 * 1024, 1024),
+ (8 * 1024, 1024),
+ (16 * 1024, 1024),
+ (32 * 1024, 1024),
+ (64 * 1024, 1024),
+ (128 * 1024, 1024),
+ (256 * 1024, 1024),
+ (512 * 1024, 1024),
+ (1024 * 1024, 1024),
+)
+
+
+def benchmark_all_to_all(group, sl, hs):
+ world_size = dist.get_world_size(group)
+ assert (sl % world_size) == 0
+ send_recv_sizes = [sl // world_size] * world_size
+
+ x = torch.randn((sl, hs)).cuda().half()
+
+ details = {
+ 'world_size': world_size,
+ 'message_size (B)': send_recv_sizes[0] * hs * 2, # 2B elements.
+ }
+
+ def benchmark():
+ return all_to_all(x, send_recv_sizes, send_recv_sizes, group)
+
+ time, std = benchmark_util.benchmark_function(benchmark)
+
+ if dist.get_rank(group) == 0:
+ benchmark_util.log_benchmark('All-To-All', details, time, std)
+
+
+if __name__ == '__main__':
+ assert dist.is_available()
+ group = dist.init_process_group(backend='nccl')
+ local_rank = dist.get_rank(group)
+ torch.cuda.set_device(local_rank)
+
+ for args in _ALL_TO_ALL_BENCHMARK:
+ benchmark_all_to_all(group, *args)
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/ops/binned_gather.py b/build/torch28-cxx11-cu128-x86_64-linux/ops/binned_gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..189a7fa3518d660f29ea32e7a04827164af98d60
--- /dev/null
+++ b/build/torch28-cxx11-cu128-x86_64-linux/ops/binned_gather.py
@@ -0,0 +1,37 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for binned_gather kernel.
+class BinnedGatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bins: torch.Tensor,
+ bin_size: int,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bins)
+ ctx.top_k = top_k
+ return kernels.binned_gather(x, indices, None, bins, bin_size, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ indices, bins = ctx.saved_tensors
+ out = kernels.binned_scatter(grad, indices, None, bins, ctx.top_k)
+ return out, None, None, None, None
+
+
+binned_gather = BinnedGatherOp.apply
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/ops/binned_scatter.py b/build/torch28-cxx11-cu128-x86_64-linux/ops/binned_scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb937c0c106662ce8108c1cb926f8f063b163d3d
--- /dev/null
+++ b/build/torch28-cxx11-cu128-x86_64-linux/ops/binned_scatter.py
@@ -0,0 +1,59 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for binned_scatter kernel.
+class BinnedScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ):
+ assert len(x.size()) == 3
+ ctx.bin_size = x.size(1)
+ ctx.top_k = top_k
+
+ # TODO(tgale): Don't save 'x' for backwards if we don't need to
+ # calculate the gradient w.r.t. 'weights'.
+ ctx.save_for_backward(x, indices, weights, bins)
+ return kernels.binned_scatter(x, indices, weights, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ x, indices, weights, bins = ctx.saved_tensors
+ out = kernels.binned_gather(
+ grad,
+ indices,
+ weights,
+ bins,
+ ctx.bin_size,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[2]:
+ wgrad = kernels.binned_scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bins,
+ ctx.top_k,
+ )
+ return out, None, wgrad, None, None
+
+
+binned_scatter = BinnedScatterOp.apply
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/ops/cumsum.py b/build/torch28-cxx11-cu128-x86_64-linux/ops/cumsum.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2b7572391e20045d335cf7337246e8a9b9f57ef
--- /dev/null
+++ b/build/torch28-cxx11-cu128-x86_64-linux/ops/cumsum.py
@@ -0,0 +1,52 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ # import megablocks_ops as ops # type: ignore
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrappers for cumsum kernels.
+# NOTE: Does not support gradients.
+class ExclusiveCumsumOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, dim: int):
+ if len(x.size()) == 1:
+ x = x.view([1, -1])
+ out = torch.empty_like(x)
+ ops.exclusive_cumsum(x, 1, out)
+ return out.squeeze()
+ out = torch.empty_like(x)
+ ops.exclusive_cumsum(x, dim, out)
+ return out
+
+
+exclusive_cumsum = ExclusiveCumsumOp.apply
+
+
+class InclusiveCumsumOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, dim: int) -> torch.Tensor:
+ if len(x.size()) == 1:
+ x = x.view([1, -1])
+ out = torch.empty_like(x)
+ ops.inclusive_cumsum(x, 1, out)
+ return out.squeeze()
+ out = torch.empty_like(x)
+ ops.inclusive_cumsum(x, dim, out)
+ return out
+
+
+inclusive_cumsum = InclusiveCumsumOp.apply
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/ops/gather.py b/build/torch28-cxx11-cu128-x86_64-linux/ops/gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1f87c1e7bed8d3589dd790805234976e0b05898
--- /dev/null
+++ b/build/torch28-cxx11-cu128-x86_64-linux/ops/gather.py
@@ -0,0 +1,38 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for gather kernel.
+class GatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bin_ids, bins)
+ ctx.top_k = top_k
+ return kernels.gather(x, indices, bin_ids, None, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+
+ indices, bin_ids, bins = ctx.saved_tensors
+ out = kernels.scatter(grad, indices, bin_ids, None, bins, ctx.top_k)
+ return out, None, None, None, None, None
+
+
+gather = GatherOp.apply
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/ops/histogram.py b/build/torch28-cxx11-cu128-x86_64-linux/ops/histogram.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b3f058ec373cbba7555704fb5e4212c3cc75d9d
--- /dev/null
+++ b/build/torch28-cxx11-cu128-x86_64-linux/ops/histogram.py
@@ -0,0 +1,27 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for histogram kernel.
+# NOTE: Does not support gradients.
+class HistogramOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, max_val: float):
+ return ops.histogram(x, max_val)
+
+
+histogram = HistogramOp.apply
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/ops/histogram_benchmark.py b/build/torch28-cxx11-cu128-x86_64-linux/ops/histogram_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..c57b7bf8228e01237236748147368b09ffdf8072
--- /dev/null
+++ b/build/torch28-cxx11-cu128-x86_64-linux/ops/histogram_benchmark.py
@@ -0,0 +1,78 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import numpy as np
+import torch
+from absl.testing import parameterized
+
+from .. import ops
+
+_HISTOGRAM_TESTS = (
+ (16384, torch.int32, 2),
+ (16384, torch.int32, 4),
+ (16384, torch.int32, 8),
+ (16384, torch.int32, 16),
+ (16384, torch.int32, 32),
+ (16384, torch.int32, 64),
+ (16384, torch.int32, 128),
+ (16384, torch.int32, 256),
+)
+
+
+def benchmark_function(fn, iterations=10):
+ # Run once to get rid of startup overhead.
+ fn()
+ times = []
+ for _ in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+ start.record()
+ fn()
+ end.record()
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ times = np.array(times)
+ return times.mean(), times.std(), times.max(), times.min()
+
+
+def log_benchmark(arguments, mean_t, std_t):
+ print('=' * 60)
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t))
+ print('=' * 60)
+
+
+class HistogramBenchmark(parameterized.TestCase):
+
+ @parameterized.parameters(*_HISTOGRAM_TESTS)
+ def testHistogram(self, n, dtype, max_val):
+ x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
+
+ mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.histogram(x, max_val),)
+ arguments = {
+ 'n': n,
+ 'dtype': dtype,
+ 'max_val': max_val,
+ }
+ log_benchmark(arguments, mean_t, std_t)
+
+ @parameterized.parameters(*_HISTOGRAM_TESTS)
+ def testTorchHistogram(self, n, dtype, max_val):
+ x = torch.randint(0, 128, (n,)).cuda().to(dtype)
+
+ mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.histc(x, max_val, 0, max_val - 1),)
+ arguments = {
+ 'n': n,
+ 'dtype': dtype,
+ 'max_val': max_val,
+ }
+ log_benchmark(arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/ops/matmul_benchmark.py b/build/torch28-cxx11-cu128-x86_64-linux/ops/matmul_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ccc5dcec5e9a663794fad944c45285869c4d1c1
--- /dev/null
+++ b/build/torch28-cxx11-cu128-x86_64-linux/ops/matmul_benchmark.py
@@ -0,0 +1,415 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+
+# import stk
+
+# try:
+# import stk
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/ops/matmul_benchmark.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+
+# Calling tensor.t() calls tensor.transpose(0, 1) which calls
+# torch.as_strided(...). Circumvent this chain to avoid an overhead
+# this adds.
+def transpose_view(x):
+ return torch.as_strided(
+ x,
+ (x.shape[1], x.shape[0]),
+ (x.stride()[1], x.stride()[0]),
+ )
+
+
+_MATMUL_TESTS = (
+ (64 * 1024, 512, 2048, 64),
+ (32 * 1024, 768, 3072, 64),
+ (8 * 1024, 1024, 4096, 64),
+ (4 * 2048, 4096, 4 * 4096, 4),
+)
+
+
+def log_benchmark(name, arguments, time, std, flops):
+ benchmark_util.log_benchmark(name, arguments, time, std)
+ print('flops = {:.2f}B'.format(flops / 1e9))
+ print('throughput = {:.2f}T'.format(flops / 1e9 / time))
+ print('=' * 60)
+
+
+class MatmulBenchmark(parameterized.TestCase):
+
+ def build_sparse_matrix(self, x, padded_bins, fhs, ne):
+ blocking = 128
+ padded_tokens, _ = x.size()
+ assert padded_tokens % blocking == 0
+ assert fhs % blocking == 0
+
+ # Offsets for the sparse matrix. All rows have the
+ # same number of nonzero blocks dictated by the
+ # dimensionality of a single expert.
+ block_rows = padded_tokens // blocking
+ blocks_per_row = fhs // blocking
+ offsets = torch.arange(
+ 0,
+ block_rows * blocks_per_row + 1,
+ blocks_per_row,
+ dtype=torch.int32,
+ device=x.device,
+ )
+
+ # Indices for the sparse matrix. The indices for
+ # the intermediate matrix are dynamic depending
+ # on the mapping of tokens to experts.
+ column_indices = ops.topology(
+ padded_bins,
+ blocking,
+ block_rows,
+ blocks_per_row,
+ )
+ data = torch.empty(
+ column_indices.numel(),
+ blocking,
+ blocking,
+ dtype=torch.float16,
+ device=x.device,
+ )
+ shape = (padded_tokens, fhs * ne)
+ row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
+ return stk.Matrix(shape, data, row_indices, column_indices, offsets)
+
+ def build_input_matrix(self, sl, hs, ne):
+ x = torch.randn((sl, hs)).cuda().half()
+
+ # Assign tokens to experts uniformly.
+ top_expert = torch.arange(0, sl).cuda().int() % ne
+
+ bin_ids, indices = ops.sort(top_expert)
+ tokens_per_expert = ops.histogram(top_expert, ne)
+ padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ out = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, 1)
+ return out, padded_bins
+
+ def build_weight_matrix(self, ne, hs, fhs):
+ return torch.randn((hs, ne * fhs)).cuda().half()
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear0_Fwd_SDD_NT(self, sl, hs, fhs, ne):
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
+ w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+ topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+ w = transpose_view(w)
+
+ def benchmark():
+ return stk.ops.sdd(x, w, topo)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '0::Fwd::SDD::NT',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * fhs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear0_GradX_DSD_NN(self, sl, hs, fhs, ne):
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
+ w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+ topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+
+ def benchmark():
+ return stk.ops.dsd(topo, w)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '0::GradX::DSD::NN',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * fhs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear0_GradW_DSD_TN(self, sl, hs, fhs, ne):
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
+ topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+ topo = topo.t()
+
+ def benchmark():
+ return stk.ops.dsd(topo, x)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '0::GradW::DSD::TN',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * fhs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear1_Fwd_DSD_NN(self, sl, hs, fhs, ne):
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
+ w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+ x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+
+ def benchmark():
+ return stk.ops.dsd(x, w)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '1::Fwd::DSD::NN',
+ arguments,
+ mean_t,
+ std_t,
+ x.nnz * hs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear1_GradX_SDD_NT(self, sl, hs, fhs, ne):
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
+ w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+ x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+ out = stk.ops.dsd(x, w)
+ w = transpose_view(w)
+
+ def benchmark():
+ return stk.ops.sdd(out, w, x)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '1::GradX::SDD::NT',
+ arguments,
+ mean_t,
+ std_t,
+ x.nnz * hs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear1_GradW_DSD_TN(self, sl, hs, fhs, ne):
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
+ w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+ x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+ out = stk.ops.dsd(x, w)
+ x = x.t()
+
+ def benchmark():
+ return stk.ops.dsd(x, out)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '1::GradW::DSD::TN',
+ arguments,
+ mean_t,
+ std_t,
+ x.nnz * hs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear0_Fwd_DDD_NT(self, sl, hs, fhs, ne):
+ assert (sl % ne) == 0
+ x = torch.randn((ne, sl // ne, hs)).cuda().half()
+ w = torch.randn((ne, hs, fhs)).cuda().half()
+
+ w = w.transpose(1, 2).contiguous()
+ w = w.transpose(1, 2)
+
+ def benchmark():
+ return torch.bmm(x, w)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '0::Fwd:DDD::NT',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * fhs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear0_GradX_DDD_NN(self, sl, hs, fhs, ne):
+ assert (sl % ne) == 0
+ x = torch.randn((ne, sl // ne, hs)).cuda().half()
+ w = torch.randn((ne, hs, fhs)).cuda().half()
+ out = torch.bmm(x, w)
+ w = w.transpose(1, 2).contiguous()
+
+ def benchmark():
+ return torch.bmm(out, w)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '0:GradX:DDD::NN',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * fhs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear0_GradW_DDD_TN(self, sl, hs, fhs, ne):
+ assert (sl % ne) == 0
+ x = torch.randn((ne, sl // ne, hs)).cuda().half()
+ w = torch.randn((ne, hs, fhs)).cuda().half()
+ out = torch.bmm(x, w)
+ out = out.transpose(1, 2)
+
+ def benchmark():
+ return torch.bmm(out, x)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '0:GradW:DDD::TN',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * fhs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear1_Fwd_DDD_NN(self, sl, hs, fhs, ne):
+ assert (sl % ne) == 0
+ x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+ w = torch.randn((ne, fhs, hs)).cuda().half()
+
+ def benchmark():
+ return torch.bmm(x, w)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '1::Fwd::DDD::NN',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * hs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear1_GradX_DDD_NT(self, sl, hs, fhs, ne):
+ assert (sl % ne) == 0
+ x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+ w = torch.randn((ne, fhs, hs)).cuda().half()
+ out = torch.bmm(x, w)
+ w = torch.transpose(w, 1, 2)
+
+ def benchmark():
+ return torch.bmm(out, w)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '1::GradX::DDD::NT',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * hs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear1_GradW_DDD_TN(self, sl, hs, fhs, ne):
+ assert (sl % ne) == 0
+ x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+ w = torch.randn((ne, fhs, hs)).cuda().half()
+ out = torch.bmm(x, w)
+ x = torch.transpose(x, 1, 2)
+
+ def benchmark():
+ return torch.bmm(x, out)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '1::GradW::DDD::TN',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * hs * 2,
+ )
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/ops/padded_gather.py b/build/torch28-cxx11-cu128-x86_64-linux/ops/padded_gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1cf4047c9494394d2a3884ba8830179013db7ff
--- /dev/null
+++ b/build/torch28-cxx11-cu128-x86_64-linux/ops/padded_gather.py
@@ -0,0 +1,55 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for padded_gather kernel.
+class PaddedGatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bin_ids, bins, padded_bins)
+ ctx.top_k = top_k
+ return kernels.padded_gather(
+ x,
+ indices,
+ bin_ids,
+ None,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+
+ indices, bin_ids, bins, padded_bins = ctx.saved_tensors
+ out = kernels.padded_scatter(
+ grad,
+ indices,
+ bin_ids,
+ None,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+ return out, None, None, None, None, None
+
+
+padded_gather = PaddedGatherOp.apply
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/ops/padded_scatter.py b/build/torch28-cxx11-cu128-x86_64-linux/ops/padded_scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..61e021b81497e472cda5d72bdac557a0ca92d262
--- /dev/null
+++ b/build/torch28-cxx11-cu128-x86_64-linux/ops/padded_scatter.py
@@ -0,0 +1,98 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for padded_scatter kernel.
+class PaddedScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+ ):
+ maybe_x = [x] if ctx.needs_input_grad[3] else []
+ ctx.save_for_backward(
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ *maybe_x,
+ )
+ ctx.top_k = top_k
+ ctx.x_shape = x.shape
+ return kernels.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ saved_tensors = ctx.saved_tensors
+
+ indices, bin_ids, weights, bins, padded_bins = saved_tensors[:5]
+ dgrad = None
+ if ctx.needs_input_grad[0]:
+ dgrad = kernels.padded_gather(
+ grad,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[3]: # need wgrad
+ x = saved_tensors[-1]
+ wgrad = kernels.padded_scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+ return dgrad, None, None, wgrad, None, None, None, None
+
+
+def padded_scatter(
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+):
+ return PaddedScatterOp.apply(
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/ops/padded_scatter_benchmark.py b/build/torch28-cxx11-cu128-x86_64-linux/ops/padded_scatter_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..c575cfe7487d346ba9ec18bbb7ef17f2eb77ec51
--- /dev/null
+++ b/build/torch28-cxx11-cu128-x86_64-linux/ops/padded_scatter_benchmark.py
@@ -0,0 +1,66 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import torch
+from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+_PADDED_SCATTER_BENCHMARK = (
+ # dMoE-Medium, 8-way EMP.
+ (1024 * 16, 1024, 8, 4),
+ # dMoE-Medium, post-all-to-all.
+ (1024 * 16 * 4, 1024, 8, 1),
+)
+
+
+class PaddedScatterTest(parameterized.TestCase):
+
+ @parameterized.parameters(*_PADDED_SCATTER_BENCHMARK)
+ def testPaddedScatter(self, sl, hs, ne, top_k):
+ # Create the data and indices.
+ x = torch.randn((sl, hs)).cuda().half()
+
+ # Randomly assign tokens to experts.
+ top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int()
+ bin_ids, indices = ops.sort(top_expert)
+ tokens_per_expert = ops.histogram(top_expert, ne)
+ padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+
+ # Sample weights for the scatter reduce.
+ weights = torch.rand((sl * top_k,)).cuda().half()
+
+ # Gather the data to prepare for backwards.
+ x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
+
+ def benchmark():
+ return ops.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ time, std = benchmark_util.benchmark_function(benchmark)
+ benchmark_util.log_benchmark(
+ 'Padded Scatter',
+ {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'num_experts': ne,
+ 'top_k': top_k,
+ },
+ time,
+ std,
+ )
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/ops/permute_benchmark.py b/build/torch28-cxx11-cu128-x86_64-linux/ops/permute_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..6536eeeae402659a087e5c51ef9840627af56501
--- /dev/null
+++ b/build/torch28-cxx11-cu128-x86_64-linux/ops/permute_benchmark.py
@@ -0,0 +1,149 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import torch
+from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+_PERMUTE_TESTS = (
+ (16384, 768, 2),
+ (16384, 768, 4),
+ (16384, 768, 8),
+ (16384, 768, 16),
+ (16384, 768, 32),
+ (16384, 768, 64),
+ (16384, 768, 128),
+ (16384 * 8, 768, 2),
+ (16384 * 8, 768, 4),
+ (16384 * 8, 768, 8),
+ (16384 * 8, 768, 16),
+ (16384 * 8, 768, 32),
+ (16384 * 8, 768, 64),
+ (16384 * 8, 768, 128),
+)
+
+
+class PermuteBenchmark(parameterized.TestCase):
+
+ @parameterized.parameters(*_PERMUTE_TESTS)
+ def testBinnedGather(self, sl, hs, ne):
+ # NOTE: Capacity factor == 1.
+ ec = sl // ne
+
+ # Create the data and indices.
+ x = torch.randn((sl, hs)).cuda().half()
+ top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+ bin_ids, indices = ops.sort(top_expert)
+ tokens_per_expert = ops.histogram(indices, ne)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+
+ def benchmark():
+ return ops.binned_gather(x, indices, bins, ec)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'num_experts': ne,
+ }
+ benchmark_util.log_benchmark('BinnedGather', arguments, mean_t, std_t)
+
+ @parameterized.parameters(*_PERMUTE_TESTS)
+ def testBinnedScatter(self, sl, hs, ne):
+ # NOTE: Capacity factor == 1.
+ ec = sl // ne
+
+ # Create the data and indices.
+ x = torch.randn((sl, hs)).cuda().half()
+ top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+ bin_ids, indices = ops.sort(top_expert)
+ tokens_per_expert = ops.histogram(indices, ne)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ x = ops.binned_gather(x, indices, bins, ec)
+
+ def benchmark():
+ return ops.binned_scatter(x, indices, bins)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'num_experts': ne,
+ }
+ benchmark_util.log_benchmark('BinnedScatter', arguments, mean_t, std_t)
+
+ @parameterized.parameters(*_PERMUTE_TESTS)
+ def testPaddedGather(self, sl, hs, ne):
+ # Create the data and indices.
+ x = torch.randn((sl, hs)).cuda().half()
+
+ # Randomly assign tokens to experts.
+ top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+ bin_ids, indices = ops.sort(top_expert)
+ tokens_per_expert = ops.histogram(top_expert, ne)
+ padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+
+ def benchmark():
+ return ops.padded_gather(x, indices, bin_ids, bins, padded_bins)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'num_experts': ne,
+ }
+ benchmark_util.log_benchmark('PaddedGather', arguments, mean_t, std_t)
+
+ @parameterized.parameters(*_PERMUTE_TESTS)
+ def testPaddedScatter(self, sl, hs, ne):
+ # Create the data and indices.
+ x = torch.randn((sl, hs)).cuda().half()
+
+ # Randomly assign tokens to experts.
+ top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+ bin_ids, indices = ops.sort(top_expert)
+ tokens_per_expert = ops.histogram(top_expert, ne)
+ padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins)
+
+ def benchmark():
+ return ops.padded_scatter(x, indices, bin_ids, bins, padded_bins)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'num_experts': ne,
+ }
+ benchmark_util.log_benchmark('PaddedScatter', arguments, mean_t, std_t)
+
+ @parameterized.parameters(*_PERMUTE_TESTS)
+ def testCopy(self, sl, hs, ne):
+ # NOTE: Capacity factor == 1.
+ # ec = sl // ne
+
+ # Create the data and indices.
+ x = torch.randn((sl, hs)).cuda().half()
+ y = x.clone()
+
+ def benchmark():
+ return y.copy_(x)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'num_experts': ne,
+ }
+ benchmark_util.log_benchmark('Copy', arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/ops/repeat.py b/build/torch28-cxx11-cu128-x86_64-linux/ops/repeat.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e9e09de5f857d51cd758ab30b2f3a846d6f9275
--- /dev/null
+++ b/build/torch28-cxx11-cu128-x86_64-linux/ops/repeat.py
@@ -0,0 +1,10 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+
+def repeat(x: torch.Tensor, tiling: torch.Size):
+ if all((t == 1 for t in tiling)):
+ return x
+ return x.repeat(*tiling)
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/ops/replicate.py b/build/torch28-cxx11-cu128-x86_64-linux/ops/replicate.py
new file mode 100644
index 0000000000000000000000000000000000000000..26daf0eede330603a4b8ea7167faf1411d07ca93
--- /dev/null
+++ b/build/torch28-cxx11-cu128-x86_64-linux/ops/replicate.py
@@ -0,0 +1,36 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for replicate kernel.
+class ReplicateOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, bins: torch.Tensor, num_outputs: int):
+ ctx.save_for_backward(bins)
+ out = torch.empty((x.shape[0], num_outputs), dtype=x.dtype, device=x.device)
+ ops.replicate_forward(x, bins, out)
+ return out
+
+ @staticmethod
+ def backward(ctx: Any, grad: torch.Tensor):
+ bins, = ctx.saved_tensors
+ out = torch.empty((grad.shape[0], bins.shape[0]), dtype=grad.dtype, device=grad.device)
+ ops.replicate_backward(grad, bins, out)
+ return out, None, None
+
+
+replicate = ReplicateOp.apply
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/ops/round_up.py b/build/torch28-cxx11-cu128-x86_64-linux/ops/round_up.py
new file mode 100644
index 0000000000000000000000000000000000000000..6cf6bc873c9f448c5fa9126ebcfd66e8688002af
--- /dev/null
+++ b/build/torch28-cxx11-cu128-x86_64-linux/ops/round_up.py
@@ -0,0 +1,14 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+
+def round_up(x: torch.Tensor, value: int):
+ assert isinstance(value, int)
+ assert x.dtype == torch.int32
+
+ # TODO(tgale): If this becomes and issue
+ # do this in a custom kernel. We only expect
+ # to use this on arrays of less than 1k elements.
+ return torch.div(x + (value - 1), value, rounding_mode='trunc') * value
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/ops/scatter.py b/build/torch28-cxx11-cu128-x86_64-linux/ops/scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4605d9b46f387761b070352365f223dbfe69d47
--- /dev/null
+++ b/build/torch28-cxx11-cu128-x86_64-linux/ops/scatter.py
@@ -0,0 +1,72 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Optional
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for scatter kernel.
+class ScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ) -> torch.Tensor:
+ maybe_x = [x] if ctx.needs_input_grad[3] else []
+ ctx.save_for_backward(indices, bin_ids, weights, bins, *maybe_x)
+ ctx.top_k = top_k
+ ctx.x_shape = x.shape
+ return kernels.scatter(x, indices, bin_ids, weights, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ saved_tensors = ctx.saved_tensors
+
+ indices, bin_ids, weights, bins = saved_tensors[:4]
+ dgrad = None
+ if ctx.needs_input_grad[0]:
+ dgrad = kernels.gather(
+ grad,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[3]: # need wgrad
+ x = saved_tensors[-1]
+ wgrad = kernels.scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bin_ids,
+ bins,
+ ctx.top_k,
+ )
+ return dgrad, None, None, wgrad, None, None, None
+
+
+def scatter(
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+) -> Optional[torch.Tensor]:
+ return ScatterOp.apply(x, indices, bin_ids, weights, bins, top_k)
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/ops/sort.py b/build/torch28-cxx11-cu128-x86_64-linux/ops/sort.py
new file mode 100644
index 0000000000000000000000000000000000000000..bda3bf64283e39533c2eae3627e76bb2d0262c9f
--- /dev/null
+++ b/build/torch28-cxx11-cu128-x86_64-linux/ops/sort.py
@@ -0,0 +1,38 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Optional, Tuple
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+_BITS_FOR_DTYPE = {
+ torch.int16: 16,
+ torch.int32: 32,
+ torch.int64: 64,
+}
+
+
+# Autograd wrapper for sort kernel.
+# NOTE: Does not support gradients.
+class SortOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, end_bit: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
+ if end_bit is None:
+ end_bit = _BITS_FOR_DTYPE[x.dtype]
+ x_out = torch.empty_like(x)
+ iota_out = torch.empty_like(x)
+ ops.sort(x, end_bit, x_out, iota_out)
+ return (x_out, iota_out)
+
+
+sort = SortOp.apply
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/ops/sort_benchmark.py b/build/torch28-cxx11-cu128-x86_64-linux/ops/sort_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..a92ff957d4c552c6e61d9279a7989795472af7b7
--- /dev/null
+++ b/build/torch28-cxx11-cu128-x86_64-linux/ops/sort_benchmark.py
@@ -0,0 +1,85 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import numpy as np
+import torch
+from absl.testing import parameterized
+
+from .. import ops
+
+_SORT_TESTS = (
+ (16384, torch.int32, None),
+ (16384, torch.int32, 2),
+ (16384, torch.int32, 128),
+)
+
+_BASELINE_SORT_TESTS = ((16384,),)
+
+
+def numpy_dtype(dtype):
+ types = {
+ torch.int16: np.int16,
+ torch.int32: np.int32,
+ torch.int64: np.int64,
+ }
+ return types[dtype]
+
+
+def benchmark_function(fn, iterations=10):
+ # Run once to get rid of startup overhead.
+ fn()
+ times = []
+ for _ in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+ start.record()
+ fn()
+ end.record()
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ times = np.array(times)
+ return times.mean(), times.std(), times.max(), times.min()
+
+
+def log_benchmark(arguments, mean_t, std_t):
+ print('=' * 60)
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t))
+ print('=' * 60)
+
+
+class SortBenchmark(parameterized.TestCase):
+
+ @parameterized.parameters(*_SORT_TESTS)
+ def testSort(self, n, dtype, max_val):
+ if max_val is None:
+ max_val = np.iinfo(numpy_dtype(dtype)).max
+ end_bit = int(np.ceil(np.log2(max_val)))
+ x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
+
+ mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.sort(x, end_bit),)
+ arguments = {
+ 'n': n,
+ 'dtype': dtype,
+ 'max_val': max_val,
+ }
+ log_benchmark(arguments, mean_t, std_t)
+
+ @parameterized.parameters(*_BASELINE_SORT_TESTS)
+ def testTorchSort(self, n):
+ x = torch.randint(0, 128, (n,)).cuda().to(torch.int32)
+
+ mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.sort(x))
+ arguments = {
+ 'n': n,
+ }
+ log_benchmark(arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/ops/stk_autocast.py b/build/torch28-cxx11-cu128-x86_64-linux/ops/stk_autocast.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a3626e5e0eec51339c95a448bca84be14a2ca93
--- /dev/null
+++ b/build/torch28-cxx11-cu128-x86_64-linux/ops/stk_autocast.py
@@ -0,0 +1,39 @@
+# vendored from
+# https://github.com/stanford-futuredata/stk/blob/736313768ef697ce13a0594a41b2512a0fbc9884/stk/backend/autocast.py
+import functools
+import torch
+
+
+def _is_eligible(x):
+ return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64)
+
+
+def _cast(x, dtype):
+ if isinstance(x, torch.Tensor) and _is_eligible(x):
+ return x.to(dtype)
+ elif isinstance(x, map):
+ return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()}
+ elif isinstance(x, list) or isinstance(x, tuple):
+ return type(x)(map(lambda y: _cast(y, dtype), x))
+ return x
+
+
+def custom_fwd(fwd):
+ """Wrap a custom autograd function that always uses autocast dtype."""
+
+ @functools.wraps(fwd)
+ def decorate_fwd(*args, **kwargs):
+ if torch.is_autocast_enabled():
+ with torch.autocast(device_type="cuda", enabled=False):
+ dtype = torch.get_autocast_gpu_dtype()
+ return fwd(*_cast(args, dtype), **_cast(kwargs, dtype))
+ return fwd(*args, **kwargs)
+ return decorate_fwd
+
+
+def custom_bwd(bwd):
+ @functools.wraps(bwd)
+ def decorate_bwd(*args, **kwargs):
+ with torch.autocast(device_type="cuda", enabled=False):
+ return bwd(*args, **kwargs)
+ return decorate_bwd
\ No newline at end of file
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/ops/sum.py b/build/torch28-cxx11-cu128-x86_64-linux/ops/sum.py
new file mode 100644
index 0000000000000000000000000000000000000000..e00c1aa68e1193f5b72f75a2edc37de8d505facc
--- /dev/null
+++ b/build/torch28-cxx11-cu128-x86_64-linux/ops/sum.py
@@ -0,0 +1,9 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+import torch
+
+
+def sum(x: torch.Tensor, dim: int = 0):
+ if x.shape[dim] == 1:
+ return x.squeeze(dim=dim)
+ return x.sum(dim=dim)
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/ops/topology.py b/build/torch28-cxx11-cu128-x86_64-linux/ops/topology.py
new file mode 100644
index 0000000000000000000000000000000000000000..76a50d3164db20534b099dcb4d8487a7aef25d15
--- /dev/null
+++ b/build/torch28-cxx11-cu128-x86_64-linux/ops/topology.py
@@ -0,0 +1,45 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for topology kernel.
+# NOTE: Does not support gradients.
+class TopologyOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(
+ ctx: Any,
+ padded_bins: torch.Tensor,
+ block_size: int,
+ output_block_rows: int,
+ output_block_columns: int,
+ ):
+ out = torch.empty(
+ output_block_rows * output_block_columns,
+ dtype=torch.int16,
+ device=padded_bins.device,
+ )
+ ops.indices(
+ padded_bins,
+ block_size,
+ output_block_rows,
+ output_block_columns,
+ out,
+ )
+ return out
+
+
+topology = TopologyOp.apply
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/stk/__init__.py b/build/torch28-cxx11-cu128-x86_64-linux/stk/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..73c40c267b1c3f4949e9c957a5d2c9f682dfc1a6
--- /dev/null
+++ b/build/torch28-cxx11-cu128-x86_64-linux/stk/__init__.py
@@ -0,0 +1,7 @@
+# import stk.random
+# import stk.ops
+# from stk.matrix import Matrix
+
+from . import random
+from . import ops
+from .matrix import Matrix
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/stk/backend/__init__.py b/build/torch28-cxx11-cu128-x86_64-linux/stk/backend/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/stk/backend/autocast.py b/build/torch28-cxx11-cu128-x86_64-linux/stk/backend/autocast.py
new file mode 100644
index 0000000000000000000000000000000000000000..97f6e919a60f3fd579ed0215031008d14111dc96
--- /dev/null
+++ b/build/torch28-cxx11-cu128-x86_64-linux/stk/backend/autocast.py
@@ -0,0 +1,37 @@
+import functools
+import torch
+
+
+def _is_eligible(x):
+ return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64)
+
+
+def _cast(x, dtype):
+ if isinstance(x, torch.Tensor) and _is_eligible(x):
+ return x.to(dtype)
+ elif isinstance(x, map):
+ return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()}
+ elif isinstance(x, list) or isinstance(x, tuple):
+ return type(x)(map(lambda y: _cast(y, dtype), x))
+ return x
+
+
+def custom_fwd(fwd):
+ """Wrap a custom autograd function that always uses autocast dtype."""
+
+ @functools.wraps(fwd)
+ def decorate_fwd(*args, **kwargs):
+ if torch.is_autocast_enabled():
+ with torch.autocast(device_type="cuda", enabled=False):
+ dtype = torch.get_autocast_gpu_dtype()
+ return fwd(*_cast(args, dtype), **_cast(kwargs, dtype))
+ return fwd(*args, **kwargs)
+ return decorate_fwd
+
+
+def custom_bwd(bwd):
+ @functools.wraps(bwd)
+ def decorate_bwd(*args, **kwargs):
+ with torch.autocast(device_type="cuda", enabled=False):
+ return bwd(*args, **kwargs)
+ return decorate_bwd
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/stk/backend/sputnik.py b/build/torch28-cxx11-cu128-x86_64-linux/stk/backend/sputnik.py
new file mode 100644
index 0000000000000000000000000000000000000000..220c947bc1e932e8c77cc30f4069e9930f1aa962
--- /dev/null
+++ b/build/torch28-cxx11-cu128-x86_64-linux/stk/backend/sputnik.py
@@ -0,0 +1,316 @@
+import torch
+
+from ..backend import triton_kernels as backend
+from ..backend.autocast import custom_bwd, custom_fwd
+
+
+def _standardize_shape(x, transpose):
+ if transpose:
+ return torch.Size((x[1], x[0]))
+ return x
+
+
+def _sparse_transpose(x):
+ return (torch.Size((x[0][1], x[0][0])), ) + x[1:]
+
+
+def _transpose_helper(x, transpose):
+ if isinstance(x, torch.Tensor):
+ return x.t() if transpose else x
+ if transpose:
+ x = _sparse_transpose(x)
+ return x + (transpose,)
+
+
+def _wrap(x):
+ if isinstance(x, torch.Tensor):
+ return (x,)
+ return x
+
+
+def _is_transposed(x):
+ return (not x.is_contiguous() and
+ x.stride()[0] == 1 and
+ x.stride()[1] == x.size()[0])
+
+
+def _call_helper(op, out, a, b, trans_a, trans_b):
+ args = (_wrap(_transpose_helper(a, trans_a)) +
+ _wrap(_transpose_helper(b, trans_b)))
+ if isinstance(out, tuple):
+ args = args + out
+ return op(*args)
+
+
+def _preprocess_inputs(lhs, rhs, dy):
+ if isinstance(lhs, torch.Tensor) and _is_transposed(lhs):
+ lhs = lhs.t()
+ if isinstance(rhs, torch.Tensor) and _is_transposed(rhs):
+ rhs = rhs.t()
+ if (isinstance(dy, torch.Tensor) and
+ not dy.is_contiguous() and
+ not _is_transposed(dy)):
+ dy = dy.contiguous()
+ if isinstance(dy, tuple) and not dy[1].is_contiguous():
+ dy = (dy[0], dy[1].contiguous()) + dy[2:]
+ return lhs, rhs, dy
+
+
+def _postprocess_outputs(x, transpose, grad):
+ if isinstance(x, torch.Tensor) and transpose:
+ return grad.t()
+ return grad
+
+
+def _lhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs):
+ lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy)
+
+ a, b = (rhs, dy) if trans_lhs else (dy, rhs)
+ trans_a = trans_lhs and trans_rhs
+ trans_b = trans_lhs or not trans_rhs
+ out = _call_helper(op, lhs, a, b, trans_a, trans_b)
+ return _postprocess_outputs(lhs, trans_lhs, out)
+
+
+def _rhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs):
+ lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy)
+
+ a, b = (dy, lhs) if trans_rhs else (lhs, dy)
+ trans_a = not trans_lhs or trans_rhs
+ trans_b = trans_lhs and trans_rhs
+ out = _call_helper(op, rhs, a, b, trans_a, trans_b)
+ return _postprocess_outputs(rhs, trans_rhs, out)
+
+
+class DSD(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs):
+ ctx.save_for_backward(data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ rhs)
+ ctx.shape = _standardize_shape(shape, transpose_a)
+ ctx.transpose_a = transpose_a
+
+ out = torch.empty(
+ (shape[0], rhs.size()[1]),
+ dtype=rhs.dtype,
+ device=rhs.device)
+
+ backend.dsd(shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs,
+ out)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs = (ctx.shape,) + saved_tensors[:-1]
+ rhs = saved_tensors[-1]
+ trans_a = ctx.transpose_a
+ trans_b = _is_transposed(rhs)
+
+ ddata = None
+ if ctx.needs_input_grad[1]:
+ ddata = _lhs_gradient(sdd,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ drhs = None
+ if ctx.needs_input_grad[-1]:
+ op = dds if trans_b else dsd
+ drhs = _rhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return None, ddata, None, None, None, None, None, None, None, drhs
+
+
+dsd = DSD.apply
+
+
+class DDS(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b):
+ ctx.save_for_backward(lhs,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t)
+ ctx.shape = _standardize_shape(shape, transpose_b)
+ ctx.transpose_b = transpose_b
+ out = torch.empty((lhs.size()[0], shape[1]),
+ dtype=lhs.dtype,
+ device=lhs.device)
+ backend.dds(lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b,
+ out)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs = saved_tensors[0]
+ rhs = (ctx.shape,) + saved_tensors[1:]
+ trans_a = _is_transposed(lhs)
+ trans_b = ctx.transpose_b
+
+ dlhs = None
+ if ctx.needs_input_grad[0]:
+ op = dsd if trans_a else dds
+ dlhs = _lhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ ddata = None
+ if ctx.needs_input_grad[2]:
+ ddata = _rhs_gradient(sdd,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return dlhs, None, ddata, None, None, None, None, None, None, None
+
+
+dds = DDS.apply
+
+
+class SDD(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ lhs,
+ rhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t):
+ ctx.save_for_backward(
+ lhs,
+ rhs,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t)
+ ctx.shape = shape
+ out = torch.empty(
+ data.shape,
+ dtype=lhs.dtype,
+ device=lhs.device)
+ backend.sdd(lhs,
+ rhs,
+ shape,
+ out,
+ offsets,
+ row_indices,
+ column_indices)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs, rhs = saved_tensors[:2]
+ dy = (ctx.shape, dy) + saved_tensors[2:]
+ trans_a = _is_transposed(lhs)
+ trans_b = _is_transposed(rhs)
+
+ dlhs = None
+ if ctx.needs_input_grad[0]:
+ op = dds if trans_a else dsd
+ dlhs = _lhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ drhs = None
+ if ctx.needs_input_grad[1]:
+ op = dsd if trans_b else dds
+ drhs = _rhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return dlhs, drhs, None, None, None, None, None, None, None, None
+
+
+sdd = SDD.apply
+
+class RowIndices(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, shape, data, offsets, column_indices):
+ out = torch.empty(
+ column_indices.shape,
+ dtype=column_indices.dtype,
+ device=column_indices.device)
+ backend.row_indices(shape, data, offsets, column_indices, out)
+ return out
+
+
+row_indices = RowIndices.apply
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/stk/backend/triton_kernels.py b/build/torch28-cxx11-cu128-x86_64-linux/stk/backend/triton_kernels.py
new file mode 100644
index 0000000000000000000000000000000000000000..c535309f3321249f475367164a558f94a4f8eb86
--- /dev/null
+++ b/build/torch28-cxx11-cu128-x86_64-linux/stk/backend/triton_kernels.py
@@ -0,0 +1,393 @@
+import torch
+import triton
+import triton.language as tl
+from dataclasses import dataclass
+
+@dataclass
+class TritonConfig:
+ BLOCK_M: int = 128
+ BLOCK_N: int = 128
+ BLOCK_K: int = 32
+ BLOCK_SIZE: int = 128
+ NUM_STAGES: int = 4
+ NUM_WARPS: int = 4
+
+def _validate_matmul_dims(M: int, K: int, N: int):
+ error_string = "incompatible dimensions: tensor has dim with length: {}, which must be divisible by {}"
+ assert M % TritonConfig.BLOCK_M == 0, error_string.format(M, TritonConfig.BLOCK_M)
+ assert K % TritonConfig.BLOCK_K == 0, error_string.format(K, TritonConfig.BLOCK_K)
+ assert N % TritonConfig.BLOCK_N == 0, error_string.format(N, TritonConfig.BLOCK_N)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _sdd_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+ # matrix multiplication
+ pid = tl.program_id(0)
+ pid_m = tl.load(row_indices + pid)
+ pid_n = tl.load(column_indices + pid)
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
+ rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
+ rk = tl.arange(0, BLOCK_K)
+ # pointers
+ A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
+ B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ for k in range(0, tl.cdiv(K, BLOCK_K)):
+ a = tl.load(A)
+ b = tl.load(B)
+ acc += tl.dot(a, b)
+ A += BLOCK_K * stride_ak
+ B += BLOCK_K * stride_bk
+ #Store to sparse matrix
+ acc = acc.to(C.dtype.element_ty)
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+ cm = tl.arange(0, BLOCK_M)
+ cn = tl.arange(0, BLOCK_N)
+ C = C + pid * BLOCK_ELEMENTS + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _dsd_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices, offsets,
+ block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+
+ # matrix multiplication
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+
+ num_pid_m = tl.num_programs(0)
+ num_pid_n = tl.num_programs(1)
+ pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M)
+
+ start_inx = tl.load(offsets + pid_m)
+ end_inx = tl.load(offsets + pid_m + 1)
+
+ # pointers to sparse matrix
+ rm = tl.arange(0, BLOCK_M)
+ rak = tl.arange(0, BLOCK_K)
+
+ A += (rm[:, None] * stride_am + rak[None, :] * stride_ak)
+
+ # pointers to dense matrix
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ rbk = tl.arange(0, BLOCK_K)
+ B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn)
+
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K)
+
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+ ak_sub_incr = BLOCK_K * stride_ak
+ bk_sub_incr = BLOCK_K * stride_bk
+ bk_block_incr = BLOCK_SIZE * stride_bk
+
+ for k in range(nsub_blocks * (end_inx - start_inx)):
+ sub_block_inx = k % nsub_blocks
+ block_inx = k // nsub_blocks
+
+ if trans_A:
+ ptr_A = A + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr
+ else:
+ ptr_A = A + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr
+
+ ptr_B = B + tl.load(column_indices + start_inx + block_inx) * bk_block_incr + sub_block_inx * bk_sub_incr
+
+ a = tl.load(ptr_A)
+ b = tl.load(ptr_B)
+ acc += tl.dot(a, b)
+
+ acc = acc.to(C.dtype.element_ty)
+
+ cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+
+ C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _dds_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices, offsets,
+ block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+
+ # matrix multiplication
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+
+ num_pid_m = tl.num_programs(0)
+ num_pid_n = tl.num_programs(1)
+ pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M)
+
+ start_inx = tl.load(offsets + pid_n)
+ end_inx = tl.load(offsets + pid_n + 1)
+
+ # pointers to dense matrix
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ rak = tl.arange(0, BLOCK_K)
+
+ A += (rm[:, None] * stride_am + rak[None, :] * stride_ak)
+
+ # pointers to sparse matrix
+ rn = tl.arange(0, BLOCK_N)
+ rbk = tl.arange(0, BLOCK_K)
+ B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn)
+
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K)
+
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+
+ ak_sub_incr = BLOCK_K * stride_ak
+ ak_block_incr = BLOCK_SIZE * stride_ak
+ bk_sub_incr = BLOCK_K * stride_bk
+
+ for k in range(nsub_blocks * (end_inx - start_inx)):
+ sub_block_inx = k % nsub_blocks
+ block_inx = k // nsub_blocks
+
+ if trans_B:
+ ptr_B = B + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr
+ else:
+ ptr_B = B + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr
+
+ ptr_A = A + tl.load(column_indices + start_inx + block_inx) * ak_block_incr + sub_block_inx * ak_sub_incr
+ a = tl.load(ptr_A)
+ b = tl.load(ptr_B)
+ acc += tl.dot(a, b)
+
+ acc = acc.to(C.dtype.element_ty)
+ cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+def dsd(shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs,
+ out
+ ):
+
+ device = rhs.device
+ trans_A = transpose_a
+ trans_B = False
+
+ if rhs.stride(0) > 1 and rhs.stride(1) > 1:
+ trans_B = True
+
+ # checks constraints
+ assert shape[1] == rhs.shape[0], "incompatible dimensions"
+ M, K = shape
+ _, N = rhs.shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if rhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ stride_am, stride_ak = data.stride(1), data.stride(2)
+ stride_bk, stride_bn = rhs.stride(0), rhs.stride(1)
+ a_column_indices = column_indices
+ a_offsets = offsets
+
+ # launch kernel
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N']))
+
+ if trans_A:
+ stride_am, stride_ak = data.stride(2), data.stride(1)
+ a_column_indices, a_offsets = column_indices_t, offsets_t
+
+ if trans_B:
+ stride_bk, stride_bn = rhs.stride(1), rhs.stride(0)
+
+ _dsd_kernel[grid](
+ data.data, rhs, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(0), out.stride(1),
+ row_indices, a_column_indices, a_offsets,
+ block_offsets_t, trans_A, trans_B,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+ # return out
+
+def dds(lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b,
+ out
+ ):
+
+ device = lhs.device
+ trans_B = transpose_b
+ trans_A = False
+
+ if lhs.stride(0) > 1 and lhs.stride(1) > 1:
+ trans_A = True
+
+ # checks constraints
+ assert lhs.shape[1] == shape[0], "incompatible dimensions"
+ M, K = lhs.shape
+ _, N = shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if lhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ stride_am, stride_ak = lhs.stride(0), lhs.stride(1)
+ stride_bk, stride_bn = data.stride(1), data.stride(2)
+ b_column_indices = column_indices_t
+ b_offsets = offsets_t
+
+ # launch kernel
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N']))
+
+ if trans_A:
+ stride_am, stride_ak = lhs.stride(1), lhs.stride(0)
+ if trans_B:
+ stride_bk, stride_bn = data.stride(2), data.stride(1)
+ b_column_indices, b_offsets = column_indices, offsets
+
+ _dds_kernel[grid](
+ lhs, data, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(0), out.stride(1),
+ row_indices, b_column_indices, b_offsets,
+ block_offsets_t, trans_A, trans_B,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+
+def sdd(lhs,
+ rhs,
+ shape,
+ out,
+ offsets,
+ row_indices,
+ column_indices
+ ):
+
+ device = out.device
+ trans_A = False
+ trans_B = False
+
+ if lhs.stride(0) > 1 and lhs.stride(1) > 1:
+ trans_A = True
+ if rhs.stride(0) > 1 and rhs.stride(1) > 1:
+ trans_B = True
+
+ # checks constraints
+ assert lhs.shape[1] == rhs.shape[0], "incompatible dimensions"
+ M, K = lhs.shape
+ _, N = rhs.shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if out.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ # launch kernel
+ nnz_blocks = len(row_indices)
+ grid = lambda META: (nnz_blocks,)
+
+ stride_am, stride_ak = lhs.stride(0), lhs.stride(1)
+ stride_bk, stride_bn = rhs.stride(0), rhs.stride(1)
+
+ if trans_A:
+ stride_am, stride_ak = lhs.stride(1), lhs.stride(0)
+ if trans_B:
+ stride_bk, stride_bn = rhs.stride(1), rhs.stride(0)
+
+ _sdd_kernel[grid](
+ lhs, rhs, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(1), out.stride(2),
+ row_indices, column_indices,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+
+@triton.jit
+def _row_indices_kernel(offsets, out):
+ pid = tl.program_id(0)
+ row_offset = tl.load(offsets + pid)
+ nnz_blocks = tl.load(offsets + pid + 1) - row_offset
+ for nnz_block in range(nnz_blocks):
+ tl.store(out + row_offset + nnz_block, pid)
+
+def row_indices(
+ shape, data, offsets, column_indices, out
+):
+ block_rows = len(offsets) - 1
+ _row_indices_kernel[(block_rows, )](offsets, out)
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/stk/matrix.py b/build/torch28-cxx11-cu128-x86_64-linux/stk/matrix.py
new file mode 100644
index 0000000000000000000000000000000000000000..80f42263d6aada287adbfa52a61fe950162a9e28
--- /dev/null
+++ b/build/torch28-cxx11-cu128-x86_64-linux/stk/matrix.py
@@ -0,0 +1,329 @@
+import numpy as np
+import torch
+
+# 1. Add heavyweight (data) validation helper.
+# 2. Add construction helpers
+# 3. Make indentation consistent
+# 4. Replace asserts with descriptive errors.
+
+##
+### Validation helpers.
+##
+
+
+def _validate_matrix(shape, data, row_indices, column_indices, offsets):
+ # Data should be [nnz, block_size, block_size]
+ if data.dim() == 1:
+ data = torch.reshape(data, [data.numel(), 1, 1])
+
+ # Blocks should be square.
+ if data.shape[-2] != data.shape[-1]:
+ raise ValueError(
+ "Expected square blocking in data. "
+ f"Got block shape {[data.shape[-2], data.shape[-1]]}")
+
+ # Flatten batch dimensions on data - original shape preserved
+ # in shape argument.
+ block_size = data.shape[-1]
+ data = data.view([-1, block_size, block_size])
+
+ if data.dim() != 3:
+ raise ValueError(
+ "Expected 3D shape for data (nnz, block, block). "
+ f"Got shape {data.dim()}D shape.")
+
+ block_size = data.shape[1]
+ if shape[-2] % block_size != 0 or shape[-1] % block_size != 0:
+ raise ValueError(
+ "Matrix shape must be dividible by blocking. "
+ f"Got shape {shape} with "
+ f"{[block_size, block_size]} blocking.")
+
+ if np.prod(shape) < data.numel():
+ raise ValueError(
+ "Invalid matrix. Number of nonzeros exceeds matrix capacity "
+ f"({data.numel()} v. {np.prod(shape)})")
+
+ if row_indices.dim() != 1:
+ raise ValueError(
+ f"Expected 1D row_indices. Got {row_indices.dim()}D row_indices.")
+
+ if column_indices.dim() != 1:
+ raise ValueError(
+ f"Expected 1D column_indices. Got {column_indices.dim()}D column_indices.")
+
+ if offsets.dim() != 1:
+ raise ValueError(
+ f"Expected 1D offsets. Got {offsets.dim()}D offsets.")
+
+ if row_indices.numel() != data.shape[0]:
+ raise ValueError(
+ "Expected 1 index per nonzero block. "
+ f"Got {row_indices.numel()} row_indices for {data.shape[0]} blocks")
+
+ if column_indices.numel() != data.shape[0]:
+ raise ValueError(
+ "Expected 1 index per nonzero block. "
+ f"Got {column_indices.numel()} column_indices for {data.shape[0]} blocks")
+
+ block_rows = np.prod(shape[:-1]) / block_size
+ if offsets.numel() != block_rows + 1:
+ raise ValueError(
+ "Expected one offset per block row plus one. "
+ f"Got {offsets.numel()} offsets with {block_rows} block rows.")
+
+ is_cuda = (data.is_cuda and
+ row_indices.is_cuda and
+ column_indices.is_cuda and
+ offsets.is_cuda)
+ is_cpu = (not data.is_cuda and
+ not row_indices.is_cuda and
+ not column_indices.is_cuda and
+ not offsets.is_cuda)
+ if not (is_cuda or is_cpu):
+ raise ValueError(
+ "Expected data & meta-data on common device. "
+ f"Got data on {data.device}, row_indices on {row_indices.device} "
+ f"column_indices on {column_indices.device} and "
+ f"offsets on {offsets.device}.")
+
+ if data.dtype != torch.float16:
+ raise ValueError(
+ f"Expected float16 data. Got {data.dtype} data.")
+ if row_indices.dtype != torch.int16:
+ raise ValueError(
+ f"Expected int16 row_indices. Got {row_indices.dtype} row_indices.")
+ if column_indices.dtype != torch.int16:
+ raise ValueError(
+ f"Expected int16 column_indices. Got {column_indices.dtype} column_indices.")
+ if offsets.dtype != torch.int32:
+ raise ValueError(
+ f"Expected int32 offsets. Got {offsets.dtype} offsets.")
+ return data
+
+
+def _transpose(size, data, row_indices, column_indices, offsets):
+ block_columns = size[1] // data.shape[1]
+
+ # Sort row indices by column indices to get the transposed matrix's
+ # column indices.
+ gather_indices = column_indices.argsort()
+ column_indices_t = row_indices.gather(0, gather_indices)
+ block_offsets_t = gather_indices.int()
+
+ # NOTE: Histogram is not implemented for any integer type on CPU. Do
+ # the histogram in 32-bit float, which can exactly represent 16-bit
+ # integers.
+ column_indices_float = column_indices.float()
+
+ zero = torch.zeros((1,), dtype=torch.int32, device=data.device)
+ nnz_per_column = column_indices_float.histc(block_columns, 0, block_columns)
+ nnz_per_column = nnz_per_column.int()
+ offsets_t = torch.cat([zero, nnz_per_column.cumsum(0, dtype=torch.int32)])
+ return column_indices_t, offsets_t, block_offsets_t
+
+
+class Matrix(torch.nn.Module):
+ """A matrix stored in sparse format.
+
+ Underlying format is block compressed sparse row (BCSR).
+
+ TODO(tgale): Make this mirror torch.Tensor API as much as possible.
+ """
+
+ def __init__(self,
+ size,
+ data,
+ row_indices,
+ column_indices,
+ offsets,
+ column_indices_t=None,
+ offsets_t=None,
+ block_offsets_t=None):
+ super().__init__()
+ self._size = size
+ self._data = data
+ self._row_indices = row_indices
+ self._column_indices = column_indices
+ self._offsets = offsets
+
+ # Produce the transpose meta-data if it is not passed in.
+ if ((column_indices_t is None) or (offsets_t is None) or
+ (block_offsets_t is None)):
+ column_indices_t, offsets_t, block_offsets_t = _transpose(
+ size, data, row_indices, column_indices, offsets)
+ self._column_indices_t = column_indices_t
+ self._offsets_t = offsets_t
+ self._block_offsets_t = block_offsets_t
+
+ self._transposed = False
+
+ # Validate that our metadata will not overflow.
+ max_dim = np.iinfo(np.int16).max * self.blocking
+ if column_indices.dtype == torch.int16:
+ if size[0] > max_dim or size[1] > max_dim:
+ raise ValueError(
+ "Sparse matrix with shape {size} exceeds representable "
+ "size with 16-bit indices.")
+
+ def validate(self):
+ _validate_matrix(self._size,
+ self._data,
+ self._row_indices,
+ self._column_indices,
+ self._offsets)
+
+ # TODO(tgale): Add heavyweight data validation.
+
+ def to(self, device):
+ # TODO(tgale): Handle type conversions here. We
+ # need to set the appropriate meta-data type for
+ # the given floating-point type.
+ self._data = self._data.to(device)
+ self._row_indices = self._row_indices.to(device)
+ self._column_indices = self._column_indices.to(device)
+ self._offsets = self._offsets.to(device)
+ self._column_indices_t = self._column_indices_t.to(device)
+ self._offsets_t = self._offsets_t.to(device)
+ self._block_offsets_t = self._block_offsets_t.to(device)
+ return self
+
+ def cuda(self):
+ return self.to(torch.cuda.current_device())
+
+ def clone(self):
+ return Matrix(
+ self.size(),
+ self.data.clone(),
+ self.row_indices.clone(),
+ self.column_indices.clone(),
+ self.offsets.clone(),
+ self.column_indices_t.clone(),
+ self.offsets_t.clone(),
+ self.block_offsets_t.clone())
+
+ def t(self):
+ if self.dim() != 2:
+ raise ValueError(
+ "t() expects a tensor with <= 2 dimensions, "
+ f"but self is {self.dim()}D.")
+ out = Matrix(self.size(),
+ self.data,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+ out._transposed = not self._transposed
+ out._size = torch.Size((self._size[1], self._size[0]))
+ return out
+
+ def contiguous(self):
+ raise ValueError("Not yet implemented.")
+
+ def is_contiguous(self):
+ return not self._transposed
+
+ @property
+ def is_cuda(self):
+ return self._data.is_cuda
+
+ @property
+ def device(self):
+ return self._data.device
+
+ def size(self):
+ return self._size
+
+ @property
+ def shape(self):
+ return self.size()
+
+ def dim(self):
+ return len(self._size)
+
+ @property
+ def data(self):
+ return self._data
+
+ @property
+ def row_indices(self):
+ return self._row_indices
+
+ @property
+ def column_indices(self):
+ return self._column_indices
+
+ @property
+ def offsets(self):
+ return self._offsets
+
+ @property
+ def offsets_t(self):
+ return self._offsets_t
+
+ @property
+ def column_indices_t(self):
+ return self._column_indices_t
+
+ @property
+ def block_offsets_t(self):
+ return self._block_offsets_t
+
+ @property
+ def dtype(self):
+ return self.data.dtype
+
+ @property
+ def nnz(self):
+ return self.data.numel()
+
+ @property
+ def blocking(self):
+ return self.data.shape[1]
+
+ @property
+ def requires_grad(self):
+ return self.data.requires_grad
+
+ def requires_grad_(self, x):
+ self.data.requires_grad_(x)
+ return self
+
+ def view(self, *shape):
+ assert self.is_contiguous()
+ if shape[-1] != self.size()[-1]:
+ raise ValueError(
+ "Can't change view on compressed dimension. "
+ f"{self.size()[-1]} v. {shape[-1]}.")
+ if np.prod(shape) != np.prod(self.size()):
+ raise ValueError(
+ "Mismatch in numel of Matrix and new shape. "
+ f"{np.prod(self.size())} v. {np.prod(shape)}")
+ return Matrix(shape,
+ self.data,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+
+ @property
+ def grad(self):
+ # TODO(tgale): Make sure this mirrors torch.Tensor
+ # behavior in the case where we ask for the gradient
+ # of a non-contiguous tensor.
+ size = self.size()
+ if not self.is_contiguous():
+ size = torch.Size((size[1], size[0]))
+ out = Matrix(size,
+ self.data.grad,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+ return out if self.is_contiguous() else out.t()
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/stk/ops/__init__.py b/build/torch28-cxx11-cu128-x86_64-linux/stk/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc873b236f4cd4036964c016a4036e3ce5ebf1ac
--- /dev/null
+++ b/build/torch28-cxx11-cu128-x86_64-linux/stk/ops/__init__.py
@@ -0,0 +1,3 @@
+from .linear_ops import dds, dsd, sdd
+from .matrix_ops import ones_like, row_indices, sum, to_dense, to_sparse
+from .eltwise_ops import mul
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/stk/ops/eltwise_ops.py b/build/torch28-cxx11-cu128-x86_64-linux/stk/ops/eltwise_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba7d7332320250fd01fa60e60528f19de3e8ed03
--- /dev/null
+++ b/build/torch28-cxx11-cu128-x86_64-linux/stk/ops/eltwise_ops.py
@@ -0,0 +1,28 @@
+from ..matrix import Matrix
+
+def mul(a, b):
+ """Performs element-wise multiplication of matrices a and b.
+
+ It is the user's responsibility to make sure that a and b
+ follow the same matrix topology. This function assumes it is safe
+ to use the topoplogy of a.
+
+ Args:
+ a: stk.Matrix.
+ b: stk.Matrix with a's matrix topology.
+
+ Returns:
+ stk.Matrix where the entries correspond to torch.mul(a, b).
+ """
+ assert isinstance(a, Matrix)
+ assert isinstance(b, Matrix)
+ assert a.size() == b.size()
+
+ return Matrix(a.size(),
+ a.data * b.data,
+ a.row_indices,
+ a.column_indices,
+ a.offsets,
+ a.column_indices_t,
+ a.offsets_t,
+ a.block_offsets_t)
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/stk/ops/eltwise_ops_test.py b/build/torch28-cxx11-cu128-x86_64-linux/stk/ops/eltwise_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..66bfd4f6af77042d3c5bdb1fe18d00e457478d46
--- /dev/null
+++ b/build/torch28-cxx11-cu128-x86_64-linux/stk/ops/eltwise_ops_test.py
@@ -0,0 +1,86 @@
+import unittest
+import itertools
+import torch
+from absl.testing import parameterized
+
+import stk
+from stk.ops.linear_ops_test import allclose, _dense_and_sparse
+
+_MATRIX_SIZES = (
+ (128, 128, 0.0),
+ (256, 256, 0.5),
+ (2048, 1024, 0.8),
+ (512, 128, 0.0),
+ (128, 512, 0.0),
+ (1024, 512, 0.0),
+ (1024, 512, 0.5),
+ (1024, 512, 0.75),
+ (512, 1024, 0.0),
+ (512, 1024, 0.5),
+ (512, 1024, 0.75),
+ (1024, 1024, 0.0),
+ (1024, 1024, 0.5),
+ (1024, 1024, 0.75),
+)
+
+_DTYPE = (
+ torch.float16, torch.bfloat16
+)
+
+def _generate_testcases():
+ testcases = itertools.product(_MATRIX_SIZES, _DTYPE)
+ testcases = [(*size, 128, dtype) for
+ (size, dtype) in testcases]
+ return testcases
+
+_ELTWISE_OP_TESTS = _generate_testcases()
+
+def _dense_and_sparse_like(x, std=0.1):
+ dense_data = torch.randn_like(x.data, device=x.device) * std
+ sparse = stk.Matrix(x.size(),
+ dense_data,
+ x.row_indices,
+ x.column_indices,
+ x.offsets)
+ dense = stk.ops.to_dense(sparse)
+
+ return (dense.requires_grad_(True),
+ sparse.requires_grad_(True))
+
+@parameterized.parameters(_ELTWISE_OP_TESTS)
+class EltwiseOpsTest(parameterized.TestCase):
+
+ def testEltwiseMul(self, m, n, sparsity, blocking, dtype):
+
+ a_dense, a = _dense_and_sparse(m, n, sparsity, blocking, dtype)
+ b_dense, b = _dense_and_sparse_like(a)
+
+ out = stk.ops.mul(a, b)
+ expected_out = torch.mul(a_dense, b_dense)
+
+ # Compute the gradients w.r.t. the inputs.
+ expected_out.sum().backward()
+ stk.ops.sum(out).backward()
+
+ # Validate the results.
+ out = stk.ops.to_dense(out)
+ self.assertEqual(out.dim(), 2)
+ self.assertEqual(expected_out.size(), out.size())
+ self.assertTrue(allclose(out, expected_out))
+
+ # LHS gradient.
+ grad = stk.ops.to_dense(a.grad)
+ expected_grad = a_dense.grad
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size(), grad.size())
+ self.assertTrue(allclose(grad, expected_grad))
+
+ # RHS gradient.
+ grad = stk.ops.to_dense(b.grad)
+ expected_grad = b_dense.grad
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size(), grad.size())
+ self.assertTrue(allclose(grad, expected_grad))
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/stk/ops/linear_ops.py b/build/torch28-cxx11-cu128-x86_64-linux/stk/ops/linear_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d277c8c07f9e30addc31900a12175c8a1f4d7ad
--- /dev/null
+++ b/build/torch28-cxx11-cu128-x86_64-linux/stk/ops/linear_ops.py
@@ -0,0 +1,59 @@
+import torch
+
+from ..backend import sputnik
+from ..matrix import Matrix
+
+
+def dsd(a, b):
+ assert isinstance(a, Matrix)
+ assert isinstance(b, torch.Tensor)
+ return sputnik.dsd(
+ a.size(),
+ a.data, a.offsets,
+ a.row_indices,
+ a.column_indices,
+ a.offsets_t,
+ a.column_indices_t,
+ a.block_offsets_t,
+ not a.is_contiguous(),
+ b)
+
+
+def dds(a, b):
+ assert isinstance(a, torch.Tensor)
+ assert isinstance(b, Matrix)
+ return sputnik.dds(
+ a,
+ b.size(),
+ b.data, b.offsets,
+ b.row_indices,
+ b.column_indices,
+ b.offsets_t,
+ b.column_indices_t,
+ b.block_offsets_t,
+ not b.is_contiguous())
+
+
+def sdd(a, b, topo):
+ assert isinstance(a, torch.Tensor)
+ assert isinstance(b, torch.Tensor)
+ assert isinstance(topo, Matrix)
+ assert topo.is_contiguous()
+ out = sputnik.sdd(
+ a, b,
+ topo.size(),
+ topo.data,
+ topo.offsets,
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets_t,
+ topo.column_indices_t,
+ topo.block_offsets_t)
+ return Matrix(topo.size(),
+ out,
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets,
+ topo.column_indices_t,
+ topo.offsets_t,
+ topo.block_offsets_t)
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/stk/ops/linear_ops_test.py b/build/torch28-cxx11-cu128-x86_64-linux/stk/ops/linear_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..ced1d782fbc9f9ca16b3449239f1588dc5ff5e00
--- /dev/null
+++ b/build/torch28-cxx11-cu128-x86_64-linux/stk/ops/linear_ops_test.py
@@ -0,0 +1,216 @@
+import unittest
+import itertools
+import numpy as np
+import torch
+from absl.testing import parameterized
+
+import stk
+
+
+def allclose(x, y, pct=0.25):
+ mask = torch.isclose(x, y, rtol=5e-2)
+ pct_diff = (mask.numel() - mask.sum()) / mask.numel() * 100
+ if pct_diff > pct:
+ print("{:.2f}% of values not close.".format(pct_diff))
+ return False
+ return True
+
+
+# An assortment of problems designed to make sure
+# the bindings are operating correctly.
+_MATRIX_SIZES = (
+ (128, 128, 128, 0.0),
+ (256, 256, 256, 0.5),
+ (2048, 1024, 512, 0.8),
+ (512, 128, 128, 0.0),
+ (128, 128, 512, 0.0),
+ (1024, 512, 512, 0.0),
+ (1024, 512, 512, 0.5),
+ (1024, 512, 512, 0.75),
+ (512, 512, 1024, 0.0),
+ (512, 512, 1024, 0.5),
+ (512, 512, 1024, 0.75),
+ (1024, 1024, 1024, 0.0),
+ (1024, 1024, 1024, 0.5),
+ (1024, 1024, 1024, 0.75),
+)
+
+_TRANSPOSE = (
+ (False, False),
+ (False, True),
+ (True, False),
+ (True, True),
+)
+
+_DTYPE = (
+ torch.float16, torch.bfloat16
+)
+
+def _generate_testcases():
+ testcases = itertools.product(_MATRIX_SIZES, _TRANSPOSE, _DTYPE)
+ testcases = [(*size, *trans, 128, dtype) for
+ (size, trans, dtype) in testcases]
+ return testcases
+
+_LINEAR_OP_TESTS = _generate_testcases()
+
+def _dense_and_sparse(rows, cols, sparsity, blocking, dtype, std=0.1):
+ mask = stk.random.dense_mask(rows, cols, sparsity, blocking)
+ dense = (torch.randn(rows, cols) * std * mask).type(dtype)
+ sparse = stk.ops.to_sparse(dense, blocking)
+ cuda_device = torch.device("cuda")
+ return (dense.to(cuda_device).requires_grad_(True),
+ sparse.to(cuda_device).requires_grad_(True))
+
+
+def _dense(rows, cols, dtype, std=0.1):
+ cuda_device = torch.device("cuda")
+ out = (torch.randn(rows, cols) * std).type(dtype)
+ return out.to(cuda_device).requires_grad_(True)
+
+
+def _dense_2x(rows, cols, dtype):
+ a = _dense(rows, cols, dtype)
+ return a, a.detach().requires_grad_(True)
+
+
+def _with_transpose(op, a, b, trans_a, trans_b):
+ a = a.t() if trans_a else a
+ b = b.t() if trans_b else b
+ return op(a, b)
+
+
+def _mmm(a, b, topo):
+ mask = stk.ops.to_dense(stk.ops.ones_like(topo))
+ return torch.mm(a, b) * mask
+
+
+def _sparse_out_with_transpose(op, a, b, topo, trans_a, trans_b):
+ a = a.t() if trans_a else a
+ b = b.t() if trans_b else b
+ return op(a, b, topo)
+
+
+def _mask(x, mask):
+ mask = stk.ops.to_dense(stk.ops.ones_like(mask))
+ return x * mask
+
+
+@parameterized.parameters(*_LINEAR_OP_TESTS)
+class LinearOpsTest(parameterized.TestCase):
+
+ def testLinearOps_Dsd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+ # Construct the operands.
+ a_shape = (k, m) if trans_a else (m, k)
+ a_dense, a = _dense_and_sparse(*a_shape, sparsity, blocking, dtype)
+ b_shape = (n, k) if trans_b else (k, n)
+ b, bcp = _dense_2x(*b_shape, dtype)
+
+ # Execute the matmul.
+ out = _with_transpose(stk.ops.dsd, a, b, trans_a, trans_b)
+ expected_out = _with_transpose(torch.mm, a_dense, bcp, trans_a, trans_b)
+
+ # Compute the gradients w.r.t. the inputs.
+ expected_out.sum().backward()
+ out.sum().backward()
+
+ # Validate the results.
+ self.assertEqual(out.dim(), 2)
+ self.assertEqual(expected_out.size()[0], out.size()[0])
+ self.assertEqual(expected_out.size()[1], out.size()[1])
+ self.assertTrue(allclose(out, expected_out))
+
+ # LHS gradient.
+ grad = stk.ops.to_dense(a.grad)
+ expected_grad = _mask(a_dense.grad, a.grad)
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
+ self.assertTrue(allclose(grad, expected_grad))
+
+ # RHS gradient.
+ grad = b.grad
+ expected_grad = bcp.grad
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
+ self.assertTrue(allclose(grad, expected_grad))
+
+ def testLinearOps_Dds(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+ # Construct the operands.
+ a_shape = (k, m) if trans_a else (m, k)
+ a, acp = _dense_2x(*a_shape, dtype)
+ b_shape = (n, k) if trans_b else (k, n)
+ b_dense, b = _dense_and_sparse(*b_shape, sparsity, blocking, dtype)
+
+ # Execute the matmul.
+ out = _with_transpose(stk.ops.dds, a, b, trans_a, trans_b)
+ expected_out = _with_transpose(torch.mm, acp, b_dense, trans_a, trans_b)
+
+ # Compute the gradients w.r.t. the inputs.
+ expected_out.sum().backward()
+ out.sum().backward()
+
+ # Validate the results.
+ self.assertEqual(out.dim(), 2)
+ self.assertEqual(expected_out.size()[0], out.size()[0])
+ self.assertEqual(expected_out.size()[1], out.size()[1])
+ self.assertTrue(allclose(out, expected_out))
+
+ # LHS gradient.
+ grad = a.grad
+ expected_grad = acp.grad
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
+ self.assertTrue(allclose(grad, expected_grad))
+
+ # RHS gradient.
+ grad = stk.ops.to_dense(b.grad)
+ expected_grad = _mask(b_dense.grad, b.grad)
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
+ self.assertTrue(allclose(grad, expected_grad))
+
+ def testLinearOps_Sdd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+ # Construct the operands.
+ a_shape = (k, m) if trans_a else (m, k)
+ a, acp = _dense_2x(*a_shape, dtype)
+ b_shape = (n, k) if trans_b else (k, n)
+ b, bcp = _dense_2x(*b_shape, dtype)
+ _, topo = _dense_and_sparse(m, n, sparsity, blocking, dtype)
+
+ # Execute the matmul.
+ out = _sparse_out_with_transpose(stk.ops.sdd, a, b, topo, trans_a, trans_b)
+ expected_out = _sparse_out_with_transpose(_mmm, acp, bcp, topo, trans_a, trans_b)
+
+ # Compute the gradients w.r.t. the inputs.
+ expected_out.sum().backward()
+ stk.ops.sum(out).backward()
+
+ # Validate the results.
+ out = stk.ops.to_dense(out)
+ self.assertEqual(out.dim(), 2)
+ self.assertEqual(expected_out.size()[0], out.size()[0])
+ self.assertEqual(expected_out.size()[1], out.size()[1])
+ self.assertTrue(allclose(out, expected_out))
+
+ # LHS gradient.
+ grad = a.grad
+ expected_grad = acp.grad
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
+ self.assertTrue(allclose(grad, expected_grad))
+
+ # RHS gradient.
+ grad = b.grad
+ expected_grad = bcp.grad
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
+ self.assertTrue(allclose(grad, expected_grad))
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/stk/ops/matrix_ops.py b/build/torch28-cxx11-cu128-x86_64-linux/stk/ops/matrix_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..447c72dc73439d84f58c917676cc04e64f13e97d
--- /dev/null
+++ b/build/torch28-cxx11-cu128-x86_64-linux/stk/ops/matrix_ops.py
@@ -0,0 +1,98 @@
+from ..backend import sputnik
+from ..matrix import Matrix
+import torch
+import numpy as np
+
+
+@torch.no_grad()
+def row_indices(shape, data, offsets, column_indices):
+ return sputnik.row_indices(shape, data, offsets, column_indices)
+
+
+# TODO(tgale): Replace this helper with a custom kernel. This operation
+# is much simpler to do than how it's currently implemented.
+@torch.no_grad()
+def _expand_for_blocking(idxs, blocking):
+ # Duplicate for block column dimension.
+ idxs = torch.reshape(idxs, [idxs.size()[0], 1, 2]).repeat(1, blocking, 1)
+
+ # Update the column indices.
+ idxs[:, :, 1] *= blocking
+ idxs[:, :, 1] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking])
+
+ # Duplicate for block row dimension.
+ idxs = torch.reshape(idxs, [idxs.size()[0], 1, blocking, 2])
+ idxs = idxs.repeat(1, blocking, 1, 1)
+
+ # Update the row indices.
+ idxs[:, :, :, 0] *= blocking
+ idxs[:, :, :, 0] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking, 1])
+ idxs = torch.reshape(idxs, [-1, 2])
+ return idxs
+
+
+# TODO(tgale): Add input type checking.
+@torch.no_grad()
+def to_dense(x):
+ assert isinstance(x, Matrix)
+
+ shape = (np.prod(x.shape[:-1]), x.shape[-1])
+ row_idxs = x.row_indices.type(torch.int32)
+ col_idxs = x.column_indices.type(torch.int32)
+ indices = _expand_for_blocking(torch.stack([row_idxs, col_idxs], dim=1), x.blocking)
+ indices = (indices[:, 0] * shape[1] + indices[:, 1]).type(torch.int64)
+
+ out = torch.zeros(shape[0] * shape[1], dtype=x.dtype, device=x.device)
+ out.scatter_(0, indices, x.data.flatten())
+ return out.reshape(x.size())
+
+
+@torch.no_grad()
+def _mask(x, blocking=1):
+ assert x.dim() == 2
+ assert x.size()[0] % blocking == 0
+ assert x.size()[1] % blocking == 0
+ block_rows = x.size()[0] // blocking
+ block_cols = x.size()[1] // blocking
+ x = torch.reshape(x, [block_rows, blocking, block_cols, blocking])
+ x = torch.sum(torch.abs(x), dim=(1, 3))
+ return x != 0
+
+
+# TODO(tgale): Add input type checking.
+@torch.no_grad()
+def to_sparse(x, blocking=1):
+ m = _mask(x, blocking)
+
+ # TODO(tgale): Set to appropriate type for input matrix.
+ row_nnzs = torch.sum(m, dim=1).type(torch.int32)
+ zeros = torch.zeros((1,), dtype=row_nnzs.dtype, device=row_nnzs.device)
+ offsets = torch.cat([zeros, torch.cumsum(row_nnzs, dim=0)])
+ offsets = offsets.type(torch.int32)
+
+ indices = torch.nonzero(m).type(torch.int16)
+ row_indices = indices[:, 0]
+ column_indices = indices[:, 1]
+
+ # Nonzero indices in the dense matrix.
+ nonzero_indices = torch.nonzero(m)
+ nonzero_indices = _expand_for_blocking(nonzero_indices, blocking)
+ nonzero_indices = nonzero_indices[:, 0] * x.size()[1] + nonzero_indices[:, 1]
+
+ # Gather the data and construct the sparse matrix.
+ data = torch.gather(x.flatten(), dim=0, index=nonzero_indices)
+ data = torch.reshape(data, [-1, blocking, blocking])
+ return Matrix(x.size(), data, row_indices, column_indices, offsets)
+
+
+@torch.no_grad()
+def ones_like(x):
+ return Matrix(x.size(),
+ torch.ones_like(x.data),
+ x.row_indices,
+ x.column_indices, x.offsets)
+
+
+def sum(x):
+ assert isinstance(x, Matrix)
+ return x.data.sum()
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/stk/ops/matrix_ops_test.py b/build/torch28-cxx11-cu128-x86_64-linux/stk/ops/matrix_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..3af04c0760483e578f93303dc457415948a2a34c
--- /dev/null
+++ b/build/torch28-cxx11-cu128-x86_64-linux/stk/ops/matrix_ops_test.py
@@ -0,0 +1,62 @@
+import unittest
+
+from absl.testing import parameterized
+import stk
+import torch
+
+
+@parameterized.parameters(
+ (8, 16, 0.0, 1),
+ (8, 16, 0.5, 1),
+ (8, 16, .95, 1),
+ (16, 8, 0.0, 1),
+ (16, 8, 0.5, 1),
+ (16, 8, .95, 1),
+ (8, 16, 0.0, 8),
+ (8, 16, 0.5, 8),
+ (8, 16, 1.0, 8),
+ (16, 8, 0.0, 8),
+ (16, 8, 0.5, 8),
+ (16, 8, 1.0, 8),
+ (128, 256, 0.5, 16),
+ (256, 128, 0.75, 32),
+ (512, 512, .875, 128))
+class MatrixOpsTest(parameterized.TestCase):
+
+ def testMatrixOps_FormatConversion(self, rows, cols, sparsity, blocking):
+ mask = stk.random.dense_mask(rows, cols, sparsity, blocking)
+ x = (torch.randn(rows, cols) * mask).type(torch.float16)
+
+ # Convert the matrix to sparse format.
+ sparse_x = stk.ops.to_sparse(x, blocking)
+
+ # Validate the matrix.
+ sparse_x.validate()
+
+ # Validate the shape.
+ self.assertEqual(sparse_x.dim(), 2)
+ self.assertEqual(sparse_x.size()[0], rows)
+ self.assertEqual(sparse_x.size()[1], cols)
+
+ # Validate the sparsity.
+ numblocks = rows // blocking * cols // blocking
+ nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+ self.assertEqual(sparse_x.nnz, nnz)
+
+ # Convert back to dense format.
+ dense_x = stk.ops.to_dense(sparse_x)
+
+ # Validate the shape.
+ self.assertEqual(dense_x.dim(), 2)
+ self.assertEqual(dense_x.size()[0], rows)
+ self.assertEqual(dense_x.size()[1], cols)
+
+ # Validate the sparsity
+ self.assertEqual(torch.count_nonzero(dense_x).item(), nnz)
+
+ # Validate the output.
+ self.assertTrue(torch.all(torch.eq(x, dense_x)))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/stk/random/__init__.py b/build/torch28-cxx11-cu128-x86_64-linux/stk/random/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2576d1ca27283f77569a9a620c7c99fa68aaf30e
--- /dev/null
+++ b/build/torch28-cxx11-cu128-x86_64-linux/stk/random/__init__.py
@@ -0,0 +1,2 @@
+# from stk.random.random_ops import dense_mask, mask, randn
+from .random_ops import dense_mask, mask, randn
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/stk/random/random_ops.py b/build/torch28-cxx11-cu128-x86_64-linux/stk/random/random_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1b36771e0eb8e7abf46bcb3b136b5fb1d29df93
--- /dev/null
+++ b/build/torch28-cxx11-cu128-x86_64-linux/stk/random/random_ops.py
@@ -0,0 +1,36 @@
+import numpy as np
+import torch
+from ..ops import matrix_ops
+
+
+@torch.no_grad()
+def dense_mask(rows, cols, sparsity, blocking=1):
+ assert sparsity >= 0.0 and sparsity <= 1.0
+ assert rows % blocking == 0 and cols % blocking == 0
+
+ block_rows, block_cols = (rows // blocking, cols // blocking)
+ nnz = round(block_rows * block_cols * (1 - sparsity))
+
+ out = np.ones(block_rows * block_cols)
+ mask = np.random.choice(out.size, out.size - nnz, replace=False)
+ out[mask] = 0.0
+
+ out = np.tile(
+ np.reshape(out, [block_rows, 1, block_cols, 1]),
+ (1, blocking, 1, blocking))
+ out = np.reshape(out, [rows, cols])
+ return torch.from_numpy(out.astype(np.float32))
+
+
+@torch.no_grad()
+def mask(m, n, sparsity, blocking=1):
+ out = dense_mask(m, n, sparsity, blocking).type(torch.float16)
+ return matrix_ops.to_sparse(out, blocking=blocking)
+
+
+@torch.no_grad()
+def randn(shape, sparsity, blocking=1):
+ shape_2d = (np.prod(shape[:-1]), shape[-1])
+ out = mask(*shape_2d, sparsity, blocking)
+ out.data.copy_(torch.randn(*out.data.shape))
+ return out.view(*shape)
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/stk/random/random_ops_test.py b/build/torch28-cxx11-cu128-x86_64-linux/stk/random/random_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..587b44ec890c861879c6296b8f9028f5d99ab82f
--- /dev/null
+++ b/build/torch28-cxx11-cu128-x86_64-linux/stk/random/random_ops_test.py
@@ -0,0 +1,73 @@
+import unittest
+
+from absl.testing import parameterized
+from . import random
+import torch
+
+
+@parameterized.parameters(
+ (8, 16, 0.0, 1),
+ (8, 16, 0.5, 1),
+ (8, 16, .95, 1),
+ (16, 8, 0.0, 1),
+ (16, 8, 0.5, 1),
+ (16, 8, .95, 1),
+ (8, 16, 0.0, 8),
+ (8, 16, 0.5, 8),
+ (8, 16, 1.0, 8),
+ (16, 8, 0.0, 8),
+ (16, 8, 0.5, 8),
+ (16, 8, 1.0, 8),
+ (128, 256, 0.5, 16),
+ (256, 128, 0.75, 32),
+ (512, 512, .875, 128))
+class RandomOpsTest(parameterized.TestCase):
+
+ def testRandomOps_DenseMask(self, rows, cols, sparsity, blocking):
+ mask = random.dense_mask(
+ rows, cols, sparsity, blocking)
+
+ # Validate the shape.
+ self.assertEqual(mask.dim(), 2)
+ self.assertEqual(mask.size()[0], rows)
+ self.assertEqual(mask.size()[1], cols)
+
+ # Validate the sparsity
+ numblocks = rows // blocking * cols // blocking
+ nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+ self.assertEqual(
+ torch.count_nonzero(mask).item(),
+ nnz)
+
+ # Check values are zero or one.
+ self.assertTrue(
+ torch.all(torch.logical_or(
+ torch.eq(mask, 0),
+ torch.eq(mask, 1))))
+
+ def testRandomOps_SparseMask(self, rows, cols, sparsity, blocking):
+ mask = random.mask(
+ rows, cols, sparsity, blocking)
+
+ # Validate the matrix.
+ mask.validate()
+
+ # Validate the shape.
+ self.assertEqual(mask.dim(), 2)
+ self.assertEqual(mask.size()[0], rows)
+ self.assertEqual(mask.size()[1], cols)
+
+ # Validate the sparsity.
+ numblocks = rows // blocking * cols // blocking
+ nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+ self.assertEqual(mask.nnz, nnz)
+
+ # Check values are zero or one.
+ self.assertTrue(
+ torch.all(torch.logical_or(
+ torch.eq(mask.data, 0),
+ torch.eq(mask.data, 1))))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/xpu_fused_moe.py b/build/torch28-cxx11-cu128-x86_64-linux/xpu_fused_moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..d2e7c6692f101f9141e9d716c8af6ac92be95351
--- /dev/null
+++ b/build/torch28-cxx11-cu128-x86_64-linux/xpu_fused_moe.py
@@ -0,0 +1,577 @@
+# SPDX-License-Identifier: Apache-2.0
+# MegaBlocks XPU Fused MoE Implementation
+import os
+import torch
+
+from ._ops import ops
+
+
+# Install meta kernels for torch.compile compatibility
+def _install_xpu_meta_kernels():
+ """Install meta kernels for XPU MoE operations to support torch.compile"""
+
+ # Patch cutlass_grouped_gemm_interface
+ if hasattr(ops, "cutlass_grouped_gemm_interface"):
+ original_gemm = ops.cutlass_grouped_gemm_interface
+
+ def gemm_with_meta(ptr_A, ptr_B, ptr_scales, ptr_bias, ptr_D,
+ expert_first_token_offset, N, K, num_experts,
+ is_B_int4, is_B_mxfp4):
+ if torch.compiler.is_compiling():
+ # Meta implementation - ptr_D is the output, return it
+ return ptr_D
+ return original_gemm(ptr_A, ptr_B, ptr_scales, ptr_bias, ptr_D,
+ expert_first_token_offset, N, K, num_experts,
+ is_B_int4, is_B_mxfp4)
+
+ ops.cutlass_grouped_gemm_interface = gemm_with_meta
+
+ # Patch fused_moe_prologue
+ if hasattr(ops, "fused_moe_prologue"):
+ original_prologue = ops.fused_moe_prologue
+
+ def prologue_with_meta(input, token_selected_experts, token_final_scales,
+ workspace, hidden_size, inter_size, num_experts_on_rank):
+ if torch.compiler.is_compiling():
+ # Meta implementation - this op modifies workspace in-place
+ return None
+ return original_prologue(input, token_selected_experts, token_final_scales,
+ workspace, hidden_size, inter_size, num_experts_on_rank)
+
+ ops.fused_moe_prologue = prologue_with_meta
+
+ # Patch moe_gather
+ if hasattr(ops, "moe_gather"):
+ original_gather = ops.moe_gather
+
+ def gather_with_meta(output, moe_output, topk_weights,
+ unpermuted_row_to_permuted_row, num_experts):
+ if torch.compiler.is_compiling():
+ # Meta implementation - output is modified in-place
+ return None
+ return original_gather(output, moe_output, topk_weights,
+ unpermuted_row_to_permuted_row, num_experts)
+
+ ops.moe_gather = gather_with_meta
+
+ # Patch activation ops
+ for act_name in ["silu_and_mul", "gelu_and_mul", "gelu_tanh_and_mul",
+ "gelu_fast", "gelu_new", "gelu_quick", "mul_and_silu",
+ "swigluoai_and_mul"]:
+ if hasattr(ops, act_name):
+ original_act = getattr(ops, act_name)
+
+ def make_act_wrapper(orig_fn):
+ def act_with_meta(*args, **kwargs):
+ if torch.compiler.is_compiling():
+ # Meta implementation - in-place ops, return None
+ return None
+ return orig_fn(*args, **kwargs)
+ return act_with_meta
+
+ setattr(ops, act_name, make_act_wrapper(original_act))
+
+
+# Install meta kernels on module load
+_install_xpu_meta_kernels()
+
+
+# default
+def cutlass_grouped_gemm(input_A, input_B, bias, output, expert_token_count, n,
+ k, num_experts):
+ # expert_token_count_ = torch.tensor(expert_token_count,
+ # dtype=torch.int64,
+ # device=input_A.device)
+ # if bias is not None:
+ # bias = bias.repeat_interleave(expert_token_count_, dim=0).float()
+
+ def exclusive_prefix_sum(arr):
+ prefix = [0]
+ for i, x in enumerate(arr):
+ prefix.append(prefix[-1] + x)
+ return prefix
+
+ expert_offset = torch.tensor(exclusive_prefix_sum(expert_token_count),
+ dtype=torch.int64,
+ device="xpu")
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=None,
+ ptr_bias=bias,
+ ptr_D=output,
+ expert_first_token_offset=expert_offset,
+ N=n,
+ K=k,
+ num_experts=num_experts,
+ is_B_int4=False,
+ is_B_mxfp4=False)
+
+
+def cutlass_grouped_gemm_xe2(input_A, input_B, scales, bias, output,
+ num_rows_per_expert, n, k, num_experts, is_B_int4,
+ is_B_mxfp4):
+ expert_first_token_offset = torch.cat([
+ torch.tensor([0],
+ dtype=num_rows_per_expert.dtype,
+ device=num_rows_per_expert.device),
+ torch.cumsum(num_rows_per_expert, dim=0)
+ ]).to(torch.int64)
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=scales,
+ ptr_bias=bias,
+ ptr_D=output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=n,
+ K=k,
+ num_experts=num_experts,
+ is_B_int4=is_B_int4,
+ is_B_mxfp4=is_B_mxfp4)
+
+
+def ceilDiv(a, b):
+ return (a + b - 1) // b
+
+
+def compute_num_tokens_per_block(num_tokens, num_experts_per_node):
+ for num_tokens_per_block in [32, 64, 128, 256, 512, 1024]:
+ num_blocks_per_seq = ceilDiv(num_tokens, num_tokens_per_block)
+ if num_blocks_per_seq * num_experts_per_node <= num_tokens_per_block:
+ return num_tokens_per_block
+ return 1024
+
+
+def implement_zp(qweight):
+ # change u4 to s4 to avoid zero point in gemm kernel
+ # only support default zero point now
+ assert qweight.dtype == torch.uint8, "Input tensor must be uint8"
+
+ high_u4 = (qweight >> 4) & 0x0F
+ low_u4 = qweight & 0x0F
+
+ high_s8 = high_u4.to(torch.int8)
+ low_s8 = low_u4.to(torch.int8)
+
+ high_s8 = high_s8 - 8
+ low_s8 = low_s8 - 8
+
+ def pack_compact(a, b):
+
+ def process_number(x):
+ sign = (x < 0).to(torch.uint8)
+ abs_low3 = (x.view(torch.uint8) & 0x7).to(torch.uint8)
+ return (sign << 3) | abs_low3
+
+ packed_a = process_number(a)
+ packed_b = process_number(b)
+
+ return (packed_a << 4) | packed_b
+
+ result = pack_compact(high_s8, low_s8)
+
+ return result
+
+
+def xpu_fused_moe(hidden_states,
+ w13,
+ w13_scales,
+ w13_bias,
+ w2,
+ w2_scales,
+ w2_bias,
+ topk_weights,
+ topk_ids,
+ n_experts_per_token,
+ activation,
+ num_experts,
+ is_fp8=False,
+ is_int4=False,
+ is_mxfp4=False):
+ '''
+ hidden_states: [num_rows, hidden_size]
+ w13: [num_experts, 2*inter_size, hidden_size]
+ w13_scales:
+ None for bf16/fp16
+ or [num_experts] for fp8
+ or [num_experts, 2*inter_size, hidden_size // group_size] for 4bits
+ w13_bias: [num_experts, 2*inter_size] or None
+ w2: [num_experts, hidden_size, inter_size]
+ w2_scales:
+ None for bf16/fp16
+ or [num_experts] for fp8
+ or [num_experts, hidden_size, inter_size // group_size] for 4bits
+ w2_bias: [num_experts, hidden_size] or None
+ topk_weights: [num_rows, topk]
+ topk_ids: [num_rows, topk]
+ n_experts_per_token: int
+ activation: str
+ num_experts: int
+ is_int4: bool
+ is_mxfp4: bool
+ '''
+
+ output = torch.empty_like(hidden_states)
+ num_rows, hidden_size = list(hidden_states.shape)
+
+ dim_last = w13.shape[-1]
+ dim_second_last = w13.shape[-2]
+
+ # w13 is combined gate+up weights, so one dimension is 2*inter_size
+ # Determine which dimension is hidden_size and which is 2*inter_size
+ if dim_second_last == hidden_size:
+ # w13 is [E, hidden_size, 2*inter_size] - standard layout
+ inter_size = dim_last // 2
+ needs_transpose = False
+ else:
+ # w13 is [E, 2*inter_size, hidden_size] - needs transpose
+ inter_size = dim_second_last // 2
+ needs_transpose = True
+
+ assert w13.is_contiguous() and w2.is_contiguous()
+
+ # 4bits support [E, N, K]
+ # other types [E, K, N]
+ if not is_int4 and not is_mxfp4:
+ if not hasattr(w13, 'xpu_fused_moe'):
+ if needs_transpose:
+ w13.data = w13.transpose(-1, -2).contiguous()
+ w2.data = w2.transpose(-1, -2).contiguous()
+ w13.xpu_fused_moe = True
+ w13.inter_size = inter_size
+ else:
+ inter_size = w13.inter_size
+
+ if is_int4 and not hasattr(w13, 'xpu_fused_moe'):
+ w13_tmp = torch.empty_like(w13)
+ w2_tmp = torch.empty_like(w2)
+ for i in range(num_experts):
+ w13_tmp[i] = implement_zp(w13[i])
+ w2_tmp[i] = implement_zp(w2[i])
+ w13_tmp = w13_tmp.contiguous()
+ w2_tmp = w2_tmp.contiguous()
+ w13.data = w13_tmp
+ w2.data = w2_tmp
+ w13.xpu_fused_moe = True
+
+ # TODO: will all integrated in Cpp func. Temporary expose before gemm fusion
+ num_experts_per_node = num_experts
+ experts_per_token = n_experts_per_token
+ num_moe_inputs = n_experts_per_token * num_rows
+ permuted_elems = num_moe_inputs * hidden_size
+ # interbuf_elems = num_moe_inputs * inter_size
+ permuted_row_to_unpermuted_row_size = num_moe_inputs * 4
+ permuted_token_selected_experts_size = num_moe_inputs * 4
+ src_to_dest_map_size = experts_per_token * num_rows * 4
+ expert_first_token_offset_size = (num_experts_per_node + 1) * 8
+ num_tokens_per_block = compute_num_tokens_per_block(
+ num_rows, num_experts_per_node)
+ num_blocks_per_seq = ceilDiv(num_rows, num_tokens_per_block)
+ blocked_expert_counts_size = num_experts_per_node * num_blocks_per_seq * 4
+ blocked_expert_counts_cumsum_size = blocked_expert_counts_size
+ blocked_row_to_unpermuted_row_size = num_experts_per_node * num_rows * 4
+ permuted_data_size = permuted_elems * hidden_states.element_size()
+ permuted_token_final_scales_size = num_moe_inputs * 4
+
+ ws_map = {}
+ map_offset = 0
+
+ def config_ws(name, size):
+ nonlocal map_offset
+ if size % 256 != 0:
+ size += 256 - size % 256
+ ws_map[name] = (size, map_offset)
+ map_offset += size
+
+ config_ws("permuted_row_to_unpermuted_row",
+ permuted_row_to_unpermuted_row_size)
+ config_ws("permuted_token_selected_experts",
+ permuted_token_selected_experts_size)
+ config_ws("unpermuted_row_to_permuted_row", src_to_dest_map_size)
+ config_ws("blocked_expert_counts", blocked_expert_counts_size)
+ config_ws("blocked_expert_counts_cumsum",
+ blocked_expert_counts_cumsum_size)
+ config_ws("blocked_row_to_unpermuted_row",
+ blocked_row_to_unpermuted_row_size)
+ config_ws("expert_first_token_offset", expert_first_token_offset_size)
+ config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
+ config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
+
+ workspace = torch.zeros(map_offset,
+ dtype=torch.uint8,
+ device=hidden_states.device)
+ if topk_ids.dtype == torch.int32:
+ topk_ids = topk_ids.to(torch.int64)
+ ops.fused_moe_prologue(
+ input=hidden_states,
+ token_selected_experts=topk_ids,
+ token_final_scales=topk_weights,
+ workspace=workspace,
+ hidden_size=hidden_size,
+ inter_size=inter_size,
+ num_experts_on_rank=num_experts_per_node)
+
+ expert_first_token_offset = workspace[
+ ws_map["expert_first_token_offset"][1]:
+ ws_map["expert_first_token_offset"][1] +
+ expert_first_token_offset_size].view(torch.int64)
+ unpermuted_row_to_permuted_row = workspace[
+ ws_map["unpermuted_row_to_permuted_row"][1]:
+ ws_map["unpermuted_row_to_permuted_row"][1] +
+ src_to_dest_map_size].view(torch.int32)
+ gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
+ ws_map["overlapped_gemm1_gemm2_inputs"][1] +
+ permuted_data_size].view(hidden_states.dtype).view(
+ num_moe_inputs, hidden_size)
+ # permuted_token_final_scales = workspace[
+ # ws_map["permuted_token_final_scales"][1]:
+ # ws_map["permuted_token_final_scales"][1] +
+ # permuted_token_final_scales_size].view(torch.float)
+ gemm1_output = torch.empty((num_moe_inputs, 2 * inter_size),
+ dtype=hidden_states.dtype,
+ device=hidden_states.device)
+
+ ########### gemm1 ##################
+ input_B = w13
+
+ if not is_fp8 and not is_int4 and not is_mxfp4:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=gemm1_input,
+ ptr_B=input_B,
+ ptr_scales=None,
+ ptr_bias=w13_bias,
+ ptr_D=gemm1_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=2 * inter_size,
+ K=hidden_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+ else:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=gemm1_input,
+ ptr_B=input_B,
+ ptr_scales=w13_scales,
+ ptr_bias=w13_bias,
+ ptr_D=gemm1_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=2 * inter_size,
+ K=hidden_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+
+ # act
+ act_output = torch.empty((num_moe_inputs, inter_size),
+ dtype=gemm1_output.dtype,
+ device=gemm1_output.device)
+ if activation == "silu":
+ ops.silu_and_mul(act_output, gemm1_output)
+ elif activation == "gelu":
+ ops.gelu_and_mul(act_output, gemm1_output)
+ elif activation == "swigluoai":
+ ops.swigluoai_and_mul(act_output, gemm1_output, 1.702, 7.0)
+ else:
+ raise ValueError(f"Unsupported FusedMoe activation: {activation}.")
+
+ ########### gemm2 ##################
+ input_A = act_output.contiguous()
+ input_B = w2
+ gemm2_output = torch.empty((num_moe_inputs, hidden_size),
+ dtype=hidden_states.dtype,
+ device=hidden_states.device)
+ if not is_fp8 and not is_int4 and not is_mxfp4:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=None,
+ ptr_bias=w2_bias,
+ ptr_D=gemm2_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=hidden_size,
+ K=inter_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+ else:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=w2_scales,
+ ptr_bias=w2_bias,
+ ptr_D=gemm2_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=hidden_size,
+ K=inter_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+
+ ops.moe_gather(output, gemm2_output, topk_weights,
+ unpermuted_row_to_permuted_row,
+ num_experts_per_node)
+ return output
+
+
+def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor:
+ """Apply jitter to the input tensor for regularization."""
+ low = 1.0 - moe_jitter_eps
+ high = 1.0 + moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return x * (low + noise * (high - low))
+
+
+def compute_top_k(scores: torch.Tensor, moe_top_k: int):
+ """Compute the top-k scores from the logits."""
+ if moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, moe_top_k, dim=-1)
+
+
+def route_tokens_xpu(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: torch.Tensor,
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ training: bool = False,
+) -> tuple:
+ """Route tokens to experts and compute expert weights and indices (XPU version)."""
+ if training and moe_jitter_eps is not None:
+ x = apply_jitter(x, moe_jitter_eps)
+
+ x_flat = x.view(-1, x.shape[-1])
+ logits = torch.nn.functional.linear(x_flat, router_weight, router_bias)
+ expert_weights, expert_indices = compute_top_k(logits, moe_top_k)
+ expert_weights = expert_weights.softmax(dim=-1)
+ if moe_normalize_expert_weights is not None:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+
+ return logits, expert_weights, expert_indices
+
+
+class MegaBlocksMoeMLP(torch.nn.Module):
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> tuple:
+ """
+ Forward pass through the MoE layer.
+
+ Args:
+ x: Input tensor of shape [batch_size, seq_len, hidden_size] or [tokens, hidden_size]
+
+ Returns:
+ Tuple of (output, expert_weights) where:
+ - output: Tensor of same shape as input
+ - expert_weights: Expert weights for each token [tokens, top_k]
+ """
+
+ # Get MoE parameters from the wrapped modules
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+
+ # Detect activation type - check for GptOss-style swigluoai activation
+ # GptOssExperts has alpha and limit attributes for swigluoai
+ if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
+ activation = "swigluoai"
+ else:
+ activation = getattr(self.experts, "activation", "silu")
+
+ # Get weight tensors - support different naming conventions
+ if hasattr(self.experts, "gate_up_proj"):
+ w13 = self.experts.gate_up_proj
+ # NOTE: swigluoai_and_mul kernel expects interleaved layout [g0,u0,g1,u1,...]
+ # which matches GptOss's gate_up_proj layout, so no conversion needed.
+
+ elif hasattr(self.experts, "w1"):
+ # Combine w1 and w3 if stored separately
+ w1 = self.experts.w1
+ w3 = getattr(self.experts, "w3", None)
+ if w3 is not None:
+ w13 = torch.cat([w1, w3], dim=-2)
+ else:
+ w13 = w1
+ else:
+ raise AttributeError("experts module must have 'gate_up_proj' or 'w1' attribute")
+
+ if hasattr(self.experts, "down_proj"):
+ w2 = self.experts.down_proj
+ elif hasattr(self.experts, "w2"):
+ w2 = self.experts.w2
+ else:
+ raise AttributeError("experts module must have 'down_proj' or 'w2' attribute")
+
+ # Get optional bias tensors
+ w13_bias = getattr(self.experts, "gate_up_proj_bias", None)
+ w2_bias = getattr(self.experts, "down_proj_bias", None)
+
+ # Get quantization info
+ is_fp8 = getattr(self.experts, "is_fp8", False)
+ is_int4 = getattr(self.experts, "is_int4", False)
+ is_mxfp4 = getattr(self.experts, "is_mxfp4", False)
+
+ w13_scales = getattr(self.experts, "gate_up_proj_scales", None)
+ w2_scales = getattr(self.experts, "down_proj_scales", None)
+
+ # Store original shape
+ in_shape = x.size()
+
+ # Route tokens to experts
+ logits, expert_weights, expert_indices = route_tokens_xpu(
+ x,
+ self.router.weight,
+ getattr(self.router, "bias", None),
+ moe_top_k,
+ moe_num_experts,
+ moe_jitter_eps,
+ moe_normalize_expert_weights,
+ self.training,
+ )
+
+ # Reshape input for fused MoE
+ x_flat = x.view(-1, x.shape[-1])
+
+ # Call XPU fused MoE kernel
+ output = xpu_fused_moe(
+ hidden_states=x_flat,
+ w13=w13,
+ w13_scales=w13_scales,
+ w13_bias=w13_bias,
+ w2=w2,
+ w2_scales=w2_scales,
+ w2_bias=w2_bias,
+ topk_weights=expert_weights.float(),
+ topk_ids=expert_indices,
+ n_experts_per_token=moe_top_k,
+ activation=activation,
+ num_experts=moe_num_experts,
+ is_fp8=is_fp8,
+ is_int4=is_int4,
+ is_mxfp4=is_mxfp4,
+ )
+
+ # Restore original shape
+ output = output.view(in_shape)
+
+ return output, expert_weights
+
+
+# Export classes and functions
+__all__ = [
+ "MegaBlocksMoeMLP",
+ "xpu_fused_moe",
+ "cutlass_grouped_gemm",
+ "cutlass_grouped_gemm_xe2",
+]
\ No newline at end of file
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/__init__.py b/build/torch28-cxx11-cu129-x86_64-linux/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..38075732c6d8fa0e1e6ef493145e1aca3851ae6b
--- /dev/null
+++ b/build/torch28-cxx11-cu129-x86_64-linux/__init__.py
@@ -0,0 +1,202 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+from ._ops import ops
+
+from .grouped_gemm import backend as gg_backend
+from .grouped_gemm import ops as gg_ops
+
+
+from ._layers.arguments import Arguments
+from ._layers.dmoe import ParallelDroplessMLP, dMoE
+from ._layers.glu import SparseGLU
+from ._layers.mlp import MLP, SparseMLP
+from ._layers.moe import MoE, ParallelMLP, get_load_balancing_loss
+
+from . import layers
+
+# This section contains the direct kernel exports (not inlcuded in the original code)
+def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
+ """
+ Compute exclusive cumulative sum along the specified dimension.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ result = ops.exclusive_cumsum(x, dim)
+ out.copy_(result)
+ return out
+
+
+def inclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
+ """
+ Compute inclusive cumulative sum along the specified dimension.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ result = ops.inclusive_cumsum(x, dim)
+ out.copy_(result)
+ return out
+
+
+def histogram(x: torch.Tensor, num_bins: int) -> torch.Tensor:
+ """
+ Compute histogram of input tensor values.
+
+ Args:
+ x: Input tensor
+ num_bins: Number of histogram bins
+
+ Returns:
+ Histogram tensor with counts for each bin
+ """
+ return ops.histogram(x, num_bins)
+
+
+def indices(
+ padded_bins: torch.Tensor,
+ block_size: int,
+ output_block_rows: int,
+ output_block_columns: int,
+) -> torch.Tensor:
+ """
+ Construct indices from padded bins for sparse operations.
+
+ Args:
+ padded_bins: Tensor containing bin boundaries
+ block_size: Size of each block
+ output_block_rows: Number of rows in output blocks
+ output_block_columns: Number of columns in output blocks
+
+ Returns:
+ Tensor containing constructed indices
+ """
+ return ops.indices(padded_bins, block_size, output_block_rows, output_block_columns)
+
+
+def replicate_forward(
+ x: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Forward pass of replicate operation - replicate values according to bin sizes.
+
+ Args:
+ x: Input tensor with values to replicate
+ bins: Tensor containing bin sizes
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ return ops.replicate_forward(x, bins, out)
+
+
+def replicate_backward(
+ grad: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Backward pass of replicate operation - reduce gradients back to bins.
+
+ Args:
+ grad: Gradient tensor to reduce
+ bins: Tensor containing bin sizes
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ return ops.replicate_backward(grad, bins, out)
+
+
+def sort(
+ x: torch.Tensor, end_bit: int, x_out: torch.Tensor, iota_out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Radix sort with index tracking.
+
+ Args:
+ x: Input tensor to sort
+ end_bit: Number of bits to consider in sorting
+ x_out: Output tensor for sorted values
+ iota_out: Output tensor for sorted indices
+
+ Returns:
+ The sorted values tensor
+ """
+ return ops.sort(x, end_bit, x_out, iota_out)
+
+
+# Convenience functions for common use cases
+def cumsum(x: torch.Tensor, dim: int = -1, exclusive: bool = False) -> torch.Tensor:
+ """
+ Compute cumulative sum with automatic output allocation.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum (default: last dimension)
+ exclusive: Whether to compute exclusive (True) or inclusive (False) cumsum
+
+ Returns:
+ New tensor containing the cumulative sum
+ """
+ out = torch.empty_like(x)
+ if exclusive:
+ return exclusive_cumsum(x, dim, out)
+ else:
+ return inclusive_cumsum(x, dim, out)
+
+
+def argsort(x: torch.Tensor, end_bit: int = 32) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Sort tensor and return both sorted values and indices.
+
+ Args:
+ x: Input tensor to sort
+ end_bit: Number of bits to consider in sorting
+
+ Returns:
+ Tuple of (sorted_values, sorted_indices)
+ """
+ x_out = torch.empty_like(x)
+ iota_out = torch.empty_like(x)
+ sort(x, end_bit, x_out, iota_out)
+ return x_out, iota_out
+
+
+# Export public API
+__all__ = [
+ "MyReplacementLayer",
+ # Direct kernel exports
+ "exclusive_cumsum",
+ "inclusive_cumsum",
+ "histogram",
+ "indices",
+ "replicate_forward",
+ "replicate_backward",
+ "sort",
+ "cumsum",
+ "argsort",
+ # Original exports
+ "Arguments",
+ "ParallelDroplessMLP",
+ "dMoE",
+ "SparseGLU",
+ "MLP",
+ "SparseMLP",
+ "MoE",
+ "ParallelMLP",
+ "get_load_balancing_loss",
+]
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/_layers/__init__.py b/build/torch28-cxx11-cu129-x86_64-linux/_layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a720e7a2cc4e44636f6e433a2750e945dc38e8b2
--- /dev/null
+++ b/build/torch28-cxx11-cu129-x86_64-linux/_layers/__init__.py
@@ -0,0 +1,10 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# from megablocks.layers.dmoe import dMoE
+from .moe import MoE
+
+__all__ = [
+ 'MoE',
+ # 'dMoE',
+]
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/_layers/activation_fn.py b/build/torch28-cxx11-cu129-x86_64-linux/_layers/activation_fn.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e1d956704840aa4daf7d1d71d24e051567feab9
--- /dev/null
+++ b/build/torch28-cxx11-cu129-x86_64-linux/_layers/activation_fn.py
@@ -0,0 +1,33 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Callable, Union
+
+import torch
+from ..stk import Matrix
+
+
+def act_fn(
+ x: Matrix,
+ function: Callable,
+ return_grad_fn: bool = False,
+ **kwargs,
+) -> Union[tuple[Matrix, Any] | Matrix]:
+ assert isinstance(x, Matrix)
+ with torch.set_grad_enabled(torch.is_grad_enabled() or return_grad_fn):
+ if return_grad_fn:
+ x.data.requires_grad = True
+ out = function(x.data, **kwargs)
+ y = Matrix(
+ x.size(),
+ out,
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
+ if return_grad_fn:
+ return y, out.backward
+ return y
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/_layers/all_to_all.py b/build/torch28-cxx11-cu129-x86_64-linux/_layers/all_to_all.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ac7067bcaa34db1d82b340c43550fe3577aa7a3
--- /dev/null
+++ b/build/torch28-cxx11-cu129-x86_64-linux/_layers/all_to_all.py
@@ -0,0 +1,54 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import torch.distributed as dist
+
+
+class AllToAllOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op):
+ out = torch.empty((sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype)
+
+ ctx.input_shape = x.shape
+ ctx.output_split_sizes = output_split_sizes
+ ctx.input_split_sizes = input_split_sizes
+ ctx.group = group
+ handle = dist.all_to_all_single(
+ out,
+ x,
+ output_split_sizes=output_split_sizes,
+ input_split_sizes=input_split_sizes,
+ group=group,
+ async_op=async_op,
+ )
+ return out, handle
+
+ @staticmethod
+ def backward(ctx, grad, _):
+ if ctx.needs_input_grad[0]:
+ out = torch.empty(
+ ctx.input_shape,
+ device=grad.device,
+ dtype=grad.dtype,
+ )
+ dist.all_to_all_single(
+ out,
+ grad,
+ output_split_sizes=ctx.input_split_sizes,
+ input_split_sizes=ctx.output_split_sizes,
+ group=ctx.group,
+ )
+ return out, None, None, None, None
+ return None, None, None, None, None
+
+
+def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False):
+ return AllToAllOp.apply(
+ x,
+ output_split_sizes,
+ input_split_sizes,
+ group,
+ async_op,
+ )
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/_layers/arguments.py b/build/torch28-cxx11-cu129-x86_64-linux/_layers/arguments.py
new file mode 100644
index 0000000000000000000000000000000000000000..4db9b1bd38bc2e2f421625c124f86b85f45c5ae0
--- /dev/null
+++ b/build/torch28-cxx11-cu129-x86_64-linux/_layers/arguments.py
@@ -0,0 +1,101 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import dataclasses
+from functools import partial
+from typing import Any, Callable, Optional, Union
+
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+
+# import megablocks.grouped_gemm_util as grouped_gemm
+from .. import grouped_gemm_util as grouped_gemm
+
+# Type annotation for in-place Tensor initialization function.
+InitFn = Union[Callable[[torch.Tensor], None], partial[torch.Tensor]]
+
+_ALLOWED_BITWIDTHS = (-1, 4, 8)
+
+DEFAULT_ACTIVATION_FN = partial(F.gelu, approximate='tanh')
+
+
+@dataclasses.dataclass
+class Arguments:
+ # Model arguments.
+ hidden_size: int = 1024
+ ffn_hidden_size: int = 4096
+ num_layers: int = 1
+ bias: bool = True
+ return_bias: bool = True
+ activation_fn: Optional[Callable] = DEFAULT_ACTIVATION_FN
+
+ # MoE arguments.
+ moe_num_experts: int = 1
+ moe_top_k: int = 1
+ moe_capacity_factor: int = 1
+ moe_normalize_expert_weights: Optional[Union[int, float]] = None
+ moe_loss_weight: float = 0.1
+ moe_jitter_eps: Optional[float] = None
+ moe_lbl_in_fp32: bool = False
+
+ # Parallelism arguments.
+ moe_expert_model_parallelism: bool = False
+ expert_parallel_group: Optional[dist.ProcessGroup] = None
+ pipeline_model_parallel_size: int = 1
+ num_layers_per_virtual_pipeline_stage: Optional[int] = None
+
+ # Compute arguments.
+ memory_optimized_mlp: bool = False
+ mlp_type: str = 'mlp'
+ mlp_impl: str = 'sparse'
+
+ # Initialization arguments.
+ fp16: bool = True
+ bf16: bool = False
+ device: Union[int, torch.device] = dataclasses.field(default_factory=torch.cuda.current_device)
+ init_method: InitFn = partial(torch.nn.init.normal_, mean=0.0, std=0.02)
+ output_layer_init_method: InitFn = init_method
+
+ # Benchmarking arguments.
+ uniform_expert_assignment: bool = False
+
+ # shared expert arguments
+ shared_expert: bool = False # enable using shared expert
+ fc_cls: Any = torch.nn.Linear # class of the fully connected layer in shared expert (purpose: to allow using custom FC layer eg te.Linear (for FP8))
+ fc_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict,) # kwargs for custom fc layers
+ remat_act_fn: bool = True # enable act fn to be rematerialized instead of stored
+ shared_expert_hidden_size: Optional[
+ int] = None # hidden size of the shared expert IF we want to set it to something different from hidden_size
+ shared_expert_weighted_sum: bool = False # enable using weighted sum for shared expert output (wieghted by number of experts used)
+
+ # Router Z-loss arguments
+ moe_zloss_weight: float = 0 # 1e-3 is a reasonable value
+ moe_zloss_in_fp32: bool = False
+
+ def __post_init__(self):
+ # Sparse MLP is not supported with triton >=3.2.0
+ # TODO: Remove this once sparse is supported with triton >=3.2.0
+ if self.__getattribute__('mlp_impl') == 'sparse':
+ try:
+ import triton
+ if triton.__version__ >= '3.2.0':
+ raise ValueError(
+ 'Sparse MLP is not supported with triton >=3.2.0. Please use mlp_impl="grouped" instead.',
+ )
+ except ImportError:
+ raise ImportError('Triton is required for sparse MLP implementation')
+
+ if self.__getattribute__('mlp_impl') == 'grouped':
+ grouped_gemm.assert_grouped_gemm_is_available()
+
+ if self.shared_expert_hidden_size is None:
+ self.shared_expert_hidden_size = self.ffn_hidden_size
+
+
+def from_megatron(megatron_args: Any):
+ args = Arguments()
+ for field in dataclasses.fields(args):
+ if hasattr(megatron_args, field.name):
+ setattr(args, field.name, getattr(megatron_args, field.name))
+ return args
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/_layers/common.py b/build/torch28-cxx11-cu129-x86_64-linux/_layers/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d07109702963ba48a3b94ab860807954dfd79c1
--- /dev/null
+++ b/build/torch28-cxx11-cu129-x86_64-linux/_layers/common.py
@@ -0,0 +1,26 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+from .arguments import Arguments
+
+
+def dtype(args: Arguments):
+ if args.fp16:
+ return torch.float16
+ elif args.bf16:
+ return torch.bfloat16
+ return None
+
+
+def cast_if_autocast_enabled(tensor):
+ if torch.is_autocast_enabled():
+ if tensor.device.type == 'cuda':
+ dtype = torch.get_autocast_gpu_dtype()
+ elif tensor.device.type == 'cpu':
+ dtype = torch.get_autocast_cpu_dtype()
+ else:
+ raise NotImplementedError()
+ return tensor.to(dtype=dtype)
+ return tensor
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/_layers/dmlp_registry.py b/build/torch28-cxx11-cu129-x86_64-linux/_layers/dmlp_registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..de2ed047042e438c7190ebb139b6f7f30009734c
--- /dev/null
+++ b/build/torch28-cxx11-cu129-x86_64-linux/_layers/dmlp_registry.py
@@ -0,0 +1,42 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Union
+
+from . import glu, mlp
+from .arguments import Arguments
+
+MlpType = Union[mlp.SparseMLP, glu.SparseGLU]
+
+_REGISTRY = {
+ 'mlp': {
+ 'grouped': mlp.GroupedMLP,
+ 'sparse': mlp.SparseMLP,
+ },
+ 'glu': {
+ 'grouped': glu.GroupedGLU,
+ 'sparse': glu.SparseGLU,
+ },
+}
+
+
+def get(args: Arguments) -> MlpType:
+ """Returns an MLP for use in a dMoE instance.
+
+ Uses the provided arguments to instantiate the appropriate
+ MLP instance. This only contains MLPs for use in dMoEs
+ (ie. only for the dropless versions of MoEs).
+
+ Args:
+ args: propagated Arguments dataclass.
+
+ Returns:
+ An instantiated MLP constructed using the input args.
+ """
+ if args.mlp_type not in _REGISTRY:
+ raise ValueError(f'Unsupported mlp type: {args.mlp_type}')
+
+ if args.mlp_impl not in _REGISTRY[args.mlp_type]:
+ raise ValueError(f'{args.mlp_type} does not support {args.mlp_impl} backend.',)
+
+ return _REGISTRY[args.mlp_type][args.mlp_impl](args)
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/_layers/dmoe.py b/build/torch28-cxx11-cu129-x86_64-linux/_layers/dmoe.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d0375a4df2f27134c4127e60be04f3b45693050
--- /dev/null
+++ b/build/torch28-cxx11-cu129-x86_64-linux/_layers/dmoe.py
@@ -0,0 +1,337 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import numpy as np
+import torch
+
+# try:
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/dmoe.py is needed.',
+# )
+
+# import megablocks.ops as ops
+# # from megablocks.ops import ops
+# from megablocks.layers import common, dmlp_registry, moe, mpu
+# from megablocks.layers.arguments import Arguments
+
+from .. import stk
+from .. import ops
+from . import common, dmlp_registry, moe, mpu
+from .arguments import Arguments
+
+def promote_scalar(x):
+ return x.view(1) if not len(x.size()) else x
+
+
+class ParallelDroplessMLP(moe.ParallelMLP):
+
+ def __init__(self, args: Arguments):
+ super(ParallelDroplessMLP, self).__init__(args)
+ self.hidden_size = args.hidden_size
+ self.ffn_hidden_size = mpu.features_per_rank(args)
+ self.blocking = 128
+ self.mlp = dmlp_registry.get(args)
+
+ # Calculate the number of bits needed to represent the column indices
+ # in the intermediate sparse matrix.
+ max_column_index = ((self.ffn_hidden_size * self.num_experts) // self.blocking)
+ self.transpose_sort_end_bit = max(
+ int(np.ceil(np.log2(max_column_index))),
+ 1,
+ )
+
+ def sparse_transpose(self, size, row_indices, column_indices, offsets):
+ block_columns = size[1] // self.blocking
+
+ # Sort row indices by column indices to get the transposed matrix's
+ # column indices.
+ #
+ # NOTE: Our sort operation uses the same width indices as the input values.
+ # To avoid overflow when we have large activation matrices we cast to
+ # 32-bit before sorting.
+ _, gather_indices = ops.sort(
+ column_indices.int(),
+ self.transpose_sort_end_bit,
+ )
+
+ # There are a constant number of blocks in every row of the sparse matrix.
+ # A blocks offset is:
+ #
+ # row_index * blocks_per_row + column_index % blocks_per_row
+ #
+ # Once we have the block offsets ordered for transposition we can divide
+ # by blocks_per_row to get the transposed column indices.
+ column_indices_t = row_indices.gather(0, gather_indices.long())
+ block_offsets_t = gather_indices.int()
+
+ zero = torch.zeros((1,), dtype=torch.int32, device=row_indices.device)
+ nnz_per_column = ops.histogram(column_indices, block_columns)
+ nnz_per_column = ops.inclusive_cumsum(nnz_per_column, 0)
+ if nnz_per_column.dim() == 0:
+ # This addresses an edge case when ffn_hidden_size is equal to self.blocking.
+ nnz_per_column = nnz_per_column.unsqueeze(0)
+ offsets_t = torch.cat([zero, nnz_per_column])
+ return column_indices_t, offsets_t, block_offsets_t
+
+ def topology(self, x, padded_bins):
+ padded_tokens, _ = x.size()
+ assert padded_tokens % self.blocking == 0
+ if self.ffn_hidden_size % self.blocking != 0:
+ raise ValueError(
+ f'The ffn_hidden_size {self.ffn_hidden_size} must be divisible by ' +
+ f'the block size {self.blocking}. Please update your configuration.',
+ )
+
+ # Offsets for the sparse matrix. All rows have the
+ # same number of nonzero blocks dictated by the
+ # dimensionality of a single expert.
+ block_rows = padded_tokens // self.blocking
+ blocks_per_row = self.ffn_hidden_size // self.blocking
+ offsets = torch.arange(
+ 0,
+ block_rows * blocks_per_row + 1,
+ blocks_per_row,
+ dtype=torch.int32,
+ device=x.device,
+ )
+
+ # Indices for the sparse matrix. The indices for
+ # the intermediate matrix are dynamic depending
+ # on the mapping of tokens to experts.
+ column_indices = ops.topology(
+ padded_bins,
+ self.blocking,
+ block_rows,
+ blocks_per_row,
+ )
+
+ # TODO(tgale): This is unused. Remove the need for this in stk.
+ # For now, use meta init to save the device memory.
+ data = torch.empty(
+ column_indices.numel(),
+ self.blocking,
+ self.blocking,
+ dtype=common.dtype(self.args),
+ device='meta',
+ )
+ shape = (
+ padded_tokens,
+ self.ffn_hidden_size * mpu.experts_per_rank(self.args),
+ )
+ row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
+ column_indices_t, offsets_t, block_offsets_t = self.sparse_transpose(
+ shape,
+ row_indices,
+ column_indices,
+ offsets,
+ )
+ return stk.Matrix(
+ shape,
+ data,
+ row_indices,
+ column_indices,
+ offsets,
+ column_indices_t,
+ offsets_t,
+ block_offsets_t,
+ )
+
+ def indices_and_padded_bins(self, top_experts):
+ # Sort the expert ids to produce the scatter/gather
+ # indices for the permutation.
+ top_experts = top_experts.int()
+ bin_ids, indices = ops.sort(top_experts, self.sort_end_bit)
+
+ # Histogram the expert ids to identify the number of
+ # tokens routed to each expert.
+ tokens_per_expert = ops.histogram(top_experts, self.num_experts)
+
+ # Round the token counts up to the block size used in
+ # the matrix muliplications. Caculate the starting
+ # position of each bin.
+ padded_tokens_per_expert = ops.round_up(
+ tokens_per_expert,
+ self.blocking,
+ )
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ padded_bins = promote_scalar(padded_bins)
+
+ # Calculate the bin bounds for the sorted tokens.
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ bins = promote_scalar(bins)
+ return indices, bin_ids, bins, padded_bins, tokens_per_expert
+
+ def sparse_forward_once(self, x, expert_weights, top_experts):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, padded_bins, tokens_per_expert = (self.indices_and_padded_bins(top_experts))
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.padded_gather(
+ x,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ self.top_k,
+ )
+
+ # Create the sparse matrix topology.
+ with torch.no_grad():
+ topo = self.topology(x, padded_bins)
+
+ # Perform the expert computation.
+ x = self.mlp(x, topo)
+
+ # Un-route the data for the MoE output.
+ x = ops.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ padded_bins,
+ self.top_k,
+ )
+ return x, tokens_per_expert
+
+ # For use in the base-class parallel_forward_once.
+ def sparse_permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy, # unused
+ top_k,
+ ):
+
+ # Round the token counts up to the block size used in the matrix
+ # multiplication. Calculate the starting position of each bin.
+ padded_tokens_per_expert = ops.round_up(
+ tokens_per_expert,
+ self.blocking,
+ )
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ padded_bins = promote_scalar(padded_bins)
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
+
+ # Create the sparse matrix topology.
+ with torch.no_grad():
+ topo = self.topology(x, padded_bins)
+
+ # Perform the expert computation.
+ x = self.mlp(x, topo)
+
+ # Un-route the data for the MoE output.
+ return ops.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ def grouped_forward_once(self, x, expert_weights, top_experts):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ out = self.grouped_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ -1, # unused
+ self.args.moe_top_k,
+ )
+ return out, tokens_per_expert
+
+ def grouped_permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy, # unused
+ top_k,
+ ):
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.gather(x, indices, bin_ids, bins, top_k)
+
+ # Perform the expert computation.
+ x = self.mlp(x, tokens_per_expert)
+
+ # Un-route the data for the MoE output.
+ return ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
+
+ def forward_once(self, x, expert_weights, top_experts):
+ if self.args.mlp_impl == 'sparse':
+ return self.sparse_forward_once(x, expert_weights, top_experts)
+ else:
+ return self.grouped_forward_once(x, expert_weights, top_experts)
+
+ def permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ ):
+ if self.args.mlp_impl == 'sparse':
+ return self.sparse_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ )
+ else:
+ return self.grouped_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ )
+
+
+class dMoE(moe.MoE):
+
+ def _init_experts_mlp(self, args: Arguments):
+ return ParallelDroplessMLP(args)
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/_layers/gelu.py b/build/torch28-cxx11-cu129-x86_64-linux/_layers/gelu.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4c9e6532798615b5c12c96694241a4c18ee8f7b
--- /dev/null
+++ b/build/torch28-cxx11-cu129-x86_64-linux/_layers/gelu.py
@@ -0,0 +1,52 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# try:
+# import stk
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/gelu.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+import torch.nn.functional as F
+
+
+@torch.jit.script
+def _gelu_backward_inplace(g, x):
+ tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
+ ff = (0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out))
+ return g.mul_(ff)
+
+
+def gelu_backward_(grad: stk.Matrix, x: stk.Matrix):
+ # NOTE: The two sparse matrices must have the same topology.
+ if isinstance(grad, stk.Matrix) and isinstance(x, stk.Matrix):
+ return stk.Matrix(
+ x.size(),
+ _gelu_backward_inplace(grad.data, x.data),
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
+ return _gelu_backward_inplace(grad, x)
+
+
+def gelu(x: stk.Matrix):
+ assert isinstance(x, stk.Matrix)
+ return stk.Matrix(
+ x.size(),
+ F.gelu(x.data, approximate='tanh'),
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/_layers/glu.py b/build/torch28-cxx11-cu129-x86_64-linux/_layers/glu.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f297a41ff6a1a2a285f5b461951672364b898da
--- /dev/null
+++ b/build/torch28-cxx11-cu129-x86_64-linux/_layers/glu.py
@@ -0,0 +1,244 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# import stk.ops
+# try:
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/glu.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+
+# from megablocks import grouped_gemm_util as gg
+# from megablocks.layers import common, mpu
+# from megablocks.layers.activation_fn import act_fn
+# from megablocks.layers.arguments import Arguments
+# from megablocks.layers.mlp import (
+# SharedMLP,
+# SparseMLP,
+# create_dmoe_expert_weights,
+# resolve_dtensor,
+# )
+
+from .. import grouped_gemm_util as gg
+from . import common, mpu
+from .activation_fn import act_fn
+from .arguments import Arguments
+from .mlp import (
+ SharedMLP,
+ SparseMLP,
+ create_dmoe_expert_weights,
+ resolve_dtensor,
+)
+
+
+class SparseGLU(SparseMLP):
+
+ def __init__(self, args: Arguments):
+ super().__init__(args)
+ self.v1 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ with torch.no_grad():
+ self.v1.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ ),
+ )
+
+ mpu.set_expert_model_parallel_attributes(
+ self.v1,
+ self._should_set_parallelism_attribute,
+ )
+
+ def forward(self, x, topo):
+ if self.args.memory_optimized_mlp:
+ raise NotImplementedError(
+ 'Memory optimized implementation not yet supported with GLU with sparse kernels.',
+ )
+
+ w1, v1, w2 = self.scale_grad(self.w1), self.scale_grad(self.v1,), self.scale_grad(self.w2)
+ w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2)
+
+ # Compute the GLU.
+ x1 = stk.ops.sdd(x, w1.t(), topo)
+ x2 = stk.ops.sdd(x, v1.t(), topo)
+
+ activation_fn_out = act_fn(x1, self.args.activation_fn)
+ x1 = stk.ops.mul(activation_fn_out, x2)
+
+ return stk.ops.dsd(x1, w2)
+
+
+class MemoryOptimizedGroupedGLU(torch.autograd.Function):
+ """GroupedMLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, v1, w2, batch_sizes, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ v1 = v1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], v1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not v1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1', 'v1' and 'w2'.")
+
+ # Layer 0: x @ w1.t().
+ assert gg.backend is not None
+ sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True)
+ v1_out = gg.backend.gmm(x, v1, batch_sizes, trans_b=True)
+
+ # GeLU.
+ activation_fn_out = activation_fn(sdd_out) * v1_out
+
+ # Layer 1: x @ w2.
+ dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, v1, w2, batch_sizes, x, sdd_out, v1_out)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx, ddsd_out):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # Unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, v1, w2 = saved_tensors[:3]
+ batch_sizes = saved_tensors[3]
+ x = saved_tensors[4]
+ sdd_out, v1_out = saved_tensors[5:7]
+
+ # Rematerialize activation_fn output.
+ activation_fn = ctx.activation_fn
+ with torch.set_grad_enabled(True):
+ sdd_out.requires_grad = True
+ v1_out.requires_grad = True
+ activation_fn_out = activation_fn(sdd_out) * v1_out
+ activation_grad_fn = activation_fn_out.backward
+
+ # Compute dw2 with recomputed activation_fn output.
+ assert gg.backend is not None
+ dw2 = gg.backend.gmm(
+ activation_fn_out,
+ ddsd_out,
+ batch_sizes,
+ trans_a=True,
+ )
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ gg.backend.gmm(
+ ddsd_out,
+ w2,
+ batch_sizes,
+ trans_b=True,
+ c=dactivation_fn_out,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out)
+ dsdd_out = sdd_out.grad
+ dv1_out = v1_out.grad
+
+ # Compute dw1.
+ dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True)
+
+ # Compute dv1.
+ dv1 = gg.backend.gmm(dv1_out, x, batch_sizes, trans_a=True)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ dx = ddsd_out
+ gg.backend.gmm(dsdd_out, w1, batch_sizes, c=dx)
+ dx += gg.backend.gmm(dv1_out, v1, batch_sizes)
+ return dx, dw1, dv1, dw2, None, None
+
+
+memory_optimized_grouped_glu = MemoryOptimizedGroupedGLU.apply
+
+
+class GroupedGLU(SparseGLU):
+
+ def forward(self, x, tokens_per_expert):
+ batch_sizes = tokens_per_expert.cpu().to(torch.long)
+ w1, v1, w2 = (
+ self.scale_grad(self.w1),
+ self.scale_grad(self.v1),
+ self.scale_grad(self.w2),
+ )
+ w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2)
+
+ # Re-shape the weights for the grouped GEMMs.
+ ne = mpu.experts_per_rank(self.args)
+ w1 = w1.view(ne, -1, self.args.hidden_size)
+ v1 = v1.view(ne, -1, self.args.hidden_size)
+ w2 = w2.view(ne, -1, self.args.hidden_size)
+
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_grouped_glu(
+ x,
+ w1,
+ v1,
+ w2,
+ batch_sizes,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ assert gg.ops is not None
+ x1 = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
+ x2 = gg.ops.gmm(x, v1, batch_sizes, trans_b=True)
+ x1 = self.args.activation_fn(x1) * x2
+ return gg.ops.gmm(x1, w2, batch_sizes)
+
+
+class SharedGLU(SharedMLP):
+ """GPU for shared expert.
+
+ Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTGLU class
+ """
+
+ def __init__(self, args: Arguments):
+ super().__init__(args)
+ self.gate_proj = args.fc_cls(
+ args.hidden_size,
+ self.args.shared_expert_hidden_size,
+ **self.fc_kwargs,
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/_layers/memory_test.py b/build/torch28-cxx11-cu129-x86_64-linux/_layers/memory_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..74d1166931b712635131985b25a89f4ca23e576d
--- /dev/null
+++ b/build/torch28-cxx11-cu129-x86_64-linux/_layers/memory_test.py
@@ -0,0 +1,103 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import gc
+
+import torch
+import torch.distributed as dist
+
+# from megablocks.layers import arguments, dmoe
+from . import arguments, dmoe
+
+_TESTS = ((8, 2048, 4096, 4096, 32, 4),)
+
+
+def get_tensors():
+ ptrs = set()
+ out = []
+ for obj in gc.get_objects():
+ if torch.is_tensor(obj):
+ if not obj.is_contiguous() or obj.data_ptr() in ptrs:
+ continue
+ out.append(obj)
+ ptrs.add(obj.data_ptr())
+ return out
+
+
+def test_memory(
+ group,
+ batch_size,
+ sequence_length,
+ hidden_size,
+ ffn_hidden_size,
+ num_experts,
+ top_k,
+):
+ args = arguments.Arguments(
+ hidden_size=hidden_size,
+ ffn_hidden_size=ffn_hidden_size,
+ moe_num_experts=num_experts,
+ moe_top_k=top_k,
+ moe_expert_model_parallelism=True,
+ expert_parallel_group=group,
+ fp16=False,
+ bf16=True,
+ device=torch.cuda.current_device(),
+ )
+ layer = dmoe.dMoE(args).cuda()
+
+ x = torch.randn((batch_size, sequence_length, hidden_size),
+ device=torch.cuda.current_device(),
+ dtype=torch.bfloat16).requires_grad_(True)
+ torch.cuda.empty_cache()
+
+ # Run forward + backward.
+ # with torch.autograd.detect_anomaly():
+ out, _ = layer(x)
+ out.mean().backward()
+
+ # Report peak memory.
+ mem = torch.cuda.max_memory_allocated()
+ print('Max Memory Allocated = {:0.0f}MiB'.format(mem / 1e6))
+ print('Max Memory Reserved = {:0.0f}MiB'.format(torch.cuda.max_memory_reserved() / 1e6,),)
+
+ # Calculate weight and gradient memory usage.
+ weight_memory = 2 * (
+ layer.router.layer.weight.numel() + layer.experts.mlp.w1.numel() + layer.experts.mlp.w2.numel()
+ )
+
+ def grad_numel(x):
+ if x.grad is not None:
+ return x.grad.numel()
+ return 0
+
+ grad_memory = 2 * (
+ grad_numel(layer.router.layer.weight) + grad_numel(layer.experts.mlp.w1) + grad_numel(layer.experts.mlp.w2)
+ )
+ weight_memory += grad_memory
+
+ print('Weight Memory Allocated = {:0.0f}MiB'.format(weight_memory / 1e6))
+ print('Activation Memory Allocated = {:0.0f}MiB'.format((mem - weight_memory) / 1e6,),)
+
+ # Manually calculate GPU memory usage from the garbage
+ # collector.
+ gc.collect()
+ total = 0
+ tensors = get_tensors()
+ tensors = sorted(tensors, key=lambda x: -x.numel())
+ for i, t in enumerate(tensors):
+ total += t.numel()
+ print(f'{i}: {t.shape}, {t.numel() * 2}')
+ del tensors
+
+ print('Total Bytes Found = {:0.0f}MiB'.format(total * 2 / 1e6))
+
+
+if __name__ == '__main__':
+ assert dist.is_available()
+ group = dist.init_process_group(backend='nccl')
+ local_rank = dist.get_rank(group)
+ torch.cuda.set_device(local_rank)
+
+ for args in _TESTS:
+ test_memory(group, *args)
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/_layers/mlp.py b/build/torch28-cxx11-cu129-x86_64-linux/_layers/mlp.py
new file mode 100644
index 0000000000000000000000000000000000000000..c99afb9904c24a8b6a83e79059cd1251dbbfd99e
--- /dev/null
+++ b/build/torch28-cxx11-cu129-x86_64-linux/_layers/mlp.py
@@ -0,0 +1,587 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# try:
+# import stk
+# import stk.backend.triton_kernels
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/mlp.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+from packaging import version
+
+# from megablocks import grouped_gemm_util as gg
+# from megablocks.layers import common, gelu, mpu
+# from megablocks.layers.activation_fn import act_fn
+# from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
+
+from .. import grouped_gemm_util as gg
+from . import common, gelu, mpu
+from .activation_fn import act_fn
+from .arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
+
+class ScaleGradient(torch.autograd.Function):
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx: Any, x: torch.Tensor, scale: float):
+ ctx.scale = scale
+ return x
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx: torch.Tensor, grad: torch.Tensor):
+ return grad * ctx.scale, None
+
+
+scale_gradient = ScaleGradient.apply
+
+
+def resolve_dtensor(weight: torch.Tensor):
+ if version.parse(torch.__version__) >= version.parse('2.0.0'):
+ from torch.distributed._tensor import DTensor
+ if isinstance(weight, DTensor):
+ return weight.to_local()
+ return weight
+
+
+def create_moe_expert_weights(
+ args: Arguments,
+ num_experts: int,
+ ffn_hidden_size: int,
+ hidden_size: int,
+ init_method: InitFn,
+):
+ # Create the entire weight matrix such that the sampled weights will
+ # not vary between data parallelism and expert model parallelism for
+ # the same random seed.
+ master_weights = torch.empty(
+ num_experts,
+ ffn_hidden_size,
+ hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ )
+ init_method(master_weights)
+
+ if not args.moe_expert_model_parallelism:
+ return master_weights
+
+ # Calculate the amount of sharding in each dimension.
+ expert_sharding_degree = mpu.expert_sharding_degree(args)
+ hidden_sharding_degree = mpu.hidden_sharding_degree(args)
+
+ # Calculate the experts per rank.
+ #
+ # NOTE: We assign ranks to be expert parallel before going
+ # tensor parallel.
+ rank = mpu.get_expert_parallel_rank(args)
+ expert_rank = rank % expert_sharding_degree
+ num_experts_per_rank = num_experts // expert_sharding_degree
+ start_expert = expert_rank * num_experts_per_rank
+ end_expert = (expert_rank + 1) * num_experts_per_rank
+
+ # Calculate the rows per rank.
+ row_rank = rank // expert_sharding_degree
+ num_rows_per_rank = ffn_hidden_size // hidden_sharding_degree
+ start_row = row_rank * num_rows_per_rank
+ end_row = (row_rank + 1) * num_rows_per_rank
+
+ # Slice the weight matrix to get the chunk for this rank.
+ with torch.no_grad():
+ weights = master_weights[start_expert:end_expert, start_row:end_row]
+ return weights
+
+
+class MLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ # expert_parallel_world_size = mpu.get_expert_parallel_world_size(args)
+ experts_per_rank = mpu.experts_per_rank(args)
+
+ self.w1 = torch.nn.Parameter(
+ torch.empty(
+ experts_per_rank,
+ args.hidden_size,
+ mpu.features_per_rank(args),
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ self.w2 = torch.nn.Parameter(
+ torch.empty(
+ experts_per_rank,
+ mpu.features_per_rank(args),
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w1,
+ args.moe_expert_model_parallelism,
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w2,
+ args.moe_expert_model_parallelism,
+ )
+
+ # Initialize the parameters for the MLP.
+ #
+ # NOTE: It is important that we create the weight tensors prior
+ # to creating the master weights and slicing our the piece for
+ # this rank. If the master weights are created first the PyTorch
+ # caching allocator appears to use the same memory block for these
+ # and the slice which causes large increases in our peak memory
+ # usage.
+ with torch.no_grad():
+ w1 = create_moe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ )
+ self.w1.copy_(w1.transpose(1, 2).contiguous())
+ self.w2.copy_(
+ create_moe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.output_layer_init_method,
+ ),
+ )
+
+ self.gradient_scale = None
+ if self.args.moe_expert_model_parallelism:
+ self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,)
+
+ def scale_grad(self, w):
+ if self.gradient_scale is None:
+ return w
+ return scale_gradient(w, self.gradient_scale)
+
+ def forward(self, x):
+ w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2)
+ w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2)
+ x = torch.bmm(x, w1)
+ x = self.args.activation_fn(x)
+ return torch.bmm(x, w2)
+
+
+def create_dmoe_expert_weights(
+ args: Arguments,
+ num_experts: int,
+ rows: int,
+ columns: int,
+ init_method: InitFn,
+):
+ weights = create_moe_expert_weights(
+ args,
+ num_experts,
+ rows,
+ columns,
+ init_method,
+ )
+ return weights.view([-1, columns])
+
+
+class MemoryOptimizedMLP(torch.autograd.Function):
+ """Sparse MLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, w2, topo, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.")
+
+ topo_tensors = (
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets,
+ topo.column_indices_t,
+ topo.offsets_t,
+ topo.block_offsets_t,
+ )
+
+ # Layer 0: x @ w1.t().
+ sdd_out = stk.ops.sdd(x, w1.t(), topo)
+
+ # GeLU.
+ activation_fn_out = act_fn(sdd_out, activation_fn)
+
+ # Layer 1: x @ w2.
+ dsd_out = stk.ops.dsd(activation_fn_out, w2)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.shape = topo.shape
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.data.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, w2, *topo_tensors, x, sdd_out.data)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx, ddsd_out):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, w2 = saved_tensors[:2]
+ topo_tensors = saved_tensors[2:8]
+ x = saved_tensors[8]
+ sdd_out_data = saved_tensors[9]
+
+ # rematerialize activation function output
+ activation_fn = ctx.activation_fn
+ sdd_out = stk.Matrix(ctx.shape, sdd_out_data, *topo_tensors)
+ activation_fn_out, activation_grad_fn = act_fn(
+ sdd_out,
+ activation_fn,
+ return_grad_fn=True,
+ )
+
+ # Compute dw2 with recomputed activation_fn output.
+ dw2 = stk.ops.dsd(activation_fn_out.t(), ddsd_out)
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ stk.backend.triton_kernels.sdd(
+ ddsd_out,
+ w2.t(),
+ dactivation_fn_out.shape,
+ dactivation_fn_out.data,
+ dactivation_fn_out.offsets,
+ dactivation_fn_out.row_indices,
+ dactivation_fn_out.column_indices,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ if activation_fn is DEFAULT_ACTIVATION_FN:
+ dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out)
+ else:
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out.data)
+ dsdd_out = stk.Matrix(ctx.shape, sdd_out.data.grad, *topo_tensors)
+
+ # Compute dw1.
+ dw1 = stk.ops.dsd(dsdd_out.t(), x)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ stk.backend.triton_kernels.dsd(
+ dsdd_out.shape,
+ dsdd_out.data,
+ dsdd_out.offsets,
+ dsdd_out.row_indices,
+ dsdd_out.column_indices,
+ dsdd_out.offsets_t,
+ dsdd_out.column_indices_t,
+ dsdd_out.block_offsets_t,
+ False,
+ w1,
+ ddsd_out,
+ )
+ dx = ddsd_out
+ return dx, dw1, dw2, None, None
+
+
+memory_optimized_mlp = MemoryOptimizedMLP.apply
+
+
+class SparseMLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ self._num_rows_per_rank = mpu.experts_per_rank(args) * mpu.features_per_rank(args)
+
+ self.w1 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ self.w2 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+
+ # Initialize the parameters for the MLP.
+ #
+ # NOTE: It is important that we create the weight tensors prior
+ # to creating the master weights and slicing our the piece for
+ # this rank. If the master weights are created first the PyTorch
+ # caching allocator appears to use the same memory block for these
+ # and the slice which causes large increases in our peak memory
+ # usage.
+ with torch.no_grad():
+ self.w1.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ ),
+ )
+ self.w2.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.output_layer_init_method,
+ ),
+ )
+
+ self._should_set_parallelism_attribute = args.moe_expert_model_parallelism
+ mpu.set_expert_model_parallel_attributes(
+ self.w1,
+ self._should_set_parallelism_attribute,
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w2,
+ self._should_set_parallelism_attribute,
+ )
+
+ self.gradient_scale = None
+ if self.args.moe_expert_model_parallelism:
+ self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,)
+
+ def scale_grad(self, w):
+ if self.gradient_scale is None:
+ return w
+ return scale_gradient(w, self.gradient_scale)
+
+ def forward(self, x, topo):
+ w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2)
+ w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2)
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_mlp(
+ x,
+ w1,
+ w2,
+ topo,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ x = stk.ops.sdd(x, w1.t(), topo)
+ activation_fn_out = act_fn(x, self.args.activation_fn)
+ return stk.ops.dsd(activation_fn_out, w2)
+
+
+class MemoryOptimizedGroupedMLP(torch.autograd.Function):
+ """GroupedMLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, w2, batch_sizes, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.")
+
+ # Layer 0: x @ w1.t().
+ assert gg.backend is not None
+ sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True)
+
+ # activation_fn
+ activation_fn_out = activation_fn(sdd_out)
+
+ # Layer 1: x @ w2.
+ dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, w2, batch_sizes, x, sdd_out)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx: Any, ddsd_out: torch.Tensor):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # Unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, w2 = saved_tensors[:2]
+ batch_sizes = saved_tensors[2]
+ x = saved_tensors[3]
+ sdd_out = saved_tensors[4]
+
+ # Rematerialize activation_fn output.
+ activation_fn = ctx.activation_fn
+ with torch.set_grad_enabled(True):
+ sdd_out.requires_grad = True
+ activation_fn_out = activation_fn(sdd_out)
+ activation_grad_fn = activation_fn_out.backward
+
+ # Compute dw2 with recomputed activation_fn output.
+ assert gg.backend is not None
+ dw2 = gg.backend.gmm(
+ activation_fn_out,
+ ddsd_out,
+ batch_sizes,
+ trans_a=True,
+ )
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ gg.backend.gmm(
+ ddsd_out,
+ w2,
+ batch_sizes,
+ trans_b=True,
+ c=dactivation_fn_out,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ if activation_fn is DEFAULT_ACTIVATION_FN:
+ dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out)
+ else:
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out)
+ dsdd_out = sdd_out.grad
+
+ # Compute dw1.
+ dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ gg.backend.gmm(dsdd_out, w1, batch_sizes, c=ddsd_out)
+ dx = ddsd_out
+ return dx, dw1, dw2, None, None
+
+
+memory_optimized_grouped_mlp = MemoryOptimizedGroupedMLP.apply
+
+
+class GroupedMLP(SparseMLP):
+
+ def forward(self, x, tokens_per_expert):
+ batch_sizes = tokens_per_expert.cpu().to(torch.long)
+ w1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.w2))
+
+ # Re-shape the weights for the grouped GEMMs.
+ ne = mpu.experts_per_rank(self.args)
+ w1 = resolve_dtensor(w1).view(ne, -1, self.args.hidden_size)
+ w2 = resolve_dtensor(w2).view(ne, -1, self.args.hidden_size)
+
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_grouped_mlp(
+ x,
+ w1,
+ w2,
+ batch_sizes,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ assert gg.ops is not None
+ x = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
+ x = self.args.activation_fn(x)
+ return gg.ops.gmm(x, w2, batch_sizes)
+
+
+class SharedMLP(torch.nn.Module):
+ """MLP for shared expert.
+
+ Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTMLP class
+ """
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ self.fc_kwargs: dict[str, Any] = {
+ 'bias': args.bias,
+ 'device': args.device,
+ }
+ self.fc_kwargs.update(args.fc_kwargs)
+
+ self.up_proj = args.fc_cls(
+ args.hidden_size,
+ args.shared_expert_hidden_size,
+ **self.fc_kwargs,
+ )
+ self.act = args.activation_fn
+ self.down_proj = args.fc_cls(
+ args.shared_expert_hidden_size,
+ args.hidden_size,
+ **self.fc_kwargs,
+ )
+ self.down_proj._is_residual = True # a flag for llm-foundry init
+
+ def add_experts_sharedexpert(
+ self,
+ shared_expert_out: torch.Tensor,
+ expert_out: torch.Tensor,
+ ) -> torch.Tensor:
+ # Helper function to add expert output to shared expert output
+ # with optional weighted sum.
+ if self.args.shared_expert_weighted_sum:
+ # enable using weighted sum for shared expert output
+ # wieghted by number of experts used
+ t_experts = self.args.moe_top_k + 1
+ sh_mlp_out = shared_expert_out / t_experts
+ return sh_mlp_out.add(
+ expert_out,
+ alpha=(self.args.moe_top_k / t_experts),
+ )
+
+ return shared_expert_out + expert_out
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.down_proj(self.act(self.up_proj(x)))
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/_layers/moe.py b/build/torch28-cxx11-cu129-x86_64-linux/_layers/moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0a4aeaacc9c86fc70944e730c53f7a55644e05e
--- /dev/null
+++ b/build/torch28-cxx11-cu129-x86_64-linux/_layers/moe.py
@@ -0,0 +1,507 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Optional, Tuple
+
+import numpy as np
+import torch
+import torch.distributed as dist
+
+# import megablocks.ops as ops
+# from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry
+# from megablocks.layers.all_to_all import all_to_all
+# from megablocks.layers.arguments import Arguments
+
+from ..ops import (
+ sort,
+ histogram,
+ inclusive_cumsum,
+ exclusive_cumsum,
+ binned_gather,
+ binned_scatter,
+ gather,
+ scatter,
+ repeat,
+ replicate,
+)
+
+from . import common, mlp, mpu, router, sharedexpert_registry
+from .arguments import Arguments
+from .all_to_all import all_to_all
+
+_LOAD_BALANCING_LOSS = []
+
+
+def save_load_balancing_loss(loss):
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.append(loss)
+
+
+def get_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ return _LOAD_BALANCING_LOSS
+
+
+def clear_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.clear()
+
+
+def batched_load_balancing_loss(args: Arguments):
+ if args.moe_loss_weight == 0:
+ return 0.0
+
+ # tokens_per_expert[i].shape = (num_experts)
+ # expert_scores[i].shape = (tokens, num_experts)
+ tokens_per_expert, expert_scores = zip(*get_load_balancing_loss())
+ num_layers_per_pipeline_stage = (args.num_layers // args.pipeline_model_parallel_size)
+ if args.num_layers_per_virtual_pipeline_stage is not None:
+ num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage
+
+ if len(tokens_per_expert) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f'Expected {num_layers_per_pipeline_stage} token_per_experts '
+ f'but found {len(tokens_per_expert)}.\nnum_layers = '
+ f'{args.num_layers}\npipeline_model_parallel_size = '
+ f'{args.pipeline_model_parallel_size}\n'
+ 'num_layers_per_virtual_pipeline_stage'
+ f' = {args.num_layers_per_virtual_pipeline_stage}',
+ )
+ if len(expert_scores) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f'Expected {num_layers_per_pipeline_stage} expert_scores '
+ f'but found {len(tokens_per_expert)}.\nnum_layers = '
+ f'{args.num_layers}\npipeline_model_parallel_size = '
+ f'{args.pipeline_model_parallel_size}\n'
+ 'num_layers_per_virtual_pipeline_stage'
+ f' = {args.num_layers_per_virtual_pipeline_stage}',
+ )
+
+ # Verify the shape of the tokens_per_expert and expert_scores tensors.
+ assert all((x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert))
+
+ tokens = expert_scores[0].shape[0]
+ assert all(((x.ndim == 2 and x.shape[1] == args.moe_num_experts and x.shape[0] == tokens) for x in expert_scores))
+
+ # Concatenate the contributions of each layer and convert to
+ # the correct types and formats for the dot product.
+ expert_scores = torch.cat(expert_scores, dim=1)
+ if args.moe_lbl_in_fp32:
+ expert_scores = expert_scores.float()
+ if tokens != 0:
+ expert_scores = expert_scores.mean(dim=0)
+ else:
+ expert_scores = expert_scores.sum(dim=0)
+ tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype)
+
+ expected_values = num_layers_per_pipeline_stage * args.moe_num_experts
+ assert tokens_per_expert.numel() == expected_values
+ assert expert_scores.numel() == expected_values
+
+ # Calculate the total scale across all factors.
+ #
+ # loss_weight * num_experts / (num_layers * tokens * top_k)
+ scale_numerator = (args.moe_num_experts * args.moe_loss_weight)
+ scale_denominator = (args.num_layers * tokens * args.moe_top_k)
+ scale = scale_numerator / scale_denominator
+ return scale * torch.dot(tokens_per_expert, expert_scores)
+
+
+# NOTE: This class defines MoE expert computation, including expert model parallel
+# communication. When using FSDP on top of MegaBlocks this is the module that should
+# be wrapped s.t. the weight all-gathers can be scheduled *before* the expert model
+# parallel all2all.
+class ParallelMLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super(ParallelMLP, self).__init__()
+ self.args = args
+
+ # Calculate the number of experts in total and the number of experts
+ # owned by this rank.
+ # world_size = mpu.get_expert_parallel_world_size(args)
+ self.num_experts = args.moe_num_experts
+ self.top_k = self.args.moe_top_k
+
+ # Calculate the number of bits needed to represent the expert indices
+ # so that we can pass it to radix sort.
+ self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1)
+
+ # Expert MLP.
+ self.mlp = mlp.MLP(args)
+
+ self.bias: Optional[torch.Tensor]
+ if self.args.bias:
+ # Note that the output bias is not parallelized with expert
+ # model parallelism.
+ self.bias = torch.nn.Parameter(
+ torch.empty(
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ torch.nn.init.zeros_(self.bias)
+ else:
+ self.register_parameter('bias', None)
+
+ # Select the forward function for the operating mode.
+ self.forward_fn = (self.parallel_forward_once if args.moe_expert_model_parallelism else self.forward_once)
+
+ def expert_capacity(self, tokens: int) -> int:
+ world_size = mpu.get_expert_parallel_world_size(self.args)
+ tokens_per_expert = (self.top_k * tokens * world_size / self.num_experts)
+ return int(self.args.moe_capacity_factor * tokens_per_expert)
+
+ def load_balancing_loss(self, tokens_per_expert: torch.Tensor, expert_scores: torch.Tensor):
+ """Calculate the load balancing loss contribution."""
+ assert len(expert_scores.size()) == 2
+ tokens, num_experts = expert_scores.size()
+ assert num_experts == self.num_experts
+ assert len(tokens_per_expert.size()) == 1
+ num_experts, = tokens_per_expert.size()
+ assert num_experts == self.num_experts
+ scale = self.num_experts / (tokens * self.top_k)
+ return scale * torch.dot(
+ tokens_per_expert.to(expert_scores.dtype),
+ expert_scores.mean(dim=0),
+ )
+
+ def indices_and_bins(self,
+ top_expert: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ # Sort the expert ids to produce the scatter/gather
+ # indices for the permutation.
+ #
+ # TODO(tgale): Is it worth doing this conversion to 32-bit
+ # prior? Could we place the `torch.max` operation to return
+ # 32-bit expert indices?
+ top_expert = top_expert.int()
+ # output = ops.sort(top_expert, self.sort_end_bit)
+ output = sort(top_expert, self.sort_end_bit)
+ assert output is not None
+ bin_ids, indices = output
+
+ # Histogram the expert ids to identify the number of
+ # tokens routed to each expert.
+ #
+ # TODO(tgale): Does the sorted data produce a more favorable
+ # data distribution for histogram? Or is the op parallelism
+ # worth more?
+ # tokens_per_expert = ops.histogram(top_expert, self.num_experts)
+ tokens_per_expert = histogram(top_expert, self.num_experts)
+
+ # Calculate the bin bounds for the sorted tokens.
+ # bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ bins = inclusive_cumsum(tokens_per_expert, 0)
+ assert bins is not None
+ bins = bins.view(1) if not len(bins.size()) else bins
+
+ assert isinstance(indices, torch.Tensor)
+ assert isinstance(bin_ids, torch.Tensor)
+ assert isinstance(bins, torch.Tensor)
+ assert isinstance(tokens_per_expert, torch.Tensor)
+
+ return indices, bin_ids, bins, tokens_per_expert
+
+ def permute_and_compute(
+ self,
+ x: torch.Tensor,
+ tokens_per_expert: int, # unused
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor, # unused
+ expert_weights: torch.Tensor,
+ bins: torch.Tensor,
+ expert_capacity: int,
+ top_k: int,
+ ):
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ # output = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
+ output = binned_gather(x, indices, bins, expert_capacity, top_k)
+ assert output is not None
+ x = output
+
+ # Perform the expert computation. Note that we don't
+ # use biases for these linear operations.
+ x = self.mlp(x)
+
+ # Un-route the data for the MoE output.
+ # return ops.binned_scatter(x, indices, expert_weights, bins, top_k)
+ return binned_scatter(x, indices, expert_weights, bins, top_k)
+
+
+ def forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ # If expert_capacity is set to zero, set the number of tokens
+ # per expert to the maximum we need to avoid dropping tokens.
+ sl, bs, _ = x.size()
+ expert_capacity = self.expert_capacity(sl * bs)
+ if expert_capacity == 0:
+ expert_capacity = torch.max(tokens_per_expert).item()
+
+ x = self.permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ self.top_k,
+ )
+ return x, tokens_per_expert
+
+ def parallel_forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ # NOTE: This function implements the same computation as forward_once
+ # but with expert model parallelism.
+ #
+ # 1. Permute the tokens locally so that they are grouped by their
+ # expert assignments. This allows us to transfer all of the tokens
+ # for a remote device in one communication primitive.
+ #
+ # 2. Permute the tokens across the expert parallel devices. After
+ # this is completed each device has all of the tokens assigned to
+ # its set of experts in its local HBM.
+ #
+ # 3. Permute the tokens locally so that they are grouped by their
+ # expert assignement. After the distributed permutation the tokens
+ # are grouped by which device they came from. We re-order them
+ # locally to allow for efficient computation.
+ #
+ # After this series of permutations we compute the linear layers
+ # and then repeat these three steps in reverse to produce the final
+ # output.
+ #
+ # Compute the mapping of local tokens to experts.
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ # If we're sharding the experts along the hidden dimension
+ # multiple devices own parts of the same sets of experts.
+ # Replicate the token counts so every device gets the counts.
+ # repeated_tokens_per_expert = ops.repeat(
+ repeated_tokens_per_expert = repeat(
+ tokens_per_expert,
+ (mpu.hidden_sharding_degree(self.args),),
+ )
+
+ # Pass token count information to the device on which the
+ # target expert resides.
+ parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert,)
+ tpe_handle = dist.all_to_all_single(
+ parallel_tokens_per_expert,
+ repeated_tokens_per_expert,
+ group=self.args.expert_parallel_group,
+ async_op=True,
+ )
+
+ # Permute locally and without any padding so that tokens for each
+ # parallel device are stored contiguously.
+ #
+ # This view updates the shape of the tensor from [sl, bs, hs] to
+ # [sl * bs, hs] prior to the permutation.
+ x = x.view(-1, x.shape[-1])
+ # output = ops.gather(x, indices, bin_ids, bins, self.top_k)
+ output = gather(x, indices, bin_ids, bins, self.top_k)
+ assert output is not None
+ x = output
+
+ # Compute the number of tokens that will be received from each
+ # device and permute the input data across the devices.
+ with torch.no_grad():
+ tpe_handle.wait()
+ experts_per_rank = mpu.experts_per_rank(self.args)
+
+ # Reshape to [world_size, num_experts_per_rank].
+ world_size = mpu.get_expert_parallel_world_size(self.args)
+ repeated_tokens_per_expert = (repeated_tokens_per_expert.view(world_size, experts_per_rank))
+ parallel_tokens_per_expert = (parallel_tokens_per_expert.view(world_size, experts_per_rank))
+
+ # TODO(tgale): It might be faster to do this on the GPU and
+ # then communicate the results back to the host.
+ send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1)
+ parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
+ recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1)
+
+ # Convert the send/recv counts to lists.
+ send_counts = send_counts.tolist()
+ recv_counts = recv_counts.tolist()
+ tokens_received = sum(recv_counts)
+
+ # If we're sharding the experts along the hidden dimension
+ # multiple devices own parts of the same sets of experts.
+ # Replicate the token counts so devices that share experts
+ # get all of the tokens assigned to them.
+ #
+ # TODO(tgale): Fuse this into the prior, local permutation.
+ # x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
+ x = repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
+
+ # Start the cross-device permutation asynchronously so we can
+ # overlap communication with computation.
+ parallel_x, parallel_x_handle = all_to_all(
+ x,
+ recv_counts,
+ send_counts,
+ self.args.expert_parallel_group,
+ async_op=True,
+ )
+
+ with torch.no_grad():
+ # After we do the cross-device permutation we have the tokens on the
+ # correct device but not yet grouped by expert because we received
+ # tokens from each device as contiguous chunks. To group the tokens
+ # for expert computation we'll do one more local permutation. The
+ # rest of this torch.no_grad() scope sets up the indices and bins
+ # for this permutation.
+ # replicate_bins = ops.inclusive_cumsum(
+ replicate_bins = inclusive_cumsum(
+ parallel_tokens_per_expert.flatten(),
+ 0,
+ )
+ replicate_bins = (replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins)
+
+ # Construct the expert indices for the permuted tokens.
+ parallel_top_expert = torch.remainder(
+ torch.arange(
+ self.num_experts * mpu.hidden_sharding_degree(self.args),
+ dtype=torch.int32,
+ device=indices.device,
+ ),
+ mpu.experts_per_rank(self.args),
+ )
+ # parallel_top_expert = ops.replicate(
+ parallel_top_expert = replicate(
+ parallel_top_expert.unsqueeze(dim=0),
+ replicate_bins,
+ tokens_received,
+ ).flatten()
+
+ # TODO(tgale): The sort_end_bit here can be reduced.
+ # parallel_bin_ids, parallel_indices = ops.sort(
+ parallel_bin_ids, parallel_indices = sort(
+ parallel_top_expert,
+ self.sort_end_bit,
+ )
+
+ # Calculate the bins boundaries from the token counts.
+ parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+ # parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = (parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins)
+
+ # If expert_capacity is set to zero, set the number of tokens
+ # per expert to the maximum we need to avoid dropping tokens.
+ tokens, _ = x.size()
+ expert_capacity = self.expert_capacity(tokens)
+ if expert_capacity == 0:
+ expert_capacity = torch.max(parallel_tokens_per_expert).item()
+
+ # Locally permute the tokens and perform the expert computation.
+ # Block to make sure that the cross-device permutation is complete.
+ if self.args.mlp_impl == 'grouped':
+ # GroupedMLP requires counts on CPU. We can use the tensor already
+ # moved to CPU for the prior all_to_all, which avoids an extra
+ # device synchronization.
+ parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+ parallel_x_handle.wait()
+ parallel_x = self.permute_and_compute(
+ parallel_x,
+ parallel_tokens_per_expert,
+ parallel_indices,
+ parallel_bin_ids,
+ None, # expert_weights
+ parallel_bins,
+ expert_capacity,
+ top_k=1,
+ )
+
+ # Un-permute the tokens across the devices.
+ x, _ = all_to_all(
+ parallel_x,
+ send_counts,
+ recv_counts,
+ self.args.expert_parallel_group,
+ )
+
+ # Reduce along the hidden sharding to get the final outputs.
+ #
+ # TODO(tgale): Fuse this into the following local permutation.
+ shape = (
+ mpu.hidden_sharding_degree(self.args),
+ -1,
+ self.args.hidden_size,
+ )
+ # x = ops.sum(x.view(shape), dim=0)
+ x = x.view(shape).sum(dim=0)
+
+ # Un-permute locally to setup for the next series of operations.
+ # x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
+ x = scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
+ return x, tokens_per_expert.flatten()
+
+ def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ in_shape = x.size()
+
+ # Compute the experts.
+ x, tokens_per_expert = self.forward_fn(x, expert_weights, top_experts)
+ if self.training and self.args.moe_loss_weight > 0:
+ save_load_balancing_loss((tokens_per_expert, scores))
+ x = x.view(in_shape)
+ if self.bias is not None:
+ if self.args.return_bias:
+ return x, self.bias
+ return x + self.bias
+ return x
+
+
+class MoE(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super(MoE, self).__init__()
+
+ # Token router.
+ self.router = router.LearnedRouter(args)
+
+ # Expert computation helper.
+ self.experts = self._init_experts_mlp(args)
+
+ self.shared_expert = None
+ if args.shared_expert:
+ # SharedExpert computation helper.
+ self.shared_expert = sharedexpert_registry.get(args)
+
+ def _init_experts_mlp(self, args: Arguments):
+ return ParallelMLP(args)
+
+ def forward(self, x: torch.Tensor):
+ # NOTE: If we're going to cast the activations to lower precision
+ # do it before we permute the tokens to save bandwidth.
+ x = common.cast_if_autocast_enabled(x)
+
+ # Compute the expert scores and assignments.
+ scores, expert_weights, top_experts = self.router(x)
+
+ # Compute the experts.
+ out = self.experts(x, scores, expert_weights, top_experts)
+ if self.shared_expert is not None:
+ shared_expert_out = self.shared_expert(x)
+ out = self.shared_expert.add_experts_sharedexpert(
+ shared_expert_out,
+ out,
+ )
+ return out
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/_layers/mpu.py b/build/torch28-cxx11-cu129-x86_64-linux/_layers/mpu.py
new file mode 100644
index 0000000000000000000000000000000000000000..434e143ab42bf3f83406d69e9dd1f72777716e22
--- /dev/null
+++ b/build/torch28-cxx11-cu129-x86_64-linux/_layers/mpu.py
@@ -0,0 +1,94 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Optional
+
+import torch
+import torch.distributed as dist
+
+# from megablocks.layers.arguments import Arguments
+from .arguments import Arguments
+
+
+class MoeParam(torch.Tensor):
+
+ def __init__(self):
+ super().__init__(self)
+ self.expert_model_parallel: bool
+
+
+def is_moe_param(tensor: torch.Tensor) -> bool:
+ return hasattr(tensor, 'expert_model_parallel')
+
+
+def get_expert_parallel_world_size(args: Arguments) -> int:
+ return (dist.get_world_size(args.expert_parallel_group) if args.moe_expert_model_parallelism else 1)
+
+
+def get_expert_parallel_rank(args: Arguments) -> int:
+ return (dist.get_rank(args.expert_parallel_group) if args.moe_expert_model_parallelism else 0)
+
+
+def set_expert_model_parallel_attributes(
+ tensor: torch.Tensor,
+ is_parallel: bool,
+):
+ assert not hasattr(tensor, 'expert_model_parallel')
+ setattr(tensor, 'expert_model_parallel', is_parallel)
+
+
+def param_is_expert_model_parallel(param: MoeParam) -> bool:
+ return (hasattr(param, 'expert_model_parallel') and param.expert_model_parallel)
+
+
+def copy_expert_model_parallel_attributes(
+ destination_tensor: torch.Tensor,
+ source_tensor: torch.Tensor,
+):
+ if hasattr(source_tensor, 'expert_model_parallel'):
+ setattr(
+ destination_tensor,
+ 'expert_model_parallel',
+ getattr(source_tensor, 'expert_model_parallel'),
+ )
+
+
+def synchronized_print(group: Optional[dist.ProcessGroup], *x: torch.Tensor):
+ world_size = dist.get_world_size(group)
+ rank = dist.get_rank(group)
+ for i in range(world_size):
+ dist.barrier(group)
+ if i == rank:
+ print(f'rank = {rank}', *x)
+
+
+# Helpers for expert/tensor sharding.
+def expert_sharding_degree(args: Arguments) -> int:
+ world_size = get_expert_parallel_world_size(args)
+ esd = min(world_size, args.moe_num_experts)
+
+ if (args.moe_num_experts % esd) != 0:
+ raise ValueError(f'Cannot shard {args.moe_num_experts} experts {esd} ways.',)
+ return esd
+
+
+def hidden_sharding_degree(args: Arguments) -> int:
+ world_size = get_expert_parallel_world_size(args)
+ esd = expert_sharding_degree(args)
+ hsd = world_size // esd
+
+ if (args.ffn_hidden_size % hsd) != 0:
+ raise ValueError(f'Cannot shard {args.ffn_hidden_size} features {hsd} ways.',)
+ if (esd * hsd) != world_size:
+ raise ValueError(
+ f"Invalid sharding. 'expert_sharding_degree' ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size}).",
+ )
+ return hsd
+
+
+def experts_per_rank(args: Arguments) -> int:
+ return args.moe_num_experts // expert_sharding_degree(args)
+
+
+def features_per_rank(args: Arguments) -> int:
+ return args.ffn_hidden_size // hidden_sharding_degree(args)
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/_layers/router.py b/build/torch28-cxx11-cu129-x86_64-linux/_layers/router.py
new file mode 100644
index 0000000000000000000000000000000000000000..37cb2782348d62583376f1a183c7ede83601216d
--- /dev/null
+++ b/build/torch28-cxx11-cu129-x86_64-linux/_layers/router.py
@@ -0,0 +1,116 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+
+# from megablocks.layers import common
+# from megablocks.layers.arguments import Arguments
+from . import common
+from .arguments import Arguments
+
+_ROUTER_LOGITS = []
+
+
+def _save_router_logits(logits: torch.Tensor, args: Arguments):
+ if args.moe_zloss_weight == 0:
+ return
+ global _ROUTER_LOGITS
+ _ROUTER_LOGITS.append(logits)
+
+
+def clear_router_zloss():
+ global _ROUTER_LOGITS
+ _ROUTER_LOGITS.clear()
+
+
+def batched_router_zloss(args: Arguments):
+ global _ROUTER_LOGITS
+
+ if args.moe_zloss_weight == 0:
+ import warnings
+ warnings.warn('Call to batched_router_zloss, but moe_zloss_weight=0')
+ return 0
+
+ logits_per_router = _ROUTER_LOGITS
+
+ if args.moe_zloss_in_fp32:
+ logits_per_router = [logits.float() for logits in logits_per_router]
+
+ unscaled_zloss_per_router = torch.stack([
+ torch.logsumexp(logits, dim=1).square().mean() for logits in logits_per_router
+ ])
+
+ return args.moe_zloss_weight * unscaled_zloss_per_router
+
+
+# NOTE: To enable end-to-end benchmarking without convergence we
+# support a flag to force the router to assign tokens uniformly
+# across the experts. We do this with a custom autograd operation
+# so that PyTorch still executes the full set of router operation.
+class _UniformExpertAssignment(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, num_experts: int):
+ out = torch.arange(x.numel(), dtype=x.dtype, device=x.device)
+ out = torch.remainder(out, num_experts)
+ return out.view(x.shape)
+
+
+_uniform_expert_assignment = _UniformExpertAssignment.apply
+
+
+class LearnedRouter(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+
+ # Learned router parameters.
+ #
+ # NOTE: This weight matrix is not parallelized with expert model
+ # parallelism. Each device needs the entire router weight matrix
+ # so that it can route its batch of data correctly.
+ self.layer = torch.nn.Linear(
+ args.hidden_size,
+ args.moe_num_experts,
+ bias=False,
+ dtype=common.dtype(args),
+ device=args.device,
+ )
+ args.init_method(self.layer.weight)
+
+ def jitter(self, x: torch.Tensor):
+ low: float = 1.0 - self.args.moe_jitter_eps
+ high: float = 1.0 + self.args.moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return low + noise * (high - low)
+
+ def _top_k(self, scores: torch.Tensor):
+ if self.args.moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, self.args.moe_top_k, dim=-1)
+
+ def forward(self, x: torch.Tensor):
+ if self.training and self.args.moe_jitter_eps is not None:
+ x = x * self.jitter(x)
+
+ logits = self.layer(x.view(-1, x.shape[-1]))
+ _save_router_logits(logits, self.args)
+ scores = logits.softmax(dim=-1)
+ expert_weights, expert_indices = self._top_k(scores)
+ if self.args.moe_normalize_expert_weights:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=self.args.moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+
+ expert_indices = (
+ _uniform_expert_assignment(
+ expert_indices,
+ self.args.moe_num_experts,
+ ) if self.args.uniform_expert_assignment else expert_indices
+ )
+ return scores, expert_weights, expert_indices
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/_layers/sharedexpert_registry.py b/build/torch28-cxx11-cu129-x86_64-linux/_layers/sharedexpert_registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..5840862f88f370ace5fd49bd0612fc98d186cc49
--- /dev/null
+++ b/build/torch28-cxx11-cu129-x86_64-linux/_layers/sharedexpert_registry.py
@@ -0,0 +1,32 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Union
+
+# from megablocks.layers import glu, mlp
+# from megablocks.layers.arguments import Arguments
+from . import glu, mlp
+from .arguments import Arguments
+
+_REGISTRY = {
+ 'mlp': mlp.SharedMLP,
+ 'glu': glu.SharedGLU,
+}
+
+
+def get(args: Arguments) -> Union[mlp.SharedMLP, glu.SharedGLU]:
+ """Returns an SharedMLP for use in a dMoE instance.
+
+ Uses the provided arguments to instantiate the appropriate
+ SharedMLP instance.
+
+ Args:
+ args: propagated Arguments dataclass.
+
+ Returns:
+ An instantiated SharedMLP constructed using the input args.
+ """
+ if args.mlp_type not in _REGISTRY:
+ raise ValueError(f'Unsupported mlp type: {args.mlp_type}')
+
+ return _REGISTRY[args.mlp_type](args)
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/_megablocks_4f35d2a.abi3.so b/build/torch28-cxx11-cu129-x86_64-linux/_megablocks_4f35d2a.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..778d495436e824f26c32d2932a721e98d5c38807
--- /dev/null
+++ b/build/torch28-cxx11-cu129-x86_64-linux/_megablocks_4f35d2a.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5738502578db7ee323b7c4fbf17f40cac29061c825f88ce0e9d331ec0f3e7f06
+size 16003376
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/_ops.py b/build/torch28-cxx11-cu129-x86_64-linux/_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..69479df3044627d4a8ac3fb70d0b1f0e9b22deed
--- /dev/null
+++ b/build/torch28-cxx11-cu129-x86_64-linux/_ops.py
@@ -0,0 +1,9 @@
+import torch
+from . import _megablocks_4f35d2a
+ops = torch.ops._megablocks_4f35d2a
+
+def add_op_namespace_prefix(op_name: str):
+ """
+ Prefix op by namespace.
+ """
+ return f"_megablocks_4f35d2a::{op_name}"
\ No newline at end of file
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/_version.py b/build/torch28-cxx11-cu129-x86_64-linux/_version.py
new file mode 100644
index 0000000000000000000000000000000000000000..c55783177af19bc03654c730c4892df8f8532279
--- /dev/null
+++ b/build/torch28-cxx11-cu129-x86_64-linux/_version.py
@@ -0,0 +1,6 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+"""The MegaBlocks Version."""
+
+__version__ = '0.11.0.dev0'
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/backend/__init__.py b/build/torch28-cxx11-cu129-x86_64-linux/backend/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d4e43e9b7e6a9a1c3bde2df34914643ca5d8332
--- /dev/null
+++ b/build/torch28-cxx11-cu129-x86_64-linux/backend/__init__.py
@@ -0,0 +1,2 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/backend/kernels.py b/build/torch28-cxx11-cu129-x86_64-linux/backend/kernels.py
new file mode 100644
index 0000000000000000000000000000000000000000..d20139324883338ddb312e4b05a72056d74491ac
--- /dev/null
+++ b/build/torch28-cxx11-cu129-x86_64-linux/backend/kernels.py
@@ -0,0 +1,557 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import triton
+import triton.language as tl
+
+# Stub triton autotune when testing in a env that does not have CUDA
+# this approach preserves the original code but enables testing without a GPU
+if torch.cuda.is_available() is False:
+ import warnings
+
+ warnings.warn("CUDA is not available. Triton autotuning is disabled.")
+
+ def _no_autotune(*args, **kwargs):
+ def deco(fn):
+ return fn
+ return deco
+
+ triton.autotune = _no_autotune
+
+
+def assert_is_tensor(x, ndim):
+ if x.ndim != ndim:
+ raise ValueError(f'Expected {ndim}-tensor but got {x.ndim}-tensor')
+
+
+def assert_is_matrix(x):
+ assert_is_tensor(x, 2)
+
+
+def assert_is_vector(x):
+ if x.ndim != 1:
+ raise ValueError(f'Expected 1-tensor but got {x.ndim}-tensor')
+
+
+def assert_equal(a, b):
+ if a != b:
+ raise ValueError(f'Expected dimensions to be equal but got {a} and {b}.',)
+
+
+# a: (tokens, hidden_size), real.
+# indices: (tokens * top_k), integer.
+# bin_ids: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+# padded_bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _padded_copy(
+ a,
+ b,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+ A_TO_B: tl.constexpr,
+ SCALE: tl.constexpr,
+):
+ # Our index into array 'a'.
+ index_a = tl.load(indices + tl.program_id(0))
+
+ # One threadblock per row in 'a'. Array 'b' has greater or equal
+ # number of rows since they could be padded.
+ bin_idx = tl.load(bin_ids + tl.program_id(0))
+
+ # Now we know what bin we're assigned to, but we need to know how
+ # many threadblocks were assigned to earlier bins so we can offset
+ # in our bin properly.
+ offset_in_bin = tl.program_id(0)
+ if bin_idx > 0:
+ offset_in_bin -= tl.load(bins + bin_idx - 1)
+
+ # Load the starting index of our bin in array 'b'.
+ index_b = offset_in_bin
+ if bin_idx > 0:
+ index_b += tl.load(padded_bins + bin_idx - 1)
+
+ # Offset the input and output pointers.
+ #
+ # If we're going from A to B, divide the input index to copy
+ # the same input repeatedly. If we're going from B to A we
+ # need to reduce the result. Using atomics is slow, so we
+ # do the reduce step in a second kernel.
+ offset = index_a // TOP_K if A_TO_B else index_a
+ a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS)
+ b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ # Load the scale, if requested.
+ scale = tl.load(weights + index_a) if SCALE else 1
+
+ # Swap the pointers depending on the direction.
+ iptr = a if A_TO_B else b
+ optr = b if A_TO_B else a
+
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ x = tl.load(iptr + offsets, mask=mask)
+ x = x.to(tl.float32) * scale.to(tl.float32)
+
+ tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask)
+
+ offsets += BLOCK_X
+
+
+def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+ assert_equal(bin_ids.shape[0], x.shape[0] * top_k)
+ assert_equal(bins.size(), padded_bins.size())
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ # NOTE: Because of the padding, the output size is dynamic.
+ # We load the final padded bin bound to get the output rows.
+ output_rows = padded_bins[-1].cpu().item()
+ out = torch.zeros((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ x,
+ out,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def gather(x, indices, bin_ids, weights, bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+ assert_equal(bin_ids.shape[0], x.shape[0] * top_k)
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ # NOTE: There is no padding so the output rows equals the
+ # input rows multiplied by top_k.
+ output_rows = x.shape[0] * top_k
+ out = torch.empty((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ x,
+ out,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], bin_ids.shape[0])
+ assert_equal(bins.size(), padded_bins.size())
+
+ if weights is not None:
+ assert_equal(indices.shape[0], weights.shape[0])
+
+ tokens = indices.shape[0] // top_k
+ out = torch.empty((tokens, top_k, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ out,
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=False,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+
+ # Reduce along the top-k dimension, if needed.
+ return out.sum(dim=1) if top_k > 1 else out.view(tokens, x.shape[1])
+
+
+def scatter(x, indices, bin_ids, weights, bins, top_k):
+ return padded_scatter(x, indices, bin_ids, weights, bins, bins, top_k)
+
+
+# x: (tokens, top_k, hidden_size), real
+# grad: (tokens, hidden_size), real.
+# wgrad: (tokens, top_k), real.
+# indices: (tokens * top_k), integer.
+# bin_ids: (tokens * top_k), integer.
+# bins: (num_experts), integer.
+# padded_bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _padded_copy_wgrad(
+ x,
+ grad,
+ wgrad,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+):
+ # Our index into 'tokens * top_k'.
+ index_out = tl.load(indices + tl.program_id(0))
+
+ # One threadblock per row in 'a'. Array 'b' has greater or equal
+ # number of rows since they could be padded.
+ bin_idx = tl.load(bin_ids + tl.program_id(0))
+
+ # Now we know what bin we're assigned to, but we need to know how
+ # many threadblocks were assigned to earlier bins so we can offset
+ # in our bin properly.
+ offset_in_bin = tl.program_id(0)
+ if bin_idx > 0:
+ offset_in_bin -= tl.load(bins + bin_idx - 1)
+
+ # Load the starting index of our bin in array 'x'.
+ index_x = offset_in_bin
+ if bin_idx > 0:
+ index_x += tl.load(padded_bins + bin_idx - 1)
+
+ # Offset the input and output pointers.
+ wgrad += index_out
+ grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS)
+ x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ data = tl.load(x + offsets, mask=mask).to(tl.float32)
+ scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
+ acc += data * scale
+ offsets += BLOCK_X
+
+ # Reduce to get the final result and store.
+ out = tl.sum(acc).to(wgrad.dtype.element_ty)
+ tl.store(wgrad, out)
+
+
+def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_matrix(grad)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], bin_ids.shape[0])
+ assert_equal(bins.size(), padded_bins.size())
+
+ tokens = indices.shape[0] // top_k
+ out = torch.empty((tokens * top_k), dtype=x.dtype, device=x.device)
+ _padded_copy_wgrad[(indices.shape[0],)](
+ x,
+ grad,
+ out,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ TOP_K=top_k,
+ )
+ return out
+
+
+def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k):
+ return padded_scatter_wgrad(x, grad, indices, bin_ids, bins, bins, top_k)
+
+
+# a: (tokens, hidden_size), real.
+# b: (num_experts, expert_capacity, num_columns), real.
+# indices: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _binned_copy(
+ a,
+ b,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+ A_TO_B: tl.constexpr,
+ SCALE: tl.constexpr,
+):
+ # Load our indices into the output.
+ expert_idx = tl.program_id(0)
+ entry_idx = tl.program_id(1)
+
+ # Calculate our offset into the output.
+ index_b = expert_idx * expert_capacity + entry_idx
+
+ # Load the index bounds for our bin and calculate
+ # the number of tokens assigned to our expert.
+ start = 0
+ if expert_idx > 0:
+ start = tl.load(bins + expert_idx - 1)
+ end = tl.load(bins + expert_idx)
+ num_tokens = end - start
+
+ # Calculate our offset into the input. If we don't
+ # have an input exit early.
+ if entry_idx >= num_tokens:
+ return
+ index_a = tl.load(indices + start + entry_idx)
+
+ # Offset the input and output pointers.
+ #
+ # If we're going from A to B, divide the input index to copy
+ # the same input repeatedly. If we're going from B to A we
+ # need to reduce the result. Using atomics is slow, so we
+ # do the reduce step in a second kernel.
+ offset = index_a // TOP_K if A_TO_B else index_a
+ a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS)
+ b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ # Load the scale, if requested.
+ scale = tl.load(weights + index_a) if SCALE else 1
+
+ # Swap the pointers depending on the direction.
+ #
+ # NOTE: We need to zero the output in both directions.
+ iptr = a if A_TO_B else b
+ optr = b if A_TO_B else a
+
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ x = tl.load(iptr + offsets, mask=mask)
+ x = x.to(tl.float32) * scale.to(tl.float32)
+
+ tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask)
+
+ offsets += BLOCK_X
+
+
+def binned_gather(x, indices, weights, bins, expert_capacity, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ num_experts = bins.shape[0]
+ out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
+
+ _binned_copy[(num_experts, expert_capacity)](
+ x,
+ out,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def binned_scatter(x, indices, weights, bins, top_k):
+ # Validate the input shapes.
+ assert_is_tensor(x, 3)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(bins.shape[0], x.shape[0])
+
+ if weights is not None:
+ assert_equal(indices.shape[0], weights.shape[0])
+
+ num_experts, expert_capacity, hidden_size = x.shape
+ tokens = indices.shape[0] // top_k
+ out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
+ _binned_copy[(num_experts, expert_capacity)](
+ out,
+ x,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS=hidden_size,
+ A_TO_B=False,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+
+ # Reduce along the top-k dimension, if needed.
+ return out.sum(dim=1) if top_k > 1 else out.view(tokens, hidden_size)
+
+
+# a: (tokens, hidden_size), real.
+# b: (num_experts, expert_capacity, num_columns), real.
+# indices: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _binned_copy_wgrad(
+ x,
+ grad,
+ wgrad,
+ num_experts,
+ expert_capacity,
+ indices,
+ bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+):
+ # Load our indices into the output.
+ expert_idx = tl.program_id(0)
+ entry_idx = tl.program_id(1)
+
+ # Calculate our offset into the output.
+ index_x = expert_idx * expert_capacity + entry_idx
+
+ # Load the index bounds for our bin and calculate
+ # the number of tokens assigned to our expert.
+ start = 0
+ if expert_idx > 0:
+ start = tl.load(bins + expert_idx - 1)
+ end = tl.load(bins + expert_idx)
+ num_tokens = end - start
+
+ # Calculate our offset into the input. If we don't
+ # have an input exit early.
+ if entry_idx >= num_tokens:
+ return
+ index_out = tl.load(indices + start + entry_idx)
+
+ # Offset the input and output pointers.
+ wgrad += index_out
+ grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS)
+ x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ data = tl.load(x + offsets, mask=mask).to(tl.float32)
+ scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
+ acc += data * scale
+ offsets += BLOCK_X
+
+ # Reduce to get the final result and store.
+ out = tl.sum(acc).to(wgrad.dtype.element_ty)
+ tl.store(wgrad, out)
+
+
+def binned_scatter_wgrad(x, grad, indices, bins, top_k):
+ # Validate the input shapes.
+ assert_is_tensor(x, 3)
+ assert_is_matrix(grad)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(bins.shape[0], x.shape[0])
+
+ num_experts, expert_capacity, hidden_size = x.shape
+ tokens = indices.shape[0] // top_k
+ out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
+ _binned_copy_wgrad[(num_experts, expert_capacity)](
+ x,
+ grad,
+ out,
+ num_experts,
+ expert_capacity,
+ indices,
+ bins,
+ NUM_COLUMNS=hidden_size,
+ TOP_K=top_k,
+ )
+ return out
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/benchmark_util.py b/build/torch28-cxx11-cu129-x86_64-linux/benchmark_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..02612d95e3ead1175a596e2878fa34b5bf85ad6f
--- /dev/null
+++ b/build/torch28-cxx11-cu129-x86_64-linux/benchmark_util.py
@@ -0,0 +1,35 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import numpy as np
+import torch
+
+
+def log_benchmark(name, arguments, time, std):
+ print('=' * 60)
+ print(f'{name} Benchmark')
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean time = {:.3f}ms, std time = {:.3f}ms'.format(time, std))
+ print('=' * 60)
+
+
+def benchmark_function(fn, iterations=100, warmup=10):
+ # Warmup iterations.
+ for _ in range(warmup):
+ fn()
+
+ times = []
+ for i in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+
+ start.record()
+ fn()
+ end.record()
+
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ return np.mean(times), np.std(times)
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/grouped_gemm/__init__.py b/build/torch28-cxx11-cu129-x86_64-linux/grouped_gemm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b91c8308f0c24f4c4171b6e4f15b6f76dabf295a
--- /dev/null
+++ b/build/torch28-cxx11-cu129-x86_64-linux/grouped_gemm/__init__.py
@@ -0,0 +1,2 @@
+from . import ops
+from . import backend
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/grouped_gemm/backend.py b/build/torch28-cxx11-cu129-x86_64-linux/grouped_gemm/backend.py
new file mode 100644
index 0000000000000000000000000000000000000000..76037d8039cbfc2f0577275c78e4bc0be762592a
--- /dev/null
+++ b/build/torch28-cxx11-cu129-x86_64-linux/grouped_gemm/backend.py
@@ -0,0 +1,33 @@
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# # TODO(tgale): Wrap this in a try-block with better
+# # error message and instructions for building the
+# # c++ operations.
+# import grouped_gemm_backend as backend
+
+# We import the backend operations from the megablocks package as
+# grouped_gemm is vendored in megablocks in this repository.
+# from ... import _ops as backend
+# from megablocks._ops import ops as backend # type: ignore
+from .._ops import ops as backend # type: ignore
+
+def _allocate_output(a, b, batch_sizes, trans_a, trans_b):
+ assert not (trans_a and trans_b)
+ assert batch_sizes.ndim == 1, "Expected 1d tensor for batch_sizes"
+ assert a.ndim == 2, "Expected 2d tensor for 'a'"
+ assert b.ndim == (2 if trans_a else 3)
+
+ shape = (
+ (batch_sizes.shape[0], a.shape[1], b.shape[1])
+ if trans_a else
+ (a.shape[0], (b.shape[1] if trans_b else b.shape[2]))
+ )
+ return torch.empty(*shape, device=a.device, dtype=a.dtype)
+
+def gmm(a, b, batch_sizes, trans_a=False, trans_b=False, c=None):
+ if c is None:
+ c = _allocate_output(a, b, batch_sizes, trans_a, trans_b)
+ backend.gmm(a, b, c, batch_sizes, trans_a, trans_b)
+ return c
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/grouped_gemm/ops.py b/build/torch28-cxx11-cu129-x86_64-linux/grouped_gemm/ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b30dd14e23837ea3b12334f4e31337ed9ad2b69
--- /dev/null
+++ b/build/torch28-cxx11-cu129-x86_64-linux/grouped_gemm/ops.py
@@ -0,0 +1,33 @@
+from . import backend
+import torch
+
+
+class GroupedGemm(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, a, b, batch_sizes, trans_b):
+ ctx.save_for_backward(a, b, batch_sizes)
+ ctx.trans_b = trans_b
+ return backend.gmm(a, b, batch_sizes, trans_a=False, trans_b=trans_b)
+
+ @staticmethod
+ def backward(ctx, grad):
+ grad = grad.contiguous()
+ a, b, batch_sizes = ctx.saved_tensors
+ trans_b = ctx.trans_b
+
+ agrad = None
+ if ctx.needs_input_grad[0]:
+ agrad = backend.gmm(
+ grad, b, batch_sizes, trans_a=False, trans_b=not trans_b)
+
+ bgrad = None
+ if ctx.needs_input_grad[1]:
+ lhs, rhs = (grad, a) if trans_b else (a, grad)
+ bgrad = backend.gmm(
+ lhs, rhs, batch_sizes, trans_a=True, trans_b=False)
+ return agrad, bgrad, None, None
+
+
+def gmm(a, b, batch_sizes, trans_b=False):
+ return GroupedGemm.apply(a, b, batch_sizes, trans_b)
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/grouped_gemm_util.py b/build/torch28-cxx11-cu129-x86_64-linux/grouped_gemm_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d6d49fc46b0a57ad46e4179df3cc1ac2a24f7ae
--- /dev/null
+++ b/build/torch28-cxx11-cu129-x86_64-linux/grouped_gemm_util.py
@@ -0,0 +1,31 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+import warnings
+
+_grouped_gemm_is_available: bool = False
+try:
+ # import grouped_gemm
+ pass
+ _grouped_gemm_is_available = True
+except ImportError as error:
+ warnings.warn('Grouped GEMM not available.')
+
+
+def grouped_gemm_is_available():
+ return _grouped_gemm_is_available
+
+
+def assert_grouped_gemm_is_available():
+ msg = (
+ 'Grouped GEMM not available. Please run '
+ '`pip install git+https://github.com/tgale96/grouped_gemm@main`.',
+ )
+ assert _grouped_gemm_is_available, msg
+
+
+# backend = grouped_gemm.backend if grouped_gemm_is_available() else None
+# ops = grouped_gemm.ops if grouped_gemm_is_available() else None
+
+
+from .grouped_gemm import backend as ops
+from .grouped_gemm import ops as backend
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/layers.py b/build/torch28-cxx11-cu129-x86_64-linux/layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..d3a66957b08d748fd5b4fca8ad5f2c68c81cf429
--- /dev/null
+++ b/build/torch28-cxx11-cu129-x86_64-linux/layers.py
@@ -0,0 +1,1230 @@
+import torch
+import torch.distributed as dist
+
+from typing import Optional, Any, TYPE_CHECKING
+
+from . import _layers
+from . import ops
+
+# Conditional import for meta kernel registration
+if TYPE_CHECKING:
+
+ def register_fake(fn):
+ return lambda name: fn
+
+else:
+ try:
+ from torch.library import register_fake
+ except ImportError:
+ try:
+ from torch.library import impl_abstract as register_fake
+ except ImportError:
+ # Fallback for older PyTorch versions
+ def register_fake(op_name):
+ def decorator(fn):
+ return fn
+
+ return decorator
+
+
+# Meta kernel implementations for torch.compile compatibility
+def _install_meta_kernels():
+ """Install meta kernels for existing MegaBlocks operations"""
+
+ # Create wrapper functions that check for compilation and return meta tensors
+
+ # Patch ops.sort
+ if hasattr(ops, "sort"):
+ original_sort = ops.sort
+
+ def sort_with_meta(x, end_bit=None):
+ if torch.compiler.is_compiling():
+ print("Using meta kernel for sort")
+ # Meta implementation - return tensors with correct shape/dtype/device
+ return torch.empty_like(x), torch.empty_like(x)
+ # print("Using original sort kernel")
+ return original_sort(x, end_bit)
+
+ ops.sort = sort_with_meta
+
+ # Patch ops.histogram
+ if hasattr(ops, "histogram"):
+ original_histogram = ops.histogram
+
+ def histogram_with_meta(x, max_val):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty((max_val,), dtype=torch.int32, device=x.device)
+ return original_histogram(x, max_val)
+
+ ops.histogram = histogram_with_meta
+
+ # Patch ops.inclusive_cumsum
+ if hasattr(ops, "inclusive_cumsum"):
+ original_inclusive_cumsum = ops.inclusive_cumsum
+
+ def inclusive_cumsum_with_meta(x, dim):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty_like(x)
+ return original_inclusive_cumsum(x, dim)
+
+ ops.inclusive_cumsum = inclusive_cumsum_with_meta
+
+ # Patch ops.binned_gather
+ if hasattr(ops, "binned_gather"):
+ original_binned_gather = ops.binned_gather
+
+ def binned_gather_with_meta(x, indices, bins, bin_size, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - output shape based on bin_size
+ if x.dim() >= 2:
+ hidden_size = x.size(-1)
+ return torch.empty(
+ (bin_size, x.size(1), hidden_size),
+ dtype=x.dtype,
+ device=x.device,
+ )
+ else:
+ return torch.empty((bin_size,), dtype=x.dtype, device=x.device)
+ return original_binned_gather(x, indices, bins, bin_size, top_k)
+
+ ops.binned_gather = binned_gather_with_meta
+
+ # Patch ops.binned_scatter
+ if hasattr(ops, "binned_scatter"):
+ original_binned_scatter = ops.binned_scatter
+
+ def binned_scatter_with_meta(x, indices, weights, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - typically reduces to 2D
+ if x.dim() >= 3:
+ return torch.empty(
+ (x.size(1), x.size(2)), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty_like(x)
+ return original_binned_scatter(x, indices, weights, bins, top_k)
+
+ ops.binned_scatter = binned_scatter_with_meta
+
+ # Patch ops.gather
+ if hasattr(ops, "gather"):
+ original_gather = ops.gather
+
+ def gather_with_meta(x, indices, bin_ids, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ if x.dim() >= 2:
+ hidden_size = x.size(-1)
+ return torch.empty(
+ (indices.numel(), hidden_size), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty(indices.shape, dtype=x.dtype, device=x.device)
+ return original_gather(x, indices, bin_ids, bins, top_k)
+
+ ops.gather = gather_with_meta
+
+ # Patch ops.scatter
+ if hasattr(ops, "scatter"):
+ original_scatter = ops.scatter
+
+ def scatter_with_meta(x, indices, bin_ids, weights, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - restore sequence shape
+ seq_len = (
+ indices.size(0) // top_k
+ if indices.numel() > 0 and top_k > 0
+ else x.size(0)
+ )
+ if x.dim() >= 2:
+ return torch.empty(
+ (seq_len, x.size(-1)), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty((seq_len,), dtype=x.dtype, device=x.device)
+ return original_scatter(x, indices, bin_ids, weights, bins, top_k)
+
+ ops.scatter = scatter_with_meta
+
+ # Patch ops.replicate
+ if hasattr(ops, "replicate"):
+ original_replicate = ops.replicate
+
+ def replicate_with_meta(x, bins, num_outputs):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty(
+ (x.shape[0], num_outputs), dtype=x.dtype, device=x.device
+ )
+ return original_replicate(x, bins, num_outputs)
+
+ ops.replicate = replicate_with_meta
+
+ # Patch ops.repeat (if it's a regular function)
+ if hasattr(ops, "repeat"):
+ original_repeat = ops.repeat
+
+ def repeat_with_meta(x, repeats):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ if isinstance(repeats, (tuple, list)):
+ new_shape = list(x.shape)
+ for i, rep in enumerate(repeats):
+ if i < len(new_shape):
+ new_shape[i] *= rep
+ return torch.empty(new_shape, dtype=x.dtype, device=x.device)
+ else:
+ new_shape = [x.size(0) * repeats] + list(x.shape[1:])
+ return torch.empty(new_shape, dtype=x.dtype, device=x.device)
+ return original_repeat(x, repeats)
+
+ ops.repeat = repeat_with_meta
+
+
+# Install meta kernels on import
+try:
+ _install_meta_kernels()
+except Exception as e:
+ # If meta kernel installation fails, continue without them
+ # torch.compile may not work but the library will still function
+ import warnings
+
+ warnings.warn(
+ f"Failed to install meta kernels for torch.compile support: {e}", UserWarning
+ )
+
+
+# Set the expert model parallel attributes on a tensor
+def set_expert_model_parallel_attributes(
+ tensor: torch.Tensor,
+ is_parallel: bool,
+):
+ assert not hasattr(tensor, "expert_model_parallel")
+ setattr(tensor, "expert_model_parallel", is_parallel)
+
+
+# Get the expert model parallel attributes from a tensor
+def expert_sharding_degree(
+ world_size: int,
+ moe_num_experts: int,
+) -> int:
+ esd = min(world_size, moe_num_experts)
+ if (moe_num_experts % esd) != 0:
+ raise ValueError(f"Cannot shard {moe_num_experts} experts {esd} ways.")
+ return esd
+
+
+# Calculate the hidden sharding degree based on world size and expert sharding degree
+def hidden_sharding_degree(
+ world_size: int,
+ moe_num_experts: int,
+ ffn_hidden_size: int,
+) -> int:
+ esd = expert_sharding_degree(world_size, moe_num_experts)
+ hsd = world_size // esd
+ if (ffn_hidden_size % hsd) != 0:
+ raise ValueError(f"Cannot shard {ffn_hidden_size} features {hsd} ways.")
+ if (esd * hsd) != world_size:
+ raise ValueError(
+ f"Invalid sharding. expert_sharding_degree ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size})."
+ )
+ return hsd
+
+
+# Calculate the number of experts per rank based on world size and expert sharding degree
+def experts_per_rank(
+ moe_num_experts: int,
+ world_size: int,
+) -> int:
+ return moe_num_experts // expert_sharding_degree(world_size, moe_num_experts)
+
+
+# Calculate the number of features per rank based on ffn hidden size and hidden sharding degree
+def features_per_rank(
+ ffn_hidden_size: int, world_size: int, moe_num_experts: int
+) -> int:
+ return ffn_hidden_size // hidden_sharding_degree(
+ world_size, moe_num_experts, ffn_hidden_size
+ )
+
+
+# Apply jitter to the input tensor
+def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor:
+ low = 1.0 - moe_jitter_eps
+ high = 1.0 + moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return x * (low + noise * (high - low))
+
+
+# Compute the top-k scores from the logits
+def compute_top_k(scores: torch.Tensor, moe_top_k: int):
+ if moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, moe_top_k, dim=-1)
+
+
+# Route tokens to experts and compute expert weights and indices
+def route_tokens(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: torch.Tensor,
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ if training and moe_jitter_eps is not None:
+ x = apply_jitter(x, moe_jitter_eps)
+
+ x_flat = x.view(-1, x.shape[-1])
+ logits = torch.nn.functional.linear(x_flat, router_weight, router_bias)
+ expert_weights, expert_indices = compute_top_k(logits, moe_top_k)
+ expert_weights = expert_weights.softmax(dim=-1)
+ if moe_normalize_expert_weights is not None:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+ if uniform_expert_assignment:
+ expert_indices = _layers.router._uniform_expert_assignment(
+ expert_indices,
+ moe_num_experts,
+ )
+
+ return logits, expert_weights, expert_indices
+
+
+# Scale the gradient of the weights
+def scale_grad(
+ w: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+) -> torch.Tensor:
+ if gradient_scale is None:
+ return w
+ return _layers.mlp.scale_gradient(w, gradient_scale)
+
+
+# Forward pass for the MLP layer
+def mlp_forward(
+ x: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ limit: float = 7.0,
+):
+ # Scale weights
+ w1 = scale_grad(w1, gradient_scale)
+ w2 = scale_grad(w2, gradient_scale)
+ w1_bias = scale_grad(w1_bias, gradient_scale)
+ w2_bias = scale_grad(w2_bias, gradient_scale)
+
+ # Resolve dtensors
+ w1 = _layers.mlp.resolve_dtensor(w1)
+ w2 = _layers.mlp.resolve_dtensor(w2)
+ w1_bias = _layers.mlp.resolve_dtensor(w1_bias)
+ w2_bias = _layers.mlp.resolve_dtensor(w2_bias)
+
+ # Forward pass
+ gate_up = torch.bmm(x, w1) + w1_bias[..., None, :]
+ gate, up = gate_up[..., ::2], gate_up[..., 1::2]
+ gate = gate.clamp(min=None, max=limit)
+ up = up.clamp(min=-limit, max=limit)
+ glu = gate * torch.sigmoid(gate * alpha)
+ next_states = torch.bmm(((up + 1) * glu), w2)
+ next_states += w2_bias[..., None, :]
+ return next_states
+
+# Shared expert MLP forward pass
+def shared_mlp_forward(
+ x: torch.Tensor,
+ up_proj_weight: torch.Tensor,
+ down_proj_weight: torch.Tensor,
+ up_proj_bias: Optional[torch.Tensor] = None,
+ down_proj_bias: Optional[torch.Tensor] = None,
+ activation_fn: Optional[Any] = None,
+ gradient_scale: Optional[float] = None,
+) -> torch.Tensor:
+ # Default activation function
+ if activation_fn is None:
+ activation_fn = torch.nn.functional.gelu
+
+ # Scale weights
+ up_proj_weight = scale_grad(up_proj_weight, gradient_scale)
+ down_proj_weight = scale_grad(down_proj_weight, gradient_scale)
+ if up_proj_bias is not None:
+ up_proj_bias = scale_grad(up_proj_bias, gradient_scale)
+ if down_proj_bias is not None:
+ down_proj_bias = scale_grad(down_proj_bias, gradient_scale)
+
+ # Resolve dtensors
+ up_proj_weight = _layers.mlp.resolve_dtensor(up_proj_weight)
+ down_proj_weight = _layers.mlp.resolve_dtensor(down_proj_weight)
+ if up_proj_bias is not None:
+ up_proj_bias = _layers.mlp.resolve_dtensor(up_proj_bias)
+ if down_proj_bias is not None:
+ down_proj_bias = _layers.mlp.resolve_dtensor(down_proj_bias)
+
+ # Up projection
+ x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias)
+
+ # Activation
+ x = activation_fn(x)
+
+ # Down projection
+ x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias)
+
+ return x
+
+
+# Combine outputs from shared expert and regular experts
+def combine_expert_shared_outputs(
+ shared_expert_out: torch.Tensor,
+ expert_out: torch.Tensor,
+ shared_expert_weighted_sum: bool = False,
+ moe_top_k: int = 1,
+) -> torch.Tensor:
+ if shared_expert_weighted_sum:
+ # Weighted sum based on number of experts used
+ total_experts = moe_top_k + 1
+ shared_weight = 1.0 / total_experts
+ expert_weight = moe_top_k / total_experts
+ return shared_expert_out * shared_weight + expert_out * expert_weight
+ else:
+ # Simple addition
+ return shared_expert_out + expert_out
+
+
+# Global variable to store load balancing loss
+_LOAD_BALANCING_LOSS = []
+
+
+def save_load_balancing_loss(loss):
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.append(loss)
+
+
+def get_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ return _LOAD_BALANCING_LOSS
+
+
+def clear_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.clear()
+
+
+def batched_load_balancing_loss(args):
+ if args.moe_loss_weight == 0:
+ return 0.0
+
+ tokens_per_expert, expert_scores = zip(*get_load_balancing_loss())
+ num_layers_per_pipeline_stage = args.num_layers // args.pipeline_model_parallel_size
+ if args.num_layers_per_virtual_pipeline_stage is not None:
+ num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage
+
+ if len(tokens_per_expert) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f"Expected {num_layers_per_pipeline_stage} token_per_experts "
+ f"but found {len(tokens_per_expert)}.\nnum_layers = "
+ f"{args.num_layers}\npipeline_model_parallel_size = "
+ f"{args.pipeline_model_parallel_size}\n"
+ "num_layers_per_virtual_pipeline_stage"
+ f" = {args.num_layers_per_virtual_pipeline_stage}",
+ )
+ if len(expert_scores) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f"Expected {num_layers_per_pipeline_stage} expert_scores "
+ f"but found {len(tokens_per_expert)}.\nnum_layers = "
+ f"{args.num_layers}\npipeline_model_parallel_size = "
+ f"{args.pipeline_model_parallel_size}\n"
+ "num_layers_per_virtual_pipeline_stage"
+ f" = {args.num_layers_per_virtual_pipeline_stage}",
+ )
+
+ # Verify the shape of the tokens_per_expert and expert_scores tensors.
+ assert all(
+ (x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert)
+ )
+
+ tokens = expert_scores[0].shape[0]
+ assert all(
+ (
+ (
+ x.ndim == 2
+ and x.shape[1] == args.moe_num_experts
+ and x.shape[0] == tokens
+ )
+ for x in expert_scores
+ )
+ )
+
+ # Concatenate the contributions of each layer and convert to
+ # the correct types and formats for the dot product.
+ expert_scores = torch.cat(expert_scores, dim=1)
+ if args.moe_lbl_in_fp32:
+ expert_scores = expert_scores.float()
+ if tokens != 0:
+ expert_scores = expert_scores.mean(dim=0)
+ else:
+ expert_scores = expert_scores.sum(dim=0)
+ tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype)
+
+ expected_values = num_layers_per_pipeline_stage * args.moe_num_experts
+ assert tokens_per_expert.numel() == expected_values
+ assert expert_scores.numel() == expected_values
+
+ # Calculate the total scale across all factors.
+ #
+ # loss_weight * num_experts / (num_layers * tokens * top_k)
+ scale_numerator = args.moe_num_experts * args.moe_loss_weight
+ scale_denominator = args.num_layers * tokens * args.moe_top_k
+ scale = scale_numerator / scale_denominator
+ return scale * torch.dot(tokens_per_expert, expert_scores)
+
+
+# Calculate the expert capacity based on tokens, top_k, number of experts,
+# expert parallel group, capacity factor, and whether expert model parallelism is used.
+def expert_capacity(
+ tokens: int,
+ top_k: int,
+ num_experts: int,
+ expert_parallel_group: int,
+ moe_capacity_factor: float,
+ moe_expert_model_parallelism: bool,
+) -> int:
+ world_size = (
+ dist.get_world_size(expert_parallel_group)
+ if moe_expert_model_parallelism
+ else 1
+ )
+
+ tokens_per_expert = top_k * tokens * world_size / num_experts
+ return int(moe_capacity_factor * tokens_per_expert)
+
+
+def load_balancing_loss(
+ tokens_per_expert: torch.Tensor,
+ expert_scores: torch.Tensor,
+ top_k: int,
+ num_experts: int,
+):
+ assert len(expert_scores.size()) == 2
+ tokens, num_experts = expert_scores.size()
+ assert num_experts == num_experts
+ assert len(tokens_per_expert.size()) == 1
+ (num_experts,) = tokens_per_expert.size()
+ assert num_experts == num_experts
+ scale = num_experts / (tokens * top_k)
+ return scale * torch.dot(
+ tokens_per_expert.to(expert_scores.dtype),
+ expert_scores.mean(dim=0),
+ )
+
+
+def indices_and_bins(
+ top_expert: torch.Tensor,
+ sort_end_bit: int,
+ num_experts: int,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ top_expert = top_expert.int()
+
+ # Ensure contiguous memory layout
+ top_expert = top_expert.contiguous()
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(top_expert.device):
+ output = ops.sort(top_expert, sort_end_bit)
+ bin_ids, indices = output
+ tokens_per_expert = ops.histogram(top_expert, num_experts)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+
+ bins = bins.view(1) if not len(bins.size()) else bins
+ return indices, bin_ids, bins, tokens_per_expert
+
+
+def expert_capacity_fn(
+ tokens: int,
+ top_k: int,
+ num_experts: int,
+ expert_parallel_group: torch.distributed.ProcessGroup,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+) -> int:
+ world_size = (
+ dist.get_world_size(expert_parallel_group)
+ if moe_expert_model_parallelism
+ else 1
+ )
+ tokens_per_expert = top_k * tokens * world_size / num_experts
+ return int(moe_capacity_factor * tokens_per_expert)
+
+
+def permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ top_k,
+ w1,
+ w2,
+ w1_bias,
+ w2_bias,
+ gradient_scale,
+ alpha,
+):
+ # Route tokens to experts
+ x = x.view(-1, x.shape[-1])
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(x.device):
+ x = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
+
+ # Expert computation
+ x = mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale, alpha)
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(x.device):
+ # Route tokens back
+ out = ops.binned_scatter(x, indices, expert_weights, bins, top_k)
+ return out
+
+
+def forward_once(
+ x: torch.Tensor,
+ expert_weights: torch.Tensor,
+ top_experts: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ top_k: int = 4,
+ num_experts: int = 128,
+ expert_parallel_group: int = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ mlp_impl: Optional[str] = None,
+):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
+ top_experts, sort_end_bit, num_experts
+ )
+
+ # Calculate expert capacity
+ sl, bs, _ = x.size()
+
+ expert_capacity = expert_capacity_fn(
+ sl * bs,
+ top_k,
+ num_experts,
+ expert_parallel_group,
+ moe_capacity_factor,
+ moe_expert_model_parallelism,
+ )
+
+ if expert_capacity == 0:
+ expert_capacity = torch.max(tokens_per_expert).item()
+
+ x = permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ top_k,
+ w1,
+ w2,
+ w1_bias,
+ w2_bias,
+ gradient_scale,
+ alpha,
+ )
+ return x, tokens_per_expert
+
+
+def parallel_forward_once(
+ x: torch.Tensor,
+ expert_weights: torch.Tensor,
+ top_experts: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ top_k: int = 4,
+ num_experts: int = 128,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = True,
+ hidden_size: int = 1152,
+ mlp_impl: Optional[str] = "grouped",
+):
+ # Flatten inputs
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+
+ # TODO: remove debugging var
+ # my_rank = dist.get_rank(expert_parallel_group) if expert_parallel_group else 0
+
+ with torch.no_grad():
+ # Step 1: Local permutation setup
+ indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
+ top_experts, sort_end_bit, num_experts
+ )
+
+ # Calculate sharding parameters
+ world_size = dist.get_world_size(expert_parallel_group)
+ hidden_sharding_deg = hidden_sharding_degree(
+ world_size, num_experts, hidden_size
+ )
+ experts_per_rank_val = experts_per_rank(num_experts, world_size)
+
+ # Replicate token counts for hidden sharding
+ repeated_tokens_per_expert = ops.repeat(
+ tokens_per_expert, (hidden_sharding_deg,)
+ )
+
+ # Exchange token counts across devices
+ parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert)
+
+ # Ensure CUB knows which device to use
+ tpe_handle = dist.all_to_all_single(
+ parallel_tokens_per_expert,
+ repeated_tokens_per_expert,
+ group=expert_parallel_group,
+ async_op=True,
+ )
+
+ # Step 2: Local permutation - group tokens by target device
+ x = x.view(-1, x.shape[-1]) # [sl * bs, hs]
+ x = ops.gather(x, indices, bin_ids, bins, top_k)
+
+ # Step 3: Compute communication counts and exchange tokens
+ with torch.no_grad():
+ tpe_handle.wait()
+
+ # Reshape for per-device calculations
+ repeated_tokens_per_expert = repeated_tokens_per_expert.view(
+ world_size, experts_per_rank_val
+ )
+ parallel_tokens_per_expert = parallel_tokens_per_expert.view(
+ world_size, experts_per_rank_val
+ )
+
+ # Calculate send/recv counts
+ send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1).tolist()
+ # recv_counts = parallel_tokens_per_expert.cpu().sum(dim=-1).tolist()
+ parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
+ recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1).tolist()
+ tokens_received = sum(recv_counts)
+
+ # Replicate for hidden sharding
+ x = ops.repeat(x, (hidden_sharding_deg, 1))
+
+ # Cross-device token exchange
+ parallel_x, parallel_x_handle = _layers.all_to_all.all_to_all(
+ x, recv_counts, send_counts, expert_parallel_group, async_op=True
+ )
+
+ with torch.no_grad():
+ # Step 4: Setup for local expert computation
+ replicate_bins = ops.inclusive_cumsum(parallel_tokens_per_expert.flatten(), 0)
+ replicate_bins = (
+ replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins
+ )
+
+ # Create expert indices for received tokens
+ parallel_top_expert = torch.remainder(
+ torch.arange(
+ num_experts * hidden_sharding_deg,
+ dtype=torch.int32,
+ device=indices.device,
+ ),
+ experts_per_rank_val,
+ )
+ parallel_top_expert = ops.replicate(
+ parallel_top_expert.unsqueeze(dim=0),
+ replicate_bins,
+ tokens_received,
+ ).flatten()
+
+ # Sort tokens by expert assignment
+ parallel_bin_ids, parallel_indices = ops.sort(
+ parallel_top_expert,
+ sort_end_bit,
+ )
+
+ # Calculate bins for local experts
+ parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
+ dim=0, dtype=torch.int
+ )
+ parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = (
+ parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins
+ )
+
+ # Calculate expert capacity
+ expert_capacity = expert_capacity_fn(
+ tokens_received,
+ top_k,
+ experts_per_rank_val,
+ expert_parallel_group,
+ moe_capacity_factor,
+ moe_expert_model_parallelism,
+ )
+ if expert_capacity == 0:
+ expert_capacity = torch.max(parallel_tokens_per_expert).item()
+
+ # Locally permute the tokens and perform the expert computation.
+ # Block to make sure that the cross-device permutation is complete.
+ if mlp_impl == "grouped":
+ # GroupedMLP requires counts on CPU. We can use the tensor already
+ # moved to CPU for the prior all_to_all, which avoids an extra
+ # device synchronization.
+ parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+
+ # Step 5: Expert computation
+ parallel_x_handle.wait()
+
+ parallel_x = permute_and_compute(
+ parallel_x,
+ parallel_tokens_per_expert,
+ parallel_indices,
+ parallel_bin_ids,
+ None, # expert_weights
+ parallel_bins,
+ expert_capacity,
+ top_k=1,
+ w1=w1,
+ w2=w2,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ )
+
+ # Step 6: Reverse communication - send results back
+ x, _ = _layers.all_to_all.all_to_all(
+ parallel_x, send_counts, recv_counts, expert_parallel_group
+ )
+
+ # Step 7: Reduce across hidden sharding dimension
+ shape = (hidden_sharding_deg, -1, hidden_size)
+ x = x.view(shape).sum(dim=0)
+
+ # Step 8: Final local unpermutation
+ x = ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
+
+ return x, tokens_per_expert.flatten()
+
+
+def moe_forward(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: Optional[torch.Tensor],
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+ w1: torch.Tensor = None,
+ w2: torch.Tensor = None,
+ w1_bias: torch.Tensor = None,
+ w2_bias: torch.Tensor = None,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ forward_fn: Any = None,
+ hidden_size: int = None,
+ mlp_impl: str = "grouped",
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+
+ # Route tokens to experts
+ logits, expert_weights, expert_indices = route_tokens(
+ x,
+ router_weight,
+ router_bias,
+ moe_top_k,
+ moe_num_experts,
+ moe_jitter_eps,
+ moe_normalize_expert_weights,
+ uniform_expert_assignment,
+ training,
+ )
+
+ # Create router scores for output
+ router_scores = (
+ torch.zeros_like(logits)
+ .scatter_(1, expert_indices, expert_weights)
+ .transpose(0, 1)
+ )
+
+ in_shape = x.size()
+
+ # Prepare forward function arguments
+ forward_args = {
+ "x": x,
+ "expert_weights": expert_weights,
+ "top_experts": expert_indices,
+ "w1": w1,
+ "w2": w2,
+ "w1_bias": w1_bias,
+ "w2_bias": w2_bias,
+ "gradient_scale": gradient_scale,
+ "alpha": alpha,
+ "sort_end_bit": sort_end_bit,
+ "top_k": moe_top_k,
+ "num_experts": moe_num_experts,
+ "expert_parallel_group": expert_parallel_group,
+ "moe_capacity_factor": moe_capacity_factor,
+ "moe_expert_model_parallelism": moe_expert_model_parallelism,
+ "mlp_impl": mlp_impl,
+ }
+
+ # Add hidden_size for parallel forward
+ if moe_expert_model_parallelism and hidden_size is not None:
+ forward_args["hidden_size"] = hidden_size
+ elif moe_expert_model_parallelism and hidden_size is None:
+ # Infer hidden_size from input shape
+ forward_args["hidden_size"] = x.shape[-1]
+
+ # Compute expert outputs
+ x, tokens_per_expert = forward_fn(**forward_args)
+
+ # Save load balancing loss if needed
+ moe_loss_weight = 0.0 # Can be made configurable
+ if training and moe_loss_weight > 0:
+ save_load_balancing_loss((tokens_per_expert, logits))
+
+ # Restore original shape
+ x = x.view(in_shape)
+
+ return x, expert_weights, router_scores
+
+
+def moe_forward_with_shared_expert(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: Optional[torch.Tensor],
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+ w1: torch.Tensor = None,
+ w2: torch.Tensor = None,
+ w1_bias: torch.Tensor = None,
+ w2_bias: torch.Tensor = None,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ forward_fn: Any = None,
+ hidden_size: int = None,
+ mlp_impl: str = "grouped",
+ # Shared expert parameters
+ shared_up_proj_weight: Optional[torch.Tensor] = None,
+ shared_down_proj_weight: Optional[torch.Tensor] = None,
+ shared_up_proj_bias: Optional[torch.Tensor] = None,
+ shared_down_proj_bias: Optional[torch.Tensor] = None,
+ shared_expert_weighted_sum: bool = False,
+ shared_activation_fn: Optional[Any] = None,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+
+ # First, compute regular MoE forward pass
+ expert_out, expert_weights, router_scores = moe_forward(
+ x=x,
+ router_weight=router_weight,
+ router_bias=router_bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=training,
+ w1=w1,
+ w2=w2,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=moe_expert_model_parallelism,
+ forward_fn=forward_fn,
+ hidden_size=hidden_size,
+ mlp_impl=mlp_impl,
+ )
+
+ # If shared expert weights provided, compute shared expert output
+ if shared_up_proj_weight is not None and shared_down_proj_weight is not None:
+ shared_expert_out = shared_mlp_forward(
+ x=x,
+ up_proj_weight=shared_up_proj_weight,
+ down_proj_weight=shared_down_proj_weight,
+ up_proj_bias=shared_up_proj_bias,
+ down_proj_bias=shared_down_proj_bias,
+ activation_fn=shared_activation_fn,
+ gradient_scale=gradient_scale,
+ )
+
+ # Combine expert outputs
+ combined_out = combine_expert_shared_outputs(
+ shared_expert_out=shared_expert_out,
+ expert_out=expert_out,
+ shared_expert_weighted_sum=shared_expert_weighted_sum,
+ moe_top_k=moe_top_k,
+ )
+
+ return combined_out, expert_weights, router_scores
+
+ # Return regular MoE output if no shared expert
+ return expert_out, expert_weights, router_scores
+
+
+def create_shared_expert_weights(
+ hidden_size: int,
+ shared_expert_hidden_size: int,
+ device: torch.device,
+ dtype: torch.dtype,
+ init_method: Any,
+ output_layer_init_method: Any = None,
+) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
+
+ if output_layer_init_method is None:
+ output_layer_init_method = init_method
+
+ # Create weight tensors
+ up_proj_weight = torch.empty(
+ shared_expert_hidden_size,
+ hidden_size,
+ device=device,
+ dtype=dtype,
+ )
+ down_proj_weight = torch.empty(
+ hidden_size,
+ shared_expert_hidden_size,
+ device=device,
+ dtype=dtype,
+ )
+
+ # Initialize weights
+ init_method(up_proj_weight)
+ output_layer_init_method(down_proj_weight)
+
+ # No bias by default
+ return up_proj_weight, down_proj_weight, None, None
+
+
+# HACK: Extract device_mesh from pre-hook closure - required for transformers integration
+# This exists because device_mesh is trapped in hook closures with no model attribute
+# Fragile - breaks if hook structure changes or Python internals change
+# TODO: Replace with a more robust solution when available
+def get_device_mesh(model):
+ # Extract device_mesh from child's unused pre_hook closure
+ try:
+ # Find the pre-hook that contains 'device_mesh' in its closure
+ hook = next(
+ h
+ for h in model.experts._forward_pre_hooks.values()
+ if "device_mesh" in h.__code__.co_freevars
+ )
+ # Extract the device_mesh from the closure
+ return hook.__closure__[
+ hook.__code__.co_freevars.index("device_mesh")
+ ].cell_contents
+ except Exception:
+ return None
+
+
+class MegaBlocksMoeMLP(torch.nn.Module):
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ gradient_scale = getattr(self.experts, "gradient_scale", None)
+ alpha = getattr(self.experts, "alpha", 1.0)
+ moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+ uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
+
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = get_device_mesh(self)
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
+
+ has_parallel = (
+ expert_parallel_group is not None
+ and dist.is_initialized()
+ and dist.get_world_size(expert_parallel_group) > 1
+ )
+ forward_fn = parallel_forward_once if has_parallel else forward_once
+
+ sort_end_bit = max(
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
+ )
+ mlp_impl = getattr(self, "mlp_impl", "grouped")
+ output, expert_weights_out, *_ = moe_forward(
+ x=x,
+ router_weight=self.router.weight,
+ router_bias=self.router.bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=self.training,
+ w1=self.experts.gate_up_proj,
+ w2=self.experts.down_proj,
+ w1_bias=self.experts.gate_up_proj_bias,
+ w2_bias=self.experts.down_proj_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=has_parallel,
+ forward_fn=forward_fn,
+ hidden_size=self.experts.hidden_size,
+ mlp_impl=mlp_impl,
+ )
+ return output, expert_weights_out
+
+
+# Export main classes
+__all__ = ["MegaBlocksMoeMLP", "MegaBlocksMoeMLPWithSharedExpert"]
+
+
+class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
+
+ def __init__(self):
+ super().__init__()
+ # Shared expert weights will be set by the user
+ self.shared_up_proj_weight = None
+ self.shared_down_proj_weight = None
+ self.shared_up_proj_bias = None
+ self.shared_down_proj_bias = None
+ self.shared_expert_weighted_sum = False
+ self.shared_activation_fn = None
+
+ def set_shared_expert_weights(
+ self,
+ up_proj_weight: torch.Tensor,
+ down_proj_weight: torch.Tensor,
+ up_proj_bias: Optional[torch.Tensor] = None,
+ down_proj_bias: Optional[torch.Tensor] = None,
+ weighted_sum: bool = False,
+ activation_fn: Optional[Any] = None,
+ ):
+ self.shared_up_proj_weight = up_proj_weight
+ self.shared_down_proj_weight = down_proj_weight
+ self.shared_up_proj_bias = up_proj_bias
+ self.shared_down_proj_bias = down_proj_bias
+ self.shared_expert_weighted_sum = weighted_sum
+ self.shared_activation_fn = activation_fn
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ gradient_scale = getattr(self.experts, "gradient_scale", None)
+ alpha = getattr(self.experts, "alpha", 1.0)
+ moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+ uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
+
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = get_device_mesh(self)
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
+
+ has_parallel = (
+ expert_parallel_group is not None
+ and dist.is_initialized()
+ and dist.get_world_size(expert_parallel_group) > 1
+ )
+ forward_fn = parallel_forward_once if has_parallel else forward_once
+
+ sort_end_bit = max(
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
+ )
+ mlp_impl = getattr(self, "mlp_impl", "grouped")
+
+ output, expert_weights_out, *_ = moe_forward_with_shared_expert(
+ x=x,
+ router_weight=self.router.weight,
+ router_bias=self.router.bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=self.training,
+ w1=self.experts.gate_up_proj,
+ w2=self.experts.down_proj,
+ w1_bias=self.experts.gate_up_proj_bias,
+ w2_bias=self.experts.down_proj_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=has_parallel,
+ forward_fn=forward_fn,
+ hidden_size=self.experts.hidden_size,
+ mlp_impl=mlp_impl,
+ # Shared expert parameters
+ shared_up_proj_weight=self.shared_up_proj_weight,
+ shared_down_proj_weight=self.shared_down_proj_weight,
+ shared_up_proj_bias=self.shared_up_proj_bias,
+ shared_down_proj_bias=self.shared_down_proj_bias,
+ shared_expert_weighted_sum=self.shared_expert_weighted_sum,
+ shared_activation_fn=self.shared_activation_fn,
+ )
+ return output, expert_weights_out
+
+
+# Patch for XPU support
+if hasattr(torch, "xpu") and torch.xpu.is_available():
+ from .xpu_fused_moe import MegaBlocksMoeMLP
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/megablocks/__init__.py b/build/torch28-cxx11-cu129-x86_64-linux/megablocks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..03dbc1afe1cf156661a2b1b22003cd5f599a0309
--- /dev/null
+++ b/build/torch28-cxx11-cu129-x86_64-linux/megablocks/__init__.py
@@ -0,0 +1,26 @@
+import ctypes
+import sys
+
+import importlib
+from pathlib import Path
+from types import ModuleType
+
+def _import_from_path(file_path: Path) -> ModuleType:
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
+ # it would also be used for other imports. So, we make a module name that
+ # depends on the path for it to be unique using the hex-encoded hash of
+ # the path.
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
+ module_name = path_hash
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ if spec is None:
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
+ module = importlib.util.module_from_spec(spec)
+ if module is None:
+ raise ImportError(f"Cannot load module {module_name} from spec")
+ sys.modules[module_name] = module
+ spec.loader.exec_module(module) # type: ignore
+ return module
+
+
+globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/metadata.json b/build/torch28-cxx11-cu129-x86_64-linux/metadata.json
new file mode 100644
index 0000000000000000000000000000000000000000..9cf5deed9898dce769f4cc73913d3530b92a0bd8
--- /dev/null
+++ b/build/torch28-cxx11-cu129-x86_64-linux/metadata.json
@@ -0,0 +1,4 @@
+{
+ "version": 1,
+ "python-depends": []
+}
\ No newline at end of file
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/ops/__init__.py b/build/torch28-cxx11-cu129-x86_64-linux/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b944080df810d0b0cfc571f3009b0098a651f9b7
--- /dev/null
+++ b/build/torch28-cxx11-cu129-x86_64-linux/ops/__init__.py
@@ -0,0 +1,35 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from .binned_gather import binned_gather
+from .binned_scatter import binned_scatter
+from .cumsum import exclusive_cumsum, inclusive_cumsum
+from .gather import gather
+from .histogram import histogram
+from .padded_gather import padded_gather
+from .padded_scatter import padded_scatter
+from .repeat import repeat
+from .replicate import replicate
+from .round_up import round_up
+from .scatter import scatter
+from .sort import sort
+from .sum import sum
+from .topology import topology
+
+__all__ = [
+ 'binned_gather',
+ 'binned_scatter',
+ 'exclusive_cumsum',
+ 'inclusive_cumsum',
+ 'gather',
+ 'histogram',
+ 'padded_gather',
+ 'padded_scatter',
+ 'repeat',
+ 'replicate',
+ 'round_up',
+ 'scatter',
+ 'sort',
+ 'sum',
+ 'topology',
+]
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/ops/all_to_all_benchmark.py b/build/torch28-cxx11-cu129-x86_64-linux/ops/all_to_all_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c939818edca3345f6344bbc7cef07ffe3cd0181
--- /dev/null
+++ b/build/torch28-cxx11-cu129-x86_64-linux/ops/all_to_all_benchmark.py
@@ -0,0 +1,63 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import torch.distributed as dist
+
+# from megablocks import benchmark_util
+# from megablocks.layers.all_to_all import all_to_all
+
+from .. import benchmark_util
+from .._layers.all_to_all import all_to_all
+
+_ALL_TO_ALL_BENCHMARK = (
+ (8, 1024),
+ (16, 1024),
+ (32, 1024),
+ (64, 1024),
+ (128, 1024),
+ (256, 1024),
+ (512, 1024),
+ (1024, 1024),
+ (2 * 1024, 1024),
+ (4 * 1024, 1024),
+ (8 * 1024, 1024),
+ (16 * 1024, 1024),
+ (32 * 1024, 1024),
+ (64 * 1024, 1024),
+ (128 * 1024, 1024),
+ (256 * 1024, 1024),
+ (512 * 1024, 1024),
+ (1024 * 1024, 1024),
+)
+
+
+def benchmark_all_to_all(group, sl, hs):
+ world_size = dist.get_world_size(group)
+ assert (sl % world_size) == 0
+ send_recv_sizes = [sl // world_size] * world_size
+
+ x = torch.randn((sl, hs)).cuda().half()
+
+ details = {
+ 'world_size': world_size,
+ 'message_size (B)': send_recv_sizes[0] * hs * 2, # 2B elements.
+ }
+
+ def benchmark():
+ return all_to_all(x, send_recv_sizes, send_recv_sizes, group)
+
+ time, std = benchmark_util.benchmark_function(benchmark)
+
+ if dist.get_rank(group) == 0:
+ benchmark_util.log_benchmark('All-To-All', details, time, std)
+
+
+if __name__ == '__main__':
+ assert dist.is_available()
+ group = dist.init_process_group(backend='nccl')
+ local_rank = dist.get_rank(group)
+ torch.cuda.set_device(local_rank)
+
+ for args in _ALL_TO_ALL_BENCHMARK:
+ benchmark_all_to_all(group, *args)
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/ops/binned_gather.py b/build/torch28-cxx11-cu129-x86_64-linux/ops/binned_gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..189a7fa3518d660f29ea32e7a04827164af98d60
--- /dev/null
+++ b/build/torch28-cxx11-cu129-x86_64-linux/ops/binned_gather.py
@@ -0,0 +1,37 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for binned_gather kernel.
+class BinnedGatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bins: torch.Tensor,
+ bin_size: int,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bins)
+ ctx.top_k = top_k
+ return kernels.binned_gather(x, indices, None, bins, bin_size, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ indices, bins = ctx.saved_tensors
+ out = kernels.binned_scatter(grad, indices, None, bins, ctx.top_k)
+ return out, None, None, None, None
+
+
+binned_gather = BinnedGatherOp.apply
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/ops/binned_scatter.py b/build/torch28-cxx11-cu129-x86_64-linux/ops/binned_scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb937c0c106662ce8108c1cb926f8f063b163d3d
--- /dev/null
+++ b/build/torch28-cxx11-cu129-x86_64-linux/ops/binned_scatter.py
@@ -0,0 +1,59 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for binned_scatter kernel.
+class BinnedScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ):
+ assert len(x.size()) == 3
+ ctx.bin_size = x.size(1)
+ ctx.top_k = top_k
+
+ # TODO(tgale): Don't save 'x' for backwards if we don't need to
+ # calculate the gradient w.r.t. 'weights'.
+ ctx.save_for_backward(x, indices, weights, bins)
+ return kernels.binned_scatter(x, indices, weights, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ x, indices, weights, bins = ctx.saved_tensors
+ out = kernels.binned_gather(
+ grad,
+ indices,
+ weights,
+ bins,
+ ctx.bin_size,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[2]:
+ wgrad = kernels.binned_scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bins,
+ ctx.top_k,
+ )
+ return out, None, wgrad, None, None
+
+
+binned_scatter = BinnedScatterOp.apply
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/ops/cumsum.py b/build/torch28-cxx11-cu129-x86_64-linux/ops/cumsum.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2b7572391e20045d335cf7337246e8a9b9f57ef
--- /dev/null
+++ b/build/torch28-cxx11-cu129-x86_64-linux/ops/cumsum.py
@@ -0,0 +1,52 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ # import megablocks_ops as ops # type: ignore
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrappers for cumsum kernels.
+# NOTE: Does not support gradients.
+class ExclusiveCumsumOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, dim: int):
+ if len(x.size()) == 1:
+ x = x.view([1, -1])
+ out = torch.empty_like(x)
+ ops.exclusive_cumsum(x, 1, out)
+ return out.squeeze()
+ out = torch.empty_like(x)
+ ops.exclusive_cumsum(x, dim, out)
+ return out
+
+
+exclusive_cumsum = ExclusiveCumsumOp.apply
+
+
+class InclusiveCumsumOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, dim: int) -> torch.Tensor:
+ if len(x.size()) == 1:
+ x = x.view([1, -1])
+ out = torch.empty_like(x)
+ ops.inclusive_cumsum(x, 1, out)
+ return out.squeeze()
+ out = torch.empty_like(x)
+ ops.inclusive_cumsum(x, dim, out)
+ return out
+
+
+inclusive_cumsum = InclusiveCumsumOp.apply
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/ops/gather.py b/build/torch28-cxx11-cu129-x86_64-linux/ops/gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1f87c1e7bed8d3589dd790805234976e0b05898
--- /dev/null
+++ b/build/torch28-cxx11-cu129-x86_64-linux/ops/gather.py
@@ -0,0 +1,38 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for gather kernel.
+class GatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bin_ids, bins)
+ ctx.top_k = top_k
+ return kernels.gather(x, indices, bin_ids, None, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+
+ indices, bin_ids, bins = ctx.saved_tensors
+ out = kernels.scatter(grad, indices, bin_ids, None, bins, ctx.top_k)
+ return out, None, None, None, None, None
+
+
+gather = GatherOp.apply
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/ops/histogram.py b/build/torch28-cxx11-cu129-x86_64-linux/ops/histogram.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b3f058ec373cbba7555704fb5e4212c3cc75d9d
--- /dev/null
+++ b/build/torch28-cxx11-cu129-x86_64-linux/ops/histogram.py
@@ -0,0 +1,27 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for histogram kernel.
+# NOTE: Does not support gradients.
+class HistogramOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, max_val: float):
+ return ops.histogram(x, max_val)
+
+
+histogram = HistogramOp.apply
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/ops/histogram_benchmark.py b/build/torch28-cxx11-cu129-x86_64-linux/ops/histogram_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..c57b7bf8228e01237236748147368b09ffdf8072
--- /dev/null
+++ b/build/torch28-cxx11-cu129-x86_64-linux/ops/histogram_benchmark.py
@@ -0,0 +1,78 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import numpy as np
+import torch
+from absl.testing import parameterized
+
+from .. import ops
+
+_HISTOGRAM_TESTS = (
+ (16384, torch.int32, 2),
+ (16384, torch.int32, 4),
+ (16384, torch.int32, 8),
+ (16384, torch.int32, 16),
+ (16384, torch.int32, 32),
+ (16384, torch.int32, 64),
+ (16384, torch.int32, 128),
+ (16384, torch.int32, 256),
+)
+
+
+def benchmark_function(fn, iterations=10):
+ # Run once to get rid of startup overhead.
+ fn()
+ times = []
+ for _ in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+ start.record()
+ fn()
+ end.record()
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ times = np.array(times)
+ return times.mean(), times.std(), times.max(), times.min()
+
+
+def log_benchmark(arguments, mean_t, std_t):
+ print('=' * 60)
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t))
+ print('=' * 60)
+
+
+class HistogramBenchmark(parameterized.TestCase):
+
+ @parameterized.parameters(*_HISTOGRAM_TESTS)
+ def testHistogram(self, n, dtype, max_val):
+ x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
+
+ mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.histogram(x, max_val),)
+ arguments = {
+ 'n': n,
+ 'dtype': dtype,
+ 'max_val': max_val,
+ }
+ log_benchmark(arguments, mean_t, std_t)
+
+ @parameterized.parameters(*_HISTOGRAM_TESTS)
+ def testTorchHistogram(self, n, dtype, max_val):
+ x = torch.randint(0, 128, (n,)).cuda().to(dtype)
+
+ mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.histc(x, max_val, 0, max_val - 1),)
+ arguments = {
+ 'n': n,
+ 'dtype': dtype,
+ 'max_val': max_val,
+ }
+ log_benchmark(arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/ops/matmul_benchmark.py b/build/torch28-cxx11-cu129-x86_64-linux/ops/matmul_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ccc5dcec5e9a663794fad944c45285869c4d1c1
--- /dev/null
+++ b/build/torch28-cxx11-cu129-x86_64-linux/ops/matmul_benchmark.py
@@ -0,0 +1,415 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+
+# import stk
+
+# try:
+# import stk
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/ops/matmul_benchmark.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+
+# Calling tensor.t() calls tensor.transpose(0, 1) which calls
+# torch.as_strided(...). Circumvent this chain to avoid an overhead
+# this adds.
+def transpose_view(x):
+ return torch.as_strided(
+ x,
+ (x.shape[1], x.shape[0]),
+ (x.stride()[1], x.stride()[0]),
+ )
+
+
+_MATMUL_TESTS = (
+ (64 * 1024, 512, 2048, 64),
+ (32 * 1024, 768, 3072, 64),
+ (8 * 1024, 1024, 4096, 64),
+ (4 * 2048, 4096, 4 * 4096, 4),
+)
+
+
+def log_benchmark(name, arguments, time, std, flops):
+ benchmark_util.log_benchmark(name, arguments, time, std)
+ print('flops = {:.2f}B'.format(flops / 1e9))
+ print('throughput = {:.2f}T'.format(flops / 1e9 / time))
+ print('=' * 60)
+
+
+class MatmulBenchmark(parameterized.TestCase):
+
+ def build_sparse_matrix(self, x, padded_bins, fhs, ne):
+ blocking = 128
+ padded_tokens, _ = x.size()
+ assert padded_tokens % blocking == 0
+ assert fhs % blocking == 0
+
+ # Offsets for the sparse matrix. All rows have the
+ # same number of nonzero blocks dictated by the
+ # dimensionality of a single expert.
+ block_rows = padded_tokens // blocking
+ blocks_per_row = fhs // blocking
+ offsets = torch.arange(
+ 0,
+ block_rows * blocks_per_row + 1,
+ blocks_per_row,
+ dtype=torch.int32,
+ device=x.device,
+ )
+
+ # Indices for the sparse matrix. The indices for
+ # the intermediate matrix are dynamic depending
+ # on the mapping of tokens to experts.
+ column_indices = ops.topology(
+ padded_bins,
+ blocking,
+ block_rows,
+ blocks_per_row,
+ )
+ data = torch.empty(
+ column_indices.numel(),
+ blocking,
+ blocking,
+ dtype=torch.float16,
+ device=x.device,
+ )
+ shape = (padded_tokens, fhs * ne)
+ row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
+ return stk.Matrix(shape, data, row_indices, column_indices, offsets)
+
+ def build_input_matrix(self, sl, hs, ne):
+ x = torch.randn((sl, hs)).cuda().half()
+
+ # Assign tokens to experts uniformly.
+ top_expert = torch.arange(0, sl).cuda().int() % ne
+
+ bin_ids, indices = ops.sort(top_expert)
+ tokens_per_expert = ops.histogram(top_expert, ne)
+ padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ out = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, 1)
+ return out, padded_bins
+
+ def build_weight_matrix(self, ne, hs, fhs):
+ return torch.randn((hs, ne * fhs)).cuda().half()
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear0_Fwd_SDD_NT(self, sl, hs, fhs, ne):
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
+ w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+ topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+ w = transpose_view(w)
+
+ def benchmark():
+ return stk.ops.sdd(x, w, topo)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '0::Fwd::SDD::NT',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * fhs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear0_GradX_DSD_NN(self, sl, hs, fhs, ne):
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
+ w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+ topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+
+ def benchmark():
+ return stk.ops.dsd(topo, w)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '0::GradX::DSD::NN',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * fhs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear0_GradW_DSD_TN(self, sl, hs, fhs, ne):
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
+ topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+ topo = topo.t()
+
+ def benchmark():
+ return stk.ops.dsd(topo, x)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '0::GradW::DSD::TN',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * fhs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear1_Fwd_DSD_NN(self, sl, hs, fhs, ne):
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
+ w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+ x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+
+ def benchmark():
+ return stk.ops.dsd(x, w)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '1::Fwd::DSD::NN',
+ arguments,
+ mean_t,
+ std_t,
+ x.nnz * hs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear1_GradX_SDD_NT(self, sl, hs, fhs, ne):
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
+ w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+ x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+ out = stk.ops.dsd(x, w)
+ w = transpose_view(w)
+
+ def benchmark():
+ return stk.ops.sdd(out, w, x)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '1::GradX::SDD::NT',
+ arguments,
+ mean_t,
+ std_t,
+ x.nnz * hs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear1_GradW_DSD_TN(self, sl, hs, fhs, ne):
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
+ w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+ x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+ out = stk.ops.dsd(x, w)
+ x = x.t()
+
+ def benchmark():
+ return stk.ops.dsd(x, out)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '1::GradW::DSD::TN',
+ arguments,
+ mean_t,
+ std_t,
+ x.nnz * hs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear0_Fwd_DDD_NT(self, sl, hs, fhs, ne):
+ assert (sl % ne) == 0
+ x = torch.randn((ne, sl // ne, hs)).cuda().half()
+ w = torch.randn((ne, hs, fhs)).cuda().half()
+
+ w = w.transpose(1, 2).contiguous()
+ w = w.transpose(1, 2)
+
+ def benchmark():
+ return torch.bmm(x, w)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '0::Fwd:DDD::NT',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * fhs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear0_GradX_DDD_NN(self, sl, hs, fhs, ne):
+ assert (sl % ne) == 0
+ x = torch.randn((ne, sl // ne, hs)).cuda().half()
+ w = torch.randn((ne, hs, fhs)).cuda().half()
+ out = torch.bmm(x, w)
+ w = w.transpose(1, 2).contiguous()
+
+ def benchmark():
+ return torch.bmm(out, w)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '0:GradX:DDD::NN',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * fhs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear0_GradW_DDD_TN(self, sl, hs, fhs, ne):
+ assert (sl % ne) == 0
+ x = torch.randn((ne, sl // ne, hs)).cuda().half()
+ w = torch.randn((ne, hs, fhs)).cuda().half()
+ out = torch.bmm(x, w)
+ out = out.transpose(1, 2)
+
+ def benchmark():
+ return torch.bmm(out, x)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '0:GradW:DDD::TN',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * fhs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear1_Fwd_DDD_NN(self, sl, hs, fhs, ne):
+ assert (sl % ne) == 0
+ x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+ w = torch.randn((ne, fhs, hs)).cuda().half()
+
+ def benchmark():
+ return torch.bmm(x, w)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '1::Fwd::DDD::NN',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * hs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear1_GradX_DDD_NT(self, sl, hs, fhs, ne):
+ assert (sl % ne) == 0
+ x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+ w = torch.randn((ne, fhs, hs)).cuda().half()
+ out = torch.bmm(x, w)
+ w = torch.transpose(w, 1, 2)
+
+ def benchmark():
+ return torch.bmm(out, w)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '1::GradX::DDD::NT',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * hs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear1_GradW_DDD_TN(self, sl, hs, fhs, ne):
+ assert (sl % ne) == 0
+ x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+ w = torch.randn((ne, fhs, hs)).cuda().half()
+ out = torch.bmm(x, w)
+ x = torch.transpose(x, 1, 2)
+
+ def benchmark():
+ return torch.bmm(x, out)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '1::GradW::DDD::TN',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * hs * 2,
+ )
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/ops/padded_gather.py b/build/torch28-cxx11-cu129-x86_64-linux/ops/padded_gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1cf4047c9494394d2a3884ba8830179013db7ff
--- /dev/null
+++ b/build/torch28-cxx11-cu129-x86_64-linux/ops/padded_gather.py
@@ -0,0 +1,55 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for padded_gather kernel.
+class PaddedGatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bin_ids, bins, padded_bins)
+ ctx.top_k = top_k
+ return kernels.padded_gather(
+ x,
+ indices,
+ bin_ids,
+ None,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+
+ indices, bin_ids, bins, padded_bins = ctx.saved_tensors
+ out = kernels.padded_scatter(
+ grad,
+ indices,
+ bin_ids,
+ None,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+ return out, None, None, None, None, None
+
+
+padded_gather = PaddedGatherOp.apply
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/ops/padded_scatter.py b/build/torch28-cxx11-cu129-x86_64-linux/ops/padded_scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..61e021b81497e472cda5d72bdac557a0ca92d262
--- /dev/null
+++ b/build/torch28-cxx11-cu129-x86_64-linux/ops/padded_scatter.py
@@ -0,0 +1,98 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for padded_scatter kernel.
+class PaddedScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+ ):
+ maybe_x = [x] if ctx.needs_input_grad[3] else []
+ ctx.save_for_backward(
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ *maybe_x,
+ )
+ ctx.top_k = top_k
+ ctx.x_shape = x.shape
+ return kernels.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ saved_tensors = ctx.saved_tensors
+
+ indices, bin_ids, weights, bins, padded_bins = saved_tensors[:5]
+ dgrad = None
+ if ctx.needs_input_grad[0]:
+ dgrad = kernels.padded_gather(
+ grad,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[3]: # need wgrad
+ x = saved_tensors[-1]
+ wgrad = kernels.padded_scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+ return dgrad, None, None, wgrad, None, None, None, None
+
+
+def padded_scatter(
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+):
+ return PaddedScatterOp.apply(
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/ops/padded_scatter_benchmark.py b/build/torch28-cxx11-cu129-x86_64-linux/ops/padded_scatter_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..c575cfe7487d346ba9ec18bbb7ef17f2eb77ec51
--- /dev/null
+++ b/build/torch28-cxx11-cu129-x86_64-linux/ops/padded_scatter_benchmark.py
@@ -0,0 +1,66 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import torch
+from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+_PADDED_SCATTER_BENCHMARK = (
+ # dMoE-Medium, 8-way EMP.
+ (1024 * 16, 1024, 8, 4),
+ # dMoE-Medium, post-all-to-all.
+ (1024 * 16 * 4, 1024, 8, 1),
+)
+
+
+class PaddedScatterTest(parameterized.TestCase):
+
+ @parameterized.parameters(*_PADDED_SCATTER_BENCHMARK)
+ def testPaddedScatter(self, sl, hs, ne, top_k):
+ # Create the data and indices.
+ x = torch.randn((sl, hs)).cuda().half()
+
+ # Randomly assign tokens to experts.
+ top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int()
+ bin_ids, indices = ops.sort(top_expert)
+ tokens_per_expert = ops.histogram(top_expert, ne)
+ padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+
+ # Sample weights for the scatter reduce.
+ weights = torch.rand((sl * top_k,)).cuda().half()
+
+ # Gather the data to prepare for backwards.
+ x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
+
+ def benchmark():
+ return ops.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ time, std = benchmark_util.benchmark_function(benchmark)
+ benchmark_util.log_benchmark(
+ 'Padded Scatter',
+ {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'num_experts': ne,
+ 'top_k': top_k,
+ },
+ time,
+ std,
+ )
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/ops/permute_benchmark.py b/build/torch28-cxx11-cu129-x86_64-linux/ops/permute_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..6536eeeae402659a087e5c51ef9840627af56501
--- /dev/null
+++ b/build/torch28-cxx11-cu129-x86_64-linux/ops/permute_benchmark.py
@@ -0,0 +1,149 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import torch
+from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+_PERMUTE_TESTS = (
+ (16384, 768, 2),
+ (16384, 768, 4),
+ (16384, 768, 8),
+ (16384, 768, 16),
+ (16384, 768, 32),
+ (16384, 768, 64),
+ (16384, 768, 128),
+ (16384 * 8, 768, 2),
+ (16384 * 8, 768, 4),
+ (16384 * 8, 768, 8),
+ (16384 * 8, 768, 16),
+ (16384 * 8, 768, 32),
+ (16384 * 8, 768, 64),
+ (16384 * 8, 768, 128),
+)
+
+
+class PermuteBenchmark(parameterized.TestCase):
+
+ @parameterized.parameters(*_PERMUTE_TESTS)
+ def testBinnedGather(self, sl, hs, ne):
+ # NOTE: Capacity factor == 1.
+ ec = sl // ne
+
+ # Create the data and indices.
+ x = torch.randn((sl, hs)).cuda().half()
+ top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+ bin_ids, indices = ops.sort(top_expert)
+ tokens_per_expert = ops.histogram(indices, ne)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+
+ def benchmark():
+ return ops.binned_gather(x, indices, bins, ec)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'num_experts': ne,
+ }
+ benchmark_util.log_benchmark('BinnedGather', arguments, mean_t, std_t)
+
+ @parameterized.parameters(*_PERMUTE_TESTS)
+ def testBinnedScatter(self, sl, hs, ne):
+ # NOTE: Capacity factor == 1.
+ ec = sl // ne
+
+ # Create the data and indices.
+ x = torch.randn((sl, hs)).cuda().half()
+ top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+ bin_ids, indices = ops.sort(top_expert)
+ tokens_per_expert = ops.histogram(indices, ne)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ x = ops.binned_gather(x, indices, bins, ec)
+
+ def benchmark():
+ return ops.binned_scatter(x, indices, bins)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'num_experts': ne,
+ }
+ benchmark_util.log_benchmark('BinnedScatter', arguments, mean_t, std_t)
+
+ @parameterized.parameters(*_PERMUTE_TESTS)
+ def testPaddedGather(self, sl, hs, ne):
+ # Create the data and indices.
+ x = torch.randn((sl, hs)).cuda().half()
+
+ # Randomly assign tokens to experts.
+ top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+ bin_ids, indices = ops.sort(top_expert)
+ tokens_per_expert = ops.histogram(top_expert, ne)
+ padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+
+ def benchmark():
+ return ops.padded_gather(x, indices, bin_ids, bins, padded_bins)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'num_experts': ne,
+ }
+ benchmark_util.log_benchmark('PaddedGather', arguments, mean_t, std_t)
+
+ @parameterized.parameters(*_PERMUTE_TESTS)
+ def testPaddedScatter(self, sl, hs, ne):
+ # Create the data and indices.
+ x = torch.randn((sl, hs)).cuda().half()
+
+ # Randomly assign tokens to experts.
+ top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+ bin_ids, indices = ops.sort(top_expert)
+ tokens_per_expert = ops.histogram(top_expert, ne)
+ padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins)
+
+ def benchmark():
+ return ops.padded_scatter(x, indices, bin_ids, bins, padded_bins)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'num_experts': ne,
+ }
+ benchmark_util.log_benchmark('PaddedScatter', arguments, mean_t, std_t)
+
+ @parameterized.parameters(*_PERMUTE_TESTS)
+ def testCopy(self, sl, hs, ne):
+ # NOTE: Capacity factor == 1.
+ # ec = sl // ne
+
+ # Create the data and indices.
+ x = torch.randn((sl, hs)).cuda().half()
+ y = x.clone()
+
+ def benchmark():
+ return y.copy_(x)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'num_experts': ne,
+ }
+ benchmark_util.log_benchmark('Copy', arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/ops/repeat.py b/build/torch28-cxx11-cu129-x86_64-linux/ops/repeat.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e9e09de5f857d51cd758ab30b2f3a846d6f9275
--- /dev/null
+++ b/build/torch28-cxx11-cu129-x86_64-linux/ops/repeat.py
@@ -0,0 +1,10 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+
+def repeat(x: torch.Tensor, tiling: torch.Size):
+ if all((t == 1 for t in tiling)):
+ return x
+ return x.repeat(*tiling)
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/ops/replicate.py b/build/torch28-cxx11-cu129-x86_64-linux/ops/replicate.py
new file mode 100644
index 0000000000000000000000000000000000000000..26daf0eede330603a4b8ea7167faf1411d07ca93
--- /dev/null
+++ b/build/torch28-cxx11-cu129-x86_64-linux/ops/replicate.py
@@ -0,0 +1,36 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for replicate kernel.
+class ReplicateOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, bins: torch.Tensor, num_outputs: int):
+ ctx.save_for_backward(bins)
+ out = torch.empty((x.shape[0], num_outputs), dtype=x.dtype, device=x.device)
+ ops.replicate_forward(x, bins, out)
+ return out
+
+ @staticmethod
+ def backward(ctx: Any, grad: torch.Tensor):
+ bins, = ctx.saved_tensors
+ out = torch.empty((grad.shape[0], bins.shape[0]), dtype=grad.dtype, device=grad.device)
+ ops.replicate_backward(grad, bins, out)
+ return out, None, None
+
+
+replicate = ReplicateOp.apply
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/ops/round_up.py b/build/torch28-cxx11-cu129-x86_64-linux/ops/round_up.py
new file mode 100644
index 0000000000000000000000000000000000000000..6cf6bc873c9f448c5fa9126ebcfd66e8688002af
--- /dev/null
+++ b/build/torch28-cxx11-cu129-x86_64-linux/ops/round_up.py
@@ -0,0 +1,14 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+
+def round_up(x: torch.Tensor, value: int):
+ assert isinstance(value, int)
+ assert x.dtype == torch.int32
+
+ # TODO(tgale): If this becomes and issue
+ # do this in a custom kernel. We only expect
+ # to use this on arrays of less than 1k elements.
+ return torch.div(x + (value - 1), value, rounding_mode='trunc') * value
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/ops/scatter.py b/build/torch28-cxx11-cu129-x86_64-linux/ops/scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4605d9b46f387761b070352365f223dbfe69d47
--- /dev/null
+++ b/build/torch28-cxx11-cu129-x86_64-linux/ops/scatter.py
@@ -0,0 +1,72 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Optional
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for scatter kernel.
+class ScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ) -> torch.Tensor:
+ maybe_x = [x] if ctx.needs_input_grad[3] else []
+ ctx.save_for_backward(indices, bin_ids, weights, bins, *maybe_x)
+ ctx.top_k = top_k
+ ctx.x_shape = x.shape
+ return kernels.scatter(x, indices, bin_ids, weights, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ saved_tensors = ctx.saved_tensors
+
+ indices, bin_ids, weights, bins = saved_tensors[:4]
+ dgrad = None
+ if ctx.needs_input_grad[0]:
+ dgrad = kernels.gather(
+ grad,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[3]: # need wgrad
+ x = saved_tensors[-1]
+ wgrad = kernels.scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bin_ids,
+ bins,
+ ctx.top_k,
+ )
+ return dgrad, None, None, wgrad, None, None, None
+
+
+def scatter(
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+) -> Optional[torch.Tensor]:
+ return ScatterOp.apply(x, indices, bin_ids, weights, bins, top_k)
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/ops/sort.py b/build/torch28-cxx11-cu129-x86_64-linux/ops/sort.py
new file mode 100644
index 0000000000000000000000000000000000000000..bda3bf64283e39533c2eae3627e76bb2d0262c9f
--- /dev/null
+++ b/build/torch28-cxx11-cu129-x86_64-linux/ops/sort.py
@@ -0,0 +1,38 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Optional, Tuple
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+_BITS_FOR_DTYPE = {
+ torch.int16: 16,
+ torch.int32: 32,
+ torch.int64: 64,
+}
+
+
+# Autograd wrapper for sort kernel.
+# NOTE: Does not support gradients.
+class SortOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, end_bit: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
+ if end_bit is None:
+ end_bit = _BITS_FOR_DTYPE[x.dtype]
+ x_out = torch.empty_like(x)
+ iota_out = torch.empty_like(x)
+ ops.sort(x, end_bit, x_out, iota_out)
+ return (x_out, iota_out)
+
+
+sort = SortOp.apply
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/ops/sort_benchmark.py b/build/torch28-cxx11-cu129-x86_64-linux/ops/sort_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..a92ff957d4c552c6e61d9279a7989795472af7b7
--- /dev/null
+++ b/build/torch28-cxx11-cu129-x86_64-linux/ops/sort_benchmark.py
@@ -0,0 +1,85 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import numpy as np
+import torch
+from absl.testing import parameterized
+
+from .. import ops
+
+_SORT_TESTS = (
+ (16384, torch.int32, None),
+ (16384, torch.int32, 2),
+ (16384, torch.int32, 128),
+)
+
+_BASELINE_SORT_TESTS = ((16384,),)
+
+
+def numpy_dtype(dtype):
+ types = {
+ torch.int16: np.int16,
+ torch.int32: np.int32,
+ torch.int64: np.int64,
+ }
+ return types[dtype]
+
+
+def benchmark_function(fn, iterations=10):
+ # Run once to get rid of startup overhead.
+ fn()
+ times = []
+ for _ in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+ start.record()
+ fn()
+ end.record()
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ times = np.array(times)
+ return times.mean(), times.std(), times.max(), times.min()
+
+
+def log_benchmark(arguments, mean_t, std_t):
+ print('=' * 60)
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t))
+ print('=' * 60)
+
+
+class SortBenchmark(parameterized.TestCase):
+
+ @parameterized.parameters(*_SORT_TESTS)
+ def testSort(self, n, dtype, max_val):
+ if max_val is None:
+ max_val = np.iinfo(numpy_dtype(dtype)).max
+ end_bit = int(np.ceil(np.log2(max_val)))
+ x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
+
+ mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.sort(x, end_bit),)
+ arguments = {
+ 'n': n,
+ 'dtype': dtype,
+ 'max_val': max_val,
+ }
+ log_benchmark(arguments, mean_t, std_t)
+
+ @parameterized.parameters(*_BASELINE_SORT_TESTS)
+ def testTorchSort(self, n):
+ x = torch.randint(0, 128, (n,)).cuda().to(torch.int32)
+
+ mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.sort(x))
+ arguments = {
+ 'n': n,
+ }
+ log_benchmark(arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/ops/stk_autocast.py b/build/torch28-cxx11-cu129-x86_64-linux/ops/stk_autocast.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a3626e5e0eec51339c95a448bca84be14a2ca93
--- /dev/null
+++ b/build/torch28-cxx11-cu129-x86_64-linux/ops/stk_autocast.py
@@ -0,0 +1,39 @@
+# vendored from
+# https://github.com/stanford-futuredata/stk/blob/736313768ef697ce13a0594a41b2512a0fbc9884/stk/backend/autocast.py
+import functools
+import torch
+
+
+def _is_eligible(x):
+ return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64)
+
+
+def _cast(x, dtype):
+ if isinstance(x, torch.Tensor) and _is_eligible(x):
+ return x.to(dtype)
+ elif isinstance(x, map):
+ return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()}
+ elif isinstance(x, list) or isinstance(x, tuple):
+ return type(x)(map(lambda y: _cast(y, dtype), x))
+ return x
+
+
+def custom_fwd(fwd):
+ """Wrap a custom autograd function that always uses autocast dtype."""
+
+ @functools.wraps(fwd)
+ def decorate_fwd(*args, **kwargs):
+ if torch.is_autocast_enabled():
+ with torch.autocast(device_type="cuda", enabled=False):
+ dtype = torch.get_autocast_gpu_dtype()
+ return fwd(*_cast(args, dtype), **_cast(kwargs, dtype))
+ return fwd(*args, **kwargs)
+ return decorate_fwd
+
+
+def custom_bwd(bwd):
+ @functools.wraps(bwd)
+ def decorate_bwd(*args, **kwargs):
+ with torch.autocast(device_type="cuda", enabled=False):
+ return bwd(*args, **kwargs)
+ return decorate_bwd
\ No newline at end of file
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/ops/sum.py b/build/torch28-cxx11-cu129-x86_64-linux/ops/sum.py
new file mode 100644
index 0000000000000000000000000000000000000000..e00c1aa68e1193f5b72f75a2edc37de8d505facc
--- /dev/null
+++ b/build/torch28-cxx11-cu129-x86_64-linux/ops/sum.py
@@ -0,0 +1,9 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+import torch
+
+
+def sum(x: torch.Tensor, dim: int = 0):
+ if x.shape[dim] == 1:
+ return x.squeeze(dim=dim)
+ return x.sum(dim=dim)
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/ops/topology.py b/build/torch28-cxx11-cu129-x86_64-linux/ops/topology.py
new file mode 100644
index 0000000000000000000000000000000000000000..76a50d3164db20534b099dcb4d8487a7aef25d15
--- /dev/null
+++ b/build/torch28-cxx11-cu129-x86_64-linux/ops/topology.py
@@ -0,0 +1,45 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for topology kernel.
+# NOTE: Does not support gradients.
+class TopologyOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(
+ ctx: Any,
+ padded_bins: torch.Tensor,
+ block_size: int,
+ output_block_rows: int,
+ output_block_columns: int,
+ ):
+ out = torch.empty(
+ output_block_rows * output_block_columns,
+ dtype=torch.int16,
+ device=padded_bins.device,
+ )
+ ops.indices(
+ padded_bins,
+ block_size,
+ output_block_rows,
+ output_block_columns,
+ out,
+ )
+ return out
+
+
+topology = TopologyOp.apply
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/stk/__init__.py b/build/torch28-cxx11-cu129-x86_64-linux/stk/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..73c40c267b1c3f4949e9c957a5d2c9f682dfc1a6
--- /dev/null
+++ b/build/torch28-cxx11-cu129-x86_64-linux/stk/__init__.py
@@ -0,0 +1,7 @@
+# import stk.random
+# import stk.ops
+# from stk.matrix import Matrix
+
+from . import random
+from . import ops
+from .matrix import Matrix
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/stk/backend/__init__.py b/build/torch28-cxx11-cu129-x86_64-linux/stk/backend/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/stk/backend/autocast.py b/build/torch28-cxx11-cu129-x86_64-linux/stk/backend/autocast.py
new file mode 100644
index 0000000000000000000000000000000000000000..97f6e919a60f3fd579ed0215031008d14111dc96
--- /dev/null
+++ b/build/torch28-cxx11-cu129-x86_64-linux/stk/backend/autocast.py
@@ -0,0 +1,37 @@
+import functools
+import torch
+
+
+def _is_eligible(x):
+ return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64)
+
+
+def _cast(x, dtype):
+ if isinstance(x, torch.Tensor) and _is_eligible(x):
+ return x.to(dtype)
+ elif isinstance(x, map):
+ return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()}
+ elif isinstance(x, list) or isinstance(x, tuple):
+ return type(x)(map(lambda y: _cast(y, dtype), x))
+ return x
+
+
+def custom_fwd(fwd):
+ """Wrap a custom autograd function that always uses autocast dtype."""
+
+ @functools.wraps(fwd)
+ def decorate_fwd(*args, **kwargs):
+ if torch.is_autocast_enabled():
+ with torch.autocast(device_type="cuda", enabled=False):
+ dtype = torch.get_autocast_gpu_dtype()
+ return fwd(*_cast(args, dtype), **_cast(kwargs, dtype))
+ return fwd(*args, **kwargs)
+ return decorate_fwd
+
+
+def custom_bwd(bwd):
+ @functools.wraps(bwd)
+ def decorate_bwd(*args, **kwargs):
+ with torch.autocast(device_type="cuda", enabled=False):
+ return bwd(*args, **kwargs)
+ return decorate_bwd
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/stk/backend/sputnik.py b/build/torch28-cxx11-cu129-x86_64-linux/stk/backend/sputnik.py
new file mode 100644
index 0000000000000000000000000000000000000000..220c947bc1e932e8c77cc30f4069e9930f1aa962
--- /dev/null
+++ b/build/torch28-cxx11-cu129-x86_64-linux/stk/backend/sputnik.py
@@ -0,0 +1,316 @@
+import torch
+
+from ..backend import triton_kernels as backend
+from ..backend.autocast import custom_bwd, custom_fwd
+
+
+def _standardize_shape(x, transpose):
+ if transpose:
+ return torch.Size((x[1], x[0]))
+ return x
+
+
+def _sparse_transpose(x):
+ return (torch.Size((x[0][1], x[0][0])), ) + x[1:]
+
+
+def _transpose_helper(x, transpose):
+ if isinstance(x, torch.Tensor):
+ return x.t() if transpose else x
+ if transpose:
+ x = _sparse_transpose(x)
+ return x + (transpose,)
+
+
+def _wrap(x):
+ if isinstance(x, torch.Tensor):
+ return (x,)
+ return x
+
+
+def _is_transposed(x):
+ return (not x.is_contiguous() and
+ x.stride()[0] == 1 and
+ x.stride()[1] == x.size()[0])
+
+
+def _call_helper(op, out, a, b, trans_a, trans_b):
+ args = (_wrap(_transpose_helper(a, trans_a)) +
+ _wrap(_transpose_helper(b, trans_b)))
+ if isinstance(out, tuple):
+ args = args + out
+ return op(*args)
+
+
+def _preprocess_inputs(lhs, rhs, dy):
+ if isinstance(lhs, torch.Tensor) and _is_transposed(lhs):
+ lhs = lhs.t()
+ if isinstance(rhs, torch.Tensor) and _is_transposed(rhs):
+ rhs = rhs.t()
+ if (isinstance(dy, torch.Tensor) and
+ not dy.is_contiguous() and
+ not _is_transposed(dy)):
+ dy = dy.contiguous()
+ if isinstance(dy, tuple) and not dy[1].is_contiguous():
+ dy = (dy[0], dy[1].contiguous()) + dy[2:]
+ return lhs, rhs, dy
+
+
+def _postprocess_outputs(x, transpose, grad):
+ if isinstance(x, torch.Tensor) and transpose:
+ return grad.t()
+ return grad
+
+
+def _lhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs):
+ lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy)
+
+ a, b = (rhs, dy) if trans_lhs else (dy, rhs)
+ trans_a = trans_lhs and trans_rhs
+ trans_b = trans_lhs or not trans_rhs
+ out = _call_helper(op, lhs, a, b, trans_a, trans_b)
+ return _postprocess_outputs(lhs, trans_lhs, out)
+
+
+def _rhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs):
+ lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy)
+
+ a, b = (dy, lhs) if trans_rhs else (lhs, dy)
+ trans_a = not trans_lhs or trans_rhs
+ trans_b = trans_lhs and trans_rhs
+ out = _call_helper(op, rhs, a, b, trans_a, trans_b)
+ return _postprocess_outputs(rhs, trans_rhs, out)
+
+
+class DSD(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs):
+ ctx.save_for_backward(data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ rhs)
+ ctx.shape = _standardize_shape(shape, transpose_a)
+ ctx.transpose_a = transpose_a
+
+ out = torch.empty(
+ (shape[0], rhs.size()[1]),
+ dtype=rhs.dtype,
+ device=rhs.device)
+
+ backend.dsd(shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs,
+ out)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs = (ctx.shape,) + saved_tensors[:-1]
+ rhs = saved_tensors[-1]
+ trans_a = ctx.transpose_a
+ trans_b = _is_transposed(rhs)
+
+ ddata = None
+ if ctx.needs_input_grad[1]:
+ ddata = _lhs_gradient(sdd,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ drhs = None
+ if ctx.needs_input_grad[-1]:
+ op = dds if trans_b else dsd
+ drhs = _rhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return None, ddata, None, None, None, None, None, None, None, drhs
+
+
+dsd = DSD.apply
+
+
+class DDS(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b):
+ ctx.save_for_backward(lhs,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t)
+ ctx.shape = _standardize_shape(shape, transpose_b)
+ ctx.transpose_b = transpose_b
+ out = torch.empty((lhs.size()[0], shape[1]),
+ dtype=lhs.dtype,
+ device=lhs.device)
+ backend.dds(lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b,
+ out)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs = saved_tensors[0]
+ rhs = (ctx.shape,) + saved_tensors[1:]
+ trans_a = _is_transposed(lhs)
+ trans_b = ctx.transpose_b
+
+ dlhs = None
+ if ctx.needs_input_grad[0]:
+ op = dsd if trans_a else dds
+ dlhs = _lhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ ddata = None
+ if ctx.needs_input_grad[2]:
+ ddata = _rhs_gradient(sdd,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return dlhs, None, ddata, None, None, None, None, None, None, None
+
+
+dds = DDS.apply
+
+
+class SDD(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ lhs,
+ rhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t):
+ ctx.save_for_backward(
+ lhs,
+ rhs,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t)
+ ctx.shape = shape
+ out = torch.empty(
+ data.shape,
+ dtype=lhs.dtype,
+ device=lhs.device)
+ backend.sdd(lhs,
+ rhs,
+ shape,
+ out,
+ offsets,
+ row_indices,
+ column_indices)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs, rhs = saved_tensors[:2]
+ dy = (ctx.shape, dy) + saved_tensors[2:]
+ trans_a = _is_transposed(lhs)
+ trans_b = _is_transposed(rhs)
+
+ dlhs = None
+ if ctx.needs_input_grad[0]:
+ op = dds if trans_a else dsd
+ dlhs = _lhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ drhs = None
+ if ctx.needs_input_grad[1]:
+ op = dsd if trans_b else dds
+ drhs = _rhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return dlhs, drhs, None, None, None, None, None, None, None, None
+
+
+sdd = SDD.apply
+
+class RowIndices(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, shape, data, offsets, column_indices):
+ out = torch.empty(
+ column_indices.shape,
+ dtype=column_indices.dtype,
+ device=column_indices.device)
+ backend.row_indices(shape, data, offsets, column_indices, out)
+ return out
+
+
+row_indices = RowIndices.apply
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/stk/backend/triton_kernels.py b/build/torch28-cxx11-cu129-x86_64-linux/stk/backend/triton_kernels.py
new file mode 100644
index 0000000000000000000000000000000000000000..c535309f3321249f475367164a558f94a4f8eb86
--- /dev/null
+++ b/build/torch28-cxx11-cu129-x86_64-linux/stk/backend/triton_kernels.py
@@ -0,0 +1,393 @@
+import torch
+import triton
+import triton.language as tl
+from dataclasses import dataclass
+
+@dataclass
+class TritonConfig:
+ BLOCK_M: int = 128
+ BLOCK_N: int = 128
+ BLOCK_K: int = 32
+ BLOCK_SIZE: int = 128
+ NUM_STAGES: int = 4
+ NUM_WARPS: int = 4
+
+def _validate_matmul_dims(M: int, K: int, N: int):
+ error_string = "incompatible dimensions: tensor has dim with length: {}, which must be divisible by {}"
+ assert M % TritonConfig.BLOCK_M == 0, error_string.format(M, TritonConfig.BLOCK_M)
+ assert K % TritonConfig.BLOCK_K == 0, error_string.format(K, TritonConfig.BLOCK_K)
+ assert N % TritonConfig.BLOCK_N == 0, error_string.format(N, TritonConfig.BLOCK_N)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _sdd_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+ # matrix multiplication
+ pid = tl.program_id(0)
+ pid_m = tl.load(row_indices + pid)
+ pid_n = tl.load(column_indices + pid)
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
+ rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
+ rk = tl.arange(0, BLOCK_K)
+ # pointers
+ A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
+ B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ for k in range(0, tl.cdiv(K, BLOCK_K)):
+ a = tl.load(A)
+ b = tl.load(B)
+ acc += tl.dot(a, b)
+ A += BLOCK_K * stride_ak
+ B += BLOCK_K * stride_bk
+ #Store to sparse matrix
+ acc = acc.to(C.dtype.element_ty)
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+ cm = tl.arange(0, BLOCK_M)
+ cn = tl.arange(0, BLOCK_N)
+ C = C + pid * BLOCK_ELEMENTS + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _dsd_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices, offsets,
+ block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+
+ # matrix multiplication
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+
+ num_pid_m = tl.num_programs(0)
+ num_pid_n = tl.num_programs(1)
+ pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M)
+
+ start_inx = tl.load(offsets + pid_m)
+ end_inx = tl.load(offsets + pid_m + 1)
+
+ # pointers to sparse matrix
+ rm = tl.arange(0, BLOCK_M)
+ rak = tl.arange(0, BLOCK_K)
+
+ A += (rm[:, None] * stride_am + rak[None, :] * stride_ak)
+
+ # pointers to dense matrix
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ rbk = tl.arange(0, BLOCK_K)
+ B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn)
+
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K)
+
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+ ak_sub_incr = BLOCK_K * stride_ak
+ bk_sub_incr = BLOCK_K * stride_bk
+ bk_block_incr = BLOCK_SIZE * stride_bk
+
+ for k in range(nsub_blocks * (end_inx - start_inx)):
+ sub_block_inx = k % nsub_blocks
+ block_inx = k // nsub_blocks
+
+ if trans_A:
+ ptr_A = A + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr
+ else:
+ ptr_A = A + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr
+
+ ptr_B = B + tl.load(column_indices + start_inx + block_inx) * bk_block_incr + sub_block_inx * bk_sub_incr
+
+ a = tl.load(ptr_A)
+ b = tl.load(ptr_B)
+ acc += tl.dot(a, b)
+
+ acc = acc.to(C.dtype.element_ty)
+
+ cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+
+ C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _dds_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices, offsets,
+ block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+
+ # matrix multiplication
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+
+ num_pid_m = tl.num_programs(0)
+ num_pid_n = tl.num_programs(1)
+ pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M)
+
+ start_inx = tl.load(offsets + pid_n)
+ end_inx = tl.load(offsets + pid_n + 1)
+
+ # pointers to dense matrix
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ rak = tl.arange(0, BLOCK_K)
+
+ A += (rm[:, None] * stride_am + rak[None, :] * stride_ak)
+
+ # pointers to sparse matrix
+ rn = tl.arange(0, BLOCK_N)
+ rbk = tl.arange(0, BLOCK_K)
+ B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn)
+
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K)
+
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+
+ ak_sub_incr = BLOCK_K * stride_ak
+ ak_block_incr = BLOCK_SIZE * stride_ak
+ bk_sub_incr = BLOCK_K * stride_bk
+
+ for k in range(nsub_blocks * (end_inx - start_inx)):
+ sub_block_inx = k % nsub_blocks
+ block_inx = k // nsub_blocks
+
+ if trans_B:
+ ptr_B = B + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr
+ else:
+ ptr_B = B + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr
+
+ ptr_A = A + tl.load(column_indices + start_inx + block_inx) * ak_block_incr + sub_block_inx * ak_sub_incr
+ a = tl.load(ptr_A)
+ b = tl.load(ptr_B)
+ acc += tl.dot(a, b)
+
+ acc = acc.to(C.dtype.element_ty)
+ cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+def dsd(shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs,
+ out
+ ):
+
+ device = rhs.device
+ trans_A = transpose_a
+ trans_B = False
+
+ if rhs.stride(0) > 1 and rhs.stride(1) > 1:
+ trans_B = True
+
+ # checks constraints
+ assert shape[1] == rhs.shape[0], "incompatible dimensions"
+ M, K = shape
+ _, N = rhs.shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if rhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ stride_am, stride_ak = data.stride(1), data.stride(2)
+ stride_bk, stride_bn = rhs.stride(0), rhs.stride(1)
+ a_column_indices = column_indices
+ a_offsets = offsets
+
+ # launch kernel
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N']))
+
+ if trans_A:
+ stride_am, stride_ak = data.stride(2), data.stride(1)
+ a_column_indices, a_offsets = column_indices_t, offsets_t
+
+ if trans_B:
+ stride_bk, stride_bn = rhs.stride(1), rhs.stride(0)
+
+ _dsd_kernel[grid](
+ data.data, rhs, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(0), out.stride(1),
+ row_indices, a_column_indices, a_offsets,
+ block_offsets_t, trans_A, trans_B,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+ # return out
+
+def dds(lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b,
+ out
+ ):
+
+ device = lhs.device
+ trans_B = transpose_b
+ trans_A = False
+
+ if lhs.stride(0) > 1 and lhs.stride(1) > 1:
+ trans_A = True
+
+ # checks constraints
+ assert lhs.shape[1] == shape[0], "incompatible dimensions"
+ M, K = lhs.shape
+ _, N = shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if lhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ stride_am, stride_ak = lhs.stride(0), lhs.stride(1)
+ stride_bk, stride_bn = data.stride(1), data.stride(2)
+ b_column_indices = column_indices_t
+ b_offsets = offsets_t
+
+ # launch kernel
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N']))
+
+ if trans_A:
+ stride_am, stride_ak = lhs.stride(1), lhs.stride(0)
+ if trans_B:
+ stride_bk, stride_bn = data.stride(2), data.stride(1)
+ b_column_indices, b_offsets = column_indices, offsets
+
+ _dds_kernel[grid](
+ lhs, data, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(0), out.stride(1),
+ row_indices, b_column_indices, b_offsets,
+ block_offsets_t, trans_A, trans_B,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+
+def sdd(lhs,
+ rhs,
+ shape,
+ out,
+ offsets,
+ row_indices,
+ column_indices
+ ):
+
+ device = out.device
+ trans_A = False
+ trans_B = False
+
+ if lhs.stride(0) > 1 and lhs.stride(1) > 1:
+ trans_A = True
+ if rhs.stride(0) > 1 and rhs.stride(1) > 1:
+ trans_B = True
+
+ # checks constraints
+ assert lhs.shape[1] == rhs.shape[0], "incompatible dimensions"
+ M, K = lhs.shape
+ _, N = rhs.shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if out.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ # launch kernel
+ nnz_blocks = len(row_indices)
+ grid = lambda META: (nnz_blocks,)
+
+ stride_am, stride_ak = lhs.stride(0), lhs.stride(1)
+ stride_bk, stride_bn = rhs.stride(0), rhs.stride(1)
+
+ if trans_A:
+ stride_am, stride_ak = lhs.stride(1), lhs.stride(0)
+ if trans_B:
+ stride_bk, stride_bn = rhs.stride(1), rhs.stride(0)
+
+ _sdd_kernel[grid](
+ lhs, rhs, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(1), out.stride(2),
+ row_indices, column_indices,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+
+@triton.jit
+def _row_indices_kernel(offsets, out):
+ pid = tl.program_id(0)
+ row_offset = tl.load(offsets + pid)
+ nnz_blocks = tl.load(offsets + pid + 1) - row_offset
+ for nnz_block in range(nnz_blocks):
+ tl.store(out + row_offset + nnz_block, pid)
+
+def row_indices(
+ shape, data, offsets, column_indices, out
+):
+ block_rows = len(offsets) - 1
+ _row_indices_kernel[(block_rows, )](offsets, out)
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/stk/matrix.py b/build/torch28-cxx11-cu129-x86_64-linux/stk/matrix.py
new file mode 100644
index 0000000000000000000000000000000000000000..80f42263d6aada287adbfa52a61fe950162a9e28
--- /dev/null
+++ b/build/torch28-cxx11-cu129-x86_64-linux/stk/matrix.py
@@ -0,0 +1,329 @@
+import numpy as np
+import torch
+
+# 1. Add heavyweight (data) validation helper.
+# 2. Add construction helpers
+# 3. Make indentation consistent
+# 4. Replace asserts with descriptive errors.
+
+##
+### Validation helpers.
+##
+
+
+def _validate_matrix(shape, data, row_indices, column_indices, offsets):
+ # Data should be [nnz, block_size, block_size]
+ if data.dim() == 1:
+ data = torch.reshape(data, [data.numel(), 1, 1])
+
+ # Blocks should be square.
+ if data.shape[-2] != data.shape[-1]:
+ raise ValueError(
+ "Expected square blocking in data. "
+ f"Got block shape {[data.shape[-2], data.shape[-1]]}")
+
+ # Flatten batch dimensions on data - original shape preserved
+ # in shape argument.
+ block_size = data.shape[-1]
+ data = data.view([-1, block_size, block_size])
+
+ if data.dim() != 3:
+ raise ValueError(
+ "Expected 3D shape for data (nnz, block, block). "
+ f"Got shape {data.dim()}D shape.")
+
+ block_size = data.shape[1]
+ if shape[-2] % block_size != 0 or shape[-1] % block_size != 0:
+ raise ValueError(
+ "Matrix shape must be dividible by blocking. "
+ f"Got shape {shape} with "
+ f"{[block_size, block_size]} blocking.")
+
+ if np.prod(shape) < data.numel():
+ raise ValueError(
+ "Invalid matrix. Number of nonzeros exceeds matrix capacity "
+ f"({data.numel()} v. {np.prod(shape)})")
+
+ if row_indices.dim() != 1:
+ raise ValueError(
+ f"Expected 1D row_indices. Got {row_indices.dim()}D row_indices.")
+
+ if column_indices.dim() != 1:
+ raise ValueError(
+ f"Expected 1D column_indices. Got {column_indices.dim()}D column_indices.")
+
+ if offsets.dim() != 1:
+ raise ValueError(
+ f"Expected 1D offsets. Got {offsets.dim()}D offsets.")
+
+ if row_indices.numel() != data.shape[0]:
+ raise ValueError(
+ "Expected 1 index per nonzero block. "
+ f"Got {row_indices.numel()} row_indices for {data.shape[0]} blocks")
+
+ if column_indices.numel() != data.shape[0]:
+ raise ValueError(
+ "Expected 1 index per nonzero block. "
+ f"Got {column_indices.numel()} column_indices for {data.shape[0]} blocks")
+
+ block_rows = np.prod(shape[:-1]) / block_size
+ if offsets.numel() != block_rows + 1:
+ raise ValueError(
+ "Expected one offset per block row plus one. "
+ f"Got {offsets.numel()} offsets with {block_rows} block rows.")
+
+ is_cuda = (data.is_cuda and
+ row_indices.is_cuda and
+ column_indices.is_cuda and
+ offsets.is_cuda)
+ is_cpu = (not data.is_cuda and
+ not row_indices.is_cuda and
+ not column_indices.is_cuda and
+ not offsets.is_cuda)
+ if not (is_cuda or is_cpu):
+ raise ValueError(
+ "Expected data & meta-data on common device. "
+ f"Got data on {data.device}, row_indices on {row_indices.device} "
+ f"column_indices on {column_indices.device} and "
+ f"offsets on {offsets.device}.")
+
+ if data.dtype != torch.float16:
+ raise ValueError(
+ f"Expected float16 data. Got {data.dtype} data.")
+ if row_indices.dtype != torch.int16:
+ raise ValueError(
+ f"Expected int16 row_indices. Got {row_indices.dtype} row_indices.")
+ if column_indices.dtype != torch.int16:
+ raise ValueError(
+ f"Expected int16 column_indices. Got {column_indices.dtype} column_indices.")
+ if offsets.dtype != torch.int32:
+ raise ValueError(
+ f"Expected int32 offsets. Got {offsets.dtype} offsets.")
+ return data
+
+
+def _transpose(size, data, row_indices, column_indices, offsets):
+ block_columns = size[1] // data.shape[1]
+
+ # Sort row indices by column indices to get the transposed matrix's
+ # column indices.
+ gather_indices = column_indices.argsort()
+ column_indices_t = row_indices.gather(0, gather_indices)
+ block_offsets_t = gather_indices.int()
+
+ # NOTE: Histogram is not implemented for any integer type on CPU. Do
+ # the histogram in 32-bit float, which can exactly represent 16-bit
+ # integers.
+ column_indices_float = column_indices.float()
+
+ zero = torch.zeros((1,), dtype=torch.int32, device=data.device)
+ nnz_per_column = column_indices_float.histc(block_columns, 0, block_columns)
+ nnz_per_column = nnz_per_column.int()
+ offsets_t = torch.cat([zero, nnz_per_column.cumsum(0, dtype=torch.int32)])
+ return column_indices_t, offsets_t, block_offsets_t
+
+
+class Matrix(torch.nn.Module):
+ """A matrix stored in sparse format.
+
+ Underlying format is block compressed sparse row (BCSR).
+
+ TODO(tgale): Make this mirror torch.Tensor API as much as possible.
+ """
+
+ def __init__(self,
+ size,
+ data,
+ row_indices,
+ column_indices,
+ offsets,
+ column_indices_t=None,
+ offsets_t=None,
+ block_offsets_t=None):
+ super().__init__()
+ self._size = size
+ self._data = data
+ self._row_indices = row_indices
+ self._column_indices = column_indices
+ self._offsets = offsets
+
+ # Produce the transpose meta-data if it is not passed in.
+ if ((column_indices_t is None) or (offsets_t is None) or
+ (block_offsets_t is None)):
+ column_indices_t, offsets_t, block_offsets_t = _transpose(
+ size, data, row_indices, column_indices, offsets)
+ self._column_indices_t = column_indices_t
+ self._offsets_t = offsets_t
+ self._block_offsets_t = block_offsets_t
+
+ self._transposed = False
+
+ # Validate that our metadata will not overflow.
+ max_dim = np.iinfo(np.int16).max * self.blocking
+ if column_indices.dtype == torch.int16:
+ if size[0] > max_dim or size[1] > max_dim:
+ raise ValueError(
+ "Sparse matrix with shape {size} exceeds representable "
+ "size with 16-bit indices.")
+
+ def validate(self):
+ _validate_matrix(self._size,
+ self._data,
+ self._row_indices,
+ self._column_indices,
+ self._offsets)
+
+ # TODO(tgale): Add heavyweight data validation.
+
+ def to(self, device):
+ # TODO(tgale): Handle type conversions here. We
+ # need to set the appropriate meta-data type for
+ # the given floating-point type.
+ self._data = self._data.to(device)
+ self._row_indices = self._row_indices.to(device)
+ self._column_indices = self._column_indices.to(device)
+ self._offsets = self._offsets.to(device)
+ self._column_indices_t = self._column_indices_t.to(device)
+ self._offsets_t = self._offsets_t.to(device)
+ self._block_offsets_t = self._block_offsets_t.to(device)
+ return self
+
+ def cuda(self):
+ return self.to(torch.cuda.current_device())
+
+ def clone(self):
+ return Matrix(
+ self.size(),
+ self.data.clone(),
+ self.row_indices.clone(),
+ self.column_indices.clone(),
+ self.offsets.clone(),
+ self.column_indices_t.clone(),
+ self.offsets_t.clone(),
+ self.block_offsets_t.clone())
+
+ def t(self):
+ if self.dim() != 2:
+ raise ValueError(
+ "t() expects a tensor with <= 2 dimensions, "
+ f"but self is {self.dim()}D.")
+ out = Matrix(self.size(),
+ self.data,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+ out._transposed = not self._transposed
+ out._size = torch.Size((self._size[1], self._size[0]))
+ return out
+
+ def contiguous(self):
+ raise ValueError("Not yet implemented.")
+
+ def is_contiguous(self):
+ return not self._transposed
+
+ @property
+ def is_cuda(self):
+ return self._data.is_cuda
+
+ @property
+ def device(self):
+ return self._data.device
+
+ def size(self):
+ return self._size
+
+ @property
+ def shape(self):
+ return self.size()
+
+ def dim(self):
+ return len(self._size)
+
+ @property
+ def data(self):
+ return self._data
+
+ @property
+ def row_indices(self):
+ return self._row_indices
+
+ @property
+ def column_indices(self):
+ return self._column_indices
+
+ @property
+ def offsets(self):
+ return self._offsets
+
+ @property
+ def offsets_t(self):
+ return self._offsets_t
+
+ @property
+ def column_indices_t(self):
+ return self._column_indices_t
+
+ @property
+ def block_offsets_t(self):
+ return self._block_offsets_t
+
+ @property
+ def dtype(self):
+ return self.data.dtype
+
+ @property
+ def nnz(self):
+ return self.data.numel()
+
+ @property
+ def blocking(self):
+ return self.data.shape[1]
+
+ @property
+ def requires_grad(self):
+ return self.data.requires_grad
+
+ def requires_grad_(self, x):
+ self.data.requires_grad_(x)
+ return self
+
+ def view(self, *shape):
+ assert self.is_contiguous()
+ if shape[-1] != self.size()[-1]:
+ raise ValueError(
+ "Can't change view on compressed dimension. "
+ f"{self.size()[-1]} v. {shape[-1]}.")
+ if np.prod(shape) != np.prod(self.size()):
+ raise ValueError(
+ "Mismatch in numel of Matrix and new shape. "
+ f"{np.prod(self.size())} v. {np.prod(shape)}")
+ return Matrix(shape,
+ self.data,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+
+ @property
+ def grad(self):
+ # TODO(tgale): Make sure this mirrors torch.Tensor
+ # behavior in the case where we ask for the gradient
+ # of a non-contiguous tensor.
+ size = self.size()
+ if not self.is_contiguous():
+ size = torch.Size((size[1], size[0]))
+ out = Matrix(size,
+ self.data.grad,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+ return out if self.is_contiguous() else out.t()
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/stk/ops/__init__.py b/build/torch28-cxx11-cu129-x86_64-linux/stk/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc873b236f4cd4036964c016a4036e3ce5ebf1ac
--- /dev/null
+++ b/build/torch28-cxx11-cu129-x86_64-linux/stk/ops/__init__.py
@@ -0,0 +1,3 @@
+from .linear_ops import dds, dsd, sdd
+from .matrix_ops import ones_like, row_indices, sum, to_dense, to_sparse
+from .eltwise_ops import mul
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/stk/ops/eltwise_ops.py b/build/torch28-cxx11-cu129-x86_64-linux/stk/ops/eltwise_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba7d7332320250fd01fa60e60528f19de3e8ed03
--- /dev/null
+++ b/build/torch28-cxx11-cu129-x86_64-linux/stk/ops/eltwise_ops.py
@@ -0,0 +1,28 @@
+from ..matrix import Matrix
+
+def mul(a, b):
+ """Performs element-wise multiplication of matrices a and b.
+
+ It is the user's responsibility to make sure that a and b
+ follow the same matrix topology. This function assumes it is safe
+ to use the topoplogy of a.
+
+ Args:
+ a: stk.Matrix.
+ b: stk.Matrix with a's matrix topology.
+
+ Returns:
+ stk.Matrix where the entries correspond to torch.mul(a, b).
+ """
+ assert isinstance(a, Matrix)
+ assert isinstance(b, Matrix)
+ assert a.size() == b.size()
+
+ return Matrix(a.size(),
+ a.data * b.data,
+ a.row_indices,
+ a.column_indices,
+ a.offsets,
+ a.column_indices_t,
+ a.offsets_t,
+ a.block_offsets_t)
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/stk/ops/eltwise_ops_test.py b/build/torch28-cxx11-cu129-x86_64-linux/stk/ops/eltwise_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..66bfd4f6af77042d3c5bdb1fe18d00e457478d46
--- /dev/null
+++ b/build/torch28-cxx11-cu129-x86_64-linux/stk/ops/eltwise_ops_test.py
@@ -0,0 +1,86 @@
+import unittest
+import itertools
+import torch
+from absl.testing import parameterized
+
+import stk
+from stk.ops.linear_ops_test import allclose, _dense_and_sparse
+
+_MATRIX_SIZES = (
+ (128, 128, 0.0),
+ (256, 256, 0.5),
+ (2048, 1024, 0.8),
+ (512, 128, 0.0),
+ (128, 512, 0.0),
+ (1024, 512, 0.0),
+ (1024, 512, 0.5),
+ (1024, 512, 0.75),
+ (512, 1024, 0.0),
+ (512, 1024, 0.5),
+ (512, 1024, 0.75),
+ (1024, 1024, 0.0),
+ (1024, 1024, 0.5),
+ (1024, 1024, 0.75),
+)
+
+_DTYPE = (
+ torch.float16, torch.bfloat16
+)
+
+def _generate_testcases():
+ testcases = itertools.product(_MATRIX_SIZES, _DTYPE)
+ testcases = [(*size, 128, dtype) for
+ (size, dtype) in testcases]
+ return testcases
+
+_ELTWISE_OP_TESTS = _generate_testcases()
+
+def _dense_and_sparse_like(x, std=0.1):
+ dense_data = torch.randn_like(x.data, device=x.device) * std
+ sparse = stk.Matrix(x.size(),
+ dense_data,
+ x.row_indices,
+ x.column_indices,
+ x.offsets)
+ dense = stk.ops.to_dense(sparse)
+
+ return (dense.requires_grad_(True),
+ sparse.requires_grad_(True))
+
+@parameterized.parameters(_ELTWISE_OP_TESTS)
+class EltwiseOpsTest(parameterized.TestCase):
+
+ def testEltwiseMul(self, m, n, sparsity, blocking, dtype):
+
+ a_dense, a = _dense_and_sparse(m, n, sparsity, blocking, dtype)
+ b_dense, b = _dense_and_sparse_like(a)
+
+ out = stk.ops.mul(a, b)
+ expected_out = torch.mul(a_dense, b_dense)
+
+ # Compute the gradients w.r.t. the inputs.
+ expected_out.sum().backward()
+ stk.ops.sum(out).backward()
+
+ # Validate the results.
+ out = stk.ops.to_dense(out)
+ self.assertEqual(out.dim(), 2)
+ self.assertEqual(expected_out.size(), out.size())
+ self.assertTrue(allclose(out, expected_out))
+
+ # LHS gradient.
+ grad = stk.ops.to_dense(a.grad)
+ expected_grad = a_dense.grad
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size(), grad.size())
+ self.assertTrue(allclose(grad, expected_grad))
+
+ # RHS gradient.
+ grad = stk.ops.to_dense(b.grad)
+ expected_grad = b_dense.grad
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size(), grad.size())
+ self.assertTrue(allclose(grad, expected_grad))
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/stk/ops/linear_ops.py b/build/torch28-cxx11-cu129-x86_64-linux/stk/ops/linear_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d277c8c07f9e30addc31900a12175c8a1f4d7ad
--- /dev/null
+++ b/build/torch28-cxx11-cu129-x86_64-linux/stk/ops/linear_ops.py
@@ -0,0 +1,59 @@
+import torch
+
+from ..backend import sputnik
+from ..matrix import Matrix
+
+
+def dsd(a, b):
+ assert isinstance(a, Matrix)
+ assert isinstance(b, torch.Tensor)
+ return sputnik.dsd(
+ a.size(),
+ a.data, a.offsets,
+ a.row_indices,
+ a.column_indices,
+ a.offsets_t,
+ a.column_indices_t,
+ a.block_offsets_t,
+ not a.is_contiguous(),
+ b)
+
+
+def dds(a, b):
+ assert isinstance(a, torch.Tensor)
+ assert isinstance(b, Matrix)
+ return sputnik.dds(
+ a,
+ b.size(),
+ b.data, b.offsets,
+ b.row_indices,
+ b.column_indices,
+ b.offsets_t,
+ b.column_indices_t,
+ b.block_offsets_t,
+ not b.is_contiguous())
+
+
+def sdd(a, b, topo):
+ assert isinstance(a, torch.Tensor)
+ assert isinstance(b, torch.Tensor)
+ assert isinstance(topo, Matrix)
+ assert topo.is_contiguous()
+ out = sputnik.sdd(
+ a, b,
+ topo.size(),
+ topo.data,
+ topo.offsets,
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets_t,
+ topo.column_indices_t,
+ topo.block_offsets_t)
+ return Matrix(topo.size(),
+ out,
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets,
+ topo.column_indices_t,
+ topo.offsets_t,
+ topo.block_offsets_t)
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/stk/ops/linear_ops_test.py b/build/torch28-cxx11-cu129-x86_64-linux/stk/ops/linear_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..ced1d782fbc9f9ca16b3449239f1588dc5ff5e00
--- /dev/null
+++ b/build/torch28-cxx11-cu129-x86_64-linux/stk/ops/linear_ops_test.py
@@ -0,0 +1,216 @@
+import unittest
+import itertools
+import numpy as np
+import torch
+from absl.testing import parameterized
+
+import stk
+
+
+def allclose(x, y, pct=0.25):
+ mask = torch.isclose(x, y, rtol=5e-2)
+ pct_diff = (mask.numel() - mask.sum()) / mask.numel() * 100
+ if pct_diff > pct:
+ print("{:.2f}% of values not close.".format(pct_diff))
+ return False
+ return True
+
+
+# An assortment of problems designed to make sure
+# the bindings are operating correctly.
+_MATRIX_SIZES = (
+ (128, 128, 128, 0.0),
+ (256, 256, 256, 0.5),
+ (2048, 1024, 512, 0.8),
+ (512, 128, 128, 0.0),
+ (128, 128, 512, 0.0),
+ (1024, 512, 512, 0.0),
+ (1024, 512, 512, 0.5),
+ (1024, 512, 512, 0.75),
+ (512, 512, 1024, 0.0),
+ (512, 512, 1024, 0.5),
+ (512, 512, 1024, 0.75),
+ (1024, 1024, 1024, 0.0),
+ (1024, 1024, 1024, 0.5),
+ (1024, 1024, 1024, 0.75),
+)
+
+_TRANSPOSE = (
+ (False, False),
+ (False, True),
+ (True, False),
+ (True, True),
+)
+
+_DTYPE = (
+ torch.float16, torch.bfloat16
+)
+
+def _generate_testcases():
+ testcases = itertools.product(_MATRIX_SIZES, _TRANSPOSE, _DTYPE)
+ testcases = [(*size, *trans, 128, dtype) for
+ (size, trans, dtype) in testcases]
+ return testcases
+
+_LINEAR_OP_TESTS = _generate_testcases()
+
+def _dense_and_sparse(rows, cols, sparsity, blocking, dtype, std=0.1):
+ mask = stk.random.dense_mask(rows, cols, sparsity, blocking)
+ dense = (torch.randn(rows, cols) * std * mask).type(dtype)
+ sparse = stk.ops.to_sparse(dense, blocking)
+ cuda_device = torch.device("cuda")
+ return (dense.to(cuda_device).requires_grad_(True),
+ sparse.to(cuda_device).requires_grad_(True))
+
+
+def _dense(rows, cols, dtype, std=0.1):
+ cuda_device = torch.device("cuda")
+ out = (torch.randn(rows, cols) * std).type(dtype)
+ return out.to(cuda_device).requires_grad_(True)
+
+
+def _dense_2x(rows, cols, dtype):
+ a = _dense(rows, cols, dtype)
+ return a, a.detach().requires_grad_(True)
+
+
+def _with_transpose(op, a, b, trans_a, trans_b):
+ a = a.t() if trans_a else a
+ b = b.t() if trans_b else b
+ return op(a, b)
+
+
+def _mmm(a, b, topo):
+ mask = stk.ops.to_dense(stk.ops.ones_like(topo))
+ return torch.mm(a, b) * mask
+
+
+def _sparse_out_with_transpose(op, a, b, topo, trans_a, trans_b):
+ a = a.t() if trans_a else a
+ b = b.t() if trans_b else b
+ return op(a, b, topo)
+
+
+def _mask(x, mask):
+ mask = stk.ops.to_dense(stk.ops.ones_like(mask))
+ return x * mask
+
+
+@parameterized.parameters(*_LINEAR_OP_TESTS)
+class LinearOpsTest(parameterized.TestCase):
+
+ def testLinearOps_Dsd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+ # Construct the operands.
+ a_shape = (k, m) if trans_a else (m, k)
+ a_dense, a = _dense_and_sparse(*a_shape, sparsity, blocking, dtype)
+ b_shape = (n, k) if trans_b else (k, n)
+ b, bcp = _dense_2x(*b_shape, dtype)
+
+ # Execute the matmul.
+ out = _with_transpose(stk.ops.dsd, a, b, trans_a, trans_b)
+ expected_out = _with_transpose(torch.mm, a_dense, bcp, trans_a, trans_b)
+
+ # Compute the gradients w.r.t. the inputs.
+ expected_out.sum().backward()
+ out.sum().backward()
+
+ # Validate the results.
+ self.assertEqual(out.dim(), 2)
+ self.assertEqual(expected_out.size()[0], out.size()[0])
+ self.assertEqual(expected_out.size()[1], out.size()[1])
+ self.assertTrue(allclose(out, expected_out))
+
+ # LHS gradient.
+ grad = stk.ops.to_dense(a.grad)
+ expected_grad = _mask(a_dense.grad, a.grad)
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
+ self.assertTrue(allclose(grad, expected_grad))
+
+ # RHS gradient.
+ grad = b.grad
+ expected_grad = bcp.grad
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
+ self.assertTrue(allclose(grad, expected_grad))
+
+ def testLinearOps_Dds(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+ # Construct the operands.
+ a_shape = (k, m) if trans_a else (m, k)
+ a, acp = _dense_2x(*a_shape, dtype)
+ b_shape = (n, k) if trans_b else (k, n)
+ b_dense, b = _dense_and_sparse(*b_shape, sparsity, blocking, dtype)
+
+ # Execute the matmul.
+ out = _with_transpose(stk.ops.dds, a, b, trans_a, trans_b)
+ expected_out = _with_transpose(torch.mm, acp, b_dense, trans_a, trans_b)
+
+ # Compute the gradients w.r.t. the inputs.
+ expected_out.sum().backward()
+ out.sum().backward()
+
+ # Validate the results.
+ self.assertEqual(out.dim(), 2)
+ self.assertEqual(expected_out.size()[0], out.size()[0])
+ self.assertEqual(expected_out.size()[1], out.size()[1])
+ self.assertTrue(allclose(out, expected_out))
+
+ # LHS gradient.
+ grad = a.grad
+ expected_grad = acp.grad
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
+ self.assertTrue(allclose(grad, expected_grad))
+
+ # RHS gradient.
+ grad = stk.ops.to_dense(b.grad)
+ expected_grad = _mask(b_dense.grad, b.grad)
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
+ self.assertTrue(allclose(grad, expected_grad))
+
+ def testLinearOps_Sdd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+ # Construct the operands.
+ a_shape = (k, m) if trans_a else (m, k)
+ a, acp = _dense_2x(*a_shape, dtype)
+ b_shape = (n, k) if trans_b else (k, n)
+ b, bcp = _dense_2x(*b_shape, dtype)
+ _, topo = _dense_and_sparse(m, n, sparsity, blocking, dtype)
+
+ # Execute the matmul.
+ out = _sparse_out_with_transpose(stk.ops.sdd, a, b, topo, trans_a, trans_b)
+ expected_out = _sparse_out_with_transpose(_mmm, acp, bcp, topo, trans_a, trans_b)
+
+ # Compute the gradients w.r.t. the inputs.
+ expected_out.sum().backward()
+ stk.ops.sum(out).backward()
+
+ # Validate the results.
+ out = stk.ops.to_dense(out)
+ self.assertEqual(out.dim(), 2)
+ self.assertEqual(expected_out.size()[0], out.size()[0])
+ self.assertEqual(expected_out.size()[1], out.size()[1])
+ self.assertTrue(allclose(out, expected_out))
+
+ # LHS gradient.
+ grad = a.grad
+ expected_grad = acp.grad
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
+ self.assertTrue(allclose(grad, expected_grad))
+
+ # RHS gradient.
+ grad = b.grad
+ expected_grad = bcp.grad
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
+ self.assertTrue(allclose(grad, expected_grad))
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/stk/ops/matrix_ops.py b/build/torch28-cxx11-cu129-x86_64-linux/stk/ops/matrix_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..447c72dc73439d84f58c917676cc04e64f13e97d
--- /dev/null
+++ b/build/torch28-cxx11-cu129-x86_64-linux/stk/ops/matrix_ops.py
@@ -0,0 +1,98 @@
+from ..backend import sputnik
+from ..matrix import Matrix
+import torch
+import numpy as np
+
+
+@torch.no_grad()
+def row_indices(shape, data, offsets, column_indices):
+ return sputnik.row_indices(shape, data, offsets, column_indices)
+
+
+# TODO(tgale): Replace this helper with a custom kernel. This operation
+# is much simpler to do than how it's currently implemented.
+@torch.no_grad()
+def _expand_for_blocking(idxs, blocking):
+ # Duplicate for block column dimension.
+ idxs = torch.reshape(idxs, [idxs.size()[0], 1, 2]).repeat(1, blocking, 1)
+
+ # Update the column indices.
+ idxs[:, :, 1] *= blocking
+ idxs[:, :, 1] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking])
+
+ # Duplicate for block row dimension.
+ idxs = torch.reshape(idxs, [idxs.size()[0], 1, blocking, 2])
+ idxs = idxs.repeat(1, blocking, 1, 1)
+
+ # Update the row indices.
+ idxs[:, :, :, 0] *= blocking
+ idxs[:, :, :, 0] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking, 1])
+ idxs = torch.reshape(idxs, [-1, 2])
+ return idxs
+
+
+# TODO(tgale): Add input type checking.
+@torch.no_grad()
+def to_dense(x):
+ assert isinstance(x, Matrix)
+
+ shape = (np.prod(x.shape[:-1]), x.shape[-1])
+ row_idxs = x.row_indices.type(torch.int32)
+ col_idxs = x.column_indices.type(torch.int32)
+ indices = _expand_for_blocking(torch.stack([row_idxs, col_idxs], dim=1), x.blocking)
+ indices = (indices[:, 0] * shape[1] + indices[:, 1]).type(torch.int64)
+
+ out = torch.zeros(shape[0] * shape[1], dtype=x.dtype, device=x.device)
+ out.scatter_(0, indices, x.data.flatten())
+ return out.reshape(x.size())
+
+
+@torch.no_grad()
+def _mask(x, blocking=1):
+ assert x.dim() == 2
+ assert x.size()[0] % blocking == 0
+ assert x.size()[1] % blocking == 0
+ block_rows = x.size()[0] // blocking
+ block_cols = x.size()[1] // blocking
+ x = torch.reshape(x, [block_rows, blocking, block_cols, blocking])
+ x = torch.sum(torch.abs(x), dim=(1, 3))
+ return x != 0
+
+
+# TODO(tgale): Add input type checking.
+@torch.no_grad()
+def to_sparse(x, blocking=1):
+ m = _mask(x, blocking)
+
+ # TODO(tgale): Set to appropriate type for input matrix.
+ row_nnzs = torch.sum(m, dim=1).type(torch.int32)
+ zeros = torch.zeros((1,), dtype=row_nnzs.dtype, device=row_nnzs.device)
+ offsets = torch.cat([zeros, torch.cumsum(row_nnzs, dim=0)])
+ offsets = offsets.type(torch.int32)
+
+ indices = torch.nonzero(m).type(torch.int16)
+ row_indices = indices[:, 0]
+ column_indices = indices[:, 1]
+
+ # Nonzero indices in the dense matrix.
+ nonzero_indices = torch.nonzero(m)
+ nonzero_indices = _expand_for_blocking(nonzero_indices, blocking)
+ nonzero_indices = nonzero_indices[:, 0] * x.size()[1] + nonzero_indices[:, 1]
+
+ # Gather the data and construct the sparse matrix.
+ data = torch.gather(x.flatten(), dim=0, index=nonzero_indices)
+ data = torch.reshape(data, [-1, blocking, blocking])
+ return Matrix(x.size(), data, row_indices, column_indices, offsets)
+
+
+@torch.no_grad()
+def ones_like(x):
+ return Matrix(x.size(),
+ torch.ones_like(x.data),
+ x.row_indices,
+ x.column_indices, x.offsets)
+
+
+def sum(x):
+ assert isinstance(x, Matrix)
+ return x.data.sum()
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/stk/ops/matrix_ops_test.py b/build/torch28-cxx11-cu129-x86_64-linux/stk/ops/matrix_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..3af04c0760483e578f93303dc457415948a2a34c
--- /dev/null
+++ b/build/torch28-cxx11-cu129-x86_64-linux/stk/ops/matrix_ops_test.py
@@ -0,0 +1,62 @@
+import unittest
+
+from absl.testing import parameterized
+import stk
+import torch
+
+
+@parameterized.parameters(
+ (8, 16, 0.0, 1),
+ (8, 16, 0.5, 1),
+ (8, 16, .95, 1),
+ (16, 8, 0.0, 1),
+ (16, 8, 0.5, 1),
+ (16, 8, .95, 1),
+ (8, 16, 0.0, 8),
+ (8, 16, 0.5, 8),
+ (8, 16, 1.0, 8),
+ (16, 8, 0.0, 8),
+ (16, 8, 0.5, 8),
+ (16, 8, 1.0, 8),
+ (128, 256, 0.5, 16),
+ (256, 128, 0.75, 32),
+ (512, 512, .875, 128))
+class MatrixOpsTest(parameterized.TestCase):
+
+ def testMatrixOps_FormatConversion(self, rows, cols, sparsity, blocking):
+ mask = stk.random.dense_mask(rows, cols, sparsity, blocking)
+ x = (torch.randn(rows, cols) * mask).type(torch.float16)
+
+ # Convert the matrix to sparse format.
+ sparse_x = stk.ops.to_sparse(x, blocking)
+
+ # Validate the matrix.
+ sparse_x.validate()
+
+ # Validate the shape.
+ self.assertEqual(sparse_x.dim(), 2)
+ self.assertEqual(sparse_x.size()[0], rows)
+ self.assertEqual(sparse_x.size()[1], cols)
+
+ # Validate the sparsity.
+ numblocks = rows // blocking * cols // blocking
+ nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+ self.assertEqual(sparse_x.nnz, nnz)
+
+ # Convert back to dense format.
+ dense_x = stk.ops.to_dense(sparse_x)
+
+ # Validate the shape.
+ self.assertEqual(dense_x.dim(), 2)
+ self.assertEqual(dense_x.size()[0], rows)
+ self.assertEqual(dense_x.size()[1], cols)
+
+ # Validate the sparsity
+ self.assertEqual(torch.count_nonzero(dense_x).item(), nnz)
+
+ # Validate the output.
+ self.assertTrue(torch.all(torch.eq(x, dense_x)))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/stk/random/__init__.py b/build/torch28-cxx11-cu129-x86_64-linux/stk/random/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2576d1ca27283f77569a9a620c7c99fa68aaf30e
--- /dev/null
+++ b/build/torch28-cxx11-cu129-x86_64-linux/stk/random/__init__.py
@@ -0,0 +1,2 @@
+# from stk.random.random_ops import dense_mask, mask, randn
+from .random_ops import dense_mask, mask, randn
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/stk/random/random_ops.py b/build/torch28-cxx11-cu129-x86_64-linux/stk/random/random_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1b36771e0eb8e7abf46bcb3b136b5fb1d29df93
--- /dev/null
+++ b/build/torch28-cxx11-cu129-x86_64-linux/stk/random/random_ops.py
@@ -0,0 +1,36 @@
+import numpy as np
+import torch
+from ..ops import matrix_ops
+
+
+@torch.no_grad()
+def dense_mask(rows, cols, sparsity, blocking=1):
+ assert sparsity >= 0.0 and sparsity <= 1.0
+ assert rows % blocking == 0 and cols % blocking == 0
+
+ block_rows, block_cols = (rows // blocking, cols // blocking)
+ nnz = round(block_rows * block_cols * (1 - sparsity))
+
+ out = np.ones(block_rows * block_cols)
+ mask = np.random.choice(out.size, out.size - nnz, replace=False)
+ out[mask] = 0.0
+
+ out = np.tile(
+ np.reshape(out, [block_rows, 1, block_cols, 1]),
+ (1, blocking, 1, blocking))
+ out = np.reshape(out, [rows, cols])
+ return torch.from_numpy(out.astype(np.float32))
+
+
+@torch.no_grad()
+def mask(m, n, sparsity, blocking=1):
+ out = dense_mask(m, n, sparsity, blocking).type(torch.float16)
+ return matrix_ops.to_sparse(out, blocking=blocking)
+
+
+@torch.no_grad()
+def randn(shape, sparsity, blocking=1):
+ shape_2d = (np.prod(shape[:-1]), shape[-1])
+ out = mask(*shape_2d, sparsity, blocking)
+ out.data.copy_(torch.randn(*out.data.shape))
+ return out.view(*shape)
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/stk/random/random_ops_test.py b/build/torch28-cxx11-cu129-x86_64-linux/stk/random/random_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..587b44ec890c861879c6296b8f9028f5d99ab82f
--- /dev/null
+++ b/build/torch28-cxx11-cu129-x86_64-linux/stk/random/random_ops_test.py
@@ -0,0 +1,73 @@
+import unittest
+
+from absl.testing import parameterized
+from . import random
+import torch
+
+
+@parameterized.parameters(
+ (8, 16, 0.0, 1),
+ (8, 16, 0.5, 1),
+ (8, 16, .95, 1),
+ (16, 8, 0.0, 1),
+ (16, 8, 0.5, 1),
+ (16, 8, .95, 1),
+ (8, 16, 0.0, 8),
+ (8, 16, 0.5, 8),
+ (8, 16, 1.0, 8),
+ (16, 8, 0.0, 8),
+ (16, 8, 0.5, 8),
+ (16, 8, 1.0, 8),
+ (128, 256, 0.5, 16),
+ (256, 128, 0.75, 32),
+ (512, 512, .875, 128))
+class RandomOpsTest(parameterized.TestCase):
+
+ def testRandomOps_DenseMask(self, rows, cols, sparsity, blocking):
+ mask = random.dense_mask(
+ rows, cols, sparsity, blocking)
+
+ # Validate the shape.
+ self.assertEqual(mask.dim(), 2)
+ self.assertEqual(mask.size()[0], rows)
+ self.assertEqual(mask.size()[1], cols)
+
+ # Validate the sparsity
+ numblocks = rows // blocking * cols // blocking
+ nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+ self.assertEqual(
+ torch.count_nonzero(mask).item(),
+ nnz)
+
+ # Check values are zero or one.
+ self.assertTrue(
+ torch.all(torch.logical_or(
+ torch.eq(mask, 0),
+ torch.eq(mask, 1))))
+
+ def testRandomOps_SparseMask(self, rows, cols, sparsity, blocking):
+ mask = random.mask(
+ rows, cols, sparsity, blocking)
+
+ # Validate the matrix.
+ mask.validate()
+
+ # Validate the shape.
+ self.assertEqual(mask.dim(), 2)
+ self.assertEqual(mask.size()[0], rows)
+ self.assertEqual(mask.size()[1], cols)
+
+ # Validate the sparsity.
+ numblocks = rows // blocking * cols // blocking
+ nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+ self.assertEqual(mask.nnz, nnz)
+
+ # Check values are zero or one.
+ self.assertTrue(
+ torch.all(torch.logical_or(
+ torch.eq(mask.data, 0),
+ torch.eq(mask.data, 1))))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/xpu_fused_moe.py b/build/torch28-cxx11-cu129-x86_64-linux/xpu_fused_moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..d2e7c6692f101f9141e9d716c8af6ac92be95351
--- /dev/null
+++ b/build/torch28-cxx11-cu129-x86_64-linux/xpu_fused_moe.py
@@ -0,0 +1,577 @@
+# SPDX-License-Identifier: Apache-2.0
+# MegaBlocks XPU Fused MoE Implementation
+import os
+import torch
+
+from ._ops import ops
+
+
+# Install meta kernels for torch.compile compatibility
+def _install_xpu_meta_kernels():
+ """Install meta kernels for XPU MoE operations to support torch.compile"""
+
+ # Patch cutlass_grouped_gemm_interface
+ if hasattr(ops, "cutlass_grouped_gemm_interface"):
+ original_gemm = ops.cutlass_grouped_gemm_interface
+
+ def gemm_with_meta(ptr_A, ptr_B, ptr_scales, ptr_bias, ptr_D,
+ expert_first_token_offset, N, K, num_experts,
+ is_B_int4, is_B_mxfp4):
+ if torch.compiler.is_compiling():
+ # Meta implementation - ptr_D is the output, return it
+ return ptr_D
+ return original_gemm(ptr_A, ptr_B, ptr_scales, ptr_bias, ptr_D,
+ expert_first_token_offset, N, K, num_experts,
+ is_B_int4, is_B_mxfp4)
+
+ ops.cutlass_grouped_gemm_interface = gemm_with_meta
+
+ # Patch fused_moe_prologue
+ if hasattr(ops, "fused_moe_prologue"):
+ original_prologue = ops.fused_moe_prologue
+
+ def prologue_with_meta(input, token_selected_experts, token_final_scales,
+ workspace, hidden_size, inter_size, num_experts_on_rank):
+ if torch.compiler.is_compiling():
+ # Meta implementation - this op modifies workspace in-place
+ return None
+ return original_prologue(input, token_selected_experts, token_final_scales,
+ workspace, hidden_size, inter_size, num_experts_on_rank)
+
+ ops.fused_moe_prologue = prologue_with_meta
+
+ # Patch moe_gather
+ if hasattr(ops, "moe_gather"):
+ original_gather = ops.moe_gather
+
+ def gather_with_meta(output, moe_output, topk_weights,
+ unpermuted_row_to_permuted_row, num_experts):
+ if torch.compiler.is_compiling():
+ # Meta implementation - output is modified in-place
+ return None
+ return original_gather(output, moe_output, topk_weights,
+ unpermuted_row_to_permuted_row, num_experts)
+
+ ops.moe_gather = gather_with_meta
+
+ # Patch activation ops
+ for act_name in ["silu_and_mul", "gelu_and_mul", "gelu_tanh_and_mul",
+ "gelu_fast", "gelu_new", "gelu_quick", "mul_and_silu",
+ "swigluoai_and_mul"]:
+ if hasattr(ops, act_name):
+ original_act = getattr(ops, act_name)
+
+ def make_act_wrapper(orig_fn):
+ def act_with_meta(*args, **kwargs):
+ if torch.compiler.is_compiling():
+ # Meta implementation - in-place ops, return None
+ return None
+ return orig_fn(*args, **kwargs)
+ return act_with_meta
+
+ setattr(ops, act_name, make_act_wrapper(original_act))
+
+
+# Install meta kernels on module load
+_install_xpu_meta_kernels()
+
+
+# default
+def cutlass_grouped_gemm(input_A, input_B, bias, output, expert_token_count, n,
+ k, num_experts):
+ # expert_token_count_ = torch.tensor(expert_token_count,
+ # dtype=torch.int64,
+ # device=input_A.device)
+ # if bias is not None:
+ # bias = bias.repeat_interleave(expert_token_count_, dim=0).float()
+
+ def exclusive_prefix_sum(arr):
+ prefix = [0]
+ for i, x in enumerate(arr):
+ prefix.append(prefix[-1] + x)
+ return prefix
+
+ expert_offset = torch.tensor(exclusive_prefix_sum(expert_token_count),
+ dtype=torch.int64,
+ device="xpu")
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=None,
+ ptr_bias=bias,
+ ptr_D=output,
+ expert_first_token_offset=expert_offset,
+ N=n,
+ K=k,
+ num_experts=num_experts,
+ is_B_int4=False,
+ is_B_mxfp4=False)
+
+
+def cutlass_grouped_gemm_xe2(input_A, input_B, scales, bias, output,
+ num_rows_per_expert, n, k, num_experts, is_B_int4,
+ is_B_mxfp4):
+ expert_first_token_offset = torch.cat([
+ torch.tensor([0],
+ dtype=num_rows_per_expert.dtype,
+ device=num_rows_per_expert.device),
+ torch.cumsum(num_rows_per_expert, dim=0)
+ ]).to(torch.int64)
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=scales,
+ ptr_bias=bias,
+ ptr_D=output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=n,
+ K=k,
+ num_experts=num_experts,
+ is_B_int4=is_B_int4,
+ is_B_mxfp4=is_B_mxfp4)
+
+
+def ceilDiv(a, b):
+ return (a + b - 1) // b
+
+
+def compute_num_tokens_per_block(num_tokens, num_experts_per_node):
+ for num_tokens_per_block in [32, 64, 128, 256, 512, 1024]:
+ num_blocks_per_seq = ceilDiv(num_tokens, num_tokens_per_block)
+ if num_blocks_per_seq * num_experts_per_node <= num_tokens_per_block:
+ return num_tokens_per_block
+ return 1024
+
+
+def implement_zp(qweight):
+ # change u4 to s4 to avoid zero point in gemm kernel
+ # only support default zero point now
+ assert qweight.dtype == torch.uint8, "Input tensor must be uint8"
+
+ high_u4 = (qweight >> 4) & 0x0F
+ low_u4 = qweight & 0x0F
+
+ high_s8 = high_u4.to(torch.int8)
+ low_s8 = low_u4.to(torch.int8)
+
+ high_s8 = high_s8 - 8
+ low_s8 = low_s8 - 8
+
+ def pack_compact(a, b):
+
+ def process_number(x):
+ sign = (x < 0).to(torch.uint8)
+ abs_low3 = (x.view(torch.uint8) & 0x7).to(torch.uint8)
+ return (sign << 3) | abs_low3
+
+ packed_a = process_number(a)
+ packed_b = process_number(b)
+
+ return (packed_a << 4) | packed_b
+
+ result = pack_compact(high_s8, low_s8)
+
+ return result
+
+
+def xpu_fused_moe(hidden_states,
+ w13,
+ w13_scales,
+ w13_bias,
+ w2,
+ w2_scales,
+ w2_bias,
+ topk_weights,
+ topk_ids,
+ n_experts_per_token,
+ activation,
+ num_experts,
+ is_fp8=False,
+ is_int4=False,
+ is_mxfp4=False):
+ '''
+ hidden_states: [num_rows, hidden_size]
+ w13: [num_experts, 2*inter_size, hidden_size]
+ w13_scales:
+ None for bf16/fp16
+ or [num_experts] for fp8
+ or [num_experts, 2*inter_size, hidden_size // group_size] for 4bits
+ w13_bias: [num_experts, 2*inter_size] or None
+ w2: [num_experts, hidden_size, inter_size]
+ w2_scales:
+ None for bf16/fp16
+ or [num_experts] for fp8
+ or [num_experts, hidden_size, inter_size // group_size] for 4bits
+ w2_bias: [num_experts, hidden_size] or None
+ topk_weights: [num_rows, topk]
+ topk_ids: [num_rows, topk]
+ n_experts_per_token: int
+ activation: str
+ num_experts: int
+ is_int4: bool
+ is_mxfp4: bool
+ '''
+
+ output = torch.empty_like(hidden_states)
+ num_rows, hidden_size = list(hidden_states.shape)
+
+ dim_last = w13.shape[-1]
+ dim_second_last = w13.shape[-2]
+
+ # w13 is combined gate+up weights, so one dimension is 2*inter_size
+ # Determine which dimension is hidden_size and which is 2*inter_size
+ if dim_second_last == hidden_size:
+ # w13 is [E, hidden_size, 2*inter_size] - standard layout
+ inter_size = dim_last // 2
+ needs_transpose = False
+ else:
+ # w13 is [E, 2*inter_size, hidden_size] - needs transpose
+ inter_size = dim_second_last // 2
+ needs_transpose = True
+
+ assert w13.is_contiguous() and w2.is_contiguous()
+
+ # 4bits support [E, N, K]
+ # other types [E, K, N]
+ if not is_int4 and not is_mxfp4:
+ if not hasattr(w13, 'xpu_fused_moe'):
+ if needs_transpose:
+ w13.data = w13.transpose(-1, -2).contiguous()
+ w2.data = w2.transpose(-1, -2).contiguous()
+ w13.xpu_fused_moe = True
+ w13.inter_size = inter_size
+ else:
+ inter_size = w13.inter_size
+
+ if is_int4 and not hasattr(w13, 'xpu_fused_moe'):
+ w13_tmp = torch.empty_like(w13)
+ w2_tmp = torch.empty_like(w2)
+ for i in range(num_experts):
+ w13_tmp[i] = implement_zp(w13[i])
+ w2_tmp[i] = implement_zp(w2[i])
+ w13_tmp = w13_tmp.contiguous()
+ w2_tmp = w2_tmp.contiguous()
+ w13.data = w13_tmp
+ w2.data = w2_tmp
+ w13.xpu_fused_moe = True
+
+ # TODO: will all integrated in Cpp func. Temporary expose before gemm fusion
+ num_experts_per_node = num_experts
+ experts_per_token = n_experts_per_token
+ num_moe_inputs = n_experts_per_token * num_rows
+ permuted_elems = num_moe_inputs * hidden_size
+ # interbuf_elems = num_moe_inputs * inter_size
+ permuted_row_to_unpermuted_row_size = num_moe_inputs * 4
+ permuted_token_selected_experts_size = num_moe_inputs * 4
+ src_to_dest_map_size = experts_per_token * num_rows * 4
+ expert_first_token_offset_size = (num_experts_per_node + 1) * 8
+ num_tokens_per_block = compute_num_tokens_per_block(
+ num_rows, num_experts_per_node)
+ num_blocks_per_seq = ceilDiv(num_rows, num_tokens_per_block)
+ blocked_expert_counts_size = num_experts_per_node * num_blocks_per_seq * 4
+ blocked_expert_counts_cumsum_size = blocked_expert_counts_size
+ blocked_row_to_unpermuted_row_size = num_experts_per_node * num_rows * 4
+ permuted_data_size = permuted_elems * hidden_states.element_size()
+ permuted_token_final_scales_size = num_moe_inputs * 4
+
+ ws_map = {}
+ map_offset = 0
+
+ def config_ws(name, size):
+ nonlocal map_offset
+ if size % 256 != 0:
+ size += 256 - size % 256
+ ws_map[name] = (size, map_offset)
+ map_offset += size
+
+ config_ws("permuted_row_to_unpermuted_row",
+ permuted_row_to_unpermuted_row_size)
+ config_ws("permuted_token_selected_experts",
+ permuted_token_selected_experts_size)
+ config_ws("unpermuted_row_to_permuted_row", src_to_dest_map_size)
+ config_ws("blocked_expert_counts", blocked_expert_counts_size)
+ config_ws("blocked_expert_counts_cumsum",
+ blocked_expert_counts_cumsum_size)
+ config_ws("blocked_row_to_unpermuted_row",
+ blocked_row_to_unpermuted_row_size)
+ config_ws("expert_first_token_offset", expert_first_token_offset_size)
+ config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
+ config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
+
+ workspace = torch.zeros(map_offset,
+ dtype=torch.uint8,
+ device=hidden_states.device)
+ if topk_ids.dtype == torch.int32:
+ topk_ids = topk_ids.to(torch.int64)
+ ops.fused_moe_prologue(
+ input=hidden_states,
+ token_selected_experts=topk_ids,
+ token_final_scales=topk_weights,
+ workspace=workspace,
+ hidden_size=hidden_size,
+ inter_size=inter_size,
+ num_experts_on_rank=num_experts_per_node)
+
+ expert_first_token_offset = workspace[
+ ws_map["expert_first_token_offset"][1]:
+ ws_map["expert_first_token_offset"][1] +
+ expert_first_token_offset_size].view(torch.int64)
+ unpermuted_row_to_permuted_row = workspace[
+ ws_map["unpermuted_row_to_permuted_row"][1]:
+ ws_map["unpermuted_row_to_permuted_row"][1] +
+ src_to_dest_map_size].view(torch.int32)
+ gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
+ ws_map["overlapped_gemm1_gemm2_inputs"][1] +
+ permuted_data_size].view(hidden_states.dtype).view(
+ num_moe_inputs, hidden_size)
+ # permuted_token_final_scales = workspace[
+ # ws_map["permuted_token_final_scales"][1]:
+ # ws_map["permuted_token_final_scales"][1] +
+ # permuted_token_final_scales_size].view(torch.float)
+ gemm1_output = torch.empty((num_moe_inputs, 2 * inter_size),
+ dtype=hidden_states.dtype,
+ device=hidden_states.device)
+
+ ########### gemm1 ##################
+ input_B = w13
+
+ if not is_fp8 and not is_int4 and not is_mxfp4:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=gemm1_input,
+ ptr_B=input_B,
+ ptr_scales=None,
+ ptr_bias=w13_bias,
+ ptr_D=gemm1_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=2 * inter_size,
+ K=hidden_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+ else:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=gemm1_input,
+ ptr_B=input_B,
+ ptr_scales=w13_scales,
+ ptr_bias=w13_bias,
+ ptr_D=gemm1_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=2 * inter_size,
+ K=hidden_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+
+ # act
+ act_output = torch.empty((num_moe_inputs, inter_size),
+ dtype=gemm1_output.dtype,
+ device=gemm1_output.device)
+ if activation == "silu":
+ ops.silu_and_mul(act_output, gemm1_output)
+ elif activation == "gelu":
+ ops.gelu_and_mul(act_output, gemm1_output)
+ elif activation == "swigluoai":
+ ops.swigluoai_and_mul(act_output, gemm1_output, 1.702, 7.0)
+ else:
+ raise ValueError(f"Unsupported FusedMoe activation: {activation}.")
+
+ ########### gemm2 ##################
+ input_A = act_output.contiguous()
+ input_B = w2
+ gemm2_output = torch.empty((num_moe_inputs, hidden_size),
+ dtype=hidden_states.dtype,
+ device=hidden_states.device)
+ if not is_fp8 and not is_int4 and not is_mxfp4:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=None,
+ ptr_bias=w2_bias,
+ ptr_D=gemm2_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=hidden_size,
+ K=inter_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+ else:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=w2_scales,
+ ptr_bias=w2_bias,
+ ptr_D=gemm2_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=hidden_size,
+ K=inter_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+
+ ops.moe_gather(output, gemm2_output, topk_weights,
+ unpermuted_row_to_permuted_row,
+ num_experts_per_node)
+ return output
+
+
+def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor:
+ """Apply jitter to the input tensor for regularization."""
+ low = 1.0 - moe_jitter_eps
+ high = 1.0 + moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return x * (low + noise * (high - low))
+
+
+def compute_top_k(scores: torch.Tensor, moe_top_k: int):
+ """Compute the top-k scores from the logits."""
+ if moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, moe_top_k, dim=-1)
+
+
+def route_tokens_xpu(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: torch.Tensor,
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ training: bool = False,
+) -> tuple:
+ """Route tokens to experts and compute expert weights and indices (XPU version)."""
+ if training and moe_jitter_eps is not None:
+ x = apply_jitter(x, moe_jitter_eps)
+
+ x_flat = x.view(-1, x.shape[-1])
+ logits = torch.nn.functional.linear(x_flat, router_weight, router_bias)
+ expert_weights, expert_indices = compute_top_k(logits, moe_top_k)
+ expert_weights = expert_weights.softmax(dim=-1)
+ if moe_normalize_expert_weights is not None:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+
+ return logits, expert_weights, expert_indices
+
+
+class MegaBlocksMoeMLP(torch.nn.Module):
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> tuple:
+ """
+ Forward pass through the MoE layer.
+
+ Args:
+ x: Input tensor of shape [batch_size, seq_len, hidden_size] or [tokens, hidden_size]
+
+ Returns:
+ Tuple of (output, expert_weights) where:
+ - output: Tensor of same shape as input
+ - expert_weights: Expert weights for each token [tokens, top_k]
+ """
+
+ # Get MoE parameters from the wrapped modules
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+
+ # Detect activation type - check for GptOss-style swigluoai activation
+ # GptOssExperts has alpha and limit attributes for swigluoai
+ if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
+ activation = "swigluoai"
+ else:
+ activation = getattr(self.experts, "activation", "silu")
+
+ # Get weight tensors - support different naming conventions
+ if hasattr(self.experts, "gate_up_proj"):
+ w13 = self.experts.gate_up_proj
+ # NOTE: swigluoai_and_mul kernel expects interleaved layout [g0,u0,g1,u1,...]
+ # which matches GptOss's gate_up_proj layout, so no conversion needed.
+
+ elif hasattr(self.experts, "w1"):
+ # Combine w1 and w3 if stored separately
+ w1 = self.experts.w1
+ w3 = getattr(self.experts, "w3", None)
+ if w3 is not None:
+ w13 = torch.cat([w1, w3], dim=-2)
+ else:
+ w13 = w1
+ else:
+ raise AttributeError("experts module must have 'gate_up_proj' or 'w1' attribute")
+
+ if hasattr(self.experts, "down_proj"):
+ w2 = self.experts.down_proj
+ elif hasattr(self.experts, "w2"):
+ w2 = self.experts.w2
+ else:
+ raise AttributeError("experts module must have 'down_proj' or 'w2' attribute")
+
+ # Get optional bias tensors
+ w13_bias = getattr(self.experts, "gate_up_proj_bias", None)
+ w2_bias = getattr(self.experts, "down_proj_bias", None)
+
+ # Get quantization info
+ is_fp8 = getattr(self.experts, "is_fp8", False)
+ is_int4 = getattr(self.experts, "is_int4", False)
+ is_mxfp4 = getattr(self.experts, "is_mxfp4", False)
+
+ w13_scales = getattr(self.experts, "gate_up_proj_scales", None)
+ w2_scales = getattr(self.experts, "down_proj_scales", None)
+
+ # Store original shape
+ in_shape = x.size()
+
+ # Route tokens to experts
+ logits, expert_weights, expert_indices = route_tokens_xpu(
+ x,
+ self.router.weight,
+ getattr(self.router, "bias", None),
+ moe_top_k,
+ moe_num_experts,
+ moe_jitter_eps,
+ moe_normalize_expert_weights,
+ self.training,
+ )
+
+ # Reshape input for fused MoE
+ x_flat = x.view(-1, x.shape[-1])
+
+ # Call XPU fused MoE kernel
+ output = xpu_fused_moe(
+ hidden_states=x_flat,
+ w13=w13,
+ w13_scales=w13_scales,
+ w13_bias=w13_bias,
+ w2=w2,
+ w2_scales=w2_scales,
+ w2_bias=w2_bias,
+ topk_weights=expert_weights.float(),
+ topk_ids=expert_indices,
+ n_experts_per_token=moe_top_k,
+ activation=activation,
+ num_experts=moe_num_experts,
+ is_fp8=is_fp8,
+ is_int4=is_int4,
+ is_mxfp4=is_mxfp4,
+ )
+
+ # Restore original shape
+ output = output.view(in_shape)
+
+ return output, expert_weights
+
+
+# Export classes and functions
+__all__ = [
+ "MegaBlocksMoeMLP",
+ "xpu_fused_moe",
+ "cutlass_grouped_gemm",
+ "cutlass_grouped_gemm_xe2",
+]
\ No newline at end of file
diff --git a/build/torch29-cxx11-cpu-x86_64-linux/__init__.py b/build/torch29-cxx11-cpu-x86_64-linux/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..38075732c6d8fa0e1e6ef493145e1aca3851ae6b
--- /dev/null
+++ b/build/torch29-cxx11-cpu-x86_64-linux/__init__.py
@@ -0,0 +1,202 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+from ._ops import ops
+
+from .grouped_gemm import backend as gg_backend
+from .grouped_gemm import ops as gg_ops
+
+
+from ._layers.arguments import Arguments
+from ._layers.dmoe import ParallelDroplessMLP, dMoE
+from ._layers.glu import SparseGLU
+from ._layers.mlp import MLP, SparseMLP
+from ._layers.moe import MoE, ParallelMLP, get_load_balancing_loss
+
+from . import layers
+
+# This section contains the direct kernel exports (not inlcuded in the original code)
+def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
+ """
+ Compute exclusive cumulative sum along the specified dimension.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ result = ops.exclusive_cumsum(x, dim)
+ out.copy_(result)
+ return out
+
+
+def inclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
+ """
+ Compute inclusive cumulative sum along the specified dimension.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ result = ops.inclusive_cumsum(x, dim)
+ out.copy_(result)
+ return out
+
+
+def histogram(x: torch.Tensor, num_bins: int) -> torch.Tensor:
+ """
+ Compute histogram of input tensor values.
+
+ Args:
+ x: Input tensor
+ num_bins: Number of histogram bins
+
+ Returns:
+ Histogram tensor with counts for each bin
+ """
+ return ops.histogram(x, num_bins)
+
+
+def indices(
+ padded_bins: torch.Tensor,
+ block_size: int,
+ output_block_rows: int,
+ output_block_columns: int,
+) -> torch.Tensor:
+ """
+ Construct indices from padded bins for sparse operations.
+
+ Args:
+ padded_bins: Tensor containing bin boundaries
+ block_size: Size of each block
+ output_block_rows: Number of rows in output blocks
+ output_block_columns: Number of columns in output blocks
+
+ Returns:
+ Tensor containing constructed indices
+ """
+ return ops.indices(padded_bins, block_size, output_block_rows, output_block_columns)
+
+
+def replicate_forward(
+ x: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Forward pass of replicate operation - replicate values according to bin sizes.
+
+ Args:
+ x: Input tensor with values to replicate
+ bins: Tensor containing bin sizes
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ return ops.replicate_forward(x, bins, out)
+
+
+def replicate_backward(
+ grad: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Backward pass of replicate operation - reduce gradients back to bins.
+
+ Args:
+ grad: Gradient tensor to reduce
+ bins: Tensor containing bin sizes
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ return ops.replicate_backward(grad, bins, out)
+
+
+def sort(
+ x: torch.Tensor, end_bit: int, x_out: torch.Tensor, iota_out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Radix sort with index tracking.
+
+ Args:
+ x: Input tensor to sort
+ end_bit: Number of bits to consider in sorting
+ x_out: Output tensor for sorted values
+ iota_out: Output tensor for sorted indices
+
+ Returns:
+ The sorted values tensor
+ """
+ return ops.sort(x, end_bit, x_out, iota_out)
+
+
+# Convenience functions for common use cases
+def cumsum(x: torch.Tensor, dim: int = -1, exclusive: bool = False) -> torch.Tensor:
+ """
+ Compute cumulative sum with automatic output allocation.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum (default: last dimension)
+ exclusive: Whether to compute exclusive (True) or inclusive (False) cumsum
+
+ Returns:
+ New tensor containing the cumulative sum
+ """
+ out = torch.empty_like(x)
+ if exclusive:
+ return exclusive_cumsum(x, dim, out)
+ else:
+ return inclusive_cumsum(x, dim, out)
+
+
+def argsort(x: torch.Tensor, end_bit: int = 32) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Sort tensor and return both sorted values and indices.
+
+ Args:
+ x: Input tensor to sort
+ end_bit: Number of bits to consider in sorting
+
+ Returns:
+ Tuple of (sorted_values, sorted_indices)
+ """
+ x_out = torch.empty_like(x)
+ iota_out = torch.empty_like(x)
+ sort(x, end_bit, x_out, iota_out)
+ return x_out, iota_out
+
+
+# Export public API
+__all__ = [
+ "MyReplacementLayer",
+ # Direct kernel exports
+ "exclusive_cumsum",
+ "inclusive_cumsum",
+ "histogram",
+ "indices",
+ "replicate_forward",
+ "replicate_backward",
+ "sort",
+ "cumsum",
+ "argsort",
+ # Original exports
+ "Arguments",
+ "ParallelDroplessMLP",
+ "dMoE",
+ "SparseGLU",
+ "MLP",
+ "SparseMLP",
+ "MoE",
+ "ParallelMLP",
+ "get_load_balancing_loss",
+]
diff --git a/build/torch29-cxx11-cpu-x86_64-linux/_layers/__init__.py b/build/torch29-cxx11-cpu-x86_64-linux/_layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a720e7a2cc4e44636f6e433a2750e945dc38e8b2
--- /dev/null
+++ b/build/torch29-cxx11-cpu-x86_64-linux/_layers/__init__.py
@@ -0,0 +1,10 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# from megablocks.layers.dmoe import dMoE
+from .moe import MoE
+
+__all__ = [
+ 'MoE',
+ # 'dMoE',
+]
diff --git a/build/torch29-cxx11-cpu-x86_64-linux/_layers/activation_fn.py b/build/torch29-cxx11-cpu-x86_64-linux/_layers/activation_fn.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e1d956704840aa4daf7d1d71d24e051567feab9
--- /dev/null
+++ b/build/torch29-cxx11-cpu-x86_64-linux/_layers/activation_fn.py
@@ -0,0 +1,33 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Callable, Union
+
+import torch
+from ..stk import Matrix
+
+
+def act_fn(
+ x: Matrix,
+ function: Callable,
+ return_grad_fn: bool = False,
+ **kwargs,
+) -> Union[tuple[Matrix, Any] | Matrix]:
+ assert isinstance(x, Matrix)
+ with torch.set_grad_enabled(torch.is_grad_enabled() or return_grad_fn):
+ if return_grad_fn:
+ x.data.requires_grad = True
+ out = function(x.data, **kwargs)
+ y = Matrix(
+ x.size(),
+ out,
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
+ if return_grad_fn:
+ return y, out.backward
+ return y
diff --git a/build/torch29-cxx11-cpu-x86_64-linux/_layers/all_to_all.py b/build/torch29-cxx11-cpu-x86_64-linux/_layers/all_to_all.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ac7067bcaa34db1d82b340c43550fe3577aa7a3
--- /dev/null
+++ b/build/torch29-cxx11-cpu-x86_64-linux/_layers/all_to_all.py
@@ -0,0 +1,54 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import torch.distributed as dist
+
+
+class AllToAllOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op):
+ out = torch.empty((sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype)
+
+ ctx.input_shape = x.shape
+ ctx.output_split_sizes = output_split_sizes
+ ctx.input_split_sizes = input_split_sizes
+ ctx.group = group
+ handle = dist.all_to_all_single(
+ out,
+ x,
+ output_split_sizes=output_split_sizes,
+ input_split_sizes=input_split_sizes,
+ group=group,
+ async_op=async_op,
+ )
+ return out, handle
+
+ @staticmethod
+ def backward(ctx, grad, _):
+ if ctx.needs_input_grad[0]:
+ out = torch.empty(
+ ctx.input_shape,
+ device=grad.device,
+ dtype=grad.dtype,
+ )
+ dist.all_to_all_single(
+ out,
+ grad,
+ output_split_sizes=ctx.input_split_sizes,
+ input_split_sizes=ctx.output_split_sizes,
+ group=ctx.group,
+ )
+ return out, None, None, None, None
+ return None, None, None, None, None
+
+
+def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False):
+ return AllToAllOp.apply(
+ x,
+ output_split_sizes,
+ input_split_sizes,
+ group,
+ async_op,
+ )
diff --git a/build/torch29-cxx11-cpu-x86_64-linux/_layers/arguments.py b/build/torch29-cxx11-cpu-x86_64-linux/_layers/arguments.py
new file mode 100644
index 0000000000000000000000000000000000000000..4db9b1bd38bc2e2f421625c124f86b85f45c5ae0
--- /dev/null
+++ b/build/torch29-cxx11-cpu-x86_64-linux/_layers/arguments.py
@@ -0,0 +1,101 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import dataclasses
+from functools import partial
+from typing import Any, Callable, Optional, Union
+
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+
+# import megablocks.grouped_gemm_util as grouped_gemm
+from .. import grouped_gemm_util as grouped_gemm
+
+# Type annotation for in-place Tensor initialization function.
+InitFn = Union[Callable[[torch.Tensor], None], partial[torch.Tensor]]
+
+_ALLOWED_BITWIDTHS = (-1, 4, 8)
+
+DEFAULT_ACTIVATION_FN = partial(F.gelu, approximate='tanh')
+
+
+@dataclasses.dataclass
+class Arguments:
+ # Model arguments.
+ hidden_size: int = 1024
+ ffn_hidden_size: int = 4096
+ num_layers: int = 1
+ bias: bool = True
+ return_bias: bool = True
+ activation_fn: Optional[Callable] = DEFAULT_ACTIVATION_FN
+
+ # MoE arguments.
+ moe_num_experts: int = 1
+ moe_top_k: int = 1
+ moe_capacity_factor: int = 1
+ moe_normalize_expert_weights: Optional[Union[int, float]] = None
+ moe_loss_weight: float = 0.1
+ moe_jitter_eps: Optional[float] = None
+ moe_lbl_in_fp32: bool = False
+
+ # Parallelism arguments.
+ moe_expert_model_parallelism: bool = False
+ expert_parallel_group: Optional[dist.ProcessGroup] = None
+ pipeline_model_parallel_size: int = 1
+ num_layers_per_virtual_pipeline_stage: Optional[int] = None
+
+ # Compute arguments.
+ memory_optimized_mlp: bool = False
+ mlp_type: str = 'mlp'
+ mlp_impl: str = 'sparse'
+
+ # Initialization arguments.
+ fp16: bool = True
+ bf16: bool = False
+ device: Union[int, torch.device] = dataclasses.field(default_factory=torch.cuda.current_device)
+ init_method: InitFn = partial(torch.nn.init.normal_, mean=0.0, std=0.02)
+ output_layer_init_method: InitFn = init_method
+
+ # Benchmarking arguments.
+ uniform_expert_assignment: bool = False
+
+ # shared expert arguments
+ shared_expert: bool = False # enable using shared expert
+ fc_cls: Any = torch.nn.Linear # class of the fully connected layer in shared expert (purpose: to allow using custom FC layer eg te.Linear (for FP8))
+ fc_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict,) # kwargs for custom fc layers
+ remat_act_fn: bool = True # enable act fn to be rematerialized instead of stored
+ shared_expert_hidden_size: Optional[
+ int] = None # hidden size of the shared expert IF we want to set it to something different from hidden_size
+ shared_expert_weighted_sum: bool = False # enable using weighted sum for shared expert output (wieghted by number of experts used)
+
+ # Router Z-loss arguments
+ moe_zloss_weight: float = 0 # 1e-3 is a reasonable value
+ moe_zloss_in_fp32: bool = False
+
+ def __post_init__(self):
+ # Sparse MLP is not supported with triton >=3.2.0
+ # TODO: Remove this once sparse is supported with triton >=3.2.0
+ if self.__getattribute__('mlp_impl') == 'sparse':
+ try:
+ import triton
+ if triton.__version__ >= '3.2.0':
+ raise ValueError(
+ 'Sparse MLP is not supported with triton >=3.2.0. Please use mlp_impl="grouped" instead.',
+ )
+ except ImportError:
+ raise ImportError('Triton is required for sparse MLP implementation')
+
+ if self.__getattribute__('mlp_impl') == 'grouped':
+ grouped_gemm.assert_grouped_gemm_is_available()
+
+ if self.shared_expert_hidden_size is None:
+ self.shared_expert_hidden_size = self.ffn_hidden_size
+
+
+def from_megatron(megatron_args: Any):
+ args = Arguments()
+ for field in dataclasses.fields(args):
+ if hasattr(megatron_args, field.name):
+ setattr(args, field.name, getattr(megatron_args, field.name))
+ return args
diff --git a/build/torch29-cxx11-cpu-x86_64-linux/_layers/common.py b/build/torch29-cxx11-cpu-x86_64-linux/_layers/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d07109702963ba48a3b94ab860807954dfd79c1
--- /dev/null
+++ b/build/torch29-cxx11-cpu-x86_64-linux/_layers/common.py
@@ -0,0 +1,26 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+from .arguments import Arguments
+
+
+def dtype(args: Arguments):
+ if args.fp16:
+ return torch.float16
+ elif args.bf16:
+ return torch.bfloat16
+ return None
+
+
+def cast_if_autocast_enabled(tensor):
+ if torch.is_autocast_enabled():
+ if tensor.device.type == 'cuda':
+ dtype = torch.get_autocast_gpu_dtype()
+ elif tensor.device.type == 'cpu':
+ dtype = torch.get_autocast_cpu_dtype()
+ else:
+ raise NotImplementedError()
+ return tensor.to(dtype=dtype)
+ return tensor
diff --git a/build/torch29-cxx11-cpu-x86_64-linux/_layers/dmlp_registry.py b/build/torch29-cxx11-cpu-x86_64-linux/_layers/dmlp_registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..de2ed047042e438c7190ebb139b6f7f30009734c
--- /dev/null
+++ b/build/torch29-cxx11-cpu-x86_64-linux/_layers/dmlp_registry.py
@@ -0,0 +1,42 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Union
+
+from . import glu, mlp
+from .arguments import Arguments
+
+MlpType = Union[mlp.SparseMLP, glu.SparseGLU]
+
+_REGISTRY = {
+ 'mlp': {
+ 'grouped': mlp.GroupedMLP,
+ 'sparse': mlp.SparseMLP,
+ },
+ 'glu': {
+ 'grouped': glu.GroupedGLU,
+ 'sparse': glu.SparseGLU,
+ },
+}
+
+
+def get(args: Arguments) -> MlpType:
+ """Returns an MLP for use in a dMoE instance.
+
+ Uses the provided arguments to instantiate the appropriate
+ MLP instance. This only contains MLPs for use in dMoEs
+ (ie. only for the dropless versions of MoEs).
+
+ Args:
+ args: propagated Arguments dataclass.
+
+ Returns:
+ An instantiated MLP constructed using the input args.
+ """
+ if args.mlp_type not in _REGISTRY:
+ raise ValueError(f'Unsupported mlp type: {args.mlp_type}')
+
+ if args.mlp_impl not in _REGISTRY[args.mlp_type]:
+ raise ValueError(f'{args.mlp_type} does not support {args.mlp_impl} backend.',)
+
+ return _REGISTRY[args.mlp_type][args.mlp_impl](args)
diff --git a/build/torch29-cxx11-cpu-x86_64-linux/_layers/dmoe.py b/build/torch29-cxx11-cpu-x86_64-linux/_layers/dmoe.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d0375a4df2f27134c4127e60be04f3b45693050
--- /dev/null
+++ b/build/torch29-cxx11-cpu-x86_64-linux/_layers/dmoe.py
@@ -0,0 +1,337 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import numpy as np
+import torch
+
+# try:
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/dmoe.py is needed.',
+# )
+
+# import megablocks.ops as ops
+# # from megablocks.ops import ops
+# from megablocks.layers import common, dmlp_registry, moe, mpu
+# from megablocks.layers.arguments import Arguments
+
+from .. import stk
+from .. import ops
+from . import common, dmlp_registry, moe, mpu
+from .arguments import Arguments
+
+def promote_scalar(x):
+ return x.view(1) if not len(x.size()) else x
+
+
+class ParallelDroplessMLP(moe.ParallelMLP):
+
+ def __init__(self, args: Arguments):
+ super(ParallelDroplessMLP, self).__init__(args)
+ self.hidden_size = args.hidden_size
+ self.ffn_hidden_size = mpu.features_per_rank(args)
+ self.blocking = 128
+ self.mlp = dmlp_registry.get(args)
+
+ # Calculate the number of bits needed to represent the column indices
+ # in the intermediate sparse matrix.
+ max_column_index = ((self.ffn_hidden_size * self.num_experts) // self.blocking)
+ self.transpose_sort_end_bit = max(
+ int(np.ceil(np.log2(max_column_index))),
+ 1,
+ )
+
+ def sparse_transpose(self, size, row_indices, column_indices, offsets):
+ block_columns = size[1] // self.blocking
+
+ # Sort row indices by column indices to get the transposed matrix's
+ # column indices.
+ #
+ # NOTE: Our sort operation uses the same width indices as the input values.
+ # To avoid overflow when we have large activation matrices we cast to
+ # 32-bit before sorting.
+ _, gather_indices = ops.sort(
+ column_indices.int(),
+ self.transpose_sort_end_bit,
+ )
+
+ # There are a constant number of blocks in every row of the sparse matrix.
+ # A blocks offset is:
+ #
+ # row_index * blocks_per_row + column_index % blocks_per_row
+ #
+ # Once we have the block offsets ordered for transposition we can divide
+ # by blocks_per_row to get the transposed column indices.
+ column_indices_t = row_indices.gather(0, gather_indices.long())
+ block_offsets_t = gather_indices.int()
+
+ zero = torch.zeros((1,), dtype=torch.int32, device=row_indices.device)
+ nnz_per_column = ops.histogram(column_indices, block_columns)
+ nnz_per_column = ops.inclusive_cumsum(nnz_per_column, 0)
+ if nnz_per_column.dim() == 0:
+ # This addresses an edge case when ffn_hidden_size is equal to self.blocking.
+ nnz_per_column = nnz_per_column.unsqueeze(0)
+ offsets_t = torch.cat([zero, nnz_per_column])
+ return column_indices_t, offsets_t, block_offsets_t
+
+ def topology(self, x, padded_bins):
+ padded_tokens, _ = x.size()
+ assert padded_tokens % self.blocking == 0
+ if self.ffn_hidden_size % self.blocking != 0:
+ raise ValueError(
+ f'The ffn_hidden_size {self.ffn_hidden_size} must be divisible by ' +
+ f'the block size {self.blocking}. Please update your configuration.',
+ )
+
+ # Offsets for the sparse matrix. All rows have the
+ # same number of nonzero blocks dictated by the
+ # dimensionality of a single expert.
+ block_rows = padded_tokens // self.blocking
+ blocks_per_row = self.ffn_hidden_size // self.blocking
+ offsets = torch.arange(
+ 0,
+ block_rows * blocks_per_row + 1,
+ blocks_per_row,
+ dtype=torch.int32,
+ device=x.device,
+ )
+
+ # Indices for the sparse matrix. The indices for
+ # the intermediate matrix are dynamic depending
+ # on the mapping of tokens to experts.
+ column_indices = ops.topology(
+ padded_bins,
+ self.blocking,
+ block_rows,
+ blocks_per_row,
+ )
+
+ # TODO(tgale): This is unused. Remove the need for this in stk.
+ # For now, use meta init to save the device memory.
+ data = torch.empty(
+ column_indices.numel(),
+ self.blocking,
+ self.blocking,
+ dtype=common.dtype(self.args),
+ device='meta',
+ )
+ shape = (
+ padded_tokens,
+ self.ffn_hidden_size * mpu.experts_per_rank(self.args),
+ )
+ row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
+ column_indices_t, offsets_t, block_offsets_t = self.sparse_transpose(
+ shape,
+ row_indices,
+ column_indices,
+ offsets,
+ )
+ return stk.Matrix(
+ shape,
+ data,
+ row_indices,
+ column_indices,
+ offsets,
+ column_indices_t,
+ offsets_t,
+ block_offsets_t,
+ )
+
+ def indices_and_padded_bins(self, top_experts):
+ # Sort the expert ids to produce the scatter/gather
+ # indices for the permutation.
+ top_experts = top_experts.int()
+ bin_ids, indices = ops.sort(top_experts, self.sort_end_bit)
+
+ # Histogram the expert ids to identify the number of
+ # tokens routed to each expert.
+ tokens_per_expert = ops.histogram(top_experts, self.num_experts)
+
+ # Round the token counts up to the block size used in
+ # the matrix muliplications. Caculate the starting
+ # position of each bin.
+ padded_tokens_per_expert = ops.round_up(
+ tokens_per_expert,
+ self.blocking,
+ )
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ padded_bins = promote_scalar(padded_bins)
+
+ # Calculate the bin bounds for the sorted tokens.
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ bins = promote_scalar(bins)
+ return indices, bin_ids, bins, padded_bins, tokens_per_expert
+
+ def sparse_forward_once(self, x, expert_weights, top_experts):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, padded_bins, tokens_per_expert = (self.indices_and_padded_bins(top_experts))
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.padded_gather(
+ x,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ self.top_k,
+ )
+
+ # Create the sparse matrix topology.
+ with torch.no_grad():
+ topo = self.topology(x, padded_bins)
+
+ # Perform the expert computation.
+ x = self.mlp(x, topo)
+
+ # Un-route the data for the MoE output.
+ x = ops.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ padded_bins,
+ self.top_k,
+ )
+ return x, tokens_per_expert
+
+ # For use in the base-class parallel_forward_once.
+ def sparse_permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy, # unused
+ top_k,
+ ):
+
+ # Round the token counts up to the block size used in the matrix
+ # multiplication. Calculate the starting position of each bin.
+ padded_tokens_per_expert = ops.round_up(
+ tokens_per_expert,
+ self.blocking,
+ )
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ padded_bins = promote_scalar(padded_bins)
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
+
+ # Create the sparse matrix topology.
+ with torch.no_grad():
+ topo = self.topology(x, padded_bins)
+
+ # Perform the expert computation.
+ x = self.mlp(x, topo)
+
+ # Un-route the data for the MoE output.
+ return ops.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ def grouped_forward_once(self, x, expert_weights, top_experts):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ out = self.grouped_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ -1, # unused
+ self.args.moe_top_k,
+ )
+ return out, tokens_per_expert
+
+ def grouped_permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy, # unused
+ top_k,
+ ):
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.gather(x, indices, bin_ids, bins, top_k)
+
+ # Perform the expert computation.
+ x = self.mlp(x, tokens_per_expert)
+
+ # Un-route the data for the MoE output.
+ return ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
+
+ def forward_once(self, x, expert_weights, top_experts):
+ if self.args.mlp_impl == 'sparse':
+ return self.sparse_forward_once(x, expert_weights, top_experts)
+ else:
+ return self.grouped_forward_once(x, expert_weights, top_experts)
+
+ def permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ ):
+ if self.args.mlp_impl == 'sparse':
+ return self.sparse_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ )
+ else:
+ return self.grouped_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ )
+
+
+class dMoE(moe.MoE):
+
+ def _init_experts_mlp(self, args: Arguments):
+ return ParallelDroplessMLP(args)
diff --git a/build/torch29-cxx11-cpu-x86_64-linux/_layers/gelu.py b/build/torch29-cxx11-cpu-x86_64-linux/_layers/gelu.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4c9e6532798615b5c12c96694241a4c18ee8f7b
--- /dev/null
+++ b/build/torch29-cxx11-cpu-x86_64-linux/_layers/gelu.py
@@ -0,0 +1,52 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# try:
+# import stk
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/gelu.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+import torch.nn.functional as F
+
+
+@torch.jit.script
+def _gelu_backward_inplace(g, x):
+ tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
+ ff = (0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out))
+ return g.mul_(ff)
+
+
+def gelu_backward_(grad: stk.Matrix, x: stk.Matrix):
+ # NOTE: The two sparse matrices must have the same topology.
+ if isinstance(grad, stk.Matrix) and isinstance(x, stk.Matrix):
+ return stk.Matrix(
+ x.size(),
+ _gelu_backward_inplace(grad.data, x.data),
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
+ return _gelu_backward_inplace(grad, x)
+
+
+def gelu(x: stk.Matrix):
+ assert isinstance(x, stk.Matrix)
+ return stk.Matrix(
+ x.size(),
+ F.gelu(x.data, approximate='tanh'),
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
diff --git a/build/torch29-cxx11-cpu-x86_64-linux/_layers/glu.py b/build/torch29-cxx11-cpu-x86_64-linux/_layers/glu.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f297a41ff6a1a2a285f5b461951672364b898da
--- /dev/null
+++ b/build/torch29-cxx11-cpu-x86_64-linux/_layers/glu.py
@@ -0,0 +1,244 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# import stk.ops
+# try:
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/glu.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+
+# from megablocks import grouped_gemm_util as gg
+# from megablocks.layers import common, mpu
+# from megablocks.layers.activation_fn import act_fn
+# from megablocks.layers.arguments import Arguments
+# from megablocks.layers.mlp import (
+# SharedMLP,
+# SparseMLP,
+# create_dmoe_expert_weights,
+# resolve_dtensor,
+# )
+
+from .. import grouped_gemm_util as gg
+from . import common, mpu
+from .activation_fn import act_fn
+from .arguments import Arguments
+from .mlp import (
+ SharedMLP,
+ SparseMLP,
+ create_dmoe_expert_weights,
+ resolve_dtensor,
+)
+
+
+class SparseGLU(SparseMLP):
+
+ def __init__(self, args: Arguments):
+ super().__init__(args)
+ self.v1 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ with torch.no_grad():
+ self.v1.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ ),
+ )
+
+ mpu.set_expert_model_parallel_attributes(
+ self.v1,
+ self._should_set_parallelism_attribute,
+ )
+
+ def forward(self, x, topo):
+ if self.args.memory_optimized_mlp:
+ raise NotImplementedError(
+ 'Memory optimized implementation not yet supported with GLU with sparse kernels.',
+ )
+
+ w1, v1, w2 = self.scale_grad(self.w1), self.scale_grad(self.v1,), self.scale_grad(self.w2)
+ w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2)
+
+ # Compute the GLU.
+ x1 = stk.ops.sdd(x, w1.t(), topo)
+ x2 = stk.ops.sdd(x, v1.t(), topo)
+
+ activation_fn_out = act_fn(x1, self.args.activation_fn)
+ x1 = stk.ops.mul(activation_fn_out, x2)
+
+ return stk.ops.dsd(x1, w2)
+
+
+class MemoryOptimizedGroupedGLU(torch.autograd.Function):
+ """GroupedMLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, v1, w2, batch_sizes, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ v1 = v1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], v1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not v1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1', 'v1' and 'w2'.")
+
+ # Layer 0: x @ w1.t().
+ assert gg.backend is not None
+ sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True)
+ v1_out = gg.backend.gmm(x, v1, batch_sizes, trans_b=True)
+
+ # GeLU.
+ activation_fn_out = activation_fn(sdd_out) * v1_out
+
+ # Layer 1: x @ w2.
+ dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, v1, w2, batch_sizes, x, sdd_out, v1_out)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx, ddsd_out):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # Unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, v1, w2 = saved_tensors[:3]
+ batch_sizes = saved_tensors[3]
+ x = saved_tensors[4]
+ sdd_out, v1_out = saved_tensors[5:7]
+
+ # Rematerialize activation_fn output.
+ activation_fn = ctx.activation_fn
+ with torch.set_grad_enabled(True):
+ sdd_out.requires_grad = True
+ v1_out.requires_grad = True
+ activation_fn_out = activation_fn(sdd_out) * v1_out
+ activation_grad_fn = activation_fn_out.backward
+
+ # Compute dw2 with recomputed activation_fn output.
+ assert gg.backend is not None
+ dw2 = gg.backend.gmm(
+ activation_fn_out,
+ ddsd_out,
+ batch_sizes,
+ trans_a=True,
+ )
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ gg.backend.gmm(
+ ddsd_out,
+ w2,
+ batch_sizes,
+ trans_b=True,
+ c=dactivation_fn_out,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out)
+ dsdd_out = sdd_out.grad
+ dv1_out = v1_out.grad
+
+ # Compute dw1.
+ dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True)
+
+ # Compute dv1.
+ dv1 = gg.backend.gmm(dv1_out, x, batch_sizes, trans_a=True)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ dx = ddsd_out
+ gg.backend.gmm(dsdd_out, w1, batch_sizes, c=dx)
+ dx += gg.backend.gmm(dv1_out, v1, batch_sizes)
+ return dx, dw1, dv1, dw2, None, None
+
+
+memory_optimized_grouped_glu = MemoryOptimizedGroupedGLU.apply
+
+
+class GroupedGLU(SparseGLU):
+
+ def forward(self, x, tokens_per_expert):
+ batch_sizes = tokens_per_expert.cpu().to(torch.long)
+ w1, v1, w2 = (
+ self.scale_grad(self.w1),
+ self.scale_grad(self.v1),
+ self.scale_grad(self.w2),
+ )
+ w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2)
+
+ # Re-shape the weights for the grouped GEMMs.
+ ne = mpu.experts_per_rank(self.args)
+ w1 = w1.view(ne, -1, self.args.hidden_size)
+ v1 = v1.view(ne, -1, self.args.hidden_size)
+ w2 = w2.view(ne, -1, self.args.hidden_size)
+
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_grouped_glu(
+ x,
+ w1,
+ v1,
+ w2,
+ batch_sizes,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ assert gg.ops is not None
+ x1 = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
+ x2 = gg.ops.gmm(x, v1, batch_sizes, trans_b=True)
+ x1 = self.args.activation_fn(x1) * x2
+ return gg.ops.gmm(x1, w2, batch_sizes)
+
+
+class SharedGLU(SharedMLP):
+ """GPU for shared expert.
+
+ Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTGLU class
+ """
+
+ def __init__(self, args: Arguments):
+ super().__init__(args)
+ self.gate_proj = args.fc_cls(
+ args.hidden_size,
+ self.args.shared_expert_hidden_size,
+ **self.fc_kwargs,
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))
diff --git a/build/torch29-cxx11-cpu-x86_64-linux/_layers/memory_test.py b/build/torch29-cxx11-cpu-x86_64-linux/_layers/memory_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..74d1166931b712635131985b25a89f4ca23e576d
--- /dev/null
+++ b/build/torch29-cxx11-cpu-x86_64-linux/_layers/memory_test.py
@@ -0,0 +1,103 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import gc
+
+import torch
+import torch.distributed as dist
+
+# from megablocks.layers import arguments, dmoe
+from . import arguments, dmoe
+
+_TESTS = ((8, 2048, 4096, 4096, 32, 4),)
+
+
+def get_tensors():
+ ptrs = set()
+ out = []
+ for obj in gc.get_objects():
+ if torch.is_tensor(obj):
+ if not obj.is_contiguous() or obj.data_ptr() in ptrs:
+ continue
+ out.append(obj)
+ ptrs.add(obj.data_ptr())
+ return out
+
+
+def test_memory(
+ group,
+ batch_size,
+ sequence_length,
+ hidden_size,
+ ffn_hidden_size,
+ num_experts,
+ top_k,
+):
+ args = arguments.Arguments(
+ hidden_size=hidden_size,
+ ffn_hidden_size=ffn_hidden_size,
+ moe_num_experts=num_experts,
+ moe_top_k=top_k,
+ moe_expert_model_parallelism=True,
+ expert_parallel_group=group,
+ fp16=False,
+ bf16=True,
+ device=torch.cuda.current_device(),
+ )
+ layer = dmoe.dMoE(args).cuda()
+
+ x = torch.randn((batch_size, sequence_length, hidden_size),
+ device=torch.cuda.current_device(),
+ dtype=torch.bfloat16).requires_grad_(True)
+ torch.cuda.empty_cache()
+
+ # Run forward + backward.
+ # with torch.autograd.detect_anomaly():
+ out, _ = layer(x)
+ out.mean().backward()
+
+ # Report peak memory.
+ mem = torch.cuda.max_memory_allocated()
+ print('Max Memory Allocated = {:0.0f}MiB'.format(mem / 1e6))
+ print('Max Memory Reserved = {:0.0f}MiB'.format(torch.cuda.max_memory_reserved() / 1e6,),)
+
+ # Calculate weight and gradient memory usage.
+ weight_memory = 2 * (
+ layer.router.layer.weight.numel() + layer.experts.mlp.w1.numel() + layer.experts.mlp.w2.numel()
+ )
+
+ def grad_numel(x):
+ if x.grad is not None:
+ return x.grad.numel()
+ return 0
+
+ grad_memory = 2 * (
+ grad_numel(layer.router.layer.weight) + grad_numel(layer.experts.mlp.w1) + grad_numel(layer.experts.mlp.w2)
+ )
+ weight_memory += grad_memory
+
+ print('Weight Memory Allocated = {:0.0f}MiB'.format(weight_memory / 1e6))
+ print('Activation Memory Allocated = {:0.0f}MiB'.format((mem - weight_memory) / 1e6,),)
+
+ # Manually calculate GPU memory usage from the garbage
+ # collector.
+ gc.collect()
+ total = 0
+ tensors = get_tensors()
+ tensors = sorted(tensors, key=lambda x: -x.numel())
+ for i, t in enumerate(tensors):
+ total += t.numel()
+ print(f'{i}: {t.shape}, {t.numel() * 2}')
+ del tensors
+
+ print('Total Bytes Found = {:0.0f}MiB'.format(total * 2 / 1e6))
+
+
+if __name__ == '__main__':
+ assert dist.is_available()
+ group = dist.init_process_group(backend='nccl')
+ local_rank = dist.get_rank(group)
+ torch.cuda.set_device(local_rank)
+
+ for args in _TESTS:
+ test_memory(group, *args)
diff --git a/build/torch29-cxx11-cpu-x86_64-linux/_layers/mlp.py b/build/torch29-cxx11-cpu-x86_64-linux/_layers/mlp.py
new file mode 100644
index 0000000000000000000000000000000000000000..c99afb9904c24a8b6a83e79059cd1251dbbfd99e
--- /dev/null
+++ b/build/torch29-cxx11-cpu-x86_64-linux/_layers/mlp.py
@@ -0,0 +1,587 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# try:
+# import stk
+# import stk.backend.triton_kernels
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/mlp.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+from packaging import version
+
+# from megablocks import grouped_gemm_util as gg
+# from megablocks.layers import common, gelu, mpu
+# from megablocks.layers.activation_fn import act_fn
+# from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
+
+from .. import grouped_gemm_util as gg
+from . import common, gelu, mpu
+from .activation_fn import act_fn
+from .arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
+
+class ScaleGradient(torch.autograd.Function):
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx: Any, x: torch.Tensor, scale: float):
+ ctx.scale = scale
+ return x
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx: torch.Tensor, grad: torch.Tensor):
+ return grad * ctx.scale, None
+
+
+scale_gradient = ScaleGradient.apply
+
+
+def resolve_dtensor(weight: torch.Tensor):
+ if version.parse(torch.__version__) >= version.parse('2.0.0'):
+ from torch.distributed._tensor import DTensor
+ if isinstance(weight, DTensor):
+ return weight.to_local()
+ return weight
+
+
+def create_moe_expert_weights(
+ args: Arguments,
+ num_experts: int,
+ ffn_hidden_size: int,
+ hidden_size: int,
+ init_method: InitFn,
+):
+ # Create the entire weight matrix such that the sampled weights will
+ # not vary between data parallelism and expert model parallelism for
+ # the same random seed.
+ master_weights = torch.empty(
+ num_experts,
+ ffn_hidden_size,
+ hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ )
+ init_method(master_weights)
+
+ if not args.moe_expert_model_parallelism:
+ return master_weights
+
+ # Calculate the amount of sharding in each dimension.
+ expert_sharding_degree = mpu.expert_sharding_degree(args)
+ hidden_sharding_degree = mpu.hidden_sharding_degree(args)
+
+ # Calculate the experts per rank.
+ #
+ # NOTE: We assign ranks to be expert parallel before going
+ # tensor parallel.
+ rank = mpu.get_expert_parallel_rank(args)
+ expert_rank = rank % expert_sharding_degree
+ num_experts_per_rank = num_experts // expert_sharding_degree
+ start_expert = expert_rank * num_experts_per_rank
+ end_expert = (expert_rank + 1) * num_experts_per_rank
+
+ # Calculate the rows per rank.
+ row_rank = rank // expert_sharding_degree
+ num_rows_per_rank = ffn_hidden_size // hidden_sharding_degree
+ start_row = row_rank * num_rows_per_rank
+ end_row = (row_rank + 1) * num_rows_per_rank
+
+ # Slice the weight matrix to get the chunk for this rank.
+ with torch.no_grad():
+ weights = master_weights[start_expert:end_expert, start_row:end_row]
+ return weights
+
+
+class MLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ # expert_parallel_world_size = mpu.get_expert_parallel_world_size(args)
+ experts_per_rank = mpu.experts_per_rank(args)
+
+ self.w1 = torch.nn.Parameter(
+ torch.empty(
+ experts_per_rank,
+ args.hidden_size,
+ mpu.features_per_rank(args),
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ self.w2 = torch.nn.Parameter(
+ torch.empty(
+ experts_per_rank,
+ mpu.features_per_rank(args),
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w1,
+ args.moe_expert_model_parallelism,
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w2,
+ args.moe_expert_model_parallelism,
+ )
+
+ # Initialize the parameters for the MLP.
+ #
+ # NOTE: It is important that we create the weight tensors prior
+ # to creating the master weights and slicing our the piece for
+ # this rank. If the master weights are created first the PyTorch
+ # caching allocator appears to use the same memory block for these
+ # and the slice which causes large increases in our peak memory
+ # usage.
+ with torch.no_grad():
+ w1 = create_moe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ )
+ self.w1.copy_(w1.transpose(1, 2).contiguous())
+ self.w2.copy_(
+ create_moe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.output_layer_init_method,
+ ),
+ )
+
+ self.gradient_scale = None
+ if self.args.moe_expert_model_parallelism:
+ self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,)
+
+ def scale_grad(self, w):
+ if self.gradient_scale is None:
+ return w
+ return scale_gradient(w, self.gradient_scale)
+
+ def forward(self, x):
+ w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2)
+ w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2)
+ x = torch.bmm(x, w1)
+ x = self.args.activation_fn(x)
+ return torch.bmm(x, w2)
+
+
+def create_dmoe_expert_weights(
+ args: Arguments,
+ num_experts: int,
+ rows: int,
+ columns: int,
+ init_method: InitFn,
+):
+ weights = create_moe_expert_weights(
+ args,
+ num_experts,
+ rows,
+ columns,
+ init_method,
+ )
+ return weights.view([-1, columns])
+
+
+class MemoryOptimizedMLP(torch.autograd.Function):
+ """Sparse MLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, w2, topo, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.")
+
+ topo_tensors = (
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets,
+ topo.column_indices_t,
+ topo.offsets_t,
+ topo.block_offsets_t,
+ )
+
+ # Layer 0: x @ w1.t().
+ sdd_out = stk.ops.sdd(x, w1.t(), topo)
+
+ # GeLU.
+ activation_fn_out = act_fn(sdd_out, activation_fn)
+
+ # Layer 1: x @ w2.
+ dsd_out = stk.ops.dsd(activation_fn_out, w2)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.shape = topo.shape
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.data.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, w2, *topo_tensors, x, sdd_out.data)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx, ddsd_out):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, w2 = saved_tensors[:2]
+ topo_tensors = saved_tensors[2:8]
+ x = saved_tensors[8]
+ sdd_out_data = saved_tensors[9]
+
+ # rematerialize activation function output
+ activation_fn = ctx.activation_fn
+ sdd_out = stk.Matrix(ctx.shape, sdd_out_data, *topo_tensors)
+ activation_fn_out, activation_grad_fn = act_fn(
+ sdd_out,
+ activation_fn,
+ return_grad_fn=True,
+ )
+
+ # Compute dw2 with recomputed activation_fn output.
+ dw2 = stk.ops.dsd(activation_fn_out.t(), ddsd_out)
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ stk.backend.triton_kernels.sdd(
+ ddsd_out,
+ w2.t(),
+ dactivation_fn_out.shape,
+ dactivation_fn_out.data,
+ dactivation_fn_out.offsets,
+ dactivation_fn_out.row_indices,
+ dactivation_fn_out.column_indices,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ if activation_fn is DEFAULT_ACTIVATION_FN:
+ dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out)
+ else:
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out.data)
+ dsdd_out = stk.Matrix(ctx.shape, sdd_out.data.grad, *topo_tensors)
+
+ # Compute dw1.
+ dw1 = stk.ops.dsd(dsdd_out.t(), x)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ stk.backend.triton_kernels.dsd(
+ dsdd_out.shape,
+ dsdd_out.data,
+ dsdd_out.offsets,
+ dsdd_out.row_indices,
+ dsdd_out.column_indices,
+ dsdd_out.offsets_t,
+ dsdd_out.column_indices_t,
+ dsdd_out.block_offsets_t,
+ False,
+ w1,
+ ddsd_out,
+ )
+ dx = ddsd_out
+ return dx, dw1, dw2, None, None
+
+
+memory_optimized_mlp = MemoryOptimizedMLP.apply
+
+
+class SparseMLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ self._num_rows_per_rank = mpu.experts_per_rank(args) * mpu.features_per_rank(args)
+
+ self.w1 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ self.w2 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+
+ # Initialize the parameters for the MLP.
+ #
+ # NOTE: It is important that we create the weight tensors prior
+ # to creating the master weights and slicing our the piece for
+ # this rank. If the master weights are created first the PyTorch
+ # caching allocator appears to use the same memory block for these
+ # and the slice which causes large increases in our peak memory
+ # usage.
+ with torch.no_grad():
+ self.w1.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ ),
+ )
+ self.w2.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.output_layer_init_method,
+ ),
+ )
+
+ self._should_set_parallelism_attribute = args.moe_expert_model_parallelism
+ mpu.set_expert_model_parallel_attributes(
+ self.w1,
+ self._should_set_parallelism_attribute,
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w2,
+ self._should_set_parallelism_attribute,
+ )
+
+ self.gradient_scale = None
+ if self.args.moe_expert_model_parallelism:
+ self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,)
+
+ def scale_grad(self, w):
+ if self.gradient_scale is None:
+ return w
+ return scale_gradient(w, self.gradient_scale)
+
+ def forward(self, x, topo):
+ w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2)
+ w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2)
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_mlp(
+ x,
+ w1,
+ w2,
+ topo,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ x = stk.ops.sdd(x, w1.t(), topo)
+ activation_fn_out = act_fn(x, self.args.activation_fn)
+ return stk.ops.dsd(activation_fn_out, w2)
+
+
+class MemoryOptimizedGroupedMLP(torch.autograd.Function):
+ """GroupedMLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, w2, batch_sizes, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.")
+
+ # Layer 0: x @ w1.t().
+ assert gg.backend is not None
+ sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True)
+
+ # activation_fn
+ activation_fn_out = activation_fn(sdd_out)
+
+ # Layer 1: x @ w2.
+ dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, w2, batch_sizes, x, sdd_out)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx: Any, ddsd_out: torch.Tensor):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # Unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, w2 = saved_tensors[:2]
+ batch_sizes = saved_tensors[2]
+ x = saved_tensors[3]
+ sdd_out = saved_tensors[4]
+
+ # Rematerialize activation_fn output.
+ activation_fn = ctx.activation_fn
+ with torch.set_grad_enabled(True):
+ sdd_out.requires_grad = True
+ activation_fn_out = activation_fn(sdd_out)
+ activation_grad_fn = activation_fn_out.backward
+
+ # Compute dw2 with recomputed activation_fn output.
+ assert gg.backend is not None
+ dw2 = gg.backend.gmm(
+ activation_fn_out,
+ ddsd_out,
+ batch_sizes,
+ trans_a=True,
+ )
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ gg.backend.gmm(
+ ddsd_out,
+ w2,
+ batch_sizes,
+ trans_b=True,
+ c=dactivation_fn_out,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ if activation_fn is DEFAULT_ACTIVATION_FN:
+ dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out)
+ else:
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out)
+ dsdd_out = sdd_out.grad
+
+ # Compute dw1.
+ dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ gg.backend.gmm(dsdd_out, w1, batch_sizes, c=ddsd_out)
+ dx = ddsd_out
+ return dx, dw1, dw2, None, None
+
+
+memory_optimized_grouped_mlp = MemoryOptimizedGroupedMLP.apply
+
+
+class GroupedMLP(SparseMLP):
+
+ def forward(self, x, tokens_per_expert):
+ batch_sizes = tokens_per_expert.cpu().to(torch.long)
+ w1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.w2))
+
+ # Re-shape the weights for the grouped GEMMs.
+ ne = mpu.experts_per_rank(self.args)
+ w1 = resolve_dtensor(w1).view(ne, -1, self.args.hidden_size)
+ w2 = resolve_dtensor(w2).view(ne, -1, self.args.hidden_size)
+
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_grouped_mlp(
+ x,
+ w1,
+ w2,
+ batch_sizes,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ assert gg.ops is not None
+ x = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
+ x = self.args.activation_fn(x)
+ return gg.ops.gmm(x, w2, batch_sizes)
+
+
+class SharedMLP(torch.nn.Module):
+ """MLP for shared expert.
+
+ Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTMLP class
+ """
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ self.fc_kwargs: dict[str, Any] = {
+ 'bias': args.bias,
+ 'device': args.device,
+ }
+ self.fc_kwargs.update(args.fc_kwargs)
+
+ self.up_proj = args.fc_cls(
+ args.hidden_size,
+ args.shared_expert_hidden_size,
+ **self.fc_kwargs,
+ )
+ self.act = args.activation_fn
+ self.down_proj = args.fc_cls(
+ args.shared_expert_hidden_size,
+ args.hidden_size,
+ **self.fc_kwargs,
+ )
+ self.down_proj._is_residual = True # a flag for llm-foundry init
+
+ def add_experts_sharedexpert(
+ self,
+ shared_expert_out: torch.Tensor,
+ expert_out: torch.Tensor,
+ ) -> torch.Tensor:
+ # Helper function to add expert output to shared expert output
+ # with optional weighted sum.
+ if self.args.shared_expert_weighted_sum:
+ # enable using weighted sum for shared expert output
+ # wieghted by number of experts used
+ t_experts = self.args.moe_top_k + 1
+ sh_mlp_out = shared_expert_out / t_experts
+ return sh_mlp_out.add(
+ expert_out,
+ alpha=(self.args.moe_top_k / t_experts),
+ )
+
+ return shared_expert_out + expert_out
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.down_proj(self.act(self.up_proj(x)))
diff --git a/build/torch29-cxx11-cpu-x86_64-linux/_layers/moe.py b/build/torch29-cxx11-cpu-x86_64-linux/_layers/moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0a4aeaacc9c86fc70944e730c53f7a55644e05e
--- /dev/null
+++ b/build/torch29-cxx11-cpu-x86_64-linux/_layers/moe.py
@@ -0,0 +1,507 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Optional, Tuple
+
+import numpy as np
+import torch
+import torch.distributed as dist
+
+# import megablocks.ops as ops
+# from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry
+# from megablocks.layers.all_to_all import all_to_all
+# from megablocks.layers.arguments import Arguments
+
+from ..ops import (
+ sort,
+ histogram,
+ inclusive_cumsum,
+ exclusive_cumsum,
+ binned_gather,
+ binned_scatter,
+ gather,
+ scatter,
+ repeat,
+ replicate,
+)
+
+from . import common, mlp, mpu, router, sharedexpert_registry
+from .arguments import Arguments
+from .all_to_all import all_to_all
+
+_LOAD_BALANCING_LOSS = []
+
+
+def save_load_balancing_loss(loss):
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.append(loss)
+
+
+def get_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ return _LOAD_BALANCING_LOSS
+
+
+def clear_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.clear()
+
+
+def batched_load_balancing_loss(args: Arguments):
+ if args.moe_loss_weight == 0:
+ return 0.0
+
+ # tokens_per_expert[i].shape = (num_experts)
+ # expert_scores[i].shape = (tokens, num_experts)
+ tokens_per_expert, expert_scores = zip(*get_load_balancing_loss())
+ num_layers_per_pipeline_stage = (args.num_layers // args.pipeline_model_parallel_size)
+ if args.num_layers_per_virtual_pipeline_stage is not None:
+ num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage
+
+ if len(tokens_per_expert) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f'Expected {num_layers_per_pipeline_stage} token_per_experts '
+ f'but found {len(tokens_per_expert)}.\nnum_layers = '
+ f'{args.num_layers}\npipeline_model_parallel_size = '
+ f'{args.pipeline_model_parallel_size}\n'
+ 'num_layers_per_virtual_pipeline_stage'
+ f' = {args.num_layers_per_virtual_pipeline_stage}',
+ )
+ if len(expert_scores) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f'Expected {num_layers_per_pipeline_stage} expert_scores '
+ f'but found {len(tokens_per_expert)}.\nnum_layers = '
+ f'{args.num_layers}\npipeline_model_parallel_size = '
+ f'{args.pipeline_model_parallel_size}\n'
+ 'num_layers_per_virtual_pipeline_stage'
+ f' = {args.num_layers_per_virtual_pipeline_stage}',
+ )
+
+ # Verify the shape of the tokens_per_expert and expert_scores tensors.
+ assert all((x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert))
+
+ tokens = expert_scores[0].shape[0]
+ assert all(((x.ndim == 2 and x.shape[1] == args.moe_num_experts and x.shape[0] == tokens) for x in expert_scores))
+
+ # Concatenate the contributions of each layer and convert to
+ # the correct types and formats for the dot product.
+ expert_scores = torch.cat(expert_scores, dim=1)
+ if args.moe_lbl_in_fp32:
+ expert_scores = expert_scores.float()
+ if tokens != 0:
+ expert_scores = expert_scores.mean(dim=0)
+ else:
+ expert_scores = expert_scores.sum(dim=0)
+ tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype)
+
+ expected_values = num_layers_per_pipeline_stage * args.moe_num_experts
+ assert tokens_per_expert.numel() == expected_values
+ assert expert_scores.numel() == expected_values
+
+ # Calculate the total scale across all factors.
+ #
+ # loss_weight * num_experts / (num_layers * tokens * top_k)
+ scale_numerator = (args.moe_num_experts * args.moe_loss_weight)
+ scale_denominator = (args.num_layers * tokens * args.moe_top_k)
+ scale = scale_numerator / scale_denominator
+ return scale * torch.dot(tokens_per_expert, expert_scores)
+
+
+# NOTE: This class defines MoE expert computation, including expert model parallel
+# communication. When using FSDP on top of MegaBlocks this is the module that should
+# be wrapped s.t. the weight all-gathers can be scheduled *before* the expert model
+# parallel all2all.
+class ParallelMLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super(ParallelMLP, self).__init__()
+ self.args = args
+
+ # Calculate the number of experts in total and the number of experts
+ # owned by this rank.
+ # world_size = mpu.get_expert_parallel_world_size(args)
+ self.num_experts = args.moe_num_experts
+ self.top_k = self.args.moe_top_k
+
+ # Calculate the number of bits needed to represent the expert indices
+ # so that we can pass it to radix sort.
+ self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1)
+
+ # Expert MLP.
+ self.mlp = mlp.MLP(args)
+
+ self.bias: Optional[torch.Tensor]
+ if self.args.bias:
+ # Note that the output bias is not parallelized with expert
+ # model parallelism.
+ self.bias = torch.nn.Parameter(
+ torch.empty(
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ torch.nn.init.zeros_(self.bias)
+ else:
+ self.register_parameter('bias', None)
+
+ # Select the forward function for the operating mode.
+ self.forward_fn = (self.parallel_forward_once if args.moe_expert_model_parallelism else self.forward_once)
+
+ def expert_capacity(self, tokens: int) -> int:
+ world_size = mpu.get_expert_parallel_world_size(self.args)
+ tokens_per_expert = (self.top_k * tokens * world_size / self.num_experts)
+ return int(self.args.moe_capacity_factor * tokens_per_expert)
+
+ def load_balancing_loss(self, tokens_per_expert: torch.Tensor, expert_scores: torch.Tensor):
+ """Calculate the load balancing loss contribution."""
+ assert len(expert_scores.size()) == 2
+ tokens, num_experts = expert_scores.size()
+ assert num_experts == self.num_experts
+ assert len(tokens_per_expert.size()) == 1
+ num_experts, = tokens_per_expert.size()
+ assert num_experts == self.num_experts
+ scale = self.num_experts / (tokens * self.top_k)
+ return scale * torch.dot(
+ tokens_per_expert.to(expert_scores.dtype),
+ expert_scores.mean(dim=0),
+ )
+
+ def indices_and_bins(self,
+ top_expert: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ # Sort the expert ids to produce the scatter/gather
+ # indices for the permutation.
+ #
+ # TODO(tgale): Is it worth doing this conversion to 32-bit
+ # prior? Could we place the `torch.max` operation to return
+ # 32-bit expert indices?
+ top_expert = top_expert.int()
+ # output = ops.sort(top_expert, self.sort_end_bit)
+ output = sort(top_expert, self.sort_end_bit)
+ assert output is not None
+ bin_ids, indices = output
+
+ # Histogram the expert ids to identify the number of
+ # tokens routed to each expert.
+ #
+ # TODO(tgale): Does the sorted data produce a more favorable
+ # data distribution for histogram? Or is the op parallelism
+ # worth more?
+ # tokens_per_expert = ops.histogram(top_expert, self.num_experts)
+ tokens_per_expert = histogram(top_expert, self.num_experts)
+
+ # Calculate the bin bounds for the sorted tokens.
+ # bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ bins = inclusive_cumsum(tokens_per_expert, 0)
+ assert bins is not None
+ bins = bins.view(1) if not len(bins.size()) else bins
+
+ assert isinstance(indices, torch.Tensor)
+ assert isinstance(bin_ids, torch.Tensor)
+ assert isinstance(bins, torch.Tensor)
+ assert isinstance(tokens_per_expert, torch.Tensor)
+
+ return indices, bin_ids, bins, tokens_per_expert
+
+ def permute_and_compute(
+ self,
+ x: torch.Tensor,
+ tokens_per_expert: int, # unused
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor, # unused
+ expert_weights: torch.Tensor,
+ bins: torch.Tensor,
+ expert_capacity: int,
+ top_k: int,
+ ):
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ # output = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
+ output = binned_gather(x, indices, bins, expert_capacity, top_k)
+ assert output is not None
+ x = output
+
+ # Perform the expert computation. Note that we don't
+ # use biases for these linear operations.
+ x = self.mlp(x)
+
+ # Un-route the data for the MoE output.
+ # return ops.binned_scatter(x, indices, expert_weights, bins, top_k)
+ return binned_scatter(x, indices, expert_weights, bins, top_k)
+
+
+ def forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ # If expert_capacity is set to zero, set the number of tokens
+ # per expert to the maximum we need to avoid dropping tokens.
+ sl, bs, _ = x.size()
+ expert_capacity = self.expert_capacity(sl * bs)
+ if expert_capacity == 0:
+ expert_capacity = torch.max(tokens_per_expert).item()
+
+ x = self.permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ self.top_k,
+ )
+ return x, tokens_per_expert
+
+ def parallel_forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ # NOTE: This function implements the same computation as forward_once
+ # but with expert model parallelism.
+ #
+ # 1. Permute the tokens locally so that they are grouped by their
+ # expert assignments. This allows us to transfer all of the tokens
+ # for a remote device in one communication primitive.
+ #
+ # 2. Permute the tokens across the expert parallel devices. After
+ # this is completed each device has all of the tokens assigned to
+ # its set of experts in its local HBM.
+ #
+ # 3. Permute the tokens locally so that they are grouped by their
+ # expert assignement. After the distributed permutation the tokens
+ # are grouped by which device they came from. We re-order them
+ # locally to allow for efficient computation.
+ #
+ # After this series of permutations we compute the linear layers
+ # and then repeat these three steps in reverse to produce the final
+ # output.
+ #
+ # Compute the mapping of local tokens to experts.
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ # If we're sharding the experts along the hidden dimension
+ # multiple devices own parts of the same sets of experts.
+ # Replicate the token counts so every device gets the counts.
+ # repeated_tokens_per_expert = ops.repeat(
+ repeated_tokens_per_expert = repeat(
+ tokens_per_expert,
+ (mpu.hidden_sharding_degree(self.args),),
+ )
+
+ # Pass token count information to the device on which the
+ # target expert resides.
+ parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert,)
+ tpe_handle = dist.all_to_all_single(
+ parallel_tokens_per_expert,
+ repeated_tokens_per_expert,
+ group=self.args.expert_parallel_group,
+ async_op=True,
+ )
+
+ # Permute locally and without any padding so that tokens for each
+ # parallel device are stored contiguously.
+ #
+ # This view updates the shape of the tensor from [sl, bs, hs] to
+ # [sl * bs, hs] prior to the permutation.
+ x = x.view(-1, x.shape[-1])
+ # output = ops.gather(x, indices, bin_ids, bins, self.top_k)
+ output = gather(x, indices, bin_ids, bins, self.top_k)
+ assert output is not None
+ x = output
+
+ # Compute the number of tokens that will be received from each
+ # device and permute the input data across the devices.
+ with torch.no_grad():
+ tpe_handle.wait()
+ experts_per_rank = mpu.experts_per_rank(self.args)
+
+ # Reshape to [world_size, num_experts_per_rank].
+ world_size = mpu.get_expert_parallel_world_size(self.args)
+ repeated_tokens_per_expert = (repeated_tokens_per_expert.view(world_size, experts_per_rank))
+ parallel_tokens_per_expert = (parallel_tokens_per_expert.view(world_size, experts_per_rank))
+
+ # TODO(tgale): It might be faster to do this on the GPU and
+ # then communicate the results back to the host.
+ send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1)
+ parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
+ recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1)
+
+ # Convert the send/recv counts to lists.
+ send_counts = send_counts.tolist()
+ recv_counts = recv_counts.tolist()
+ tokens_received = sum(recv_counts)
+
+ # If we're sharding the experts along the hidden dimension
+ # multiple devices own parts of the same sets of experts.
+ # Replicate the token counts so devices that share experts
+ # get all of the tokens assigned to them.
+ #
+ # TODO(tgale): Fuse this into the prior, local permutation.
+ # x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
+ x = repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
+
+ # Start the cross-device permutation asynchronously so we can
+ # overlap communication with computation.
+ parallel_x, parallel_x_handle = all_to_all(
+ x,
+ recv_counts,
+ send_counts,
+ self.args.expert_parallel_group,
+ async_op=True,
+ )
+
+ with torch.no_grad():
+ # After we do the cross-device permutation we have the tokens on the
+ # correct device but not yet grouped by expert because we received
+ # tokens from each device as contiguous chunks. To group the tokens
+ # for expert computation we'll do one more local permutation. The
+ # rest of this torch.no_grad() scope sets up the indices and bins
+ # for this permutation.
+ # replicate_bins = ops.inclusive_cumsum(
+ replicate_bins = inclusive_cumsum(
+ parallel_tokens_per_expert.flatten(),
+ 0,
+ )
+ replicate_bins = (replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins)
+
+ # Construct the expert indices for the permuted tokens.
+ parallel_top_expert = torch.remainder(
+ torch.arange(
+ self.num_experts * mpu.hidden_sharding_degree(self.args),
+ dtype=torch.int32,
+ device=indices.device,
+ ),
+ mpu.experts_per_rank(self.args),
+ )
+ # parallel_top_expert = ops.replicate(
+ parallel_top_expert = replicate(
+ parallel_top_expert.unsqueeze(dim=0),
+ replicate_bins,
+ tokens_received,
+ ).flatten()
+
+ # TODO(tgale): The sort_end_bit here can be reduced.
+ # parallel_bin_ids, parallel_indices = ops.sort(
+ parallel_bin_ids, parallel_indices = sort(
+ parallel_top_expert,
+ self.sort_end_bit,
+ )
+
+ # Calculate the bins boundaries from the token counts.
+ parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+ # parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = (parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins)
+
+ # If expert_capacity is set to zero, set the number of tokens
+ # per expert to the maximum we need to avoid dropping tokens.
+ tokens, _ = x.size()
+ expert_capacity = self.expert_capacity(tokens)
+ if expert_capacity == 0:
+ expert_capacity = torch.max(parallel_tokens_per_expert).item()
+
+ # Locally permute the tokens and perform the expert computation.
+ # Block to make sure that the cross-device permutation is complete.
+ if self.args.mlp_impl == 'grouped':
+ # GroupedMLP requires counts on CPU. We can use the tensor already
+ # moved to CPU for the prior all_to_all, which avoids an extra
+ # device synchronization.
+ parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+ parallel_x_handle.wait()
+ parallel_x = self.permute_and_compute(
+ parallel_x,
+ parallel_tokens_per_expert,
+ parallel_indices,
+ parallel_bin_ids,
+ None, # expert_weights
+ parallel_bins,
+ expert_capacity,
+ top_k=1,
+ )
+
+ # Un-permute the tokens across the devices.
+ x, _ = all_to_all(
+ parallel_x,
+ send_counts,
+ recv_counts,
+ self.args.expert_parallel_group,
+ )
+
+ # Reduce along the hidden sharding to get the final outputs.
+ #
+ # TODO(tgale): Fuse this into the following local permutation.
+ shape = (
+ mpu.hidden_sharding_degree(self.args),
+ -1,
+ self.args.hidden_size,
+ )
+ # x = ops.sum(x.view(shape), dim=0)
+ x = x.view(shape).sum(dim=0)
+
+ # Un-permute locally to setup for the next series of operations.
+ # x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
+ x = scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
+ return x, tokens_per_expert.flatten()
+
+ def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ in_shape = x.size()
+
+ # Compute the experts.
+ x, tokens_per_expert = self.forward_fn(x, expert_weights, top_experts)
+ if self.training and self.args.moe_loss_weight > 0:
+ save_load_balancing_loss((tokens_per_expert, scores))
+ x = x.view(in_shape)
+ if self.bias is not None:
+ if self.args.return_bias:
+ return x, self.bias
+ return x + self.bias
+ return x
+
+
+class MoE(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super(MoE, self).__init__()
+
+ # Token router.
+ self.router = router.LearnedRouter(args)
+
+ # Expert computation helper.
+ self.experts = self._init_experts_mlp(args)
+
+ self.shared_expert = None
+ if args.shared_expert:
+ # SharedExpert computation helper.
+ self.shared_expert = sharedexpert_registry.get(args)
+
+ def _init_experts_mlp(self, args: Arguments):
+ return ParallelMLP(args)
+
+ def forward(self, x: torch.Tensor):
+ # NOTE: If we're going to cast the activations to lower precision
+ # do it before we permute the tokens to save bandwidth.
+ x = common.cast_if_autocast_enabled(x)
+
+ # Compute the expert scores and assignments.
+ scores, expert_weights, top_experts = self.router(x)
+
+ # Compute the experts.
+ out = self.experts(x, scores, expert_weights, top_experts)
+ if self.shared_expert is not None:
+ shared_expert_out = self.shared_expert(x)
+ out = self.shared_expert.add_experts_sharedexpert(
+ shared_expert_out,
+ out,
+ )
+ return out
diff --git a/build/torch29-cxx11-cpu-x86_64-linux/_layers/mpu.py b/build/torch29-cxx11-cpu-x86_64-linux/_layers/mpu.py
new file mode 100644
index 0000000000000000000000000000000000000000..434e143ab42bf3f83406d69e9dd1f72777716e22
--- /dev/null
+++ b/build/torch29-cxx11-cpu-x86_64-linux/_layers/mpu.py
@@ -0,0 +1,94 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Optional
+
+import torch
+import torch.distributed as dist
+
+# from megablocks.layers.arguments import Arguments
+from .arguments import Arguments
+
+
+class MoeParam(torch.Tensor):
+
+ def __init__(self):
+ super().__init__(self)
+ self.expert_model_parallel: bool
+
+
+def is_moe_param(tensor: torch.Tensor) -> bool:
+ return hasattr(tensor, 'expert_model_parallel')
+
+
+def get_expert_parallel_world_size(args: Arguments) -> int:
+ return (dist.get_world_size(args.expert_parallel_group) if args.moe_expert_model_parallelism else 1)
+
+
+def get_expert_parallel_rank(args: Arguments) -> int:
+ return (dist.get_rank(args.expert_parallel_group) if args.moe_expert_model_parallelism else 0)
+
+
+def set_expert_model_parallel_attributes(
+ tensor: torch.Tensor,
+ is_parallel: bool,
+):
+ assert not hasattr(tensor, 'expert_model_parallel')
+ setattr(tensor, 'expert_model_parallel', is_parallel)
+
+
+def param_is_expert_model_parallel(param: MoeParam) -> bool:
+ return (hasattr(param, 'expert_model_parallel') and param.expert_model_parallel)
+
+
+def copy_expert_model_parallel_attributes(
+ destination_tensor: torch.Tensor,
+ source_tensor: torch.Tensor,
+):
+ if hasattr(source_tensor, 'expert_model_parallel'):
+ setattr(
+ destination_tensor,
+ 'expert_model_parallel',
+ getattr(source_tensor, 'expert_model_parallel'),
+ )
+
+
+def synchronized_print(group: Optional[dist.ProcessGroup], *x: torch.Tensor):
+ world_size = dist.get_world_size(group)
+ rank = dist.get_rank(group)
+ for i in range(world_size):
+ dist.barrier(group)
+ if i == rank:
+ print(f'rank = {rank}', *x)
+
+
+# Helpers for expert/tensor sharding.
+def expert_sharding_degree(args: Arguments) -> int:
+ world_size = get_expert_parallel_world_size(args)
+ esd = min(world_size, args.moe_num_experts)
+
+ if (args.moe_num_experts % esd) != 0:
+ raise ValueError(f'Cannot shard {args.moe_num_experts} experts {esd} ways.',)
+ return esd
+
+
+def hidden_sharding_degree(args: Arguments) -> int:
+ world_size = get_expert_parallel_world_size(args)
+ esd = expert_sharding_degree(args)
+ hsd = world_size // esd
+
+ if (args.ffn_hidden_size % hsd) != 0:
+ raise ValueError(f'Cannot shard {args.ffn_hidden_size} features {hsd} ways.',)
+ if (esd * hsd) != world_size:
+ raise ValueError(
+ f"Invalid sharding. 'expert_sharding_degree' ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size}).",
+ )
+ return hsd
+
+
+def experts_per_rank(args: Arguments) -> int:
+ return args.moe_num_experts // expert_sharding_degree(args)
+
+
+def features_per_rank(args: Arguments) -> int:
+ return args.ffn_hidden_size // hidden_sharding_degree(args)
diff --git a/build/torch29-cxx11-cpu-x86_64-linux/_layers/router.py b/build/torch29-cxx11-cpu-x86_64-linux/_layers/router.py
new file mode 100644
index 0000000000000000000000000000000000000000..37cb2782348d62583376f1a183c7ede83601216d
--- /dev/null
+++ b/build/torch29-cxx11-cpu-x86_64-linux/_layers/router.py
@@ -0,0 +1,116 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+
+# from megablocks.layers import common
+# from megablocks.layers.arguments import Arguments
+from . import common
+from .arguments import Arguments
+
+_ROUTER_LOGITS = []
+
+
+def _save_router_logits(logits: torch.Tensor, args: Arguments):
+ if args.moe_zloss_weight == 0:
+ return
+ global _ROUTER_LOGITS
+ _ROUTER_LOGITS.append(logits)
+
+
+def clear_router_zloss():
+ global _ROUTER_LOGITS
+ _ROUTER_LOGITS.clear()
+
+
+def batched_router_zloss(args: Arguments):
+ global _ROUTER_LOGITS
+
+ if args.moe_zloss_weight == 0:
+ import warnings
+ warnings.warn('Call to batched_router_zloss, but moe_zloss_weight=0')
+ return 0
+
+ logits_per_router = _ROUTER_LOGITS
+
+ if args.moe_zloss_in_fp32:
+ logits_per_router = [logits.float() for logits in logits_per_router]
+
+ unscaled_zloss_per_router = torch.stack([
+ torch.logsumexp(logits, dim=1).square().mean() for logits in logits_per_router
+ ])
+
+ return args.moe_zloss_weight * unscaled_zloss_per_router
+
+
+# NOTE: To enable end-to-end benchmarking without convergence we
+# support a flag to force the router to assign tokens uniformly
+# across the experts. We do this with a custom autograd operation
+# so that PyTorch still executes the full set of router operation.
+class _UniformExpertAssignment(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, num_experts: int):
+ out = torch.arange(x.numel(), dtype=x.dtype, device=x.device)
+ out = torch.remainder(out, num_experts)
+ return out.view(x.shape)
+
+
+_uniform_expert_assignment = _UniformExpertAssignment.apply
+
+
+class LearnedRouter(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+
+ # Learned router parameters.
+ #
+ # NOTE: This weight matrix is not parallelized with expert model
+ # parallelism. Each device needs the entire router weight matrix
+ # so that it can route its batch of data correctly.
+ self.layer = torch.nn.Linear(
+ args.hidden_size,
+ args.moe_num_experts,
+ bias=False,
+ dtype=common.dtype(args),
+ device=args.device,
+ )
+ args.init_method(self.layer.weight)
+
+ def jitter(self, x: torch.Tensor):
+ low: float = 1.0 - self.args.moe_jitter_eps
+ high: float = 1.0 + self.args.moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return low + noise * (high - low)
+
+ def _top_k(self, scores: torch.Tensor):
+ if self.args.moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, self.args.moe_top_k, dim=-1)
+
+ def forward(self, x: torch.Tensor):
+ if self.training and self.args.moe_jitter_eps is not None:
+ x = x * self.jitter(x)
+
+ logits = self.layer(x.view(-1, x.shape[-1]))
+ _save_router_logits(logits, self.args)
+ scores = logits.softmax(dim=-1)
+ expert_weights, expert_indices = self._top_k(scores)
+ if self.args.moe_normalize_expert_weights:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=self.args.moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+
+ expert_indices = (
+ _uniform_expert_assignment(
+ expert_indices,
+ self.args.moe_num_experts,
+ ) if self.args.uniform_expert_assignment else expert_indices
+ )
+ return scores, expert_weights, expert_indices
diff --git a/build/torch29-cxx11-cpu-x86_64-linux/_layers/sharedexpert_registry.py b/build/torch29-cxx11-cpu-x86_64-linux/_layers/sharedexpert_registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..5840862f88f370ace5fd49bd0612fc98d186cc49
--- /dev/null
+++ b/build/torch29-cxx11-cpu-x86_64-linux/_layers/sharedexpert_registry.py
@@ -0,0 +1,32 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Union
+
+# from megablocks.layers import glu, mlp
+# from megablocks.layers.arguments import Arguments
+from . import glu, mlp
+from .arguments import Arguments
+
+_REGISTRY = {
+ 'mlp': mlp.SharedMLP,
+ 'glu': glu.SharedGLU,
+}
+
+
+def get(args: Arguments) -> Union[mlp.SharedMLP, glu.SharedGLU]:
+ """Returns an SharedMLP for use in a dMoE instance.
+
+ Uses the provided arguments to instantiate the appropriate
+ SharedMLP instance.
+
+ Args:
+ args: propagated Arguments dataclass.
+
+ Returns:
+ An instantiated SharedMLP constructed using the input args.
+ """
+ if args.mlp_type not in _REGISTRY:
+ raise ValueError(f'Unsupported mlp type: {args.mlp_type}')
+
+ return _REGISTRY[args.mlp_type](args)
diff --git a/build/torch29-cxx11-cpu-x86_64-linux/_megablocks_cpu_6e04dec.abi3.so b/build/torch29-cxx11-cpu-x86_64-linux/_megablocks_cpu_6e04dec.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..fc992d33633b7be30174d1b5dbbe46f6bb5aaea9
--- /dev/null
+++ b/build/torch29-cxx11-cpu-x86_64-linux/_megablocks_cpu_6e04dec.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:18348238274eb1b281afe628b09ca6a4a5b8267370aaed7bf34a2bd91c9b815b
+size 2201200
diff --git a/build/torch29-cxx11-cpu-x86_64-linux/_ops.py b/build/torch29-cxx11-cpu-x86_64-linux/_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..9010966e70976a4a5febea9802b714fa9a068af4
--- /dev/null
+++ b/build/torch29-cxx11-cpu-x86_64-linux/_ops.py
@@ -0,0 +1,9 @@
+import torch
+from . import _megablocks_cpu_6e04dec
+ops = torch.ops._megablocks_cpu_6e04dec
+
+def add_op_namespace_prefix(op_name: str):
+ """
+ Prefix op by namespace.
+ """
+ return f"_megablocks_cpu_6e04dec::{op_name}"
diff --git a/build/torch29-cxx11-cpu-x86_64-linux/_version.py b/build/torch29-cxx11-cpu-x86_64-linux/_version.py
new file mode 100644
index 0000000000000000000000000000000000000000..c55783177af19bc03654c730c4892df8f8532279
--- /dev/null
+++ b/build/torch29-cxx11-cpu-x86_64-linux/_version.py
@@ -0,0 +1,6 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+"""The MegaBlocks Version."""
+
+__version__ = '0.11.0.dev0'
diff --git a/build/torch29-cxx11-cpu-x86_64-linux/backend/__init__.py b/build/torch29-cxx11-cpu-x86_64-linux/backend/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d4e43e9b7e6a9a1c3bde2df34914643ca5d8332
--- /dev/null
+++ b/build/torch29-cxx11-cpu-x86_64-linux/backend/__init__.py
@@ -0,0 +1,2 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
diff --git a/build/torch29-cxx11-cpu-x86_64-linux/backend/kernels.py b/build/torch29-cxx11-cpu-x86_64-linux/backend/kernels.py
new file mode 100644
index 0000000000000000000000000000000000000000..d20139324883338ddb312e4b05a72056d74491ac
--- /dev/null
+++ b/build/torch29-cxx11-cpu-x86_64-linux/backend/kernels.py
@@ -0,0 +1,557 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import triton
+import triton.language as tl
+
+# Stub triton autotune when testing in a env that does not have CUDA
+# this approach preserves the original code but enables testing without a GPU
+if torch.cuda.is_available() is False:
+ import warnings
+
+ warnings.warn("CUDA is not available. Triton autotuning is disabled.")
+
+ def _no_autotune(*args, **kwargs):
+ def deco(fn):
+ return fn
+ return deco
+
+ triton.autotune = _no_autotune
+
+
+def assert_is_tensor(x, ndim):
+ if x.ndim != ndim:
+ raise ValueError(f'Expected {ndim}-tensor but got {x.ndim}-tensor')
+
+
+def assert_is_matrix(x):
+ assert_is_tensor(x, 2)
+
+
+def assert_is_vector(x):
+ if x.ndim != 1:
+ raise ValueError(f'Expected 1-tensor but got {x.ndim}-tensor')
+
+
+def assert_equal(a, b):
+ if a != b:
+ raise ValueError(f'Expected dimensions to be equal but got {a} and {b}.',)
+
+
+# a: (tokens, hidden_size), real.
+# indices: (tokens * top_k), integer.
+# bin_ids: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+# padded_bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _padded_copy(
+ a,
+ b,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+ A_TO_B: tl.constexpr,
+ SCALE: tl.constexpr,
+):
+ # Our index into array 'a'.
+ index_a = tl.load(indices + tl.program_id(0))
+
+ # One threadblock per row in 'a'. Array 'b' has greater or equal
+ # number of rows since they could be padded.
+ bin_idx = tl.load(bin_ids + tl.program_id(0))
+
+ # Now we know what bin we're assigned to, but we need to know how
+ # many threadblocks were assigned to earlier bins so we can offset
+ # in our bin properly.
+ offset_in_bin = tl.program_id(0)
+ if bin_idx > 0:
+ offset_in_bin -= tl.load(bins + bin_idx - 1)
+
+ # Load the starting index of our bin in array 'b'.
+ index_b = offset_in_bin
+ if bin_idx > 0:
+ index_b += tl.load(padded_bins + bin_idx - 1)
+
+ # Offset the input and output pointers.
+ #
+ # If we're going from A to B, divide the input index to copy
+ # the same input repeatedly. If we're going from B to A we
+ # need to reduce the result. Using atomics is slow, so we
+ # do the reduce step in a second kernel.
+ offset = index_a // TOP_K if A_TO_B else index_a
+ a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS)
+ b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ # Load the scale, if requested.
+ scale = tl.load(weights + index_a) if SCALE else 1
+
+ # Swap the pointers depending on the direction.
+ iptr = a if A_TO_B else b
+ optr = b if A_TO_B else a
+
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ x = tl.load(iptr + offsets, mask=mask)
+ x = x.to(tl.float32) * scale.to(tl.float32)
+
+ tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask)
+
+ offsets += BLOCK_X
+
+
+def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+ assert_equal(bin_ids.shape[0], x.shape[0] * top_k)
+ assert_equal(bins.size(), padded_bins.size())
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ # NOTE: Because of the padding, the output size is dynamic.
+ # We load the final padded bin bound to get the output rows.
+ output_rows = padded_bins[-1].cpu().item()
+ out = torch.zeros((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ x,
+ out,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def gather(x, indices, bin_ids, weights, bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+ assert_equal(bin_ids.shape[0], x.shape[0] * top_k)
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ # NOTE: There is no padding so the output rows equals the
+ # input rows multiplied by top_k.
+ output_rows = x.shape[0] * top_k
+ out = torch.empty((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ x,
+ out,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], bin_ids.shape[0])
+ assert_equal(bins.size(), padded_bins.size())
+
+ if weights is not None:
+ assert_equal(indices.shape[0], weights.shape[0])
+
+ tokens = indices.shape[0] // top_k
+ out = torch.empty((tokens, top_k, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ out,
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=False,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+
+ # Reduce along the top-k dimension, if needed.
+ return out.sum(dim=1) if top_k > 1 else out.view(tokens, x.shape[1])
+
+
+def scatter(x, indices, bin_ids, weights, bins, top_k):
+ return padded_scatter(x, indices, bin_ids, weights, bins, bins, top_k)
+
+
+# x: (tokens, top_k, hidden_size), real
+# grad: (tokens, hidden_size), real.
+# wgrad: (tokens, top_k), real.
+# indices: (tokens * top_k), integer.
+# bin_ids: (tokens * top_k), integer.
+# bins: (num_experts), integer.
+# padded_bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _padded_copy_wgrad(
+ x,
+ grad,
+ wgrad,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+):
+ # Our index into 'tokens * top_k'.
+ index_out = tl.load(indices + tl.program_id(0))
+
+ # One threadblock per row in 'a'. Array 'b' has greater or equal
+ # number of rows since they could be padded.
+ bin_idx = tl.load(bin_ids + tl.program_id(0))
+
+ # Now we know what bin we're assigned to, but we need to know how
+ # many threadblocks were assigned to earlier bins so we can offset
+ # in our bin properly.
+ offset_in_bin = tl.program_id(0)
+ if bin_idx > 0:
+ offset_in_bin -= tl.load(bins + bin_idx - 1)
+
+ # Load the starting index of our bin in array 'x'.
+ index_x = offset_in_bin
+ if bin_idx > 0:
+ index_x += tl.load(padded_bins + bin_idx - 1)
+
+ # Offset the input and output pointers.
+ wgrad += index_out
+ grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS)
+ x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ data = tl.load(x + offsets, mask=mask).to(tl.float32)
+ scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
+ acc += data * scale
+ offsets += BLOCK_X
+
+ # Reduce to get the final result and store.
+ out = tl.sum(acc).to(wgrad.dtype.element_ty)
+ tl.store(wgrad, out)
+
+
+def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_matrix(grad)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], bin_ids.shape[0])
+ assert_equal(bins.size(), padded_bins.size())
+
+ tokens = indices.shape[0] // top_k
+ out = torch.empty((tokens * top_k), dtype=x.dtype, device=x.device)
+ _padded_copy_wgrad[(indices.shape[0],)](
+ x,
+ grad,
+ out,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ TOP_K=top_k,
+ )
+ return out
+
+
+def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k):
+ return padded_scatter_wgrad(x, grad, indices, bin_ids, bins, bins, top_k)
+
+
+# a: (tokens, hidden_size), real.
+# b: (num_experts, expert_capacity, num_columns), real.
+# indices: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _binned_copy(
+ a,
+ b,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+ A_TO_B: tl.constexpr,
+ SCALE: tl.constexpr,
+):
+ # Load our indices into the output.
+ expert_idx = tl.program_id(0)
+ entry_idx = tl.program_id(1)
+
+ # Calculate our offset into the output.
+ index_b = expert_idx * expert_capacity + entry_idx
+
+ # Load the index bounds for our bin and calculate
+ # the number of tokens assigned to our expert.
+ start = 0
+ if expert_idx > 0:
+ start = tl.load(bins + expert_idx - 1)
+ end = tl.load(bins + expert_idx)
+ num_tokens = end - start
+
+ # Calculate our offset into the input. If we don't
+ # have an input exit early.
+ if entry_idx >= num_tokens:
+ return
+ index_a = tl.load(indices + start + entry_idx)
+
+ # Offset the input and output pointers.
+ #
+ # If we're going from A to B, divide the input index to copy
+ # the same input repeatedly. If we're going from B to A we
+ # need to reduce the result. Using atomics is slow, so we
+ # do the reduce step in a second kernel.
+ offset = index_a // TOP_K if A_TO_B else index_a
+ a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS)
+ b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ # Load the scale, if requested.
+ scale = tl.load(weights + index_a) if SCALE else 1
+
+ # Swap the pointers depending on the direction.
+ #
+ # NOTE: We need to zero the output in both directions.
+ iptr = a if A_TO_B else b
+ optr = b if A_TO_B else a
+
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ x = tl.load(iptr + offsets, mask=mask)
+ x = x.to(tl.float32) * scale.to(tl.float32)
+
+ tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask)
+
+ offsets += BLOCK_X
+
+
+def binned_gather(x, indices, weights, bins, expert_capacity, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ num_experts = bins.shape[0]
+ out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
+
+ _binned_copy[(num_experts, expert_capacity)](
+ x,
+ out,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def binned_scatter(x, indices, weights, bins, top_k):
+ # Validate the input shapes.
+ assert_is_tensor(x, 3)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(bins.shape[0], x.shape[0])
+
+ if weights is not None:
+ assert_equal(indices.shape[0], weights.shape[0])
+
+ num_experts, expert_capacity, hidden_size = x.shape
+ tokens = indices.shape[0] // top_k
+ out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
+ _binned_copy[(num_experts, expert_capacity)](
+ out,
+ x,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS=hidden_size,
+ A_TO_B=False,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+
+ # Reduce along the top-k dimension, if needed.
+ return out.sum(dim=1) if top_k > 1 else out.view(tokens, hidden_size)
+
+
+# a: (tokens, hidden_size), real.
+# b: (num_experts, expert_capacity, num_columns), real.
+# indices: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _binned_copy_wgrad(
+ x,
+ grad,
+ wgrad,
+ num_experts,
+ expert_capacity,
+ indices,
+ bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+):
+ # Load our indices into the output.
+ expert_idx = tl.program_id(0)
+ entry_idx = tl.program_id(1)
+
+ # Calculate our offset into the output.
+ index_x = expert_idx * expert_capacity + entry_idx
+
+ # Load the index bounds for our bin and calculate
+ # the number of tokens assigned to our expert.
+ start = 0
+ if expert_idx > 0:
+ start = tl.load(bins + expert_idx - 1)
+ end = tl.load(bins + expert_idx)
+ num_tokens = end - start
+
+ # Calculate our offset into the input. If we don't
+ # have an input exit early.
+ if entry_idx >= num_tokens:
+ return
+ index_out = tl.load(indices + start + entry_idx)
+
+ # Offset the input and output pointers.
+ wgrad += index_out
+ grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS)
+ x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ data = tl.load(x + offsets, mask=mask).to(tl.float32)
+ scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
+ acc += data * scale
+ offsets += BLOCK_X
+
+ # Reduce to get the final result and store.
+ out = tl.sum(acc).to(wgrad.dtype.element_ty)
+ tl.store(wgrad, out)
+
+
+def binned_scatter_wgrad(x, grad, indices, bins, top_k):
+ # Validate the input shapes.
+ assert_is_tensor(x, 3)
+ assert_is_matrix(grad)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(bins.shape[0], x.shape[0])
+
+ num_experts, expert_capacity, hidden_size = x.shape
+ tokens = indices.shape[0] // top_k
+ out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
+ _binned_copy_wgrad[(num_experts, expert_capacity)](
+ x,
+ grad,
+ out,
+ num_experts,
+ expert_capacity,
+ indices,
+ bins,
+ NUM_COLUMNS=hidden_size,
+ TOP_K=top_k,
+ )
+ return out
diff --git a/build/torch29-cxx11-cpu-x86_64-linux/benchmark_util.py b/build/torch29-cxx11-cpu-x86_64-linux/benchmark_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..02612d95e3ead1175a596e2878fa34b5bf85ad6f
--- /dev/null
+++ b/build/torch29-cxx11-cpu-x86_64-linux/benchmark_util.py
@@ -0,0 +1,35 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import numpy as np
+import torch
+
+
+def log_benchmark(name, arguments, time, std):
+ print('=' * 60)
+ print(f'{name} Benchmark')
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean time = {:.3f}ms, std time = {:.3f}ms'.format(time, std))
+ print('=' * 60)
+
+
+def benchmark_function(fn, iterations=100, warmup=10):
+ # Warmup iterations.
+ for _ in range(warmup):
+ fn()
+
+ times = []
+ for i in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+
+ start.record()
+ fn()
+ end.record()
+
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ return np.mean(times), np.std(times)
diff --git a/build/torch29-cxx11-cpu-x86_64-linux/cpu_fused_moe.py b/build/torch29-cxx11-cpu-x86_64-linux/cpu_fused_moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a29e81d6e9c57f64b0a78ac1c0828e45fd9d855
--- /dev/null
+++ b/build/torch29-cxx11-cpu-x86_64-linux/cpu_fused_moe.py
@@ -0,0 +1,311 @@
+# SPDX-License-Identifier: Apache-2.0
+# MegaBlocks CPU Fused MoE Implementation
+#
+# This is a pure Python/PyTorch implementation for CPU.
+# For better performance, consider using the C++ kernel implementation.
+#
+import torch
+import torch.nn.functional as F
+
+
+def swigluoai_activation(gate: torch.Tensor, up: torch.Tensor,
+ alpha: float = 1.702, limit: float = 7.0) -> torch.Tensor:
+ """
+ SwigluOAI activation function used in GptOss models.
+
+ Formula:
+ gate = clamp(gate, max=limit)
+ up = clamp(up, -limit, limit)
+ glu = gate * sigmoid(gate * alpha)
+ output = (up + 1) * glu
+
+ Args:
+ gate: Gate tensor from gate projection
+ up: Up tensor from up projection
+ alpha: Scaling factor for sigmoid (default: 1.702)
+ limit: Clamp limit (default: 7.0)
+
+ Returns:
+ Activated tensor
+ """
+ gate = gate.clamp(max=limit)
+ up = up.clamp(min=-limit, max=limit)
+ glu = gate * torch.sigmoid(gate * alpha)
+ return (up + 1) * glu
+
+
+def silu_and_mul_activation(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor:
+ """
+ SiLU (Swish) activation with element-wise multiplication.
+
+ Formula:
+ output = silu(gate) * up
+
+ Args:
+ gate: Gate tensor
+ up: Up tensor
+
+ Returns:
+ Activated tensor
+ """
+ return F.silu(gate) * up
+
+
+def route_tokens_cpu(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: torch.Tensor | None,
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_normalize_expert_weights: int | None = None,
+) -> tuple:
+ """
+ Route tokens to experts and compute expert weights and indices (CPU version).
+
+ Args:
+ x: Input tensor [batch, seq, hidden] or [tokens, hidden]
+ router_weight: Router weight [num_experts, hidden]
+ router_bias: Router bias [num_experts] or None
+ moe_top_k: Number of experts per token
+ moe_num_experts: Total number of experts
+ moe_normalize_expert_weights: Normalization order or None
+
+ Returns:
+ Tuple of (logits, expert_weights, expert_indices)
+ """
+ x_flat = x.view(-1, x.shape[-1])
+ logits = F.linear(x_flat, router_weight, router_bias)
+
+ if moe_top_k == 1:
+ expert_weights, expert_indices = logits.max(dim=-1, keepdim=True)
+ else:
+ expert_weights, expert_indices = torch.topk(logits, moe_top_k, dim=-1)
+
+ expert_weights = expert_weights.softmax(dim=-1)
+
+ if moe_normalize_expert_weights is not None:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+
+ return logits, expert_weights, expert_indices
+
+
+def cpu_fused_moe(
+ hidden_states: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ w1_bias: torch.Tensor | None = None,
+ w2_bias: torch.Tensor | None = None,
+ activation: str = "silu",
+ alpha: float = 1.702,
+ limit: float = 7.0,
+ is_interleaved: bool = True,
+) -> torch.Tensor:
+ """
+ CPU Fused MoE using PyTorch operations.
+
+ This implementation processes all experts in parallel using batched operations
+ instead of sequential for loops, which is more efficient on CPU.
+
+ Args:
+ hidden_states: [num_tokens, hidden_size]
+ w1: [num_experts, hidden_size, 2*inter_size] - gate_up_proj weights
+ w2: [num_experts, inter_size, hidden_size] - down_proj weights
+ topk_weights: [num_tokens, topk] - routing weights
+ topk_ids: [num_tokens, topk] - expert indices
+ w1_bias: [num_experts, 2*inter_size] or None
+ w2_bias: [num_experts, hidden_size] or None
+ activation: "silu" or "swigluoai"
+ alpha: swigluoai alpha parameter
+ limit: swigluoai limit parameter
+ is_interleaved: whether gate_up is interleaved [g0,u0,g1,u1,...] (True for GptOss)
+
+ Returns:
+ output: [num_tokens, hidden_size]
+ """
+ num_tokens, hidden_size = hidden_states.shape
+ num_experts = w1.shape[0]
+ inter_size = w2.shape[1]
+ topk = topk_weights.shape[1]
+
+ # Initialize output
+ output = torch.zeros_like(hidden_states)
+
+ # Build expert mask: which tokens go to which expert
+ # expert_mask[expert_id] contains indices of (token_idx, topk_pos) pairs
+ for expert_idx in range(num_experts):
+ # Find tokens assigned to this expert
+ # mask shape: [num_tokens, topk], True where topk_ids == expert_idx
+ mask = (topk_ids == expert_idx)
+
+ if not mask.any():
+ continue
+
+ # Get token indices and topk positions
+ token_indices, topk_positions = torch.where(mask)
+
+ if len(token_indices) == 0:
+ continue
+
+ # Gather input tokens for this expert
+ # current_hidden: [num_selected_tokens, hidden_size]
+ current_hidden = hidden_states[token_indices]
+
+ # Get weights for this expert
+ # w1[expert_idx]: [hidden_size, 2*inter_size]
+ # w2[expert_idx]: [inter_size, hidden_size]
+ expert_w1 = w1[expert_idx] # [hidden_size, 2*inter_size]
+ expert_w2 = w2[expert_idx] # [inter_size, hidden_size]
+
+ # First projection: hidden @ w1 -> [num_selected, 2*inter_size]
+ gate_up = current_hidden @ expert_w1
+
+ # Add bias if present
+ if w1_bias is not None:
+ gate_up = gate_up + w1_bias[expert_idx]
+
+ # Split gate and up projections
+ if is_interleaved:
+ # GptOss uses interleaved layout: [g0, u0, g1, u1, ...]
+ gate = gate_up[..., ::2] # [num_selected, inter_size]
+ up = gate_up[..., 1::2] # [num_selected, inter_size]
+ else:
+ # Standard layout: [gate_all, up_all]
+ gate = gate_up[..., :inter_size]
+ up = gate_up[..., inter_size:]
+
+ # Apply activation
+ if activation == "swigluoai":
+ activated = swigluoai_activation(gate, up, alpha, limit)
+ else: # silu
+ activated = silu_and_mul_activation(gate, up)
+
+ # Second projection: activated @ w2 -> [num_selected, hidden_size]
+ expert_out = activated @ expert_w2
+
+ # Add bias if present
+ if w2_bias is not None:
+ expert_out = expert_out + w2_bias[expert_idx]
+
+ # Apply routing weights and accumulate
+ # weights shape: [num_selected]
+ weights = topk_weights[token_indices, topk_positions].unsqueeze(-1)
+ weighted_out = expert_out * weights
+
+ # Accumulate to output
+ output.index_add_(0, token_indices, weighted_out.to(output.dtype))
+
+ return output
+
+
+class MegaBlocksMoeMLP(torch.nn.Module):
+ """
+ CPU MoE MLP module that can be used as a drop-in replacement for
+ the transformers GptOssMLP when using @use_kernel_forward_from_hub.
+ """
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> tuple:
+ """
+ Forward pass through the MoE layer.
+
+ Args:
+ x: Input tensor of shape [batch_size, seq_len, hidden_size] or [tokens, hidden_size]
+
+ Returns:
+ Tuple of (output, expert_weights) where:
+ - output: Tensor of same shape as input
+ - expert_weights: Expert weights for each token [tokens, top_k]
+ """
+ # Get MoE parameters from the wrapped modules
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+
+ # Detect activation type
+ if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
+ activation = "swigluoai"
+ alpha = self.experts.alpha
+ limit = self.experts.limit
+ else:
+ activation = getattr(self.experts, "activation", "silu")
+ alpha = 1.702
+ limit = 7.0
+
+ # Get weight tensors
+ if hasattr(self.experts, "gate_up_proj"):
+ w1 = self.experts.gate_up_proj
+ is_interleaved = True # GptOss uses interleaved layout
+ elif hasattr(self.experts, "w1"):
+ w1 = self.experts.w1
+ w3 = getattr(self.experts, "w3", None)
+ if w3 is not None:
+ w1 = torch.cat([w1, w3], dim=-1)
+ is_interleaved = False
+ else:
+ raise AttributeError("experts module must have 'gate_up_proj' or 'w1' attribute")
+
+ if hasattr(self.experts, "down_proj"):
+ w2 = self.experts.down_proj
+ elif hasattr(self.experts, "w2"):
+ w2 = self.experts.w2
+ else:
+ raise AttributeError("experts module must have 'down_proj' or 'w2' attribute")
+
+ # Get optional bias tensors
+ w1_bias = getattr(self.experts, "gate_up_proj_bias", None)
+ w2_bias = getattr(self.experts, "down_proj_bias", None)
+
+ # Store original shape
+ in_shape = x.size()
+
+ # Route tokens to experts
+ logits, expert_weights, expert_indices = route_tokens_cpu(
+ x,
+ self.router.weight,
+ getattr(self.router, "bias", None),
+ moe_top_k,
+ moe_num_experts,
+ moe_normalize_expert_weights,
+ )
+
+ # Reshape input for fused MoE
+ x_flat = x.view(-1, x.shape[-1])
+
+ # Call CPU fused MoE
+ output = cpu_fused_moe(
+ hidden_states=x_flat,
+ w1=w1,
+ w2=w2,
+ topk_weights=expert_weights,
+ topk_ids=expert_indices,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ activation=activation,
+ alpha=alpha,
+ limit=limit,
+ is_interleaved=is_interleaved,
+ )
+
+ # Restore original shape
+ output = output.view(in_shape)
+
+ return output, expert_weights
+
+
+# Export classes and functions
+__all__ = [
+ "MegaBlocksMoeMLP",
+ "cpu_fused_moe",
+ "route_tokens_cpu",
+ "swigluoai_activation",
+ "silu_and_mul_activation",
+]
diff --git a/build/torch29-cxx11-cpu-x86_64-linux/cpu_moe_cpp.py b/build/torch29-cxx11-cpu-x86_64-linux/cpu_moe_cpp.py
new file mode 100644
index 0000000000000000000000000000000000000000..073ff66d24ce348fbb5ed19c9027fadd3f7a9c61
--- /dev/null
+++ b/build/torch29-cxx11-cpu-x86_64-linux/cpu_moe_cpp.py
@@ -0,0 +1,265 @@
+# SPDX-License-Identifier: Apache-2.0
+# MegaBlocks C++ Optimized CPU MoE
+
+"""
+C++ accelerated MoE with brgemm optimization for Intel AMX.
+Direct replacement for cpu_fused_moe.MegaBlocksMoeMLP with better performance.
+"""
+
+import torch
+from typing import Optional
+from .cpu_fused_moe import route_tokens_cpu
+from ._ops import ops
+
+
+def _to_local_tensor(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
+ """Convert DTensor to local torch.Tensor if needed for custom ops compatibility."""
+ if tensor is None:
+ return None
+ # Check if it's a DTensor by looking for the to_local() method
+ if hasattr(tensor, "to_local"):
+ return tensor.to_local()
+ return tensor
+
+
+def fused_moe_cpp(
+ hidden_states: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ inplace: bool = False,
+ use_int8_w8a8: bool = False,
+ use_fp8_w8a16: bool = False,
+ use_mxfp4: bool = False,
+ w1_scale: Optional[torch.Tensor] = None,
+ w2_scale: Optional[torch.Tensor] = None,
+ block_size: Optional[list] = None,
+ a1_scale: Optional[torch.Tensor] = None,
+ a2_scale: Optional[torch.Tensor] = None,
+ w1_bias: Optional[torch.Tensor] = None,
+ w2_bias: Optional[torch.Tensor] = None,
+ alpha: Optional[float] = None,
+ limit: Optional[float] = None,
+ is_vnni: bool = False,
+) -> torch.Tensor:
+ """
+ C++ Fused MoE with brgemm optimization (sglang compatible interface).
+
+ Uses at::native::cpublas::brgemm for efficient batch GEMM on Intel CPUs.
+ Supports both silu_and_mul (standard SwiGLU) and swigluoai (GptOss) activations.
+
+ Args:
+ hidden_states: Input tensor [M, K]
+ w1: Gate and up projections [E, 2N, K]
+ w2: Down projection [E, K, N]
+ topk_weights: Expert weights [M, topk]
+ topk_ids: Expert indices [M, topk]
+ inplace: Whether to use hidden_states as output
+ use_int8_w8a8: Use int8 quantization
+ use_fp8_w8a16: Use fp8 quantization
+ use_mxfp4: Use mxfp4 quantization
+ w1_scale, w2_scale: Quantization scales
+ block_size: Block size for fp8
+ a1_scale, a2_scale: Activation scales
+ w1_bias, w2_bias: Optional biases
+ alpha: swigluoai alpha parameter (set to enable swiglu)
+ limit: swigluoai limit parameter (set to enable swiglu)
+ is_vnni: Whether w1/w2 are already in VNNI packed format
+ """
+ # MXFP4/FP8 kernels only support bf16, convert if needed
+ orig_dtype = hidden_states.dtype
+ need_convert = ((use_mxfp4 or use_fp8_w8a16) and orig_dtype != torch.bfloat16) or orig_dtype == torch.float32
+ if need_convert:
+ hidden_states = hidden_states.to(torch.bfloat16)
+
+ # bias must match hidden_states dtype
+ if w1_bias is not None:
+ w1_bias = w1_bias.to(hidden_states.dtype)
+ if w2_bias is not None:
+ w2_bias = w2_bias.to(hidden_states.dtype)
+
+ # Convert DTensor to local tensor for custom ops compatibility (TP mode)
+ hidden_states = _to_local_tensor(hidden_states)
+ w1 = _to_local_tensor(w1)
+ w2 = _to_local_tensor(w2)
+ topk_weights = _to_local_tensor(topk_weights)
+ topk_ids = _to_local_tensor(topk_ids)
+ w1_scale = _to_local_tensor(w1_scale)
+ w2_scale = _to_local_tensor(w2_scale)
+ a1_scale = _to_local_tensor(a1_scale)
+ a2_scale = _to_local_tensor(a2_scale)
+ w1_bias = _to_local_tensor(w1_bias)
+ w2_bias = _to_local_tensor(w2_bias)
+
+ output = ops.fused_experts(
+ hidden_states, w1, w2, topk_weights, topk_ids,
+ inplace, use_int8_w8a8, use_fp8_w8a16, use_mxfp4,
+ w1_scale, w2_scale, block_size, a1_scale, a2_scale,
+ w1_bias, w2_bias, alpha, limit, is_vnni
+ )
+
+ # Convert back to original dtype if needed
+ if need_convert:
+ output = output.to(orig_dtype)
+ return output
+
+
+class CPUMegaBlocksMoeMLP(torch.nn.Module):
+ """
+ C++ optimized MoE MLP using brgemm.
+ Drop-in replacement for cpu_fused_moe.MegaBlocksMoeMLP with better performance.
+
+ Usage in transformers:
+ # Will be used via @use_kernel_forward_from_hub decorator
+ """
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> tuple:
+ """
+ Forward pass through the MoE layer using C++ kernel.
+
+ Args:
+ x: Input tensor [batch_size, seq_len, hidden_size]
+
+ Returns:
+ Tuple of (output, expert_weights)
+ """
+ # Optimization for GPT-OSS model
+ if getattr(self, "use_mxfp4", None) is None:
+ self.use_mxfp4 = False
+
+ w1_scale = None
+ w2_scale = None
+
+ if (
+ not getattr(self, "packed_scales", False)
+ and hasattr(self.experts, "gate_up_proj")
+ and getattr(self.experts, "gate_up_proj_precision_config", None) is not None
+ ):
+ # convert scales
+ data_1 = ops.convert_scale_packed(self.experts.gate_up_proj_precision_config.weight_scale.data.transpose(-1, -2).contiguous())
+ data_2 = ops.convert_scale_packed(self.experts.down_proj_precision_config.weight_scale.data.transpose(-1, -2).contiguous())
+ self.experts.gate_up_proj_precision_config.weight_scale.storage.data = data_1
+ self.experts.down_proj_precision_config.weight_scale.storage.data = data_2
+ self.packed_scales = True
+ self.use_mxfp4 = True
+
+ if not getattr(self, "packed_weight", False) and hasattr(
+ self.experts, "gate_up_proj"
+ ):
+ # convert weights
+ data_1 = self.experts.gate_up_proj.data.transpose(-1, -2).contiguous()
+ data_2 = self.experts.down_proj.data.transpose(-1, -2).contiguous()
+ if self.use_mxfp4:
+ self.experts.gate_up_proj.storage.data = ops.convert_weight_packed(data_1)
+ self.experts.down_proj.storage.data = ops.convert_weight_packed(data_2)
+ else:
+ # convert_weight_packed only supports bfloat16, float16, int8, fp8_e4m3 or uint8(mxfp4 or int4).
+ data_1 = data_1.to(torch.bfloat16) if data_1.dtype == torch.float32 else data_1
+ data_2 = data_2.to(torch.bfloat16) if data_2.dtype == torch.float32 else data_2
+ self.experts.gate_up_proj.data = ops.convert_weight_packed(data_1)
+ self.experts.down_proj.data = ops.convert_weight_packed(data_2)
+
+ # C++ kernel does not support float32.
+ dtype = torch.bfloat16 if x.dtype == torch.float32 else x.dtype
+ if getattr(self.experts, "gate_up_proj_bias", None) is not None:
+ self.experts.gate_up_proj_bias.data = self.experts.gate_up_proj_bias.data.to(dtype)
+ if getattr(self.experts, "down_proj_bias", None) is not None:
+ self.experts.down_proj_bias.data = self.experts.down_proj_bias.data.to(dtype)
+
+ self.packed_weight = True
+
+ # Get MoE parameters
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
+
+ # Detect activation type
+ if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
+ activation = "swigluoai"
+ alpha = self.experts.alpha
+ limit = self.experts.limit
+ else:
+ activation = getattr(self.experts, "activation", "silu")
+ alpha = 1.702
+ limit = 7.0
+
+ # Get weight tensors
+ if hasattr(self.experts, "gate_up_proj"):
+ w1 = self.experts.gate_up_proj
+ elif hasattr(self.experts, "w1"):
+ w1 = self.experts.w1
+ w3 = getattr(self.experts, "w3", None)
+ if w3 is not None:
+ w1 = torch.cat([w1, w3], dim=-1)
+ else:
+ raise AttributeError("experts module must have 'gate_up_proj' or 'w1' attribute")
+
+ if hasattr(self.experts, "down_proj"):
+ w2 = self.experts.down_proj
+ elif hasattr(self.experts, "w2"):
+ w2 = self.experts.w2
+ else:
+ raise AttributeError("experts module must have 'down_proj' or 'w2' attribute")
+
+ # Get optional bias tensors
+ w1_bias = getattr(self.experts, "gate_up_proj_bias", None)
+ w2_bias = getattr(self.experts, "down_proj_bias", None)
+ w1_bias = w1_bias if w1_bias is None else w1_bias.data
+ w2_bias = w2_bias if w2_bias is None else w2_bias.data
+
+ if self.use_mxfp4:
+ w1_scale = self.experts.gate_up_proj_precision_config.weight_scale.data
+ w2_scale = self.experts.down_proj_precision_config.weight_scale.data
+
+ # Store original shape
+ in_shape = x.size()
+
+ # Route tokens to experts (Python implementation is fast enough)
+ logits, expert_weights, expert_indices = route_tokens_cpu(
+ x,
+ self.router.weight,
+ getattr(self.router, "bias", None),
+ moe_top_k,
+ moe_num_experts,
+ moe_normalize_expert_weights,
+ )
+
+ # Flatten input
+ x_flat = x.view(-1, x.shape[-1])
+
+ # Determine alpha/limit for swiglu activation
+ use_alpha = alpha if activation == "swigluoai" else None
+ use_limit = limit if activation == "swigluoai" else None
+
+ # Call C++ optimized kernel
+ output = fused_moe_cpp(
+ hidden_states=x_flat,
+ w1=w1.data,
+ w2=w2.data,
+ topk_weights=expert_weights,
+ topk_ids=expert_indices.to(torch.int32),
+ inplace=False,
+ use_int8_w8a8=False,
+ use_fp8_w8a16=False,
+ use_mxfp4=self.use_mxfp4,
+ w1_scale=w1_scale,
+ w2_scale=w2_scale,
+ block_size=None,
+ a1_scale=None,
+ a2_scale=None,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ alpha=use_alpha,
+ limit=use_limit,
+ is_vnni=getattr(self, "packed_weight", False),
+ )
+
+ # Restore original shape
+ output = output.view(in_shape)
+
+ return output, expert_weights
+
+
+__all__ = ["fused_moe_cpp", "MegaBlocksMoeMLP"]
diff --git a/build/torch29-cxx11-cpu-x86_64-linux/grouped_gemm/__init__.py b/build/torch29-cxx11-cpu-x86_64-linux/grouped_gemm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b91c8308f0c24f4c4171b6e4f15b6f76dabf295a
--- /dev/null
+++ b/build/torch29-cxx11-cpu-x86_64-linux/grouped_gemm/__init__.py
@@ -0,0 +1,2 @@
+from . import ops
+from . import backend
diff --git a/build/torch29-cxx11-cpu-x86_64-linux/grouped_gemm/backend.py b/build/torch29-cxx11-cpu-x86_64-linux/grouped_gemm/backend.py
new file mode 100644
index 0000000000000000000000000000000000000000..76037d8039cbfc2f0577275c78e4bc0be762592a
--- /dev/null
+++ b/build/torch29-cxx11-cpu-x86_64-linux/grouped_gemm/backend.py
@@ -0,0 +1,33 @@
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# # TODO(tgale): Wrap this in a try-block with better
+# # error message and instructions for building the
+# # c++ operations.
+# import grouped_gemm_backend as backend
+
+# We import the backend operations from the megablocks package as
+# grouped_gemm is vendored in megablocks in this repository.
+# from ... import _ops as backend
+# from megablocks._ops import ops as backend # type: ignore
+from .._ops import ops as backend # type: ignore
+
+def _allocate_output(a, b, batch_sizes, trans_a, trans_b):
+ assert not (trans_a and trans_b)
+ assert batch_sizes.ndim == 1, "Expected 1d tensor for batch_sizes"
+ assert a.ndim == 2, "Expected 2d tensor for 'a'"
+ assert b.ndim == (2 if trans_a else 3)
+
+ shape = (
+ (batch_sizes.shape[0], a.shape[1], b.shape[1])
+ if trans_a else
+ (a.shape[0], (b.shape[1] if trans_b else b.shape[2]))
+ )
+ return torch.empty(*shape, device=a.device, dtype=a.dtype)
+
+def gmm(a, b, batch_sizes, trans_a=False, trans_b=False, c=None):
+ if c is None:
+ c = _allocate_output(a, b, batch_sizes, trans_a, trans_b)
+ backend.gmm(a, b, c, batch_sizes, trans_a, trans_b)
+ return c
diff --git a/build/torch29-cxx11-cpu-x86_64-linux/grouped_gemm/ops.py b/build/torch29-cxx11-cpu-x86_64-linux/grouped_gemm/ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b30dd14e23837ea3b12334f4e31337ed9ad2b69
--- /dev/null
+++ b/build/torch29-cxx11-cpu-x86_64-linux/grouped_gemm/ops.py
@@ -0,0 +1,33 @@
+from . import backend
+import torch
+
+
+class GroupedGemm(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, a, b, batch_sizes, trans_b):
+ ctx.save_for_backward(a, b, batch_sizes)
+ ctx.trans_b = trans_b
+ return backend.gmm(a, b, batch_sizes, trans_a=False, trans_b=trans_b)
+
+ @staticmethod
+ def backward(ctx, grad):
+ grad = grad.contiguous()
+ a, b, batch_sizes = ctx.saved_tensors
+ trans_b = ctx.trans_b
+
+ agrad = None
+ if ctx.needs_input_grad[0]:
+ agrad = backend.gmm(
+ grad, b, batch_sizes, trans_a=False, trans_b=not trans_b)
+
+ bgrad = None
+ if ctx.needs_input_grad[1]:
+ lhs, rhs = (grad, a) if trans_b else (a, grad)
+ bgrad = backend.gmm(
+ lhs, rhs, batch_sizes, trans_a=True, trans_b=False)
+ return agrad, bgrad, None, None
+
+
+def gmm(a, b, batch_sizes, trans_b=False):
+ return GroupedGemm.apply(a, b, batch_sizes, trans_b)
diff --git a/build/torch29-cxx11-cpu-x86_64-linux/grouped_gemm_util.py b/build/torch29-cxx11-cpu-x86_64-linux/grouped_gemm_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d6d49fc46b0a57ad46e4179df3cc1ac2a24f7ae
--- /dev/null
+++ b/build/torch29-cxx11-cpu-x86_64-linux/grouped_gemm_util.py
@@ -0,0 +1,31 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+import warnings
+
+_grouped_gemm_is_available: bool = False
+try:
+ # import grouped_gemm
+ pass
+ _grouped_gemm_is_available = True
+except ImportError as error:
+ warnings.warn('Grouped GEMM not available.')
+
+
+def grouped_gemm_is_available():
+ return _grouped_gemm_is_available
+
+
+def assert_grouped_gemm_is_available():
+ msg = (
+ 'Grouped GEMM not available. Please run '
+ '`pip install git+https://github.com/tgale96/grouped_gemm@main`.',
+ )
+ assert _grouped_gemm_is_available, msg
+
+
+# backend = grouped_gemm.backend if grouped_gemm_is_available() else None
+# ops = grouped_gemm.ops if grouped_gemm_is_available() else None
+
+
+from .grouped_gemm import backend as ops
+from .grouped_gemm import ops as backend
diff --git a/build/torch29-cxx11-cpu-x86_64-linux/layers.py b/build/torch29-cxx11-cpu-x86_64-linux/layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c91edcd9e2d1a4ef9eac90217ff481f08ab1886
--- /dev/null
+++ b/build/torch29-cxx11-cpu-x86_64-linux/layers.py
@@ -0,0 +1,1232 @@
+import torch
+import torch.distributed as dist
+
+from typing import Optional, Any, TYPE_CHECKING
+
+from . import _layers
+from . import ops
+
+# Conditional import for meta kernel registration
+if TYPE_CHECKING:
+
+ def register_fake(fn):
+ return lambda name: fn
+
+else:
+ try:
+ from torch.library import register_fake
+ except ImportError:
+ try:
+ from torch.library import impl_abstract as register_fake
+ except ImportError:
+ # Fallback for older PyTorch versions
+ def register_fake(op_name):
+ def decorator(fn):
+ return fn
+
+ return decorator
+
+
+# Meta kernel implementations for torch.compile compatibility
+def _install_meta_kernels():
+ """Install meta kernels for existing MegaBlocks operations"""
+
+ # Create wrapper functions that check for compilation and return meta tensors
+
+ # Patch ops.sort
+ if hasattr(ops, "sort"):
+ original_sort = ops.sort
+
+ def sort_with_meta(x, end_bit=None):
+ if torch.compiler.is_compiling():
+ # print("Using meta kernel for sort")
+ # Meta implementation - return tensors with correct shape/dtype/device
+ return torch.empty_like(x), torch.empty_like(x)
+ # print("Using original sort kernel")
+ return original_sort(x, end_bit)
+
+ ops.sort = sort_with_meta
+
+ # Patch ops.histogram
+ if hasattr(ops, "histogram"):
+ original_histogram = ops.histogram
+
+ def histogram_with_meta(x, max_val):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty((max_val,), dtype=torch.int32, device=x.device)
+ return original_histogram(x, max_val)
+
+ ops.histogram = histogram_with_meta
+
+ # Patch ops.inclusive_cumsum
+ if hasattr(ops, "inclusive_cumsum"):
+ original_inclusive_cumsum = ops.inclusive_cumsum
+
+ def inclusive_cumsum_with_meta(x, dim):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty_like(x)
+ return original_inclusive_cumsum(x, dim)
+
+ ops.inclusive_cumsum = inclusive_cumsum_with_meta
+
+ # Patch ops.binned_gather
+ if hasattr(ops, "binned_gather"):
+ original_binned_gather = ops.binned_gather
+
+ def binned_gather_with_meta(x, indices, bins, bin_size, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - output shape based on bin_size
+ if x.dim() >= 2:
+ hidden_size = x.size(-1)
+ return torch.empty(
+ (bin_size, x.size(1), hidden_size),
+ dtype=x.dtype,
+ device=x.device,
+ )
+ else:
+ return torch.empty((bin_size,), dtype=x.dtype, device=x.device)
+ return original_binned_gather(x, indices, bins, bin_size, top_k)
+
+ ops.binned_gather = binned_gather_with_meta
+
+ # Patch ops.binned_scatter
+ if hasattr(ops, "binned_scatter"):
+ original_binned_scatter = ops.binned_scatter
+
+ def binned_scatter_with_meta(x, indices, weights, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - typically reduces to 2D
+ if x.dim() >= 3:
+ return torch.empty(
+ (x.size(1), x.size(2)), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty_like(x)
+ return original_binned_scatter(x, indices, weights, bins, top_k)
+
+ ops.binned_scatter = binned_scatter_with_meta
+
+ # Patch ops.gather
+ if hasattr(ops, "gather"):
+ original_gather = ops.gather
+
+ def gather_with_meta(x, indices, bin_ids, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ if x.dim() >= 2:
+ hidden_size = x.size(-1)
+ return torch.empty(
+ (indices.numel(), hidden_size), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty(indices.shape, dtype=x.dtype, device=x.device)
+ return original_gather(x, indices, bin_ids, bins, top_k)
+
+ ops.gather = gather_with_meta
+
+ # Patch ops.scatter
+ if hasattr(ops, "scatter"):
+ original_scatter = ops.scatter
+
+ def scatter_with_meta(x, indices, bin_ids, weights, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - restore sequence shape
+ seq_len = (
+ indices.size(0) // top_k
+ if indices.numel() > 0 and top_k > 0
+ else x.size(0)
+ )
+ if x.dim() >= 2:
+ return torch.empty(
+ (seq_len, x.size(-1)), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty((seq_len,), dtype=x.dtype, device=x.device)
+ return original_scatter(x, indices, bin_ids, weights, bins, top_k)
+
+ ops.scatter = scatter_with_meta
+
+ # Patch ops.replicate
+ if hasattr(ops, "replicate"):
+ original_replicate = ops.replicate
+
+ def replicate_with_meta(x, bins, num_outputs):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty(
+ (x.shape[0], num_outputs), dtype=x.dtype, device=x.device
+ )
+ return original_replicate(x, bins, num_outputs)
+
+ ops.replicate = replicate_with_meta
+
+ # Patch ops.repeat (if it's a regular function)
+ if hasattr(ops, "repeat"):
+ original_repeat = ops.repeat
+
+ def repeat_with_meta(x, repeats):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ if isinstance(repeats, (tuple, list)):
+ new_shape = list(x.shape)
+ for i, rep in enumerate(repeats):
+ if i < len(new_shape):
+ new_shape[i] *= rep
+ return torch.empty(new_shape, dtype=x.dtype, device=x.device)
+ else:
+ new_shape = [x.size(0) * repeats] + list(x.shape[1:])
+ return torch.empty(new_shape, dtype=x.dtype, device=x.device)
+ return original_repeat(x, repeats)
+
+ ops.repeat = repeat_with_meta
+
+
+# Install meta kernels on import
+try:
+ _install_meta_kernels()
+except Exception as e:
+ # If meta kernel installation fails, continue without them
+ # torch.compile may not work but the library will still function
+ import warnings
+
+ warnings.warn(
+ f"Failed to install meta kernels for torch.compile support: {e}", UserWarning
+ )
+
+
+# Set the expert model parallel attributes on a tensor
+def set_expert_model_parallel_attributes(
+ tensor: torch.Tensor,
+ is_parallel: bool,
+):
+ assert not hasattr(tensor, "expert_model_parallel")
+ setattr(tensor, "expert_model_parallel", is_parallel)
+
+
+# Get the expert model parallel attributes from a tensor
+def expert_sharding_degree(
+ world_size: int,
+ moe_num_experts: int,
+) -> int:
+ esd = min(world_size, moe_num_experts)
+ if (moe_num_experts % esd) != 0:
+ raise ValueError(f"Cannot shard {moe_num_experts} experts {esd} ways.")
+ return esd
+
+
+# Calculate the hidden sharding degree based on world size and expert sharding degree
+def hidden_sharding_degree(
+ world_size: int,
+ moe_num_experts: int,
+ ffn_hidden_size: int,
+) -> int:
+ esd = expert_sharding_degree(world_size, moe_num_experts)
+ hsd = world_size // esd
+ if (ffn_hidden_size % hsd) != 0:
+ raise ValueError(f"Cannot shard {ffn_hidden_size} features {hsd} ways.")
+ if (esd * hsd) != world_size:
+ raise ValueError(
+ f"Invalid sharding. expert_sharding_degree ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size})."
+ )
+ return hsd
+
+
+# Calculate the number of experts per rank based on world size and expert sharding degree
+def experts_per_rank(
+ moe_num_experts: int,
+ world_size: int,
+) -> int:
+ return moe_num_experts // expert_sharding_degree(world_size, moe_num_experts)
+
+
+# Calculate the number of features per rank based on ffn hidden size and hidden sharding degree
+def features_per_rank(
+ ffn_hidden_size: int, world_size: int, moe_num_experts: int
+) -> int:
+ return ffn_hidden_size // hidden_sharding_degree(
+ world_size, moe_num_experts, ffn_hidden_size
+ )
+
+
+# Apply jitter to the input tensor
+def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor:
+ low = 1.0 - moe_jitter_eps
+ high = 1.0 + moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return x * (low + noise * (high - low))
+
+
+# Compute the top-k scores from the logits
+def compute_top_k(scores: torch.Tensor, moe_top_k: int):
+ if moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, moe_top_k, dim=-1)
+
+
+# Route tokens to experts and compute expert weights and indices
+def route_tokens(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: torch.Tensor,
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ if training and moe_jitter_eps is not None:
+ x = apply_jitter(x, moe_jitter_eps)
+
+ x_flat = x.view(-1, x.shape[-1])
+ logits = torch.nn.functional.linear(x_flat, router_weight, router_bias)
+ expert_weights, expert_indices = compute_top_k(logits, moe_top_k)
+ expert_weights = expert_weights.softmax(dim=-1)
+ if moe_normalize_expert_weights is not None:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+ if uniform_expert_assignment:
+ expert_indices = _layers.router._uniform_expert_assignment(
+ expert_indices,
+ moe_num_experts,
+ )
+
+ return logits, expert_weights, expert_indices
+
+
+# Scale the gradient of the weights
+def scale_grad(
+ w: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+) -> torch.Tensor:
+ if gradient_scale is None:
+ return w
+ return _layers.mlp.scale_gradient(w, gradient_scale)
+
+
+# Forward pass for the MLP layer
+def mlp_forward(
+ x: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ limit: float = 7.0,
+):
+ # Scale weights
+ w1 = scale_grad(w1, gradient_scale)
+ w2 = scale_grad(w2, gradient_scale)
+ w1_bias = scale_grad(w1_bias, gradient_scale)
+ w2_bias = scale_grad(w2_bias, gradient_scale)
+
+ # Resolve dtensors
+ w1 = _layers.mlp.resolve_dtensor(w1)
+ w2 = _layers.mlp.resolve_dtensor(w2)
+ w1_bias = _layers.mlp.resolve_dtensor(w1_bias)
+ w2_bias = _layers.mlp.resolve_dtensor(w2_bias)
+
+ # Forward pass
+ gate_up = torch.bmm(x, w1) + w1_bias[..., None, :]
+ gate, up = gate_up[..., ::2], gate_up[..., 1::2]
+ gate = gate.clamp(min=None, max=limit)
+ up = up.clamp(min=-limit, max=limit)
+ glu = gate * torch.sigmoid(gate * alpha)
+ next_states = torch.bmm(((up + 1) * glu), w2)
+ next_states += w2_bias[..., None, :]
+ return next_states
+
+# Shared expert MLP forward pass
+def shared_mlp_forward(
+ x: torch.Tensor,
+ up_proj_weight: torch.Tensor,
+ down_proj_weight: torch.Tensor,
+ up_proj_bias: Optional[torch.Tensor] = None,
+ down_proj_bias: Optional[torch.Tensor] = None,
+ activation_fn: Optional[Any] = None,
+ gradient_scale: Optional[float] = None,
+) -> torch.Tensor:
+ # Default activation function
+ if activation_fn is None:
+ activation_fn = torch.nn.functional.gelu
+
+ # Scale weights
+ up_proj_weight = scale_grad(up_proj_weight, gradient_scale)
+ down_proj_weight = scale_grad(down_proj_weight, gradient_scale)
+ if up_proj_bias is not None:
+ up_proj_bias = scale_grad(up_proj_bias, gradient_scale)
+ if down_proj_bias is not None:
+ down_proj_bias = scale_grad(down_proj_bias, gradient_scale)
+
+ # Resolve dtensors
+ up_proj_weight = _layers.mlp.resolve_dtensor(up_proj_weight)
+ down_proj_weight = _layers.mlp.resolve_dtensor(down_proj_weight)
+ if up_proj_bias is not None:
+ up_proj_bias = _layers.mlp.resolve_dtensor(up_proj_bias)
+ if down_proj_bias is not None:
+ down_proj_bias = _layers.mlp.resolve_dtensor(down_proj_bias)
+
+ # Up projection
+ x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias)
+
+ # Activation
+ x = activation_fn(x)
+
+ # Down projection
+ x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias)
+
+ return x
+
+
+# Combine outputs from shared expert and regular experts
+def combine_expert_shared_outputs(
+ shared_expert_out: torch.Tensor,
+ expert_out: torch.Tensor,
+ shared_expert_weighted_sum: bool = False,
+ moe_top_k: int = 1,
+) -> torch.Tensor:
+ if shared_expert_weighted_sum:
+ # Weighted sum based on number of experts used
+ total_experts = moe_top_k + 1
+ shared_weight = 1.0 / total_experts
+ expert_weight = moe_top_k / total_experts
+ return shared_expert_out * shared_weight + expert_out * expert_weight
+ else:
+ # Simple addition
+ return shared_expert_out + expert_out
+
+
+# Global variable to store load balancing loss
+_LOAD_BALANCING_LOSS = []
+
+
+def save_load_balancing_loss(loss):
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.append(loss)
+
+
+def get_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ return _LOAD_BALANCING_LOSS
+
+
+def clear_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.clear()
+
+
+def batched_load_balancing_loss(args):
+ if args.moe_loss_weight == 0:
+ return 0.0
+
+ tokens_per_expert, expert_scores = zip(*get_load_balancing_loss())
+ num_layers_per_pipeline_stage = args.num_layers // args.pipeline_model_parallel_size
+ if args.num_layers_per_virtual_pipeline_stage is not None:
+ num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage
+
+ if len(tokens_per_expert) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f"Expected {num_layers_per_pipeline_stage} token_per_experts "
+ f"but found {len(tokens_per_expert)}.\nnum_layers = "
+ f"{args.num_layers}\npipeline_model_parallel_size = "
+ f"{args.pipeline_model_parallel_size}\n"
+ "num_layers_per_virtual_pipeline_stage"
+ f" = {args.num_layers_per_virtual_pipeline_stage}",
+ )
+ if len(expert_scores) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f"Expected {num_layers_per_pipeline_stage} expert_scores "
+ f"but found {len(tokens_per_expert)}.\nnum_layers = "
+ f"{args.num_layers}\npipeline_model_parallel_size = "
+ f"{args.pipeline_model_parallel_size}\n"
+ "num_layers_per_virtual_pipeline_stage"
+ f" = {args.num_layers_per_virtual_pipeline_stage}",
+ )
+
+ # Verify the shape of the tokens_per_expert and expert_scores tensors.
+ assert all(
+ (x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert)
+ )
+
+ tokens = expert_scores[0].shape[0]
+ assert all(
+ (
+ (
+ x.ndim == 2
+ and x.shape[1] == args.moe_num_experts
+ and x.shape[0] == tokens
+ )
+ for x in expert_scores
+ )
+ )
+
+ # Concatenate the contributions of each layer and convert to
+ # the correct types and formats for the dot product.
+ expert_scores = torch.cat(expert_scores, dim=1)
+ if args.moe_lbl_in_fp32:
+ expert_scores = expert_scores.float()
+ if tokens != 0:
+ expert_scores = expert_scores.mean(dim=0)
+ else:
+ expert_scores = expert_scores.sum(dim=0)
+ tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype)
+
+ expected_values = num_layers_per_pipeline_stage * args.moe_num_experts
+ assert tokens_per_expert.numel() == expected_values
+ assert expert_scores.numel() == expected_values
+
+ # Calculate the total scale across all factors.
+ #
+ # loss_weight * num_experts / (num_layers * tokens * top_k)
+ scale_numerator = args.moe_num_experts * args.moe_loss_weight
+ scale_denominator = args.num_layers * tokens * args.moe_top_k
+ scale = scale_numerator / scale_denominator
+ return scale * torch.dot(tokens_per_expert, expert_scores)
+
+
+# Calculate the expert capacity based on tokens, top_k, number of experts,
+# expert parallel group, capacity factor, and whether expert model parallelism is used.
+def expert_capacity(
+ tokens: int,
+ top_k: int,
+ num_experts: int,
+ expert_parallel_group: int,
+ moe_capacity_factor: float,
+ moe_expert_model_parallelism: bool,
+) -> int:
+ world_size = (
+ dist.get_world_size(expert_parallel_group)
+ if moe_expert_model_parallelism
+ else 1
+ )
+
+ tokens_per_expert = top_k * tokens * world_size / num_experts
+ return int(moe_capacity_factor * tokens_per_expert)
+
+
+def load_balancing_loss(
+ tokens_per_expert: torch.Tensor,
+ expert_scores: torch.Tensor,
+ top_k: int,
+ num_experts: int,
+):
+ assert len(expert_scores.size()) == 2
+ tokens, num_experts = expert_scores.size()
+ assert num_experts == num_experts
+ assert len(tokens_per_expert.size()) == 1
+ (num_experts,) = tokens_per_expert.size()
+ assert num_experts == num_experts
+ scale = num_experts / (tokens * top_k)
+ return scale * torch.dot(
+ tokens_per_expert.to(expert_scores.dtype),
+ expert_scores.mean(dim=0),
+ )
+
+
+def indices_and_bins(
+ top_expert: torch.Tensor,
+ sort_end_bit: int,
+ num_experts: int,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ top_expert = top_expert.int()
+
+ # Ensure contiguous memory layout
+ top_expert = top_expert.contiguous()
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(top_expert.device):
+ output = ops.sort(top_expert, sort_end_bit)
+ bin_ids, indices = output
+ tokens_per_expert = ops.histogram(top_expert, num_experts)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+
+ bins = bins.view(1) if not len(bins.size()) else bins
+ return indices, bin_ids, bins, tokens_per_expert
+
+
+def expert_capacity_fn(
+ tokens: int,
+ top_k: int,
+ num_experts: int,
+ expert_parallel_group: torch.distributed.ProcessGroup,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+) -> int:
+ world_size = (
+ dist.get_world_size(expert_parallel_group)
+ if moe_expert_model_parallelism
+ else 1
+ )
+ tokens_per_expert = top_k * tokens * world_size / num_experts
+ return int(moe_capacity_factor * tokens_per_expert)
+
+
+def permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ top_k,
+ w1,
+ w2,
+ w1_bias,
+ w2_bias,
+ gradient_scale,
+ alpha,
+):
+ # Route tokens to experts
+ x = x.view(-1, x.shape[-1])
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(x.device):
+ x = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
+
+ # Expert computation
+ x = mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale, alpha)
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(x.device):
+ # Route tokens back
+ out = ops.binned_scatter(x, indices, expert_weights, bins, top_k)
+ return out
+
+
+def forward_once(
+ x: torch.Tensor,
+ expert_weights: torch.Tensor,
+ top_experts: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ top_k: int = 4,
+ num_experts: int = 128,
+ expert_parallel_group: int = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ mlp_impl: Optional[str] = None,
+):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
+ top_experts, sort_end_bit, num_experts
+ )
+
+ # Calculate expert capacity
+ sl, bs, _ = x.size()
+
+ expert_capacity = expert_capacity_fn(
+ sl * bs,
+ top_k,
+ num_experts,
+ expert_parallel_group,
+ moe_capacity_factor,
+ moe_expert_model_parallelism,
+ )
+
+ if expert_capacity == 0:
+ expert_capacity = torch.max(tokens_per_expert).item()
+
+ x = permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ top_k,
+ w1,
+ w2,
+ w1_bias,
+ w2_bias,
+ gradient_scale,
+ alpha,
+ )
+ return x, tokens_per_expert
+
+
+def parallel_forward_once(
+ x: torch.Tensor,
+ expert_weights: torch.Tensor,
+ top_experts: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ top_k: int = 4,
+ num_experts: int = 128,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = True,
+ hidden_size: int = 1152,
+ mlp_impl: Optional[str] = "grouped",
+):
+ # Flatten inputs
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+
+ # TODO: remove debugging var
+ # my_rank = dist.get_rank(expert_parallel_group) if expert_parallel_group else 0
+
+ with torch.no_grad():
+ # Step 1: Local permutation setup
+ indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
+ top_experts, sort_end_bit, num_experts
+ )
+
+ # Calculate sharding parameters
+ world_size = dist.get_world_size(expert_parallel_group)
+ hidden_sharding_deg = hidden_sharding_degree(
+ world_size, num_experts, hidden_size
+ )
+ experts_per_rank_val = experts_per_rank(num_experts, world_size)
+
+ # Replicate token counts for hidden sharding
+ repeated_tokens_per_expert = ops.repeat(
+ tokens_per_expert, (hidden_sharding_deg,)
+ )
+
+ # Exchange token counts across devices
+ parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert)
+
+ # Ensure CUB knows which device to use
+ tpe_handle = dist.all_to_all_single(
+ parallel_tokens_per_expert,
+ repeated_tokens_per_expert,
+ group=expert_parallel_group,
+ async_op=True,
+ )
+
+ # Step 2: Local permutation - group tokens by target device
+ x = x.view(-1, x.shape[-1]) # [sl * bs, hs]
+ x = ops.gather(x, indices, bin_ids, bins, top_k)
+
+ # Step 3: Compute communication counts and exchange tokens
+ with torch.no_grad():
+ tpe_handle.wait()
+
+ # Reshape for per-device calculations
+ repeated_tokens_per_expert = repeated_tokens_per_expert.view(
+ world_size, experts_per_rank_val
+ )
+ parallel_tokens_per_expert = parallel_tokens_per_expert.view(
+ world_size, experts_per_rank_val
+ )
+
+ # Calculate send/recv counts
+ send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1).tolist()
+ # recv_counts = parallel_tokens_per_expert.cpu().sum(dim=-1).tolist()
+ parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
+ recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1).tolist()
+ tokens_received = sum(recv_counts)
+
+ # Replicate for hidden sharding
+ x = ops.repeat(x, (hidden_sharding_deg, 1))
+
+ # Cross-device token exchange
+ parallel_x, parallel_x_handle = _layers.all_to_all.all_to_all(
+ x, recv_counts, send_counts, expert_parallel_group, async_op=True
+ )
+
+ with torch.no_grad():
+ # Step 4: Setup for local expert computation
+ replicate_bins = ops.inclusive_cumsum(parallel_tokens_per_expert.flatten(), 0)
+ replicate_bins = (
+ replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins
+ )
+
+ # Create expert indices for received tokens
+ parallel_top_expert = torch.remainder(
+ torch.arange(
+ num_experts * hidden_sharding_deg,
+ dtype=torch.int32,
+ device=indices.device,
+ ),
+ experts_per_rank_val,
+ )
+ parallel_top_expert = ops.replicate(
+ parallel_top_expert.unsqueeze(dim=0),
+ replicate_bins,
+ tokens_received,
+ ).flatten()
+
+ # Sort tokens by expert assignment
+ parallel_bin_ids, parallel_indices = ops.sort(
+ parallel_top_expert,
+ sort_end_bit,
+ )
+
+ # Calculate bins for local experts
+ parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
+ dim=0, dtype=torch.int
+ )
+ parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = (
+ parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins
+ )
+
+ # Calculate expert capacity
+ expert_capacity = expert_capacity_fn(
+ tokens_received,
+ top_k,
+ experts_per_rank_val,
+ expert_parallel_group,
+ moe_capacity_factor,
+ moe_expert_model_parallelism,
+ )
+ if expert_capacity == 0:
+ expert_capacity = torch.max(parallel_tokens_per_expert).item()
+
+ # Locally permute the tokens and perform the expert computation.
+ # Block to make sure that the cross-device permutation is complete.
+ if mlp_impl == "grouped":
+ # GroupedMLP requires counts on CPU. We can use the tensor already
+ # moved to CPU for the prior all_to_all, which avoids an extra
+ # device synchronization.
+ parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+
+ # Step 5: Expert computation
+ parallel_x_handle.wait()
+
+ parallel_x = permute_and_compute(
+ parallel_x,
+ parallel_tokens_per_expert,
+ parallel_indices,
+ parallel_bin_ids,
+ None, # expert_weights
+ parallel_bins,
+ expert_capacity,
+ top_k=1,
+ w1=w1,
+ w2=w2,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ )
+
+ # Step 6: Reverse communication - send results back
+ x, _ = _layers.all_to_all.all_to_all(
+ parallel_x, send_counts, recv_counts, expert_parallel_group
+ )
+
+ # Step 7: Reduce across hidden sharding dimension
+ shape = (hidden_sharding_deg, -1, hidden_size)
+ x = x.view(shape).sum(dim=0)
+
+ # Step 8: Final local unpermutation
+ x = ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
+
+ return x, tokens_per_expert.flatten()
+
+
+def moe_forward(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: Optional[torch.Tensor],
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+ w1: torch.Tensor = None,
+ w2: torch.Tensor = None,
+ w1_bias: torch.Tensor = None,
+ w2_bias: torch.Tensor = None,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ forward_fn: Any = None,
+ hidden_size: int = None,
+ mlp_impl: str = "grouped",
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+
+ # Route tokens to experts
+ logits, expert_weights, expert_indices = route_tokens(
+ x,
+ router_weight,
+ router_bias,
+ moe_top_k,
+ moe_num_experts,
+ moe_jitter_eps,
+ moe_normalize_expert_weights,
+ uniform_expert_assignment,
+ training,
+ )
+
+ # Create router scores for output
+ router_scores = (
+ torch.zeros_like(logits)
+ .scatter_(1, expert_indices, expert_weights)
+ .transpose(0, 1)
+ )
+
+ in_shape = x.size()
+
+ # Prepare forward function arguments
+ forward_args = {
+ "x": x,
+ "expert_weights": expert_weights,
+ "top_experts": expert_indices,
+ "w1": w1,
+ "w2": w2,
+ "w1_bias": w1_bias,
+ "w2_bias": w2_bias,
+ "gradient_scale": gradient_scale,
+ "alpha": alpha,
+ "sort_end_bit": sort_end_bit,
+ "top_k": moe_top_k,
+ "num_experts": moe_num_experts,
+ "expert_parallel_group": expert_parallel_group,
+ "moe_capacity_factor": moe_capacity_factor,
+ "moe_expert_model_parallelism": moe_expert_model_parallelism,
+ "mlp_impl": mlp_impl,
+ }
+
+ # Add hidden_size for parallel forward
+ if moe_expert_model_parallelism and hidden_size is not None:
+ forward_args["hidden_size"] = hidden_size
+ elif moe_expert_model_parallelism and hidden_size is None:
+ # Infer hidden_size from input shape
+ forward_args["hidden_size"] = x.shape[-1]
+
+ # Compute expert outputs
+ x, tokens_per_expert = forward_fn(**forward_args)
+
+ # Save load balancing loss if needed
+ moe_loss_weight = 0.0 # Can be made configurable
+ if training and moe_loss_weight > 0:
+ save_load_balancing_loss((tokens_per_expert, logits))
+
+ # Restore original shape
+ x = x.view(in_shape)
+
+ return x, expert_weights, router_scores
+
+
+def moe_forward_with_shared_expert(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: Optional[torch.Tensor],
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+ w1: torch.Tensor = None,
+ w2: torch.Tensor = None,
+ w1_bias: torch.Tensor = None,
+ w2_bias: torch.Tensor = None,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ forward_fn: Any = None,
+ hidden_size: int = None,
+ mlp_impl: str = "grouped",
+ # Shared expert parameters
+ shared_up_proj_weight: Optional[torch.Tensor] = None,
+ shared_down_proj_weight: Optional[torch.Tensor] = None,
+ shared_up_proj_bias: Optional[torch.Tensor] = None,
+ shared_down_proj_bias: Optional[torch.Tensor] = None,
+ shared_expert_weighted_sum: bool = False,
+ shared_activation_fn: Optional[Any] = None,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+
+ # First, compute regular MoE forward pass
+ expert_out, expert_weights, router_scores = moe_forward(
+ x=x,
+ router_weight=router_weight,
+ router_bias=router_bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=training,
+ w1=w1,
+ w2=w2,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=moe_expert_model_parallelism,
+ forward_fn=forward_fn,
+ hidden_size=hidden_size,
+ mlp_impl=mlp_impl,
+ )
+
+ # If shared expert weights provided, compute shared expert output
+ if shared_up_proj_weight is not None and shared_down_proj_weight is not None:
+ shared_expert_out = shared_mlp_forward(
+ x=x,
+ up_proj_weight=shared_up_proj_weight,
+ down_proj_weight=shared_down_proj_weight,
+ up_proj_bias=shared_up_proj_bias,
+ down_proj_bias=shared_down_proj_bias,
+ activation_fn=shared_activation_fn,
+ gradient_scale=gradient_scale,
+ )
+
+ # Combine expert outputs
+ combined_out = combine_expert_shared_outputs(
+ shared_expert_out=shared_expert_out,
+ expert_out=expert_out,
+ shared_expert_weighted_sum=shared_expert_weighted_sum,
+ moe_top_k=moe_top_k,
+ )
+
+ return combined_out, expert_weights, router_scores
+
+ # Return regular MoE output if no shared expert
+ return expert_out, expert_weights, router_scores
+
+
+def create_shared_expert_weights(
+ hidden_size: int,
+ shared_expert_hidden_size: int,
+ device: torch.device,
+ dtype: torch.dtype,
+ init_method: Any,
+ output_layer_init_method: Any = None,
+) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
+
+ if output_layer_init_method is None:
+ output_layer_init_method = init_method
+
+ # Create weight tensors
+ up_proj_weight = torch.empty(
+ shared_expert_hidden_size,
+ hidden_size,
+ device=device,
+ dtype=dtype,
+ )
+ down_proj_weight = torch.empty(
+ hidden_size,
+ shared_expert_hidden_size,
+ device=device,
+ dtype=dtype,
+ )
+
+ # Initialize weights
+ init_method(up_proj_weight)
+ output_layer_init_method(down_proj_weight)
+
+ # No bias by default
+ return up_proj_weight, down_proj_weight, None, None
+
+
+# HACK: Extract device_mesh from pre-hook closure - required for transformers integration
+# This exists because device_mesh is trapped in hook closures with no model attribute
+# Fragile - breaks if hook structure changes or Python internals change
+# TODO: Replace with a more robust solution when available
+def get_device_mesh(model):
+ # Extract device_mesh from child's unused pre_hook closure
+ try:
+ # Find the pre-hook that contains 'device_mesh' in its closure
+ hook = next(
+ h
+ for h in model.experts._forward_pre_hooks.values()
+ if "device_mesh" in h.__code__.co_freevars
+ )
+ # Extract the device_mesh from the closure
+ return hook.__closure__[
+ hook.__code__.co_freevars.index("device_mesh")
+ ].cell_contents
+ except Exception:
+ return None
+
+
+class MegaBlocksMoeMLP(torch.nn.Module):
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ gradient_scale = getattr(self.experts, "gradient_scale", None)
+ alpha = getattr(self.experts, "alpha", 1.0)
+ moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+ uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
+
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = get_device_mesh(self)
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
+
+ has_parallel = (
+ expert_parallel_group is not None
+ and dist.is_initialized()
+ and dist.get_world_size(expert_parallel_group) > 1
+ )
+ forward_fn = parallel_forward_once if has_parallel else forward_once
+
+ sort_end_bit = max(
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
+ )
+ mlp_impl = getattr(self, "mlp_impl", "grouped")
+ output, expert_weights_out, *_ = moe_forward(
+ x=x,
+ router_weight=self.router.weight,
+ router_bias=self.router.bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=self.training,
+ w1=self.experts.gate_up_proj,
+ w2=self.experts.down_proj,
+ w1_bias=self.experts.gate_up_proj_bias,
+ w2_bias=self.experts.down_proj_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=has_parallel,
+ forward_fn=forward_fn,
+ hidden_size=self.experts.hidden_size,
+ mlp_impl=mlp_impl,
+ )
+ return output, expert_weights_out
+
+
+# Export main classes
+__all__ = ["MegaBlocksMoeMLP", "MegaBlocksMoeMLPWithSharedExpert"]
+
+
+class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
+
+ def __init__(self):
+ super().__init__()
+ # Shared expert weights will be set by the user
+ self.shared_up_proj_weight = None
+ self.shared_down_proj_weight = None
+ self.shared_up_proj_bias = None
+ self.shared_down_proj_bias = None
+ self.shared_expert_weighted_sum = False
+ self.shared_activation_fn = None
+
+ def set_shared_expert_weights(
+ self,
+ up_proj_weight: torch.Tensor,
+ down_proj_weight: torch.Tensor,
+ up_proj_bias: Optional[torch.Tensor] = None,
+ down_proj_bias: Optional[torch.Tensor] = None,
+ weighted_sum: bool = False,
+ activation_fn: Optional[Any] = None,
+ ):
+ self.shared_up_proj_weight = up_proj_weight
+ self.shared_down_proj_weight = down_proj_weight
+ self.shared_up_proj_bias = up_proj_bias
+ self.shared_down_proj_bias = down_proj_bias
+ self.shared_expert_weighted_sum = weighted_sum
+ self.shared_activation_fn = activation_fn
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ gradient_scale = getattr(self.experts, "gradient_scale", None)
+ alpha = getattr(self.experts, "alpha", 1.0)
+ moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+ uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
+
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = get_device_mesh(self)
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
+
+ has_parallel = (
+ expert_parallel_group is not None
+ and dist.is_initialized()
+ and dist.get_world_size(expert_parallel_group) > 1
+ )
+ forward_fn = parallel_forward_once if has_parallel else forward_once
+
+ sort_end_bit = max(
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
+ )
+ mlp_impl = getattr(self, "mlp_impl", "grouped")
+
+ output, expert_weights_out, *_ = moe_forward_with_shared_expert(
+ x=x,
+ router_weight=self.router.weight,
+ router_bias=self.router.bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=self.training,
+ w1=self.experts.gate_up_proj,
+ w2=self.experts.down_proj,
+ w1_bias=self.experts.gate_up_proj_bias,
+ w2_bias=self.experts.down_proj_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=has_parallel,
+ forward_fn=forward_fn,
+ hidden_size=self.experts.hidden_size,
+ mlp_impl=mlp_impl,
+ # Shared expert parameters
+ shared_up_proj_weight=self.shared_up_proj_weight,
+ shared_down_proj_weight=self.shared_down_proj_weight,
+ shared_up_proj_bias=self.shared_up_proj_bias,
+ shared_down_proj_bias=self.shared_down_proj_bias,
+ shared_expert_weighted_sum=self.shared_expert_weighted_sum,
+ shared_activation_fn=self.shared_activation_fn,
+ )
+ return output, expert_weights_out
+
+
+# Patch for XPU or CPU support
+if hasattr(torch, "xpu") and torch.xpu.is_available():
+ from .xpu_fused_moe import MegaBlocksMoeMLP
+
+from .cpu_moe_cpp import CPUMegaBlocksMoeMLP
diff --git a/build/torch29-cxx11-cpu-x86_64-linux/megablocks/__init__.py b/build/torch29-cxx11-cpu-x86_64-linux/megablocks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..03dbc1afe1cf156661a2b1b22003cd5f599a0309
--- /dev/null
+++ b/build/torch29-cxx11-cpu-x86_64-linux/megablocks/__init__.py
@@ -0,0 +1,26 @@
+import ctypes
+import sys
+
+import importlib
+from pathlib import Path
+from types import ModuleType
+
+def _import_from_path(file_path: Path) -> ModuleType:
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
+ # it would also be used for other imports. So, we make a module name that
+ # depends on the path for it to be unique using the hex-encoded hash of
+ # the path.
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
+ module_name = path_hash
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ if spec is None:
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
+ module = importlib.util.module_from_spec(spec)
+ if module is None:
+ raise ImportError(f"Cannot load module {module_name} from spec")
+ sys.modules[module_name] = module
+ spec.loader.exec_module(module) # type: ignore
+ return module
+
+
+globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
diff --git a/build/torch29-cxx11-cpu-x86_64-linux/metadata.json b/build/torch29-cxx11-cpu-x86_64-linux/metadata.json
new file mode 100644
index 0000000000000000000000000000000000000000..a5381dd80836f863378b9f33a559815688de9287
--- /dev/null
+++ b/build/torch29-cxx11-cpu-x86_64-linux/metadata.json
@@ -0,0 +1,5 @@
+{
+ "version": 1,
+ "license": "Apache-2.0",
+ "python-depends": []
+}
\ No newline at end of file
diff --git a/build/torch29-cxx11-cpu-x86_64-linux/ops/__init__.py b/build/torch29-cxx11-cpu-x86_64-linux/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b944080df810d0b0cfc571f3009b0098a651f9b7
--- /dev/null
+++ b/build/torch29-cxx11-cpu-x86_64-linux/ops/__init__.py
@@ -0,0 +1,35 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from .binned_gather import binned_gather
+from .binned_scatter import binned_scatter
+from .cumsum import exclusive_cumsum, inclusive_cumsum
+from .gather import gather
+from .histogram import histogram
+from .padded_gather import padded_gather
+from .padded_scatter import padded_scatter
+from .repeat import repeat
+from .replicate import replicate
+from .round_up import round_up
+from .scatter import scatter
+from .sort import sort
+from .sum import sum
+from .topology import topology
+
+__all__ = [
+ 'binned_gather',
+ 'binned_scatter',
+ 'exclusive_cumsum',
+ 'inclusive_cumsum',
+ 'gather',
+ 'histogram',
+ 'padded_gather',
+ 'padded_scatter',
+ 'repeat',
+ 'replicate',
+ 'round_up',
+ 'scatter',
+ 'sort',
+ 'sum',
+ 'topology',
+]
diff --git a/build/torch29-cxx11-cpu-x86_64-linux/ops/all_to_all_benchmark.py b/build/torch29-cxx11-cpu-x86_64-linux/ops/all_to_all_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c939818edca3345f6344bbc7cef07ffe3cd0181
--- /dev/null
+++ b/build/torch29-cxx11-cpu-x86_64-linux/ops/all_to_all_benchmark.py
@@ -0,0 +1,63 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import torch.distributed as dist
+
+# from megablocks import benchmark_util
+# from megablocks.layers.all_to_all import all_to_all
+
+from .. import benchmark_util
+from .._layers.all_to_all import all_to_all
+
+_ALL_TO_ALL_BENCHMARK = (
+ (8, 1024),
+ (16, 1024),
+ (32, 1024),
+ (64, 1024),
+ (128, 1024),
+ (256, 1024),
+ (512, 1024),
+ (1024, 1024),
+ (2 * 1024, 1024),
+ (4 * 1024, 1024),
+ (8 * 1024, 1024),
+ (16 * 1024, 1024),
+ (32 * 1024, 1024),
+ (64 * 1024, 1024),
+ (128 * 1024, 1024),
+ (256 * 1024, 1024),
+ (512 * 1024, 1024),
+ (1024 * 1024, 1024),
+)
+
+
+def benchmark_all_to_all(group, sl, hs):
+ world_size = dist.get_world_size(group)
+ assert (sl % world_size) == 0
+ send_recv_sizes = [sl // world_size] * world_size
+
+ x = torch.randn((sl, hs)).cuda().half()
+
+ details = {
+ 'world_size': world_size,
+ 'message_size (B)': send_recv_sizes[0] * hs * 2, # 2B elements.
+ }
+
+ def benchmark():
+ return all_to_all(x, send_recv_sizes, send_recv_sizes, group)
+
+ time, std = benchmark_util.benchmark_function(benchmark)
+
+ if dist.get_rank(group) == 0:
+ benchmark_util.log_benchmark('All-To-All', details, time, std)
+
+
+if __name__ == '__main__':
+ assert dist.is_available()
+ group = dist.init_process_group(backend='nccl')
+ local_rank = dist.get_rank(group)
+ torch.cuda.set_device(local_rank)
+
+ for args in _ALL_TO_ALL_BENCHMARK:
+ benchmark_all_to_all(group, *args)
diff --git a/build/torch29-cxx11-cpu-x86_64-linux/ops/binned_gather.py b/build/torch29-cxx11-cpu-x86_64-linux/ops/binned_gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..189a7fa3518d660f29ea32e7a04827164af98d60
--- /dev/null
+++ b/build/torch29-cxx11-cpu-x86_64-linux/ops/binned_gather.py
@@ -0,0 +1,37 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for binned_gather kernel.
+class BinnedGatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bins: torch.Tensor,
+ bin_size: int,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bins)
+ ctx.top_k = top_k
+ return kernels.binned_gather(x, indices, None, bins, bin_size, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ indices, bins = ctx.saved_tensors
+ out = kernels.binned_scatter(grad, indices, None, bins, ctx.top_k)
+ return out, None, None, None, None
+
+
+binned_gather = BinnedGatherOp.apply
diff --git a/build/torch29-cxx11-cpu-x86_64-linux/ops/binned_scatter.py b/build/torch29-cxx11-cpu-x86_64-linux/ops/binned_scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb937c0c106662ce8108c1cb926f8f063b163d3d
--- /dev/null
+++ b/build/torch29-cxx11-cpu-x86_64-linux/ops/binned_scatter.py
@@ -0,0 +1,59 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for binned_scatter kernel.
+class BinnedScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ):
+ assert len(x.size()) == 3
+ ctx.bin_size = x.size(1)
+ ctx.top_k = top_k
+
+ # TODO(tgale): Don't save 'x' for backwards if we don't need to
+ # calculate the gradient w.r.t. 'weights'.
+ ctx.save_for_backward(x, indices, weights, bins)
+ return kernels.binned_scatter(x, indices, weights, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ x, indices, weights, bins = ctx.saved_tensors
+ out = kernels.binned_gather(
+ grad,
+ indices,
+ weights,
+ bins,
+ ctx.bin_size,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[2]:
+ wgrad = kernels.binned_scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bins,
+ ctx.top_k,
+ )
+ return out, None, wgrad, None, None
+
+
+binned_scatter = BinnedScatterOp.apply
diff --git a/build/torch29-cxx11-cpu-x86_64-linux/ops/cumsum.py b/build/torch29-cxx11-cpu-x86_64-linux/ops/cumsum.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2b7572391e20045d335cf7337246e8a9b9f57ef
--- /dev/null
+++ b/build/torch29-cxx11-cpu-x86_64-linux/ops/cumsum.py
@@ -0,0 +1,52 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ # import megablocks_ops as ops # type: ignore
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrappers for cumsum kernels.
+# NOTE: Does not support gradients.
+class ExclusiveCumsumOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, dim: int):
+ if len(x.size()) == 1:
+ x = x.view([1, -1])
+ out = torch.empty_like(x)
+ ops.exclusive_cumsum(x, 1, out)
+ return out.squeeze()
+ out = torch.empty_like(x)
+ ops.exclusive_cumsum(x, dim, out)
+ return out
+
+
+exclusive_cumsum = ExclusiveCumsumOp.apply
+
+
+class InclusiveCumsumOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, dim: int) -> torch.Tensor:
+ if len(x.size()) == 1:
+ x = x.view([1, -1])
+ out = torch.empty_like(x)
+ ops.inclusive_cumsum(x, 1, out)
+ return out.squeeze()
+ out = torch.empty_like(x)
+ ops.inclusive_cumsum(x, dim, out)
+ return out
+
+
+inclusive_cumsum = InclusiveCumsumOp.apply
diff --git a/build/torch29-cxx11-cpu-x86_64-linux/ops/gather.py b/build/torch29-cxx11-cpu-x86_64-linux/ops/gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1f87c1e7bed8d3589dd790805234976e0b05898
--- /dev/null
+++ b/build/torch29-cxx11-cpu-x86_64-linux/ops/gather.py
@@ -0,0 +1,38 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for gather kernel.
+class GatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bin_ids, bins)
+ ctx.top_k = top_k
+ return kernels.gather(x, indices, bin_ids, None, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+
+ indices, bin_ids, bins = ctx.saved_tensors
+ out = kernels.scatter(grad, indices, bin_ids, None, bins, ctx.top_k)
+ return out, None, None, None, None, None
+
+
+gather = GatherOp.apply
diff --git a/build/torch29-cxx11-cpu-x86_64-linux/ops/histogram.py b/build/torch29-cxx11-cpu-x86_64-linux/ops/histogram.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b3f058ec373cbba7555704fb5e4212c3cc75d9d
--- /dev/null
+++ b/build/torch29-cxx11-cpu-x86_64-linux/ops/histogram.py
@@ -0,0 +1,27 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for histogram kernel.
+# NOTE: Does not support gradients.
+class HistogramOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, max_val: float):
+ return ops.histogram(x, max_val)
+
+
+histogram = HistogramOp.apply
diff --git a/build/torch29-cxx11-cpu-x86_64-linux/ops/histogram_benchmark.py b/build/torch29-cxx11-cpu-x86_64-linux/ops/histogram_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..c57b7bf8228e01237236748147368b09ffdf8072
--- /dev/null
+++ b/build/torch29-cxx11-cpu-x86_64-linux/ops/histogram_benchmark.py
@@ -0,0 +1,78 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import numpy as np
+import torch
+from absl.testing import parameterized
+
+from .. import ops
+
+_HISTOGRAM_TESTS = (
+ (16384, torch.int32, 2),
+ (16384, torch.int32, 4),
+ (16384, torch.int32, 8),
+ (16384, torch.int32, 16),
+ (16384, torch.int32, 32),
+ (16384, torch.int32, 64),
+ (16384, torch.int32, 128),
+ (16384, torch.int32, 256),
+)
+
+
+def benchmark_function(fn, iterations=10):
+ # Run once to get rid of startup overhead.
+ fn()
+ times = []
+ for _ in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+ start.record()
+ fn()
+ end.record()
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ times = np.array(times)
+ return times.mean(), times.std(), times.max(), times.min()
+
+
+def log_benchmark(arguments, mean_t, std_t):
+ print('=' * 60)
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t))
+ print('=' * 60)
+
+
+class HistogramBenchmark(parameterized.TestCase):
+
+ @parameterized.parameters(*_HISTOGRAM_TESTS)
+ def testHistogram(self, n, dtype, max_val):
+ x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
+
+ mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.histogram(x, max_val),)
+ arguments = {
+ 'n': n,
+ 'dtype': dtype,
+ 'max_val': max_val,
+ }
+ log_benchmark(arguments, mean_t, std_t)
+
+ @parameterized.parameters(*_HISTOGRAM_TESTS)
+ def testTorchHistogram(self, n, dtype, max_val):
+ x = torch.randint(0, 128, (n,)).cuda().to(dtype)
+
+ mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.histc(x, max_val, 0, max_val - 1),)
+ arguments = {
+ 'n': n,
+ 'dtype': dtype,
+ 'max_val': max_val,
+ }
+ log_benchmark(arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch29-cxx11-cpu-x86_64-linux/ops/matmul_benchmark.py b/build/torch29-cxx11-cpu-x86_64-linux/ops/matmul_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ccc5dcec5e9a663794fad944c45285869c4d1c1
--- /dev/null
+++ b/build/torch29-cxx11-cpu-x86_64-linux/ops/matmul_benchmark.py
@@ -0,0 +1,415 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+
+# import stk
+
+# try:
+# import stk
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/ops/matmul_benchmark.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+
+# Calling tensor.t() calls tensor.transpose(0, 1) which calls
+# torch.as_strided(...). Circumvent this chain to avoid an overhead
+# this adds.
+def transpose_view(x):
+ return torch.as_strided(
+ x,
+ (x.shape[1], x.shape[0]),
+ (x.stride()[1], x.stride()[0]),
+ )
+
+
+_MATMUL_TESTS = (
+ (64 * 1024, 512, 2048, 64),
+ (32 * 1024, 768, 3072, 64),
+ (8 * 1024, 1024, 4096, 64),
+ (4 * 2048, 4096, 4 * 4096, 4),
+)
+
+
+def log_benchmark(name, arguments, time, std, flops):
+ benchmark_util.log_benchmark(name, arguments, time, std)
+ print('flops = {:.2f}B'.format(flops / 1e9))
+ print('throughput = {:.2f}T'.format(flops / 1e9 / time))
+ print('=' * 60)
+
+
+class MatmulBenchmark(parameterized.TestCase):
+
+ def build_sparse_matrix(self, x, padded_bins, fhs, ne):
+ blocking = 128
+ padded_tokens, _ = x.size()
+ assert padded_tokens % blocking == 0
+ assert fhs % blocking == 0
+
+ # Offsets for the sparse matrix. All rows have the
+ # same number of nonzero blocks dictated by the
+ # dimensionality of a single expert.
+ block_rows = padded_tokens // blocking
+ blocks_per_row = fhs // blocking
+ offsets = torch.arange(
+ 0,
+ block_rows * blocks_per_row + 1,
+ blocks_per_row,
+ dtype=torch.int32,
+ device=x.device,
+ )
+
+ # Indices for the sparse matrix. The indices for
+ # the intermediate matrix are dynamic depending
+ # on the mapping of tokens to experts.
+ column_indices = ops.topology(
+ padded_bins,
+ blocking,
+ block_rows,
+ blocks_per_row,
+ )
+ data = torch.empty(
+ column_indices.numel(),
+ blocking,
+ blocking,
+ dtype=torch.float16,
+ device=x.device,
+ )
+ shape = (padded_tokens, fhs * ne)
+ row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
+ return stk.Matrix(shape, data, row_indices, column_indices, offsets)
+
+ def build_input_matrix(self, sl, hs, ne):
+ x = torch.randn((sl, hs)).cuda().half()
+
+ # Assign tokens to experts uniformly.
+ top_expert = torch.arange(0, sl).cuda().int() % ne
+
+ bin_ids, indices = ops.sort(top_expert)
+ tokens_per_expert = ops.histogram(top_expert, ne)
+ padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ out = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, 1)
+ return out, padded_bins
+
+ def build_weight_matrix(self, ne, hs, fhs):
+ return torch.randn((hs, ne * fhs)).cuda().half()
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear0_Fwd_SDD_NT(self, sl, hs, fhs, ne):
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
+ w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+ topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+ w = transpose_view(w)
+
+ def benchmark():
+ return stk.ops.sdd(x, w, topo)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '0::Fwd::SDD::NT',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * fhs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear0_GradX_DSD_NN(self, sl, hs, fhs, ne):
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
+ w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+ topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+
+ def benchmark():
+ return stk.ops.dsd(topo, w)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '0::GradX::DSD::NN',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * fhs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear0_GradW_DSD_TN(self, sl, hs, fhs, ne):
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
+ topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+ topo = topo.t()
+
+ def benchmark():
+ return stk.ops.dsd(topo, x)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '0::GradW::DSD::TN',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * fhs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear1_Fwd_DSD_NN(self, sl, hs, fhs, ne):
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
+ w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+ x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+
+ def benchmark():
+ return stk.ops.dsd(x, w)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '1::Fwd::DSD::NN',
+ arguments,
+ mean_t,
+ std_t,
+ x.nnz * hs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear1_GradX_SDD_NT(self, sl, hs, fhs, ne):
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
+ w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+ x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+ out = stk.ops.dsd(x, w)
+ w = transpose_view(w)
+
+ def benchmark():
+ return stk.ops.sdd(out, w, x)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '1::GradX::SDD::NT',
+ arguments,
+ mean_t,
+ std_t,
+ x.nnz * hs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear1_GradW_DSD_TN(self, sl, hs, fhs, ne):
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
+ w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+ x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+ out = stk.ops.dsd(x, w)
+ x = x.t()
+
+ def benchmark():
+ return stk.ops.dsd(x, out)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '1::GradW::DSD::TN',
+ arguments,
+ mean_t,
+ std_t,
+ x.nnz * hs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear0_Fwd_DDD_NT(self, sl, hs, fhs, ne):
+ assert (sl % ne) == 0
+ x = torch.randn((ne, sl // ne, hs)).cuda().half()
+ w = torch.randn((ne, hs, fhs)).cuda().half()
+
+ w = w.transpose(1, 2).contiguous()
+ w = w.transpose(1, 2)
+
+ def benchmark():
+ return torch.bmm(x, w)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '0::Fwd:DDD::NT',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * fhs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear0_GradX_DDD_NN(self, sl, hs, fhs, ne):
+ assert (sl % ne) == 0
+ x = torch.randn((ne, sl // ne, hs)).cuda().half()
+ w = torch.randn((ne, hs, fhs)).cuda().half()
+ out = torch.bmm(x, w)
+ w = w.transpose(1, 2).contiguous()
+
+ def benchmark():
+ return torch.bmm(out, w)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '0:GradX:DDD::NN',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * fhs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear0_GradW_DDD_TN(self, sl, hs, fhs, ne):
+ assert (sl % ne) == 0
+ x = torch.randn((ne, sl // ne, hs)).cuda().half()
+ w = torch.randn((ne, hs, fhs)).cuda().half()
+ out = torch.bmm(x, w)
+ out = out.transpose(1, 2)
+
+ def benchmark():
+ return torch.bmm(out, x)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '0:GradW:DDD::TN',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * fhs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear1_Fwd_DDD_NN(self, sl, hs, fhs, ne):
+ assert (sl % ne) == 0
+ x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+ w = torch.randn((ne, fhs, hs)).cuda().half()
+
+ def benchmark():
+ return torch.bmm(x, w)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '1::Fwd::DDD::NN',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * hs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear1_GradX_DDD_NT(self, sl, hs, fhs, ne):
+ assert (sl % ne) == 0
+ x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+ w = torch.randn((ne, fhs, hs)).cuda().half()
+ out = torch.bmm(x, w)
+ w = torch.transpose(w, 1, 2)
+
+ def benchmark():
+ return torch.bmm(out, w)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '1::GradX::DDD::NT',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * hs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear1_GradW_DDD_TN(self, sl, hs, fhs, ne):
+ assert (sl % ne) == 0
+ x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+ w = torch.randn((ne, fhs, hs)).cuda().half()
+ out = torch.bmm(x, w)
+ x = torch.transpose(x, 1, 2)
+
+ def benchmark():
+ return torch.bmm(x, out)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '1::GradW::DDD::TN',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * hs * 2,
+ )
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch29-cxx11-cpu-x86_64-linux/ops/padded_gather.py b/build/torch29-cxx11-cpu-x86_64-linux/ops/padded_gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1cf4047c9494394d2a3884ba8830179013db7ff
--- /dev/null
+++ b/build/torch29-cxx11-cpu-x86_64-linux/ops/padded_gather.py
@@ -0,0 +1,55 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for padded_gather kernel.
+class PaddedGatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bin_ids, bins, padded_bins)
+ ctx.top_k = top_k
+ return kernels.padded_gather(
+ x,
+ indices,
+ bin_ids,
+ None,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+
+ indices, bin_ids, bins, padded_bins = ctx.saved_tensors
+ out = kernels.padded_scatter(
+ grad,
+ indices,
+ bin_ids,
+ None,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+ return out, None, None, None, None, None
+
+
+padded_gather = PaddedGatherOp.apply
diff --git a/build/torch29-cxx11-cpu-x86_64-linux/ops/padded_scatter.py b/build/torch29-cxx11-cpu-x86_64-linux/ops/padded_scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..61e021b81497e472cda5d72bdac557a0ca92d262
--- /dev/null
+++ b/build/torch29-cxx11-cpu-x86_64-linux/ops/padded_scatter.py
@@ -0,0 +1,98 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for padded_scatter kernel.
+class PaddedScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+ ):
+ maybe_x = [x] if ctx.needs_input_grad[3] else []
+ ctx.save_for_backward(
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ *maybe_x,
+ )
+ ctx.top_k = top_k
+ ctx.x_shape = x.shape
+ return kernels.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ saved_tensors = ctx.saved_tensors
+
+ indices, bin_ids, weights, bins, padded_bins = saved_tensors[:5]
+ dgrad = None
+ if ctx.needs_input_grad[0]:
+ dgrad = kernels.padded_gather(
+ grad,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[3]: # need wgrad
+ x = saved_tensors[-1]
+ wgrad = kernels.padded_scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+ return dgrad, None, None, wgrad, None, None, None, None
+
+
+def padded_scatter(
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+):
+ return PaddedScatterOp.apply(
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
diff --git a/build/torch29-cxx11-cpu-x86_64-linux/ops/padded_scatter_benchmark.py b/build/torch29-cxx11-cpu-x86_64-linux/ops/padded_scatter_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..c575cfe7487d346ba9ec18bbb7ef17f2eb77ec51
--- /dev/null
+++ b/build/torch29-cxx11-cpu-x86_64-linux/ops/padded_scatter_benchmark.py
@@ -0,0 +1,66 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import torch
+from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+_PADDED_SCATTER_BENCHMARK = (
+ # dMoE-Medium, 8-way EMP.
+ (1024 * 16, 1024, 8, 4),
+ # dMoE-Medium, post-all-to-all.
+ (1024 * 16 * 4, 1024, 8, 1),
+)
+
+
+class PaddedScatterTest(parameterized.TestCase):
+
+ @parameterized.parameters(*_PADDED_SCATTER_BENCHMARK)
+ def testPaddedScatter(self, sl, hs, ne, top_k):
+ # Create the data and indices.
+ x = torch.randn((sl, hs)).cuda().half()
+
+ # Randomly assign tokens to experts.
+ top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int()
+ bin_ids, indices = ops.sort(top_expert)
+ tokens_per_expert = ops.histogram(top_expert, ne)
+ padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+
+ # Sample weights for the scatter reduce.
+ weights = torch.rand((sl * top_k,)).cuda().half()
+
+ # Gather the data to prepare for backwards.
+ x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
+
+ def benchmark():
+ return ops.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ time, std = benchmark_util.benchmark_function(benchmark)
+ benchmark_util.log_benchmark(
+ 'Padded Scatter',
+ {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'num_experts': ne,
+ 'top_k': top_k,
+ },
+ time,
+ std,
+ )
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch29-cxx11-cpu-x86_64-linux/ops/permute_benchmark.py b/build/torch29-cxx11-cpu-x86_64-linux/ops/permute_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..6536eeeae402659a087e5c51ef9840627af56501
--- /dev/null
+++ b/build/torch29-cxx11-cpu-x86_64-linux/ops/permute_benchmark.py
@@ -0,0 +1,149 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import torch
+from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+_PERMUTE_TESTS = (
+ (16384, 768, 2),
+ (16384, 768, 4),
+ (16384, 768, 8),
+ (16384, 768, 16),
+ (16384, 768, 32),
+ (16384, 768, 64),
+ (16384, 768, 128),
+ (16384 * 8, 768, 2),
+ (16384 * 8, 768, 4),
+ (16384 * 8, 768, 8),
+ (16384 * 8, 768, 16),
+ (16384 * 8, 768, 32),
+ (16384 * 8, 768, 64),
+ (16384 * 8, 768, 128),
+)
+
+
+class PermuteBenchmark(parameterized.TestCase):
+
+ @parameterized.parameters(*_PERMUTE_TESTS)
+ def testBinnedGather(self, sl, hs, ne):
+ # NOTE: Capacity factor == 1.
+ ec = sl // ne
+
+ # Create the data and indices.
+ x = torch.randn((sl, hs)).cuda().half()
+ top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+ bin_ids, indices = ops.sort(top_expert)
+ tokens_per_expert = ops.histogram(indices, ne)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+
+ def benchmark():
+ return ops.binned_gather(x, indices, bins, ec)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'num_experts': ne,
+ }
+ benchmark_util.log_benchmark('BinnedGather', arguments, mean_t, std_t)
+
+ @parameterized.parameters(*_PERMUTE_TESTS)
+ def testBinnedScatter(self, sl, hs, ne):
+ # NOTE: Capacity factor == 1.
+ ec = sl // ne
+
+ # Create the data and indices.
+ x = torch.randn((sl, hs)).cuda().half()
+ top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+ bin_ids, indices = ops.sort(top_expert)
+ tokens_per_expert = ops.histogram(indices, ne)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ x = ops.binned_gather(x, indices, bins, ec)
+
+ def benchmark():
+ return ops.binned_scatter(x, indices, bins)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'num_experts': ne,
+ }
+ benchmark_util.log_benchmark('BinnedScatter', arguments, mean_t, std_t)
+
+ @parameterized.parameters(*_PERMUTE_TESTS)
+ def testPaddedGather(self, sl, hs, ne):
+ # Create the data and indices.
+ x = torch.randn((sl, hs)).cuda().half()
+
+ # Randomly assign tokens to experts.
+ top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+ bin_ids, indices = ops.sort(top_expert)
+ tokens_per_expert = ops.histogram(top_expert, ne)
+ padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+
+ def benchmark():
+ return ops.padded_gather(x, indices, bin_ids, bins, padded_bins)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'num_experts': ne,
+ }
+ benchmark_util.log_benchmark('PaddedGather', arguments, mean_t, std_t)
+
+ @parameterized.parameters(*_PERMUTE_TESTS)
+ def testPaddedScatter(self, sl, hs, ne):
+ # Create the data and indices.
+ x = torch.randn((sl, hs)).cuda().half()
+
+ # Randomly assign tokens to experts.
+ top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+ bin_ids, indices = ops.sort(top_expert)
+ tokens_per_expert = ops.histogram(top_expert, ne)
+ padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins)
+
+ def benchmark():
+ return ops.padded_scatter(x, indices, bin_ids, bins, padded_bins)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'num_experts': ne,
+ }
+ benchmark_util.log_benchmark('PaddedScatter', arguments, mean_t, std_t)
+
+ @parameterized.parameters(*_PERMUTE_TESTS)
+ def testCopy(self, sl, hs, ne):
+ # NOTE: Capacity factor == 1.
+ # ec = sl // ne
+
+ # Create the data and indices.
+ x = torch.randn((sl, hs)).cuda().half()
+ y = x.clone()
+
+ def benchmark():
+ return y.copy_(x)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'num_experts': ne,
+ }
+ benchmark_util.log_benchmark('Copy', arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch29-cxx11-cpu-x86_64-linux/ops/repeat.py b/build/torch29-cxx11-cpu-x86_64-linux/ops/repeat.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e9e09de5f857d51cd758ab30b2f3a846d6f9275
--- /dev/null
+++ b/build/torch29-cxx11-cpu-x86_64-linux/ops/repeat.py
@@ -0,0 +1,10 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+
+def repeat(x: torch.Tensor, tiling: torch.Size):
+ if all((t == 1 for t in tiling)):
+ return x
+ return x.repeat(*tiling)
diff --git a/build/torch29-cxx11-cpu-x86_64-linux/ops/replicate.py b/build/torch29-cxx11-cpu-x86_64-linux/ops/replicate.py
new file mode 100644
index 0000000000000000000000000000000000000000..26daf0eede330603a4b8ea7167faf1411d07ca93
--- /dev/null
+++ b/build/torch29-cxx11-cpu-x86_64-linux/ops/replicate.py
@@ -0,0 +1,36 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for replicate kernel.
+class ReplicateOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, bins: torch.Tensor, num_outputs: int):
+ ctx.save_for_backward(bins)
+ out = torch.empty((x.shape[0], num_outputs), dtype=x.dtype, device=x.device)
+ ops.replicate_forward(x, bins, out)
+ return out
+
+ @staticmethod
+ def backward(ctx: Any, grad: torch.Tensor):
+ bins, = ctx.saved_tensors
+ out = torch.empty((grad.shape[0], bins.shape[0]), dtype=grad.dtype, device=grad.device)
+ ops.replicate_backward(grad, bins, out)
+ return out, None, None
+
+
+replicate = ReplicateOp.apply
diff --git a/build/torch29-cxx11-cpu-x86_64-linux/ops/round_up.py b/build/torch29-cxx11-cpu-x86_64-linux/ops/round_up.py
new file mode 100644
index 0000000000000000000000000000000000000000..6cf6bc873c9f448c5fa9126ebcfd66e8688002af
--- /dev/null
+++ b/build/torch29-cxx11-cpu-x86_64-linux/ops/round_up.py
@@ -0,0 +1,14 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+
+def round_up(x: torch.Tensor, value: int):
+ assert isinstance(value, int)
+ assert x.dtype == torch.int32
+
+ # TODO(tgale): If this becomes and issue
+ # do this in a custom kernel. We only expect
+ # to use this on arrays of less than 1k elements.
+ return torch.div(x + (value - 1), value, rounding_mode='trunc') * value
diff --git a/build/torch29-cxx11-cpu-x86_64-linux/ops/scatter.py b/build/torch29-cxx11-cpu-x86_64-linux/ops/scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4605d9b46f387761b070352365f223dbfe69d47
--- /dev/null
+++ b/build/torch29-cxx11-cpu-x86_64-linux/ops/scatter.py
@@ -0,0 +1,72 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Optional
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for scatter kernel.
+class ScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ) -> torch.Tensor:
+ maybe_x = [x] if ctx.needs_input_grad[3] else []
+ ctx.save_for_backward(indices, bin_ids, weights, bins, *maybe_x)
+ ctx.top_k = top_k
+ ctx.x_shape = x.shape
+ return kernels.scatter(x, indices, bin_ids, weights, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ saved_tensors = ctx.saved_tensors
+
+ indices, bin_ids, weights, bins = saved_tensors[:4]
+ dgrad = None
+ if ctx.needs_input_grad[0]:
+ dgrad = kernels.gather(
+ grad,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[3]: # need wgrad
+ x = saved_tensors[-1]
+ wgrad = kernels.scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bin_ids,
+ bins,
+ ctx.top_k,
+ )
+ return dgrad, None, None, wgrad, None, None, None
+
+
+def scatter(
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+) -> Optional[torch.Tensor]:
+ return ScatterOp.apply(x, indices, bin_ids, weights, bins, top_k)
diff --git a/build/torch29-cxx11-cpu-x86_64-linux/ops/sort.py b/build/torch29-cxx11-cpu-x86_64-linux/ops/sort.py
new file mode 100644
index 0000000000000000000000000000000000000000..bda3bf64283e39533c2eae3627e76bb2d0262c9f
--- /dev/null
+++ b/build/torch29-cxx11-cpu-x86_64-linux/ops/sort.py
@@ -0,0 +1,38 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Optional, Tuple
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+_BITS_FOR_DTYPE = {
+ torch.int16: 16,
+ torch.int32: 32,
+ torch.int64: 64,
+}
+
+
+# Autograd wrapper for sort kernel.
+# NOTE: Does not support gradients.
+class SortOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, end_bit: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
+ if end_bit is None:
+ end_bit = _BITS_FOR_DTYPE[x.dtype]
+ x_out = torch.empty_like(x)
+ iota_out = torch.empty_like(x)
+ ops.sort(x, end_bit, x_out, iota_out)
+ return (x_out, iota_out)
+
+
+sort = SortOp.apply
diff --git a/build/torch29-cxx11-cpu-x86_64-linux/ops/sort_benchmark.py b/build/torch29-cxx11-cpu-x86_64-linux/ops/sort_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..a92ff957d4c552c6e61d9279a7989795472af7b7
--- /dev/null
+++ b/build/torch29-cxx11-cpu-x86_64-linux/ops/sort_benchmark.py
@@ -0,0 +1,85 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import numpy as np
+import torch
+from absl.testing import parameterized
+
+from .. import ops
+
+_SORT_TESTS = (
+ (16384, torch.int32, None),
+ (16384, torch.int32, 2),
+ (16384, torch.int32, 128),
+)
+
+_BASELINE_SORT_TESTS = ((16384,),)
+
+
+def numpy_dtype(dtype):
+ types = {
+ torch.int16: np.int16,
+ torch.int32: np.int32,
+ torch.int64: np.int64,
+ }
+ return types[dtype]
+
+
+def benchmark_function(fn, iterations=10):
+ # Run once to get rid of startup overhead.
+ fn()
+ times = []
+ for _ in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+ start.record()
+ fn()
+ end.record()
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ times = np.array(times)
+ return times.mean(), times.std(), times.max(), times.min()
+
+
+def log_benchmark(arguments, mean_t, std_t):
+ print('=' * 60)
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t))
+ print('=' * 60)
+
+
+class SortBenchmark(parameterized.TestCase):
+
+ @parameterized.parameters(*_SORT_TESTS)
+ def testSort(self, n, dtype, max_val):
+ if max_val is None:
+ max_val = np.iinfo(numpy_dtype(dtype)).max
+ end_bit = int(np.ceil(np.log2(max_val)))
+ x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
+
+ mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.sort(x, end_bit),)
+ arguments = {
+ 'n': n,
+ 'dtype': dtype,
+ 'max_val': max_val,
+ }
+ log_benchmark(arguments, mean_t, std_t)
+
+ @parameterized.parameters(*_BASELINE_SORT_TESTS)
+ def testTorchSort(self, n):
+ x = torch.randint(0, 128, (n,)).cuda().to(torch.int32)
+
+ mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.sort(x))
+ arguments = {
+ 'n': n,
+ }
+ log_benchmark(arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch29-cxx11-cpu-x86_64-linux/ops/stk_autocast.py b/build/torch29-cxx11-cpu-x86_64-linux/ops/stk_autocast.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a3626e5e0eec51339c95a448bca84be14a2ca93
--- /dev/null
+++ b/build/torch29-cxx11-cpu-x86_64-linux/ops/stk_autocast.py
@@ -0,0 +1,39 @@
+# vendored from
+# https://github.com/stanford-futuredata/stk/blob/736313768ef697ce13a0594a41b2512a0fbc9884/stk/backend/autocast.py
+import functools
+import torch
+
+
+def _is_eligible(x):
+ return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64)
+
+
+def _cast(x, dtype):
+ if isinstance(x, torch.Tensor) and _is_eligible(x):
+ return x.to(dtype)
+ elif isinstance(x, map):
+ return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()}
+ elif isinstance(x, list) or isinstance(x, tuple):
+ return type(x)(map(lambda y: _cast(y, dtype), x))
+ return x
+
+
+def custom_fwd(fwd):
+ """Wrap a custom autograd function that always uses autocast dtype."""
+
+ @functools.wraps(fwd)
+ def decorate_fwd(*args, **kwargs):
+ if torch.is_autocast_enabled():
+ with torch.autocast(device_type="cuda", enabled=False):
+ dtype = torch.get_autocast_gpu_dtype()
+ return fwd(*_cast(args, dtype), **_cast(kwargs, dtype))
+ return fwd(*args, **kwargs)
+ return decorate_fwd
+
+
+def custom_bwd(bwd):
+ @functools.wraps(bwd)
+ def decorate_bwd(*args, **kwargs):
+ with torch.autocast(device_type="cuda", enabled=False):
+ return bwd(*args, **kwargs)
+ return decorate_bwd
\ No newline at end of file
diff --git a/build/torch29-cxx11-cpu-x86_64-linux/ops/sum.py b/build/torch29-cxx11-cpu-x86_64-linux/ops/sum.py
new file mode 100644
index 0000000000000000000000000000000000000000..e00c1aa68e1193f5b72f75a2edc37de8d505facc
--- /dev/null
+++ b/build/torch29-cxx11-cpu-x86_64-linux/ops/sum.py
@@ -0,0 +1,9 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+import torch
+
+
+def sum(x: torch.Tensor, dim: int = 0):
+ if x.shape[dim] == 1:
+ return x.squeeze(dim=dim)
+ return x.sum(dim=dim)
diff --git a/build/torch29-cxx11-cpu-x86_64-linux/ops/topology.py b/build/torch29-cxx11-cpu-x86_64-linux/ops/topology.py
new file mode 100644
index 0000000000000000000000000000000000000000..76a50d3164db20534b099dcb4d8487a7aef25d15
--- /dev/null
+++ b/build/torch29-cxx11-cpu-x86_64-linux/ops/topology.py
@@ -0,0 +1,45 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for topology kernel.
+# NOTE: Does not support gradients.
+class TopologyOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(
+ ctx: Any,
+ padded_bins: torch.Tensor,
+ block_size: int,
+ output_block_rows: int,
+ output_block_columns: int,
+ ):
+ out = torch.empty(
+ output_block_rows * output_block_columns,
+ dtype=torch.int16,
+ device=padded_bins.device,
+ )
+ ops.indices(
+ padded_bins,
+ block_size,
+ output_block_rows,
+ output_block_columns,
+ out,
+ )
+ return out
+
+
+topology = TopologyOp.apply
diff --git a/build/torch29-cxx11-cpu-x86_64-linux/stk/__init__.py b/build/torch29-cxx11-cpu-x86_64-linux/stk/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..73c40c267b1c3f4949e9c957a5d2c9f682dfc1a6
--- /dev/null
+++ b/build/torch29-cxx11-cpu-x86_64-linux/stk/__init__.py
@@ -0,0 +1,7 @@
+# import stk.random
+# import stk.ops
+# from stk.matrix import Matrix
+
+from . import random
+from . import ops
+from .matrix import Matrix
diff --git a/build/torch29-cxx11-cpu-x86_64-linux/stk/backend/__init__.py b/build/torch29-cxx11-cpu-x86_64-linux/stk/backend/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/build/torch29-cxx11-cpu-x86_64-linux/stk/backend/autocast.py b/build/torch29-cxx11-cpu-x86_64-linux/stk/backend/autocast.py
new file mode 100644
index 0000000000000000000000000000000000000000..97f6e919a60f3fd579ed0215031008d14111dc96
--- /dev/null
+++ b/build/torch29-cxx11-cpu-x86_64-linux/stk/backend/autocast.py
@@ -0,0 +1,37 @@
+import functools
+import torch
+
+
+def _is_eligible(x):
+ return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64)
+
+
+def _cast(x, dtype):
+ if isinstance(x, torch.Tensor) and _is_eligible(x):
+ return x.to(dtype)
+ elif isinstance(x, map):
+ return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()}
+ elif isinstance(x, list) or isinstance(x, tuple):
+ return type(x)(map(lambda y: _cast(y, dtype), x))
+ return x
+
+
+def custom_fwd(fwd):
+ """Wrap a custom autograd function that always uses autocast dtype."""
+
+ @functools.wraps(fwd)
+ def decorate_fwd(*args, **kwargs):
+ if torch.is_autocast_enabled():
+ with torch.autocast(device_type="cuda", enabled=False):
+ dtype = torch.get_autocast_gpu_dtype()
+ return fwd(*_cast(args, dtype), **_cast(kwargs, dtype))
+ return fwd(*args, **kwargs)
+ return decorate_fwd
+
+
+def custom_bwd(bwd):
+ @functools.wraps(bwd)
+ def decorate_bwd(*args, **kwargs):
+ with torch.autocast(device_type="cuda", enabled=False):
+ return bwd(*args, **kwargs)
+ return decorate_bwd
diff --git a/build/torch29-cxx11-cpu-x86_64-linux/stk/backend/sputnik.py b/build/torch29-cxx11-cpu-x86_64-linux/stk/backend/sputnik.py
new file mode 100644
index 0000000000000000000000000000000000000000..220c947bc1e932e8c77cc30f4069e9930f1aa962
--- /dev/null
+++ b/build/torch29-cxx11-cpu-x86_64-linux/stk/backend/sputnik.py
@@ -0,0 +1,316 @@
+import torch
+
+from ..backend import triton_kernels as backend
+from ..backend.autocast import custom_bwd, custom_fwd
+
+
+def _standardize_shape(x, transpose):
+ if transpose:
+ return torch.Size((x[1], x[0]))
+ return x
+
+
+def _sparse_transpose(x):
+ return (torch.Size((x[0][1], x[0][0])), ) + x[1:]
+
+
+def _transpose_helper(x, transpose):
+ if isinstance(x, torch.Tensor):
+ return x.t() if transpose else x
+ if transpose:
+ x = _sparse_transpose(x)
+ return x + (transpose,)
+
+
+def _wrap(x):
+ if isinstance(x, torch.Tensor):
+ return (x,)
+ return x
+
+
+def _is_transposed(x):
+ return (not x.is_contiguous() and
+ x.stride()[0] == 1 and
+ x.stride()[1] == x.size()[0])
+
+
+def _call_helper(op, out, a, b, trans_a, trans_b):
+ args = (_wrap(_transpose_helper(a, trans_a)) +
+ _wrap(_transpose_helper(b, trans_b)))
+ if isinstance(out, tuple):
+ args = args + out
+ return op(*args)
+
+
+def _preprocess_inputs(lhs, rhs, dy):
+ if isinstance(lhs, torch.Tensor) and _is_transposed(lhs):
+ lhs = lhs.t()
+ if isinstance(rhs, torch.Tensor) and _is_transposed(rhs):
+ rhs = rhs.t()
+ if (isinstance(dy, torch.Tensor) and
+ not dy.is_contiguous() and
+ not _is_transposed(dy)):
+ dy = dy.contiguous()
+ if isinstance(dy, tuple) and not dy[1].is_contiguous():
+ dy = (dy[0], dy[1].contiguous()) + dy[2:]
+ return lhs, rhs, dy
+
+
+def _postprocess_outputs(x, transpose, grad):
+ if isinstance(x, torch.Tensor) and transpose:
+ return grad.t()
+ return grad
+
+
+def _lhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs):
+ lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy)
+
+ a, b = (rhs, dy) if trans_lhs else (dy, rhs)
+ trans_a = trans_lhs and trans_rhs
+ trans_b = trans_lhs or not trans_rhs
+ out = _call_helper(op, lhs, a, b, trans_a, trans_b)
+ return _postprocess_outputs(lhs, trans_lhs, out)
+
+
+def _rhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs):
+ lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy)
+
+ a, b = (dy, lhs) if trans_rhs else (lhs, dy)
+ trans_a = not trans_lhs or trans_rhs
+ trans_b = trans_lhs and trans_rhs
+ out = _call_helper(op, rhs, a, b, trans_a, trans_b)
+ return _postprocess_outputs(rhs, trans_rhs, out)
+
+
+class DSD(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs):
+ ctx.save_for_backward(data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ rhs)
+ ctx.shape = _standardize_shape(shape, transpose_a)
+ ctx.transpose_a = transpose_a
+
+ out = torch.empty(
+ (shape[0], rhs.size()[1]),
+ dtype=rhs.dtype,
+ device=rhs.device)
+
+ backend.dsd(shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs,
+ out)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs = (ctx.shape,) + saved_tensors[:-1]
+ rhs = saved_tensors[-1]
+ trans_a = ctx.transpose_a
+ trans_b = _is_transposed(rhs)
+
+ ddata = None
+ if ctx.needs_input_grad[1]:
+ ddata = _lhs_gradient(sdd,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ drhs = None
+ if ctx.needs_input_grad[-1]:
+ op = dds if trans_b else dsd
+ drhs = _rhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return None, ddata, None, None, None, None, None, None, None, drhs
+
+
+dsd = DSD.apply
+
+
+class DDS(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b):
+ ctx.save_for_backward(lhs,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t)
+ ctx.shape = _standardize_shape(shape, transpose_b)
+ ctx.transpose_b = transpose_b
+ out = torch.empty((lhs.size()[0], shape[1]),
+ dtype=lhs.dtype,
+ device=lhs.device)
+ backend.dds(lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b,
+ out)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs = saved_tensors[0]
+ rhs = (ctx.shape,) + saved_tensors[1:]
+ trans_a = _is_transposed(lhs)
+ trans_b = ctx.transpose_b
+
+ dlhs = None
+ if ctx.needs_input_grad[0]:
+ op = dsd if trans_a else dds
+ dlhs = _lhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ ddata = None
+ if ctx.needs_input_grad[2]:
+ ddata = _rhs_gradient(sdd,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return dlhs, None, ddata, None, None, None, None, None, None, None
+
+
+dds = DDS.apply
+
+
+class SDD(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ lhs,
+ rhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t):
+ ctx.save_for_backward(
+ lhs,
+ rhs,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t)
+ ctx.shape = shape
+ out = torch.empty(
+ data.shape,
+ dtype=lhs.dtype,
+ device=lhs.device)
+ backend.sdd(lhs,
+ rhs,
+ shape,
+ out,
+ offsets,
+ row_indices,
+ column_indices)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs, rhs = saved_tensors[:2]
+ dy = (ctx.shape, dy) + saved_tensors[2:]
+ trans_a = _is_transposed(lhs)
+ trans_b = _is_transposed(rhs)
+
+ dlhs = None
+ if ctx.needs_input_grad[0]:
+ op = dds if trans_a else dsd
+ dlhs = _lhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ drhs = None
+ if ctx.needs_input_grad[1]:
+ op = dsd if trans_b else dds
+ drhs = _rhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return dlhs, drhs, None, None, None, None, None, None, None, None
+
+
+sdd = SDD.apply
+
+class RowIndices(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, shape, data, offsets, column_indices):
+ out = torch.empty(
+ column_indices.shape,
+ dtype=column_indices.dtype,
+ device=column_indices.device)
+ backend.row_indices(shape, data, offsets, column_indices, out)
+ return out
+
+
+row_indices = RowIndices.apply
diff --git a/build/torch29-cxx11-cpu-x86_64-linux/stk/backend/triton_kernels.py b/build/torch29-cxx11-cpu-x86_64-linux/stk/backend/triton_kernels.py
new file mode 100644
index 0000000000000000000000000000000000000000..c535309f3321249f475367164a558f94a4f8eb86
--- /dev/null
+++ b/build/torch29-cxx11-cpu-x86_64-linux/stk/backend/triton_kernels.py
@@ -0,0 +1,393 @@
+import torch
+import triton
+import triton.language as tl
+from dataclasses import dataclass
+
+@dataclass
+class TritonConfig:
+ BLOCK_M: int = 128
+ BLOCK_N: int = 128
+ BLOCK_K: int = 32
+ BLOCK_SIZE: int = 128
+ NUM_STAGES: int = 4
+ NUM_WARPS: int = 4
+
+def _validate_matmul_dims(M: int, K: int, N: int):
+ error_string = "incompatible dimensions: tensor has dim with length: {}, which must be divisible by {}"
+ assert M % TritonConfig.BLOCK_M == 0, error_string.format(M, TritonConfig.BLOCK_M)
+ assert K % TritonConfig.BLOCK_K == 0, error_string.format(K, TritonConfig.BLOCK_K)
+ assert N % TritonConfig.BLOCK_N == 0, error_string.format(N, TritonConfig.BLOCK_N)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _sdd_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+ # matrix multiplication
+ pid = tl.program_id(0)
+ pid_m = tl.load(row_indices + pid)
+ pid_n = tl.load(column_indices + pid)
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
+ rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
+ rk = tl.arange(0, BLOCK_K)
+ # pointers
+ A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
+ B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ for k in range(0, tl.cdiv(K, BLOCK_K)):
+ a = tl.load(A)
+ b = tl.load(B)
+ acc += tl.dot(a, b)
+ A += BLOCK_K * stride_ak
+ B += BLOCK_K * stride_bk
+ #Store to sparse matrix
+ acc = acc.to(C.dtype.element_ty)
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+ cm = tl.arange(0, BLOCK_M)
+ cn = tl.arange(0, BLOCK_N)
+ C = C + pid * BLOCK_ELEMENTS + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _dsd_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices, offsets,
+ block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+
+ # matrix multiplication
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+
+ num_pid_m = tl.num_programs(0)
+ num_pid_n = tl.num_programs(1)
+ pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M)
+
+ start_inx = tl.load(offsets + pid_m)
+ end_inx = tl.load(offsets + pid_m + 1)
+
+ # pointers to sparse matrix
+ rm = tl.arange(0, BLOCK_M)
+ rak = tl.arange(0, BLOCK_K)
+
+ A += (rm[:, None] * stride_am + rak[None, :] * stride_ak)
+
+ # pointers to dense matrix
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ rbk = tl.arange(0, BLOCK_K)
+ B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn)
+
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K)
+
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+ ak_sub_incr = BLOCK_K * stride_ak
+ bk_sub_incr = BLOCK_K * stride_bk
+ bk_block_incr = BLOCK_SIZE * stride_bk
+
+ for k in range(nsub_blocks * (end_inx - start_inx)):
+ sub_block_inx = k % nsub_blocks
+ block_inx = k // nsub_blocks
+
+ if trans_A:
+ ptr_A = A + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr
+ else:
+ ptr_A = A + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr
+
+ ptr_B = B + tl.load(column_indices + start_inx + block_inx) * bk_block_incr + sub_block_inx * bk_sub_incr
+
+ a = tl.load(ptr_A)
+ b = tl.load(ptr_B)
+ acc += tl.dot(a, b)
+
+ acc = acc.to(C.dtype.element_ty)
+
+ cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+
+ C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _dds_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices, offsets,
+ block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+
+ # matrix multiplication
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+
+ num_pid_m = tl.num_programs(0)
+ num_pid_n = tl.num_programs(1)
+ pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M)
+
+ start_inx = tl.load(offsets + pid_n)
+ end_inx = tl.load(offsets + pid_n + 1)
+
+ # pointers to dense matrix
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ rak = tl.arange(0, BLOCK_K)
+
+ A += (rm[:, None] * stride_am + rak[None, :] * stride_ak)
+
+ # pointers to sparse matrix
+ rn = tl.arange(0, BLOCK_N)
+ rbk = tl.arange(0, BLOCK_K)
+ B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn)
+
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K)
+
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+
+ ak_sub_incr = BLOCK_K * stride_ak
+ ak_block_incr = BLOCK_SIZE * stride_ak
+ bk_sub_incr = BLOCK_K * stride_bk
+
+ for k in range(nsub_blocks * (end_inx - start_inx)):
+ sub_block_inx = k % nsub_blocks
+ block_inx = k // nsub_blocks
+
+ if trans_B:
+ ptr_B = B + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr
+ else:
+ ptr_B = B + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr
+
+ ptr_A = A + tl.load(column_indices + start_inx + block_inx) * ak_block_incr + sub_block_inx * ak_sub_incr
+ a = tl.load(ptr_A)
+ b = tl.load(ptr_B)
+ acc += tl.dot(a, b)
+
+ acc = acc.to(C.dtype.element_ty)
+ cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+def dsd(shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs,
+ out
+ ):
+
+ device = rhs.device
+ trans_A = transpose_a
+ trans_B = False
+
+ if rhs.stride(0) > 1 and rhs.stride(1) > 1:
+ trans_B = True
+
+ # checks constraints
+ assert shape[1] == rhs.shape[0], "incompatible dimensions"
+ M, K = shape
+ _, N = rhs.shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if rhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ stride_am, stride_ak = data.stride(1), data.stride(2)
+ stride_bk, stride_bn = rhs.stride(0), rhs.stride(1)
+ a_column_indices = column_indices
+ a_offsets = offsets
+
+ # launch kernel
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N']))
+
+ if trans_A:
+ stride_am, stride_ak = data.stride(2), data.stride(1)
+ a_column_indices, a_offsets = column_indices_t, offsets_t
+
+ if trans_B:
+ stride_bk, stride_bn = rhs.stride(1), rhs.stride(0)
+
+ _dsd_kernel[grid](
+ data.data, rhs, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(0), out.stride(1),
+ row_indices, a_column_indices, a_offsets,
+ block_offsets_t, trans_A, trans_B,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+ # return out
+
+def dds(lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b,
+ out
+ ):
+
+ device = lhs.device
+ trans_B = transpose_b
+ trans_A = False
+
+ if lhs.stride(0) > 1 and lhs.stride(1) > 1:
+ trans_A = True
+
+ # checks constraints
+ assert lhs.shape[1] == shape[0], "incompatible dimensions"
+ M, K = lhs.shape
+ _, N = shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if lhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ stride_am, stride_ak = lhs.stride(0), lhs.stride(1)
+ stride_bk, stride_bn = data.stride(1), data.stride(2)
+ b_column_indices = column_indices_t
+ b_offsets = offsets_t
+
+ # launch kernel
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N']))
+
+ if trans_A:
+ stride_am, stride_ak = lhs.stride(1), lhs.stride(0)
+ if trans_B:
+ stride_bk, stride_bn = data.stride(2), data.stride(1)
+ b_column_indices, b_offsets = column_indices, offsets
+
+ _dds_kernel[grid](
+ lhs, data, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(0), out.stride(1),
+ row_indices, b_column_indices, b_offsets,
+ block_offsets_t, trans_A, trans_B,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+
+def sdd(lhs,
+ rhs,
+ shape,
+ out,
+ offsets,
+ row_indices,
+ column_indices
+ ):
+
+ device = out.device
+ trans_A = False
+ trans_B = False
+
+ if lhs.stride(0) > 1 and lhs.stride(1) > 1:
+ trans_A = True
+ if rhs.stride(0) > 1 and rhs.stride(1) > 1:
+ trans_B = True
+
+ # checks constraints
+ assert lhs.shape[1] == rhs.shape[0], "incompatible dimensions"
+ M, K = lhs.shape
+ _, N = rhs.shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if out.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ # launch kernel
+ nnz_blocks = len(row_indices)
+ grid = lambda META: (nnz_blocks,)
+
+ stride_am, stride_ak = lhs.stride(0), lhs.stride(1)
+ stride_bk, stride_bn = rhs.stride(0), rhs.stride(1)
+
+ if trans_A:
+ stride_am, stride_ak = lhs.stride(1), lhs.stride(0)
+ if trans_B:
+ stride_bk, stride_bn = rhs.stride(1), rhs.stride(0)
+
+ _sdd_kernel[grid](
+ lhs, rhs, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(1), out.stride(2),
+ row_indices, column_indices,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+
+@triton.jit
+def _row_indices_kernel(offsets, out):
+ pid = tl.program_id(0)
+ row_offset = tl.load(offsets + pid)
+ nnz_blocks = tl.load(offsets + pid + 1) - row_offset
+ for nnz_block in range(nnz_blocks):
+ tl.store(out + row_offset + nnz_block, pid)
+
+def row_indices(
+ shape, data, offsets, column_indices, out
+):
+ block_rows = len(offsets) - 1
+ _row_indices_kernel[(block_rows, )](offsets, out)
diff --git a/build/torch29-cxx11-cpu-x86_64-linux/stk/matrix.py b/build/torch29-cxx11-cpu-x86_64-linux/stk/matrix.py
new file mode 100644
index 0000000000000000000000000000000000000000..80f42263d6aada287adbfa52a61fe950162a9e28
--- /dev/null
+++ b/build/torch29-cxx11-cpu-x86_64-linux/stk/matrix.py
@@ -0,0 +1,329 @@
+import numpy as np
+import torch
+
+# 1. Add heavyweight (data) validation helper.
+# 2. Add construction helpers
+# 3. Make indentation consistent
+# 4. Replace asserts with descriptive errors.
+
+##
+### Validation helpers.
+##
+
+
+def _validate_matrix(shape, data, row_indices, column_indices, offsets):
+ # Data should be [nnz, block_size, block_size]
+ if data.dim() == 1:
+ data = torch.reshape(data, [data.numel(), 1, 1])
+
+ # Blocks should be square.
+ if data.shape[-2] != data.shape[-1]:
+ raise ValueError(
+ "Expected square blocking in data. "
+ f"Got block shape {[data.shape[-2], data.shape[-1]]}")
+
+ # Flatten batch dimensions on data - original shape preserved
+ # in shape argument.
+ block_size = data.shape[-1]
+ data = data.view([-1, block_size, block_size])
+
+ if data.dim() != 3:
+ raise ValueError(
+ "Expected 3D shape for data (nnz, block, block). "
+ f"Got shape {data.dim()}D shape.")
+
+ block_size = data.shape[1]
+ if shape[-2] % block_size != 0 or shape[-1] % block_size != 0:
+ raise ValueError(
+ "Matrix shape must be dividible by blocking. "
+ f"Got shape {shape} with "
+ f"{[block_size, block_size]} blocking.")
+
+ if np.prod(shape) < data.numel():
+ raise ValueError(
+ "Invalid matrix. Number of nonzeros exceeds matrix capacity "
+ f"({data.numel()} v. {np.prod(shape)})")
+
+ if row_indices.dim() != 1:
+ raise ValueError(
+ f"Expected 1D row_indices. Got {row_indices.dim()}D row_indices.")
+
+ if column_indices.dim() != 1:
+ raise ValueError(
+ f"Expected 1D column_indices. Got {column_indices.dim()}D column_indices.")
+
+ if offsets.dim() != 1:
+ raise ValueError(
+ f"Expected 1D offsets. Got {offsets.dim()}D offsets.")
+
+ if row_indices.numel() != data.shape[0]:
+ raise ValueError(
+ "Expected 1 index per nonzero block. "
+ f"Got {row_indices.numel()} row_indices for {data.shape[0]} blocks")
+
+ if column_indices.numel() != data.shape[0]:
+ raise ValueError(
+ "Expected 1 index per nonzero block. "
+ f"Got {column_indices.numel()} column_indices for {data.shape[0]} blocks")
+
+ block_rows = np.prod(shape[:-1]) / block_size
+ if offsets.numel() != block_rows + 1:
+ raise ValueError(
+ "Expected one offset per block row plus one. "
+ f"Got {offsets.numel()} offsets with {block_rows} block rows.")
+
+ is_cuda = (data.is_cuda and
+ row_indices.is_cuda and
+ column_indices.is_cuda and
+ offsets.is_cuda)
+ is_cpu = (not data.is_cuda and
+ not row_indices.is_cuda and
+ not column_indices.is_cuda and
+ not offsets.is_cuda)
+ if not (is_cuda or is_cpu):
+ raise ValueError(
+ "Expected data & meta-data on common device. "
+ f"Got data on {data.device}, row_indices on {row_indices.device} "
+ f"column_indices on {column_indices.device} and "
+ f"offsets on {offsets.device}.")
+
+ if data.dtype != torch.float16:
+ raise ValueError(
+ f"Expected float16 data. Got {data.dtype} data.")
+ if row_indices.dtype != torch.int16:
+ raise ValueError(
+ f"Expected int16 row_indices. Got {row_indices.dtype} row_indices.")
+ if column_indices.dtype != torch.int16:
+ raise ValueError(
+ f"Expected int16 column_indices. Got {column_indices.dtype} column_indices.")
+ if offsets.dtype != torch.int32:
+ raise ValueError(
+ f"Expected int32 offsets. Got {offsets.dtype} offsets.")
+ return data
+
+
+def _transpose(size, data, row_indices, column_indices, offsets):
+ block_columns = size[1] // data.shape[1]
+
+ # Sort row indices by column indices to get the transposed matrix's
+ # column indices.
+ gather_indices = column_indices.argsort()
+ column_indices_t = row_indices.gather(0, gather_indices)
+ block_offsets_t = gather_indices.int()
+
+ # NOTE: Histogram is not implemented for any integer type on CPU. Do
+ # the histogram in 32-bit float, which can exactly represent 16-bit
+ # integers.
+ column_indices_float = column_indices.float()
+
+ zero = torch.zeros((1,), dtype=torch.int32, device=data.device)
+ nnz_per_column = column_indices_float.histc(block_columns, 0, block_columns)
+ nnz_per_column = nnz_per_column.int()
+ offsets_t = torch.cat([zero, nnz_per_column.cumsum(0, dtype=torch.int32)])
+ return column_indices_t, offsets_t, block_offsets_t
+
+
+class Matrix(torch.nn.Module):
+ """A matrix stored in sparse format.
+
+ Underlying format is block compressed sparse row (BCSR).
+
+ TODO(tgale): Make this mirror torch.Tensor API as much as possible.
+ """
+
+ def __init__(self,
+ size,
+ data,
+ row_indices,
+ column_indices,
+ offsets,
+ column_indices_t=None,
+ offsets_t=None,
+ block_offsets_t=None):
+ super().__init__()
+ self._size = size
+ self._data = data
+ self._row_indices = row_indices
+ self._column_indices = column_indices
+ self._offsets = offsets
+
+ # Produce the transpose meta-data if it is not passed in.
+ if ((column_indices_t is None) or (offsets_t is None) or
+ (block_offsets_t is None)):
+ column_indices_t, offsets_t, block_offsets_t = _transpose(
+ size, data, row_indices, column_indices, offsets)
+ self._column_indices_t = column_indices_t
+ self._offsets_t = offsets_t
+ self._block_offsets_t = block_offsets_t
+
+ self._transposed = False
+
+ # Validate that our metadata will not overflow.
+ max_dim = np.iinfo(np.int16).max * self.blocking
+ if column_indices.dtype == torch.int16:
+ if size[0] > max_dim or size[1] > max_dim:
+ raise ValueError(
+ "Sparse matrix with shape {size} exceeds representable "
+ "size with 16-bit indices.")
+
+ def validate(self):
+ _validate_matrix(self._size,
+ self._data,
+ self._row_indices,
+ self._column_indices,
+ self._offsets)
+
+ # TODO(tgale): Add heavyweight data validation.
+
+ def to(self, device):
+ # TODO(tgale): Handle type conversions here. We
+ # need to set the appropriate meta-data type for
+ # the given floating-point type.
+ self._data = self._data.to(device)
+ self._row_indices = self._row_indices.to(device)
+ self._column_indices = self._column_indices.to(device)
+ self._offsets = self._offsets.to(device)
+ self._column_indices_t = self._column_indices_t.to(device)
+ self._offsets_t = self._offsets_t.to(device)
+ self._block_offsets_t = self._block_offsets_t.to(device)
+ return self
+
+ def cuda(self):
+ return self.to(torch.cuda.current_device())
+
+ def clone(self):
+ return Matrix(
+ self.size(),
+ self.data.clone(),
+ self.row_indices.clone(),
+ self.column_indices.clone(),
+ self.offsets.clone(),
+ self.column_indices_t.clone(),
+ self.offsets_t.clone(),
+ self.block_offsets_t.clone())
+
+ def t(self):
+ if self.dim() != 2:
+ raise ValueError(
+ "t() expects a tensor with <= 2 dimensions, "
+ f"but self is {self.dim()}D.")
+ out = Matrix(self.size(),
+ self.data,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+ out._transposed = not self._transposed
+ out._size = torch.Size((self._size[1], self._size[0]))
+ return out
+
+ def contiguous(self):
+ raise ValueError("Not yet implemented.")
+
+ def is_contiguous(self):
+ return not self._transposed
+
+ @property
+ def is_cuda(self):
+ return self._data.is_cuda
+
+ @property
+ def device(self):
+ return self._data.device
+
+ def size(self):
+ return self._size
+
+ @property
+ def shape(self):
+ return self.size()
+
+ def dim(self):
+ return len(self._size)
+
+ @property
+ def data(self):
+ return self._data
+
+ @property
+ def row_indices(self):
+ return self._row_indices
+
+ @property
+ def column_indices(self):
+ return self._column_indices
+
+ @property
+ def offsets(self):
+ return self._offsets
+
+ @property
+ def offsets_t(self):
+ return self._offsets_t
+
+ @property
+ def column_indices_t(self):
+ return self._column_indices_t
+
+ @property
+ def block_offsets_t(self):
+ return self._block_offsets_t
+
+ @property
+ def dtype(self):
+ return self.data.dtype
+
+ @property
+ def nnz(self):
+ return self.data.numel()
+
+ @property
+ def blocking(self):
+ return self.data.shape[1]
+
+ @property
+ def requires_grad(self):
+ return self.data.requires_grad
+
+ def requires_grad_(self, x):
+ self.data.requires_grad_(x)
+ return self
+
+ def view(self, *shape):
+ assert self.is_contiguous()
+ if shape[-1] != self.size()[-1]:
+ raise ValueError(
+ "Can't change view on compressed dimension. "
+ f"{self.size()[-1]} v. {shape[-1]}.")
+ if np.prod(shape) != np.prod(self.size()):
+ raise ValueError(
+ "Mismatch in numel of Matrix and new shape. "
+ f"{np.prod(self.size())} v. {np.prod(shape)}")
+ return Matrix(shape,
+ self.data,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+
+ @property
+ def grad(self):
+ # TODO(tgale): Make sure this mirrors torch.Tensor
+ # behavior in the case where we ask for the gradient
+ # of a non-contiguous tensor.
+ size = self.size()
+ if not self.is_contiguous():
+ size = torch.Size((size[1], size[0]))
+ out = Matrix(size,
+ self.data.grad,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+ return out if self.is_contiguous() else out.t()
diff --git a/build/torch29-cxx11-cpu-x86_64-linux/stk/ops/__init__.py b/build/torch29-cxx11-cpu-x86_64-linux/stk/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc873b236f4cd4036964c016a4036e3ce5ebf1ac
--- /dev/null
+++ b/build/torch29-cxx11-cpu-x86_64-linux/stk/ops/__init__.py
@@ -0,0 +1,3 @@
+from .linear_ops import dds, dsd, sdd
+from .matrix_ops import ones_like, row_indices, sum, to_dense, to_sparse
+from .eltwise_ops import mul
diff --git a/build/torch29-cxx11-cpu-x86_64-linux/stk/ops/eltwise_ops.py b/build/torch29-cxx11-cpu-x86_64-linux/stk/ops/eltwise_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba7d7332320250fd01fa60e60528f19de3e8ed03
--- /dev/null
+++ b/build/torch29-cxx11-cpu-x86_64-linux/stk/ops/eltwise_ops.py
@@ -0,0 +1,28 @@
+from ..matrix import Matrix
+
+def mul(a, b):
+ """Performs element-wise multiplication of matrices a and b.
+
+ It is the user's responsibility to make sure that a and b
+ follow the same matrix topology. This function assumes it is safe
+ to use the topoplogy of a.
+
+ Args:
+ a: stk.Matrix.
+ b: stk.Matrix with a's matrix topology.
+
+ Returns:
+ stk.Matrix where the entries correspond to torch.mul(a, b).
+ """
+ assert isinstance(a, Matrix)
+ assert isinstance(b, Matrix)
+ assert a.size() == b.size()
+
+ return Matrix(a.size(),
+ a.data * b.data,
+ a.row_indices,
+ a.column_indices,
+ a.offsets,
+ a.column_indices_t,
+ a.offsets_t,
+ a.block_offsets_t)
diff --git a/build/torch29-cxx11-cpu-x86_64-linux/stk/ops/eltwise_ops_test.py b/build/torch29-cxx11-cpu-x86_64-linux/stk/ops/eltwise_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..66bfd4f6af77042d3c5bdb1fe18d00e457478d46
--- /dev/null
+++ b/build/torch29-cxx11-cpu-x86_64-linux/stk/ops/eltwise_ops_test.py
@@ -0,0 +1,86 @@
+import unittest
+import itertools
+import torch
+from absl.testing import parameterized
+
+import stk
+from stk.ops.linear_ops_test import allclose, _dense_and_sparse
+
+_MATRIX_SIZES = (
+ (128, 128, 0.0),
+ (256, 256, 0.5),
+ (2048, 1024, 0.8),
+ (512, 128, 0.0),
+ (128, 512, 0.0),
+ (1024, 512, 0.0),
+ (1024, 512, 0.5),
+ (1024, 512, 0.75),
+ (512, 1024, 0.0),
+ (512, 1024, 0.5),
+ (512, 1024, 0.75),
+ (1024, 1024, 0.0),
+ (1024, 1024, 0.5),
+ (1024, 1024, 0.75),
+)
+
+_DTYPE = (
+ torch.float16, torch.bfloat16
+)
+
+def _generate_testcases():
+ testcases = itertools.product(_MATRIX_SIZES, _DTYPE)
+ testcases = [(*size, 128, dtype) for
+ (size, dtype) in testcases]
+ return testcases
+
+_ELTWISE_OP_TESTS = _generate_testcases()
+
+def _dense_and_sparse_like(x, std=0.1):
+ dense_data = torch.randn_like(x.data, device=x.device) * std
+ sparse = stk.Matrix(x.size(),
+ dense_data,
+ x.row_indices,
+ x.column_indices,
+ x.offsets)
+ dense = stk.ops.to_dense(sparse)
+
+ return (dense.requires_grad_(True),
+ sparse.requires_grad_(True))
+
+@parameterized.parameters(_ELTWISE_OP_TESTS)
+class EltwiseOpsTest(parameterized.TestCase):
+
+ def testEltwiseMul(self, m, n, sparsity, blocking, dtype):
+
+ a_dense, a = _dense_and_sparse(m, n, sparsity, blocking, dtype)
+ b_dense, b = _dense_and_sparse_like(a)
+
+ out = stk.ops.mul(a, b)
+ expected_out = torch.mul(a_dense, b_dense)
+
+ # Compute the gradients w.r.t. the inputs.
+ expected_out.sum().backward()
+ stk.ops.sum(out).backward()
+
+ # Validate the results.
+ out = stk.ops.to_dense(out)
+ self.assertEqual(out.dim(), 2)
+ self.assertEqual(expected_out.size(), out.size())
+ self.assertTrue(allclose(out, expected_out))
+
+ # LHS gradient.
+ grad = stk.ops.to_dense(a.grad)
+ expected_grad = a_dense.grad
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size(), grad.size())
+ self.assertTrue(allclose(grad, expected_grad))
+
+ # RHS gradient.
+ grad = stk.ops.to_dense(b.grad)
+ expected_grad = b_dense.grad
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size(), grad.size())
+ self.assertTrue(allclose(grad, expected_grad))
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch29-cxx11-cpu-x86_64-linux/stk/ops/linear_ops.py b/build/torch29-cxx11-cpu-x86_64-linux/stk/ops/linear_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d277c8c07f9e30addc31900a12175c8a1f4d7ad
--- /dev/null
+++ b/build/torch29-cxx11-cpu-x86_64-linux/stk/ops/linear_ops.py
@@ -0,0 +1,59 @@
+import torch
+
+from ..backend import sputnik
+from ..matrix import Matrix
+
+
+def dsd(a, b):
+ assert isinstance(a, Matrix)
+ assert isinstance(b, torch.Tensor)
+ return sputnik.dsd(
+ a.size(),
+ a.data, a.offsets,
+ a.row_indices,
+ a.column_indices,
+ a.offsets_t,
+ a.column_indices_t,
+ a.block_offsets_t,
+ not a.is_contiguous(),
+ b)
+
+
+def dds(a, b):
+ assert isinstance(a, torch.Tensor)
+ assert isinstance(b, Matrix)
+ return sputnik.dds(
+ a,
+ b.size(),
+ b.data, b.offsets,
+ b.row_indices,
+ b.column_indices,
+ b.offsets_t,
+ b.column_indices_t,
+ b.block_offsets_t,
+ not b.is_contiguous())
+
+
+def sdd(a, b, topo):
+ assert isinstance(a, torch.Tensor)
+ assert isinstance(b, torch.Tensor)
+ assert isinstance(topo, Matrix)
+ assert topo.is_contiguous()
+ out = sputnik.sdd(
+ a, b,
+ topo.size(),
+ topo.data,
+ topo.offsets,
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets_t,
+ topo.column_indices_t,
+ topo.block_offsets_t)
+ return Matrix(topo.size(),
+ out,
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets,
+ topo.column_indices_t,
+ topo.offsets_t,
+ topo.block_offsets_t)
diff --git a/build/torch29-cxx11-cpu-x86_64-linux/stk/ops/linear_ops_test.py b/build/torch29-cxx11-cpu-x86_64-linux/stk/ops/linear_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..ced1d782fbc9f9ca16b3449239f1588dc5ff5e00
--- /dev/null
+++ b/build/torch29-cxx11-cpu-x86_64-linux/stk/ops/linear_ops_test.py
@@ -0,0 +1,216 @@
+import unittest
+import itertools
+import numpy as np
+import torch
+from absl.testing import parameterized
+
+import stk
+
+
+def allclose(x, y, pct=0.25):
+ mask = torch.isclose(x, y, rtol=5e-2)
+ pct_diff = (mask.numel() - mask.sum()) / mask.numel() * 100
+ if pct_diff > pct:
+ print("{:.2f}% of values not close.".format(pct_diff))
+ return False
+ return True
+
+
+# An assortment of problems designed to make sure
+# the bindings are operating correctly.
+_MATRIX_SIZES = (
+ (128, 128, 128, 0.0),
+ (256, 256, 256, 0.5),
+ (2048, 1024, 512, 0.8),
+ (512, 128, 128, 0.0),
+ (128, 128, 512, 0.0),
+ (1024, 512, 512, 0.0),
+ (1024, 512, 512, 0.5),
+ (1024, 512, 512, 0.75),
+ (512, 512, 1024, 0.0),
+ (512, 512, 1024, 0.5),
+ (512, 512, 1024, 0.75),
+ (1024, 1024, 1024, 0.0),
+ (1024, 1024, 1024, 0.5),
+ (1024, 1024, 1024, 0.75),
+)
+
+_TRANSPOSE = (
+ (False, False),
+ (False, True),
+ (True, False),
+ (True, True),
+)
+
+_DTYPE = (
+ torch.float16, torch.bfloat16
+)
+
+def _generate_testcases():
+ testcases = itertools.product(_MATRIX_SIZES, _TRANSPOSE, _DTYPE)
+ testcases = [(*size, *trans, 128, dtype) for
+ (size, trans, dtype) in testcases]
+ return testcases
+
+_LINEAR_OP_TESTS = _generate_testcases()
+
+def _dense_and_sparse(rows, cols, sparsity, blocking, dtype, std=0.1):
+ mask = stk.random.dense_mask(rows, cols, sparsity, blocking)
+ dense = (torch.randn(rows, cols) * std * mask).type(dtype)
+ sparse = stk.ops.to_sparse(dense, blocking)
+ cuda_device = torch.device("cuda")
+ return (dense.to(cuda_device).requires_grad_(True),
+ sparse.to(cuda_device).requires_grad_(True))
+
+
+def _dense(rows, cols, dtype, std=0.1):
+ cuda_device = torch.device("cuda")
+ out = (torch.randn(rows, cols) * std).type(dtype)
+ return out.to(cuda_device).requires_grad_(True)
+
+
+def _dense_2x(rows, cols, dtype):
+ a = _dense(rows, cols, dtype)
+ return a, a.detach().requires_grad_(True)
+
+
+def _with_transpose(op, a, b, trans_a, trans_b):
+ a = a.t() if trans_a else a
+ b = b.t() if trans_b else b
+ return op(a, b)
+
+
+def _mmm(a, b, topo):
+ mask = stk.ops.to_dense(stk.ops.ones_like(topo))
+ return torch.mm(a, b) * mask
+
+
+def _sparse_out_with_transpose(op, a, b, topo, trans_a, trans_b):
+ a = a.t() if trans_a else a
+ b = b.t() if trans_b else b
+ return op(a, b, topo)
+
+
+def _mask(x, mask):
+ mask = stk.ops.to_dense(stk.ops.ones_like(mask))
+ return x * mask
+
+
+@parameterized.parameters(*_LINEAR_OP_TESTS)
+class LinearOpsTest(parameterized.TestCase):
+
+ def testLinearOps_Dsd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+ # Construct the operands.
+ a_shape = (k, m) if trans_a else (m, k)
+ a_dense, a = _dense_and_sparse(*a_shape, sparsity, blocking, dtype)
+ b_shape = (n, k) if trans_b else (k, n)
+ b, bcp = _dense_2x(*b_shape, dtype)
+
+ # Execute the matmul.
+ out = _with_transpose(stk.ops.dsd, a, b, trans_a, trans_b)
+ expected_out = _with_transpose(torch.mm, a_dense, bcp, trans_a, trans_b)
+
+ # Compute the gradients w.r.t. the inputs.
+ expected_out.sum().backward()
+ out.sum().backward()
+
+ # Validate the results.
+ self.assertEqual(out.dim(), 2)
+ self.assertEqual(expected_out.size()[0], out.size()[0])
+ self.assertEqual(expected_out.size()[1], out.size()[1])
+ self.assertTrue(allclose(out, expected_out))
+
+ # LHS gradient.
+ grad = stk.ops.to_dense(a.grad)
+ expected_grad = _mask(a_dense.grad, a.grad)
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
+ self.assertTrue(allclose(grad, expected_grad))
+
+ # RHS gradient.
+ grad = b.grad
+ expected_grad = bcp.grad
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
+ self.assertTrue(allclose(grad, expected_grad))
+
+ def testLinearOps_Dds(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+ # Construct the operands.
+ a_shape = (k, m) if trans_a else (m, k)
+ a, acp = _dense_2x(*a_shape, dtype)
+ b_shape = (n, k) if trans_b else (k, n)
+ b_dense, b = _dense_and_sparse(*b_shape, sparsity, blocking, dtype)
+
+ # Execute the matmul.
+ out = _with_transpose(stk.ops.dds, a, b, trans_a, trans_b)
+ expected_out = _with_transpose(torch.mm, acp, b_dense, trans_a, trans_b)
+
+ # Compute the gradients w.r.t. the inputs.
+ expected_out.sum().backward()
+ out.sum().backward()
+
+ # Validate the results.
+ self.assertEqual(out.dim(), 2)
+ self.assertEqual(expected_out.size()[0], out.size()[0])
+ self.assertEqual(expected_out.size()[1], out.size()[1])
+ self.assertTrue(allclose(out, expected_out))
+
+ # LHS gradient.
+ grad = a.grad
+ expected_grad = acp.grad
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
+ self.assertTrue(allclose(grad, expected_grad))
+
+ # RHS gradient.
+ grad = stk.ops.to_dense(b.grad)
+ expected_grad = _mask(b_dense.grad, b.grad)
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
+ self.assertTrue(allclose(grad, expected_grad))
+
+ def testLinearOps_Sdd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+ # Construct the operands.
+ a_shape = (k, m) if trans_a else (m, k)
+ a, acp = _dense_2x(*a_shape, dtype)
+ b_shape = (n, k) if trans_b else (k, n)
+ b, bcp = _dense_2x(*b_shape, dtype)
+ _, topo = _dense_and_sparse(m, n, sparsity, blocking, dtype)
+
+ # Execute the matmul.
+ out = _sparse_out_with_transpose(stk.ops.sdd, a, b, topo, trans_a, trans_b)
+ expected_out = _sparse_out_with_transpose(_mmm, acp, bcp, topo, trans_a, trans_b)
+
+ # Compute the gradients w.r.t. the inputs.
+ expected_out.sum().backward()
+ stk.ops.sum(out).backward()
+
+ # Validate the results.
+ out = stk.ops.to_dense(out)
+ self.assertEqual(out.dim(), 2)
+ self.assertEqual(expected_out.size()[0], out.size()[0])
+ self.assertEqual(expected_out.size()[1], out.size()[1])
+ self.assertTrue(allclose(out, expected_out))
+
+ # LHS gradient.
+ grad = a.grad
+ expected_grad = acp.grad
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
+ self.assertTrue(allclose(grad, expected_grad))
+
+ # RHS gradient.
+ grad = b.grad
+ expected_grad = bcp.grad
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
+ self.assertTrue(allclose(grad, expected_grad))
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch29-cxx11-cpu-x86_64-linux/stk/ops/matrix_ops.py b/build/torch29-cxx11-cpu-x86_64-linux/stk/ops/matrix_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..447c72dc73439d84f58c917676cc04e64f13e97d
--- /dev/null
+++ b/build/torch29-cxx11-cpu-x86_64-linux/stk/ops/matrix_ops.py
@@ -0,0 +1,98 @@
+from ..backend import sputnik
+from ..matrix import Matrix
+import torch
+import numpy as np
+
+
+@torch.no_grad()
+def row_indices(shape, data, offsets, column_indices):
+ return sputnik.row_indices(shape, data, offsets, column_indices)
+
+
+# TODO(tgale): Replace this helper with a custom kernel. This operation
+# is much simpler to do than how it's currently implemented.
+@torch.no_grad()
+def _expand_for_blocking(idxs, blocking):
+ # Duplicate for block column dimension.
+ idxs = torch.reshape(idxs, [idxs.size()[0], 1, 2]).repeat(1, blocking, 1)
+
+ # Update the column indices.
+ idxs[:, :, 1] *= blocking
+ idxs[:, :, 1] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking])
+
+ # Duplicate for block row dimension.
+ idxs = torch.reshape(idxs, [idxs.size()[0], 1, blocking, 2])
+ idxs = idxs.repeat(1, blocking, 1, 1)
+
+ # Update the row indices.
+ idxs[:, :, :, 0] *= blocking
+ idxs[:, :, :, 0] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking, 1])
+ idxs = torch.reshape(idxs, [-1, 2])
+ return idxs
+
+
+# TODO(tgale): Add input type checking.
+@torch.no_grad()
+def to_dense(x):
+ assert isinstance(x, Matrix)
+
+ shape = (np.prod(x.shape[:-1]), x.shape[-1])
+ row_idxs = x.row_indices.type(torch.int32)
+ col_idxs = x.column_indices.type(torch.int32)
+ indices = _expand_for_blocking(torch.stack([row_idxs, col_idxs], dim=1), x.blocking)
+ indices = (indices[:, 0] * shape[1] + indices[:, 1]).type(torch.int64)
+
+ out = torch.zeros(shape[0] * shape[1], dtype=x.dtype, device=x.device)
+ out.scatter_(0, indices, x.data.flatten())
+ return out.reshape(x.size())
+
+
+@torch.no_grad()
+def _mask(x, blocking=1):
+ assert x.dim() == 2
+ assert x.size()[0] % blocking == 0
+ assert x.size()[1] % blocking == 0
+ block_rows = x.size()[0] // blocking
+ block_cols = x.size()[1] // blocking
+ x = torch.reshape(x, [block_rows, blocking, block_cols, blocking])
+ x = torch.sum(torch.abs(x), dim=(1, 3))
+ return x != 0
+
+
+# TODO(tgale): Add input type checking.
+@torch.no_grad()
+def to_sparse(x, blocking=1):
+ m = _mask(x, blocking)
+
+ # TODO(tgale): Set to appropriate type for input matrix.
+ row_nnzs = torch.sum(m, dim=1).type(torch.int32)
+ zeros = torch.zeros((1,), dtype=row_nnzs.dtype, device=row_nnzs.device)
+ offsets = torch.cat([zeros, torch.cumsum(row_nnzs, dim=0)])
+ offsets = offsets.type(torch.int32)
+
+ indices = torch.nonzero(m).type(torch.int16)
+ row_indices = indices[:, 0]
+ column_indices = indices[:, 1]
+
+ # Nonzero indices in the dense matrix.
+ nonzero_indices = torch.nonzero(m)
+ nonzero_indices = _expand_for_blocking(nonzero_indices, blocking)
+ nonzero_indices = nonzero_indices[:, 0] * x.size()[1] + nonzero_indices[:, 1]
+
+ # Gather the data and construct the sparse matrix.
+ data = torch.gather(x.flatten(), dim=0, index=nonzero_indices)
+ data = torch.reshape(data, [-1, blocking, blocking])
+ return Matrix(x.size(), data, row_indices, column_indices, offsets)
+
+
+@torch.no_grad()
+def ones_like(x):
+ return Matrix(x.size(),
+ torch.ones_like(x.data),
+ x.row_indices,
+ x.column_indices, x.offsets)
+
+
+def sum(x):
+ assert isinstance(x, Matrix)
+ return x.data.sum()
diff --git a/build/torch29-cxx11-cpu-x86_64-linux/stk/ops/matrix_ops_test.py b/build/torch29-cxx11-cpu-x86_64-linux/stk/ops/matrix_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..3af04c0760483e578f93303dc457415948a2a34c
--- /dev/null
+++ b/build/torch29-cxx11-cpu-x86_64-linux/stk/ops/matrix_ops_test.py
@@ -0,0 +1,62 @@
+import unittest
+
+from absl.testing import parameterized
+import stk
+import torch
+
+
+@parameterized.parameters(
+ (8, 16, 0.0, 1),
+ (8, 16, 0.5, 1),
+ (8, 16, .95, 1),
+ (16, 8, 0.0, 1),
+ (16, 8, 0.5, 1),
+ (16, 8, .95, 1),
+ (8, 16, 0.0, 8),
+ (8, 16, 0.5, 8),
+ (8, 16, 1.0, 8),
+ (16, 8, 0.0, 8),
+ (16, 8, 0.5, 8),
+ (16, 8, 1.0, 8),
+ (128, 256, 0.5, 16),
+ (256, 128, 0.75, 32),
+ (512, 512, .875, 128))
+class MatrixOpsTest(parameterized.TestCase):
+
+ def testMatrixOps_FormatConversion(self, rows, cols, sparsity, blocking):
+ mask = stk.random.dense_mask(rows, cols, sparsity, blocking)
+ x = (torch.randn(rows, cols) * mask).type(torch.float16)
+
+ # Convert the matrix to sparse format.
+ sparse_x = stk.ops.to_sparse(x, blocking)
+
+ # Validate the matrix.
+ sparse_x.validate()
+
+ # Validate the shape.
+ self.assertEqual(sparse_x.dim(), 2)
+ self.assertEqual(sparse_x.size()[0], rows)
+ self.assertEqual(sparse_x.size()[1], cols)
+
+ # Validate the sparsity.
+ numblocks = rows // blocking * cols // blocking
+ nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+ self.assertEqual(sparse_x.nnz, nnz)
+
+ # Convert back to dense format.
+ dense_x = stk.ops.to_dense(sparse_x)
+
+ # Validate the shape.
+ self.assertEqual(dense_x.dim(), 2)
+ self.assertEqual(dense_x.size()[0], rows)
+ self.assertEqual(dense_x.size()[1], cols)
+
+ # Validate the sparsity
+ self.assertEqual(torch.count_nonzero(dense_x).item(), nnz)
+
+ # Validate the output.
+ self.assertTrue(torch.all(torch.eq(x, dense_x)))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch29-cxx11-cpu-x86_64-linux/stk/random/__init__.py b/build/torch29-cxx11-cpu-x86_64-linux/stk/random/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2576d1ca27283f77569a9a620c7c99fa68aaf30e
--- /dev/null
+++ b/build/torch29-cxx11-cpu-x86_64-linux/stk/random/__init__.py
@@ -0,0 +1,2 @@
+# from stk.random.random_ops import dense_mask, mask, randn
+from .random_ops import dense_mask, mask, randn
diff --git a/build/torch29-cxx11-cpu-x86_64-linux/stk/random/random_ops.py b/build/torch29-cxx11-cpu-x86_64-linux/stk/random/random_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1b36771e0eb8e7abf46bcb3b136b5fb1d29df93
--- /dev/null
+++ b/build/torch29-cxx11-cpu-x86_64-linux/stk/random/random_ops.py
@@ -0,0 +1,36 @@
+import numpy as np
+import torch
+from ..ops import matrix_ops
+
+
+@torch.no_grad()
+def dense_mask(rows, cols, sparsity, blocking=1):
+ assert sparsity >= 0.0 and sparsity <= 1.0
+ assert rows % blocking == 0 and cols % blocking == 0
+
+ block_rows, block_cols = (rows // blocking, cols // blocking)
+ nnz = round(block_rows * block_cols * (1 - sparsity))
+
+ out = np.ones(block_rows * block_cols)
+ mask = np.random.choice(out.size, out.size - nnz, replace=False)
+ out[mask] = 0.0
+
+ out = np.tile(
+ np.reshape(out, [block_rows, 1, block_cols, 1]),
+ (1, blocking, 1, blocking))
+ out = np.reshape(out, [rows, cols])
+ return torch.from_numpy(out.astype(np.float32))
+
+
+@torch.no_grad()
+def mask(m, n, sparsity, blocking=1):
+ out = dense_mask(m, n, sparsity, blocking).type(torch.float16)
+ return matrix_ops.to_sparse(out, blocking=blocking)
+
+
+@torch.no_grad()
+def randn(shape, sparsity, blocking=1):
+ shape_2d = (np.prod(shape[:-1]), shape[-1])
+ out = mask(*shape_2d, sparsity, blocking)
+ out.data.copy_(torch.randn(*out.data.shape))
+ return out.view(*shape)
diff --git a/build/torch29-cxx11-cpu-x86_64-linux/stk/random/random_ops_test.py b/build/torch29-cxx11-cpu-x86_64-linux/stk/random/random_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..587b44ec890c861879c6296b8f9028f5d99ab82f
--- /dev/null
+++ b/build/torch29-cxx11-cpu-x86_64-linux/stk/random/random_ops_test.py
@@ -0,0 +1,73 @@
+import unittest
+
+from absl.testing import parameterized
+from . import random
+import torch
+
+
+@parameterized.parameters(
+ (8, 16, 0.0, 1),
+ (8, 16, 0.5, 1),
+ (8, 16, .95, 1),
+ (16, 8, 0.0, 1),
+ (16, 8, 0.5, 1),
+ (16, 8, .95, 1),
+ (8, 16, 0.0, 8),
+ (8, 16, 0.5, 8),
+ (8, 16, 1.0, 8),
+ (16, 8, 0.0, 8),
+ (16, 8, 0.5, 8),
+ (16, 8, 1.0, 8),
+ (128, 256, 0.5, 16),
+ (256, 128, 0.75, 32),
+ (512, 512, .875, 128))
+class RandomOpsTest(parameterized.TestCase):
+
+ def testRandomOps_DenseMask(self, rows, cols, sparsity, blocking):
+ mask = random.dense_mask(
+ rows, cols, sparsity, blocking)
+
+ # Validate the shape.
+ self.assertEqual(mask.dim(), 2)
+ self.assertEqual(mask.size()[0], rows)
+ self.assertEqual(mask.size()[1], cols)
+
+ # Validate the sparsity
+ numblocks = rows // blocking * cols // blocking
+ nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+ self.assertEqual(
+ torch.count_nonzero(mask).item(),
+ nnz)
+
+ # Check values are zero or one.
+ self.assertTrue(
+ torch.all(torch.logical_or(
+ torch.eq(mask, 0),
+ torch.eq(mask, 1))))
+
+ def testRandomOps_SparseMask(self, rows, cols, sparsity, blocking):
+ mask = random.mask(
+ rows, cols, sparsity, blocking)
+
+ # Validate the matrix.
+ mask.validate()
+
+ # Validate the shape.
+ self.assertEqual(mask.dim(), 2)
+ self.assertEqual(mask.size()[0], rows)
+ self.assertEqual(mask.size()[1], cols)
+
+ # Validate the sparsity.
+ numblocks = rows // blocking * cols // blocking
+ nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+ self.assertEqual(mask.nnz, nnz)
+
+ # Check values are zero or one.
+ self.assertTrue(
+ torch.all(torch.logical_or(
+ torch.eq(mask.data, 0),
+ torch.eq(mask.data, 1))))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch29-cxx11-cpu-x86_64-linux/xpu_fused_moe.py b/build/torch29-cxx11-cpu-x86_64-linux/xpu_fused_moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..0440b484e6f4073a384dcf3bc562174d2ef71d25
--- /dev/null
+++ b/build/torch29-cxx11-cpu-x86_64-linux/xpu_fused_moe.py
@@ -0,0 +1,672 @@
+# SPDX-License-Identifier: Apache-2.0
+# MegaBlocks XPU Fused MoE Implementation
+import os
+import torch
+
+from ._ops import ops, add_op_namespace_prefix
+
+from torch.library import register_fake
+
+
+def resolve_dtensor(weight: torch.Tensor):
+ """Convert DTensor to local tensor for use with custom ops."""
+ from torch.distributed._tensor import DTensor
+ if isinstance(weight, DTensor):
+ return weight.to_local()
+ return weight
+
+
+# Register fake/meta kernels for torch.compile compatibility
+def _register_xpu_fake_kernels():
+ """Register fake kernels for XPU MoE operations to support torch.compile."""
+
+ def _register_if_available(op_name, fn):
+ if hasattr(ops, op_name):
+ register_fake(add_op_namespace_prefix(op_name))(fn)
+
+ _register_if_available(
+ "cutlass_grouped_gemm_interface",
+ lambda ptr_A, ptr_B, ptr_scales, ptr_bias, ptr_D, expert_first_token_offset, N, K, num_experts, is_B_int4, is_B_mxfp4: ptr_D,
+ )
+
+ _register_if_available(
+ "fused_moe_prologue",
+ lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, ep_rank, ep_size, num_experts_on_rank: None,
+ )
+
+ _register_if_available(
+ "moe_gather",
+ lambda output, moe_output, topk_weights, permuted_row_to_unpermuted_row, unpermuted_row_to_permuted_row, expert_first_token_offset, num_experts: None,
+ )
+
+ _register_if_available(
+ "silu_and_mul",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "mul_and_silu",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_and_mul",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_tanh_and_mul",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_fast",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_new",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_quick",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "swigluoai_and_mul",
+ lambda out, input, alpha=1.702, limit=7.0: None,
+ )
+
+
+# Register fake kernels on module load
+_register_xpu_fake_kernels()
+
+
+# default
+def cutlass_grouped_gemm(input_A, input_B, bias, output, expert_token_count, n,
+ k, num_experts):
+ # expert_token_count_ = torch.tensor(expert_token_count,
+ # dtype=torch.int64,
+ # device=input_A.device)
+ # if bias is not None:
+ # bias = bias.repeat_interleave(expert_token_count_, dim=0).float()
+
+ def exclusive_prefix_sum(arr):
+ prefix = [0]
+ for i, x in enumerate(arr):
+ prefix.append(prefix[-1] + x)
+ return prefix
+
+ expert_offset = torch.tensor(exclusive_prefix_sum(expert_token_count),
+ dtype=torch.int64,
+ device="xpu")
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=None,
+ ptr_bias=bias,
+ ptr_D=output,
+ expert_first_token_offset=expert_offset,
+ N=n,
+ K=k,
+ num_experts=num_experts,
+ is_B_int4=False,
+ is_B_mxfp4=False)
+
+
+def cutlass_grouped_gemm_xe2(input_A, input_B, scales, bias, output,
+ num_rows_per_expert, n, k, num_experts, is_B_int4,
+ is_B_mxfp4):
+ expert_first_token_offset = torch.cat([
+ torch.tensor([0],
+ dtype=num_rows_per_expert.dtype,
+ device=num_rows_per_expert.device),
+ torch.cumsum(num_rows_per_expert, dim=0)
+ ]).to(torch.int64)
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=scales,
+ ptr_bias=bias,
+ ptr_D=output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=n,
+ K=k,
+ num_experts=num_experts,
+ is_B_int4=is_B_int4,
+ is_B_mxfp4=is_B_mxfp4)
+
+
+def ceilDiv(a, b):
+ return (a + b - 1) // b
+
+
+def compute_num_tokens_per_block(num_tokens, num_experts_per_node):
+ for num_tokens_per_block in [32, 64, 128, 256, 512, 1024]:
+ num_blocks_per_seq = ceilDiv(num_tokens, num_tokens_per_block)
+ if num_blocks_per_seq * num_experts_per_node <= num_tokens_per_block:
+ return num_tokens_per_block
+ return 1024
+
+
+def _bytes_to_typed_tensor(byte_tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
+ """Reinterpret a uint8 buffer as a typed tensor by copying bytes.
+
+ This avoids `Tensor.view(dtype)` which can fail under torch.compile
+ constant folding when shape divisibility is not proven.
+ """
+ if byte_tensor.dtype != torch.uint8:
+ raise ValueError("byte_tensor must be uint8")
+ itemsize = torch.empty((), dtype=dtype).element_size()
+ numel = byte_tensor.numel() // itemsize
+ out = torch.empty((numel,), dtype=dtype, device=byte_tensor.device)
+ out.view(torch.uint8).copy_(byte_tensor.contiguous())
+ return out
+
+
+def implement_zp(qweight):
+ # change u4 to s4 to avoid zero point in gemm kernel
+ # only support default zero point now
+ assert qweight.dtype == torch.uint8, "Input tensor must be uint8"
+
+ high_u4 = (qweight >> 4) & 0x0F
+ low_u4 = qweight & 0x0F
+
+ high_s8 = high_u4.to(torch.int8)
+ low_s8 = low_u4.to(torch.int8)
+
+ high_s8 = high_s8 - 8
+ low_s8 = low_s8 - 8
+
+ def pack_compact(a, b):
+
+ def process_number(x):
+ sign = (x < 0).to(torch.uint8)
+ abs_low3 = (x.view(torch.uint8) & 0x7).to(torch.uint8)
+ return (sign << 3) | abs_low3
+
+ packed_a = process_number(a)
+ packed_b = process_number(b)
+
+ return (packed_a << 4) | packed_b
+
+ result = pack_compact(high_s8, low_s8)
+
+ return result
+
+
+def xpu_fused_moe(hidden_states,
+ w13,
+ w13_scales,
+ w13_bias,
+ w2,
+ w2_scales,
+ w2_bias,
+ topk_weights,
+ topk_ids,
+ n_experts_per_token,
+ activation,
+ num_experts,
+ ep_rank=0,
+ ep_size=1,
+ is_fp8=False,
+ is_int4=False,
+ is_mxfp4=False):
+ '''
+ hidden_states: [num_rows, hidden_size]
+ w13: [num_experts, 2*inter_size, hidden_size]
+ w13_scales:
+ None for bf16/fp16
+ or [num_experts] for fp8
+ or [num_experts, 2*inter_size, hidden_size // group_size] for 4bits
+ w13_bias: [num_experts, 2*inter_size] or None
+ w2: [num_experts, hidden_size, inter_size]
+ w2_scales:
+ None for bf16/fp16
+ or [num_experts] for fp8
+ or [num_experts, hidden_size, inter_size // group_size] for 4bits
+ w2_bias: [num_experts, hidden_size] or None
+ topk_weights: [num_rows, topk]
+ topk_ids: [num_rows, topk]
+ n_experts_per_token: int
+ activation: str
+ num_experts: int
+ is_int4: bool
+ is_mxfp4: bool
+ '''
+
+ # Resolve DTensors to local tensors before passing to custom ops
+ hidden_states = resolve_dtensor(hidden_states)
+ w13 = resolve_dtensor(w13)
+ w2 = resolve_dtensor(w2)
+ if w13_scales is not None:
+ w13_scales = resolve_dtensor(w13_scales)
+ if w13_bias is not None:
+ w13_bias = resolve_dtensor(w13_bias)
+ if w2_scales is not None:
+ w2_scales = resolve_dtensor(w2_scales)
+ if w2_bias is not None:
+ w2_bias = resolve_dtensor(w2_bias)
+ topk_weights = resolve_dtensor(topk_weights)
+ topk_ids = resolve_dtensor(topk_ids)
+
+ output = torch.empty_like(hidden_states)
+ num_rows, hidden_size = list(hidden_states.shape)
+
+ dim_last = w13.shape[-1]
+ dim_second_last = w13.shape[-2]
+
+ # w13 is combined gate+up weights, so one dimension is 2*inter_size
+ # Determine which dimension is hidden_size and which is 2*inter_size
+ if dim_second_last == hidden_size:
+ # w13 is [E, hidden_size, 2*inter_size] - standard layout
+ inter_size = dim_last // 2
+ needs_transpose = False
+ else:
+ # w13 is [E, 2*inter_size, hidden_size] - needs transpose
+ inter_size = dim_second_last // 2
+ needs_transpose = True
+
+ assert w13.is_contiguous() and w2.is_contiguous()
+
+ # 4bits support [E, N, K]
+ # other types [E, K, N]
+ if not is_int4 and not is_mxfp4:
+ if not hasattr(w13, 'xpu_fused_moe'):
+ if needs_transpose:
+ w13.data = w13.transpose(-1, -2).contiguous()
+ w2.data = w2.transpose(-1, -2).contiguous()
+ w13.xpu_fused_moe = True
+ w13.inter_size = inter_size
+ else:
+ inter_size = w13.inter_size
+
+ if is_int4 and not hasattr(w13, 'xpu_fused_moe'):
+ w13_tmp = torch.empty_like(w13)
+ w2_tmp = torch.empty_like(w2)
+ for i in range(num_experts):
+ w13_tmp[i] = implement_zp(w13[i])
+ w2_tmp[i] = implement_zp(w2[i])
+ w13_tmp = w13_tmp.contiguous()
+ w2_tmp = w2_tmp.contiguous()
+ w13.data = w13_tmp
+ w2.data = w2_tmp
+ w13.xpu_fused_moe = True
+
+ # TODO: will all integrated in Cpp func. Temporary expose before gemm fusion
+ num_experts_per_node = num_experts
+ experts_per_token = n_experts_per_token
+ num_moe_inputs = n_experts_per_token * num_rows
+ permuted_elems = num_moe_inputs * hidden_size
+ # interbuf_elems = num_moe_inputs * inter_size
+ permuted_row_to_unpermuted_row_size = num_moe_inputs * 4
+ permuted_token_selected_experts_size = num_moe_inputs * 4
+ src_to_dest_map_size = experts_per_token * num_rows * 4
+ expert_first_token_offset_size = (num_experts_per_node + 1) * 8
+ num_tokens_per_block = compute_num_tokens_per_block(
+ num_rows, num_experts_per_node)
+ num_blocks_per_seq = ceilDiv(num_rows, num_tokens_per_block)
+ blocked_expert_counts_size = num_experts_per_node * num_blocks_per_seq * 4
+ blocked_expert_counts_cumsum_size = blocked_expert_counts_size
+ blocked_row_to_unpermuted_row_size = num_experts_per_node * num_rows * 4
+ permuted_data_size = permuted_elems * hidden_states.element_size()
+ permuted_token_final_scales_size = num_moe_inputs * 4
+
+ ws_map = {}
+ map_offset = 0
+
+ def config_ws(name, size):
+ nonlocal map_offset
+ if size % 256 != 0:
+ size += 256 - size % 256
+ ws_map[name] = (size, map_offset)
+ map_offset += size
+
+ config_ws("permuted_row_to_unpermuted_row",
+ permuted_row_to_unpermuted_row_size)
+ config_ws("permuted_token_selected_experts",
+ permuted_token_selected_experts_size)
+ config_ws("unpermuted_row_to_permuted_row", src_to_dest_map_size)
+ config_ws("blocked_expert_counts", blocked_expert_counts_size)
+ config_ws("blocked_expert_counts_cumsum",
+ blocked_expert_counts_cumsum_size)
+ config_ws("blocked_row_to_unpermuted_row",
+ blocked_row_to_unpermuted_row_size)
+ config_ws("expert_first_token_offset", expert_first_token_offset_size)
+ config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
+ config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
+
+ workspace = torch.zeros(map_offset,
+ dtype=torch.uint8,
+ device=hidden_states.device)
+ if topk_ids.dtype == torch.int32:
+ topk_ids = topk_ids.to(torch.int64)
+ ops.fused_moe_prologue(
+ input=hidden_states,
+ token_selected_experts=topk_ids,
+ token_final_scales=topk_weights,
+ workspace=workspace,
+ hidden_size=hidden_size,
+ inter_size=inter_size,
+ ep_rank=ep_rank,
+ ep_size=ep_size,
+ num_experts_on_rank=num_experts_per_node)
+
+ expert_first_token_offset_bytes = workspace[
+ ws_map["expert_first_token_offset"][1]:
+ ws_map["expert_first_token_offset"][1] +
+ expert_first_token_offset_size]
+ unpermuted_row_to_permuted_row_bytes = workspace[
+ ws_map["unpermuted_row_to_permuted_row"][1]:
+ ws_map["unpermuted_row_to_permuted_row"][1] +
+ src_to_dest_map_size]
+ permuted_row_to_unpermuted_row_bytes = workspace[
+ ws_map["permuted_row_to_unpermuted_row"][1]:
+ ws_map["permuted_row_to_unpermuted_row"][1] +
+ permuted_row_to_unpermuted_row_size]
+
+ if torch.compiler.is_compiling():
+ expert_first_token_offset = _bytes_to_typed_tensor(
+ expert_first_token_offset_bytes, torch.int64
+ )
+ unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
+ unpermuted_row_to_permuted_row_bytes, torch.int32
+ )
+ permuted_row_to_unpermuted_row = _bytes_to_typed_tensor(
+ permuted_row_to_unpermuted_row_bytes, torch.int32
+ )
+ else:
+ expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
+ unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
+ permuted_row_to_unpermuted_row = permuted_row_to_unpermuted_row_bytes.view(torch.int32)
+ gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
+ ws_map["overlapped_gemm1_gemm2_inputs"][1] +
+ permuted_data_size].view(hidden_states.dtype).view(
+ num_moe_inputs, hidden_size)
+ # permuted_token_final_scales = workspace[
+ # ws_map["permuted_token_final_scales"][1]:
+ # ws_map["permuted_token_final_scales"][1] +
+ # permuted_token_final_scales_size].view(torch.float)
+ gemm1_output = torch.empty((num_moe_inputs, 2 * inter_size),
+ dtype=hidden_states.dtype,
+ device=hidden_states.device)
+
+ ########### gemm1 ##################
+ input_B = w13
+
+ if not is_fp8 and not is_int4 and not is_mxfp4:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=gemm1_input,
+ ptr_B=input_B,
+ ptr_scales=None,
+ ptr_bias=w13_bias,
+ ptr_D=gemm1_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=2 * inter_size,
+ K=hidden_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+ else:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=gemm1_input,
+ ptr_B=input_B,
+ ptr_scales=w13_scales,
+ ptr_bias=w13_bias,
+ ptr_D=gemm1_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=2 * inter_size,
+ K=hidden_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+
+ # act
+ act_output = torch.empty((num_moe_inputs, inter_size),
+ dtype=gemm1_output.dtype,
+ device=gemm1_output.device)
+ if activation == "silu":
+ ops.silu_and_mul(act_output, gemm1_output)
+ elif activation == "gelu":
+ ops.gelu_and_mul(act_output, gemm1_output)
+ elif activation == "swigluoai":
+ ops.swigluoai_and_mul(act_output, gemm1_output, 1.702, 7.0)
+ else:
+ raise ValueError(f"Unsupported FusedMoe activation: {activation}.")
+
+ ########### gemm2 ##################
+ input_A = act_output.contiguous()
+ input_B = w2
+ gemm2_output = torch.empty((num_moe_inputs, hidden_size),
+ dtype=hidden_states.dtype,
+ device=hidden_states.device)
+ if not is_fp8 and not is_int4 and not is_mxfp4:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=None,
+ ptr_bias=w2_bias,
+ ptr_D=gemm2_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=hidden_size,
+ K=inter_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+ else:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=w2_scales,
+ ptr_bias=w2_bias,
+ ptr_D=gemm2_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=hidden_size,
+ K=inter_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+
+ ops.moe_gather(output, gemm2_output, topk_weights,
+ permuted_row_to_unpermuted_row,
+ unpermuted_row_to_permuted_row,
+ expert_first_token_offset,
+ num_experts_per_node)
+ return output
+
+
+def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor:
+ """Apply jitter to the input tensor for regularization."""
+ low = 1.0 - moe_jitter_eps
+ high = 1.0 + moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return x * (low + noise * (high - low))
+
+
+def compute_top_k(scores: torch.Tensor, moe_top_k: int):
+ """Compute the top-k scores from the logits."""
+ if moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, moe_top_k, dim=-1)
+
+
+def route_tokens_xpu(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: torch.Tensor,
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ training: bool = False,
+) -> tuple:
+ """Route tokens to experts and compute expert weights and indices (XPU version)."""
+ if training and moe_jitter_eps is not None:
+ x = apply_jitter(x, moe_jitter_eps)
+
+ x_flat = x.view(-1, x.shape[-1])
+ logits = torch.nn.functional.linear(x_flat, router_weight, router_bias)
+ expert_weights, expert_indices = compute_top_k(logits, moe_top_k)
+ expert_weights = expert_weights.softmax(dim=-1)
+ if moe_normalize_expert_weights is not None:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+
+ return logits, expert_weights, expert_indices
+
+
+def _get_device_mesh(model):
+ """Extract device_mesh from child's unused pre_hook closure for EP support."""
+ try:
+ hook = next(
+ h
+ for h in model.experts._forward_pre_hooks.values()
+ if "device_mesh" in h.__code__.co_freevars
+ )
+ return hook.__closure__[
+ hook.__code__.co_freevars.index("device_mesh")
+ ].cell_contents
+ except Exception:
+ return None
+
+
+class MegaBlocksMoeMLP(torch.nn.Module):
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> tuple:
+ """
+ Forward pass through the MoE layer.
+
+ Args:
+ x: Input tensor of shape [batch_size, seq_len, hidden_size] or [tokens, hidden_size]
+
+ Returns:
+ Tuple of (output, expert_weights) where:
+ - output: Tensor of same shape as input
+ - expert_weights: Expert weights for each token [tokens, top_k]
+ """
+
+ # Get MoE parameters from the wrapped modules
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+
+ # Get EP (Expert Parallelism) parameters
+ ep_size = 1
+ ep_rank = 0
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = _get_device_mesh(self)
+ if device_mesh is not None:
+ expert_parallel_group = device_mesh.get_group()
+ if expert_parallel_group is not None:
+ import torch.distributed as dist
+ if dist.is_initialized():
+ ep_size = dist.get_world_size(expert_parallel_group)
+ ep_rank = dist.get_rank(expert_parallel_group)
+
+ # Number of experts on this rank
+ num_experts_on_rank = moe_num_experts // ep_size
+
+ # Detect activation type - check for GptOss-style swigluoai activation
+ # GptOssExperts has alpha and limit attributes for swigluoai
+ if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
+ activation = "swigluoai"
+ else:
+ activation = getattr(self.experts, "activation", "silu")
+
+ # Get weight tensors - support different naming conventions
+ if hasattr(self.experts, "gate_up_proj"):
+ w13 = self.experts.gate_up_proj
+ # NOTE: swigluoai_and_mul kernel expects interleaved layout [g0,u0,g1,u1,...]
+ # which matches GptOss's gate_up_proj layout, so no conversion needed.
+
+ elif hasattr(self.experts, "w1"):
+ # Combine w1 and w3 if stored separately
+ w1 = self.experts.w1
+ w3 = getattr(self.experts, "w3", None)
+ if w3 is not None:
+ w13 = torch.cat([w1, w3], dim=-2)
+ else:
+ w13 = w1
+ else:
+ raise AttributeError("experts module must have 'gate_up_proj' or 'w1' attribute")
+
+ if hasattr(self.experts, "down_proj"):
+ w2 = self.experts.down_proj
+ elif hasattr(self.experts, "w2"):
+ w2 = self.experts.w2
+ else:
+ raise AttributeError("experts module must have 'down_proj' or 'w2' attribute")
+
+ # Get optional bias tensors
+ w13_bias = getattr(self.experts, "gate_up_proj_bias", None)
+ w2_bias = getattr(self.experts, "down_proj_bias", None)
+
+ # Get quantization info
+ is_fp8 = getattr(self.experts, "is_fp8", False)
+ is_int4 = getattr(self.experts, "is_int4", False)
+ is_mxfp4 = getattr(self.experts, "is_mxfp4", False)
+
+ w13_scales = getattr(self.experts, "gate_up_proj_scales", None)
+ w2_scales = getattr(self.experts, "down_proj_scales", None)
+
+ # Store original shape
+ in_shape = x.size()
+
+ # Route tokens to experts
+ logits, expert_weights, expert_indices = route_tokens_xpu(
+ x,
+ self.router.weight,
+ getattr(self.router, "bias", None),
+ moe_top_k,
+ moe_num_experts,
+ moe_jitter_eps,
+ moe_normalize_expert_weights,
+ self.training,
+ )
+
+ # Reshape input for fused MoE
+ x_flat = x.view(-1, x.shape[-1])
+
+ # Call XPU fused MoE kernel
+ output = xpu_fused_moe(
+ hidden_states=x_flat,
+ w13=w13,
+ w13_scales=w13_scales,
+ w13_bias=w13_bias,
+ w2=w2,
+ w2_scales=w2_scales,
+ w2_bias=w2_bias,
+ topk_weights=expert_weights.float(),
+ topk_ids=expert_indices,
+ n_experts_per_token=moe_top_k,
+ activation=activation,
+ num_experts=num_experts_on_rank,
+ ep_rank=ep_rank,
+ ep_size=ep_size,
+ is_fp8=is_fp8,
+ is_int4=is_int4,
+ is_mxfp4=is_mxfp4,
+ )
+
+ # All-reduce across EP group to combine partial expert outputs
+ if ep_size > 1 and expert_parallel_group is not None:
+ import torch.distributed as dist
+ dist.all_reduce(output, op=dist.ReduceOp.SUM, group=expert_parallel_group)
+
+ # Restore original shape
+ output = output.view(in_shape)
+
+ return output, expert_weights
+
+
+# Export classes and functions
+__all__ = [
+ "MegaBlocksMoeMLP",
+ "xpu_fused_moe",
+ "cutlass_grouped_gemm",
+ "cutlass_grouped_gemm_xe2",
+]
\ No newline at end of file
diff --git a/build/torch29-cxx11-cu126-aarch64-linux/__init__.py b/build/torch29-cxx11-cu126-aarch64-linux/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..38075732c6d8fa0e1e6ef493145e1aca3851ae6b
--- /dev/null
+++ b/build/torch29-cxx11-cu126-aarch64-linux/__init__.py
@@ -0,0 +1,202 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+from ._ops import ops
+
+from .grouped_gemm import backend as gg_backend
+from .grouped_gemm import ops as gg_ops
+
+
+from ._layers.arguments import Arguments
+from ._layers.dmoe import ParallelDroplessMLP, dMoE
+from ._layers.glu import SparseGLU
+from ._layers.mlp import MLP, SparseMLP
+from ._layers.moe import MoE, ParallelMLP, get_load_balancing_loss
+
+from . import layers
+
+# This section contains the direct kernel exports (not inlcuded in the original code)
+def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
+ """
+ Compute exclusive cumulative sum along the specified dimension.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ result = ops.exclusive_cumsum(x, dim)
+ out.copy_(result)
+ return out
+
+
+def inclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
+ """
+ Compute inclusive cumulative sum along the specified dimension.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ result = ops.inclusive_cumsum(x, dim)
+ out.copy_(result)
+ return out
+
+
+def histogram(x: torch.Tensor, num_bins: int) -> torch.Tensor:
+ """
+ Compute histogram of input tensor values.
+
+ Args:
+ x: Input tensor
+ num_bins: Number of histogram bins
+
+ Returns:
+ Histogram tensor with counts for each bin
+ """
+ return ops.histogram(x, num_bins)
+
+
+def indices(
+ padded_bins: torch.Tensor,
+ block_size: int,
+ output_block_rows: int,
+ output_block_columns: int,
+) -> torch.Tensor:
+ """
+ Construct indices from padded bins for sparse operations.
+
+ Args:
+ padded_bins: Tensor containing bin boundaries
+ block_size: Size of each block
+ output_block_rows: Number of rows in output blocks
+ output_block_columns: Number of columns in output blocks
+
+ Returns:
+ Tensor containing constructed indices
+ """
+ return ops.indices(padded_bins, block_size, output_block_rows, output_block_columns)
+
+
+def replicate_forward(
+ x: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Forward pass of replicate operation - replicate values according to bin sizes.
+
+ Args:
+ x: Input tensor with values to replicate
+ bins: Tensor containing bin sizes
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ return ops.replicate_forward(x, bins, out)
+
+
+def replicate_backward(
+ grad: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Backward pass of replicate operation - reduce gradients back to bins.
+
+ Args:
+ grad: Gradient tensor to reduce
+ bins: Tensor containing bin sizes
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ return ops.replicate_backward(grad, bins, out)
+
+
+def sort(
+ x: torch.Tensor, end_bit: int, x_out: torch.Tensor, iota_out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Radix sort with index tracking.
+
+ Args:
+ x: Input tensor to sort
+ end_bit: Number of bits to consider in sorting
+ x_out: Output tensor for sorted values
+ iota_out: Output tensor for sorted indices
+
+ Returns:
+ The sorted values tensor
+ """
+ return ops.sort(x, end_bit, x_out, iota_out)
+
+
+# Convenience functions for common use cases
+def cumsum(x: torch.Tensor, dim: int = -1, exclusive: bool = False) -> torch.Tensor:
+ """
+ Compute cumulative sum with automatic output allocation.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum (default: last dimension)
+ exclusive: Whether to compute exclusive (True) or inclusive (False) cumsum
+
+ Returns:
+ New tensor containing the cumulative sum
+ """
+ out = torch.empty_like(x)
+ if exclusive:
+ return exclusive_cumsum(x, dim, out)
+ else:
+ return inclusive_cumsum(x, dim, out)
+
+
+def argsort(x: torch.Tensor, end_bit: int = 32) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Sort tensor and return both sorted values and indices.
+
+ Args:
+ x: Input tensor to sort
+ end_bit: Number of bits to consider in sorting
+
+ Returns:
+ Tuple of (sorted_values, sorted_indices)
+ """
+ x_out = torch.empty_like(x)
+ iota_out = torch.empty_like(x)
+ sort(x, end_bit, x_out, iota_out)
+ return x_out, iota_out
+
+
+# Export public API
+__all__ = [
+ "MyReplacementLayer",
+ # Direct kernel exports
+ "exclusive_cumsum",
+ "inclusive_cumsum",
+ "histogram",
+ "indices",
+ "replicate_forward",
+ "replicate_backward",
+ "sort",
+ "cumsum",
+ "argsort",
+ # Original exports
+ "Arguments",
+ "ParallelDroplessMLP",
+ "dMoE",
+ "SparseGLU",
+ "MLP",
+ "SparseMLP",
+ "MoE",
+ "ParallelMLP",
+ "get_load_balancing_loss",
+]
diff --git a/build/torch29-cxx11-cu126-aarch64-linux/_layers/__init__.py b/build/torch29-cxx11-cu126-aarch64-linux/_layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a720e7a2cc4e44636f6e433a2750e945dc38e8b2
--- /dev/null
+++ b/build/torch29-cxx11-cu126-aarch64-linux/_layers/__init__.py
@@ -0,0 +1,10 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# from megablocks.layers.dmoe import dMoE
+from .moe import MoE
+
+__all__ = [
+ 'MoE',
+ # 'dMoE',
+]
diff --git a/build/torch29-cxx11-cu126-aarch64-linux/_layers/activation_fn.py b/build/torch29-cxx11-cu126-aarch64-linux/_layers/activation_fn.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e1d956704840aa4daf7d1d71d24e051567feab9
--- /dev/null
+++ b/build/torch29-cxx11-cu126-aarch64-linux/_layers/activation_fn.py
@@ -0,0 +1,33 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Callable, Union
+
+import torch
+from ..stk import Matrix
+
+
+def act_fn(
+ x: Matrix,
+ function: Callable,
+ return_grad_fn: bool = False,
+ **kwargs,
+) -> Union[tuple[Matrix, Any] | Matrix]:
+ assert isinstance(x, Matrix)
+ with torch.set_grad_enabled(torch.is_grad_enabled() or return_grad_fn):
+ if return_grad_fn:
+ x.data.requires_grad = True
+ out = function(x.data, **kwargs)
+ y = Matrix(
+ x.size(),
+ out,
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
+ if return_grad_fn:
+ return y, out.backward
+ return y
diff --git a/build/torch29-cxx11-cu126-aarch64-linux/_layers/all_to_all.py b/build/torch29-cxx11-cu126-aarch64-linux/_layers/all_to_all.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ac7067bcaa34db1d82b340c43550fe3577aa7a3
--- /dev/null
+++ b/build/torch29-cxx11-cu126-aarch64-linux/_layers/all_to_all.py
@@ -0,0 +1,54 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import torch.distributed as dist
+
+
+class AllToAllOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op):
+ out = torch.empty((sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype)
+
+ ctx.input_shape = x.shape
+ ctx.output_split_sizes = output_split_sizes
+ ctx.input_split_sizes = input_split_sizes
+ ctx.group = group
+ handle = dist.all_to_all_single(
+ out,
+ x,
+ output_split_sizes=output_split_sizes,
+ input_split_sizes=input_split_sizes,
+ group=group,
+ async_op=async_op,
+ )
+ return out, handle
+
+ @staticmethod
+ def backward(ctx, grad, _):
+ if ctx.needs_input_grad[0]:
+ out = torch.empty(
+ ctx.input_shape,
+ device=grad.device,
+ dtype=grad.dtype,
+ )
+ dist.all_to_all_single(
+ out,
+ grad,
+ output_split_sizes=ctx.input_split_sizes,
+ input_split_sizes=ctx.output_split_sizes,
+ group=ctx.group,
+ )
+ return out, None, None, None, None
+ return None, None, None, None, None
+
+
+def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False):
+ return AllToAllOp.apply(
+ x,
+ output_split_sizes,
+ input_split_sizes,
+ group,
+ async_op,
+ )
diff --git a/build/torch29-cxx11-cu126-aarch64-linux/_layers/arguments.py b/build/torch29-cxx11-cu126-aarch64-linux/_layers/arguments.py
new file mode 100644
index 0000000000000000000000000000000000000000..4db9b1bd38bc2e2f421625c124f86b85f45c5ae0
--- /dev/null
+++ b/build/torch29-cxx11-cu126-aarch64-linux/_layers/arguments.py
@@ -0,0 +1,101 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import dataclasses
+from functools import partial
+from typing import Any, Callable, Optional, Union
+
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+
+# import megablocks.grouped_gemm_util as grouped_gemm
+from .. import grouped_gemm_util as grouped_gemm
+
+# Type annotation for in-place Tensor initialization function.
+InitFn = Union[Callable[[torch.Tensor], None], partial[torch.Tensor]]
+
+_ALLOWED_BITWIDTHS = (-1, 4, 8)
+
+DEFAULT_ACTIVATION_FN = partial(F.gelu, approximate='tanh')
+
+
+@dataclasses.dataclass
+class Arguments:
+ # Model arguments.
+ hidden_size: int = 1024
+ ffn_hidden_size: int = 4096
+ num_layers: int = 1
+ bias: bool = True
+ return_bias: bool = True
+ activation_fn: Optional[Callable] = DEFAULT_ACTIVATION_FN
+
+ # MoE arguments.
+ moe_num_experts: int = 1
+ moe_top_k: int = 1
+ moe_capacity_factor: int = 1
+ moe_normalize_expert_weights: Optional[Union[int, float]] = None
+ moe_loss_weight: float = 0.1
+ moe_jitter_eps: Optional[float] = None
+ moe_lbl_in_fp32: bool = False
+
+ # Parallelism arguments.
+ moe_expert_model_parallelism: bool = False
+ expert_parallel_group: Optional[dist.ProcessGroup] = None
+ pipeline_model_parallel_size: int = 1
+ num_layers_per_virtual_pipeline_stage: Optional[int] = None
+
+ # Compute arguments.
+ memory_optimized_mlp: bool = False
+ mlp_type: str = 'mlp'
+ mlp_impl: str = 'sparse'
+
+ # Initialization arguments.
+ fp16: bool = True
+ bf16: bool = False
+ device: Union[int, torch.device] = dataclasses.field(default_factory=torch.cuda.current_device)
+ init_method: InitFn = partial(torch.nn.init.normal_, mean=0.0, std=0.02)
+ output_layer_init_method: InitFn = init_method
+
+ # Benchmarking arguments.
+ uniform_expert_assignment: bool = False
+
+ # shared expert arguments
+ shared_expert: bool = False # enable using shared expert
+ fc_cls: Any = torch.nn.Linear # class of the fully connected layer in shared expert (purpose: to allow using custom FC layer eg te.Linear (for FP8))
+ fc_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict,) # kwargs for custom fc layers
+ remat_act_fn: bool = True # enable act fn to be rematerialized instead of stored
+ shared_expert_hidden_size: Optional[
+ int] = None # hidden size of the shared expert IF we want to set it to something different from hidden_size
+ shared_expert_weighted_sum: bool = False # enable using weighted sum for shared expert output (wieghted by number of experts used)
+
+ # Router Z-loss arguments
+ moe_zloss_weight: float = 0 # 1e-3 is a reasonable value
+ moe_zloss_in_fp32: bool = False
+
+ def __post_init__(self):
+ # Sparse MLP is not supported with triton >=3.2.0
+ # TODO: Remove this once sparse is supported with triton >=3.2.0
+ if self.__getattribute__('mlp_impl') == 'sparse':
+ try:
+ import triton
+ if triton.__version__ >= '3.2.0':
+ raise ValueError(
+ 'Sparse MLP is not supported with triton >=3.2.0. Please use mlp_impl="grouped" instead.',
+ )
+ except ImportError:
+ raise ImportError('Triton is required for sparse MLP implementation')
+
+ if self.__getattribute__('mlp_impl') == 'grouped':
+ grouped_gemm.assert_grouped_gemm_is_available()
+
+ if self.shared_expert_hidden_size is None:
+ self.shared_expert_hidden_size = self.ffn_hidden_size
+
+
+def from_megatron(megatron_args: Any):
+ args = Arguments()
+ for field in dataclasses.fields(args):
+ if hasattr(megatron_args, field.name):
+ setattr(args, field.name, getattr(megatron_args, field.name))
+ return args
diff --git a/build/torch29-cxx11-cu126-aarch64-linux/_layers/common.py b/build/torch29-cxx11-cu126-aarch64-linux/_layers/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d07109702963ba48a3b94ab860807954dfd79c1
--- /dev/null
+++ b/build/torch29-cxx11-cu126-aarch64-linux/_layers/common.py
@@ -0,0 +1,26 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+from .arguments import Arguments
+
+
+def dtype(args: Arguments):
+ if args.fp16:
+ return torch.float16
+ elif args.bf16:
+ return torch.bfloat16
+ return None
+
+
+def cast_if_autocast_enabled(tensor):
+ if torch.is_autocast_enabled():
+ if tensor.device.type == 'cuda':
+ dtype = torch.get_autocast_gpu_dtype()
+ elif tensor.device.type == 'cpu':
+ dtype = torch.get_autocast_cpu_dtype()
+ else:
+ raise NotImplementedError()
+ return tensor.to(dtype=dtype)
+ return tensor
diff --git a/build/torch29-cxx11-cu126-aarch64-linux/_layers/dmlp_registry.py b/build/torch29-cxx11-cu126-aarch64-linux/_layers/dmlp_registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..de2ed047042e438c7190ebb139b6f7f30009734c
--- /dev/null
+++ b/build/torch29-cxx11-cu126-aarch64-linux/_layers/dmlp_registry.py
@@ -0,0 +1,42 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Union
+
+from . import glu, mlp
+from .arguments import Arguments
+
+MlpType = Union[mlp.SparseMLP, glu.SparseGLU]
+
+_REGISTRY = {
+ 'mlp': {
+ 'grouped': mlp.GroupedMLP,
+ 'sparse': mlp.SparseMLP,
+ },
+ 'glu': {
+ 'grouped': glu.GroupedGLU,
+ 'sparse': glu.SparseGLU,
+ },
+}
+
+
+def get(args: Arguments) -> MlpType:
+ """Returns an MLP for use in a dMoE instance.
+
+ Uses the provided arguments to instantiate the appropriate
+ MLP instance. This only contains MLPs for use in dMoEs
+ (ie. only for the dropless versions of MoEs).
+
+ Args:
+ args: propagated Arguments dataclass.
+
+ Returns:
+ An instantiated MLP constructed using the input args.
+ """
+ if args.mlp_type not in _REGISTRY:
+ raise ValueError(f'Unsupported mlp type: {args.mlp_type}')
+
+ if args.mlp_impl not in _REGISTRY[args.mlp_type]:
+ raise ValueError(f'{args.mlp_type} does not support {args.mlp_impl} backend.',)
+
+ return _REGISTRY[args.mlp_type][args.mlp_impl](args)
diff --git a/build/torch29-cxx11-cu126-aarch64-linux/_layers/dmoe.py b/build/torch29-cxx11-cu126-aarch64-linux/_layers/dmoe.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d0375a4df2f27134c4127e60be04f3b45693050
--- /dev/null
+++ b/build/torch29-cxx11-cu126-aarch64-linux/_layers/dmoe.py
@@ -0,0 +1,337 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import numpy as np
+import torch
+
+# try:
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/dmoe.py is needed.',
+# )
+
+# import megablocks.ops as ops
+# # from megablocks.ops import ops
+# from megablocks.layers import common, dmlp_registry, moe, mpu
+# from megablocks.layers.arguments import Arguments
+
+from .. import stk
+from .. import ops
+from . import common, dmlp_registry, moe, mpu
+from .arguments import Arguments
+
+def promote_scalar(x):
+ return x.view(1) if not len(x.size()) else x
+
+
+class ParallelDroplessMLP(moe.ParallelMLP):
+
+ def __init__(self, args: Arguments):
+ super(ParallelDroplessMLP, self).__init__(args)
+ self.hidden_size = args.hidden_size
+ self.ffn_hidden_size = mpu.features_per_rank(args)
+ self.blocking = 128
+ self.mlp = dmlp_registry.get(args)
+
+ # Calculate the number of bits needed to represent the column indices
+ # in the intermediate sparse matrix.
+ max_column_index = ((self.ffn_hidden_size * self.num_experts) // self.blocking)
+ self.transpose_sort_end_bit = max(
+ int(np.ceil(np.log2(max_column_index))),
+ 1,
+ )
+
+ def sparse_transpose(self, size, row_indices, column_indices, offsets):
+ block_columns = size[1] // self.blocking
+
+ # Sort row indices by column indices to get the transposed matrix's
+ # column indices.
+ #
+ # NOTE: Our sort operation uses the same width indices as the input values.
+ # To avoid overflow when we have large activation matrices we cast to
+ # 32-bit before sorting.
+ _, gather_indices = ops.sort(
+ column_indices.int(),
+ self.transpose_sort_end_bit,
+ )
+
+ # There are a constant number of blocks in every row of the sparse matrix.
+ # A blocks offset is:
+ #
+ # row_index * blocks_per_row + column_index % blocks_per_row
+ #
+ # Once we have the block offsets ordered for transposition we can divide
+ # by blocks_per_row to get the transposed column indices.
+ column_indices_t = row_indices.gather(0, gather_indices.long())
+ block_offsets_t = gather_indices.int()
+
+ zero = torch.zeros((1,), dtype=torch.int32, device=row_indices.device)
+ nnz_per_column = ops.histogram(column_indices, block_columns)
+ nnz_per_column = ops.inclusive_cumsum(nnz_per_column, 0)
+ if nnz_per_column.dim() == 0:
+ # This addresses an edge case when ffn_hidden_size is equal to self.blocking.
+ nnz_per_column = nnz_per_column.unsqueeze(0)
+ offsets_t = torch.cat([zero, nnz_per_column])
+ return column_indices_t, offsets_t, block_offsets_t
+
+ def topology(self, x, padded_bins):
+ padded_tokens, _ = x.size()
+ assert padded_tokens % self.blocking == 0
+ if self.ffn_hidden_size % self.blocking != 0:
+ raise ValueError(
+ f'The ffn_hidden_size {self.ffn_hidden_size} must be divisible by ' +
+ f'the block size {self.blocking}. Please update your configuration.',
+ )
+
+ # Offsets for the sparse matrix. All rows have the
+ # same number of nonzero blocks dictated by the
+ # dimensionality of a single expert.
+ block_rows = padded_tokens // self.blocking
+ blocks_per_row = self.ffn_hidden_size // self.blocking
+ offsets = torch.arange(
+ 0,
+ block_rows * blocks_per_row + 1,
+ blocks_per_row,
+ dtype=torch.int32,
+ device=x.device,
+ )
+
+ # Indices for the sparse matrix. The indices for
+ # the intermediate matrix are dynamic depending
+ # on the mapping of tokens to experts.
+ column_indices = ops.topology(
+ padded_bins,
+ self.blocking,
+ block_rows,
+ blocks_per_row,
+ )
+
+ # TODO(tgale): This is unused. Remove the need for this in stk.
+ # For now, use meta init to save the device memory.
+ data = torch.empty(
+ column_indices.numel(),
+ self.blocking,
+ self.blocking,
+ dtype=common.dtype(self.args),
+ device='meta',
+ )
+ shape = (
+ padded_tokens,
+ self.ffn_hidden_size * mpu.experts_per_rank(self.args),
+ )
+ row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
+ column_indices_t, offsets_t, block_offsets_t = self.sparse_transpose(
+ shape,
+ row_indices,
+ column_indices,
+ offsets,
+ )
+ return stk.Matrix(
+ shape,
+ data,
+ row_indices,
+ column_indices,
+ offsets,
+ column_indices_t,
+ offsets_t,
+ block_offsets_t,
+ )
+
+ def indices_and_padded_bins(self, top_experts):
+ # Sort the expert ids to produce the scatter/gather
+ # indices for the permutation.
+ top_experts = top_experts.int()
+ bin_ids, indices = ops.sort(top_experts, self.sort_end_bit)
+
+ # Histogram the expert ids to identify the number of
+ # tokens routed to each expert.
+ tokens_per_expert = ops.histogram(top_experts, self.num_experts)
+
+ # Round the token counts up to the block size used in
+ # the matrix muliplications. Caculate the starting
+ # position of each bin.
+ padded_tokens_per_expert = ops.round_up(
+ tokens_per_expert,
+ self.blocking,
+ )
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ padded_bins = promote_scalar(padded_bins)
+
+ # Calculate the bin bounds for the sorted tokens.
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ bins = promote_scalar(bins)
+ return indices, bin_ids, bins, padded_bins, tokens_per_expert
+
+ def sparse_forward_once(self, x, expert_weights, top_experts):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, padded_bins, tokens_per_expert = (self.indices_and_padded_bins(top_experts))
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.padded_gather(
+ x,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ self.top_k,
+ )
+
+ # Create the sparse matrix topology.
+ with torch.no_grad():
+ topo = self.topology(x, padded_bins)
+
+ # Perform the expert computation.
+ x = self.mlp(x, topo)
+
+ # Un-route the data for the MoE output.
+ x = ops.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ padded_bins,
+ self.top_k,
+ )
+ return x, tokens_per_expert
+
+ # For use in the base-class parallel_forward_once.
+ def sparse_permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy, # unused
+ top_k,
+ ):
+
+ # Round the token counts up to the block size used in the matrix
+ # multiplication. Calculate the starting position of each bin.
+ padded_tokens_per_expert = ops.round_up(
+ tokens_per_expert,
+ self.blocking,
+ )
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ padded_bins = promote_scalar(padded_bins)
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
+
+ # Create the sparse matrix topology.
+ with torch.no_grad():
+ topo = self.topology(x, padded_bins)
+
+ # Perform the expert computation.
+ x = self.mlp(x, topo)
+
+ # Un-route the data for the MoE output.
+ return ops.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ def grouped_forward_once(self, x, expert_weights, top_experts):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ out = self.grouped_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ -1, # unused
+ self.args.moe_top_k,
+ )
+ return out, tokens_per_expert
+
+ def grouped_permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy, # unused
+ top_k,
+ ):
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.gather(x, indices, bin_ids, bins, top_k)
+
+ # Perform the expert computation.
+ x = self.mlp(x, tokens_per_expert)
+
+ # Un-route the data for the MoE output.
+ return ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
+
+ def forward_once(self, x, expert_weights, top_experts):
+ if self.args.mlp_impl == 'sparse':
+ return self.sparse_forward_once(x, expert_weights, top_experts)
+ else:
+ return self.grouped_forward_once(x, expert_weights, top_experts)
+
+ def permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ ):
+ if self.args.mlp_impl == 'sparse':
+ return self.sparse_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ )
+ else:
+ return self.grouped_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ )
+
+
+class dMoE(moe.MoE):
+
+ def _init_experts_mlp(self, args: Arguments):
+ return ParallelDroplessMLP(args)
diff --git a/build/torch29-cxx11-cu126-aarch64-linux/_layers/gelu.py b/build/torch29-cxx11-cu126-aarch64-linux/_layers/gelu.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4c9e6532798615b5c12c96694241a4c18ee8f7b
--- /dev/null
+++ b/build/torch29-cxx11-cu126-aarch64-linux/_layers/gelu.py
@@ -0,0 +1,52 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# try:
+# import stk
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/gelu.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+import torch.nn.functional as F
+
+
+@torch.jit.script
+def _gelu_backward_inplace(g, x):
+ tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
+ ff = (0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out))
+ return g.mul_(ff)
+
+
+def gelu_backward_(grad: stk.Matrix, x: stk.Matrix):
+ # NOTE: The two sparse matrices must have the same topology.
+ if isinstance(grad, stk.Matrix) and isinstance(x, stk.Matrix):
+ return stk.Matrix(
+ x.size(),
+ _gelu_backward_inplace(grad.data, x.data),
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
+ return _gelu_backward_inplace(grad, x)
+
+
+def gelu(x: stk.Matrix):
+ assert isinstance(x, stk.Matrix)
+ return stk.Matrix(
+ x.size(),
+ F.gelu(x.data, approximate='tanh'),
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
diff --git a/build/torch29-cxx11-cu126-aarch64-linux/_layers/glu.py b/build/torch29-cxx11-cu126-aarch64-linux/_layers/glu.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f297a41ff6a1a2a285f5b461951672364b898da
--- /dev/null
+++ b/build/torch29-cxx11-cu126-aarch64-linux/_layers/glu.py
@@ -0,0 +1,244 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# import stk.ops
+# try:
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/glu.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+
+# from megablocks import grouped_gemm_util as gg
+# from megablocks.layers import common, mpu
+# from megablocks.layers.activation_fn import act_fn
+# from megablocks.layers.arguments import Arguments
+# from megablocks.layers.mlp import (
+# SharedMLP,
+# SparseMLP,
+# create_dmoe_expert_weights,
+# resolve_dtensor,
+# )
+
+from .. import grouped_gemm_util as gg
+from . import common, mpu
+from .activation_fn import act_fn
+from .arguments import Arguments
+from .mlp import (
+ SharedMLP,
+ SparseMLP,
+ create_dmoe_expert_weights,
+ resolve_dtensor,
+)
+
+
+class SparseGLU(SparseMLP):
+
+ def __init__(self, args: Arguments):
+ super().__init__(args)
+ self.v1 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ with torch.no_grad():
+ self.v1.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ ),
+ )
+
+ mpu.set_expert_model_parallel_attributes(
+ self.v1,
+ self._should_set_parallelism_attribute,
+ )
+
+ def forward(self, x, topo):
+ if self.args.memory_optimized_mlp:
+ raise NotImplementedError(
+ 'Memory optimized implementation not yet supported with GLU with sparse kernels.',
+ )
+
+ w1, v1, w2 = self.scale_grad(self.w1), self.scale_grad(self.v1,), self.scale_grad(self.w2)
+ w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2)
+
+ # Compute the GLU.
+ x1 = stk.ops.sdd(x, w1.t(), topo)
+ x2 = stk.ops.sdd(x, v1.t(), topo)
+
+ activation_fn_out = act_fn(x1, self.args.activation_fn)
+ x1 = stk.ops.mul(activation_fn_out, x2)
+
+ return stk.ops.dsd(x1, w2)
+
+
+class MemoryOptimizedGroupedGLU(torch.autograd.Function):
+ """GroupedMLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, v1, w2, batch_sizes, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ v1 = v1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], v1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not v1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1', 'v1' and 'w2'.")
+
+ # Layer 0: x @ w1.t().
+ assert gg.backend is not None
+ sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True)
+ v1_out = gg.backend.gmm(x, v1, batch_sizes, trans_b=True)
+
+ # GeLU.
+ activation_fn_out = activation_fn(sdd_out) * v1_out
+
+ # Layer 1: x @ w2.
+ dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, v1, w2, batch_sizes, x, sdd_out, v1_out)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx, ddsd_out):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # Unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, v1, w2 = saved_tensors[:3]
+ batch_sizes = saved_tensors[3]
+ x = saved_tensors[4]
+ sdd_out, v1_out = saved_tensors[5:7]
+
+ # Rematerialize activation_fn output.
+ activation_fn = ctx.activation_fn
+ with torch.set_grad_enabled(True):
+ sdd_out.requires_grad = True
+ v1_out.requires_grad = True
+ activation_fn_out = activation_fn(sdd_out) * v1_out
+ activation_grad_fn = activation_fn_out.backward
+
+ # Compute dw2 with recomputed activation_fn output.
+ assert gg.backend is not None
+ dw2 = gg.backend.gmm(
+ activation_fn_out,
+ ddsd_out,
+ batch_sizes,
+ trans_a=True,
+ )
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ gg.backend.gmm(
+ ddsd_out,
+ w2,
+ batch_sizes,
+ trans_b=True,
+ c=dactivation_fn_out,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out)
+ dsdd_out = sdd_out.grad
+ dv1_out = v1_out.grad
+
+ # Compute dw1.
+ dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True)
+
+ # Compute dv1.
+ dv1 = gg.backend.gmm(dv1_out, x, batch_sizes, trans_a=True)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ dx = ddsd_out
+ gg.backend.gmm(dsdd_out, w1, batch_sizes, c=dx)
+ dx += gg.backend.gmm(dv1_out, v1, batch_sizes)
+ return dx, dw1, dv1, dw2, None, None
+
+
+memory_optimized_grouped_glu = MemoryOptimizedGroupedGLU.apply
+
+
+class GroupedGLU(SparseGLU):
+
+ def forward(self, x, tokens_per_expert):
+ batch_sizes = tokens_per_expert.cpu().to(torch.long)
+ w1, v1, w2 = (
+ self.scale_grad(self.w1),
+ self.scale_grad(self.v1),
+ self.scale_grad(self.w2),
+ )
+ w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2)
+
+ # Re-shape the weights for the grouped GEMMs.
+ ne = mpu.experts_per_rank(self.args)
+ w1 = w1.view(ne, -1, self.args.hidden_size)
+ v1 = v1.view(ne, -1, self.args.hidden_size)
+ w2 = w2.view(ne, -1, self.args.hidden_size)
+
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_grouped_glu(
+ x,
+ w1,
+ v1,
+ w2,
+ batch_sizes,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ assert gg.ops is not None
+ x1 = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
+ x2 = gg.ops.gmm(x, v1, batch_sizes, trans_b=True)
+ x1 = self.args.activation_fn(x1) * x2
+ return gg.ops.gmm(x1, w2, batch_sizes)
+
+
+class SharedGLU(SharedMLP):
+ """GPU for shared expert.
+
+ Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTGLU class
+ """
+
+ def __init__(self, args: Arguments):
+ super().__init__(args)
+ self.gate_proj = args.fc_cls(
+ args.hidden_size,
+ self.args.shared_expert_hidden_size,
+ **self.fc_kwargs,
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))
diff --git a/build/torch29-cxx11-cu126-aarch64-linux/_layers/memory_test.py b/build/torch29-cxx11-cu126-aarch64-linux/_layers/memory_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..74d1166931b712635131985b25a89f4ca23e576d
--- /dev/null
+++ b/build/torch29-cxx11-cu126-aarch64-linux/_layers/memory_test.py
@@ -0,0 +1,103 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import gc
+
+import torch
+import torch.distributed as dist
+
+# from megablocks.layers import arguments, dmoe
+from . import arguments, dmoe
+
+_TESTS = ((8, 2048, 4096, 4096, 32, 4),)
+
+
+def get_tensors():
+ ptrs = set()
+ out = []
+ for obj in gc.get_objects():
+ if torch.is_tensor(obj):
+ if not obj.is_contiguous() or obj.data_ptr() in ptrs:
+ continue
+ out.append(obj)
+ ptrs.add(obj.data_ptr())
+ return out
+
+
+def test_memory(
+ group,
+ batch_size,
+ sequence_length,
+ hidden_size,
+ ffn_hidden_size,
+ num_experts,
+ top_k,
+):
+ args = arguments.Arguments(
+ hidden_size=hidden_size,
+ ffn_hidden_size=ffn_hidden_size,
+ moe_num_experts=num_experts,
+ moe_top_k=top_k,
+ moe_expert_model_parallelism=True,
+ expert_parallel_group=group,
+ fp16=False,
+ bf16=True,
+ device=torch.cuda.current_device(),
+ )
+ layer = dmoe.dMoE(args).cuda()
+
+ x = torch.randn((batch_size, sequence_length, hidden_size),
+ device=torch.cuda.current_device(),
+ dtype=torch.bfloat16).requires_grad_(True)
+ torch.cuda.empty_cache()
+
+ # Run forward + backward.
+ # with torch.autograd.detect_anomaly():
+ out, _ = layer(x)
+ out.mean().backward()
+
+ # Report peak memory.
+ mem = torch.cuda.max_memory_allocated()
+ print('Max Memory Allocated = {:0.0f}MiB'.format(mem / 1e6))
+ print('Max Memory Reserved = {:0.0f}MiB'.format(torch.cuda.max_memory_reserved() / 1e6,),)
+
+ # Calculate weight and gradient memory usage.
+ weight_memory = 2 * (
+ layer.router.layer.weight.numel() + layer.experts.mlp.w1.numel() + layer.experts.mlp.w2.numel()
+ )
+
+ def grad_numel(x):
+ if x.grad is not None:
+ return x.grad.numel()
+ return 0
+
+ grad_memory = 2 * (
+ grad_numel(layer.router.layer.weight) + grad_numel(layer.experts.mlp.w1) + grad_numel(layer.experts.mlp.w2)
+ )
+ weight_memory += grad_memory
+
+ print('Weight Memory Allocated = {:0.0f}MiB'.format(weight_memory / 1e6))
+ print('Activation Memory Allocated = {:0.0f}MiB'.format((mem - weight_memory) / 1e6,),)
+
+ # Manually calculate GPU memory usage from the garbage
+ # collector.
+ gc.collect()
+ total = 0
+ tensors = get_tensors()
+ tensors = sorted(tensors, key=lambda x: -x.numel())
+ for i, t in enumerate(tensors):
+ total += t.numel()
+ print(f'{i}: {t.shape}, {t.numel() * 2}')
+ del tensors
+
+ print('Total Bytes Found = {:0.0f}MiB'.format(total * 2 / 1e6))
+
+
+if __name__ == '__main__':
+ assert dist.is_available()
+ group = dist.init_process_group(backend='nccl')
+ local_rank = dist.get_rank(group)
+ torch.cuda.set_device(local_rank)
+
+ for args in _TESTS:
+ test_memory(group, *args)
diff --git a/build/torch29-cxx11-cu126-aarch64-linux/_layers/mlp.py b/build/torch29-cxx11-cu126-aarch64-linux/_layers/mlp.py
new file mode 100644
index 0000000000000000000000000000000000000000..c99afb9904c24a8b6a83e79059cd1251dbbfd99e
--- /dev/null
+++ b/build/torch29-cxx11-cu126-aarch64-linux/_layers/mlp.py
@@ -0,0 +1,587 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# try:
+# import stk
+# import stk.backend.triton_kernels
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/mlp.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+from packaging import version
+
+# from megablocks import grouped_gemm_util as gg
+# from megablocks.layers import common, gelu, mpu
+# from megablocks.layers.activation_fn import act_fn
+# from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
+
+from .. import grouped_gemm_util as gg
+from . import common, gelu, mpu
+from .activation_fn import act_fn
+from .arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
+
+class ScaleGradient(torch.autograd.Function):
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx: Any, x: torch.Tensor, scale: float):
+ ctx.scale = scale
+ return x
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx: torch.Tensor, grad: torch.Tensor):
+ return grad * ctx.scale, None
+
+
+scale_gradient = ScaleGradient.apply
+
+
+def resolve_dtensor(weight: torch.Tensor):
+ if version.parse(torch.__version__) >= version.parse('2.0.0'):
+ from torch.distributed._tensor import DTensor
+ if isinstance(weight, DTensor):
+ return weight.to_local()
+ return weight
+
+
+def create_moe_expert_weights(
+ args: Arguments,
+ num_experts: int,
+ ffn_hidden_size: int,
+ hidden_size: int,
+ init_method: InitFn,
+):
+ # Create the entire weight matrix such that the sampled weights will
+ # not vary between data parallelism and expert model parallelism for
+ # the same random seed.
+ master_weights = torch.empty(
+ num_experts,
+ ffn_hidden_size,
+ hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ )
+ init_method(master_weights)
+
+ if not args.moe_expert_model_parallelism:
+ return master_weights
+
+ # Calculate the amount of sharding in each dimension.
+ expert_sharding_degree = mpu.expert_sharding_degree(args)
+ hidden_sharding_degree = mpu.hidden_sharding_degree(args)
+
+ # Calculate the experts per rank.
+ #
+ # NOTE: We assign ranks to be expert parallel before going
+ # tensor parallel.
+ rank = mpu.get_expert_parallel_rank(args)
+ expert_rank = rank % expert_sharding_degree
+ num_experts_per_rank = num_experts // expert_sharding_degree
+ start_expert = expert_rank * num_experts_per_rank
+ end_expert = (expert_rank + 1) * num_experts_per_rank
+
+ # Calculate the rows per rank.
+ row_rank = rank // expert_sharding_degree
+ num_rows_per_rank = ffn_hidden_size // hidden_sharding_degree
+ start_row = row_rank * num_rows_per_rank
+ end_row = (row_rank + 1) * num_rows_per_rank
+
+ # Slice the weight matrix to get the chunk for this rank.
+ with torch.no_grad():
+ weights = master_weights[start_expert:end_expert, start_row:end_row]
+ return weights
+
+
+class MLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ # expert_parallel_world_size = mpu.get_expert_parallel_world_size(args)
+ experts_per_rank = mpu.experts_per_rank(args)
+
+ self.w1 = torch.nn.Parameter(
+ torch.empty(
+ experts_per_rank,
+ args.hidden_size,
+ mpu.features_per_rank(args),
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ self.w2 = torch.nn.Parameter(
+ torch.empty(
+ experts_per_rank,
+ mpu.features_per_rank(args),
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w1,
+ args.moe_expert_model_parallelism,
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w2,
+ args.moe_expert_model_parallelism,
+ )
+
+ # Initialize the parameters for the MLP.
+ #
+ # NOTE: It is important that we create the weight tensors prior
+ # to creating the master weights and slicing our the piece for
+ # this rank. If the master weights are created first the PyTorch
+ # caching allocator appears to use the same memory block for these
+ # and the slice which causes large increases in our peak memory
+ # usage.
+ with torch.no_grad():
+ w1 = create_moe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ )
+ self.w1.copy_(w1.transpose(1, 2).contiguous())
+ self.w2.copy_(
+ create_moe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.output_layer_init_method,
+ ),
+ )
+
+ self.gradient_scale = None
+ if self.args.moe_expert_model_parallelism:
+ self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,)
+
+ def scale_grad(self, w):
+ if self.gradient_scale is None:
+ return w
+ return scale_gradient(w, self.gradient_scale)
+
+ def forward(self, x):
+ w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2)
+ w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2)
+ x = torch.bmm(x, w1)
+ x = self.args.activation_fn(x)
+ return torch.bmm(x, w2)
+
+
+def create_dmoe_expert_weights(
+ args: Arguments,
+ num_experts: int,
+ rows: int,
+ columns: int,
+ init_method: InitFn,
+):
+ weights = create_moe_expert_weights(
+ args,
+ num_experts,
+ rows,
+ columns,
+ init_method,
+ )
+ return weights.view([-1, columns])
+
+
+class MemoryOptimizedMLP(torch.autograd.Function):
+ """Sparse MLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, w2, topo, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.")
+
+ topo_tensors = (
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets,
+ topo.column_indices_t,
+ topo.offsets_t,
+ topo.block_offsets_t,
+ )
+
+ # Layer 0: x @ w1.t().
+ sdd_out = stk.ops.sdd(x, w1.t(), topo)
+
+ # GeLU.
+ activation_fn_out = act_fn(sdd_out, activation_fn)
+
+ # Layer 1: x @ w2.
+ dsd_out = stk.ops.dsd(activation_fn_out, w2)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.shape = topo.shape
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.data.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, w2, *topo_tensors, x, sdd_out.data)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx, ddsd_out):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, w2 = saved_tensors[:2]
+ topo_tensors = saved_tensors[2:8]
+ x = saved_tensors[8]
+ sdd_out_data = saved_tensors[9]
+
+ # rematerialize activation function output
+ activation_fn = ctx.activation_fn
+ sdd_out = stk.Matrix(ctx.shape, sdd_out_data, *topo_tensors)
+ activation_fn_out, activation_grad_fn = act_fn(
+ sdd_out,
+ activation_fn,
+ return_grad_fn=True,
+ )
+
+ # Compute dw2 with recomputed activation_fn output.
+ dw2 = stk.ops.dsd(activation_fn_out.t(), ddsd_out)
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ stk.backend.triton_kernels.sdd(
+ ddsd_out,
+ w2.t(),
+ dactivation_fn_out.shape,
+ dactivation_fn_out.data,
+ dactivation_fn_out.offsets,
+ dactivation_fn_out.row_indices,
+ dactivation_fn_out.column_indices,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ if activation_fn is DEFAULT_ACTIVATION_FN:
+ dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out)
+ else:
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out.data)
+ dsdd_out = stk.Matrix(ctx.shape, sdd_out.data.grad, *topo_tensors)
+
+ # Compute dw1.
+ dw1 = stk.ops.dsd(dsdd_out.t(), x)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ stk.backend.triton_kernels.dsd(
+ dsdd_out.shape,
+ dsdd_out.data,
+ dsdd_out.offsets,
+ dsdd_out.row_indices,
+ dsdd_out.column_indices,
+ dsdd_out.offsets_t,
+ dsdd_out.column_indices_t,
+ dsdd_out.block_offsets_t,
+ False,
+ w1,
+ ddsd_out,
+ )
+ dx = ddsd_out
+ return dx, dw1, dw2, None, None
+
+
+memory_optimized_mlp = MemoryOptimizedMLP.apply
+
+
+class SparseMLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ self._num_rows_per_rank = mpu.experts_per_rank(args) * mpu.features_per_rank(args)
+
+ self.w1 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ self.w2 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+
+ # Initialize the parameters for the MLP.
+ #
+ # NOTE: It is important that we create the weight tensors prior
+ # to creating the master weights and slicing our the piece for
+ # this rank. If the master weights are created first the PyTorch
+ # caching allocator appears to use the same memory block for these
+ # and the slice which causes large increases in our peak memory
+ # usage.
+ with torch.no_grad():
+ self.w1.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ ),
+ )
+ self.w2.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.output_layer_init_method,
+ ),
+ )
+
+ self._should_set_parallelism_attribute = args.moe_expert_model_parallelism
+ mpu.set_expert_model_parallel_attributes(
+ self.w1,
+ self._should_set_parallelism_attribute,
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w2,
+ self._should_set_parallelism_attribute,
+ )
+
+ self.gradient_scale = None
+ if self.args.moe_expert_model_parallelism:
+ self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,)
+
+ def scale_grad(self, w):
+ if self.gradient_scale is None:
+ return w
+ return scale_gradient(w, self.gradient_scale)
+
+ def forward(self, x, topo):
+ w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2)
+ w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2)
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_mlp(
+ x,
+ w1,
+ w2,
+ topo,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ x = stk.ops.sdd(x, w1.t(), topo)
+ activation_fn_out = act_fn(x, self.args.activation_fn)
+ return stk.ops.dsd(activation_fn_out, w2)
+
+
+class MemoryOptimizedGroupedMLP(torch.autograd.Function):
+ """GroupedMLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, w2, batch_sizes, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.")
+
+ # Layer 0: x @ w1.t().
+ assert gg.backend is not None
+ sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True)
+
+ # activation_fn
+ activation_fn_out = activation_fn(sdd_out)
+
+ # Layer 1: x @ w2.
+ dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, w2, batch_sizes, x, sdd_out)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx: Any, ddsd_out: torch.Tensor):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # Unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, w2 = saved_tensors[:2]
+ batch_sizes = saved_tensors[2]
+ x = saved_tensors[3]
+ sdd_out = saved_tensors[4]
+
+ # Rematerialize activation_fn output.
+ activation_fn = ctx.activation_fn
+ with torch.set_grad_enabled(True):
+ sdd_out.requires_grad = True
+ activation_fn_out = activation_fn(sdd_out)
+ activation_grad_fn = activation_fn_out.backward
+
+ # Compute dw2 with recomputed activation_fn output.
+ assert gg.backend is not None
+ dw2 = gg.backend.gmm(
+ activation_fn_out,
+ ddsd_out,
+ batch_sizes,
+ trans_a=True,
+ )
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ gg.backend.gmm(
+ ddsd_out,
+ w2,
+ batch_sizes,
+ trans_b=True,
+ c=dactivation_fn_out,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ if activation_fn is DEFAULT_ACTIVATION_FN:
+ dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out)
+ else:
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out)
+ dsdd_out = sdd_out.grad
+
+ # Compute dw1.
+ dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ gg.backend.gmm(dsdd_out, w1, batch_sizes, c=ddsd_out)
+ dx = ddsd_out
+ return dx, dw1, dw2, None, None
+
+
+memory_optimized_grouped_mlp = MemoryOptimizedGroupedMLP.apply
+
+
+class GroupedMLP(SparseMLP):
+
+ def forward(self, x, tokens_per_expert):
+ batch_sizes = tokens_per_expert.cpu().to(torch.long)
+ w1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.w2))
+
+ # Re-shape the weights for the grouped GEMMs.
+ ne = mpu.experts_per_rank(self.args)
+ w1 = resolve_dtensor(w1).view(ne, -1, self.args.hidden_size)
+ w2 = resolve_dtensor(w2).view(ne, -1, self.args.hidden_size)
+
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_grouped_mlp(
+ x,
+ w1,
+ w2,
+ batch_sizes,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ assert gg.ops is not None
+ x = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
+ x = self.args.activation_fn(x)
+ return gg.ops.gmm(x, w2, batch_sizes)
+
+
+class SharedMLP(torch.nn.Module):
+ """MLP for shared expert.
+
+ Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTMLP class
+ """
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ self.fc_kwargs: dict[str, Any] = {
+ 'bias': args.bias,
+ 'device': args.device,
+ }
+ self.fc_kwargs.update(args.fc_kwargs)
+
+ self.up_proj = args.fc_cls(
+ args.hidden_size,
+ args.shared_expert_hidden_size,
+ **self.fc_kwargs,
+ )
+ self.act = args.activation_fn
+ self.down_proj = args.fc_cls(
+ args.shared_expert_hidden_size,
+ args.hidden_size,
+ **self.fc_kwargs,
+ )
+ self.down_proj._is_residual = True # a flag for llm-foundry init
+
+ def add_experts_sharedexpert(
+ self,
+ shared_expert_out: torch.Tensor,
+ expert_out: torch.Tensor,
+ ) -> torch.Tensor:
+ # Helper function to add expert output to shared expert output
+ # with optional weighted sum.
+ if self.args.shared_expert_weighted_sum:
+ # enable using weighted sum for shared expert output
+ # wieghted by number of experts used
+ t_experts = self.args.moe_top_k + 1
+ sh_mlp_out = shared_expert_out / t_experts
+ return sh_mlp_out.add(
+ expert_out,
+ alpha=(self.args.moe_top_k / t_experts),
+ )
+
+ return shared_expert_out + expert_out
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.down_proj(self.act(self.up_proj(x)))
diff --git a/build/torch29-cxx11-cu126-aarch64-linux/_layers/moe.py b/build/torch29-cxx11-cu126-aarch64-linux/_layers/moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0a4aeaacc9c86fc70944e730c53f7a55644e05e
--- /dev/null
+++ b/build/torch29-cxx11-cu126-aarch64-linux/_layers/moe.py
@@ -0,0 +1,507 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Optional, Tuple
+
+import numpy as np
+import torch
+import torch.distributed as dist
+
+# import megablocks.ops as ops
+# from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry
+# from megablocks.layers.all_to_all import all_to_all
+# from megablocks.layers.arguments import Arguments
+
+from ..ops import (
+ sort,
+ histogram,
+ inclusive_cumsum,
+ exclusive_cumsum,
+ binned_gather,
+ binned_scatter,
+ gather,
+ scatter,
+ repeat,
+ replicate,
+)
+
+from . import common, mlp, mpu, router, sharedexpert_registry
+from .arguments import Arguments
+from .all_to_all import all_to_all
+
+_LOAD_BALANCING_LOSS = []
+
+
+def save_load_balancing_loss(loss):
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.append(loss)
+
+
+def get_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ return _LOAD_BALANCING_LOSS
+
+
+def clear_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.clear()
+
+
+def batched_load_balancing_loss(args: Arguments):
+ if args.moe_loss_weight == 0:
+ return 0.0
+
+ # tokens_per_expert[i].shape = (num_experts)
+ # expert_scores[i].shape = (tokens, num_experts)
+ tokens_per_expert, expert_scores = zip(*get_load_balancing_loss())
+ num_layers_per_pipeline_stage = (args.num_layers // args.pipeline_model_parallel_size)
+ if args.num_layers_per_virtual_pipeline_stage is not None:
+ num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage
+
+ if len(tokens_per_expert) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f'Expected {num_layers_per_pipeline_stage} token_per_experts '
+ f'but found {len(tokens_per_expert)}.\nnum_layers = '
+ f'{args.num_layers}\npipeline_model_parallel_size = '
+ f'{args.pipeline_model_parallel_size}\n'
+ 'num_layers_per_virtual_pipeline_stage'
+ f' = {args.num_layers_per_virtual_pipeline_stage}',
+ )
+ if len(expert_scores) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f'Expected {num_layers_per_pipeline_stage} expert_scores '
+ f'but found {len(tokens_per_expert)}.\nnum_layers = '
+ f'{args.num_layers}\npipeline_model_parallel_size = '
+ f'{args.pipeline_model_parallel_size}\n'
+ 'num_layers_per_virtual_pipeline_stage'
+ f' = {args.num_layers_per_virtual_pipeline_stage}',
+ )
+
+ # Verify the shape of the tokens_per_expert and expert_scores tensors.
+ assert all((x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert))
+
+ tokens = expert_scores[0].shape[0]
+ assert all(((x.ndim == 2 and x.shape[1] == args.moe_num_experts and x.shape[0] == tokens) for x in expert_scores))
+
+ # Concatenate the contributions of each layer and convert to
+ # the correct types and formats for the dot product.
+ expert_scores = torch.cat(expert_scores, dim=1)
+ if args.moe_lbl_in_fp32:
+ expert_scores = expert_scores.float()
+ if tokens != 0:
+ expert_scores = expert_scores.mean(dim=0)
+ else:
+ expert_scores = expert_scores.sum(dim=0)
+ tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype)
+
+ expected_values = num_layers_per_pipeline_stage * args.moe_num_experts
+ assert tokens_per_expert.numel() == expected_values
+ assert expert_scores.numel() == expected_values
+
+ # Calculate the total scale across all factors.
+ #
+ # loss_weight * num_experts / (num_layers * tokens * top_k)
+ scale_numerator = (args.moe_num_experts * args.moe_loss_weight)
+ scale_denominator = (args.num_layers * tokens * args.moe_top_k)
+ scale = scale_numerator / scale_denominator
+ return scale * torch.dot(tokens_per_expert, expert_scores)
+
+
+# NOTE: This class defines MoE expert computation, including expert model parallel
+# communication. When using FSDP on top of MegaBlocks this is the module that should
+# be wrapped s.t. the weight all-gathers can be scheduled *before* the expert model
+# parallel all2all.
+class ParallelMLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super(ParallelMLP, self).__init__()
+ self.args = args
+
+ # Calculate the number of experts in total and the number of experts
+ # owned by this rank.
+ # world_size = mpu.get_expert_parallel_world_size(args)
+ self.num_experts = args.moe_num_experts
+ self.top_k = self.args.moe_top_k
+
+ # Calculate the number of bits needed to represent the expert indices
+ # so that we can pass it to radix sort.
+ self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1)
+
+ # Expert MLP.
+ self.mlp = mlp.MLP(args)
+
+ self.bias: Optional[torch.Tensor]
+ if self.args.bias:
+ # Note that the output bias is not parallelized with expert
+ # model parallelism.
+ self.bias = torch.nn.Parameter(
+ torch.empty(
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ torch.nn.init.zeros_(self.bias)
+ else:
+ self.register_parameter('bias', None)
+
+ # Select the forward function for the operating mode.
+ self.forward_fn = (self.parallel_forward_once if args.moe_expert_model_parallelism else self.forward_once)
+
+ def expert_capacity(self, tokens: int) -> int:
+ world_size = mpu.get_expert_parallel_world_size(self.args)
+ tokens_per_expert = (self.top_k * tokens * world_size / self.num_experts)
+ return int(self.args.moe_capacity_factor * tokens_per_expert)
+
+ def load_balancing_loss(self, tokens_per_expert: torch.Tensor, expert_scores: torch.Tensor):
+ """Calculate the load balancing loss contribution."""
+ assert len(expert_scores.size()) == 2
+ tokens, num_experts = expert_scores.size()
+ assert num_experts == self.num_experts
+ assert len(tokens_per_expert.size()) == 1
+ num_experts, = tokens_per_expert.size()
+ assert num_experts == self.num_experts
+ scale = self.num_experts / (tokens * self.top_k)
+ return scale * torch.dot(
+ tokens_per_expert.to(expert_scores.dtype),
+ expert_scores.mean(dim=0),
+ )
+
+ def indices_and_bins(self,
+ top_expert: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ # Sort the expert ids to produce the scatter/gather
+ # indices for the permutation.
+ #
+ # TODO(tgale): Is it worth doing this conversion to 32-bit
+ # prior? Could we place the `torch.max` operation to return
+ # 32-bit expert indices?
+ top_expert = top_expert.int()
+ # output = ops.sort(top_expert, self.sort_end_bit)
+ output = sort(top_expert, self.sort_end_bit)
+ assert output is not None
+ bin_ids, indices = output
+
+ # Histogram the expert ids to identify the number of
+ # tokens routed to each expert.
+ #
+ # TODO(tgale): Does the sorted data produce a more favorable
+ # data distribution for histogram? Or is the op parallelism
+ # worth more?
+ # tokens_per_expert = ops.histogram(top_expert, self.num_experts)
+ tokens_per_expert = histogram(top_expert, self.num_experts)
+
+ # Calculate the bin bounds for the sorted tokens.
+ # bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ bins = inclusive_cumsum(tokens_per_expert, 0)
+ assert bins is not None
+ bins = bins.view(1) if not len(bins.size()) else bins
+
+ assert isinstance(indices, torch.Tensor)
+ assert isinstance(bin_ids, torch.Tensor)
+ assert isinstance(bins, torch.Tensor)
+ assert isinstance(tokens_per_expert, torch.Tensor)
+
+ return indices, bin_ids, bins, tokens_per_expert
+
+ def permute_and_compute(
+ self,
+ x: torch.Tensor,
+ tokens_per_expert: int, # unused
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor, # unused
+ expert_weights: torch.Tensor,
+ bins: torch.Tensor,
+ expert_capacity: int,
+ top_k: int,
+ ):
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ # output = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
+ output = binned_gather(x, indices, bins, expert_capacity, top_k)
+ assert output is not None
+ x = output
+
+ # Perform the expert computation. Note that we don't
+ # use biases for these linear operations.
+ x = self.mlp(x)
+
+ # Un-route the data for the MoE output.
+ # return ops.binned_scatter(x, indices, expert_weights, bins, top_k)
+ return binned_scatter(x, indices, expert_weights, bins, top_k)
+
+
+ def forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ # If expert_capacity is set to zero, set the number of tokens
+ # per expert to the maximum we need to avoid dropping tokens.
+ sl, bs, _ = x.size()
+ expert_capacity = self.expert_capacity(sl * bs)
+ if expert_capacity == 0:
+ expert_capacity = torch.max(tokens_per_expert).item()
+
+ x = self.permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ self.top_k,
+ )
+ return x, tokens_per_expert
+
+ def parallel_forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ # NOTE: This function implements the same computation as forward_once
+ # but with expert model parallelism.
+ #
+ # 1. Permute the tokens locally so that they are grouped by their
+ # expert assignments. This allows us to transfer all of the tokens
+ # for a remote device in one communication primitive.
+ #
+ # 2. Permute the tokens across the expert parallel devices. After
+ # this is completed each device has all of the tokens assigned to
+ # its set of experts in its local HBM.
+ #
+ # 3. Permute the tokens locally so that they are grouped by their
+ # expert assignement. After the distributed permutation the tokens
+ # are grouped by which device they came from. We re-order them
+ # locally to allow for efficient computation.
+ #
+ # After this series of permutations we compute the linear layers
+ # and then repeat these three steps in reverse to produce the final
+ # output.
+ #
+ # Compute the mapping of local tokens to experts.
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ # If we're sharding the experts along the hidden dimension
+ # multiple devices own parts of the same sets of experts.
+ # Replicate the token counts so every device gets the counts.
+ # repeated_tokens_per_expert = ops.repeat(
+ repeated_tokens_per_expert = repeat(
+ tokens_per_expert,
+ (mpu.hidden_sharding_degree(self.args),),
+ )
+
+ # Pass token count information to the device on which the
+ # target expert resides.
+ parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert,)
+ tpe_handle = dist.all_to_all_single(
+ parallel_tokens_per_expert,
+ repeated_tokens_per_expert,
+ group=self.args.expert_parallel_group,
+ async_op=True,
+ )
+
+ # Permute locally and without any padding so that tokens for each
+ # parallel device are stored contiguously.
+ #
+ # This view updates the shape of the tensor from [sl, bs, hs] to
+ # [sl * bs, hs] prior to the permutation.
+ x = x.view(-1, x.shape[-1])
+ # output = ops.gather(x, indices, bin_ids, bins, self.top_k)
+ output = gather(x, indices, bin_ids, bins, self.top_k)
+ assert output is not None
+ x = output
+
+ # Compute the number of tokens that will be received from each
+ # device and permute the input data across the devices.
+ with torch.no_grad():
+ tpe_handle.wait()
+ experts_per_rank = mpu.experts_per_rank(self.args)
+
+ # Reshape to [world_size, num_experts_per_rank].
+ world_size = mpu.get_expert_parallel_world_size(self.args)
+ repeated_tokens_per_expert = (repeated_tokens_per_expert.view(world_size, experts_per_rank))
+ parallel_tokens_per_expert = (parallel_tokens_per_expert.view(world_size, experts_per_rank))
+
+ # TODO(tgale): It might be faster to do this on the GPU and
+ # then communicate the results back to the host.
+ send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1)
+ parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
+ recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1)
+
+ # Convert the send/recv counts to lists.
+ send_counts = send_counts.tolist()
+ recv_counts = recv_counts.tolist()
+ tokens_received = sum(recv_counts)
+
+ # If we're sharding the experts along the hidden dimension
+ # multiple devices own parts of the same sets of experts.
+ # Replicate the token counts so devices that share experts
+ # get all of the tokens assigned to them.
+ #
+ # TODO(tgale): Fuse this into the prior, local permutation.
+ # x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
+ x = repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
+
+ # Start the cross-device permutation asynchronously so we can
+ # overlap communication with computation.
+ parallel_x, parallel_x_handle = all_to_all(
+ x,
+ recv_counts,
+ send_counts,
+ self.args.expert_parallel_group,
+ async_op=True,
+ )
+
+ with torch.no_grad():
+ # After we do the cross-device permutation we have the tokens on the
+ # correct device but not yet grouped by expert because we received
+ # tokens from each device as contiguous chunks. To group the tokens
+ # for expert computation we'll do one more local permutation. The
+ # rest of this torch.no_grad() scope sets up the indices and bins
+ # for this permutation.
+ # replicate_bins = ops.inclusive_cumsum(
+ replicate_bins = inclusive_cumsum(
+ parallel_tokens_per_expert.flatten(),
+ 0,
+ )
+ replicate_bins = (replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins)
+
+ # Construct the expert indices for the permuted tokens.
+ parallel_top_expert = torch.remainder(
+ torch.arange(
+ self.num_experts * mpu.hidden_sharding_degree(self.args),
+ dtype=torch.int32,
+ device=indices.device,
+ ),
+ mpu.experts_per_rank(self.args),
+ )
+ # parallel_top_expert = ops.replicate(
+ parallel_top_expert = replicate(
+ parallel_top_expert.unsqueeze(dim=0),
+ replicate_bins,
+ tokens_received,
+ ).flatten()
+
+ # TODO(tgale): The sort_end_bit here can be reduced.
+ # parallel_bin_ids, parallel_indices = ops.sort(
+ parallel_bin_ids, parallel_indices = sort(
+ parallel_top_expert,
+ self.sort_end_bit,
+ )
+
+ # Calculate the bins boundaries from the token counts.
+ parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+ # parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = (parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins)
+
+ # If expert_capacity is set to zero, set the number of tokens
+ # per expert to the maximum we need to avoid dropping tokens.
+ tokens, _ = x.size()
+ expert_capacity = self.expert_capacity(tokens)
+ if expert_capacity == 0:
+ expert_capacity = torch.max(parallel_tokens_per_expert).item()
+
+ # Locally permute the tokens and perform the expert computation.
+ # Block to make sure that the cross-device permutation is complete.
+ if self.args.mlp_impl == 'grouped':
+ # GroupedMLP requires counts on CPU. We can use the tensor already
+ # moved to CPU for the prior all_to_all, which avoids an extra
+ # device synchronization.
+ parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+ parallel_x_handle.wait()
+ parallel_x = self.permute_and_compute(
+ parallel_x,
+ parallel_tokens_per_expert,
+ parallel_indices,
+ parallel_bin_ids,
+ None, # expert_weights
+ parallel_bins,
+ expert_capacity,
+ top_k=1,
+ )
+
+ # Un-permute the tokens across the devices.
+ x, _ = all_to_all(
+ parallel_x,
+ send_counts,
+ recv_counts,
+ self.args.expert_parallel_group,
+ )
+
+ # Reduce along the hidden sharding to get the final outputs.
+ #
+ # TODO(tgale): Fuse this into the following local permutation.
+ shape = (
+ mpu.hidden_sharding_degree(self.args),
+ -1,
+ self.args.hidden_size,
+ )
+ # x = ops.sum(x.view(shape), dim=0)
+ x = x.view(shape).sum(dim=0)
+
+ # Un-permute locally to setup for the next series of operations.
+ # x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
+ x = scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
+ return x, tokens_per_expert.flatten()
+
+ def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ in_shape = x.size()
+
+ # Compute the experts.
+ x, tokens_per_expert = self.forward_fn(x, expert_weights, top_experts)
+ if self.training and self.args.moe_loss_weight > 0:
+ save_load_balancing_loss((tokens_per_expert, scores))
+ x = x.view(in_shape)
+ if self.bias is not None:
+ if self.args.return_bias:
+ return x, self.bias
+ return x + self.bias
+ return x
+
+
+class MoE(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super(MoE, self).__init__()
+
+ # Token router.
+ self.router = router.LearnedRouter(args)
+
+ # Expert computation helper.
+ self.experts = self._init_experts_mlp(args)
+
+ self.shared_expert = None
+ if args.shared_expert:
+ # SharedExpert computation helper.
+ self.shared_expert = sharedexpert_registry.get(args)
+
+ def _init_experts_mlp(self, args: Arguments):
+ return ParallelMLP(args)
+
+ def forward(self, x: torch.Tensor):
+ # NOTE: If we're going to cast the activations to lower precision
+ # do it before we permute the tokens to save bandwidth.
+ x = common.cast_if_autocast_enabled(x)
+
+ # Compute the expert scores and assignments.
+ scores, expert_weights, top_experts = self.router(x)
+
+ # Compute the experts.
+ out = self.experts(x, scores, expert_weights, top_experts)
+ if self.shared_expert is not None:
+ shared_expert_out = self.shared_expert(x)
+ out = self.shared_expert.add_experts_sharedexpert(
+ shared_expert_out,
+ out,
+ )
+ return out
diff --git a/build/torch29-cxx11-cu126-aarch64-linux/_layers/mpu.py b/build/torch29-cxx11-cu126-aarch64-linux/_layers/mpu.py
new file mode 100644
index 0000000000000000000000000000000000000000..434e143ab42bf3f83406d69e9dd1f72777716e22
--- /dev/null
+++ b/build/torch29-cxx11-cu126-aarch64-linux/_layers/mpu.py
@@ -0,0 +1,94 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Optional
+
+import torch
+import torch.distributed as dist
+
+# from megablocks.layers.arguments import Arguments
+from .arguments import Arguments
+
+
+class MoeParam(torch.Tensor):
+
+ def __init__(self):
+ super().__init__(self)
+ self.expert_model_parallel: bool
+
+
+def is_moe_param(tensor: torch.Tensor) -> bool:
+ return hasattr(tensor, 'expert_model_parallel')
+
+
+def get_expert_parallel_world_size(args: Arguments) -> int:
+ return (dist.get_world_size(args.expert_parallel_group) if args.moe_expert_model_parallelism else 1)
+
+
+def get_expert_parallel_rank(args: Arguments) -> int:
+ return (dist.get_rank(args.expert_parallel_group) if args.moe_expert_model_parallelism else 0)
+
+
+def set_expert_model_parallel_attributes(
+ tensor: torch.Tensor,
+ is_parallel: bool,
+):
+ assert not hasattr(tensor, 'expert_model_parallel')
+ setattr(tensor, 'expert_model_parallel', is_parallel)
+
+
+def param_is_expert_model_parallel(param: MoeParam) -> bool:
+ return (hasattr(param, 'expert_model_parallel') and param.expert_model_parallel)
+
+
+def copy_expert_model_parallel_attributes(
+ destination_tensor: torch.Tensor,
+ source_tensor: torch.Tensor,
+):
+ if hasattr(source_tensor, 'expert_model_parallel'):
+ setattr(
+ destination_tensor,
+ 'expert_model_parallel',
+ getattr(source_tensor, 'expert_model_parallel'),
+ )
+
+
+def synchronized_print(group: Optional[dist.ProcessGroup], *x: torch.Tensor):
+ world_size = dist.get_world_size(group)
+ rank = dist.get_rank(group)
+ for i in range(world_size):
+ dist.barrier(group)
+ if i == rank:
+ print(f'rank = {rank}', *x)
+
+
+# Helpers for expert/tensor sharding.
+def expert_sharding_degree(args: Arguments) -> int:
+ world_size = get_expert_parallel_world_size(args)
+ esd = min(world_size, args.moe_num_experts)
+
+ if (args.moe_num_experts % esd) != 0:
+ raise ValueError(f'Cannot shard {args.moe_num_experts} experts {esd} ways.',)
+ return esd
+
+
+def hidden_sharding_degree(args: Arguments) -> int:
+ world_size = get_expert_parallel_world_size(args)
+ esd = expert_sharding_degree(args)
+ hsd = world_size // esd
+
+ if (args.ffn_hidden_size % hsd) != 0:
+ raise ValueError(f'Cannot shard {args.ffn_hidden_size} features {hsd} ways.',)
+ if (esd * hsd) != world_size:
+ raise ValueError(
+ f"Invalid sharding. 'expert_sharding_degree' ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size}).",
+ )
+ return hsd
+
+
+def experts_per_rank(args: Arguments) -> int:
+ return args.moe_num_experts // expert_sharding_degree(args)
+
+
+def features_per_rank(args: Arguments) -> int:
+ return args.ffn_hidden_size // hidden_sharding_degree(args)
diff --git a/build/torch29-cxx11-cu126-aarch64-linux/_layers/router.py b/build/torch29-cxx11-cu126-aarch64-linux/_layers/router.py
new file mode 100644
index 0000000000000000000000000000000000000000..37cb2782348d62583376f1a183c7ede83601216d
--- /dev/null
+++ b/build/torch29-cxx11-cu126-aarch64-linux/_layers/router.py
@@ -0,0 +1,116 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+
+# from megablocks.layers import common
+# from megablocks.layers.arguments import Arguments
+from . import common
+from .arguments import Arguments
+
+_ROUTER_LOGITS = []
+
+
+def _save_router_logits(logits: torch.Tensor, args: Arguments):
+ if args.moe_zloss_weight == 0:
+ return
+ global _ROUTER_LOGITS
+ _ROUTER_LOGITS.append(logits)
+
+
+def clear_router_zloss():
+ global _ROUTER_LOGITS
+ _ROUTER_LOGITS.clear()
+
+
+def batched_router_zloss(args: Arguments):
+ global _ROUTER_LOGITS
+
+ if args.moe_zloss_weight == 0:
+ import warnings
+ warnings.warn('Call to batched_router_zloss, but moe_zloss_weight=0')
+ return 0
+
+ logits_per_router = _ROUTER_LOGITS
+
+ if args.moe_zloss_in_fp32:
+ logits_per_router = [logits.float() for logits in logits_per_router]
+
+ unscaled_zloss_per_router = torch.stack([
+ torch.logsumexp(logits, dim=1).square().mean() for logits in logits_per_router
+ ])
+
+ return args.moe_zloss_weight * unscaled_zloss_per_router
+
+
+# NOTE: To enable end-to-end benchmarking without convergence we
+# support a flag to force the router to assign tokens uniformly
+# across the experts. We do this with a custom autograd operation
+# so that PyTorch still executes the full set of router operation.
+class _UniformExpertAssignment(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, num_experts: int):
+ out = torch.arange(x.numel(), dtype=x.dtype, device=x.device)
+ out = torch.remainder(out, num_experts)
+ return out.view(x.shape)
+
+
+_uniform_expert_assignment = _UniformExpertAssignment.apply
+
+
+class LearnedRouter(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+
+ # Learned router parameters.
+ #
+ # NOTE: This weight matrix is not parallelized with expert model
+ # parallelism. Each device needs the entire router weight matrix
+ # so that it can route its batch of data correctly.
+ self.layer = torch.nn.Linear(
+ args.hidden_size,
+ args.moe_num_experts,
+ bias=False,
+ dtype=common.dtype(args),
+ device=args.device,
+ )
+ args.init_method(self.layer.weight)
+
+ def jitter(self, x: torch.Tensor):
+ low: float = 1.0 - self.args.moe_jitter_eps
+ high: float = 1.0 + self.args.moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return low + noise * (high - low)
+
+ def _top_k(self, scores: torch.Tensor):
+ if self.args.moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, self.args.moe_top_k, dim=-1)
+
+ def forward(self, x: torch.Tensor):
+ if self.training and self.args.moe_jitter_eps is not None:
+ x = x * self.jitter(x)
+
+ logits = self.layer(x.view(-1, x.shape[-1]))
+ _save_router_logits(logits, self.args)
+ scores = logits.softmax(dim=-1)
+ expert_weights, expert_indices = self._top_k(scores)
+ if self.args.moe_normalize_expert_weights:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=self.args.moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+
+ expert_indices = (
+ _uniform_expert_assignment(
+ expert_indices,
+ self.args.moe_num_experts,
+ ) if self.args.uniform_expert_assignment else expert_indices
+ )
+ return scores, expert_weights, expert_indices
diff --git a/build/torch29-cxx11-cu126-aarch64-linux/_layers/sharedexpert_registry.py b/build/torch29-cxx11-cu126-aarch64-linux/_layers/sharedexpert_registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..5840862f88f370ace5fd49bd0612fc98d186cc49
--- /dev/null
+++ b/build/torch29-cxx11-cu126-aarch64-linux/_layers/sharedexpert_registry.py
@@ -0,0 +1,32 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Union
+
+# from megablocks.layers import glu, mlp
+# from megablocks.layers.arguments import Arguments
+from . import glu, mlp
+from .arguments import Arguments
+
+_REGISTRY = {
+ 'mlp': mlp.SharedMLP,
+ 'glu': glu.SharedGLU,
+}
+
+
+def get(args: Arguments) -> Union[mlp.SharedMLP, glu.SharedGLU]:
+ """Returns an SharedMLP for use in a dMoE instance.
+
+ Uses the provided arguments to instantiate the appropriate
+ SharedMLP instance.
+
+ Args:
+ args: propagated Arguments dataclass.
+
+ Returns:
+ An instantiated SharedMLP constructed using the input args.
+ """
+ if args.mlp_type not in _REGISTRY:
+ raise ValueError(f'Unsupported mlp type: {args.mlp_type}')
+
+ return _REGISTRY[args.mlp_type](args)
diff --git a/build/torch29-cxx11-cu126-aarch64-linux/_megablocks_cuda_6e04dec.abi3.so b/build/torch29-cxx11-cu126-aarch64-linux/_megablocks_cuda_6e04dec.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..20de7d814a05bdd4ce30e7a8742261aa9f3b5f22
--- /dev/null
+++ b/build/torch29-cxx11-cu126-aarch64-linux/_megablocks_cuda_6e04dec.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:581f5d3cd17031f674e6da22c23430881408630004e4ece5a57f9c36583665b5
+size 15121720
diff --git a/build/torch29-cxx11-cu126-aarch64-linux/_ops.py b/build/torch29-cxx11-cu126-aarch64-linux/_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..b2f202b8db3c3f3028303ab4308cf35f950e2c74
--- /dev/null
+++ b/build/torch29-cxx11-cu126-aarch64-linux/_ops.py
@@ -0,0 +1,9 @@
+import torch
+from . import _megablocks_cuda_6e04dec
+ops = torch.ops._megablocks_cuda_6e04dec
+
+def add_op_namespace_prefix(op_name: str):
+ """
+ Prefix op by namespace.
+ """
+ return f"_megablocks_cuda_6e04dec::{op_name}"
diff --git a/build/torch29-cxx11-cu126-aarch64-linux/_version.py b/build/torch29-cxx11-cu126-aarch64-linux/_version.py
new file mode 100644
index 0000000000000000000000000000000000000000..c55783177af19bc03654c730c4892df8f8532279
--- /dev/null
+++ b/build/torch29-cxx11-cu126-aarch64-linux/_version.py
@@ -0,0 +1,6 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+"""The MegaBlocks Version."""
+
+__version__ = '0.11.0.dev0'
diff --git a/build/torch29-cxx11-cu126-aarch64-linux/backend/__init__.py b/build/torch29-cxx11-cu126-aarch64-linux/backend/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d4e43e9b7e6a9a1c3bde2df34914643ca5d8332
--- /dev/null
+++ b/build/torch29-cxx11-cu126-aarch64-linux/backend/__init__.py
@@ -0,0 +1,2 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
diff --git a/build/torch29-cxx11-cu126-aarch64-linux/backend/kernels.py b/build/torch29-cxx11-cu126-aarch64-linux/backend/kernels.py
new file mode 100644
index 0000000000000000000000000000000000000000..d20139324883338ddb312e4b05a72056d74491ac
--- /dev/null
+++ b/build/torch29-cxx11-cu126-aarch64-linux/backend/kernels.py
@@ -0,0 +1,557 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import triton
+import triton.language as tl
+
+# Stub triton autotune when testing in a env that does not have CUDA
+# this approach preserves the original code but enables testing without a GPU
+if torch.cuda.is_available() is False:
+ import warnings
+
+ warnings.warn("CUDA is not available. Triton autotuning is disabled.")
+
+ def _no_autotune(*args, **kwargs):
+ def deco(fn):
+ return fn
+ return deco
+
+ triton.autotune = _no_autotune
+
+
+def assert_is_tensor(x, ndim):
+ if x.ndim != ndim:
+ raise ValueError(f'Expected {ndim}-tensor but got {x.ndim}-tensor')
+
+
+def assert_is_matrix(x):
+ assert_is_tensor(x, 2)
+
+
+def assert_is_vector(x):
+ if x.ndim != 1:
+ raise ValueError(f'Expected 1-tensor but got {x.ndim}-tensor')
+
+
+def assert_equal(a, b):
+ if a != b:
+ raise ValueError(f'Expected dimensions to be equal but got {a} and {b}.',)
+
+
+# a: (tokens, hidden_size), real.
+# indices: (tokens * top_k), integer.
+# bin_ids: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+# padded_bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _padded_copy(
+ a,
+ b,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+ A_TO_B: tl.constexpr,
+ SCALE: tl.constexpr,
+):
+ # Our index into array 'a'.
+ index_a = tl.load(indices + tl.program_id(0))
+
+ # One threadblock per row in 'a'. Array 'b' has greater or equal
+ # number of rows since they could be padded.
+ bin_idx = tl.load(bin_ids + tl.program_id(0))
+
+ # Now we know what bin we're assigned to, but we need to know how
+ # many threadblocks were assigned to earlier bins so we can offset
+ # in our bin properly.
+ offset_in_bin = tl.program_id(0)
+ if bin_idx > 0:
+ offset_in_bin -= tl.load(bins + bin_idx - 1)
+
+ # Load the starting index of our bin in array 'b'.
+ index_b = offset_in_bin
+ if bin_idx > 0:
+ index_b += tl.load(padded_bins + bin_idx - 1)
+
+ # Offset the input and output pointers.
+ #
+ # If we're going from A to B, divide the input index to copy
+ # the same input repeatedly. If we're going from B to A we
+ # need to reduce the result. Using atomics is slow, so we
+ # do the reduce step in a second kernel.
+ offset = index_a // TOP_K if A_TO_B else index_a
+ a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS)
+ b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ # Load the scale, if requested.
+ scale = tl.load(weights + index_a) if SCALE else 1
+
+ # Swap the pointers depending on the direction.
+ iptr = a if A_TO_B else b
+ optr = b if A_TO_B else a
+
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ x = tl.load(iptr + offsets, mask=mask)
+ x = x.to(tl.float32) * scale.to(tl.float32)
+
+ tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask)
+
+ offsets += BLOCK_X
+
+
+def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+ assert_equal(bin_ids.shape[0], x.shape[0] * top_k)
+ assert_equal(bins.size(), padded_bins.size())
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ # NOTE: Because of the padding, the output size is dynamic.
+ # We load the final padded bin bound to get the output rows.
+ output_rows = padded_bins[-1].cpu().item()
+ out = torch.zeros((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ x,
+ out,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def gather(x, indices, bin_ids, weights, bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+ assert_equal(bin_ids.shape[0], x.shape[0] * top_k)
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ # NOTE: There is no padding so the output rows equals the
+ # input rows multiplied by top_k.
+ output_rows = x.shape[0] * top_k
+ out = torch.empty((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ x,
+ out,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], bin_ids.shape[0])
+ assert_equal(bins.size(), padded_bins.size())
+
+ if weights is not None:
+ assert_equal(indices.shape[0], weights.shape[0])
+
+ tokens = indices.shape[0] // top_k
+ out = torch.empty((tokens, top_k, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ out,
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=False,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+
+ # Reduce along the top-k dimension, if needed.
+ return out.sum(dim=1) if top_k > 1 else out.view(tokens, x.shape[1])
+
+
+def scatter(x, indices, bin_ids, weights, bins, top_k):
+ return padded_scatter(x, indices, bin_ids, weights, bins, bins, top_k)
+
+
+# x: (tokens, top_k, hidden_size), real
+# grad: (tokens, hidden_size), real.
+# wgrad: (tokens, top_k), real.
+# indices: (tokens * top_k), integer.
+# bin_ids: (tokens * top_k), integer.
+# bins: (num_experts), integer.
+# padded_bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _padded_copy_wgrad(
+ x,
+ grad,
+ wgrad,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+):
+ # Our index into 'tokens * top_k'.
+ index_out = tl.load(indices + tl.program_id(0))
+
+ # One threadblock per row in 'a'. Array 'b' has greater or equal
+ # number of rows since they could be padded.
+ bin_idx = tl.load(bin_ids + tl.program_id(0))
+
+ # Now we know what bin we're assigned to, but we need to know how
+ # many threadblocks were assigned to earlier bins so we can offset
+ # in our bin properly.
+ offset_in_bin = tl.program_id(0)
+ if bin_idx > 0:
+ offset_in_bin -= tl.load(bins + bin_idx - 1)
+
+ # Load the starting index of our bin in array 'x'.
+ index_x = offset_in_bin
+ if bin_idx > 0:
+ index_x += tl.load(padded_bins + bin_idx - 1)
+
+ # Offset the input and output pointers.
+ wgrad += index_out
+ grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS)
+ x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ data = tl.load(x + offsets, mask=mask).to(tl.float32)
+ scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
+ acc += data * scale
+ offsets += BLOCK_X
+
+ # Reduce to get the final result and store.
+ out = tl.sum(acc).to(wgrad.dtype.element_ty)
+ tl.store(wgrad, out)
+
+
+def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_matrix(grad)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], bin_ids.shape[0])
+ assert_equal(bins.size(), padded_bins.size())
+
+ tokens = indices.shape[0] // top_k
+ out = torch.empty((tokens * top_k), dtype=x.dtype, device=x.device)
+ _padded_copy_wgrad[(indices.shape[0],)](
+ x,
+ grad,
+ out,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ TOP_K=top_k,
+ )
+ return out
+
+
+def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k):
+ return padded_scatter_wgrad(x, grad, indices, bin_ids, bins, bins, top_k)
+
+
+# a: (tokens, hidden_size), real.
+# b: (num_experts, expert_capacity, num_columns), real.
+# indices: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _binned_copy(
+ a,
+ b,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+ A_TO_B: tl.constexpr,
+ SCALE: tl.constexpr,
+):
+ # Load our indices into the output.
+ expert_idx = tl.program_id(0)
+ entry_idx = tl.program_id(1)
+
+ # Calculate our offset into the output.
+ index_b = expert_idx * expert_capacity + entry_idx
+
+ # Load the index bounds for our bin and calculate
+ # the number of tokens assigned to our expert.
+ start = 0
+ if expert_idx > 0:
+ start = tl.load(bins + expert_idx - 1)
+ end = tl.load(bins + expert_idx)
+ num_tokens = end - start
+
+ # Calculate our offset into the input. If we don't
+ # have an input exit early.
+ if entry_idx >= num_tokens:
+ return
+ index_a = tl.load(indices + start + entry_idx)
+
+ # Offset the input and output pointers.
+ #
+ # If we're going from A to B, divide the input index to copy
+ # the same input repeatedly. If we're going from B to A we
+ # need to reduce the result. Using atomics is slow, so we
+ # do the reduce step in a second kernel.
+ offset = index_a // TOP_K if A_TO_B else index_a
+ a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS)
+ b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ # Load the scale, if requested.
+ scale = tl.load(weights + index_a) if SCALE else 1
+
+ # Swap the pointers depending on the direction.
+ #
+ # NOTE: We need to zero the output in both directions.
+ iptr = a if A_TO_B else b
+ optr = b if A_TO_B else a
+
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ x = tl.load(iptr + offsets, mask=mask)
+ x = x.to(tl.float32) * scale.to(tl.float32)
+
+ tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask)
+
+ offsets += BLOCK_X
+
+
+def binned_gather(x, indices, weights, bins, expert_capacity, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ num_experts = bins.shape[0]
+ out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
+
+ _binned_copy[(num_experts, expert_capacity)](
+ x,
+ out,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def binned_scatter(x, indices, weights, bins, top_k):
+ # Validate the input shapes.
+ assert_is_tensor(x, 3)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(bins.shape[0], x.shape[0])
+
+ if weights is not None:
+ assert_equal(indices.shape[0], weights.shape[0])
+
+ num_experts, expert_capacity, hidden_size = x.shape
+ tokens = indices.shape[0] // top_k
+ out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
+ _binned_copy[(num_experts, expert_capacity)](
+ out,
+ x,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS=hidden_size,
+ A_TO_B=False,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+
+ # Reduce along the top-k dimension, if needed.
+ return out.sum(dim=1) if top_k > 1 else out.view(tokens, hidden_size)
+
+
+# a: (tokens, hidden_size), real.
+# b: (num_experts, expert_capacity, num_columns), real.
+# indices: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _binned_copy_wgrad(
+ x,
+ grad,
+ wgrad,
+ num_experts,
+ expert_capacity,
+ indices,
+ bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+):
+ # Load our indices into the output.
+ expert_idx = tl.program_id(0)
+ entry_idx = tl.program_id(1)
+
+ # Calculate our offset into the output.
+ index_x = expert_idx * expert_capacity + entry_idx
+
+ # Load the index bounds for our bin and calculate
+ # the number of tokens assigned to our expert.
+ start = 0
+ if expert_idx > 0:
+ start = tl.load(bins + expert_idx - 1)
+ end = tl.load(bins + expert_idx)
+ num_tokens = end - start
+
+ # Calculate our offset into the input. If we don't
+ # have an input exit early.
+ if entry_idx >= num_tokens:
+ return
+ index_out = tl.load(indices + start + entry_idx)
+
+ # Offset the input and output pointers.
+ wgrad += index_out
+ grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS)
+ x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ data = tl.load(x + offsets, mask=mask).to(tl.float32)
+ scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
+ acc += data * scale
+ offsets += BLOCK_X
+
+ # Reduce to get the final result and store.
+ out = tl.sum(acc).to(wgrad.dtype.element_ty)
+ tl.store(wgrad, out)
+
+
+def binned_scatter_wgrad(x, grad, indices, bins, top_k):
+ # Validate the input shapes.
+ assert_is_tensor(x, 3)
+ assert_is_matrix(grad)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(bins.shape[0], x.shape[0])
+
+ num_experts, expert_capacity, hidden_size = x.shape
+ tokens = indices.shape[0] // top_k
+ out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
+ _binned_copy_wgrad[(num_experts, expert_capacity)](
+ x,
+ grad,
+ out,
+ num_experts,
+ expert_capacity,
+ indices,
+ bins,
+ NUM_COLUMNS=hidden_size,
+ TOP_K=top_k,
+ )
+ return out
diff --git a/build/torch29-cxx11-cu126-aarch64-linux/benchmark_util.py b/build/torch29-cxx11-cu126-aarch64-linux/benchmark_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..02612d95e3ead1175a596e2878fa34b5bf85ad6f
--- /dev/null
+++ b/build/torch29-cxx11-cu126-aarch64-linux/benchmark_util.py
@@ -0,0 +1,35 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import numpy as np
+import torch
+
+
+def log_benchmark(name, arguments, time, std):
+ print('=' * 60)
+ print(f'{name} Benchmark')
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean time = {:.3f}ms, std time = {:.3f}ms'.format(time, std))
+ print('=' * 60)
+
+
+def benchmark_function(fn, iterations=100, warmup=10):
+ # Warmup iterations.
+ for _ in range(warmup):
+ fn()
+
+ times = []
+ for i in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+
+ start.record()
+ fn()
+ end.record()
+
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ return np.mean(times), np.std(times)
diff --git a/build/torch29-cxx11-cu126-aarch64-linux/cpu_fused_moe.py b/build/torch29-cxx11-cu126-aarch64-linux/cpu_fused_moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a29e81d6e9c57f64b0a78ac1c0828e45fd9d855
--- /dev/null
+++ b/build/torch29-cxx11-cu126-aarch64-linux/cpu_fused_moe.py
@@ -0,0 +1,311 @@
+# SPDX-License-Identifier: Apache-2.0
+# MegaBlocks CPU Fused MoE Implementation
+#
+# This is a pure Python/PyTorch implementation for CPU.
+# For better performance, consider using the C++ kernel implementation.
+#
+import torch
+import torch.nn.functional as F
+
+
+def swigluoai_activation(gate: torch.Tensor, up: torch.Tensor,
+ alpha: float = 1.702, limit: float = 7.0) -> torch.Tensor:
+ """
+ SwigluOAI activation function used in GptOss models.
+
+ Formula:
+ gate = clamp(gate, max=limit)
+ up = clamp(up, -limit, limit)
+ glu = gate * sigmoid(gate * alpha)
+ output = (up + 1) * glu
+
+ Args:
+ gate: Gate tensor from gate projection
+ up: Up tensor from up projection
+ alpha: Scaling factor for sigmoid (default: 1.702)
+ limit: Clamp limit (default: 7.0)
+
+ Returns:
+ Activated tensor
+ """
+ gate = gate.clamp(max=limit)
+ up = up.clamp(min=-limit, max=limit)
+ glu = gate * torch.sigmoid(gate * alpha)
+ return (up + 1) * glu
+
+
+def silu_and_mul_activation(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor:
+ """
+ SiLU (Swish) activation with element-wise multiplication.
+
+ Formula:
+ output = silu(gate) * up
+
+ Args:
+ gate: Gate tensor
+ up: Up tensor
+
+ Returns:
+ Activated tensor
+ """
+ return F.silu(gate) * up
+
+
+def route_tokens_cpu(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: torch.Tensor | None,
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_normalize_expert_weights: int | None = None,
+) -> tuple:
+ """
+ Route tokens to experts and compute expert weights and indices (CPU version).
+
+ Args:
+ x: Input tensor [batch, seq, hidden] or [tokens, hidden]
+ router_weight: Router weight [num_experts, hidden]
+ router_bias: Router bias [num_experts] or None
+ moe_top_k: Number of experts per token
+ moe_num_experts: Total number of experts
+ moe_normalize_expert_weights: Normalization order or None
+
+ Returns:
+ Tuple of (logits, expert_weights, expert_indices)
+ """
+ x_flat = x.view(-1, x.shape[-1])
+ logits = F.linear(x_flat, router_weight, router_bias)
+
+ if moe_top_k == 1:
+ expert_weights, expert_indices = logits.max(dim=-1, keepdim=True)
+ else:
+ expert_weights, expert_indices = torch.topk(logits, moe_top_k, dim=-1)
+
+ expert_weights = expert_weights.softmax(dim=-1)
+
+ if moe_normalize_expert_weights is not None:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+
+ return logits, expert_weights, expert_indices
+
+
+def cpu_fused_moe(
+ hidden_states: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ w1_bias: torch.Tensor | None = None,
+ w2_bias: torch.Tensor | None = None,
+ activation: str = "silu",
+ alpha: float = 1.702,
+ limit: float = 7.0,
+ is_interleaved: bool = True,
+) -> torch.Tensor:
+ """
+ CPU Fused MoE using PyTorch operations.
+
+ This implementation processes all experts in parallel using batched operations
+ instead of sequential for loops, which is more efficient on CPU.
+
+ Args:
+ hidden_states: [num_tokens, hidden_size]
+ w1: [num_experts, hidden_size, 2*inter_size] - gate_up_proj weights
+ w2: [num_experts, inter_size, hidden_size] - down_proj weights
+ topk_weights: [num_tokens, topk] - routing weights
+ topk_ids: [num_tokens, topk] - expert indices
+ w1_bias: [num_experts, 2*inter_size] or None
+ w2_bias: [num_experts, hidden_size] or None
+ activation: "silu" or "swigluoai"
+ alpha: swigluoai alpha parameter
+ limit: swigluoai limit parameter
+ is_interleaved: whether gate_up is interleaved [g0,u0,g1,u1,...] (True for GptOss)
+
+ Returns:
+ output: [num_tokens, hidden_size]
+ """
+ num_tokens, hidden_size = hidden_states.shape
+ num_experts = w1.shape[0]
+ inter_size = w2.shape[1]
+ topk = topk_weights.shape[1]
+
+ # Initialize output
+ output = torch.zeros_like(hidden_states)
+
+ # Build expert mask: which tokens go to which expert
+ # expert_mask[expert_id] contains indices of (token_idx, topk_pos) pairs
+ for expert_idx in range(num_experts):
+ # Find tokens assigned to this expert
+ # mask shape: [num_tokens, topk], True where topk_ids == expert_idx
+ mask = (topk_ids == expert_idx)
+
+ if not mask.any():
+ continue
+
+ # Get token indices and topk positions
+ token_indices, topk_positions = torch.where(mask)
+
+ if len(token_indices) == 0:
+ continue
+
+ # Gather input tokens for this expert
+ # current_hidden: [num_selected_tokens, hidden_size]
+ current_hidden = hidden_states[token_indices]
+
+ # Get weights for this expert
+ # w1[expert_idx]: [hidden_size, 2*inter_size]
+ # w2[expert_idx]: [inter_size, hidden_size]
+ expert_w1 = w1[expert_idx] # [hidden_size, 2*inter_size]
+ expert_w2 = w2[expert_idx] # [inter_size, hidden_size]
+
+ # First projection: hidden @ w1 -> [num_selected, 2*inter_size]
+ gate_up = current_hidden @ expert_w1
+
+ # Add bias if present
+ if w1_bias is not None:
+ gate_up = gate_up + w1_bias[expert_idx]
+
+ # Split gate and up projections
+ if is_interleaved:
+ # GptOss uses interleaved layout: [g0, u0, g1, u1, ...]
+ gate = gate_up[..., ::2] # [num_selected, inter_size]
+ up = gate_up[..., 1::2] # [num_selected, inter_size]
+ else:
+ # Standard layout: [gate_all, up_all]
+ gate = gate_up[..., :inter_size]
+ up = gate_up[..., inter_size:]
+
+ # Apply activation
+ if activation == "swigluoai":
+ activated = swigluoai_activation(gate, up, alpha, limit)
+ else: # silu
+ activated = silu_and_mul_activation(gate, up)
+
+ # Second projection: activated @ w2 -> [num_selected, hidden_size]
+ expert_out = activated @ expert_w2
+
+ # Add bias if present
+ if w2_bias is not None:
+ expert_out = expert_out + w2_bias[expert_idx]
+
+ # Apply routing weights and accumulate
+ # weights shape: [num_selected]
+ weights = topk_weights[token_indices, topk_positions].unsqueeze(-1)
+ weighted_out = expert_out * weights
+
+ # Accumulate to output
+ output.index_add_(0, token_indices, weighted_out.to(output.dtype))
+
+ return output
+
+
+class MegaBlocksMoeMLP(torch.nn.Module):
+ """
+ CPU MoE MLP module that can be used as a drop-in replacement for
+ the transformers GptOssMLP when using @use_kernel_forward_from_hub.
+ """
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> tuple:
+ """
+ Forward pass through the MoE layer.
+
+ Args:
+ x: Input tensor of shape [batch_size, seq_len, hidden_size] or [tokens, hidden_size]
+
+ Returns:
+ Tuple of (output, expert_weights) where:
+ - output: Tensor of same shape as input
+ - expert_weights: Expert weights for each token [tokens, top_k]
+ """
+ # Get MoE parameters from the wrapped modules
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+
+ # Detect activation type
+ if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
+ activation = "swigluoai"
+ alpha = self.experts.alpha
+ limit = self.experts.limit
+ else:
+ activation = getattr(self.experts, "activation", "silu")
+ alpha = 1.702
+ limit = 7.0
+
+ # Get weight tensors
+ if hasattr(self.experts, "gate_up_proj"):
+ w1 = self.experts.gate_up_proj
+ is_interleaved = True # GptOss uses interleaved layout
+ elif hasattr(self.experts, "w1"):
+ w1 = self.experts.w1
+ w3 = getattr(self.experts, "w3", None)
+ if w3 is not None:
+ w1 = torch.cat([w1, w3], dim=-1)
+ is_interleaved = False
+ else:
+ raise AttributeError("experts module must have 'gate_up_proj' or 'w1' attribute")
+
+ if hasattr(self.experts, "down_proj"):
+ w2 = self.experts.down_proj
+ elif hasattr(self.experts, "w2"):
+ w2 = self.experts.w2
+ else:
+ raise AttributeError("experts module must have 'down_proj' or 'w2' attribute")
+
+ # Get optional bias tensors
+ w1_bias = getattr(self.experts, "gate_up_proj_bias", None)
+ w2_bias = getattr(self.experts, "down_proj_bias", None)
+
+ # Store original shape
+ in_shape = x.size()
+
+ # Route tokens to experts
+ logits, expert_weights, expert_indices = route_tokens_cpu(
+ x,
+ self.router.weight,
+ getattr(self.router, "bias", None),
+ moe_top_k,
+ moe_num_experts,
+ moe_normalize_expert_weights,
+ )
+
+ # Reshape input for fused MoE
+ x_flat = x.view(-1, x.shape[-1])
+
+ # Call CPU fused MoE
+ output = cpu_fused_moe(
+ hidden_states=x_flat,
+ w1=w1,
+ w2=w2,
+ topk_weights=expert_weights,
+ topk_ids=expert_indices,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ activation=activation,
+ alpha=alpha,
+ limit=limit,
+ is_interleaved=is_interleaved,
+ )
+
+ # Restore original shape
+ output = output.view(in_shape)
+
+ return output, expert_weights
+
+
+# Export classes and functions
+__all__ = [
+ "MegaBlocksMoeMLP",
+ "cpu_fused_moe",
+ "route_tokens_cpu",
+ "swigluoai_activation",
+ "silu_and_mul_activation",
+]
diff --git a/build/torch29-cxx11-cu126-aarch64-linux/cpu_moe_cpp.py b/build/torch29-cxx11-cu126-aarch64-linux/cpu_moe_cpp.py
new file mode 100644
index 0000000000000000000000000000000000000000..073ff66d24ce348fbb5ed19c9027fadd3f7a9c61
--- /dev/null
+++ b/build/torch29-cxx11-cu126-aarch64-linux/cpu_moe_cpp.py
@@ -0,0 +1,265 @@
+# SPDX-License-Identifier: Apache-2.0
+# MegaBlocks C++ Optimized CPU MoE
+
+"""
+C++ accelerated MoE with brgemm optimization for Intel AMX.
+Direct replacement for cpu_fused_moe.MegaBlocksMoeMLP with better performance.
+"""
+
+import torch
+from typing import Optional
+from .cpu_fused_moe import route_tokens_cpu
+from ._ops import ops
+
+
+def _to_local_tensor(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
+ """Convert DTensor to local torch.Tensor if needed for custom ops compatibility."""
+ if tensor is None:
+ return None
+ # Check if it's a DTensor by looking for the to_local() method
+ if hasattr(tensor, "to_local"):
+ return tensor.to_local()
+ return tensor
+
+
+def fused_moe_cpp(
+ hidden_states: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ inplace: bool = False,
+ use_int8_w8a8: bool = False,
+ use_fp8_w8a16: bool = False,
+ use_mxfp4: bool = False,
+ w1_scale: Optional[torch.Tensor] = None,
+ w2_scale: Optional[torch.Tensor] = None,
+ block_size: Optional[list] = None,
+ a1_scale: Optional[torch.Tensor] = None,
+ a2_scale: Optional[torch.Tensor] = None,
+ w1_bias: Optional[torch.Tensor] = None,
+ w2_bias: Optional[torch.Tensor] = None,
+ alpha: Optional[float] = None,
+ limit: Optional[float] = None,
+ is_vnni: bool = False,
+) -> torch.Tensor:
+ """
+ C++ Fused MoE with brgemm optimization (sglang compatible interface).
+
+ Uses at::native::cpublas::brgemm for efficient batch GEMM on Intel CPUs.
+ Supports both silu_and_mul (standard SwiGLU) and swigluoai (GptOss) activations.
+
+ Args:
+ hidden_states: Input tensor [M, K]
+ w1: Gate and up projections [E, 2N, K]
+ w2: Down projection [E, K, N]
+ topk_weights: Expert weights [M, topk]
+ topk_ids: Expert indices [M, topk]
+ inplace: Whether to use hidden_states as output
+ use_int8_w8a8: Use int8 quantization
+ use_fp8_w8a16: Use fp8 quantization
+ use_mxfp4: Use mxfp4 quantization
+ w1_scale, w2_scale: Quantization scales
+ block_size: Block size for fp8
+ a1_scale, a2_scale: Activation scales
+ w1_bias, w2_bias: Optional biases
+ alpha: swigluoai alpha parameter (set to enable swiglu)
+ limit: swigluoai limit parameter (set to enable swiglu)
+ is_vnni: Whether w1/w2 are already in VNNI packed format
+ """
+ # MXFP4/FP8 kernels only support bf16, convert if needed
+ orig_dtype = hidden_states.dtype
+ need_convert = ((use_mxfp4 or use_fp8_w8a16) and orig_dtype != torch.bfloat16) or orig_dtype == torch.float32
+ if need_convert:
+ hidden_states = hidden_states.to(torch.bfloat16)
+
+ # bias must match hidden_states dtype
+ if w1_bias is not None:
+ w1_bias = w1_bias.to(hidden_states.dtype)
+ if w2_bias is not None:
+ w2_bias = w2_bias.to(hidden_states.dtype)
+
+ # Convert DTensor to local tensor for custom ops compatibility (TP mode)
+ hidden_states = _to_local_tensor(hidden_states)
+ w1 = _to_local_tensor(w1)
+ w2 = _to_local_tensor(w2)
+ topk_weights = _to_local_tensor(topk_weights)
+ topk_ids = _to_local_tensor(topk_ids)
+ w1_scale = _to_local_tensor(w1_scale)
+ w2_scale = _to_local_tensor(w2_scale)
+ a1_scale = _to_local_tensor(a1_scale)
+ a2_scale = _to_local_tensor(a2_scale)
+ w1_bias = _to_local_tensor(w1_bias)
+ w2_bias = _to_local_tensor(w2_bias)
+
+ output = ops.fused_experts(
+ hidden_states, w1, w2, topk_weights, topk_ids,
+ inplace, use_int8_w8a8, use_fp8_w8a16, use_mxfp4,
+ w1_scale, w2_scale, block_size, a1_scale, a2_scale,
+ w1_bias, w2_bias, alpha, limit, is_vnni
+ )
+
+ # Convert back to original dtype if needed
+ if need_convert:
+ output = output.to(orig_dtype)
+ return output
+
+
+class CPUMegaBlocksMoeMLP(torch.nn.Module):
+ """
+ C++ optimized MoE MLP using brgemm.
+ Drop-in replacement for cpu_fused_moe.MegaBlocksMoeMLP with better performance.
+
+ Usage in transformers:
+ # Will be used via @use_kernel_forward_from_hub decorator
+ """
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> tuple:
+ """
+ Forward pass through the MoE layer using C++ kernel.
+
+ Args:
+ x: Input tensor [batch_size, seq_len, hidden_size]
+
+ Returns:
+ Tuple of (output, expert_weights)
+ """
+ # Optimization for GPT-OSS model
+ if getattr(self, "use_mxfp4", None) is None:
+ self.use_mxfp4 = False
+
+ w1_scale = None
+ w2_scale = None
+
+ if (
+ not getattr(self, "packed_scales", False)
+ and hasattr(self.experts, "gate_up_proj")
+ and getattr(self.experts, "gate_up_proj_precision_config", None) is not None
+ ):
+ # convert scales
+ data_1 = ops.convert_scale_packed(self.experts.gate_up_proj_precision_config.weight_scale.data.transpose(-1, -2).contiguous())
+ data_2 = ops.convert_scale_packed(self.experts.down_proj_precision_config.weight_scale.data.transpose(-1, -2).contiguous())
+ self.experts.gate_up_proj_precision_config.weight_scale.storage.data = data_1
+ self.experts.down_proj_precision_config.weight_scale.storage.data = data_2
+ self.packed_scales = True
+ self.use_mxfp4 = True
+
+ if not getattr(self, "packed_weight", False) and hasattr(
+ self.experts, "gate_up_proj"
+ ):
+ # convert weights
+ data_1 = self.experts.gate_up_proj.data.transpose(-1, -2).contiguous()
+ data_2 = self.experts.down_proj.data.transpose(-1, -2).contiguous()
+ if self.use_mxfp4:
+ self.experts.gate_up_proj.storage.data = ops.convert_weight_packed(data_1)
+ self.experts.down_proj.storage.data = ops.convert_weight_packed(data_2)
+ else:
+ # convert_weight_packed only supports bfloat16, float16, int8, fp8_e4m3 or uint8(mxfp4 or int4).
+ data_1 = data_1.to(torch.bfloat16) if data_1.dtype == torch.float32 else data_1
+ data_2 = data_2.to(torch.bfloat16) if data_2.dtype == torch.float32 else data_2
+ self.experts.gate_up_proj.data = ops.convert_weight_packed(data_1)
+ self.experts.down_proj.data = ops.convert_weight_packed(data_2)
+
+ # C++ kernel does not support float32.
+ dtype = torch.bfloat16 if x.dtype == torch.float32 else x.dtype
+ if getattr(self.experts, "gate_up_proj_bias", None) is not None:
+ self.experts.gate_up_proj_bias.data = self.experts.gate_up_proj_bias.data.to(dtype)
+ if getattr(self.experts, "down_proj_bias", None) is not None:
+ self.experts.down_proj_bias.data = self.experts.down_proj_bias.data.to(dtype)
+
+ self.packed_weight = True
+
+ # Get MoE parameters
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
+
+ # Detect activation type
+ if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
+ activation = "swigluoai"
+ alpha = self.experts.alpha
+ limit = self.experts.limit
+ else:
+ activation = getattr(self.experts, "activation", "silu")
+ alpha = 1.702
+ limit = 7.0
+
+ # Get weight tensors
+ if hasattr(self.experts, "gate_up_proj"):
+ w1 = self.experts.gate_up_proj
+ elif hasattr(self.experts, "w1"):
+ w1 = self.experts.w1
+ w3 = getattr(self.experts, "w3", None)
+ if w3 is not None:
+ w1 = torch.cat([w1, w3], dim=-1)
+ else:
+ raise AttributeError("experts module must have 'gate_up_proj' or 'w1' attribute")
+
+ if hasattr(self.experts, "down_proj"):
+ w2 = self.experts.down_proj
+ elif hasattr(self.experts, "w2"):
+ w2 = self.experts.w2
+ else:
+ raise AttributeError("experts module must have 'down_proj' or 'w2' attribute")
+
+ # Get optional bias tensors
+ w1_bias = getattr(self.experts, "gate_up_proj_bias", None)
+ w2_bias = getattr(self.experts, "down_proj_bias", None)
+ w1_bias = w1_bias if w1_bias is None else w1_bias.data
+ w2_bias = w2_bias if w2_bias is None else w2_bias.data
+
+ if self.use_mxfp4:
+ w1_scale = self.experts.gate_up_proj_precision_config.weight_scale.data
+ w2_scale = self.experts.down_proj_precision_config.weight_scale.data
+
+ # Store original shape
+ in_shape = x.size()
+
+ # Route tokens to experts (Python implementation is fast enough)
+ logits, expert_weights, expert_indices = route_tokens_cpu(
+ x,
+ self.router.weight,
+ getattr(self.router, "bias", None),
+ moe_top_k,
+ moe_num_experts,
+ moe_normalize_expert_weights,
+ )
+
+ # Flatten input
+ x_flat = x.view(-1, x.shape[-1])
+
+ # Determine alpha/limit for swiglu activation
+ use_alpha = alpha if activation == "swigluoai" else None
+ use_limit = limit if activation == "swigluoai" else None
+
+ # Call C++ optimized kernel
+ output = fused_moe_cpp(
+ hidden_states=x_flat,
+ w1=w1.data,
+ w2=w2.data,
+ topk_weights=expert_weights,
+ topk_ids=expert_indices.to(torch.int32),
+ inplace=False,
+ use_int8_w8a8=False,
+ use_fp8_w8a16=False,
+ use_mxfp4=self.use_mxfp4,
+ w1_scale=w1_scale,
+ w2_scale=w2_scale,
+ block_size=None,
+ a1_scale=None,
+ a2_scale=None,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ alpha=use_alpha,
+ limit=use_limit,
+ is_vnni=getattr(self, "packed_weight", False),
+ )
+
+ # Restore original shape
+ output = output.view(in_shape)
+
+ return output, expert_weights
+
+
+__all__ = ["fused_moe_cpp", "MegaBlocksMoeMLP"]
diff --git a/build/torch29-cxx11-cu126-aarch64-linux/grouped_gemm/__init__.py b/build/torch29-cxx11-cu126-aarch64-linux/grouped_gemm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b91c8308f0c24f4c4171b6e4f15b6f76dabf295a
--- /dev/null
+++ b/build/torch29-cxx11-cu126-aarch64-linux/grouped_gemm/__init__.py
@@ -0,0 +1,2 @@
+from . import ops
+from . import backend
diff --git a/build/torch29-cxx11-cu126-aarch64-linux/grouped_gemm/backend.py b/build/torch29-cxx11-cu126-aarch64-linux/grouped_gemm/backend.py
new file mode 100644
index 0000000000000000000000000000000000000000..76037d8039cbfc2f0577275c78e4bc0be762592a
--- /dev/null
+++ b/build/torch29-cxx11-cu126-aarch64-linux/grouped_gemm/backend.py
@@ -0,0 +1,33 @@
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# # TODO(tgale): Wrap this in a try-block with better
+# # error message and instructions for building the
+# # c++ operations.
+# import grouped_gemm_backend as backend
+
+# We import the backend operations from the megablocks package as
+# grouped_gemm is vendored in megablocks in this repository.
+# from ... import _ops as backend
+# from megablocks._ops import ops as backend # type: ignore
+from .._ops import ops as backend # type: ignore
+
+def _allocate_output(a, b, batch_sizes, trans_a, trans_b):
+ assert not (trans_a and trans_b)
+ assert batch_sizes.ndim == 1, "Expected 1d tensor for batch_sizes"
+ assert a.ndim == 2, "Expected 2d tensor for 'a'"
+ assert b.ndim == (2 if trans_a else 3)
+
+ shape = (
+ (batch_sizes.shape[0], a.shape[1], b.shape[1])
+ if trans_a else
+ (a.shape[0], (b.shape[1] if trans_b else b.shape[2]))
+ )
+ return torch.empty(*shape, device=a.device, dtype=a.dtype)
+
+def gmm(a, b, batch_sizes, trans_a=False, trans_b=False, c=None):
+ if c is None:
+ c = _allocate_output(a, b, batch_sizes, trans_a, trans_b)
+ backend.gmm(a, b, c, batch_sizes, trans_a, trans_b)
+ return c
diff --git a/build/torch29-cxx11-cu126-aarch64-linux/grouped_gemm/ops.py b/build/torch29-cxx11-cu126-aarch64-linux/grouped_gemm/ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b30dd14e23837ea3b12334f4e31337ed9ad2b69
--- /dev/null
+++ b/build/torch29-cxx11-cu126-aarch64-linux/grouped_gemm/ops.py
@@ -0,0 +1,33 @@
+from . import backend
+import torch
+
+
+class GroupedGemm(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, a, b, batch_sizes, trans_b):
+ ctx.save_for_backward(a, b, batch_sizes)
+ ctx.trans_b = trans_b
+ return backend.gmm(a, b, batch_sizes, trans_a=False, trans_b=trans_b)
+
+ @staticmethod
+ def backward(ctx, grad):
+ grad = grad.contiguous()
+ a, b, batch_sizes = ctx.saved_tensors
+ trans_b = ctx.trans_b
+
+ agrad = None
+ if ctx.needs_input_grad[0]:
+ agrad = backend.gmm(
+ grad, b, batch_sizes, trans_a=False, trans_b=not trans_b)
+
+ bgrad = None
+ if ctx.needs_input_grad[1]:
+ lhs, rhs = (grad, a) if trans_b else (a, grad)
+ bgrad = backend.gmm(
+ lhs, rhs, batch_sizes, trans_a=True, trans_b=False)
+ return agrad, bgrad, None, None
+
+
+def gmm(a, b, batch_sizes, trans_b=False):
+ return GroupedGemm.apply(a, b, batch_sizes, trans_b)
diff --git a/build/torch29-cxx11-cu126-aarch64-linux/grouped_gemm_util.py b/build/torch29-cxx11-cu126-aarch64-linux/grouped_gemm_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d6d49fc46b0a57ad46e4179df3cc1ac2a24f7ae
--- /dev/null
+++ b/build/torch29-cxx11-cu126-aarch64-linux/grouped_gemm_util.py
@@ -0,0 +1,31 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+import warnings
+
+_grouped_gemm_is_available: bool = False
+try:
+ # import grouped_gemm
+ pass
+ _grouped_gemm_is_available = True
+except ImportError as error:
+ warnings.warn('Grouped GEMM not available.')
+
+
+def grouped_gemm_is_available():
+ return _grouped_gemm_is_available
+
+
+def assert_grouped_gemm_is_available():
+ msg = (
+ 'Grouped GEMM not available. Please run '
+ '`pip install git+https://github.com/tgale96/grouped_gemm@main`.',
+ )
+ assert _grouped_gemm_is_available, msg
+
+
+# backend = grouped_gemm.backend if grouped_gemm_is_available() else None
+# ops = grouped_gemm.ops if grouped_gemm_is_available() else None
+
+
+from .grouped_gemm import backend as ops
+from .grouped_gemm import ops as backend
diff --git a/build/torch29-cxx11-cu126-aarch64-linux/layers.py b/build/torch29-cxx11-cu126-aarch64-linux/layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c91edcd9e2d1a4ef9eac90217ff481f08ab1886
--- /dev/null
+++ b/build/torch29-cxx11-cu126-aarch64-linux/layers.py
@@ -0,0 +1,1232 @@
+import torch
+import torch.distributed as dist
+
+from typing import Optional, Any, TYPE_CHECKING
+
+from . import _layers
+from . import ops
+
+# Conditional import for meta kernel registration
+if TYPE_CHECKING:
+
+ def register_fake(fn):
+ return lambda name: fn
+
+else:
+ try:
+ from torch.library import register_fake
+ except ImportError:
+ try:
+ from torch.library import impl_abstract as register_fake
+ except ImportError:
+ # Fallback for older PyTorch versions
+ def register_fake(op_name):
+ def decorator(fn):
+ return fn
+
+ return decorator
+
+
+# Meta kernel implementations for torch.compile compatibility
+def _install_meta_kernels():
+ """Install meta kernels for existing MegaBlocks operations"""
+
+ # Create wrapper functions that check for compilation and return meta tensors
+
+ # Patch ops.sort
+ if hasattr(ops, "sort"):
+ original_sort = ops.sort
+
+ def sort_with_meta(x, end_bit=None):
+ if torch.compiler.is_compiling():
+ # print("Using meta kernel for sort")
+ # Meta implementation - return tensors with correct shape/dtype/device
+ return torch.empty_like(x), torch.empty_like(x)
+ # print("Using original sort kernel")
+ return original_sort(x, end_bit)
+
+ ops.sort = sort_with_meta
+
+ # Patch ops.histogram
+ if hasattr(ops, "histogram"):
+ original_histogram = ops.histogram
+
+ def histogram_with_meta(x, max_val):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty((max_val,), dtype=torch.int32, device=x.device)
+ return original_histogram(x, max_val)
+
+ ops.histogram = histogram_with_meta
+
+ # Patch ops.inclusive_cumsum
+ if hasattr(ops, "inclusive_cumsum"):
+ original_inclusive_cumsum = ops.inclusive_cumsum
+
+ def inclusive_cumsum_with_meta(x, dim):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty_like(x)
+ return original_inclusive_cumsum(x, dim)
+
+ ops.inclusive_cumsum = inclusive_cumsum_with_meta
+
+ # Patch ops.binned_gather
+ if hasattr(ops, "binned_gather"):
+ original_binned_gather = ops.binned_gather
+
+ def binned_gather_with_meta(x, indices, bins, bin_size, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - output shape based on bin_size
+ if x.dim() >= 2:
+ hidden_size = x.size(-1)
+ return torch.empty(
+ (bin_size, x.size(1), hidden_size),
+ dtype=x.dtype,
+ device=x.device,
+ )
+ else:
+ return torch.empty((bin_size,), dtype=x.dtype, device=x.device)
+ return original_binned_gather(x, indices, bins, bin_size, top_k)
+
+ ops.binned_gather = binned_gather_with_meta
+
+ # Patch ops.binned_scatter
+ if hasattr(ops, "binned_scatter"):
+ original_binned_scatter = ops.binned_scatter
+
+ def binned_scatter_with_meta(x, indices, weights, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - typically reduces to 2D
+ if x.dim() >= 3:
+ return torch.empty(
+ (x.size(1), x.size(2)), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty_like(x)
+ return original_binned_scatter(x, indices, weights, bins, top_k)
+
+ ops.binned_scatter = binned_scatter_with_meta
+
+ # Patch ops.gather
+ if hasattr(ops, "gather"):
+ original_gather = ops.gather
+
+ def gather_with_meta(x, indices, bin_ids, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ if x.dim() >= 2:
+ hidden_size = x.size(-1)
+ return torch.empty(
+ (indices.numel(), hidden_size), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty(indices.shape, dtype=x.dtype, device=x.device)
+ return original_gather(x, indices, bin_ids, bins, top_k)
+
+ ops.gather = gather_with_meta
+
+ # Patch ops.scatter
+ if hasattr(ops, "scatter"):
+ original_scatter = ops.scatter
+
+ def scatter_with_meta(x, indices, bin_ids, weights, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - restore sequence shape
+ seq_len = (
+ indices.size(0) // top_k
+ if indices.numel() > 0 and top_k > 0
+ else x.size(0)
+ )
+ if x.dim() >= 2:
+ return torch.empty(
+ (seq_len, x.size(-1)), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty((seq_len,), dtype=x.dtype, device=x.device)
+ return original_scatter(x, indices, bin_ids, weights, bins, top_k)
+
+ ops.scatter = scatter_with_meta
+
+ # Patch ops.replicate
+ if hasattr(ops, "replicate"):
+ original_replicate = ops.replicate
+
+ def replicate_with_meta(x, bins, num_outputs):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty(
+ (x.shape[0], num_outputs), dtype=x.dtype, device=x.device
+ )
+ return original_replicate(x, bins, num_outputs)
+
+ ops.replicate = replicate_with_meta
+
+ # Patch ops.repeat (if it's a regular function)
+ if hasattr(ops, "repeat"):
+ original_repeat = ops.repeat
+
+ def repeat_with_meta(x, repeats):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ if isinstance(repeats, (tuple, list)):
+ new_shape = list(x.shape)
+ for i, rep in enumerate(repeats):
+ if i < len(new_shape):
+ new_shape[i] *= rep
+ return torch.empty(new_shape, dtype=x.dtype, device=x.device)
+ else:
+ new_shape = [x.size(0) * repeats] + list(x.shape[1:])
+ return torch.empty(new_shape, dtype=x.dtype, device=x.device)
+ return original_repeat(x, repeats)
+
+ ops.repeat = repeat_with_meta
+
+
+# Install meta kernels on import
+try:
+ _install_meta_kernels()
+except Exception as e:
+ # If meta kernel installation fails, continue without them
+ # torch.compile may not work but the library will still function
+ import warnings
+
+ warnings.warn(
+ f"Failed to install meta kernels for torch.compile support: {e}", UserWarning
+ )
+
+
+# Set the expert model parallel attributes on a tensor
+def set_expert_model_parallel_attributes(
+ tensor: torch.Tensor,
+ is_parallel: bool,
+):
+ assert not hasattr(tensor, "expert_model_parallel")
+ setattr(tensor, "expert_model_parallel", is_parallel)
+
+
+# Get the expert model parallel attributes from a tensor
+def expert_sharding_degree(
+ world_size: int,
+ moe_num_experts: int,
+) -> int:
+ esd = min(world_size, moe_num_experts)
+ if (moe_num_experts % esd) != 0:
+ raise ValueError(f"Cannot shard {moe_num_experts} experts {esd} ways.")
+ return esd
+
+
+# Calculate the hidden sharding degree based on world size and expert sharding degree
+def hidden_sharding_degree(
+ world_size: int,
+ moe_num_experts: int,
+ ffn_hidden_size: int,
+) -> int:
+ esd = expert_sharding_degree(world_size, moe_num_experts)
+ hsd = world_size // esd
+ if (ffn_hidden_size % hsd) != 0:
+ raise ValueError(f"Cannot shard {ffn_hidden_size} features {hsd} ways.")
+ if (esd * hsd) != world_size:
+ raise ValueError(
+ f"Invalid sharding. expert_sharding_degree ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size})."
+ )
+ return hsd
+
+
+# Calculate the number of experts per rank based on world size and expert sharding degree
+def experts_per_rank(
+ moe_num_experts: int,
+ world_size: int,
+) -> int:
+ return moe_num_experts // expert_sharding_degree(world_size, moe_num_experts)
+
+
+# Calculate the number of features per rank based on ffn hidden size and hidden sharding degree
+def features_per_rank(
+ ffn_hidden_size: int, world_size: int, moe_num_experts: int
+) -> int:
+ return ffn_hidden_size // hidden_sharding_degree(
+ world_size, moe_num_experts, ffn_hidden_size
+ )
+
+
+# Apply jitter to the input tensor
+def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor:
+ low = 1.0 - moe_jitter_eps
+ high = 1.0 + moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return x * (low + noise * (high - low))
+
+
+# Compute the top-k scores from the logits
+def compute_top_k(scores: torch.Tensor, moe_top_k: int):
+ if moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, moe_top_k, dim=-1)
+
+
+# Route tokens to experts and compute expert weights and indices
+def route_tokens(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: torch.Tensor,
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ if training and moe_jitter_eps is not None:
+ x = apply_jitter(x, moe_jitter_eps)
+
+ x_flat = x.view(-1, x.shape[-1])
+ logits = torch.nn.functional.linear(x_flat, router_weight, router_bias)
+ expert_weights, expert_indices = compute_top_k(logits, moe_top_k)
+ expert_weights = expert_weights.softmax(dim=-1)
+ if moe_normalize_expert_weights is not None:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+ if uniform_expert_assignment:
+ expert_indices = _layers.router._uniform_expert_assignment(
+ expert_indices,
+ moe_num_experts,
+ )
+
+ return logits, expert_weights, expert_indices
+
+
+# Scale the gradient of the weights
+def scale_grad(
+ w: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+) -> torch.Tensor:
+ if gradient_scale is None:
+ return w
+ return _layers.mlp.scale_gradient(w, gradient_scale)
+
+
+# Forward pass for the MLP layer
+def mlp_forward(
+ x: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ limit: float = 7.0,
+):
+ # Scale weights
+ w1 = scale_grad(w1, gradient_scale)
+ w2 = scale_grad(w2, gradient_scale)
+ w1_bias = scale_grad(w1_bias, gradient_scale)
+ w2_bias = scale_grad(w2_bias, gradient_scale)
+
+ # Resolve dtensors
+ w1 = _layers.mlp.resolve_dtensor(w1)
+ w2 = _layers.mlp.resolve_dtensor(w2)
+ w1_bias = _layers.mlp.resolve_dtensor(w1_bias)
+ w2_bias = _layers.mlp.resolve_dtensor(w2_bias)
+
+ # Forward pass
+ gate_up = torch.bmm(x, w1) + w1_bias[..., None, :]
+ gate, up = gate_up[..., ::2], gate_up[..., 1::2]
+ gate = gate.clamp(min=None, max=limit)
+ up = up.clamp(min=-limit, max=limit)
+ glu = gate * torch.sigmoid(gate * alpha)
+ next_states = torch.bmm(((up + 1) * glu), w2)
+ next_states += w2_bias[..., None, :]
+ return next_states
+
+# Shared expert MLP forward pass
+def shared_mlp_forward(
+ x: torch.Tensor,
+ up_proj_weight: torch.Tensor,
+ down_proj_weight: torch.Tensor,
+ up_proj_bias: Optional[torch.Tensor] = None,
+ down_proj_bias: Optional[torch.Tensor] = None,
+ activation_fn: Optional[Any] = None,
+ gradient_scale: Optional[float] = None,
+) -> torch.Tensor:
+ # Default activation function
+ if activation_fn is None:
+ activation_fn = torch.nn.functional.gelu
+
+ # Scale weights
+ up_proj_weight = scale_grad(up_proj_weight, gradient_scale)
+ down_proj_weight = scale_grad(down_proj_weight, gradient_scale)
+ if up_proj_bias is not None:
+ up_proj_bias = scale_grad(up_proj_bias, gradient_scale)
+ if down_proj_bias is not None:
+ down_proj_bias = scale_grad(down_proj_bias, gradient_scale)
+
+ # Resolve dtensors
+ up_proj_weight = _layers.mlp.resolve_dtensor(up_proj_weight)
+ down_proj_weight = _layers.mlp.resolve_dtensor(down_proj_weight)
+ if up_proj_bias is not None:
+ up_proj_bias = _layers.mlp.resolve_dtensor(up_proj_bias)
+ if down_proj_bias is not None:
+ down_proj_bias = _layers.mlp.resolve_dtensor(down_proj_bias)
+
+ # Up projection
+ x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias)
+
+ # Activation
+ x = activation_fn(x)
+
+ # Down projection
+ x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias)
+
+ return x
+
+
+# Combine outputs from shared expert and regular experts
+def combine_expert_shared_outputs(
+ shared_expert_out: torch.Tensor,
+ expert_out: torch.Tensor,
+ shared_expert_weighted_sum: bool = False,
+ moe_top_k: int = 1,
+) -> torch.Tensor:
+ if shared_expert_weighted_sum:
+ # Weighted sum based on number of experts used
+ total_experts = moe_top_k + 1
+ shared_weight = 1.0 / total_experts
+ expert_weight = moe_top_k / total_experts
+ return shared_expert_out * shared_weight + expert_out * expert_weight
+ else:
+ # Simple addition
+ return shared_expert_out + expert_out
+
+
+# Global variable to store load balancing loss
+_LOAD_BALANCING_LOSS = []
+
+
+def save_load_balancing_loss(loss):
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.append(loss)
+
+
+def get_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ return _LOAD_BALANCING_LOSS
+
+
+def clear_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.clear()
+
+
+def batched_load_balancing_loss(args):
+ if args.moe_loss_weight == 0:
+ return 0.0
+
+ tokens_per_expert, expert_scores = zip(*get_load_balancing_loss())
+ num_layers_per_pipeline_stage = args.num_layers // args.pipeline_model_parallel_size
+ if args.num_layers_per_virtual_pipeline_stage is not None:
+ num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage
+
+ if len(tokens_per_expert) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f"Expected {num_layers_per_pipeline_stage} token_per_experts "
+ f"but found {len(tokens_per_expert)}.\nnum_layers = "
+ f"{args.num_layers}\npipeline_model_parallel_size = "
+ f"{args.pipeline_model_parallel_size}\n"
+ "num_layers_per_virtual_pipeline_stage"
+ f" = {args.num_layers_per_virtual_pipeline_stage}",
+ )
+ if len(expert_scores) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f"Expected {num_layers_per_pipeline_stage} expert_scores "
+ f"but found {len(tokens_per_expert)}.\nnum_layers = "
+ f"{args.num_layers}\npipeline_model_parallel_size = "
+ f"{args.pipeline_model_parallel_size}\n"
+ "num_layers_per_virtual_pipeline_stage"
+ f" = {args.num_layers_per_virtual_pipeline_stage}",
+ )
+
+ # Verify the shape of the tokens_per_expert and expert_scores tensors.
+ assert all(
+ (x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert)
+ )
+
+ tokens = expert_scores[0].shape[0]
+ assert all(
+ (
+ (
+ x.ndim == 2
+ and x.shape[1] == args.moe_num_experts
+ and x.shape[0] == tokens
+ )
+ for x in expert_scores
+ )
+ )
+
+ # Concatenate the contributions of each layer and convert to
+ # the correct types and formats for the dot product.
+ expert_scores = torch.cat(expert_scores, dim=1)
+ if args.moe_lbl_in_fp32:
+ expert_scores = expert_scores.float()
+ if tokens != 0:
+ expert_scores = expert_scores.mean(dim=0)
+ else:
+ expert_scores = expert_scores.sum(dim=0)
+ tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype)
+
+ expected_values = num_layers_per_pipeline_stage * args.moe_num_experts
+ assert tokens_per_expert.numel() == expected_values
+ assert expert_scores.numel() == expected_values
+
+ # Calculate the total scale across all factors.
+ #
+ # loss_weight * num_experts / (num_layers * tokens * top_k)
+ scale_numerator = args.moe_num_experts * args.moe_loss_weight
+ scale_denominator = args.num_layers * tokens * args.moe_top_k
+ scale = scale_numerator / scale_denominator
+ return scale * torch.dot(tokens_per_expert, expert_scores)
+
+
+# Calculate the expert capacity based on tokens, top_k, number of experts,
+# expert parallel group, capacity factor, and whether expert model parallelism is used.
+def expert_capacity(
+ tokens: int,
+ top_k: int,
+ num_experts: int,
+ expert_parallel_group: int,
+ moe_capacity_factor: float,
+ moe_expert_model_parallelism: bool,
+) -> int:
+ world_size = (
+ dist.get_world_size(expert_parallel_group)
+ if moe_expert_model_parallelism
+ else 1
+ )
+
+ tokens_per_expert = top_k * tokens * world_size / num_experts
+ return int(moe_capacity_factor * tokens_per_expert)
+
+
+def load_balancing_loss(
+ tokens_per_expert: torch.Tensor,
+ expert_scores: torch.Tensor,
+ top_k: int,
+ num_experts: int,
+):
+ assert len(expert_scores.size()) == 2
+ tokens, num_experts = expert_scores.size()
+ assert num_experts == num_experts
+ assert len(tokens_per_expert.size()) == 1
+ (num_experts,) = tokens_per_expert.size()
+ assert num_experts == num_experts
+ scale = num_experts / (tokens * top_k)
+ return scale * torch.dot(
+ tokens_per_expert.to(expert_scores.dtype),
+ expert_scores.mean(dim=0),
+ )
+
+
+def indices_and_bins(
+ top_expert: torch.Tensor,
+ sort_end_bit: int,
+ num_experts: int,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ top_expert = top_expert.int()
+
+ # Ensure contiguous memory layout
+ top_expert = top_expert.contiguous()
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(top_expert.device):
+ output = ops.sort(top_expert, sort_end_bit)
+ bin_ids, indices = output
+ tokens_per_expert = ops.histogram(top_expert, num_experts)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+
+ bins = bins.view(1) if not len(bins.size()) else bins
+ return indices, bin_ids, bins, tokens_per_expert
+
+
+def expert_capacity_fn(
+ tokens: int,
+ top_k: int,
+ num_experts: int,
+ expert_parallel_group: torch.distributed.ProcessGroup,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+) -> int:
+ world_size = (
+ dist.get_world_size(expert_parallel_group)
+ if moe_expert_model_parallelism
+ else 1
+ )
+ tokens_per_expert = top_k * tokens * world_size / num_experts
+ return int(moe_capacity_factor * tokens_per_expert)
+
+
+def permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ top_k,
+ w1,
+ w2,
+ w1_bias,
+ w2_bias,
+ gradient_scale,
+ alpha,
+):
+ # Route tokens to experts
+ x = x.view(-1, x.shape[-1])
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(x.device):
+ x = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
+
+ # Expert computation
+ x = mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale, alpha)
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(x.device):
+ # Route tokens back
+ out = ops.binned_scatter(x, indices, expert_weights, bins, top_k)
+ return out
+
+
+def forward_once(
+ x: torch.Tensor,
+ expert_weights: torch.Tensor,
+ top_experts: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ top_k: int = 4,
+ num_experts: int = 128,
+ expert_parallel_group: int = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ mlp_impl: Optional[str] = None,
+):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
+ top_experts, sort_end_bit, num_experts
+ )
+
+ # Calculate expert capacity
+ sl, bs, _ = x.size()
+
+ expert_capacity = expert_capacity_fn(
+ sl * bs,
+ top_k,
+ num_experts,
+ expert_parallel_group,
+ moe_capacity_factor,
+ moe_expert_model_parallelism,
+ )
+
+ if expert_capacity == 0:
+ expert_capacity = torch.max(tokens_per_expert).item()
+
+ x = permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ top_k,
+ w1,
+ w2,
+ w1_bias,
+ w2_bias,
+ gradient_scale,
+ alpha,
+ )
+ return x, tokens_per_expert
+
+
+def parallel_forward_once(
+ x: torch.Tensor,
+ expert_weights: torch.Tensor,
+ top_experts: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ top_k: int = 4,
+ num_experts: int = 128,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = True,
+ hidden_size: int = 1152,
+ mlp_impl: Optional[str] = "grouped",
+):
+ # Flatten inputs
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+
+ # TODO: remove debugging var
+ # my_rank = dist.get_rank(expert_parallel_group) if expert_parallel_group else 0
+
+ with torch.no_grad():
+ # Step 1: Local permutation setup
+ indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
+ top_experts, sort_end_bit, num_experts
+ )
+
+ # Calculate sharding parameters
+ world_size = dist.get_world_size(expert_parallel_group)
+ hidden_sharding_deg = hidden_sharding_degree(
+ world_size, num_experts, hidden_size
+ )
+ experts_per_rank_val = experts_per_rank(num_experts, world_size)
+
+ # Replicate token counts for hidden sharding
+ repeated_tokens_per_expert = ops.repeat(
+ tokens_per_expert, (hidden_sharding_deg,)
+ )
+
+ # Exchange token counts across devices
+ parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert)
+
+ # Ensure CUB knows which device to use
+ tpe_handle = dist.all_to_all_single(
+ parallel_tokens_per_expert,
+ repeated_tokens_per_expert,
+ group=expert_parallel_group,
+ async_op=True,
+ )
+
+ # Step 2: Local permutation - group tokens by target device
+ x = x.view(-1, x.shape[-1]) # [sl * bs, hs]
+ x = ops.gather(x, indices, bin_ids, bins, top_k)
+
+ # Step 3: Compute communication counts and exchange tokens
+ with torch.no_grad():
+ tpe_handle.wait()
+
+ # Reshape for per-device calculations
+ repeated_tokens_per_expert = repeated_tokens_per_expert.view(
+ world_size, experts_per_rank_val
+ )
+ parallel_tokens_per_expert = parallel_tokens_per_expert.view(
+ world_size, experts_per_rank_val
+ )
+
+ # Calculate send/recv counts
+ send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1).tolist()
+ # recv_counts = parallel_tokens_per_expert.cpu().sum(dim=-1).tolist()
+ parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
+ recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1).tolist()
+ tokens_received = sum(recv_counts)
+
+ # Replicate for hidden sharding
+ x = ops.repeat(x, (hidden_sharding_deg, 1))
+
+ # Cross-device token exchange
+ parallel_x, parallel_x_handle = _layers.all_to_all.all_to_all(
+ x, recv_counts, send_counts, expert_parallel_group, async_op=True
+ )
+
+ with torch.no_grad():
+ # Step 4: Setup for local expert computation
+ replicate_bins = ops.inclusive_cumsum(parallel_tokens_per_expert.flatten(), 0)
+ replicate_bins = (
+ replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins
+ )
+
+ # Create expert indices for received tokens
+ parallel_top_expert = torch.remainder(
+ torch.arange(
+ num_experts * hidden_sharding_deg,
+ dtype=torch.int32,
+ device=indices.device,
+ ),
+ experts_per_rank_val,
+ )
+ parallel_top_expert = ops.replicate(
+ parallel_top_expert.unsqueeze(dim=0),
+ replicate_bins,
+ tokens_received,
+ ).flatten()
+
+ # Sort tokens by expert assignment
+ parallel_bin_ids, parallel_indices = ops.sort(
+ parallel_top_expert,
+ sort_end_bit,
+ )
+
+ # Calculate bins for local experts
+ parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
+ dim=0, dtype=torch.int
+ )
+ parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = (
+ parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins
+ )
+
+ # Calculate expert capacity
+ expert_capacity = expert_capacity_fn(
+ tokens_received,
+ top_k,
+ experts_per_rank_val,
+ expert_parallel_group,
+ moe_capacity_factor,
+ moe_expert_model_parallelism,
+ )
+ if expert_capacity == 0:
+ expert_capacity = torch.max(parallel_tokens_per_expert).item()
+
+ # Locally permute the tokens and perform the expert computation.
+ # Block to make sure that the cross-device permutation is complete.
+ if mlp_impl == "grouped":
+ # GroupedMLP requires counts on CPU. We can use the tensor already
+ # moved to CPU for the prior all_to_all, which avoids an extra
+ # device synchronization.
+ parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+
+ # Step 5: Expert computation
+ parallel_x_handle.wait()
+
+ parallel_x = permute_and_compute(
+ parallel_x,
+ parallel_tokens_per_expert,
+ parallel_indices,
+ parallel_bin_ids,
+ None, # expert_weights
+ parallel_bins,
+ expert_capacity,
+ top_k=1,
+ w1=w1,
+ w2=w2,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ )
+
+ # Step 6: Reverse communication - send results back
+ x, _ = _layers.all_to_all.all_to_all(
+ parallel_x, send_counts, recv_counts, expert_parallel_group
+ )
+
+ # Step 7: Reduce across hidden sharding dimension
+ shape = (hidden_sharding_deg, -1, hidden_size)
+ x = x.view(shape).sum(dim=0)
+
+ # Step 8: Final local unpermutation
+ x = ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
+
+ return x, tokens_per_expert.flatten()
+
+
+def moe_forward(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: Optional[torch.Tensor],
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+ w1: torch.Tensor = None,
+ w2: torch.Tensor = None,
+ w1_bias: torch.Tensor = None,
+ w2_bias: torch.Tensor = None,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ forward_fn: Any = None,
+ hidden_size: int = None,
+ mlp_impl: str = "grouped",
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+
+ # Route tokens to experts
+ logits, expert_weights, expert_indices = route_tokens(
+ x,
+ router_weight,
+ router_bias,
+ moe_top_k,
+ moe_num_experts,
+ moe_jitter_eps,
+ moe_normalize_expert_weights,
+ uniform_expert_assignment,
+ training,
+ )
+
+ # Create router scores for output
+ router_scores = (
+ torch.zeros_like(logits)
+ .scatter_(1, expert_indices, expert_weights)
+ .transpose(0, 1)
+ )
+
+ in_shape = x.size()
+
+ # Prepare forward function arguments
+ forward_args = {
+ "x": x,
+ "expert_weights": expert_weights,
+ "top_experts": expert_indices,
+ "w1": w1,
+ "w2": w2,
+ "w1_bias": w1_bias,
+ "w2_bias": w2_bias,
+ "gradient_scale": gradient_scale,
+ "alpha": alpha,
+ "sort_end_bit": sort_end_bit,
+ "top_k": moe_top_k,
+ "num_experts": moe_num_experts,
+ "expert_parallel_group": expert_parallel_group,
+ "moe_capacity_factor": moe_capacity_factor,
+ "moe_expert_model_parallelism": moe_expert_model_parallelism,
+ "mlp_impl": mlp_impl,
+ }
+
+ # Add hidden_size for parallel forward
+ if moe_expert_model_parallelism and hidden_size is not None:
+ forward_args["hidden_size"] = hidden_size
+ elif moe_expert_model_parallelism and hidden_size is None:
+ # Infer hidden_size from input shape
+ forward_args["hidden_size"] = x.shape[-1]
+
+ # Compute expert outputs
+ x, tokens_per_expert = forward_fn(**forward_args)
+
+ # Save load balancing loss if needed
+ moe_loss_weight = 0.0 # Can be made configurable
+ if training and moe_loss_weight > 0:
+ save_load_balancing_loss((tokens_per_expert, logits))
+
+ # Restore original shape
+ x = x.view(in_shape)
+
+ return x, expert_weights, router_scores
+
+
+def moe_forward_with_shared_expert(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: Optional[torch.Tensor],
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+ w1: torch.Tensor = None,
+ w2: torch.Tensor = None,
+ w1_bias: torch.Tensor = None,
+ w2_bias: torch.Tensor = None,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ forward_fn: Any = None,
+ hidden_size: int = None,
+ mlp_impl: str = "grouped",
+ # Shared expert parameters
+ shared_up_proj_weight: Optional[torch.Tensor] = None,
+ shared_down_proj_weight: Optional[torch.Tensor] = None,
+ shared_up_proj_bias: Optional[torch.Tensor] = None,
+ shared_down_proj_bias: Optional[torch.Tensor] = None,
+ shared_expert_weighted_sum: bool = False,
+ shared_activation_fn: Optional[Any] = None,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+
+ # First, compute regular MoE forward pass
+ expert_out, expert_weights, router_scores = moe_forward(
+ x=x,
+ router_weight=router_weight,
+ router_bias=router_bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=training,
+ w1=w1,
+ w2=w2,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=moe_expert_model_parallelism,
+ forward_fn=forward_fn,
+ hidden_size=hidden_size,
+ mlp_impl=mlp_impl,
+ )
+
+ # If shared expert weights provided, compute shared expert output
+ if shared_up_proj_weight is not None and shared_down_proj_weight is not None:
+ shared_expert_out = shared_mlp_forward(
+ x=x,
+ up_proj_weight=shared_up_proj_weight,
+ down_proj_weight=shared_down_proj_weight,
+ up_proj_bias=shared_up_proj_bias,
+ down_proj_bias=shared_down_proj_bias,
+ activation_fn=shared_activation_fn,
+ gradient_scale=gradient_scale,
+ )
+
+ # Combine expert outputs
+ combined_out = combine_expert_shared_outputs(
+ shared_expert_out=shared_expert_out,
+ expert_out=expert_out,
+ shared_expert_weighted_sum=shared_expert_weighted_sum,
+ moe_top_k=moe_top_k,
+ )
+
+ return combined_out, expert_weights, router_scores
+
+ # Return regular MoE output if no shared expert
+ return expert_out, expert_weights, router_scores
+
+
+def create_shared_expert_weights(
+ hidden_size: int,
+ shared_expert_hidden_size: int,
+ device: torch.device,
+ dtype: torch.dtype,
+ init_method: Any,
+ output_layer_init_method: Any = None,
+) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
+
+ if output_layer_init_method is None:
+ output_layer_init_method = init_method
+
+ # Create weight tensors
+ up_proj_weight = torch.empty(
+ shared_expert_hidden_size,
+ hidden_size,
+ device=device,
+ dtype=dtype,
+ )
+ down_proj_weight = torch.empty(
+ hidden_size,
+ shared_expert_hidden_size,
+ device=device,
+ dtype=dtype,
+ )
+
+ # Initialize weights
+ init_method(up_proj_weight)
+ output_layer_init_method(down_proj_weight)
+
+ # No bias by default
+ return up_proj_weight, down_proj_weight, None, None
+
+
+# HACK: Extract device_mesh from pre-hook closure - required for transformers integration
+# This exists because device_mesh is trapped in hook closures with no model attribute
+# Fragile - breaks if hook structure changes or Python internals change
+# TODO: Replace with a more robust solution when available
+def get_device_mesh(model):
+ # Extract device_mesh from child's unused pre_hook closure
+ try:
+ # Find the pre-hook that contains 'device_mesh' in its closure
+ hook = next(
+ h
+ for h in model.experts._forward_pre_hooks.values()
+ if "device_mesh" in h.__code__.co_freevars
+ )
+ # Extract the device_mesh from the closure
+ return hook.__closure__[
+ hook.__code__.co_freevars.index("device_mesh")
+ ].cell_contents
+ except Exception:
+ return None
+
+
+class MegaBlocksMoeMLP(torch.nn.Module):
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ gradient_scale = getattr(self.experts, "gradient_scale", None)
+ alpha = getattr(self.experts, "alpha", 1.0)
+ moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+ uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
+
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = get_device_mesh(self)
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
+
+ has_parallel = (
+ expert_parallel_group is not None
+ and dist.is_initialized()
+ and dist.get_world_size(expert_parallel_group) > 1
+ )
+ forward_fn = parallel_forward_once if has_parallel else forward_once
+
+ sort_end_bit = max(
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
+ )
+ mlp_impl = getattr(self, "mlp_impl", "grouped")
+ output, expert_weights_out, *_ = moe_forward(
+ x=x,
+ router_weight=self.router.weight,
+ router_bias=self.router.bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=self.training,
+ w1=self.experts.gate_up_proj,
+ w2=self.experts.down_proj,
+ w1_bias=self.experts.gate_up_proj_bias,
+ w2_bias=self.experts.down_proj_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=has_parallel,
+ forward_fn=forward_fn,
+ hidden_size=self.experts.hidden_size,
+ mlp_impl=mlp_impl,
+ )
+ return output, expert_weights_out
+
+
+# Export main classes
+__all__ = ["MegaBlocksMoeMLP", "MegaBlocksMoeMLPWithSharedExpert"]
+
+
+class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
+
+ def __init__(self):
+ super().__init__()
+ # Shared expert weights will be set by the user
+ self.shared_up_proj_weight = None
+ self.shared_down_proj_weight = None
+ self.shared_up_proj_bias = None
+ self.shared_down_proj_bias = None
+ self.shared_expert_weighted_sum = False
+ self.shared_activation_fn = None
+
+ def set_shared_expert_weights(
+ self,
+ up_proj_weight: torch.Tensor,
+ down_proj_weight: torch.Tensor,
+ up_proj_bias: Optional[torch.Tensor] = None,
+ down_proj_bias: Optional[torch.Tensor] = None,
+ weighted_sum: bool = False,
+ activation_fn: Optional[Any] = None,
+ ):
+ self.shared_up_proj_weight = up_proj_weight
+ self.shared_down_proj_weight = down_proj_weight
+ self.shared_up_proj_bias = up_proj_bias
+ self.shared_down_proj_bias = down_proj_bias
+ self.shared_expert_weighted_sum = weighted_sum
+ self.shared_activation_fn = activation_fn
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ gradient_scale = getattr(self.experts, "gradient_scale", None)
+ alpha = getattr(self.experts, "alpha", 1.0)
+ moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+ uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
+
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = get_device_mesh(self)
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
+
+ has_parallel = (
+ expert_parallel_group is not None
+ and dist.is_initialized()
+ and dist.get_world_size(expert_parallel_group) > 1
+ )
+ forward_fn = parallel_forward_once if has_parallel else forward_once
+
+ sort_end_bit = max(
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
+ )
+ mlp_impl = getattr(self, "mlp_impl", "grouped")
+
+ output, expert_weights_out, *_ = moe_forward_with_shared_expert(
+ x=x,
+ router_weight=self.router.weight,
+ router_bias=self.router.bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=self.training,
+ w1=self.experts.gate_up_proj,
+ w2=self.experts.down_proj,
+ w1_bias=self.experts.gate_up_proj_bias,
+ w2_bias=self.experts.down_proj_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=has_parallel,
+ forward_fn=forward_fn,
+ hidden_size=self.experts.hidden_size,
+ mlp_impl=mlp_impl,
+ # Shared expert parameters
+ shared_up_proj_weight=self.shared_up_proj_weight,
+ shared_down_proj_weight=self.shared_down_proj_weight,
+ shared_up_proj_bias=self.shared_up_proj_bias,
+ shared_down_proj_bias=self.shared_down_proj_bias,
+ shared_expert_weighted_sum=self.shared_expert_weighted_sum,
+ shared_activation_fn=self.shared_activation_fn,
+ )
+ return output, expert_weights_out
+
+
+# Patch for XPU or CPU support
+if hasattr(torch, "xpu") and torch.xpu.is_available():
+ from .xpu_fused_moe import MegaBlocksMoeMLP
+
+from .cpu_moe_cpp import CPUMegaBlocksMoeMLP
diff --git a/build/torch29-cxx11-cu126-aarch64-linux/megablocks/__init__.py b/build/torch29-cxx11-cu126-aarch64-linux/megablocks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..03dbc1afe1cf156661a2b1b22003cd5f599a0309
--- /dev/null
+++ b/build/torch29-cxx11-cu126-aarch64-linux/megablocks/__init__.py
@@ -0,0 +1,26 @@
+import ctypes
+import sys
+
+import importlib
+from pathlib import Path
+from types import ModuleType
+
+def _import_from_path(file_path: Path) -> ModuleType:
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
+ # it would also be used for other imports. So, we make a module name that
+ # depends on the path for it to be unique using the hex-encoded hash of
+ # the path.
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
+ module_name = path_hash
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ if spec is None:
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
+ module = importlib.util.module_from_spec(spec)
+ if module is None:
+ raise ImportError(f"Cannot load module {module_name} from spec")
+ sys.modules[module_name] = module
+ spec.loader.exec_module(module) # type: ignore
+ return module
+
+
+globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
diff --git a/build/torch29-cxx11-cu126-aarch64-linux/metadata.json b/build/torch29-cxx11-cu126-aarch64-linux/metadata.json
new file mode 100644
index 0000000000000000000000000000000000000000..155112c59509d3b4d07f4d090cbf57071e3f5217
--- /dev/null
+++ b/build/torch29-cxx11-cu126-aarch64-linux/metadata.json
@@ -0,0 +1,18 @@
+{
+ "version": 1,
+ "license": "Apache-2.0",
+ "python-depends": [],
+ "backend": {
+ "type": "cuda",
+ "archs": [
+ "7.0",
+ "7.2",
+ "7.5",
+ "8.0",
+ "8.6",
+ "8.7",
+ "8.9",
+ "9.0"
+ ]
+ }
+}
diff --git a/build/torch29-cxx11-cu126-aarch64-linux/ops/__init__.py b/build/torch29-cxx11-cu126-aarch64-linux/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b944080df810d0b0cfc571f3009b0098a651f9b7
--- /dev/null
+++ b/build/torch29-cxx11-cu126-aarch64-linux/ops/__init__.py
@@ -0,0 +1,35 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from .binned_gather import binned_gather
+from .binned_scatter import binned_scatter
+from .cumsum import exclusive_cumsum, inclusive_cumsum
+from .gather import gather
+from .histogram import histogram
+from .padded_gather import padded_gather
+from .padded_scatter import padded_scatter
+from .repeat import repeat
+from .replicate import replicate
+from .round_up import round_up
+from .scatter import scatter
+from .sort import sort
+from .sum import sum
+from .topology import topology
+
+__all__ = [
+ 'binned_gather',
+ 'binned_scatter',
+ 'exclusive_cumsum',
+ 'inclusive_cumsum',
+ 'gather',
+ 'histogram',
+ 'padded_gather',
+ 'padded_scatter',
+ 'repeat',
+ 'replicate',
+ 'round_up',
+ 'scatter',
+ 'sort',
+ 'sum',
+ 'topology',
+]
diff --git a/build/torch29-cxx11-cu126-aarch64-linux/ops/all_to_all_benchmark.py b/build/torch29-cxx11-cu126-aarch64-linux/ops/all_to_all_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c939818edca3345f6344bbc7cef07ffe3cd0181
--- /dev/null
+++ b/build/torch29-cxx11-cu126-aarch64-linux/ops/all_to_all_benchmark.py
@@ -0,0 +1,63 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import torch.distributed as dist
+
+# from megablocks import benchmark_util
+# from megablocks.layers.all_to_all import all_to_all
+
+from .. import benchmark_util
+from .._layers.all_to_all import all_to_all
+
+_ALL_TO_ALL_BENCHMARK = (
+ (8, 1024),
+ (16, 1024),
+ (32, 1024),
+ (64, 1024),
+ (128, 1024),
+ (256, 1024),
+ (512, 1024),
+ (1024, 1024),
+ (2 * 1024, 1024),
+ (4 * 1024, 1024),
+ (8 * 1024, 1024),
+ (16 * 1024, 1024),
+ (32 * 1024, 1024),
+ (64 * 1024, 1024),
+ (128 * 1024, 1024),
+ (256 * 1024, 1024),
+ (512 * 1024, 1024),
+ (1024 * 1024, 1024),
+)
+
+
+def benchmark_all_to_all(group, sl, hs):
+ world_size = dist.get_world_size(group)
+ assert (sl % world_size) == 0
+ send_recv_sizes = [sl // world_size] * world_size
+
+ x = torch.randn((sl, hs)).cuda().half()
+
+ details = {
+ 'world_size': world_size,
+ 'message_size (B)': send_recv_sizes[0] * hs * 2, # 2B elements.
+ }
+
+ def benchmark():
+ return all_to_all(x, send_recv_sizes, send_recv_sizes, group)
+
+ time, std = benchmark_util.benchmark_function(benchmark)
+
+ if dist.get_rank(group) == 0:
+ benchmark_util.log_benchmark('All-To-All', details, time, std)
+
+
+if __name__ == '__main__':
+ assert dist.is_available()
+ group = dist.init_process_group(backend='nccl')
+ local_rank = dist.get_rank(group)
+ torch.cuda.set_device(local_rank)
+
+ for args in _ALL_TO_ALL_BENCHMARK:
+ benchmark_all_to_all(group, *args)
diff --git a/build/torch29-cxx11-cu126-aarch64-linux/ops/binned_gather.py b/build/torch29-cxx11-cu126-aarch64-linux/ops/binned_gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..189a7fa3518d660f29ea32e7a04827164af98d60
--- /dev/null
+++ b/build/torch29-cxx11-cu126-aarch64-linux/ops/binned_gather.py
@@ -0,0 +1,37 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for binned_gather kernel.
+class BinnedGatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bins: torch.Tensor,
+ bin_size: int,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bins)
+ ctx.top_k = top_k
+ return kernels.binned_gather(x, indices, None, bins, bin_size, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ indices, bins = ctx.saved_tensors
+ out = kernels.binned_scatter(grad, indices, None, bins, ctx.top_k)
+ return out, None, None, None, None
+
+
+binned_gather = BinnedGatherOp.apply
diff --git a/build/torch29-cxx11-cu126-aarch64-linux/ops/binned_scatter.py b/build/torch29-cxx11-cu126-aarch64-linux/ops/binned_scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb937c0c106662ce8108c1cb926f8f063b163d3d
--- /dev/null
+++ b/build/torch29-cxx11-cu126-aarch64-linux/ops/binned_scatter.py
@@ -0,0 +1,59 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for binned_scatter kernel.
+class BinnedScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ):
+ assert len(x.size()) == 3
+ ctx.bin_size = x.size(1)
+ ctx.top_k = top_k
+
+ # TODO(tgale): Don't save 'x' for backwards if we don't need to
+ # calculate the gradient w.r.t. 'weights'.
+ ctx.save_for_backward(x, indices, weights, bins)
+ return kernels.binned_scatter(x, indices, weights, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ x, indices, weights, bins = ctx.saved_tensors
+ out = kernels.binned_gather(
+ grad,
+ indices,
+ weights,
+ bins,
+ ctx.bin_size,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[2]:
+ wgrad = kernels.binned_scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bins,
+ ctx.top_k,
+ )
+ return out, None, wgrad, None, None
+
+
+binned_scatter = BinnedScatterOp.apply
diff --git a/build/torch29-cxx11-cu126-aarch64-linux/ops/cumsum.py b/build/torch29-cxx11-cu126-aarch64-linux/ops/cumsum.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2b7572391e20045d335cf7337246e8a9b9f57ef
--- /dev/null
+++ b/build/torch29-cxx11-cu126-aarch64-linux/ops/cumsum.py
@@ -0,0 +1,52 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ # import megablocks_ops as ops # type: ignore
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrappers for cumsum kernels.
+# NOTE: Does not support gradients.
+class ExclusiveCumsumOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, dim: int):
+ if len(x.size()) == 1:
+ x = x.view([1, -1])
+ out = torch.empty_like(x)
+ ops.exclusive_cumsum(x, 1, out)
+ return out.squeeze()
+ out = torch.empty_like(x)
+ ops.exclusive_cumsum(x, dim, out)
+ return out
+
+
+exclusive_cumsum = ExclusiveCumsumOp.apply
+
+
+class InclusiveCumsumOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, dim: int) -> torch.Tensor:
+ if len(x.size()) == 1:
+ x = x.view([1, -1])
+ out = torch.empty_like(x)
+ ops.inclusive_cumsum(x, 1, out)
+ return out.squeeze()
+ out = torch.empty_like(x)
+ ops.inclusive_cumsum(x, dim, out)
+ return out
+
+
+inclusive_cumsum = InclusiveCumsumOp.apply
diff --git a/build/torch29-cxx11-cu126-aarch64-linux/ops/gather.py b/build/torch29-cxx11-cu126-aarch64-linux/ops/gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1f87c1e7bed8d3589dd790805234976e0b05898
--- /dev/null
+++ b/build/torch29-cxx11-cu126-aarch64-linux/ops/gather.py
@@ -0,0 +1,38 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for gather kernel.
+class GatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bin_ids, bins)
+ ctx.top_k = top_k
+ return kernels.gather(x, indices, bin_ids, None, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+
+ indices, bin_ids, bins = ctx.saved_tensors
+ out = kernels.scatter(grad, indices, bin_ids, None, bins, ctx.top_k)
+ return out, None, None, None, None, None
+
+
+gather = GatherOp.apply
diff --git a/build/torch29-cxx11-cu126-aarch64-linux/ops/histogram.py b/build/torch29-cxx11-cu126-aarch64-linux/ops/histogram.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b3f058ec373cbba7555704fb5e4212c3cc75d9d
--- /dev/null
+++ b/build/torch29-cxx11-cu126-aarch64-linux/ops/histogram.py
@@ -0,0 +1,27 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for histogram kernel.
+# NOTE: Does not support gradients.
+class HistogramOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, max_val: float):
+ return ops.histogram(x, max_val)
+
+
+histogram = HistogramOp.apply
diff --git a/build/torch29-cxx11-cu126-aarch64-linux/ops/histogram_benchmark.py b/build/torch29-cxx11-cu126-aarch64-linux/ops/histogram_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..c57b7bf8228e01237236748147368b09ffdf8072
--- /dev/null
+++ b/build/torch29-cxx11-cu126-aarch64-linux/ops/histogram_benchmark.py
@@ -0,0 +1,78 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import numpy as np
+import torch
+from absl.testing import parameterized
+
+from .. import ops
+
+_HISTOGRAM_TESTS = (
+ (16384, torch.int32, 2),
+ (16384, torch.int32, 4),
+ (16384, torch.int32, 8),
+ (16384, torch.int32, 16),
+ (16384, torch.int32, 32),
+ (16384, torch.int32, 64),
+ (16384, torch.int32, 128),
+ (16384, torch.int32, 256),
+)
+
+
+def benchmark_function(fn, iterations=10):
+ # Run once to get rid of startup overhead.
+ fn()
+ times = []
+ for _ in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+ start.record()
+ fn()
+ end.record()
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ times = np.array(times)
+ return times.mean(), times.std(), times.max(), times.min()
+
+
+def log_benchmark(arguments, mean_t, std_t):
+ print('=' * 60)
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t))
+ print('=' * 60)
+
+
+class HistogramBenchmark(parameterized.TestCase):
+
+ @parameterized.parameters(*_HISTOGRAM_TESTS)
+ def testHistogram(self, n, dtype, max_val):
+ x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
+
+ mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.histogram(x, max_val),)
+ arguments = {
+ 'n': n,
+ 'dtype': dtype,
+ 'max_val': max_val,
+ }
+ log_benchmark(arguments, mean_t, std_t)
+
+ @parameterized.parameters(*_HISTOGRAM_TESTS)
+ def testTorchHistogram(self, n, dtype, max_val):
+ x = torch.randint(0, 128, (n,)).cuda().to(dtype)
+
+ mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.histc(x, max_val, 0, max_val - 1),)
+ arguments = {
+ 'n': n,
+ 'dtype': dtype,
+ 'max_val': max_val,
+ }
+ log_benchmark(arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch29-cxx11-cu126-aarch64-linux/ops/matmul_benchmark.py b/build/torch29-cxx11-cu126-aarch64-linux/ops/matmul_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ccc5dcec5e9a663794fad944c45285869c4d1c1
--- /dev/null
+++ b/build/torch29-cxx11-cu126-aarch64-linux/ops/matmul_benchmark.py
@@ -0,0 +1,415 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+
+# import stk
+
+# try:
+# import stk
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/ops/matmul_benchmark.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+
+# Calling tensor.t() calls tensor.transpose(0, 1) which calls
+# torch.as_strided(...). Circumvent this chain to avoid an overhead
+# this adds.
+def transpose_view(x):
+ return torch.as_strided(
+ x,
+ (x.shape[1], x.shape[0]),
+ (x.stride()[1], x.stride()[0]),
+ )
+
+
+_MATMUL_TESTS = (
+ (64 * 1024, 512, 2048, 64),
+ (32 * 1024, 768, 3072, 64),
+ (8 * 1024, 1024, 4096, 64),
+ (4 * 2048, 4096, 4 * 4096, 4),
+)
+
+
+def log_benchmark(name, arguments, time, std, flops):
+ benchmark_util.log_benchmark(name, arguments, time, std)
+ print('flops = {:.2f}B'.format(flops / 1e9))
+ print('throughput = {:.2f}T'.format(flops / 1e9 / time))
+ print('=' * 60)
+
+
+class MatmulBenchmark(parameterized.TestCase):
+
+ def build_sparse_matrix(self, x, padded_bins, fhs, ne):
+ blocking = 128
+ padded_tokens, _ = x.size()
+ assert padded_tokens % blocking == 0
+ assert fhs % blocking == 0
+
+ # Offsets for the sparse matrix. All rows have the
+ # same number of nonzero blocks dictated by the
+ # dimensionality of a single expert.
+ block_rows = padded_tokens // blocking
+ blocks_per_row = fhs // blocking
+ offsets = torch.arange(
+ 0,
+ block_rows * blocks_per_row + 1,
+ blocks_per_row,
+ dtype=torch.int32,
+ device=x.device,
+ )
+
+ # Indices for the sparse matrix. The indices for
+ # the intermediate matrix are dynamic depending
+ # on the mapping of tokens to experts.
+ column_indices = ops.topology(
+ padded_bins,
+ blocking,
+ block_rows,
+ blocks_per_row,
+ )
+ data = torch.empty(
+ column_indices.numel(),
+ blocking,
+ blocking,
+ dtype=torch.float16,
+ device=x.device,
+ )
+ shape = (padded_tokens, fhs * ne)
+ row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
+ return stk.Matrix(shape, data, row_indices, column_indices, offsets)
+
+ def build_input_matrix(self, sl, hs, ne):
+ x = torch.randn((sl, hs)).cuda().half()
+
+ # Assign tokens to experts uniformly.
+ top_expert = torch.arange(0, sl).cuda().int() % ne
+
+ bin_ids, indices = ops.sort(top_expert)
+ tokens_per_expert = ops.histogram(top_expert, ne)
+ padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ out = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, 1)
+ return out, padded_bins
+
+ def build_weight_matrix(self, ne, hs, fhs):
+ return torch.randn((hs, ne * fhs)).cuda().half()
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear0_Fwd_SDD_NT(self, sl, hs, fhs, ne):
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
+ w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+ topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+ w = transpose_view(w)
+
+ def benchmark():
+ return stk.ops.sdd(x, w, topo)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '0::Fwd::SDD::NT',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * fhs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear0_GradX_DSD_NN(self, sl, hs, fhs, ne):
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
+ w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+ topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+
+ def benchmark():
+ return stk.ops.dsd(topo, w)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '0::GradX::DSD::NN',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * fhs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear0_GradW_DSD_TN(self, sl, hs, fhs, ne):
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
+ topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+ topo = topo.t()
+
+ def benchmark():
+ return stk.ops.dsd(topo, x)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '0::GradW::DSD::TN',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * fhs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear1_Fwd_DSD_NN(self, sl, hs, fhs, ne):
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
+ w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+ x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+
+ def benchmark():
+ return stk.ops.dsd(x, w)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '1::Fwd::DSD::NN',
+ arguments,
+ mean_t,
+ std_t,
+ x.nnz * hs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear1_GradX_SDD_NT(self, sl, hs, fhs, ne):
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
+ w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+ x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+ out = stk.ops.dsd(x, w)
+ w = transpose_view(w)
+
+ def benchmark():
+ return stk.ops.sdd(out, w, x)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '1::GradX::SDD::NT',
+ arguments,
+ mean_t,
+ std_t,
+ x.nnz * hs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear1_GradW_DSD_TN(self, sl, hs, fhs, ne):
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
+ w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+ x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+ out = stk.ops.dsd(x, w)
+ x = x.t()
+
+ def benchmark():
+ return stk.ops.dsd(x, out)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '1::GradW::DSD::TN',
+ arguments,
+ mean_t,
+ std_t,
+ x.nnz * hs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear0_Fwd_DDD_NT(self, sl, hs, fhs, ne):
+ assert (sl % ne) == 0
+ x = torch.randn((ne, sl // ne, hs)).cuda().half()
+ w = torch.randn((ne, hs, fhs)).cuda().half()
+
+ w = w.transpose(1, 2).contiguous()
+ w = w.transpose(1, 2)
+
+ def benchmark():
+ return torch.bmm(x, w)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '0::Fwd:DDD::NT',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * fhs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear0_GradX_DDD_NN(self, sl, hs, fhs, ne):
+ assert (sl % ne) == 0
+ x = torch.randn((ne, sl // ne, hs)).cuda().half()
+ w = torch.randn((ne, hs, fhs)).cuda().half()
+ out = torch.bmm(x, w)
+ w = w.transpose(1, 2).contiguous()
+
+ def benchmark():
+ return torch.bmm(out, w)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '0:GradX:DDD::NN',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * fhs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear0_GradW_DDD_TN(self, sl, hs, fhs, ne):
+ assert (sl % ne) == 0
+ x = torch.randn((ne, sl // ne, hs)).cuda().half()
+ w = torch.randn((ne, hs, fhs)).cuda().half()
+ out = torch.bmm(x, w)
+ out = out.transpose(1, 2)
+
+ def benchmark():
+ return torch.bmm(out, x)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '0:GradW:DDD::TN',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * fhs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear1_Fwd_DDD_NN(self, sl, hs, fhs, ne):
+ assert (sl % ne) == 0
+ x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+ w = torch.randn((ne, fhs, hs)).cuda().half()
+
+ def benchmark():
+ return torch.bmm(x, w)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '1::Fwd::DDD::NN',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * hs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear1_GradX_DDD_NT(self, sl, hs, fhs, ne):
+ assert (sl % ne) == 0
+ x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+ w = torch.randn((ne, fhs, hs)).cuda().half()
+ out = torch.bmm(x, w)
+ w = torch.transpose(w, 1, 2)
+
+ def benchmark():
+ return torch.bmm(out, w)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '1::GradX::DDD::NT',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * hs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear1_GradW_DDD_TN(self, sl, hs, fhs, ne):
+ assert (sl % ne) == 0
+ x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+ w = torch.randn((ne, fhs, hs)).cuda().half()
+ out = torch.bmm(x, w)
+ x = torch.transpose(x, 1, 2)
+
+ def benchmark():
+ return torch.bmm(x, out)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '1::GradW::DDD::TN',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * hs * 2,
+ )
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch29-cxx11-cu126-aarch64-linux/ops/padded_gather.py b/build/torch29-cxx11-cu126-aarch64-linux/ops/padded_gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1cf4047c9494394d2a3884ba8830179013db7ff
--- /dev/null
+++ b/build/torch29-cxx11-cu126-aarch64-linux/ops/padded_gather.py
@@ -0,0 +1,55 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for padded_gather kernel.
+class PaddedGatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bin_ids, bins, padded_bins)
+ ctx.top_k = top_k
+ return kernels.padded_gather(
+ x,
+ indices,
+ bin_ids,
+ None,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+
+ indices, bin_ids, bins, padded_bins = ctx.saved_tensors
+ out = kernels.padded_scatter(
+ grad,
+ indices,
+ bin_ids,
+ None,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+ return out, None, None, None, None, None
+
+
+padded_gather = PaddedGatherOp.apply
diff --git a/build/torch29-cxx11-cu126-aarch64-linux/ops/padded_scatter.py b/build/torch29-cxx11-cu126-aarch64-linux/ops/padded_scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..61e021b81497e472cda5d72bdac557a0ca92d262
--- /dev/null
+++ b/build/torch29-cxx11-cu126-aarch64-linux/ops/padded_scatter.py
@@ -0,0 +1,98 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for padded_scatter kernel.
+class PaddedScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+ ):
+ maybe_x = [x] if ctx.needs_input_grad[3] else []
+ ctx.save_for_backward(
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ *maybe_x,
+ )
+ ctx.top_k = top_k
+ ctx.x_shape = x.shape
+ return kernels.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ saved_tensors = ctx.saved_tensors
+
+ indices, bin_ids, weights, bins, padded_bins = saved_tensors[:5]
+ dgrad = None
+ if ctx.needs_input_grad[0]:
+ dgrad = kernels.padded_gather(
+ grad,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[3]: # need wgrad
+ x = saved_tensors[-1]
+ wgrad = kernels.padded_scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+ return dgrad, None, None, wgrad, None, None, None, None
+
+
+def padded_scatter(
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+):
+ return PaddedScatterOp.apply(
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
diff --git a/build/torch29-cxx11-cu126-aarch64-linux/ops/padded_scatter_benchmark.py b/build/torch29-cxx11-cu126-aarch64-linux/ops/padded_scatter_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..c575cfe7487d346ba9ec18bbb7ef17f2eb77ec51
--- /dev/null
+++ b/build/torch29-cxx11-cu126-aarch64-linux/ops/padded_scatter_benchmark.py
@@ -0,0 +1,66 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import torch
+from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+_PADDED_SCATTER_BENCHMARK = (
+ # dMoE-Medium, 8-way EMP.
+ (1024 * 16, 1024, 8, 4),
+ # dMoE-Medium, post-all-to-all.
+ (1024 * 16 * 4, 1024, 8, 1),
+)
+
+
+class PaddedScatterTest(parameterized.TestCase):
+
+ @parameterized.parameters(*_PADDED_SCATTER_BENCHMARK)
+ def testPaddedScatter(self, sl, hs, ne, top_k):
+ # Create the data and indices.
+ x = torch.randn((sl, hs)).cuda().half()
+
+ # Randomly assign tokens to experts.
+ top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int()
+ bin_ids, indices = ops.sort(top_expert)
+ tokens_per_expert = ops.histogram(top_expert, ne)
+ padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+
+ # Sample weights for the scatter reduce.
+ weights = torch.rand((sl * top_k,)).cuda().half()
+
+ # Gather the data to prepare for backwards.
+ x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
+
+ def benchmark():
+ return ops.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ time, std = benchmark_util.benchmark_function(benchmark)
+ benchmark_util.log_benchmark(
+ 'Padded Scatter',
+ {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'num_experts': ne,
+ 'top_k': top_k,
+ },
+ time,
+ std,
+ )
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch29-cxx11-cu126-aarch64-linux/ops/permute_benchmark.py b/build/torch29-cxx11-cu126-aarch64-linux/ops/permute_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..6536eeeae402659a087e5c51ef9840627af56501
--- /dev/null
+++ b/build/torch29-cxx11-cu126-aarch64-linux/ops/permute_benchmark.py
@@ -0,0 +1,149 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import torch
+from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+_PERMUTE_TESTS = (
+ (16384, 768, 2),
+ (16384, 768, 4),
+ (16384, 768, 8),
+ (16384, 768, 16),
+ (16384, 768, 32),
+ (16384, 768, 64),
+ (16384, 768, 128),
+ (16384 * 8, 768, 2),
+ (16384 * 8, 768, 4),
+ (16384 * 8, 768, 8),
+ (16384 * 8, 768, 16),
+ (16384 * 8, 768, 32),
+ (16384 * 8, 768, 64),
+ (16384 * 8, 768, 128),
+)
+
+
+class PermuteBenchmark(parameterized.TestCase):
+
+ @parameterized.parameters(*_PERMUTE_TESTS)
+ def testBinnedGather(self, sl, hs, ne):
+ # NOTE: Capacity factor == 1.
+ ec = sl // ne
+
+ # Create the data and indices.
+ x = torch.randn((sl, hs)).cuda().half()
+ top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+ bin_ids, indices = ops.sort(top_expert)
+ tokens_per_expert = ops.histogram(indices, ne)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+
+ def benchmark():
+ return ops.binned_gather(x, indices, bins, ec)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'num_experts': ne,
+ }
+ benchmark_util.log_benchmark('BinnedGather', arguments, mean_t, std_t)
+
+ @parameterized.parameters(*_PERMUTE_TESTS)
+ def testBinnedScatter(self, sl, hs, ne):
+ # NOTE: Capacity factor == 1.
+ ec = sl // ne
+
+ # Create the data and indices.
+ x = torch.randn((sl, hs)).cuda().half()
+ top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+ bin_ids, indices = ops.sort(top_expert)
+ tokens_per_expert = ops.histogram(indices, ne)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ x = ops.binned_gather(x, indices, bins, ec)
+
+ def benchmark():
+ return ops.binned_scatter(x, indices, bins)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'num_experts': ne,
+ }
+ benchmark_util.log_benchmark('BinnedScatter', arguments, mean_t, std_t)
+
+ @parameterized.parameters(*_PERMUTE_TESTS)
+ def testPaddedGather(self, sl, hs, ne):
+ # Create the data and indices.
+ x = torch.randn((sl, hs)).cuda().half()
+
+ # Randomly assign tokens to experts.
+ top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+ bin_ids, indices = ops.sort(top_expert)
+ tokens_per_expert = ops.histogram(top_expert, ne)
+ padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+
+ def benchmark():
+ return ops.padded_gather(x, indices, bin_ids, bins, padded_bins)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'num_experts': ne,
+ }
+ benchmark_util.log_benchmark('PaddedGather', arguments, mean_t, std_t)
+
+ @parameterized.parameters(*_PERMUTE_TESTS)
+ def testPaddedScatter(self, sl, hs, ne):
+ # Create the data and indices.
+ x = torch.randn((sl, hs)).cuda().half()
+
+ # Randomly assign tokens to experts.
+ top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+ bin_ids, indices = ops.sort(top_expert)
+ tokens_per_expert = ops.histogram(top_expert, ne)
+ padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins)
+
+ def benchmark():
+ return ops.padded_scatter(x, indices, bin_ids, bins, padded_bins)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'num_experts': ne,
+ }
+ benchmark_util.log_benchmark('PaddedScatter', arguments, mean_t, std_t)
+
+ @parameterized.parameters(*_PERMUTE_TESTS)
+ def testCopy(self, sl, hs, ne):
+ # NOTE: Capacity factor == 1.
+ # ec = sl // ne
+
+ # Create the data and indices.
+ x = torch.randn((sl, hs)).cuda().half()
+ y = x.clone()
+
+ def benchmark():
+ return y.copy_(x)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'num_experts': ne,
+ }
+ benchmark_util.log_benchmark('Copy', arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch29-cxx11-cu126-aarch64-linux/ops/repeat.py b/build/torch29-cxx11-cu126-aarch64-linux/ops/repeat.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e9e09de5f857d51cd758ab30b2f3a846d6f9275
--- /dev/null
+++ b/build/torch29-cxx11-cu126-aarch64-linux/ops/repeat.py
@@ -0,0 +1,10 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+
+def repeat(x: torch.Tensor, tiling: torch.Size):
+ if all((t == 1 for t in tiling)):
+ return x
+ return x.repeat(*tiling)
diff --git a/build/torch29-cxx11-cu126-aarch64-linux/ops/replicate.py b/build/torch29-cxx11-cu126-aarch64-linux/ops/replicate.py
new file mode 100644
index 0000000000000000000000000000000000000000..26daf0eede330603a4b8ea7167faf1411d07ca93
--- /dev/null
+++ b/build/torch29-cxx11-cu126-aarch64-linux/ops/replicate.py
@@ -0,0 +1,36 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for replicate kernel.
+class ReplicateOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, bins: torch.Tensor, num_outputs: int):
+ ctx.save_for_backward(bins)
+ out = torch.empty((x.shape[0], num_outputs), dtype=x.dtype, device=x.device)
+ ops.replicate_forward(x, bins, out)
+ return out
+
+ @staticmethod
+ def backward(ctx: Any, grad: torch.Tensor):
+ bins, = ctx.saved_tensors
+ out = torch.empty((grad.shape[0], bins.shape[0]), dtype=grad.dtype, device=grad.device)
+ ops.replicate_backward(grad, bins, out)
+ return out, None, None
+
+
+replicate = ReplicateOp.apply
diff --git a/build/torch29-cxx11-cu126-aarch64-linux/ops/round_up.py b/build/torch29-cxx11-cu126-aarch64-linux/ops/round_up.py
new file mode 100644
index 0000000000000000000000000000000000000000..6cf6bc873c9f448c5fa9126ebcfd66e8688002af
--- /dev/null
+++ b/build/torch29-cxx11-cu126-aarch64-linux/ops/round_up.py
@@ -0,0 +1,14 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+
+def round_up(x: torch.Tensor, value: int):
+ assert isinstance(value, int)
+ assert x.dtype == torch.int32
+
+ # TODO(tgale): If this becomes and issue
+ # do this in a custom kernel. We only expect
+ # to use this on arrays of less than 1k elements.
+ return torch.div(x + (value - 1), value, rounding_mode='trunc') * value
diff --git a/build/torch29-cxx11-cu126-aarch64-linux/ops/scatter.py b/build/torch29-cxx11-cu126-aarch64-linux/ops/scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4605d9b46f387761b070352365f223dbfe69d47
--- /dev/null
+++ b/build/torch29-cxx11-cu126-aarch64-linux/ops/scatter.py
@@ -0,0 +1,72 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Optional
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for scatter kernel.
+class ScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ) -> torch.Tensor:
+ maybe_x = [x] if ctx.needs_input_grad[3] else []
+ ctx.save_for_backward(indices, bin_ids, weights, bins, *maybe_x)
+ ctx.top_k = top_k
+ ctx.x_shape = x.shape
+ return kernels.scatter(x, indices, bin_ids, weights, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ saved_tensors = ctx.saved_tensors
+
+ indices, bin_ids, weights, bins = saved_tensors[:4]
+ dgrad = None
+ if ctx.needs_input_grad[0]:
+ dgrad = kernels.gather(
+ grad,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[3]: # need wgrad
+ x = saved_tensors[-1]
+ wgrad = kernels.scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bin_ids,
+ bins,
+ ctx.top_k,
+ )
+ return dgrad, None, None, wgrad, None, None, None
+
+
+def scatter(
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+) -> Optional[torch.Tensor]:
+ return ScatterOp.apply(x, indices, bin_ids, weights, bins, top_k)
diff --git a/build/torch29-cxx11-cu126-aarch64-linux/ops/sort.py b/build/torch29-cxx11-cu126-aarch64-linux/ops/sort.py
new file mode 100644
index 0000000000000000000000000000000000000000..bda3bf64283e39533c2eae3627e76bb2d0262c9f
--- /dev/null
+++ b/build/torch29-cxx11-cu126-aarch64-linux/ops/sort.py
@@ -0,0 +1,38 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Optional, Tuple
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+_BITS_FOR_DTYPE = {
+ torch.int16: 16,
+ torch.int32: 32,
+ torch.int64: 64,
+}
+
+
+# Autograd wrapper for sort kernel.
+# NOTE: Does not support gradients.
+class SortOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, end_bit: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
+ if end_bit is None:
+ end_bit = _BITS_FOR_DTYPE[x.dtype]
+ x_out = torch.empty_like(x)
+ iota_out = torch.empty_like(x)
+ ops.sort(x, end_bit, x_out, iota_out)
+ return (x_out, iota_out)
+
+
+sort = SortOp.apply
diff --git a/build/torch29-cxx11-cu126-aarch64-linux/ops/sort_benchmark.py b/build/torch29-cxx11-cu126-aarch64-linux/ops/sort_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..a92ff957d4c552c6e61d9279a7989795472af7b7
--- /dev/null
+++ b/build/torch29-cxx11-cu126-aarch64-linux/ops/sort_benchmark.py
@@ -0,0 +1,85 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import numpy as np
+import torch
+from absl.testing import parameterized
+
+from .. import ops
+
+_SORT_TESTS = (
+ (16384, torch.int32, None),
+ (16384, torch.int32, 2),
+ (16384, torch.int32, 128),
+)
+
+_BASELINE_SORT_TESTS = ((16384,),)
+
+
+def numpy_dtype(dtype):
+ types = {
+ torch.int16: np.int16,
+ torch.int32: np.int32,
+ torch.int64: np.int64,
+ }
+ return types[dtype]
+
+
+def benchmark_function(fn, iterations=10):
+ # Run once to get rid of startup overhead.
+ fn()
+ times = []
+ for _ in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+ start.record()
+ fn()
+ end.record()
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ times = np.array(times)
+ return times.mean(), times.std(), times.max(), times.min()
+
+
+def log_benchmark(arguments, mean_t, std_t):
+ print('=' * 60)
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t))
+ print('=' * 60)
+
+
+class SortBenchmark(parameterized.TestCase):
+
+ @parameterized.parameters(*_SORT_TESTS)
+ def testSort(self, n, dtype, max_val):
+ if max_val is None:
+ max_val = np.iinfo(numpy_dtype(dtype)).max
+ end_bit = int(np.ceil(np.log2(max_val)))
+ x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
+
+ mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.sort(x, end_bit),)
+ arguments = {
+ 'n': n,
+ 'dtype': dtype,
+ 'max_val': max_val,
+ }
+ log_benchmark(arguments, mean_t, std_t)
+
+ @parameterized.parameters(*_BASELINE_SORT_TESTS)
+ def testTorchSort(self, n):
+ x = torch.randint(0, 128, (n,)).cuda().to(torch.int32)
+
+ mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.sort(x))
+ arguments = {
+ 'n': n,
+ }
+ log_benchmark(arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch29-cxx11-cu126-aarch64-linux/ops/stk_autocast.py b/build/torch29-cxx11-cu126-aarch64-linux/ops/stk_autocast.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a3626e5e0eec51339c95a448bca84be14a2ca93
--- /dev/null
+++ b/build/torch29-cxx11-cu126-aarch64-linux/ops/stk_autocast.py
@@ -0,0 +1,39 @@
+# vendored from
+# https://github.com/stanford-futuredata/stk/blob/736313768ef697ce13a0594a41b2512a0fbc9884/stk/backend/autocast.py
+import functools
+import torch
+
+
+def _is_eligible(x):
+ return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64)
+
+
+def _cast(x, dtype):
+ if isinstance(x, torch.Tensor) and _is_eligible(x):
+ return x.to(dtype)
+ elif isinstance(x, map):
+ return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()}
+ elif isinstance(x, list) or isinstance(x, tuple):
+ return type(x)(map(lambda y: _cast(y, dtype), x))
+ return x
+
+
+def custom_fwd(fwd):
+ """Wrap a custom autograd function that always uses autocast dtype."""
+
+ @functools.wraps(fwd)
+ def decorate_fwd(*args, **kwargs):
+ if torch.is_autocast_enabled():
+ with torch.autocast(device_type="cuda", enabled=False):
+ dtype = torch.get_autocast_gpu_dtype()
+ return fwd(*_cast(args, dtype), **_cast(kwargs, dtype))
+ return fwd(*args, **kwargs)
+ return decorate_fwd
+
+
+def custom_bwd(bwd):
+ @functools.wraps(bwd)
+ def decorate_bwd(*args, **kwargs):
+ with torch.autocast(device_type="cuda", enabled=False):
+ return bwd(*args, **kwargs)
+ return decorate_bwd
\ No newline at end of file
diff --git a/build/torch29-cxx11-cu126-aarch64-linux/ops/sum.py b/build/torch29-cxx11-cu126-aarch64-linux/ops/sum.py
new file mode 100644
index 0000000000000000000000000000000000000000..e00c1aa68e1193f5b72f75a2edc37de8d505facc
--- /dev/null
+++ b/build/torch29-cxx11-cu126-aarch64-linux/ops/sum.py
@@ -0,0 +1,9 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+import torch
+
+
+def sum(x: torch.Tensor, dim: int = 0):
+ if x.shape[dim] == 1:
+ return x.squeeze(dim=dim)
+ return x.sum(dim=dim)
diff --git a/build/torch29-cxx11-cu126-aarch64-linux/ops/topology.py b/build/torch29-cxx11-cu126-aarch64-linux/ops/topology.py
new file mode 100644
index 0000000000000000000000000000000000000000..76a50d3164db20534b099dcb4d8487a7aef25d15
--- /dev/null
+++ b/build/torch29-cxx11-cu126-aarch64-linux/ops/topology.py
@@ -0,0 +1,45 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for topology kernel.
+# NOTE: Does not support gradients.
+class TopologyOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(
+ ctx: Any,
+ padded_bins: torch.Tensor,
+ block_size: int,
+ output_block_rows: int,
+ output_block_columns: int,
+ ):
+ out = torch.empty(
+ output_block_rows * output_block_columns,
+ dtype=torch.int16,
+ device=padded_bins.device,
+ )
+ ops.indices(
+ padded_bins,
+ block_size,
+ output_block_rows,
+ output_block_columns,
+ out,
+ )
+ return out
+
+
+topology = TopologyOp.apply
diff --git a/build/torch29-cxx11-cu126-aarch64-linux/stk/__init__.py b/build/torch29-cxx11-cu126-aarch64-linux/stk/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..73c40c267b1c3f4949e9c957a5d2c9f682dfc1a6
--- /dev/null
+++ b/build/torch29-cxx11-cu126-aarch64-linux/stk/__init__.py
@@ -0,0 +1,7 @@
+# import stk.random
+# import stk.ops
+# from stk.matrix import Matrix
+
+from . import random
+from . import ops
+from .matrix import Matrix
diff --git a/build/torch29-cxx11-cu126-aarch64-linux/stk/backend/__init__.py b/build/torch29-cxx11-cu126-aarch64-linux/stk/backend/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/build/torch29-cxx11-cu126-aarch64-linux/stk/backend/autocast.py b/build/torch29-cxx11-cu126-aarch64-linux/stk/backend/autocast.py
new file mode 100644
index 0000000000000000000000000000000000000000..97f6e919a60f3fd579ed0215031008d14111dc96
--- /dev/null
+++ b/build/torch29-cxx11-cu126-aarch64-linux/stk/backend/autocast.py
@@ -0,0 +1,37 @@
+import functools
+import torch
+
+
+def _is_eligible(x):
+ return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64)
+
+
+def _cast(x, dtype):
+ if isinstance(x, torch.Tensor) and _is_eligible(x):
+ return x.to(dtype)
+ elif isinstance(x, map):
+ return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()}
+ elif isinstance(x, list) or isinstance(x, tuple):
+ return type(x)(map(lambda y: _cast(y, dtype), x))
+ return x
+
+
+def custom_fwd(fwd):
+ """Wrap a custom autograd function that always uses autocast dtype."""
+
+ @functools.wraps(fwd)
+ def decorate_fwd(*args, **kwargs):
+ if torch.is_autocast_enabled():
+ with torch.autocast(device_type="cuda", enabled=False):
+ dtype = torch.get_autocast_gpu_dtype()
+ return fwd(*_cast(args, dtype), **_cast(kwargs, dtype))
+ return fwd(*args, **kwargs)
+ return decorate_fwd
+
+
+def custom_bwd(bwd):
+ @functools.wraps(bwd)
+ def decorate_bwd(*args, **kwargs):
+ with torch.autocast(device_type="cuda", enabled=False):
+ return bwd(*args, **kwargs)
+ return decorate_bwd
diff --git a/build/torch29-cxx11-cu126-aarch64-linux/stk/backend/sputnik.py b/build/torch29-cxx11-cu126-aarch64-linux/stk/backend/sputnik.py
new file mode 100644
index 0000000000000000000000000000000000000000..220c947bc1e932e8c77cc30f4069e9930f1aa962
--- /dev/null
+++ b/build/torch29-cxx11-cu126-aarch64-linux/stk/backend/sputnik.py
@@ -0,0 +1,316 @@
+import torch
+
+from ..backend import triton_kernels as backend
+from ..backend.autocast import custom_bwd, custom_fwd
+
+
+def _standardize_shape(x, transpose):
+ if transpose:
+ return torch.Size((x[1], x[0]))
+ return x
+
+
+def _sparse_transpose(x):
+ return (torch.Size((x[0][1], x[0][0])), ) + x[1:]
+
+
+def _transpose_helper(x, transpose):
+ if isinstance(x, torch.Tensor):
+ return x.t() if transpose else x
+ if transpose:
+ x = _sparse_transpose(x)
+ return x + (transpose,)
+
+
+def _wrap(x):
+ if isinstance(x, torch.Tensor):
+ return (x,)
+ return x
+
+
+def _is_transposed(x):
+ return (not x.is_contiguous() and
+ x.stride()[0] == 1 and
+ x.stride()[1] == x.size()[0])
+
+
+def _call_helper(op, out, a, b, trans_a, trans_b):
+ args = (_wrap(_transpose_helper(a, trans_a)) +
+ _wrap(_transpose_helper(b, trans_b)))
+ if isinstance(out, tuple):
+ args = args + out
+ return op(*args)
+
+
+def _preprocess_inputs(lhs, rhs, dy):
+ if isinstance(lhs, torch.Tensor) and _is_transposed(lhs):
+ lhs = lhs.t()
+ if isinstance(rhs, torch.Tensor) and _is_transposed(rhs):
+ rhs = rhs.t()
+ if (isinstance(dy, torch.Tensor) and
+ not dy.is_contiguous() and
+ not _is_transposed(dy)):
+ dy = dy.contiguous()
+ if isinstance(dy, tuple) and not dy[1].is_contiguous():
+ dy = (dy[0], dy[1].contiguous()) + dy[2:]
+ return lhs, rhs, dy
+
+
+def _postprocess_outputs(x, transpose, grad):
+ if isinstance(x, torch.Tensor) and transpose:
+ return grad.t()
+ return grad
+
+
+def _lhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs):
+ lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy)
+
+ a, b = (rhs, dy) if trans_lhs else (dy, rhs)
+ trans_a = trans_lhs and trans_rhs
+ trans_b = trans_lhs or not trans_rhs
+ out = _call_helper(op, lhs, a, b, trans_a, trans_b)
+ return _postprocess_outputs(lhs, trans_lhs, out)
+
+
+def _rhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs):
+ lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy)
+
+ a, b = (dy, lhs) if trans_rhs else (lhs, dy)
+ trans_a = not trans_lhs or trans_rhs
+ trans_b = trans_lhs and trans_rhs
+ out = _call_helper(op, rhs, a, b, trans_a, trans_b)
+ return _postprocess_outputs(rhs, trans_rhs, out)
+
+
+class DSD(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs):
+ ctx.save_for_backward(data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ rhs)
+ ctx.shape = _standardize_shape(shape, transpose_a)
+ ctx.transpose_a = transpose_a
+
+ out = torch.empty(
+ (shape[0], rhs.size()[1]),
+ dtype=rhs.dtype,
+ device=rhs.device)
+
+ backend.dsd(shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs,
+ out)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs = (ctx.shape,) + saved_tensors[:-1]
+ rhs = saved_tensors[-1]
+ trans_a = ctx.transpose_a
+ trans_b = _is_transposed(rhs)
+
+ ddata = None
+ if ctx.needs_input_grad[1]:
+ ddata = _lhs_gradient(sdd,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ drhs = None
+ if ctx.needs_input_grad[-1]:
+ op = dds if trans_b else dsd
+ drhs = _rhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return None, ddata, None, None, None, None, None, None, None, drhs
+
+
+dsd = DSD.apply
+
+
+class DDS(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b):
+ ctx.save_for_backward(lhs,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t)
+ ctx.shape = _standardize_shape(shape, transpose_b)
+ ctx.transpose_b = transpose_b
+ out = torch.empty((lhs.size()[0], shape[1]),
+ dtype=lhs.dtype,
+ device=lhs.device)
+ backend.dds(lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b,
+ out)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs = saved_tensors[0]
+ rhs = (ctx.shape,) + saved_tensors[1:]
+ trans_a = _is_transposed(lhs)
+ trans_b = ctx.transpose_b
+
+ dlhs = None
+ if ctx.needs_input_grad[0]:
+ op = dsd if trans_a else dds
+ dlhs = _lhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ ddata = None
+ if ctx.needs_input_grad[2]:
+ ddata = _rhs_gradient(sdd,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return dlhs, None, ddata, None, None, None, None, None, None, None
+
+
+dds = DDS.apply
+
+
+class SDD(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ lhs,
+ rhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t):
+ ctx.save_for_backward(
+ lhs,
+ rhs,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t)
+ ctx.shape = shape
+ out = torch.empty(
+ data.shape,
+ dtype=lhs.dtype,
+ device=lhs.device)
+ backend.sdd(lhs,
+ rhs,
+ shape,
+ out,
+ offsets,
+ row_indices,
+ column_indices)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs, rhs = saved_tensors[:2]
+ dy = (ctx.shape, dy) + saved_tensors[2:]
+ trans_a = _is_transposed(lhs)
+ trans_b = _is_transposed(rhs)
+
+ dlhs = None
+ if ctx.needs_input_grad[0]:
+ op = dds if trans_a else dsd
+ dlhs = _lhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ drhs = None
+ if ctx.needs_input_grad[1]:
+ op = dsd if trans_b else dds
+ drhs = _rhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return dlhs, drhs, None, None, None, None, None, None, None, None
+
+
+sdd = SDD.apply
+
+class RowIndices(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, shape, data, offsets, column_indices):
+ out = torch.empty(
+ column_indices.shape,
+ dtype=column_indices.dtype,
+ device=column_indices.device)
+ backend.row_indices(shape, data, offsets, column_indices, out)
+ return out
+
+
+row_indices = RowIndices.apply
diff --git a/build/torch29-cxx11-cu126-aarch64-linux/stk/backend/triton_kernels.py b/build/torch29-cxx11-cu126-aarch64-linux/stk/backend/triton_kernels.py
new file mode 100644
index 0000000000000000000000000000000000000000..c535309f3321249f475367164a558f94a4f8eb86
--- /dev/null
+++ b/build/torch29-cxx11-cu126-aarch64-linux/stk/backend/triton_kernels.py
@@ -0,0 +1,393 @@
+import torch
+import triton
+import triton.language as tl
+from dataclasses import dataclass
+
+@dataclass
+class TritonConfig:
+ BLOCK_M: int = 128
+ BLOCK_N: int = 128
+ BLOCK_K: int = 32
+ BLOCK_SIZE: int = 128
+ NUM_STAGES: int = 4
+ NUM_WARPS: int = 4
+
+def _validate_matmul_dims(M: int, K: int, N: int):
+ error_string = "incompatible dimensions: tensor has dim with length: {}, which must be divisible by {}"
+ assert M % TritonConfig.BLOCK_M == 0, error_string.format(M, TritonConfig.BLOCK_M)
+ assert K % TritonConfig.BLOCK_K == 0, error_string.format(K, TritonConfig.BLOCK_K)
+ assert N % TritonConfig.BLOCK_N == 0, error_string.format(N, TritonConfig.BLOCK_N)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _sdd_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+ # matrix multiplication
+ pid = tl.program_id(0)
+ pid_m = tl.load(row_indices + pid)
+ pid_n = tl.load(column_indices + pid)
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
+ rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
+ rk = tl.arange(0, BLOCK_K)
+ # pointers
+ A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
+ B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ for k in range(0, tl.cdiv(K, BLOCK_K)):
+ a = tl.load(A)
+ b = tl.load(B)
+ acc += tl.dot(a, b)
+ A += BLOCK_K * stride_ak
+ B += BLOCK_K * stride_bk
+ #Store to sparse matrix
+ acc = acc.to(C.dtype.element_ty)
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+ cm = tl.arange(0, BLOCK_M)
+ cn = tl.arange(0, BLOCK_N)
+ C = C + pid * BLOCK_ELEMENTS + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _dsd_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices, offsets,
+ block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+
+ # matrix multiplication
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+
+ num_pid_m = tl.num_programs(0)
+ num_pid_n = tl.num_programs(1)
+ pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M)
+
+ start_inx = tl.load(offsets + pid_m)
+ end_inx = tl.load(offsets + pid_m + 1)
+
+ # pointers to sparse matrix
+ rm = tl.arange(0, BLOCK_M)
+ rak = tl.arange(0, BLOCK_K)
+
+ A += (rm[:, None] * stride_am + rak[None, :] * stride_ak)
+
+ # pointers to dense matrix
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ rbk = tl.arange(0, BLOCK_K)
+ B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn)
+
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K)
+
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+ ak_sub_incr = BLOCK_K * stride_ak
+ bk_sub_incr = BLOCK_K * stride_bk
+ bk_block_incr = BLOCK_SIZE * stride_bk
+
+ for k in range(nsub_blocks * (end_inx - start_inx)):
+ sub_block_inx = k % nsub_blocks
+ block_inx = k // nsub_blocks
+
+ if trans_A:
+ ptr_A = A + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr
+ else:
+ ptr_A = A + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr
+
+ ptr_B = B + tl.load(column_indices + start_inx + block_inx) * bk_block_incr + sub_block_inx * bk_sub_incr
+
+ a = tl.load(ptr_A)
+ b = tl.load(ptr_B)
+ acc += tl.dot(a, b)
+
+ acc = acc.to(C.dtype.element_ty)
+
+ cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+
+ C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _dds_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices, offsets,
+ block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+
+ # matrix multiplication
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+
+ num_pid_m = tl.num_programs(0)
+ num_pid_n = tl.num_programs(1)
+ pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M)
+
+ start_inx = tl.load(offsets + pid_n)
+ end_inx = tl.load(offsets + pid_n + 1)
+
+ # pointers to dense matrix
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ rak = tl.arange(0, BLOCK_K)
+
+ A += (rm[:, None] * stride_am + rak[None, :] * stride_ak)
+
+ # pointers to sparse matrix
+ rn = tl.arange(0, BLOCK_N)
+ rbk = tl.arange(0, BLOCK_K)
+ B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn)
+
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K)
+
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+
+ ak_sub_incr = BLOCK_K * stride_ak
+ ak_block_incr = BLOCK_SIZE * stride_ak
+ bk_sub_incr = BLOCK_K * stride_bk
+
+ for k in range(nsub_blocks * (end_inx - start_inx)):
+ sub_block_inx = k % nsub_blocks
+ block_inx = k // nsub_blocks
+
+ if trans_B:
+ ptr_B = B + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr
+ else:
+ ptr_B = B + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr
+
+ ptr_A = A + tl.load(column_indices + start_inx + block_inx) * ak_block_incr + sub_block_inx * ak_sub_incr
+ a = tl.load(ptr_A)
+ b = tl.load(ptr_B)
+ acc += tl.dot(a, b)
+
+ acc = acc.to(C.dtype.element_ty)
+ cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+def dsd(shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs,
+ out
+ ):
+
+ device = rhs.device
+ trans_A = transpose_a
+ trans_B = False
+
+ if rhs.stride(0) > 1 and rhs.stride(1) > 1:
+ trans_B = True
+
+ # checks constraints
+ assert shape[1] == rhs.shape[0], "incompatible dimensions"
+ M, K = shape
+ _, N = rhs.shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if rhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ stride_am, stride_ak = data.stride(1), data.stride(2)
+ stride_bk, stride_bn = rhs.stride(0), rhs.stride(1)
+ a_column_indices = column_indices
+ a_offsets = offsets
+
+ # launch kernel
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N']))
+
+ if trans_A:
+ stride_am, stride_ak = data.stride(2), data.stride(1)
+ a_column_indices, a_offsets = column_indices_t, offsets_t
+
+ if trans_B:
+ stride_bk, stride_bn = rhs.stride(1), rhs.stride(0)
+
+ _dsd_kernel[grid](
+ data.data, rhs, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(0), out.stride(1),
+ row_indices, a_column_indices, a_offsets,
+ block_offsets_t, trans_A, trans_B,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+ # return out
+
+def dds(lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b,
+ out
+ ):
+
+ device = lhs.device
+ trans_B = transpose_b
+ trans_A = False
+
+ if lhs.stride(0) > 1 and lhs.stride(1) > 1:
+ trans_A = True
+
+ # checks constraints
+ assert lhs.shape[1] == shape[0], "incompatible dimensions"
+ M, K = lhs.shape
+ _, N = shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if lhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ stride_am, stride_ak = lhs.stride(0), lhs.stride(1)
+ stride_bk, stride_bn = data.stride(1), data.stride(2)
+ b_column_indices = column_indices_t
+ b_offsets = offsets_t
+
+ # launch kernel
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N']))
+
+ if trans_A:
+ stride_am, stride_ak = lhs.stride(1), lhs.stride(0)
+ if trans_B:
+ stride_bk, stride_bn = data.stride(2), data.stride(1)
+ b_column_indices, b_offsets = column_indices, offsets
+
+ _dds_kernel[grid](
+ lhs, data, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(0), out.stride(1),
+ row_indices, b_column_indices, b_offsets,
+ block_offsets_t, trans_A, trans_B,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+
+def sdd(lhs,
+ rhs,
+ shape,
+ out,
+ offsets,
+ row_indices,
+ column_indices
+ ):
+
+ device = out.device
+ trans_A = False
+ trans_B = False
+
+ if lhs.stride(0) > 1 and lhs.stride(1) > 1:
+ trans_A = True
+ if rhs.stride(0) > 1 and rhs.stride(1) > 1:
+ trans_B = True
+
+ # checks constraints
+ assert lhs.shape[1] == rhs.shape[0], "incompatible dimensions"
+ M, K = lhs.shape
+ _, N = rhs.shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if out.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ # launch kernel
+ nnz_blocks = len(row_indices)
+ grid = lambda META: (nnz_blocks,)
+
+ stride_am, stride_ak = lhs.stride(0), lhs.stride(1)
+ stride_bk, stride_bn = rhs.stride(0), rhs.stride(1)
+
+ if trans_A:
+ stride_am, stride_ak = lhs.stride(1), lhs.stride(0)
+ if trans_B:
+ stride_bk, stride_bn = rhs.stride(1), rhs.stride(0)
+
+ _sdd_kernel[grid](
+ lhs, rhs, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(1), out.stride(2),
+ row_indices, column_indices,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+
+@triton.jit
+def _row_indices_kernel(offsets, out):
+ pid = tl.program_id(0)
+ row_offset = tl.load(offsets + pid)
+ nnz_blocks = tl.load(offsets + pid + 1) - row_offset
+ for nnz_block in range(nnz_blocks):
+ tl.store(out + row_offset + nnz_block, pid)
+
+def row_indices(
+ shape, data, offsets, column_indices, out
+):
+ block_rows = len(offsets) - 1
+ _row_indices_kernel[(block_rows, )](offsets, out)
diff --git a/build/torch29-cxx11-cu126-aarch64-linux/stk/matrix.py b/build/torch29-cxx11-cu126-aarch64-linux/stk/matrix.py
new file mode 100644
index 0000000000000000000000000000000000000000..80f42263d6aada287adbfa52a61fe950162a9e28
--- /dev/null
+++ b/build/torch29-cxx11-cu126-aarch64-linux/stk/matrix.py
@@ -0,0 +1,329 @@
+import numpy as np
+import torch
+
+# 1. Add heavyweight (data) validation helper.
+# 2. Add construction helpers
+# 3. Make indentation consistent
+# 4. Replace asserts with descriptive errors.
+
+##
+### Validation helpers.
+##
+
+
+def _validate_matrix(shape, data, row_indices, column_indices, offsets):
+ # Data should be [nnz, block_size, block_size]
+ if data.dim() == 1:
+ data = torch.reshape(data, [data.numel(), 1, 1])
+
+ # Blocks should be square.
+ if data.shape[-2] != data.shape[-1]:
+ raise ValueError(
+ "Expected square blocking in data. "
+ f"Got block shape {[data.shape[-2], data.shape[-1]]}")
+
+ # Flatten batch dimensions on data - original shape preserved
+ # in shape argument.
+ block_size = data.shape[-1]
+ data = data.view([-1, block_size, block_size])
+
+ if data.dim() != 3:
+ raise ValueError(
+ "Expected 3D shape for data (nnz, block, block). "
+ f"Got shape {data.dim()}D shape.")
+
+ block_size = data.shape[1]
+ if shape[-2] % block_size != 0 or shape[-1] % block_size != 0:
+ raise ValueError(
+ "Matrix shape must be dividible by blocking. "
+ f"Got shape {shape} with "
+ f"{[block_size, block_size]} blocking.")
+
+ if np.prod(shape) < data.numel():
+ raise ValueError(
+ "Invalid matrix. Number of nonzeros exceeds matrix capacity "
+ f"({data.numel()} v. {np.prod(shape)})")
+
+ if row_indices.dim() != 1:
+ raise ValueError(
+ f"Expected 1D row_indices. Got {row_indices.dim()}D row_indices.")
+
+ if column_indices.dim() != 1:
+ raise ValueError(
+ f"Expected 1D column_indices. Got {column_indices.dim()}D column_indices.")
+
+ if offsets.dim() != 1:
+ raise ValueError(
+ f"Expected 1D offsets. Got {offsets.dim()}D offsets.")
+
+ if row_indices.numel() != data.shape[0]:
+ raise ValueError(
+ "Expected 1 index per nonzero block. "
+ f"Got {row_indices.numel()} row_indices for {data.shape[0]} blocks")
+
+ if column_indices.numel() != data.shape[0]:
+ raise ValueError(
+ "Expected 1 index per nonzero block. "
+ f"Got {column_indices.numel()} column_indices for {data.shape[0]} blocks")
+
+ block_rows = np.prod(shape[:-1]) / block_size
+ if offsets.numel() != block_rows + 1:
+ raise ValueError(
+ "Expected one offset per block row plus one. "
+ f"Got {offsets.numel()} offsets with {block_rows} block rows.")
+
+ is_cuda = (data.is_cuda and
+ row_indices.is_cuda and
+ column_indices.is_cuda and
+ offsets.is_cuda)
+ is_cpu = (not data.is_cuda and
+ not row_indices.is_cuda and
+ not column_indices.is_cuda and
+ not offsets.is_cuda)
+ if not (is_cuda or is_cpu):
+ raise ValueError(
+ "Expected data & meta-data on common device. "
+ f"Got data on {data.device}, row_indices on {row_indices.device} "
+ f"column_indices on {column_indices.device} and "
+ f"offsets on {offsets.device}.")
+
+ if data.dtype != torch.float16:
+ raise ValueError(
+ f"Expected float16 data. Got {data.dtype} data.")
+ if row_indices.dtype != torch.int16:
+ raise ValueError(
+ f"Expected int16 row_indices. Got {row_indices.dtype} row_indices.")
+ if column_indices.dtype != torch.int16:
+ raise ValueError(
+ f"Expected int16 column_indices. Got {column_indices.dtype} column_indices.")
+ if offsets.dtype != torch.int32:
+ raise ValueError(
+ f"Expected int32 offsets. Got {offsets.dtype} offsets.")
+ return data
+
+
+def _transpose(size, data, row_indices, column_indices, offsets):
+ block_columns = size[1] // data.shape[1]
+
+ # Sort row indices by column indices to get the transposed matrix's
+ # column indices.
+ gather_indices = column_indices.argsort()
+ column_indices_t = row_indices.gather(0, gather_indices)
+ block_offsets_t = gather_indices.int()
+
+ # NOTE: Histogram is not implemented for any integer type on CPU. Do
+ # the histogram in 32-bit float, which can exactly represent 16-bit
+ # integers.
+ column_indices_float = column_indices.float()
+
+ zero = torch.zeros((1,), dtype=torch.int32, device=data.device)
+ nnz_per_column = column_indices_float.histc(block_columns, 0, block_columns)
+ nnz_per_column = nnz_per_column.int()
+ offsets_t = torch.cat([zero, nnz_per_column.cumsum(0, dtype=torch.int32)])
+ return column_indices_t, offsets_t, block_offsets_t
+
+
+class Matrix(torch.nn.Module):
+ """A matrix stored in sparse format.
+
+ Underlying format is block compressed sparse row (BCSR).
+
+ TODO(tgale): Make this mirror torch.Tensor API as much as possible.
+ """
+
+ def __init__(self,
+ size,
+ data,
+ row_indices,
+ column_indices,
+ offsets,
+ column_indices_t=None,
+ offsets_t=None,
+ block_offsets_t=None):
+ super().__init__()
+ self._size = size
+ self._data = data
+ self._row_indices = row_indices
+ self._column_indices = column_indices
+ self._offsets = offsets
+
+ # Produce the transpose meta-data if it is not passed in.
+ if ((column_indices_t is None) or (offsets_t is None) or
+ (block_offsets_t is None)):
+ column_indices_t, offsets_t, block_offsets_t = _transpose(
+ size, data, row_indices, column_indices, offsets)
+ self._column_indices_t = column_indices_t
+ self._offsets_t = offsets_t
+ self._block_offsets_t = block_offsets_t
+
+ self._transposed = False
+
+ # Validate that our metadata will not overflow.
+ max_dim = np.iinfo(np.int16).max * self.blocking
+ if column_indices.dtype == torch.int16:
+ if size[0] > max_dim or size[1] > max_dim:
+ raise ValueError(
+ "Sparse matrix with shape {size} exceeds representable "
+ "size with 16-bit indices.")
+
+ def validate(self):
+ _validate_matrix(self._size,
+ self._data,
+ self._row_indices,
+ self._column_indices,
+ self._offsets)
+
+ # TODO(tgale): Add heavyweight data validation.
+
+ def to(self, device):
+ # TODO(tgale): Handle type conversions here. We
+ # need to set the appropriate meta-data type for
+ # the given floating-point type.
+ self._data = self._data.to(device)
+ self._row_indices = self._row_indices.to(device)
+ self._column_indices = self._column_indices.to(device)
+ self._offsets = self._offsets.to(device)
+ self._column_indices_t = self._column_indices_t.to(device)
+ self._offsets_t = self._offsets_t.to(device)
+ self._block_offsets_t = self._block_offsets_t.to(device)
+ return self
+
+ def cuda(self):
+ return self.to(torch.cuda.current_device())
+
+ def clone(self):
+ return Matrix(
+ self.size(),
+ self.data.clone(),
+ self.row_indices.clone(),
+ self.column_indices.clone(),
+ self.offsets.clone(),
+ self.column_indices_t.clone(),
+ self.offsets_t.clone(),
+ self.block_offsets_t.clone())
+
+ def t(self):
+ if self.dim() != 2:
+ raise ValueError(
+ "t() expects a tensor with <= 2 dimensions, "
+ f"but self is {self.dim()}D.")
+ out = Matrix(self.size(),
+ self.data,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+ out._transposed = not self._transposed
+ out._size = torch.Size((self._size[1], self._size[0]))
+ return out
+
+ def contiguous(self):
+ raise ValueError("Not yet implemented.")
+
+ def is_contiguous(self):
+ return not self._transposed
+
+ @property
+ def is_cuda(self):
+ return self._data.is_cuda
+
+ @property
+ def device(self):
+ return self._data.device
+
+ def size(self):
+ return self._size
+
+ @property
+ def shape(self):
+ return self.size()
+
+ def dim(self):
+ return len(self._size)
+
+ @property
+ def data(self):
+ return self._data
+
+ @property
+ def row_indices(self):
+ return self._row_indices
+
+ @property
+ def column_indices(self):
+ return self._column_indices
+
+ @property
+ def offsets(self):
+ return self._offsets
+
+ @property
+ def offsets_t(self):
+ return self._offsets_t
+
+ @property
+ def column_indices_t(self):
+ return self._column_indices_t
+
+ @property
+ def block_offsets_t(self):
+ return self._block_offsets_t
+
+ @property
+ def dtype(self):
+ return self.data.dtype
+
+ @property
+ def nnz(self):
+ return self.data.numel()
+
+ @property
+ def blocking(self):
+ return self.data.shape[1]
+
+ @property
+ def requires_grad(self):
+ return self.data.requires_grad
+
+ def requires_grad_(self, x):
+ self.data.requires_grad_(x)
+ return self
+
+ def view(self, *shape):
+ assert self.is_contiguous()
+ if shape[-1] != self.size()[-1]:
+ raise ValueError(
+ "Can't change view on compressed dimension. "
+ f"{self.size()[-1]} v. {shape[-1]}.")
+ if np.prod(shape) != np.prod(self.size()):
+ raise ValueError(
+ "Mismatch in numel of Matrix and new shape. "
+ f"{np.prod(self.size())} v. {np.prod(shape)}")
+ return Matrix(shape,
+ self.data,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+
+ @property
+ def grad(self):
+ # TODO(tgale): Make sure this mirrors torch.Tensor
+ # behavior in the case where we ask for the gradient
+ # of a non-contiguous tensor.
+ size = self.size()
+ if not self.is_contiguous():
+ size = torch.Size((size[1], size[0]))
+ out = Matrix(size,
+ self.data.grad,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+ return out if self.is_contiguous() else out.t()
diff --git a/build/torch29-cxx11-cu126-aarch64-linux/stk/ops/__init__.py b/build/torch29-cxx11-cu126-aarch64-linux/stk/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc873b236f4cd4036964c016a4036e3ce5ebf1ac
--- /dev/null
+++ b/build/torch29-cxx11-cu126-aarch64-linux/stk/ops/__init__.py
@@ -0,0 +1,3 @@
+from .linear_ops import dds, dsd, sdd
+from .matrix_ops import ones_like, row_indices, sum, to_dense, to_sparse
+from .eltwise_ops import mul
diff --git a/build/torch29-cxx11-cu126-aarch64-linux/stk/ops/eltwise_ops.py b/build/torch29-cxx11-cu126-aarch64-linux/stk/ops/eltwise_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba7d7332320250fd01fa60e60528f19de3e8ed03
--- /dev/null
+++ b/build/torch29-cxx11-cu126-aarch64-linux/stk/ops/eltwise_ops.py
@@ -0,0 +1,28 @@
+from ..matrix import Matrix
+
+def mul(a, b):
+ """Performs element-wise multiplication of matrices a and b.
+
+ It is the user's responsibility to make sure that a and b
+ follow the same matrix topology. This function assumes it is safe
+ to use the topoplogy of a.
+
+ Args:
+ a: stk.Matrix.
+ b: stk.Matrix with a's matrix topology.
+
+ Returns:
+ stk.Matrix where the entries correspond to torch.mul(a, b).
+ """
+ assert isinstance(a, Matrix)
+ assert isinstance(b, Matrix)
+ assert a.size() == b.size()
+
+ return Matrix(a.size(),
+ a.data * b.data,
+ a.row_indices,
+ a.column_indices,
+ a.offsets,
+ a.column_indices_t,
+ a.offsets_t,
+ a.block_offsets_t)
diff --git a/build/torch29-cxx11-cu126-aarch64-linux/stk/ops/eltwise_ops_test.py b/build/torch29-cxx11-cu126-aarch64-linux/stk/ops/eltwise_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..66bfd4f6af77042d3c5bdb1fe18d00e457478d46
--- /dev/null
+++ b/build/torch29-cxx11-cu126-aarch64-linux/stk/ops/eltwise_ops_test.py
@@ -0,0 +1,86 @@
+import unittest
+import itertools
+import torch
+from absl.testing import parameterized
+
+import stk
+from stk.ops.linear_ops_test import allclose, _dense_and_sparse
+
+_MATRIX_SIZES = (
+ (128, 128, 0.0),
+ (256, 256, 0.5),
+ (2048, 1024, 0.8),
+ (512, 128, 0.0),
+ (128, 512, 0.0),
+ (1024, 512, 0.0),
+ (1024, 512, 0.5),
+ (1024, 512, 0.75),
+ (512, 1024, 0.0),
+ (512, 1024, 0.5),
+ (512, 1024, 0.75),
+ (1024, 1024, 0.0),
+ (1024, 1024, 0.5),
+ (1024, 1024, 0.75),
+)
+
+_DTYPE = (
+ torch.float16, torch.bfloat16
+)
+
+def _generate_testcases():
+ testcases = itertools.product(_MATRIX_SIZES, _DTYPE)
+ testcases = [(*size, 128, dtype) for
+ (size, dtype) in testcases]
+ return testcases
+
+_ELTWISE_OP_TESTS = _generate_testcases()
+
+def _dense_and_sparse_like(x, std=0.1):
+ dense_data = torch.randn_like(x.data, device=x.device) * std
+ sparse = stk.Matrix(x.size(),
+ dense_data,
+ x.row_indices,
+ x.column_indices,
+ x.offsets)
+ dense = stk.ops.to_dense(sparse)
+
+ return (dense.requires_grad_(True),
+ sparse.requires_grad_(True))
+
+@parameterized.parameters(_ELTWISE_OP_TESTS)
+class EltwiseOpsTest(parameterized.TestCase):
+
+ def testEltwiseMul(self, m, n, sparsity, blocking, dtype):
+
+ a_dense, a = _dense_and_sparse(m, n, sparsity, blocking, dtype)
+ b_dense, b = _dense_and_sparse_like(a)
+
+ out = stk.ops.mul(a, b)
+ expected_out = torch.mul(a_dense, b_dense)
+
+ # Compute the gradients w.r.t. the inputs.
+ expected_out.sum().backward()
+ stk.ops.sum(out).backward()
+
+ # Validate the results.
+ out = stk.ops.to_dense(out)
+ self.assertEqual(out.dim(), 2)
+ self.assertEqual(expected_out.size(), out.size())
+ self.assertTrue(allclose(out, expected_out))
+
+ # LHS gradient.
+ grad = stk.ops.to_dense(a.grad)
+ expected_grad = a_dense.grad
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size(), grad.size())
+ self.assertTrue(allclose(grad, expected_grad))
+
+ # RHS gradient.
+ grad = stk.ops.to_dense(b.grad)
+ expected_grad = b_dense.grad
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size(), grad.size())
+ self.assertTrue(allclose(grad, expected_grad))
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch29-cxx11-cu126-aarch64-linux/stk/ops/linear_ops.py b/build/torch29-cxx11-cu126-aarch64-linux/stk/ops/linear_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d277c8c07f9e30addc31900a12175c8a1f4d7ad
--- /dev/null
+++ b/build/torch29-cxx11-cu126-aarch64-linux/stk/ops/linear_ops.py
@@ -0,0 +1,59 @@
+import torch
+
+from ..backend import sputnik
+from ..matrix import Matrix
+
+
+def dsd(a, b):
+ assert isinstance(a, Matrix)
+ assert isinstance(b, torch.Tensor)
+ return sputnik.dsd(
+ a.size(),
+ a.data, a.offsets,
+ a.row_indices,
+ a.column_indices,
+ a.offsets_t,
+ a.column_indices_t,
+ a.block_offsets_t,
+ not a.is_contiguous(),
+ b)
+
+
+def dds(a, b):
+ assert isinstance(a, torch.Tensor)
+ assert isinstance(b, Matrix)
+ return sputnik.dds(
+ a,
+ b.size(),
+ b.data, b.offsets,
+ b.row_indices,
+ b.column_indices,
+ b.offsets_t,
+ b.column_indices_t,
+ b.block_offsets_t,
+ not b.is_contiguous())
+
+
+def sdd(a, b, topo):
+ assert isinstance(a, torch.Tensor)
+ assert isinstance(b, torch.Tensor)
+ assert isinstance(topo, Matrix)
+ assert topo.is_contiguous()
+ out = sputnik.sdd(
+ a, b,
+ topo.size(),
+ topo.data,
+ topo.offsets,
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets_t,
+ topo.column_indices_t,
+ topo.block_offsets_t)
+ return Matrix(topo.size(),
+ out,
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets,
+ topo.column_indices_t,
+ topo.offsets_t,
+ topo.block_offsets_t)
diff --git a/build/torch29-cxx11-cu126-aarch64-linux/stk/ops/linear_ops_test.py b/build/torch29-cxx11-cu126-aarch64-linux/stk/ops/linear_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..ced1d782fbc9f9ca16b3449239f1588dc5ff5e00
--- /dev/null
+++ b/build/torch29-cxx11-cu126-aarch64-linux/stk/ops/linear_ops_test.py
@@ -0,0 +1,216 @@
+import unittest
+import itertools
+import numpy as np
+import torch
+from absl.testing import parameterized
+
+import stk
+
+
+def allclose(x, y, pct=0.25):
+ mask = torch.isclose(x, y, rtol=5e-2)
+ pct_diff = (mask.numel() - mask.sum()) / mask.numel() * 100
+ if pct_diff > pct:
+ print("{:.2f}% of values not close.".format(pct_diff))
+ return False
+ return True
+
+
+# An assortment of problems designed to make sure
+# the bindings are operating correctly.
+_MATRIX_SIZES = (
+ (128, 128, 128, 0.0),
+ (256, 256, 256, 0.5),
+ (2048, 1024, 512, 0.8),
+ (512, 128, 128, 0.0),
+ (128, 128, 512, 0.0),
+ (1024, 512, 512, 0.0),
+ (1024, 512, 512, 0.5),
+ (1024, 512, 512, 0.75),
+ (512, 512, 1024, 0.0),
+ (512, 512, 1024, 0.5),
+ (512, 512, 1024, 0.75),
+ (1024, 1024, 1024, 0.0),
+ (1024, 1024, 1024, 0.5),
+ (1024, 1024, 1024, 0.75),
+)
+
+_TRANSPOSE = (
+ (False, False),
+ (False, True),
+ (True, False),
+ (True, True),
+)
+
+_DTYPE = (
+ torch.float16, torch.bfloat16
+)
+
+def _generate_testcases():
+ testcases = itertools.product(_MATRIX_SIZES, _TRANSPOSE, _DTYPE)
+ testcases = [(*size, *trans, 128, dtype) for
+ (size, trans, dtype) in testcases]
+ return testcases
+
+_LINEAR_OP_TESTS = _generate_testcases()
+
+def _dense_and_sparse(rows, cols, sparsity, blocking, dtype, std=0.1):
+ mask = stk.random.dense_mask(rows, cols, sparsity, blocking)
+ dense = (torch.randn(rows, cols) * std * mask).type(dtype)
+ sparse = stk.ops.to_sparse(dense, blocking)
+ cuda_device = torch.device("cuda")
+ return (dense.to(cuda_device).requires_grad_(True),
+ sparse.to(cuda_device).requires_grad_(True))
+
+
+def _dense(rows, cols, dtype, std=0.1):
+ cuda_device = torch.device("cuda")
+ out = (torch.randn(rows, cols) * std).type(dtype)
+ return out.to(cuda_device).requires_grad_(True)
+
+
+def _dense_2x(rows, cols, dtype):
+ a = _dense(rows, cols, dtype)
+ return a, a.detach().requires_grad_(True)
+
+
+def _with_transpose(op, a, b, trans_a, trans_b):
+ a = a.t() if trans_a else a
+ b = b.t() if trans_b else b
+ return op(a, b)
+
+
+def _mmm(a, b, topo):
+ mask = stk.ops.to_dense(stk.ops.ones_like(topo))
+ return torch.mm(a, b) * mask
+
+
+def _sparse_out_with_transpose(op, a, b, topo, trans_a, trans_b):
+ a = a.t() if trans_a else a
+ b = b.t() if trans_b else b
+ return op(a, b, topo)
+
+
+def _mask(x, mask):
+ mask = stk.ops.to_dense(stk.ops.ones_like(mask))
+ return x * mask
+
+
+@parameterized.parameters(*_LINEAR_OP_TESTS)
+class LinearOpsTest(parameterized.TestCase):
+
+ def testLinearOps_Dsd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+ # Construct the operands.
+ a_shape = (k, m) if trans_a else (m, k)
+ a_dense, a = _dense_and_sparse(*a_shape, sparsity, blocking, dtype)
+ b_shape = (n, k) if trans_b else (k, n)
+ b, bcp = _dense_2x(*b_shape, dtype)
+
+ # Execute the matmul.
+ out = _with_transpose(stk.ops.dsd, a, b, trans_a, trans_b)
+ expected_out = _with_transpose(torch.mm, a_dense, bcp, trans_a, trans_b)
+
+ # Compute the gradients w.r.t. the inputs.
+ expected_out.sum().backward()
+ out.sum().backward()
+
+ # Validate the results.
+ self.assertEqual(out.dim(), 2)
+ self.assertEqual(expected_out.size()[0], out.size()[0])
+ self.assertEqual(expected_out.size()[1], out.size()[1])
+ self.assertTrue(allclose(out, expected_out))
+
+ # LHS gradient.
+ grad = stk.ops.to_dense(a.grad)
+ expected_grad = _mask(a_dense.grad, a.grad)
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
+ self.assertTrue(allclose(grad, expected_grad))
+
+ # RHS gradient.
+ grad = b.grad
+ expected_grad = bcp.grad
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
+ self.assertTrue(allclose(grad, expected_grad))
+
+ def testLinearOps_Dds(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+ # Construct the operands.
+ a_shape = (k, m) if trans_a else (m, k)
+ a, acp = _dense_2x(*a_shape, dtype)
+ b_shape = (n, k) if trans_b else (k, n)
+ b_dense, b = _dense_and_sparse(*b_shape, sparsity, blocking, dtype)
+
+ # Execute the matmul.
+ out = _with_transpose(stk.ops.dds, a, b, trans_a, trans_b)
+ expected_out = _with_transpose(torch.mm, acp, b_dense, trans_a, trans_b)
+
+ # Compute the gradients w.r.t. the inputs.
+ expected_out.sum().backward()
+ out.sum().backward()
+
+ # Validate the results.
+ self.assertEqual(out.dim(), 2)
+ self.assertEqual(expected_out.size()[0], out.size()[0])
+ self.assertEqual(expected_out.size()[1], out.size()[1])
+ self.assertTrue(allclose(out, expected_out))
+
+ # LHS gradient.
+ grad = a.grad
+ expected_grad = acp.grad
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
+ self.assertTrue(allclose(grad, expected_grad))
+
+ # RHS gradient.
+ grad = stk.ops.to_dense(b.grad)
+ expected_grad = _mask(b_dense.grad, b.grad)
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
+ self.assertTrue(allclose(grad, expected_grad))
+
+ def testLinearOps_Sdd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+ # Construct the operands.
+ a_shape = (k, m) if trans_a else (m, k)
+ a, acp = _dense_2x(*a_shape, dtype)
+ b_shape = (n, k) if trans_b else (k, n)
+ b, bcp = _dense_2x(*b_shape, dtype)
+ _, topo = _dense_and_sparse(m, n, sparsity, blocking, dtype)
+
+ # Execute the matmul.
+ out = _sparse_out_with_transpose(stk.ops.sdd, a, b, topo, trans_a, trans_b)
+ expected_out = _sparse_out_with_transpose(_mmm, acp, bcp, topo, trans_a, trans_b)
+
+ # Compute the gradients w.r.t. the inputs.
+ expected_out.sum().backward()
+ stk.ops.sum(out).backward()
+
+ # Validate the results.
+ out = stk.ops.to_dense(out)
+ self.assertEqual(out.dim(), 2)
+ self.assertEqual(expected_out.size()[0], out.size()[0])
+ self.assertEqual(expected_out.size()[1], out.size()[1])
+ self.assertTrue(allclose(out, expected_out))
+
+ # LHS gradient.
+ grad = a.grad
+ expected_grad = acp.grad
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
+ self.assertTrue(allclose(grad, expected_grad))
+
+ # RHS gradient.
+ grad = b.grad
+ expected_grad = bcp.grad
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
+ self.assertTrue(allclose(grad, expected_grad))
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch29-cxx11-cu126-aarch64-linux/stk/ops/matrix_ops.py b/build/torch29-cxx11-cu126-aarch64-linux/stk/ops/matrix_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..447c72dc73439d84f58c917676cc04e64f13e97d
--- /dev/null
+++ b/build/torch29-cxx11-cu126-aarch64-linux/stk/ops/matrix_ops.py
@@ -0,0 +1,98 @@
+from ..backend import sputnik
+from ..matrix import Matrix
+import torch
+import numpy as np
+
+
+@torch.no_grad()
+def row_indices(shape, data, offsets, column_indices):
+ return sputnik.row_indices(shape, data, offsets, column_indices)
+
+
+# TODO(tgale): Replace this helper with a custom kernel. This operation
+# is much simpler to do than how it's currently implemented.
+@torch.no_grad()
+def _expand_for_blocking(idxs, blocking):
+ # Duplicate for block column dimension.
+ idxs = torch.reshape(idxs, [idxs.size()[0], 1, 2]).repeat(1, blocking, 1)
+
+ # Update the column indices.
+ idxs[:, :, 1] *= blocking
+ idxs[:, :, 1] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking])
+
+ # Duplicate for block row dimension.
+ idxs = torch.reshape(idxs, [idxs.size()[0], 1, blocking, 2])
+ idxs = idxs.repeat(1, blocking, 1, 1)
+
+ # Update the row indices.
+ idxs[:, :, :, 0] *= blocking
+ idxs[:, :, :, 0] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking, 1])
+ idxs = torch.reshape(idxs, [-1, 2])
+ return idxs
+
+
+# TODO(tgale): Add input type checking.
+@torch.no_grad()
+def to_dense(x):
+ assert isinstance(x, Matrix)
+
+ shape = (np.prod(x.shape[:-1]), x.shape[-1])
+ row_idxs = x.row_indices.type(torch.int32)
+ col_idxs = x.column_indices.type(torch.int32)
+ indices = _expand_for_blocking(torch.stack([row_idxs, col_idxs], dim=1), x.blocking)
+ indices = (indices[:, 0] * shape[1] + indices[:, 1]).type(torch.int64)
+
+ out = torch.zeros(shape[0] * shape[1], dtype=x.dtype, device=x.device)
+ out.scatter_(0, indices, x.data.flatten())
+ return out.reshape(x.size())
+
+
+@torch.no_grad()
+def _mask(x, blocking=1):
+ assert x.dim() == 2
+ assert x.size()[0] % blocking == 0
+ assert x.size()[1] % blocking == 0
+ block_rows = x.size()[0] // blocking
+ block_cols = x.size()[1] // blocking
+ x = torch.reshape(x, [block_rows, blocking, block_cols, blocking])
+ x = torch.sum(torch.abs(x), dim=(1, 3))
+ return x != 0
+
+
+# TODO(tgale): Add input type checking.
+@torch.no_grad()
+def to_sparse(x, blocking=1):
+ m = _mask(x, blocking)
+
+ # TODO(tgale): Set to appropriate type for input matrix.
+ row_nnzs = torch.sum(m, dim=1).type(torch.int32)
+ zeros = torch.zeros((1,), dtype=row_nnzs.dtype, device=row_nnzs.device)
+ offsets = torch.cat([zeros, torch.cumsum(row_nnzs, dim=0)])
+ offsets = offsets.type(torch.int32)
+
+ indices = torch.nonzero(m).type(torch.int16)
+ row_indices = indices[:, 0]
+ column_indices = indices[:, 1]
+
+ # Nonzero indices in the dense matrix.
+ nonzero_indices = torch.nonzero(m)
+ nonzero_indices = _expand_for_blocking(nonzero_indices, blocking)
+ nonzero_indices = nonzero_indices[:, 0] * x.size()[1] + nonzero_indices[:, 1]
+
+ # Gather the data and construct the sparse matrix.
+ data = torch.gather(x.flatten(), dim=0, index=nonzero_indices)
+ data = torch.reshape(data, [-1, blocking, blocking])
+ return Matrix(x.size(), data, row_indices, column_indices, offsets)
+
+
+@torch.no_grad()
+def ones_like(x):
+ return Matrix(x.size(),
+ torch.ones_like(x.data),
+ x.row_indices,
+ x.column_indices, x.offsets)
+
+
+def sum(x):
+ assert isinstance(x, Matrix)
+ return x.data.sum()
diff --git a/build/torch29-cxx11-cu126-aarch64-linux/stk/ops/matrix_ops_test.py b/build/torch29-cxx11-cu126-aarch64-linux/stk/ops/matrix_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..3af04c0760483e578f93303dc457415948a2a34c
--- /dev/null
+++ b/build/torch29-cxx11-cu126-aarch64-linux/stk/ops/matrix_ops_test.py
@@ -0,0 +1,62 @@
+import unittest
+
+from absl.testing import parameterized
+import stk
+import torch
+
+
+@parameterized.parameters(
+ (8, 16, 0.0, 1),
+ (8, 16, 0.5, 1),
+ (8, 16, .95, 1),
+ (16, 8, 0.0, 1),
+ (16, 8, 0.5, 1),
+ (16, 8, .95, 1),
+ (8, 16, 0.0, 8),
+ (8, 16, 0.5, 8),
+ (8, 16, 1.0, 8),
+ (16, 8, 0.0, 8),
+ (16, 8, 0.5, 8),
+ (16, 8, 1.0, 8),
+ (128, 256, 0.5, 16),
+ (256, 128, 0.75, 32),
+ (512, 512, .875, 128))
+class MatrixOpsTest(parameterized.TestCase):
+
+ def testMatrixOps_FormatConversion(self, rows, cols, sparsity, blocking):
+ mask = stk.random.dense_mask(rows, cols, sparsity, blocking)
+ x = (torch.randn(rows, cols) * mask).type(torch.float16)
+
+ # Convert the matrix to sparse format.
+ sparse_x = stk.ops.to_sparse(x, blocking)
+
+ # Validate the matrix.
+ sparse_x.validate()
+
+ # Validate the shape.
+ self.assertEqual(sparse_x.dim(), 2)
+ self.assertEqual(sparse_x.size()[0], rows)
+ self.assertEqual(sparse_x.size()[1], cols)
+
+ # Validate the sparsity.
+ numblocks = rows // blocking * cols // blocking
+ nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+ self.assertEqual(sparse_x.nnz, nnz)
+
+ # Convert back to dense format.
+ dense_x = stk.ops.to_dense(sparse_x)
+
+ # Validate the shape.
+ self.assertEqual(dense_x.dim(), 2)
+ self.assertEqual(dense_x.size()[0], rows)
+ self.assertEqual(dense_x.size()[1], cols)
+
+ # Validate the sparsity
+ self.assertEqual(torch.count_nonzero(dense_x).item(), nnz)
+
+ # Validate the output.
+ self.assertTrue(torch.all(torch.eq(x, dense_x)))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch29-cxx11-cu126-aarch64-linux/stk/random/__init__.py b/build/torch29-cxx11-cu126-aarch64-linux/stk/random/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2576d1ca27283f77569a9a620c7c99fa68aaf30e
--- /dev/null
+++ b/build/torch29-cxx11-cu126-aarch64-linux/stk/random/__init__.py
@@ -0,0 +1,2 @@
+# from stk.random.random_ops import dense_mask, mask, randn
+from .random_ops import dense_mask, mask, randn
diff --git a/build/torch29-cxx11-cu126-aarch64-linux/stk/random/random_ops.py b/build/torch29-cxx11-cu126-aarch64-linux/stk/random/random_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1b36771e0eb8e7abf46bcb3b136b5fb1d29df93
--- /dev/null
+++ b/build/torch29-cxx11-cu126-aarch64-linux/stk/random/random_ops.py
@@ -0,0 +1,36 @@
+import numpy as np
+import torch
+from ..ops import matrix_ops
+
+
+@torch.no_grad()
+def dense_mask(rows, cols, sparsity, blocking=1):
+ assert sparsity >= 0.0 and sparsity <= 1.0
+ assert rows % blocking == 0 and cols % blocking == 0
+
+ block_rows, block_cols = (rows // blocking, cols // blocking)
+ nnz = round(block_rows * block_cols * (1 - sparsity))
+
+ out = np.ones(block_rows * block_cols)
+ mask = np.random.choice(out.size, out.size - nnz, replace=False)
+ out[mask] = 0.0
+
+ out = np.tile(
+ np.reshape(out, [block_rows, 1, block_cols, 1]),
+ (1, blocking, 1, blocking))
+ out = np.reshape(out, [rows, cols])
+ return torch.from_numpy(out.astype(np.float32))
+
+
+@torch.no_grad()
+def mask(m, n, sparsity, blocking=1):
+ out = dense_mask(m, n, sparsity, blocking).type(torch.float16)
+ return matrix_ops.to_sparse(out, blocking=blocking)
+
+
+@torch.no_grad()
+def randn(shape, sparsity, blocking=1):
+ shape_2d = (np.prod(shape[:-1]), shape[-1])
+ out = mask(*shape_2d, sparsity, blocking)
+ out.data.copy_(torch.randn(*out.data.shape))
+ return out.view(*shape)
diff --git a/build/torch29-cxx11-cu126-aarch64-linux/stk/random/random_ops_test.py b/build/torch29-cxx11-cu126-aarch64-linux/stk/random/random_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..587b44ec890c861879c6296b8f9028f5d99ab82f
--- /dev/null
+++ b/build/torch29-cxx11-cu126-aarch64-linux/stk/random/random_ops_test.py
@@ -0,0 +1,73 @@
+import unittest
+
+from absl.testing import parameterized
+from . import random
+import torch
+
+
+@parameterized.parameters(
+ (8, 16, 0.0, 1),
+ (8, 16, 0.5, 1),
+ (8, 16, .95, 1),
+ (16, 8, 0.0, 1),
+ (16, 8, 0.5, 1),
+ (16, 8, .95, 1),
+ (8, 16, 0.0, 8),
+ (8, 16, 0.5, 8),
+ (8, 16, 1.0, 8),
+ (16, 8, 0.0, 8),
+ (16, 8, 0.5, 8),
+ (16, 8, 1.0, 8),
+ (128, 256, 0.5, 16),
+ (256, 128, 0.75, 32),
+ (512, 512, .875, 128))
+class RandomOpsTest(parameterized.TestCase):
+
+ def testRandomOps_DenseMask(self, rows, cols, sparsity, blocking):
+ mask = random.dense_mask(
+ rows, cols, sparsity, blocking)
+
+ # Validate the shape.
+ self.assertEqual(mask.dim(), 2)
+ self.assertEqual(mask.size()[0], rows)
+ self.assertEqual(mask.size()[1], cols)
+
+ # Validate the sparsity
+ numblocks = rows // blocking * cols // blocking
+ nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+ self.assertEqual(
+ torch.count_nonzero(mask).item(),
+ nnz)
+
+ # Check values are zero or one.
+ self.assertTrue(
+ torch.all(torch.logical_or(
+ torch.eq(mask, 0),
+ torch.eq(mask, 1))))
+
+ def testRandomOps_SparseMask(self, rows, cols, sparsity, blocking):
+ mask = random.mask(
+ rows, cols, sparsity, blocking)
+
+ # Validate the matrix.
+ mask.validate()
+
+ # Validate the shape.
+ self.assertEqual(mask.dim(), 2)
+ self.assertEqual(mask.size()[0], rows)
+ self.assertEqual(mask.size()[1], cols)
+
+ # Validate the sparsity.
+ numblocks = rows // blocking * cols // blocking
+ nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+ self.assertEqual(mask.nnz, nnz)
+
+ # Check values are zero or one.
+ self.assertTrue(
+ torch.all(torch.logical_or(
+ torch.eq(mask.data, 0),
+ torch.eq(mask.data, 1))))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch29-cxx11-cu126-aarch64-linux/xpu_fused_moe.py b/build/torch29-cxx11-cu126-aarch64-linux/xpu_fused_moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..0440b484e6f4073a384dcf3bc562174d2ef71d25
--- /dev/null
+++ b/build/torch29-cxx11-cu126-aarch64-linux/xpu_fused_moe.py
@@ -0,0 +1,672 @@
+# SPDX-License-Identifier: Apache-2.0
+# MegaBlocks XPU Fused MoE Implementation
+import os
+import torch
+
+from ._ops import ops, add_op_namespace_prefix
+
+from torch.library import register_fake
+
+
+def resolve_dtensor(weight: torch.Tensor):
+ """Convert DTensor to local tensor for use with custom ops."""
+ from torch.distributed._tensor import DTensor
+ if isinstance(weight, DTensor):
+ return weight.to_local()
+ return weight
+
+
+# Register fake/meta kernels for torch.compile compatibility
+def _register_xpu_fake_kernels():
+ """Register fake kernels for XPU MoE operations to support torch.compile."""
+
+ def _register_if_available(op_name, fn):
+ if hasattr(ops, op_name):
+ register_fake(add_op_namespace_prefix(op_name))(fn)
+
+ _register_if_available(
+ "cutlass_grouped_gemm_interface",
+ lambda ptr_A, ptr_B, ptr_scales, ptr_bias, ptr_D, expert_first_token_offset, N, K, num_experts, is_B_int4, is_B_mxfp4: ptr_D,
+ )
+
+ _register_if_available(
+ "fused_moe_prologue",
+ lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, ep_rank, ep_size, num_experts_on_rank: None,
+ )
+
+ _register_if_available(
+ "moe_gather",
+ lambda output, moe_output, topk_weights, permuted_row_to_unpermuted_row, unpermuted_row_to_permuted_row, expert_first_token_offset, num_experts: None,
+ )
+
+ _register_if_available(
+ "silu_and_mul",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "mul_and_silu",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_and_mul",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_tanh_and_mul",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_fast",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_new",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_quick",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "swigluoai_and_mul",
+ lambda out, input, alpha=1.702, limit=7.0: None,
+ )
+
+
+# Register fake kernels on module load
+_register_xpu_fake_kernels()
+
+
+# default
+def cutlass_grouped_gemm(input_A, input_B, bias, output, expert_token_count, n,
+ k, num_experts):
+ # expert_token_count_ = torch.tensor(expert_token_count,
+ # dtype=torch.int64,
+ # device=input_A.device)
+ # if bias is not None:
+ # bias = bias.repeat_interleave(expert_token_count_, dim=0).float()
+
+ def exclusive_prefix_sum(arr):
+ prefix = [0]
+ for i, x in enumerate(arr):
+ prefix.append(prefix[-1] + x)
+ return prefix
+
+ expert_offset = torch.tensor(exclusive_prefix_sum(expert_token_count),
+ dtype=torch.int64,
+ device="xpu")
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=None,
+ ptr_bias=bias,
+ ptr_D=output,
+ expert_first_token_offset=expert_offset,
+ N=n,
+ K=k,
+ num_experts=num_experts,
+ is_B_int4=False,
+ is_B_mxfp4=False)
+
+
+def cutlass_grouped_gemm_xe2(input_A, input_B, scales, bias, output,
+ num_rows_per_expert, n, k, num_experts, is_B_int4,
+ is_B_mxfp4):
+ expert_first_token_offset = torch.cat([
+ torch.tensor([0],
+ dtype=num_rows_per_expert.dtype,
+ device=num_rows_per_expert.device),
+ torch.cumsum(num_rows_per_expert, dim=0)
+ ]).to(torch.int64)
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=scales,
+ ptr_bias=bias,
+ ptr_D=output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=n,
+ K=k,
+ num_experts=num_experts,
+ is_B_int4=is_B_int4,
+ is_B_mxfp4=is_B_mxfp4)
+
+
+def ceilDiv(a, b):
+ return (a + b - 1) // b
+
+
+def compute_num_tokens_per_block(num_tokens, num_experts_per_node):
+ for num_tokens_per_block in [32, 64, 128, 256, 512, 1024]:
+ num_blocks_per_seq = ceilDiv(num_tokens, num_tokens_per_block)
+ if num_blocks_per_seq * num_experts_per_node <= num_tokens_per_block:
+ return num_tokens_per_block
+ return 1024
+
+
+def _bytes_to_typed_tensor(byte_tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
+ """Reinterpret a uint8 buffer as a typed tensor by copying bytes.
+
+ This avoids `Tensor.view(dtype)` which can fail under torch.compile
+ constant folding when shape divisibility is not proven.
+ """
+ if byte_tensor.dtype != torch.uint8:
+ raise ValueError("byte_tensor must be uint8")
+ itemsize = torch.empty((), dtype=dtype).element_size()
+ numel = byte_tensor.numel() // itemsize
+ out = torch.empty((numel,), dtype=dtype, device=byte_tensor.device)
+ out.view(torch.uint8).copy_(byte_tensor.contiguous())
+ return out
+
+
+def implement_zp(qweight):
+ # change u4 to s4 to avoid zero point in gemm kernel
+ # only support default zero point now
+ assert qweight.dtype == torch.uint8, "Input tensor must be uint8"
+
+ high_u4 = (qweight >> 4) & 0x0F
+ low_u4 = qweight & 0x0F
+
+ high_s8 = high_u4.to(torch.int8)
+ low_s8 = low_u4.to(torch.int8)
+
+ high_s8 = high_s8 - 8
+ low_s8 = low_s8 - 8
+
+ def pack_compact(a, b):
+
+ def process_number(x):
+ sign = (x < 0).to(torch.uint8)
+ abs_low3 = (x.view(torch.uint8) & 0x7).to(torch.uint8)
+ return (sign << 3) | abs_low3
+
+ packed_a = process_number(a)
+ packed_b = process_number(b)
+
+ return (packed_a << 4) | packed_b
+
+ result = pack_compact(high_s8, low_s8)
+
+ return result
+
+
+def xpu_fused_moe(hidden_states,
+ w13,
+ w13_scales,
+ w13_bias,
+ w2,
+ w2_scales,
+ w2_bias,
+ topk_weights,
+ topk_ids,
+ n_experts_per_token,
+ activation,
+ num_experts,
+ ep_rank=0,
+ ep_size=1,
+ is_fp8=False,
+ is_int4=False,
+ is_mxfp4=False):
+ '''
+ hidden_states: [num_rows, hidden_size]
+ w13: [num_experts, 2*inter_size, hidden_size]
+ w13_scales:
+ None for bf16/fp16
+ or [num_experts] for fp8
+ or [num_experts, 2*inter_size, hidden_size // group_size] for 4bits
+ w13_bias: [num_experts, 2*inter_size] or None
+ w2: [num_experts, hidden_size, inter_size]
+ w2_scales:
+ None for bf16/fp16
+ or [num_experts] for fp8
+ or [num_experts, hidden_size, inter_size // group_size] for 4bits
+ w2_bias: [num_experts, hidden_size] or None
+ topk_weights: [num_rows, topk]
+ topk_ids: [num_rows, topk]
+ n_experts_per_token: int
+ activation: str
+ num_experts: int
+ is_int4: bool
+ is_mxfp4: bool
+ '''
+
+ # Resolve DTensors to local tensors before passing to custom ops
+ hidden_states = resolve_dtensor(hidden_states)
+ w13 = resolve_dtensor(w13)
+ w2 = resolve_dtensor(w2)
+ if w13_scales is not None:
+ w13_scales = resolve_dtensor(w13_scales)
+ if w13_bias is not None:
+ w13_bias = resolve_dtensor(w13_bias)
+ if w2_scales is not None:
+ w2_scales = resolve_dtensor(w2_scales)
+ if w2_bias is not None:
+ w2_bias = resolve_dtensor(w2_bias)
+ topk_weights = resolve_dtensor(topk_weights)
+ topk_ids = resolve_dtensor(topk_ids)
+
+ output = torch.empty_like(hidden_states)
+ num_rows, hidden_size = list(hidden_states.shape)
+
+ dim_last = w13.shape[-1]
+ dim_second_last = w13.shape[-2]
+
+ # w13 is combined gate+up weights, so one dimension is 2*inter_size
+ # Determine which dimension is hidden_size and which is 2*inter_size
+ if dim_second_last == hidden_size:
+ # w13 is [E, hidden_size, 2*inter_size] - standard layout
+ inter_size = dim_last // 2
+ needs_transpose = False
+ else:
+ # w13 is [E, 2*inter_size, hidden_size] - needs transpose
+ inter_size = dim_second_last // 2
+ needs_transpose = True
+
+ assert w13.is_contiguous() and w2.is_contiguous()
+
+ # 4bits support [E, N, K]
+ # other types [E, K, N]
+ if not is_int4 and not is_mxfp4:
+ if not hasattr(w13, 'xpu_fused_moe'):
+ if needs_transpose:
+ w13.data = w13.transpose(-1, -2).contiguous()
+ w2.data = w2.transpose(-1, -2).contiguous()
+ w13.xpu_fused_moe = True
+ w13.inter_size = inter_size
+ else:
+ inter_size = w13.inter_size
+
+ if is_int4 and not hasattr(w13, 'xpu_fused_moe'):
+ w13_tmp = torch.empty_like(w13)
+ w2_tmp = torch.empty_like(w2)
+ for i in range(num_experts):
+ w13_tmp[i] = implement_zp(w13[i])
+ w2_tmp[i] = implement_zp(w2[i])
+ w13_tmp = w13_tmp.contiguous()
+ w2_tmp = w2_tmp.contiguous()
+ w13.data = w13_tmp
+ w2.data = w2_tmp
+ w13.xpu_fused_moe = True
+
+ # TODO: will all integrated in Cpp func. Temporary expose before gemm fusion
+ num_experts_per_node = num_experts
+ experts_per_token = n_experts_per_token
+ num_moe_inputs = n_experts_per_token * num_rows
+ permuted_elems = num_moe_inputs * hidden_size
+ # interbuf_elems = num_moe_inputs * inter_size
+ permuted_row_to_unpermuted_row_size = num_moe_inputs * 4
+ permuted_token_selected_experts_size = num_moe_inputs * 4
+ src_to_dest_map_size = experts_per_token * num_rows * 4
+ expert_first_token_offset_size = (num_experts_per_node + 1) * 8
+ num_tokens_per_block = compute_num_tokens_per_block(
+ num_rows, num_experts_per_node)
+ num_blocks_per_seq = ceilDiv(num_rows, num_tokens_per_block)
+ blocked_expert_counts_size = num_experts_per_node * num_blocks_per_seq * 4
+ blocked_expert_counts_cumsum_size = blocked_expert_counts_size
+ blocked_row_to_unpermuted_row_size = num_experts_per_node * num_rows * 4
+ permuted_data_size = permuted_elems * hidden_states.element_size()
+ permuted_token_final_scales_size = num_moe_inputs * 4
+
+ ws_map = {}
+ map_offset = 0
+
+ def config_ws(name, size):
+ nonlocal map_offset
+ if size % 256 != 0:
+ size += 256 - size % 256
+ ws_map[name] = (size, map_offset)
+ map_offset += size
+
+ config_ws("permuted_row_to_unpermuted_row",
+ permuted_row_to_unpermuted_row_size)
+ config_ws("permuted_token_selected_experts",
+ permuted_token_selected_experts_size)
+ config_ws("unpermuted_row_to_permuted_row", src_to_dest_map_size)
+ config_ws("blocked_expert_counts", blocked_expert_counts_size)
+ config_ws("blocked_expert_counts_cumsum",
+ blocked_expert_counts_cumsum_size)
+ config_ws("blocked_row_to_unpermuted_row",
+ blocked_row_to_unpermuted_row_size)
+ config_ws("expert_first_token_offset", expert_first_token_offset_size)
+ config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
+ config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
+
+ workspace = torch.zeros(map_offset,
+ dtype=torch.uint8,
+ device=hidden_states.device)
+ if topk_ids.dtype == torch.int32:
+ topk_ids = topk_ids.to(torch.int64)
+ ops.fused_moe_prologue(
+ input=hidden_states,
+ token_selected_experts=topk_ids,
+ token_final_scales=topk_weights,
+ workspace=workspace,
+ hidden_size=hidden_size,
+ inter_size=inter_size,
+ ep_rank=ep_rank,
+ ep_size=ep_size,
+ num_experts_on_rank=num_experts_per_node)
+
+ expert_first_token_offset_bytes = workspace[
+ ws_map["expert_first_token_offset"][1]:
+ ws_map["expert_first_token_offset"][1] +
+ expert_first_token_offset_size]
+ unpermuted_row_to_permuted_row_bytes = workspace[
+ ws_map["unpermuted_row_to_permuted_row"][1]:
+ ws_map["unpermuted_row_to_permuted_row"][1] +
+ src_to_dest_map_size]
+ permuted_row_to_unpermuted_row_bytes = workspace[
+ ws_map["permuted_row_to_unpermuted_row"][1]:
+ ws_map["permuted_row_to_unpermuted_row"][1] +
+ permuted_row_to_unpermuted_row_size]
+
+ if torch.compiler.is_compiling():
+ expert_first_token_offset = _bytes_to_typed_tensor(
+ expert_first_token_offset_bytes, torch.int64
+ )
+ unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
+ unpermuted_row_to_permuted_row_bytes, torch.int32
+ )
+ permuted_row_to_unpermuted_row = _bytes_to_typed_tensor(
+ permuted_row_to_unpermuted_row_bytes, torch.int32
+ )
+ else:
+ expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
+ unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
+ permuted_row_to_unpermuted_row = permuted_row_to_unpermuted_row_bytes.view(torch.int32)
+ gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
+ ws_map["overlapped_gemm1_gemm2_inputs"][1] +
+ permuted_data_size].view(hidden_states.dtype).view(
+ num_moe_inputs, hidden_size)
+ # permuted_token_final_scales = workspace[
+ # ws_map["permuted_token_final_scales"][1]:
+ # ws_map["permuted_token_final_scales"][1] +
+ # permuted_token_final_scales_size].view(torch.float)
+ gemm1_output = torch.empty((num_moe_inputs, 2 * inter_size),
+ dtype=hidden_states.dtype,
+ device=hidden_states.device)
+
+ ########### gemm1 ##################
+ input_B = w13
+
+ if not is_fp8 and not is_int4 and not is_mxfp4:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=gemm1_input,
+ ptr_B=input_B,
+ ptr_scales=None,
+ ptr_bias=w13_bias,
+ ptr_D=gemm1_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=2 * inter_size,
+ K=hidden_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+ else:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=gemm1_input,
+ ptr_B=input_B,
+ ptr_scales=w13_scales,
+ ptr_bias=w13_bias,
+ ptr_D=gemm1_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=2 * inter_size,
+ K=hidden_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+
+ # act
+ act_output = torch.empty((num_moe_inputs, inter_size),
+ dtype=gemm1_output.dtype,
+ device=gemm1_output.device)
+ if activation == "silu":
+ ops.silu_and_mul(act_output, gemm1_output)
+ elif activation == "gelu":
+ ops.gelu_and_mul(act_output, gemm1_output)
+ elif activation == "swigluoai":
+ ops.swigluoai_and_mul(act_output, gemm1_output, 1.702, 7.0)
+ else:
+ raise ValueError(f"Unsupported FusedMoe activation: {activation}.")
+
+ ########### gemm2 ##################
+ input_A = act_output.contiguous()
+ input_B = w2
+ gemm2_output = torch.empty((num_moe_inputs, hidden_size),
+ dtype=hidden_states.dtype,
+ device=hidden_states.device)
+ if not is_fp8 and not is_int4 and not is_mxfp4:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=None,
+ ptr_bias=w2_bias,
+ ptr_D=gemm2_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=hidden_size,
+ K=inter_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+ else:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=w2_scales,
+ ptr_bias=w2_bias,
+ ptr_D=gemm2_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=hidden_size,
+ K=inter_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+
+ ops.moe_gather(output, gemm2_output, topk_weights,
+ permuted_row_to_unpermuted_row,
+ unpermuted_row_to_permuted_row,
+ expert_first_token_offset,
+ num_experts_per_node)
+ return output
+
+
+def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor:
+ """Apply jitter to the input tensor for regularization."""
+ low = 1.0 - moe_jitter_eps
+ high = 1.0 + moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return x * (low + noise * (high - low))
+
+
+def compute_top_k(scores: torch.Tensor, moe_top_k: int):
+ """Compute the top-k scores from the logits."""
+ if moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, moe_top_k, dim=-1)
+
+
+def route_tokens_xpu(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: torch.Tensor,
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ training: bool = False,
+) -> tuple:
+ """Route tokens to experts and compute expert weights and indices (XPU version)."""
+ if training and moe_jitter_eps is not None:
+ x = apply_jitter(x, moe_jitter_eps)
+
+ x_flat = x.view(-1, x.shape[-1])
+ logits = torch.nn.functional.linear(x_flat, router_weight, router_bias)
+ expert_weights, expert_indices = compute_top_k(logits, moe_top_k)
+ expert_weights = expert_weights.softmax(dim=-1)
+ if moe_normalize_expert_weights is not None:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+
+ return logits, expert_weights, expert_indices
+
+
+def _get_device_mesh(model):
+ """Extract device_mesh from child's unused pre_hook closure for EP support."""
+ try:
+ hook = next(
+ h
+ for h in model.experts._forward_pre_hooks.values()
+ if "device_mesh" in h.__code__.co_freevars
+ )
+ return hook.__closure__[
+ hook.__code__.co_freevars.index("device_mesh")
+ ].cell_contents
+ except Exception:
+ return None
+
+
+class MegaBlocksMoeMLP(torch.nn.Module):
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> tuple:
+ """
+ Forward pass through the MoE layer.
+
+ Args:
+ x: Input tensor of shape [batch_size, seq_len, hidden_size] or [tokens, hidden_size]
+
+ Returns:
+ Tuple of (output, expert_weights) where:
+ - output: Tensor of same shape as input
+ - expert_weights: Expert weights for each token [tokens, top_k]
+ """
+
+ # Get MoE parameters from the wrapped modules
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+
+ # Get EP (Expert Parallelism) parameters
+ ep_size = 1
+ ep_rank = 0
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = _get_device_mesh(self)
+ if device_mesh is not None:
+ expert_parallel_group = device_mesh.get_group()
+ if expert_parallel_group is not None:
+ import torch.distributed as dist
+ if dist.is_initialized():
+ ep_size = dist.get_world_size(expert_parallel_group)
+ ep_rank = dist.get_rank(expert_parallel_group)
+
+ # Number of experts on this rank
+ num_experts_on_rank = moe_num_experts // ep_size
+
+ # Detect activation type - check for GptOss-style swigluoai activation
+ # GptOssExperts has alpha and limit attributes for swigluoai
+ if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
+ activation = "swigluoai"
+ else:
+ activation = getattr(self.experts, "activation", "silu")
+
+ # Get weight tensors - support different naming conventions
+ if hasattr(self.experts, "gate_up_proj"):
+ w13 = self.experts.gate_up_proj
+ # NOTE: swigluoai_and_mul kernel expects interleaved layout [g0,u0,g1,u1,...]
+ # which matches GptOss's gate_up_proj layout, so no conversion needed.
+
+ elif hasattr(self.experts, "w1"):
+ # Combine w1 and w3 if stored separately
+ w1 = self.experts.w1
+ w3 = getattr(self.experts, "w3", None)
+ if w3 is not None:
+ w13 = torch.cat([w1, w3], dim=-2)
+ else:
+ w13 = w1
+ else:
+ raise AttributeError("experts module must have 'gate_up_proj' or 'w1' attribute")
+
+ if hasattr(self.experts, "down_proj"):
+ w2 = self.experts.down_proj
+ elif hasattr(self.experts, "w2"):
+ w2 = self.experts.w2
+ else:
+ raise AttributeError("experts module must have 'down_proj' or 'w2' attribute")
+
+ # Get optional bias tensors
+ w13_bias = getattr(self.experts, "gate_up_proj_bias", None)
+ w2_bias = getattr(self.experts, "down_proj_bias", None)
+
+ # Get quantization info
+ is_fp8 = getattr(self.experts, "is_fp8", False)
+ is_int4 = getattr(self.experts, "is_int4", False)
+ is_mxfp4 = getattr(self.experts, "is_mxfp4", False)
+
+ w13_scales = getattr(self.experts, "gate_up_proj_scales", None)
+ w2_scales = getattr(self.experts, "down_proj_scales", None)
+
+ # Store original shape
+ in_shape = x.size()
+
+ # Route tokens to experts
+ logits, expert_weights, expert_indices = route_tokens_xpu(
+ x,
+ self.router.weight,
+ getattr(self.router, "bias", None),
+ moe_top_k,
+ moe_num_experts,
+ moe_jitter_eps,
+ moe_normalize_expert_weights,
+ self.training,
+ )
+
+ # Reshape input for fused MoE
+ x_flat = x.view(-1, x.shape[-1])
+
+ # Call XPU fused MoE kernel
+ output = xpu_fused_moe(
+ hidden_states=x_flat,
+ w13=w13,
+ w13_scales=w13_scales,
+ w13_bias=w13_bias,
+ w2=w2,
+ w2_scales=w2_scales,
+ w2_bias=w2_bias,
+ topk_weights=expert_weights.float(),
+ topk_ids=expert_indices,
+ n_experts_per_token=moe_top_k,
+ activation=activation,
+ num_experts=num_experts_on_rank,
+ ep_rank=ep_rank,
+ ep_size=ep_size,
+ is_fp8=is_fp8,
+ is_int4=is_int4,
+ is_mxfp4=is_mxfp4,
+ )
+
+ # All-reduce across EP group to combine partial expert outputs
+ if ep_size > 1 and expert_parallel_group is not None:
+ import torch.distributed as dist
+ dist.all_reduce(output, op=dist.ReduceOp.SUM, group=expert_parallel_group)
+
+ # Restore original shape
+ output = output.view(in_shape)
+
+ return output, expert_weights
+
+
+# Export classes and functions
+__all__ = [
+ "MegaBlocksMoeMLP",
+ "xpu_fused_moe",
+ "cutlass_grouped_gemm",
+ "cutlass_grouped_gemm_xe2",
+]
\ No newline at end of file
diff --git a/build/torch29-cxx11-cu126-x86_64-linux/__init__.py b/build/torch29-cxx11-cu126-x86_64-linux/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..38075732c6d8fa0e1e6ef493145e1aca3851ae6b
--- /dev/null
+++ b/build/torch29-cxx11-cu126-x86_64-linux/__init__.py
@@ -0,0 +1,202 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+from ._ops import ops
+
+from .grouped_gemm import backend as gg_backend
+from .grouped_gemm import ops as gg_ops
+
+
+from ._layers.arguments import Arguments
+from ._layers.dmoe import ParallelDroplessMLP, dMoE
+from ._layers.glu import SparseGLU
+from ._layers.mlp import MLP, SparseMLP
+from ._layers.moe import MoE, ParallelMLP, get_load_balancing_loss
+
+from . import layers
+
+# This section contains the direct kernel exports (not inlcuded in the original code)
+def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
+ """
+ Compute exclusive cumulative sum along the specified dimension.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ result = ops.exclusive_cumsum(x, dim)
+ out.copy_(result)
+ return out
+
+
+def inclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
+ """
+ Compute inclusive cumulative sum along the specified dimension.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ result = ops.inclusive_cumsum(x, dim)
+ out.copy_(result)
+ return out
+
+
+def histogram(x: torch.Tensor, num_bins: int) -> torch.Tensor:
+ """
+ Compute histogram of input tensor values.
+
+ Args:
+ x: Input tensor
+ num_bins: Number of histogram bins
+
+ Returns:
+ Histogram tensor with counts for each bin
+ """
+ return ops.histogram(x, num_bins)
+
+
+def indices(
+ padded_bins: torch.Tensor,
+ block_size: int,
+ output_block_rows: int,
+ output_block_columns: int,
+) -> torch.Tensor:
+ """
+ Construct indices from padded bins for sparse operations.
+
+ Args:
+ padded_bins: Tensor containing bin boundaries
+ block_size: Size of each block
+ output_block_rows: Number of rows in output blocks
+ output_block_columns: Number of columns in output blocks
+
+ Returns:
+ Tensor containing constructed indices
+ """
+ return ops.indices(padded_bins, block_size, output_block_rows, output_block_columns)
+
+
+def replicate_forward(
+ x: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Forward pass of replicate operation - replicate values according to bin sizes.
+
+ Args:
+ x: Input tensor with values to replicate
+ bins: Tensor containing bin sizes
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ return ops.replicate_forward(x, bins, out)
+
+
+def replicate_backward(
+ grad: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Backward pass of replicate operation - reduce gradients back to bins.
+
+ Args:
+ grad: Gradient tensor to reduce
+ bins: Tensor containing bin sizes
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ return ops.replicate_backward(grad, bins, out)
+
+
+def sort(
+ x: torch.Tensor, end_bit: int, x_out: torch.Tensor, iota_out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Radix sort with index tracking.
+
+ Args:
+ x: Input tensor to sort
+ end_bit: Number of bits to consider in sorting
+ x_out: Output tensor for sorted values
+ iota_out: Output tensor for sorted indices
+
+ Returns:
+ The sorted values tensor
+ """
+ return ops.sort(x, end_bit, x_out, iota_out)
+
+
+# Convenience functions for common use cases
+def cumsum(x: torch.Tensor, dim: int = -1, exclusive: bool = False) -> torch.Tensor:
+ """
+ Compute cumulative sum with automatic output allocation.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum (default: last dimension)
+ exclusive: Whether to compute exclusive (True) or inclusive (False) cumsum
+
+ Returns:
+ New tensor containing the cumulative sum
+ """
+ out = torch.empty_like(x)
+ if exclusive:
+ return exclusive_cumsum(x, dim, out)
+ else:
+ return inclusive_cumsum(x, dim, out)
+
+
+def argsort(x: torch.Tensor, end_bit: int = 32) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Sort tensor and return both sorted values and indices.
+
+ Args:
+ x: Input tensor to sort
+ end_bit: Number of bits to consider in sorting
+
+ Returns:
+ Tuple of (sorted_values, sorted_indices)
+ """
+ x_out = torch.empty_like(x)
+ iota_out = torch.empty_like(x)
+ sort(x, end_bit, x_out, iota_out)
+ return x_out, iota_out
+
+
+# Export public API
+__all__ = [
+ "MyReplacementLayer",
+ # Direct kernel exports
+ "exclusive_cumsum",
+ "inclusive_cumsum",
+ "histogram",
+ "indices",
+ "replicate_forward",
+ "replicate_backward",
+ "sort",
+ "cumsum",
+ "argsort",
+ # Original exports
+ "Arguments",
+ "ParallelDroplessMLP",
+ "dMoE",
+ "SparseGLU",
+ "MLP",
+ "SparseMLP",
+ "MoE",
+ "ParallelMLP",
+ "get_load_balancing_loss",
+]
diff --git a/build/torch29-cxx11-cu126-x86_64-linux/_layers/__init__.py b/build/torch29-cxx11-cu126-x86_64-linux/_layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a720e7a2cc4e44636f6e433a2750e945dc38e8b2
--- /dev/null
+++ b/build/torch29-cxx11-cu126-x86_64-linux/_layers/__init__.py
@@ -0,0 +1,10 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# from megablocks.layers.dmoe import dMoE
+from .moe import MoE
+
+__all__ = [
+ 'MoE',
+ # 'dMoE',
+]
diff --git a/build/torch29-cxx11-cu126-x86_64-linux/_layers/activation_fn.py b/build/torch29-cxx11-cu126-x86_64-linux/_layers/activation_fn.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e1d956704840aa4daf7d1d71d24e051567feab9
--- /dev/null
+++ b/build/torch29-cxx11-cu126-x86_64-linux/_layers/activation_fn.py
@@ -0,0 +1,33 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Callable, Union
+
+import torch
+from ..stk import Matrix
+
+
+def act_fn(
+ x: Matrix,
+ function: Callable,
+ return_grad_fn: bool = False,
+ **kwargs,
+) -> Union[tuple[Matrix, Any] | Matrix]:
+ assert isinstance(x, Matrix)
+ with torch.set_grad_enabled(torch.is_grad_enabled() or return_grad_fn):
+ if return_grad_fn:
+ x.data.requires_grad = True
+ out = function(x.data, **kwargs)
+ y = Matrix(
+ x.size(),
+ out,
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
+ if return_grad_fn:
+ return y, out.backward
+ return y
diff --git a/build/torch29-cxx11-cu126-x86_64-linux/_layers/all_to_all.py b/build/torch29-cxx11-cu126-x86_64-linux/_layers/all_to_all.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ac7067bcaa34db1d82b340c43550fe3577aa7a3
--- /dev/null
+++ b/build/torch29-cxx11-cu126-x86_64-linux/_layers/all_to_all.py
@@ -0,0 +1,54 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import torch.distributed as dist
+
+
+class AllToAllOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op):
+ out = torch.empty((sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype)
+
+ ctx.input_shape = x.shape
+ ctx.output_split_sizes = output_split_sizes
+ ctx.input_split_sizes = input_split_sizes
+ ctx.group = group
+ handle = dist.all_to_all_single(
+ out,
+ x,
+ output_split_sizes=output_split_sizes,
+ input_split_sizes=input_split_sizes,
+ group=group,
+ async_op=async_op,
+ )
+ return out, handle
+
+ @staticmethod
+ def backward(ctx, grad, _):
+ if ctx.needs_input_grad[0]:
+ out = torch.empty(
+ ctx.input_shape,
+ device=grad.device,
+ dtype=grad.dtype,
+ )
+ dist.all_to_all_single(
+ out,
+ grad,
+ output_split_sizes=ctx.input_split_sizes,
+ input_split_sizes=ctx.output_split_sizes,
+ group=ctx.group,
+ )
+ return out, None, None, None, None
+ return None, None, None, None, None
+
+
+def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False):
+ return AllToAllOp.apply(
+ x,
+ output_split_sizes,
+ input_split_sizes,
+ group,
+ async_op,
+ )
diff --git a/build/torch29-cxx11-cu126-x86_64-linux/_layers/arguments.py b/build/torch29-cxx11-cu126-x86_64-linux/_layers/arguments.py
new file mode 100644
index 0000000000000000000000000000000000000000..4db9b1bd38bc2e2f421625c124f86b85f45c5ae0
--- /dev/null
+++ b/build/torch29-cxx11-cu126-x86_64-linux/_layers/arguments.py
@@ -0,0 +1,101 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import dataclasses
+from functools import partial
+from typing import Any, Callable, Optional, Union
+
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+
+# import megablocks.grouped_gemm_util as grouped_gemm
+from .. import grouped_gemm_util as grouped_gemm
+
+# Type annotation for in-place Tensor initialization function.
+InitFn = Union[Callable[[torch.Tensor], None], partial[torch.Tensor]]
+
+_ALLOWED_BITWIDTHS = (-1, 4, 8)
+
+DEFAULT_ACTIVATION_FN = partial(F.gelu, approximate='tanh')
+
+
+@dataclasses.dataclass
+class Arguments:
+ # Model arguments.
+ hidden_size: int = 1024
+ ffn_hidden_size: int = 4096
+ num_layers: int = 1
+ bias: bool = True
+ return_bias: bool = True
+ activation_fn: Optional[Callable] = DEFAULT_ACTIVATION_FN
+
+ # MoE arguments.
+ moe_num_experts: int = 1
+ moe_top_k: int = 1
+ moe_capacity_factor: int = 1
+ moe_normalize_expert_weights: Optional[Union[int, float]] = None
+ moe_loss_weight: float = 0.1
+ moe_jitter_eps: Optional[float] = None
+ moe_lbl_in_fp32: bool = False
+
+ # Parallelism arguments.
+ moe_expert_model_parallelism: bool = False
+ expert_parallel_group: Optional[dist.ProcessGroup] = None
+ pipeline_model_parallel_size: int = 1
+ num_layers_per_virtual_pipeline_stage: Optional[int] = None
+
+ # Compute arguments.
+ memory_optimized_mlp: bool = False
+ mlp_type: str = 'mlp'
+ mlp_impl: str = 'sparse'
+
+ # Initialization arguments.
+ fp16: bool = True
+ bf16: bool = False
+ device: Union[int, torch.device] = dataclasses.field(default_factory=torch.cuda.current_device)
+ init_method: InitFn = partial(torch.nn.init.normal_, mean=0.0, std=0.02)
+ output_layer_init_method: InitFn = init_method
+
+ # Benchmarking arguments.
+ uniform_expert_assignment: bool = False
+
+ # shared expert arguments
+ shared_expert: bool = False # enable using shared expert
+ fc_cls: Any = torch.nn.Linear # class of the fully connected layer in shared expert (purpose: to allow using custom FC layer eg te.Linear (for FP8))
+ fc_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict,) # kwargs for custom fc layers
+ remat_act_fn: bool = True # enable act fn to be rematerialized instead of stored
+ shared_expert_hidden_size: Optional[
+ int] = None # hidden size of the shared expert IF we want to set it to something different from hidden_size
+ shared_expert_weighted_sum: bool = False # enable using weighted sum for shared expert output (wieghted by number of experts used)
+
+ # Router Z-loss arguments
+ moe_zloss_weight: float = 0 # 1e-3 is a reasonable value
+ moe_zloss_in_fp32: bool = False
+
+ def __post_init__(self):
+ # Sparse MLP is not supported with triton >=3.2.0
+ # TODO: Remove this once sparse is supported with triton >=3.2.0
+ if self.__getattribute__('mlp_impl') == 'sparse':
+ try:
+ import triton
+ if triton.__version__ >= '3.2.0':
+ raise ValueError(
+ 'Sparse MLP is not supported with triton >=3.2.0. Please use mlp_impl="grouped" instead.',
+ )
+ except ImportError:
+ raise ImportError('Triton is required for sparse MLP implementation')
+
+ if self.__getattribute__('mlp_impl') == 'grouped':
+ grouped_gemm.assert_grouped_gemm_is_available()
+
+ if self.shared_expert_hidden_size is None:
+ self.shared_expert_hidden_size = self.ffn_hidden_size
+
+
+def from_megatron(megatron_args: Any):
+ args = Arguments()
+ for field in dataclasses.fields(args):
+ if hasattr(megatron_args, field.name):
+ setattr(args, field.name, getattr(megatron_args, field.name))
+ return args
diff --git a/build/torch29-cxx11-cu126-x86_64-linux/_layers/common.py b/build/torch29-cxx11-cu126-x86_64-linux/_layers/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d07109702963ba48a3b94ab860807954dfd79c1
--- /dev/null
+++ b/build/torch29-cxx11-cu126-x86_64-linux/_layers/common.py
@@ -0,0 +1,26 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+from .arguments import Arguments
+
+
+def dtype(args: Arguments):
+ if args.fp16:
+ return torch.float16
+ elif args.bf16:
+ return torch.bfloat16
+ return None
+
+
+def cast_if_autocast_enabled(tensor):
+ if torch.is_autocast_enabled():
+ if tensor.device.type == 'cuda':
+ dtype = torch.get_autocast_gpu_dtype()
+ elif tensor.device.type == 'cpu':
+ dtype = torch.get_autocast_cpu_dtype()
+ else:
+ raise NotImplementedError()
+ return tensor.to(dtype=dtype)
+ return tensor
diff --git a/build/torch29-cxx11-cu126-x86_64-linux/_layers/dmlp_registry.py b/build/torch29-cxx11-cu126-x86_64-linux/_layers/dmlp_registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..de2ed047042e438c7190ebb139b6f7f30009734c
--- /dev/null
+++ b/build/torch29-cxx11-cu126-x86_64-linux/_layers/dmlp_registry.py
@@ -0,0 +1,42 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Union
+
+from . import glu, mlp
+from .arguments import Arguments
+
+MlpType = Union[mlp.SparseMLP, glu.SparseGLU]
+
+_REGISTRY = {
+ 'mlp': {
+ 'grouped': mlp.GroupedMLP,
+ 'sparse': mlp.SparseMLP,
+ },
+ 'glu': {
+ 'grouped': glu.GroupedGLU,
+ 'sparse': glu.SparseGLU,
+ },
+}
+
+
+def get(args: Arguments) -> MlpType:
+ """Returns an MLP for use in a dMoE instance.
+
+ Uses the provided arguments to instantiate the appropriate
+ MLP instance. This only contains MLPs for use in dMoEs
+ (ie. only for the dropless versions of MoEs).
+
+ Args:
+ args: propagated Arguments dataclass.
+
+ Returns:
+ An instantiated MLP constructed using the input args.
+ """
+ if args.mlp_type not in _REGISTRY:
+ raise ValueError(f'Unsupported mlp type: {args.mlp_type}')
+
+ if args.mlp_impl not in _REGISTRY[args.mlp_type]:
+ raise ValueError(f'{args.mlp_type} does not support {args.mlp_impl} backend.',)
+
+ return _REGISTRY[args.mlp_type][args.mlp_impl](args)
diff --git a/build/torch29-cxx11-cu126-x86_64-linux/_layers/dmoe.py b/build/torch29-cxx11-cu126-x86_64-linux/_layers/dmoe.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d0375a4df2f27134c4127e60be04f3b45693050
--- /dev/null
+++ b/build/torch29-cxx11-cu126-x86_64-linux/_layers/dmoe.py
@@ -0,0 +1,337 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import numpy as np
+import torch
+
+# try:
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/dmoe.py is needed.',
+# )
+
+# import megablocks.ops as ops
+# # from megablocks.ops import ops
+# from megablocks.layers import common, dmlp_registry, moe, mpu
+# from megablocks.layers.arguments import Arguments
+
+from .. import stk
+from .. import ops
+from . import common, dmlp_registry, moe, mpu
+from .arguments import Arguments
+
+def promote_scalar(x):
+ return x.view(1) if not len(x.size()) else x
+
+
+class ParallelDroplessMLP(moe.ParallelMLP):
+
+ def __init__(self, args: Arguments):
+ super(ParallelDroplessMLP, self).__init__(args)
+ self.hidden_size = args.hidden_size
+ self.ffn_hidden_size = mpu.features_per_rank(args)
+ self.blocking = 128
+ self.mlp = dmlp_registry.get(args)
+
+ # Calculate the number of bits needed to represent the column indices
+ # in the intermediate sparse matrix.
+ max_column_index = ((self.ffn_hidden_size * self.num_experts) // self.blocking)
+ self.transpose_sort_end_bit = max(
+ int(np.ceil(np.log2(max_column_index))),
+ 1,
+ )
+
+ def sparse_transpose(self, size, row_indices, column_indices, offsets):
+ block_columns = size[1] // self.blocking
+
+ # Sort row indices by column indices to get the transposed matrix's
+ # column indices.
+ #
+ # NOTE: Our sort operation uses the same width indices as the input values.
+ # To avoid overflow when we have large activation matrices we cast to
+ # 32-bit before sorting.
+ _, gather_indices = ops.sort(
+ column_indices.int(),
+ self.transpose_sort_end_bit,
+ )
+
+ # There are a constant number of blocks in every row of the sparse matrix.
+ # A blocks offset is:
+ #
+ # row_index * blocks_per_row + column_index % blocks_per_row
+ #
+ # Once we have the block offsets ordered for transposition we can divide
+ # by blocks_per_row to get the transposed column indices.
+ column_indices_t = row_indices.gather(0, gather_indices.long())
+ block_offsets_t = gather_indices.int()
+
+ zero = torch.zeros((1,), dtype=torch.int32, device=row_indices.device)
+ nnz_per_column = ops.histogram(column_indices, block_columns)
+ nnz_per_column = ops.inclusive_cumsum(nnz_per_column, 0)
+ if nnz_per_column.dim() == 0:
+ # This addresses an edge case when ffn_hidden_size is equal to self.blocking.
+ nnz_per_column = nnz_per_column.unsqueeze(0)
+ offsets_t = torch.cat([zero, nnz_per_column])
+ return column_indices_t, offsets_t, block_offsets_t
+
+ def topology(self, x, padded_bins):
+ padded_tokens, _ = x.size()
+ assert padded_tokens % self.blocking == 0
+ if self.ffn_hidden_size % self.blocking != 0:
+ raise ValueError(
+ f'The ffn_hidden_size {self.ffn_hidden_size} must be divisible by ' +
+ f'the block size {self.blocking}. Please update your configuration.',
+ )
+
+ # Offsets for the sparse matrix. All rows have the
+ # same number of nonzero blocks dictated by the
+ # dimensionality of a single expert.
+ block_rows = padded_tokens // self.blocking
+ blocks_per_row = self.ffn_hidden_size // self.blocking
+ offsets = torch.arange(
+ 0,
+ block_rows * blocks_per_row + 1,
+ blocks_per_row,
+ dtype=torch.int32,
+ device=x.device,
+ )
+
+ # Indices for the sparse matrix. The indices for
+ # the intermediate matrix are dynamic depending
+ # on the mapping of tokens to experts.
+ column_indices = ops.topology(
+ padded_bins,
+ self.blocking,
+ block_rows,
+ blocks_per_row,
+ )
+
+ # TODO(tgale): This is unused. Remove the need for this in stk.
+ # For now, use meta init to save the device memory.
+ data = torch.empty(
+ column_indices.numel(),
+ self.blocking,
+ self.blocking,
+ dtype=common.dtype(self.args),
+ device='meta',
+ )
+ shape = (
+ padded_tokens,
+ self.ffn_hidden_size * mpu.experts_per_rank(self.args),
+ )
+ row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
+ column_indices_t, offsets_t, block_offsets_t = self.sparse_transpose(
+ shape,
+ row_indices,
+ column_indices,
+ offsets,
+ )
+ return stk.Matrix(
+ shape,
+ data,
+ row_indices,
+ column_indices,
+ offsets,
+ column_indices_t,
+ offsets_t,
+ block_offsets_t,
+ )
+
+ def indices_and_padded_bins(self, top_experts):
+ # Sort the expert ids to produce the scatter/gather
+ # indices for the permutation.
+ top_experts = top_experts.int()
+ bin_ids, indices = ops.sort(top_experts, self.sort_end_bit)
+
+ # Histogram the expert ids to identify the number of
+ # tokens routed to each expert.
+ tokens_per_expert = ops.histogram(top_experts, self.num_experts)
+
+ # Round the token counts up to the block size used in
+ # the matrix muliplications. Caculate the starting
+ # position of each bin.
+ padded_tokens_per_expert = ops.round_up(
+ tokens_per_expert,
+ self.blocking,
+ )
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ padded_bins = promote_scalar(padded_bins)
+
+ # Calculate the bin bounds for the sorted tokens.
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ bins = promote_scalar(bins)
+ return indices, bin_ids, bins, padded_bins, tokens_per_expert
+
+ def sparse_forward_once(self, x, expert_weights, top_experts):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, padded_bins, tokens_per_expert = (self.indices_and_padded_bins(top_experts))
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.padded_gather(
+ x,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ self.top_k,
+ )
+
+ # Create the sparse matrix topology.
+ with torch.no_grad():
+ topo = self.topology(x, padded_bins)
+
+ # Perform the expert computation.
+ x = self.mlp(x, topo)
+
+ # Un-route the data for the MoE output.
+ x = ops.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ padded_bins,
+ self.top_k,
+ )
+ return x, tokens_per_expert
+
+ # For use in the base-class parallel_forward_once.
+ def sparse_permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy, # unused
+ top_k,
+ ):
+
+ # Round the token counts up to the block size used in the matrix
+ # multiplication. Calculate the starting position of each bin.
+ padded_tokens_per_expert = ops.round_up(
+ tokens_per_expert,
+ self.blocking,
+ )
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ padded_bins = promote_scalar(padded_bins)
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
+
+ # Create the sparse matrix topology.
+ with torch.no_grad():
+ topo = self.topology(x, padded_bins)
+
+ # Perform the expert computation.
+ x = self.mlp(x, topo)
+
+ # Un-route the data for the MoE output.
+ return ops.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ def grouped_forward_once(self, x, expert_weights, top_experts):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ out = self.grouped_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ -1, # unused
+ self.args.moe_top_k,
+ )
+ return out, tokens_per_expert
+
+ def grouped_permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy, # unused
+ top_k,
+ ):
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.gather(x, indices, bin_ids, bins, top_k)
+
+ # Perform the expert computation.
+ x = self.mlp(x, tokens_per_expert)
+
+ # Un-route the data for the MoE output.
+ return ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
+
+ def forward_once(self, x, expert_weights, top_experts):
+ if self.args.mlp_impl == 'sparse':
+ return self.sparse_forward_once(x, expert_weights, top_experts)
+ else:
+ return self.grouped_forward_once(x, expert_weights, top_experts)
+
+ def permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ ):
+ if self.args.mlp_impl == 'sparse':
+ return self.sparse_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ )
+ else:
+ return self.grouped_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ )
+
+
+class dMoE(moe.MoE):
+
+ def _init_experts_mlp(self, args: Arguments):
+ return ParallelDroplessMLP(args)
diff --git a/build/torch29-cxx11-cu126-x86_64-linux/_layers/gelu.py b/build/torch29-cxx11-cu126-x86_64-linux/_layers/gelu.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4c9e6532798615b5c12c96694241a4c18ee8f7b
--- /dev/null
+++ b/build/torch29-cxx11-cu126-x86_64-linux/_layers/gelu.py
@@ -0,0 +1,52 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# try:
+# import stk
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/gelu.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+import torch.nn.functional as F
+
+
+@torch.jit.script
+def _gelu_backward_inplace(g, x):
+ tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
+ ff = (0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out))
+ return g.mul_(ff)
+
+
+def gelu_backward_(grad: stk.Matrix, x: stk.Matrix):
+ # NOTE: The two sparse matrices must have the same topology.
+ if isinstance(grad, stk.Matrix) and isinstance(x, stk.Matrix):
+ return stk.Matrix(
+ x.size(),
+ _gelu_backward_inplace(grad.data, x.data),
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
+ return _gelu_backward_inplace(grad, x)
+
+
+def gelu(x: stk.Matrix):
+ assert isinstance(x, stk.Matrix)
+ return stk.Matrix(
+ x.size(),
+ F.gelu(x.data, approximate='tanh'),
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
diff --git a/build/torch29-cxx11-cu126-x86_64-linux/_layers/glu.py b/build/torch29-cxx11-cu126-x86_64-linux/_layers/glu.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f297a41ff6a1a2a285f5b461951672364b898da
--- /dev/null
+++ b/build/torch29-cxx11-cu126-x86_64-linux/_layers/glu.py
@@ -0,0 +1,244 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# import stk.ops
+# try:
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/glu.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+
+# from megablocks import grouped_gemm_util as gg
+# from megablocks.layers import common, mpu
+# from megablocks.layers.activation_fn import act_fn
+# from megablocks.layers.arguments import Arguments
+# from megablocks.layers.mlp import (
+# SharedMLP,
+# SparseMLP,
+# create_dmoe_expert_weights,
+# resolve_dtensor,
+# )
+
+from .. import grouped_gemm_util as gg
+from . import common, mpu
+from .activation_fn import act_fn
+from .arguments import Arguments
+from .mlp import (
+ SharedMLP,
+ SparseMLP,
+ create_dmoe_expert_weights,
+ resolve_dtensor,
+)
+
+
+class SparseGLU(SparseMLP):
+
+ def __init__(self, args: Arguments):
+ super().__init__(args)
+ self.v1 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ with torch.no_grad():
+ self.v1.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ ),
+ )
+
+ mpu.set_expert_model_parallel_attributes(
+ self.v1,
+ self._should_set_parallelism_attribute,
+ )
+
+ def forward(self, x, topo):
+ if self.args.memory_optimized_mlp:
+ raise NotImplementedError(
+ 'Memory optimized implementation not yet supported with GLU with sparse kernels.',
+ )
+
+ w1, v1, w2 = self.scale_grad(self.w1), self.scale_grad(self.v1,), self.scale_grad(self.w2)
+ w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2)
+
+ # Compute the GLU.
+ x1 = stk.ops.sdd(x, w1.t(), topo)
+ x2 = stk.ops.sdd(x, v1.t(), topo)
+
+ activation_fn_out = act_fn(x1, self.args.activation_fn)
+ x1 = stk.ops.mul(activation_fn_out, x2)
+
+ return stk.ops.dsd(x1, w2)
+
+
+class MemoryOptimizedGroupedGLU(torch.autograd.Function):
+ """GroupedMLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, v1, w2, batch_sizes, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ v1 = v1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], v1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not v1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1', 'v1' and 'w2'.")
+
+ # Layer 0: x @ w1.t().
+ assert gg.backend is not None
+ sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True)
+ v1_out = gg.backend.gmm(x, v1, batch_sizes, trans_b=True)
+
+ # GeLU.
+ activation_fn_out = activation_fn(sdd_out) * v1_out
+
+ # Layer 1: x @ w2.
+ dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, v1, w2, batch_sizes, x, sdd_out, v1_out)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx, ddsd_out):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # Unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, v1, w2 = saved_tensors[:3]
+ batch_sizes = saved_tensors[3]
+ x = saved_tensors[4]
+ sdd_out, v1_out = saved_tensors[5:7]
+
+ # Rematerialize activation_fn output.
+ activation_fn = ctx.activation_fn
+ with torch.set_grad_enabled(True):
+ sdd_out.requires_grad = True
+ v1_out.requires_grad = True
+ activation_fn_out = activation_fn(sdd_out) * v1_out
+ activation_grad_fn = activation_fn_out.backward
+
+ # Compute dw2 with recomputed activation_fn output.
+ assert gg.backend is not None
+ dw2 = gg.backend.gmm(
+ activation_fn_out,
+ ddsd_out,
+ batch_sizes,
+ trans_a=True,
+ )
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ gg.backend.gmm(
+ ddsd_out,
+ w2,
+ batch_sizes,
+ trans_b=True,
+ c=dactivation_fn_out,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out)
+ dsdd_out = sdd_out.grad
+ dv1_out = v1_out.grad
+
+ # Compute dw1.
+ dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True)
+
+ # Compute dv1.
+ dv1 = gg.backend.gmm(dv1_out, x, batch_sizes, trans_a=True)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ dx = ddsd_out
+ gg.backend.gmm(dsdd_out, w1, batch_sizes, c=dx)
+ dx += gg.backend.gmm(dv1_out, v1, batch_sizes)
+ return dx, dw1, dv1, dw2, None, None
+
+
+memory_optimized_grouped_glu = MemoryOptimizedGroupedGLU.apply
+
+
+class GroupedGLU(SparseGLU):
+
+ def forward(self, x, tokens_per_expert):
+ batch_sizes = tokens_per_expert.cpu().to(torch.long)
+ w1, v1, w2 = (
+ self.scale_grad(self.w1),
+ self.scale_grad(self.v1),
+ self.scale_grad(self.w2),
+ )
+ w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2)
+
+ # Re-shape the weights for the grouped GEMMs.
+ ne = mpu.experts_per_rank(self.args)
+ w1 = w1.view(ne, -1, self.args.hidden_size)
+ v1 = v1.view(ne, -1, self.args.hidden_size)
+ w2 = w2.view(ne, -1, self.args.hidden_size)
+
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_grouped_glu(
+ x,
+ w1,
+ v1,
+ w2,
+ batch_sizes,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ assert gg.ops is not None
+ x1 = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
+ x2 = gg.ops.gmm(x, v1, batch_sizes, trans_b=True)
+ x1 = self.args.activation_fn(x1) * x2
+ return gg.ops.gmm(x1, w2, batch_sizes)
+
+
+class SharedGLU(SharedMLP):
+ """GPU for shared expert.
+
+ Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTGLU class
+ """
+
+ def __init__(self, args: Arguments):
+ super().__init__(args)
+ self.gate_proj = args.fc_cls(
+ args.hidden_size,
+ self.args.shared_expert_hidden_size,
+ **self.fc_kwargs,
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))
diff --git a/build/torch29-cxx11-cu126-x86_64-linux/_layers/memory_test.py b/build/torch29-cxx11-cu126-x86_64-linux/_layers/memory_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..74d1166931b712635131985b25a89f4ca23e576d
--- /dev/null
+++ b/build/torch29-cxx11-cu126-x86_64-linux/_layers/memory_test.py
@@ -0,0 +1,103 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import gc
+
+import torch
+import torch.distributed as dist
+
+# from megablocks.layers import arguments, dmoe
+from . import arguments, dmoe
+
+_TESTS = ((8, 2048, 4096, 4096, 32, 4),)
+
+
+def get_tensors():
+ ptrs = set()
+ out = []
+ for obj in gc.get_objects():
+ if torch.is_tensor(obj):
+ if not obj.is_contiguous() or obj.data_ptr() in ptrs:
+ continue
+ out.append(obj)
+ ptrs.add(obj.data_ptr())
+ return out
+
+
+def test_memory(
+ group,
+ batch_size,
+ sequence_length,
+ hidden_size,
+ ffn_hidden_size,
+ num_experts,
+ top_k,
+):
+ args = arguments.Arguments(
+ hidden_size=hidden_size,
+ ffn_hidden_size=ffn_hidden_size,
+ moe_num_experts=num_experts,
+ moe_top_k=top_k,
+ moe_expert_model_parallelism=True,
+ expert_parallel_group=group,
+ fp16=False,
+ bf16=True,
+ device=torch.cuda.current_device(),
+ )
+ layer = dmoe.dMoE(args).cuda()
+
+ x = torch.randn((batch_size, sequence_length, hidden_size),
+ device=torch.cuda.current_device(),
+ dtype=torch.bfloat16).requires_grad_(True)
+ torch.cuda.empty_cache()
+
+ # Run forward + backward.
+ # with torch.autograd.detect_anomaly():
+ out, _ = layer(x)
+ out.mean().backward()
+
+ # Report peak memory.
+ mem = torch.cuda.max_memory_allocated()
+ print('Max Memory Allocated = {:0.0f}MiB'.format(mem / 1e6))
+ print('Max Memory Reserved = {:0.0f}MiB'.format(torch.cuda.max_memory_reserved() / 1e6,),)
+
+ # Calculate weight and gradient memory usage.
+ weight_memory = 2 * (
+ layer.router.layer.weight.numel() + layer.experts.mlp.w1.numel() + layer.experts.mlp.w2.numel()
+ )
+
+ def grad_numel(x):
+ if x.grad is not None:
+ return x.grad.numel()
+ return 0
+
+ grad_memory = 2 * (
+ grad_numel(layer.router.layer.weight) + grad_numel(layer.experts.mlp.w1) + grad_numel(layer.experts.mlp.w2)
+ )
+ weight_memory += grad_memory
+
+ print('Weight Memory Allocated = {:0.0f}MiB'.format(weight_memory / 1e6))
+ print('Activation Memory Allocated = {:0.0f}MiB'.format((mem - weight_memory) / 1e6,),)
+
+ # Manually calculate GPU memory usage from the garbage
+ # collector.
+ gc.collect()
+ total = 0
+ tensors = get_tensors()
+ tensors = sorted(tensors, key=lambda x: -x.numel())
+ for i, t in enumerate(tensors):
+ total += t.numel()
+ print(f'{i}: {t.shape}, {t.numel() * 2}')
+ del tensors
+
+ print('Total Bytes Found = {:0.0f}MiB'.format(total * 2 / 1e6))
+
+
+if __name__ == '__main__':
+ assert dist.is_available()
+ group = dist.init_process_group(backend='nccl')
+ local_rank = dist.get_rank(group)
+ torch.cuda.set_device(local_rank)
+
+ for args in _TESTS:
+ test_memory(group, *args)
diff --git a/build/torch29-cxx11-cu126-x86_64-linux/_layers/mlp.py b/build/torch29-cxx11-cu126-x86_64-linux/_layers/mlp.py
new file mode 100644
index 0000000000000000000000000000000000000000..c99afb9904c24a8b6a83e79059cd1251dbbfd99e
--- /dev/null
+++ b/build/torch29-cxx11-cu126-x86_64-linux/_layers/mlp.py
@@ -0,0 +1,587 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# try:
+# import stk
+# import stk.backend.triton_kernels
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/mlp.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+from packaging import version
+
+# from megablocks import grouped_gemm_util as gg
+# from megablocks.layers import common, gelu, mpu
+# from megablocks.layers.activation_fn import act_fn
+# from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
+
+from .. import grouped_gemm_util as gg
+from . import common, gelu, mpu
+from .activation_fn import act_fn
+from .arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
+
+class ScaleGradient(torch.autograd.Function):
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx: Any, x: torch.Tensor, scale: float):
+ ctx.scale = scale
+ return x
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx: torch.Tensor, grad: torch.Tensor):
+ return grad * ctx.scale, None
+
+
+scale_gradient = ScaleGradient.apply
+
+
+def resolve_dtensor(weight: torch.Tensor):
+ if version.parse(torch.__version__) >= version.parse('2.0.0'):
+ from torch.distributed._tensor import DTensor
+ if isinstance(weight, DTensor):
+ return weight.to_local()
+ return weight
+
+
+def create_moe_expert_weights(
+ args: Arguments,
+ num_experts: int,
+ ffn_hidden_size: int,
+ hidden_size: int,
+ init_method: InitFn,
+):
+ # Create the entire weight matrix such that the sampled weights will
+ # not vary between data parallelism and expert model parallelism for
+ # the same random seed.
+ master_weights = torch.empty(
+ num_experts,
+ ffn_hidden_size,
+ hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ )
+ init_method(master_weights)
+
+ if not args.moe_expert_model_parallelism:
+ return master_weights
+
+ # Calculate the amount of sharding in each dimension.
+ expert_sharding_degree = mpu.expert_sharding_degree(args)
+ hidden_sharding_degree = mpu.hidden_sharding_degree(args)
+
+ # Calculate the experts per rank.
+ #
+ # NOTE: We assign ranks to be expert parallel before going
+ # tensor parallel.
+ rank = mpu.get_expert_parallel_rank(args)
+ expert_rank = rank % expert_sharding_degree
+ num_experts_per_rank = num_experts // expert_sharding_degree
+ start_expert = expert_rank * num_experts_per_rank
+ end_expert = (expert_rank + 1) * num_experts_per_rank
+
+ # Calculate the rows per rank.
+ row_rank = rank // expert_sharding_degree
+ num_rows_per_rank = ffn_hidden_size // hidden_sharding_degree
+ start_row = row_rank * num_rows_per_rank
+ end_row = (row_rank + 1) * num_rows_per_rank
+
+ # Slice the weight matrix to get the chunk for this rank.
+ with torch.no_grad():
+ weights = master_weights[start_expert:end_expert, start_row:end_row]
+ return weights
+
+
+class MLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ # expert_parallel_world_size = mpu.get_expert_parallel_world_size(args)
+ experts_per_rank = mpu.experts_per_rank(args)
+
+ self.w1 = torch.nn.Parameter(
+ torch.empty(
+ experts_per_rank,
+ args.hidden_size,
+ mpu.features_per_rank(args),
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ self.w2 = torch.nn.Parameter(
+ torch.empty(
+ experts_per_rank,
+ mpu.features_per_rank(args),
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w1,
+ args.moe_expert_model_parallelism,
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w2,
+ args.moe_expert_model_parallelism,
+ )
+
+ # Initialize the parameters for the MLP.
+ #
+ # NOTE: It is important that we create the weight tensors prior
+ # to creating the master weights and slicing our the piece for
+ # this rank. If the master weights are created first the PyTorch
+ # caching allocator appears to use the same memory block for these
+ # and the slice which causes large increases in our peak memory
+ # usage.
+ with torch.no_grad():
+ w1 = create_moe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ )
+ self.w1.copy_(w1.transpose(1, 2).contiguous())
+ self.w2.copy_(
+ create_moe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.output_layer_init_method,
+ ),
+ )
+
+ self.gradient_scale = None
+ if self.args.moe_expert_model_parallelism:
+ self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,)
+
+ def scale_grad(self, w):
+ if self.gradient_scale is None:
+ return w
+ return scale_gradient(w, self.gradient_scale)
+
+ def forward(self, x):
+ w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2)
+ w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2)
+ x = torch.bmm(x, w1)
+ x = self.args.activation_fn(x)
+ return torch.bmm(x, w2)
+
+
+def create_dmoe_expert_weights(
+ args: Arguments,
+ num_experts: int,
+ rows: int,
+ columns: int,
+ init_method: InitFn,
+):
+ weights = create_moe_expert_weights(
+ args,
+ num_experts,
+ rows,
+ columns,
+ init_method,
+ )
+ return weights.view([-1, columns])
+
+
+class MemoryOptimizedMLP(torch.autograd.Function):
+ """Sparse MLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, w2, topo, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.")
+
+ topo_tensors = (
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets,
+ topo.column_indices_t,
+ topo.offsets_t,
+ topo.block_offsets_t,
+ )
+
+ # Layer 0: x @ w1.t().
+ sdd_out = stk.ops.sdd(x, w1.t(), topo)
+
+ # GeLU.
+ activation_fn_out = act_fn(sdd_out, activation_fn)
+
+ # Layer 1: x @ w2.
+ dsd_out = stk.ops.dsd(activation_fn_out, w2)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.shape = topo.shape
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.data.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, w2, *topo_tensors, x, sdd_out.data)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx, ddsd_out):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, w2 = saved_tensors[:2]
+ topo_tensors = saved_tensors[2:8]
+ x = saved_tensors[8]
+ sdd_out_data = saved_tensors[9]
+
+ # rematerialize activation function output
+ activation_fn = ctx.activation_fn
+ sdd_out = stk.Matrix(ctx.shape, sdd_out_data, *topo_tensors)
+ activation_fn_out, activation_grad_fn = act_fn(
+ sdd_out,
+ activation_fn,
+ return_grad_fn=True,
+ )
+
+ # Compute dw2 with recomputed activation_fn output.
+ dw2 = stk.ops.dsd(activation_fn_out.t(), ddsd_out)
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ stk.backend.triton_kernels.sdd(
+ ddsd_out,
+ w2.t(),
+ dactivation_fn_out.shape,
+ dactivation_fn_out.data,
+ dactivation_fn_out.offsets,
+ dactivation_fn_out.row_indices,
+ dactivation_fn_out.column_indices,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ if activation_fn is DEFAULT_ACTIVATION_FN:
+ dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out)
+ else:
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out.data)
+ dsdd_out = stk.Matrix(ctx.shape, sdd_out.data.grad, *topo_tensors)
+
+ # Compute dw1.
+ dw1 = stk.ops.dsd(dsdd_out.t(), x)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ stk.backend.triton_kernels.dsd(
+ dsdd_out.shape,
+ dsdd_out.data,
+ dsdd_out.offsets,
+ dsdd_out.row_indices,
+ dsdd_out.column_indices,
+ dsdd_out.offsets_t,
+ dsdd_out.column_indices_t,
+ dsdd_out.block_offsets_t,
+ False,
+ w1,
+ ddsd_out,
+ )
+ dx = ddsd_out
+ return dx, dw1, dw2, None, None
+
+
+memory_optimized_mlp = MemoryOptimizedMLP.apply
+
+
+class SparseMLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ self._num_rows_per_rank = mpu.experts_per_rank(args) * mpu.features_per_rank(args)
+
+ self.w1 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ self.w2 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+
+ # Initialize the parameters for the MLP.
+ #
+ # NOTE: It is important that we create the weight tensors prior
+ # to creating the master weights and slicing our the piece for
+ # this rank. If the master weights are created first the PyTorch
+ # caching allocator appears to use the same memory block for these
+ # and the slice which causes large increases in our peak memory
+ # usage.
+ with torch.no_grad():
+ self.w1.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ ),
+ )
+ self.w2.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.output_layer_init_method,
+ ),
+ )
+
+ self._should_set_parallelism_attribute = args.moe_expert_model_parallelism
+ mpu.set_expert_model_parallel_attributes(
+ self.w1,
+ self._should_set_parallelism_attribute,
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w2,
+ self._should_set_parallelism_attribute,
+ )
+
+ self.gradient_scale = None
+ if self.args.moe_expert_model_parallelism:
+ self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,)
+
+ def scale_grad(self, w):
+ if self.gradient_scale is None:
+ return w
+ return scale_gradient(w, self.gradient_scale)
+
+ def forward(self, x, topo):
+ w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2)
+ w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2)
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_mlp(
+ x,
+ w1,
+ w2,
+ topo,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ x = stk.ops.sdd(x, w1.t(), topo)
+ activation_fn_out = act_fn(x, self.args.activation_fn)
+ return stk.ops.dsd(activation_fn_out, w2)
+
+
+class MemoryOptimizedGroupedMLP(torch.autograd.Function):
+ """GroupedMLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, w2, batch_sizes, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.")
+
+ # Layer 0: x @ w1.t().
+ assert gg.backend is not None
+ sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True)
+
+ # activation_fn
+ activation_fn_out = activation_fn(sdd_out)
+
+ # Layer 1: x @ w2.
+ dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, w2, batch_sizes, x, sdd_out)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx: Any, ddsd_out: torch.Tensor):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # Unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, w2 = saved_tensors[:2]
+ batch_sizes = saved_tensors[2]
+ x = saved_tensors[3]
+ sdd_out = saved_tensors[4]
+
+ # Rematerialize activation_fn output.
+ activation_fn = ctx.activation_fn
+ with torch.set_grad_enabled(True):
+ sdd_out.requires_grad = True
+ activation_fn_out = activation_fn(sdd_out)
+ activation_grad_fn = activation_fn_out.backward
+
+ # Compute dw2 with recomputed activation_fn output.
+ assert gg.backend is not None
+ dw2 = gg.backend.gmm(
+ activation_fn_out,
+ ddsd_out,
+ batch_sizes,
+ trans_a=True,
+ )
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ gg.backend.gmm(
+ ddsd_out,
+ w2,
+ batch_sizes,
+ trans_b=True,
+ c=dactivation_fn_out,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ if activation_fn is DEFAULT_ACTIVATION_FN:
+ dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out)
+ else:
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out)
+ dsdd_out = sdd_out.grad
+
+ # Compute dw1.
+ dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ gg.backend.gmm(dsdd_out, w1, batch_sizes, c=ddsd_out)
+ dx = ddsd_out
+ return dx, dw1, dw2, None, None
+
+
+memory_optimized_grouped_mlp = MemoryOptimizedGroupedMLP.apply
+
+
+class GroupedMLP(SparseMLP):
+
+ def forward(self, x, tokens_per_expert):
+ batch_sizes = tokens_per_expert.cpu().to(torch.long)
+ w1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.w2))
+
+ # Re-shape the weights for the grouped GEMMs.
+ ne = mpu.experts_per_rank(self.args)
+ w1 = resolve_dtensor(w1).view(ne, -1, self.args.hidden_size)
+ w2 = resolve_dtensor(w2).view(ne, -1, self.args.hidden_size)
+
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_grouped_mlp(
+ x,
+ w1,
+ w2,
+ batch_sizes,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ assert gg.ops is not None
+ x = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
+ x = self.args.activation_fn(x)
+ return gg.ops.gmm(x, w2, batch_sizes)
+
+
+class SharedMLP(torch.nn.Module):
+ """MLP for shared expert.
+
+ Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTMLP class
+ """
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ self.fc_kwargs: dict[str, Any] = {
+ 'bias': args.bias,
+ 'device': args.device,
+ }
+ self.fc_kwargs.update(args.fc_kwargs)
+
+ self.up_proj = args.fc_cls(
+ args.hidden_size,
+ args.shared_expert_hidden_size,
+ **self.fc_kwargs,
+ )
+ self.act = args.activation_fn
+ self.down_proj = args.fc_cls(
+ args.shared_expert_hidden_size,
+ args.hidden_size,
+ **self.fc_kwargs,
+ )
+ self.down_proj._is_residual = True # a flag for llm-foundry init
+
+ def add_experts_sharedexpert(
+ self,
+ shared_expert_out: torch.Tensor,
+ expert_out: torch.Tensor,
+ ) -> torch.Tensor:
+ # Helper function to add expert output to shared expert output
+ # with optional weighted sum.
+ if self.args.shared_expert_weighted_sum:
+ # enable using weighted sum for shared expert output
+ # wieghted by number of experts used
+ t_experts = self.args.moe_top_k + 1
+ sh_mlp_out = shared_expert_out / t_experts
+ return sh_mlp_out.add(
+ expert_out,
+ alpha=(self.args.moe_top_k / t_experts),
+ )
+
+ return shared_expert_out + expert_out
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.down_proj(self.act(self.up_proj(x)))
diff --git a/build/torch29-cxx11-cu126-x86_64-linux/_layers/moe.py b/build/torch29-cxx11-cu126-x86_64-linux/_layers/moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0a4aeaacc9c86fc70944e730c53f7a55644e05e
--- /dev/null
+++ b/build/torch29-cxx11-cu126-x86_64-linux/_layers/moe.py
@@ -0,0 +1,507 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Optional, Tuple
+
+import numpy as np
+import torch
+import torch.distributed as dist
+
+# import megablocks.ops as ops
+# from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry
+# from megablocks.layers.all_to_all import all_to_all
+# from megablocks.layers.arguments import Arguments
+
+from ..ops import (
+ sort,
+ histogram,
+ inclusive_cumsum,
+ exclusive_cumsum,
+ binned_gather,
+ binned_scatter,
+ gather,
+ scatter,
+ repeat,
+ replicate,
+)
+
+from . import common, mlp, mpu, router, sharedexpert_registry
+from .arguments import Arguments
+from .all_to_all import all_to_all
+
+_LOAD_BALANCING_LOSS = []
+
+
+def save_load_balancing_loss(loss):
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.append(loss)
+
+
+def get_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ return _LOAD_BALANCING_LOSS
+
+
+def clear_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.clear()
+
+
+def batched_load_balancing_loss(args: Arguments):
+ if args.moe_loss_weight == 0:
+ return 0.0
+
+ # tokens_per_expert[i].shape = (num_experts)
+ # expert_scores[i].shape = (tokens, num_experts)
+ tokens_per_expert, expert_scores = zip(*get_load_balancing_loss())
+ num_layers_per_pipeline_stage = (args.num_layers // args.pipeline_model_parallel_size)
+ if args.num_layers_per_virtual_pipeline_stage is not None:
+ num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage
+
+ if len(tokens_per_expert) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f'Expected {num_layers_per_pipeline_stage} token_per_experts '
+ f'but found {len(tokens_per_expert)}.\nnum_layers = '
+ f'{args.num_layers}\npipeline_model_parallel_size = '
+ f'{args.pipeline_model_parallel_size}\n'
+ 'num_layers_per_virtual_pipeline_stage'
+ f' = {args.num_layers_per_virtual_pipeline_stage}',
+ )
+ if len(expert_scores) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f'Expected {num_layers_per_pipeline_stage} expert_scores '
+ f'but found {len(tokens_per_expert)}.\nnum_layers = '
+ f'{args.num_layers}\npipeline_model_parallel_size = '
+ f'{args.pipeline_model_parallel_size}\n'
+ 'num_layers_per_virtual_pipeline_stage'
+ f' = {args.num_layers_per_virtual_pipeline_stage}',
+ )
+
+ # Verify the shape of the tokens_per_expert and expert_scores tensors.
+ assert all((x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert))
+
+ tokens = expert_scores[0].shape[0]
+ assert all(((x.ndim == 2 and x.shape[1] == args.moe_num_experts and x.shape[0] == tokens) for x in expert_scores))
+
+ # Concatenate the contributions of each layer and convert to
+ # the correct types and formats for the dot product.
+ expert_scores = torch.cat(expert_scores, dim=1)
+ if args.moe_lbl_in_fp32:
+ expert_scores = expert_scores.float()
+ if tokens != 0:
+ expert_scores = expert_scores.mean(dim=0)
+ else:
+ expert_scores = expert_scores.sum(dim=0)
+ tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype)
+
+ expected_values = num_layers_per_pipeline_stage * args.moe_num_experts
+ assert tokens_per_expert.numel() == expected_values
+ assert expert_scores.numel() == expected_values
+
+ # Calculate the total scale across all factors.
+ #
+ # loss_weight * num_experts / (num_layers * tokens * top_k)
+ scale_numerator = (args.moe_num_experts * args.moe_loss_weight)
+ scale_denominator = (args.num_layers * tokens * args.moe_top_k)
+ scale = scale_numerator / scale_denominator
+ return scale * torch.dot(tokens_per_expert, expert_scores)
+
+
+# NOTE: This class defines MoE expert computation, including expert model parallel
+# communication. When using FSDP on top of MegaBlocks this is the module that should
+# be wrapped s.t. the weight all-gathers can be scheduled *before* the expert model
+# parallel all2all.
+class ParallelMLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super(ParallelMLP, self).__init__()
+ self.args = args
+
+ # Calculate the number of experts in total and the number of experts
+ # owned by this rank.
+ # world_size = mpu.get_expert_parallel_world_size(args)
+ self.num_experts = args.moe_num_experts
+ self.top_k = self.args.moe_top_k
+
+ # Calculate the number of bits needed to represent the expert indices
+ # so that we can pass it to radix sort.
+ self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1)
+
+ # Expert MLP.
+ self.mlp = mlp.MLP(args)
+
+ self.bias: Optional[torch.Tensor]
+ if self.args.bias:
+ # Note that the output bias is not parallelized with expert
+ # model parallelism.
+ self.bias = torch.nn.Parameter(
+ torch.empty(
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ torch.nn.init.zeros_(self.bias)
+ else:
+ self.register_parameter('bias', None)
+
+ # Select the forward function for the operating mode.
+ self.forward_fn = (self.parallel_forward_once if args.moe_expert_model_parallelism else self.forward_once)
+
+ def expert_capacity(self, tokens: int) -> int:
+ world_size = mpu.get_expert_parallel_world_size(self.args)
+ tokens_per_expert = (self.top_k * tokens * world_size / self.num_experts)
+ return int(self.args.moe_capacity_factor * tokens_per_expert)
+
+ def load_balancing_loss(self, tokens_per_expert: torch.Tensor, expert_scores: torch.Tensor):
+ """Calculate the load balancing loss contribution."""
+ assert len(expert_scores.size()) == 2
+ tokens, num_experts = expert_scores.size()
+ assert num_experts == self.num_experts
+ assert len(tokens_per_expert.size()) == 1
+ num_experts, = tokens_per_expert.size()
+ assert num_experts == self.num_experts
+ scale = self.num_experts / (tokens * self.top_k)
+ return scale * torch.dot(
+ tokens_per_expert.to(expert_scores.dtype),
+ expert_scores.mean(dim=0),
+ )
+
+ def indices_and_bins(self,
+ top_expert: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ # Sort the expert ids to produce the scatter/gather
+ # indices for the permutation.
+ #
+ # TODO(tgale): Is it worth doing this conversion to 32-bit
+ # prior? Could we place the `torch.max` operation to return
+ # 32-bit expert indices?
+ top_expert = top_expert.int()
+ # output = ops.sort(top_expert, self.sort_end_bit)
+ output = sort(top_expert, self.sort_end_bit)
+ assert output is not None
+ bin_ids, indices = output
+
+ # Histogram the expert ids to identify the number of
+ # tokens routed to each expert.
+ #
+ # TODO(tgale): Does the sorted data produce a more favorable
+ # data distribution for histogram? Or is the op parallelism
+ # worth more?
+ # tokens_per_expert = ops.histogram(top_expert, self.num_experts)
+ tokens_per_expert = histogram(top_expert, self.num_experts)
+
+ # Calculate the bin bounds for the sorted tokens.
+ # bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ bins = inclusive_cumsum(tokens_per_expert, 0)
+ assert bins is not None
+ bins = bins.view(1) if not len(bins.size()) else bins
+
+ assert isinstance(indices, torch.Tensor)
+ assert isinstance(bin_ids, torch.Tensor)
+ assert isinstance(bins, torch.Tensor)
+ assert isinstance(tokens_per_expert, torch.Tensor)
+
+ return indices, bin_ids, bins, tokens_per_expert
+
+ def permute_and_compute(
+ self,
+ x: torch.Tensor,
+ tokens_per_expert: int, # unused
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor, # unused
+ expert_weights: torch.Tensor,
+ bins: torch.Tensor,
+ expert_capacity: int,
+ top_k: int,
+ ):
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ # output = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
+ output = binned_gather(x, indices, bins, expert_capacity, top_k)
+ assert output is not None
+ x = output
+
+ # Perform the expert computation. Note that we don't
+ # use biases for these linear operations.
+ x = self.mlp(x)
+
+ # Un-route the data for the MoE output.
+ # return ops.binned_scatter(x, indices, expert_weights, bins, top_k)
+ return binned_scatter(x, indices, expert_weights, bins, top_k)
+
+
+ def forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ # If expert_capacity is set to zero, set the number of tokens
+ # per expert to the maximum we need to avoid dropping tokens.
+ sl, bs, _ = x.size()
+ expert_capacity = self.expert_capacity(sl * bs)
+ if expert_capacity == 0:
+ expert_capacity = torch.max(tokens_per_expert).item()
+
+ x = self.permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ self.top_k,
+ )
+ return x, tokens_per_expert
+
+ def parallel_forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ # NOTE: This function implements the same computation as forward_once
+ # but with expert model parallelism.
+ #
+ # 1. Permute the tokens locally so that they are grouped by their
+ # expert assignments. This allows us to transfer all of the tokens
+ # for a remote device in one communication primitive.
+ #
+ # 2. Permute the tokens across the expert parallel devices. After
+ # this is completed each device has all of the tokens assigned to
+ # its set of experts in its local HBM.
+ #
+ # 3. Permute the tokens locally so that they are grouped by their
+ # expert assignement. After the distributed permutation the tokens
+ # are grouped by which device they came from. We re-order them
+ # locally to allow for efficient computation.
+ #
+ # After this series of permutations we compute the linear layers
+ # and then repeat these three steps in reverse to produce the final
+ # output.
+ #
+ # Compute the mapping of local tokens to experts.
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ # If we're sharding the experts along the hidden dimension
+ # multiple devices own parts of the same sets of experts.
+ # Replicate the token counts so every device gets the counts.
+ # repeated_tokens_per_expert = ops.repeat(
+ repeated_tokens_per_expert = repeat(
+ tokens_per_expert,
+ (mpu.hidden_sharding_degree(self.args),),
+ )
+
+ # Pass token count information to the device on which the
+ # target expert resides.
+ parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert,)
+ tpe_handle = dist.all_to_all_single(
+ parallel_tokens_per_expert,
+ repeated_tokens_per_expert,
+ group=self.args.expert_parallel_group,
+ async_op=True,
+ )
+
+ # Permute locally and without any padding so that tokens for each
+ # parallel device are stored contiguously.
+ #
+ # This view updates the shape of the tensor from [sl, bs, hs] to
+ # [sl * bs, hs] prior to the permutation.
+ x = x.view(-1, x.shape[-1])
+ # output = ops.gather(x, indices, bin_ids, bins, self.top_k)
+ output = gather(x, indices, bin_ids, bins, self.top_k)
+ assert output is not None
+ x = output
+
+ # Compute the number of tokens that will be received from each
+ # device and permute the input data across the devices.
+ with torch.no_grad():
+ tpe_handle.wait()
+ experts_per_rank = mpu.experts_per_rank(self.args)
+
+ # Reshape to [world_size, num_experts_per_rank].
+ world_size = mpu.get_expert_parallel_world_size(self.args)
+ repeated_tokens_per_expert = (repeated_tokens_per_expert.view(world_size, experts_per_rank))
+ parallel_tokens_per_expert = (parallel_tokens_per_expert.view(world_size, experts_per_rank))
+
+ # TODO(tgale): It might be faster to do this on the GPU and
+ # then communicate the results back to the host.
+ send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1)
+ parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
+ recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1)
+
+ # Convert the send/recv counts to lists.
+ send_counts = send_counts.tolist()
+ recv_counts = recv_counts.tolist()
+ tokens_received = sum(recv_counts)
+
+ # If we're sharding the experts along the hidden dimension
+ # multiple devices own parts of the same sets of experts.
+ # Replicate the token counts so devices that share experts
+ # get all of the tokens assigned to them.
+ #
+ # TODO(tgale): Fuse this into the prior, local permutation.
+ # x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
+ x = repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
+
+ # Start the cross-device permutation asynchronously so we can
+ # overlap communication with computation.
+ parallel_x, parallel_x_handle = all_to_all(
+ x,
+ recv_counts,
+ send_counts,
+ self.args.expert_parallel_group,
+ async_op=True,
+ )
+
+ with torch.no_grad():
+ # After we do the cross-device permutation we have the tokens on the
+ # correct device but not yet grouped by expert because we received
+ # tokens from each device as contiguous chunks. To group the tokens
+ # for expert computation we'll do one more local permutation. The
+ # rest of this torch.no_grad() scope sets up the indices and bins
+ # for this permutation.
+ # replicate_bins = ops.inclusive_cumsum(
+ replicate_bins = inclusive_cumsum(
+ parallel_tokens_per_expert.flatten(),
+ 0,
+ )
+ replicate_bins = (replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins)
+
+ # Construct the expert indices for the permuted tokens.
+ parallel_top_expert = torch.remainder(
+ torch.arange(
+ self.num_experts * mpu.hidden_sharding_degree(self.args),
+ dtype=torch.int32,
+ device=indices.device,
+ ),
+ mpu.experts_per_rank(self.args),
+ )
+ # parallel_top_expert = ops.replicate(
+ parallel_top_expert = replicate(
+ parallel_top_expert.unsqueeze(dim=0),
+ replicate_bins,
+ tokens_received,
+ ).flatten()
+
+ # TODO(tgale): The sort_end_bit here can be reduced.
+ # parallel_bin_ids, parallel_indices = ops.sort(
+ parallel_bin_ids, parallel_indices = sort(
+ parallel_top_expert,
+ self.sort_end_bit,
+ )
+
+ # Calculate the bins boundaries from the token counts.
+ parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+ # parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = (parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins)
+
+ # If expert_capacity is set to zero, set the number of tokens
+ # per expert to the maximum we need to avoid dropping tokens.
+ tokens, _ = x.size()
+ expert_capacity = self.expert_capacity(tokens)
+ if expert_capacity == 0:
+ expert_capacity = torch.max(parallel_tokens_per_expert).item()
+
+ # Locally permute the tokens and perform the expert computation.
+ # Block to make sure that the cross-device permutation is complete.
+ if self.args.mlp_impl == 'grouped':
+ # GroupedMLP requires counts on CPU. We can use the tensor already
+ # moved to CPU for the prior all_to_all, which avoids an extra
+ # device synchronization.
+ parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+ parallel_x_handle.wait()
+ parallel_x = self.permute_and_compute(
+ parallel_x,
+ parallel_tokens_per_expert,
+ parallel_indices,
+ parallel_bin_ids,
+ None, # expert_weights
+ parallel_bins,
+ expert_capacity,
+ top_k=1,
+ )
+
+ # Un-permute the tokens across the devices.
+ x, _ = all_to_all(
+ parallel_x,
+ send_counts,
+ recv_counts,
+ self.args.expert_parallel_group,
+ )
+
+ # Reduce along the hidden sharding to get the final outputs.
+ #
+ # TODO(tgale): Fuse this into the following local permutation.
+ shape = (
+ mpu.hidden_sharding_degree(self.args),
+ -1,
+ self.args.hidden_size,
+ )
+ # x = ops.sum(x.view(shape), dim=0)
+ x = x.view(shape).sum(dim=0)
+
+ # Un-permute locally to setup for the next series of operations.
+ # x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
+ x = scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
+ return x, tokens_per_expert.flatten()
+
+ def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ in_shape = x.size()
+
+ # Compute the experts.
+ x, tokens_per_expert = self.forward_fn(x, expert_weights, top_experts)
+ if self.training and self.args.moe_loss_weight > 0:
+ save_load_balancing_loss((tokens_per_expert, scores))
+ x = x.view(in_shape)
+ if self.bias is not None:
+ if self.args.return_bias:
+ return x, self.bias
+ return x + self.bias
+ return x
+
+
+class MoE(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super(MoE, self).__init__()
+
+ # Token router.
+ self.router = router.LearnedRouter(args)
+
+ # Expert computation helper.
+ self.experts = self._init_experts_mlp(args)
+
+ self.shared_expert = None
+ if args.shared_expert:
+ # SharedExpert computation helper.
+ self.shared_expert = sharedexpert_registry.get(args)
+
+ def _init_experts_mlp(self, args: Arguments):
+ return ParallelMLP(args)
+
+ def forward(self, x: torch.Tensor):
+ # NOTE: If we're going to cast the activations to lower precision
+ # do it before we permute the tokens to save bandwidth.
+ x = common.cast_if_autocast_enabled(x)
+
+ # Compute the expert scores and assignments.
+ scores, expert_weights, top_experts = self.router(x)
+
+ # Compute the experts.
+ out = self.experts(x, scores, expert_weights, top_experts)
+ if self.shared_expert is not None:
+ shared_expert_out = self.shared_expert(x)
+ out = self.shared_expert.add_experts_sharedexpert(
+ shared_expert_out,
+ out,
+ )
+ return out
diff --git a/build/torch29-cxx11-cu126-x86_64-linux/_layers/mpu.py b/build/torch29-cxx11-cu126-x86_64-linux/_layers/mpu.py
new file mode 100644
index 0000000000000000000000000000000000000000..434e143ab42bf3f83406d69e9dd1f72777716e22
--- /dev/null
+++ b/build/torch29-cxx11-cu126-x86_64-linux/_layers/mpu.py
@@ -0,0 +1,94 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Optional
+
+import torch
+import torch.distributed as dist
+
+# from megablocks.layers.arguments import Arguments
+from .arguments import Arguments
+
+
+class MoeParam(torch.Tensor):
+
+ def __init__(self):
+ super().__init__(self)
+ self.expert_model_parallel: bool
+
+
+def is_moe_param(tensor: torch.Tensor) -> bool:
+ return hasattr(tensor, 'expert_model_parallel')
+
+
+def get_expert_parallel_world_size(args: Arguments) -> int:
+ return (dist.get_world_size(args.expert_parallel_group) if args.moe_expert_model_parallelism else 1)
+
+
+def get_expert_parallel_rank(args: Arguments) -> int:
+ return (dist.get_rank(args.expert_parallel_group) if args.moe_expert_model_parallelism else 0)
+
+
+def set_expert_model_parallel_attributes(
+ tensor: torch.Tensor,
+ is_parallel: bool,
+):
+ assert not hasattr(tensor, 'expert_model_parallel')
+ setattr(tensor, 'expert_model_parallel', is_parallel)
+
+
+def param_is_expert_model_parallel(param: MoeParam) -> bool:
+ return (hasattr(param, 'expert_model_parallel') and param.expert_model_parallel)
+
+
+def copy_expert_model_parallel_attributes(
+ destination_tensor: torch.Tensor,
+ source_tensor: torch.Tensor,
+):
+ if hasattr(source_tensor, 'expert_model_parallel'):
+ setattr(
+ destination_tensor,
+ 'expert_model_parallel',
+ getattr(source_tensor, 'expert_model_parallel'),
+ )
+
+
+def synchronized_print(group: Optional[dist.ProcessGroup], *x: torch.Tensor):
+ world_size = dist.get_world_size(group)
+ rank = dist.get_rank(group)
+ for i in range(world_size):
+ dist.barrier(group)
+ if i == rank:
+ print(f'rank = {rank}', *x)
+
+
+# Helpers for expert/tensor sharding.
+def expert_sharding_degree(args: Arguments) -> int:
+ world_size = get_expert_parallel_world_size(args)
+ esd = min(world_size, args.moe_num_experts)
+
+ if (args.moe_num_experts % esd) != 0:
+ raise ValueError(f'Cannot shard {args.moe_num_experts} experts {esd} ways.',)
+ return esd
+
+
+def hidden_sharding_degree(args: Arguments) -> int:
+ world_size = get_expert_parallel_world_size(args)
+ esd = expert_sharding_degree(args)
+ hsd = world_size // esd
+
+ if (args.ffn_hidden_size % hsd) != 0:
+ raise ValueError(f'Cannot shard {args.ffn_hidden_size} features {hsd} ways.',)
+ if (esd * hsd) != world_size:
+ raise ValueError(
+ f"Invalid sharding. 'expert_sharding_degree' ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size}).",
+ )
+ return hsd
+
+
+def experts_per_rank(args: Arguments) -> int:
+ return args.moe_num_experts // expert_sharding_degree(args)
+
+
+def features_per_rank(args: Arguments) -> int:
+ return args.ffn_hidden_size // hidden_sharding_degree(args)
diff --git a/build/torch29-cxx11-cu126-x86_64-linux/_layers/router.py b/build/torch29-cxx11-cu126-x86_64-linux/_layers/router.py
new file mode 100644
index 0000000000000000000000000000000000000000..37cb2782348d62583376f1a183c7ede83601216d
--- /dev/null
+++ b/build/torch29-cxx11-cu126-x86_64-linux/_layers/router.py
@@ -0,0 +1,116 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+
+# from megablocks.layers import common
+# from megablocks.layers.arguments import Arguments
+from . import common
+from .arguments import Arguments
+
+_ROUTER_LOGITS = []
+
+
+def _save_router_logits(logits: torch.Tensor, args: Arguments):
+ if args.moe_zloss_weight == 0:
+ return
+ global _ROUTER_LOGITS
+ _ROUTER_LOGITS.append(logits)
+
+
+def clear_router_zloss():
+ global _ROUTER_LOGITS
+ _ROUTER_LOGITS.clear()
+
+
+def batched_router_zloss(args: Arguments):
+ global _ROUTER_LOGITS
+
+ if args.moe_zloss_weight == 0:
+ import warnings
+ warnings.warn('Call to batched_router_zloss, but moe_zloss_weight=0')
+ return 0
+
+ logits_per_router = _ROUTER_LOGITS
+
+ if args.moe_zloss_in_fp32:
+ logits_per_router = [logits.float() for logits in logits_per_router]
+
+ unscaled_zloss_per_router = torch.stack([
+ torch.logsumexp(logits, dim=1).square().mean() for logits in logits_per_router
+ ])
+
+ return args.moe_zloss_weight * unscaled_zloss_per_router
+
+
+# NOTE: To enable end-to-end benchmarking without convergence we
+# support a flag to force the router to assign tokens uniformly
+# across the experts. We do this with a custom autograd operation
+# so that PyTorch still executes the full set of router operation.
+class _UniformExpertAssignment(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, num_experts: int):
+ out = torch.arange(x.numel(), dtype=x.dtype, device=x.device)
+ out = torch.remainder(out, num_experts)
+ return out.view(x.shape)
+
+
+_uniform_expert_assignment = _UniformExpertAssignment.apply
+
+
+class LearnedRouter(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+
+ # Learned router parameters.
+ #
+ # NOTE: This weight matrix is not parallelized with expert model
+ # parallelism. Each device needs the entire router weight matrix
+ # so that it can route its batch of data correctly.
+ self.layer = torch.nn.Linear(
+ args.hidden_size,
+ args.moe_num_experts,
+ bias=False,
+ dtype=common.dtype(args),
+ device=args.device,
+ )
+ args.init_method(self.layer.weight)
+
+ def jitter(self, x: torch.Tensor):
+ low: float = 1.0 - self.args.moe_jitter_eps
+ high: float = 1.0 + self.args.moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return low + noise * (high - low)
+
+ def _top_k(self, scores: torch.Tensor):
+ if self.args.moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, self.args.moe_top_k, dim=-1)
+
+ def forward(self, x: torch.Tensor):
+ if self.training and self.args.moe_jitter_eps is not None:
+ x = x * self.jitter(x)
+
+ logits = self.layer(x.view(-1, x.shape[-1]))
+ _save_router_logits(logits, self.args)
+ scores = logits.softmax(dim=-1)
+ expert_weights, expert_indices = self._top_k(scores)
+ if self.args.moe_normalize_expert_weights:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=self.args.moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+
+ expert_indices = (
+ _uniform_expert_assignment(
+ expert_indices,
+ self.args.moe_num_experts,
+ ) if self.args.uniform_expert_assignment else expert_indices
+ )
+ return scores, expert_weights, expert_indices
diff --git a/build/torch29-cxx11-cu126-x86_64-linux/_layers/sharedexpert_registry.py b/build/torch29-cxx11-cu126-x86_64-linux/_layers/sharedexpert_registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..5840862f88f370ace5fd49bd0612fc98d186cc49
--- /dev/null
+++ b/build/torch29-cxx11-cu126-x86_64-linux/_layers/sharedexpert_registry.py
@@ -0,0 +1,32 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Union
+
+# from megablocks.layers import glu, mlp
+# from megablocks.layers.arguments import Arguments
+from . import glu, mlp
+from .arguments import Arguments
+
+_REGISTRY = {
+ 'mlp': mlp.SharedMLP,
+ 'glu': glu.SharedGLU,
+}
+
+
+def get(args: Arguments) -> Union[mlp.SharedMLP, glu.SharedGLU]:
+ """Returns an SharedMLP for use in a dMoE instance.
+
+ Uses the provided arguments to instantiate the appropriate
+ SharedMLP instance.
+
+ Args:
+ args: propagated Arguments dataclass.
+
+ Returns:
+ An instantiated SharedMLP constructed using the input args.
+ """
+ if args.mlp_type not in _REGISTRY:
+ raise ValueError(f'Unsupported mlp type: {args.mlp_type}')
+
+ return _REGISTRY[args.mlp_type](args)
diff --git a/build/torch29-cxx11-cu126-x86_64-linux/_megablocks_cuda_6e04dec.abi3.so b/build/torch29-cxx11-cu126-x86_64-linux/_megablocks_cuda_6e04dec.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..49c7e72e12b54f9cafe86f5fd108efd17175d314
--- /dev/null
+++ b/build/torch29-cxx11-cu126-x86_64-linux/_megablocks_cuda_6e04dec.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fae42809a452f57bb4ef6967a397029f4e557ad73424c1b68fb613070dcd3f0d
+size 15046832
diff --git a/build/torch29-cxx11-cu126-x86_64-linux/_ops.py b/build/torch29-cxx11-cu126-x86_64-linux/_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..b2f202b8db3c3f3028303ab4308cf35f950e2c74
--- /dev/null
+++ b/build/torch29-cxx11-cu126-x86_64-linux/_ops.py
@@ -0,0 +1,9 @@
+import torch
+from . import _megablocks_cuda_6e04dec
+ops = torch.ops._megablocks_cuda_6e04dec
+
+def add_op_namespace_prefix(op_name: str):
+ """
+ Prefix op by namespace.
+ """
+ return f"_megablocks_cuda_6e04dec::{op_name}"
diff --git a/build/torch29-cxx11-cu126-x86_64-linux/_version.py b/build/torch29-cxx11-cu126-x86_64-linux/_version.py
new file mode 100644
index 0000000000000000000000000000000000000000..c55783177af19bc03654c730c4892df8f8532279
--- /dev/null
+++ b/build/torch29-cxx11-cu126-x86_64-linux/_version.py
@@ -0,0 +1,6 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+"""The MegaBlocks Version."""
+
+__version__ = '0.11.0.dev0'
diff --git a/build/torch29-cxx11-cu126-x86_64-linux/backend/__init__.py b/build/torch29-cxx11-cu126-x86_64-linux/backend/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d4e43e9b7e6a9a1c3bde2df34914643ca5d8332
--- /dev/null
+++ b/build/torch29-cxx11-cu126-x86_64-linux/backend/__init__.py
@@ -0,0 +1,2 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
diff --git a/build/torch29-cxx11-cu126-x86_64-linux/backend/kernels.py b/build/torch29-cxx11-cu126-x86_64-linux/backend/kernels.py
new file mode 100644
index 0000000000000000000000000000000000000000..d20139324883338ddb312e4b05a72056d74491ac
--- /dev/null
+++ b/build/torch29-cxx11-cu126-x86_64-linux/backend/kernels.py
@@ -0,0 +1,557 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import triton
+import triton.language as tl
+
+# Stub triton autotune when testing in a env that does not have CUDA
+# this approach preserves the original code but enables testing without a GPU
+if torch.cuda.is_available() is False:
+ import warnings
+
+ warnings.warn("CUDA is not available. Triton autotuning is disabled.")
+
+ def _no_autotune(*args, **kwargs):
+ def deco(fn):
+ return fn
+ return deco
+
+ triton.autotune = _no_autotune
+
+
+def assert_is_tensor(x, ndim):
+ if x.ndim != ndim:
+ raise ValueError(f'Expected {ndim}-tensor but got {x.ndim}-tensor')
+
+
+def assert_is_matrix(x):
+ assert_is_tensor(x, 2)
+
+
+def assert_is_vector(x):
+ if x.ndim != 1:
+ raise ValueError(f'Expected 1-tensor but got {x.ndim}-tensor')
+
+
+def assert_equal(a, b):
+ if a != b:
+ raise ValueError(f'Expected dimensions to be equal but got {a} and {b}.',)
+
+
+# a: (tokens, hidden_size), real.
+# indices: (tokens * top_k), integer.
+# bin_ids: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+# padded_bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _padded_copy(
+ a,
+ b,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+ A_TO_B: tl.constexpr,
+ SCALE: tl.constexpr,
+):
+ # Our index into array 'a'.
+ index_a = tl.load(indices + tl.program_id(0))
+
+ # One threadblock per row in 'a'. Array 'b' has greater or equal
+ # number of rows since they could be padded.
+ bin_idx = tl.load(bin_ids + tl.program_id(0))
+
+ # Now we know what bin we're assigned to, but we need to know how
+ # many threadblocks were assigned to earlier bins so we can offset
+ # in our bin properly.
+ offset_in_bin = tl.program_id(0)
+ if bin_idx > 0:
+ offset_in_bin -= tl.load(bins + bin_idx - 1)
+
+ # Load the starting index of our bin in array 'b'.
+ index_b = offset_in_bin
+ if bin_idx > 0:
+ index_b += tl.load(padded_bins + bin_idx - 1)
+
+ # Offset the input and output pointers.
+ #
+ # If we're going from A to B, divide the input index to copy
+ # the same input repeatedly. If we're going from B to A we
+ # need to reduce the result. Using atomics is slow, so we
+ # do the reduce step in a second kernel.
+ offset = index_a // TOP_K if A_TO_B else index_a
+ a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS)
+ b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ # Load the scale, if requested.
+ scale = tl.load(weights + index_a) if SCALE else 1
+
+ # Swap the pointers depending on the direction.
+ iptr = a if A_TO_B else b
+ optr = b if A_TO_B else a
+
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ x = tl.load(iptr + offsets, mask=mask)
+ x = x.to(tl.float32) * scale.to(tl.float32)
+
+ tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask)
+
+ offsets += BLOCK_X
+
+
+def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+ assert_equal(bin_ids.shape[0], x.shape[0] * top_k)
+ assert_equal(bins.size(), padded_bins.size())
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ # NOTE: Because of the padding, the output size is dynamic.
+ # We load the final padded bin bound to get the output rows.
+ output_rows = padded_bins[-1].cpu().item()
+ out = torch.zeros((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ x,
+ out,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def gather(x, indices, bin_ids, weights, bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+ assert_equal(bin_ids.shape[0], x.shape[0] * top_k)
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ # NOTE: There is no padding so the output rows equals the
+ # input rows multiplied by top_k.
+ output_rows = x.shape[0] * top_k
+ out = torch.empty((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ x,
+ out,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], bin_ids.shape[0])
+ assert_equal(bins.size(), padded_bins.size())
+
+ if weights is not None:
+ assert_equal(indices.shape[0], weights.shape[0])
+
+ tokens = indices.shape[0] // top_k
+ out = torch.empty((tokens, top_k, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ out,
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=False,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+
+ # Reduce along the top-k dimension, if needed.
+ return out.sum(dim=1) if top_k > 1 else out.view(tokens, x.shape[1])
+
+
+def scatter(x, indices, bin_ids, weights, bins, top_k):
+ return padded_scatter(x, indices, bin_ids, weights, bins, bins, top_k)
+
+
+# x: (tokens, top_k, hidden_size), real
+# grad: (tokens, hidden_size), real.
+# wgrad: (tokens, top_k), real.
+# indices: (tokens * top_k), integer.
+# bin_ids: (tokens * top_k), integer.
+# bins: (num_experts), integer.
+# padded_bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _padded_copy_wgrad(
+ x,
+ grad,
+ wgrad,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+):
+ # Our index into 'tokens * top_k'.
+ index_out = tl.load(indices + tl.program_id(0))
+
+ # One threadblock per row in 'a'. Array 'b' has greater or equal
+ # number of rows since they could be padded.
+ bin_idx = tl.load(bin_ids + tl.program_id(0))
+
+ # Now we know what bin we're assigned to, but we need to know how
+ # many threadblocks were assigned to earlier bins so we can offset
+ # in our bin properly.
+ offset_in_bin = tl.program_id(0)
+ if bin_idx > 0:
+ offset_in_bin -= tl.load(bins + bin_idx - 1)
+
+ # Load the starting index of our bin in array 'x'.
+ index_x = offset_in_bin
+ if bin_idx > 0:
+ index_x += tl.load(padded_bins + bin_idx - 1)
+
+ # Offset the input and output pointers.
+ wgrad += index_out
+ grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS)
+ x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ data = tl.load(x + offsets, mask=mask).to(tl.float32)
+ scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
+ acc += data * scale
+ offsets += BLOCK_X
+
+ # Reduce to get the final result and store.
+ out = tl.sum(acc).to(wgrad.dtype.element_ty)
+ tl.store(wgrad, out)
+
+
+def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_matrix(grad)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], bin_ids.shape[0])
+ assert_equal(bins.size(), padded_bins.size())
+
+ tokens = indices.shape[0] // top_k
+ out = torch.empty((tokens * top_k), dtype=x.dtype, device=x.device)
+ _padded_copy_wgrad[(indices.shape[0],)](
+ x,
+ grad,
+ out,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ TOP_K=top_k,
+ )
+ return out
+
+
+def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k):
+ return padded_scatter_wgrad(x, grad, indices, bin_ids, bins, bins, top_k)
+
+
+# a: (tokens, hidden_size), real.
+# b: (num_experts, expert_capacity, num_columns), real.
+# indices: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _binned_copy(
+ a,
+ b,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+ A_TO_B: tl.constexpr,
+ SCALE: tl.constexpr,
+):
+ # Load our indices into the output.
+ expert_idx = tl.program_id(0)
+ entry_idx = tl.program_id(1)
+
+ # Calculate our offset into the output.
+ index_b = expert_idx * expert_capacity + entry_idx
+
+ # Load the index bounds for our bin and calculate
+ # the number of tokens assigned to our expert.
+ start = 0
+ if expert_idx > 0:
+ start = tl.load(bins + expert_idx - 1)
+ end = tl.load(bins + expert_idx)
+ num_tokens = end - start
+
+ # Calculate our offset into the input. If we don't
+ # have an input exit early.
+ if entry_idx >= num_tokens:
+ return
+ index_a = tl.load(indices + start + entry_idx)
+
+ # Offset the input and output pointers.
+ #
+ # If we're going from A to B, divide the input index to copy
+ # the same input repeatedly. If we're going from B to A we
+ # need to reduce the result. Using atomics is slow, so we
+ # do the reduce step in a second kernel.
+ offset = index_a // TOP_K if A_TO_B else index_a
+ a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS)
+ b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ # Load the scale, if requested.
+ scale = tl.load(weights + index_a) if SCALE else 1
+
+ # Swap the pointers depending on the direction.
+ #
+ # NOTE: We need to zero the output in both directions.
+ iptr = a if A_TO_B else b
+ optr = b if A_TO_B else a
+
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ x = tl.load(iptr + offsets, mask=mask)
+ x = x.to(tl.float32) * scale.to(tl.float32)
+
+ tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask)
+
+ offsets += BLOCK_X
+
+
+def binned_gather(x, indices, weights, bins, expert_capacity, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ num_experts = bins.shape[0]
+ out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
+
+ _binned_copy[(num_experts, expert_capacity)](
+ x,
+ out,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def binned_scatter(x, indices, weights, bins, top_k):
+ # Validate the input shapes.
+ assert_is_tensor(x, 3)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(bins.shape[0], x.shape[0])
+
+ if weights is not None:
+ assert_equal(indices.shape[0], weights.shape[0])
+
+ num_experts, expert_capacity, hidden_size = x.shape
+ tokens = indices.shape[0] // top_k
+ out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
+ _binned_copy[(num_experts, expert_capacity)](
+ out,
+ x,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS=hidden_size,
+ A_TO_B=False,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+
+ # Reduce along the top-k dimension, if needed.
+ return out.sum(dim=1) if top_k > 1 else out.view(tokens, hidden_size)
+
+
+# a: (tokens, hidden_size), real.
+# b: (num_experts, expert_capacity, num_columns), real.
+# indices: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _binned_copy_wgrad(
+ x,
+ grad,
+ wgrad,
+ num_experts,
+ expert_capacity,
+ indices,
+ bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+):
+ # Load our indices into the output.
+ expert_idx = tl.program_id(0)
+ entry_idx = tl.program_id(1)
+
+ # Calculate our offset into the output.
+ index_x = expert_idx * expert_capacity + entry_idx
+
+ # Load the index bounds for our bin and calculate
+ # the number of tokens assigned to our expert.
+ start = 0
+ if expert_idx > 0:
+ start = tl.load(bins + expert_idx - 1)
+ end = tl.load(bins + expert_idx)
+ num_tokens = end - start
+
+ # Calculate our offset into the input. If we don't
+ # have an input exit early.
+ if entry_idx >= num_tokens:
+ return
+ index_out = tl.load(indices + start + entry_idx)
+
+ # Offset the input and output pointers.
+ wgrad += index_out
+ grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS)
+ x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ data = tl.load(x + offsets, mask=mask).to(tl.float32)
+ scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
+ acc += data * scale
+ offsets += BLOCK_X
+
+ # Reduce to get the final result and store.
+ out = tl.sum(acc).to(wgrad.dtype.element_ty)
+ tl.store(wgrad, out)
+
+
+def binned_scatter_wgrad(x, grad, indices, bins, top_k):
+ # Validate the input shapes.
+ assert_is_tensor(x, 3)
+ assert_is_matrix(grad)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(bins.shape[0], x.shape[0])
+
+ num_experts, expert_capacity, hidden_size = x.shape
+ tokens = indices.shape[0] // top_k
+ out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
+ _binned_copy_wgrad[(num_experts, expert_capacity)](
+ x,
+ grad,
+ out,
+ num_experts,
+ expert_capacity,
+ indices,
+ bins,
+ NUM_COLUMNS=hidden_size,
+ TOP_K=top_k,
+ )
+ return out
diff --git a/build/torch29-cxx11-cu126-x86_64-linux/benchmark_util.py b/build/torch29-cxx11-cu126-x86_64-linux/benchmark_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..02612d95e3ead1175a596e2878fa34b5bf85ad6f
--- /dev/null
+++ b/build/torch29-cxx11-cu126-x86_64-linux/benchmark_util.py
@@ -0,0 +1,35 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import numpy as np
+import torch
+
+
+def log_benchmark(name, arguments, time, std):
+ print('=' * 60)
+ print(f'{name} Benchmark')
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean time = {:.3f}ms, std time = {:.3f}ms'.format(time, std))
+ print('=' * 60)
+
+
+def benchmark_function(fn, iterations=100, warmup=10):
+ # Warmup iterations.
+ for _ in range(warmup):
+ fn()
+
+ times = []
+ for i in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+
+ start.record()
+ fn()
+ end.record()
+
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ return np.mean(times), np.std(times)
diff --git a/build/torch29-cxx11-cu126-x86_64-linux/cpu_fused_moe.py b/build/torch29-cxx11-cu126-x86_64-linux/cpu_fused_moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a29e81d6e9c57f64b0a78ac1c0828e45fd9d855
--- /dev/null
+++ b/build/torch29-cxx11-cu126-x86_64-linux/cpu_fused_moe.py
@@ -0,0 +1,311 @@
+# SPDX-License-Identifier: Apache-2.0
+# MegaBlocks CPU Fused MoE Implementation
+#
+# This is a pure Python/PyTorch implementation for CPU.
+# For better performance, consider using the C++ kernel implementation.
+#
+import torch
+import torch.nn.functional as F
+
+
+def swigluoai_activation(gate: torch.Tensor, up: torch.Tensor,
+ alpha: float = 1.702, limit: float = 7.0) -> torch.Tensor:
+ """
+ SwigluOAI activation function used in GptOss models.
+
+ Formula:
+ gate = clamp(gate, max=limit)
+ up = clamp(up, -limit, limit)
+ glu = gate * sigmoid(gate * alpha)
+ output = (up + 1) * glu
+
+ Args:
+ gate: Gate tensor from gate projection
+ up: Up tensor from up projection
+ alpha: Scaling factor for sigmoid (default: 1.702)
+ limit: Clamp limit (default: 7.0)
+
+ Returns:
+ Activated tensor
+ """
+ gate = gate.clamp(max=limit)
+ up = up.clamp(min=-limit, max=limit)
+ glu = gate * torch.sigmoid(gate * alpha)
+ return (up + 1) * glu
+
+
+def silu_and_mul_activation(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor:
+ """
+ SiLU (Swish) activation with element-wise multiplication.
+
+ Formula:
+ output = silu(gate) * up
+
+ Args:
+ gate: Gate tensor
+ up: Up tensor
+
+ Returns:
+ Activated tensor
+ """
+ return F.silu(gate) * up
+
+
+def route_tokens_cpu(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: torch.Tensor | None,
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_normalize_expert_weights: int | None = None,
+) -> tuple:
+ """
+ Route tokens to experts and compute expert weights and indices (CPU version).
+
+ Args:
+ x: Input tensor [batch, seq, hidden] or [tokens, hidden]
+ router_weight: Router weight [num_experts, hidden]
+ router_bias: Router bias [num_experts] or None
+ moe_top_k: Number of experts per token
+ moe_num_experts: Total number of experts
+ moe_normalize_expert_weights: Normalization order or None
+
+ Returns:
+ Tuple of (logits, expert_weights, expert_indices)
+ """
+ x_flat = x.view(-1, x.shape[-1])
+ logits = F.linear(x_flat, router_weight, router_bias)
+
+ if moe_top_k == 1:
+ expert_weights, expert_indices = logits.max(dim=-1, keepdim=True)
+ else:
+ expert_weights, expert_indices = torch.topk(logits, moe_top_k, dim=-1)
+
+ expert_weights = expert_weights.softmax(dim=-1)
+
+ if moe_normalize_expert_weights is not None:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+
+ return logits, expert_weights, expert_indices
+
+
+def cpu_fused_moe(
+ hidden_states: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ w1_bias: torch.Tensor | None = None,
+ w2_bias: torch.Tensor | None = None,
+ activation: str = "silu",
+ alpha: float = 1.702,
+ limit: float = 7.0,
+ is_interleaved: bool = True,
+) -> torch.Tensor:
+ """
+ CPU Fused MoE using PyTorch operations.
+
+ This implementation processes all experts in parallel using batched operations
+ instead of sequential for loops, which is more efficient on CPU.
+
+ Args:
+ hidden_states: [num_tokens, hidden_size]
+ w1: [num_experts, hidden_size, 2*inter_size] - gate_up_proj weights
+ w2: [num_experts, inter_size, hidden_size] - down_proj weights
+ topk_weights: [num_tokens, topk] - routing weights
+ topk_ids: [num_tokens, topk] - expert indices
+ w1_bias: [num_experts, 2*inter_size] or None
+ w2_bias: [num_experts, hidden_size] or None
+ activation: "silu" or "swigluoai"
+ alpha: swigluoai alpha parameter
+ limit: swigluoai limit parameter
+ is_interleaved: whether gate_up is interleaved [g0,u0,g1,u1,...] (True for GptOss)
+
+ Returns:
+ output: [num_tokens, hidden_size]
+ """
+ num_tokens, hidden_size = hidden_states.shape
+ num_experts = w1.shape[0]
+ inter_size = w2.shape[1]
+ topk = topk_weights.shape[1]
+
+ # Initialize output
+ output = torch.zeros_like(hidden_states)
+
+ # Build expert mask: which tokens go to which expert
+ # expert_mask[expert_id] contains indices of (token_idx, topk_pos) pairs
+ for expert_idx in range(num_experts):
+ # Find tokens assigned to this expert
+ # mask shape: [num_tokens, topk], True where topk_ids == expert_idx
+ mask = (topk_ids == expert_idx)
+
+ if not mask.any():
+ continue
+
+ # Get token indices and topk positions
+ token_indices, topk_positions = torch.where(mask)
+
+ if len(token_indices) == 0:
+ continue
+
+ # Gather input tokens for this expert
+ # current_hidden: [num_selected_tokens, hidden_size]
+ current_hidden = hidden_states[token_indices]
+
+ # Get weights for this expert
+ # w1[expert_idx]: [hidden_size, 2*inter_size]
+ # w2[expert_idx]: [inter_size, hidden_size]
+ expert_w1 = w1[expert_idx] # [hidden_size, 2*inter_size]
+ expert_w2 = w2[expert_idx] # [inter_size, hidden_size]
+
+ # First projection: hidden @ w1 -> [num_selected, 2*inter_size]
+ gate_up = current_hidden @ expert_w1
+
+ # Add bias if present
+ if w1_bias is not None:
+ gate_up = gate_up + w1_bias[expert_idx]
+
+ # Split gate and up projections
+ if is_interleaved:
+ # GptOss uses interleaved layout: [g0, u0, g1, u1, ...]
+ gate = gate_up[..., ::2] # [num_selected, inter_size]
+ up = gate_up[..., 1::2] # [num_selected, inter_size]
+ else:
+ # Standard layout: [gate_all, up_all]
+ gate = gate_up[..., :inter_size]
+ up = gate_up[..., inter_size:]
+
+ # Apply activation
+ if activation == "swigluoai":
+ activated = swigluoai_activation(gate, up, alpha, limit)
+ else: # silu
+ activated = silu_and_mul_activation(gate, up)
+
+ # Second projection: activated @ w2 -> [num_selected, hidden_size]
+ expert_out = activated @ expert_w2
+
+ # Add bias if present
+ if w2_bias is not None:
+ expert_out = expert_out + w2_bias[expert_idx]
+
+ # Apply routing weights and accumulate
+ # weights shape: [num_selected]
+ weights = topk_weights[token_indices, topk_positions].unsqueeze(-1)
+ weighted_out = expert_out * weights
+
+ # Accumulate to output
+ output.index_add_(0, token_indices, weighted_out.to(output.dtype))
+
+ return output
+
+
+class MegaBlocksMoeMLP(torch.nn.Module):
+ """
+ CPU MoE MLP module that can be used as a drop-in replacement for
+ the transformers GptOssMLP when using @use_kernel_forward_from_hub.
+ """
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> tuple:
+ """
+ Forward pass through the MoE layer.
+
+ Args:
+ x: Input tensor of shape [batch_size, seq_len, hidden_size] or [tokens, hidden_size]
+
+ Returns:
+ Tuple of (output, expert_weights) where:
+ - output: Tensor of same shape as input
+ - expert_weights: Expert weights for each token [tokens, top_k]
+ """
+ # Get MoE parameters from the wrapped modules
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+
+ # Detect activation type
+ if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
+ activation = "swigluoai"
+ alpha = self.experts.alpha
+ limit = self.experts.limit
+ else:
+ activation = getattr(self.experts, "activation", "silu")
+ alpha = 1.702
+ limit = 7.0
+
+ # Get weight tensors
+ if hasattr(self.experts, "gate_up_proj"):
+ w1 = self.experts.gate_up_proj
+ is_interleaved = True # GptOss uses interleaved layout
+ elif hasattr(self.experts, "w1"):
+ w1 = self.experts.w1
+ w3 = getattr(self.experts, "w3", None)
+ if w3 is not None:
+ w1 = torch.cat([w1, w3], dim=-1)
+ is_interleaved = False
+ else:
+ raise AttributeError("experts module must have 'gate_up_proj' or 'w1' attribute")
+
+ if hasattr(self.experts, "down_proj"):
+ w2 = self.experts.down_proj
+ elif hasattr(self.experts, "w2"):
+ w2 = self.experts.w2
+ else:
+ raise AttributeError("experts module must have 'down_proj' or 'w2' attribute")
+
+ # Get optional bias tensors
+ w1_bias = getattr(self.experts, "gate_up_proj_bias", None)
+ w2_bias = getattr(self.experts, "down_proj_bias", None)
+
+ # Store original shape
+ in_shape = x.size()
+
+ # Route tokens to experts
+ logits, expert_weights, expert_indices = route_tokens_cpu(
+ x,
+ self.router.weight,
+ getattr(self.router, "bias", None),
+ moe_top_k,
+ moe_num_experts,
+ moe_normalize_expert_weights,
+ )
+
+ # Reshape input for fused MoE
+ x_flat = x.view(-1, x.shape[-1])
+
+ # Call CPU fused MoE
+ output = cpu_fused_moe(
+ hidden_states=x_flat,
+ w1=w1,
+ w2=w2,
+ topk_weights=expert_weights,
+ topk_ids=expert_indices,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ activation=activation,
+ alpha=alpha,
+ limit=limit,
+ is_interleaved=is_interleaved,
+ )
+
+ # Restore original shape
+ output = output.view(in_shape)
+
+ return output, expert_weights
+
+
+# Export classes and functions
+__all__ = [
+ "MegaBlocksMoeMLP",
+ "cpu_fused_moe",
+ "route_tokens_cpu",
+ "swigluoai_activation",
+ "silu_and_mul_activation",
+]
diff --git a/build/torch29-cxx11-cu126-x86_64-linux/cpu_moe_cpp.py b/build/torch29-cxx11-cu126-x86_64-linux/cpu_moe_cpp.py
new file mode 100644
index 0000000000000000000000000000000000000000..073ff66d24ce348fbb5ed19c9027fadd3f7a9c61
--- /dev/null
+++ b/build/torch29-cxx11-cu126-x86_64-linux/cpu_moe_cpp.py
@@ -0,0 +1,265 @@
+# SPDX-License-Identifier: Apache-2.0
+# MegaBlocks C++ Optimized CPU MoE
+
+"""
+C++ accelerated MoE with brgemm optimization for Intel AMX.
+Direct replacement for cpu_fused_moe.MegaBlocksMoeMLP with better performance.
+"""
+
+import torch
+from typing import Optional
+from .cpu_fused_moe import route_tokens_cpu
+from ._ops import ops
+
+
+def _to_local_tensor(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
+ """Convert DTensor to local torch.Tensor if needed for custom ops compatibility."""
+ if tensor is None:
+ return None
+ # Check if it's a DTensor by looking for the to_local() method
+ if hasattr(tensor, "to_local"):
+ return tensor.to_local()
+ return tensor
+
+
+def fused_moe_cpp(
+ hidden_states: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ inplace: bool = False,
+ use_int8_w8a8: bool = False,
+ use_fp8_w8a16: bool = False,
+ use_mxfp4: bool = False,
+ w1_scale: Optional[torch.Tensor] = None,
+ w2_scale: Optional[torch.Tensor] = None,
+ block_size: Optional[list] = None,
+ a1_scale: Optional[torch.Tensor] = None,
+ a2_scale: Optional[torch.Tensor] = None,
+ w1_bias: Optional[torch.Tensor] = None,
+ w2_bias: Optional[torch.Tensor] = None,
+ alpha: Optional[float] = None,
+ limit: Optional[float] = None,
+ is_vnni: bool = False,
+) -> torch.Tensor:
+ """
+ C++ Fused MoE with brgemm optimization (sglang compatible interface).
+
+ Uses at::native::cpublas::brgemm for efficient batch GEMM on Intel CPUs.
+ Supports both silu_and_mul (standard SwiGLU) and swigluoai (GptOss) activations.
+
+ Args:
+ hidden_states: Input tensor [M, K]
+ w1: Gate and up projections [E, 2N, K]
+ w2: Down projection [E, K, N]
+ topk_weights: Expert weights [M, topk]
+ topk_ids: Expert indices [M, topk]
+ inplace: Whether to use hidden_states as output
+ use_int8_w8a8: Use int8 quantization
+ use_fp8_w8a16: Use fp8 quantization
+ use_mxfp4: Use mxfp4 quantization
+ w1_scale, w2_scale: Quantization scales
+ block_size: Block size for fp8
+ a1_scale, a2_scale: Activation scales
+ w1_bias, w2_bias: Optional biases
+ alpha: swigluoai alpha parameter (set to enable swiglu)
+ limit: swigluoai limit parameter (set to enable swiglu)
+ is_vnni: Whether w1/w2 are already in VNNI packed format
+ """
+ # MXFP4/FP8 kernels only support bf16, convert if needed
+ orig_dtype = hidden_states.dtype
+ need_convert = ((use_mxfp4 or use_fp8_w8a16) and orig_dtype != torch.bfloat16) or orig_dtype == torch.float32
+ if need_convert:
+ hidden_states = hidden_states.to(torch.bfloat16)
+
+ # bias must match hidden_states dtype
+ if w1_bias is not None:
+ w1_bias = w1_bias.to(hidden_states.dtype)
+ if w2_bias is not None:
+ w2_bias = w2_bias.to(hidden_states.dtype)
+
+ # Convert DTensor to local tensor for custom ops compatibility (TP mode)
+ hidden_states = _to_local_tensor(hidden_states)
+ w1 = _to_local_tensor(w1)
+ w2 = _to_local_tensor(w2)
+ topk_weights = _to_local_tensor(topk_weights)
+ topk_ids = _to_local_tensor(topk_ids)
+ w1_scale = _to_local_tensor(w1_scale)
+ w2_scale = _to_local_tensor(w2_scale)
+ a1_scale = _to_local_tensor(a1_scale)
+ a2_scale = _to_local_tensor(a2_scale)
+ w1_bias = _to_local_tensor(w1_bias)
+ w2_bias = _to_local_tensor(w2_bias)
+
+ output = ops.fused_experts(
+ hidden_states, w1, w2, topk_weights, topk_ids,
+ inplace, use_int8_w8a8, use_fp8_w8a16, use_mxfp4,
+ w1_scale, w2_scale, block_size, a1_scale, a2_scale,
+ w1_bias, w2_bias, alpha, limit, is_vnni
+ )
+
+ # Convert back to original dtype if needed
+ if need_convert:
+ output = output.to(orig_dtype)
+ return output
+
+
+class CPUMegaBlocksMoeMLP(torch.nn.Module):
+ """
+ C++ optimized MoE MLP using brgemm.
+ Drop-in replacement for cpu_fused_moe.MegaBlocksMoeMLP with better performance.
+
+ Usage in transformers:
+ # Will be used via @use_kernel_forward_from_hub decorator
+ """
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> tuple:
+ """
+ Forward pass through the MoE layer using C++ kernel.
+
+ Args:
+ x: Input tensor [batch_size, seq_len, hidden_size]
+
+ Returns:
+ Tuple of (output, expert_weights)
+ """
+ # Optimization for GPT-OSS model
+ if getattr(self, "use_mxfp4", None) is None:
+ self.use_mxfp4 = False
+
+ w1_scale = None
+ w2_scale = None
+
+ if (
+ not getattr(self, "packed_scales", False)
+ and hasattr(self.experts, "gate_up_proj")
+ and getattr(self.experts, "gate_up_proj_precision_config", None) is not None
+ ):
+ # convert scales
+ data_1 = ops.convert_scale_packed(self.experts.gate_up_proj_precision_config.weight_scale.data.transpose(-1, -2).contiguous())
+ data_2 = ops.convert_scale_packed(self.experts.down_proj_precision_config.weight_scale.data.transpose(-1, -2).contiguous())
+ self.experts.gate_up_proj_precision_config.weight_scale.storage.data = data_1
+ self.experts.down_proj_precision_config.weight_scale.storage.data = data_2
+ self.packed_scales = True
+ self.use_mxfp4 = True
+
+ if not getattr(self, "packed_weight", False) and hasattr(
+ self.experts, "gate_up_proj"
+ ):
+ # convert weights
+ data_1 = self.experts.gate_up_proj.data.transpose(-1, -2).contiguous()
+ data_2 = self.experts.down_proj.data.transpose(-1, -2).contiguous()
+ if self.use_mxfp4:
+ self.experts.gate_up_proj.storage.data = ops.convert_weight_packed(data_1)
+ self.experts.down_proj.storage.data = ops.convert_weight_packed(data_2)
+ else:
+ # convert_weight_packed only supports bfloat16, float16, int8, fp8_e4m3 or uint8(mxfp4 or int4).
+ data_1 = data_1.to(torch.bfloat16) if data_1.dtype == torch.float32 else data_1
+ data_2 = data_2.to(torch.bfloat16) if data_2.dtype == torch.float32 else data_2
+ self.experts.gate_up_proj.data = ops.convert_weight_packed(data_1)
+ self.experts.down_proj.data = ops.convert_weight_packed(data_2)
+
+ # C++ kernel does not support float32.
+ dtype = torch.bfloat16 if x.dtype == torch.float32 else x.dtype
+ if getattr(self.experts, "gate_up_proj_bias", None) is not None:
+ self.experts.gate_up_proj_bias.data = self.experts.gate_up_proj_bias.data.to(dtype)
+ if getattr(self.experts, "down_proj_bias", None) is not None:
+ self.experts.down_proj_bias.data = self.experts.down_proj_bias.data.to(dtype)
+
+ self.packed_weight = True
+
+ # Get MoE parameters
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
+
+ # Detect activation type
+ if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
+ activation = "swigluoai"
+ alpha = self.experts.alpha
+ limit = self.experts.limit
+ else:
+ activation = getattr(self.experts, "activation", "silu")
+ alpha = 1.702
+ limit = 7.0
+
+ # Get weight tensors
+ if hasattr(self.experts, "gate_up_proj"):
+ w1 = self.experts.gate_up_proj
+ elif hasattr(self.experts, "w1"):
+ w1 = self.experts.w1
+ w3 = getattr(self.experts, "w3", None)
+ if w3 is not None:
+ w1 = torch.cat([w1, w3], dim=-1)
+ else:
+ raise AttributeError("experts module must have 'gate_up_proj' or 'w1' attribute")
+
+ if hasattr(self.experts, "down_proj"):
+ w2 = self.experts.down_proj
+ elif hasattr(self.experts, "w2"):
+ w2 = self.experts.w2
+ else:
+ raise AttributeError("experts module must have 'down_proj' or 'w2' attribute")
+
+ # Get optional bias tensors
+ w1_bias = getattr(self.experts, "gate_up_proj_bias", None)
+ w2_bias = getattr(self.experts, "down_proj_bias", None)
+ w1_bias = w1_bias if w1_bias is None else w1_bias.data
+ w2_bias = w2_bias if w2_bias is None else w2_bias.data
+
+ if self.use_mxfp4:
+ w1_scale = self.experts.gate_up_proj_precision_config.weight_scale.data
+ w2_scale = self.experts.down_proj_precision_config.weight_scale.data
+
+ # Store original shape
+ in_shape = x.size()
+
+ # Route tokens to experts (Python implementation is fast enough)
+ logits, expert_weights, expert_indices = route_tokens_cpu(
+ x,
+ self.router.weight,
+ getattr(self.router, "bias", None),
+ moe_top_k,
+ moe_num_experts,
+ moe_normalize_expert_weights,
+ )
+
+ # Flatten input
+ x_flat = x.view(-1, x.shape[-1])
+
+ # Determine alpha/limit for swiglu activation
+ use_alpha = alpha if activation == "swigluoai" else None
+ use_limit = limit if activation == "swigluoai" else None
+
+ # Call C++ optimized kernel
+ output = fused_moe_cpp(
+ hidden_states=x_flat,
+ w1=w1.data,
+ w2=w2.data,
+ topk_weights=expert_weights,
+ topk_ids=expert_indices.to(torch.int32),
+ inplace=False,
+ use_int8_w8a8=False,
+ use_fp8_w8a16=False,
+ use_mxfp4=self.use_mxfp4,
+ w1_scale=w1_scale,
+ w2_scale=w2_scale,
+ block_size=None,
+ a1_scale=None,
+ a2_scale=None,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ alpha=use_alpha,
+ limit=use_limit,
+ is_vnni=getattr(self, "packed_weight", False),
+ )
+
+ # Restore original shape
+ output = output.view(in_shape)
+
+ return output, expert_weights
+
+
+__all__ = ["fused_moe_cpp", "MegaBlocksMoeMLP"]
diff --git a/build/torch29-cxx11-cu126-x86_64-linux/grouped_gemm/__init__.py b/build/torch29-cxx11-cu126-x86_64-linux/grouped_gemm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b91c8308f0c24f4c4171b6e4f15b6f76dabf295a
--- /dev/null
+++ b/build/torch29-cxx11-cu126-x86_64-linux/grouped_gemm/__init__.py
@@ -0,0 +1,2 @@
+from . import ops
+from . import backend
diff --git a/build/torch29-cxx11-cu126-x86_64-linux/grouped_gemm/backend.py b/build/torch29-cxx11-cu126-x86_64-linux/grouped_gemm/backend.py
new file mode 100644
index 0000000000000000000000000000000000000000..76037d8039cbfc2f0577275c78e4bc0be762592a
--- /dev/null
+++ b/build/torch29-cxx11-cu126-x86_64-linux/grouped_gemm/backend.py
@@ -0,0 +1,33 @@
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# # TODO(tgale): Wrap this in a try-block with better
+# # error message and instructions for building the
+# # c++ operations.
+# import grouped_gemm_backend as backend
+
+# We import the backend operations from the megablocks package as
+# grouped_gemm is vendored in megablocks in this repository.
+# from ... import _ops as backend
+# from megablocks._ops import ops as backend # type: ignore
+from .._ops import ops as backend # type: ignore
+
+def _allocate_output(a, b, batch_sizes, trans_a, trans_b):
+ assert not (trans_a and trans_b)
+ assert batch_sizes.ndim == 1, "Expected 1d tensor for batch_sizes"
+ assert a.ndim == 2, "Expected 2d tensor for 'a'"
+ assert b.ndim == (2 if trans_a else 3)
+
+ shape = (
+ (batch_sizes.shape[0], a.shape[1], b.shape[1])
+ if trans_a else
+ (a.shape[0], (b.shape[1] if trans_b else b.shape[2]))
+ )
+ return torch.empty(*shape, device=a.device, dtype=a.dtype)
+
+def gmm(a, b, batch_sizes, trans_a=False, trans_b=False, c=None):
+ if c is None:
+ c = _allocate_output(a, b, batch_sizes, trans_a, trans_b)
+ backend.gmm(a, b, c, batch_sizes, trans_a, trans_b)
+ return c
diff --git a/build/torch29-cxx11-cu126-x86_64-linux/grouped_gemm/ops.py b/build/torch29-cxx11-cu126-x86_64-linux/grouped_gemm/ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b30dd14e23837ea3b12334f4e31337ed9ad2b69
--- /dev/null
+++ b/build/torch29-cxx11-cu126-x86_64-linux/grouped_gemm/ops.py
@@ -0,0 +1,33 @@
+from . import backend
+import torch
+
+
+class GroupedGemm(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, a, b, batch_sizes, trans_b):
+ ctx.save_for_backward(a, b, batch_sizes)
+ ctx.trans_b = trans_b
+ return backend.gmm(a, b, batch_sizes, trans_a=False, trans_b=trans_b)
+
+ @staticmethod
+ def backward(ctx, grad):
+ grad = grad.contiguous()
+ a, b, batch_sizes = ctx.saved_tensors
+ trans_b = ctx.trans_b
+
+ agrad = None
+ if ctx.needs_input_grad[0]:
+ agrad = backend.gmm(
+ grad, b, batch_sizes, trans_a=False, trans_b=not trans_b)
+
+ bgrad = None
+ if ctx.needs_input_grad[1]:
+ lhs, rhs = (grad, a) if trans_b else (a, grad)
+ bgrad = backend.gmm(
+ lhs, rhs, batch_sizes, trans_a=True, trans_b=False)
+ return agrad, bgrad, None, None
+
+
+def gmm(a, b, batch_sizes, trans_b=False):
+ return GroupedGemm.apply(a, b, batch_sizes, trans_b)
diff --git a/build/torch29-cxx11-cu126-x86_64-linux/grouped_gemm_util.py b/build/torch29-cxx11-cu126-x86_64-linux/grouped_gemm_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d6d49fc46b0a57ad46e4179df3cc1ac2a24f7ae
--- /dev/null
+++ b/build/torch29-cxx11-cu126-x86_64-linux/grouped_gemm_util.py
@@ -0,0 +1,31 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+import warnings
+
+_grouped_gemm_is_available: bool = False
+try:
+ # import grouped_gemm
+ pass
+ _grouped_gemm_is_available = True
+except ImportError as error:
+ warnings.warn('Grouped GEMM not available.')
+
+
+def grouped_gemm_is_available():
+ return _grouped_gemm_is_available
+
+
+def assert_grouped_gemm_is_available():
+ msg = (
+ 'Grouped GEMM not available. Please run '
+ '`pip install git+https://github.com/tgale96/grouped_gemm@main`.',
+ )
+ assert _grouped_gemm_is_available, msg
+
+
+# backend = grouped_gemm.backend if grouped_gemm_is_available() else None
+# ops = grouped_gemm.ops if grouped_gemm_is_available() else None
+
+
+from .grouped_gemm import backend as ops
+from .grouped_gemm import ops as backend
diff --git a/build/torch29-cxx11-cu126-x86_64-linux/layers.py b/build/torch29-cxx11-cu126-x86_64-linux/layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c91edcd9e2d1a4ef9eac90217ff481f08ab1886
--- /dev/null
+++ b/build/torch29-cxx11-cu126-x86_64-linux/layers.py
@@ -0,0 +1,1232 @@
+import torch
+import torch.distributed as dist
+
+from typing import Optional, Any, TYPE_CHECKING
+
+from . import _layers
+from . import ops
+
+# Conditional import for meta kernel registration
+if TYPE_CHECKING:
+
+ def register_fake(fn):
+ return lambda name: fn
+
+else:
+ try:
+ from torch.library import register_fake
+ except ImportError:
+ try:
+ from torch.library import impl_abstract as register_fake
+ except ImportError:
+ # Fallback for older PyTorch versions
+ def register_fake(op_name):
+ def decorator(fn):
+ return fn
+
+ return decorator
+
+
+# Meta kernel implementations for torch.compile compatibility
+def _install_meta_kernels():
+ """Install meta kernels for existing MegaBlocks operations"""
+
+ # Create wrapper functions that check for compilation and return meta tensors
+
+ # Patch ops.sort
+ if hasattr(ops, "sort"):
+ original_sort = ops.sort
+
+ def sort_with_meta(x, end_bit=None):
+ if torch.compiler.is_compiling():
+ # print("Using meta kernel for sort")
+ # Meta implementation - return tensors with correct shape/dtype/device
+ return torch.empty_like(x), torch.empty_like(x)
+ # print("Using original sort kernel")
+ return original_sort(x, end_bit)
+
+ ops.sort = sort_with_meta
+
+ # Patch ops.histogram
+ if hasattr(ops, "histogram"):
+ original_histogram = ops.histogram
+
+ def histogram_with_meta(x, max_val):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty((max_val,), dtype=torch.int32, device=x.device)
+ return original_histogram(x, max_val)
+
+ ops.histogram = histogram_with_meta
+
+ # Patch ops.inclusive_cumsum
+ if hasattr(ops, "inclusive_cumsum"):
+ original_inclusive_cumsum = ops.inclusive_cumsum
+
+ def inclusive_cumsum_with_meta(x, dim):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty_like(x)
+ return original_inclusive_cumsum(x, dim)
+
+ ops.inclusive_cumsum = inclusive_cumsum_with_meta
+
+ # Patch ops.binned_gather
+ if hasattr(ops, "binned_gather"):
+ original_binned_gather = ops.binned_gather
+
+ def binned_gather_with_meta(x, indices, bins, bin_size, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - output shape based on bin_size
+ if x.dim() >= 2:
+ hidden_size = x.size(-1)
+ return torch.empty(
+ (bin_size, x.size(1), hidden_size),
+ dtype=x.dtype,
+ device=x.device,
+ )
+ else:
+ return torch.empty((bin_size,), dtype=x.dtype, device=x.device)
+ return original_binned_gather(x, indices, bins, bin_size, top_k)
+
+ ops.binned_gather = binned_gather_with_meta
+
+ # Patch ops.binned_scatter
+ if hasattr(ops, "binned_scatter"):
+ original_binned_scatter = ops.binned_scatter
+
+ def binned_scatter_with_meta(x, indices, weights, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - typically reduces to 2D
+ if x.dim() >= 3:
+ return torch.empty(
+ (x.size(1), x.size(2)), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty_like(x)
+ return original_binned_scatter(x, indices, weights, bins, top_k)
+
+ ops.binned_scatter = binned_scatter_with_meta
+
+ # Patch ops.gather
+ if hasattr(ops, "gather"):
+ original_gather = ops.gather
+
+ def gather_with_meta(x, indices, bin_ids, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ if x.dim() >= 2:
+ hidden_size = x.size(-1)
+ return torch.empty(
+ (indices.numel(), hidden_size), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty(indices.shape, dtype=x.dtype, device=x.device)
+ return original_gather(x, indices, bin_ids, bins, top_k)
+
+ ops.gather = gather_with_meta
+
+ # Patch ops.scatter
+ if hasattr(ops, "scatter"):
+ original_scatter = ops.scatter
+
+ def scatter_with_meta(x, indices, bin_ids, weights, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - restore sequence shape
+ seq_len = (
+ indices.size(0) // top_k
+ if indices.numel() > 0 and top_k > 0
+ else x.size(0)
+ )
+ if x.dim() >= 2:
+ return torch.empty(
+ (seq_len, x.size(-1)), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty((seq_len,), dtype=x.dtype, device=x.device)
+ return original_scatter(x, indices, bin_ids, weights, bins, top_k)
+
+ ops.scatter = scatter_with_meta
+
+ # Patch ops.replicate
+ if hasattr(ops, "replicate"):
+ original_replicate = ops.replicate
+
+ def replicate_with_meta(x, bins, num_outputs):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty(
+ (x.shape[0], num_outputs), dtype=x.dtype, device=x.device
+ )
+ return original_replicate(x, bins, num_outputs)
+
+ ops.replicate = replicate_with_meta
+
+ # Patch ops.repeat (if it's a regular function)
+ if hasattr(ops, "repeat"):
+ original_repeat = ops.repeat
+
+ def repeat_with_meta(x, repeats):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ if isinstance(repeats, (tuple, list)):
+ new_shape = list(x.shape)
+ for i, rep in enumerate(repeats):
+ if i < len(new_shape):
+ new_shape[i] *= rep
+ return torch.empty(new_shape, dtype=x.dtype, device=x.device)
+ else:
+ new_shape = [x.size(0) * repeats] + list(x.shape[1:])
+ return torch.empty(new_shape, dtype=x.dtype, device=x.device)
+ return original_repeat(x, repeats)
+
+ ops.repeat = repeat_with_meta
+
+
+# Install meta kernels on import
+try:
+ _install_meta_kernels()
+except Exception as e:
+ # If meta kernel installation fails, continue without them
+ # torch.compile may not work but the library will still function
+ import warnings
+
+ warnings.warn(
+ f"Failed to install meta kernels for torch.compile support: {e}", UserWarning
+ )
+
+
+# Set the expert model parallel attributes on a tensor
+def set_expert_model_parallel_attributes(
+ tensor: torch.Tensor,
+ is_parallel: bool,
+):
+ assert not hasattr(tensor, "expert_model_parallel")
+ setattr(tensor, "expert_model_parallel", is_parallel)
+
+
+# Get the expert model parallel attributes from a tensor
+def expert_sharding_degree(
+ world_size: int,
+ moe_num_experts: int,
+) -> int:
+ esd = min(world_size, moe_num_experts)
+ if (moe_num_experts % esd) != 0:
+ raise ValueError(f"Cannot shard {moe_num_experts} experts {esd} ways.")
+ return esd
+
+
+# Calculate the hidden sharding degree based on world size and expert sharding degree
+def hidden_sharding_degree(
+ world_size: int,
+ moe_num_experts: int,
+ ffn_hidden_size: int,
+) -> int:
+ esd = expert_sharding_degree(world_size, moe_num_experts)
+ hsd = world_size // esd
+ if (ffn_hidden_size % hsd) != 0:
+ raise ValueError(f"Cannot shard {ffn_hidden_size} features {hsd} ways.")
+ if (esd * hsd) != world_size:
+ raise ValueError(
+ f"Invalid sharding. expert_sharding_degree ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size})."
+ )
+ return hsd
+
+
+# Calculate the number of experts per rank based on world size and expert sharding degree
+def experts_per_rank(
+ moe_num_experts: int,
+ world_size: int,
+) -> int:
+ return moe_num_experts // expert_sharding_degree(world_size, moe_num_experts)
+
+
+# Calculate the number of features per rank based on ffn hidden size and hidden sharding degree
+def features_per_rank(
+ ffn_hidden_size: int, world_size: int, moe_num_experts: int
+) -> int:
+ return ffn_hidden_size // hidden_sharding_degree(
+ world_size, moe_num_experts, ffn_hidden_size
+ )
+
+
+# Apply jitter to the input tensor
+def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor:
+ low = 1.0 - moe_jitter_eps
+ high = 1.0 + moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return x * (low + noise * (high - low))
+
+
+# Compute the top-k scores from the logits
+def compute_top_k(scores: torch.Tensor, moe_top_k: int):
+ if moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, moe_top_k, dim=-1)
+
+
+# Route tokens to experts and compute expert weights and indices
+def route_tokens(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: torch.Tensor,
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ if training and moe_jitter_eps is not None:
+ x = apply_jitter(x, moe_jitter_eps)
+
+ x_flat = x.view(-1, x.shape[-1])
+ logits = torch.nn.functional.linear(x_flat, router_weight, router_bias)
+ expert_weights, expert_indices = compute_top_k(logits, moe_top_k)
+ expert_weights = expert_weights.softmax(dim=-1)
+ if moe_normalize_expert_weights is not None:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+ if uniform_expert_assignment:
+ expert_indices = _layers.router._uniform_expert_assignment(
+ expert_indices,
+ moe_num_experts,
+ )
+
+ return logits, expert_weights, expert_indices
+
+
+# Scale the gradient of the weights
+def scale_grad(
+ w: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+) -> torch.Tensor:
+ if gradient_scale is None:
+ return w
+ return _layers.mlp.scale_gradient(w, gradient_scale)
+
+
+# Forward pass for the MLP layer
+def mlp_forward(
+ x: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ limit: float = 7.0,
+):
+ # Scale weights
+ w1 = scale_grad(w1, gradient_scale)
+ w2 = scale_grad(w2, gradient_scale)
+ w1_bias = scale_grad(w1_bias, gradient_scale)
+ w2_bias = scale_grad(w2_bias, gradient_scale)
+
+ # Resolve dtensors
+ w1 = _layers.mlp.resolve_dtensor(w1)
+ w2 = _layers.mlp.resolve_dtensor(w2)
+ w1_bias = _layers.mlp.resolve_dtensor(w1_bias)
+ w2_bias = _layers.mlp.resolve_dtensor(w2_bias)
+
+ # Forward pass
+ gate_up = torch.bmm(x, w1) + w1_bias[..., None, :]
+ gate, up = gate_up[..., ::2], gate_up[..., 1::2]
+ gate = gate.clamp(min=None, max=limit)
+ up = up.clamp(min=-limit, max=limit)
+ glu = gate * torch.sigmoid(gate * alpha)
+ next_states = torch.bmm(((up + 1) * glu), w2)
+ next_states += w2_bias[..., None, :]
+ return next_states
+
+# Shared expert MLP forward pass
+def shared_mlp_forward(
+ x: torch.Tensor,
+ up_proj_weight: torch.Tensor,
+ down_proj_weight: torch.Tensor,
+ up_proj_bias: Optional[torch.Tensor] = None,
+ down_proj_bias: Optional[torch.Tensor] = None,
+ activation_fn: Optional[Any] = None,
+ gradient_scale: Optional[float] = None,
+) -> torch.Tensor:
+ # Default activation function
+ if activation_fn is None:
+ activation_fn = torch.nn.functional.gelu
+
+ # Scale weights
+ up_proj_weight = scale_grad(up_proj_weight, gradient_scale)
+ down_proj_weight = scale_grad(down_proj_weight, gradient_scale)
+ if up_proj_bias is not None:
+ up_proj_bias = scale_grad(up_proj_bias, gradient_scale)
+ if down_proj_bias is not None:
+ down_proj_bias = scale_grad(down_proj_bias, gradient_scale)
+
+ # Resolve dtensors
+ up_proj_weight = _layers.mlp.resolve_dtensor(up_proj_weight)
+ down_proj_weight = _layers.mlp.resolve_dtensor(down_proj_weight)
+ if up_proj_bias is not None:
+ up_proj_bias = _layers.mlp.resolve_dtensor(up_proj_bias)
+ if down_proj_bias is not None:
+ down_proj_bias = _layers.mlp.resolve_dtensor(down_proj_bias)
+
+ # Up projection
+ x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias)
+
+ # Activation
+ x = activation_fn(x)
+
+ # Down projection
+ x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias)
+
+ return x
+
+
+# Combine outputs from shared expert and regular experts
+def combine_expert_shared_outputs(
+ shared_expert_out: torch.Tensor,
+ expert_out: torch.Tensor,
+ shared_expert_weighted_sum: bool = False,
+ moe_top_k: int = 1,
+) -> torch.Tensor:
+ if shared_expert_weighted_sum:
+ # Weighted sum based on number of experts used
+ total_experts = moe_top_k + 1
+ shared_weight = 1.0 / total_experts
+ expert_weight = moe_top_k / total_experts
+ return shared_expert_out * shared_weight + expert_out * expert_weight
+ else:
+ # Simple addition
+ return shared_expert_out + expert_out
+
+
+# Global variable to store load balancing loss
+_LOAD_BALANCING_LOSS = []
+
+
+def save_load_balancing_loss(loss):
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.append(loss)
+
+
+def get_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ return _LOAD_BALANCING_LOSS
+
+
+def clear_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.clear()
+
+
+def batched_load_balancing_loss(args):
+ if args.moe_loss_weight == 0:
+ return 0.0
+
+ tokens_per_expert, expert_scores = zip(*get_load_balancing_loss())
+ num_layers_per_pipeline_stage = args.num_layers // args.pipeline_model_parallel_size
+ if args.num_layers_per_virtual_pipeline_stage is not None:
+ num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage
+
+ if len(tokens_per_expert) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f"Expected {num_layers_per_pipeline_stage} token_per_experts "
+ f"but found {len(tokens_per_expert)}.\nnum_layers = "
+ f"{args.num_layers}\npipeline_model_parallel_size = "
+ f"{args.pipeline_model_parallel_size}\n"
+ "num_layers_per_virtual_pipeline_stage"
+ f" = {args.num_layers_per_virtual_pipeline_stage}",
+ )
+ if len(expert_scores) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f"Expected {num_layers_per_pipeline_stage} expert_scores "
+ f"but found {len(tokens_per_expert)}.\nnum_layers = "
+ f"{args.num_layers}\npipeline_model_parallel_size = "
+ f"{args.pipeline_model_parallel_size}\n"
+ "num_layers_per_virtual_pipeline_stage"
+ f" = {args.num_layers_per_virtual_pipeline_stage}",
+ )
+
+ # Verify the shape of the tokens_per_expert and expert_scores tensors.
+ assert all(
+ (x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert)
+ )
+
+ tokens = expert_scores[0].shape[0]
+ assert all(
+ (
+ (
+ x.ndim == 2
+ and x.shape[1] == args.moe_num_experts
+ and x.shape[0] == tokens
+ )
+ for x in expert_scores
+ )
+ )
+
+ # Concatenate the contributions of each layer and convert to
+ # the correct types and formats for the dot product.
+ expert_scores = torch.cat(expert_scores, dim=1)
+ if args.moe_lbl_in_fp32:
+ expert_scores = expert_scores.float()
+ if tokens != 0:
+ expert_scores = expert_scores.mean(dim=0)
+ else:
+ expert_scores = expert_scores.sum(dim=0)
+ tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype)
+
+ expected_values = num_layers_per_pipeline_stage * args.moe_num_experts
+ assert tokens_per_expert.numel() == expected_values
+ assert expert_scores.numel() == expected_values
+
+ # Calculate the total scale across all factors.
+ #
+ # loss_weight * num_experts / (num_layers * tokens * top_k)
+ scale_numerator = args.moe_num_experts * args.moe_loss_weight
+ scale_denominator = args.num_layers * tokens * args.moe_top_k
+ scale = scale_numerator / scale_denominator
+ return scale * torch.dot(tokens_per_expert, expert_scores)
+
+
+# Calculate the expert capacity based on tokens, top_k, number of experts,
+# expert parallel group, capacity factor, and whether expert model parallelism is used.
+def expert_capacity(
+ tokens: int,
+ top_k: int,
+ num_experts: int,
+ expert_parallel_group: int,
+ moe_capacity_factor: float,
+ moe_expert_model_parallelism: bool,
+) -> int:
+ world_size = (
+ dist.get_world_size(expert_parallel_group)
+ if moe_expert_model_parallelism
+ else 1
+ )
+
+ tokens_per_expert = top_k * tokens * world_size / num_experts
+ return int(moe_capacity_factor * tokens_per_expert)
+
+
+def load_balancing_loss(
+ tokens_per_expert: torch.Tensor,
+ expert_scores: torch.Tensor,
+ top_k: int,
+ num_experts: int,
+):
+ assert len(expert_scores.size()) == 2
+ tokens, num_experts = expert_scores.size()
+ assert num_experts == num_experts
+ assert len(tokens_per_expert.size()) == 1
+ (num_experts,) = tokens_per_expert.size()
+ assert num_experts == num_experts
+ scale = num_experts / (tokens * top_k)
+ return scale * torch.dot(
+ tokens_per_expert.to(expert_scores.dtype),
+ expert_scores.mean(dim=0),
+ )
+
+
+def indices_and_bins(
+ top_expert: torch.Tensor,
+ sort_end_bit: int,
+ num_experts: int,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ top_expert = top_expert.int()
+
+ # Ensure contiguous memory layout
+ top_expert = top_expert.contiguous()
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(top_expert.device):
+ output = ops.sort(top_expert, sort_end_bit)
+ bin_ids, indices = output
+ tokens_per_expert = ops.histogram(top_expert, num_experts)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+
+ bins = bins.view(1) if not len(bins.size()) else bins
+ return indices, bin_ids, bins, tokens_per_expert
+
+
+def expert_capacity_fn(
+ tokens: int,
+ top_k: int,
+ num_experts: int,
+ expert_parallel_group: torch.distributed.ProcessGroup,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+) -> int:
+ world_size = (
+ dist.get_world_size(expert_parallel_group)
+ if moe_expert_model_parallelism
+ else 1
+ )
+ tokens_per_expert = top_k * tokens * world_size / num_experts
+ return int(moe_capacity_factor * tokens_per_expert)
+
+
+def permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ top_k,
+ w1,
+ w2,
+ w1_bias,
+ w2_bias,
+ gradient_scale,
+ alpha,
+):
+ # Route tokens to experts
+ x = x.view(-1, x.shape[-1])
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(x.device):
+ x = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
+
+ # Expert computation
+ x = mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale, alpha)
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(x.device):
+ # Route tokens back
+ out = ops.binned_scatter(x, indices, expert_weights, bins, top_k)
+ return out
+
+
+def forward_once(
+ x: torch.Tensor,
+ expert_weights: torch.Tensor,
+ top_experts: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ top_k: int = 4,
+ num_experts: int = 128,
+ expert_parallel_group: int = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ mlp_impl: Optional[str] = None,
+):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
+ top_experts, sort_end_bit, num_experts
+ )
+
+ # Calculate expert capacity
+ sl, bs, _ = x.size()
+
+ expert_capacity = expert_capacity_fn(
+ sl * bs,
+ top_k,
+ num_experts,
+ expert_parallel_group,
+ moe_capacity_factor,
+ moe_expert_model_parallelism,
+ )
+
+ if expert_capacity == 0:
+ expert_capacity = torch.max(tokens_per_expert).item()
+
+ x = permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ top_k,
+ w1,
+ w2,
+ w1_bias,
+ w2_bias,
+ gradient_scale,
+ alpha,
+ )
+ return x, tokens_per_expert
+
+
+def parallel_forward_once(
+ x: torch.Tensor,
+ expert_weights: torch.Tensor,
+ top_experts: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ top_k: int = 4,
+ num_experts: int = 128,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = True,
+ hidden_size: int = 1152,
+ mlp_impl: Optional[str] = "grouped",
+):
+ # Flatten inputs
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+
+ # TODO: remove debugging var
+ # my_rank = dist.get_rank(expert_parallel_group) if expert_parallel_group else 0
+
+ with torch.no_grad():
+ # Step 1: Local permutation setup
+ indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
+ top_experts, sort_end_bit, num_experts
+ )
+
+ # Calculate sharding parameters
+ world_size = dist.get_world_size(expert_parallel_group)
+ hidden_sharding_deg = hidden_sharding_degree(
+ world_size, num_experts, hidden_size
+ )
+ experts_per_rank_val = experts_per_rank(num_experts, world_size)
+
+ # Replicate token counts for hidden sharding
+ repeated_tokens_per_expert = ops.repeat(
+ tokens_per_expert, (hidden_sharding_deg,)
+ )
+
+ # Exchange token counts across devices
+ parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert)
+
+ # Ensure CUB knows which device to use
+ tpe_handle = dist.all_to_all_single(
+ parallel_tokens_per_expert,
+ repeated_tokens_per_expert,
+ group=expert_parallel_group,
+ async_op=True,
+ )
+
+ # Step 2: Local permutation - group tokens by target device
+ x = x.view(-1, x.shape[-1]) # [sl * bs, hs]
+ x = ops.gather(x, indices, bin_ids, bins, top_k)
+
+ # Step 3: Compute communication counts and exchange tokens
+ with torch.no_grad():
+ tpe_handle.wait()
+
+ # Reshape for per-device calculations
+ repeated_tokens_per_expert = repeated_tokens_per_expert.view(
+ world_size, experts_per_rank_val
+ )
+ parallel_tokens_per_expert = parallel_tokens_per_expert.view(
+ world_size, experts_per_rank_val
+ )
+
+ # Calculate send/recv counts
+ send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1).tolist()
+ # recv_counts = parallel_tokens_per_expert.cpu().sum(dim=-1).tolist()
+ parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
+ recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1).tolist()
+ tokens_received = sum(recv_counts)
+
+ # Replicate for hidden sharding
+ x = ops.repeat(x, (hidden_sharding_deg, 1))
+
+ # Cross-device token exchange
+ parallel_x, parallel_x_handle = _layers.all_to_all.all_to_all(
+ x, recv_counts, send_counts, expert_parallel_group, async_op=True
+ )
+
+ with torch.no_grad():
+ # Step 4: Setup for local expert computation
+ replicate_bins = ops.inclusive_cumsum(parallel_tokens_per_expert.flatten(), 0)
+ replicate_bins = (
+ replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins
+ )
+
+ # Create expert indices for received tokens
+ parallel_top_expert = torch.remainder(
+ torch.arange(
+ num_experts * hidden_sharding_deg,
+ dtype=torch.int32,
+ device=indices.device,
+ ),
+ experts_per_rank_val,
+ )
+ parallel_top_expert = ops.replicate(
+ parallel_top_expert.unsqueeze(dim=0),
+ replicate_bins,
+ tokens_received,
+ ).flatten()
+
+ # Sort tokens by expert assignment
+ parallel_bin_ids, parallel_indices = ops.sort(
+ parallel_top_expert,
+ sort_end_bit,
+ )
+
+ # Calculate bins for local experts
+ parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
+ dim=0, dtype=torch.int
+ )
+ parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = (
+ parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins
+ )
+
+ # Calculate expert capacity
+ expert_capacity = expert_capacity_fn(
+ tokens_received,
+ top_k,
+ experts_per_rank_val,
+ expert_parallel_group,
+ moe_capacity_factor,
+ moe_expert_model_parallelism,
+ )
+ if expert_capacity == 0:
+ expert_capacity = torch.max(parallel_tokens_per_expert).item()
+
+ # Locally permute the tokens and perform the expert computation.
+ # Block to make sure that the cross-device permutation is complete.
+ if mlp_impl == "grouped":
+ # GroupedMLP requires counts on CPU. We can use the tensor already
+ # moved to CPU for the prior all_to_all, which avoids an extra
+ # device synchronization.
+ parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+
+ # Step 5: Expert computation
+ parallel_x_handle.wait()
+
+ parallel_x = permute_and_compute(
+ parallel_x,
+ parallel_tokens_per_expert,
+ parallel_indices,
+ parallel_bin_ids,
+ None, # expert_weights
+ parallel_bins,
+ expert_capacity,
+ top_k=1,
+ w1=w1,
+ w2=w2,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ )
+
+ # Step 6: Reverse communication - send results back
+ x, _ = _layers.all_to_all.all_to_all(
+ parallel_x, send_counts, recv_counts, expert_parallel_group
+ )
+
+ # Step 7: Reduce across hidden sharding dimension
+ shape = (hidden_sharding_deg, -1, hidden_size)
+ x = x.view(shape).sum(dim=0)
+
+ # Step 8: Final local unpermutation
+ x = ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
+
+ return x, tokens_per_expert.flatten()
+
+
+def moe_forward(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: Optional[torch.Tensor],
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+ w1: torch.Tensor = None,
+ w2: torch.Tensor = None,
+ w1_bias: torch.Tensor = None,
+ w2_bias: torch.Tensor = None,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ forward_fn: Any = None,
+ hidden_size: int = None,
+ mlp_impl: str = "grouped",
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+
+ # Route tokens to experts
+ logits, expert_weights, expert_indices = route_tokens(
+ x,
+ router_weight,
+ router_bias,
+ moe_top_k,
+ moe_num_experts,
+ moe_jitter_eps,
+ moe_normalize_expert_weights,
+ uniform_expert_assignment,
+ training,
+ )
+
+ # Create router scores for output
+ router_scores = (
+ torch.zeros_like(logits)
+ .scatter_(1, expert_indices, expert_weights)
+ .transpose(0, 1)
+ )
+
+ in_shape = x.size()
+
+ # Prepare forward function arguments
+ forward_args = {
+ "x": x,
+ "expert_weights": expert_weights,
+ "top_experts": expert_indices,
+ "w1": w1,
+ "w2": w2,
+ "w1_bias": w1_bias,
+ "w2_bias": w2_bias,
+ "gradient_scale": gradient_scale,
+ "alpha": alpha,
+ "sort_end_bit": sort_end_bit,
+ "top_k": moe_top_k,
+ "num_experts": moe_num_experts,
+ "expert_parallel_group": expert_parallel_group,
+ "moe_capacity_factor": moe_capacity_factor,
+ "moe_expert_model_parallelism": moe_expert_model_parallelism,
+ "mlp_impl": mlp_impl,
+ }
+
+ # Add hidden_size for parallel forward
+ if moe_expert_model_parallelism and hidden_size is not None:
+ forward_args["hidden_size"] = hidden_size
+ elif moe_expert_model_parallelism and hidden_size is None:
+ # Infer hidden_size from input shape
+ forward_args["hidden_size"] = x.shape[-1]
+
+ # Compute expert outputs
+ x, tokens_per_expert = forward_fn(**forward_args)
+
+ # Save load balancing loss if needed
+ moe_loss_weight = 0.0 # Can be made configurable
+ if training and moe_loss_weight > 0:
+ save_load_balancing_loss((tokens_per_expert, logits))
+
+ # Restore original shape
+ x = x.view(in_shape)
+
+ return x, expert_weights, router_scores
+
+
+def moe_forward_with_shared_expert(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: Optional[torch.Tensor],
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+ w1: torch.Tensor = None,
+ w2: torch.Tensor = None,
+ w1_bias: torch.Tensor = None,
+ w2_bias: torch.Tensor = None,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ forward_fn: Any = None,
+ hidden_size: int = None,
+ mlp_impl: str = "grouped",
+ # Shared expert parameters
+ shared_up_proj_weight: Optional[torch.Tensor] = None,
+ shared_down_proj_weight: Optional[torch.Tensor] = None,
+ shared_up_proj_bias: Optional[torch.Tensor] = None,
+ shared_down_proj_bias: Optional[torch.Tensor] = None,
+ shared_expert_weighted_sum: bool = False,
+ shared_activation_fn: Optional[Any] = None,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+
+ # First, compute regular MoE forward pass
+ expert_out, expert_weights, router_scores = moe_forward(
+ x=x,
+ router_weight=router_weight,
+ router_bias=router_bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=training,
+ w1=w1,
+ w2=w2,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=moe_expert_model_parallelism,
+ forward_fn=forward_fn,
+ hidden_size=hidden_size,
+ mlp_impl=mlp_impl,
+ )
+
+ # If shared expert weights provided, compute shared expert output
+ if shared_up_proj_weight is not None and shared_down_proj_weight is not None:
+ shared_expert_out = shared_mlp_forward(
+ x=x,
+ up_proj_weight=shared_up_proj_weight,
+ down_proj_weight=shared_down_proj_weight,
+ up_proj_bias=shared_up_proj_bias,
+ down_proj_bias=shared_down_proj_bias,
+ activation_fn=shared_activation_fn,
+ gradient_scale=gradient_scale,
+ )
+
+ # Combine expert outputs
+ combined_out = combine_expert_shared_outputs(
+ shared_expert_out=shared_expert_out,
+ expert_out=expert_out,
+ shared_expert_weighted_sum=shared_expert_weighted_sum,
+ moe_top_k=moe_top_k,
+ )
+
+ return combined_out, expert_weights, router_scores
+
+ # Return regular MoE output if no shared expert
+ return expert_out, expert_weights, router_scores
+
+
+def create_shared_expert_weights(
+ hidden_size: int,
+ shared_expert_hidden_size: int,
+ device: torch.device,
+ dtype: torch.dtype,
+ init_method: Any,
+ output_layer_init_method: Any = None,
+) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
+
+ if output_layer_init_method is None:
+ output_layer_init_method = init_method
+
+ # Create weight tensors
+ up_proj_weight = torch.empty(
+ shared_expert_hidden_size,
+ hidden_size,
+ device=device,
+ dtype=dtype,
+ )
+ down_proj_weight = torch.empty(
+ hidden_size,
+ shared_expert_hidden_size,
+ device=device,
+ dtype=dtype,
+ )
+
+ # Initialize weights
+ init_method(up_proj_weight)
+ output_layer_init_method(down_proj_weight)
+
+ # No bias by default
+ return up_proj_weight, down_proj_weight, None, None
+
+
+# HACK: Extract device_mesh from pre-hook closure - required for transformers integration
+# This exists because device_mesh is trapped in hook closures with no model attribute
+# Fragile - breaks if hook structure changes or Python internals change
+# TODO: Replace with a more robust solution when available
+def get_device_mesh(model):
+ # Extract device_mesh from child's unused pre_hook closure
+ try:
+ # Find the pre-hook that contains 'device_mesh' in its closure
+ hook = next(
+ h
+ for h in model.experts._forward_pre_hooks.values()
+ if "device_mesh" in h.__code__.co_freevars
+ )
+ # Extract the device_mesh from the closure
+ return hook.__closure__[
+ hook.__code__.co_freevars.index("device_mesh")
+ ].cell_contents
+ except Exception:
+ return None
+
+
+class MegaBlocksMoeMLP(torch.nn.Module):
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ gradient_scale = getattr(self.experts, "gradient_scale", None)
+ alpha = getattr(self.experts, "alpha", 1.0)
+ moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+ uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
+
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = get_device_mesh(self)
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
+
+ has_parallel = (
+ expert_parallel_group is not None
+ and dist.is_initialized()
+ and dist.get_world_size(expert_parallel_group) > 1
+ )
+ forward_fn = parallel_forward_once if has_parallel else forward_once
+
+ sort_end_bit = max(
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
+ )
+ mlp_impl = getattr(self, "mlp_impl", "grouped")
+ output, expert_weights_out, *_ = moe_forward(
+ x=x,
+ router_weight=self.router.weight,
+ router_bias=self.router.bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=self.training,
+ w1=self.experts.gate_up_proj,
+ w2=self.experts.down_proj,
+ w1_bias=self.experts.gate_up_proj_bias,
+ w2_bias=self.experts.down_proj_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=has_parallel,
+ forward_fn=forward_fn,
+ hidden_size=self.experts.hidden_size,
+ mlp_impl=mlp_impl,
+ )
+ return output, expert_weights_out
+
+
+# Export main classes
+__all__ = ["MegaBlocksMoeMLP", "MegaBlocksMoeMLPWithSharedExpert"]
+
+
+class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
+
+ def __init__(self):
+ super().__init__()
+ # Shared expert weights will be set by the user
+ self.shared_up_proj_weight = None
+ self.shared_down_proj_weight = None
+ self.shared_up_proj_bias = None
+ self.shared_down_proj_bias = None
+ self.shared_expert_weighted_sum = False
+ self.shared_activation_fn = None
+
+ def set_shared_expert_weights(
+ self,
+ up_proj_weight: torch.Tensor,
+ down_proj_weight: torch.Tensor,
+ up_proj_bias: Optional[torch.Tensor] = None,
+ down_proj_bias: Optional[torch.Tensor] = None,
+ weighted_sum: bool = False,
+ activation_fn: Optional[Any] = None,
+ ):
+ self.shared_up_proj_weight = up_proj_weight
+ self.shared_down_proj_weight = down_proj_weight
+ self.shared_up_proj_bias = up_proj_bias
+ self.shared_down_proj_bias = down_proj_bias
+ self.shared_expert_weighted_sum = weighted_sum
+ self.shared_activation_fn = activation_fn
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ gradient_scale = getattr(self.experts, "gradient_scale", None)
+ alpha = getattr(self.experts, "alpha", 1.0)
+ moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+ uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
+
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = get_device_mesh(self)
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
+
+ has_parallel = (
+ expert_parallel_group is not None
+ and dist.is_initialized()
+ and dist.get_world_size(expert_parallel_group) > 1
+ )
+ forward_fn = parallel_forward_once if has_parallel else forward_once
+
+ sort_end_bit = max(
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
+ )
+ mlp_impl = getattr(self, "mlp_impl", "grouped")
+
+ output, expert_weights_out, *_ = moe_forward_with_shared_expert(
+ x=x,
+ router_weight=self.router.weight,
+ router_bias=self.router.bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=self.training,
+ w1=self.experts.gate_up_proj,
+ w2=self.experts.down_proj,
+ w1_bias=self.experts.gate_up_proj_bias,
+ w2_bias=self.experts.down_proj_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=has_parallel,
+ forward_fn=forward_fn,
+ hidden_size=self.experts.hidden_size,
+ mlp_impl=mlp_impl,
+ # Shared expert parameters
+ shared_up_proj_weight=self.shared_up_proj_weight,
+ shared_down_proj_weight=self.shared_down_proj_weight,
+ shared_up_proj_bias=self.shared_up_proj_bias,
+ shared_down_proj_bias=self.shared_down_proj_bias,
+ shared_expert_weighted_sum=self.shared_expert_weighted_sum,
+ shared_activation_fn=self.shared_activation_fn,
+ )
+ return output, expert_weights_out
+
+
+# Patch for XPU or CPU support
+if hasattr(torch, "xpu") and torch.xpu.is_available():
+ from .xpu_fused_moe import MegaBlocksMoeMLP
+
+from .cpu_moe_cpp import CPUMegaBlocksMoeMLP
diff --git a/build/torch29-cxx11-cu126-x86_64-linux/megablocks/__init__.py b/build/torch29-cxx11-cu126-x86_64-linux/megablocks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..03dbc1afe1cf156661a2b1b22003cd5f599a0309
--- /dev/null
+++ b/build/torch29-cxx11-cu126-x86_64-linux/megablocks/__init__.py
@@ -0,0 +1,26 @@
+import ctypes
+import sys
+
+import importlib
+from pathlib import Path
+from types import ModuleType
+
+def _import_from_path(file_path: Path) -> ModuleType:
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
+ # it would also be used for other imports. So, we make a module name that
+ # depends on the path for it to be unique using the hex-encoded hash of
+ # the path.
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
+ module_name = path_hash
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ if spec is None:
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
+ module = importlib.util.module_from_spec(spec)
+ if module is None:
+ raise ImportError(f"Cannot load module {module_name} from spec")
+ sys.modules[module_name] = module
+ spec.loader.exec_module(module) # type: ignore
+ return module
+
+
+globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
diff --git a/build/torch29-cxx11-cu126-x86_64-linux/metadata.json b/build/torch29-cxx11-cu126-x86_64-linux/metadata.json
new file mode 100644
index 0000000000000000000000000000000000000000..155112c59509d3b4d07f4d090cbf57071e3f5217
--- /dev/null
+++ b/build/torch29-cxx11-cu126-x86_64-linux/metadata.json
@@ -0,0 +1,18 @@
+{
+ "version": 1,
+ "license": "Apache-2.0",
+ "python-depends": [],
+ "backend": {
+ "type": "cuda",
+ "archs": [
+ "7.0",
+ "7.2",
+ "7.5",
+ "8.0",
+ "8.6",
+ "8.7",
+ "8.9",
+ "9.0"
+ ]
+ }
+}
diff --git a/build/torch29-cxx11-cu126-x86_64-linux/ops/__init__.py b/build/torch29-cxx11-cu126-x86_64-linux/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b944080df810d0b0cfc571f3009b0098a651f9b7
--- /dev/null
+++ b/build/torch29-cxx11-cu126-x86_64-linux/ops/__init__.py
@@ -0,0 +1,35 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from .binned_gather import binned_gather
+from .binned_scatter import binned_scatter
+from .cumsum import exclusive_cumsum, inclusive_cumsum
+from .gather import gather
+from .histogram import histogram
+from .padded_gather import padded_gather
+from .padded_scatter import padded_scatter
+from .repeat import repeat
+from .replicate import replicate
+from .round_up import round_up
+from .scatter import scatter
+from .sort import sort
+from .sum import sum
+from .topology import topology
+
+__all__ = [
+ 'binned_gather',
+ 'binned_scatter',
+ 'exclusive_cumsum',
+ 'inclusive_cumsum',
+ 'gather',
+ 'histogram',
+ 'padded_gather',
+ 'padded_scatter',
+ 'repeat',
+ 'replicate',
+ 'round_up',
+ 'scatter',
+ 'sort',
+ 'sum',
+ 'topology',
+]
diff --git a/build/torch29-cxx11-cu126-x86_64-linux/ops/all_to_all_benchmark.py b/build/torch29-cxx11-cu126-x86_64-linux/ops/all_to_all_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c939818edca3345f6344bbc7cef07ffe3cd0181
--- /dev/null
+++ b/build/torch29-cxx11-cu126-x86_64-linux/ops/all_to_all_benchmark.py
@@ -0,0 +1,63 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import torch.distributed as dist
+
+# from megablocks import benchmark_util
+# from megablocks.layers.all_to_all import all_to_all
+
+from .. import benchmark_util
+from .._layers.all_to_all import all_to_all
+
+_ALL_TO_ALL_BENCHMARK = (
+ (8, 1024),
+ (16, 1024),
+ (32, 1024),
+ (64, 1024),
+ (128, 1024),
+ (256, 1024),
+ (512, 1024),
+ (1024, 1024),
+ (2 * 1024, 1024),
+ (4 * 1024, 1024),
+ (8 * 1024, 1024),
+ (16 * 1024, 1024),
+ (32 * 1024, 1024),
+ (64 * 1024, 1024),
+ (128 * 1024, 1024),
+ (256 * 1024, 1024),
+ (512 * 1024, 1024),
+ (1024 * 1024, 1024),
+)
+
+
+def benchmark_all_to_all(group, sl, hs):
+ world_size = dist.get_world_size(group)
+ assert (sl % world_size) == 0
+ send_recv_sizes = [sl // world_size] * world_size
+
+ x = torch.randn((sl, hs)).cuda().half()
+
+ details = {
+ 'world_size': world_size,
+ 'message_size (B)': send_recv_sizes[0] * hs * 2, # 2B elements.
+ }
+
+ def benchmark():
+ return all_to_all(x, send_recv_sizes, send_recv_sizes, group)
+
+ time, std = benchmark_util.benchmark_function(benchmark)
+
+ if dist.get_rank(group) == 0:
+ benchmark_util.log_benchmark('All-To-All', details, time, std)
+
+
+if __name__ == '__main__':
+ assert dist.is_available()
+ group = dist.init_process_group(backend='nccl')
+ local_rank = dist.get_rank(group)
+ torch.cuda.set_device(local_rank)
+
+ for args in _ALL_TO_ALL_BENCHMARK:
+ benchmark_all_to_all(group, *args)
diff --git a/build/torch29-cxx11-cu126-x86_64-linux/ops/binned_gather.py b/build/torch29-cxx11-cu126-x86_64-linux/ops/binned_gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..189a7fa3518d660f29ea32e7a04827164af98d60
--- /dev/null
+++ b/build/torch29-cxx11-cu126-x86_64-linux/ops/binned_gather.py
@@ -0,0 +1,37 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for binned_gather kernel.
+class BinnedGatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bins: torch.Tensor,
+ bin_size: int,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bins)
+ ctx.top_k = top_k
+ return kernels.binned_gather(x, indices, None, bins, bin_size, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ indices, bins = ctx.saved_tensors
+ out = kernels.binned_scatter(grad, indices, None, bins, ctx.top_k)
+ return out, None, None, None, None
+
+
+binned_gather = BinnedGatherOp.apply
diff --git a/build/torch29-cxx11-cu126-x86_64-linux/ops/binned_scatter.py b/build/torch29-cxx11-cu126-x86_64-linux/ops/binned_scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb937c0c106662ce8108c1cb926f8f063b163d3d
--- /dev/null
+++ b/build/torch29-cxx11-cu126-x86_64-linux/ops/binned_scatter.py
@@ -0,0 +1,59 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for binned_scatter kernel.
+class BinnedScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ):
+ assert len(x.size()) == 3
+ ctx.bin_size = x.size(1)
+ ctx.top_k = top_k
+
+ # TODO(tgale): Don't save 'x' for backwards if we don't need to
+ # calculate the gradient w.r.t. 'weights'.
+ ctx.save_for_backward(x, indices, weights, bins)
+ return kernels.binned_scatter(x, indices, weights, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ x, indices, weights, bins = ctx.saved_tensors
+ out = kernels.binned_gather(
+ grad,
+ indices,
+ weights,
+ bins,
+ ctx.bin_size,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[2]:
+ wgrad = kernels.binned_scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bins,
+ ctx.top_k,
+ )
+ return out, None, wgrad, None, None
+
+
+binned_scatter = BinnedScatterOp.apply
diff --git a/build/torch29-cxx11-cu126-x86_64-linux/ops/cumsum.py b/build/torch29-cxx11-cu126-x86_64-linux/ops/cumsum.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2b7572391e20045d335cf7337246e8a9b9f57ef
--- /dev/null
+++ b/build/torch29-cxx11-cu126-x86_64-linux/ops/cumsum.py
@@ -0,0 +1,52 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ # import megablocks_ops as ops # type: ignore
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrappers for cumsum kernels.
+# NOTE: Does not support gradients.
+class ExclusiveCumsumOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, dim: int):
+ if len(x.size()) == 1:
+ x = x.view([1, -1])
+ out = torch.empty_like(x)
+ ops.exclusive_cumsum(x, 1, out)
+ return out.squeeze()
+ out = torch.empty_like(x)
+ ops.exclusive_cumsum(x, dim, out)
+ return out
+
+
+exclusive_cumsum = ExclusiveCumsumOp.apply
+
+
+class InclusiveCumsumOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, dim: int) -> torch.Tensor:
+ if len(x.size()) == 1:
+ x = x.view([1, -1])
+ out = torch.empty_like(x)
+ ops.inclusive_cumsum(x, 1, out)
+ return out.squeeze()
+ out = torch.empty_like(x)
+ ops.inclusive_cumsum(x, dim, out)
+ return out
+
+
+inclusive_cumsum = InclusiveCumsumOp.apply
diff --git a/build/torch29-cxx11-cu126-x86_64-linux/ops/gather.py b/build/torch29-cxx11-cu126-x86_64-linux/ops/gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1f87c1e7bed8d3589dd790805234976e0b05898
--- /dev/null
+++ b/build/torch29-cxx11-cu126-x86_64-linux/ops/gather.py
@@ -0,0 +1,38 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for gather kernel.
+class GatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bin_ids, bins)
+ ctx.top_k = top_k
+ return kernels.gather(x, indices, bin_ids, None, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+
+ indices, bin_ids, bins = ctx.saved_tensors
+ out = kernels.scatter(grad, indices, bin_ids, None, bins, ctx.top_k)
+ return out, None, None, None, None, None
+
+
+gather = GatherOp.apply
diff --git a/build/torch29-cxx11-cu126-x86_64-linux/ops/histogram.py b/build/torch29-cxx11-cu126-x86_64-linux/ops/histogram.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b3f058ec373cbba7555704fb5e4212c3cc75d9d
--- /dev/null
+++ b/build/torch29-cxx11-cu126-x86_64-linux/ops/histogram.py
@@ -0,0 +1,27 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for histogram kernel.
+# NOTE: Does not support gradients.
+class HistogramOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, max_val: float):
+ return ops.histogram(x, max_val)
+
+
+histogram = HistogramOp.apply
diff --git a/build/torch29-cxx11-cu126-x86_64-linux/ops/histogram_benchmark.py b/build/torch29-cxx11-cu126-x86_64-linux/ops/histogram_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..c57b7bf8228e01237236748147368b09ffdf8072
--- /dev/null
+++ b/build/torch29-cxx11-cu126-x86_64-linux/ops/histogram_benchmark.py
@@ -0,0 +1,78 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import numpy as np
+import torch
+from absl.testing import parameterized
+
+from .. import ops
+
+_HISTOGRAM_TESTS = (
+ (16384, torch.int32, 2),
+ (16384, torch.int32, 4),
+ (16384, torch.int32, 8),
+ (16384, torch.int32, 16),
+ (16384, torch.int32, 32),
+ (16384, torch.int32, 64),
+ (16384, torch.int32, 128),
+ (16384, torch.int32, 256),
+)
+
+
+def benchmark_function(fn, iterations=10):
+ # Run once to get rid of startup overhead.
+ fn()
+ times = []
+ for _ in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+ start.record()
+ fn()
+ end.record()
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ times = np.array(times)
+ return times.mean(), times.std(), times.max(), times.min()
+
+
+def log_benchmark(arguments, mean_t, std_t):
+ print('=' * 60)
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t))
+ print('=' * 60)
+
+
+class HistogramBenchmark(parameterized.TestCase):
+
+ @parameterized.parameters(*_HISTOGRAM_TESTS)
+ def testHistogram(self, n, dtype, max_val):
+ x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
+
+ mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.histogram(x, max_val),)
+ arguments = {
+ 'n': n,
+ 'dtype': dtype,
+ 'max_val': max_val,
+ }
+ log_benchmark(arguments, mean_t, std_t)
+
+ @parameterized.parameters(*_HISTOGRAM_TESTS)
+ def testTorchHistogram(self, n, dtype, max_val):
+ x = torch.randint(0, 128, (n,)).cuda().to(dtype)
+
+ mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.histc(x, max_val, 0, max_val - 1),)
+ arguments = {
+ 'n': n,
+ 'dtype': dtype,
+ 'max_val': max_val,
+ }
+ log_benchmark(arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch29-cxx11-cu126-x86_64-linux/ops/matmul_benchmark.py b/build/torch29-cxx11-cu126-x86_64-linux/ops/matmul_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ccc5dcec5e9a663794fad944c45285869c4d1c1
--- /dev/null
+++ b/build/torch29-cxx11-cu126-x86_64-linux/ops/matmul_benchmark.py
@@ -0,0 +1,415 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+
+# import stk
+
+# try:
+# import stk
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/ops/matmul_benchmark.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+
+# Calling tensor.t() calls tensor.transpose(0, 1) which calls
+# torch.as_strided(...). Circumvent this chain to avoid an overhead
+# this adds.
+def transpose_view(x):
+ return torch.as_strided(
+ x,
+ (x.shape[1], x.shape[0]),
+ (x.stride()[1], x.stride()[0]),
+ )
+
+
+_MATMUL_TESTS = (
+ (64 * 1024, 512, 2048, 64),
+ (32 * 1024, 768, 3072, 64),
+ (8 * 1024, 1024, 4096, 64),
+ (4 * 2048, 4096, 4 * 4096, 4),
+)
+
+
+def log_benchmark(name, arguments, time, std, flops):
+ benchmark_util.log_benchmark(name, arguments, time, std)
+ print('flops = {:.2f}B'.format(flops / 1e9))
+ print('throughput = {:.2f}T'.format(flops / 1e9 / time))
+ print('=' * 60)
+
+
+class MatmulBenchmark(parameterized.TestCase):
+
+ def build_sparse_matrix(self, x, padded_bins, fhs, ne):
+ blocking = 128
+ padded_tokens, _ = x.size()
+ assert padded_tokens % blocking == 0
+ assert fhs % blocking == 0
+
+ # Offsets for the sparse matrix. All rows have the
+ # same number of nonzero blocks dictated by the
+ # dimensionality of a single expert.
+ block_rows = padded_tokens // blocking
+ blocks_per_row = fhs // blocking
+ offsets = torch.arange(
+ 0,
+ block_rows * blocks_per_row + 1,
+ blocks_per_row,
+ dtype=torch.int32,
+ device=x.device,
+ )
+
+ # Indices for the sparse matrix. The indices for
+ # the intermediate matrix are dynamic depending
+ # on the mapping of tokens to experts.
+ column_indices = ops.topology(
+ padded_bins,
+ blocking,
+ block_rows,
+ blocks_per_row,
+ )
+ data = torch.empty(
+ column_indices.numel(),
+ blocking,
+ blocking,
+ dtype=torch.float16,
+ device=x.device,
+ )
+ shape = (padded_tokens, fhs * ne)
+ row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
+ return stk.Matrix(shape, data, row_indices, column_indices, offsets)
+
+ def build_input_matrix(self, sl, hs, ne):
+ x = torch.randn((sl, hs)).cuda().half()
+
+ # Assign tokens to experts uniformly.
+ top_expert = torch.arange(0, sl).cuda().int() % ne
+
+ bin_ids, indices = ops.sort(top_expert)
+ tokens_per_expert = ops.histogram(top_expert, ne)
+ padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ out = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, 1)
+ return out, padded_bins
+
+ def build_weight_matrix(self, ne, hs, fhs):
+ return torch.randn((hs, ne * fhs)).cuda().half()
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear0_Fwd_SDD_NT(self, sl, hs, fhs, ne):
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
+ w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+ topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+ w = transpose_view(w)
+
+ def benchmark():
+ return stk.ops.sdd(x, w, topo)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '0::Fwd::SDD::NT',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * fhs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear0_GradX_DSD_NN(self, sl, hs, fhs, ne):
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
+ w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+ topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+
+ def benchmark():
+ return stk.ops.dsd(topo, w)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '0::GradX::DSD::NN',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * fhs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear0_GradW_DSD_TN(self, sl, hs, fhs, ne):
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
+ topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+ topo = topo.t()
+
+ def benchmark():
+ return stk.ops.dsd(topo, x)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '0::GradW::DSD::TN',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * fhs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear1_Fwd_DSD_NN(self, sl, hs, fhs, ne):
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
+ w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+ x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+
+ def benchmark():
+ return stk.ops.dsd(x, w)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '1::Fwd::DSD::NN',
+ arguments,
+ mean_t,
+ std_t,
+ x.nnz * hs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear1_GradX_SDD_NT(self, sl, hs, fhs, ne):
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
+ w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+ x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+ out = stk.ops.dsd(x, w)
+ w = transpose_view(w)
+
+ def benchmark():
+ return stk.ops.sdd(out, w, x)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '1::GradX::SDD::NT',
+ arguments,
+ mean_t,
+ std_t,
+ x.nnz * hs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear1_GradW_DSD_TN(self, sl, hs, fhs, ne):
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
+ w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+ x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+ out = stk.ops.dsd(x, w)
+ x = x.t()
+
+ def benchmark():
+ return stk.ops.dsd(x, out)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '1::GradW::DSD::TN',
+ arguments,
+ mean_t,
+ std_t,
+ x.nnz * hs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear0_Fwd_DDD_NT(self, sl, hs, fhs, ne):
+ assert (sl % ne) == 0
+ x = torch.randn((ne, sl // ne, hs)).cuda().half()
+ w = torch.randn((ne, hs, fhs)).cuda().half()
+
+ w = w.transpose(1, 2).contiguous()
+ w = w.transpose(1, 2)
+
+ def benchmark():
+ return torch.bmm(x, w)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '0::Fwd:DDD::NT',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * fhs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear0_GradX_DDD_NN(self, sl, hs, fhs, ne):
+ assert (sl % ne) == 0
+ x = torch.randn((ne, sl // ne, hs)).cuda().half()
+ w = torch.randn((ne, hs, fhs)).cuda().half()
+ out = torch.bmm(x, w)
+ w = w.transpose(1, 2).contiguous()
+
+ def benchmark():
+ return torch.bmm(out, w)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '0:GradX:DDD::NN',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * fhs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear0_GradW_DDD_TN(self, sl, hs, fhs, ne):
+ assert (sl % ne) == 0
+ x = torch.randn((ne, sl // ne, hs)).cuda().half()
+ w = torch.randn((ne, hs, fhs)).cuda().half()
+ out = torch.bmm(x, w)
+ out = out.transpose(1, 2)
+
+ def benchmark():
+ return torch.bmm(out, x)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '0:GradW:DDD::TN',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * fhs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear1_Fwd_DDD_NN(self, sl, hs, fhs, ne):
+ assert (sl % ne) == 0
+ x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+ w = torch.randn((ne, fhs, hs)).cuda().half()
+
+ def benchmark():
+ return torch.bmm(x, w)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '1::Fwd::DDD::NN',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * hs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear1_GradX_DDD_NT(self, sl, hs, fhs, ne):
+ assert (sl % ne) == 0
+ x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+ w = torch.randn((ne, fhs, hs)).cuda().half()
+ out = torch.bmm(x, w)
+ w = torch.transpose(w, 1, 2)
+
+ def benchmark():
+ return torch.bmm(out, w)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '1::GradX::DDD::NT',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * hs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear1_GradW_DDD_TN(self, sl, hs, fhs, ne):
+ assert (sl % ne) == 0
+ x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+ w = torch.randn((ne, fhs, hs)).cuda().half()
+ out = torch.bmm(x, w)
+ x = torch.transpose(x, 1, 2)
+
+ def benchmark():
+ return torch.bmm(x, out)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '1::GradW::DDD::TN',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * hs * 2,
+ )
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch29-cxx11-cu126-x86_64-linux/ops/padded_gather.py b/build/torch29-cxx11-cu126-x86_64-linux/ops/padded_gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1cf4047c9494394d2a3884ba8830179013db7ff
--- /dev/null
+++ b/build/torch29-cxx11-cu126-x86_64-linux/ops/padded_gather.py
@@ -0,0 +1,55 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for padded_gather kernel.
+class PaddedGatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bin_ids, bins, padded_bins)
+ ctx.top_k = top_k
+ return kernels.padded_gather(
+ x,
+ indices,
+ bin_ids,
+ None,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+
+ indices, bin_ids, bins, padded_bins = ctx.saved_tensors
+ out = kernels.padded_scatter(
+ grad,
+ indices,
+ bin_ids,
+ None,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+ return out, None, None, None, None, None
+
+
+padded_gather = PaddedGatherOp.apply
diff --git a/build/torch29-cxx11-cu126-x86_64-linux/ops/padded_scatter.py b/build/torch29-cxx11-cu126-x86_64-linux/ops/padded_scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..61e021b81497e472cda5d72bdac557a0ca92d262
--- /dev/null
+++ b/build/torch29-cxx11-cu126-x86_64-linux/ops/padded_scatter.py
@@ -0,0 +1,98 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for padded_scatter kernel.
+class PaddedScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+ ):
+ maybe_x = [x] if ctx.needs_input_grad[3] else []
+ ctx.save_for_backward(
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ *maybe_x,
+ )
+ ctx.top_k = top_k
+ ctx.x_shape = x.shape
+ return kernels.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ saved_tensors = ctx.saved_tensors
+
+ indices, bin_ids, weights, bins, padded_bins = saved_tensors[:5]
+ dgrad = None
+ if ctx.needs_input_grad[0]:
+ dgrad = kernels.padded_gather(
+ grad,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[3]: # need wgrad
+ x = saved_tensors[-1]
+ wgrad = kernels.padded_scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+ return dgrad, None, None, wgrad, None, None, None, None
+
+
+def padded_scatter(
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+):
+ return PaddedScatterOp.apply(
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
diff --git a/build/torch29-cxx11-cu126-x86_64-linux/ops/padded_scatter_benchmark.py b/build/torch29-cxx11-cu126-x86_64-linux/ops/padded_scatter_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..c575cfe7487d346ba9ec18bbb7ef17f2eb77ec51
--- /dev/null
+++ b/build/torch29-cxx11-cu126-x86_64-linux/ops/padded_scatter_benchmark.py
@@ -0,0 +1,66 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import torch
+from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+_PADDED_SCATTER_BENCHMARK = (
+ # dMoE-Medium, 8-way EMP.
+ (1024 * 16, 1024, 8, 4),
+ # dMoE-Medium, post-all-to-all.
+ (1024 * 16 * 4, 1024, 8, 1),
+)
+
+
+class PaddedScatterTest(parameterized.TestCase):
+
+ @parameterized.parameters(*_PADDED_SCATTER_BENCHMARK)
+ def testPaddedScatter(self, sl, hs, ne, top_k):
+ # Create the data and indices.
+ x = torch.randn((sl, hs)).cuda().half()
+
+ # Randomly assign tokens to experts.
+ top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int()
+ bin_ids, indices = ops.sort(top_expert)
+ tokens_per_expert = ops.histogram(top_expert, ne)
+ padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+
+ # Sample weights for the scatter reduce.
+ weights = torch.rand((sl * top_k,)).cuda().half()
+
+ # Gather the data to prepare for backwards.
+ x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
+
+ def benchmark():
+ return ops.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ time, std = benchmark_util.benchmark_function(benchmark)
+ benchmark_util.log_benchmark(
+ 'Padded Scatter',
+ {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'num_experts': ne,
+ 'top_k': top_k,
+ },
+ time,
+ std,
+ )
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch29-cxx11-cu126-x86_64-linux/ops/permute_benchmark.py b/build/torch29-cxx11-cu126-x86_64-linux/ops/permute_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..6536eeeae402659a087e5c51ef9840627af56501
--- /dev/null
+++ b/build/torch29-cxx11-cu126-x86_64-linux/ops/permute_benchmark.py
@@ -0,0 +1,149 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import torch
+from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+_PERMUTE_TESTS = (
+ (16384, 768, 2),
+ (16384, 768, 4),
+ (16384, 768, 8),
+ (16384, 768, 16),
+ (16384, 768, 32),
+ (16384, 768, 64),
+ (16384, 768, 128),
+ (16384 * 8, 768, 2),
+ (16384 * 8, 768, 4),
+ (16384 * 8, 768, 8),
+ (16384 * 8, 768, 16),
+ (16384 * 8, 768, 32),
+ (16384 * 8, 768, 64),
+ (16384 * 8, 768, 128),
+)
+
+
+class PermuteBenchmark(parameterized.TestCase):
+
+ @parameterized.parameters(*_PERMUTE_TESTS)
+ def testBinnedGather(self, sl, hs, ne):
+ # NOTE: Capacity factor == 1.
+ ec = sl // ne
+
+ # Create the data and indices.
+ x = torch.randn((sl, hs)).cuda().half()
+ top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+ bin_ids, indices = ops.sort(top_expert)
+ tokens_per_expert = ops.histogram(indices, ne)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+
+ def benchmark():
+ return ops.binned_gather(x, indices, bins, ec)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'num_experts': ne,
+ }
+ benchmark_util.log_benchmark('BinnedGather', arguments, mean_t, std_t)
+
+ @parameterized.parameters(*_PERMUTE_TESTS)
+ def testBinnedScatter(self, sl, hs, ne):
+ # NOTE: Capacity factor == 1.
+ ec = sl // ne
+
+ # Create the data and indices.
+ x = torch.randn((sl, hs)).cuda().half()
+ top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+ bin_ids, indices = ops.sort(top_expert)
+ tokens_per_expert = ops.histogram(indices, ne)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ x = ops.binned_gather(x, indices, bins, ec)
+
+ def benchmark():
+ return ops.binned_scatter(x, indices, bins)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'num_experts': ne,
+ }
+ benchmark_util.log_benchmark('BinnedScatter', arguments, mean_t, std_t)
+
+ @parameterized.parameters(*_PERMUTE_TESTS)
+ def testPaddedGather(self, sl, hs, ne):
+ # Create the data and indices.
+ x = torch.randn((sl, hs)).cuda().half()
+
+ # Randomly assign tokens to experts.
+ top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+ bin_ids, indices = ops.sort(top_expert)
+ tokens_per_expert = ops.histogram(top_expert, ne)
+ padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+
+ def benchmark():
+ return ops.padded_gather(x, indices, bin_ids, bins, padded_bins)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'num_experts': ne,
+ }
+ benchmark_util.log_benchmark('PaddedGather', arguments, mean_t, std_t)
+
+ @parameterized.parameters(*_PERMUTE_TESTS)
+ def testPaddedScatter(self, sl, hs, ne):
+ # Create the data and indices.
+ x = torch.randn((sl, hs)).cuda().half()
+
+ # Randomly assign tokens to experts.
+ top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+ bin_ids, indices = ops.sort(top_expert)
+ tokens_per_expert = ops.histogram(top_expert, ne)
+ padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins)
+
+ def benchmark():
+ return ops.padded_scatter(x, indices, bin_ids, bins, padded_bins)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'num_experts': ne,
+ }
+ benchmark_util.log_benchmark('PaddedScatter', arguments, mean_t, std_t)
+
+ @parameterized.parameters(*_PERMUTE_TESTS)
+ def testCopy(self, sl, hs, ne):
+ # NOTE: Capacity factor == 1.
+ # ec = sl // ne
+
+ # Create the data and indices.
+ x = torch.randn((sl, hs)).cuda().half()
+ y = x.clone()
+
+ def benchmark():
+ return y.copy_(x)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'num_experts': ne,
+ }
+ benchmark_util.log_benchmark('Copy', arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch29-cxx11-cu126-x86_64-linux/ops/repeat.py b/build/torch29-cxx11-cu126-x86_64-linux/ops/repeat.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e9e09de5f857d51cd758ab30b2f3a846d6f9275
--- /dev/null
+++ b/build/torch29-cxx11-cu126-x86_64-linux/ops/repeat.py
@@ -0,0 +1,10 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+
+def repeat(x: torch.Tensor, tiling: torch.Size):
+ if all((t == 1 for t in tiling)):
+ return x
+ return x.repeat(*tiling)
diff --git a/build/torch29-cxx11-cu126-x86_64-linux/ops/replicate.py b/build/torch29-cxx11-cu126-x86_64-linux/ops/replicate.py
new file mode 100644
index 0000000000000000000000000000000000000000..26daf0eede330603a4b8ea7167faf1411d07ca93
--- /dev/null
+++ b/build/torch29-cxx11-cu126-x86_64-linux/ops/replicate.py
@@ -0,0 +1,36 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for replicate kernel.
+class ReplicateOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, bins: torch.Tensor, num_outputs: int):
+ ctx.save_for_backward(bins)
+ out = torch.empty((x.shape[0], num_outputs), dtype=x.dtype, device=x.device)
+ ops.replicate_forward(x, bins, out)
+ return out
+
+ @staticmethod
+ def backward(ctx: Any, grad: torch.Tensor):
+ bins, = ctx.saved_tensors
+ out = torch.empty((grad.shape[0], bins.shape[0]), dtype=grad.dtype, device=grad.device)
+ ops.replicate_backward(grad, bins, out)
+ return out, None, None
+
+
+replicate = ReplicateOp.apply
diff --git a/build/torch29-cxx11-cu126-x86_64-linux/ops/round_up.py b/build/torch29-cxx11-cu126-x86_64-linux/ops/round_up.py
new file mode 100644
index 0000000000000000000000000000000000000000..6cf6bc873c9f448c5fa9126ebcfd66e8688002af
--- /dev/null
+++ b/build/torch29-cxx11-cu126-x86_64-linux/ops/round_up.py
@@ -0,0 +1,14 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+
+def round_up(x: torch.Tensor, value: int):
+ assert isinstance(value, int)
+ assert x.dtype == torch.int32
+
+ # TODO(tgale): If this becomes and issue
+ # do this in a custom kernel. We only expect
+ # to use this on arrays of less than 1k elements.
+ return torch.div(x + (value - 1), value, rounding_mode='trunc') * value
diff --git a/build/torch29-cxx11-cu126-x86_64-linux/ops/scatter.py b/build/torch29-cxx11-cu126-x86_64-linux/ops/scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4605d9b46f387761b070352365f223dbfe69d47
--- /dev/null
+++ b/build/torch29-cxx11-cu126-x86_64-linux/ops/scatter.py
@@ -0,0 +1,72 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Optional
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for scatter kernel.
+class ScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ) -> torch.Tensor:
+ maybe_x = [x] if ctx.needs_input_grad[3] else []
+ ctx.save_for_backward(indices, bin_ids, weights, bins, *maybe_x)
+ ctx.top_k = top_k
+ ctx.x_shape = x.shape
+ return kernels.scatter(x, indices, bin_ids, weights, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ saved_tensors = ctx.saved_tensors
+
+ indices, bin_ids, weights, bins = saved_tensors[:4]
+ dgrad = None
+ if ctx.needs_input_grad[0]:
+ dgrad = kernels.gather(
+ grad,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[3]: # need wgrad
+ x = saved_tensors[-1]
+ wgrad = kernels.scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bin_ids,
+ bins,
+ ctx.top_k,
+ )
+ return dgrad, None, None, wgrad, None, None, None
+
+
+def scatter(
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+) -> Optional[torch.Tensor]:
+ return ScatterOp.apply(x, indices, bin_ids, weights, bins, top_k)
diff --git a/build/torch29-cxx11-cu126-x86_64-linux/ops/sort.py b/build/torch29-cxx11-cu126-x86_64-linux/ops/sort.py
new file mode 100644
index 0000000000000000000000000000000000000000..bda3bf64283e39533c2eae3627e76bb2d0262c9f
--- /dev/null
+++ b/build/torch29-cxx11-cu126-x86_64-linux/ops/sort.py
@@ -0,0 +1,38 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Optional, Tuple
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+_BITS_FOR_DTYPE = {
+ torch.int16: 16,
+ torch.int32: 32,
+ torch.int64: 64,
+}
+
+
+# Autograd wrapper for sort kernel.
+# NOTE: Does not support gradients.
+class SortOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, end_bit: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
+ if end_bit is None:
+ end_bit = _BITS_FOR_DTYPE[x.dtype]
+ x_out = torch.empty_like(x)
+ iota_out = torch.empty_like(x)
+ ops.sort(x, end_bit, x_out, iota_out)
+ return (x_out, iota_out)
+
+
+sort = SortOp.apply
diff --git a/build/torch29-cxx11-cu126-x86_64-linux/ops/sort_benchmark.py b/build/torch29-cxx11-cu126-x86_64-linux/ops/sort_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..a92ff957d4c552c6e61d9279a7989795472af7b7
--- /dev/null
+++ b/build/torch29-cxx11-cu126-x86_64-linux/ops/sort_benchmark.py
@@ -0,0 +1,85 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import numpy as np
+import torch
+from absl.testing import parameterized
+
+from .. import ops
+
+_SORT_TESTS = (
+ (16384, torch.int32, None),
+ (16384, torch.int32, 2),
+ (16384, torch.int32, 128),
+)
+
+_BASELINE_SORT_TESTS = ((16384,),)
+
+
+def numpy_dtype(dtype):
+ types = {
+ torch.int16: np.int16,
+ torch.int32: np.int32,
+ torch.int64: np.int64,
+ }
+ return types[dtype]
+
+
+def benchmark_function(fn, iterations=10):
+ # Run once to get rid of startup overhead.
+ fn()
+ times = []
+ for _ in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+ start.record()
+ fn()
+ end.record()
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ times = np.array(times)
+ return times.mean(), times.std(), times.max(), times.min()
+
+
+def log_benchmark(arguments, mean_t, std_t):
+ print('=' * 60)
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t))
+ print('=' * 60)
+
+
+class SortBenchmark(parameterized.TestCase):
+
+ @parameterized.parameters(*_SORT_TESTS)
+ def testSort(self, n, dtype, max_val):
+ if max_val is None:
+ max_val = np.iinfo(numpy_dtype(dtype)).max
+ end_bit = int(np.ceil(np.log2(max_val)))
+ x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
+
+ mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.sort(x, end_bit),)
+ arguments = {
+ 'n': n,
+ 'dtype': dtype,
+ 'max_val': max_val,
+ }
+ log_benchmark(arguments, mean_t, std_t)
+
+ @parameterized.parameters(*_BASELINE_SORT_TESTS)
+ def testTorchSort(self, n):
+ x = torch.randint(0, 128, (n,)).cuda().to(torch.int32)
+
+ mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.sort(x))
+ arguments = {
+ 'n': n,
+ }
+ log_benchmark(arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch29-cxx11-cu126-x86_64-linux/ops/stk_autocast.py b/build/torch29-cxx11-cu126-x86_64-linux/ops/stk_autocast.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a3626e5e0eec51339c95a448bca84be14a2ca93
--- /dev/null
+++ b/build/torch29-cxx11-cu126-x86_64-linux/ops/stk_autocast.py
@@ -0,0 +1,39 @@
+# vendored from
+# https://github.com/stanford-futuredata/stk/blob/736313768ef697ce13a0594a41b2512a0fbc9884/stk/backend/autocast.py
+import functools
+import torch
+
+
+def _is_eligible(x):
+ return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64)
+
+
+def _cast(x, dtype):
+ if isinstance(x, torch.Tensor) and _is_eligible(x):
+ return x.to(dtype)
+ elif isinstance(x, map):
+ return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()}
+ elif isinstance(x, list) or isinstance(x, tuple):
+ return type(x)(map(lambda y: _cast(y, dtype), x))
+ return x
+
+
+def custom_fwd(fwd):
+ """Wrap a custom autograd function that always uses autocast dtype."""
+
+ @functools.wraps(fwd)
+ def decorate_fwd(*args, **kwargs):
+ if torch.is_autocast_enabled():
+ with torch.autocast(device_type="cuda", enabled=False):
+ dtype = torch.get_autocast_gpu_dtype()
+ return fwd(*_cast(args, dtype), **_cast(kwargs, dtype))
+ return fwd(*args, **kwargs)
+ return decorate_fwd
+
+
+def custom_bwd(bwd):
+ @functools.wraps(bwd)
+ def decorate_bwd(*args, **kwargs):
+ with torch.autocast(device_type="cuda", enabled=False):
+ return bwd(*args, **kwargs)
+ return decorate_bwd
\ No newline at end of file
diff --git a/build/torch29-cxx11-cu126-x86_64-linux/ops/sum.py b/build/torch29-cxx11-cu126-x86_64-linux/ops/sum.py
new file mode 100644
index 0000000000000000000000000000000000000000..e00c1aa68e1193f5b72f75a2edc37de8d505facc
--- /dev/null
+++ b/build/torch29-cxx11-cu126-x86_64-linux/ops/sum.py
@@ -0,0 +1,9 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+import torch
+
+
+def sum(x: torch.Tensor, dim: int = 0):
+ if x.shape[dim] == 1:
+ return x.squeeze(dim=dim)
+ return x.sum(dim=dim)
diff --git a/build/torch29-cxx11-cu126-x86_64-linux/ops/topology.py b/build/torch29-cxx11-cu126-x86_64-linux/ops/topology.py
new file mode 100644
index 0000000000000000000000000000000000000000..76a50d3164db20534b099dcb4d8487a7aef25d15
--- /dev/null
+++ b/build/torch29-cxx11-cu126-x86_64-linux/ops/topology.py
@@ -0,0 +1,45 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for topology kernel.
+# NOTE: Does not support gradients.
+class TopologyOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(
+ ctx: Any,
+ padded_bins: torch.Tensor,
+ block_size: int,
+ output_block_rows: int,
+ output_block_columns: int,
+ ):
+ out = torch.empty(
+ output_block_rows * output_block_columns,
+ dtype=torch.int16,
+ device=padded_bins.device,
+ )
+ ops.indices(
+ padded_bins,
+ block_size,
+ output_block_rows,
+ output_block_columns,
+ out,
+ )
+ return out
+
+
+topology = TopologyOp.apply
diff --git a/build/torch29-cxx11-cu126-x86_64-linux/stk/__init__.py b/build/torch29-cxx11-cu126-x86_64-linux/stk/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..73c40c267b1c3f4949e9c957a5d2c9f682dfc1a6
--- /dev/null
+++ b/build/torch29-cxx11-cu126-x86_64-linux/stk/__init__.py
@@ -0,0 +1,7 @@
+# import stk.random
+# import stk.ops
+# from stk.matrix import Matrix
+
+from . import random
+from . import ops
+from .matrix import Matrix
diff --git a/build/torch29-cxx11-cu126-x86_64-linux/stk/backend/__init__.py b/build/torch29-cxx11-cu126-x86_64-linux/stk/backend/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/build/torch29-cxx11-cu126-x86_64-linux/stk/backend/autocast.py b/build/torch29-cxx11-cu126-x86_64-linux/stk/backend/autocast.py
new file mode 100644
index 0000000000000000000000000000000000000000..97f6e919a60f3fd579ed0215031008d14111dc96
--- /dev/null
+++ b/build/torch29-cxx11-cu126-x86_64-linux/stk/backend/autocast.py
@@ -0,0 +1,37 @@
+import functools
+import torch
+
+
+def _is_eligible(x):
+ return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64)
+
+
+def _cast(x, dtype):
+ if isinstance(x, torch.Tensor) and _is_eligible(x):
+ return x.to(dtype)
+ elif isinstance(x, map):
+ return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()}
+ elif isinstance(x, list) or isinstance(x, tuple):
+ return type(x)(map(lambda y: _cast(y, dtype), x))
+ return x
+
+
+def custom_fwd(fwd):
+ """Wrap a custom autograd function that always uses autocast dtype."""
+
+ @functools.wraps(fwd)
+ def decorate_fwd(*args, **kwargs):
+ if torch.is_autocast_enabled():
+ with torch.autocast(device_type="cuda", enabled=False):
+ dtype = torch.get_autocast_gpu_dtype()
+ return fwd(*_cast(args, dtype), **_cast(kwargs, dtype))
+ return fwd(*args, **kwargs)
+ return decorate_fwd
+
+
+def custom_bwd(bwd):
+ @functools.wraps(bwd)
+ def decorate_bwd(*args, **kwargs):
+ with torch.autocast(device_type="cuda", enabled=False):
+ return bwd(*args, **kwargs)
+ return decorate_bwd
diff --git a/build/torch29-cxx11-cu126-x86_64-linux/stk/backend/sputnik.py b/build/torch29-cxx11-cu126-x86_64-linux/stk/backend/sputnik.py
new file mode 100644
index 0000000000000000000000000000000000000000..220c947bc1e932e8c77cc30f4069e9930f1aa962
--- /dev/null
+++ b/build/torch29-cxx11-cu126-x86_64-linux/stk/backend/sputnik.py
@@ -0,0 +1,316 @@
+import torch
+
+from ..backend import triton_kernels as backend
+from ..backend.autocast import custom_bwd, custom_fwd
+
+
+def _standardize_shape(x, transpose):
+ if transpose:
+ return torch.Size((x[1], x[0]))
+ return x
+
+
+def _sparse_transpose(x):
+ return (torch.Size((x[0][1], x[0][0])), ) + x[1:]
+
+
+def _transpose_helper(x, transpose):
+ if isinstance(x, torch.Tensor):
+ return x.t() if transpose else x
+ if transpose:
+ x = _sparse_transpose(x)
+ return x + (transpose,)
+
+
+def _wrap(x):
+ if isinstance(x, torch.Tensor):
+ return (x,)
+ return x
+
+
+def _is_transposed(x):
+ return (not x.is_contiguous() and
+ x.stride()[0] == 1 and
+ x.stride()[1] == x.size()[0])
+
+
+def _call_helper(op, out, a, b, trans_a, trans_b):
+ args = (_wrap(_transpose_helper(a, trans_a)) +
+ _wrap(_transpose_helper(b, trans_b)))
+ if isinstance(out, tuple):
+ args = args + out
+ return op(*args)
+
+
+def _preprocess_inputs(lhs, rhs, dy):
+ if isinstance(lhs, torch.Tensor) and _is_transposed(lhs):
+ lhs = lhs.t()
+ if isinstance(rhs, torch.Tensor) and _is_transposed(rhs):
+ rhs = rhs.t()
+ if (isinstance(dy, torch.Tensor) and
+ not dy.is_contiguous() and
+ not _is_transposed(dy)):
+ dy = dy.contiguous()
+ if isinstance(dy, tuple) and not dy[1].is_contiguous():
+ dy = (dy[0], dy[1].contiguous()) + dy[2:]
+ return lhs, rhs, dy
+
+
+def _postprocess_outputs(x, transpose, grad):
+ if isinstance(x, torch.Tensor) and transpose:
+ return grad.t()
+ return grad
+
+
+def _lhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs):
+ lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy)
+
+ a, b = (rhs, dy) if trans_lhs else (dy, rhs)
+ trans_a = trans_lhs and trans_rhs
+ trans_b = trans_lhs or not trans_rhs
+ out = _call_helper(op, lhs, a, b, trans_a, trans_b)
+ return _postprocess_outputs(lhs, trans_lhs, out)
+
+
+def _rhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs):
+ lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy)
+
+ a, b = (dy, lhs) if trans_rhs else (lhs, dy)
+ trans_a = not trans_lhs or trans_rhs
+ trans_b = trans_lhs and trans_rhs
+ out = _call_helper(op, rhs, a, b, trans_a, trans_b)
+ return _postprocess_outputs(rhs, trans_rhs, out)
+
+
+class DSD(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs):
+ ctx.save_for_backward(data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ rhs)
+ ctx.shape = _standardize_shape(shape, transpose_a)
+ ctx.transpose_a = transpose_a
+
+ out = torch.empty(
+ (shape[0], rhs.size()[1]),
+ dtype=rhs.dtype,
+ device=rhs.device)
+
+ backend.dsd(shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs,
+ out)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs = (ctx.shape,) + saved_tensors[:-1]
+ rhs = saved_tensors[-1]
+ trans_a = ctx.transpose_a
+ trans_b = _is_transposed(rhs)
+
+ ddata = None
+ if ctx.needs_input_grad[1]:
+ ddata = _lhs_gradient(sdd,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ drhs = None
+ if ctx.needs_input_grad[-1]:
+ op = dds if trans_b else dsd
+ drhs = _rhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return None, ddata, None, None, None, None, None, None, None, drhs
+
+
+dsd = DSD.apply
+
+
+class DDS(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b):
+ ctx.save_for_backward(lhs,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t)
+ ctx.shape = _standardize_shape(shape, transpose_b)
+ ctx.transpose_b = transpose_b
+ out = torch.empty((lhs.size()[0], shape[1]),
+ dtype=lhs.dtype,
+ device=lhs.device)
+ backend.dds(lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b,
+ out)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs = saved_tensors[0]
+ rhs = (ctx.shape,) + saved_tensors[1:]
+ trans_a = _is_transposed(lhs)
+ trans_b = ctx.transpose_b
+
+ dlhs = None
+ if ctx.needs_input_grad[0]:
+ op = dsd if trans_a else dds
+ dlhs = _lhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ ddata = None
+ if ctx.needs_input_grad[2]:
+ ddata = _rhs_gradient(sdd,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return dlhs, None, ddata, None, None, None, None, None, None, None
+
+
+dds = DDS.apply
+
+
+class SDD(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ lhs,
+ rhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t):
+ ctx.save_for_backward(
+ lhs,
+ rhs,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t)
+ ctx.shape = shape
+ out = torch.empty(
+ data.shape,
+ dtype=lhs.dtype,
+ device=lhs.device)
+ backend.sdd(lhs,
+ rhs,
+ shape,
+ out,
+ offsets,
+ row_indices,
+ column_indices)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs, rhs = saved_tensors[:2]
+ dy = (ctx.shape, dy) + saved_tensors[2:]
+ trans_a = _is_transposed(lhs)
+ trans_b = _is_transposed(rhs)
+
+ dlhs = None
+ if ctx.needs_input_grad[0]:
+ op = dds if trans_a else dsd
+ dlhs = _lhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ drhs = None
+ if ctx.needs_input_grad[1]:
+ op = dsd if trans_b else dds
+ drhs = _rhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return dlhs, drhs, None, None, None, None, None, None, None, None
+
+
+sdd = SDD.apply
+
+class RowIndices(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, shape, data, offsets, column_indices):
+ out = torch.empty(
+ column_indices.shape,
+ dtype=column_indices.dtype,
+ device=column_indices.device)
+ backend.row_indices(shape, data, offsets, column_indices, out)
+ return out
+
+
+row_indices = RowIndices.apply
diff --git a/build/torch29-cxx11-cu126-x86_64-linux/stk/backend/triton_kernels.py b/build/torch29-cxx11-cu126-x86_64-linux/stk/backend/triton_kernels.py
new file mode 100644
index 0000000000000000000000000000000000000000..c535309f3321249f475367164a558f94a4f8eb86
--- /dev/null
+++ b/build/torch29-cxx11-cu126-x86_64-linux/stk/backend/triton_kernels.py
@@ -0,0 +1,393 @@
+import torch
+import triton
+import triton.language as tl
+from dataclasses import dataclass
+
+@dataclass
+class TritonConfig:
+ BLOCK_M: int = 128
+ BLOCK_N: int = 128
+ BLOCK_K: int = 32
+ BLOCK_SIZE: int = 128
+ NUM_STAGES: int = 4
+ NUM_WARPS: int = 4
+
+def _validate_matmul_dims(M: int, K: int, N: int):
+ error_string = "incompatible dimensions: tensor has dim with length: {}, which must be divisible by {}"
+ assert M % TritonConfig.BLOCK_M == 0, error_string.format(M, TritonConfig.BLOCK_M)
+ assert K % TritonConfig.BLOCK_K == 0, error_string.format(K, TritonConfig.BLOCK_K)
+ assert N % TritonConfig.BLOCK_N == 0, error_string.format(N, TritonConfig.BLOCK_N)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _sdd_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+ # matrix multiplication
+ pid = tl.program_id(0)
+ pid_m = tl.load(row_indices + pid)
+ pid_n = tl.load(column_indices + pid)
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
+ rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
+ rk = tl.arange(0, BLOCK_K)
+ # pointers
+ A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
+ B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ for k in range(0, tl.cdiv(K, BLOCK_K)):
+ a = tl.load(A)
+ b = tl.load(B)
+ acc += tl.dot(a, b)
+ A += BLOCK_K * stride_ak
+ B += BLOCK_K * stride_bk
+ #Store to sparse matrix
+ acc = acc.to(C.dtype.element_ty)
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+ cm = tl.arange(0, BLOCK_M)
+ cn = tl.arange(0, BLOCK_N)
+ C = C + pid * BLOCK_ELEMENTS + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _dsd_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices, offsets,
+ block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+
+ # matrix multiplication
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+
+ num_pid_m = tl.num_programs(0)
+ num_pid_n = tl.num_programs(1)
+ pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M)
+
+ start_inx = tl.load(offsets + pid_m)
+ end_inx = tl.load(offsets + pid_m + 1)
+
+ # pointers to sparse matrix
+ rm = tl.arange(0, BLOCK_M)
+ rak = tl.arange(0, BLOCK_K)
+
+ A += (rm[:, None] * stride_am + rak[None, :] * stride_ak)
+
+ # pointers to dense matrix
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ rbk = tl.arange(0, BLOCK_K)
+ B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn)
+
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K)
+
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+ ak_sub_incr = BLOCK_K * stride_ak
+ bk_sub_incr = BLOCK_K * stride_bk
+ bk_block_incr = BLOCK_SIZE * stride_bk
+
+ for k in range(nsub_blocks * (end_inx - start_inx)):
+ sub_block_inx = k % nsub_blocks
+ block_inx = k // nsub_blocks
+
+ if trans_A:
+ ptr_A = A + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr
+ else:
+ ptr_A = A + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr
+
+ ptr_B = B + tl.load(column_indices + start_inx + block_inx) * bk_block_incr + sub_block_inx * bk_sub_incr
+
+ a = tl.load(ptr_A)
+ b = tl.load(ptr_B)
+ acc += tl.dot(a, b)
+
+ acc = acc.to(C.dtype.element_ty)
+
+ cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+
+ C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _dds_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices, offsets,
+ block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+
+ # matrix multiplication
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+
+ num_pid_m = tl.num_programs(0)
+ num_pid_n = tl.num_programs(1)
+ pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M)
+
+ start_inx = tl.load(offsets + pid_n)
+ end_inx = tl.load(offsets + pid_n + 1)
+
+ # pointers to dense matrix
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ rak = tl.arange(0, BLOCK_K)
+
+ A += (rm[:, None] * stride_am + rak[None, :] * stride_ak)
+
+ # pointers to sparse matrix
+ rn = tl.arange(0, BLOCK_N)
+ rbk = tl.arange(0, BLOCK_K)
+ B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn)
+
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K)
+
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+
+ ak_sub_incr = BLOCK_K * stride_ak
+ ak_block_incr = BLOCK_SIZE * stride_ak
+ bk_sub_incr = BLOCK_K * stride_bk
+
+ for k in range(nsub_blocks * (end_inx - start_inx)):
+ sub_block_inx = k % nsub_blocks
+ block_inx = k // nsub_blocks
+
+ if trans_B:
+ ptr_B = B + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr
+ else:
+ ptr_B = B + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr
+
+ ptr_A = A + tl.load(column_indices + start_inx + block_inx) * ak_block_incr + sub_block_inx * ak_sub_incr
+ a = tl.load(ptr_A)
+ b = tl.load(ptr_B)
+ acc += tl.dot(a, b)
+
+ acc = acc.to(C.dtype.element_ty)
+ cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+def dsd(shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs,
+ out
+ ):
+
+ device = rhs.device
+ trans_A = transpose_a
+ trans_B = False
+
+ if rhs.stride(0) > 1 and rhs.stride(1) > 1:
+ trans_B = True
+
+ # checks constraints
+ assert shape[1] == rhs.shape[0], "incompatible dimensions"
+ M, K = shape
+ _, N = rhs.shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if rhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ stride_am, stride_ak = data.stride(1), data.stride(2)
+ stride_bk, stride_bn = rhs.stride(0), rhs.stride(1)
+ a_column_indices = column_indices
+ a_offsets = offsets
+
+ # launch kernel
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N']))
+
+ if trans_A:
+ stride_am, stride_ak = data.stride(2), data.stride(1)
+ a_column_indices, a_offsets = column_indices_t, offsets_t
+
+ if trans_B:
+ stride_bk, stride_bn = rhs.stride(1), rhs.stride(0)
+
+ _dsd_kernel[grid](
+ data.data, rhs, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(0), out.stride(1),
+ row_indices, a_column_indices, a_offsets,
+ block_offsets_t, trans_A, trans_B,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+ # return out
+
+def dds(lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b,
+ out
+ ):
+
+ device = lhs.device
+ trans_B = transpose_b
+ trans_A = False
+
+ if lhs.stride(0) > 1 and lhs.stride(1) > 1:
+ trans_A = True
+
+ # checks constraints
+ assert lhs.shape[1] == shape[0], "incompatible dimensions"
+ M, K = lhs.shape
+ _, N = shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if lhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ stride_am, stride_ak = lhs.stride(0), lhs.stride(1)
+ stride_bk, stride_bn = data.stride(1), data.stride(2)
+ b_column_indices = column_indices_t
+ b_offsets = offsets_t
+
+ # launch kernel
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N']))
+
+ if trans_A:
+ stride_am, stride_ak = lhs.stride(1), lhs.stride(0)
+ if trans_B:
+ stride_bk, stride_bn = data.stride(2), data.stride(1)
+ b_column_indices, b_offsets = column_indices, offsets
+
+ _dds_kernel[grid](
+ lhs, data, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(0), out.stride(1),
+ row_indices, b_column_indices, b_offsets,
+ block_offsets_t, trans_A, trans_B,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+
+def sdd(lhs,
+ rhs,
+ shape,
+ out,
+ offsets,
+ row_indices,
+ column_indices
+ ):
+
+ device = out.device
+ trans_A = False
+ trans_B = False
+
+ if lhs.stride(0) > 1 and lhs.stride(1) > 1:
+ trans_A = True
+ if rhs.stride(0) > 1 and rhs.stride(1) > 1:
+ trans_B = True
+
+ # checks constraints
+ assert lhs.shape[1] == rhs.shape[0], "incompatible dimensions"
+ M, K = lhs.shape
+ _, N = rhs.shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if out.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ # launch kernel
+ nnz_blocks = len(row_indices)
+ grid = lambda META: (nnz_blocks,)
+
+ stride_am, stride_ak = lhs.stride(0), lhs.stride(1)
+ stride_bk, stride_bn = rhs.stride(0), rhs.stride(1)
+
+ if trans_A:
+ stride_am, stride_ak = lhs.stride(1), lhs.stride(0)
+ if trans_B:
+ stride_bk, stride_bn = rhs.stride(1), rhs.stride(0)
+
+ _sdd_kernel[grid](
+ lhs, rhs, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(1), out.stride(2),
+ row_indices, column_indices,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+
+@triton.jit
+def _row_indices_kernel(offsets, out):
+ pid = tl.program_id(0)
+ row_offset = tl.load(offsets + pid)
+ nnz_blocks = tl.load(offsets + pid + 1) - row_offset
+ for nnz_block in range(nnz_blocks):
+ tl.store(out + row_offset + nnz_block, pid)
+
+def row_indices(
+ shape, data, offsets, column_indices, out
+):
+ block_rows = len(offsets) - 1
+ _row_indices_kernel[(block_rows, )](offsets, out)
diff --git a/build/torch29-cxx11-cu126-x86_64-linux/stk/matrix.py b/build/torch29-cxx11-cu126-x86_64-linux/stk/matrix.py
new file mode 100644
index 0000000000000000000000000000000000000000..80f42263d6aada287adbfa52a61fe950162a9e28
--- /dev/null
+++ b/build/torch29-cxx11-cu126-x86_64-linux/stk/matrix.py
@@ -0,0 +1,329 @@
+import numpy as np
+import torch
+
+# 1. Add heavyweight (data) validation helper.
+# 2. Add construction helpers
+# 3. Make indentation consistent
+# 4. Replace asserts with descriptive errors.
+
+##
+### Validation helpers.
+##
+
+
+def _validate_matrix(shape, data, row_indices, column_indices, offsets):
+ # Data should be [nnz, block_size, block_size]
+ if data.dim() == 1:
+ data = torch.reshape(data, [data.numel(), 1, 1])
+
+ # Blocks should be square.
+ if data.shape[-2] != data.shape[-1]:
+ raise ValueError(
+ "Expected square blocking in data. "
+ f"Got block shape {[data.shape[-2], data.shape[-1]]}")
+
+ # Flatten batch dimensions on data - original shape preserved
+ # in shape argument.
+ block_size = data.shape[-1]
+ data = data.view([-1, block_size, block_size])
+
+ if data.dim() != 3:
+ raise ValueError(
+ "Expected 3D shape for data (nnz, block, block). "
+ f"Got shape {data.dim()}D shape.")
+
+ block_size = data.shape[1]
+ if shape[-2] % block_size != 0 or shape[-1] % block_size != 0:
+ raise ValueError(
+ "Matrix shape must be dividible by blocking. "
+ f"Got shape {shape} with "
+ f"{[block_size, block_size]} blocking.")
+
+ if np.prod(shape) < data.numel():
+ raise ValueError(
+ "Invalid matrix. Number of nonzeros exceeds matrix capacity "
+ f"({data.numel()} v. {np.prod(shape)})")
+
+ if row_indices.dim() != 1:
+ raise ValueError(
+ f"Expected 1D row_indices. Got {row_indices.dim()}D row_indices.")
+
+ if column_indices.dim() != 1:
+ raise ValueError(
+ f"Expected 1D column_indices. Got {column_indices.dim()}D column_indices.")
+
+ if offsets.dim() != 1:
+ raise ValueError(
+ f"Expected 1D offsets. Got {offsets.dim()}D offsets.")
+
+ if row_indices.numel() != data.shape[0]:
+ raise ValueError(
+ "Expected 1 index per nonzero block. "
+ f"Got {row_indices.numel()} row_indices for {data.shape[0]} blocks")
+
+ if column_indices.numel() != data.shape[0]:
+ raise ValueError(
+ "Expected 1 index per nonzero block. "
+ f"Got {column_indices.numel()} column_indices for {data.shape[0]} blocks")
+
+ block_rows = np.prod(shape[:-1]) / block_size
+ if offsets.numel() != block_rows + 1:
+ raise ValueError(
+ "Expected one offset per block row plus one. "
+ f"Got {offsets.numel()} offsets with {block_rows} block rows.")
+
+ is_cuda = (data.is_cuda and
+ row_indices.is_cuda and
+ column_indices.is_cuda and
+ offsets.is_cuda)
+ is_cpu = (not data.is_cuda and
+ not row_indices.is_cuda and
+ not column_indices.is_cuda and
+ not offsets.is_cuda)
+ if not (is_cuda or is_cpu):
+ raise ValueError(
+ "Expected data & meta-data on common device. "
+ f"Got data on {data.device}, row_indices on {row_indices.device} "
+ f"column_indices on {column_indices.device} and "
+ f"offsets on {offsets.device}.")
+
+ if data.dtype != torch.float16:
+ raise ValueError(
+ f"Expected float16 data. Got {data.dtype} data.")
+ if row_indices.dtype != torch.int16:
+ raise ValueError(
+ f"Expected int16 row_indices. Got {row_indices.dtype} row_indices.")
+ if column_indices.dtype != torch.int16:
+ raise ValueError(
+ f"Expected int16 column_indices. Got {column_indices.dtype} column_indices.")
+ if offsets.dtype != torch.int32:
+ raise ValueError(
+ f"Expected int32 offsets. Got {offsets.dtype} offsets.")
+ return data
+
+
+def _transpose(size, data, row_indices, column_indices, offsets):
+ block_columns = size[1] // data.shape[1]
+
+ # Sort row indices by column indices to get the transposed matrix's
+ # column indices.
+ gather_indices = column_indices.argsort()
+ column_indices_t = row_indices.gather(0, gather_indices)
+ block_offsets_t = gather_indices.int()
+
+ # NOTE: Histogram is not implemented for any integer type on CPU. Do
+ # the histogram in 32-bit float, which can exactly represent 16-bit
+ # integers.
+ column_indices_float = column_indices.float()
+
+ zero = torch.zeros((1,), dtype=torch.int32, device=data.device)
+ nnz_per_column = column_indices_float.histc(block_columns, 0, block_columns)
+ nnz_per_column = nnz_per_column.int()
+ offsets_t = torch.cat([zero, nnz_per_column.cumsum(0, dtype=torch.int32)])
+ return column_indices_t, offsets_t, block_offsets_t
+
+
+class Matrix(torch.nn.Module):
+ """A matrix stored in sparse format.
+
+ Underlying format is block compressed sparse row (BCSR).
+
+ TODO(tgale): Make this mirror torch.Tensor API as much as possible.
+ """
+
+ def __init__(self,
+ size,
+ data,
+ row_indices,
+ column_indices,
+ offsets,
+ column_indices_t=None,
+ offsets_t=None,
+ block_offsets_t=None):
+ super().__init__()
+ self._size = size
+ self._data = data
+ self._row_indices = row_indices
+ self._column_indices = column_indices
+ self._offsets = offsets
+
+ # Produce the transpose meta-data if it is not passed in.
+ if ((column_indices_t is None) or (offsets_t is None) or
+ (block_offsets_t is None)):
+ column_indices_t, offsets_t, block_offsets_t = _transpose(
+ size, data, row_indices, column_indices, offsets)
+ self._column_indices_t = column_indices_t
+ self._offsets_t = offsets_t
+ self._block_offsets_t = block_offsets_t
+
+ self._transposed = False
+
+ # Validate that our metadata will not overflow.
+ max_dim = np.iinfo(np.int16).max * self.blocking
+ if column_indices.dtype == torch.int16:
+ if size[0] > max_dim or size[1] > max_dim:
+ raise ValueError(
+ "Sparse matrix with shape {size} exceeds representable "
+ "size with 16-bit indices.")
+
+ def validate(self):
+ _validate_matrix(self._size,
+ self._data,
+ self._row_indices,
+ self._column_indices,
+ self._offsets)
+
+ # TODO(tgale): Add heavyweight data validation.
+
+ def to(self, device):
+ # TODO(tgale): Handle type conversions here. We
+ # need to set the appropriate meta-data type for
+ # the given floating-point type.
+ self._data = self._data.to(device)
+ self._row_indices = self._row_indices.to(device)
+ self._column_indices = self._column_indices.to(device)
+ self._offsets = self._offsets.to(device)
+ self._column_indices_t = self._column_indices_t.to(device)
+ self._offsets_t = self._offsets_t.to(device)
+ self._block_offsets_t = self._block_offsets_t.to(device)
+ return self
+
+ def cuda(self):
+ return self.to(torch.cuda.current_device())
+
+ def clone(self):
+ return Matrix(
+ self.size(),
+ self.data.clone(),
+ self.row_indices.clone(),
+ self.column_indices.clone(),
+ self.offsets.clone(),
+ self.column_indices_t.clone(),
+ self.offsets_t.clone(),
+ self.block_offsets_t.clone())
+
+ def t(self):
+ if self.dim() != 2:
+ raise ValueError(
+ "t() expects a tensor with <= 2 dimensions, "
+ f"but self is {self.dim()}D.")
+ out = Matrix(self.size(),
+ self.data,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+ out._transposed = not self._transposed
+ out._size = torch.Size((self._size[1], self._size[0]))
+ return out
+
+ def contiguous(self):
+ raise ValueError("Not yet implemented.")
+
+ def is_contiguous(self):
+ return not self._transposed
+
+ @property
+ def is_cuda(self):
+ return self._data.is_cuda
+
+ @property
+ def device(self):
+ return self._data.device
+
+ def size(self):
+ return self._size
+
+ @property
+ def shape(self):
+ return self.size()
+
+ def dim(self):
+ return len(self._size)
+
+ @property
+ def data(self):
+ return self._data
+
+ @property
+ def row_indices(self):
+ return self._row_indices
+
+ @property
+ def column_indices(self):
+ return self._column_indices
+
+ @property
+ def offsets(self):
+ return self._offsets
+
+ @property
+ def offsets_t(self):
+ return self._offsets_t
+
+ @property
+ def column_indices_t(self):
+ return self._column_indices_t
+
+ @property
+ def block_offsets_t(self):
+ return self._block_offsets_t
+
+ @property
+ def dtype(self):
+ return self.data.dtype
+
+ @property
+ def nnz(self):
+ return self.data.numel()
+
+ @property
+ def blocking(self):
+ return self.data.shape[1]
+
+ @property
+ def requires_grad(self):
+ return self.data.requires_grad
+
+ def requires_grad_(self, x):
+ self.data.requires_grad_(x)
+ return self
+
+ def view(self, *shape):
+ assert self.is_contiguous()
+ if shape[-1] != self.size()[-1]:
+ raise ValueError(
+ "Can't change view on compressed dimension. "
+ f"{self.size()[-1]} v. {shape[-1]}.")
+ if np.prod(shape) != np.prod(self.size()):
+ raise ValueError(
+ "Mismatch in numel of Matrix and new shape. "
+ f"{np.prod(self.size())} v. {np.prod(shape)}")
+ return Matrix(shape,
+ self.data,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+
+ @property
+ def grad(self):
+ # TODO(tgale): Make sure this mirrors torch.Tensor
+ # behavior in the case where we ask for the gradient
+ # of a non-contiguous tensor.
+ size = self.size()
+ if not self.is_contiguous():
+ size = torch.Size((size[1], size[0]))
+ out = Matrix(size,
+ self.data.grad,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+ return out if self.is_contiguous() else out.t()
diff --git a/build/torch29-cxx11-cu126-x86_64-linux/stk/ops/__init__.py b/build/torch29-cxx11-cu126-x86_64-linux/stk/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc873b236f4cd4036964c016a4036e3ce5ebf1ac
--- /dev/null
+++ b/build/torch29-cxx11-cu126-x86_64-linux/stk/ops/__init__.py
@@ -0,0 +1,3 @@
+from .linear_ops import dds, dsd, sdd
+from .matrix_ops import ones_like, row_indices, sum, to_dense, to_sparse
+from .eltwise_ops import mul
diff --git a/build/torch29-cxx11-cu126-x86_64-linux/stk/ops/eltwise_ops.py b/build/torch29-cxx11-cu126-x86_64-linux/stk/ops/eltwise_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba7d7332320250fd01fa60e60528f19de3e8ed03
--- /dev/null
+++ b/build/torch29-cxx11-cu126-x86_64-linux/stk/ops/eltwise_ops.py
@@ -0,0 +1,28 @@
+from ..matrix import Matrix
+
+def mul(a, b):
+ """Performs element-wise multiplication of matrices a and b.
+
+ It is the user's responsibility to make sure that a and b
+ follow the same matrix topology. This function assumes it is safe
+ to use the topoplogy of a.
+
+ Args:
+ a: stk.Matrix.
+ b: stk.Matrix with a's matrix topology.
+
+ Returns:
+ stk.Matrix where the entries correspond to torch.mul(a, b).
+ """
+ assert isinstance(a, Matrix)
+ assert isinstance(b, Matrix)
+ assert a.size() == b.size()
+
+ return Matrix(a.size(),
+ a.data * b.data,
+ a.row_indices,
+ a.column_indices,
+ a.offsets,
+ a.column_indices_t,
+ a.offsets_t,
+ a.block_offsets_t)
diff --git a/build/torch29-cxx11-cu126-x86_64-linux/stk/ops/eltwise_ops_test.py b/build/torch29-cxx11-cu126-x86_64-linux/stk/ops/eltwise_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..66bfd4f6af77042d3c5bdb1fe18d00e457478d46
--- /dev/null
+++ b/build/torch29-cxx11-cu126-x86_64-linux/stk/ops/eltwise_ops_test.py
@@ -0,0 +1,86 @@
+import unittest
+import itertools
+import torch
+from absl.testing import parameterized
+
+import stk
+from stk.ops.linear_ops_test import allclose, _dense_and_sparse
+
+_MATRIX_SIZES = (
+ (128, 128, 0.0),
+ (256, 256, 0.5),
+ (2048, 1024, 0.8),
+ (512, 128, 0.0),
+ (128, 512, 0.0),
+ (1024, 512, 0.0),
+ (1024, 512, 0.5),
+ (1024, 512, 0.75),
+ (512, 1024, 0.0),
+ (512, 1024, 0.5),
+ (512, 1024, 0.75),
+ (1024, 1024, 0.0),
+ (1024, 1024, 0.5),
+ (1024, 1024, 0.75),
+)
+
+_DTYPE = (
+ torch.float16, torch.bfloat16
+)
+
+def _generate_testcases():
+ testcases = itertools.product(_MATRIX_SIZES, _DTYPE)
+ testcases = [(*size, 128, dtype) for
+ (size, dtype) in testcases]
+ return testcases
+
+_ELTWISE_OP_TESTS = _generate_testcases()
+
+def _dense_and_sparse_like(x, std=0.1):
+ dense_data = torch.randn_like(x.data, device=x.device) * std
+ sparse = stk.Matrix(x.size(),
+ dense_data,
+ x.row_indices,
+ x.column_indices,
+ x.offsets)
+ dense = stk.ops.to_dense(sparse)
+
+ return (dense.requires_grad_(True),
+ sparse.requires_grad_(True))
+
+@parameterized.parameters(_ELTWISE_OP_TESTS)
+class EltwiseOpsTest(parameterized.TestCase):
+
+ def testEltwiseMul(self, m, n, sparsity, blocking, dtype):
+
+ a_dense, a = _dense_and_sparse(m, n, sparsity, blocking, dtype)
+ b_dense, b = _dense_and_sparse_like(a)
+
+ out = stk.ops.mul(a, b)
+ expected_out = torch.mul(a_dense, b_dense)
+
+ # Compute the gradients w.r.t. the inputs.
+ expected_out.sum().backward()
+ stk.ops.sum(out).backward()
+
+ # Validate the results.
+ out = stk.ops.to_dense(out)
+ self.assertEqual(out.dim(), 2)
+ self.assertEqual(expected_out.size(), out.size())
+ self.assertTrue(allclose(out, expected_out))
+
+ # LHS gradient.
+ grad = stk.ops.to_dense(a.grad)
+ expected_grad = a_dense.grad
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size(), grad.size())
+ self.assertTrue(allclose(grad, expected_grad))
+
+ # RHS gradient.
+ grad = stk.ops.to_dense(b.grad)
+ expected_grad = b_dense.grad
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size(), grad.size())
+ self.assertTrue(allclose(grad, expected_grad))
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch29-cxx11-cu126-x86_64-linux/stk/ops/linear_ops.py b/build/torch29-cxx11-cu126-x86_64-linux/stk/ops/linear_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d277c8c07f9e30addc31900a12175c8a1f4d7ad
--- /dev/null
+++ b/build/torch29-cxx11-cu126-x86_64-linux/stk/ops/linear_ops.py
@@ -0,0 +1,59 @@
+import torch
+
+from ..backend import sputnik
+from ..matrix import Matrix
+
+
+def dsd(a, b):
+ assert isinstance(a, Matrix)
+ assert isinstance(b, torch.Tensor)
+ return sputnik.dsd(
+ a.size(),
+ a.data, a.offsets,
+ a.row_indices,
+ a.column_indices,
+ a.offsets_t,
+ a.column_indices_t,
+ a.block_offsets_t,
+ not a.is_contiguous(),
+ b)
+
+
+def dds(a, b):
+ assert isinstance(a, torch.Tensor)
+ assert isinstance(b, Matrix)
+ return sputnik.dds(
+ a,
+ b.size(),
+ b.data, b.offsets,
+ b.row_indices,
+ b.column_indices,
+ b.offsets_t,
+ b.column_indices_t,
+ b.block_offsets_t,
+ not b.is_contiguous())
+
+
+def sdd(a, b, topo):
+ assert isinstance(a, torch.Tensor)
+ assert isinstance(b, torch.Tensor)
+ assert isinstance(topo, Matrix)
+ assert topo.is_contiguous()
+ out = sputnik.sdd(
+ a, b,
+ topo.size(),
+ topo.data,
+ topo.offsets,
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets_t,
+ topo.column_indices_t,
+ topo.block_offsets_t)
+ return Matrix(topo.size(),
+ out,
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets,
+ topo.column_indices_t,
+ topo.offsets_t,
+ topo.block_offsets_t)
diff --git a/build/torch29-cxx11-cu126-x86_64-linux/stk/ops/linear_ops_test.py b/build/torch29-cxx11-cu126-x86_64-linux/stk/ops/linear_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..ced1d782fbc9f9ca16b3449239f1588dc5ff5e00
--- /dev/null
+++ b/build/torch29-cxx11-cu126-x86_64-linux/stk/ops/linear_ops_test.py
@@ -0,0 +1,216 @@
+import unittest
+import itertools
+import numpy as np
+import torch
+from absl.testing import parameterized
+
+import stk
+
+
+def allclose(x, y, pct=0.25):
+ mask = torch.isclose(x, y, rtol=5e-2)
+ pct_diff = (mask.numel() - mask.sum()) / mask.numel() * 100
+ if pct_diff > pct:
+ print("{:.2f}% of values not close.".format(pct_diff))
+ return False
+ return True
+
+
+# An assortment of problems designed to make sure
+# the bindings are operating correctly.
+_MATRIX_SIZES = (
+ (128, 128, 128, 0.0),
+ (256, 256, 256, 0.5),
+ (2048, 1024, 512, 0.8),
+ (512, 128, 128, 0.0),
+ (128, 128, 512, 0.0),
+ (1024, 512, 512, 0.0),
+ (1024, 512, 512, 0.5),
+ (1024, 512, 512, 0.75),
+ (512, 512, 1024, 0.0),
+ (512, 512, 1024, 0.5),
+ (512, 512, 1024, 0.75),
+ (1024, 1024, 1024, 0.0),
+ (1024, 1024, 1024, 0.5),
+ (1024, 1024, 1024, 0.75),
+)
+
+_TRANSPOSE = (
+ (False, False),
+ (False, True),
+ (True, False),
+ (True, True),
+)
+
+_DTYPE = (
+ torch.float16, torch.bfloat16
+)
+
+def _generate_testcases():
+ testcases = itertools.product(_MATRIX_SIZES, _TRANSPOSE, _DTYPE)
+ testcases = [(*size, *trans, 128, dtype) for
+ (size, trans, dtype) in testcases]
+ return testcases
+
+_LINEAR_OP_TESTS = _generate_testcases()
+
+def _dense_and_sparse(rows, cols, sparsity, blocking, dtype, std=0.1):
+ mask = stk.random.dense_mask(rows, cols, sparsity, blocking)
+ dense = (torch.randn(rows, cols) * std * mask).type(dtype)
+ sparse = stk.ops.to_sparse(dense, blocking)
+ cuda_device = torch.device("cuda")
+ return (dense.to(cuda_device).requires_grad_(True),
+ sparse.to(cuda_device).requires_grad_(True))
+
+
+def _dense(rows, cols, dtype, std=0.1):
+ cuda_device = torch.device("cuda")
+ out = (torch.randn(rows, cols) * std).type(dtype)
+ return out.to(cuda_device).requires_grad_(True)
+
+
+def _dense_2x(rows, cols, dtype):
+ a = _dense(rows, cols, dtype)
+ return a, a.detach().requires_grad_(True)
+
+
+def _with_transpose(op, a, b, trans_a, trans_b):
+ a = a.t() if trans_a else a
+ b = b.t() if trans_b else b
+ return op(a, b)
+
+
+def _mmm(a, b, topo):
+ mask = stk.ops.to_dense(stk.ops.ones_like(topo))
+ return torch.mm(a, b) * mask
+
+
+def _sparse_out_with_transpose(op, a, b, topo, trans_a, trans_b):
+ a = a.t() if trans_a else a
+ b = b.t() if trans_b else b
+ return op(a, b, topo)
+
+
+def _mask(x, mask):
+ mask = stk.ops.to_dense(stk.ops.ones_like(mask))
+ return x * mask
+
+
+@parameterized.parameters(*_LINEAR_OP_TESTS)
+class LinearOpsTest(parameterized.TestCase):
+
+ def testLinearOps_Dsd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+ # Construct the operands.
+ a_shape = (k, m) if trans_a else (m, k)
+ a_dense, a = _dense_and_sparse(*a_shape, sparsity, blocking, dtype)
+ b_shape = (n, k) if trans_b else (k, n)
+ b, bcp = _dense_2x(*b_shape, dtype)
+
+ # Execute the matmul.
+ out = _with_transpose(stk.ops.dsd, a, b, trans_a, trans_b)
+ expected_out = _with_transpose(torch.mm, a_dense, bcp, trans_a, trans_b)
+
+ # Compute the gradients w.r.t. the inputs.
+ expected_out.sum().backward()
+ out.sum().backward()
+
+ # Validate the results.
+ self.assertEqual(out.dim(), 2)
+ self.assertEqual(expected_out.size()[0], out.size()[0])
+ self.assertEqual(expected_out.size()[1], out.size()[1])
+ self.assertTrue(allclose(out, expected_out))
+
+ # LHS gradient.
+ grad = stk.ops.to_dense(a.grad)
+ expected_grad = _mask(a_dense.grad, a.grad)
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
+ self.assertTrue(allclose(grad, expected_grad))
+
+ # RHS gradient.
+ grad = b.grad
+ expected_grad = bcp.grad
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
+ self.assertTrue(allclose(grad, expected_grad))
+
+ def testLinearOps_Dds(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+ # Construct the operands.
+ a_shape = (k, m) if trans_a else (m, k)
+ a, acp = _dense_2x(*a_shape, dtype)
+ b_shape = (n, k) if trans_b else (k, n)
+ b_dense, b = _dense_and_sparse(*b_shape, sparsity, blocking, dtype)
+
+ # Execute the matmul.
+ out = _with_transpose(stk.ops.dds, a, b, trans_a, trans_b)
+ expected_out = _with_transpose(torch.mm, acp, b_dense, trans_a, trans_b)
+
+ # Compute the gradients w.r.t. the inputs.
+ expected_out.sum().backward()
+ out.sum().backward()
+
+ # Validate the results.
+ self.assertEqual(out.dim(), 2)
+ self.assertEqual(expected_out.size()[0], out.size()[0])
+ self.assertEqual(expected_out.size()[1], out.size()[1])
+ self.assertTrue(allclose(out, expected_out))
+
+ # LHS gradient.
+ grad = a.grad
+ expected_grad = acp.grad
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
+ self.assertTrue(allclose(grad, expected_grad))
+
+ # RHS gradient.
+ grad = stk.ops.to_dense(b.grad)
+ expected_grad = _mask(b_dense.grad, b.grad)
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
+ self.assertTrue(allclose(grad, expected_grad))
+
+ def testLinearOps_Sdd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+ # Construct the operands.
+ a_shape = (k, m) if trans_a else (m, k)
+ a, acp = _dense_2x(*a_shape, dtype)
+ b_shape = (n, k) if trans_b else (k, n)
+ b, bcp = _dense_2x(*b_shape, dtype)
+ _, topo = _dense_and_sparse(m, n, sparsity, blocking, dtype)
+
+ # Execute the matmul.
+ out = _sparse_out_with_transpose(stk.ops.sdd, a, b, topo, trans_a, trans_b)
+ expected_out = _sparse_out_with_transpose(_mmm, acp, bcp, topo, trans_a, trans_b)
+
+ # Compute the gradients w.r.t. the inputs.
+ expected_out.sum().backward()
+ stk.ops.sum(out).backward()
+
+ # Validate the results.
+ out = stk.ops.to_dense(out)
+ self.assertEqual(out.dim(), 2)
+ self.assertEqual(expected_out.size()[0], out.size()[0])
+ self.assertEqual(expected_out.size()[1], out.size()[1])
+ self.assertTrue(allclose(out, expected_out))
+
+ # LHS gradient.
+ grad = a.grad
+ expected_grad = acp.grad
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
+ self.assertTrue(allclose(grad, expected_grad))
+
+ # RHS gradient.
+ grad = b.grad
+ expected_grad = bcp.grad
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
+ self.assertTrue(allclose(grad, expected_grad))
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch29-cxx11-cu126-x86_64-linux/stk/ops/matrix_ops.py b/build/torch29-cxx11-cu126-x86_64-linux/stk/ops/matrix_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..447c72dc73439d84f58c917676cc04e64f13e97d
--- /dev/null
+++ b/build/torch29-cxx11-cu126-x86_64-linux/stk/ops/matrix_ops.py
@@ -0,0 +1,98 @@
+from ..backend import sputnik
+from ..matrix import Matrix
+import torch
+import numpy as np
+
+
+@torch.no_grad()
+def row_indices(shape, data, offsets, column_indices):
+ return sputnik.row_indices(shape, data, offsets, column_indices)
+
+
+# TODO(tgale): Replace this helper with a custom kernel. This operation
+# is much simpler to do than how it's currently implemented.
+@torch.no_grad()
+def _expand_for_blocking(idxs, blocking):
+ # Duplicate for block column dimension.
+ idxs = torch.reshape(idxs, [idxs.size()[0], 1, 2]).repeat(1, blocking, 1)
+
+ # Update the column indices.
+ idxs[:, :, 1] *= blocking
+ idxs[:, :, 1] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking])
+
+ # Duplicate for block row dimension.
+ idxs = torch.reshape(idxs, [idxs.size()[0], 1, blocking, 2])
+ idxs = idxs.repeat(1, blocking, 1, 1)
+
+ # Update the row indices.
+ idxs[:, :, :, 0] *= blocking
+ idxs[:, :, :, 0] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking, 1])
+ idxs = torch.reshape(idxs, [-1, 2])
+ return idxs
+
+
+# TODO(tgale): Add input type checking.
+@torch.no_grad()
+def to_dense(x):
+ assert isinstance(x, Matrix)
+
+ shape = (np.prod(x.shape[:-1]), x.shape[-1])
+ row_idxs = x.row_indices.type(torch.int32)
+ col_idxs = x.column_indices.type(torch.int32)
+ indices = _expand_for_blocking(torch.stack([row_idxs, col_idxs], dim=1), x.blocking)
+ indices = (indices[:, 0] * shape[1] + indices[:, 1]).type(torch.int64)
+
+ out = torch.zeros(shape[0] * shape[1], dtype=x.dtype, device=x.device)
+ out.scatter_(0, indices, x.data.flatten())
+ return out.reshape(x.size())
+
+
+@torch.no_grad()
+def _mask(x, blocking=1):
+ assert x.dim() == 2
+ assert x.size()[0] % blocking == 0
+ assert x.size()[1] % blocking == 0
+ block_rows = x.size()[0] // blocking
+ block_cols = x.size()[1] // blocking
+ x = torch.reshape(x, [block_rows, blocking, block_cols, blocking])
+ x = torch.sum(torch.abs(x), dim=(1, 3))
+ return x != 0
+
+
+# TODO(tgale): Add input type checking.
+@torch.no_grad()
+def to_sparse(x, blocking=1):
+ m = _mask(x, blocking)
+
+ # TODO(tgale): Set to appropriate type for input matrix.
+ row_nnzs = torch.sum(m, dim=1).type(torch.int32)
+ zeros = torch.zeros((1,), dtype=row_nnzs.dtype, device=row_nnzs.device)
+ offsets = torch.cat([zeros, torch.cumsum(row_nnzs, dim=0)])
+ offsets = offsets.type(torch.int32)
+
+ indices = torch.nonzero(m).type(torch.int16)
+ row_indices = indices[:, 0]
+ column_indices = indices[:, 1]
+
+ # Nonzero indices in the dense matrix.
+ nonzero_indices = torch.nonzero(m)
+ nonzero_indices = _expand_for_blocking(nonzero_indices, blocking)
+ nonzero_indices = nonzero_indices[:, 0] * x.size()[1] + nonzero_indices[:, 1]
+
+ # Gather the data and construct the sparse matrix.
+ data = torch.gather(x.flatten(), dim=0, index=nonzero_indices)
+ data = torch.reshape(data, [-1, blocking, blocking])
+ return Matrix(x.size(), data, row_indices, column_indices, offsets)
+
+
+@torch.no_grad()
+def ones_like(x):
+ return Matrix(x.size(),
+ torch.ones_like(x.data),
+ x.row_indices,
+ x.column_indices, x.offsets)
+
+
+def sum(x):
+ assert isinstance(x, Matrix)
+ return x.data.sum()
diff --git a/build/torch29-cxx11-cu126-x86_64-linux/stk/ops/matrix_ops_test.py b/build/torch29-cxx11-cu126-x86_64-linux/stk/ops/matrix_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..3af04c0760483e578f93303dc457415948a2a34c
--- /dev/null
+++ b/build/torch29-cxx11-cu126-x86_64-linux/stk/ops/matrix_ops_test.py
@@ -0,0 +1,62 @@
+import unittest
+
+from absl.testing import parameterized
+import stk
+import torch
+
+
+@parameterized.parameters(
+ (8, 16, 0.0, 1),
+ (8, 16, 0.5, 1),
+ (8, 16, .95, 1),
+ (16, 8, 0.0, 1),
+ (16, 8, 0.5, 1),
+ (16, 8, .95, 1),
+ (8, 16, 0.0, 8),
+ (8, 16, 0.5, 8),
+ (8, 16, 1.0, 8),
+ (16, 8, 0.0, 8),
+ (16, 8, 0.5, 8),
+ (16, 8, 1.0, 8),
+ (128, 256, 0.5, 16),
+ (256, 128, 0.75, 32),
+ (512, 512, .875, 128))
+class MatrixOpsTest(parameterized.TestCase):
+
+ def testMatrixOps_FormatConversion(self, rows, cols, sparsity, blocking):
+ mask = stk.random.dense_mask(rows, cols, sparsity, blocking)
+ x = (torch.randn(rows, cols) * mask).type(torch.float16)
+
+ # Convert the matrix to sparse format.
+ sparse_x = stk.ops.to_sparse(x, blocking)
+
+ # Validate the matrix.
+ sparse_x.validate()
+
+ # Validate the shape.
+ self.assertEqual(sparse_x.dim(), 2)
+ self.assertEqual(sparse_x.size()[0], rows)
+ self.assertEqual(sparse_x.size()[1], cols)
+
+ # Validate the sparsity.
+ numblocks = rows // blocking * cols // blocking
+ nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+ self.assertEqual(sparse_x.nnz, nnz)
+
+ # Convert back to dense format.
+ dense_x = stk.ops.to_dense(sparse_x)
+
+ # Validate the shape.
+ self.assertEqual(dense_x.dim(), 2)
+ self.assertEqual(dense_x.size()[0], rows)
+ self.assertEqual(dense_x.size()[1], cols)
+
+ # Validate the sparsity
+ self.assertEqual(torch.count_nonzero(dense_x).item(), nnz)
+
+ # Validate the output.
+ self.assertTrue(torch.all(torch.eq(x, dense_x)))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch29-cxx11-cu126-x86_64-linux/stk/random/__init__.py b/build/torch29-cxx11-cu126-x86_64-linux/stk/random/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2576d1ca27283f77569a9a620c7c99fa68aaf30e
--- /dev/null
+++ b/build/torch29-cxx11-cu126-x86_64-linux/stk/random/__init__.py
@@ -0,0 +1,2 @@
+# from stk.random.random_ops import dense_mask, mask, randn
+from .random_ops import dense_mask, mask, randn
diff --git a/build/torch29-cxx11-cu126-x86_64-linux/stk/random/random_ops.py b/build/torch29-cxx11-cu126-x86_64-linux/stk/random/random_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1b36771e0eb8e7abf46bcb3b136b5fb1d29df93
--- /dev/null
+++ b/build/torch29-cxx11-cu126-x86_64-linux/stk/random/random_ops.py
@@ -0,0 +1,36 @@
+import numpy as np
+import torch
+from ..ops import matrix_ops
+
+
+@torch.no_grad()
+def dense_mask(rows, cols, sparsity, blocking=1):
+ assert sparsity >= 0.0 and sparsity <= 1.0
+ assert rows % blocking == 0 and cols % blocking == 0
+
+ block_rows, block_cols = (rows // blocking, cols // blocking)
+ nnz = round(block_rows * block_cols * (1 - sparsity))
+
+ out = np.ones(block_rows * block_cols)
+ mask = np.random.choice(out.size, out.size - nnz, replace=False)
+ out[mask] = 0.0
+
+ out = np.tile(
+ np.reshape(out, [block_rows, 1, block_cols, 1]),
+ (1, blocking, 1, blocking))
+ out = np.reshape(out, [rows, cols])
+ return torch.from_numpy(out.astype(np.float32))
+
+
+@torch.no_grad()
+def mask(m, n, sparsity, blocking=1):
+ out = dense_mask(m, n, sparsity, blocking).type(torch.float16)
+ return matrix_ops.to_sparse(out, blocking=blocking)
+
+
+@torch.no_grad()
+def randn(shape, sparsity, blocking=1):
+ shape_2d = (np.prod(shape[:-1]), shape[-1])
+ out = mask(*shape_2d, sparsity, blocking)
+ out.data.copy_(torch.randn(*out.data.shape))
+ return out.view(*shape)
diff --git a/build/torch29-cxx11-cu126-x86_64-linux/stk/random/random_ops_test.py b/build/torch29-cxx11-cu126-x86_64-linux/stk/random/random_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..587b44ec890c861879c6296b8f9028f5d99ab82f
--- /dev/null
+++ b/build/torch29-cxx11-cu126-x86_64-linux/stk/random/random_ops_test.py
@@ -0,0 +1,73 @@
+import unittest
+
+from absl.testing import parameterized
+from . import random
+import torch
+
+
+@parameterized.parameters(
+ (8, 16, 0.0, 1),
+ (8, 16, 0.5, 1),
+ (8, 16, .95, 1),
+ (16, 8, 0.0, 1),
+ (16, 8, 0.5, 1),
+ (16, 8, .95, 1),
+ (8, 16, 0.0, 8),
+ (8, 16, 0.5, 8),
+ (8, 16, 1.0, 8),
+ (16, 8, 0.0, 8),
+ (16, 8, 0.5, 8),
+ (16, 8, 1.0, 8),
+ (128, 256, 0.5, 16),
+ (256, 128, 0.75, 32),
+ (512, 512, .875, 128))
+class RandomOpsTest(parameterized.TestCase):
+
+ def testRandomOps_DenseMask(self, rows, cols, sparsity, blocking):
+ mask = random.dense_mask(
+ rows, cols, sparsity, blocking)
+
+ # Validate the shape.
+ self.assertEqual(mask.dim(), 2)
+ self.assertEqual(mask.size()[0], rows)
+ self.assertEqual(mask.size()[1], cols)
+
+ # Validate the sparsity
+ numblocks = rows // blocking * cols // blocking
+ nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+ self.assertEqual(
+ torch.count_nonzero(mask).item(),
+ nnz)
+
+ # Check values are zero or one.
+ self.assertTrue(
+ torch.all(torch.logical_or(
+ torch.eq(mask, 0),
+ torch.eq(mask, 1))))
+
+ def testRandomOps_SparseMask(self, rows, cols, sparsity, blocking):
+ mask = random.mask(
+ rows, cols, sparsity, blocking)
+
+ # Validate the matrix.
+ mask.validate()
+
+ # Validate the shape.
+ self.assertEqual(mask.dim(), 2)
+ self.assertEqual(mask.size()[0], rows)
+ self.assertEqual(mask.size()[1], cols)
+
+ # Validate the sparsity.
+ numblocks = rows // blocking * cols // blocking
+ nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+ self.assertEqual(mask.nnz, nnz)
+
+ # Check values are zero or one.
+ self.assertTrue(
+ torch.all(torch.logical_or(
+ torch.eq(mask.data, 0),
+ torch.eq(mask.data, 1))))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch29-cxx11-cu126-x86_64-linux/xpu_fused_moe.py b/build/torch29-cxx11-cu126-x86_64-linux/xpu_fused_moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..0440b484e6f4073a384dcf3bc562174d2ef71d25
--- /dev/null
+++ b/build/torch29-cxx11-cu126-x86_64-linux/xpu_fused_moe.py
@@ -0,0 +1,672 @@
+# SPDX-License-Identifier: Apache-2.0
+# MegaBlocks XPU Fused MoE Implementation
+import os
+import torch
+
+from ._ops import ops, add_op_namespace_prefix
+
+from torch.library import register_fake
+
+
+def resolve_dtensor(weight: torch.Tensor):
+ """Convert DTensor to local tensor for use with custom ops."""
+ from torch.distributed._tensor import DTensor
+ if isinstance(weight, DTensor):
+ return weight.to_local()
+ return weight
+
+
+# Register fake/meta kernels for torch.compile compatibility
+def _register_xpu_fake_kernels():
+ """Register fake kernels for XPU MoE operations to support torch.compile."""
+
+ def _register_if_available(op_name, fn):
+ if hasattr(ops, op_name):
+ register_fake(add_op_namespace_prefix(op_name))(fn)
+
+ _register_if_available(
+ "cutlass_grouped_gemm_interface",
+ lambda ptr_A, ptr_B, ptr_scales, ptr_bias, ptr_D, expert_first_token_offset, N, K, num_experts, is_B_int4, is_B_mxfp4: ptr_D,
+ )
+
+ _register_if_available(
+ "fused_moe_prologue",
+ lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, ep_rank, ep_size, num_experts_on_rank: None,
+ )
+
+ _register_if_available(
+ "moe_gather",
+ lambda output, moe_output, topk_weights, permuted_row_to_unpermuted_row, unpermuted_row_to_permuted_row, expert_first_token_offset, num_experts: None,
+ )
+
+ _register_if_available(
+ "silu_and_mul",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "mul_and_silu",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_and_mul",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_tanh_and_mul",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_fast",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_new",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_quick",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "swigluoai_and_mul",
+ lambda out, input, alpha=1.702, limit=7.0: None,
+ )
+
+
+# Register fake kernels on module load
+_register_xpu_fake_kernels()
+
+
+# default
+def cutlass_grouped_gemm(input_A, input_B, bias, output, expert_token_count, n,
+ k, num_experts):
+ # expert_token_count_ = torch.tensor(expert_token_count,
+ # dtype=torch.int64,
+ # device=input_A.device)
+ # if bias is not None:
+ # bias = bias.repeat_interleave(expert_token_count_, dim=0).float()
+
+ def exclusive_prefix_sum(arr):
+ prefix = [0]
+ for i, x in enumerate(arr):
+ prefix.append(prefix[-1] + x)
+ return prefix
+
+ expert_offset = torch.tensor(exclusive_prefix_sum(expert_token_count),
+ dtype=torch.int64,
+ device="xpu")
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=None,
+ ptr_bias=bias,
+ ptr_D=output,
+ expert_first_token_offset=expert_offset,
+ N=n,
+ K=k,
+ num_experts=num_experts,
+ is_B_int4=False,
+ is_B_mxfp4=False)
+
+
+def cutlass_grouped_gemm_xe2(input_A, input_B, scales, bias, output,
+ num_rows_per_expert, n, k, num_experts, is_B_int4,
+ is_B_mxfp4):
+ expert_first_token_offset = torch.cat([
+ torch.tensor([0],
+ dtype=num_rows_per_expert.dtype,
+ device=num_rows_per_expert.device),
+ torch.cumsum(num_rows_per_expert, dim=0)
+ ]).to(torch.int64)
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=scales,
+ ptr_bias=bias,
+ ptr_D=output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=n,
+ K=k,
+ num_experts=num_experts,
+ is_B_int4=is_B_int4,
+ is_B_mxfp4=is_B_mxfp4)
+
+
+def ceilDiv(a, b):
+ return (a + b - 1) // b
+
+
+def compute_num_tokens_per_block(num_tokens, num_experts_per_node):
+ for num_tokens_per_block in [32, 64, 128, 256, 512, 1024]:
+ num_blocks_per_seq = ceilDiv(num_tokens, num_tokens_per_block)
+ if num_blocks_per_seq * num_experts_per_node <= num_tokens_per_block:
+ return num_tokens_per_block
+ return 1024
+
+
+def _bytes_to_typed_tensor(byte_tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
+ """Reinterpret a uint8 buffer as a typed tensor by copying bytes.
+
+ This avoids `Tensor.view(dtype)` which can fail under torch.compile
+ constant folding when shape divisibility is not proven.
+ """
+ if byte_tensor.dtype != torch.uint8:
+ raise ValueError("byte_tensor must be uint8")
+ itemsize = torch.empty((), dtype=dtype).element_size()
+ numel = byte_tensor.numel() // itemsize
+ out = torch.empty((numel,), dtype=dtype, device=byte_tensor.device)
+ out.view(torch.uint8).copy_(byte_tensor.contiguous())
+ return out
+
+
+def implement_zp(qweight):
+ # change u4 to s4 to avoid zero point in gemm kernel
+ # only support default zero point now
+ assert qweight.dtype == torch.uint8, "Input tensor must be uint8"
+
+ high_u4 = (qweight >> 4) & 0x0F
+ low_u4 = qweight & 0x0F
+
+ high_s8 = high_u4.to(torch.int8)
+ low_s8 = low_u4.to(torch.int8)
+
+ high_s8 = high_s8 - 8
+ low_s8 = low_s8 - 8
+
+ def pack_compact(a, b):
+
+ def process_number(x):
+ sign = (x < 0).to(torch.uint8)
+ abs_low3 = (x.view(torch.uint8) & 0x7).to(torch.uint8)
+ return (sign << 3) | abs_low3
+
+ packed_a = process_number(a)
+ packed_b = process_number(b)
+
+ return (packed_a << 4) | packed_b
+
+ result = pack_compact(high_s8, low_s8)
+
+ return result
+
+
+def xpu_fused_moe(hidden_states,
+ w13,
+ w13_scales,
+ w13_bias,
+ w2,
+ w2_scales,
+ w2_bias,
+ topk_weights,
+ topk_ids,
+ n_experts_per_token,
+ activation,
+ num_experts,
+ ep_rank=0,
+ ep_size=1,
+ is_fp8=False,
+ is_int4=False,
+ is_mxfp4=False):
+ '''
+ hidden_states: [num_rows, hidden_size]
+ w13: [num_experts, 2*inter_size, hidden_size]
+ w13_scales:
+ None for bf16/fp16
+ or [num_experts] for fp8
+ or [num_experts, 2*inter_size, hidden_size // group_size] for 4bits
+ w13_bias: [num_experts, 2*inter_size] or None
+ w2: [num_experts, hidden_size, inter_size]
+ w2_scales:
+ None for bf16/fp16
+ or [num_experts] for fp8
+ or [num_experts, hidden_size, inter_size // group_size] for 4bits
+ w2_bias: [num_experts, hidden_size] or None
+ topk_weights: [num_rows, topk]
+ topk_ids: [num_rows, topk]
+ n_experts_per_token: int
+ activation: str
+ num_experts: int
+ is_int4: bool
+ is_mxfp4: bool
+ '''
+
+ # Resolve DTensors to local tensors before passing to custom ops
+ hidden_states = resolve_dtensor(hidden_states)
+ w13 = resolve_dtensor(w13)
+ w2 = resolve_dtensor(w2)
+ if w13_scales is not None:
+ w13_scales = resolve_dtensor(w13_scales)
+ if w13_bias is not None:
+ w13_bias = resolve_dtensor(w13_bias)
+ if w2_scales is not None:
+ w2_scales = resolve_dtensor(w2_scales)
+ if w2_bias is not None:
+ w2_bias = resolve_dtensor(w2_bias)
+ topk_weights = resolve_dtensor(topk_weights)
+ topk_ids = resolve_dtensor(topk_ids)
+
+ output = torch.empty_like(hidden_states)
+ num_rows, hidden_size = list(hidden_states.shape)
+
+ dim_last = w13.shape[-1]
+ dim_second_last = w13.shape[-2]
+
+ # w13 is combined gate+up weights, so one dimension is 2*inter_size
+ # Determine which dimension is hidden_size and which is 2*inter_size
+ if dim_second_last == hidden_size:
+ # w13 is [E, hidden_size, 2*inter_size] - standard layout
+ inter_size = dim_last // 2
+ needs_transpose = False
+ else:
+ # w13 is [E, 2*inter_size, hidden_size] - needs transpose
+ inter_size = dim_second_last // 2
+ needs_transpose = True
+
+ assert w13.is_contiguous() and w2.is_contiguous()
+
+ # 4bits support [E, N, K]
+ # other types [E, K, N]
+ if not is_int4 and not is_mxfp4:
+ if not hasattr(w13, 'xpu_fused_moe'):
+ if needs_transpose:
+ w13.data = w13.transpose(-1, -2).contiguous()
+ w2.data = w2.transpose(-1, -2).contiguous()
+ w13.xpu_fused_moe = True
+ w13.inter_size = inter_size
+ else:
+ inter_size = w13.inter_size
+
+ if is_int4 and not hasattr(w13, 'xpu_fused_moe'):
+ w13_tmp = torch.empty_like(w13)
+ w2_tmp = torch.empty_like(w2)
+ for i in range(num_experts):
+ w13_tmp[i] = implement_zp(w13[i])
+ w2_tmp[i] = implement_zp(w2[i])
+ w13_tmp = w13_tmp.contiguous()
+ w2_tmp = w2_tmp.contiguous()
+ w13.data = w13_tmp
+ w2.data = w2_tmp
+ w13.xpu_fused_moe = True
+
+ # TODO: will all integrated in Cpp func. Temporary expose before gemm fusion
+ num_experts_per_node = num_experts
+ experts_per_token = n_experts_per_token
+ num_moe_inputs = n_experts_per_token * num_rows
+ permuted_elems = num_moe_inputs * hidden_size
+ # interbuf_elems = num_moe_inputs * inter_size
+ permuted_row_to_unpermuted_row_size = num_moe_inputs * 4
+ permuted_token_selected_experts_size = num_moe_inputs * 4
+ src_to_dest_map_size = experts_per_token * num_rows * 4
+ expert_first_token_offset_size = (num_experts_per_node + 1) * 8
+ num_tokens_per_block = compute_num_tokens_per_block(
+ num_rows, num_experts_per_node)
+ num_blocks_per_seq = ceilDiv(num_rows, num_tokens_per_block)
+ blocked_expert_counts_size = num_experts_per_node * num_blocks_per_seq * 4
+ blocked_expert_counts_cumsum_size = blocked_expert_counts_size
+ blocked_row_to_unpermuted_row_size = num_experts_per_node * num_rows * 4
+ permuted_data_size = permuted_elems * hidden_states.element_size()
+ permuted_token_final_scales_size = num_moe_inputs * 4
+
+ ws_map = {}
+ map_offset = 0
+
+ def config_ws(name, size):
+ nonlocal map_offset
+ if size % 256 != 0:
+ size += 256 - size % 256
+ ws_map[name] = (size, map_offset)
+ map_offset += size
+
+ config_ws("permuted_row_to_unpermuted_row",
+ permuted_row_to_unpermuted_row_size)
+ config_ws("permuted_token_selected_experts",
+ permuted_token_selected_experts_size)
+ config_ws("unpermuted_row_to_permuted_row", src_to_dest_map_size)
+ config_ws("blocked_expert_counts", blocked_expert_counts_size)
+ config_ws("blocked_expert_counts_cumsum",
+ blocked_expert_counts_cumsum_size)
+ config_ws("blocked_row_to_unpermuted_row",
+ blocked_row_to_unpermuted_row_size)
+ config_ws("expert_first_token_offset", expert_first_token_offset_size)
+ config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
+ config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
+
+ workspace = torch.zeros(map_offset,
+ dtype=torch.uint8,
+ device=hidden_states.device)
+ if topk_ids.dtype == torch.int32:
+ topk_ids = topk_ids.to(torch.int64)
+ ops.fused_moe_prologue(
+ input=hidden_states,
+ token_selected_experts=topk_ids,
+ token_final_scales=topk_weights,
+ workspace=workspace,
+ hidden_size=hidden_size,
+ inter_size=inter_size,
+ ep_rank=ep_rank,
+ ep_size=ep_size,
+ num_experts_on_rank=num_experts_per_node)
+
+ expert_first_token_offset_bytes = workspace[
+ ws_map["expert_first_token_offset"][1]:
+ ws_map["expert_first_token_offset"][1] +
+ expert_first_token_offset_size]
+ unpermuted_row_to_permuted_row_bytes = workspace[
+ ws_map["unpermuted_row_to_permuted_row"][1]:
+ ws_map["unpermuted_row_to_permuted_row"][1] +
+ src_to_dest_map_size]
+ permuted_row_to_unpermuted_row_bytes = workspace[
+ ws_map["permuted_row_to_unpermuted_row"][1]:
+ ws_map["permuted_row_to_unpermuted_row"][1] +
+ permuted_row_to_unpermuted_row_size]
+
+ if torch.compiler.is_compiling():
+ expert_first_token_offset = _bytes_to_typed_tensor(
+ expert_first_token_offset_bytes, torch.int64
+ )
+ unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
+ unpermuted_row_to_permuted_row_bytes, torch.int32
+ )
+ permuted_row_to_unpermuted_row = _bytes_to_typed_tensor(
+ permuted_row_to_unpermuted_row_bytes, torch.int32
+ )
+ else:
+ expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
+ unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
+ permuted_row_to_unpermuted_row = permuted_row_to_unpermuted_row_bytes.view(torch.int32)
+ gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
+ ws_map["overlapped_gemm1_gemm2_inputs"][1] +
+ permuted_data_size].view(hidden_states.dtype).view(
+ num_moe_inputs, hidden_size)
+ # permuted_token_final_scales = workspace[
+ # ws_map["permuted_token_final_scales"][1]:
+ # ws_map["permuted_token_final_scales"][1] +
+ # permuted_token_final_scales_size].view(torch.float)
+ gemm1_output = torch.empty((num_moe_inputs, 2 * inter_size),
+ dtype=hidden_states.dtype,
+ device=hidden_states.device)
+
+ ########### gemm1 ##################
+ input_B = w13
+
+ if not is_fp8 and not is_int4 and not is_mxfp4:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=gemm1_input,
+ ptr_B=input_B,
+ ptr_scales=None,
+ ptr_bias=w13_bias,
+ ptr_D=gemm1_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=2 * inter_size,
+ K=hidden_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+ else:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=gemm1_input,
+ ptr_B=input_B,
+ ptr_scales=w13_scales,
+ ptr_bias=w13_bias,
+ ptr_D=gemm1_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=2 * inter_size,
+ K=hidden_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+
+ # act
+ act_output = torch.empty((num_moe_inputs, inter_size),
+ dtype=gemm1_output.dtype,
+ device=gemm1_output.device)
+ if activation == "silu":
+ ops.silu_and_mul(act_output, gemm1_output)
+ elif activation == "gelu":
+ ops.gelu_and_mul(act_output, gemm1_output)
+ elif activation == "swigluoai":
+ ops.swigluoai_and_mul(act_output, gemm1_output, 1.702, 7.0)
+ else:
+ raise ValueError(f"Unsupported FusedMoe activation: {activation}.")
+
+ ########### gemm2 ##################
+ input_A = act_output.contiguous()
+ input_B = w2
+ gemm2_output = torch.empty((num_moe_inputs, hidden_size),
+ dtype=hidden_states.dtype,
+ device=hidden_states.device)
+ if not is_fp8 and not is_int4 and not is_mxfp4:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=None,
+ ptr_bias=w2_bias,
+ ptr_D=gemm2_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=hidden_size,
+ K=inter_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+ else:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=w2_scales,
+ ptr_bias=w2_bias,
+ ptr_D=gemm2_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=hidden_size,
+ K=inter_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+
+ ops.moe_gather(output, gemm2_output, topk_weights,
+ permuted_row_to_unpermuted_row,
+ unpermuted_row_to_permuted_row,
+ expert_first_token_offset,
+ num_experts_per_node)
+ return output
+
+
+def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor:
+ """Apply jitter to the input tensor for regularization."""
+ low = 1.0 - moe_jitter_eps
+ high = 1.0 + moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return x * (low + noise * (high - low))
+
+
+def compute_top_k(scores: torch.Tensor, moe_top_k: int):
+ """Compute the top-k scores from the logits."""
+ if moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, moe_top_k, dim=-1)
+
+
+def route_tokens_xpu(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: torch.Tensor,
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ training: bool = False,
+) -> tuple:
+ """Route tokens to experts and compute expert weights and indices (XPU version)."""
+ if training and moe_jitter_eps is not None:
+ x = apply_jitter(x, moe_jitter_eps)
+
+ x_flat = x.view(-1, x.shape[-1])
+ logits = torch.nn.functional.linear(x_flat, router_weight, router_bias)
+ expert_weights, expert_indices = compute_top_k(logits, moe_top_k)
+ expert_weights = expert_weights.softmax(dim=-1)
+ if moe_normalize_expert_weights is not None:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+
+ return logits, expert_weights, expert_indices
+
+
+def _get_device_mesh(model):
+ """Extract device_mesh from child's unused pre_hook closure for EP support."""
+ try:
+ hook = next(
+ h
+ for h in model.experts._forward_pre_hooks.values()
+ if "device_mesh" in h.__code__.co_freevars
+ )
+ return hook.__closure__[
+ hook.__code__.co_freevars.index("device_mesh")
+ ].cell_contents
+ except Exception:
+ return None
+
+
+class MegaBlocksMoeMLP(torch.nn.Module):
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> tuple:
+ """
+ Forward pass through the MoE layer.
+
+ Args:
+ x: Input tensor of shape [batch_size, seq_len, hidden_size] or [tokens, hidden_size]
+
+ Returns:
+ Tuple of (output, expert_weights) where:
+ - output: Tensor of same shape as input
+ - expert_weights: Expert weights for each token [tokens, top_k]
+ """
+
+ # Get MoE parameters from the wrapped modules
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+
+ # Get EP (Expert Parallelism) parameters
+ ep_size = 1
+ ep_rank = 0
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = _get_device_mesh(self)
+ if device_mesh is not None:
+ expert_parallel_group = device_mesh.get_group()
+ if expert_parallel_group is not None:
+ import torch.distributed as dist
+ if dist.is_initialized():
+ ep_size = dist.get_world_size(expert_parallel_group)
+ ep_rank = dist.get_rank(expert_parallel_group)
+
+ # Number of experts on this rank
+ num_experts_on_rank = moe_num_experts // ep_size
+
+ # Detect activation type - check for GptOss-style swigluoai activation
+ # GptOssExperts has alpha and limit attributes for swigluoai
+ if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
+ activation = "swigluoai"
+ else:
+ activation = getattr(self.experts, "activation", "silu")
+
+ # Get weight tensors - support different naming conventions
+ if hasattr(self.experts, "gate_up_proj"):
+ w13 = self.experts.gate_up_proj
+ # NOTE: swigluoai_and_mul kernel expects interleaved layout [g0,u0,g1,u1,...]
+ # which matches GptOss's gate_up_proj layout, so no conversion needed.
+
+ elif hasattr(self.experts, "w1"):
+ # Combine w1 and w3 if stored separately
+ w1 = self.experts.w1
+ w3 = getattr(self.experts, "w3", None)
+ if w3 is not None:
+ w13 = torch.cat([w1, w3], dim=-2)
+ else:
+ w13 = w1
+ else:
+ raise AttributeError("experts module must have 'gate_up_proj' or 'w1' attribute")
+
+ if hasattr(self.experts, "down_proj"):
+ w2 = self.experts.down_proj
+ elif hasattr(self.experts, "w2"):
+ w2 = self.experts.w2
+ else:
+ raise AttributeError("experts module must have 'down_proj' or 'w2' attribute")
+
+ # Get optional bias tensors
+ w13_bias = getattr(self.experts, "gate_up_proj_bias", None)
+ w2_bias = getattr(self.experts, "down_proj_bias", None)
+
+ # Get quantization info
+ is_fp8 = getattr(self.experts, "is_fp8", False)
+ is_int4 = getattr(self.experts, "is_int4", False)
+ is_mxfp4 = getattr(self.experts, "is_mxfp4", False)
+
+ w13_scales = getattr(self.experts, "gate_up_proj_scales", None)
+ w2_scales = getattr(self.experts, "down_proj_scales", None)
+
+ # Store original shape
+ in_shape = x.size()
+
+ # Route tokens to experts
+ logits, expert_weights, expert_indices = route_tokens_xpu(
+ x,
+ self.router.weight,
+ getattr(self.router, "bias", None),
+ moe_top_k,
+ moe_num_experts,
+ moe_jitter_eps,
+ moe_normalize_expert_weights,
+ self.training,
+ )
+
+ # Reshape input for fused MoE
+ x_flat = x.view(-1, x.shape[-1])
+
+ # Call XPU fused MoE kernel
+ output = xpu_fused_moe(
+ hidden_states=x_flat,
+ w13=w13,
+ w13_scales=w13_scales,
+ w13_bias=w13_bias,
+ w2=w2,
+ w2_scales=w2_scales,
+ w2_bias=w2_bias,
+ topk_weights=expert_weights.float(),
+ topk_ids=expert_indices,
+ n_experts_per_token=moe_top_k,
+ activation=activation,
+ num_experts=num_experts_on_rank,
+ ep_rank=ep_rank,
+ ep_size=ep_size,
+ is_fp8=is_fp8,
+ is_int4=is_int4,
+ is_mxfp4=is_mxfp4,
+ )
+
+ # All-reduce across EP group to combine partial expert outputs
+ if ep_size > 1 and expert_parallel_group is not None:
+ import torch.distributed as dist
+ dist.all_reduce(output, op=dist.ReduceOp.SUM, group=expert_parallel_group)
+
+ # Restore original shape
+ output = output.view(in_shape)
+
+ return output, expert_weights
+
+
+# Export classes and functions
+__all__ = [
+ "MegaBlocksMoeMLP",
+ "xpu_fused_moe",
+ "cutlass_grouped_gemm",
+ "cutlass_grouped_gemm_xe2",
+]
\ No newline at end of file
diff --git a/build/torch29-cxx11-cu128-aarch64-linux/__init__.py b/build/torch29-cxx11-cu128-aarch64-linux/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..38075732c6d8fa0e1e6ef493145e1aca3851ae6b
--- /dev/null
+++ b/build/torch29-cxx11-cu128-aarch64-linux/__init__.py
@@ -0,0 +1,202 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+from ._ops import ops
+
+from .grouped_gemm import backend as gg_backend
+from .grouped_gemm import ops as gg_ops
+
+
+from ._layers.arguments import Arguments
+from ._layers.dmoe import ParallelDroplessMLP, dMoE
+from ._layers.glu import SparseGLU
+from ._layers.mlp import MLP, SparseMLP
+from ._layers.moe import MoE, ParallelMLP, get_load_balancing_loss
+
+from . import layers
+
+# This section contains the direct kernel exports (not inlcuded in the original code)
+def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
+ """
+ Compute exclusive cumulative sum along the specified dimension.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ result = ops.exclusive_cumsum(x, dim)
+ out.copy_(result)
+ return out
+
+
+def inclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
+ """
+ Compute inclusive cumulative sum along the specified dimension.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ result = ops.inclusive_cumsum(x, dim)
+ out.copy_(result)
+ return out
+
+
+def histogram(x: torch.Tensor, num_bins: int) -> torch.Tensor:
+ """
+ Compute histogram of input tensor values.
+
+ Args:
+ x: Input tensor
+ num_bins: Number of histogram bins
+
+ Returns:
+ Histogram tensor with counts for each bin
+ """
+ return ops.histogram(x, num_bins)
+
+
+def indices(
+ padded_bins: torch.Tensor,
+ block_size: int,
+ output_block_rows: int,
+ output_block_columns: int,
+) -> torch.Tensor:
+ """
+ Construct indices from padded bins for sparse operations.
+
+ Args:
+ padded_bins: Tensor containing bin boundaries
+ block_size: Size of each block
+ output_block_rows: Number of rows in output blocks
+ output_block_columns: Number of columns in output blocks
+
+ Returns:
+ Tensor containing constructed indices
+ """
+ return ops.indices(padded_bins, block_size, output_block_rows, output_block_columns)
+
+
+def replicate_forward(
+ x: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Forward pass of replicate operation - replicate values according to bin sizes.
+
+ Args:
+ x: Input tensor with values to replicate
+ bins: Tensor containing bin sizes
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ return ops.replicate_forward(x, bins, out)
+
+
+def replicate_backward(
+ grad: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Backward pass of replicate operation - reduce gradients back to bins.
+
+ Args:
+ grad: Gradient tensor to reduce
+ bins: Tensor containing bin sizes
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ return ops.replicate_backward(grad, bins, out)
+
+
+def sort(
+ x: torch.Tensor, end_bit: int, x_out: torch.Tensor, iota_out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Radix sort with index tracking.
+
+ Args:
+ x: Input tensor to sort
+ end_bit: Number of bits to consider in sorting
+ x_out: Output tensor for sorted values
+ iota_out: Output tensor for sorted indices
+
+ Returns:
+ The sorted values tensor
+ """
+ return ops.sort(x, end_bit, x_out, iota_out)
+
+
+# Convenience functions for common use cases
+def cumsum(x: torch.Tensor, dim: int = -1, exclusive: bool = False) -> torch.Tensor:
+ """
+ Compute cumulative sum with automatic output allocation.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum (default: last dimension)
+ exclusive: Whether to compute exclusive (True) or inclusive (False) cumsum
+
+ Returns:
+ New tensor containing the cumulative sum
+ """
+ out = torch.empty_like(x)
+ if exclusive:
+ return exclusive_cumsum(x, dim, out)
+ else:
+ return inclusive_cumsum(x, dim, out)
+
+
+def argsort(x: torch.Tensor, end_bit: int = 32) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Sort tensor and return both sorted values and indices.
+
+ Args:
+ x: Input tensor to sort
+ end_bit: Number of bits to consider in sorting
+
+ Returns:
+ Tuple of (sorted_values, sorted_indices)
+ """
+ x_out = torch.empty_like(x)
+ iota_out = torch.empty_like(x)
+ sort(x, end_bit, x_out, iota_out)
+ return x_out, iota_out
+
+
+# Export public API
+__all__ = [
+ "MyReplacementLayer",
+ # Direct kernel exports
+ "exclusive_cumsum",
+ "inclusive_cumsum",
+ "histogram",
+ "indices",
+ "replicate_forward",
+ "replicate_backward",
+ "sort",
+ "cumsum",
+ "argsort",
+ # Original exports
+ "Arguments",
+ "ParallelDroplessMLP",
+ "dMoE",
+ "SparseGLU",
+ "MLP",
+ "SparseMLP",
+ "MoE",
+ "ParallelMLP",
+ "get_load_balancing_loss",
+]
diff --git a/build/torch29-cxx11-cu128-aarch64-linux/_layers/__init__.py b/build/torch29-cxx11-cu128-aarch64-linux/_layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a720e7a2cc4e44636f6e433a2750e945dc38e8b2
--- /dev/null
+++ b/build/torch29-cxx11-cu128-aarch64-linux/_layers/__init__.py
@@ -0,0 +1,10 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# from megablocks.layers.dmoe import dMoE
+from .moe import MoE
+
+__all__ = [
+ 'MoE',
+ # 'dMoE',
+]
diff --git a/build/torch29-cxx11-cu128-aarch64-linux/_layers/activation_fn.py b/build/torch29-cxx11-cu128-aarch64-linux/_layers/activation_fn.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e1d956704840aa4daf7d1d71d24e051567feab9
--- /dev/null
+++ b/build/torch29-cxx11-cu128-aarch64-linux/_layers/activation_fn.py
@@ -0,0 +1,33 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Callable, Union
+
+import torch
+from ..stk import Matrix
+
+
+def act_fn(
+ x: Matrix,
+ function: Callable,
+ return_grad_fn: bool = False,
+ **kwargs,
+) -> Union[tuple[Matrix, Any] | Matrix]:
+ assert isinstance(x, Matrix)
+ with torch.set_grad_enabled(torch.is_grad_enabled() or return_grad_fn):
+ if return_grad_fn:
+ x.data.requires_grad = True
+ out = function(x.data, **kwargs)
+ y = Matrix(
+ x.size(),
+ out,
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
+ if return_grad_fn:
+ return y, out.backward
+ return y
diff --git a/build/torch29-cxx11-cu128-aarch64-linux/_layers/all_to_all.py b/build/torch29-cxx11-cu128-aarch64-linux/_layers/all_to_all.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ac7067bcaa34db1d82b340c43550fe3577aa7a3
--- /dev/null
+++ b/build/torch29-cxx11-cu128-aarch64-linux/_layers/all_to_all.py
@@ -0,0 +1,54 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import torch.distributed as dist
+
+
+class AllToAllOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op):
+ out = torch.empty((sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype)
+
+ ctx.input_shape = x.shape
+ ctx.output_split_sizes = output_split_sizes
+ ctx.input_split_sizes = input_split_sizes
+ ctx.group = group
+ handle = dist.all_to_all_single(
+ out,
+ x,
+ output_split_sizes=output_split_sizes,
+ input_split_sizes=input_split_sizes,
+ group=group,
+ async_op=async_op,
+ )
+ return out, handle
+
+ @staticmethod
+ def backward(ctx, grad, _):
+ if ctx.needs_input_grad[0]:
+ out = torch.empty(
+ ctx.input_shape,
+ device=grad.device,
+ dtype=grad.dtype,
+ )
+ dist.all_to_all_single(
+ out,
+ grad,
+ output_split_sizes=ctx.input_split_sizes,
+ input_split_sizes=ctx.output_split_sizes,
+ group=ctx.group,
+ )
+ return out, None, None, None, None
+ return None, None, None, None, None
+
+
+def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False):
+ return AllToAllOp.apply(
+ x,
+ output_split_sizes,
+ input_split_sizes,
+ group,
+ async_op,
+ )
diff --git a/build/torch29-cxx11-cu128-aarch64-linux/_layers/arguments.py b/build/torch29-cxx11-cu128-aarch64-linux/_layers/arguments.py
new file mode 100644
index 0000000000000000000000000000000000000000..4db9b1bd38bc2e2f421625c124f86b85f45c5ae0
--- /dev/null
+++ b/build/torch29-cxx11-cu128-aarch64-linux/_layers/arguments.py
@@ -0,0 +1,101 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import dataclasses
+from functools import partial
+from typing import Any, Callable, Optional, Union
+
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+
+# import megablocks.grouped_gemm_util as grouped_gemm
+from .. import grouped_gemm_util as grouped_gemm
+
+# Type annotation for in-place Tensor initialization function.
+InitFn = Union[Callable[[torch.Tensor], None], partial[torch.Tensor]]
+
+_ALLOWED_BITWIDTHS = (-1, 4, 8)
+
+DEFAULT_ACTIVATION_FN = partial(F.gelu, approximate='tanh')
+
+
+@dataclasses.dataclass
+class Arguments:
+ # Model arguments.
+ hidden_size: int = 1024
+ ffn_hidden_size: int = 4096
+ num_layers: int = 1
+ bias: bool = True
+ return_bias: bool = True
+ activation_fn: Optional[Callable] = DEFAULT_ACTIVATION_FN
+
+ # MoE arguments.
+ moe_num_experts: int = 1
+ moe_top_k: int = 1
+ moe_capacity_factor: int = 1
+ moe_normalize_expert_weights: Optional[Union[int, float]] = None
+ moe_loss_weight: float = 0.1
+ moe_jitter_eps: Optional[float] = None
+ moe_lbl_in_fp32: bool = False
+
+ # Parallelism arguments.
+ moe_expert_model_parallelism: bool = False
+ expert_parallel_group: Optional[dist.ProcessGroup] = None
+ pipeline_model_parallel_size: int = 1
+ num_layers_per_virtual_pipeline_stage: Optional[int] = None
+
+ # Compute arguments.
+ memory_optimized_mlp: bool = False
+ mlp_type: str = 'mlp'
+ mlp_impl: str = 'sparse'
+
+ # Initialization arguments.
+ fp16: bool = True
+ bf16: bool = False
+ device: Union[int, torch.device] = dataclasses.field(default_factory=torch.cuda.current_device)
+ init_method: InitFn = partial(torch.nn.init.normal_, mean=0.0, std=0.02)
+ output_layer_init_method: InitFn = init_method
+
+ # Benchmarking arguments.
+ uniform_expert_assignment: bool = False
+
+ # shared expert arguments
+ shared_expert: bool = False # enable using shared expert
+ fc_cls: Any = torch.nn.Linear # class of the fully connected layer in shared expert (purpose: to allow using custom FC layer eg te.Linear (for FP8))
+ fc_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict,) # kwargs for custom fc layers
+ remat_act_fn: bool = True # enable act fn to be rematerialized instead of stored
+ shared_expert_hidden_size: Optional[
+ int] = None # hidden size of the shared expert IF we want to set it to something different from hidden_size
+ shared_expert_weighted_sum: bool = False # enable using weighted sum for shared expert output (wieghted by number of experts used)
+
+ # Router Z-loss arguments
+ moe_zloss_weight: float = 0 # 1e-3 is a reasonable value
+ moe_zloss_in_fp32: bool = False
+
+ def __post_init__(self):
+ # Sparse MLP is not supported with triton >=3.2.0
+ # TODO: Remove this once sparse is supported with triton >=3.2.0
+ if self.__getattribute__('mlp_impl') == 'sparse':
+ try:
+ import triton
+ if triton.__version__ >= '3.2.0':
+ raise ValueError(
+ 'Sparse MLP is not supported with triton >=3.2.0. Please use mlp_impl="grouped" instead.',
+ )
+ except ImportError:
+ raise ImportError('Triton is required for sparse MLP implementation')
+
+ if self.__getattribute__('mlp_impl') == 'grouped':
+ grouped_gemm.assert_grouped_gemm_is_available()
+
+ if self.shared_expert_hidden_size is None:
+ self.shared_expert_hidden_size = self.ffn_hidden_size
+
+
+def from_megatron(megatron_args: Any):
+ args = Arguments()
+ for field in dataclasses.fields(args):
+ if hasattr(megatron_args, field.name):
+ setattr(args, field.name, getattr(megatron_args, field.name))
+ return args
diff --git a/build/torch29-cxx11-cu128-aarch64-linux/_layers/common.py b/build/torch29-cxx11-cu128-aarch64-linux/_layers/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d07109702963ba48a3b94ab860807954dfd79c1
--- /dev/null
+++ b/build/torch29-cxx11-cu128-aarch64-linux/_layers/common.py
@@ -0,0 +1,26 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+from .arguments import Arguments
+
+
+def dtype(args: Arguments):
+ if args.fp16:
+ return torch.float16
+ elif args.bf16:
+ return torch.bfloat16
+ return None
+
+
+def cast_if_autocast_enabled(tensor):
+ if torch.is_autocast_enabled():
+ if tensor.device.type == 'cuda':
+ dtype = torch.get_autocast_gpu_dtype()
+ elif tensor.device.type == 'cpu':
+ dtype = torch.get_autocast_cpu_dtype()
+ else:
+ raise NotImplementedError()
+ return tensor.to(dtype=dtype)
+ return tensor
diff --git a/build/torch29-cxx11-cu128-aarch64-linux/_layers/dmlp_registry.py b/build/torch29-cxx11-cu128-aarch64-linux/_layers/dmlp_registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..de2ed047042e438c7190ebb139b6f7f30009734c
--- /dev/null
+++ b/build/torch29-cxx11-cu128-aarch64-linux/_layers/dmlp_registry.py
@@ -0,0 +1,42 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Union
+
+from . import glu, mlp
+from .arguments import Arguments
+
+MlpType = Union[mlp.SparseMLP, glu.SparseGLU]
+
+_REGISTRY = {
+ 'mlp': {
+ 'grouped': mlp.GroupedMLP,
+ 'sparse': mlp.SparseMLP,
+ },
+ 'glu': {
+ 'grouped': glu.GroupedGLU,
+ 'sparse': glu.SparseGLU,
+ },
+}
+
+
+def get(args: Arguments) -> MlpType:
+ """Returns an MLP for use in a dMoE instance.
+
+ Uses the provided arguments to instantiate the appropriate
+ MLP instance. This only contains MLPs for use in dMoEs
+ (ie. only for the dropless versions of MoEs).
+
+ Args:
+ args: propagated Arguments dataclass.
+
+ Returns:
+ An instantiated MLP constructed using the input args.
+ """
+ if args.mlp_type not in _REGISTRY:
+ raise ValueError(f'Unsupported mlp type: {args.mlp_type}')
+
+ if args.mlp_impl not in _REGISTRY[args.mlp_type]:
+ raise ValueError(f'{args.mlp_type} does not support {args.mlp_impl} backend.',)
+
+ return _REGISTRY[args.mlp_type][args.mlp_impl](args)
diff --git a/build/torch29-cxx11-cu128-aarch64-linux/_layers/dmoe.py b/build/torch29-cxx11-cu128-aarch64-linux/_layers/dmoe.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d0375a4df2f27134c4127e60be04f3b45693050
--- /dev/null
+++ b/build/torch29-cxx11-cu128-aarch64-linux/_layers/dmoe.py
@@ -0,0 +1,337 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import numpy as np
+import torch
+
+# try:
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/dmoe.py is needed.',
+# )
+
+# import megablocks.ops as ops
+# # from megablocks.ops import ops
+# from megablocks.layers import common, dmlp_registry, moe, mpu
+# from megablocks.layers.arguments import Arguments
+
+from .. import stk
+from .. import ops
+from . import common, dmlp_registry, moe, mpu
+from .arguments import Arguments
+
+def promote_scalar(x):
+ return x.view(1) if not len(x.size()) else x
+
+
+class ParallelDroplessMLP(moe.ParallelMLP):
+
+ def __init__(self, args: Arguments):
+ super(ParallelDroplessMLP, self).__init__(args)
+ self.hidden_size = args.hidden_size
+ self.ffn_hidden_size = mpu.features_per_rank(args)
+ self.blocking = 128
+ self.mlp = dmlp_registry.get(args)
+
+ # Calculate the number of bits needed to represent the column indices
+ # in the intermediate sparse matrix.
+ max_column_index = ((self.ffn_hidden_size * self.num_experts) // self.blocking)
+ self.transpose_sort_end_bit = max(
+ int(np.ceil(np.log2(max_column_index))),
+ 1,
+ )
+
+ def sparse_transpose(self, size, row_indices, column_indices, offsets):
+ block_columns = size[1] // self.blocking
+
+ # Sort row indices by column indices to get the transposed matrix's
+ # column indices.
+ #
+ # NOTE: Our sort operation uses the same width indices as the input values.
+ # To avoid overflow when we have large activation matrices we cast to
+ # 32-bit before sorting.
+ _, gather_indices = ops.sort(
+ column_indices.int(),
+ self.transpose_sort_end_bit,
+ )
+
+ # There are a constant number of blocks in every row of the sparse matrix.
+ # A blocks offset is:
+ #
+ # row_index * blocks_per_row + column_index % blocks_per_row
+ #
+ # Once we have the block offsets ordered for transposition we can divide
+ # by blocks_per_row to get the transposed column indices.
+ column_indices_t = row_indices.gather(0, gather_indices.long())
+ block_offsets_t = gather_indices.int()
+
+ zero = torch.zeros((1,), dtype=torch.int32, device=row_indices.device)
+ nnz_per_column = ops.histogram(column_indices, block_columns)
+ nnz_per_column = ops.inclusive_cumsum(nnz_per_column, 0)
+ if nnz_per_column.dim() == 0:
+ # This addresses an edge case when ffn_hidden_size is equal to self.blocking.
+ nnz_per_column = nnz_per_column.unsqueeze(0)
+ offsets_t = torch.cat([zero, nnz_per_column])
+ return column_indices_t, offsets_t, block_offsets_t
+
+ def topology(self, x, padded_bins):
+ padded_tokens, _ = x.size()
+ assert padded_tokens % self.blocking == 0
+ if self.ffn_hidden_size % self.blocking != 0:
+ raise ValueError(
+ f'The ffn_hidden_size {self.ffn_hidden_size} must be divisible by ' +
+ f'the block size {self.blocking}. Please update your configuration.',
+ )
+
+ # Offsets for the sparse matrix. All rows have the
+ # same number of nonzero blocks dictated by the
+ # dimensionality of a single expert.
+ block_rows = padded_tokens // self.blocking
+ blocks_per_row = self.ffn_hidden_size // self.blocking
+ offsets = torch.arange(
+ 0,
+ block_rows * blocks_per_row + 1,
+ blocks_per_row,
+ dtype=torch.int32,
+ device=x.device,
+ )
+
+ # Indices for the sparse matrix. The indices for
+ # the intermediate matrix are dynamic depending
+ # on the mapping of tokens to experts.
+ column_indices = ops.topology(
+ padded_bins,
+ self.blocking,
+ block_rows,
+ blocks_per_row,
+ )
+
+ # TODO(tgale): This is unused. Remove the need for this in stk.
+ # For now, use meta init to save the device memory.
+ data = torch.empty(
+ column_indices.numel(),
+ self.blocking,
+ self.blocking,
+ dtype=common.dtype(self.args),
+ device='meta',
+ )
+ shape = (
+ padded_tokens,
+ self.ffn_hidden_size * mpu.experts_per_rank(self.args),
+ )
+ row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
+ column_indices_t, offsets_t, block_offsets_t = self.sparse_transpose(
+ shape,
+ row_indices,
+ column_indices,
+ offsets,
+ )
+ return stk.Matrix(
+ shape,
+ data,
+ row_indices,
+ column_indices,
+ offsets,
+ column_indices_t,
+ offsets_t,
+ block_offsets_t,
+ )
+
+ def indices_and_padded_bins(self, top_experts):
+ # Sort the expert ids to produce the scatter/gather
+ # indices for the permutation.
+ top_experts = top_experts.int()
+ bin_ids, indices = ops.sort(top_experts, self.sort_end_bit)
+
+ # Histogram the expert ids to identify the number of
+ # tokens routed to each expert.
+ tokens_per_expert = ops.histogram(top_experts, self.num_experts)
+
+ # Round the token counts up to the block size used in
+ # the matrix muliplications. Caculate the starting
+ # position of each bin.
+ padded_tokens_per_expert = ops.round_up(
+ tokens_per_expert,
+ self.blocking,
+ )
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ padded_bins = promote_scalar(padded_bins)
+
+ # Calculate the bin bounds for the sorted tokens.
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ bins = promote_scalar(bins)
+ return indices, bin_ids, bins, padded_bins, tokens_per_expert
+
+ def sparse_forward_once(self, x, expert_weights, top_experts):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, padded_bins, tokens_per_expert = (self.indices_and_padded_bins(top_experts))
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.padded_gather(
+ x,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ self.top_k,
+ )
+
+ # Create the sparse matrix topology.
+ with torch.no_grad():
+ topo = self.topology(x, padded_bins)
+
+ # Perform the expert computation.
+ x = self.mlp(x, topo)
+
+ # Un-route the data for the MoE output.
+ x = ops.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ padded_bins,
+ self.top_k,
+ )
+ return x, tokens_per_expert
+
+ # For use in the base-class parallel_forward_once.
+ def sparse_permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy, # unused
+ top_k,
+ ):
+
+ # Round the token counts up to the block size used in the matrix
+ # multiplication. Calculate the starting position of each bin.
+ padded_tokens_per_expert = ops.round_up(
+ tokens_per_expert,
+ self.blocking,
+ )
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ padded_bins = promote_scalar(padded_bins)
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
+
+ # Create the sparse matrix topology.
+ with torch.no_grad():
+ topo = self.topology(x, padded_bins)
+
+ # Perform the expert computation.
+ x = self.mlp(x, topo)
+
+ # Un-route the data for the MoE output.
+ return ops.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ def grouped_forward_once(self, x, expert_weights, top_experts):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ out = self.grouped_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ -1, # unused
+ self.args.moe_top_k,
+ )
+ return out, tokens_per_expert
+
+ def grouped_permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy, # unused
+ top_k,
+ ):
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.gather(x, indices, bin_ids, bins, top_k)
+
+ # Perform the expert computation.
+ x = self.mlp(x, tokens_per_expert)
+
+ # Un-route the data for the MoE output.
+ return ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
+
+ def forward_once(self, x, expert_weights, top_experts):
+ if self.args.mlp_impl == 'sparse':
+ return self.sparse_forward_once(x, expert_weights, top_experts)
+ else:
+ return self.grouped_forward_once(x, expert_weights, top_experts)
+
+ def permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ ):
+ if self.args.mlp_impl == 'sparse':
+ return self.sparse_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ )
+ else:
+ return self.grouped_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ )
+
+
+class dMoE(moe.MoE):
+
+ def _init_experts_mlp(self, args: Arguments):
+ return ParallelDroplessMLP(args)
diff --git a/build/torch29-cxx11-cu128-aarch64-linux/_layers/gelu.py b/build/torch29-cxx11-cu128-aarch64-linux/_layers/gelu.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4c9e6532798615b5c12c96694241a4c18ee8f7b
--- /dev/null
+++ b/build/torch29-cxx11-cu128-aarch64-linux/_layers/gelu.py
@@ -0,0 +1,52 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# try:
+# import stk
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/gelu.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+import torch.nn.functional as F
+
+
+@torch.jit.script
+def _gelu_backward_inplace(g, x):
+ tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
+ ff = (0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out))
+ return g.mul_(ff)
+
+
+def gelu_backward_(grad: stk.Matrix, x: stk.Matrix):
+ # NOTE: The two sparse matrices must have the same topology.
+ if isinstance(grad, stk.Matrix) and isinstance(x, stk.Matrix):
+ return stk.Matrix(
+ x.size(),
+ _gelu_backward_inplace(grad.data, x.data),
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
+ return _gelu_backward_inplace(grad, x)
+
+
+def gelu(x: stk.Matrix):
+ assert isinstance(x, stk.Matrix)
+ return stk.Matrix(
+ x.size(),
+ F.gelu(x.data, approximate='tanh'),
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
diff --git a/build/torch29-cxx11-cu128-aarch64-linux/_layers/glu.py b/build/torch29-cxx11-cu128-aarch64-linux/_layers/glu.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f297a41ff6a1a2a285f5b461951672364b898da
--- /dev/null
+++ b/build/torch29-cxx11-cu128-aarch64-linux/_layers/glu.py
@@ -0,0 +1,244 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# import stk.ops
+# try:
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/glu.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+
+# from megablocks import grouped_gemm_util as gg
+# from megablocks.layers import common, mpu
+# from megablocks.layers.activation_fn import act_fn
+# from megablocks.layers.arguments import Arguments
+# from megablocks.layers.mlp import (
+# SharedMLP,
+# SparseMLP,
+# create_dmoe_expert_weights,
+# resolve_dtensor,
+# )
+
+from .. import grouped_gemm_util as gg
+from . import common, mpu
+from .activation_fn import act_fn
+from .arguments import Arguments
+from .mlp import (
+ SharedMLP,
+ SparseMLP,
+ create_dmoe_expert_weights,
+ resolve_dtensor,
+)
+
+
+class SparseGLU(SparseMLP):
+
+ def __init__(self, args: Arguments):
+ super().__init__(args)
+ self.v1 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ with torch.no_grad():
+ self.v1.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ ),
+ )
+
+ mpu.set_expert_model_parallel_attributes(
+ self.v1,
+ self._should_set_parallelism_attribute,
+ )
+
+ def forward(self, x, topo):
+ if self.args.memory_optimized_mlp:
+ raise NotImplementedError(
+ 'Memory optimized implementation not yet supported with GLU with sparse kernels.',
+ )
+
+ w1, v1, w2 = self.scale_grad(self.w1), self.scale_grad(self.v1,), self.scale_grad(self.w2)
+ w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2)
+
+ # Compute the GLU.
+ x1 = stk.ops.sdd(x, w1.t(), topo)
+ x2 = stk.ops.sdd(x, v1.t(), topo)
+
+ activation_fn_out = act_fn(x1, self.args.activation_fn)
+ x1 = stk.ops.mul(activation_fn_out, x2)
+
+ return stk.ops.dsd(x1, w2)
+
+
+class MemoryOptimizedGroupedGLU(torch.autograd.Function):
+ """GroupedMLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, v1, w2, batch_sizes, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ v1 = v1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], v1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not v1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1', 'v1' and 'w2'.")
+
+ # Layer 0: x @ w1.t().
+ assert gg.backend is not None
+ sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True)
+ v1_out = gg.backend.gmm(x, v1, batch_sizes, trans_b=True)
+
+ # GeLU.
+ activation_fn_out = activation_fn(sdd_out) * v1_out
+
+ # Layer 1: x @ w2.
+ dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, v1, w2, batch_sizes, x, sdd_out, v1_out)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx, ddsd_out):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # Unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, v1, w2 = saved_tensors[:3]
+ batch_sizes = saved_tensors[3]
+ x = saved_tensors[4]
+ sdd_out, v1_out = saved_tensors[5:7]
+
+ # Rematerialize activation_fn output.
+ activation_fn = ctx.activation_fn
+ with torch.set_grad_enabled(True):
+ sdd_out.requires_grad = True
+ v1_out.requires_grad = True
+ activation_fn_out = activation_fn(sdd_out) * v1_out
+ activation_grad_fn = activation_fn_out.backward
+
+ # Compute dw2 with recomputed activation_fn output.
+ assert gg.backend is not None
+ dw2 = gg.backend.gmm(
+ activation_fn_out,
+ ddsd_out,
+ batch_sizes,
+ trans_a=True,
+ )
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ gg.backend.gmm(
+ ddsd_out,
+ w2,
+ batch_sizes,
+ trans_b=True,
+ c=dactivation_fn_out,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out)
+ dsdd_out = sdd_out.grad
+ dv1_out = v1_out.grad
+
+ # Compute dw1.
+ dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True)
+
+ # Compute dv1.
+ dv1 = gg.backend.gmm(dv1_out, x, batch_sizes, trans_a=True)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ dx = ddsd_out
+ gg.backend.gmm(dsdd_out, w1, batch_sizes, c=dx)
+ dx += gg.backend.gmm(dv1_out, v1, batch_sizes)
+ return dx, dw1, dv1, dw2, None, None
+
+
+memory_optimized_grouped_glu = MemoryOptimizedGroupedGLU.apply
+
+
+class GroupedGLU(SparseGLU):
+
+ def forward(self, x, tokens_per_expert):
+ batch_sizes = tokens_per_expert.cpu().to(torch.long)
+ w1, v1, w2 = (
+ self.scale_grad(self.w1),
+ self.scale_grad(self.v1),
+ self.scale_grad(self.w2),
+ )
+ w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2)
+
+ # Re-shape the weights for the grouped GEMMs.
+ ne = mpu.experts_per_rank(self.args)
+ w1 = w1.view(ne, -1, self.args.hidden_size)
+ v1 = v1.view(ne, -1, self.args.hidden_size)
+ w2 = w2.view(ne, -1, self.args.hidden_size)
+
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_grouped_glu(
+ x,
+ w1,
+ v1,
+ w2,
+ batch_sizes,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ assert gg.ops is not None
+ x1 = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
+ x2 = gg.ops.gmm(x, v1, batch_sizes, trans_b=True)
+ x1 = self.args.activation_fn(x1) * x2
+ return gg.ops.gmm(x1, w2, batch_sizes)
+
+
+class SharedGLU(SharedMLP):
+ """GPU for shared expert.
+
+ Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTGLU class
+ """
+
+ def __init__(self, args: Arguments):
+ super().__init__(args)
+ self.gate_proj = args.fc_cls(
+ args.hidden_size,
+ self.args.shared_expert_hidden_size,
+ **self.fc_kwargs,
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))
diff --git a/build/torch29-cxx11-cu128-aarch64-linux/_layers/memory_test.py b/build/torch29-cxx11-cu128-aarch64-linux/_layers/memory_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..74d1166931b712635131985b25a89f4ca23e576d
--- /dev/null
+++ b/build/torch29-cxx11-cu128-aarch64-linux/_layers/memory_test.py
@@ -0,0 +1,103 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import gc
+
+import torch
+import torch.distributed as dist
+
+# from megablocks.layers import arguments, dmoe
+from . import arguments, dmoe
+
+_TESTS = ((8, 2048, 4096, 4096, 32, 4),)
+
+
+def get_tensors():
+ ptrs = set()
+ out = []
+ for obj in gc.get_objects():
+ if torch.is_tensor(obj):
+ if not obj.is_contiguous() or obj.data_ptr() in ptrs:
+ continue
+ out.append(obj)
+ ptrs.add(obj.data_ptr())
+ return out
+
+
+def test_memory(
+ group,
+ batch_size,
+ sequence_length,
+ hidden_size,
+ ffn_hidden_size,
+ num_experts,
+ top_k,
+):
+ args = arguments.Arguments(
+ hidden_size=hidden_size,
+ ffn_hidden_size=ffn_hidden_size,
+ moe_num_experts=num_experts,
+ moe_top_k=top_k,
+ moe_expert_model_parallelism=True,
+ expert_parallel_group=group,
+ fp16=False,
+ bf16=True,
+ device=torch.cuda.current_device(),
+ )
+ layer = dmoe.dMoE(args).cuda()
+
+ x = torch.randn((batch_size, sequence_length, hidden_size),
+ device=torch.cuda.current_device(),
+ dtype=torch.bfloat16).requires_grad_(True)
+ torch.cuda.empty_cache()
+
+ # Run forward + backward.
+ # with torch.autograd.detect_anomaly():
+ out, _ = layer(x)
+ out.mean().backward()
+
+ # Report peak memory.
+ mem = torch.cuda.max_memory_allocated()
+ print('Max Memory Allocated = {:0.0f}MiB'.format(mem / 1e6))
+ print('Max Memory Reserved = {:0.0f}MiB'.format(torch.cuda.max_memory_reserved() / 1e6,),)
+
+ # Calculate weight and gradient memory usage.
+ weight_memory = 2 * (
+ layer.router.layer.weight.numel() + layer.experts.mlp.w1.numel() + layer.experts.mlp.w2.numel()
+ )
+
+ def grad_numel(x):
+ if x.grad is not None:
+ return x.grad.numel()
+ return 0
+
+ grad_memory = 2 * (
+ grad_numel(layer.router.layer.weight) + grad_numel(layer.experts.mlp.w1) + grad_numel(layer.experts.mlp.w2)
+ )
+ weight_memory += grad_memory
+
+ print('Weight Memory Allocated = {:0.0f}MiB'.format(weight_memory / 1e6))
+ print('Activation Memory Allocated = {:0.0f}MiB'.format((mem - weight_memory) / 1e6,),)
+
+ # Manually calculate GPU memory usage from the garbage
+ # collector.
+ gc.collect()
+ total = 0
+ tensors = get_tensors()
+ tensors = sorted(tensors, key=lambda x: -x.numel())
+ for i, t in enumerate(tensors):
+ total += t.numel()
+ print(f'{i}: {t.shape}, {t.numel() * 2}')
+ del tensors
+
+ print('Total Bytes Found = {:0.0f}MiB'.format(total * 2 / 1e6))
+
+
+if __name__ == '__main__':
+ assert dist.is_available()
+ group = dist.init_process_group(backend='nccl')
+ local_rank = dist.get_rank(group)
+ torch.cuda.set_device(local_rank)
+
+ for args in _TESTS:
+ test_memory(group, *args)
diff --git a/build/torch29-cxx11-cu128-aarch64-linux/_layers/mlp.py b/build/torch29-cxx11-cu128-aarch64-linux/_layers/mlp.py
new file mode 100644
index 0000000000000000000000000000000000000000..c99afb9904c24a8b6a83e79059cd1251dbbfd99e
--- /dev/null
+++ b/build/torch29-cxx11-cu128-aarch64-linux/_layers/mlp.py
@@ -0,0 +1,587 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# try:
+# import stk
+# import stk.backend.triton_kernels
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/mlp.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+from packaging import version
+
+# from megablocks import grouped_gemm_util as gg
+# from megablocks.layers import common, gelu, mpu
+# from megablocks.layers.activation_fn import act_fn
+# from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
+
+from .. import grouped_gemm_util as gg
+from . import common, gelu, mpu
+from .activation_fn import act_fn
+from .arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
+
+class ScaleGradient(torch.autograd.Function):
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx: Any, x: torch.Tensor, scale: float):
+ ctx.scale = scale
+ return x
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx: torch.Tensor, grad: torch.Tensor):
+ return grad * ctx.scale, None
+
+
+scale_gradient = ScaleGradient.apply
+
+
+def resolve_dtensor(weight: torch.Tensor):
+ if version.parse(torch.__version__) >= version.parse('2.0.0'):
+ from torch.distributed._tensor import DTensor
+ if isinstance(weight, DTensor):
+ return weight.to_local()
+ return weight
+
+
+def create_moe_expert_weights(
+ args: Arguments,
+ num_experts: int,
+ ffn_hidden_size: int,
+ hidden_size: int,
+ init_method: InitFn,
+):
+ # Create the entire weight matrix such that the sampled weights will
+ # not vary between data parallelism and expert model parallelism for
+ # the same random seed.
+ master_weights = torch.empty(
+ num_experts,
+ ffn_hidden_size,
+ hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ )
+ init_method(master_weights)
+
+ if not args.moe_expert_model_parallelism:
+ return master_weights
+
+ # Calculate the amount of sharding in each dimension.
+ expert_sharding_degree = mpu.expert_sharding_degree(args)
+ hidden_sharding_degree = mpu.hidden_sharding_degree(args)
+
+ # Calculate the experts per rank.
+ #
+ # NOTE: We assign ranks to be expert parallel before going
+ # tensor parallel.
+ rank = mpu.get_expert_parallel_rank(args)
+ expert_rank = rank % expert_sharding_degree
+ num_experts_per_rank = num_experts // expert_sharding_degree
+ start_expert = expert_rank * num_experts_per_rank
+ end_expert = (expert_rank + 1) * num_experts_per_rank
+
+ # Calculate the rows per rank.
+ row_rank = rank // expert_sharding_degree
+ num_rows_per_rank = ffn_hidden_size // hidden_sharding_degree
+ start_row = row_rank * num_rows_per_rank
+ end_row = (row_rank + 1) * num_rows_per_rank
+
+ # Slice the weight matrix to get the chunk for this rank.
+ with torch.no_grad():
+ weights = master_weights[start_expert:end_expert, start_row:end_row]
+ return weights
+
+
+class MLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ # expert_parallel_world_size = mpu.get_expert_parallel_world_size(args)
+ experts_per_rank = mpu.experts_per_rank(args)
+
+ self.w1 = torch.nn.Parameter(
+ torch.empty(
+ experts_per_rank,
+ args.hidden_size,
+ mpu.features_per_rank(args),
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ self.w2 = torch.nn.Parameter(
+ torch.empty(
+ experts_per_rank,
+ mpu.features_per_rank(args),
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w1,
+ args.moe_expert_model_parallelism,
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w2,
+ args.moe_expert_model_parallelism,
+ )
+
+ # Initialize the parameters for the MLP.
+ #
+ # NOTE: It is important that we create the weight tensors prior
+ # to creating the master weights and slicing our the piece for
+ # this rank. If the master weights are created first the PyTorch
+ # caching allocator appears to use the same memory block for these
+ # and the slice which causes large increases in our peak memory
+ # usage.
+ with torch.no_grad():
+ w1 = create_moe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ )
+ self.w1.copy_(w1.transpose(1, 2).contiguous())
+ self.w2.copy_(
+ create_moe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.output_layer_init_method,
+ ),
+ )
+
+ self.gradient_scale = None
+ if self.args.moe_expert_model_parallelism:
+ self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,)
+
+ def scale_grad(self, w):
+ if self.gradient_scale is None:
+ return w
+ return scale_gradient(w, self.gradient_scale)
+
+ def forward(self, x):
+ w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2)
+ w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2)
+ x = torch.bmm(x, w1)
+ x = self.args.activation_fn(x)
+ return torch.bmm(x, w2)
+
+
+def create_dmoe_expert_weights(
+ args: Arguments,
+ num_experts: int,
+ rows: int,
+ columns: int,
+ init_method: InitFn,
+):
+ weights = create_moe_expert_weights(
+ args,
+ num_experts,
+ rows,
+ columns,
+ init_method,
+ )
+ return weights.view([-1, columns])
+
+
+class MemoryOptimizedMLP(torch.autograd.Function):
+ """Sparse MLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, w2, topo, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.")
+
+ topo_tensors = (
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets,
+ topo.column_indices_t,
+ topo.offsets_t,
+ topo.block_offsets_t,
+ )
+
+ # Layer 0: x @ w1.t().
+ sdd_out = stk.ops.sdd(x, w1.t(), topo)
+
+ # GeLU.
+ activation_fn_out = act_fn(sdd_out, activation_fn)
+
+ # Layer 1: x @ w2.
+ dsd_out = stk.ops.dsd(activation_fn_out, w2)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.shape = topo.shape
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.data.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, w2, *topo_tensors, x, sdd_out.data)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx, ddsd_out):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, w2 = saved_tensors[:2]
+ topo_tensors = saved_tensors[2:8]
+ x = saved_tensors[8]
+ sdd_out_data = saved_tensors[9]
+
+ # rematerialize activation function output
+ activation_fn = ctx.activation_fn
+ sdd_out = stk.Matrix(ctx.shape, sdd_out_data, *topo_tensors)
+ activation_fn_out, activation_grad_fn = act_fn(
+ sdd_out,
+ activation_fn,
+ return_grad_fn=True,
+ )
+
+ # Compute dw2 with recomputed activation_fn output.
+ dw2 = stk.ops.dsd(activation_fn_out.t(), ddsd_out)
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ stk.backend.triton_kernels.sdd(
+ ddsd_out,
+ w2.t(),
+ dactivation_fn_out.shape,
+ dactivation_fn_out.data,
+ dactivation_fn_out.offsets,
+ dactivation_fn_out.row_indices,
+ dactivation_fn_out.column_indices,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ if activation_fn is DEFAULT_ACTIVATION_FN:
+ dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out)
+ else:
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out.data)
+ dsdd_out = stk.Matrix(ctx.shape, sdd_out.data.grad, *topo_tensors)
+
+ # Compute dw1.
+ dw1 = stk.ops.dsd(dsdd_out.t(), x)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ stk.backend.triton_kernels.dsd(
+ dsdd_out.shape,
+ dsdd_out.data,
+ dsdd_out.offsets,
+ dsdd_out.row_indices,
+ dsdd_out.column_indices,
+ dsdd_out.offsets_t,
+ dsdd_out.column_indices_t,
+ dsdd_out.block_offsets_t,
+ False,
+ w1,
+ ddsd_out,
+ )
+ dx = ddsd_out
+ return dx, dw1, dw2, None, None
+
+
+memory_optimized_mlp = MemoryOptimizedMLP.apply
+
+
+class SparseMLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ self._num_rows_per_rank = mpu.experts_per_rank(args) * mpu.features_per_rank(args)
+
+ self.w1 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ self.w2 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+
+ # Initialize the parameters for the MLP.
+ #
+ # NOTE: It is important that we create the weight tensors prior
+ # to creating the master weights and slicing our the piece for
+ # this rank. If the master weights are created first the PyTorch
+ # caching allocator appears to use the same memory block for these
+ # and the slice which causes large increases in our peak memory
+ # usage.
+ with torch.no_grad():
+ self.w1.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ ),
+ )
+ self.w2.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.output_layer_init_method,
+ ),
+ )
+
+ self._should_set_parallelism_attribute = args.moe_expert_model_parallelism
+ mpu.set_expert_model_parallel_attributes(
+ self.w1,
+ self._should_set_parallelism_attribute,
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w2,
+ self._should_set_parallelism_attribute,
+ )
+
+ self.gradient_scale = None
+ if self.args.moe_expert_model_parallelism:
+ self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,)
+
+ def scale_grad(self, w):
+ if self.gradient_scale is None:
+ return w
+ return scale_gradient(w, self.gradient_scale)
+
+ def forward(self, x, topo):
+ w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2)
+ w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2)
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_mlp(
+ x,
+ w1,
+ w2,
+ topo,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ x = stk.ops.sdd(x, w1.t(), topo)
+ activation_fn_out = act_fn(x, self.args.activation_fn)
+ return stk.ops.dsd(activation_fn_out, w2)
+
+
+class MemoryOptimizedGroupedMLP(torch.autograd.Function):
+ """GroupedMLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, w2, batch_sizes, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.")
+
+ # Layer 0: x @ w1.t().
+ assert gg.backend is not None
+ sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True)
+
+ # activation_fn
+ activation_fn_out = activation_fn(sdd_out)
+
+ # Layer 1: x @ w2.
+ dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, w2, batch_sizes, x, sdd_out)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx: Any, ddsd_out: torch.Tensor):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # Unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, w2 = saved_tensors[:2]
+ batch_sizes = saved_tensors[2]
+ x = saved_tensors[3]
+ sdd_out = saved_tensors[4]
+
+ # Rematerialize activation_fn output.
+ activation_fn = ctx.activation_fn
+ with torch.set_grad_enabled(True):
+ sdd_out.requires_grad = True
+ activation_fn_out = activation_fn(sdd_out)
+ activation_grad_fn = activation_fn_out.backward
+
+ # Compute dw2 with recomputed activation_fn output.
+ assert gg.backend is not None
+ dw2 = gg.backend.gmm(
+ activation_fn_out,
+ ddsd_out,
+ batch_sizes,
+ trans_a=True,
+ )
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ gg.backend.gmm(
+ ddsd_out,
+ w2,
+ batch_sizes,
+ trans_b=True,
+ c=dactivation_fn_out,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ if activation_fn is DEFAULT_ACTIVATION_FN:
+ dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out)
+ else:
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out)
+ dsdd_out = sdd_out.grad
+
+ # Compute dw1.
+ dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ gg.backend.gmm(dsdd_out, w1, batch_sizes, c=ddsd_out)
+ dx = ddsd_out
+ return dx, dw1, dw2, None, None
+
+
+memory_optimized_grouped_mlp = MemoryOptimizedGroupedMLP.apply
+
+
+class GroupedMLP(SparseMLP):
+
+ def forward(self, x, tokens_per_expert):
+ batch_sizes = tokens_per_expert.cpu().to(torch.long)
+ w1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.w2))
+
+ # Re-shape the weights for the grouped GEMMs.
+ ne = mpu.experts_per_rank(self.args)
+ w1 = resolve_dtensor(w1).view(ne, -1, self.args.hidden_size)
+ w2 = resolve_dtensor(w2).view(ne, -1, self.args.hidden_size)
+
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_grouped_mlp(
+ x,
+ w1,
+ w2,
+ batch_sizes,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ assert gg.ops is not None
+ x = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
+ x = self.args.activation_fn(x)
+ return gg.ops.gmm(x, w2, batch_sizes)
+
+
+class SharedMLP(torch.nn.Module):
+ """MLP for shared expert.
+
+ Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTMLP class
+ """
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ self.fc_kwargs: dict[str, Any] = {
+ 'bias': args.bias,
+ 'device': args.device,
+ }
+ self.fc_kwargs.update(args.fc_kwargs)
+
+ self.up_proj = args.fc_cls(
+ args.hidden_size,
+ args.shared_expert_hidden_size,
+ **self.fc_kwargs,
+ )
+ self.act = args.activation_fn
+ self.down_proj = args.fc_cls(
+ args.shared_expert_hidden_size,
+ args.hidden_size,
+ **self.fc_kwargs,
+ )
+ self.down_proj._is_residual = True # a flag for llm-foundry init
+
+ def add_experts_sharedexpert(
+ self,
+ shared_expert_out: torch.Tensor,
+ expert_out: torch.Tensor,
+ ) -> torch.Tensor:
+ # Helper function to add expert output to shared expert output
+ # with optional weighted sum.
+ if self.args.shared_expert_weighted_sum:
+ # enable using weighted sum for shared expert output
+ # wieghted by number of experts used
+ t_experts = self.args.moe_top_k + 1
+ sh_mlp_out = shared_expert_out / t_experts
+ return sh_mlp_out.add(
+ expert_out,
+ alpha=(self.args.moe_top_k / t_experts),
+ )
+
+ return shared_expert_out + expert_out
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.down_proj(self.act(self.up_proj(x)))
diff --git a/build/torch29-cxx11-cu128-aarch64-linux/_layers/moe.py b/build/torch29-cxx11-cu128-aarch64-linux/_layers/moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0a4aeaacc9c86fc70944e730c53f7a55644e05e
--- /dev/null
+++ b/build/torch29-cxx11-cu128-aarch64-linux/_layers/moe.py
@@ -0,0 +1,507 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Optional, Tuple
+
+import numpy as np
+import torch
+import torch.distributed as dist
+
+# import megablocks.ops as ops
+# from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry
+# from megablocks.layers.all_to_all import all_to_all
+# from megablocks.layers.arguments import Arguments
+
+from ..ops import (
+ sort,
+ histogram,
+ inclusive_cumsum,
+ exclusive_cumsum,
+ binned_gather,
+ binned_scatter,
+ gather,
+ scatter,
+ repeat,
+ replicate,
+)
+
+from . import common, mlp, mpu, router, sharedexpert_registry
+from .arguments import Arguments
+from .all_to_all import all_to_all
+
+_LOAD_BALANCING_LOSS = []
+
+
+def save_load_balancing_loss(loss):
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.append(loss)
+
+
+def get_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ return _LOAD_BALANCING_LOSS
+
+
+def clear_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.clear()
+
+
+def batched_load_balancing_loss(args: Arguments):
+ if args.moe_loss_weight == 0:
+ return 0.0
+
+ # tokens_per_expert[i].shape = (num_experts)
+ # expert_scores[i].shape = (tokens, num_experts)
+ tokens_per_expert, expert_scores = zip(*get_load_balancing_loss())
+ num_layers_per_pipeline_stage = (args.num_layers // args.pipeline_model_parallel_size)
+ if args.num_layers_per_virtual_pipeline_stage is not None:
+ num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage
+
+ if len(tokens_per_expert) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f'Expected {num_layers_per_pipeline_stage} token_per_experts '
+ f'but found {len(tokens_per_expert)}.\nnum_layers = '
+ f'{args.num_layers}\npipeline_model_parallel_size = '
+ f'{args.pipeline_model_parallel_size}\n'
+ 'num_layers_per_virtual_pipeline_stage'
+ f' = {args.num_layers_per_virtual_pipeline_stage}',
+ )
+ if len(expert_scores) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f'Expected {num_layers_per_pipeline_stage} expert_scores '
+ f'but found {len(tokens_per_expert)}.\nnum_layers = '
+ f'{args.num_layers}\npipeline_model_parallel_size = '
+ f'{args.pipeline_model_parallel_size}\n'
+ 'num_layers_per_virtual_pipeline_stage'
+ f' = {args.num_layers_per_virtual_pipeline_stage}',
+ )
+
+ # Verify the shape of the tokens_per_expert and expert_scores tensors.
+ assert all((x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert))
+
+ tokens = expert_scores[0].shape[0]
+ assert all(((x.ndim == 2 and x.shape[1] == args.moe_num_experts and x.shape[0] == tokens) for x in expert_scores))
+
+ # Concatenate the contributions of each layer and convert to
+ # the correct types and formats for the dot product.
+ expert_scores = torch.cat(expert_scores, dim=1)
+ if args.moe_lbl_in_fp32:
+ expert_scores = expert_scores.float()
+ if tokens != 0:
+ expert_scores = expert_scores.mean(dim=0)
+ else:
+ expert_scores = expert_scores.sum(dim=0)
+ tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype)
+
+ expected_values = num_layers_per_pipeline_stage * args.moe_num_experts
+ assert tokens_per_expert.numel() == expected_values
+ assert expert_scores.numel() == expected_values
+
+ # Calculate the total scale across all factors.
+ #
+ # loss_weight * num_experts / (num_layers * tokens * top_k)
+ scale_numerator = (args.moe_num_experts * args.moe_loss_weight)
+ scale_denominator = (args.num_layers * tokens * args.moe_top_k)
+ scale = scale_numerator / scale_denominator
+ return scale * torch.dot(tokens_per_expert, expert_scores)
+
+
+# NOTE: This class defines MoE expert computation, including expert model parallel
+# communication. When using FSDP on top of MegaBlocks this is the module that should
+# be wrapped s.t. the weight all-gathers can be scheduled *before* the expert model
+# parallel all2all.
+class ParallelMLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super(ParallelMLP, self).__init__()
+ self.args = args
+
+ # Calculate the number of experts in total and the number of experts
+ # owned by this rank.
+ # world_size = mpu.get_expert_parallel_world_size(args)
+ self.num_experts = args.moe_num_experts
+ self.top_k = self.args.moe_top_k
+
+ # Calculate the number of bits needed to represent the expert indices
+ # so that we can pass it to radix sort.
+ self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1)
+
+ # Expert MLP.
+ self.mlp = mlp.MLP(args)
+
+ self.bias: Optional[torch.Tensor]
+ if self.args.bias:
+ # Note that the output bias is not parallelized with expert
+ # model parallelism.
+ self.bias = torch.nn.Parameter(
+ torch.empty(
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ torch.nn.init.zeros_(self.bias)
+ else:
+ self.register_parameter('bias', None)
+
+ # Select the forward function for the operating mode.
+ self.forward_fn = (self.parallel_forward_once if args.moe_expert_model_parallelism else self.forward_once)
+
+ def expert_capacity(self, tokens: int) -> int:
+ world_size = mpu.get_expert_parallel_world_size(self.args)
+ tokens_per_expert = (self.top_k * tokens * world_size / self.num_experts)
+ return int(self.args.moe_capacity_factor * tokens_per_expert)
+
+ def load_balancing_loss(self, tokens_per_expert: torch.Tensor, expert_scores: torch.Tensor):
+ """Calculate the load balancing loss contribution."""
+ assert len(expert_scores.size()) == 2
+ tokens, num_experts = expert_scores.size()
+ assert num_experts == self.num_experts
+ assert len(tokens_per_expert.size()) == 1
+ num_experts, = tokens_per_expert.size()
+ assert num_experts == self.num_experts
+ scale = self.num_experts / (tokens * self.top_k)
+ return scale * torch.dot(
+ tokens_per_expert.to(expert_scores.dtype),
+ expert_scores.mean(dim=0),
+ )
+
+ def indices_and_bins(self,
+ top_expert: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ # Sort the expert ids to produce the scatter/gather
+ # indices for the permutation.
+ #
+ # TODO(tgale): Is it worth doing this conversion to 32-bit
+ # prior? Could we place the `torch.max` operation to return
+ # 32-bit expert indices?
+ top_expert = top_expert.int()
+ # output = ops.sort(top_expert, self.sort_end_bit)
+ output = sort(top_expert, self.sort_end_bit)
+ assert output is not None
+ bin_ids, indices = output
+
+ # Histogram the expert ids to identify the number of
+ # tokens routed to each expert.
+ #
+ # TODO(tgale): Does the sorted data produce a more favorable
+ # data distribution for histogram? Or is the op parallelism
+ # worth more?
+ # tokens_per_expert = ops.histogram(top_expert, self.num_experts)
+ tokens_per_expert = histogram(top_expert, self.num_experts)
+
+ # Calculate the bin bounds for the sorted tokens.
+ # bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ bins = inclusive_cumsum(tokens_per_expert, 0)
+ assert bins is not None
+ bins = bins.view(1) if not len(bins.size()) else bins
+
+ assert isinstance(indices, torch.Tensor)
+ assert isinstance(bin_ids, torch.Tensor)
+ assert isinstance(bins, torch.Tensor)
+ assert isinstance(tokens_per_expert, torch.Tensor)
+
+ return indices, bin_ids, bins, tokens_per_expert
+
+ def permute_and_compute(
+ self,
+ x: torch.Tensor,
+ tokens_per_expert: int, # unused
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor, # unused
+ expert_weights: torch.Tensor,
+ bins: torch.Tensor,
+ expert_capacity: int,
+ top_k: int,
+ ):
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ # output = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
+ output = binned_gather(x, indices, bins, expert_capacity, top_k)
+ assert output is not None
+ x = output
+
+ # Perform the expert computation. Note that we don't
+ # use biases for these linear operations.
+ x = self.mlp(x)
+
+ # Un-route the data for the MoE output.
+ # return ops.binned_scatter(x, indices, expert_weights, bins, top_k)
+ return binned_scatter(x, indices, expert_weights, bins, top_k)
+
+
+ def forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ # If expert_capacity is set to zero, set the number of tokens
+ # per expert to the maximum we need to avoid dropping tokens.
+ sl, bs, _ = x.size()
+ expert_capacity = self.expert_capacity(sl * bs)
+ if expert_capacity == 0:
+ expert_capacity = torch.max(tokens_per_expert).item()
+
+ x = self.permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ self.top_k,
+ )
+ return x, tokens_per_expert
+
+ def parallel_forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ # NOTE: This function implements the same computation as forward_once
+ # but with expert model parallelism.
+ #
+ # 1. Permute the tokens locally so that they are grouped by their
+ # expert assignments. This allows us to transfer all of the tokens
+ # for a remote device in one communication primitive.
+ #
+ # 2. Permute the tokens across the expert parallel devices. After
+ # this is completed each device has all of the tokens assigned to
+ # its set of experts in its local HBM.
+ #
+ # 3. Permute the tokens locally so that they are grouped by their
+ # expert assignement. After the distributed permutation the tokens
+ # are grouped by which device they came from. We re-order them
+ # locally to allow for efficient computation.
+ #
+ # After this series of permutations we compute the linear layers
+ # and then repeat these three steps in reverse to produce the final
+ # output.
+ #
+ # Compute the mapping of local tokens to experts.
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ # If we're sharding the experts along the hidden dimension
+ # multiple devices own parts of the same sets of experts.
+ # Replicate the token counts so every device gets the counts.
+ # repeated_tokens_per_expert = ops.repeat(
+ repeated_tokens_per_expert = repeat(
+ tokens_per_expert,
+ (mpu.hidden_sharding_degree(self.args),),
+ )
+
+ # Pass token count information to the device on which the
+ # target expert resides.
+ parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert,)
+ tpe_handle = dist.all_to_all_single(
+ parallel_tokens_per_expert,
+ repeated_tokens_per_expert,
+ group=self.args.expert_parallel_group,
+ async_op=True,
+ )
+
+ # Permute locally and without any padding so that tokens for each
+ # parallel device are stored contiguously.
+ #
+ # This view updates the shape of the tensor from [sl, bs, hs] to
+ # [sl * bs, hs] prior to the permutation.
+ x = x.view(-1, x.shape[-1])
+ # output = ops.gather(x, indices, bin_ids, bins, self.top_k)
+ output = gather(x, indices, bin_ids, bins, self.top_k)
+ assert output is not None
+ x = output
+
+ # Compute the number of tokens that will be received from each
+ # device and permute the input data across the devices.
+ with torch.no_grad():
+ tpe_handle.wait()
+ experts_per_rank = mpu.experts_per_rank(self.args)
+
+ # Reshape to [world_size, num_experts_per_rank].
+ world_size = mpu.get_expert_parallel_world_size(self.args)
+ repeated_tokens_per_expert = (repeated_tokens_per_expert.view(world_size, experts_per_rank))
+ parallel_tokens_per_expert = (parallel_tokens_per_expert.view(world_size, experts_per_rank))
+
+ # TODO(tgale): It might be faster to do this on the GPU and
+ # then communicate the results back to the host.
+ send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1)
+ parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
+ recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1)
+
+ # Convert the send/recv counts to lists.
+ send_counts = send_counts.tolist()
+ recv_counts = recv_counts.tolist()
+ tokens_received = sum(recv_counts)
+
+ # If we're sharding the experts along the hidden dimension
+ # multiple devices own parts of the same sets of experts.
+ # Replicate the token counts so devices that share experts
+ # get all of the tokens assigned to them.
+ #
+ # TODO(tgale): Fuse this into the prior, local permutation.
+ # x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
+ x = repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
+
+ # Start the cross-device permutation asynchronously so we can
+ # overlap communication with computation.
+ parallel_x, parallel_x_handle = all_to_all(
+ x,
+ recv_counts,
+ send_counts,
+ self.args.expert_parallel_group,
+ async_op=True,
+ )
+
+ with torch.no_grad():
+ # After we do the cross-device permutation we have the tokens on the
+ # correct device but not yet grouped by expert because we received
+ # tokens from each device as contiguous chunks. To group the tokens
+ # for expert computation we'll do one more local permutation. The
+ # rest of this torch.no_grad() scope sets up the indices and bins
+ # for this permutation.
+ # replicate_bins = ops.inclusive_cumsum(
+ replicate_bins = inclusive_cumsum(
+ parallel_tokens_per_expert.flatten(),
+ 0,
+ )
+ replicate_bins = (replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins)
+
+ # Construct the expert indices for the permuted tokens.
+ parallel_top_expert = torch.remainder(
+ torch.arange(
+ self.num_experts * mpu.hidden_sharding_degree(self.args),
+ dtype=torch.int32,
+ device=indices.device,
+ ),
+ mpu.experts_per_rank(self.args),
+ )
+ # parallel_top_expert = ops.replicate(
+ parallel_top_expert = replicate(
+ parallel_top_expert.unsqueeze(dim=0),
+ replicate_bins,
+ tokens_received,
+ ).flatten()
+
+ # TODO(tgale): The sort_end_bit here can be reduced.
+ # parallel_bin_ids, parallel_indices = ops.sort(
+ parallel_bin_ids, parallel_indices = sort(
+ parallel_top_expert,
+ self.sort_end_bit,
+ )
+
+ # Calculate the bins boundaries from the token counts.
+ parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+ # parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = (parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins)
+
+ # If expert_capacity is set to zero, set the number of tokens
+ # per expert to the maximum we need to avoid dropping tokens.
+ tokens, _ = x.size()
+ expert_capacity = self.expert_capacity(tokens)
+ if expert_capacity == 0:
+ expert_capacity = torch.max(parallel_tokens_per_expert).item()
+
+ # Locally permute the tokens and perform the expert computation.
+ # Block to make sure that the cross-device permutation is complete.
+ if self.args.mlp_impl == 'grouped':
+ # GroupedMLP requires counts on CPU. We can use the tensor already
+ # moved to CPU for the prior all_to_all, which avoids an extra
+ # device synchronization.
+ parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+ parallel_x_handle.wait()
+ parallel_x = self.permute_and_compute(
+ parallel_x,
+ parallel_tokens_per_expert,
+ parallel_indices,
+ parallel_bin_ids,
+ None, # expert_weights
+ parallel_bins,
+ expert_capacity,
+ top_k=1,
+ )
+
+ # Un-permute the tokens across the devices.
+ x, _ = all_to_all(
+ parallel_x,
+ send_counts,
+ recv_counts,
+ self.args.expert_parallel_group,
+ )
+
+ # Reduce along the hidden sharding to get the final outputs.
+ #
+ # TODO(tgale): Fuse this into the following local permutation.
+ shape = (
+ mpu.hidden_sharding_degree(self.args),
+ -1,
+ self.args.hidden_size,
+ )
+ # x = ops.sum(x.view(shape), dim=0)
+ x = x.view(shape).sum(dim=0)
+
+ # Un-permute locally to setup for the next series of operations.
+ # x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
+ x = scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
+ return x, tokens_per_expert.flatten()
+
+ def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ in_shape = x.size()
+
+ # Compute the experts.
+ x, tokens_per_expert = self.forward_fn(x, expert_weights, top_experts)
+ if self.training and self.args.moe_loss_weight > 0:
+ save_load_balancing_loss((tokens_per_expert, scores))
+ x = x.view(in_shape)
+ if self.bias is not None:
+ if self.args.return_bias:
+ return x, self.bias
+ return x + self.bias
+ return x
+
+
+class MoE(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super(MoE, self).__init__()
+
+ # Token router.
+ self.router = router.LearnedRouter(args)
+
+ # Expert computation helper.
+ self.experts = self._init_experts_mlp(args)
+
+ self.shared_expert = None
+ if args.shared_expert:
+ # SharedExpert computation helper.
+ self.shared_expert = sharedexpert_registry.get(args)
+
+ def _init_experts_mlp(self, args: Arguments):
+ return ParallelMLP(args)
+
+ def forward(self, x: torch.Tensor):
+ # NOTE: If we're going to cast the activations to lower precision
+ # do it before we permute the tokens to save bandwidth.
+ x = common.cast_if_autocast_enabled(x)
+
+ # Compute the expert scores and assignments.
+ scores, expert_weights, top_experts = self.router(x)
+
+ # Compute the experts.
+ out = self.experts(x, scores, expert_weights, top_experts)
+ if self.shared_expert is not None:
+ shared_expert_out = self.shared_expert(x)
+ out = self.shared_expert.add_experts_sharedexpert(
+ shared_expert_out,
+ out,
+ )
+ return out
diff --git a/build/torch29-cxx11-cu128-aarch64-linux/_layers/mpu.py b/build/torch29-cxx11-cu128-aarch64-linux/_layers/mpu.py
new file mode 100644
index 0000000000000000000000000000000000000000..434e143ab42bf3f83406d69e9dd1f72777716e22
--- /dev/null
+++ b/build/torch29-cxx11-cu128-aarch64-linux/_layers/mpu.py
@@ -0,0 +1,94 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Optional
+
+import torch
+import torch.distributed as dist
+
+# from megablocks.layers.arguments import Arguments
+from .arguments import Arguments
+
+
+class MoeParam(torch.Tensor):
+
+ def __init__(self):
+ super().__init__(self)
+ self.expert_model_parallel: bool
+
+
+def is_moe_param(tensor: torch.Tensor) -> bool:
+ return hasattr(tensor, 'expert_model_parallel')
+
+
+def get_expert_parallel_world_size(args: Arguments) -> int:
+ return (dist.get_world_size(args.expert_parallel_group) if args.moe_expert_model_parallelism else 1)
+
+
+def get_expert_parallel_rank(args: Arguments) -> int:
+ return (dist.get_rank(args.expert_parallel_group) if args.moe_expert_model_parallelism else 0)
+
+
+def set_expert_model_parallel_attributes(
+ tensor: torch.Tensor,
+ is_parallel: bool,
+):
+ assert not hasattr(tensor, 'expert_model_parallel')
+ setattr(tensor, 'expert_model_parallel', is_parallel)
+
+
+def param_is_expert_model_parallel(param: MoeParam) -> bool:
+ return (hasattr(param, 'expert_model_parallel') and param.expert_model_parallel)
+
+
+def copy_expert_model_parallel_attributes(
+ destination_tensor: torch.Tensor,
+ source_tensor: torch.Tensor,
+):
+ if hasattr(source_tensor, 'expert_model_parallel'):
+ setattr(
+ destination_tensor,
+ 'expert_model_parallel',
+ getattr(source_tensor, 'expert_model_parallel'),
+ )
+
+
+def synchronized_print(group: Optional[dist.ProcessGroup], *x: torch.Tensor):
+ world_size = dist.get_world_size(group)
+ rank = dist.get_rank(group)
+ for i in range(world_size):
+ dist.barrier(group)
+ if i == rank:
+ print(f'rank = {rank}', *x)
+
+
+# Helpers for expert/tensor sharding.
+def expert_sharding_degree(args: Arguments) -> int:
+ world_size = get_expert_parallel_world_size(args)
+ esd = min(world_size, args.moe_num_experts)
+
+ if (args.moe_num_experts % esd) != 0:
+ raise ValueError(f'Cannot shard {args.moe_num_experts} experts {esd} ways.',)
+ return esd
+
+
+def hidden_sharding_degree(args: Arguments) -> int:
+ world_size = get_expert_parallel_world_size(args)
+ esd = expert_sharding_degree(args)
+ hsd = world_size // esd
+
+ if (args.ffn_hidden_size % hsd) != 0:
+ raise ValueError(f'Cannot shard {args.ffn_hidden_size} features {hsd} ways.',)
+ if (esd * hsd) != world_size:
+ raise ValueError(
+ f"Invalid sharding. 'expert_sharding_degree' ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size}).",
+ )
+ return hsd
+
+
+def experts_per_rank(args: Arguments) -> int:
+ return args.moe_num_experts // expert_sharding_degree(args)
+
+
+def features_per_rank(args: Arguments) -> int:
+ return args.ffn_hidden_size // hidden_sharding_degree(args)
diff --git a/build/torch29-cxx11-cu128-aarch64-linux/_layers/router.py b/build/torch29-cxx11-cu128-aarch64-linux/_layers/router.py
new file mode 100644
index 0000000000000000000000000000000000000000..37cb2782348d62583376f1a183c7ede83601216d
--- /dev/null
+++ b/build/torch29-cxx11-cu128-aarch64-linux/_layers/router.py
@@ -0,0 +1,116 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+
+# from megablocks.layers import common
+# from megablocks.layers.arguments import Arguments
+from . import common
+from .arguments import Arguments
+
+_ROUTER_LOGITS = []
+
+
+def _save_router_logits(logits: torch.Tensor, args: Arguments):
+ if args.moe_zloss_weight == 0:
+ return
+ global _ROUTER_LOGITS
+ _ROUTER_LOGITS.append(logits)
+
+
+def clear_router_zloss():
+ global _ROUTER_LOGITS
+ _ROUTER_LOGITS.clear()
+
+
+def batched_router_zloss(args: Arguments):
+ global _ROUTER_LOGITS
+
+ if args.moe_zloss_weight == 0:
+ import warnings
+ warnings.warn('Call to batched_router_zloss, but moe_zloss_weight=0')
+ return 0
+
+ logits_per_router = _ROUTER_LOGITS
+
+ if args.moe_zloss_in_fp32:
+ logits_per_router = [logits.float() for logits in logits_per_router]
+
+ unscaled_zloss_per_router = torch.stack([
+ torch.logsumexp(logits, dim=1).square().mean() for logits in logits_per_router
+ ])
+
+ return args.moe_zloss_weight * unscaled_zloss_per_router
+
+
+# NOTE: To enable end-to-end benchmarking without convergence we
+# support a flag to force the router to assign tokens uniformly
+# across the experts. We do this with a custom autograd operation
+# so that PyTorch still executes the full set of router operation.
+class _UniformExpertAssignment(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, num_experts: int):
+ out = torch.arange(x.numel(), dtype=x.dtype, device=x.device)
+ out = torch.remainder(out, num_experts)
+ return out.view(x.shape)
+
+
+_uniform_expert_assignment = _UniformExpertAssignment.apply
+
+
+class LearnedRouter(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+
+ # Learned router parameters.
+ #
+ # NOTE: This weight matrix is not parallelized with expert model
+ # parallelism. Each device needs the entire router weight matrix
+ # so that it can route its batch of data correctly.
+ self.layer = torch.nn.Linear(
+ args.hidden_size,
+ args.moe_num_experts,
+ bias=False,
+ dtype=common.dtype(args),
+ device=args.device,
+ )
+ args.init_method(self.layer.weight)
+
+ def jitter(self, x: torch.Tensor):
+ low: float = 1.0 - self.args.moe_jitter_eps
+ high: float = 1.0 + self.args.moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return low + noise * (high - low)
+
+ def _top_k(self, scores: torch.Tensor):
+ if self.args.moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, self.args.moe_top_k, dim=-1)
+
+ def forward(self, x: torch.Tensor):
+ if self.training and self.args.moe_jitter_eps is not None:
+ x = x * self.jitter(x)
+
+ logits = self.layer(x.view(-1, x.shape[-1]))
+ _save_router_logits(logits, self.args)
+ scores = logits.softmax(dim=-1)
+ expert_weights, expert_indices = self._top_k(scores)
+ if self.args.moe_normalize_expert_weights:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=self.args.moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+
+ expert_indices = (
+ _uniform_expert_assignment(
+ expert_indices,
+ self.args.moe_num_experts,
+ ) if self.args.uniform_expert_assignment else expert_indices
+ )
+ return scores, expert_weights, expert_indices
diff --git a/build/torch29-cxx11-cu128-aarch64-linux/_layers/sharedexpert_registry.py b/build/torch29-cxx11-cu128-aarch64-linux/_layers/sharedexpert_registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..5840862f88f370ace5fd49bd0612fc98d186cc49
--- /dev/null
+++ b/build/torch29-cxx11-cu128-aarch64-linux/_layers/sharedexpert_registry.py
@@ -0,0 +1,32 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Union
+
+# from megablocks.layers import glu, mlp
+# from megablocks.layers.arguments import Arguments
+from . import glu, mlp
+from .arguments import Arguments
+
+_REGISTRY = {
+ 'mlp': mlp.SharedMLP,
+ 'glu': glu.SharedGLU,
+}
+
+
+def get(args: Arguments) -> Union[mlp.SharedMLP, glu.SharedGLU]:
+ """Returns an SharedMLP for use in a dMoE instance.
+
+ Uses the provided arguments to instantiate the appropriate
+ SharedMLP instance.
+
+ Args:
+ args: propagated Arguments dataclass.
+
+ Returns:
+ An instantiated SharedMLP constructed using the input args.
+ """
+ if args.mlp_type not in _REGISTRY:
+ raise ValueError(f'Unsupported mlp type: {args.mlp_type}')
+
+ return _REGISTRY[args.mlp_type](args)
diff --git a/build/torch29-cxx11-cu128-aarch64-linux/_megablocks_cuda_6e04dec.abi3.so b/build/torch29-cxx11-cu128-aarch64-linux/_megablocks_cuda_6e04dec.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..5011e7363f8536c27906751647ac7eee905efc70
--- /dev/null
+++ b/build/torch29-cxx11-cu128-aarch64-linux/_megablocks_cuda_6e04dec.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:81684a3eed6a7fb374cdbba3cf65f1cd46f5392ddc6d4992d37186c3b15f5734
+size 21085456
diff --git a/build/torch29-cxx11-cu128-aarch64-linux/_ops.py b/build/torch29-cxx11-cu128-aarch64-linux/_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..b2f202b8db3c3f3028303ab4308cf35f950e2c74
--- /dev/null
+++ b/build/torch29-cxx11-cu128-aarch64-linux/_ops.py
@@ -0,0 +1,9 @@
+import torch
+from . import _megablocks_cuda_6e04dec
+ops = torch.ops._megablocks_cuda_6e04dec
+
+def add_op_namespace_prefix(op_name: str):
+ """
+ Prefix op by namespace.
+ """
+ return f"_megablocks_cuda_6e04dec::{op_name}"
diff --git a/build/torch29-cxx11-cu128-aarch64-linux/_version.py b/build/torch29-cxx11-cu128-aarch64-linux/_version.py
new file mode 100644
index 0000000000000000000000000000000000000000..c55783177af19bc03654c730c4892df8f8532279
--- /dev/null
+++ b/build/torch29-cxx11-cu128-aarch64-linux/_version.py
@@ -0,0 +1,6 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+"""The MegaBlocks Version."""
+
+__version__ = '0.11.0.dev0'
diff --git a/build/torch29-cxx11-cu128-aarch64-linux/backend/__init__.py b/build/torch29-cxx11-cu128-aarch64-linux/backend/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d4e43e9b7e6a9a1c3bde2df34914643ca5d8332
--- /dev/null
+++ b/build/torch29-cxx11-cu128-aarch64-linux/backend/__init__.py
@@ -0,0 +1,2 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
diff --git a/build/torch29-cxx11-cu128-aarch64-linux/backend/kernels.py b/build/torch29-cxx11-cu128-aarch64-linux/backend/kernels.py
new file mode 100644
index 0000000000000000000000000000000000000000..d20139324883338ddb312e4b05a72056d74491ac
--- /dev/null
+++ b/build/torch29-cxx11-cu128-aarch64-linux/backend/kernels.py
@@ -0,0 +1,557 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import triton
+import triton.language as tl
+
+# Stub triton autotune when testing in a env that does not have CUDA
+# this approach preserves the original code but enables testing without a GPU
+if torch.cuda.is_available() is False:
+ import warnings
+
+ warnings.warn("CUDA is not available. Triton autotuning is disabled.")
+
+ def _no_autotune(*args, **kwargs):
+ def deco(fn):
+ return fn
+ return deco
+
+ triton.autotune = _no_autotune
+
+
+def assert_is_tensor(x, ndim):
+ if x.ndim != ndim:
+ raise ValueError(f'Expected {ndim}-tensor but got {x.ndim}-tensor')
+
+
+def assert_is_matrix(x):
+ assert_is_tensor(x, 2)
+
+
+def assert_is_vector(x):
+ if x.ndim != 1:
+ raise ValueError(f'Expected 1-tensor but got {x.ndim}-tensor')
+
+
+def assert_equal(a, b):
+ if a != b:
+ raise ValueError(f'Expected dimensions to be equal but got {a} and {b}.',)
+
+
+# a: (tokens, hidden_size), real.
+# indices: (tokens * top_k), integer.
+# bin_ids: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+# padded_bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _padded_copy(
+ a,
+ b,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+ A_TO_B: tl.constexpr,
+ SCALE: tl.constexpr,
+):
+ # Our index into array 'a'.
+ index_a = tl.load(indices + tl.program_id(0))
+
+ # One threadblock per row in 'a'. Array 'b' has greater or equal
+ # number of rows since they could be padded.
+ bin_idx = tl.load(bin_ids + tl.program_id(0))
+
+ # Now we know what bin we're assigned to, but we need to know how
+ # many threadblocks were assigned to earlier bins so we can offset
+ # in our bin properly.
+ offset_in_bin = tl.program_id(0)
+ if bin_idx > 0:
+ offset_in_bin -= tl.load(bins + bin_idx - 1)
+
+ # Load the starting index of our bin in array 'b'.
+ index_b = offset_in_bin
+ if bin_idx > 0:
+ index_b += tl.load(padded_bins + bin_idx - 1)
+
+ # Offset the input and output pointers.
+ #
+ # If we're going from A to B, divide the input index to copy
+ # the same input repeatedly. If we're going from B to A we
+ # need to reduce the result. Using atomics is slow, so we
+ # do the reduce step in a second kernel.
+ offset = index_a // TOP_K if A_TO_B else index_a
+ a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS)
+ b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ # Load the scale, if requested.
+ scale = tl.load(weights + index_a) if SCALE else 1
+
+ # Swap the pointers depending on the direction.
+ iptr = a if A_TO_B else b
+ optr = b if A_TO_B else a
+
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ x = tl.load(iptr + offsets, mask=mask)
+ x = x.to(tl.float32) * scale.to(tl.float32)
+
+ tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask)
+
+ offsets += BLOCK_X
+
+
+def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+ assert_equal(bin_ids.shape[0], x.shape[0] * top_k)
+ assert_equal(bins.size(), padded_bins.size())
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ # NOTE: Because of the padding, the output size is dynamic.
+ # We load the final padded bin bound to get the output rows.
+ output_rows = padded_bins[-1].cpu().item()
+ out = torch.zeros((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ x,
+ out,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def gather(x, indices, bin_ids, weights, bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+ assert_equal(bin_ids.shape[0], x.shape[0] * top_k)
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ # NOTE: There is no padding so the output rows equals the
+ # input rows multiplied by top_k.
+ output_rows = x.shape[0] * top_k
+ out = torch.empty((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ x,
+ out,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], bin_ids.shape[0])
+ assert_equal(bins.size(), padded_bins.size())
+
+ if weights is not None:
+ assert_equal(indices.shape[0], weights.shape[0])
+
+ tokens = indices.shape[0] // top_k
+ out = torch.empty((tokens, top_k, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ out,
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=False,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+
+ # Reduce along the top-k dimension, if needed.
+ return out.sum(dim=1) if top_k > 1 else out.view(tokens, x.shape[1])
+
+
+def scatter(x, indices, bin_ids, weights, bins, top_k):
+ return padded_scatter(x, indices, bin_ids, weights, bins, bins, top_k)
+
+
+# x: (tokens, top_k, hidden_size), real
+# grad: (tokens, hidden_size), real.
+# wgrad: (tokens, top_k), real.
+# indices: (tokens * top_k), integer.
+# bin_ids: (tokens * top_k), integer.
+# bins: (num_experts), integer.
+# padded_bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _padded_copy_wgrad(
+ x,
+ grad,
+ wgrad,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+):
+ # Our index into 'tokens * top_k'.
+ index_out = tl.load(indices + tl.program_id(0))
+
+ # One threadblock per row in 'a'. Array 'b' has greater or equal
+ # number of rows since they could be padded.
+ bin_idx = tl.load(bin_ids + tl.program_id(0))
+
+ # Now we know what bin we're assigned to, but we need to know how
+ # many threadblocks were assigned to earlier bins so we can offset
+ # in our bin properly.
+ offset_in_bin = tl.program_id(0)
+ if bin_idx > 0:
+ offset_in_bin -= tl.load(bins + bin_idx - 1)
+
+ # Load the starting index of our bin in array 'x'.
+ index_x = offset_in_bin
+ if bin_idx > 0:
+ index_x += tl.load(padded_bins + bin_idx - 1)
+
+ # Offset the input and output pointers.
+ wgrad += index_out
+ grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS)
+ x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ data = tl.load(x + offsets, mask=mask).to(tl.float32)
+ scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
+ acc += data * scale
+ offsets += BLOCK_X
+
+ # Reduce to get the final result and store.
+ out = tl.sum(acc).to(wgrad.dtype.element_ty)
+ tl.store(wgrad, out)
+
+
+def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_matrix(grad)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], bin_ids.shape[0])
+ assert_equal(bins.size(), padded_bins.size())
+
+ tokens = indices.shape[0] // top_k
+ out = torch.empty((tokens * top_k), dtype=x.dtype, device=x.device)
+ _padded_copy_wgrad[(indices.shape[0],)](
+ x,
+ grad,
+ out,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ TOP_K=top_k,
+ )
+ return out
+
+
+def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k):
+ return padded_scatter_wgrad(x, grad, indices, bin_ids, bins, bins, top_k)
+
+
+# a: (tokens, hidden_size), real.
+# b: (num_experts, expert_capacity, num_columns), real.
+# indices: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _binned_copy(
+ a,
+ b,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+ A_TO_B: tl.constexpr,
+ SCALE: tl.constexpr,
+):
+ # Load our indices into the output.
+ expert_idx = tl.program_id(0)
+ entry_idx = tl.program_id(1)
+
+ # Calculate our offset into the output.
+ index_b = expert_idx * expert_capacity + entry_idx
+
+ # Load the index bounds for our bin and calculate
+ # the number of tokens assigned to our expert.
+ start = 0
+ if expert_idx > 0:
+ start = tl.load(bins + expert_idx - 1)
+ end = tl.load(bins + expert_idx)
+ num_tokens = end - start
+
+ # Calculate our offset into the input. If we don't
+ # have an input exit early.
+ if entry_idx >= num_tokens:
+ return
+ index_a = tl.load(indices + start + entry_idx)
+
+ # Offset the input and output pointers.
+ #
+ # If we're going from A to B, divide the input index to copy
+ # the same input repeatedly. If we're going from B to A we
+ # need to reduce the result. Using atomics is slow, so we
+ # do the reduce step in a second kernel.
+ offset = index_a // TOP_K if A_TO_B else index_a
+ a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS)
+ b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ # Load the scale, if requested.
+ scale = tl.load(weights + index_a) if SCALE else 1
+
+ # Swap the pointers depending on the direction.
+ #
+ # NOTE: We need to zero the output in both directions.
+ iptr = a if A_TO_B else b
+ optr = b if A_TO_B else a
+
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ x = tl.load(iptr + offsets, mask=mask)
+ x = x.to(tl.float32) * scale.to(tl.float32)
+
+ tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask)
+
+ offsets += BLOCK_X
+
+
+def binned_gather(x, indices, weights, bins, expert_capacity, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ num_experts = bins.shape[0]
+ out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
+
+ _binned_copy[(num_experts, expert_capacity)](
+ x,
+ out,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def binned_scatter(x, indices, weights, bins, top_k):
+ # Validate the input shapes.
+ assert_is_tensor(x, 3)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(bins.shape[0], x.shape[0])
+
+ if weights is not None:
+ assert_equal(indices.shape[0], weights.shape[0])
+
+ num_experts, expert_capacity, hidden_size = x.shape
+ tokens = indices.shape[0] // top_k
+ out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
+ _binned_copy[(num_experts, expert_capacity)](
+ out,
+ x,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS=hidden_size,
+ A_TO_B=False,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+
+ # Reduce along the top-k dimension, if needed.
+ return out.sum(dim=1) if top_k > 1 else out.view(tokens, hidden_size)
+
+
+# a: (tokens, hidden_size), real.
+# b: (num_experts, expert_capacity, num_columns), real.
+# indices: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _binned_copy_wgrad(
+ x,
+ grad,
+ wgrad,
+ num_experts,
+ expert_capacity,
+ indices,
+ bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+):
+ # Load our indices into the output.
+ expert_idx = tl.program_id(0)
+ entry_idx = tl.program_id(1)
+
+ # Calculate our offset into the output.
+ index_x = expert_idx * expert_capacity + entry_idx
+
+ # Load the index bounds for our bin and calculate
+ # the number of tokens assigned to our expert.
+ start = 0
+ if expert_idx > 0:
+ start = tl.load(bins + expert_idx - 1)
+ end = tl.load(bins + expert_idx)
+ num_tokens = end - start
+
+ # Calculate our offset into the input. If we don't
+ # have an input exit early.
+ if entry_idx >= num_tokens:
+ return
+ index_out = tl.load(indices + start + entry_idx)
+
+ # Offset the input and output pointers.
+ wgrad += index_out
+ grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS)
+ x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ data = tl.load(x + offsets, mask=mask).to(tl.float32)
+ scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
+ acc += data * scale
+ offsets += BLOCK_X
+
+ # Reduce to get the final result and store.
+ out = tl.sum(acc).to(wgrad.dtype.element_ty)
+ tl.store(wgrad, out)
+
+
+def binned_scatter_wgrad(x, grad, indices, bins, top_k):
+ # Validate the input shapes.
+ assert_is_tensor(x, 3)
+ assert_is_matrix(grad)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(bins.shape[0], x.shape[0])
+
+ num_experts, expert_capacity, hidden_size = x.shape
+ tokens = indices.shape[0] // top_k
+ out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
+ _binned_copy_wgrad[(num_experts, expert_capacity)](
+ x,
+ grad,
+ out,
+ num_experts,
+ expert_capacity,
+ indices,
+ bins,
+ NUM_COLUMNS=hidden_size,
+ TOP_K=top_k,
+ )
+ return out
diff --git a/build/torch29-cxx11-cu128-aarch64-linux/benchmark_util.py b/build/torch29-cxx11-cu128-aarch64-linux/benchmark_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..02612d95e3ead1175a596e2878fa34b5bf85ad6f
--- /dev/null
+++ b/build/torch29-cxx11-cu128-aarch64-linux/benchmark_util.py
@@ -0,0 +1,35 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import numpy as np
+import torch
+
+
+def log_benchmark(name, arguments, time, std):
+ print('=' * 60)
+ print(f'{name} Benchmark')
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean time = {:.3f}ms, std time = {:.3f}ms'.format(time, std))
+ print('=' * 60)
+
+
+def benchmark_function(fn, iterations=100, warmup=10):
+ # Warmup iterations.
+ for _ in range(warmup):
+ fn()
+
+ times = []
+ for i in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+
+ start.record()
+ fn()
+ end.record()
+
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ return np.mean(times), np.std(times)
diff --git a/build/torch29-cxx11-cu128-aarch64-linux/cpu_fused_moe.py b/build/torch29-cxx11-cu128-aarch64-linux/cpu_fused_moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a29e81d6e9c57f64b0a78ac1c0828e45fd9d855
--- /dev/null
+++ b/build/torch29-cxx11-cu128-aarch64-linux/cpu_fused_moe.py
@@ -0,0 +1,311 @@
+# SPDX-License-Identifier: Apache-2.0
+# MegaBlocks CPU Fused MoE Implementation
+#
+# This is a pure Python/PyTorch implementation for CPU.
+# For better performance, consider using the C++ kernel implementation.
+#
+import torch
+import torch.nn.functional as F
+
+
+def swigluoai_activation(gate: torch.Tensor, up: torch.Tensor,
+ alpha: float = 1.702, limit: float = 7.0) -> torch.Tensor:
+ """
+ SwigluOAI activation function used in GptOss models.
+
+ Formula:
+ gate = clamp(gate, max=limit)
+ up = clamp(up, -limit, limit)
+ glu = gate * sigmoid(gate * alpha)
+ output = (up + 1) * glu
+
+ Args:
+ gate: Gate tensor from gate projection
+ up: Up tensor from up projection
+ alpha: Scaling factor for sigmoid (default: 1.702)
+ limit: Clamp limit (default: 7.0)
+
+ Returns:
+ Activated tensor
+ """
+ gate = gate.clamp(max=limit)
+ up = up.clamp(min=-limit, max=limit)
+ glu = gate * torch.sigmoid(gate * alpha)
+ return (up + 1) * glu
+
+
+def silu_and_mul_activation(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor:
+ """
+ SiLU (Swish) activation with element-wise multiplication.
+
+ Formula:
+ output = silu(gate) * up
+
+ Args:
+ gate: Gate tensor
+ up: Up tensor
+
+ Returns:
+ Activated tensor
+ """
+ return F.silu(gate) * up
+
+
+def route_tokens_cpu(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: torch.Tensor | None,
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_normalize_expert_weights: int | None = None,
+) -> tuple:
+ """
+ Route tokens to experts and compute expert weights and indices (CPU version).
+
+ Args:
+ x: Input tensor [batch, seq, hidden] or [tokens, hidden]
+ router_weight: Router weight [num_experts, hidden]
+ router_bias: Router bias [num_experts] or None
+ moe_top_k: Number of experts per token
+ moe_num_experts: Total number of experts
+ moe_normalize_expert_weights: Normalization order or None
+
+ Returns:
+ Tuple of (logits, expert_weights, expert_indices)
+ """
+ x_flat = x.view(-1, x.shape[-1])
+ logits = F.linear(x_flat, router_weight, router_bias)
+
+ if moe_top_k == 1:
+ expert_weights, expert_indices = logits.max(dim=-1, keepdim=True)
+ else:
+ expert_weights, expert_indices = torch.topk(logits, moe_top_k, dim=-1)
+
+ expert_weights = expert_weights.softmax(dim=-1)
+
+ if moe_normalize_expert_weights is not None:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+
+ return logits, expert_weights, expert_indices
+
+
+def cpu_fused_moe(
+ hidden_states: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ w1_bias: torch.Tensor | None = None,
+ w2_bias: torch.Tensor | None = None,
+ activation: str = "silu",
+ alpha: float = 1.702,
+ limit: float = 7.0,
+ is_interleaved: bool = True,
+) -> torch.Tensor:
+ """
+ CPU Fused MoE using PyTorch operations.
+
+ This implementation processes all experts in parallel using batched operations
+ instead of sequential for loops, which is more efficient on CPU.
+
+ Args:
+ hidden_states: [num_tokens, hidden_size]
+ w1: [num_experts, hidden_size, 2*inter_size] - gate_up_proj weights
+ w2: [num_experts, inter_size, hidden_size] - down_proj weights
+ topk_weights: [num_tokens, topk] - routing weights
+ topk_ids: [num_tokens, topk] - expert indices
+ w1_bias: [num_experts, 2*inter_size] or None
+ w2_bias: [num_experts, hidden_size] or None
+ activation: "silu" or "swigluoai"
+ alpha: swigluoai alpha parameter
+ limit: swigluoai limit parameter
+ is_interleaved: whether gate_up is interleaved [g0,u0,g1,u1,...] (True for GptOss)
+
+ Returns:
+ output: [num_tokens, hidden_size]
+ """
+ num_tokens, hidden_size = hidden_states.shape
+ num_experts = w1.shape[0]
+ inter_size = w2.shape[1]
+ topk = topk_weights.shape[1]
+
+ # Initialize output
+ output = torch.zeros_like(hidden_states)
+
+ # Build expert mask: which tokens go to which expert
+ # expert_mask[expert_id] contains indices of (token_idx, topk_pos) pairs
+ for expert_idx in range(num_experts):
+ # Find tokens assigned to this expert
+ # mask shape: [num_tokens, topk], True where topk_ids == expert_idx
+ mask = (topk_ids == expert_idx)
+
+ if not mask.any():
+ continue
+
+ # Get token indices and topk positions
+ token_indices, topk_positions = torch.where(mask)
+
+ if len(token_indices) == 0:
+ continue
+
+ # Gather input tokens for this expert
+ # current_hidden: [num_selected_tokens, hidden_size]
+ current_hidden = hidden_states[token_indices]
+
+ # Get weights for this expert
+ # w1[expert_idx]: [hidden_size, 2*inter_size]
+ # w2[expert_idx]: [inter_size, hidden_size]
+ expert_w1 = w1[expert_idx] # [hidden_size, 2*inter_size]
+ expert_w2 = w2[expert_idx] # [inter_size, hidden_size]
+
+ # First projection: hidden @ w1 -> [num_selected, 2*inter_size]
+ gate_up = current_hidden @ expert_w1
+
+ # Add bias if present
+ if w1_bias is not None:
+ gate_up = gate_up + w1_bias[expert_idx]
+
+ # Split gate and up projections
+ if is_interleaved:
+ # GptOss uses interleaved layout: [g0, u0, g1, u1, ...]
+ gate = gate_up[..., ::2] # [num_selected, inter_size]
+ up = gate_up[..., 1::2] # [num_selected, inter_size]
+ else:
+ # Standard layout: [gate_all, up_all]
+ gate = gate_up[..., :inter_size]
+ up = gate_up[..., inter_size:]
+
+ # Apply activation
+ if activation == "swigluoai":
+ activated = swigluoai_activation(gate, up, alpha, limit)
+ else: # silu
+ activated = silu_and_mul_activation(gate, up)
+
+ # Second projection: activated @ w2 -> [num_selected, hidden_size]
+ expert_out = activated @ expert_w2
+
+ # Add bias if present
+ if w2_bias is not None:
+ expert_out = expert_out + w2_bias[expert_idx]
+
+ # Apply routing weights and accumulate
+ # weights shape: [num_selected]
+ weights = topk_weights[token_indices, topk_positions].unsqueeze(-1)
+ weighted_out = expert_out * weights
+
+ # Accumulate to output
+ output.index_add_(0, token_indices, weighted_out.to(output.dtype))
+
+ return output
+
+
+class MegaBlocksMoeMLP(torch.nn.Module):
+ """
+ CPU MoE MLP module that can be used as a drop-in replacement for
+ the transformers GptOssMLP when using @use_kernel_forward_from_hub.
+ """
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> tuple:
+ """
+ Forward pass through the MoE layer.
+
+ Args:
+ x: Input tensor of shape [batch_size, seq_len, hidden_size] or [tokens, hidden_size]
+
+ Returns:
+ Tuple of (output, expert_weights) where:
+ - output: Tensor of same shape as input
+ - expert_weights: Expert weights for each token [tokens, top_k]
+ """
+ # Get MoE parameters from the wrapped modules
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+
+ # Detect activation type
+ if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
+ activation = "swigluoai"
+ alpha = self.experts.alpha
+ limit = self.experts.limit
+ else:
+ activation = getattr(self.experts, "activation", "silu")
+ alpha = 1.702
+ limit = 7.0
+
+ # Get weight tensors
+ if hasattr(self.experts, "gate_up_proj"):
+ w1 = self.experts.gate_up_proj
+ is_interleaved = True # GptOss uses interleaved layout
+ elif hasattr(self.experts, "w1"):
+ w1 = self.experts.w1
+ w3 = getattr(self.experts, "w3", None)
+ if w3 is not None:
+ w1 = torch.cat([w1, w3], dim=-1)
+ is_interleaved = False
+ else:
+ raise AttributeError("experts module must have 'gate_up_proj' or 'w1' attribute")
+
+ if hasattr(self.experts, "down_proj"):
+ w2 = self.experts.down_proj
+ elif hasattr(self.experts, "w2"):
+ w2 = self.experts.w2
+ else:
+ raise AttributeError("experts module must have 'down_proj' or 'w2' attribute")
+
+ # Get optional bias tensors
+ w1_bias = getattr(self.experts, "gate_up_proj_bias", None)
+ w2_bias = getattr(self.experts, "down_proj_bias", None)
+
+ # Store original shape
+ in_shape = x.size()
+
+ # Route tokens to experts
+ logits, expert_weights, expert_indices = route_tokens_cpu(
+ x,
+ self.router.weight,
+ getattr(self.router, "bias", None),
+ moe_top_k,
+ moe_num_experts,
+ moe_normalize_expert_weights,
+ )
+
+ # Reshape input for fused MoE
+ x_flat = x.view(-1, x.shape[-1])
+
+ # Call CPU fused MoE
+ output = cpu_fused_moe(
+ hidden_states=x_flat,
+ w1=w1,
+ w2=w2,
+ topk_weights=expert_weights,
+ topk_ids=expert_indices,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ activation=activation,
+ alpha=alpha,
+ limit=limit,
+ is_interleaved=is_interleaved,
+ )
+
+ # Restore original shape
+ output = output.view(in_shape)
+
+ return output, expert_weights
+
+
+# Export classes and functions
+__all__ = [
+ "MegaBlocksMoeMLP",
+ "cpu_fused_moe",
+ "route_tokens_cpu",
+ "swigluoai_activation",
+ "silu_and_mul_activation",
+]
diff --git a/build/torch29-cxx11-cu128-aarch64-linux/cpu_moe_cpp.py b/build/torch29-cxx11-cu128-aarch64-linux/cpu_moe_cpp.py
new file mode 100644
index 0000000000000000000000000000000000000000..073ff66d24ce348fbb5ed19c9027fadd3f7a9c61
--- /dev/null
+++ b/build/torch29-cxx11-cu128-aarch64-linux/cpu_moe_cpp.py
@@ -0,0 +1,265 @@
+# SPDX-License-Identifier: Apache-2.0
+# MegaBlocks C++ Optimized CPU MoE
+
+"""
+C++ accelerated MoE with brgemm optimization for Intel AMX.
+Direct replacement for cpu_fused_moe.MegaBlocksMoeMLP with better performance.
+"""
+
+import torch
+from typing import Optional
+from .cpu_fused_moe import route_tokens_cpu
+from ._ops import ops
+
+
+def _to_local_tensor(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
+ """Convert DTensor to local torch.Tensor if needed for custom ops compatibility."""
+ if tensor is None:
+ return None
+ # Check if it's a DTensor by looking for the to_local() method
+ if hasattr(tensor, "to_local"):
+ return tensor.to_local()
+ return tensor
+
+
+def fused_moe_cpp(
+ hidden_states: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ inplace: bool = False,
+ use_int8_w8a8: bool = False,
+ use_fp8_w8a16: bool = False,
+ use_mxfp4: bool = False,
+ w1_scale: Optional[torch.Tensor] = None,
+ w2_scale: Optional[torch.Tensor] = None,
+ block_size: Optional[list] = None,
+ a1_scale: Optional[torch.Tensor] = None,
+ a2_scale: Optional[torch.Tensor] = None,
+ w1_bias: Optional[torch.Tensor] = None,
+ w2_bias: Optional[torch.Tensor] = None,
+ alpha: Optional[float] = None,
+ limit: Optional[float] = None,
+ is_vnni: bool = False,
+) -> torch.Tensor:
+ """
+ C++ Fused MoE with brgemm optimization (sglang compatible interface).
+
+ Uses at::native::cpublas::brgemm for efficient batch GEMM on Intel CPUs.
+ Supports both silu_and_mul (standard SwiGLU) and swigluoai (GptOss) activations.
+
+ Args:
+ hidden_states: Input tensor [M, K]
+ w1: Gate and up projections [E, 2N, K]
+ w2: Down projection [E, K, N]
+ topk_weights: Expert weights [M, topk]
+ topk_ids: Expert indices [M, topk]
+ inplace: Whether to use hidden_states as output
+ use_int8_w8a8: Use int8 quantization
+ use_fp8_w8a16: Use fp8 quantization
+ use_mxfp4: Use mxfp4 quantization
+ w1_scale, w2_scale: Quantization scales
+ block_size: Block size for fp8
+ a1_scale, a2_scale: Activation scales
+ w1_bias, w2_bias: Optional biases
+ alpha: swigluoai alpha parameter (set to enable swiglu)
+ limit: swigluoai limit parameter (set to enable swiglu)
+ is_vnni: Whether w1/w2 are already in VNNI packed format
+ """
+ # MXFP4/FP8 kernels only support bf16, convert if needed
+ orig_dtype = hidden_states.dtype
+ need_convert = ((use_mxfp4 or use_fp8_w8a16) and orig_dtype != torch.bfloat16) or orig_dtype == torch.float32
+ if need_convert:
+ hidden_states = hidden_states.to(torch.bfloat16)
+
+ # bias must match hidden_states dtype
+ if w1_bias is not None:
+ w1_bias = w1_bias.to(hidden_states.dtype)
+ if w2_bias is not None:
+ w2_bias = w2_bias.to(hidden_states.dtype)
+
+ # Convert DTensor to local tensor for custom ops compatibility (TP mode)
+ hidden_states = _to_local_tensor(hidden_states)
+ w1 = _to_local_tensor(w1)
+ w2 = _to_local_tensor(w2)
+ topk_weights = _to_local_tensor(topk_weights)
+ topk_ids = _to_local_tensor(topk_ids)
+ w1_scale = _to_local_tensor(w1_scale)
+ w2_scale = _to_local_tensor(w2_scale)
+ a1_scale = _to_local_tensor(a1_scale)
+ a2_scale = _to_local_tensor(a2_scale)
+ w1_bias = _to_local_tensor(w1_bias)
+ w2_bias = _to_local_tensor(w2_bias)
+
+ output = ops.fused_experts(
+ hidden_states, w1, w2, topk_weights, topk_ids,
+ inplace, use_int8_w8a8, use_fp8_w8a16, use_mxfp4,
+ w1_scale, w2_scale, block_size, a1_scale, a2_scale,
+ w1_bias, w2_bias, alpha, limit, is_vnni
+ )
+
+ # Convert back to original dtype if needed
+ if need_convert:
+ output = output.to(orig_dtype)
+ return output
+
+
+class CPUMegaBlocksMoeMLP(torch.nn.Module):
+ """
+ C++ optimized MoE MLP using brgemm.
+ Drop-in replacement for cpu_fused_moe.MegaBlocksMoeMLP with better performance.
+
+ Usage in transformers:
+ # Will be used via @use_kernel_forward_from_hub decorator
+ """
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> tuple:
+ """
+ Forward pass through the MoE layer using C++ kernel.
+
+ Args:
+ x: Input tensor [batch_size, seq_len, hidden_size]
+
+ Returns:
+ Tuple of (output, expert_weights)
+ """
+ # Optimization for GPT-OSS model
+ if getattr(self, "use_mxfp4", None) is None:
+ self.use_mxfp4 = False
+
+ w1_scale = None
+ w2_scale = None
+
+ if (
+ not getattr(self, "packed_scales", False)
+ and hasattr(self.experts, "gate_up_proj")
+ and getattr(self.experts, "gate_up_proj_precision_config", None) is not None
+ ):
+ # convert scales
+ data_1 = ops.convert_scale_packed(self.experts.gate_up_proj_precision_config.weight_scale.data.transpose(-1, -2).contiguous())
+ data_2 = ops.convert_scale_packed(self.experts.down_proj_precision_config.weight_scale.data.transpose(-1, -2).contiguous())
+ self.experts.gate_up_proj_precision_config.weight_scale.storage.data = data_1
+ self.experts.down_proj_precision_config.weight_scale.storage.data = data_2
+ self.packed_scales = True
+ self.use_mxfp4 = True
+
+ if not getattr(self, "packed_weight", False) and hasattr(
+ self.experts, "gate_up_proj"
+ ):
+ # convert weights
+ data_1 = self.experts.gate_up_proj.data.transpose(-1, -2).contiguous()
+ data_2 = self.experts.down_proj.data.transpose(-1, -2).contiguous()
+ if self.use_mxfp4:
+ self.experts.gate_up_proj.storage.data = ops.convert_weight_packed(data_1)
+ self.experts.down_proj.storage.data = ops.convert_weight_packed(data_2)
+ else:
+ # convert_weight_packed only supports bfloat16, float16, int8, fp8_e4m3 or uint8(mxfp4 or int4).
+ data_1 = data_1.to(torch.bfloat16) if data_1.dtype == torch.float32 else data_1
+ data_2 = data_2.to(torch.bfloat16) if data_2.dtype == torch.float32 else data_2
+ self.experts.gate_up_proj.data = ops.convert_weight_packed(data_1)
+ self.experts.down_proj.data = ops.convert_weight_packed(data_2)
+
+ # C++ kernel does not support float32.
+ dtype = torch.bfloat16 if x.dtype == torch.float32 else x.dtype
+ if getattr(self.experts, "gate_up_proj_bias", None) is not None:
+ self.experts.gate_up_proj_bias.data = self.experts.gate_up_proj_bias.data.to(dtype)
+ if getattr(self.experts, "down_proj_bias", None) is not None:
+ self.experts.down_proj_bias.data = self.experts.down_proj_bias.data.to(dtype)
+
+ self.packed_weight = True
+
+ # Get MoE parameters
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
+
+ # Detect activation type
+ if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
+ activation = "swigluoai"
+ alpha = self.experts.alpha
+ limit = self.experts.limit
+ else:
+ activation = getattr(self.experts, "activation", "silu")
+ alpha = 1.702
+ limit = 7.0
+
+ # Get weight tensors
+ if hasattr(self.experts, "gate_up_proj"):
+ w1 = self.experts.gate_up_proj
+ elif hasattr(self.experts, "w1"):
+ w1 = self.experts.w1
+ w3 = getattr(self.experts, "w3", None)
+ if w3 is not None:
+ w1 = torch.cat([w1, w3], dim=-1)
+ else:
+ raise AttributeError("experts module must have 'gate_up_proj' or 'w1' attribute")
+
+ if hasattr(self.experts, "down_proj"):
+ w2 = self.experts.down_proj
+ elif hasattr(self.experts, "w2"):
+ w2 = self.experts.w2
+ else:
+ raise AttributeError("experts module must have 'down_proj' or 'w2' attribute")
+
+ # Get optional bias tensors
+ w1_bias = getattr(self.experts, "gate_up_proj_bias", None)
+ w2_bias = getattr(self.experts, "down_proj_bias", None)
+ w1_bias = w1_bias if w1_bias is None else w1_bias.data
+ w2_bias = w2_bias if w2_bias is None else w2_bias.data
+
+ if self.use_mxfp4:
+ w1_scale = self.experts.gate_up_proj_precision_config.weight_scale.data
+ w2_scale = self.experts.down_proj_precision_config.weight_scale.data
+
+ # Store original shape
+ in_shape = x.size()
+
+ # Route tokens to experts (Python implementation is fast enough)
+ logits, expert_weights, expert_indices = route_tokens_cpu(
+ x,
+ self.router.weight,
+ getattr(self.router, "bias", None),
+ moe_top_k,
+ moe_num_experts,
+ moe_normalize_expert_weights,
+ )
+
+ # Flatten input
+ x_flat = x.view(-1, x.shape[-1])
+
+ # Determine alpha/limit for swiglu activation
+ use_alpha = alpha if activation == "swigluoai" else None
+ use_limit = limit if activation == "swigluoai" else None
+
+ # Call C++ optimized kernel
+ output = fused_moe_cpp(
+ hidden_states=x_flat,
+ w1=w1.data,
+ w2=w2.data,
+ topk_weights=expert_weights,
+ topk_ids=expert_indices.to(torch.int32),
+ inplace=False,
+ use_int8_w8a8=False,
+ use_fp8_w8a16=False,
+ use_mxfp4=self.use_mxfp4,
+ w1_scale=w1_scale,
+ w2_scale=w2_scale,
+ block_size=None,
+ a1_scale=None,
+ a2_scale=None,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ alpha=use_alpha,
+ limit=use_limit,
+ is_vnni=getattr(self, "packed_weight", False),
+ )
+
+ # Restore original shape
+ output = output.view(in_shape)
+
+ return output, expert_weights
+
+
+__all__ = ["fused_moe_cpp", "MegaBlocksMoeMLP"]
diff --git a/build/torch29-cxx11-cu128-aarch64-linux/grouped_gemm/__init__.py b/build/torch29-cxx11-cu128-aarch64-linux/grouped_gemm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b91c8308f0c24f4c4171b6e4f15b6f76dabf295a
--- /dev/null
+++ b/build/torch29-cxx11-cu128-aarch64-linux/grouped_gemm/__init__.py
@@ -0,0 +1,2 @@
+from . import ops
+from . import backend
diff --git a/build/torch29-cxx11-cu128-aarch64-linux/grouped_gemm/backend.py b/build/torch29-cxx11-cu128-aarch64-linux/grouped_gemm/backend.py
new file mode 100644
index 0000000000000000000000000000000000000000..76037d8039cbfc2f0577275c78e4bc0be762592a
--- /dev/null
+++ b/build/torch29-cxx11-cu128-aarch64-linux/grouped_gemm/backend.py
@@ -0,0 +1,33 @@
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# # TODO(tgale): Wrap this in a try-block with better
+# # error message and instructions for building the
+# # c++ operations.
+# import grouped_gemm_backend as backend
+
+# We import the backend operations from the megablocks package as
+# grouped_gemm is vendored in megablocks in this repository.
+# from ... import _ops as backend
+# from megablocks._ops import ops as backend # type: ignore
+from .._ops import ops as backend # type: ignore
+
+def _allocate_output(a, b, batch_sizes, trans_a, trans_b):
+ assert not (trans_a and trans_b)
+ assert batch_sizes.ndim == 1, "Expected 1d tensor for batch_sizes"
+ assert a.ndim == 2, "Expected 2d tensor for 'a'"
+ assert b.ndim == (2 if trans_a else 3)
+
+ shape = (
+ (batch_sizes.shape[0], a.shape[1], b.shape[1])
+ if trans_a else
+ (a.shape[0], (b.shape[1] if trans_b else b.shape[2]))
+ )
+ return torch.empty(*shape, device=a.device, dtype=a.dtype)
+
+def gmm(a, b, batch_sizes, trans_a=False, trans_b=False, c=None):
+ if c is None:
+ c = _allocate_output(a, b, batch_sizes, trans_a, trans_b)
+ backend.gmm(a, b, c, batch_sizes, trans_a, trans_b)
+ return c
diff --git a/build/torch29-cxx11-cu128-aarch64-linux/grouped_gemm/ops.py b/build/torch29-cxx11-cu128-aarch64-linux/grouped_gemm/ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b30dd14e23837ea3b12334f4e31337ed9ad2b69
--- /dev/null
+++ b/build/torch29-cxx11-cu128-aarch64-linux/grouped_gemm/ops.py
@@ -0,0 +1,33 @@
+from . import backend
+import torch
+
+
+class GroupedGemm(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, a, b, batch_sizes, trans_b):
+ ctx.save_for_backward(a, b, batch_sizes)
+ ctx.trans_b = trans_b
+ return backend.gmm(a, b, batch_sizes, trans_a=False, trans_b=trans_b)
+
+ @staticmethod
+ def backward(ctx, grad):
+ grad = grad.contiguous()
+ a, b, batch_sizes = ctx.saved_tensors
+ trans_b = ctx.trans_b
+
+ agrad = None
+ if ctx.needs_input_grad[0]:
+ agrad = backend.gmm(
+ grad, b, batch_sizes, trans_a=False, trans_b=not trans_b)
+
+ bgrad = None
+ if ctx.needs_input_grad[1]:
+ lhs, rhs = (grad, a) if trans_b else (a, grad)
+ bgrad = backend.gmm(
+ lhs, rhs, batch_sizes, trans_a=True, trans_b=False)
+ return agrad, bgrad, None, None
+
+
+def gmm(a, b, batch_sizes, trans_b=False):
+ return GroupedGemm.apply(a, b, batch_sizes, trans_b)
diff --git a/build/torch29-cxx11-cu128-aarch64-linux/grouped_gemm_util.py b/build/torch29-cxx11-cu128-aarch64-linux/grouped_gemm_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d6d49fc46b0a57ad46e4179df3cc1ac2a24f7ae
--- /dev/null
+++ b/build/torch29-cxx11-cu128-aarch64-linux/grouped_gemm_util.py
@@ -0,0 +1,31 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+import warnings
+
+_grouped_gemm_is_available: bool = False
+try:
+ # import grouped_gemm
+ pass
+ _grouped_gemm_is_available = True
+except ImportError as error:
+ warnings.warn('Grouped GEMM not available.')
+
+
+def grouped_gemm_is_available():
+ return _grouped_gemm_is_available
+
+
+def assert_grouped_gemm_is_available():
+ msg = (
+ 'Grouped GEMM not available. Please run '
+ '`pip install git+https://github.com/tgale96/grouped_gemm@main`.',
+ )
+ assert _grouped_gemm_is_available, msg
+
+
+# backend = grouped_gemm.backend if grouped_gemm_is_available() else None
+# ops = grouped_gemm.ops if grouped_gemm_is_available() else None
+
+
+from .grouped_gemm import backend as ops
+from .grouped_gemm import ops as backend
diff --git a/build/torch29-cxx11-cu128-aarch64-linux/layers.py b/build/torch29-cxx11-cu128-aarch64-linux/layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c91edcd9e2d1a4ef9eac90217ff481f08ab1886
--- /dev/null
+++ b/build/torch29-cxx11-cu128-aarch64-linux/layers.py
@@ -0,0 +1,1232 @@
+import torch
+import torch.distributed as dist
+
+from typing import Optional, Any, TYPE_CHECKING
+
+from . import _layers
+from . import ops
+
+# Conditional import for meta kernel registration
+if TYPE_CHECKING:
+
+ def register_fake(fn):
+ return lambda name: fn
+
+else:
+ try:
+ from torch.library import register_fake
+ except ImportError:
+ try:
+ from torch.library import impl_abstract as register_fake
+ except ImportError:
+ # Fallback for older PyTorch versions
+ def register_fake(op_name):
+ def decorator(fn):
+ return fn
+
+ return decorator
+
+
+# Meta kernel implementations for torch.compile compatibility
+def _install_meta_kernels():
+ """Install meta kernels for existing MegaBlocks operations"""
+
+ # Create wrapper functions that check for compilation and return meta tensors
+
+ # Patch ops.sort
+ if hasattr(ops, "sort"):
+ original_sort = ops.sort
+
+ def sort_with_meta(x, end_bit=None):
+ if torch.compiler.is_compiling():
+ # print("Using meta kernel for sort")
+ # Meta implementation - return tensors with correct shape/dtype/device
+ return torch.empty_like(x), torch.empty_like(x)
+ # print("Using original sort kernel")
+ return original_sort(x, end_bit)
+
+ ops.sort = sort_with_meta
+
+ # Patch ops.histogram
+ if hasattr(ops, "histogram"):
+ original_histogram = ops.histogram
+
+ def histogram_with_meta(x, max_val):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty((max_val,), dtype=torch.int32, device=x.device)
+ return original_histogram(x, max_val)
+
+ ops.histogram = histogram_with_meta
+
+ # Patch ops.inclusive_cumsum
+ if hasattr(ops, "inclusive_cumsum"):
+ original_inclusive_cumsum = ops.inclusive_cumsum
+
+ def inclusive_cumsum_with_meta(x, dim):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty_like(x)
+ return original_inclusive_cumsum(x, dim)
+
+ ops.inclusive_cumsum = inclusive_cumsum_with_meta
+
+ # Patch ops.binned_gather
+ if hasattr(ops, "binned_gather"):
+ original_binned_gather = ops.binned_gather
+
+ def binned_gather_with_meta(x, indices, bins, bin_size, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - output shape based on bin_size
+ if x.dim() >= 2:
+ hidden_size = x.size(-1)
+ return torch.empty(
+ (bin_size, x.size(1), hidden_size),
+ dtype=x.dtype,
+ device=x.device,
+ )
+ else:
+ return torch.empty((bin_size,), dtype=x.dtype, device=x.device)
+ return original_binned_gather(x, indices, bins, bin_size, top_k)
+
+ ops.binned_gather = binned_gather_with_meta
+
+ # Patch ops.binned_scatter
+ if hasattr(ops, "binned_scatter"):
+ original_binned_scatter = ops.binned_scatter
+
+ def binned_scatter_with_meta(x, indices, weights, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - typically reduces to 2D
+ if x.dim() >= 3:
+ return torch.empty(
+ (x.size(1), x.size(2)), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty_like(x)
+ return original_binned_scatter(x, indices, weights, bins, top_k)
+
+ ops.binned_scatter = binned_scatter_with_meta
+
+ # Patch ops.gather
+ if hasattr(ops, "gather"):
+ original_gather = ops.gather
+
+ def gather_with_meta(x, indices, bin_ids, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ if x.dim() >= 2:
+ hidden_size = x.size(-1)
+ return torch.empty(
+ (indices.numel(), hidden_size), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty(indices.shape, dtype=x.dtype, device=x.device)
+ return original_gather(x, indices, bin_ids, bins, top_k)
+
+ ops.gather = gather_with_meta
+
+ # Patch ops.scatter
+ if hasattr(ops, "scatter"):
+ original_scatter = ops.scatter
+
+ def scatter_with_meta(x, indices, bin_ids, weights, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - restore sequence shape
+ seq_len = (
+ indices.size(0) // top_k
+ if indices.numel() > 0 and top_k > 0
+ else x.size(0)
+ )
+ if x.dim() >= 2:
+ return torch.empty(
+ (seq_len, x.size(-1)), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty((seq_len,), dtype=x.dtype, device=x.device)
+ return original_scatter(x, indices, bin_ids, weights, bins, top_k)
+
+ ops.scatter = scatter_with_meta
+
+ # Patch ops.replicate
+ if hasattr(ops, "replicate"):
+ original_replicate = ops.replicate
+
+ def replicate_with_meta(x, bins, num_outputs):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty(
+ (x.shape[0], num_outputs), dtype=x.dtype, device=x.device
+ )
+ return original_replicate(x, bins, num_outputs)
+
+ ops.replicate = replicate_with_meta
+
+ # Patch ops.repeat (if it's a regular function)
+ if hasattr(ops, "repeat"):
+ original_repeat = ops.repeat
+
+ def repeat_with_meta(x, repeats):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ if isinstance(repeats, (tuple, list)):
+ new_shape = list(x.shape)
+ for i, rep in enumerate(repeats):
+ if i < len(new_shape):
+ new_shape[i] *= rep
+ return torch.empty(new_shape, dtype=x.dtype, device=x.device)
+ else:
+ new_shape = [x.size(0) * repeats] + list(x.shape[1:])
+ return torch.empty(new_shape, dtype=x.dtype, device=x.device)
+ return original_repeat(x, repeats)
+
+ ops.repeat = repeat_with_meta
+
+
+# Install meta kernels on import
+try:
+ _install_meta_kernels()
+except Exception as e:
+ # If meta kernel installation fails, continue without them
+ # torch.compile may not work but the library will still function
+ import warnings
+
+ warnings.warn(
+ f"Failed to install meta kernels for torch.compile support: {e}", UserWarning
+ )
+
+
+# Set the expert model parallel attributes on a tensor
+def set_expert_model_parallel_attributes(
+ tensor: torch.Tensor,
+ is_parallel: bool,
+):
+ assert not hasattr(tensor, "expert_model_parallel")
+ setattr(tensor, "expert_model_parallel", is_parallel)
+
+
+# Get the expert model parallel attributes from a tensor
+def expert_sharding_degree(
+ world_size: int,
+ moe_num_experts: int,
+) -> int:
+ esd = min(world_size, moe_num_experts)
+ if (moe_num_experts % esd) != 0:
+ raise ValueError(f"Cannot shard {moe_num_experts} experts {esd} ways.")
+ return esd
+
+
+# Calculate the hidden sharding degree based on world size and expert sharding degree
+def hidden_sharding_degree(
+ world_size: int,
+ moe_num_experts: int,
+ ffn_hidden_size: int,
+) -> int:
+ esd = expert_sharding_degree(world_size, moe_num_experts)
+ hsd = world_size // esd
+ if (ffn_hidden_size % hsd) != 0:
+ raise ValueError(f"Cannot shard {ffn_hidden_size} features {hsd} ways.")
+ if (esd * hsd) != world_size:
+ raise ValueError(
+ f"Invalid sharding. expert_sharding_degree ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size})."
+ )
+ return hsd
+
+
+# Calculate the number of experts per rank based on world size and expert sharding degree
+def experts_per_rank(
+ moe_num_experts: int,
+ world_size: int,
+) -> int:
+ return moe_num_experts // expert_sharding_degree(world_size, moe_num_experts)
+
+
+# Calculate the number of features per rank based on ffn hidden size and hidden sharding degree
+def features_per_rank(
+ ffn_hidden_size: int, world_size: int, moe_num_experts: int
+) -> int:
+ return ffn_hidden_size // hidden_sharding_degree(
+ world_size, moe_num_experts, ffn_hidden_size
+ )
+
+
+# Apply jitter to the input tensor
+def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor:
+ low = 1.0 - moe_jitter_eps
+ high = 1.0 + moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return x * (low + noise * (high - low))
+
+
+# Compute the top-k scores from the logits
+def compute_top_k(scores: torch.Tensor, moe_top_k: int):
+ if moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, moe_top_k, dim=-1)
+
+
+# Route tokens to experts and compute expert weights and indices
+def route_tokens(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: torch.Tensor,
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ if training and moe_jitter_eps is not None:
+ x = apply_jitter(x, moe_jitter_eps)
+
+ x_flat = x.view(-1, x.shape[-1])
+ logits = torch.nn.functional.linear(x_flat, router_weight, router_bias)
+ expert_weights, expert_indices = compute_top_k(logits, moe_top_k)
+ expert_weights = expert_weights.softmax(dim=-1)
+ if moe_normalize_expert_weights is not None:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+ if uniform_expert_assignment:
+ expert_indices = _layers.router._uniform_expert_assignment(
+ expert_indices,
+ moe_num_experts,
+ )
+
+ return logits, expert_weights, expert_indices
+
+
+# Scale the gradient of the weights
+def scale_grad(
+ w: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+) -> torch.Tensor:
+ if gradient_scale is None:
+ return w
+ return _layers.mlp.scale_gradient(w, gradient_scale)
+
+
+# Forward pass for the MLP layer
+def mlp_forward(
+ x: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ limit: float = 7.0,
+):
+ # Scale weights
+ w1 = scale_grad(w1, gradient_scale)
+ w2 = scale_grad(w2, gradient_scale)
+ w1_bias = scale_grad(w1_bias, gradient_scale)
+ w2_bias = scale_grad(w2_bias, gradient_scale)
+
+ # Resolve dtensors
+ w1 = _layers.mlp.resolve_dtensor(w1)
+ w2 = _layers.mlp.resolve_dtensor(w2)
+ w1_bias = _layers.mlp.resolve_dtensor(w1_bias)
+ w2_bias = _layers.mlp.resolve_dtensor(w2_bias)
+
+ # Forward pass
+ gate_up = torch.bmm(x, w1) + w1_bias[..., None, :]
+ gate, up = gate_up[..., ::2], gate_up[..., 1::2]
+ gate = gate.clamp(min=None, max=limit)
+ up = up.clamp(min=-limit, max=limit)
+ glu = gate * torch.sigmoid(gate * alpha)
+ next_states = torch.bmm(((up + 1) * glu), w2)
+ next_states += w2_bias[..., None, :]
+ return next_states
+
+# Shared expert MLP forward pass
+def shared_mlp_forward(
+ x: torch.Tensor,
+ up_proj_weight: torch.Tensor,
+ down_proj_weight: torch.Tensor,
+ up_proj_bias: Optional[torch.Tensor] = None,
+ down_proj_bias: Optional[torch.Tensor] = None,
+ activation_fn: Optional[Any] = None,
+ gradient_scale: Optional[float] = None,
+) -> torch.Tensor:
+ # Default activation function
+ if activation_fn is None:
+ activation_fn = torch.nn.functional.gelu
+
+ # Scale weights
+ up_proj_weight = scale_grad(up_proj_weight, gradient_scale)
+ down_proj_weight = scale_grad(down_proj_weight, gradient_scale)
+ if up_proj_bias is not None:
+ up_proj_bias = scale_grad(up_proj_bias, gradient_scale)
+ if down_proj_bias is not None:
+ down_proj_bias = scale_grad(down_proj_bias, gradient_scale)
+
+ # Resolve dtensors
+ up_proj_weight = _layers.mlp.resolve_dtensor(up_proj_weight)
+ down_proj_weight = _layers.mlp.resolve_dtensor(down_proj_weight)
+ if up_proj_bias is not None:
+ up_proj_bias = _layers.mlp.resolve_dtensor(up_proj_bias)
+ if down_proj_bias is not None:
+ down_proj_bias = _layers.mlp.resolve_dtensor(down_proj_bias)
+
+ # Up projection
+ x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias)
+
+ # Activation
+ x = activation_fn(x)
+
+ # Down projection
+ x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias)
+
+ return x
+
+
+# Combine outputs from shared expert and regular experts
+def combine_expert_shared_outputs(
+ shared_expert_out: torch.Tensor,
+ expert_out: torch.Tensor,
+ shared_expert_weighted_sum: bool = False,
+ moe_top_k: int = 1,
+) -> torch.Tensor:
+ if shared_expert_weighted_sum:
+ # Weighted sum based on number of experts used
+ total_experts = moe_top_k + 1
+ shared_weight = 1.0 / total_experts
+ expert_weight = moe_top_k / total_experts
+ return shared_expert_out * shared_weight + expert_out * expert_weight
+ else:
+ # Simple addition
+ return shared_expert_out + expert_out
+
+
+# Global variable to store load balancing loss
+_LOAD_BALANCING_LOSS = []
+
+
+def save_load_balancing_loss(loss):
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.append(loss)
+
+
+def get_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ return _LOAD_BALANCING_LOSS
+
+
+def clear_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.clear()
+
+
+def batched_load_balancing_loss(args):
+ if args.moe_loss_weight == 0:
+ return 0.0
+
+ tokens_per_expert, expert_scores = zip(*get_load_balancing_loss())
+ num_layers_per_pipeline_stage = args.num_layers // args.pipeline_model_parallel_size
+ if args.num_layers_per_virtual_pipeline_stage is not None:
+ num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage
+
+ if len(tokens_per_expert) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f"Expected {num_layers_per_pipeline_stage} token_per_experts "
+ f"but found {len(tokens_per_expert)}.\nnum_layers = "
+ f"{args.num_layers}\npipeline_model_parallel_size = "
+ f"{args.pipeline_model_parallel_size}\n"
+ "num_layers_per_virtual_pipeline_stage"
+ f" = {args.num_layers_per_virtual_pipeline_stage}",
+ )
+ if len(expert_scores) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f"Expected {num_layers_per_pipeline_stage} expert_scores "
+ f"but found {len(tokens_per_expert)}.\nnum_layers = "
+ f"{args.num_layers}\npipeline_model_parallel_size = "
+ f"{args.pipeline_model_parallel_size}\n"
+ "num_layers_per_virtual_pipeline_stage"
+ f" = {args.num_layers_per_virtual_pipeline_stage}",
+ )
+
+ # Verify the shape of the tokens_per_expert and expert_scores tensors.
+ assert all(
+ (x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert)
+ )
+
+ tokens = expert_scores[0].shape[0]
+ assert all(
+ (
+ (
+ x.ndim == 2
+ and x.shape[1] == args.moe_num_experts
+ and x.shape[0] == tokens
+ )
+ for x in expert_scores
+ )
+ )
+
+ # Concatenate the contributions of each layer and convert to
+ # the correct types and formats for the dot product.
+ expert_scores = torch.cat(expert_scores, dim=1)
+ if args.moe_lbl_in_fp32:
+ expert_scores = expert_scores.float()
+ if tokens != 0:
+ expert_scores = expert_scores.mean(dim=0)
+ else:
+ expert_scores = expert_scores.sum(dim=0)
+ tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype)
+
+ expected_values = num_layers_per_pipeline_stage * args.moe_num_experts
+ assert tokens_per_expert.numel() == expected_values
+ assert expert_scores.numel() == expected_values
+
+ # Calculate the total scale across all factors.
+ #
+ # loss_weight * num_experts / (num_layers * tokens * top_k)
+ scale_numerator = args.moe_num_experts * args.moe_loss_weight
+ scale_denominator = args.num_layers * tokens * args.moe_top_k
+ scale = scale_numerator / scale_denominator
+ return scale * torch.dot(tokens_per_expert, expert_scores)
+
+
+# Calculate the expert capacity based on tokens, top_k, number of experts,
+# expert parallel group, capacity factor, and whether expert model parallelism is used.
+def expert_capacity(
+ tokens: int,
+ top_k: int,
+ num_experts: int,
+ expert_parallel_group: int,
+ moe_capacity_factor: float,
+ moe_expert_model_parallelism: bool,
+) -> int:
+ world_size = (
+ dist.get_world_size(expert_parallel_group)
+ if moe_expert_model_parallelism
+ else 1
+ )
+
+ tokens_per_expert = top_k * tokens * world_size / num_experts
+ return int(moe_capacity_factor * tokens_per_expert)
+
+
+def load_balancing_loss(
+ tokens_per_expert: torch.Tensor,
+ expert_scores: torch.Tensor,
+ top_k: int,
+ num_experts: int,
+):
+ assert len(expert_scores.size()) == 2
+ tokens, num_experts = expert_scores.size()
+ assert num_experts == num_experts
+ assert len(tokens_per_expert.size()) == 1
+ (num_experts,) = tokens_per_expert.size()
+ assert num_experts == num_experts
+ scale = num_experts / (tokens * top_k)
+ return scale * torch.dot(
+ tokens_per_expert.to(expert_scores.dtype),
+ expert_scores.mean(dim=0),
+ )
+
+
+def indices_and_bins(
+ top_expert: torch.Tensor,
+ sort_end_bit: int,
+ num_experts: int,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ top_expert = top_expert.int()
+
+ # Ensure contiguous memory layout
+ top_expert = top_expert.contiguous()
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(top_expert.device):
+ output = ops.sort(top_expert, sort_end_bit)
+ bin_ids, indices = output
+ tokens_per_expert = ops.histogram(top_expert, num_experts)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+
+ bins = bins.view(1) if not len(bins.size()) else bins
+ return indices, bin_ids, bins, tokens_per_expert
+
+
+def expert_capacity_fn(
+ tokens: int,
+ top_k: int,
+ num_experts: int,
+ expert_parallel_group: torch.distributed.ProcessGroup,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+) -> int:
+ world_size = (
+ dist.get_world_size(expert_parallel_group)
+ if moe_expert_model_parallelism
+ else 1
+ )
+ tokens_per_expert = top_k * tokens * world_size / num_experts
+ return int(moe_capacity_factor * tokens_per_expert)
+
+
+def permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ top_k,
+ w1,
+ w2,
+ w1_bias,
+ w2_bias,
+ gradient_scale,
+ alpha,
+):
+ # Route tokens to experts
+ x = x.view(-1, x.shape[-1])
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(x.device):
+ x = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
+
+ # Expert computation
+ x = mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale, alpha)
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(x.device):
+ # Route tokens back
+ out = ops.binned_scatter(x, indices, expert_weights, bins, top_k)
+ return out
+
+
+def forward_once(
+ x: torch.Tensor,
+ expert_weights: torch.Tensor,
+ top_experts: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ top_k: int = 4,
+ num_experts: int = 128,
+ expert_parallel_group: int = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ mlp_impl: Optional[str] = None,
+):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
+ top_experts, sort_end_bit, num_experts
+ )
+
+ # Calculate expert capacity
+ sl, bs, _ = x.size()
+
+ expert_capacity = expert_capacity_fn(
+ sl * bs,
+ top_k,
+ num_experts,
+ expert_parallel_group,
+ moe_capacity_factor,
+ moe_expert_model_parallelism,
+ )
+
+ if expert_capacity == 0:
+ expert_capacity = torch.max(tokens_per_expert).item()
+
+ x = permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ top_k,
+ w1,
+ w2,
+ w1_bias,
+ w2_bias,
+ gradient_scale,
+ alpha,
+ )
+ return x, tokens_per_expert
+
+
+def parallel_forward_once(
+ x: torch.Tensor,
+ expert_weights: torch.Tensor,
+ top_experts: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ top_k: int = 4,
+ num_experts: int = 128,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = True,
+ hidden_size: int = 1152,
+ mlp_impl: Optional[str] = "grouped",
+):
+ # Flatten inputs
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+
+ # TODO: remove debugging var
+ # my_rank = dist.get_rank(expert_parallel_group) if expert_parallel_group else 0
+
+ with torch.no_grad():
+ # Step 1: Local permutation setup
+ indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
+ top_experts, sort_end_bit, num_experts
+ )
+
+ # Calculate sharding parameters
+ world_size = dist.get_world_size(expert_parallel_group)
+ hidden_sharding_deg = hidden_sharding_degree(
+ world_size, num_experts, hidden_size
+ )
+ experts_per_rank_val = experts_per_rank(num_experts, world_size)
+
+ # Replicate token counts for hidden sharding
+ repeated_tokens_per_expert = ops.repeat(
+ tokens_per_expert, (hidden_sharding_deg,)
+ )
+
+ # Exchange token counts across devices
+ parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert)
+
+ # Ensure CUB knows which device to use
+ tpe_handle = dist.all_to_all_single(
+ parallel_tokens_per_expert,
+ repeated_tokens_per_expert,
+ group=expert_parallel_group,
+ async_op=True,
+ )
+
+ # Step 2: Local permutation - group tokens by target device
+ x = x.view(-1, x.shape[-1]) # [sl * bs, hs]
+ x = ops.gather(x, indices, bin_ids, bins, top_k)
+
+ # Step 3: Compute communication counts and exchange tokens
+ with torch.no_grad():
+ tpe_handle.wait()
+
+ # Reshape for per-device calculations
+ repeated_tokens_per_expert = repeated_tokens_per_expert.view(
+ world_size, experts_per_rank_val
+ )
+ parallel_tokens_per_expert = parallel_tokens_per_expert.view(
+ world_size, experts_per_rank_val
+ )
+
+ # Calculate send/recv counts
+ send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1).tolist()
+ # recv_counts = parallel_tokens_per_expert.cpu().sum(dim=-1).tolist()
+ parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
+ recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1).tolist()
+ tokens_received = sum(recv_counts)
+
+ # Replicate for hidden sharding
+ x = ops.repeat(x, (hidden_sharding_deg, 1))
+
+ # Cross-device token exchange
+ parallel_x, parallel_x_handle = _layers.all_to_all.all_to_all(
+ x, recv_counts, send_counts, expert_parallel_group, async_op=True
+ )
+
+ with torch.no_grad():
+ # Step 4: Setup for local expert computation
+ replicate_bins = ops.inclusive_cumsum(parallel_tokens_per_expert.flatten(), 0)
+ replicate_bins = (
+ replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins
+ )
+
+ # Create expert indices for received tokens
+ parallel_top_expert = torch.remainder(
+ torch.arange(
+ num_experts * hidden_sharding_deg,
+ dtype=torch.int32,
+ device=indices.device,
+ ),
+ experts_per_rank_val,
+ )
+ parallel_top_expert = ops.replicate(
+ parallel_top_expert.unsqueeze(dim=0),
+ replicate_bins,
+ tokens_received,
+ ).flatten()
+
+ # Sort tokens by expert assignment
+ parallel_bin_ids, parallel_indices = ops.sort(
+ parallel_top_expert,
+ sort_end_bit,
+ )
+
+ # Calculate bins for local experts
+ parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
+ dim=0, dtype=torch.int
+ )
+ parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = (
+ parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins
+ )
+
+ # Calculate expert capacity
+ expert_capacity = expert_capacity_fn(
+ tokens_received,
+ top_k,
+ experts_per_rank_val,
+ expert_parallel_group,
+ moe_capacity_factor,
+ moe_expert_model_parallelism,
+ )
+ if expert_capacity == 0:
+ expert_capacity = torch.max(parallel_tokens_per_expert).item()
+
+ # Locally permute the tokens and perform the expert computation.
+ # Block to make sure that the cross-device permutation is complete.
+ if mlp_impl == "grouped":
+ # GroupedMLP requires counts on CPU. We can use the tensor already
+ # moved to CPU for the prior all_to_all, which avoids an extra
+ # device synchronization.
+ parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+
+ # Step 5: Expert computation
+ parallel_x_handle.wait()
+
+ parallel_x = permute_and_compute(
+ parallel_x,
+ parallel_tokens_per_expert,
+ parallel_indices,
+ parallel_bin_ids,
+ None, # expert_weights
+ parallel_bins,
+ expert_capacity,
+ top_k=1,
+ w1=w1,
+ w2=w2,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ )
+
+ # Step 6: Reverse communication - send results back
+ x, _ = _layers.all_to_all.all_to_all(
+ parallel_x, send_counts, recv_counts, expert_parallel_group
+ )
+
+ # Step 7: Reduce across hidden sharding dimension
+ shape = (hidden_sharding_deg, -1, hidden_size)
+ x = x.view(shape).sum(dim=0)
+
+ # Step 8: Final local unpermutation
+ x = ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
+
+ return x, tokens_per_expert.flatten()
+
+
+def moe_forward(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: Optional[torch.Tensor],
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+ w1: torch.Tensor = None,
+ w2: torch.Tensor = None,
+ w1_bias: torch.Tensor = None,
+ w2_bias: torch.Tensor = None,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ forward_fn: Any = None,
+ hidden_size: int = None,
+ mlp_impl: str = "grouped",
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+
+ # Route tokens to experts
+ logits, expert_weights, expert_indices = route_tokens(
+ x,
+ router_weight,
+ router_bias,
+ moe_top_k,
+ moe_num_experts,
+ moe_jitter_eps,
+ moe_normalize_expert_weights,
+ uniform_expert_assignment,
+ training,
+ )
+
+ # Create router scores for output
+ router_scores = (
+ torch.zeros_like(logits)
+ .scatter_(1, expert_indices, expert_weights)
+ .transpose(0, 1)
+ )
+
+ in_shape = x.size()
+
+ # Prepare forward function arguments
+ forward_args = {
+ "x": x,
+ "expert_weights": expert_weights,
+ "top_experts": expert_indices,
+ "w1": w1,
+ "w2": w2,
+ "w1_bias": w1_bias,
+ "w2_bias": w2_bias,
+ "gradient_scale": gradient_scale,
+ "alpha": alpha,
+ "sort_end_bit": sort_end_bit,
+ "top_k": moe_top_k,
+ "num_experts": moe_num_experts,
+ "expert_parallel_group": expert_parallel_group,
+ "moe_capacity_factor": moe_capacity_factor,
+ "moe_expert_model_parallelism": moe_expert_model_parallelism,
+ "mlp_impl": mlp_impl,
+ }
+
+ # Add hidden_size for parallel forward
+ if moe_expert_model_parallelism and hidden_size is not None:
+ forward_args["hidden_size"] = hidden_size
+ elif moe_expert_model_parallelism and hidden_size is None:
+ # Infer hidden_size from input shape
+ forward_args["hidden_size"] = x.shape[-1]
+
+ # Compute expert outputs
+ x, tokens_per_expert = forward_fn(**forward_args)
+
+ # Save load balancing loss if needed
+ moe_loss_weight = 0.0 # Can be made configurable
+ if training and moe_loss_weight > 0:
+ save_load_balancing_loss((tokens_per_expert, logits))
+
+ # Restore original shape
+ x = x.view(in_shape)
+
+ return x, expert_weights, router_scores
+
+
+def moe_forward_with_shared_expert(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: Optional[torch.Tensor],
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+ w1: torch.Tensor = None,
+ w2: torch.Tensor = None,
+ w1_bias: torch.Tensor = None,
+ w2_bias: torch.Tensor = None,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ forward_fn: Any = None,
+ hidden_size: int = None,
+ mlp_impl: str = "grouped",
+ # Shared expert parameters
+ shared_up_proj_weight: Optional[torch.Tensor] = None,
+ shared_down_proj_weight: Optional[torch.Tensor] = None,
+ shared_up_proj_bias: Optional[torch.Tensor] = None,
+ shared_down_proj_bias: Optional[torch.Tensor] = None,
+ shared_expert_weighted_sum: bool = False,
+ shared_activation_fn: Optional[Any] = None,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+
+ # First, compute regular MoE forward pass
+ expert_out, expert_weights, router_scores = moe_forward(
+ x=x,
+ router_weight=router_weight,
+ router_bias=router_bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=training,
+ w1=w1,
+ w2=w2,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=moe_expert_model_parallelism,
+ forward_fn=forward_fn,
+ hidden_size=hidden_size,
+ mlp_impl=mlp_impl,
+ )
+
+ # If shared expert weights provided, compute shared expert output
+ if shared_up_proj_weight is not None and shared_down_proj_weight is not None:
+ shared_expert_out = shared_mlp_forward(
+ x=x,
+ up_proj_weight=shared_up_proj_weight,
+ down_proj_weight=shared_down_proj_weight,
+ up_proj_bias=shared_up_proj_bias,
+ down_proj_bias=shared_down_proj_bias,
+ activation_fn=shared_activation_fn,
+ gradient_scale=gradient_scale,
+ )
+
+ # Combine expert outputs
+ combined_out = combine_expert_shared_outputs(
+ shared_expert_out=shared_expert_out,
+ expert_out=expert_out,
+ shared_expert_weighted_sum=shared_expert_weighted_sum,
+ moe_top_k=moe_top_k,
+ )
+
+ return combined_out, expert_weights, router_scores
+
+ # Return regular MoE output if no shared expert
+ return expert_out, expert_weights, router_scores
+
+
+def create_shared_expert_weights(
+ hidden_size: int,
+ shared_expert_hidden_size: int,
+ device: torch.device,
+ dtype: torch.dtype,
+ init_method: Any,
+ output_layer_init_method: Any = None,
+) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
+
+ if output_layer_init_method is None:
+ output_layer_init_method = init_method
+
+ # Create weight tensors
+ up_proj_weight = torch.empty(
+ shared_expert_hidden_size,
+ hidden_size,
+ device=device,
+ dtype=dtype,
+ )
+ down_proj_weight = torch.empty(
+ hidden_size,
+ shared_expert_hidden_size,
+ device=device,
+ dtype=dtype,
+ )
+
+ # Initialize weights
+ init_method(up_proj_weight)
+ output_layer_init_method(down_proj_weight)
+
+ # No bias by default
+ return up_proj_weight, down_proj_weight, None, None
+
+
+# HACK: Extract device_mesh from pre-hook closure - required for transformers integration
+# This exists because device_mesh is trapped in hook closures with no model attribute
+# Fragile - breaks if hook structure changes or Python internals change
+# TODO: Replace with a more robust solution when available
+def get_device_mesh(model):
+ # Extract device_mesh from child's unused pre_hook closure
+ try:
+ # Find the pre-hook that contains 'device_mesh' in its closure
+ hook = next(
+ h
+ for h in model.experts._forward_pre_hooks.values()
+ if "device_mesh" in h.__code__.co_freevars
+ )
+ # Extract the device_mesh from the closure
+ return hook.__closure__[
+ hook.__code__.co_freevars.index("device_mesh")
+ ].cell_contents
+ except Exception:
+ return None
+
+
+class MegaBlocksMoeMLP(torch.nn.Module):
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ gradient_scale = getattr(self.experts, "gradient_scale", None)
+ alpha = getattr(self.experts, "alpha", 1.0)
+ moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+ uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
+
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = get_device_mesh(self)
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
+
+ has_parallel = (
+ expert_parallel_group is not None
+ and dist.is_initialized()
+ and dist.get_world_size(expert_parallel_group) > 1
+ )
+ forward_fn = parallel_forward_once if has_parallel else forward_once
+
+ sort_end_bit = max(
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
+ )
+ mlp_impl = getattr(self, "mlp_impl", "grouped")
+ output, expert_weights_out, *_ = moe_forward(
+ x=x,
+ router_weight=self.router.weight,
+ router_bias=self.router.bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=self.training,
+ w1=self.experts.gate_up_proj,
+ w2=self.experts.down_proj,
+ w1_bias=self.experts.gate_up_proj_bias,
+ w2_bias=self.experts.down_proj_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=has_parallel,
+ forward_fn=forward_fn,
+ hidden_size=self.experts.hidden_size,
+ mlp_impl=mlp_impl,
+ )
+ return output, expert_weights_out
+
+
+# Export main classes
+__all__ = ["MegaBlocksMoeMLP", "MegaBlocksMoeMLPWithSharedExpert"]
+
+
+class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
+
+ def __init__(self):
+ super().__init__()
+ # Shared expert weights will be set by the user
+ self.shared_up_proj_weight = None
+ self.shared_down_proj_weight = None
+ self.shared_up_proj_bias = None
+ self.shared_down_proj_bias = None
+ self.shared_expert_weighted_sum = False
+ self.shared_activation_fn = None
+
+ def set_shared_expert_weights(
+ self,
+ up_proj_weight: torch.Tensor,
+ down_proj_weight: torch.Tensor,
+ up_proj_bias: Optional[torch.Tensor] = None,
+ down_proj_bias: Optional[torch.Tensor] = None,
+ weighted_sum: bool = False,
+ activation_fn: Optional[Any] = None,
+ ):
+ self.shared_up_proj_weight = up_proj_weight
+ self.shared_down_proj_weight = down_proj_weight
+ self.shared_up_proj_bias = up_proj_bias
+ self.shared_down_proj_bias = down_proj_bias
+ self.shared_expert_weighted_sum = weighted_sum
+ self.shared_activation_fn = activation_fn
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ gradient_scale = getattr(self.experts, "gradient_scale", None)
+ alpha = getattr(self.experts, "alpha", 1.0)
+ moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+ uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
+
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = get_device_mesh(self)
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
+
+ has_parallel = (
+ expert_parallel_group is not None
+ and dist.is_initialized()
+ and dist.get_world_size(expert_parallel_group) > 1
+ )
+ forward_fn = parallel_forward_once if has_parallel else forward_once
+
+ sort_end_bit = max(
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
+ )
+ mlp_impl = getattr(self, "mlp_impl", "grouped")
+
+ output, expert_weights_out, *_ = moe_forward_with_shared_expert(
+ x=x,
+ router_weight=self.router.weight,
+ router_bias=self.router.bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=self.training,
+ w1=self.experts.gate_up_proj,
+ w2=self.experts.down_proj,
+ w1_bias=self.experts.gate_up_proj_bias,
+ w2_bias=self.experts.down_proj_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=has_parallel,
+ forward_fn=forward_fn,
+ hidden_size=self.experts.hidden_size,
+ mlp_impl=mlp_impl,
+ # Shared expert parameters
+ shared_up_proj_weight=self.shared_up_proj_weight,
+ shared_down_proj_weight=self.shared_down_proj_weight,
+ shared_up_proj_bias=self.shared_up_proj_bias,
+ shared_down_proj_bias=self.shared_down_proj_bias,
+ shared_expert_weighted_sum=self.shared_expert_weighted_sum,
+ shared_activation_fn=self.shared_activation_fn,
+ )
+ return output, expert_weights_out
+
+
+# Patch for XPU or CPU support
+if hasattr(torch, "xpu") and torch.xpu.is_available():
+ from .xpu_fused_moe import MegaBlocksMoeMLP
+
+from .cpu_moe_cpp import CPUMegaBlocksMoeMLP
diff --git a/build/torch29-cxx11-cu128-aarch64-linux/megablocks/__init__.py b/build/torch29-cxx11-cu128-aarch64-linux/megablocks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..03dbc1afe1cf156661a2b1b22003cd5f599a0309
--- /dev/null
+++ b/build/torch29-cxx11-cu128-aarch64-linux/megablocks/__init__.py
@@ -0,0 +1,26 @@
+import ctypes
+import sys
+
+import importlib
+from pathlib import Path
+from types import ModuleType
+
+def _import_from_path(file_path: Path) -> ModuleType:
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
+ # it would also be used for other imports. So, we make a module name that
+ # depends on the path for it to be unique using the hex-encoded hash of
+ # the path.
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
+ module_name = path_hash
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ if spec is None:
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
+ module = importlib.util.module_from_spec(spec)
+ if module is None:
+ raise ImportError(f"Cannot load module {module_name} from spec")
+ sys.modules[module_name] = module
+ spec.loader.exec_module(module) # type: ignore
+ return module
+
+
+globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
diff --git a/build/torch29-cxx11-cu128-aarch64-linux/metadata.json b/build/torch29-cxx11-cu128-aarch64-linux/metadata.json
new file mode 100644
index 0000000000000000000000000000000000000000..e3e4edf582b7ffb515d0ed32e9fc9c89f125c441
--- /dev/null
+++ b/build/torch29-cxx11-cu128-aarch64-linux/metadata.json
@@ -0,0 +1,21 @@
+{
+ "version": 1,
+ "license": "Apache-2.0",
+ "python-depends": [],
+ "backend": {
+ "type": "cuda",
+ "archs": [
+ "10.0",
+ "10.1",
+ "12.0",
+ "7.0",
+ "7.2",
+ "7.5",
+ "8.0",
+ "8.6",
+ "8.7",
+ "8.9",
+ "9.0"
+ ]
+ }
+}
diff --git a/build/torch29-cxx11-cu128-aarch64-linux/ops/__init__.py b/build/torch29-cxx11-cu128-aarch64-linux/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b944080df810d0b0cfc571f3009b0098a651f9b7
--- /dev/null
+++ b/build/torch29-cxx11-cu128-aarch64-linux/ops/__init__.py
@@ -0,0 +1,35 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from .binned_gather import binned_gather
+from .binned_scatter import binned_scatter
+from .cumsum import exclusive_cumsum, inclusive_cumsum
+from .gather import gather
+from .histogram import histogram
+from .padded_gather import padded_gather
+from .padded_scatter import padded_scatter
+from .repeat import repeat
+from .replicate import replicate
+from .round_up import round_up
+from .scatter import scatter
+from .sort import sort
+from .sum import sum
+from .topology import topology
+
+__all__ = [
+ 'binned_gather',
+ 'binned_scatter',
+ 'exclusive_cumsum',
+ 'inclusive_cumsum',
+ 'gather',
+ 'histogram',
+ 'padded_gather',
+ 'padded_scatter',
+ 'repeat',
+ 'replicate',
+ 'round_up',
+ 'scatter',
+ 'sort',
+ 'sum',
+ 'topology',
+]
diff --git a/build/torch29-cxx11-cu128-aarch64-linux/ops/all_to_all_benchmark.py b/build/torch29-cxx11-cu128-aarch64-linux/ops/all_to_all_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c939818edca3345f6344bbc7cef07ffe3cd0181
--- /dev/null
+++ b/build/torch29-cxx11-cu128-aarch64-linux/ops/all_to_all_benchmark.py
@@ -0,0 +1,63 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import torch.distributed as dist
+
+# from megablocks import benchmark_util
+# from megablocks.layers.all_to_all import all_to_all
+
+from .. import benchmark_util
+from .._layers.all_to_all import all_to_all
+
+_ALL_TO_ALL_BENCHMARK = (
+ (8, 1024),
+ (16, 1024),
+ (32, 1024),
+ (64, 1024),
+ (128, 1024),
+ (256, 1024),
+ (512, 1024),
+ (1024, 1024),
+ (2 * 1024, 1024),
+ (4 * 1024, 1024),
+ (8 * 1024, 1024),
+ (16 * 1024, 1024),
+ (32 * 1024, 1024),
+ (64 * 1024, 1024),
+ (128 * 1024, 1024),
+ (256 * 1024, 1024),
+ (512 * 1024, 1024),
+ (1024 * 1024, 1024),
+)
+
+
+def benchmark_all_to_all(group, sl, hs):
+ world_size = dist.get_world_size(group)
+ assert (sl % world_size) == 0
+ send_recv_sizes = [sl // world_size] * world_size
+
+ x = torch.randn((sl, hs)).cuda().half()
+
+ details = {
+ 'world_size': world_size,
+ 'message_size (B)': send_recv_sizes[0] * hs * 2, # 2B elements.
+ }
+
+ def benchmark():
+ return all_to_all(x, send_recv_sizes, send_recv_sizes, group)
+
+ time, std = benchmark_util.benchmark_function(benchmark)
+
+ if dist.get_rank(group) == 0:
+ benchmark_util.log_benchmark('All-To-All', details, time, std)
+
+
+if __name__ == '__main__':
+ assert dist.is_available()
+ group = dist.init_process_group(backend='nccl')
+ local_rank = dist.get_rank(group)
+ torch.cuda.set_device(local_rank)
+
+ for args in _ALL_TO_ALL_BENCHMARK:
+ benchmark_all_to_all(group, *args)
diff --git a/build/torch29-cxx11-cu128-aarch64-linux/ops/binned_gather.py b/build/torch29-cxx11-cu128-aarch64-linux/ops/binned_gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..189a7fa3518d660f29ea32e7a04827164af98d60
--- /dev/null
+++ b/build/torch29-cxx11-cu128-aarch64-linux/ops/binned_gather.py
@@ -0,0 +1,37 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for binned_gather kernel.
+class BinnedGatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bins: torch.Tensor,
+ bin_size: int,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bins)
+ ctx.top_k = top_k
+ return kernels.binned_gather(x, indices, None, bins, bin_size, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ indices, bins = ctx.saved_tensors
+ out = kernels.binned_scatter(grad, indices, None, bins, ctx.top_k)
+ return out, None, None, None, None
+
+
+binned_gather = BinnedGatherOp.apply
diff --git a/build/torch29-cxx11-cu128-aarch64-linux/ops/binned_scatter.py b/build/torch29-cxx11-cu128-aarch64-linux/ops/binned_scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb937c0c106662ce8108c1cb926f8f063b163d3d
--- /dev/null
+++ b/build/torch29-cxx11-cu128-aarch64-linux/ops/binned_scatter.py
@@ -0,0 +1,59 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for binned_scatter kernel.
+class BinnedScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ):
+ assert len(x.size()) == 3
+ ctx.bin_size = x.size(1)
+ ctx.top_k = top_k
+
+ # TODO(tgale): Don't save 'x' for backwards if we don't need to
+ # calculate the gradient w.r.t. 'weights'.
+ ctx.save_for_backward(x, indices, weights, bins)
+ return kernels.binned_scatter(x, indices, weights, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ x, indices, weights, bins = ctx.saved_tensors
+ out = kernels.binned_gather(
+ grad,
+ indices,
+ weights,
+ bins,
+ ctx.bin_size,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[2]:
+ wgrad = kernels.binned_scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bins,
+ ctx.top_k,
+ )
+ return out, None, wgrad, None, None
+
+
+binned_scatter = BinnedScatterOp.apply
diff --git a/build/torch29-cxx11-cu128-aarch64-linux/ops/cumsum.py b/build/torch29-cxx11-cu128-aarch64-linux/ops/cumsum.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2b7572391e20045d335cf7337246e8a9b9f57ef
--- /dev/null
+++ b/build/torch29-cxx11-cu128-aarch64-linux/ops/cumsum.py
@@ -0,0 +1,52 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ # import megablocks_ops as ops # type: ignore
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrappers for cumsum kernels.
+# NOTE: Does not support gradients.
+class ExclusiveCumsumOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, dim: int):
+ if len(x.size()) == 1:
+ x = x.view([1, -1])
+ out = torch.empty_like(x)
+ ops.exclusive_cumsum(x, 1, out)
+ return out.squeeze()
+ out = torch.empty_like(x)
+ ops.exclusive_cumsum(x, dim, out)
+ return out
+
+
+exclusive_cumsum = ExclusiveCumsumOp.apply
+
+
+class InclusiveCumsumOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, dim: int) -> torch.Tensor:
+ if len(x.size()) == 1:
+ x = x.view([1, -1])
+ out = torch.empty_like(x)
+ ops.inclusive_cumsum(x, 1, out)
+ return out.squeeze()
+ out = torch.empty_like(x)
+ ops.inclusive_cumsum(x, dim, out)
+ return out
+
+
+inclusive_cumsum = InclusiveCumsumOp.apply
diff --git a/build/torch29-cxx11-cu128-aarch64-linux/ops/gather.py b/build/torch29-cxx11-cu128-aarch64-linux/ops/gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1f87c1e7bed8d3589dd790805234976e0b05898
--- /dev/null
+++ b/build/torch29-cxx11-cu128-aarch64-linux/ops/gather.py
@@ -0,0 +1,38 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for gather kernel.
+class GatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bin_ids, bins)
+ ctx.top_k = top_k
+ return kernels.gather(x, indices, bin_ids, None, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+
+ indices, bin_ids, bins = ctx.saved_tensors
+ out = kernels.scatter(grad, indices, bin_ids, None, bins, ctx.top_k)
+ return out, None, None, None, None, None
+
+
+gather = GatherOp.apply
diff --git a/build/torch29-cxx11-cu128-aarch64-linux/ops/histogram.py b/build/torch29-cxx11-cu128-aarch64-linux/ops/histogram.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b3f058ec373cbba7555704fb5e4212c3cc75d9d
--- /dev/null
+++ b/build/torch29-cxx11-cu128-aarch64-linux/ops/histogram.py
@@ -0,0 +1,27 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for histogram kernel.
+# NOTE: Does not support gradients.
+class HistogramOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, max_val: float):
+ return ops.histogram(x, max_val)
+
+
+histogram = HistogramOp.apply
diff --git a/build/torch29-cxx11-cu128-aarch64-linux/ops/histogram_benchmark.py b/build/torch29-cxx11-cu128-aarch64-linux/ops/histogram_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..c57b7bf8228e01237236748147368b09ffdf8072
--- /dev/null
+++ b/build/torch29-cxx11-cu128-aarch64-linux/ops/histogram_benchmark.py
@@ -0,0 +1,78 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import numpy as np
+import torch
+from absl.testing import parameterized
+
+from .. import ops
+
+_HISTOGRAM_TESTS = (
+ (16384, torch.int32, 2),
+ (16384, torch.int32, 4),
+ (16384, torch.int32, 8),
+ (16384, torch.int32, 16),
+ (16384, torch.int32, 32),
+ (16384, torch.int32, 64),
+ (16384, torch.int32, 128),
+ (16384, torch.int32, 256),
+)
+
+
+def benchmark_function(fn, iterations=10):
+ # Run once to get rid of startup overhead.
+ fn()
+ times = []
+ for _ in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+ start.record()
+ fn()
+ end.record()
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ times = np.array(times)
+ return times.mean(), times.std(), times.max(), times.min()
+
+
+def log_benchmark(arguments, mean_t, std_t):
+ print('=' * 60)
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t))
+ print('=' * 60)
+
+
+class HistogramBenchmark(parameterized.TestCase):
+
+ @parameterized.parameters(*_HISTOGRAM_TESTS)
+ def testHistogram(self, n, dtype, max_val):
+ x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
+
+ mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.histogram(x, max_val),)
+ arguments = {
+ 'n': n,
+ 'dtype': dtype,
+ 'max_val': max_val,
+ }
+ log_benchmark(arguments, mean_t, std_t)
+
+ @parameterized.parameters(*_HISTOGRAM_TESTS)
+ def testTorchHistogram(self, n, dtype, max_val):
+ x = torch.randint(0, 128, (n,)).cuda().to(dtype)
+
+ mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.histc(x, max_val, 0, max_val - 1),)
+ arguments = {
+ 'n': n,
+ 'dtype': dtype,
+ 'max_val': max_val,
+ }
+ log_benchmark(arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch29-cxx11-cu128-aarch64-linux/ops/matmul_benchmark.py b/build/torch29-cxx11-cu128-aarch64-linux/ops/matmul_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ccc5dcec5e9a663794fad944c45285869c4d1c1
--- /dev/null
+++ b/build/torch29-cxx11-cu128-aarch64-linux/ops/matmul_benchmark.py
@@ -0,0 +1,415 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+
+# import stk
+
+# try:
+# import stk
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/ops/matmul_benchmark.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+
+# Calling tensor.t() calls tensor.transpose(0, 1) which calls
+# torch.as_strided(...). Circumvent this chain to avoid an overhead
+# this adds.
+def transpose_view(x):
+ return torch.as_strided(
+ x,
+ (x.shape[1], x.shape[0]),
+ (x.stride()[1], x.stride()[0]),
+ )
+
+
+_MATMUL_TESTS = (
+ (64 * 1024, 512, 2048, 64),
+ (32 * 1024, 768, 3072, 64),
+ (8 * 1024, 1024, 4096, 64),
+ (4 * 2048, 4096, 4 * 4096, 4),
+)
+
+
+def log_benchmark(name, arguments, time, std, flops):
+ benchmark_util.log_benchmark(name, arguments, time, std)
+ print('flops = {:.2f}B'.format(flops / 1e9))
+ print('throughput = {:.2f}T'.format(flops / 1e9 / time))
+ print('=' * 60)
+
+
+class MatmulBenchmark(parameterized.TestCase):
+
+ def build_sparse_matrix(self, x, padded_bins, fhs, ne):
+ blocking = 128
+ padded_tokens, _ = x.size()
+ assert padded_tokens % blocking == 0
+ assert fhs % blocking == 0
+
+ # Offsets for the sparse matrix. All rows have the
+ # same number of nonzero blocks dictated by the
+ # dimensionality of a single expert.
+ block_rows = padded_tokens // blocking
+ blocks_per_row = fhs // blocking
+ offsets = torch.arange(
+ 0,
+ block_rows * blocks_per_row + 1,
+ blocks_per_row,
+ dtype=torch.int32,
+ device=x.device,
+ )
+
+ # Indices for the sparse matrix. The indices for
+ # the intermediate matrix are dynamic depending
+ # on the mapping of tokens to experts.
+ column_indices = ops.topology(
+ padded_bins,
+ blocking,
+ block_rows,
+ blocks_per_row,
+ )
+ data = torch.empty(
+ column_indices.numel(),
+ blocking,
+ blocking,
+ dtype=torch.float16,
+ device=x.device,
+ )
+ shape = (padded_tokens, fhs * ne)
+ row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
+ return stk.Matrix(shape, data, row_indices, column_indices, offsets)
+
+ def build_input_matrix(self, sl, hs, ne):
+ x = torch.randn((sl, hs)).cuda().half()
+
+ # Assign tokens to experts uniformly.
+ top_expert = torch.arange(0, sl).cuda().int() % ne
+
+ bin_ids, indices = ops.sort(top_expert)
+ tokens_per_expert = ops.histogram(top_expert, ne)
+ padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ out = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, 1)
+ return out, padded_bins
+
+ def build_weight_matrix(self, ne, hs, fhs):
+ return torch.randn((hs, ne * fhs)).cuda().half()
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear0_Fwd_SDD_NT(self, sl, hs, fhs, ne):
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
+ w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+ topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+ w = transpose_view(w)
+
+ def benchmark():
+ return stk.ops.sdd(x, w, topo)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '0::Fwd::SDD::NT',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * fhs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear0_GradX_DSD_NN(self, sl, hs, fhs, ne):
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
+ w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+ topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+
+ def benchmark():
+ return stk.ops.dsd(topo, w)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '0::GradX::DSD::NN',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * fhs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear0_GradW_DSD_TN(self, sl, hs, fhs, ne):
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
+ topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+ topo = topo.t()
+
+ def benchmark():
+ return stk.ops.dsd(topo, x)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '0::GradW::DSD::TN',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * fhs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear1_Fwd_DSD_NN(self, sl, hs, fhs, ne):
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
+ w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+ x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+
+ def benchmark():
+ return stk.ops.dsd(x, w)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '1::Fwd::DSD::NN',
+ arguments,
+ mean_t,
+ std_t,
+ x.nnz * hs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear1_GradX_SDD_NT(self, sl, hs, fhs, ne):
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
+ w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+ x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+ out = stk.ops.dsd(x, w)
+ w = transpose_view(w)
+
+ def benchmark():
+ return stk.ops.sdd(out, w, x)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '1::GradX::SDD::NT',
+ arguments,
+ mean_t,
+ std_t,
+ x.nnz * hs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear1_GradW_DSD_TN(self, sl, hs, fhs, ne):
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
+ w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+ x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+ out = stk.ops.dsd(x, w)
+ x = x.t()
+
+ def benchmark():
+ return stk.ops.dsd(x, out)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '1::GradW::DSD::TN',
+ arguments,
+ mean_t,
+ std_t,
+ x.nnz * hs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear0_Fwd_DDD_NT(self, sl, hs, fhs, ne):
+ assert (sl % ne) == 0
+ x = torch.randn((ne, sl // ne, hs)).cuda().half()
+ w = torch.randn((ne, hs, fhs)).cuda().half()
+
+ w = w.transpose(1, 2).contiguous()
+ w = w.transpose(1, 2)
+
+ def benchmark():
+ return torch.bmm(x, w)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '0::Fwd:DDD::NT',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * fhs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear0_GradX_DDD_NN(self, sl, hs, fhs, ne):
+ assert (sl % ne) == 0
+ x = torch.randn((ne, sl // ne, hs)).cuda().half()
+ w = torch.randn((ne, hs, fhs)).cuda().half()
+ out = torch.bmm(x, w)
+ w = w.transpose(1, 2).contiguous()
+
+ def benchmark():
+ return torch.bmm(out, w)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '0:GradX:DDD::NN',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * fhs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear0_GradW_DDD_TN(self, sl, hs, fhs, ne):
+ assert (sl % ne) == 0
+ x = torch.randn((ne, sl // ne, hs)).cuda().half()
+ w = torch.randn((ne, hs, fhs)).cuda().half()
+ out = torch.bmm(x, w)
+ out = out.transpose(1, 2)
+
+ def benchmark():
+ return torch.bmm(out, x)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '0:GradW:DDD::TN',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * fhs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear1_Fwd_DDD_NN(self, sl, hs, fhs, ne):
+ assert (sl % ne) == 0
+ x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+ w = torch.randn((ne, fhs, hs)).cuda().half()
+
+ def benchmark():
+ return torch.bmm(x, w)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '1::Fwd::DDD::NN',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * hs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear1_GradX_DDD_NT(self, sl, hs, fhs, ne):
+ assert (sl % ne) == 0
+ x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+ w = torch.randn((ne, fhs, hs)).cuda().half()
+ out = torch.bmm(x, w)
+ w = torch.transpose(w, 1, 2)
+
+ def benchmark():
+ return torch.bmm(out, w)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '1::GradX::DDD::NT',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * hs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear1_GradW_DDD_TN(self, sl, hs, fhs, ne):
+ assert (sl % ne) == 0
+ x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+ w = torch.randn((ne, fhs, hs)).cuda().half()
+ out = torch.bmm(x, w)
+ x = torch.transpose(x, 1, 2)
+
+ def benchmark():
+ return torch.bmm(x, out)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '1::GradW::DDD::TN',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * hs * 2,
+ )
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch29-cxx11-cu128-aarch64-linux/ops/padded_gather.py b/build/torch29-cxx11-cu128-aarch64-linux/ops/padded_gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1cf4047c9494394d2a3884ba8830179013db7ff
--- /dev/null
+++ b/build/torch29-cxx11-cu128-aarch64-linux/ops/padded_gather.py
@@ -0,0 +1,55 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for padded_gather kernel.
+class PaddedGatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bin_ids, bins, padded_bins)
+ ctx.top_k = top_k
+ return kernels.padded_gather(
+ x,
+ indices,
+ bin_ids,
+ None,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+
+ indices, bin_ids, bins, padded_bins = ctx.saved_tensors
+ out = kernels.padded_scatter(
+ grad,
+ indices,
+ bin_ids,
+ None,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+ return out, None, None, None, None, None
+
+
+padded_gather = PaddedGatherOp.apply
diff --git a/build/torch29-cxx11-cu128-aarch64-linux/ops/padded_scatter.py b/build/torch29-cxx11-cu128-aarch64-linux/ops/padded_scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..61e021b81497e472cda5d72bdac557a0ca92d262
--- /dev/null
+++ b/build/torch29-cxx11-cu128-aarch64-linux/ops/padded_scatter.py
@@ -0,0 +1,98 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for padded_scatter kernel.
+class PaddedScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+ ):
+ maybe_x = [x] if ctx.needs_input_grad[3] else []
+ ctx.save_for_backward(
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ *maybe_x,
+ )
+ ctx.top_k = top_k
+ ctx.x_shape = x.shape
+ return kernels.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ saved_tensors = ctx.saved_tensors
+
+ indices, bin_ids, weights, bins, padded_bins = saved_tensors[:5]
+ dgrad = None
+ if ctx.needs_input_grad[0]:
+ dgrad = kernels.padded_gather(
+ grad,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[3]: # need wgrad
+ x = saved_tensors[-1]
+ wgrad = kernels.padded_scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+ return dgrad, None, None, wgrad, None, None, None, None
+
+
+def padded_scatter(
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+):
+ return PaddedScatterOp.apply(
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
diff --git a/build/torch29-cxx11-cu128-aarch64-linux/ops/padded_scatter_benchmark.py b/build/torch29-cxx11-cu128-aarch64-linux/ops/padded_scatter_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..c575cfe7487d346ba9ec18bbb7ef17f2eb77ec51
--- /dev/null
+++ b/build/torch29-cxx11-cu128-aarch64-linux/ops/padded_scatter_benchmark.py
@@ -0,0 +1,66 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import torch
+from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+_PADDED_SCATTER_BENCHMARK = (
+ # dMoE-Medium, 8-way EMP.
+ (1024 * 16, 1024, 8, 4),
+ # dMoE-Medium, post-all-to-all.
+ (1024 * 16 * 4, 1024, 8, 1),
+)
+
+
+class PaddedScatterTest(parameterized.TestCase):
+
+ @parameterized.parameters(*_PADDED_SCATTER_BENCHMARK)
+ def testPaddedScatter(self, sl, hs, ne, top_k):
+ # Create the data and indices.
+ x = torch.randn((sl, hs)).cuda().half()
+
+ # Randomly assign tokens to experts.
+ top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int()
+ bin_ids, indices = ops.sort(top_expert)
+ tokens_per_expert = ops.histogram(top_expert, ne)
+ padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+
+ # Sample weights for the scatter reduce.
+ weights = torch.rand((sl * top_k,)).cuda().half()
+
+ # Gather the data to prepare for backwards.
+ x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
+
+ def benchmark():
+ return ops.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ time, std = benchmark_util.benchmark_function(benchmark)
+ benchmark_util.log_benchmark(
+ 'Padded Scatter',
+ {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'num_experts': ne,
+ 'top_k': top_k,
+ },
+ time,
+ std,
+ )
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch29-cxx11-cu128-aarch64-linux/ops/permute_benchmark.py b/build/torch29-cxx11-cu128-aarch64-linux/ops/permute_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..6536eeeae402659a087e5c51ef9840627af56501
--- /dev/null
+++ b/build/torch29-cxx11-cu128-aarch64-linux/ops/permute_benchmark.py
@@ -0,0 +1,149 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import torch
+from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+_PERMUTE_TESTS = (
+ (16384, 768, 2),
+ (16384, 768, 4),
+ (16384, 768, 8),
+ (16384, 768, 16),
+ (16384, 768, 32),
+ (16384, 768, 64),
+ (16384, 768, 128),
+ (16384 * 8, 768, 2),
+ (16384 * 8, 768, 4),
+ (16384 * 8, 768, 8),
+ (16384 * 8, 768, 16),
+ (16384 * 8, 768, 32),
+ (16384 * 8, 768, 64),
+ (16384 * 8, 768, 128),
+)
+
+
+class PermuteBenchmark(parameterized.TestCase):
+
+ @parameterized.parameters(*_PERMUTE_TESTS)
+ def testBinnedGather(self, sl, hs, ne):
+ # NOTE: Capacity factor == 1.
+ ec = sl // ne
+
+ # Create the data and indices.
+ x = torch.randn((sl, hs)).cuda().half()
+ top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+ bin_ids, indices = ops.sort(top_expert)
+ tokens_per_expert = ops.histogram(indices, ne)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+
+ def benchmark():
+ return ops.binned_gather(x, indices, bins, ec)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'num_experts': ne,
+ }
+ benchmark_util.log_benchmark('BinnedGather', arguments, mean_t, std_t)
+
+ @parameterized.parameters(*_PERMUTE_TESTS)
+ def testBinnedScatter(self, sl, hs, ne):
+ # NOTE: Capacity factor == 1.
+ ec = sl // ne
+
+ # Create the data and indices.
+ x = torch.randn((sl, hs)).cuda().half()
+ top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+ bin_ids, indices = ops.sort(top_expert)
+ tokens_per_expert = ops.histogram(indices, ne)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ x = ops.binned_gather(x, indices, bins, ec)
+
+ def benchmark():
+ return ops.binned_scatter(x, indices, bins)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'num_experts': ne,
+ }
+ benchmark_util.log_benchmark('BinnedScatter', arguments, mean_t, std_t)
+
+ @parameterized.parameters(*_PERMUTE_TESTS)
+ def testPaddedGather(self, sl, hs, ne):
+ # Create the data and indices.
+ x = torch.randn((sl, hs)).cuda().half()
+
+ # Randomly assign tokens to experts.
+ top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+ bin_ids, indices = ops.sort(top_expert)
+ tokens_per_expert = ops.histogram(top_expert, ne)
+ padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+
+ def benchmark():
+ return ops.padded_gather(x, indices, bin_ids, bins, padded_bins)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'num_experts': ne,
+ }
+ benchmark_util.log_benchmark('PaddedGather', arguments, mean_t, std_t)
+
+ @parameterized.parameters(*_PERMUTE_TESTS)
+ def testPaddedScatter(self, sl, hs, ne):
+ # Create the data and indices.
+ x = torch.randn((sl, hs)).cuda().half()
+
+ # Randomly assign tokens to experts.
+ top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+ bin_ids, indices = ops.sort(top_expert)
+ tokens_per_expert = ops.histogram(top_expert, ne)
+ padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins)
+
+ def benchmark():
+ return ops.padded_scatter(x, indices, bin_ids, bins, padded_bins)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'num_experts': ne,
+ }
+ benchmark_util.log_benchmark('PaddedScatter', arguments, mean_t, std_t)
+
+ @parameterized.parameters(*_PERMUTE_TESTS)
+ def testCopy(self, sl, hs, ne):
+ # NOTE: Capacity factor == 1.
+ # ec = sl // ne
+
+ # Create the data and indices.
+ x = torch.randn((sl, hs)).cuda().half()
+ y = x.clone()
+
+ def benchmark():
+ return y.copy_(x)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'num_experts': ne,
+ }
+ benchmark_util.log_benchmark('Copy', arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch29-cxx11-cu128-aarch64-linux/ops/repeat.py b/build/torch29-cxx11-cu128-aarch64-linux/ops/repeat.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e9e09de5f857d51cd758ab30b2f3a846d6f9275
--- /dev/null
+++ b/build/torch29-cxx11-cu128-aarch64-linux/ops/repeat.py
@@ -0,0 +1,10 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+
+def repeat(x: torch.Tensor, tiling: torch.Size):
+ if all((t == 1 for t in tiling)):
+ return x
+ return x.repeat(*tiling)
diff --git a/build/torch29-cxx11-cu128-aarch64-linux/ops/replicate.py b/build/torch29-cxx11-cu128-aarch64-linux/ops/replicate.py
new file mode 100644
index 0000000000000000000000000000000000000000..26daf0eede330603a4b8ea7167faf1411d07ca93
--- /dev/null
+++ b/build/torch29-cxx11-cu128-aarch64-linux/ops/replicate.py
@@ -0,0 +1,36 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for replicate kernel.
+class ReplicateOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, bins: torch.Tensor, num_outputs: int):
+ ctx.save_for_backward(bins)
+ out = torch.empty((x.shape[0], num_outputs), dtype=x.dtype, device=x.device)
+ ops.replicate_forward(x, bins, out)
+ return out
+
+ @staticmethod
+ def backward(ctx: Any, grad: torch.Tensor):
+ bins, = ctx.saved_tensors
+ out = torch.empty((grad.shape[0], bins.shape[0]), dtype=grad.dtype, device=grad.device)
+ ops.replicate_backward(grad, bins, out)
+ return out, None, None
+
+
+replicate = ReplicateOp.apply
diff --git a/build/torch29-cxx11-cu128-aarch64-linux/ops/round_up.py b/build/torch29-cxx11-cu128-aarch64-linux/ops/round_up.py
new file mode 100644
index 0000000000000000000000000000000000000000..6cf6bc873c9f448c5fa9126ebcfd66e8688002af
--- /dev/null
+++ b/build/torch29-cxx11-cu128-aarch64-linux/ops/round_up.py
@@ -0,0 +1,14 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+
+def round_up(x: torch.Tensor, value: int):
+ assert isinstance(value, int)
+ assert x.dtype == torch.int32
+
+ # TODO(tgale): If this becomes and issue
+ # do this in a custom kernel. We only expect
+ # to use this on arrays of less than 1k elements.
+ return torch.div(x + (value - 1), value, rounding_mode='trunc') * value
diff --git a/build/torch29-cxx11-cu128-aarch64-linux/ops/scatter.py b/build/torch29-cxx11-cu128-aarch64-linux/ops/scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4605d9b46f387761b070352365f223dbfe69d47
--- /dev/null
+++ b/build/torch29-cxx11-cu128-aarch64-linux/ops/scatter.py
@@ -0,0 +1,72 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Optional
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for scatter kernel.
+class ScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ) -> torch.Tensor:
+ maybe_x = [x] if ctx.needs_input_grad[3] else []
+ ctx.save_for_backward(indices, bin_ids, weights, bins, *maybe_x)
+ ctx.top_k = top_k
+ ctx.x_shape = x.shape
+ return kernels.scatter(x, indices, bin_ids, weights, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ saved_tensors = ctx.saved_tensors
+
+ indices, bin_ids, weights, bins = saved_tensors[:4]
+ dgrad = None
+ if ctx.needs_input_grad[0]:
+ dgrad = kernels.gather(
+ grad,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[3]: # need wgrad
+ x = saved_tensors[-1]
+ wgrad = kernels.scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bin_ids,
+ bins,
+ ctx.top_k,
+ )
+ return dgrad, None, None, wgrad, None, None, None
+
+
+def scatter(
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+) -> Optional[torch.Tensor]:
+ return ScatterOp.apply(x, indices, bin_ids, weights, bins, top_k)
diff --git a/build/torch29-cxx11-cu128-aarch64-linux/ops/sort.py b/build/torch29-cxx11-cu128-aarch64-linux/ops/sort.py
new file mode 100644
index 0000000000000000000000000000000000000000..bda3bf64283e39533c2eae3627e76bb2d0262c9f
--- /dev/null
+++ b/build/torch29-cxx11-cu128-aarch64-linux/ops/sort.py
@@ -0,0 +1,38 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Optional, Tuple
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+_BITS_FOR_DTYPE = {
+ torch.int16: 16,
+ torch.int32: 32,
+ torch.int64: 64,
+}
+
+
+# Autograd wrapper for sort kernel.
+# NOTE: Does not support gradients.
+class SortOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, end_bit: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
+ if end_bit is None:
+ end_bit = _BITS_FOR_DTYPE[x.dtype]
+ x_out = torch.empty_like(x)
+ iota_out = torch.empty_like(x)
+ ops.sort(x, end_bit, x_out, iota_out)
+ return (x_out, iota_out)
+
+
+sort = SortOp.apply
diff --git a/build/torch29-cxx11-cu128-aarch64-linux/ops/sort_benchmark.py b/build/torch29-cxx11-cu128-aarch64-linux/ops/sort_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..a92ff957d4c552c6e61d9279a7989795472af7b7
--- /dev/null
+++ b/build/torch29-cxx11-cu128-aarch64-linux/ops/sort_benchmark.py
@@ -0,0 +1,85 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import numpy as np
+import torch
+from absl.testing import parameterized
+
+from .. import ops
+
+_SORT_TESTS = (
+ (16384, torch.int32, None),
+ (16384, torch.int32, 2),
+ (16384, torch.int32, 128),
+)
+
+_BASELINE_SORT_TESTS = ((16384,),)
+
+
+def numpy_dtype(dtype):
+ types = {
+ torch.int16: np.int16,
+ torch.int32: np.int32,
+ torch.int64: np.int64,
+ }
+ return types[dtype]
+
+
+def benchmark_function(fn, iterations=10):
+ # Run once to get rid of startup overhead.
+ fn()
+ times = []
+ for _ in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+ start.record()
+ fn()
+ end.record()
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ times = np.array(times)
+ return times.mean(), times.std(), times.max(), times.min()
+
+
+def log_benchmark(arguments, mean_t, std_t):
+ print('=' * 60)
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t))
+ print('=' * 60)
+
+
+class SortBenchmark(parameterized.TestCase):
+
+ @parameterized.parameters(*_SORT_TESTS)
+ def testSort(self, n, dtype, max_val):
+ if max_val is None:
+ max_val = np.iinfo(numpy_dtype(dtype)).max
+ end_bit = int(np.ceil(np.log2(max_val)))
+ x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
+
+ mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.sort(x, end_bit),)
+ arguments = {
+ 'n': n,
+ 'dtype': dtype,
+ 'max_val': max_val,
+ }
+ log_benchmark(arguments, mean_t, std_t)
+
+ @parameterized.parameters(*_BASELINE_SORT_TESTS)
+ def testTorchSort(self, n):
+ x = torch.randint(0, 128, (n,)).cuda().to(torch.int32)
+
+ mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.sort(x))
+ arguments = {
+ 'n': n,
+ }
+ log_benchmark(arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch29-cxx11-cu128-aarch64-linux/ops/stk_autocast.py b/build/torch29-cxx11-cu128-aarch64-linux/ops/stk_autocast.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a3626e5e0eec51339c95a448bca84be14a2ca93
--- /dev/null
+++ b/build/torch29-cxx11-cu128-aarch64-linux/ops/stk_autocast.py
@@ -0,0 +1,39 @@
+# vendored from
+# https://github.com/stanford-futuredata/stk/blob/736313768ef697ce13a0594a41b2512a0fbc9884/stk/backend/autocast.py
+import functools
+import torch
+
+
+def _is_eligible(x):
+ return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64)
+
+
+def _cast(x, dtype):
+ if isinstance(x, torch.Tensor) and _is_eligible(x):
+ return x.to(dtype)
+ elif isinstance(x, map):
+ return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()}
+ elif isinstance(x, list) or isinstance(x, tuple):
+ return type(x)(map(lambda y: _cast(y, dtype), x))
+ return x
+
+
+def custom_fwd(fwd):
+ """Wrap a custom autograd function that always uses autocast dtype."""
+
+ @functools.wraps(fwd)
+ def decorate_fwd(*args, **kwargs):
+ if torch.is_autocast_enabled():
+ with torch.autocast(device_type="cuda", enabled=False):
+ dtype = torch.get_autocast_gpu_dtype()
+ return fwd(*_cast(args, dtype), **_cast(kwargs, dtype))
+ return fwd(*args, **kwargs)
+ return decorate_fwd
+
+
+def custom_bwd(bwd):
+ @functools.wraps(bwd)
+ def decorate_bwd(*args, **kwargs):
+ with torch.autocast(device_type="cuda", enabled=False):
+ return bwd(*args, **kwargs)
+ return decorate_bwd
\ No newline at end of file
diff --git a/build/torch29-cxx11-cu128-aarch64-linux/ops/sum.py b/build/torch29-cxx11-cu128-aarch64-linux/ops/sum.py
new file mode 100644
index 0000000000000000000000000000000000000000..e00c1aa68e1193f5b72f75a2edc37de8d505facc
--- /dev/null
+++ b/build/torch29-cxx11-cu128-aarch64-linux/ops/sum.py
@@ -0,0 +1,9 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+import torch
+
+
+def sum(x: torch.Tensor, dim: int = 0):
+ if x.shape[dim] == 1:
+ return x.squeeze(dim=dim)
+ return x.sum(dim=dim)
diff --git a/build/torch29-cxx11-cu128-aarch64-linux/ops/topology.py b/build/torch29-cxx11-cu128-aarch64-linux/ops/topology.py
new file mode 100644
index 0000000000000000000000000000000000000000..76a50d3164db20534b099dcb4d8487a7aef25d15
--- /dev/null
+++ b/build/torch29-cxx11-cu128-aarch64-linux/ops/topology.py
@@ -0,0 +1,45 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for topology kernel.
+# NOTE: Does not support gradients.
+class TopologyOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(
+ ctx: Any,
+ padded_bins: torch.Tensor,
+ block_size: int,
+ output_block_rows: int,
+ output_block_columns: int,
+ ):
+ out = torch.empty(
+ output_block_rows * output_block_columns,
+ dtype=torch.int16,
+ device=padded_bins.device,
+ )
+ ops.indices(
+ padded_bins,
+ block_size,
+ output_block_rows,
+ output_block_columns,
+ out,
+ )
+ return out
+
+
+topology = TopologyOp.apply
diff --git a/build/torch29-cxx11-cu128-aarch64-linux/stk/__init__.py b/build/torch29-cxx11-cu128-aarch64-linux/stk/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..73c40c267b1c3f4949e9c957a5d2c9f682dfc1a6
--- /dev/null
+++ b/build/torch29-cxx11-cu128-aarch64-linux/stk/__init__.py
@@ -0,0 +1,7 @@
+# import stk.random
+# import stk.ops
+# from stk.matrix import Matrix
+
+from . import random
+from . import ops
+from .matrix import Matrix
diff --git a/build/torch29-cxx11-cu128-aarch64-linux/stk/backend/__init__.py b/build/torch29-cxx11-cu128-aarch64-linux/stk/backend/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/build/torch29-cxx11-cu128-aarch64-linux/stk/backend/autocast.py b/build/torch29-cxx11-cu128-aarch64-linux/stk/backend/autocast.py
new file mode 100644
index 0000000000000000000000000000000000000000..97f6e919a60f3fd579ed0215031008d14111dc96
--- /dev/null
+++ b/build/torch29-cxx11-cu128-aarch64-linux/stk/backend/autocast.py
@@ -0,0 +1,37 @@
+import functools
+import torch
+
+
+def _is_eligible(x):
+ return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64)
+
+
+def _cast(x, dtype):
+ if isinstance(x, torch.Tensor) and _is_eligible(x):
+ return x.to(dtype)
+ elif isinstance(x, map):
+ return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()}
+ elif isinstance(x, list) or isinstance(x, tuple):
+ return type(x)(map(lambda y: _cast(y, dtype), x))
+ return x
+
+
+def custom_fwd(fwd):
+ """Wrap a custom autograd function that always uses autocast dtype."""
+
+ @functools.wraps(fwd)
+ def decorate_fwd(*args, **kwargs):
+ if torch.is_autocast_enabled():
+ with torch.autocast(device_type="cuda", enabled=False):
+ dtype = torch.get_autocast_gpu_dtype()
+ return fwd(*_cast(args, dtype), **_cast(kwargs, dtype))
+ return fwd(*args, **kwargs)
+ return decorate_fwd
+
+
+def custom_bwd(bwd):
+ @functools.wraps(bwd)
+ def decorate_bwd(*args, **kwargs):
+ with torch.autocast(device_type="cuda", enabled=False):
+ return bwd(*args, **kwargs)
+ return decorate_bwd
diff --git a/build/torch29-cxx11-cu128-aarch64-linux/stk/backend/sputnik.py b/build/torch29-cxx11-cu128-aarch64-linux/stk/backend/sputnik.py
new file mode 100644
index 0000000000000000000000000000000000000000..220c947bc1e932e8c77cc30f4069e9930f1aa962
--- /dev/null
+++ b/build/torch29-cxx11-cu128-aarch64-linux/stk/backend/sputnik.py
@@ -0,0 +1,316 @@
+import torch
+
+from ..backend import triton_kernels as backend
+from ..backend.autocast import custom_bwd, custom_fwd
+
+
+def _standardize_shape(x, transpose):
+ if transpose:
+ return torch.Size((x[1], x[0]))
+ return x
+
+
+def _sparse_transpose(x):
+ return (torch.Size((x[0][1], x[0][0])), ) + x[1:]
+
+
+def _transpose_helper(x, transpose):
+ if isinstance(x, torch.Tensor):
+ return x.t() if transpose else x
+ if transpose:
+ x = _sparse_transpose(x)
+ return x + (transpose,)
+
+
+def _wrap(x):
+ if isinstance(x, torch.Tensor):
+ return (x,)
+ return x
+
+
+def _is_transposed(x):
+ return (not x.is_contiguous() and
+ x.stride()[0] == 1 and
+ x.stride()[1] == x.size()[0])
+
+
+def _call_helper(op, out, a, b, trans_a, trans_b):
+ args = (_wrap(_transpose_helper(a, trans_a)) +
+ _wrap(_transpose_helper(b, trans_b)))
+ if isinstance(out, tuple):
+ args = args + out
+ return op(*args)
+
+
+def _preprocess_inputs(lhs, rhs, dy):
+ if isinstance(lhs, torch.Tensor) and _is_transposed(lhs):
+ lhs = lhs.t()
+ if isinstance(rhs, torch.Tensor) and _is_transposed(rhs):
+ rhs = rhs.t()
+ if (isinstance(dy, torch.Tensor) and
+ not dy.is_contiguous() and
+ not _is_transposed(dy)):
+ dy = dy.contiguous()
+ if isinstance(dy, tuple) and not dy[1].is_contiguous():
+ dy = (dy[0], dy[1].contiguous()) + dy[2:]
+ return lhs, rhs, dy
+
+
+def _postprocess_outputs(x, transpose, grad):
+ if isinstance(x, torch.Tensor) and transpose:
+ return grad.t()
+ return grad
+
+
+def _lhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs):
+ lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy)
+
+ a, b = (rhs, dy) if trans_lhs else (dy, rhs)
+ trans_a = trans_lhs and trans_rhs
+ trans_b = trans_lhs or not trans_rhs
+ out = _call_helper(op, lhs, a, b, trans_a, trans_b)
+ return _postprocess_outputs(lhs, trans_lhs, out)
+
+
+def _rhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs):
+ lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy)
+
+ a, b = (dy, lhs) if trans_rhs else (lhs, dy)
+ trans_a = not trans_lhs or trans_rhs
+ trans_b = trans_lhs and trans_rhs
+ out = _call_helper(op, rhs, a, b, trans_a, trans_b)
+ return _postprocess_outputs(rhs, trans_rhs, out)
+
+
+class DSD(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs):
+ ctx.save_for_backward(data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ rhs)
+ ctx.shape = _standardize_shape(shape, transpose_a)
+ ctx.transpose_a = transpose_a
+
+ out = torch.empty(
+ (shape[0], rhs.size()[1]),
+ dtype=rhs.dtype,
+ device=rhs.device)
+
+ backend.dsd(shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs,
+ out)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs = (ctx.shape,) + saved_tensors[:-1]
+ rhs = saved_tensors[-1]
+ trans_a = ctx.transpose_a
+ trans_b = _is_transposed(rhs)
+
+ ddata = None
+ if ctx.needs_input_grad[1]:
+ ddata = _lhs_gradient(sdd,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ drhs = None
+ if ctx.needs_input_grad[-1]:
+ op = dds if trans_b else dsd
+ drhs = _rhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return None, ddata, None, None, None, None, None, None, None, drhs
+
+
+dsd = DSD.apply
+
+
+class DDS(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b):
+ ctx.save_for_backward(lhs,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t)
+ ctx.shape = _standardize_shape(shape, transpose_b)
+ ctx.transpose_b = transpose_b
+ out = torch.empty((lhs.size()[0], shape[1]),
+ dtype=lhs.dtype,
+ device=lhs.device)
+ backend.dds(lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b,
+ out)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs = saved_tensors[0]
+ rhs = (ctx.shape,) + saved_tensors[1:]
+ trans_a = _is_transposed(lhs)
+ trans_b = ctx.transpose_b
+
+ dlhs = None
+ if ctx.needs_input_grad[0]:
+ op = dsd if trans_a else dds
+ dlhs = _lhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ ddata = None
+ if ctx.needs_input_grad[2]:
+ ddata = _rhs_gradient(sdd,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return dlhs, None, ddata, None, None, None, None, None, None, None
+
+
+dds = DDS.apply
+
+
+class SDD(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ lhs,
+ rhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t):
+ ctx.save_for_backward(
+ lhs,
+ rhs,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t)
+ ctx.shape = shape
+ out = torch.empty(
+ data.shape,
+ dtype=lhs.dtype,
+ device=lhs.device)
+ backend.sdd(lhs,
+ rhs,
+ shape,
+ out,
+ offsets,
+ row_indices,
+ column_indices)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs, rhs = saved_tensors[:2]
+ dy = (ctx.shape, dy) + saved_tensors[2:]
+ trans_a = _is_transposed(lhs)
+ trans_b = _is_transposed(rhs)
+
+ dlhs = None
+ if ctx.needs_input_grad[0]:
+ op = dds if trans_a else dsd
+ dlhs = _lhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ drhs = None
+ if ctx.needs_input_grad[1]:
+ op = dsd if trans_b else dds
+ drhs = _rhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return dlhs, drhs, None, None, None, None, None, None, None, None
+
+
+sdd = SDD.apply
+
+class RowIndices(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, shape, data, offsets, column_indices):
+ out = torch.empty(
+ column_indices.shape,
+ dtype=column_indices.dtype,
+ device=column_indices.device)
+ backend.row_indices(shape, data, offsets, column_indices, out)
+ return out
+
+
+row_indices = RowIndices.apply
diff --git a/build/torch29-cxx11-cu128-aarch64-linux/stk/backend/triton_kernels.py b/build/torch29-cxx11-cu128-aarch64-linux/stk/backend/triton_kernels.py
new file mode 100644
index 0000000000000000000000000000000000000000..c535309f3321249f475367164a558f94a4f8eb86
--- /dev/null
+++ b/build/torch29-cxx11-cu128-aarch64-linux/stk/backend/triton_kernels.py
@@ -0,0 +1,393 @@
+import torch
+import triton
+import triton.language as tl
+from dataclasses import dataclass
+
+@dataclass
+class TritonConfig:
+ BLOCK_M: int = 128
+ BLOCK_N: int = 128
+ BLOCK_K: int = 32
+ BLOCK_SIZE: int = 128
+ NUM_STAGES: int = 4
+ NUM_WARPS: int = 4
+
+def _validate_matmul_dims(M: int, K: int, N: int):
+ error_string = "incompatible dimensions: tensor has dim with length: {}, which must be divisible by {}"
+ assert M % TritonConfig.BLOCK_M == 0, error_string.format(M, TritonConfig.BLOCK_M)
+ assert K % TritonConfig.BLOCK_K == 0, error_string.format(K, TritonConfig.BLOCK_K)
+ assert N % TritonConfig.BLOCK_N == 0, error_string.format(N, TritonConfig.BLOCK_N)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _sdd_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+ # matrix multiplication
+ pid = tl.program_id(0)
+ pid_m = tl.load(row_indices + pid)
+ pid_n = tl.load(column_indices + pid)
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
+ rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
+ rk = tl.arange(0, BLOCK_K)
+ # pointers
+ A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
+ B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ for k in range(0, tl.cdiv(K, BLOCK_K)):
+ a = tl.load(A)
+ b = tl.load(B)
+ acc += tl.dot(a, b)
+ A += BLOCK_K * stride_ak
+ B += BLOCK_K * stride_bk
+ #Store to sparse matrix
+ acc = acc.to(C.dtype.element_ty)
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+ cm = tl.arange(0, BLOCK_M)
+ cn = tl.arange(0, BLOCK_N)
+ C = C + pid * BLOCK_ELEMENTS + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _dsd_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices, offsets,
+ block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+
+ # matrix multiplication
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+
+ num_pid_m = tl.num_programs(0)
+ num_pid_n = tl.num_programs(1)
+ pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M)
+
+ start_inx = tl.load(offsets + pid_m)
+ end_inx = tl.load(offsets + pid_m + 1)
+
+ # pointers to sparse matrix
+ rm = tl.arange(0, BLOCK_M)
+ rak = tl.arange(0, BLOCK_K)
+
+ A += (rm[:, None] * stride_am + rak[None, :] * stride_ak)
+
+ # pointers to dense matrix
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ rbk = tl.arange(0, BLOCK_K)
+ B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn)
+
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K)
+
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+ ak_sub_incr = BLOCK_K * stride_ak
+ bk_sub_incr = BLOCK_K * stride_bk
+ bk_block_incr = BLOCK_SIZE * stride_bk
+
+ for k in range(nsub_blocks * (end_inx - start_inx)):
+ sub_block_inx = k % nsub_blocks
+ block_inx = k // nsub_blocks
+
+ if trans_A:
+ ptr_A = A + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr
+ else:
+ ptr_A = A + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr
+
+ ptr_B = B + tl.load(column_indices + start_inx + block_inx) * bk_block_incr + sub_block_inx * bk_sub_incr
+
+ a = tl.load(ptr_A)
+ b = tl.load(ptr_B)
+ acc += tl.dot(a, b)
+
+ acc = acc.to(C.dtype.element_ty)
+
+ cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+
+ C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _dds_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices, offsets,
+ block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+
+ # matrix multiplication
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+
+ num_pid_m = tl.num_programs(0)
+ num_pid_n = tl.num_programs(1)
+ pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M)
+
+ start_inx = tl.load(offsets + pid_n)
+ end_inx = tl.load(offsets + pid_n + 1)
+
+ # pointers to dense matrix
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ rak = tl.arange(0, BLOCK_K)
+
+ A += (rm[:, None] * stride_am + rak[None, :] * stride_ak)
+
+ # pointers to sparse matrix
+ rn = tl.arange(0, BLOCK_N)
+ rbk = tl.arange(0, BLOCK_K)
+ B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn)
+
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K)
+
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+
+ ak_sub_incr = BLOCK_K * stride_ak
+ ak_block_incr = BLOCK_SIZE * stride_ak
+ bk_sub_incr = BLOCK_K * stride_bk
+
+ for k in range(nsub_blocks * (end_inx - start_inx)):
+ sub_block_inx = k % nsub_blocks
+ block_inx = k // nsub_blocks
+
+ if trans_B:
+ ptr_B = B + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr
+ else:
+ ptr_B = B + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr
+
+ ptr_A = A + tl.load(column_indices + start_inx + block_inx) * ak_block_incr + sub_block_inx * ak_sub_incr
+ a = tl.load(ptr_A)
+ b = tl.load(ptr_B)
+ acc += tl.dot(a, b)
+
+ acc = acc.to(C.dtype.element_ty)
+ cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+def dsd(shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs,
+ out
+ ):
+
+ device = rhs.device
+ trans_A = transpose_a
+ trans_B = False
+
+ if rhs.stride(0) > 1 and rhs.stride(1) > 1:
+ trans_B = True
+
+ # checks constraints
+ assert shape[1] == rhs.shape[0], "incompatible dimensions"
+ M, K = shape
+ _, N = rhs.shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if rhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ stride_am, stride_ak = data.stride(1), data.stride(2)
+ stride_bk, stride_bn = rhs.stride(0), rhs.stride(1)
+ a_column_indices = column_indices
+ a_offsets = offsets
+
+ # launch kernel
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N']))
+
+ if trans_A:
+ stride_am, stride_ak = data.stride(2), data.stride(1)
+ a_column_indices, a_offsets = column_indices_t, offsets_t
+
+ if trans_B:
+ stride_bk, stride_bn = rhs.stride(1), rhs.stride(0)
+
+ _dsd_kernel[grid](
+ data.data, rhs, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(0), out.stride(1),
+ row_indices, a_column_indices, a_offsets,
+ block_offsets_t, trans_A, trans_B,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+ # return out
+
+def dds(lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b,
+ out
+ ):
+
+ device = lhs.device
+ trans_B = transpose_b
+ trans_A = False
+
+ if lhs.stride(0) > 1 and lhs.stride(1) > 1:
+ trans_A = True
+
+ # checks constraints
+ assert lhs.shape[1] == shape[0], "incompatible dimensions"
+ M, K = lhs.shape
+ _, N = shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if lhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ stride_am, stride_ak = lhs.stride(0), lhs.stride(1)
+ stride_bk, stride_bn = data.stride(1), data.stride(2)
+ b_column_indices = column_indices_t
+ b_offsets = offsets_t
+
+ # launch kernel
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N']))
+
+ if trans_A:
+ stride_am, stride_ak = lhs.stride(1), lhs.stride(0)
+ if trans_B:
+ stride_bk, stride_bn = data.stride(2), data.stride(1)
+ b_column_indices, b_offsets = column_indices, offsets
+
+ _dds_kernel[grid](
+ lhs, data, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(0), out.stride(1),
+ row_indices, b_column_indices, b_offsets,
+ block_offsets_t, trans_A, trans_B,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+
+def sdd(lhs,
+ rhs,
+ shape,
+ out,
+ offsets,
+ row_indices,
+ column_indices
+ ):
+
+ device = out.device
+ trans_A = False
+ trans_B = False
+
+ if lhs.stride(0) > 1 and lhs.stride(1) > 1:
+ trans_A = True
+ if rhs.stride(0) > 1 and rhs.stride(1) > 1:
+ trans_B = True
+
+ # checks constraints
+ assert lhs.shape[1] == rhs.shape[0], "incompatible dimensions"
+ M, K = lhs.shape
+ _, N = rhs.shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if out.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ # launch kernel
+ nnz_blocks = len(row_indices)
+ grid = lambda META: (nnz_blocks,)
+
+ stride_am, stride_ak = lhs.stride(0), lhs.stride(1)
+ stride_bk, stride_bn = rhs.stride(0), rhs.stride(1)
+
+ if trans_A:
+ stride_am, stride_ak = lhs.stride(1), lhs.stride(0)
+ if trans_B:
+ stride_bk, stride_bn = rhs.stride(1), rhs.stride(0)
+
+ _sdd_kernel[grid](
+ lhs, rhs, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(1), out.stride(2),
+ row_indices, column_indices,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+
+@triton.jit
+def _row_indices_kernel(offsets, out):
+ pid = tl.program_id(0)
+ row_offset = tl.load(offsets + pid)
+ nnz_blocks = tl.load(offsets + pid + 1) - row_offset
+ for nnz_block in range(nnz_blocks):
+ tl.store(out + row_offset + nnz_block, pid)
+
+def row_indices(
+ shape, data, offsets, column_indices, out
+):
+ block_rows = len(offsets) - 1
+ _row_indices_kernel[(block_rows, )](offsets, out)
diff --git a/build/torch29-cxx11-cu128-aarch64-linux/stk/matrix.py b/build/torch29-cxx11-cu128-aarch64-linux/stk/matrix.py
new file mode 100644
index 0000000000000000000000000000000000000000..80f42263d6aada287adbfa52a61fe950162a9e28
--- /dev/null
+++ b/build/torch29-cxx11-cu128-aarch64-linux/stk/matrix.py
@@ -0,0 +1,329 @@
+import numpy as np
+import torch
+
+# 1. Add heavyweight (data) validation helper.
+# 2. Add construction helpers
+# 3. Make indentation consistent
+# 4. Replace asserts with descriptive errors.
+
+##
+### Validation helpers.
+##
+
+
+def _validate_matrix(shape, data, row_indices, column_indices, offsets):
+ # Data should be [nnz, block_size, block_size]
+ if data.dim() == 1:
+ data = torch.reshape(data, [data.numel(), 1, 1])
+
+ # Blocks should be square.
+ if data.shape[-2] != data.shape[-1]:
+ raise ValueError(
+ "Expected square blocking in data. "
+ f"Got block shape {[data.shape[-2], data.shape[-1]]}")
+
+ # Flatten batch dimensions on data - original shape preserved
+ # in shape argument.
+ block_size = data.shape[-1]
+ data = data.view([-1, block_size, block_size])
+
+ if data.dim() != 3:
+ raise ValueError(
+ "Expected 3D shape for data (nnz, block, block). "
+ f"Got shape {data.dim()}D shape.")
+
+ block_size = data.shape[1]
+ if shape[-2] % block_size != 0 or shape[-1] % block_size != 0:
+ raise ValueError(
+ "Matrix shape must be dividible by blocking. "
+ f"Got shape {shape} with "
+ f"{[block_size, block_size]} blocking.")
+
+ if np.prod(shape) < data.numel():
+ raise ValueError(
+ "Invalid matrix. Number of nonzeros exceeds matrix capacity "
+ f"({data.numel()} v. {np.prod(shape)})")
+
+ if row_indices.dim() != 1:
+ raise ValueError(
+ f"Expected 1D row_indices. Got {row_indices.dim()}D row_indices.")
+
+ if column_indices.dim() != 1:
+ raise ValueError(
+ f"Expected 1D column_indices. Got {column_indices.dim()}D column_indices.")
+
+ if offsets.dim() != 1:
+ raise ValueError(
+ f"Expected 1D offsets. Got {offsets.dim()}D offsets.")
+
+ if row_indices.numel() != data.shape[0]:
+ raise ValueError(
+ "Expected 1 index per nonzero block. "
+ f"Got {row_indices.numel()} row_indices for {data.shape[0]} blocks")
+
+ if column_indices.numel() != data.shape[0]:
+ raise ValueError(
+ "Expected 1 index per nonzero block. "
+ f"Got {column_indices.numel()} column_indices for {data.shape[0]} blocks")
+
+ block_rows = np.prod(shape[:-1]) / block_size
+ if offsets.numel() != block_rows + 1:
+ raise ValueError(
+ "Expected one offset per block row plus one. "
+ f"Got {offsets.numel()} offsets with {block_rows} block rows.")
+
+ is_cuda = (data.is_cuda and
+ row_indices.is_cuda and
+ column_indices.is_cuda and
+ offsets.is_cuda)
+ is_cpu = (not data.is_cuda and
+ not row_indices.is_cuda and
+ not column_indices.is_cuda and
+ not offsets.is_cuda)
+ if not (is_cuda or is_cpu):
+ raise ValueError(
+ "Expected data & meta-data on common device. "
+ f"Got data on {data.device}, row_indices on {row_indices.device} "
+ f"column_indices on {column_indices.device} and "
+ f"offsets on {offsets.device}.")
+
+ if data.dtype != torch.float16:
+ raise ValueError(
+ f"Expected float16 data. Got {data.dtype} data.")
+ if row_indices.dtype != torch.int16:
+ raise ValueError(
+ f"Expected int16 row_indices. Got {row_indices.dtype} row_indices.")
+ if column_indices.dtype != torch.int16:
+ raise ValueError(
+ f"Expected int16 column_indices. Got {column_indices.dtype} column_indices.")
+ if offsets.dtype != torch.int32:
+ raise ValueError(
+ f"Expected int32 offsets. Got {offsets.dtype} offsets.")
+ return data
+
+
+def _transpose(size, data, row_indices, column_indices, offsets):
+ block_columns = size[1] // data.shape[1]
+
+ # Sort row indices by column indices to get the transposed matrix's
+ # column indices.
+ gather_indices = column_indices.argsort()
+ column_indices_t = row_indices.gather(0, gather_indices)
+ block_offsets_t = gather_indices.int()
+
+ # NOTE: Histogram is not implemented for any integer type on CPU. Do
+ # the histogram in 32-bit float, which can exactly represent 16-bit
+ # integers.
+ column_indices_float = column_indices.float()
+
+ zero = torch.zeros((1,), dtype=torch.int32, device=data.device)
+ nnz_per_column = column_indices_float.histc(block_columns, 0, block_columns)
+ nnz_per_column = nnz_per_column.int()
+ offsets_t = torch.cat([zero, nnz_per_column.cumsum(0, dtype=torch.int32)])
+ return column_indices_t, offsets_t, block_offsets_t
+
+
+class Matrix(torch.nn.Module):
+ """A matrix stored in sparse format.
+
+ Underlying format is block compressed sparse row (BCSR).
+
+ TODO(tgale): Make this mirror torch.Tensor API as much as possible.
+ """
+
+ def __init__(self,
+ size,
+ data,
+ row_indices,
+ column_indices,
+ offsets,
+ column_indices_t=None,
+ offsets_t=None,
+ block_offsets_t=None):
+ super().__init__()
+ self._size = size
+ self._data = data
+ self._row_indices = row_indices
+ self._column_indices = column_indices
+ self._offsets = offsets
+
+ # Produce the transpose meta-data if it is not passed in.
+ if ((column_indices_t is None) or (offsets_t is None) or
+ (block_offsets_t is None)):
+ column_indices_t, offsets_t, block_offsets_t = _transpose(
+ size, data, row_indices, column_indices, offsets)
+ self._column_indices_t = column_indices_t
+ self._offsets_t = offsets_t
+ self._block_offsets_t = block_offsets_t
+
+ self._transposed = False
+
+ # Validate that our metadata will not overflow.
+ max_dim = np.iinfo(np.int16).max * self.blocking
+ if column_indices.dtype == torch.int16:
+ if size[0] > max_dim or size[1] > max_dim:
+ raise ValueError(
+ "Sparse matrix with shape {size} exceeds representable "
+ "size with 16-bit indices.")
+
+ def validate(self):
+ _validate_matrix(self._size,
+ self._data,
+ self._row_indices,
+ self._column_indices,
+ self._offsets)
+
+ # TODO(tgale): Add heavyweight data validation.
+
+ def to(self, device):
+ # TODO(tgale): Handle type conversions here. We
+ # need to set the appropriate meta-data type for
+ # the given floating-point type.
+ self._data = self._data.to(device)
+ self._row_indices = self._row_indices.to(device)
+ self._column_indices = self._column_indices.to(device)
+ self._offsets = self._offsets.to(device)
+ self._column_indices_t = self._column_indices_t.to(device)
+ self._offsets_t = self._offsets_t.to(device)
+ self._block_offsets_t = self._block_offsets_t.to(device)
+ return self
+
+ def cuda(self):
+ return self.to(torch.cuda.current_device())
+
+ def clone(self):
+ return Matrix(
+ self.size(),
+ self.data.clone(),
+ self.row_indices.clone(),
+ self.column_indices.clone(),
+ self.offsets.clone(),
+ self.column_indices_t.clone(),
+ self.offsets_t.clone(),
+ self.block_offsets_t.clone())
+
+ def t(self):
+ if self.dim() != 2:
+ raise ValueError(
+ "t() expects a tensor with <= 2 dimensions, "
+ f"but self is {self.dim()}D.")
+ out = Matrix(self.size(),
+ self.data,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+ out._transposed = not self._transposed
+ out._size = torch.Size((self._size[1], self._size[0]))
+ return out
+
+ def contiguous(self):
+ raise ValueError("Not yet implemented.")
+
+ def is_contiguous(self):
+ return not self._transposed
+
+ @property
+ def is_cuda(self):
+ return self._data.is_cuda
+
+ @property
+ def device(self):
+ return self._data.device
+
+ def size(self):
+ return self._size
+
+ @property
+ def shape(self):
+ return self.size()
+
+ def dim(self):
+ return len(self._size)
+
+ @property
+ def data(self):
+ return self._data
+
+ @property
+ def row_indices(self):
+ return self._row_indices
+
+ @property
+ def column_indices(self):
+ return self._column_indices
+
+ @property
+ def offsets(self):
+ return self._offsets
+
+ @property
+ def offsets_t(self):
+ return self._offsets_t
+
+ @property
+ def column_indices_t(self):
+ return self._column_indices_t
+
+ @property
+ def block_offsets_t(self):
+ return self._block_offsets_t
+
+ @property
+ def dtype(self):
+ return self.data.dtype
+
+ @property
+ def nnz(self):
+ return self.data.numel()
+
+ @property
+ def blocking(self):
+ return self.data.shape[1]
+
+ @property
+ def requires_grad(self):
+ return self.data.requires_grad
+
+ def requires_grad_(self, x):
+ self.data.requires_grad_(x)
+ return self
+
+ def view(self, *shape):
+ assert self.is_contiguous()
+ if shape[-1] != self.size()[-1]:
+ raise ValueError(
+ "Can't change view on compressed dimension. "
+ f"{self.size()[-1]} v. {shape[-1]}.")
+ if np.prod(shape) != np.prod(self.size()):
+ raise ValueError(
+ "Mismatch in numel of Matrix and new shape. "
+ f"{np.prod(self.size())} v. {np.prod(shape)}")
+ return Matrix(shape,
+ self.data,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+
+ @property
+ def grad(self):
+ # TODO(tgale): Make sure this mirrors torch.Tensor
+ # behavior in the case where we ask for the gradient
+ # of a non-contiguous tensor.
+ size = self.size()
+ if not self.is_contiguous():
+ size = torch.Size((size[1], size[0]))
+ out = Matrix(size,
+ self.data.grad,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+ return out if self.is_contiguous() else out.t()
diff --git a/build/torch29-cxx11-cu128-aarch64-linux/stk/ops/__init__.py b/build/torch29-cxx11-cu128-aarch64-linux/stk/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc873b236f4cd4036964c016a4036e3ce5ebf1ac
--- /dev/null
+++ b/build/torch29-cxx11-cu128-aarch64-linux/stk/ops/__init__.py
@@ -0,0 +1,3 @@
+from .linear_ops import dds, dsd, sdd
+from .matrix_ops import ones_like, row_indices, sum, to_dense, to_sparse
+from .eltwise_ops import mul
diff --git a/build/torch29-cxx11-cu128-aarch64-linux/stk/ops/eltwise_ops.py b/build/torch29-cxx11-cu128-aarch64-linux/stk/ops/eltwise_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba7d7332320250fd01fa60e60528f19de3e8ed03
--- /dev/null
+++ b/build/torch29-cxx11-cu128-aarch64-linux/stk/ops/eltwise_ops.py
@@ -0,0 +1,28 @@
+from ..matrix import Matrix
+
+def mul(a, b):
+ """Performs element-wise multiplication of matrices a and b.
+
+ It is the user's responsibility to make sure that a and b
+ follow the same matrix topology. This function assumes it is safe
+ to use the topoplogy of a.
+
+ Args:
+ a: stk.Matrix.
+ b: stk.Matrix with a's matrix topology.
+
+ Returns:
+ stk.Matrix where the entries correspond to torch.mul(a, b).
+ """
+ assert isinstance(a, Matrix)
+ assert isinstance(b, Matrix)
+ assert a.size() == b.size()
+
+ return Matrix(a.size(),
+ a.data * b.data,
+ a.row_indices,
+ a.column_indices,
+ a.offsets,
+ a.column_indices_t,
+ a.offsets_t,
+ a.block_offsets_t)
diff --git a/build/torch29-cxx11-cu128-aarch64-linux/stk/ops/eltwise_ops_test.py b/build/torch29-cxx11-cu128-aarch64-linux/stk/ops/eltwise_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..66bfd4f6af77042d3c5bdb1fe18d00e457478d46
--- /dev/null
+++ b/build/torch29-cxx11-cu128-aarch64-linux/stk/ops/eltwise_ops_test.py
@@ -0,0 +1,86 @@
+import unittest
+import itertools
+import torch
+from absl.testing import parameterized
+
+import stk
+from stk.ops.linear_ops_test import allclose, _dense_and_sparse
+
+_MATRIX_SIZES = (
+ (128, 128, 0.0),
+ (256, 256, 0.5),
+ (2048, 1024, 0.8),
+ (512, 128, 0.0),
+ (128, 512, 0.0),
+ (1024, 512, 0.0),
+ (1024, 512, 0.5),
+ (1024, 512, 0.75),
+ (512, 1024, 0.0),
+ (512, 1024, 0.5),
+ (512, 1024, 0.75),
+ (1024, 1024, 0.0),
+ (1024, 1024, 0.5),
+ (1024, 1024, 0.75),
+)
+
+_DTYPE = (
+ torch.float16, torch.bfloat16
+)
+
+def _generate_testcases():
+ testcases = itertools.product(_MATRIX_SIZES, _DTYPE)
+ testcases = [(*size, 128, dtype) for
+ (size, dtype) in testcases]
+ return testcases
+
+_ELTWISE_OP_TESTS = _generate_testcases()
+
+def _dense_and_sparse_like(x, std=0.1):
+ dense_data = torch.randn_like(x.data, device=x.device) * std
+ sparse = stk.Matrix(x.size(),
+ dense_data,
+ x.row_indices,
+ x.column_indices,
+ x.offsets)
+ dense = stk.ops.to_dense(sparse)
+
+ return (dense.requires_grad_(True),
+ sparse.requires_grad_(True))
+
+@parameterized.parameters(_ELTWISE_OP_TESTS)
+class EltwiseOpsTest(parameterized.TestCase):
+
+ def testEltwiseMul(self, m, n, sparsity, blocking, dtype):
+
+ a_dense, a = _dense_and_sparse(m, n, sparsity, blocking, dtype)
+ b_dense, b = _dense_and_sparse_like(a)
+
+ out = stk.ops.mul(a, b)
+ expected_out = torch.mul(a_dense, b_dense)
+
+ # Compute the gradients w.r.t. the inputs.
+ expected_out.sum().backward()
+ stk.ops.sum(out).backward()
+
+ # Validate the results.
+ out = stk.ops.to_dense(out)
+ self.assertEqual(out.dim(), 2)
+ self.assertEqual(expected_out.size(), out.size())
+ self.assertTrue(allclose(out, expected_out))
+
+ # LHS gradient.
+ grad = stk.ops.to_dense(a.grad)
+ expected_grad = a_dense.grad
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size(), grad.size())
+ self.assertTrue(allclose(grad, expected_grad))
+
+ # RHS gradient.
+ grad = stk.ops.to_dense(b.grad)
+ expected_grad = b_dense.grad
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size(), grad.size())
+ self.assertTrue(allclose(grad, expected_grad))
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch29-cxx11-cu128-aarch64-linux/stk/ops/linear_ops.py b/build/torch29-cxx11-cu128-aarch64-linux/stk/ops/linear_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d277c8c07f9e30addc31900a12175c8a1f4d7ad
--- /dev/null
+++ b/build/torch29-cxx11-cu128-aarch64-linux/stk/ops/linear_ops.py
@@ -0,0 +1,59 @@
+import torch
+
+from ..backend import sputnik
+from ..matrix import Matrix
+
+
+def dsd(a, b):
+ assert isinstance(a, Matrix)
+ assert isinstance(b, torch.Tensor)
+ return sputnik.dsd(
+ a.size(),
+ a.data, a.offsets,
+ a.row_indices,
+ a.column_indices,
+ a.offsets_t,
+ a.column_indices_t,
+ a.block_offsets_t,
+ not a.is_contiguous(),
+ b)
+
+
+def dds(a, b):
+ assert isinstance(a, torch.Tensor)
+ assert isinstance(b, Matrix)
+ return sputnik.dds(
+ a,
+ b.size(),
+ b.data, b.offsets,
+ b.row_indices,
+ b.column_indices,
+ b.offsets_t,
+ b.column_indices_t,
+ b.block_offsets_t,
+ not b.is_contiguous())
+
+
+def sdd(a, b, topo):
+ assert isinstance(a, torch.Tensor)
+ assert isinstance(b, torch.Tensor)
+ assert isinstance(topo, Matrix)
+ assert topo.is_contiguous()
+ out = sputnik.sdd(
+ a, b,
+ topo.size(),
+ topo.data,
+ topo.offsets,
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets_t,
+ topo.column_indices_t,
+ topo.block_offsets_t)
+ return Matrix(topo.size(),
+ out,
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets,
+ topo.column_indices_t,
+ topo.offsets_t,
+ topo.block_offsets_t)
diff --git a/build/torch29-cxx11-cu128-aarch64-linux/stk/ops/linear_ops_test.py b/build/torch29-cxx11-cu128-aarch64-linux/stk/ops/linear_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..ced1d782fbc9f9ca16b3449239f1588dc5ff5e00
--- /dev/null
+++ b/build/torch29-cxx11-cu128-aarch64-linux/stk/ops/linear_ops_test.py
@@ -0,0 +1,216 @@
+import unittest
+import itertools
+import numpy as np
+import torch
+from absl.testing import parameterized
+
+import stk
+
+
+def allclose(x, y, pct=0.25):
+ mask = torch.isclose(x, y, rtol=5e-2)
+ pct_diff = (mask.numel() - mask.sum()) / mask.numel() * 100
+ if pct_diff > pct:
+ print("{:.2f}% of values not close.".format(pct_diff))
+ return False
+ return True
+
+
+# An assortment of problems designed to make sure
+# the bindings are operating correctly.
+_MATRIX_SIZES = (
+ (128, 128, 128, 0.0),
+ (256, 256, 256, 0.5),
+ (2048, 1024, 512, 0.8),
+ (512, 128, 128, 0.0),
+ (128, 128, 512, 0.0),
+ (1024, 512, 512, 0.0),
+ (1024, 512, 512, 0.5),
+ (1024, 512, 512, 0.75),
+ (512, 512, 1024, 0.0),
+ (512, 512, 1024, 0.5),
+ (512, 512, 1024, 0.75),
+ (1024, 1024, 1024, 0.0),
+ (1024, 1024, 1024, 0.5),
+ (1024, 1024, 1024, 0.75),
+)
+
+_TRANSPOSE = (
+ (False, False),
+ (False, True),
+ (True, False),
+ (True, True),
+)
+
+_DTYPE = (
+ torch.float16, torch.bfloat16
+)
+
+def _generate_testcases():
+ testcases = itertools.product(_MATRIX_SIZES, _TRANSPOSE, _DTYPE)
+ testcases = [(*size, *trans, 128, dtype) for
+ (size, trans, dtype) in testcases]
+ return testcases
+
+_LINEAR_OP_TESTS = _generate_testcases()
+
+def _dense_and_sparse(rows, cols, sparsity, blocking, dtype, std=0.1):
+ mask = stk.random.dense_mask(rows, cols, sparsity, blocking)
+ dense = (torch.randn(rows, cols) * std * mask).type(dtype)
+ sparse = stk.ops.to_sparse(dense, blocking)
+ cuda_device = torch.device("cuda")
+ return (dense.to(cuda_device).requires_grad_(True),
+ sparse.to(cuda_device).requires_grad_(True))
+
+
+def _dense(rows, cols, dtype, std=0.1):
+ cuda_device = torch.device("cuda")
+ out = (torch.randn(rows, cols) * std).type(dtype)
+ return out.to(cuda_device).requires_grad_(True)
+
+
+def _dense_2x(rows, cols, dtype):
+ a = _dense(rows, cols, dtype)
+ return a, a.detach().requires_grad_(True)
+
+
+def _with_transpose(op, a, b, trans_a, trans_b):
+ a = a.t() if trans_a else a
+ b = b.t() if trans_b else b
+ return op(a, b)
+
+
+def _mmm(a, b, topo):
+ mask = stk.ops.to_dense(stk.ops.ones_like(topo))
+ return torch.mm(a, b) * mask
+
+
+def _sparse_out_with_transpose(op, a, b, topo, trans_a, trans_b):
+ a = a.t() if trans_a else a
+ b = b.t() if trans_b else b
+ return op(a, b, topo)
+
+
+def _mask(x, mask):
+ mask = stk.ops.to_dense(stk.ops.ones_like(mask))
+ return x * mask
+
+
+@parameterized.parameters(*_LINEAR_OP_TESTS)
+class LinearOpsTest(parameterized.TestCase):
+
+ def testLinearOps_Dsd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+ # Construct the operands.
+ a_shape = (k, m) if trans_a else (m, k)
+ a_dense, a = _dense_and_sparse(*a_shape, sparsity, blocking, dtype)
+ b_shape = (n, k) if trans_b else (k, n)
+ b, bcp = _dense_2x(*b_shape, dtype)
+
+ # Execute the matmul.
+ out = _with_transpose(stk.ops.dsd, a, b, trans_a, trans_b)
+ expected_out = _with_transpose(torch.mm, a_dense, bcp, trans_a, trans_b)
+
+ # Compute the gradients w.r.t. the inputs.
+ expected_out.sum().backward()
+ out.sum().backward()
+
+ # Validate the results.
+ self.assertEqual(out.dim(), 2)
+ self.assertEqual(expected_out.size()[0], out.size()[0])
+ self.assertEqual(expected_out.size()[1], out.size()[1])
+ self.assertTrue(allclose(out, expected_out))
+
+ # LHS gradient.
+ grad = stk.ops.to_dense(a.grad)
+ expected_grad = _mask(a_dense.grad, a.grad)
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
+ self.assertTrue(allclose(grad, expected_grad))
+
+ # RHS gradient.
+ grad = b.grad
+ expected_grad = bcp.grad
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
+ self.assertTrue(allclose(grad, expected_grad))
+
+ def testLinearOps_Dds(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+ # Construct the operands.
+ a_shape = (k, m) if trans_a else (m, k)
+ a, acp = _dense_2x(*a_shape, dtype)
+ b_shape = (n, k) if trans_b else (k, n)
+ b_dense, b = _dense_and_sparse(*b_shape, sparsity, blocking, dtype)
+
+ # Execute the matmul.
+ out = _with_transpose(stk.ops.dds, a, b, trans_a, trans_b)
+ expected_out = _with_transpose(torch.mm, acp, b_dense, trans_a, trans_b)
+
+ # Compute the gradients w.r.t. the inputs.
+ expected_out.sum().backward()
+ out.sum().backward()
+
+ # Validate the results.
+ self.assertEqual(out.dim(), 2)
+ self.assertEqual(expected_out.size()[0], out.size()[0])
+ self.assertEqual(expected_out.size()[1], out.size()[1])
+ self.assertTrue(allclose(out, expected_out))
+
+ # LHS gradient.
+ grad = a.grad
+ expected_grad = acp.grad
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
+ self.assertTrue(allclose(grad, expected_grad))
+
+ # RHS gradient.
+ grad = stk.ops.to_dense(b.grad)
+ expected_grad = _mask(b_dense.grad, b.grad)
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
+ self.assertTrue(allclose(grad, expected_grad))
+
+ def testLinearOps_Sdd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+ # Construct the operands.
+ a_shape = (k, m) if trans_a else (m, k)
+ a, acp = _dense_2x(*a_shape, dtype)
+ b_shape = (n, k) if trans_b else (k, n)
+ b, bcp = _dense_2x(*b_shape, dtype)
+ _, topo = _dense_and_sparse(m, n, sparsity, blocking, dtype)
+
+ # Execute the matmul.
+ out = _sparse_out_with_transpose(stk.ops.sdd, a, b, topo, trans_a, trans_b)
+ expected_out = _sparse_out_with_transpose(_mmm, acp, bcp, topo, trans_a, trans_b)
+
+ # Compute the gradients w.r.t. the inputs.
+ expected_out.sum().backward()
+ stk.ops.sum(out).backward()
+
+ # Validate the results.
+ out = stk.ops.to_dense(out)
+ self.assertEqual(out.dim(), 2)
+ self.assertEqual(expected_out.size()[0], out.size()[0])
+ self.assertEqual(expected_out.size()[1], out.size()[1])
+ self.assertTrue(allclose(out, expected_out))
+
+ # LHS gradient.
+ grad = a.grad
+ expected_grad = acp.grad
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
+ self.assertTrue(allclose(grad, expected_grad))
+
+ # RHS gradient.
+ grad = b.grad
+ expected_grad = bcp.grad
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
+ self.assertTrue(allclose(grad, expected_grad))
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch29-cxx11-cu128-aarch64-linux/stk/ops/matrix_ops.py b/build/torch29-cxx11-cu128-aarch64-linux/stk/ops/matrix_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..447c72dc73439d84f58c917676cc04e64f13e97d
--- /dev/null
+++ b/build/torch29-cxx11-cu128-aarch64-linux/stk/ops/matrix_ops.py
@@ -0,0 +1,98 @@
+from ..backend import sputnik
+from ..matrix import Matrix
+import torch
+import numpy as np
+
+
+@torch.no_grad()
+def row_indices(shape, data, offsets, column_indices):
+ return sputnik.row_indices(shape, data, offsets, column_indices)
+
+
+# TODO(tgale): Replace this helper with a custom kernel. This operation
+# is much simpler to do than how it's currently implemented.
+@torch.no_grad()
+def _expand_for_blocking(idxs, blocking):
+ # Duplicate for block column dimension.
+ idxs = torch.reshape(idxs, [idxs.size()[0], 1, 2]).repeat(1, blocking, 1)
+
+ # Update the column indices.
+ idxs[:, :, 1] *= blocking
+ idxs[:, :, 1] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking])
+
+ # Duplicate for block row dimension.
+ idxs = torch.reshape(idxs, [idxs.size()[0], 1, blocking, 2])
+ idxs = idxs.repeat(1, blocking, 1, 1)
+
+ # Update the row indices.
+ idxs[:, :, :, 0] *= blocking
+ idxs[:, :, :, 0] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking, 1])
+ idxs = torch.reshape(idxs, [-1, 2])
+ return idxs
+
+
+# TODO(tgale): Add input type checking.
+@torch.no_grad()
+def to_dense(x):
+ assert isinstance(x, Matrix)
+
+ shape = (np.prod(x.shape[:-1]), x.shape[-1])
+ row_idxs = x.row_indices.type(torch.int32)
+ col_idxs = x.column_indices.type(torch.int32)
+ indices = _expand_for_blocking(torch.stack([row_idxs, col_idxs], dim=1), x.blocking)
+ indices = (indices[:, 0] * shape[1] + indices[:, 1]).type(torch.int64)
+
+ out = torch.zeros(shape[0] * shape[1], dtype=x.dtype, device=x.device)
+ out.scatter_(0, indices, x.data.flatten())
+ return out.reshape(x.size())
+
+
+@torch.no_grad()
+def _mask(x, blocking=1):
+ assert x.dim() == 2
+ assert x.size()[0] % blocking == 0
+ assert x.size()[1] % blocking == 0
+ block_rows = x.size()[0] // blocking
+ block_cols = x.size()[1] // blocking
+ x = torch.reshape(x, [block_rows, blocking, block_cols, blocking])
+ x = torch.sum(torch.abs(x), dim=(1, 3))
+ return x != 0
+
+
+# TODO(tgale): Add input type checking.
+@torch.no_grad()
+def to_sparse(x, blocking=1):
+ m = _mask(x, blocking)
+
+ # TODO(tgale): Set to appropriate type for input matrix.
+ row_nnzs = torch.sum(m, dim=1).type(torch.int32)
+ zeros = torch.zeros((1,), dtype=row_nnzs.dtype, device=row_nnzs.device)
+ offsets = torch.cat([zeros, torch.cumsum(row_nnzs, dim=0)])
+ offsets = offsets.type(torch.int32)
+
+ indices = torch.nonzero(m).type(torch.int16)
+ row_indices = indices[:, 0]
+ column_indices = indices[:, 1]
+
+ # Nonzero indices in the dense matrix.
+ nonzero_indices = torch.nonzero(m)
+ nonzero_indices = _expand_for_blocking(nonzero_indices, blocking)
+ nonzero_indices = nonzero_indices[:, 0] * x.size()[1] + nonzero_indices[:, 1]
+
+ # Gather the data and construct the sparse matrix.
+ data = torch.gather(x.flatten(), dim=0, index=nonzero_indices)
+ data = torch.reshape(data, [-1, blocking, blocking])
+ return Matrix(x.size(), data, row_indices, column_indices, offsets)
+
+
+@torch.no_grad()
+def ones_like(x):
+ return Matrix(x.size(),
+ torch.ones_like(x.data),
+ x.row_indices,
+ x.column_indices, x.offsets)
+
+
+def sum(x):
+ assert isinstance(x, Matrix)
+ return x.data.sum()
diff --git a/build/torch29-cxx11-cu128-aarch64-linux/stk/ops/matrix_ops_test.py b/build/torch29-cxx11-cu128-aarch64-linux/stk/ops/matrix_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..3af04c0760483e578f93303dc457415948a2a34c
--- /dev/null
+++ b/build/torch29-cxx11-cu128-aarch64-linux/stk/ops/matrix_ops_test.py
@@ -0,0 +1,62 @@
+import unittest
+
+from absl.testing import parameterized
+import stk
+import torch
+
+
+@parameterized.parameters(
+ (8, 16, 0.0, 1),
+ (8, 16, 0.5, 1),
+ (8, 16, .95, 1),
+ (16, 8, 0.0, 1),
+ (16, 8, 0.5, 1),
+ (16, 8, .95, 1),
+ (8, 16, 0.0, 8),
+ (8, 16, 0.5, 8),
+ (8, 16, 1.0, 8),
+ (16, 8, 0.0, 8),
+ (16, 8, 0.5, 8),
+ (16, 8, 1.0, 8),
+ (128, 256, 0.5, 16),
+ (256, 128, 0.75, 32),
+ (512, 512, .875, 128))
+class MatrixOpsTest(parameterized.TestCase):
+
+ def testMatrixOps_FormatConversion(self, rows, cols, sparsity, blocking):
+ mask = stk.random.dense_mask(rows, cols, sparsity, blocking)
+ x = (torch.randn(rows, cols) * mask).type(torch.float16)
+
+ # Convert the matrix to sparse format.
+ sparse_x = stk.ops.to_sparse(x, blocking)
+
+ # Validate the matrix.
+ sparse_x.validate()
+
+ # Validate the shape.
+ self.assertEqual(sparse_x.dim(), 2)
+ self.assertEqual(sparse_x.size()[0], rows)
+ self.assertEqual(sparse_x.size()[1], cols)
+
+ # Validate the sparsity.
+ numblocks = rows // blocking * cols // blocking
+ nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+ self.assertEqual(sparse_x.nnz, nnz)
+
+ # Convert back to dense format.
+ dense_x = stk.ops.to_dense(sparse_x)
+
+ # Validate the shape.
+ self.assertEqual(dense_x.dim(), 2)
+ self.assertEqual(dense_x.size()[0], rows)
+ self.assertEqual(dense_x.size()[1], cols)
+
+ # Validate the sparsity
+ self.assertEqual(torch.count_nonzero(dense_x).item(), nnz)
+
+ # Validate the output.
+ self.assertTrue(torch.all(torch.eq(x, dense_x)))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch29-cxx11-cu128-aarch64-linux/stk/random/__init__.py b/build/torch29-cxx11-cu128-aarch64-linux/stk/random/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2576d1ca27283f77569a9a620c7c99fa68aaf30e
--- /dev/null
+++ b/build/torch29-cxx11-cu128-aarch64-linux/stk/random/__init__.py
@@ -0,0 +1,2 @@
+# from stk.random.random_ops import dense_mask, mask, randn
+from .random_ops import dense_mask, mask, randn
diff --git a/build/torch29-cxx11-cu128-aarch64-linux/stk/random/random_ops.py b/build/torch29-cxx11-cu128-aarch64-linux/stk/random/random_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1b36771e0eb8e7abf46bcb3b136b5fb1d29df93
--- /dev/null
+++ b/build/torch29-cxx11-cu128-aarch64-linux/stk/random/random_ops.py
@@ -0,0 +1,36 @@
+import numpy as np
+import torch
+from ..ops import matrix_ops
+
+
+@torch.no_grad()
+def dense_mask(rows, cols, sparsity, blocking=1):
+ assert sparsity >= 0.0 and sparsity <= 1.0
+ assert rows % blocking == 0 and cols % blocking == 0
+
+ block_rows, block_cols = (rows // blocking, cols // blocking)
+ nnz = round(block_rows * block_cols * (1 - sparsity))
+
+ out = np.ones(block_rows * block_cols)
+ mask = np.random.choice(out.size, out.size - nnz, replace=False)
+ out[mask] = 0.0
+
+ out = np.tile(
+ np.reshape(out, [block_rows, 1, block_cols, 1]),
+ (1, blocking, 1, blocking))
+ out = np.reshape(out, [rows, cols])
+ return torch.from_numpy(out.astype(np.float32))
+
+
+@torch.no_grad()
+def mask(m, n, sparsity, blocking=1):
+ out = dense_mask(m, n, sparsity, blocking).type(torch.float16)
+ return matrix_ops.to_sparse(out, blocking=blocking)
+
+
+@torch.no_grad()
+def randn(shape, sparsity, blocking=1):
+ shape_2d = (np.prod(shape[:-1]), shape[-1])
+ out = mask(*shape_2d, sparsity, blocking)
+ out.data.copy_(torch.randn(*out.data.shape))
+ return out.view(*shape)
diff --git a/build/torch29-cxx11-cu128-aarch64-linux/stk/random/random_ops_test.py b/build/torch29-cxx11-cu128-aarch64-linux/stk/random/random_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..587b44ec890c861879c6296b8f9028f5d99ab82f
--- /dev/null
+++ b/build/torch29-cxx11-cu128-aarch64-linux/stk/random/random_ops_test.py
@@ -0,0 +1,73 @@
+import unittest
+
+from absl.testing import parameterized
+from . import random
+import torch
+
+
+@parameterized.parameters(
+ (8, 16, 0.0, 1),
+ (8, 16, 0.5, 1),
+ (8, 16, .95, 1),
+ (16, 8, 0.0, 1),
+ (16, 8, 0.5, 1),
+ (16, 8, .95, 1),
+ (8, 16, 0.0, 8),
+ (8, 16, 0.5, 8),
+ (8, 16, 1.0, 8),
+ (16, 8, 0.0, 8),
+ (16, 8, 0.5, 8),
+ (16, 8, 1.0, 8),
+ (128, 256, 0.5, 16),
+ (256, 128, 0.75, 32),
+ (512, 512, .875, 128))
+class RandomOpsTest(parameterized.TestCase):
+
+ def testRandomOps_DenseMask(self, rows, cols, sparsity, blocking):
+ mask = random.dense_mask(
+ rows, cols, sparsity, blocking)
+
+ # Validate the shape.
+ self.assertEqual(mask.dim(), 2)
+ self.assertEqual(mask.size()[0], rows)
+ self.assertEqual(mask.size()[1], cols)
+
+ # Validate the sparsity
+ numblocks = rows // blocking * cols // blocking
+ nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+ self.assertEqual(
+ torch.count_nonzero(mask).item(),
+ nnz)
+
+ # Check values are zero or one.
+ self.assertTrue(
+ torch.all(torch.logical_or(
+ torch.eq(mask, 0),
+ torch.eq(mask, 1))))
+
+ def testRandomOps_SparseMask(self, rows, cols, sparsity, blocking):
+ mask = random.mask(
+ rows, cols, sparsity, blocking)
+
+ # Validate the matrix.
+ mask.validate()
+
+ # Validate the shape.
+ self.assertEqual(mask.dim(), 2)
+ self.assertEqual(mask.size()[0], rows)
+ self.assertEqual(mask.size()[1], cols)
+
+ # Validate the sparsity.
+ numblocks = rows // blocking * cols // blocking
+ nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+ self.assertEqual(mask.nnz, nnz)
+
+ # Check values are zero or one.
+ self.assertTrue(
+ torch.all(torch.logical_or(
+ torch.eq(mask.data, 0),
+ torch.eq(mask.data, 1))))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch29-cxx11-cu128-aarch64-linux/xpu_fused_moe.py b/build/torch29-cxx11-cu128-aarch64-linux/xpu_fused_moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..0440b484e6f4073a384dcf3bc562174d2ef71d25
--- /dev/null
+++ b/build/torch29-cxx11-cu128-aarch64-linux/xpu_fused_moe.py
@@ -0,0 +1,672 @@
+# SPDX-License-Identifier: Apache-2.0
+# MegaBlocks XPU Fused MoE Implementation
+import os
+import torch
+
+from ._ops import ops, add_op_namespace_prefix
+
+from torch.library import register_fake
+
+
+def resolve_dtensor(weight: torch.Tensor):
+ """Convert DTensor to local tensor for use with custom ops."""
+ from torch.distributed._tensor import DTensor
+ if isinstance(weight, DTensor):
+ return weight.to_local()
+ return weight
+
+
+# Register fake/meta kernels for torch.compile compatibility
+def _register_xpu_fake_kernels():
+ """Register fake kernels for XPU MoE operations to support torch.compile."""
+
+ def _register_if_available(op_name, fn):
+ if hasattr(ops, op_name):
+ register_fake(add_op_namespace_prefix(op_name))(fn)
+
+ _register_if_available(
+ "cutlass_grouped_gemm_interface",
+ lambda ptr_A, ptr_B, ptr_scales, ptr_bias, ptr_D, expert_first_token_offset, N, K, num_experts, is_B_int4, is_B_mxfp4: ptr_D,
+ )
+
+ _register_if_available(
+ "fused_moe_prologue",
+ lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, ep_rank, ep_size, num_experts_on_rank: None,
+ )
+
+ _register_if_available(
+ "moe_gather",
+ lambda output, moe_output, topk_weights, permuted_row_to_unpermuted_row, unpermuted_row_to_permuted_row, expert_first_token_offset, num_experts: None,
+ )
+
+ _register_if_available(
+ "silu_and_mul",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "mul_and_silu",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_and_mul",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_tanh_and_mul",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_fast",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_new",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_quick",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "swigluoai_and_mul",
+ lambda out, input, alpha=1.702, limit=7.0: None,
+ )
+
+
+# Register fake kernels on module load
+_register_xpu_fake_kernels()
+
+
+# default
+def cutlass_grouped_gemm(input_A, input_B, bias, output, expert_token_count, n,
+ k, num_experts):
+ # expert_token_count_ = torch.tensor(expert_token_count,
+ # dtype=torch.int64,
+ # device=input_A.device)
+ # if bias is not None:
+ # bias = bias.repeat_interleave(expert_token_count_, dim=0).float()
+
+ def exclusive_prefix_sum(arr):
+ prefix = [0]
+ for i, x in enumerate(arr):
+ prefix.append(prefix[-1] + x)
+ return prefix
+
+ expert_offset = torch.tensor(exclusive_prefix_sum(expert_token_count),
+ dtype=torch.int64,
+ device="xpu")
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=None,
+ ptr_bias=bias,
+ ptr_D=output,
+ expert_first_token_offset=expert_offset,
+ N=n,
+ K=k,
+ num_experts=num_experts,
+ is_B_int4=False,
+ is_B_mxfp4=False)
+
+
+def cutlass_grouped_gemm_xe2(input_A, input_B, scales, bias, output,
+ num_rows_per_expert, n, k, num_experts, is_B_int4,
+ is_B_mxfp4):
+ expert_first_token_offset = torch.cat([
+ torch.tensor([0],
+ dtype=num_rows_per_expert.dtype,
+ device=num_rows_per_expert.device),
+ torch.cumsum(num_rows_per_expert, dim=0)
+ ]).to(torch.int64)
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=scales,
+ ptr_bias=bias,
+ ptr_D=output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=n,
+ K=k,
+ num_experts=num_experts,
+ is_B_int4=is_B_int4,
+ is_B_mxfp4=is_B_mxfp4)
+
+
+def ceilDiv(a, b):
+ return (a + b - 1) // b
+
+
+def compute_num_tokens_per_block(num_tokens, num_experts_per_node):
+ for num_tokens_per_block in [32, 64, 128, 256, 512, 1024]:
+ num_blocks_per_seq = ceilDiv(num_tokens, num_tokens_per_block)
+ if num_blocks_per_seq * num_experts_per_node <= num_tokens_per_block:
+ return num_tokens_per_block
+ return 1024
+
+
+def _bytes_to_typed_tensor(byte_tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
+ """Reinterpret a uint8 buffer as a typed tensor by copying bytes.
+
+ This avoids `Tensor.view(dtype)` which can fail under torch.compile
+ constant folding when shape divisibility is not proven.
+ """
+ if byte_tensor.dtype != torch.uint8:
+ raise ValueError("byte_tensor must be uint8")
+ itemsize = torch.empty((), dtype=dtype).element_size()
+ numel = byte_tensor.numel() // itemsize
+ out = torch.empty((numel,), dtype=dtype, device=byte_tensor.device)
+ out.view(torch.uint8).copy_(byte_tensor.contiguous())
+ return out
+
+
+def implement_zp(qweight):
+ # change u4 to s4 to avoid zero point in gemm kernel
+ # only support default zero point now
+ assert qweight.dtype == torch.uint8, "Input tensor must be uint8"
+
+ high_u4 = (qweight >> 4) & 0x0F
+ low_u4 = qweight & 0x0F
+
+ high_s8 = high_u4.to(torch.int8)
+ low_s8 = low_u4.to(torch.int8)
+
+ high_s8 = high_s8 - 8
+ low_s8 = low_s8 - 8
+
+ def pack_compact(a, b):
+
+ def process_number(x):
+ sign = (x < 0).to(torch.uint8)
+ abs_low3 = (x.view(torch.uint8) & 0x7).to(torch.uint8)
+ return (sign << 3) | abs_low3
+
+ packed_a = process_number(a)
+ packed_b = process_number(b)
+
+ return (packed_a << 4) | packed_b
+
+ result = pack_compact(high_s8, low_s8)
+
+ return result
+
+
+def xpu_fused_moe(hidden_states,
+ w13,
+ w13_scales,
+ w13_bias,
+ w2,
+ w2_scales,
+ w2_bias,
+ topk_weights,
+ topk_ids,
+ n_experts_per_token,
+ activation,
+ num_experts,
+ ep_rank=0,
+ ep_size=1,
+ is_fp8=False,
+ is_int4=False,
+ is_mxfp4=False):
+ '''
+ hidden_states: [num_rows, hidden_size]
+ w13: [num_experts, 2*inter_size, hidden_size]
+ w13_scales:
+ None for bf16/fp16
+ or [num_experts] for fp8
+ or [num_experts, 2*inter_size, hidden_size // group_size] for 4bits
+ w13_bias: [num_experts, 2*inter_size] or None
+ w2: [num_experts, hidden_size, inter_size]
+ w2_scales:
+ None for bf16/fp16
+ or [num_experts] for fp8
+ or [num_experts, hidden_size, inter_size // group_size] for 4bits
+ w2_bias: [num_experts, hidden_size] or None
+ topk_weights: [num_rows, topk]
+ topk_ids: [num_rows, topk]
+ n_experts_per_token: int
+ activation: str
+ num_experts: int
+ is_int4: bool
+ is_mxfp4: bool
+ '''
+
+ # Resolve DTensors to local tensors before passing to custom ops
+ hidden_states = resolve_dtensor(hidden_states)
+ w13 = resolve_dtensor(w13)
+ w2 = resolve_dtensor(w2)
+ if w13_scales is not None:
+ w13_scales = resolve_dtensor(w13_scales)
+ if w13_bias is not None:
+ w13_bias = resolve_dtensor(w13_bias)
+ if w2_scales is not None:
+ w2_scales = resolve_dtensor(w2_scales)
+ if w2_bias is not None:
+ w2_bias = resolve_dtensor(w2_bias)
+ topk_weights = resolve_dtensor(topk_weights)
+ topk_ids = resolve_dtensor(topk_ids)
+
+ output = torch.empty_like(hidden_states)
+ num_rows, hidden_size = list(hidden_states.shape)
+
+ dim_last = w13.shape[-1]
+ dim_second_last = w13.shape[-2]
+
+ # w13 is combined gate+up weights, so one dimension is 2*inter_size
+ # Determine which dimension is hidden_size and which is 2*inter_size
+ if dim_second_last == hidden_size:
+ # w13 is [E, hidden_size, 2*inter_size] - standard layout
+ inter_size = dim_last // 2
+ needs_transpose = False
+ else:
+ # w13 is [E, 2*inter_size, hidden_size] - needs transpose
+ inter_size = dim_second_last // 2
+ needs_transpose = True
+
+ assert w13.is_contiguous() and w2.is_contiguous()
+
+ # 4bits support [E, N, K]
+ # other types [E, K, N]
+ if not is_int4 and not is_mxfp4:
+ if not hasattr(w13, 'xpu_fused_moe'):
+ if needs_transpose:
+ w13.data = w13.transpose(-1, -2).contiguous()
+ w2.data = w2.transpose(-1, -2).contiguous()
+ w13.xpu_fused_moe = True
+ w13.inter_size = inter_size
+ else:
+ inter_size = w13.inter_size
+
+ if is_int4 and not hasattr(w13, 'xpu_fused_moe'):
+ w13_tmp = torch.empty_like(w13)
+ w2_tmp = torch.empty_like(w2)
+ for i in range(num_experts):
+ w13_tmp[i] = implement_zp(w13[i])
+ w2_tmp[i] = implement_zp(w2[i])
+ w13_tmp = w13_tmp.contiguous()
+ w2_tmp = w2_tmp.contiguous()
+ w13.data = w13_tmp
+ w2.data = w2_tmp
+ w13.xpu_fused_moe = True
+
+ # TODO: will all integrated in Cpp func. Temporary expose before gemm fusion
+ num_experts_per_node = num_experts
+ experts_per_token = n_experts_per_token
+ num_moe_inputs = n_experts_per_token * num_rows
+ permuted_elems = num_moe_inputs * hidden_size
+ # interbuf_elems = num_moe_inputs * inter_size
+ permuted_row_to_unpermuted_row_size = num_moe_inputs * 4
+ permuted_token_selected_experts_size = num_moe_inputs * 4
+ src_to_dest_map_size = experts_per_token * num_rows * 4
+ expert_first_token_offset_size = (num_experts_per_node + 1) * 8
+ num_tokens_per_block = compute_num_tokens_per_block(
+ num_rows, num_experts_per_node)
+ num_blocks_per_seq = ceilDiv(num_rows, num_tokens_per_block)
+ blocked_expert_counts_size = num_experts_per_node * num_blocks_per_seq * 4
+ blocked_expert_counts_cumsum_size = blocked_expert_counts_size
+ blocked_row_to_unpermuted_row_size = num_experts_per_node * num_rows * 4
+ permuted_data_size = permuted_elems * hidden_states.element_size()
+ permuted_token_final_scales_size = num_moe_inputs * 4
+
+ ws_map = {}
+ map_offset = 0
+
+ def config_ws(name, size):
+ nonlocal map_offset
+ if size % 256 != 0:
+ size += 256 - size % 256
+ ws_map[name] = (size, map_offset)
+ map_offset += size
+
+ config_ws("permuted_row_to_unpermuted_row",
+ permuted_row_to_unpermuted_row_size)
+ config_ws("permuted_token_selected_experts",
+ permuted_token_selected_experts_size)
+ config_ws("unpermuted_row_to_permuted_row", src_to_dest_map_size)
+ config_ws("blocked_expert_counts", blocked_expert_counts_size)
+ config_ws("blocked_expert_counts_cumsum",
+ blocked_expert_counts_cumsum_size)
+ config_ws("blocked_row_to_unpermuted_row",
+ blocked_row_to_unpermuted_row_size)
+ config_ws("expert_first_token_offset", expert_first_token_offset_size)
+ config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
+ config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
+
+ workspace = torch.zeros(map_offset,
+ dtype=torch.uint8,
+ device=hidden_states.device)
+ if topk_ids.dtype == torch.int32:
+ topk_ids = topk_ids.to(torch.int64)
+ ops.fused_moe_prologue(
+ input=hidden_states,
+ token_selected_experts=topk_ids,
+ token_final_scales=topk_weights,
+ workspace=workspace,
+ hidden_size=hidden_size,
+ inter_size=inter_size,
+ ep_rank=ep_rank,
+ ep_size=ep_size,
+ num_experts_on_rank=num_experts_per_node)
+
+ expert_first_token_offset_bytes = workspace[
+ ws_map["expert_first_token_offset"][1]:
+ ws_map["expert_first_token_offset"][1] +
+ expert_first_token_offset_size]
+ unpermuted_row_to_permuted_row_bytes = workspace[
+ ws_map["unpermuted_row_to_permuted_row"][1]:
+ ws_map["unpermuted_row_to_permuted_row"][1] +
+ src_to_dest_map_size]
+ permuted_row_to_unpermuted_row_bytes = workspace[
+ ws_map["permuted_row_to_unpermuted_row"][1]:
+ ws_map["permuted_row_to_unpermuted_row"][1] +
+ permuted_row_to_unpermuted_row_size]
+
+ if torch.compiler.is_compiling():
+ expert_first_token_offset = _bytes_to_typed_tensor(
+ expert_first_token_offset_bytes, torch.int64
+ )
+ unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
+ unpermuted_row_to_permuted_row_bytes, torch.int32
+ )
+ permuted_row_to_unpermuted_row = _bytes_to_typed_tensor(
+ permuted_row_to_unpermuted_row_bytes, torch.int32
+ )
+ else:
+ expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
+ unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
+ permuted_row_to_unpermuted_row = permuted_row_to_unpermuted_row_bytes.view(torch.int32)
+ gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
+ ws_map["overlapped_gemm1_gemm2_inputs"][1] +
+ permuted_data_size].view(hidden_states.dtype).view(
+ num_moe_inputs, hidden_size)
+ # permuted_token_final_scales = workspace[
+ # ws_map["permuted_token_final_scales"][1]:
+ # ws_map["permuted_token_final_scales"][1] +
+ # permuted_token_final_scales_size].view(torch.float)
+ gemm1_output = torch.empty((num_moe_inputs, 2 * inter_size),
+ dtype=hidden_states.dtype,
+ device=hidden_states.device)
+
+ ########### gemm1 ##################
+ input_B = w13
+
+ if not is_fp8 and not is_int4 and not is_mxfp4:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=gemm1_input,
+ ptr_B=input_B,
+ ptr_scales=None,
+ ptr_bias=w13_bias,
+ ptr_D=gemm1_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=2 * inter_size,
+ K=hidden_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+ else:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=gemm1_input,
+ ptr_B=input_B,
+ ptr_scales=w13_scales,
+ ptr_bias=w13_bias,
+ ptr_D=gemm1_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=2 * inter_size,
+ K=hidden_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+
+ # act
+ act_output = torch.empty((num_moe_inputs, inter_size),
+ dtype=gemm1_output.dtype,
+ device=gemm1_output.device)
+ if activation == "silu":
+ ops.silu_and_mul(act_output, gemm1_output)
+ elif activation == "gelu":
+ ops.gelu_and_mul(act_output, gemm1_output)
+ elif activation == "swigluoai":
+ ops.swigluoai_and_mul(act_output, gemm1_output, 1.702, 7.0)
+ else:
+ raise ValueError(f"Unsupported FusedMoe activation: {activation}.")
+
+ ########### gemm2 ##################
+ input_A = act_output.contiguous()
+ input_B = w2
+ gemm2_output = torch.empty((num_moe_inputs, hidden_size),
+ dtype=hidden_states.dtype,
+ device=hidden_states.device)
+ if not is_fp8 and not is_int4 and not is_mxfp4:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=None,
+ ptr_bias=w2_bias,
+ ptr_D=gemm2_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=hidden_size,
+ K=inter_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+ else:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=w2_scales,
+ ptr_bias=w2_bias,
+ ptr_D=gemm2_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=hidden_size,
+ K=inter_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+
+ ops.moe_gather(output, gemm2_output, topk_weights,
+ permuted_row_to_unpermuted_row,
+ unpermuted_row_to_permuted_row,
+ expert_first_token_offset,
+ num_experts_per_node)
+ return output
+
+
+def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor:
+ """Apply jitter to the input tensor for regularization."""
+ low = 1.0 - moe_jitter_eps
+ high = 1.0 + moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return x * (low + noise * (high - low))
+
+
+def compute_top_k(scores: torch.Tensor, moe_top_k: int):
+ """Compute the top-k scores from the logits."""
+ if moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, moe_top_k, dim=-1)
+
+
+def route_tokens_xpu(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: torch.Tensor,
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ training: bool = False,
+) -> tuple:
+ """Route tokens to experts and compute expert weights and indices (XPU version)."""
+ if training and moe_jitter_eps is not None:
+ x = apply_jitter(x, moe_jitter_eps)
+
+ x_flat = x.view(-1, x.shape[-1])
+ logits = torch.nn.functional.linear(x_flat, router_weight, router_bias)
+ expert_weights, expert_indices = compute_top_k(logits, moe_top_k)
+ expert_weights = expert_weights.softmax(dim=-1)
+ if moe_normalize_expert_weights is not None:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+
+ return logits, expert_weights, expert_indices
+
+
+def _get_device_mesh(model):
+ """Extract device_mesh from child's unused pre_hook closure for EP support."""
+ try:
+ hook = next(
+ h
+ for h in model.experts._forward_pre_hooks.values()
+ if "device_mesh" in h.__code__.co_freevars
+ )
+ return hook.__closure__[
+ hook.__code__.co_freevars.index("device_mesh")
+ ].cell_contents
+ except Exception:
+ return None
+
+
+class MegaBlocksMoeMLP(torch.nn.Module):
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> tuple:
+ """
+ Forward pass through the MoE layer.
+
+ Args:
+ x: Input tensor of shape [batch_size, seq_len, hidden_size] or [tokens, hidden_size]
+
+ Returns:
+ Tuple of (output, expert_weights) where:
+ - output: Tensor of same shape as input
+ - expert_weights: Expert weights for each token [tokens, top_k]
+ """
+
+ # Get MoE parameters from the wrapped modules
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+
+ # Get EP (Expert Parallelism) parameters
+ ep_size = 1
+ ep_rank = 0
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = _get_device_mesh(self)
+ if device_mesh is not None:
+ expert_parallel_group = device_mesh.get_group()
+ if expert_parallel_group is not None:
+ import torch.distributed as dist
+ if dist.is_initialized():
+ ep_size = dist.get_world_size(expert_parallel_group)
+ ep_rank = dist.get_rank(expert_parallel_group)
+
+ # Number of experts on this rank
+ num_experts_on_rank = moe_num_experts // ep_size
+
+ # Detect activation type - check for GptOss-style swigluoai activation
+ # GptOssExperts has alpha and limit attributes for swigluoai
+ if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
+ activation = "swigluoai"
+ else:
+ activation = getattr(self.experts, "activation", "silu")
+
+ # Get weight tensors - support different naming conventions
+ if hasattr(self.experts, "gate_up_proj"):
+ w13 = self.experts.gate_up_proj
+ # NOTE: swigluoai_and_mul kernel expects interleaved layout [g0,u0,g1,u1,...]
+ # which matches GptOss's gate_up_proj layout, so no conversion needed.
+
+ elif hasattr(self.experts, "w1"):
+ # Combine w1 and w3 if stored separately
+ w1 = self.experts.w1
+ w3 = getattr(self.experts, "w3", None)
+ if w3 is not None:
+ w13 = torch.cat([w1, w3], dim=-2)
+ else:
+ w13 = w1
+ else:
+ raise AttributeError("experts module must have 'gate_up_proj' or 'w1' attribute")
+
+ if hasattr(self.experts, "down_proj"):
+ w2 = self.experts.down_proj
+ elif hasattr(self.experts, "w2"):
+ w2 = self.experts.w2
+ else:
+ raise AttributeError("experts module must have 'down_proj' or 'w2' attribute")
+
+ # Get optional bias tensors
+ w13_bias = getattr(self.experts, "gate_up_proj_bias", None)
+ w2_bias = getattr(self.experts, "down_proj_bias", None)
+
+ # Get quantization info
+ is_fp8 = getattr(self.experts, "is_fp8", False)
+ is_int4 = getattr(self.experts, "is_int4", False)
+ is_mxfp4 = getattr(self.experts, "is_mxfp4", False)
+
+ w13_scales = getattr(self.experts, "gate_up_proj_scales", None)
+ w2_scales = getattr(self.experts, "down_proj_scales", None)
+
+ # Store original shape
+ in_shape = x.size()
+
+ # Route tokens to experts
+ logits, expert_weights, expert_indices = route_tokens_xpu(
+ x,
+ self.router.weight,
+ getattr(self.router, "bias", None),
+ moe_top_k,
+ moe_num_experts,
+ moe_jitter_eps,
+ moe_normalize_expert_weights,
+ self.training,
+ )
+
+ # Reshape input for fused MoE
+ x_flat = x.view(-1, x.shape[-1])
+
+ # Call XPU fused MoE kernel
+ output = xpu_fused_moe(
+ hidden_states=x_flat,
+ w13=w13,
+ w13_scales=w13_scales,
+ w13_bias=w13_bias,
+ w2=w2,
+ w2_scales=w2_scales,
+ w2_bias=w2_bias,
+ topk_weights=expert_weights.float(),
+ topk_ids=expert_indices,
+ n_experts_per_token=moe_top_k,
+ activation=activation,
+ num_experts=num_experts_on_rank,
+ ep_rank=ep_rank,
+ ep_size=ep_size,
+ is_fp8=is_fp8,
+ is_int4=is_int4,
+ is_mxfp4=is_mxfp4,
+ )
+
+ # All-reduce across EP group to combine partial expert outputs
+ if ep_size > 1 and expert_parallel_group is not None:
+ import torch.distributed as dist
+ dist.all_reduce(output, op=dist.ReduceOp.SUM, group=expert_parallel_group)
+
+ # Restore original shape
+ output = output.view(in_shape)
+
+ return output, expert_weights
+
+
+# Export classes and functions
+__all__ = [
+ "MegaBlocksMoeMLP",
+ "xpu_fused_moe",
+ "cutlass_grouped_gemm",
+ "cutlass_grouped_gemm_xe2",
+]
\ No newline at end of file
diff --git a/build/torch29-cxx11-cu128-x86_64-linux/__init__.py b/build/torch29-cxx11-cu128-x86_64-linux/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..38075732c6d8fa0e1e6ef493145e1aca3851ae6b
--- /dev/null
+++ b/build/torch29-cxx11-cu128-x86_64-linux/__init__.py
@@ -0,0 +1,202 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+from ._ops import ops
+
+from .grouped_gemm import backend as gg_backend
+from .grouped_gemm import ops as gg_ops
+
+
+from ._layers.arguments import Arguments
+from ._layers.dmoe import ParallelDroplessMLP, dMoE
+from ._layers.glu import SparseGLU
+from ._layers.mlp import MLP, SparseMLP
+from ._layers.moe import MoE, ParallelMLP, get_load_balancing_loss
+
+from . import layers
+
+# This section contains the direct kernel exports (not inlcuded in the original code)
+def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
+ """
+ Compute exclusive cumulative sum along the specified dimension.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ result = ops.exclusive_cumsum(x, dim)
+ out.copy_(result)
+ return out
+
+
+def inclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
+ """
+ Compute inclusive cumulative sum along the specified dimension.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ result = ops.inclusive_cumsum(x, dim)
+ out.copy_(result)
+ return out
+
+
+def histogram(x: torch.Tensor, num_bins: int) -> torch.Tensor:
+ """
+ Compute histogram of input tensor values.
+
+ Args:
+ x: Input tensor
+ num_bins: Number of histogram bins
+
+ Returns:
+ Histogram tensor with counts for each bin
+ """
+ return ops.histogram(x, num_bins)
+
+
+def indices(
+ padded_bins: torch.Tensor,
+ block_size: int,
+ output_block_rows: int,
+ output_block_columns: int,
+) -> torch.Tensor:
+ """
+ Construct indices from padded bins for sparse operations.
+
+ Args:
+ padded_bins: Tensor containing bin boundaries
+ block_size: Size of each block
+ output_block_rows: Number of rows in output blocks
+ output_block_columns: Number of columns in output blocks
+
+ Returns:
+ Tensor containing constructed indices
+ """
+ return ops.indices(padded_bins, block_size, output_block_rows, output_block_columns)
+
+
+def replicate_forward(
+ x: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Forward pass of replicate operation - replicate values according to bin sizes.
+
+ Args:
+ x: Input tensor with values to replicate
+ bins: Tensor containing bin sizes
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ return ops.replicate_forward(x, bins, out)
+
+
+def replicate_backward(
+ grad: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Backward pass of replicate operation - reduce gradients back to bins.
+
+ Args:
+ grad: Gradient tensor to reduce
+ bins: Tensor containing bin sizes
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ return ops.replicate_backward(grad, bins, out)
+
+
+def sort(
+ x: torch.Tensor, end_bit: int, x_out: torch.Tensor, iota_out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Radix sort with index tracking.
+
+ Args:
+ x: Input tensor to sort
+ end_bit: Number of bits to consider in sorting
+ x_out: Output tensor for sorted values
+ iota_out: Output tensor for sorted indices
+
+ Returns:
+ The sorted values tensor
+ """
+ return ops.sort(x, end_bit, x_out, iota_out)
+
+
+# Convenience functions for common use cases
+def cumsum(x: torch.Tensor, dim: int = -1, exclusive: bool = False) -> torch.Tensor:
+ """
+ Compute cumulative sum with automatic output allocation.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum (default: last dimension)
+ exclusive: Whether to compute exclusive (True) or inclusive (False) cumsum
+
+ Returns:
+ New tensor containing the cumulative sum
+ """
+ out = torch.empty_like(x)
+ if exclusive:
+ return exclusive_cumsum(x, dim, out)
+ else:
+ return inclusive_cumsum(x, dim, out)
+
+
+def argsort(x: torch.Tensor, end_bit: int = 32) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Sort tensor and return both sorted values and indices.
+
+ Args:
+ x: Input tensor to sort
+ end_bit: Number of bits to consider in sorting
+
+ Returns:
+ Tuple of (sorted_values, sorted_indices)
+ """
+ x_out = torch.empty_like(x)
+ iota_out = torch.empty_like(x)
+ sort(x, end_bit, x_out, iota_out)
+ return x_out, iota_out
+
+
+# Export public API
+__all__ = [
+ "MyReplacementLayer",
+ # Direct kernel exports
+ "exclusive_cumsum",
+ "inclusive_cumsum",
+ "histogram",
+ "indices",
+ "replicate_forward",
+ "replicate_backward",
+ "sort",
+ "cumsum",
+ "argsort",
+ # Original exports
+ "Arguments",
+ "ParallelDroplessMLP",
+ "dMoE",
+ "SparseGLU",
+ "MLP",
+ "SparseMLP",
+ "MoE",
+ "ParallelMLP",
+ "get_load_balancing_loss",
+]
diff --git a/build/torch29-cxx11-cu128-x86_64-linux/_layers/__init__.py b/build/torch29-cxx11-cu128-x86_64-linux/_layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a720e7a2cc4e44636f6e433a2750e945dc38e8b2
--- /dev/null
+++ b/build/torch29-cxx11-cu128-x86_64-linux/_layers/__init__.py
@@ -0,0 +1,10 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# from megablocks.layers.dmoe import dMoE
+from .moe import MoE
+
+__all__ = [
+ 'MoE',
+ # 'dMoE',
+]
diff --git a/build/torch29-cxx11-cu128-x86_64-linux/_layers/activation_fn.py b/build/torch29-cxx11-cu128-x86_64-linux/_layers/activation_fn.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e1d956704840aa4daf7d1d71d24e051567feab9
--- /dev/null
+++ b/build/torch29-cxx11-cu128-x86_64-linux/_layers/activation_fn.py
@@ -0,0 +1,33 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Callable, Union
+
+import torch
+from ..stk import Matrix
+
+
+def act_fn(
+ x: Matrix,
+ function: Callable,
+ return_grad_fn: bool = False,
+ **kwargs,
+) -> Union[tuple[Matrix, Any] | Matrix]:
+ assert isinstance(x, Matrix)
+ with torch.set_grad_enabled(torch.is_grad_enabled() or return_grad_fn):
+ if return_grad_fn:
+ x.data.requires_grad = True
+ out = function(x.data, **kwargs)
+ y = Matrix(
+ x.size(),
+ out,
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
+ if return_grad_fn:
+ return y, out.backward
+ return y
diff --git a/build/torch29-cxx11-cu128-x86_64-linux/_layers/all_to_all.py b/build/torch29-cxx11-cu128-x86_64-linux/_layers/all_to_all.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ac7067bcaa34db1d82b340c43550fe3577aa7a3
--- /dev/null
+++ b/build/torch29-cxx11-cu128-x86_64-linux/_layers/all_to_all.py
@@ -0,0 +1,54 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import torch.distributed as dist
+
+
+class AllToAllOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op):
+ out = torch.empty((sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype)
+
+ ctx.input_shape = x.shape
+ ctx.output_split_sizes = output_split_sizes
+ ctx.input_split_sizes = input_split_sizes
+ ctx.group = group
+ handle = dist.all_to_all_single(
+ out,
+ x,
+ output_split_sizes=output_split_sizes,
+ input_split_sizes=input_split_sizes,
+ group=group,
+ async_op=async_op,
+ )
+ return out, handle
+
+ @staticmethod
+ def backward(ctx, grad, _):
+ if ctx.needs_input_grad[0]:
+ out = torch.empty(
+ ctx.input_shape,
+ device=grad.device,
+ dtype=grad.dtype,
+ )
+ dist.all_to_all_single(
+ out,
+ grad,
+ output_split_sizes=ctx.input_split_sizes,
+ input_split_sizes=ctx.output_split_sizes,
+ group=ctx.group,
+ )
+ return out, None, None, None, None
+ return None, None, None, None, None
+
+
+def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False):
+ return AllToAllOp.apply(
+ x,
+ output_split_sizes,
+ input_split_sizes,
+ group,
+ async_op,
+ )
diff --git a/build/torch29-cxx11-cu128-x86_64-linux/_layers/arguments.py b/build/torch29-cxx11-cu128-x86_64-linux/_layers/arguments.py
new file mode 100644
index 0000000000000000000000000000000000000000..4db9b1bd38bc2e2f421625c124f86b85f45c5ae0
--- /dev/null
+++ b/build/torch29-cxx11-cu128-x86_64-linux/_layers/arguments.py
@@ -0,0 +1,101 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import dataclasses
+from functools import partial
+from typing import Any, Callable, Optional, Union
+
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+
+# import megablocks.grouped_gemm_util as grouped_gemm
+from .. import grouped_gemm_util as grouped_gemm
+
+# Type annotation for in-place Tensor initialization function.
+InitFn = Union[Callable[[torch.Tensor], None], partial[torch.Tensor]]
+
+_ALLOWED_BITWIDTHS = (-1, 4, 8)
+
+DEFAULT_ACTIVATION_FN = partial(F.gelu, approximate='tanh')
+
+
+@dataclasses.dataclass
+class Arguments:
+ # Model arguments.
+ hidden_size: int = 1024
+ ffn_hidden_size: int = 4096
+ num_layers: int = 1
+ bias: bool = True
+ return_bias: bool = True
+ activation_fn: Optional[Callable] = DEFAULT_ACTIVATION_FN
+
+ # MoE arguments.
+ moe_num_experts: int = 1
+ moe_top_k: int = 1
+ moe_capacity_factor: int = 1
+ moe_normalize_expert_weights: Optional[Union[int, float]] = None
+ moe_loss_weight: float = 0.1
+ moe_jitter_eps: Optional[float] = None
+ moe_lbl_in_fp32: bool = False
+
+ # Parallelism arguments.
+ moe_expert_model_parallelism: bool = False
+ expert_parallel_group: Optional[dist.ProcessGroup] = None
+ pipeline_model_parallel_size: int = 1
+ num_layers_per_virtual_pipeline_stage: Optional[int] = None
+
+ # Compute arguments.
+ memory_optimized_mlp: bool = False
+ mlp_type: str = 'mlp'
+ mlp_impl: str = 'sparse'
+
+ # Initialization arguments.
+ fp16: bool = True
+ bf16: bool = False
+ device: Union[int, torch.device] = dataclasses.field(default_factory=torch.cuda.current_device)
+ init_method: InitFn = partial(torch.nn.init.normal_, mean=0.0, std=0.02)
+ output_layer_init_method: InitFn = init_method
+
+ # Benchmarking arguments.
+ uniform_expert_assignment: bool = False
+
+ # shared expert arguments
+ shared_expert: bool = False # enable using shared expert
+ fc_cls: Any = torch.nn.Linear # class of the fully connected layer in shared expert (purpose: to allow using custom FC layer eg te.Linear (for FP8))
+ fc_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict,) # kwargs for custom fc layers
+ remat_act_fn: bool = True # enable act fn to be rematerialized instead of stored
+ shared_expert_hidden_size: Optional[
+ int] = None # hidden size of the shared expert IF we want to set it to something different from hidden_size
+ shared_expert_weighted_sum: bool = False # enable using weighted sum for shared expert output (wieghted by number of experts used)
+
+ # Router Z-loss arguments
+ moe_zloss_weight: float = 0 # 1e-3 is a reasonable value
+ moe_zloss_in_fp32: bool = False
+
+ def __post_init__(self):
+ # Sparse MLP is not supported with triton >=3.2.0
+ # TODO: Remove this once sparse is supported with triton >=3.2.0
+ if self.__getattribute__('mlp_impl') == 'sparse':
+ try:
+ import triton
+ if triton.__version__ >= '3.2.0':
+ raise ValueError(
+ 'Sparse MLP is not supported with triton >=3.2.0. Please use mlp_impl="grouped" instead.',
+ )
+ except ImportError:
+ raise ImportError('Triton is required for sparse MLP implementation')
+
+ if self.__getattribute__('mlp_impl') == 'grouped':
+ grouped_gemm.assert_grouped_gemm_is_available()
+
+ if self.shared_expert_hidden_size is None:
+ self.shared_expert_hidden_size = self.ffn_hidden_size
+
+
+def from_megatron(megatron_args: Any):
+ args = Arguments()
+ for field in dataclasses.fields(args):
+ if hasattr(megatron_args, field.name):
+ setattr(args, field.name, getattr(megatron_args, field.name))
+ return args
diff --git a/build/torch29-cxx11-cu128-x86_64-linux/_layers/common.py b/build/torch29-cxx11-cu128-x86_64-linux/_layers/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d07109702963ba48a3b94ab860807954dfd79c1
--- /dev/null
+++ b/build/torch29-cxx11-cu128-x86_64-linux/_layers/common.py
@@ -0,0 +1,26 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+from .arguments import Arguments
+
+
+def dtype(args: Arguments):
+ if args.fp16:
+ return torch.float16
+ elif args.bf16:
+ return torch.bfloat16
+ return None
+
+
+def cast_if_autocast_enabled(tensor):
+ if torch.is_autocast_enabled():
+ if tensor.device.type == 'cuda':
+ dtype = torch.get_autocast_gpu_dtype()
+ elif tensor.device.type == 'cpu':
+ dtype = torch.get_autocast_cpu_dtype()
+ else:
+ raise NotImplementedError()
+ return tensor.to(dtype=dtype)
+ return tensor
diff --git a/build/torch29-cxx11-cu128-x86_64-linux/_layers/dmlp_registry.py b/build/torch29-cxx11-cu128-x86_64-linux/_layers/dmlp_registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..de2ed047042e438c7190ebb139b6f7f30009734c
--- /dev/null
+++ b/build/torch29-cxx11-cu128-x86_64-linux/_layers/dmlp_registry.py
@@ -0,0 +1,42 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Union
+
+from . import glu, mlp
+from .arguments import Arguments
+
+MlpType = Union[mlp.SparseMLP, glu.SparseGLU]
+
+_REGISTRY = {
+ 'mlp': {
+ 'grouped': mlp.GroupedMLP,
+ 'sparse': mlp.SparseMLP,
+ },
+ 'glu': {
+ 'grouped': glu.GroupedGLU,
+ 'sparse': glu.SparseGLU,
+ },
+}
+
+
+def get(args: Arguments) -> MlpType:
+ """Returns an MLP for use in a dMoE instance.
+
+ Uses the provided arguments to instantiate the appropriate
+ MLP instance. This only contains MLPs for use in dMoEs
+ (ie. only for the dropless versions of MoEs).
+
+ Args:
+ args: propagated Arguments dataclass.
+
+ Returns:
+ An instantiated MLP constructed using the input args.
+ """
+ if args.mlp_type not in _REGISTRY:
+ raise ValueError(f'Unsupported mlp type: {args.mlp_type}')
+
+ if args.mlp_impl not in _REGISTRY[args.mlp_type]:
+ raise ValueError(f'{args.mlp_type} does not support {args.mlp_impl} backend.',)
+
+ return _REGISTRY[args.mlp_type][args.mlp_impl](args)
diff --git a/build/torch29-cxx11-cu128-x86_64-linux/_layers/dmoe.py b/build/torch29-cxx11-cu128-x86_64-linux/_layers/dmoe.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d0375a4df2f27134c4127e60be04f3b45693050
--- /dev/null
+++ b/build/torch29-cxx11-cu128-x86_64-linux/_layers/dmoe.py
@@ -0,0 +1,337 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import numpy as np
+import torch
+
+# try:
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/dmoe.py is needed.',
+# )
+
+# import megablocks.ops as ops
+# # from megablocks.ops import ops
+# from megablocks.layers import common, dmlp_registry, moe, mpu
+# from megablocks.layers.arguments import Arguments
+
+from .. import stk
+from .. import ops
+from . import common, dmlp_registry, moe, mpu
+from .arguments import Arguments
+
+def promote_scalar(x):
+ return x.view(1) if not len(x.size()) else x
+
+
+class ParallelDroplessMLP(moe.ParallelMLP):
+
+ def __init__(self, args: Arguments):
+ super(ParallelDroplessMLP, self).__init__(args)
+ self.hidden_size = args.hidden_size
+ self.ffn_hidden_size = mpu.features_per_rank(args)
+ self.blocking = 128
+ self.mlp = dmlp_registry.get(args)
+
+ # Calculate the number of bits needed to represent the column indices
+ # in the intermediate sparse matrix.
+ max_column_index = ((self.ffn_hidden_size * self.num_experts) // self.blocking)
+ self.transpose_sort_end_bit = max(
+ int(np.ceil(np.log2(max_column_index))),
+ 1,
+ )
+
+ def sparse_transpose(self, size, row_indices, column_indices, offsets):
+ block_columns = size[1] // self.blocking
+
+ # Sort row indices by column indices to get the transposed matrix's
+ # column indices.
+ #
+ # NOTE: Our sort operation uses the same width indices as the input values.
+ # To avoid overflow when we have large activation matrices we cast to
+ # 32-bit before sorting.
+ _, gather_indices = ops.sort(
+ column_indices.int(),
+ self.transpose_sort_end_bit,
+ )
+
+ # There are a constant number of blocks in every row of the sparse matrix.
+ # A blocks offset is:
+ #
+ # row_index * blocks_per_row + column_index % blocks_per_row
+ #
+ # Once we have the block offsets ordered for transposition we can divide
+ # by blocks_per_row to get the transposed column indices.
+ column_indices_t = row_indices.gather(0, gather_indices.long())
+ block_offsets_t = gather_indices.int()
+
+ zero = torch.zeros((1,), dtype=torch.int32, device=row_indices.device)
+ nnz_per_column = ops.histogram(column_indices, block_columns)
+ nnz_per_column = ops.inclusive_cumsum(nnz_per_column, 0)
+ if nnz_per_column.dim() == 0:
+ # This addresses an edge case when ffn_hidden_size is equal to self.blocking.
+ nnz_per_column = nnz_per_column.unsqueeze(0)
+ offsets_t = torch.cat([zero, nnz_per_column])
+ return column_indices_t, offsets_t, block_offsets_t
+
+ def topology(self, x, padded_bins):
+ padded_tokens, _ = x.size()
+ assert padded_tokens % self.blocking == 0
+ if self.ffn_hidden_size % self.blocking != 0:
+ raise ValueError(
+ f'The ffn_hidden_size {self.ffn_hidden_size} must be divisible by ' +
+ f'the block size {self.blocking}. Please update your configuration.',
+ )
+
+ # Offsets for the sparse matrix. All rows have the
+ # same number of nonzero blocks dictated by the
+ # dimensionality of a single expert.
+ block_rows = padded_tokens // self.blocking
+ blocks_per_row = self.ffn_hidden_size // self.blocking
+ offsets = torch.arange(
+ 0,
+ block_rows * blocks_per_row + 1,
+ blocks_per_row,
+ dtype=torch.int32,
+ device=x.device,
+ )
+
+ # Indices for the sparse matrix. The indices for
+ # the intermediate matrix are dynamic depending
+ # on the mapping of tokens to experts.
+ column_indices = ops.topology(
+ padded_bins,
+ self.blocking,
+ block_rows,
+ blocks_per_row,
+ )
+
+ # TODO(tgale): This is unused. Remove the need for this in stk.
+ # For now, use meta init to save the device memory.
+ data = torch.empty(
+ column_indices.numel(),
+ self.blocking,
+ self.blocking,
+ dtype=common.dtype(self.args),
+ device='meta',
+ )
+ shape = (
+ padded_tokens,
+ self.ffn_hidden_size * mpu.experts_per_rank(self.args),
+ )
+ row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
+ column_indices_t, offsets_t, block_offsets_t = self.sparse_transpose(
+ shape,
+ row_indices,
+ column_indices,
+ offsets,
+ )
+ return stk.Matrix(
+ shape,
+ data,
+ row_indices,
+ column_indices,
+ offsets,
+ column_indices_t,
+ offsets_t,
+ block_offsets_t,
+ )
+
+ def indices_and_padded_bins(self, top_experts):
+ # Sort the expert ids to produce the scatter/gather
+ # indices for the permutation.
+ top_experts = top_experts.int()
+ bin_ids, indices = ops.sort(top_experts, self.sort_end_bit)
+
+ # Histogram the expert ids to identify the number of
+ # tokens routed to each expert.
+ tokens_per_expert = ops.histogram(top_experts, self.num_experts)
+
+ # Round the token counts up to the block size used in
+ # the matrix muliplications. Caculate the starting
+ # position of each bin.
+ padded_tokens_per_expert = ops.round_up(
+ tokens_per_expert,
+ self.blocking,
+ )
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ padded_bins = promote_scalar(padded_bins)
+
+ # Calculate the bin bounds for the sorted tokens.
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ bins = promote_scalar(bins)
+ return indices, bin_ids, bins, padded_bins, tokens_per_expert
+
+ def sparse_forward_once(self, x, expert_weights, top_experts):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, padded_bins, tokens_per_expert = (self.indices_and_padded_bins(top_experts))
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.padded_gather(
+ x,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ self.top_k,
+ )
+
+ # Create the sparse matrix topology.
+ with torch.no_grad():
+ topo = self.topology(x, padded_bins)
+
+ # Perform the expert computation.
+ x = self.mlp(x, topo)
+
+ # Un-route the data for the MoE output.
+ x = ops.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ padded_bins,
+ self.top_k,
+ )
+ return x, tokens_per_expert
+
+ # For use in the base-class parallel_forward_once.
+ def sparse_permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy, # unused
+ top_k,
+ ):
+
+ # Round the token counts up to the block size used in the matrix
+ # multiplication. Calculate the starting position of each bin.
+ padded_tokens_per_expert = ops.round_up(
+ tokens_per_expert,
+ self.blocking,
+ )
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ padded_bins = promote_scalar(padded_bins)
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
+
+ # Create the sparse matrix topology.
+ with torch.no_grad():
+ topo = self.topology(x, padded_bins)
+
+ # Perform the expert computation.
+ x = self.mlp(x, topo)
+
+ # Un-route the data for the MoE output.
+ return ops.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ def grouped_forward_once(self, x, expert_weights, top_experts):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ out = self.grouped_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ -1, # unused
+ self.args.moe_top_k,
+ )
+ return out, tokens_per_expert
+
+ def grouped_permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy, # unused
+ top_k,
+ ):
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.gather(x, indices, bin_ids, bins, top_k)
+
+ # Perform the expert computation.
+ x = self.mlp(x, tokens_per_expert)
+
+ # Un-route the data for the MoE output.
+ return ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
+
+ def forward_once(self, x, expert_weights, top_experts):
+ if self.args.mlp_impl == 'sparse':
+ return self.sparse_forward_once(x, expert_weights, top_experts)
+ else:
+ return self.grouped_forward_once(x, expert_weights, top_experts)
+
+ def permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ ):
+ if self.args.mlp_impl == 'sparse':
+ return self.sparse_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ )
+ else:
+ return self.grouped_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ )
+
+
+class dMoE(moe.MoE):
+
+ def _init_experts_mlp(self, args: Arguments):
+ return ParallelDroplessMLP(args)
diff --git a/build/torch29-cxx11-cu128-x86_64-linux/_layers/gelu.py b/build/torch29-cxx11-cu128-x86_64-linux/_layers/gelu.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4c9e6532798615b5c12c96694241a4c18ee8f7b
--- /dev/null
+++ b/build/torch29-cxx11-cu128-x86_64-linux/_layers/gelu.py
@@ -0,0 +1,52 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# try:
+# import stk
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/gelu.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+import torch.nn.functional as F
+
+
+@torch.jit.script
+def _gelu_backward_inplace(g, x):
+ tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
+ ff = (0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out))
+ return g.mul_(ff)
+
+
+def gelu_backward_(grad: stk.Matrix, x: stk.Matrix):
+ # NOTE: The two sparse matrices must have the same topology.
+ if isinstance(grad, stk.Matrix) and isinstance(x, stk.Matrix):
+ return stk.Matrix(
+ x.size(),
+ _gelu_backward_inplace(grad.data, x.data),
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
+ return _gelu_backward_inplace(grad, x)
+
+
+def gelu(x: stk.Matrix):
+ assert isinstance(x, stk.Matrix)
+ return stk.Matrix(
+ x.size(),
+ F.gelu(x.data, approximate='tanh'),
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
diff --git a/build/torch29-cxx11-cu128-x86_64-linux/_layers/glu.py b/build/torch29-cxx11-cu128-x86_64-linux/_layers/glu.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f297a41ff6a1a2a285f5b461951672364b898da
--- /dev/null
+++ b/build/torch29-cxx11-cu128-x86_64-linux/_layers/glu.py
@@ -0,0 +1,244 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# import stk.ops
+# try:
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/glu.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+
+# from megablocks import grouped_gemm_util as gg
+# from megablocks.layers import common, mpu
+# from megablocks.layers.activation_fn import act_fn
+# from megablocks.layers.arguments import Arguments
+# from megablocks.layers.mlp import (
+# SharedMLP,
+# SparseMLP,
+# create_dmoe_expert_weights,
+# resolve_dtensor,
+# )
+
+from .. import grouped_gemm_util as gg
+from . import common, mpu
+from .activation_fn import act_fn
+from .arguments import Arguments
+from .mlp import (
+ SharedMLP,
+ SparseMLP,
+ create_dmoe_expert_weights,
+ resolve_dtensor,
+)
+
+
+class SparseGLU(SparseMLP):
+
+ def __init__(self, args: Arguments):
+ super().__init__(args)
+ self.v1 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ with torch.no_grad():
+ self.v1.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ ),
+ )
+
+ mpu.set_expert_model_parallel_attributes(
+ self.v1,
+ self._should_set_parallelism_attribute,
+ )
+
+ def forward(self, x, topo):
+ if self.args.memory_optimized_mlp:
+ raise NotImplementedError(
+ 'Memory optimized implementation not yet supported with GLU with sparse kernels.',
+ )
+
+ w1, v1, w2 = self.scale_grad(self.w1), self.scale_grad(self.v1,), self.scale_grad(self.w2)
+ w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2)
+
+ # Compute the GLU.
+ x1 = stk.ops.sdd(x, w1.t(), topo)
+ x2 = stk.ops.sdd(x, v1.t(), topo)
+
+ activation_fn_out = act_fn(x1, self.args.activation_fn)
+ x1 = stk.ops.mul(activation_fn_out, x2)
+
+ return stk.ops.dsd(x1, w2)
+
+
+class MemoryOptimizedGroupedGLU(torch.autograd.Function):
+ """GroupedMLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, v1, w2, batch_sizes, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ v1 = v1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], v1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not v1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1', 'v1' and 'w2'.")
+
+ # Layer 0: x @ w1.t().
+ assert gg.backend is not None
+ sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True)
+ v1_out = gg.backend.gmm(x, v1, batch_sizes, trans_b=True)
+
+ # GeLU.
+ activation_fn_out = activation_fn(sdd_out) * v1_out
+
+ # Layer 1: x @ w2.
+ dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, v1, w2, batch_sizes, x, sdd_out, v1_out)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx, ddsd_out):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # Unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, v1, w2 = saved_tensors[:3]
+ batch_sizes = saved_tensors[3]
+ x = saved_tensors[4]
+ sdd_out, v1_out = saved_tensors[5:7]
+
+ # Rematerialize activation_fn output.
+ activation_fn = ctx.activation_fn
+ with torch.set_grad_enabled(True):
+ sdd_out.requires_grad = True
+ v1_out.requires_grad = True
+ activation_fn_out = activation_fn(sdd_out) * v1_out
+ activation_grad_fn = activation_fn_out.backward
+
+ # Compute dw2 with recomputed activation_fn output.
+ assert gg.backend is not None
+ dw2 = gg.backend.gmm(
+ activation_fn_out,
+ ddsd_out,
+ batch_sizes,
+ trans_a=True,
+ )
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ gg.backend.gmm(
+ ddsd_out,
+ w2,
+ batch_sizes,
+ trans_b=True,
+ c=dactivation_fn_out,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out)
+ dsdd_out = sdd_out.grad
+ dv1_out = v1_out.grad
+
+ # Compute dw1.
+ dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True)
+
+ # Compute dv1.
+ dv1 = gg.backend.gmm(dv1_out, x, batch_sizes, trans_a=True)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ dx = ddsd_out
+ gg.backend.gmm(dsdd_out, w1, batch_sizes, c=dx)
+ dx += gg.backend.gmm(dv1_out, v1, batch_sizes)
+ return dx, dw1, dv1, dw2, None, None
+
+
+memory_optimized_grouped_glu = MemoryOptimizedGroupedGLU.apply
+
+
+class GroupedGLU(SparseGLU):
+
+ def forward(self, x, tokens_per_expert):
+ batch_sizes = tokens_per_expert.cpu().to(torch.long)
+ w1, v1, w2 = (
+ self.scale_grad(self.w1),
+ self.scale_grad(self.v1),
+ self.scale_grad(self.w2),
+ )
+ w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2)
+
+ # Re-shape the weights for the grouped GEMMs.
+ ne = mpu.experts_per_rank(self.args)
+ w1 = w1.view(ne, -1, self.args.hidden_size)
+ v1 = v1.view(ne, -1, self.args.hidden_size)
+ w2 = w2.view(ne, -1, self.args.hidden_size)
+
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_grouped_glu(
+ x,
+ w1,
+ v1,
+ w2,
+ batch_sizes,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ assert gg.ops is not None
+ x1 = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
+ x2 = gg.ops.gmm(x, v1, batch_sizes, trans_b=True)
+ x1 = self.args.activation_fn(x1) * x2
+ return gg.ops.gmm(x1, w2, batch_sizes)
+
+
+class SharedGLU(SharedMLP):
+ """GPU for shared expert.
+
+ Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTGLU class
+ """
+
+ def __init__(self, args: Arguments):
+ super().__init__(args)
+ self.gate_proj = args.fc_cls(
+ args.hidden_size,
+ self.args.shared_expert_hidden_size,
+ **self.fc_kwargs,
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))
diff --git a/build/torch29-cxx11-cu128-x86_64-linux/_layers/memory_test.py b/build/torch29-cxx11-cu128-x86_64-linux/_layers/memory_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..74d1166931b712635131985b25a89f4ca23e576d
--- /dev/null
+++ b/build/torch29-cxx11-cu128-x86_64-linux/_layers/memory_test.py
@@ -0,0 +1,103 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import gc
+
+import torch
+import torch.distributed as dist
+
+# from megablocks.layers import arguments, dmoe
+from . import arguments, dmoe
+
+_TESTS = ((8, 2048, 4096, 4096, 32, 4),)
+
+
+def get_tensors():
+ ptrs = set()
+ out = []
+ for obj in gc.get_objects():
+ if torch.is_tensor(obj):
+ if not obj.is_contiguous() or obj.data_ptr() in ptrs:
+ continue
+ out.append(obj)
+ ptrs.add(obj.data_ptr())
+ return out
+
+
+def test_memory(
+ group,
+ batch_size,
+ sequence_length,
+ hidden_size,
+ ffn_hidden_size,
+ num_experts,
+ top_k,
+):
+ args = arguments.Arguments(
+ hidden_size=hidden_size,
+ ffn_hidden_size=ffn_hidden_size,
+ moe_num_experts=num_experts,
+ moe_top_k=top_k,
+ moe_expert_model_parallelism=True,
+ expert_parallel_group=group,
+ fp16=False,
+ bf16=True,
+ device=torch.cuda.current_device(),
+ )
+ layer = dmoe.dMoE(args).cuda()
+
+ x = torch.randn((batch_size, sequence_length, hidden_size),
+ device=torch.cuda.current_device(),
+ dtype=torch.bfloat16).requires_grad_(True)
+ torch.cuda.empty_cache()
+
+ # Run forward + backward.
+ # with torch.autograd.detect_anomaly():
+ out, _ = layer(x)
+ out.mean().backward()
+
+ # Report peak memory.
+ mem = torch.cuda.max_memory_allocated()
+ print('Max Memory Allocated = {:0.0f}MiB'.format(mem / 1e6))
+ print('Max Memory Reserved = {:0.0f}MiB'.format(torch.cuda.max_memory_reserved() / 1e6,),)
+
+ # Calculate weight and gradient memory usage.
+ weight_memory = 2 * (
+ layer.router.layer.weight.numel() + layer.experts.mlp.w1.numel() + layer.experts.mlp.w2.numel()
+ )
+
+ def grad_numel(x):
+ if x.grad is not None:
+ return x.grad.numel()
+ return 0
+
+ grad_memory = 2 * (
+ grad_numel(layer.router.layer.weight) + grad_numel(layer.experts.mlp.w1) + grad_numel(layer.experts.mlp.w2)
+ )
+ weight_memory += grad_memory
+
+ print('Weight Memory Allocated = {:0.0f}MiB'.format(weight_memory / 1e6))
+ print('Activation Memory Allocated = {:0.0f}MiB'.format((mem - weight_memory) / 1e6,),)
+
+ # Manually calculate GPU memory usage from the garbage
+ # collector.
+ gc.collect()
+ total = 0
+ tensors = get_tensors()
+ tensors = sorted(tensors, key=lambda x: -x.numel())
+ for i, t in enumerate(tensors):
+ total += t.numel()
+ print(f'{i}: {t.shape}, {t.numel() * 2}')
+ del tensors
+
+ print('Total Bytes Found = {:0.0f}MiB'.format(total * 2 / 1e6))
+
+
+if __name__ == '__main__':
+ assert dist.is_available()
+ group = dist.init_process_group(backend='nccl')
+ local_rank = dist.get_rank(group)
+ torch.cuda.set_device(local_rank)
+
+ for args in _TESTS:
+ test_memory(group, *args)
diff --git a/build/torch29-cxx11-cu128-x86_64-linux/_layers/mlp.py b/build/torch29-cxx11-cu128-x86_64-linux/_layers/mlp.py
new file mode 100644
index 0000000000000000000000000000000000000000..c99afb9904c24a8b6a83e79059cd1251dbbfd99e
--- /dev/null
+++ b/build/torch29-cxx11-cu128-x86_64-linux/_layers/mlp.py
@@ -0,0 +1,587 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# try:
+# import stk
+# import stk.backend.triton_kernels
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/mlp.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+from packaging import version
+
+# from megablocks import grouped_gemm_util as gg
+# from megablocks.layers import common, gelu, mpu
+# from megablocks.layers.activation_fn import act_fn
+# from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
+
+from .. import grouped_gemm_util as gg
+from . import common, gelu, mpu
+from .activation_fn import act_fn
+from .arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
+
+class ScaleGradient(torch.autograd.Function):
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx: Any, x: torch.Tensor, scale: float):
+ ctx.scale = scale
+ return x
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx: torch.Tensor, grad: torch.Tensor):
+ return grad * ctx.scale, None
+
+
+scale_gradient = ScaleGradient.apply
+
+
+def resolve_dtensor(weight: torch.Tensor):
+ if version.parse(torch.__version__) >= version.parse('2.0.0'):
+ from torch.distributed._tensor import DTensor
+ if isinstance(weight, DTensor):
+ return weight.to_local()
+ return weight
+
+
+def create_moe_expert_weights(
+ args: Arguments,
+ num_experts: int,
+ ffn_hidden_size: int,
+ hidden_size: int,
+ init_method: InitFn,
+):
+ # Create the entire weight matrix such that the sampled weights will
+ # not vary between data parallelism and expert model parallelism for
+ # the same random seed.
+ master_weights = torch.empty(
+ num_experts,
+ ffn_hidden_size,
+ hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ )
+ init_method(master_weights)
+
+ if not args.moe_expert_model_parallelism:
+ return master_weights
+
+ # Calculate the amount of sharding in each dimension.
+ expert_sharding_degree = mpu.expert_sharding_degree(args)
+ hidden_sharding_degree = mpu.hidden_sharding_degree(args)
+
+ # Calculate the experts per rank.
+ #
+ # NOTE: We assign ranks to be expert parallel before going
+ # tensor parallel.
+ rank = mpu.get_expert_parallel_rank(args)
+ expert_rank = rank % expert_sharding_degree
+ num_experts_per_rank = num_experts // expert_sharding_degree
+ start_expert = expert_rank * num_experts_per_rank
+ end_expert = (expert_rank + 1) * num_experts_per_rank
+
+ # Calculate the rows per rank.
+ row_rank = rank // expert_sharding_degree
+ num_rows_per_rank = ffn_hidden_size // hidden_sharding_degree
+ start_row = row_rank * num_rows_per_rank
+ end_row = (row_rank + 1) * num_rows_per_rank
+
+ # Slice the weight matrix to get the chunk for this rank.
+ with torch.no_grad():
+ weights = master_weights[start_expert:end_expert, start_row:end_row]
+ return weights
+
+
+class MLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ # expert_parallel_world_size = mpu.get_expert_parallel_world_size(args)
+ experts_per_rank = mpu.experts_per_rank(args)
+
+ self.w1 = torch.nn.Parameter(
+ torch.empty(
+ experts_per_rank,
+ args.hidden_size,
+ mpu.features_per_rank(args),
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ self.w2 = torch.nn.Parameter(
+ torch.empty(
+ experts_per_rank,
+ mpu.features_per_rank(args),
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w1,
+ args.moe_expert_model_parallelism,
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w2,
+ args.moe_expert_model_parallelism,
+ )
+
+ # Initialize the parameters for the MLP.
+ #
+ # NOTE: It is important that we create the weight tensors prior
+ # to creating the master weights and slicing our the piece for
+ # this rank. If the master weights are created first the PyTorch
+ # caching allocator appears to use the same memory block for these
+ # and the slice which causes large increases in our peak memory
+ # usage.
+ with torch.no_grad():
+ w1 = create_moe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ )
+ self.w1.copy_(w1.transpose(1, 2).contiguous())
+ self.w2.copy_(
+ create_moe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.output_layer_init_method,
+ ),
+ )
+
+ self.gradient_scale = None
+ if self.args.moe_expert_model_parallelism:
+ self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,)
+
+ def scale_grad(self, w):
+ if self.gradient_scale is None:
+ return w
+ return scale_gradient(w, self.gradient_scale)
+
+ def forward(self, x):
+ w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2)
+ w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2)
+ x = torch.bmm(x, w1)
+ x = self.args.activation_fn(x)
+ return torch.bmm(x, w2)
+
+
+def create_dmoe_expert_weights(
+ args: Arguments,
+ num_experts: int,
+ rows: int,
+ columns: int,
+ init_method: InitFn,
+):
+ weights = create_moe_expert_weights(
+ args,
+ num_experts,
+ rows,
+ columns,
+ init_method,
+ )
+ return weights.view([-1, columns])
+
+
+class MemoryOptimizedMLP(torch.autograd.Function):
+ """Sparse MLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, w2, topo, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.")
+
+ topo_tensors = (
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets,
+ topo.column_indices_t,
+ topo.offsets_t,
+ topo.block_offsets_t,
+ )
+
+ # Layer 0: x @ w1.t().
+ sdd_out = stk.ops.sdd(x, w1.t(), topo)
+
+ # GeLU.
+ activation_fn_out = act_fn(sdd_out, activation_fn)
+
+ # Layer 1: x @ w2.
+ dsd_out = stk.ops.dsd(activation_fn_out, w2)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.shape = topo.shape
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.data.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, w2, *topo_tensors, x, sdd_out.data)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx, ddsd_out):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, w2 = saved_tensors[:2]
+ topo_tensors = saved_tensors[2:8]
+ x = saved_tensors[8]
+ sdd_out_data = saved_tensors[9]
+
+ # rematerialize activation function output
+ activation_fn = ctx.activation_fn
+ sdd_out = stk.Matrix(ctx.shape, sdd_out_data, *topo_tensors)
+ activation_fn_out, activation_grad_fn = act_fn(
+ sdd_out,
+ activation_fn,
+ return_grad_fn=True,
+ )
+
+ # Compute dw2 with recomputed activation_fn output.
+ dw2 = stk.ops.dsd(activation_fn_out.t(), ddsd_out)
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ stk.backend.triton_kernels.sdd(
+ ddsd_out,
+ w2.t(),
+ dactivation_fn_out.shape,
+ dactivation_fn_out.data,
+ dactivation_fn_out.offsets,
+ dactivation_fn_out.row_indices,
+ dactivation_fn_out.column_indices,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ if activation_fn is DEFAULT_ACTIVATION_FN:
+ dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out)
+ else:
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out.data)
+ dsdd_out = stk.Matrix(ctx.shape, sdd_out.data.grad, *topo_tensors)
+
+ # Compute dw1.
+ dw1 = stk.ops.dsd(dsdd_out.t(), x)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ stk.backend.triton_kernels.dsd(
+ dsdd_out.shape,
+ dsdd_out.data,
+ dsdd_out.offsets,
+ dsdd_out.row_indices,
+ dsdd_out.column_indices,
+ dsdd_out.offsets_t,
+ dsdd_out.column_indices_t,
+ dsdd_out.block_offsets_t,
+ False,
+ w1,
+ ddsd_out,
+ )
+ dx = ddsd_out
+ return dx, dw1, dw2, None, None
+
+
+memory_optimized_mlp = MemoryOptimizedMLP.apply
+
+
+class SparseMLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ self._num_rows_per_rank = mpu.experts_per_rank(args) * mpu.features_per_rank(args)
+
+ self.w1 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ self.w2 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+
+ # Initialize the parameters for the MLP.
+ #
+ # NOTE: It is important that we create the weight tensors prior
+ # to creating the master weights and slicing our the piece for
+ # this rank. If the master weights are created first the PyTorch
+ # caching allocator appears to use the same memory block for these
+ # and the slice which causes large increases in our peak memory
+ # usage.
+ with torch.no_grad():
+ self.w1.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ ),
+ )
+ self.w2.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.output_layer_init_method,
+ ),
+ )
+
+ self._should_set_parallelism_attribute = args.moe_expert_model_parallelism
+ mpu.set_expert_model_parallel_attributes(
+ self.w1,
+ self._should_set_parallelism_attribute,
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w2,
+ self._should_set_parallelism_attribute,
+ )
+
+ self.gradient_scale = None
+ if self.args.moe_expert_model_parallelism:
+ self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,)
+
+ def scale_grad(self, w):
+ if self.gradient_scale is None:
+ return w
+ return scale_gradient(w, self.gradient_scale)
+
+ def forward(self, x, topo):
+ w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2)
+ w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2)
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_mlp(
+ x,
+ w1,
+ w2,
+ topo,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ x = stk.ops.sdd(x, w1.t(), topo)
+ activation_fn_out = act_fn(x, self.args.activation_fn)
+ return stk.ops.dsd(activation_fn_out, w2)
+
+
+class MemoryOptimizedGroupedMLP(torch.autograd.Function):
+ """GroupedMLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, w2, batch_sizes, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.")
+
+ # Layer 0: x @ w1.t().
+ assert gg.backend is not None
+ sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True)
+
+ # activation_fn
+ activation_fn_out = activation_fn(sdd_out)
+
+ # Layer 1: x @ w2.
+ dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, w2, batch_sizes, x, sdd_out)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx: Any, ddsd_out: torch.Tensor):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # Unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, w2 = saved_tensors[:2]
+ batch_sizes = saved_tensors[2]
+ x = saved_tensors[3]
+ sdd_out = saved_tensors[4]
+
+ # Rematerialize activation_fn output.
+ activation_fn = ctx.activation_fn
+ with torch.set_grad_enabled(True):
+ sdd_out.requires_grad = True
+ activation_fn_out = activation_fn(sdd_out)
+ activation_grad_fn = activation_fn_out.backward
+
+ # Compute dw2 with recomputed activation_fn output.
+ assert gg.backend is not None
+ dw2 = gg.backend.gmm(
+ activation_fn_out,
+ ddsd_out,
+ batch_sizes,
+ trans_a=True,
+ )
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ gg.backend.gmm(
+ ddsd_out,
+ w2,
+ batch_sizes,
+ trans_b=True,
+ c=dactivation_fn_out,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ if activation_fn is DEFAULT_ACTIVATION_FN:
+ dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out)
+ else:
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out)
+ dsdd_out = sdd_out.grad
+
+ # Compute dw1.
+ dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ gg.backend.gmm(dsdd_out, w1, batch_sizes, c=ddsd_out)
+ dx = ddsd_out
+ return dx, dw1, dw2, None, None
+
+
+memory_optimized_grouped_mlp = MemoryOptimizedGroupedMLP.apply
+
+
+class GroupedMLP(SparseMLP):
+
+ def forward(self, x, tokens_per_expert):
+ batch_sizes = tokens_per_expert.cpu().to(torch.long)
+ w1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.w2))
+
+ # Re-shape the weights for the grouped GEMMs.
+ ne = mpu.experts_per_rank(self.args)
+ w1 = resolve_dtensor(w1).view(ne, -1, self.args.hidden_size)
+ w2 = resolve_dtensor(w2).view(ne, -1, self.args.hidden_size)
+
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_grouped_mlp(
+ x,
+ w1,
+ w2,
+ batch_sizes,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ assert gg.ops is not None
+ x = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
+ x = self.args.activation_fn(x)
+ return gg.ops.gmm(x, w2, batch_sizes)
+
+
+class SharedMLP(torch.nn.Module):
+ """MLP for shared expert.
+
+ Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTMLP class
+ """
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ self.fc_kwargs: dict[str, Any] = {
+ 'bias': args.bias,
+ 'device': args.device,
+ }
+ self.fc_kwargs.update(args.fc_kwargs)
+
+ self.up_proj = args.fc_cls(
+ args.hidden_size,
+ args.shared_expert_hidden_size,
+ **self.fc_kwargs,
+ )
+ self.act = args.activation_fn
+ self.down_proj = args.fc_cls(
+ args.shared_expert_hidden_size,
+ args.hidden_size,
+ **self.fc_kwargs,
+ )
+ self.down_proj._is_residual = True # a flag for llm-foundry init
+
+ def add_experts_sharedexpert(
+ self,
+ shared_expert_out: torch.Tensor,
+ expert_out: torch.Tensor,
+ ) -> torch.Tensor:
+ # Helper function to add expert output to shared expert output
+ # with optional weighted sum.
+ if self.args.shared_expert_weighted_sum:
+ # enable using weighted sum for shared expert output
+ # wieghted by number of experts used
+ t_experts = self.args.moe_top_k + 1
+ sh_mlp_out = shared_expert_out / t_experts
+ return sh_mlp_out.add(
+ expert_out,
+ alpha=(self.args.moe_top_k / t_experts),
+ )
+
+ return shared_expert_out + expert_out
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.down_proj(self.act(self.up_proj(x)))
diff --git a/build/torch29-cxx11-cu128-x86_64-linux/_layers/moe.py b/build/torch29-cxx11-cu128-x86_64-linux/_layers/moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0a4aeaacc9c86fc70944e730c53f7a55644e05e
--- /dev/null
+++ b/build/torch29-cxx11-cu128-x86_64-linux/_layers/moe.py
@@ -0,0 +1,507 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Optional, Tuple
+
+import numpy as np
+import torch
+import torch.distributed as dist
+
+# import megablocks.ops as ops
+# from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry
+# from megablocks.layers.all_to_all import all_to_all
+# from megablocks.layers.arguments import Arguments
+
+from ..ops import (
+ sort,
+ histogram,
+ inclusive_cumsum,
+ exclusive_cumsum,
+ binned_gather,
+ binned_scatter,
+ gather,
+ scatter,
+ repeat,
+ replicate,
+)
+
+from . import common, mlp, mpu, router, sharedexpert_registry
+from .arguments import Arguments
+from .all_to_all import all_to_all
+
+_LOAD_BALANCING_LOSS = []
+
+
+def save_load_balancing_loss(loss):
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.append(loss)
+
+
+def get_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ return _LOAD_BALANCING_LOSS
+
+
+def clear_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.clear()
+
+
+def batched_load_balancing_loss(args: Arguments):
+ if args.moe_loss_weight == 0:
+ return 0.0
+
+ # tokens_per_expert[i].shape = (num_experts)
+ # expert_scores[i].shape = (tokens, num_experts)
+ tokens_per_expert, expert_scores = zip(*get_load_balancing_loss())
+ num_layers_per_pipeline_stage = (args.num_layers // args.pipeline_model_parallel_size)
+ if args.num_layers_per_virtual_pipeline_stage is not None:
+ num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage
+
+ if len(tokens_per_expert) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f'Expected {num_layers_per_pipeline_stage} token_per_experts '
+ f'but found {len(tokens_per_expert)}.\nnum_layers = '
+ f'{args.num_layers}\npipeline_model_parallel_size = '
+ f'{args.pipeline_model_parallel_size}\n'
+ 'num_layers_per_virtual_pipeline_stage'
+ f' = {args.num_layers_per_virtual_pipeline_stage}',
+ )
+ if len(expert_scores) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f'Expected {num_layers_per_pipeline_stage} expert_scores '
+ f'but found {len(tokens_per_expert)}.\nnum_layers = '
+ f'{args.num_layers}\npipeline_model_parallel_size = '
+ f'{args.pipeline_model_parallel_size}\n'
+ 'num_layers_per_virtual_pipeline_stage'
+ f' = {args.num_layers_per_virtual_pipeline_stage}',
+ )
+
+ # Verify the shape of the tokens_per_expert and expert_scores tensors.
+ assert all((x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert))
+
+ tokens = expert_scores[0].shape[0]
+ assert all(((x.ndim == 2 and x.shape[1] == args.moe_num_experts and x.shape[0] == tokens) for x in expert_scores))
+
+ # Concatenate the contributions of each layer and convert to
+ # the correct types and formats for the dot product.
+ expert_scores = torch.cat(expert_scores, dim=1)
+ if args.moe_lbl_in_fp32:
+ expert_scores = expert_scores.float()
+ if tokens != 0:
+ expert_scores = expert_scores.mean(dim=0)
+ else:
+ expert_scores = expert_scores.sum(dim=0)
+ tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype)
+
+ expected_values = num_layers_per_pipeline_stage * args.moe_num_experts
+ assert tokens_per_expert.numel() == expected_values
+ assert expert_scores.numel() == expected_values
+
+ # Calculate the total scale across all factors.
+ #
+ # loss_weight * num_experts / (num_layers * tokens * top_k)
+ scale_numerator = (args.moe_num_experts * args.moe_loss_weight)
+ scale_denominator = (args.num_layers * tokens * args.moe_top_k)
+ scale = scale_numerator / scale_denominator
+ return scale * torch.dot(tokens_per_expert, expert_scores)
+
+
+# NOTE: This class defines MoE expert computation, including expert model parallel
+# communication. When using FSDP on top of MegaBlocks this is the module that should
+# be wrapped s.t. the weight all-gathers can be scheduled *before* the expert model
+# parallel all2all.
+class ParallelMLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super(ParallelMLP, self).__init__()
+ self.args = args
+
+ # Calculate the number of experts in total and the number of experts
+ # owned by this rank.
+ # world_size = mpu.get_expert_parallel_world_size(args)
+ self.num_experts = args.moe_num_experts
+ self.top_k = self.args.moe_top_k
+
+ # Calculate the number of bits needed to represent the expert indices
+ # so that we can pass it to radix sort.
+ self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1)
+
+ # Expert MLP.
+ self.mlp = mlp.MLP(args)
+
+ self.bias: Optional[torch.Tensor]
+ if self.args.bias:
+ # Note that the output bias is not parallelized with expert
+ # model parallelism.
+ self.bias = torch.nn.Parameter(
+ torch.empty(
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ torch.nn.init.zeros_(self.bias)
+ else:
+ self.register_parameter('bias', None)
+
+ # Select the forward function for the operating mode.
+ self.forward_fn = (self.parallel_forward_once if args.moe_expert_model_parallelism else self.forward_once)
+
+ def expert_capacity(self, tokens: int) -> int:
+ world_size = mpu.get_expert_parallel_world_size(self.args)
+ tokens_per_expert = (self.top_k * tokens * world_size / self.num_experts)
+ return int(self.args.moe_capacity_factor * tokens_per_expert)
+
+ def load_balancing_loss(self, tokens_per_expert: torch.Tensor, expert_scores: torch.Tensor):
+ """Calculate the load balancing loss contribution."""
+ assert len(expert_scores.size()) == 2
+ tokens, num_experts = expert_scores.size()
+ assert num_experts == self.num_experts
+ assert len(tokens_per_expert.size()) == 1
+ num_experts, = tokens_per_expert.size()
+ assert num_experts == self.num_experts
+ scale = self.num_experts / (tokens * self.top_k)
+ return scale * torch.dot(
+ tokens_per_expert.to(expert_scores.dtype),
+ expert_scores.mean(dim=0),
+ )
+
+ def indices_and_bins(self,
+ top_expert: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ # Sort the expert ids to produce the scatter/gather
+ # indices for the permutation.
+ #
+ # TODO(tgale): Is it worth doing this conversion to 32-bit
+ # prior? Could we place the `torch.max` operation to return
+ # 32-bit expert indices?
+ top_expert = top_expert.int()
+ # output = ops.sort(top_expert, self.sort_end_bit)
+ output = sort(top_expert, self.sort_end_bit)
+ assert output is not None
+ bin_ids, indices = output
+
+ # Histogram the expert ids to identify the number of
+ # tokens routed to each expert.
+ #
+ # TODO(tgale): Does the sorted data produce a more favorable
+ # data distribution for histogram? Or is the op parallelism
+ # worth more?
+ # tokens_per_expert = ops.histogram(top_expert, self.num_experts)
+ tokens_per_expert = histogram(top_expert, self.num_experts)
+
+ # Calculate the bin bounds for the sorted tokens.
+ # bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ bins = inclusive_cumsum(tokens_per_expert, 0)
+ assert bins is not None
+ bins = bins.view(1) if not len(bins.size()) else bins
+
+ assert isinstance(indices, torch.Tensor)
+ assert isinstance(bin_ids, torch.Tensor)
+ assert isinstance(bins, torch.Tensor)
+ assert isinstance(tokens_per_expert, torch.Tensor)
+
+ return indices, bin_ids, bins, tokens_per_expert
+
+ def permute_and_compute(
+ self,
+ x: torch.Tensor,
+ tokens_per_expert: int, # unused
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor, # unused
+ expert_weights: torch.Tensor,
+ bins: torch.Tensor,
+ expert_capacity: int,
+ top_k: int,
+ ):
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ # output = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
+ output = binned_gather(x, indices, bins, expert_capacity, top_k)
+ assert output is not None
+ x = output
+
+ # Perform the expert computation. Note that we don't
+ # use biases for these linear operations.
+ x = self.mlp(x)
+
+ # Un-route the data for the MoE output.
+ # return ops.binned_scatter(x, indices, expert_weights, bins, top_k)
+ return binned_scatter(x, indices, expert_weights, bins, top_k)
+
+
+ def forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ # If expert_capacity is set to zero, set the number of tokens
+ # per expert to the maximum we need to avoid dropping tokens.
+ sl, bs, _ = x.size()
+ expert_capacity = self.expert_capacity(sl * bs)
+ if expert_capacity == 0:
+ expert_capacity = torch.max(tokens_per_expert).item()
+
+ x = self.permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ self.top_k,
+ )
+ return x, tokens_per_expert
+
+ def parallel_forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ # NOTE: This function implements the same computation as forward_once
+ # but with expert model parallelism.
+ #
+ # 1. Permute the tokens locally so that they are grouped by their
+ # expert assignments. This allows us to transfer all of the tokens
+ # for a remote device in one communication primitive.
+ #
+ # 2. Permute the tokens across the expert parallel devices. After
+ # this is completed each device has all of the tokens assigned to
+ # its set of experts in its local HBM.
+ #
+ # 3. Permute the tokens locally so that they are grouped by their
+ # expert assignement. After the distributed permutation the tokens
+ # are grouped by which device they came from. We re-order them
+ # locally to allow for efficient computation.
+ #
+ # After this series of permutations we compute the linear layers
+ # and then repeat these three steps in reverse to produce the final
+ # output.
+ #
+ # Compute the mapping of local tokens to experts.
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ # If we're sharding the experts along the hidden dimension
+ # multiple devices own parts of the same sets of experts.
+ # Replicate the token counts so every device gets the counts.
+ # repeated_tokens_per_expert = ops.repeat(
+ repeated_tokens_per_expert = repeat(
+ tokens_per_expert,
+ (mpu.hidden_sharding_degree(self.args),),
+ )
+
+ # Pass token count information to the device on which the
+ # target expert resides.
+ parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert,)
+ tpe_handle = dist.all_to_all_single(
+ parallel_tokens_per_expert,
+ repeated_tokens_per_expert,
+ group=self.args.expert_parallel_group,
+ async_op=True,
+ )
+
+ # Permute locally and without any padding so that tokens for each
+ # parallel device are stored contiguously.
+ #
+ # This view updates the shape of the tensor from [sl, bs, hs] to
+ # [sl * bs, hs] prior to the permutation.
+ x = x.view(-1, x.shape[-1])
+ # output = ops.gather(x, indices, bin_ids, bins, self.top_k)
+ output = gather(x, indices, bin_ids, bins, self.top_k)
+ assert output is not None
+ x = output
+
+ # Compute the number of tokens that will be received from each
+ # device and permute the input data across the devices.
+ with torch.no_grad():
+ tpe_handle.wait()
+ experts_per_rank = mpu.experts_per_rank(self.args)
+
+ # Reshape to [world_size, num_experts_per_rank].
+ world_size = mpu.get_expert_parallel_world_size(self.args)
+ repeated_tokens_per_expert = (repeated_tokens_per_expert.view(world_size, experts_per_rank))
+ parallel_tokens_per_expert = (parallel_tokens_per_expert.view(world_size, experts_per_rank))
+
+ # TODO(tgale): It might be faster to do this on the GPU and
+ # then communicate the results back to the host.
+ send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1)
+ parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
+ recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1)
+
+ # Convert the send/recv counts to lists.
+ send_counts = send_counts.tolist()
+ recv_counts = recv_counts.tolist()
+ tokens_received = sum(recv_counts)
+
+ # If we're sharding the experts along the hidden dimension
+ # multiple devices own parts of the same sets of experts.
+ # Replicate the token counts so devices that share experts
+ # get all of the tokens assigned to them.
+ #
+ # TODO(tgale): Fuse this into the prior, local permutation.
+ # x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
+ x = repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
+
+ # Start the cross-device permutation asynchronously so we can
+ # overlap communication with computation.
+ parallel_x, parallel_x_handle = all_to_all(
+ x,
+ recv_counts,
+ send_counts,
+ self.args.expert_parallel_group,
+ async_op=True,
+ )
+
+ with torch.no_grad():
+ # After we do the cross-device permutation we have the tokens on the
+ # correct device but not yet grouped by expert because we received
+ # tokens from each device as contiguous chunks. To group the tokens
+ # for expert computation we'll do one more local permutation. The
+ # rest of this torch.no_grad() scope sets up the indices and bins
+ # for this permutation.
+ # replicate_bins = ops.inclusive_cumsum(
+ replicate_bins = inclusive_cumsum(
+ parallel_tokens_per_expert.flatten(),
+ 0,
+ )
+ replicate_bins = (replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins)
+
+ # Construct the expert indices for the permuted tokens.
+ parallel_top_expert = torch.remainder(
+ torch.arange(
+ self.num_experts * mpu.hidden_sharding_degree(self.args),
+ dtype=torch.int32,
+ device=indices.device,
+ ),
+ mpu.experts_per_rank(self.args),
+ )
+ # parallel_top_expert = ops.replicate(
+ parallel_top_expert = replicate(
+ parallel_top_expert.unsqueeze(dim=0),
+ replicate_bins,
+ tokens_received,
+ ).flatten()
+
+ # TODO(tgale): The sort_end_bit here can be reduced.
+ # parallel_bin_ids, parallel_indices = ops.sort(
+ parallel_bin_ids, parallel_indices = sort(
+ parallel_top_expert,
+ self.sort_end_bit,
+ )
+
+ # Calculate the bins boundaries from the token counts.
+ parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+ # parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = (parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins)
+
+ # If expert_capacity is set to zero, set the number of tokens
+ # per expert to the maximum we need to avoid dropping tokens.
+ tokens, _ = x.size()
+ expert_capacity = self.expert_capacity(tokens)
+ if expert_capacity == 0:
+ expert_capacity = torch.max(parallel_tokens_per_expert).item()
+
+ # Locally permute the tokens and perform the expert computation.
+ # Block to make sure that the cross-device permutation is complete.
+ if self.args.mlp_impl == 'grouped':
+ # GroupedMLP requires counts on CPU. We can use the tensor already
+ # moved to CPU for the prior all_to_all, which avoids an extra
+ # device synchronization.
+ parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+ parallel_x_handle.wait()
+ parallel_x = self.permute_and_compute(
+ parallel_x,
+ parallel_tokens_per_expert,
+ parallel_indices,
+ parallel_bin_ids,
+ None, # expert_weights
+ parallel_bins,
+ expert_capacity,
+ top_k=1,
+ )
+
+ # Un-permute the tokens across the devices.
+ x, _ = all_to_all(
+ parallel_x,
+ send_counts,
+ recv_counts,
+ self.args.expert_parallel_group,
+ )
+
+ # Reduce along the hidden sharding to get the final outputs.
+ #
+ # TODO(tgale): Fuse this into the following local permutation.
+ shape = (
+ mpu.hidden_sharding_degree(self.args),
+ -1,
+ self.args.hidden_size,
+ )
+ # x = ops.sum(x.view(shape), dim=0)
+ x = x.view(shape).sum(dim=0)
+
+ # Un-permute locally to setup for the next series of operations.
+ # x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
+ x = scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
+ return x, tokens_per_expert.flatten()
+
+ def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ in_shape = x.size()
+
+ # Compute the experts.
+ x, tokens_per_expert = self.forward_fn(x, expert_weights, top_experts)
+ if self.training and self.args.moe_loss_weight > 0:
+ save_load_balancing_loss((tokens_per_expert, scores))
+ x = x.view(in_shape)
+ if self.bias is not None:
+ if self.args.return_bias:
+ return x, self.bias
+ return x + self.bias
+ return x
+
+
+class MoE(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super(MoE, self).__init__()
+
+ # Token router.
+ self.router = router.LearnedRouter(args)
+
+ # Expert computation helper.
+ self.experts = self._init_experts_mlp(args)
+
+ self.shared_expert = None
+ if args.shared_expert:
+ # SharedExpert computation helper.
+ self.shared_expert = sharedexpert_registry.get(args)
+
+ def _init_experts_mlp(self, args: Arguments):
+ return ParallelMLP(args)
+
+ def forward(self, x: torch.Tensor):
+ # NOTE: If we're going to cast the activations to lower precision
+ # do it before we permute the tokens to save bandwidth.
+ x = common.cast_if_autocast_enabled(x)
+
+ # Compute the expert scores and assignments.
+ scores, expert_weights, top_experts = self.router(x)
+
+ # Compute the experts.
+ out = self.experts(x, scores, expert_weights, top_experts)
+ if self.shared_expert is not None:
+ shared_expert_out = self.shared_expert(x)
+ out = self.shared_expert.add_experts_sharedexpert(
+ shared_expert_out,
+ out,
+ )
+ return out
diff --git a/build/torch29-cxx11-cu128-x86_64-linux/_layers/mpu.py b/build/torch29-cxx11-cu128-x86_64-linux/_layers/mpu.py
new file mode 100644
index 0000000000000000000000000000000000000000..434e143ab42bf3f83406d69e9dd1f72777716e22
--- /dev/null
+++ b/build/torch29-cxx11-cu128-x86_64-linux/_layers/mpu.py
@@ -0,0 +1,94 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Optional
+
+import torch
+import torch.distributed as dist
+
+# from megablocks.layers.arguments import Arguments
+from .arguments import Arguments
+
+
+class MoeParam(torch.Tensor):
+
+ def __init__(self):
+ super().__init__(self)
+ self.expert_model_parallel: bool
+
+
+def is_moe_param(tensor: torch.Tensor) -> bool:
+ return hasattr(tensor, 'expert_model_parallel')
+
+
+def get_expert_parallel_world_size(args: Arguments) -> int:
+ return (dist.get_world_size(args.expert_parallel_group) if args.moe_expert_model_parallelism else 1)
+
+
+def get_expert_parallel_rank(args: Arguments) -> int:
+ return (dist.get_rank(args.expert_parallel_group) if args.moe_expert_model_parallelism else 0)
+
+
+def set_expert_model_parallel_attributes(
+ tensor: torch.Tensor,
+ is_parallel: bool,
+):
+ assert not hasattr(tensor, 'expert_model_parallel')
+ setattr(tensor, 'expert_model_parallel', is_parallel)
+
+
+def param_is_expert_model_parallel(param: MoeParam) -> bool:
+ return (hasattr(param, 'expert_model_parallel') and param.expert_model_parallel)
+
+
+def copy_expert_model_parallel_attributes(
+ destination_tensor: torch.Tensor,
+ source_tensor: torch.Tensor,
+):
+ if hasattr(source_tensor, 'expert_model_parallel'):
+ setattr(
+ destination_tensor,
+ 'expert_model_parallel',
+ getattr(source_tensor, 'expert_model_parallel'),
+ )
+
+
+def synchronized_print(group: Optional[dist.ProcessGroup], *x: torch.Tensor):
+ world_size = dist.get_world_size(group)
+ rank = dist.get_rank(group)
+ for i in range(world_size):
+ dist.barrier(group)
+ if i == rank:
+ print(f'rank = {rank}', *x)
+
+
+# Helpers for expert/tensor sharding.
+def expert_sharding_degree(args: Arguments) -> int:
+ world_size = get_expert_parallel_world_size(args)
+ esd = min(world_size, args.moe_num_experts)
+
+ if (args.moe_num_experts % esd) != 0:
+ raise ValueError(f'Cannot shard {args.moe_num_experts} experts {esd} ways.',)
+ return esd
+
+
+def hidden_sharding_degree(args: Arguments) -> int:
+ world_size = get_expert_parallel_world_size(args)
+ esd = expert_sharding_degree(args)
+ hsd = world_size // esd
+
+ if (args.ffn_hidden_size % hsd) != 0:
+ raise ValueError(f'Cannot shard {args.ffn_hidden_size} features {hsd} ways.',)
+ if (esd * hsd) != world_size:
+ raise ValueError(
+ f"Invalid sharding. 'expert_sharding_degree' ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size}).",
+ )
+ return hsd
+
+
+def experts_per_rank(args: Arguments) -> int:
+ return args.moe_num_experts // expert_sharding_degree(args)
+
+
+def features_per_rank(args: Arguments) -> int:
+ return args.ffn_hidden_size // hidden_sharding_degree(args)
diff --git a/build/torch29-cxx11-cu128-x86_64-linux/_layers/router.py b/build/torch29-cxx11-cu128-x86_64-linux/_layers/router.py
new file mode 100644
index 0000000000000000000000000000000000000000..37cb2782348d62583376f1a183c7ede83601216d
--- /dev/null
+++ b/build/torch29-cxx11-cu128-x86_64-linux/_layers/router.py
@@ -0,0 +1,116 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+
+# from megablocks.layers import common
+# from megablocks.layers.arguments import Arguments
+from . import common
+from .arguments import Arguments
+
+_ROUTER_LOGITS = []
+
+
+def _save_router_logits(logits: torch.Tensor, args: Arguments):
+ if args.moe_zloss_weight == 0:
+ return
+ global _ROUTER_LOGITS
+ _ROUTER_LOGITS.append(logits)
+
+
+def clear_router_zloss():
+ global _ROUTER_LOGITS
+ _ROUTER_LOGITS.clear()
+
+
+def batched_router_zloss(args: Arguments):
+ global _ROUTER_LOGITS
+
+ if args.moe_zloss_weight == 0:
+ import warnings
+ warnings.warn('Call to batched_router_zloss, but moe_zloss_weight=0')
+ return 0
+
+ logits_per_router = _ROUTER_LOGITS
+
+ if args.moe_zloss_in_fp32:
+ logits_per_router = [logits.float() for logits in logits_per_router]
+
+ unscaled_zloss_per_router = torch.stack([
+ torch.logsumexp(logits, dim=1).square().mean() for logits in logits_per_router
+ ])
+
+ return args.moe_zloss_weight * unscaled_zloss_per_router
+
+
+# NOTE: To enable end-to-end benchmarking without convergence we
+# support a flag to force the router to assign tokens uniformly
+# across the experts. We do this with a custom autograd operation
+# so that PyTorch still executes the full set of router operation.
+class _UniformExpertAssignment(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, num_experts: int):
+ out = torch.arange(x.numel(), dtype=x.dtype, device=x.device)
+ out = torch.remainder(out, num_experts)
+ return out.view(x.shape)
+
+
+_uniform_expert_assignment = _UniformExpertAssignment.apply
+
+
+class LearnedRouter(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+
+ # Learned router parameters.
+ #
+ # NOTE: This weight matrix is not parallelized with expert model
+ # parallelism. Each device needs the entire router weight matrix
+ # so that it can route its batch of data correctly.
+ self.layer = torch.nn.Linear(
+ args.hidden_size,
+ args.moe_num_experts,
+ bias=False,
+ dtype=common.dtype(args),
+ device=args.device,
+ )
+ args.init_method(self.layer.weight)
+
+ def jitter(self, x: torch.Tensor):
+ low: float = 1.0 - self.args.moe_jitter_eps
+ high: float = 1.0 + self.args.moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return low + noise * (high - low)
+
+ def _top_k(self, scores: torch.Tensor):
+ if self.args.moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, self.args.moe_top_k, dim=-1)
+
+ def forward(self, x: torch.Tensor):
+ if self.training and self.args.moe_jitter_eps is not None:
+ x = x * self.jitter(x)
+
+ logits = self.layer(x.view(-1, x.shape[-1]))
+ _save_router_logits(logits, self.args)
+ scores = logits.softmax(dim=-1)
+ expert_weights, expert_indices = self._top_k(scores)
+ if self.args.moe_normalize_expert_weights:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=self.args.moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+
+ expert_indices = (
+ _uniform_expert_assignment(
+ expert_indices,
+ self.args.moe_num_experts,
+ ) if self.args.uniform_expert_assignment else expert_indices
+ )
+ return scores, expert_weights, expert_indices
diff --git a/build/torch29-cxx11-cu128-x86_64-linux/_layers/sharedexpert_registry.py b/build/torch29-cxx11-cu128-x86_64-linux/_layers/sharedexpert_registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..5840862f88f370ace5fd49bd0612fc98d186cc49
--- /dev/null
+++ b/build/torch29-cxx11-cu128-x86_64-linux/_layers/sharedexpert_registry.py
@@ -0,0 +1,32 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Union
+
+# from megablocks.layers import glu, mlp
+# from megablocks.layers.arguments import Arguments
+from . import glu, mlp
+from .arguments import Arguments
+
+_REGISTRY = {
+ 'mlp': mlp.SharedMLP,
+ 'glu': glu.SharedGLU,
+}
+
+
+def get(args: Arguments) -> Union[mlp.SharedMLP, glu.SharedGLU]:
+ """Returns an SharedMLP for use in a dMoE instance.
+
+ Uses the provided arguments to instantiate the appropriate
+ SharedMLP instance.
+
+ Args:
+ args: propagated Arguments dataclass.
+
+ Returns:
+ An instantiated SharedMLP constructed using the input args.
+ """
+ if args.mlp_type not in _REGISTRY:
+ raise ValueError(f'Unsupported mlp type: {args.mlp_type}')
+
+ return _REGISTRY[args.mlp_type](args)
diff --git a/build/torch29-cxx11-cu128-x86_64-linux/_megablocks_cuda_6e04dec.abi3.so b/build/torch29-cxx11-cu128-x86_64-linux/_megablocks_cuda_6e04dec.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..8bcd44dbaa99ef7a3a231720c1e2365db938586e
--- /dev/null
+++ b/build/torch29-cxx11-cu128-x86_64-linux/_megablocks_cuda_6e04dec.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0349d7de015576f9dae76f82c321d491609d1ae84bc5f2cb8053891e167a0aca
+size 20995704
diff --git a/build/torch29-cxx11-cu128-x86_64-linux/_ops.py b/build/torch29-cxx11-cu128-x86_64-linux/_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..b2f202b8db3c3f3028303ab4308cf35f950e2c74
--- /dev/null
+++ b/build/torch29-cxx11-cu128-x86_64-linux/_ops.py
@@ -0,0 +1,9 @@
+import torch
+from . import _megablocks_cuda_6e04dec
+ops = torch.ops._megablocks_cuda_6e04dec
+
+def add_op_namespace_prefix(op_name: str):
+ """
+ Prefix op by namespace.
+ """
+ return f"_megablocks_cuda_6e04dec::{op_name}"
diff --git a/build/torch29-cxx11-cu128-x86_64-linux/_version.py b/build/torch29-cxx11-cu128-x86_64-linux/_version.py
new file mode 100644
index 0000000000000000000000000000000000000000..c55783177af19bc03654c730c4892df8f8532279
--- /dev/null
+++ b/build/torch29-cxx11-cu128-x86_64-linux/_version.py
@@ -0,0 +1,6 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+"""The MegaBlocks Version."""
+
+__version__ = '0.11.0.dev0'
diff --git a/build/torch29-cxx11-cu128-x86_64-linux/backend/__init__.py b/build/torch29-cxx11-cu128-x86_64-linux/backend/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d4e43e9b7e6a9a1c3bde2df34914643ca5d8332
--- /dev/null
+++ b/build/torch29-cxx11-cu128-x86_64-linux/backend/__init__.py
@@ -0,0 +1,2 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
diff --git a/build/torch29-cxx11-cu128-x86_64-linux/backend/kernels.py b/build/torch29-cxx11-cu128-x86_64-linux/backend/kernels.py
new file mode 100644
index 0000000000000000000000000000000000000000..d20139324883338ddb312e4b05a72056d74491ac
--- /dev/null
+++ b/build/torch29-cxx11-cu128-x86_64-linux/backend/kernels.py
@@ -0,0 +1,557 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import triton
+import triton.language as tl
+
+# Stub triton autotune when testing in a env that does not have CUDA
+# this approach preserves the original code but enables testing without a GPU
+if torch.cuda.is_available() is False:
+ import warnings
+
+ warnings.warn("CUDA is not available. Triton autotuning is disabled.")
+
+ def _no_autotune(*args, **kwargs):
+ def deco(fn):
+ return fn
+ return deco
+
+ triton.autotune = _no_autotune
+
+
+def assert_is_tensor(x, ndim):
+ if x.ndim != ndim:
+ raise ValueError(f'Expected {ndim}-tensor but got {x.ndim}-tensor')
+
+
+def assert_is_matrix(x):
+ assert_is_tensor(x, 2)
+
+
+def assert_is_vector(x):
+ if x.ndim != 1:
+ raise ValueError(f'Expected 1-tensor but got {x.ndim}-tensor')
+
+
+def assert_equal(a, b):
+ if a != b:
+ raise ValueError(f'Expected dimensions to be equal but got {a} and {b}.',)
+
+
+# a: (tokens, hidden_size), real.
+# indices: (tokens * top_k), integer.
+# bin_ids: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+# padded_bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _padded_copy(
+ a,
+ b,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+ A_TO_B: tl.constexpr,
+ SCALE: tl.constexpr,
+):
+ # Our index into array 'a'.
+ index_a = tl.load(indices + tl.program_id(0))
+
+ # One threadblock per row in 'a'. Array 'b' has greater or equal
+ # number of rows since they could be padded.
+ bin_idx = tl.load(bin_ids + tl.program_id(0))
+
+ # Now we know what bin we're assigned to, but we need to know how
+ # many threadblocks were assigned to earlier bins so we can offset
+ # in our bin properly.
+ offset_in_bin = tl.program_id(0)
+ if bin_idx > 0:
+ offset_in_bin -= tl.load(bins + bin_idx - 1)
+
+ # Load the starting index of our bin in array 'b'.
+ index_b = offset_in_bin
+ if bin_idx > 0:
+ index_b += tl.load(padded_bins + bin_idx - 1)
+
+ # Offset the input and output pointers.
+ #
+ # If we're going from A to B, divide the input index to copy
+ # the same input repeatedly. If we're going from B to A we
+ # need to reduce the result. Using atomics is slow, so we
+ # do the reduce step in a second kernel.
+ offset = index_a // TOP_K if A_TO_B else index_a
+ a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS)
+ b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ # Load the scale, if requested.
+ scale = tl.load(weights + index_a) if SCALE else 1
+
+ # Swap the pointers depending on the direction.
+ iptr = a if A_TO_B else b
+ optr = b if A_TO_B else a
+
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ x = tl.load(iptr + offsets, mask=mask)
+ x = x.to(tl.float32) * scale.to(tl.float32)
+
+ tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask)
+
+ offsets += BLOCK_X
+
+
+def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+ assert_equal(bin_ids.shape[0], x.shape[0] * top_k)
+ assert_equal(bins.size(), padded_bins.size())
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ # NOTE: Because of the padding, the output size is dynamic.
+ # We load the final padded bin bound to get the output rows.
+ output_rows = padded_bins[-1].cpu().item()
+ out = torch.zeros((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ x,
+ out,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def gather(x, indices, bin_ids, weights, bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+ assert_equal(bin_ids.shape[0], x.shape[0] * top_k)
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ # NOTE: There is no padding so the output rows equals the
+ # input rows multiplied by top_k.
+ output_rows = x.shape[0] * top_k
+ out = torch.empty((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ x,
+ out,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], bin_ids.shape[0])
+ assert_equal(bins.size(), padded_bins.size())
+
+ if weights is not None:
+ assert_equal(indices.shape[0], weights.shape[0])
+
+ tokens = indices.shape[0] // top_k
+ out = torch.empty((tokens, top_k, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ out,
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=False,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+
+ # Reduce along the top-k dimension, if needed.
+ return out.sum(dim=1) if top_k > 1 else out.view(tokens, x.shape[1])
+
+
+def scatter(x, indices, bin_ids, weights, bins, top_k):
+ return padded_scatter(x, indices, bin_ids, weights, bins, bins, top_k)
+
+
+# x: (tokens, top_k, hidden_size), real
+# grad: (tokens, hidden_size), real.
+# wgrad: (tokens, top_k), real.
+# indices: (tokens * top_k), integer.
+# bin_ids: (tokens * top_k), integer.
+# bins: (num_experts), integer.
+# padded_bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _padded_copy_wgrad(
+ x,
+ grad,
+ wgrad,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+):
+ # Our index into 'tokens * top_k'.
+ index_out = tl.load(indices + tl.program_id(0))
+
+ # One threadblock per row in 'a'. Array 'b' has greater or equal
+ # number of rows since they could be padded.
+ bin_idx = tl.load(bin_ids + tl.program_id(0))
+
+ # Now we know what bin we're assigned to, but we need to know how
+ # many threadblocks were assigned to earlier bins so we can offset
+ # in our bin properly.
+ offset_in_bin = tl.program_id(0)
+ if bin_idx > 0:
+ offset_in_bin -= tl.load(bins + bin_idx - 1)
+
+ # Load the starting index of our bin in array 'x'.
+ index_x = offset_in_bin
+ if bin_idx > 0:
+ index_x += tl.load(padded_bins + bin_idx - 1)
+
+ # Offset the input and output pointers.
+ wgrad += index_out
+ grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS)
+ x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ data = tl.load(x + offsets, mask=mask).to(tl.float32)
+ scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
+ acc += data * scale
+ offsets += BLOCK_X
+
+ # Reduce to get the final result and store.
+ out = tl.sum(acc).to(wgrad.dtype.element_ty)
+ tl.store(wgrad, out)
+
+
+def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_matrix(grad)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], bin_ids.shape[0])
+ assert_equal(bins.size(), padded_bins.size())
+
+ tokens = indices.shape[0] // top_k
+ out = torch.empty((tokens * top_k), dtype=x.dtype, device=x.device)
+ _padded_copy_wgrad[(indices.shape[0],)](
+ x,
+ grad,
+ out,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ TOP_K=top_k,
+ )
+ return out
+
+
+def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k):
+ return padded_scatter_wgrad(x, grad, indices, bin_ids, bins, bins, top_k)
+
+
+# a: (tokens, hidden_size), real.
+# b: (num_experts, expert_capacity, num_columns), real.
+# indices: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _binned_copy(
+ a,
+ b,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+ A_TO_B: tl.constexpr,
+ SCALE: tl.constexpr,
+):
+ # Load our indices into the output.
+ expert_idx = tl.program_id(0)
+ entry_idx = tl.program_id(1)
+
+ # Calculate our offset into the output.
+ index_b = expert_idx * expert_capacity + entry_idx
+
+ # Load the index bounds for our bin and calculate
+ # the number of tokens assigned to our expert.
+ start = 0
+ if expert_idx > 0:
+ start = tl.load(bins + expert_idx - 1)
+ end = tl.load(bins + expert_idx)
+ num_tokens = end - start
+
+ # Calculate our offset into the input. If we don't
+ # have an input exit early.
+ if entry_idx >= num_tokens:
+ return
+ index_a = tl.load(indices + start + entry_idx)
+
+ # Offset the input and output pointers.
+ #
+ # If we're going from A to B, divide the input index to copy
+ # the same input repeatedly. If we're going from B to A we
+ # need to reduce the result. Using atomics is slow, so we
+ # do the reduce step in a second kernel.
+ offset = index_a // TOP_K if A_TO_B else index_a
+ a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS)
+ b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ # Load the scale, if requested.
+ scale = tl.load(weights + index_a) if SCALE else 1
+
+ # Swap the pointers depending on the direction.
+ #
+ # NOTE: We need to zero the output in both directions.
+ iptr = a if A_TO_B else b
+ optr = b if A_TO_B else a
+
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ x = tl.load(iptr + offsets, mask=mask)
+ x = x.to(tl.float32) * scale.to(tl.float32)
+
+ tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask)
+
+ offsets += BLOCK_X
+
+
+def binned_gather(x, indices, weights, bins, expert_capacity, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ num_experts = bins.shape[0]
+ out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
+
+ _binned_copy[(num_experts, expert_capacity)](
+ x,
+ out,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def binned_scatter(x, indices, weights, bins, top_k):
+ # Validate the input shapes.
+ assert_is_tensor(x, 3)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(bins.shape[0], x.shape[0])
+
+ if weights is not None:
+ assert_equal(indices.shape[0], weights.shape[0])
+
+ num_experts, expert_capacity, hidden_size = x.shape
+ tokens = indices.shape[0] // top_k
+ out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
+ _binned_copy[(num_experts, expert_capacity)](
+ out,
+ x,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS=hidden_size,
+ A_TO_B=False,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+
+ # Reduce along the top-k dimension, if needed.
+ return out.sum(dim=1) if top_k > 1 else out.view(tokens, hidden_size)
+
+
+# a: (tokens, hidden_size), real.
+# b: (num_experts, expert_capacity, num_columns), real.
+# indices: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _binned_copy_wgrad(
+ x,
+ grad,
+ wgrad,
+ num_experts,
+ expert_capacity,
+ indices,
+ bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+):
+ # Load our indices into the output.
+ expert_idx = tl.program_id(0)
+ entry_idx = tl.program_id(1)
+
+ # Calculate our offset into the output.
+ index_x = expert_idx * expert_capacity + entry_idx
+
+ # Load the index bounds for our bin and calculate
+ # the number of tokens assigned to our expert.
+ start = 0
+ if expert_idx > 0:
+ start = tl.load(bins + expert_idx - 1)
+ end = tl.load(bins + expert_idx)
+ num_tokens = end - start
+
+ # Calculate our offset into the input. If we don't
+ # have an input exit early.
+ if entry_idx >= num_tokens:
+ return
+ index_out = tl.load(indices + start + entry_idx)
+
+ # Offset the input and output pointers.
+ wgrad += index_out
+ grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS)
+ x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ data = tl.load(x + offsets, mask=mask).to(tl.float32)
+ scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
+ acc += data * scale
+ offsets += BLOCK_X
+
+ # Reduce to get the final result and store.
+ out = tl.sum(acc).to(wgrad.dtype.element_ty)
+ tl.store(wgrad, out)
+
+
+def binned_scatter_wgrad(x, grad, indices, bins, top_k):
+ # Validate the input shapes.
+ assert_is_tensor(x, 3)
+ assert_is_matrix(grad)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(bins.shape[0], x.shape[0])
+
+ num_experts, expert_capacity, hidden_size = x.shape
+ tokens = indices.shape[0] // top_k
+ out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
+ _binned_copy_wgrad[(num_experts, expert_capacity)](
+ x,
+ grad,
+ out,
+ num_experts,
+ expert_capacity,
+ indices,
+ bins,
+ NUM_COLUMNS=hidden_size,
+ TOP_K=top_k,
+ )
+ return out
diff --git a/build/torch29-cxx11-cu128-x86_64-linux/benchmark_util.py b/build/torch29-cxx11-cu128-x86_64-linux/benchmark_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..02612d95e3ead1175a596e2878fa34b5bf85ad6f
--- /dev/null
+++ b/build/torch29-cxx11-cu128-x86_64-linux/benchmark_util.py
@@ -0,0 +1,35 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import numpy as np
+import torch
+
+
+def log_benchmark(name, arguments, time, std):
+ print('=' * 60)
+ print(f'{name} Benchmark')
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean time = {:.3f}ms, std time = {:.3f}ms'.format(time, std))
+ print('=' * 60)
+
+
+def benchmark_function(fn, iterations=100, warmup=10):
+ # Warmup iterations.
+ for _ in range(warmup):
+ fn()
+
+ times = []
+ for i in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+
+ start.record()
+ fn()
+ end.record()
+
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ return np.mean(times), np.std(times)
diff --git a/build/torch29-cxx11-cu128-x86_64-linux/cpu_fused_moe.py b/build/torch29-cxx11-cu128-x86_64-linux/cpu_fused_moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a29e81d6e9c57f64b0a78ac1c0828e45fd9d855
--- /dev/null
+++ b/build/torch29-cxx11-cu128-x86_64-linux/cpu_fused_moe.py
@@ -0,0 +1,311 @@
+# SPDX-License-Identifier: Apache-2.0
+# MegaBlocks CPU Fused MoE Implementation
+#
+# This is a pure Python/PyTorch implementation for CPU.
+# For better performance, consider using the C++ kernel implementation.
+#
+import torch
+import torch.nn.functional as F
+
+
+def swigluoai_activation(gate: torch.Tensor, up: torch.Tensor,
+ alpha: float = 1.702, limit: float = 7.0) -> torch.Tensor:
+ """
+ SwigluOAI activation function used in GptOss models.
+
+ Formula:
+ gate = clamp(gate, max=limit)
+ up = clamp(up, -limit, limit)
+ glu = gate * sigmoid(gate * alpha)
+ output = (up + 1) * glu
+
+ Args:
+ gate: Gate tensor from gate projection
+ up: Up tensor from up projection
+ alpha: Scaling factor for sigmoid (default: 1.702)
+ limit: Clamp limit (default: 7.0)
+
+ Returns:
+ Activated tensor
+ """
+ gate = gate.clamp(max=limit)
+ up = up.clamp(min=-limit, max=limit)
+ glu = gate * torch.sigmoid(gate * alpha)
+ return (up + 1) * glu
+
+
+def silu_and_mul_activation(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor:
+ """
+ SiLU (Swish) activation with element-wise multiplication.
+
+ Formula:
+ output = silu(gate) * up
+
+ Args:
+ gate: Gate tensor
+ up: Up tensor
+
+ Returns:
+ Activated tensor
+ """
+ return F.silu(gate) * up
+
+
+def route_tokens_cpu(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: torch.Tensor | None,
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_normalize_expert_weights: int | None = None,
+) -> tuple:
+ """
+ Route tokens to experts and compute expert weights and indices (CPU version).
+
+ Args:
+ x: Input tensor [batch, seq, hidden] or [tokens, hidden]
+ router_weight: Router weight [num_experts, hidden]
+ router_bias: Router bias [num_experts] or None
+ moe_top_k: Number of experts per token
+ moe_num_experts: Total number of experts
+ moe_normalize_expert_weights: Normalization order or None
+
+ Returns:
+ Tuple of (logits, expert_weights, expert_indices)
+ """
+ x_flat = x.view(-1, x.shape[-1])
+ logits = F.linear(x_flat, router_weight, router_bias)
+
+ if moe_top_k == 1:
+ expert_weights, expert_indices = logits.max(dim=-1, keepdim=True)
+ else:
+ expert_weights, expert_indices = torch.topk(logits, moe_top_k, dim=-1)
+
+ expert_weights = expert_weights.softmax(dim=-1)
+
+ if moe_normalize_expert_weights is not None:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+
+ return logits, expert_weights, expert_indices
+
+
+def cpu_fused_moe(
+ hidden_states: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ w1_bias: torch.Tensor | None = None,
+ w2_bias: torch.Tensor | None = None,
+ activation: str = "silu",
+ alpha: float = 1.702,
+ limit: float = 7.0,
+ is_interleaved: bool = True,
+) -> torch.Tensor:
+ """
+ CPU Fused MoE using PyTorch operations.
+
+ This implementation processes all experts in parallel using batched operations
+ instead of sequential for loops, which is more efficient on CPU.
+
+ Args:
+ hidden_states: [num_tokens, hidden_size]
+ w1: [num_experts, hidden_size, 2*inter_size] - gate_up_proj weights
+ w2: [num_experts, inter_size, hidden_size] - down_proj weights
+ topk_weights: [num_tokens, topk] - routing weights
+ topk_ids: [num_tokens, topk] - expert indices
+ w1_bias: [num_experts, 2*inter_size] or None
+ w2_bias: [num_experts, hidden_size] or None
+ activation: "silu" or "swigluoai"
+ alpha: swigluoai alpha parameter
+ limit: swigluoai limit parameter
+ is_interleaved: whether gate_up is interleaved [g0,u0,g1,u1,...] (True for GptOss)
+
+ Returns:
+ output: [num_tokens, hidden_size]
+ """
+ num_tokens, hidden_size = hidden_states.shape
+ num_experts = w1.shape[0]
+ inter_size = w2.shape[1]
+ topk = topk_weights.shape[1]
+
+ # Initialize output
+ output = torch.zeros_like(hidden_states)
+
+ # Build expert mask: which tokens go to which expert
+ # expert_mask[expert_id] contains indices of (token_idx, topk_pos) pairs
+ for expert_idx in range(num_experts):
+ # Find tokens assigned to this expert
+ # mask shape: [num_tokens, topk], True where topk_ids == expert_idx
+ mask = (topk_ids == expert_idx)
+
+ if not mask.any():
+ continue
+
+ # Get token indices and topk positions
+ token_indices, topk_positions = torch.where(mask)
+
+ if len(token_indices) == 0:
+ continue
+
+ # Gather input tokens for this expert
+ # current_hidden: [num_selected_tokens, hidden_size]
+ current_hidden = hidden_states[token_indices]
+
+ # Get weights for this expert
+ # w1[expert_idx]: [hidden_size, 2*inter_size]
+ # w2[expert_idx]: [inter_size, hidden_size]
+ expert_w1 = w1[expert_idx] # [hidden_size, 2*inter_size]
+ expert_w2 = w2[expert_idx] # [inter_size, hidden_size]
+
+ # First projection: hidden @ w1 -> [num_selected, 2*inter_size]
+ gate_up = current_hidden @ expert_w1
+
+ # Add bias if present
+ if w1_bias is not None:
+ gate_up = gate_up + w1_bias[expert_idx]
+
+ # Split gate and up projections
+ if is_interleaved:
+ # GptOss uses interleaved layout: [g0, u0, g1, u1, ...]
+ gate = gate_up[..., ::2] # [num_selected, inter_size]
+ up = gate_up[..., 1::2] # [num_selected, inter_size]
+ else:
+ # Standard layout: [gate_all, up_all]
+ gate = gate_up[..., :inter_size]
+ up = gate_up[..., inter_size:]
+
+ # Apply activation
+ if activation == "swigluoai":
+ activated = swigluoai_activation(gate, up, alpha, limit)
+ else: # silu
+ activated = silu_and_mul_activation(gate, up)
+
+ # Second projection: activated @ w2 -> [num_selected, hidden_size]
+ expert_out = activated @ expert_w2
+
+ # Add bias if present
+ if w2_bias is not None:
+ expert_out = expert_out + w2_bias[expert_idx]
+
+ # Apply routing weights and accumulate
+ # weights shape: [num_selected]
+ weights = topk_weights[token_indices, topk_positions].unsqueeze(-1)
+ weighted_out = expert_out * weights
+
+ # Accumulate to output
+ output.index_add_(0, token_indices, weighted_out.to(output.dtype))
+
+ return output
+
+
+class MegaBlocksMoeMLP(torch.nn.Module):
+ """
+ CPU MoE MLP module that can be used as a drop-in replacement for
+ the transformers GptOssMLP when using @use_kernel_forward_from_hub.
+ """
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> tuple:
+ """
+ Forward pass through the MoE layer.
+
+ Args:
+ x: Input tensor of shape [batch_size, seq_len, hidden_size] or [tokens, hidden_size]
+
+ Returns:
+ Tuple of (output, expert_weights) where:
+ - output: Tensor of same shape as input
+ - expert_weights: Expert weights for each token [tokens, top_k]
+ """
+ # Get MoE parameters from the wrapped modules
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+
+ # Detect activation type
+ if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
+ activation = "swigluoai"
+ alpha = self.experts.alpha
+ limit = self.experts.limit
+ else:
+ activation = getattr(self.experts, "activation", "silu")
+ alpha = 1.702
+ limit = 7.0
+
+ # Get weight tensors
+ if hasattr(self.experts, "gate_up_proj"):
+ w1 = self.experts.gate_up_proj
+ is_interleaved = True # GptOss uses interleaved layout
+ elif hasattr(self.experts, "w1"):
+ w1 = self.experts.w1
+ w3 = getattr(self.experts, "w3", None)
+ if w3 is not None:
+ w1 = torch.cat([w1, w3], dim=-1)
+ is_interleaved = False
+ else:
+ raise AttributeError("experts module must have 'gate_up_proj' or 'w1' attribute")
+
+ if hasattr(self.experts, "down_proj"):
+ w2 = self.experts.down_proj
+ elif hasattr(self.experts, "w2"):
+ w2 = self.experts.w2
+ else:
+ raise AttributeError("experts module must have 'down_proj' or 'w2' attribute")
+
+ # Get optional bias tensors
+ w1_bias = getattr(self.experts, "gate_up_proj_bias", None)
+ w2_bias = getattr(self.experts, "down_proj_bias", None)
+
+ # Store original shape
+ in_shape = x.size()
+
+ # Route tokens to experts
+ logits, expert_weights, expert_indices = route_tokens_cpu(
+ x,
+ self.router.weight,
+ getattr(self.router, "bias", None),
+ moe_top_k,
+ moe_num_experts,
+ moe_normalize_expert_weights,
+ )
+
+ # Reshape input for fused MoE
+ x_flat = x.view(-1, x.shape[-1])
+
+ # Call CPU fused MoE
+ output = cpu_fused_moe(
+ hidden_states=x_flat,
+ w1=w1,
+ w2=w2,
+ topk_weights=expert_weights,
+ topk_ids=expert_indices,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ activation=activation,
+ alpha=alpha,
+ limit=limit,
+ is_interleaved=is_interleaved,
+ )
+
+ # Restore original shape
+ output = output.view(in_shape)
+
+ return output, expert_weights
+
+
+# Export classes and functions
+__all__ = [
+ "MegaBlocksMoeMLP",
+ "cpu_fused_moe",
+ "route_tokens_cpu",
+ "swigluoai_activation",
+ "silu_and_mul_activation",
+]
diff --git a/build/torch29-cxx11-cu128-x86_64-linux/cpu_moe_cpp.py b/build/torch29-cxx11-cu128-x86_64-linux/cpu_moe_cpp.py
new file mode 100644
index 0000000000000000000000000000000000000000..073ff66d24ce348fbb5ed19c9027fadd3f7a9c61
--- /dev/null
+++ b/build/torch29-cxx11-cu128-x86_64-linux/cpu_moe_cpp.py
@@ -0,0 +1,265 @@
+# SPDX-License-Identifier: Apache-2.0
+# MegaBlocks C++ Optimized CPU MoE
+
+"""
+C++ accelerated MoE with brgemm optimization for Intel AMX.
+Direct replacement for cpu_fused_moe.MegaBlocksMoeMLP with better performance.
+"""
+
+import torch
+from typing import Optional
+from .cpu_fused_moe import route_tokens_cpu
+from ._ops import ops
+
+
+def _to_local_tensor(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
+ """Convert DTensor to local torch.Tensor if needed for custom ops compatibility."""
+ if tensor is None:
+ return None
+ # Check if it's a DTensor by looking for the to_local() method
+ if hasattr(tensor, "to_local"):
+ return tensor.to_local()
+ return tensor
+
+
+def fused_moe_cpp(
+ hidden_states: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ inplace: bool = False,
+ use_int8_w8a8: bool = False,
+ use_fp8_w8a16: bool = False,
+ use_mxfp4: bool = False,
+ w1_scale: Optional[torch.Tensor] = None,
+ w2_scale: Optional[torch.Tensor] = None,
+ block_size: Optional[list] = None,
+ a1_scale: Optional[torch.Tensor] = None,
+ a2_scale: Optional[torch.Tensor] = None,
+ w1_bias: Optional[torch.Tensor] = None,
+ w2_bias: Optional[torch.Tensor] = None,
+ alpha: Optional[float] = None,
+ limit: Optional[float] = None,
+ is_vnni: bool = False,
+) -> torch.Tensor:
+ """
+ C++ Fused MoE with brgemm optimization (sglang compatible interface).
+
+ Uses at::native::cpublas::brgemm for efficient batch GEMM on Intel CPUs.
+ Supports both silu_and_mul (standard SwiGLU) and swigluoai (GptOss) activations.
+
+ Args:
+ hidden_states: Input tensor [M, K]
+ w1: Gate and up projections [E, 2N, K]
+ w2: Down projection [E, K, N]
+ topk_weights: Expert weights [M, topk]
+ topk_ids: Expert indices [M, topk]
+ inplace: Whether to use hidden_states as output
+ use_int8_w8a8: Use int8 quantization
+ use_fp8_w8a16: Use fp8 quantization
+ use_mxfp4: Use mxfp4 quantization
+ w1_scale, w2_scale: Quantization scales
+ block_size: Block size for fp8
+ a1_scale, a2_scale: Activation scales
+ w1_bias, w2_bias: Optional biases
+ alpha: swigluoai alpha parameter (set to enable swiglu)
+ limit: swigluoai limit parameter (set to enable swiglu)
+ is_vnni: Whether w1/w2 are already in VNNI packed format
+ """
+ # MXFP4/FP8 kernels only support bf16, convert if needed
+ orig_dtype = hidden_states.dtype
+ need_convert = ((use_mxfp4 or use_fp8_w8a16) and orig_dtype != torch.bfloat16) or orig_dtype == torch.float32
+ if need_convert:
+ hidden_states = hidden_states.to(torch.bfloat16)
+
+ # bias must match hidden_states dtype
+ if w1_bias is not None:
+ w1_bias = w1_bias.to(hidden_states.dtype)
+ if w2_bias is not None:
+ w2_bias = w2_bias.to(hidden_states.dtype)
+
+ # Convert DTensor to local tensor for custom ops compatibility (TP mode)
+ hidden_states = _to_local_tensor(hidden_states)
+ w1 = _to_local_tensor(w1)
+ w2 = _to_local_tensor(w2)
+ topk_weights = _to_local_tensor(topk_weights)
+ topk_ids = _to_local_tensor(topk_ids)
+ w1_scale = _to_local_tensor(w1_scale)
+ w2_scale = _to_local_tensor(w2_scale)
+ a1_scale = _to_local_tensor(a1_scale)
+ a2_scale = _to_local_tensor(a2_scale)
+ w1_bias = _to_local_tensor(w1_bias)
+ w2_bias = _to_local_tensor(w2_bias)
+
+ output = ops.fused_experts(
+ hidden_states, w1, w2, topk_weights, topk_ids,
+ inplace, use_int8_w8a8, use_fp8_w8a16, use_mxfp4,
+ w1_scale, w2_scale, block_size, a1_scale, a2_scale,
+ w1_bias, w2_bias, alpha, limit, is_vnni
+ )
+
+ # Convert back to original dtype if needed
+ if need_convert:
+ output = output.to(orig_dtype)
+ return output
+
+
+class CPUMegaBlocksMoeMLP(torch.nn.Module):
+ """
+ C++ optimized MoE MLP using brgemm.
+ Drop-in replacement for cpu_fused_moe.MegaBlocksMoeMLP with better performance.
+
+ Usage in transformers:
+ # Will be used via @use_kernel_forward_from_hub decorator
+ """
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> tuple:
+ """
+ Forward pass through the MoE layer using C++ kernel.
+
+ Args:
+ x: Input tensor [batch_size, seq_len, hidden_size]
+
+ Returns:
+ Tuple of (output, expert_weights)
+ """
+ # Optimization for GPT-OSS model
+ if getattr(self, "use_mxfp4", None) is None:
+ self.use_mxfp4 = False
+
+ w1_scale = None
+ w2_scale = None
+
+ if (
+ not getattr(self, "packed_scales", False)
+ and hasattr(self.experts, "gate_up_proj")
+ and getattr(self.experts, "gate_up_proj_precision_config", None) is not None
+ ):
+ # convert scales
+ data_1 = ops.convert_scale_packed(self.experts.gate_up_proj_precision_config.weight_scale.data.transpose(-1, -2).contiguous())
+ data_2 = ops.convert_scale_packed(self.experts.down_proj_precision_config.weight_scale.data.transpose(-1, -2).contiguous())
+ self.experts.gate_up_proj_precision_config.weight_scale.storage.data = data_1
+ self.experts.down_proj_precision_config.weight_scale.storage.data = data_2
+ self.packed_scales = True
+ self.use_mxfp4 = True
+
+ if not getattr(self, "packed_weight", False) and hasattr(
+ self.experts, "gate_up_proj"
+ ):
+ # convert weights
+ data_1 = self.experts.gate_up_proj.data.transpose(-1, -2).contiguous()
+ data_2 = self.experts.down_proj.data.transpose(-1, -2).contiguous()
+ if self.use_mxfp4:
+ self.experts.gate_up_proj.storage.data = ops.convert_weight_packed(data_1)
+ self.experts.down_proj.storage.data = ops.convert_weight_packed(data_2)
+ else:
+ # convert_weight_packed only supports bfloat16, float16, int8, fp8_e4m3 or uint8(mxfp4 or int4).
+ data_1 = data_1.to(torch.bfloat16) if data_1.dtype == torch.float32 else data_1
+ data_2 = data_2.to(torch.bfloat16) if data_2.dtype == torch.float32 else data_2
+ self.experts.gate_up_proj.data = ops.convert_weight_packed(data_1)
+ self.experts.down_proj.data = ops.convert_weight_packed(data_2)
+
+ # C++ kernel does not support float32.
+ dtype = torch.bfloat16 if x.dtype == torch.float32 else x.dtype
+ if getattr(self.experts, "gate_up_proj_bias", None) is not None:
+ self.experts.gate_up_proj_bias.data = self.experts.gate_up_proj_bias.data.to(dtype)
+ if getattr(self.experts, "down_proj_bias", None) is not None:
+ self.experts.down_proj_bias.data = self.experts.down_proj_bias.data.to(dtype)
+
+ self.packed_weight = True
+
+ # Get MoE parameters
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
+
+ # Detect activation type
+ if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
+ activation = "swigluoai"
+ alpha = self.experts.alpha
+ limit = self.experts.limit
+ else:
+ activation = getattr(self.experts, "activation", "silu")
+ alpha = 1.702
+ limit = 7.0
+
+ # Get weight tensors
+ if hasattr(self.experts, "gate_up_proj"):
+ w1 = self.experts.gate_up_proj
+ elif hasattr(self.experts, "w1"):
+ w1 = self.experts.w1
+ w3 = getattr(self.experts, "w3", None)
+ if w3 is not None:
+ w1 = torch.cat([w1, w3], dim=-1)
+ else:
+ raise AttributeError("experts module must have 'gate_up_proj' or 'w1' attribute")
+
+ if hasattr(self.experts, "down_proj"):
+ w2 = self.experts.down_proj
+ elif hasattr(self.experts, "w2"):
+ w2 = self.experts.w2
+ else:
+ raise AttributeError("experts module must have 'down_proj' or 'w2' attribute")
+
+ # Get optional bias tensors
+ w1_bias = getattr(self.experts, "gate_up_proj_bias", None)
+ w2_bias = getattr(self.experts, "down_proj_bias", None)
+ w1_bias = w1_bias if w1_bias is None else w1_bias.data
+ w2_bias = w2_bias if w2_bias is None else w2_bias.data
+
+ if self.use_mxfp4:
+ w1_scale = self.experts.gate_up_proj_precision_config.weight_scale.data
+ w2_scale = self.experts.down_proj_precision_config.weight_scale.data
+
+ # Store original shape
+ in_shape = x.size()
+
+ # Route tokens to experts (Python implementation is fast enough)
+ logits, expert_weights, expert_indices = route_tokens_cpu(
+ x,
+ self.router.weight,
+ getattr(self.router, "bias", None),
+ moe_top_k,
+ moe_num_experts,
+ moe_normalize_expert_weights,
+ )
+
+ # Flatten input
+ x_flat = x.view(-1, x.shape[-1])
+
+ # Determine alpha/limit for swiglu activation
+ use_alpha = alpha if activation == "swigluoai" else None
+ use_limit = limit if activation == "swigluoai" else None
+
+ # Call C++ optimized kernel
+ output = fused_moe_cpp(
+ hidden_states=x_flat,
+ w1=w1.data,
+ w2=w2.data,
+ topk_weights=expert_weights,
+ topk_ids=expert_indices.to(torch.int32),
+ inplace=False,
+ use_int8_w8a8=False,
+ use_fp8_w8a16=False,
+ use_mxfp4=self.use_mxfp4,
+ w1_scale=w1_scale,
+ w2_scale=w2_scale,
+ block_size=None,
+ a1_scale=None,
+ a2_scale=None,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ alpha=use_alpha,
+ limit=use_limit,
+ is_vnni=getattr(self, "packed_weight", False),
+ )
+
+ # Restore original shape
+ output = output.view(in_shape)
+
+ return output, expert_weights
+
+
+__all__ = ["fused_moe_cpp", "MegaBlocksMoeMLP"]
diff --git a/build/torch29-cxx11-cu128-x86_64-linux/grouped_gemm/__init__.py b/build/torch29-cxx11-cu128-x86_64-linux/grouped_gemm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b91c8308f0c24f4c4171b6e4f15b6f76dabf295a
--- /dev/null
+++ b/build/torch29-cxx11-cu128-x86_64-linux/grouped_gemm/__init__.py
@@ -0,0 +1,2 @@
+from . import ops
+from . import backend
diff --git a/build/torch29-cxx11-cu128-x86_64-linux/grouped_gemm/backend.py b/build/torch29-cxx11-cu128-x86_64-linux/grouped_gemm/backend.py
new file mode 100644
index 0000000000000000000000000000000000000000..76037d8039cbfc2f0577275c78e4bc0be762592a
--- /dev/null
+++ b/build/torch29-cxx11-cu128-x86_64-linux/grouped_gemm/backend.py
@@ -0,0 +1,33 @@
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# # TODO(tgale): Wrap this in a try-block with better
+# # error message and instructions for building the
+# # c++ operations.
+# import grouped_gemm_backend as backend
+
+# We import the backend operations from the megablocks package as
+# grouped_gemm is vendored in megablocks in this repository.
+# from ... import _ops as backend
+# from megablocks._ops import ops as backend # type: ignore
+from .._ops import ops as backend # type: ignore
+
+def _allocate_output(a, b, batch_sizes, trans_a, trans_b):
+ assert not (trans_a and trans_b)
+ assert batch_sizes.ndim == 1, "Expected 1d tensor for batch_sizes"
+ assert a.ndim == 2, "Expected 2d tensor for 'a'"
+ assert b.ndim == (2 if trans_a else 3)
+
+ shape = (
+ (batch_sizes.shape[0], a.shape[1], b.shape[1])
+ if trans_a else
+ (a.shape[0], (b.shape[1] if trans_b else b.shape[2]))
+ )
+ return torch.empty(*shape, device=a.device, dtype=a.dtype)
+
+def gmm(a, b, batch_sizes, trans_a=False, trans_b=False, c=None):
+ if c is None:
+ c = _allocate_output(a, b, batch_sizes, trans_a, trans_b)
+ backend.gmm(a, b, c, batch_sizes, trans_a, trans_b)
+ return c
diff --git a/build/torch29-cxx11-cu128-x86_64-linux/grouped_gemm/ops.py b/build/torch29-cxx11-cu128-x86_64-linux/grouped_gemm/ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b30dd14e23837ea3b12334f4e31337ed9ad2b69
--- /dev/null
+++ b/build/torch29-cxx11-cu128-x86_64-linux/grouped_gemm/ops.py
@@ -0,0 +1,33 @@
+from . import backend
+import torch
+
+
+class GroupedGemm(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, a, b, batch_sizes, trans_b):
+ ctx.save_for_backward(a, b, batch_sizes)
+ ctx.trans_b = trans_b
+ return backend.gmm(a, b, batch_sizes, trans_a=False, trans_b=trans_b)
+
+ @staticmethod
+ def backward(ctx, grad):
+ grad = grad.contiguous()
+ a, b, batch_sizes = ctx.saved_tensors
+ trans_b = ctx.trans_b
+
+ agrad = None
+ if ctx.needs_input_grad[0]:
+ agrad = backend.gmm(
+ grad, b, batch_sizes, trans_a=False, trans_b=not trans_b)
+
+ bgrad = None
+ if ctx.needs_input_grad[1]:
+ lhs, rhs = (grad, a) if trans_b else (a, grad)
+ bgrad = backend.gmm(
+ lhs, rhs, batch_sizes, trans_a=True, trans_b=False)
+ return agrad, bgrad, None, None
+
+
+def gmm(a, b, batch_sizes, trans_b=False):
+ return GroupedGemm.apply(a, b, batch_sizes, trans_b)
diff --git a/build/torch29-cxx11-cu128-x86_64-linux/grouped_gemm_util.py b/build/torch29-cxx11-cu128-x86_64-linux/grouped_gemm_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d6d49fc46b0a57ad46e4179df3cc1ac2a24f7ae
--- /dev/null
+++ b/build/torch29-cxx11-cu128-x86_64-linux/grouped_gemm_util.py
@@ -0,0 +1,31 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+import warnings
+
+_grouped_gemm_is_available: bool = False
+try:
+ # import grouped_gemm
+ pass
+ _grouped_gemm_is_available = True
+except ImportError as error:
+ warnings.warn('Grouped GEMM not available.')
+
+
+def grouped_gemm_is_available():
+ return _grouped_gemm_is_available
+
+
+def assert_grouped_gemm_is_available():
+ msg = (
+ 'Grouped GEMM not available. Please run '
+ '`pip install git+https://github.com/tgale96/grouped_gemm@main`.',
+ )
+ assert _grouped_gemm_is_available, msg
+
+
+# backend = grouped_gemm.backend if grouped_gemm_is_available() else None
+# ops = grouped_gemm.ops if grouped_gemm_is_available() else None
+
+
+from .grouped_gemm import backend as ops
+from .grouped_gemm import ops as backend
diff --git a/build/torch29-cxx11-cu128-x86_64-linux/layers.py b/build/torch29-cxx11-cu128-x86_64-linux/layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c91edcd9e2d1a4ef9eac90217ff481f08ab1886
--- /dev/null
+++ b/build/torch29-cxx11-cu128-x86_64-linux/layers.py
@@ -0,0 +1,1232 @@
+import torch
+import torch.distributed as dist
+
+from typing import Optional, Any, TYPE_CHECKING
+
+from . import _layers
+from . import ops
+
+# Conditional import for meta kernel registration
+if TYPE_CHECKING:
+
+ def register_fake(fn):
+ return lambda name: fn
+
+else:
+ try:
+ from torch.library import register_fake
+ except ImportError:
+ try:
+ from torch.library import impl_abstract as register_fake
+ except ImportError:
+ # Fallback for older PyTorch versions
+ def register_fake(op_name):
+ def decorator(fn):
+ return fn
+
+ return decorator
+
+
+# Meta kernel implementations for torch.compile compatibility
+def _install_meta_kernels():
+ """Install meta kernels for existing MegaBlocks operations"""
+
+ # Create wrapper functions that check for compilation and return meta tensors
+
+ # Patch ops.sort
+ if hasattr(ops, "sort"):
+ original_sort = ops.sort
+
+ def sort_with_meta(x, end_bit=None):
+ if torch.compiler.is_compiling():
+ # print("Using meta kernel for sort")
+ # Meta implementation - return tensors with correct shape/dtype/device
+ return torch.empty_like(x), torch.empty_like(x)
+ # print("Using original sort kernel")
+ return original_sort(x, end_bit)
+
+ ops.sort = sort_with_meta
+
+ # Patch ops.histogram
+ if hasattr(ops, "histogram"):
+ original_histogram = ops.histogram
+
+ def histogram_with_meta(x, max_val):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty((max_val,), dtype=torch.int32, device=x.device)
+ return original_histogram(x, max_val)
+
+ ops.histogram = histogram_with_meta
+
+ # Patch ops.inclusive_cumsum
+ if hasattr(ops, "inclusive_cumsum"):
+ original_inclusive_cumsum = ops.inclusive_cumsum
+
+ def inclusive_cumsum_with_meta(x, dim):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty_like(x)
+ return original_inclusive_cumsum(x, dim)
+
+ ops.inclusive_cumsum = inclusive_cumsum_with_meta
+
+ # Patch ops.binned_gather
+ if hasattr(ops, "binned_gather"):
+ original_binned_gather = ops.binned_gather
+
+ def binned_gather_with_meta(x, indices, bins, bin_size, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - output shape based on bin_size
+ if x.dim() >= 2:
+ hidden_size = x.size(-1)
+ return torch.empty(
+ (bin_size, x.size(1), hidden_size),
+ dtype=x.dtype,
+ device=x.device,
+ )
+ else:
+ return torch.empty((bin_size,), dtype=x.dtype, device=x.device)
+ return original_binned_gather(x, indices, bins, bin_size, top_k)
+
+ ops.binned_gather = binned_gather_with_meta
+
+ # Patch ops.binned_scatter
+ if hasattr(ops, "binned_scatter"):
+ original_binned_scatter = ops.binned_scatter
+
+ def binned_scatter_with_meta(x, indices, weights, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - typically reduces to 2D
+ if x.dim() >= 3:
+ return torch.empty(
+ (x.size(1), x.size(2)), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty_like(x)
+ return original_binned_scatter(x, indices, weights, bins, top_k)
+
+ ops.binned_scatter = binned_scatter_with_meta
+
+ # Patch ops.gather
+ if hasattr(ops, "gather"):
+ original_gather = ops.gather
+
+ def gather_with_meta(x, indices, bin_ids, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ if x.dim() >= 2:
+ hidden_size = x.size(-1)
+ return torch.empty(
+ (indices.numel(), hidden_size), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty(indices.shape, dtype=x.dtype, device=x.device)
+ return original_gather(x, indices, bin_ids, bins, top_k)
+
+ ops.gather = gather_with_meta
+
+ # Patch ops.scatter
+ if hasattr(ops, "scatter"):
+ original_scatter = ops.scatter
+
+ def scatter_with_meta(x, indices, bin_ids, weights, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - restore sequence shape
+ seq_len = (
+ indices.size(0) // top_k
+ if indices.numel() > 0 and top_k > 0
+ else x.size(0)
+ )
+ if x.dim() >= 2:
+ return torch.empty(
+ (seq_len, x.size(-1)), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty((seq_len,), dtype=x.dtype, device=x.device)
+ return original_scatter(x, indices, bin_ids, weights, bins, top_k)
+
+ ops.scatter = scatter_with_meta
+
+ # Patch ops.replicate
+ if hasattr(ops, "replicate"):
+ original_replicate = ops.replicate
+
+ def replicate_with_meta(x, bins, num_outputs):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty(
+ (x.shape[0], num_outputs), dtype=x.dtype, device=x.device
+ )
+ return original_replicate(x, bins, num_outputs)
+
+ ops.replicate = replicate_with_meta
+
+ # Patch ops.repeat (if it's a regular function)
+ if hasattr(ops, "repeat"):
+ original_repeat = ops.repeat
+
+ def repeat_with_meta(x, repeats):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ if isinstance(repeats, (tuple, list)):
+ new_shape = list(x.shape)
+ for i, rep in enumerate(repeats):
+ if i < len(new_shape):
+ new_shape[i] *= rep
+ return torch.empty(new_shape, dtype=x.dtype, device=x.device)
+ else:
+ new_shape = [x.size(0) * repeats] + list(x.shape[1:])
+ return torch.empty(new_shape, dtype=x.dtype, device=x.device)
+ return original_repeat(x, repeats)
+
+ ops.repeat = repeat_with_meta
+
+
+# Install meta kernels on import
+try:
+ _install_meta_kernels()
+except Exception as e:
+ # If meta kernel installation fails, continue without them
+ # torch.compile may not work but the library will still function
+ import warnings
+
+ warnings.warn(
+ f"Failed to install meta kernels for torch.compile support: {e}", UserWarning
+ )
+
+
+# Set the expert model parallel attributes on a tensor
+def set_expert_model_parallel_attributes(
+ tensor: torch.Tensor,
+ is_parallel: bool,
+):
+ assert not hasattr(tensor, "expert_model_parallel")
+ setattr(tensor, "expert_model_parallel", is_parallel)
+
+
+# Get the expert model parallel attributes from a tensor
+def expert_sharding_degree(
+ world_size: int,
+ moe_num_experts: int,
+) -> int:
+ esd = min(world_size, moe_num_experts)
+ if (moe_num_experts % esd) != 0:
+ raise ValueError(f"Cannot shard {moe_num_experts} experts {esd} ways.")
+ return esd
+
+
+# Calculate the hidden sharding degree based on world size and expert sharding degree
+def hidden_sharding_degree(
+ world_size: int,
+ moe_num_experts: int,
+ ffn_hidden_size: int,
+) -> int:
+ esd = expert_sharding_degree(world_size, moe_num_experts)
+ hsd = world_size // esd
+ if (ffn_hidden_size % hsd) != 0:
+ raise ValueError(f"Cannot shard {ffn_hidden_size} features {hsd} ways.")
+ if (esd * hsd) != world_size:
+ raise ValueError(
+ f"Invalid sharding. expert_sharding_degree ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size})."
+ )
+ return hsd
+
+
+# Calculate the number of experts per rank based on world size and expert sharding degree
+def experts_per_rank(
+ moe_num_experts: int,
+ world_size: int,
+) -> int:
+ return moe_num_experts // expert_sharding_degree(world_size, moe_num_experts)
+
+
+# Calculate the number of features per rank based on ffn hidden size and hidden sharding degree
+def features_per_rank(
+ ffn_hidden_size: int, world_size: int, moe_num_experts: int
+) -> int:
+ return ffn_hidden_size // hidden_sharding_degree(
+ world_size, moe_num_experts, ffn_hidden_size
+ )
+
+
+# Apply jitter to the input tensor
+def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor:
+ low = 1.0 - moe_jitter_eps
+ high = 1.0 + moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return x * (low + noise * (high - low))
+
+
+# Compute the top-k scores from the logits
+def compute_top_k(scores: torch.Tensor, moe_top_k: int):
+ if moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, moe_top_k, dim=-1)
+
+
+# Route tokens to experts and compute expert weights and indices
+def route_tokens(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: torch.Tensor,
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ if training and moe_jitter_eps is not None:
+ x = apply_jitter(x, moe_jitter_eps)
+
+ x_flat = x.view(-1, x.shape[-1])
+ logits = torch.nn.functional.linear(x_flat, router_weight, router_bias)
+ expert_weights, expert_indices = compute_top_k(logits, moe_top_k)
+ expert_weights = expert_weights.softmax(dim=-1)
+ if moe_normalize_expert_weights is not None:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+ if uniform_expert_assignment:
+ expert_indices = _layers.router._uniform_expert_assignment(
+ expert_indices,
+ moe_num_experts,
+ )
+
+ return logits, expert_weights, expert_indices
+
+
+# Scale the gradient of the weights
+def scale_grad(
+ w: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+) -> torch.Tensor:
+ if gradient_scale is None:
+ return w
+ return _layers.mlp.scale_gradient(w, gradient_scale)
+
+
+# Forward pass for the MLP layer
+def mlp_forward(
+ x: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ limit: float = 7.0,
+):
+ # Scale weights
+ w1 = scale_grad(w1, gradient_scale)
+ w2 = scale_grad(w2, gradient_scale)
+ w1_bias = scale_grad(w1_bias, gradient_scale)
+ w2_bias = scale_grad(w2_bias, gradient_scale)
+
+ # Resolve dtensors
+ w1 = _layers.mlp.resolve_dtensor(w1)
+ w2 = _layers.mlp.resolve_dtensor(w2)
+ w1_bias = _layers.mlp.resolve_dtensor(w1_bias)
+ w2_bias = _layers.mlp.resolve_dtensor(w2_bias)
+
+ # Forward pass
+ gate_up = torch.bmm(x, w1) + w1_bias[..., None, :]
+ gate, up = gate_up[..., ::2], gate_up[..., 1::2]
+ gate = gate.clamp(min=None, max=limit)
+ up = up.clamp(min=-limit, max=limit)
+ glu = gate * torch.sigmoid(gate * alpha)
+ next_states = torch.bmm(((up + 1) * glu), w2)
+ next_states += w2_bias[..., None, :]
+ return next_states
+
+# Shared expert MLP forward pass
+def shared_mlp_forward(
+ x: torch.Tensor,
+ up_proj_weight: torch.Tensor,
+ down_proj_weight: torch.Tensor,
+ up_proj_bias: Optional[torch.Tensor] = None,
+ down_proj_bias: Optional[torch.Tensor] = None,
+ activation_fn: Optional[Any] = None,
+ gradient_scale: Optional[float] = None,
+) -> torch.Tensor:
+ # Default activation function
+ if activation_fn is None:
+ activation_fn = torch.nn.functional.gelu
+
+ # Scale weights
+ up_proj_weight = scale_grad(up_proj_weight, gradient_scale)
+ down_proj_weight = scale_grad(down_proj_weight, gradient_scale)
+ if up_proj_bias is not None:
+ up_proj_bias = scale_grad(up_proj_bias, gradient_scale)
+ if down_proj_bias is not None:
+ down_proj_bias = scale_grad(down_proj_bias, gradient_scale)
+
+ # Resolve dtensors
+ up_proj_weight = _layers.mlp.resolve_dtensor(up_proj_weight)
+ down_proj_weight = _layers.mlp.resolve_dtensor(down_proj_weight)
+ if up_proj_bias is not None:
+ up_proj_bias = _layers.mlp.resolve_dtensor(up_proj_bias)
+ if down_proj_bias is not None:
+ down_proj_bias = _layers.mlp.resolve_dtensor(down_proj_bias)
+
+ # Up projection
+ x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias)
+
+ # Activation
+ x = activation_fn(x)
+
+ # Down projection
+ x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias)
+
+ return x
+
+
+# Combine outputs from shared expert and regular experts
+def combine_expert_shared_outputs(
+ shared_expert_out: torch.Tensor,
+ expert_out: torch.Tensor,
+ shared_expert_weighted_sum: bool = False,
+ moe_top_k: int = 1,
+) -> torch.Tensor:
+ if shared_expert_weighted_sum:
+ # Weighted sum based on number of experts used
+ total_experts = moe_top_k + 1
+ shared_weight = 1.0 / total_experts
+ expert_weight = moe_top_k / total_experts
+ return shared_expert_out * shared_weight + expert_out * expert_weight
+ else:
+ # Simple addition
+ return shared_expert_out + expert_out
+
+
+# Global variable to store load balancing loss
+_LOAD_BALANCING_LOSS = []
+
+
+def save_load_balancing_loss(loss):
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.append(loss)
+
+
+def get_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ return _LOAD_BALANCING_LOSS
+
+
+def clear_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.clear()
+
+
+def batched_load_balancing_loss(args):
+ if args.moe_loss_weight == 0:
+ return 0.0
+
+ tokens_per_expert, expert_scores = zip(*get_load_balancing_loss())
+ num_layers_per_pipeline_stage = args.num_layers // args.pipeline_model_parallel_size
+ if args.num_layers_per_virtual_pipeline_stage is not None:
+ num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage
+
+ if len(tokens_per_expert) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f"Expected {num_layers_per_pipeline_stage} token_per_experts "
+ f"but found {len(tokens_per_expert)}.\nnum_layers = "
+ f"{args.num_layers}\npipeline_model_parallel_size = "
+ f"{args.pipeline_model_parallel_size}\n"
+ "num_layers_per_virtual_pipeline_stage"
+ f" = {args.num_layers_per_virtual_pipeline_stage}",
+ )
+ if len(expert_scores) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f"Expected {num_layers_per_pipeline_stage} expert_scores "
+ f"but found {len(tokens_per_expert)}.\nnum_layers = "
+ f"{args.num_layers}\npipeline_model_parallel_size = "
+ f"{args.pipeline_model_parallel_size}\n"
+ "num_layers_per_virtual_pipeline_stage"
+ f" = {args.num_layers_per_virtual_pipeline_stage}",
+ )
+
+ # Verify the shape of the tokens_per_expert and expert_scores tensors.
+ assert all(
+ (x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert)
+ )
+
+ tokens = expert_scores[0].shape[0]
+ assert all(
+ (
+ (
+ x.ndim == 2
+ and x.shape[1] == args.moe_num_experts
+ and x.shape[0] == tokens
+ )
+ for x in expert_scores
+ )
+ )
+
+ # Concatenate the contributions of each layer and convert to
+ # the correct types and formats for the dot product.
+ expert_scores = torch.cat(expert_scores, dim=1)
+ if args.moe_lbl_in_fp32:
+ expert_scores = expert_scores.float()
+ if tokens != 0:
+ expert_scores = expert_scores.mean(dim=0)
+ else:
+ expert_scores = expert_scores.sum(dim=0)
+ tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype)
+
+ expected_values = num_layers_per_pipeline_stage * args.moe_num_experts
+ assert tokens_per_expert.numel() == expected_values
+ assert expert_scores.numel() == expected_values
+
+ # Calculate the total scale across all factors.
+ #
+ # loss_weight * num_experts / (num_layers * tokens * top_k)
+ scale_numerator = args.moe_num_experts * args.moe_loss_weight
+ scale_denominator = args.num_layers * tokens * args.moe_top_k
+ scale = scale_numerator / scale_denominator
+ return scale * torch.dot(tokens_per_expert, expert_scores)
+
+
+# Calculate the expert capacity based on tokens, top_k, number of experts,
+# expert parallel group, capacity factor, and whether expert model parallelism is used.
+def expert_capacity(
+ tokens: int,
+ top_k: int,
+ num_experts: int,
+ expert_parallel_group: int,
+ moe_capacity_factor: float,
+ moe_expert_model_parallelism: bool,
+) -> int:
+ world_size = (
+ dist.get_world_size(expert_parallel_group)
+ if moe_expert_model_parallelism
+ else 1
+ )
+
+ tokens_per_expert = top_k * tokens * world_size / num_experts
+ return int(moe_capacity_factor * tokens_per_expert)
+
+
+def load_balancing_loss(
+ tokens_per_expert: torch.Tensor,
+ expert_scores: torch.Tensor,
+ top_k: int,
+ num_experts: int,
+):
+ assert len(expert_scores.size()) == 2
+ tokens, num_experts = expert_scores.size()
+ assert num_experts == num_experts
+ assert len(tokens_per_expert.size()) == 1
+ (num_experts,) = tokens_per_expert.size()
+ assert num_experts == num_experts
+ scale = num_experts / (tokens * top_k)
+ return scale * torch.dot(
+ tokens_per_expert.to(expert_scores.dtype),
+ expert_scores.mean(dim=0),
+ )
+
+
+def indices_and_bins(
+ top_expert: torch.Tensor,
+ sort_end_bit: int,
+ num_experts: int,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ top_expert = top_expert.int()
+
+ # Ensure contiguous memory layout
+ top_expert = top_expert.contiguous()
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(top_expert.device):
+ output = ops.sort(top_expert, sort_end_bit)
+ bin_ids, indices = output
+ tokens_per_expert = ops.histogram(top_expert, num_experts)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+
+ bins = bins.view(1) if not len(bins.size()) else bins
+ return indices, bin_ids, bins, tokens_per_expert
+
+
+def expert_capacity_fn(
+ tokens: int,
+ top_k: int,
+ num_experts: int,
+ expert_parallel_group: torch.distributed.ProcessGroup,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+) -> int:
+ world_size = (
+ dist.get_world_size(expert_parallel_group)
+ if moe_expert_model_parallelism
+ else 1
+ )
+ tokens_per_expert = top_k * tokens * world_size / num_experts
+ return int(moe_capacity_factor * tokens_per_expert)
+
+
+def permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ top_k,
+ w1,
+ w2,
+ w1_bias,
+ w2_bias,
+ gradient_scale,
+ alpha,
+):
+ # Route tokens to experts
+ x = x.view(-1, x.shape[-1])
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(x.device):
+ x = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
+
+ # Expert computation
+ x = mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale, alpha)
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(x.device):
+ # Route tokens back
+ out = ops.binned_scatter(x, indices, expert_weights, bins, top_k)
+ return out
+
+
+def forward_once(
+ x: torch.Tensor,
+ expert_weights: torch.Tensor,
+ top_experts: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ top_k: int = 4,
+ num_experts: int = 128,
+ expert_parallel_group: int = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ mlp_impl: Optional[str] = None,
+):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
+ top_experts, sort_end_bit, num_experts
+ )
+
+ # Calculate expert capacity
+ sl, bs, _ = x.size()
+
+ expert_capacity = expert_capacity_fn(
+ sl * bs,
+ top_k,
+ num_experts,
+ expert_parallel_group,
+ moe_capacity_factor,
+ moe_expert_model_parallelism,
+ )
+
+ if expert_capacity == 0:
+ expert_capacity = torch.max(tokens_per_expert).item()
+
+ x = permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ top_k,
+ w1,
+ w2,
+ w1_bias,
+ w2_bias,
+ gradient_scale,
+ alpha,
+ )
+ return x, tokens_per_expert
+
+
+def parallel_forward_once(
+ x: torch.Tensor,
+ expert_weights: torch.Tensor,
+ top_experts: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ top_k: int = 4,
+ num_experts: int = 128,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = True,
+ hidden_size: int = 1152,
+ mlp_impl: Optional[str] = "grouped",
+):
+ # Flatten inputs
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+
+ # TODO: remove debugging var
+ # my_rank = dist.get_rank(expert_parallel_group) if expert_parallel_group else 0
+
+ with torch.no_grad():
+ # Step 1: Local permutation setup
+ indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
+ top_experts, sort_end_bit, num_experts
+ )
+
+ # Calculate sharding parameters
+ world_size = dist.get_world_size(expert_parallel_group)
+ hidden_sharding_deg = hidden_sharding_degree(
+ world_size, num_experts, hidden_size
+ )
+ experts_per_rank_val = experts_per_rank(num_experts, world_size)
+
+ # Replicate token counts for hidden sharding
+ repeated_tokens_per_expert = ops.repeat(
+ tokens_per_expert, (hidden_sharding_deg,)
+ )
+
+ # Exchange token counts across devices
+ parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert)
+
+ # Ensure CUB knows which device to use
+ tpe_handle = dist.all_to_all_single(
+ parallel_tokens_per_expert,
+ repeated_tokens_per_expert,
+ group=expert_parallel_group,
+ async_op=True,
+ )
+
+ # Step 2: Local permutation - group tokens by target device
+ x = x.view(-1, x.shape[-1]) # [sl * bs, hs]
+ x = ops.gather(x, indices, bin_ids, bins, top_k)
+
+ # Step 3: Compute communication counts and exchange tokens
+ with torch.no_grad():
+ tpe_handle.wait()
+
+ # Reshape for per-device calculations
+ repeated_tokens_per_expert = repeated_tokens_per_expert.view(
+ world_size, experts_per_rank_val
+ )
+ parallel_tokens_per_expert = parallel_tokens_per_expert.view(
+ world_size, experts_per_rank_val
+ )
+
+ # Calculate send/recv counts
+ send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1).tolist()
+ # recv_counts = parallel_tokens_per_expert.cpu().sum(dim=-1).tolist()
+ parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
+ recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1).tolist()
+ tokens_received = sum(recv_counts)
+
+ # Replicate for hidden sharding
+ x = ops.repeat(x, (hidden_sharding_deg, 1))
+
+ # Cross-device token exchange
+ parallel_x, parallel_x_handle = _layers.all_to_all.all_to_all(
+ x, recv_counts, send_counts, expert_parallel_group, async_op=True
+ )
+
+ with torch.no_grad():
+ # Step 4: Setup for local expert computation
+ replicate_bins = ops.inclusive_cumsum(parallel_tokens_per_expert.flatten(), 0)
+ replicate_bins = (
+ replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins
+ )
+
+ # Create expert indices for received tokens
+ parallel_top_expert = torch.remainder(
+ torch.arange(
+ num_experts * hidden_sharding_deg,
+ dtype=torch.int32,
+ device=indices.device,
+ ),
+ experts_per_rank_val,
+ )
+ parallel_top_expert = ops.replicate(
+ parallel_top_expert.unsqueeze(dim=0),
+ replicate_bins,
+ tokens_received,
+ ).flatten()
+
+ # Sort tokens by expert assignment
+ parallel_bin_ids, parallel_indices = ops.sort(
+ parallel_top_expert,
+ sort_end_bit,
+ )
+
+ # Calculate bins for local experts
+ parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
+ dim=0, dtype=torch.int
+ )
+ parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = (
+ parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins
+ )
+
+ # Calculate expert capacity
+ expert_capacity = expert_capacity_fn(
+ tokens_received,
+ top_k,
+ experts_per_rank_val,
+ expert_parallel_group,
+ moe_capacity_factor,
+ moe_expert_model_parallelism,
+ )
+ if expert_capacity == 0:
+ expert_capacity = torch.max(parallel_tokens_per_expert).item()
+
+ # Locally permute the tokens and perform the expert computation.
+ # Block to make sure that the cross-device permutation is complete.
+ if mlp_impl == "grouped":
+ # GroupedMLP requires counts on CPU. We can use the tensor already
+ # moved to CPU for the prior all_to_all, which avoids an extra
+ # device synchronization.
+ parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+
+ # Step 5: Expert computation
+ parallel_x_handle.wait()
+
+ parallel_x = permute_and_compute(
+ parallel_x,
+ parallel_tokens_per_expert,
+ parallel_indices,
+ parallel_bin_ids,
+ None, # expert_weights
+ parallel_bins,
+ expert_capacity,
+ top_k=1,
+ w1=w1,
+ w2=w2,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ )
+
+ # Step 6: Reverse communication - send results back
+ x, _ = _layers.all_to_all.all_to_all(
+ parallel_x, send_counts, recv_counts, expert_parallel_group
+ )
+
+ # Step 7: Reduce across hidden sharding dimension
+ shape = (hidden_sharding_deg, -1, hidden_size)
+ x = x.view(shape).sum(dim=0)
+
+ # Step 8: Final local unpermutation
+ x = ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
+
+ return x, tokens_per_expert.flatten()
+
+
+def moe_forward(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: Optional[torch.Tensor],
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+ w1: torch.Tensor = None,
+ w2: torch.Tensor = None,
+ w1_bias: torch.Tensor = None,
+ w2_bias: torch.Tensor = None,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ forward_fn: Any = None,
+ hidden_size: int = None,
+ mlp_impl: str = "grouped",
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+
+ # Route tokens to experts
+ logits, expert_weights, expert_indices = route_tokens(
+ x,
+ router_weight,
+ router_bias,
+ moe_top_k,
+ moe_num_experts,
+ moe_jitter_eps,
+ moe_normalize_expert_weights,
+ uniform_expert_assignment,
+ training,
+ )
+
+ # Create router scores for output
+ router_scores = (
+ torch.zeros_like(logits)
+ .scatter_(1, expert_indices, expert_weights)
+ .transpose(0, 1)
+ )
+
+ in_shape = x.size()
+
+ # Prepare forward function arguments
+ forward_args = {
+ "x": x,
+ "expert_weights": expert_weights,
+ "top_experts": expert_indices,
+ "w1": w1,
+ "w2": w2,
+ "w1_bias": w1_bias,
+ "w2_bias": w2_bias,
+ "gradient_scale": gradient_scale,
+ "alpha": alpha,
+ "sort_end_bit": sort_end_bit,
+ "top_k": moe_top_k,
+ "num_experts": moe_num_experts,
+ "expert_parallel_group": expert_parallel_group,
+ "moe_capacity_factor": moe_capacity_factor,
+ "moe_expert_model_parallelism": moe_expert_model_parallelism,
+ "mlp_impl": mlp_impl,
+ }
+
+ # Add hidden_size for parallel forward
+ if moe_expert_model_parallelism and hidden_size is not None:
+ forward_args["hidden_size"] = hidden_size
+ elif moe_expert_model_parallelism and hidden_size is None:
+ # Infer hidden_size from input shape
+ forward_args["hidden_size"] = x.shape[-1]
+
+ # Compute expert outputs
+ x, tokens_per_expert = forward_fn(**forward_args)
+
+ # Save load balancing loss if needed
+ moe_loss_weight = 0.0 # Can be made configurable
+ if training and moe_loss_weight > 0:
+ save_load_balancing_loss((tokens_per_expert, logits))
+
+ # Restore original shape
+ x = x.view(in_shape)
+
+ return x, expert_weights, router_scores
+
+
+def moe_forward_with_shared_expert(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: Optional[torch.Tensor],
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+ w1: torch.Tensor = None,
+ w2: torch.Tensor = None,
+ w1_bias: torch.Tensor = None,
+ w2_bias: torch.Tensor = None,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ forward_fn: Any = None,
+ hidden_size: int = None,
+ mlp_impl: str = "grouped",
+ # Shared expert parameters
+ shared_up_proj_weight: Optional[torch.Tensor] = None,
+ shared_down_proj_weight: Optional[torch.Tensor] = None,
+ shared_up_proj_bias: Optional[torch.Tensor] = None,
+ shared_down_proj_bias: Optional[torch.Tensor] = None,
+ shared_expert_weighted_sum: bool = False,
+ shared_activation_fn: Optional[Any] = None,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+
+ # First, compute regular MoE forward pass
+ expert_out, expert_weights, router_scores = moe_forward(
+ x=x,
+ router_weight=router_weight,
+ router_bias=router_bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=training,
+ w1=w1,
+ w2=w2,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=moe_expert_model_parallelism,
+ forward_fn=forward_fn,
+ hidden_size=hidden_size,
+ mlp_impl=mlp_impl,
+ )
+
+ # If shared expert weights provided, compute shared expert output
+ if shared_up_proj_weight is not None and shared_down_proj_weight is not None:
+ shared_expert_out = shared_mlp_forward(
+ x=x,
+ up_proj_weight=shared_up_proj_weight,
+ down_proj_weight=shared_down_proj_weight,
+ up_proj_bias=shared_up_proj_bias,
+ down_proj_bias=shared_down_proj_bias,
+ activation_fn=shared_activation_fn,
+ gradient_scale=gradient_scale,
+ )
+
+ # Combine expert outputs
+ combined_out = combine_expert_shared_outputs(
+ shared_expert_out=shared_expert_out,
+ expert_out=expert_out,
+ shared_expert_weighted_sum=shared_expert_weighted_sum,
+ moe_top_k=moe_top_k,
+ )
+
+ return combined_out, expert_weights, router_scores
+
+ # Return regular MoE output if no shared expert
+ return expert_out, expert_weights, router_scores
+
+
+def create_shared_expert_weights(
+ hidden_size: int,
+ shared_expert_hidden_size: int,
+ device: torch.device,
+ dtype: torch.dtype,
+ init_method: Any,
+ output_layer_init_method: Any = None,
+) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
+
+ if output_layer_init_method is None:
+ output_layer_init_method = init_method
+
+ # Create weight tensors
+ up_proj_weight = torch.empty(
+ shared_expert_hidden_size,
+ hidden_size,
+ device=device,
+ dtype=dtype,
+ )
+ down_proj_weight = torch.empty(
+ hidden_size,
+ shared_expert_hidden_size,
+ device=device,
+ dtype=dtype,
+ )
+
+ # Initialize weights
+ init_method(up_proj_weight)
+ output_layer_init_method(down_proj_weight)
+
+ # No bias by default
+ return up_proj_weight, down_proj_weight, None, None
+
+
+# HACK: Extract device_mesh from pre-hook closure - required for transformers integration
+# This exists because device_mesh is trapped in hook closures with no model attribute
+# Fragile - breaks if hook structure changes or Python internals change
+# TODO: Replace with a more robust solution when available
+def get_device_mesh(model):
+ # Extract device_mesh from child's unused pre_hook closure
+ try:
+ # Find the pre-hook that contains 'device_mesh' in its closure
+ hook = next(
+ h
+ for h in model.experts._forward_pre_hooks.values()
+ if "device_mesh" in h.__code__.co_freevars
+ )
+ # Extract the device_mesh from the closure
+ return hook.__closure__[
+ hook.__code__.co_freevars.index("device_mesh")
+ ].cell_contents
+ except Exception:
+ return None
+
+
+class MegaBlocksMoeMLP(torch.nn.Module):
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ gradient_scale = getattr(self.experts, "gradient_scale", None)
+ alpha = getattr(self.experts, "alpha", 1.0)
+ moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+ uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
+
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = get_device_mesh(self)
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
+
+ has_parallel = (
+ expert_parallel_group is not None
+ and dist.is_initialized()
+ and dist.get_world_size(expert_parallel_group) > 1
+ )
+ forward_fn = parallel_forward_once if has_parallel else forward_once
+
+ sort_end_bit = max(
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
+ )
+ mlp_impl = getattr(self, "mlp_impl", "grouped")
+ output, expert_weights_out, *_ = moe_forward(
+ x=x,
+ router_weight=self.router.weight,
+ router_bias=self.router.bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=self.training,
+ w1=self.experts.gate_up_proj,
+ w2=self.experts.down_proj,
+ w1_bias=self.experts.gate_up_proj_bias,
+ w2_bias=self.experts.down_proj_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=has_parallel,
+ forward_fn=forward_fn,
+ hidden_size=self.experts.hidden_size,
+ mlp_impl=mlp_impl,
+ )
+ return output, expert_weights_out
+
+
+# Export main classes
+__all__ = ["MegaBlocksMoeMLP", "MegaBlocksMoeMLPWithSharedExpert"]
+
+
+class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
+
+ def __init__(self):
+ super().__init__()
+ # Shared expert weights will be set by the user
+ self.shared_up_proj_weight = None
+ self.shared_down_proj_weight = None
+ self.shared_up_proj_bias = None
+ self.shared_down_proj_bias = None
+ self.shared_expert_weighted_sum = False
+ self.shared_activation_fn = None
+
+ def set_shared_expert_weights(
+ self,
+ up_proj_weight: torch.Tensor,
+ down_proj_weight: torch.Tensor,
+ up_proj_bias: Optional[torch.Tensor] = None,
+ down_proj_bias: Optional[torch.Tensor] = None,
+ weighted_sum: bool = False,
+ activation_fn: Optional[Any] = None,
+ ):
+ self.shared_up_proj_weight = up_proj_weight
+ self.shared_down_proj_weight = down_proj_weight
+ self.shared_up_proj_bias = up_proj_bias
+ self.shared_down_proj_bias = down_proj_bias
+ self.shared_expert_weighted_sum = weighted_sum
+ self.shared_activation_fn = activation_fn
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ gradient_scale = getattr(self.experts, "gradient_scale", None)
+ alpha = getattr(self.experts, "alpha", 1.0)
+ moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+ uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
+
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = get_device_mesh(self)
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
+
+ has_parallel = (
+ expert_parallel_group is not None
+ and dist.is_initialized()
+ and dist.get_world_size(expert_parallel_group) > 1
+ )
+ forward_fn = parallel_forward_once if has_parallel else forward_once
+
+ sort_end_bit = max(
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
+ )
+ mlp_impl = getattr(self, "mlp_impl", "grouped")
+
+ output, expert_weights_out, *_ = moe_forward_with_shared_expert(
+ x=x,
+ router_weight=self.router.weight,
+ router_bias=self.router.bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=self.training,
+ w1=self.experts.gate_up_proj,
+ w2=self.experts.down_proj,
+ w1_bias=self.experts.gate_up_proj_bias,
+ w2_bias=self.experts.down_proj_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=has_parallel,
+ forward_fn=forward_fn,
+ hidden_size=self.experts.hidden_size,
+ mlp_impl=mlp_impl,
+ # Shared expert parameters
+ shared_up_proj_weight=self.shared_up_proj_weight,
+ shared_down_proj_weight=self.shared_down_proj_weight,
+ shared_up_proj_bias=self.shared_up_proj_bias,
+ shared_down_proj_bias=self.shared_down_proj_bias,
+ shared_expert_weighted_sum=self.shared_expert_weighted_sum,
+ shared_activation_fn=self.shared_activation_fn,
+ )
+ return output, expert_weights_out
+
+
+# Patch for XPU or CPU support
+if hasattr(torch, "xpu") and torch.xpu.is_available():
+ from .xpu_fused_moe import MegaBlocksMoeMLP
+
+from .cpu_moe_cpp import CPUMegaBlocksMoeMLP
diff --git a/build/torch29-cxx11-cu128-x86_64-linux/megablocks/__init__.py b/build/torch29-cxx11-cu128-x86_64-linux/megablocks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..03dbc1afe1cf156661a2b1b22003cd5f599a0309
--- /dev/null
+++ b/build/torch29-cxx11-cu128-x86_64-linux/megablocks/__init__.py
@@ -0,0 +1,26 @@
+import ctypes
+import sys
+
+import importlib
+from pathlib import Path
+from types import ModuleType
+
+def _import_from_path(file_path: Path) -> ModuleType:
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
+ # it would also be used for other imports. So, we make a module name that
+ # depends on the path for it to be unique using the hex-encoded hash of
+ # the path.
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
+ module_name = path_hash
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ if spec is None:
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
+ module = importlib.util.module_from_spec(spec)
+ if module is None:
+ raise ImportError(f"Cannot load module {module_name} from spec")
+ sys.modules[module_name] = module
+ spec.loader.exec_module(module) # type: ignore
+ return module
+
+
+globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
diff --git a/build/torch29-cxx11-cu128-x86_64-linux/metadata.json b/build/torch29-cxx11-cu128-x86_64-linux/metadata.json
new file mode 100644
index 0000000000000000000000000000000000000000..e3e4edf582b7ffb515d0ed32e9fc9c89f125c441
--- /dev/null
+++ b/build/torch29-cxx11-cu128-x86_64-linux/metadata.json
@@ -0,0 +1,21 @@
+{
+ "version": 1,
+ "license": "Apache-2.0",
+ "python-depends": [],
+ "backend": {
+ "type": "cuda",
+ "archs": [
+ "10.0",
+ "10.1",
+ "12.0",
+ "7.0",
+ "7.2",
+ "7.5",
+ "8.0",
+ "8.6",
+ "8.7",
+ "8.9",
+ "9.0"
+ ]
+ }
+}
diff --git a/build/torch29-cxx11-cu128-x86_64-linux/ops/__init__.py b/build/torch29-cxx11-cu128-x86_64-linux/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b944080df810d0b0cfc571f3009b0098a651f9b7
--- /dev/null
+++ b/build/torch29-cxx11-cu128-x86_64-linux/ops/__init__.py
@@ -0,0 +1,35 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from .binned_gather import binned_gather
+from .binned_scatter import binned_scatter
+from .cumsum import exclusive_cumsum, inclusive_cumsum
+from .gather import gather
+from .histogram import histogram
+from .padded_gather import padded_gather
+from .padded_scatter import padded_scatter
+from .repeat import repeat
+from .replicate import replicate
+from .round_up import round_up
+from .scatter import scatter
+from .sort import sort
+from .sum import sum
+from .topology import topology
+
+__all__ = [
+ 'binned_gather',
+ 'binned_scatter',
+ 'exclusive_cumsum',
+ 'inclusive_cumsum',
+ 'gather',
+ 'histogram',
+ 'padded_gather',
+ 'padded_scatter',
+ 'repeat',
+ 'replicate',
+ 'round_up',
+ 'scatter',
+ 'sort',
+ 'sum',
+ 'topology',
+]
diff --git a/build/torch29-cxx11-cu128-x86_64-linux/ops/all_to_all_benchmark.py b/build/torch29-cxx11-cu128-x86_64-linux/ops/all_to_all_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c939818edca3345f6344bbc7cef07ffe3cd0181
--- /dev/null
+++ b/build/torch29-cxx11-cu128-x86_64-linux/ops/all_to_all_benchmark.py
@@ -0,0 +1,63 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import torch.distributed as dist
+
+# from megablocks import benchmark_util
+# from megablocks.layers.all_to_all import all_to_all
+
+from .. import benchmark_util
+from .._layers.all_to_all import all_to_all
+
+_ALL_TO_ALL_BENCHMARK = (
+ (8, 1024),
+ (16, 1024),
+ (32, 1024),
+ (64, 1024),
+ (128, 1024),
+ (256, 1024),
+ (512, 1024),
+ (1024, 1024),
+ (2 * 1024, 1024),
+ (4 * 1024, 1024),
+ (8 * 1024, 1024),
+ (16 * 1024, 1024),
+ (32 * 1024, 1024),
+ (64 * 1024, 1024),
+ (128 * 1024, 1024),
+ (256 * 1024, 1024),
+ (512 * 1024, 1024),
+ (1024 * 1024, 1024),
+)
+
+
+def benchmark_all_to_all(group, sl, hs):
+ world_size = dist.get_world_size(group)
+ assert (sl % world_size) == 0
+ send_recv_sizes = [sl // world_size] * world_size
+
+ x = torch.randn((sl, hs)).cuda().half()
+
+ details = {
+ 'world_size': world_size,
+ 'message_size (B)': send_recv_sizes[0] * hs * 2, # 2B elements.
+ }
+
+ def benchmark():
+ return all_to_all(x, send_recv_sizes, send_recv_sizes, group)
+
+ time, std = benchmark_util.benchmark_function(benchmark)
+
+ if dist.get_rank(group) == 0:
+ benchmark_util.log_benchmark('All-To-All', details, time, std)
+
+
+if __name__ == '__main__':
+ assert dist.is_available()
+ group = dist.init_process_group(backend='nccl')
+ local_rank = dist.get_rank(group)
+ torch.cuda.set_device(local_rank)
+
+ for args in _ALL_TO_ALL_BENCHMARK:
+ benchmark_all_to_all(group, *args)
diff --git a/build/torch29-cxx11-cu128-x86_64-linux/ops/binned_gather.py b/build/torch29-cxx11-cu128-x86_64-linux/ops/binned_gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..189a7fa3518d660f29ea32e7a04827164af98d60
--- /dev/null
+++ b/build/torch29-cxx11-cu128-x86_64-linux/ops/binned_gather.py
@@ -0,0 +1,37 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for binned_gather kernel.
+class BinnedGatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bins: torch.Tensor,
+ bin_size: int,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bins)
+ ctx.top_k = top_k
+ return kernels.binned_gather(x, indices, None, bins, bin_size, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ indices, bins = ctx.saved_tensors
+ out = kernels.binned_scatter(grad, indices, None, bins, ctx.top_k)
+ return out, None, None, None, None
+
+
+binned_gather = BinnedGatherOp.apply
diff --git a/build/torch29-cxx11-cu128-x86_64-linux/ops/binned_scatter.py b/build/torch29-cxx11-cu128-x86_64-linux/ops/binned_scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb937c0c106662ce8108c1cb926f8f063b163d3d
--- /dev/null
+++ b/build/torch29-cxx11-cu128-x86_64-linux/ops/binned_scatter.py
@@ -0,0 +1,59 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for binned_scatter kernel.
+class BinnedScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ):
+ assert len(x.size()) == 3
+ ctx.bin_size = x.size(1)
+ ctx.top_k = top_k
+
+ # TODO(tgale): Don't save 'x' for backwards if we don't need to
+ # calculate the gradient w.r.t. 'weights'.
+ ctx.save_for_backward(x, indices, weights, bins)
+ return kernels.binned_scatter(x, indices, weights, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ x, indices, weights, bins = ctx.saved_tensors
+ out = kernels.binned_gather(
+ grad,
+ indices,
+ weights,
+ bins,
+ ctx.bin_size,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[2]:
+ wgrad = kernels.binned_scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bins,
+ ctx.top_k,
+ )
+ return out, None, wgrad, None, None
+
+
+binned_scatter = BinnedScatterOp.apply
diff --git a/build/torch29-cxx11-cu128-x86_64-linux/ops/cumsum.py b/build/torch29-cxx11-cu128-x86_64-linux/ops/cumsum.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2b7572391e20045d335cf7337246e8a9b9f57ef
--- /dev/null
+++ b/build/torch29-cxx11-cu128-x86_64-linux/ops/cumsum.py
@@ -0,0 +1,52 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ # import megablocks_ops as ops # type: ignore
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrappers for cumsum kernels.
+# NOTE: Does not support gradients.
+class ExclusiveCumsumOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, dim: int):
+ if len(x.size()) == 1:
+ x = x.view([1, -1])
+ out = torch.empty_like(x)
+ ops.exclusive_cumsum(x, 1, out)
+ return out.squeeze()
+ out = torch.empty_like(x)
+ ops.exclusive_cumsum(x, dim, out)
+ return out
+
+
+exclusive_cumsum = ExclusiveCumsumOp.apply
+
+
+class InclusiveCumsumOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, dim: int) -> torch.Tensor:
+ if len(x.size()) == 1:
+ x = x.view([1, -1])
+ out = torch.empty_like(x)
+ ops.inclusive_cumsum(x, 1, out)
+ return out.squeeze()
+ out = torch.empty_like(x)
+ ops.inclusive_cumsum(x, dim, out)
+ return out
+
+
+inclusive_cumsum = InclusiveCumsumOp.apply
diff --git a/build/torch29-cxx11-cu128-x86_64-linux/ops/gather.py b/build/torch29-cxx11-cu128-x86_64-linux/ops/gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1f87c1e7bed8d3589dd790805234976e0b05898
--- /dev/null
+++ b/build/torch29-cxx11-cu128-x86_64-linux/ops/gather.py
@@ -0,0 +1,38 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for gather kernel.
+class GatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bin_ids, bins)
+ ctx.top_k = top_k
+ return kernels.gather(x, indices, bin_ids, None, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+
+ indices, bin_ids, bins = ctx.saved_tensors
+ out = kernels.scatter(grad, indices, bin_ids, None, bins, ctx.top_k)
+ return out, None, None, None, None, None
+
+
+gather = GatherOp.apply
diff --git a/build/torch29-cxx11-cu128-x86_64-linux/ops/histogram.py b/build/torch29-cxx11-cu128-x86_64-linux/ops/histogram.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b3f058ec373cbba7555704fb5e4212c3cc75d9d
--- /dev/null
+++ b/build/torch29-cxx11-cu128-x86_64-linux/ops/histogram.py
@@ -0,0 +1,27 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for histogram kernel.
+# NOTE: Does not support gradients.
+class HistogramOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, max_val: float):
+ return ops.histogram(x, max_val)
+
+
+histogram = HistogramOp.apply
diff --git a/build/torch29-cxx11-cu128-x86_64-linux/ops/histogram_benchmark.py b/build/torch29-cxx11-cu128-x86_64-linux/ops/histogram_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..c57b7bf8228e01237236748147368b09ffdf8072
--- /dev/null
+++ b/build/torch29-cxx11-cu128-x86_64-linux/ops/histogram_benchmark.py
@@ -0,0 +1,78 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import numpy as np
+import torch
+from absl.testing import parameterized
+
+from .. import ops
+
+_HISTOGRAM_TESTS = (
+ (16384, torch.int32, 2),
+ (16384, torch.int32, 4),
+ (16384, torch.int32, 8),
+ (16384, torch.int32, 16),
+ (16384, torch.int32, 32),
+ (16384, torch.int32, 64),
+ (16384, torch.int32, 128),
+ (16384, torch.int32, 256),
+)
+
+
+def benchmark_function(fn, iterations=10):
+ # Run once to get rid of startup overhead.
+ fn()
+ times = []
+ for _ in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+ start.record()
+ fn()
+ end.record()
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ times = np.array(times)
+ return times.mean(), times.std(), times.max(), times.min()
+
+
+def log_benchmark(arguments, mean_t, std_t):
+ print('=' * 60)
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t))
+ print('=' * 60)
+
+
+class HistogramBenchmark(parameterized.TestCase):
+
+ @parameterized.parameters(*_HISTOGRAM_TESTS)
+ def testHistogram(self, n, dtype, max_val):
+ x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
+
+ mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.histogram(x, max_val),)
+ arguments = {
+ 'n': n,
+ 'dtype': dtype,
+ 'max_val': max_val,
+ }
+ log_benchmark(arguments, mean_t, std_t)
+
+ @parameterized.parameters(*_HISTOGRAM_TESTS)
+ def testTorchHistogram(self, n, dtype, max_val):
+ x = torch.randint(0, 128, (n,)).cuda().to(dtype)
+
+ mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.histc(x, max_val, 0, max_val - 1),)
+ arguments = {
+ 'n': n,
+ 'dtype': dtype,
+ 'max_val': max_val,
+ }
+ log_benchmark(arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch29-cxx11-cu128-x86_64-linux/ops/matmul_benchmark.py b/build/torch29-cxx11-cu128-x86_64-linux/ops/matmul_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ccc5dcec5e9a663794fad944c45285869c4d1c1
--- /dev/null
+++ b/build/torch29-cxx11-cu128-x86_64-linux/ops/matmul_benchmark.py
@@ -0,0 +1,415 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+
+# import stk
+
+# try:
+# import stk
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/ops/matmul_benchmark.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+
+# Calling tensor.t() calls tensor.transpose(0, 1) which calls
+# torch.as_strided(...). Circumvent this chain to avoid an overhead
+# this adds.
+def transpose_view(x):
+ return torch.as_strided(
+ x,
+ (x.shape[1], x.shape[0]),
+ (x.stride()[1], x.stride()[0]),
+ )
+
+
+_MATMUL_TESTS = (
+ (64 * 1024, 512, 2048, 64),
+ (32 * 1024, 768, 3072, 64),
+ (8 * 1024, 1024, 4096, 64),
+ (4 * 2048, 4096, 4 * 4096, 4),
+)
+
+
+def log_benchmark(name, arguments, time, std, flops):
+ benchmark_util.log_benchmark(name, arguments, time, std)
+ print('flops = {:.2f}B'.format(flops / 1e9))
+ print('throughput = {:.2f}T'.format(flops / 1e9 / time))
+ print('=' * 60)
+
+
+class MatmulBenchmark(parameterized.TestCase):
+
+ def build_sparse_matrix(self, x, padded_bins, fhs, ne):
+ blocking = 128
+ padded_tokens, _ = x.size()
+ assert padded_tokens % blocking == 0
+ assert fhs % blocking == 0
+
+ # Offsets for the sparse matrix. All rows have the
+ # same number of nonzero blocks dictated by the
+ # dimensionality of a single expert.
+ block_rows = padded_tokens // blocking
+ blocks_per_row = fhs // blocking
+ offsets = torch.arange(
+ 0,
+ block_rows * blocks_per_row + 1,
+ blocks_per_row,
+ dtype=torch.int32,
+ device=x.device,
+ )
+
+ # Indices for the sparse matrix. The indices for
+ # the intermediate matrix are dynamic depending
+ # on the mapping of tokens to experts.
+ column_indices = ops.topology(
+ padded_bins,
+ blocking,
+ block_rows,
+ blocks_per_row,
+ )
+ data = torch.empty(
+ column_indices.numel(),
+ blocking,
+ blocking,
+ dtype=torch.float16,
+ device=x.device,
+ )
+ shape = (padded_tokens, fhs * ne)
+ row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
+ return stk.Matrix(shape, data, row_indices, column_indices, offsets)
+
+ def build_input_matrix(self, sl, hs, ne):
+ x = torch.randn((sl, hs)).cuda().half()
+
+ # Assign tokens to experts uniformly.
+ top_expert = torch.arange(0, sl).cuda().int() % ne
+
+ bin_ids, indices = ops.sort(top_expert)
+ tokens_per_expert = ops.histogram(top_expert, ne)
+ padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ out = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, 1)
+ return out, padded_bins
+
+ def build_weight_matrix(self, ne, hs, fhs):
+ return torch.randn((hs, ne * fhs)).cuda().half()
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear0_Fwd_SDD_NT(self, sl, hs, fhs, ne):
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
+ w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+ topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+ w = transpose_view(w)
+
+ def benchmark():
+ return stk.ops.sdd(x, w, topo)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '0::Fwd::SDD::NT',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * fhs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear0_GradX_DSD_NN(self, sl, hs, fhs, ne):
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
+ w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+ topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+
+ def benchmark():
+ return stk.ops.dsd(topo, w)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '0::GradX::DSD::NN',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * fhs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear0_GradW_DSD_TN(self, sl, hs, fhs, ne):
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
+ topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+ topo = topo.t()
+
+ def benchmark():
+ return stk.ops.dsd(topo, x)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '0::GradW::DSD::TN',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * fhs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear1_Fwd_DSD_NN(self, sl, hs, fhs, ne):
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
+ w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+ x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+
+ def benchmark():
+ return stk.ops.dsd(x, w)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '1::Fwd::DSD::NN',
+ arguments,
+ mean_t,
+ std_t,
+ x.nnz * hs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear1_GradX_SDD_NT(self, sl, hs, fhs, ne):
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
+ w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+ x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+ out = stk.ops.dsd(x, w)
+ w = transpose_view(w)
+
+ def benchmark():
+ return stk.ops.sdd(out, w, x)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '1::GradX::SDD::NT',
+ arguments,
+ mean_t,
+ std_t,
+ x.nnz * hs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear1_GradW_DSD_TN(self, sl, hs, fhs, ne):
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
+ w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+ x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+ out = stk.ops.dsd(x, w)
+ x = x.t()
+
+ def benchmark():
+ return stk.ops.dsd(x, out)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '1::GradW::DSD::TN',
+ arguments,
+ mean_t,
+ std_t,
+ x.nnz * hs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear0_Fwd_DDD_NT(self, sl, hs, fhs, ne):
+ assert (sl % ne) == 0
+ x = torch.randn((ne, sl // ne, hs)).cuda().half()
+ w = torch.randn((ne, hs, fhs)).cuda().half()
+
+ w = w.transpose(1, 2).contiguous()
+ w = w.transpose(1, 2)
+
+ def benchmark():
+ return torch.bmm(x, w)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '0::Fwd:DDD::NT',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * fhs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear0_GradX_DDD_NN(self, sl, hs, fhs, ne):
+ assert (sl % ne) == 0
+ x = torch.randn((ne, sl // ne, hs)).cuda().half()
+ w = torch.randn((ne, hs, fhs)).cuda().half()
+ out = torch.bmm(x, w)
+ w = w.transpose(1, 2).contiguous()
+
+ def benchmark():
+ return torch.bmm(out, w)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '0:GradX:DDD::NN',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * fhs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear0_GradW_DDD_TN(self, sl, hs, fhs, ne):
+ assert (sl % ne) == 0
+ x = torch.randn((ne, sl // ne, hs)).cuda().half()
+ w = torch.randn((ne, hs, fhs)).cuda().half()
+ out = torch.bmm(x, w)
+ out = out.transpose(1, 2)
+
+ def benchmark():
+ return torch.bmm(out, x)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '0:GradW:DDD::TN',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * fhs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear1_Fwd_DDD_NN(self, sl, hs, fhs, ne):
+ assert (sl % ne) == 0
+ x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+ w = torch.randn((ne, fhs, hs)).cuda().half()
+
+ def benchmark():
+ return torch.bmm(x, w)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '1::Fwd::DDD::NN',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * hs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear1_GradX_DDD_NT(self, sl, hs, fhs, ne):
+ assert (sl % ne) == 0
+ x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+ w = torch.randn((ne, fhs, hs)).cuda().half()
+ out = torch.bmm(x, w)
+ w = torch.transpose(w, 1, 2)
+
+ def benchmark():
+ return torch.bmm(out, w)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '1::GradX::DDD::NT',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * hs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear1_GradW_DDD_TN(self, sl, hs, fhs, ne):
+ assert (sl % ne) == 0
+ x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+ w = torch.randn((ne, fhs, hs)).cuda().half()
+ out = torch.bmm(x, w)
+ x = torch.transpose(x, 1, 2)
+
+ def benchmark():
+ return torch.bmm(x, out)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '1::GradW::DDD::TN',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * hs * 2,
+ )
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch29-cxx11-cu128-x86_64-linux/ops/padded_gather.py b/build/torch29-cxx11-cu128-x86_64-linux/ops/padded_gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1cf4047c9494394d2a3884ba8830179013db7ff
--- /dev/null
+++ b/build/torch29-cxx11-cu128-x86_64-linux/ops/padded_gather.py
@@ -0,0 +1,55 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for padded_gather kernel.
+class PaddedGatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bin_ids, bins, padded_bins)
+ ctx.top_k = top_k
+ return kernels.padded_gather(
+ x,
+ indices,
+ bin_ids,
+ None,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+
+ indices, bin_ids, bins, padded_bins = ctx.saved_tensors
+ out = kernels.padded_scatter(
+ grad,
+ indices,
+ bin_ids,
+ None,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+ return out, None, None, None, None, None
+
+
+padded_gather = PaddedGatherOp.apply
diff --git a/build/torch29-cxx11-cu128-x86_64-linux/ops/padded_scatter.py b/build/torch29-cxx11-cu128-x86_64-linux/ops/padded_scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..61e021b81497e472cda5d72bdac557a0ca92d262
--- /dev/null
+++ b/build/torch29-cxx11-cu128-x86_64-linux/ops/padded_scatter.py
@@ -0,0 +1,98 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for padded_scatter kernel.
+class PaddedScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+ ):
+ maybe_x = [x] if ctx.needs_input_grad[3] else []
+ ctx.save_for_backward(
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ *maybe_x,
+ )
+ ctx.top_k = top_k
+ ctx.x_shape = x.shape
+ return kernels.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ saved_tensors = ctx.saved_tensors
+
+ indices, bin_ids, weights, bins, padded_bins = saved_tensors[:5]
+ dgrad = None
+ if ctx.needs_input_grad[0]:
+ dgrad = kernels.padded_gather(
+ grad,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[3]: # need wgrad
+ x = saved_tensors[-1]
+ wgrad = kernels.padded_scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+ return dgrad, None, None, wgrad, None, None, None, None
+
+
+def padded_scatter(
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+):
+ return PaddedScatterOp.apply(
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
diff --git a/build/torch29-cxx11-cu128-x86_64-linux/ops/padded_scatter_benchmark.py b/build/torch29-cxx11-cu128-x86_64-linux/ops/padded_scatter_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..c575cfe7487d346ba9ec18bbb7ef17f2eb77ec51
--- /dev/null
+++ b/build/torch29-cxx11-cu128-x86_64-linux/ops/padded_scatter_benchmark.py
@@ -0,0 +1,66 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import torch
+from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+_PADDED_SCATTER_BENCHMARK = (
+ # dMoE-Medium, 8-way EMP.
+ (1024 * 16, 1024, 8, 4),
+ # dMoE-Medium, post-all-to-all.
+ (1024 * 16 * 4, 1024, 8, 1),
+)
+
+
+class PaddedScatterTest(parameterized.TestCase):
+
+ @parameterized.parameters(*_PADDED_SCATTER_BENCHMARK)
+ def testPaddedScatter(self, sl, hs, ne, top_k):
+ # Create the data and indices.
+ x = torch.randn((sl, hs)).cuda().half()
+
+ # Randomly assign tokens to experts.
+ top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int()
+ bin_ids, indices = ops.sort(top_expert)
+ tokens_per_expert = ops.histogram(top_expert, ne)
+ padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+
+ # Sample weights for the scatter reduce.
+ weights = torch.rand((sl * top_k,)).cuda().half()
+
+ # Gather the data to prepare for backwards.
+ x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
+
+ def benchmark():
+ return ops.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ time, std = benchmark_util.benchmark_function(benchmark)
+ benchmark_util.log_benchmark(
+ 'Padded Scatter',
+ {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'num_experts': ne,
+ 'top_k': top_k,
+ },
+ time,
+ std,
+ )
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch29-cxx11-cu128-x86_64-linux/ops/permute_benchmark.py b/build/torch29-cxx11-cu128-x86_64-linux/ops/permute_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..6536eeeae402659a087e5c51ef9840627af56501
--- /dev/null
+++ b/build/torch29-cxx11-cu128-x86_64-linux/ops/permute_benchmark.py
@@ -0,0 +1,149 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import torch
+from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+_PERMUTE_TESTS = (
+ (16384, 768, 2),
+ (16384, 768, 4),
+ (16384, 768, 8),
+ (16384, 768, 16),
+ (16384, 768, 32),
+ (16384, 768, 64),
+ (16384, 768, 128),
+ (16384 * 8, 768, 2),
+ (16384 * 8, 768, 4),
+ (16384 * 8, 768, 8),
+ (16384 * 8, 768, 16),
+ (16384 * 8, 768, 32),
+ (16384 * 8, 768, 64),
+ (16384 * 8, 768, 128),
+)
+
+
+class PermuteBenchmark(parameterized.TestCase):
+
+ @parameterized.parameters(*_PERMUTE_TESTS)
+ def testBinnedGather(self, sl, hs, ne):
+ # NOTE: Capacity factor == 1.
+ ec = sl // ne
+
+ # Create the data and indices.
+ x = torch.randn((sl, hs)).cuda().half()
+ top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+ bin_ids, indices = ops.sort(top_expert)
+ tokens_per_expert = ops.histogram(indices, ne)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+
+ def benchmark():
+ return ops.binned_gather(x, indices, bins, ec)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'num_experts': ne,
+ }
+ benchmark_util.log_benchmark('BinnedGather', arguments, mean_t, std_t)
+
+ @parameterized.parameters(*_PERMUTE_TESTS)
+ def testBinnedScatter(self, sl, hs, ne):
+ # NOTE: Capacity factor == 1.
+ ec = sl // ne
+
+ # Create the data and indices.
+ x = torch.randn((sl, hs)).cuda().half()
+ top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+ bin_ids, indices = ops.sort(top_expert)
+ tokens_per_expert = ops.histogram(indices, ne)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ x = ops.binned_gather(x, indices, bins, ec)
+
+ def benchmark():
+ return ops.binned_scatter(x, indices, bins)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'num_experts': ne,
+ }
+ benchmark_util.log_benchmark('BinnedScatter', arguments, mean_t, std_t)
+
+ @parameterized.parameters(*_PERMUTE_TESTS)
+ def testPaddedGather(self, sl, hs, ne):
+ # Create the data and indices.
+ x = torch.randn((sl, hs)).cuda().half()
+
+ # Randomly assign tokens to experts.
+ top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+ bin_ids, indices = ops.sort(top_expert)
+ tokens_per_expert = ops.histogram(top_expert, ne)
+ padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+
+ def benchmark():
+ return ops.padded_gather(x, indices, bin_ids, bins, padded_bins)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'num_experts': ne,
+ }
+ benchmark_util.log_benchmark('PaddedGather', arguments, mean_t, std_t)
+
+ @parameterized.parameters(*_PERMUTE_TESTS)
+ def testPaddedScatter(self, sl, hs, ne):
+ # Create the data and indices.
+ x = torch.randn((sl, hs)).cuda().half()
+
+ # Randomly assign tokens to experts.
+ top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+ bin_ids, indices = ops.sort(top_expert)
+ tokens_per_expert = ops.histogram(top_expert, ne)
+ padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins)
+
+ def benchmark():
+ return ops.padded_scatter(x, indices, bin_ids, bins, padded_bins)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'num_experts': ne,
+ }
+ benchmark_util.log_benchmark('PaddedScatter', arguments, mean_t, std_t)
+
+ @parameterized.parameters(*_PERMUTE_TESTS)
+ def testCopy(self, sl, hs, ne):
+ # NOTE: Capacity factor == 1.
+ # ec = sl // ne
+
+ # Create the data and indices.
+ x = torch.randn((sl, hs)).cuda().half()
+ y = x.clone()
+
+ def benchmark():
+ return y.copy_(x)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'num_experts': ne,
+ }
+ benchmark_util.log_benchmark('Copy', arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch29-cxx11-cu128-x86_64-linux/ops/repeat.py b/build/torch29-cxx11-cu128-x86_64-linux/ops/repeat.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e9e09de5f857d51cd758ab30b2f3a846d6f9275
--- /dev/null
+++ b/build/torch29-cxx11-cu128-x86_64-linux/ops/repeat.py
@@ -0,0 +1,10 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+
+def repeat(x: torch.Tensor, tiling: torch.Size):
+ if all((t == 1 for t in tiling)):
+ return x
+ return x.repeat(*tiling)
diff --git a/build/torch29-cxx11-cu128-x86_64-linux/ops/replicate.py b/build/torch29-cxx11-cu128-x86_64-linux/ops/replicate.py
new file mode 100644
index 0000000000000000000000000000000000000000..26daf0eede330603a4b8ea7167faf1411d07ca93
--- /dev/null
+++ b/build/torch29-cxx11-cu128-x86_64-linux/ops/replicate.py
@@ -0,0 +1,36 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for replicate kernel.
+class ReplicateOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, bins: torch.Tensor, num_outputs: int):
+ ctx.save_for_backward(bins)
+ out = torch.empty((x.shape[0], num_outputs), dtype=x.dtype, device=x.device)
+ ops.replicate_forward(x, bins, out)
+ return out
+
+ @staticmethod
+ def backward(ctx: Any, grad: torch.Tensor):
+ bins, = ctx.saved_tensors
+ out = torch.empty((grad.shape[0], bins.shape[0]), dtype=grad.dtype, device=grad.device)
+ ops.replicate_backward(grad, bins, out)
+ return out, None, None
+
+
+replicate = ReplicateOp.apply
diff --git a/build/torch29-cxx11-cu128-x86_64-linux/ops/round_up.py b/build/torch29-cxx11-cu128-x86_64-linux/ops/round_up.py
new file mode 100644
index 0000000000000000000000000000000000000000..6cf6bc873c9f448c5fa9126ebcfd66e8688002af
--- /dev/null
+++ b/build/torch29-cxx11-cu128-x86_64-linux/ops/round_up.py
@@ -0,0 +1,14 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+
+def round_up(x: torch.Tensor, value: int):
+ assert isinstance(value, int)
+ assert x.dtype == torch.int32
+
+ # TODO(tgale): If this becomes and issue
+ # do this in a custom kernel. We only expect
+ # to use this on arrays of less than 1k elements.
+ return torch.div(x + (value - 1), value, rounding_mode='trunc') * value
diff --git a/build/torch29-cxx11-cu128-x86_64-linux/ops/scatter.py b/build/torch29-cxx11-cu128-x86_64-linux/ops/scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4605d9b46f387761b070352365f223dbfe69d47
--- /dev/null
+++ b/build/torch29-cxx11-cu128-x86_64-linux/ops/scatter.py
@@ -0,0 +1,72 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Optional
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for scatter kernel.
+class ScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ) -> torch.Tensor:
+ maybe_x = [x] if ctx.needs_input_grad[3] else []
+ ctx.save_for_backward(indices, bin_ids, weights, bins, *maybe_x)
+ ctx.top_k = top_k
+ ctx.x_shape = x.shape
+ return kernels.scatter(x, indices, bin_ids, weights, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ saved_tensors = ctx.saved_tensors
+
+ indices, bin_ids, weights, bins = saved_tensors[:4]
+ dgrad = None
+ if ctx.needs_input_grad[0]:
+ dgrad = kernels.gather(
+ grad,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[3]: # need wgrad
+ x = saved_tensors[-1]
+ wgrad = kernels.scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bin_ids,
+ bins,
+ ctx.top_k,
+ )
+ return dgrad, None, None, wgrad, None, None, None
+
+
+def scatter(
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+) -> Optional[torch.Tensor]:
+ return ScatterOp.apply(x, indices, bin_ids, weights, bins, top_k)
diff --git a/build/torch29-cxx11-cu128-x86_64-linux/ops/sort.py b/build/torch29-cxx11-cu128-x86_64-linux/ops/sort.py
new file mode 100644
index 0000000000000000000000000000000000000000..bda3bf64283e39533c2eae3627e76bb2d0262c9f
--- /dev/null
+++ b/build/torch29-cxx11-cu128-x86_64-linux/ops/sort.py
@@ -0,0 +1,38 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Optional, Tuple
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+_BITS_FOR_DTYPE = {
+ torch.int16: 16,
+ torch.int32: 32,
+ torch.int64: 64,
+}
+
+
+# Autograd wrapper for sort kernel.
+# NOTE: Does not support gradients.
+class SortOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, end_bit: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
+ if end_bit is None:
+ end_bit = _BITS_FOR_DTYPE[x.dtype]
+ x_out = torch.empty_like(x)
+ iota_out = torch.empty_like(x)
+ ops.sort(x, end_bit, x_out, iota_out)
+ return (x_out, iota_out)
+
+
+sort = SortOp.apply
diff --git a/build/torch29-cxx11-cu128-x86_64-linux/ops/sort_benchmark.py b/build/torch29-cxx11-cu128-x86_64-linux/ops/sort_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..a92ff957d4c552c6e61d9279a7989795472af7b7
--- /dev/null
+++ b/build/torch29-cxx11-cu128-x86_64-linux/ops/sort_benchmark.py
@@ -0,0 +1,85 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import numpy as np
+import torch
+from absl.testing import parameterized
+
+from .. import ops
+
+_SORT_TESTS = (
+ (16384, torch.int32, None),
+ (16384, torch.int32, 2),
+ (16384, torch.int32, 128),
+)
+
+_BASELINE_SORT_TESTS = ((16384,),)
+
+
+def numpy_dtype(dtype):
+ types = {
+ torch.int16: np.int16,
+ torch.int32: np.int32,
+ torch.int64: np.int64,
+ }
+ return types[dtype]
+
+
+def benchmark_function(fn, iterations=10):
+ # Run once to get rid of startup overhead.
+ fn()
+ times = []
+ for _ in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+ start.record()
+ fn()
+ end.record()
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ times = np.array(times)
+ return times.mean(), times.std(), times.max(), times.min()
+
+
+def log_benchmark(arguments, mean_t, std_t):
+ print('=' * 60)
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t))
+ print('=' * 60)
+
+
+class SortBenchmark(parameterized.TestCase):
+
+ @parameterized.parameters(*_SORT_TESTS)
+ def testSort(self, n, dtype, max_val):
+ if max_val is None:
+ max_val = np.iinfo(numpy_dtype(dtype)).max
+ end_bit = int(np.ceil(np.log2(max_val)))
+ x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
+
+ mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.sort(x, end_bit),)
+ arguments = {
+ 'n': n,
+ 'dtype': dtype,
+ 'max_val': max_val,
+ }
+ log_benchmark(arguments, mean_t, std_t)
+
+ @parameterized.parameters(*_BASELINE_SORT_TESTS)
+ def testTorchSort(self, n):
+ x = torch.randint(0, 128, (n,)).cuda().to(torch.int32)
+
+ mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.sort(x))
+ arguments = {
+ 'n': n,
+ }
+ log_benchmark(arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch29-cxx11-cu128-x86_64-linux/ops/stk_autocast.py b/build/torch29-cxx11-cu128-x86_64-linux/ops/stk_autocast.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a3626e5e0eec51339c95a448bca84be14a2ca93
--- /dev/null
+++ b/build/torch29-cxx11-cu128-x86_64-linux/ops/stk_autocast.py
@@ -0,0 +1,39 @@
+# vendored from
+# https://github.com/stanford-futuredata/stk/blob/736313768ef697ce13a0594a41b2512a0fbc9884/stk/backend/autocast.py
+import functools
+import torch
+
+
+def _is_eligible(x):
+ return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64)
+
+
+def _cast(x, dtype):
+ if isinstance(x, torch.Tensor) and _is_eligible(x):
+ return x.to(dtype)
+ elif isinstance(x, map):
+ return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()}
+ elif isinstance(x, list) or isinstance(x, tuple):
+ return type(x)(map(lambda y: _cast(y, dtype), x))
+ return x
+
+
+def custom_fwd(fwd):
+ """Wrap a custom autograd function that always uses autocast dtype."""
+
+ @functools.wraps(fwd)
+ def decorate_fwd(*args, **kwargs):
+ if torch.is_autocast_enabled():
+ with torch.autocast(device_type="cuda", enabled=False):
+ dtype = torch.get_autocast_gpu_dtype()
+ return fwd(*_cast(args, dtype), **_cast(kwargs, dtype))
+ return fwd(*args, **kwargs)
+ return decorate_fwd
+
+
+def custom_bwd(bwd):
+ @functools.wraps(bwd)
+ def decorate_bwd(*args, **kwargs):
+ with torch.autocast(device_type="cuda", enabled=False):
+ return bwd(*args, **kwargs)
+ return decorate_bwd
\ No newline at end of file
diff --git a/build/torch29-cxx11-cu128-x86_64-linux/ops/sum.py b/build/torch29-cxx11-cu128-x86_64-linux/ops/sum.py
new file mode 100644
index 0000000000000000000000000000000000000000..e00c1aa68e1193f5b72f75a2edc37de8d505facc
--- /dev/null
+++ b/build/torch29-cxx11-cu128-x86_64-linux/ops/sum.py
@@ -0,0 +1,9 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+import torch
+
+
+def sum(x: torch.Tensor, dim: int = 0):
+ if x.shape[dim] == 1:
+ return x.squeeze(dim=dim)
+ return x.sum(dim=dim)
diff --git a/build/torch29-cxx11-cu128-x86_64-linux/ops/topology.py b/build/torch29-cxx11-cu128-x86_64-linux/ops/topology.py
new file mode 100644
index 0000000000000000000000000000000000000000..76a50d3164db20534b099dcb4d8487a7aef25d15
--- /dev/null
+++ b/build/torch29-cxx11-cu128-x86_64-linux/ops/topology.py
@@ -0,0 +1,45 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for topology kernel.
+# NOTE: Does not support gradients.
+class TopologyOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(
+ ctx: Any,
+ padded_bins: torch.Tensor,
+ block_size: int,
+ output_block_rows: int,
+ output_block_columns: int,
+ ):
+ out = torch.empty(
+ output_block_rows * output_block_columns,
+ dtype=torch.int16,
+ device=padded_bins.device,
+ )
+ ops.indices(
+ padded_bins,
+ block_size,
+ output_block_rows,
+ output_block_columns,
+ out,
+ )
+ return out
+
+
+topology = TopologyOp.apply
diff --git a/build/torch29-cxx11-cu128-x86_64-linux/stk/__init__.py b/build/torch29-cxx11-cu128-x86_64-linux/stk/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..73c40c267b1c3f4949e9c957a5d2c9f682dfc1a6
--- /dev/null
+++ b/build/torch29-cxx11-cu128-x86_64-linux/stk/__init__.py
@@ -0,0 +1,7 @@
+# import stk.random
+# import stk.ops
+# from stk.matrix import Matrix
+
+from . import random
+from . import ops
+from .matrix import Matrix
diff --git a/build/torch29-cxx11-cu128-x86_64-linux/stk/backend/__init__.py b/build/torch29-cxx11-cu128-x86_64-linux/stk/backend/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/build/torch29-cxx11-cu128-x86_64-linux/stk/backend/autocast.py b/build/torch29-cxx11-cu128-x86_64-linux/stk/backend/autocast.py
new file mode 100644
index 0000000000000000000000000000000000000000..97f6e919a60f3fd579ed0215031008d14111dc96
--- /dev/null
+++ b/build/torch29-cxx11-cu128-x86_64-linux/stk/backend/autocast.py
@@ -0,0 +1,37 @@
+import functools
+import torch
+
+
+def _is_eligible(x):
+ return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64)
+
+
+def _cast(x, dtype):
+ if isinstance(x, torch.Tensor) and _is_eligible(x):
+ return x.to(dtype)
+ elif isinstance(x, map):
+ return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()}
+ elif isinstance(x, list) or isinstance(x, tuple):
+ return type(x)(map(lambda y: _cast(y, dtype), x))
+ return x
+
+
+def custom_fwd(fwd):
+ """Wrap a custom autograd function that always uses autocast dtype."""
+
+ @functools.wraps(fwd)
+ def decorate_fwd(*args, **kwargs):
+ if torch.is_autocast_enabled():
+ with torch.autocast(device_type="cuda", enabled=False):
+ dtype = torch.get_autocast_gpu_dtype()
+ return fwd(*_cast(args, dtype), **_cast(kwargs, dtype))
+ return fwd(*args, **kwargs)
+ return decorate_fwd
+
+
+def custom_bwd(bwd):
+ @functools.wraps(bwd)
+ def decorate_bwd(*args, **kwargs):
+ with torch.autocast(device_type="cuda", enabled=False):
+ return bwd(*args, **kwargs)
+ return decorate_bwd
diff --git a/build/torch29-cxx11-cu128-x86_64-linux/stk/backend/sputnik.py b/build/torch29-cxx11-cu128-x86_64-linux/stk/backend/sputnik.py
new file mode 100644
index 0000000000000000000000000000000000000000..220c947bc1e932e8c77cc30f4069e9930f1aa962
--- /dev/null
+++ b/build/torch29-cxx11-cu128-x86_64-linux/stk/backend/sputnik.py
@@ -0,0 +1,316 @@
+import torch
+
+from ..backend import triton_kernels as backend
+from ..backend.autocast import custom_bwd, custom_fwd
+
+
+def _standardize_shape(x, transpose):
+ if transpose:
+ return torch.Size((x[1], x[0]))
+ return x
+
+
+def _sparse_transpose(x):
+ return (torch.Size((x[0][1], x[0][0])), ) + x[1:]
+
+
+def _transpose_helper(x, transpose):
+ if isinstance(x, torch.Tensor):
+ return x.t() if transpose else x
+ if transpose:
+ x = _sparse_transpose(x)
+ return x + (transpose,)
+
+
+def _wrap(x):
+ if isinstance(x, torch.Tensor):
+ return (x,)
+ return x
+
+
+def _is_transposed(x):
+ return (not x.is_contiguous() and
+ x.stride()[0] == 1 and
+ x.stride()[1] == x.size()[0])
+
+
+def _call_helper(op, out, a, b, trans_a, trans_b):
+ args = (_wrap(_transpose_helper(a, trans_a)) +
+ _wrap(_transpose_helper(b, trans_b)))
+ if isinstance(out, tuple):
+ args = args + out
+ return op(*args)
+
+
+def _preprocess_inputs(lhs, rhs, dy):
+ if isinstance(lhs, torch.Tensor) and _is_transposed(lhs):
+ lhs = lhs.t()
+ if isinstance(rhs, torch.Tensor) and _is_transposed(rhs):
+ rhs = rhs.t()
+ if (isinstance(dy, torch.Tensor) and
+ not dy.is_contiguous() and
+ not _is_transposed(dy)):
+ dy = dy.contiguous()
+ if isinstance(dy, tuple) and not dy[1].is_contiguous():
+ dy = (dy[0], dy[1].contiguous()) + dy[2:]
+ return lhs, rhs, dy
+
+
+def _postprocess_outputs(x, transpose, grad):
+ if isinstance(x, torch.Tensor) and transpose:
+ return grad.t()
+ return grad
+
+
+def _lhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs):
+ lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy)
+
+ a, b = (rhs, dy) if trans_lhs else (dy, rhs)
+ trans_a = trans_lhs and trans_rhs
+ trans_b = trans_lhs or not trans_rhs
+ out = _call_helper(op, lhs, a, b, trans_a, trans_b)
+ return _postprocess_outputs(lhs, trans_lhs, out)
+
+
+def _rhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs):
+ lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy)
+
+ a, b = (dy, lhs) if trans_rhs else (lhs, dy)
+ trans_a = not trans_lhs or trans_rhs
+ trans_b = trans_lhs and trans_rhs
+ out = _call_helper(op, rhs, a, b, trans_a, trans_b)
+ return _postprocess_outputs(rhs, trans_rhs, out)
+
+
+class DSD(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs):
+ ctx.save_for_backward(data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ rhs)
+ ctx.shape = _standardize_shape(shape, transpose_a)
+ ctx.transpose_a = transpose_a
+
+ out = torch.empty(
+ (shape[0], rhs.size()[1]),
+ dtype=rhs.dtype,
+ device=rhs.device)
+
+ backend.dsd(shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs,
+ out)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs = (ctx.shape,) + saved_tensors[:-1]
+ rhs = saved_tensors[-1]
+ trans_a = ctx.transpose_a
+ trans_b = _is_transposed(rhs)
+
+ ddata = None
+ if ctx.needs_input_grad[1]:
+ ddata = _lhs_gradient(sdd,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ drhs = None
+ if ctx.needs_input_grad[-1]:
+ op = dds if trans_b else dsd
+ drhs = _rhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return None, ddata, None, None, None, None, None, None, None, drhs
+
+
+dsd = DSD.apply
+
+
+class DDS(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b):
+ ctx.save_for_backward(lhs,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t)
+ ctx.shape = _standardize_shape(shape, transpose_b)
+ ctx.transpose_b = transpose_b
+ out = torch.empty((lhs.size()[0], shape[1]),
+ dtype=lhs.dtype,
+ device=lhs.device)
+ backend.dds(lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b,
+ out)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs = saved_tensors[0]
+ rhs = (ctx.shape,) + saved_tensors[1:]
+ trans_a = _is_transposed(lhs)
+ trans_b = ctx.transpose_b
+
+ dlhs = None
+ if ctx.needs_input_grad[0]:
+ op = dsd if trans_a else dds
+ dlhs = _lhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ ddata = None
+ if ctx.needs_input_grad[2]:
+ ddata = _rhs_gradient(sdd,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return dlhs, None, ddata, None, None, None, None, None, None, None
+
+
+dds = DDS.apply
+
+
+class SDD(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ lhs,
+ rhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t):
+ ctx.save_for_backward(
+ lhs,
+ rhs,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t)
+ ctx.shape = shape
+ out = torch.empty(
+ data.shape,
+ dtype=lhs.dtype,
+ device=lhs.device)
+ backend.sdd(lhs,
+ rhs,
+ shape,
+ out,
+ offsets,
+ row_indices,
+ column_indices)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs, rhs = saved_tensors[:2]
+ dy = (ctx.shape, dy) + saved_tensors[2:]
+ trans_a = _is_transposed(lhs)
+ trans_b = _is_transposed(rhs)
+
+ dlhs = None
+ if ctx.needs_input_grad[0]:
+ op = dds if trans_a else dsd
+ dlhs = _lhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ drhs = None
+ if ctx.needs_input_grad[1]:
+ op = dsd if trans_b else dds
+ drhs = _rhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return dlhs, drhs, None, None, None, None, None, None, None, None
+
+
+sdd = SDD.apply
+
+class RowIndices(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, shape, data, offsets, column_indices):
+ out = torch.empty(
+ column_indices.shape,
+ dtype=column_indices.dtype,
+ device=column_indices.device)
+ backend.row_indices(shape, data, offsets, column_indices, out)
+ return out
+
+
+row_indices = RowIndices.apply
diff --git a/build/torch29-cxx11-cu128-x86_64-linux/stk/backend/triton_kernels.py b/build/torch29-cxx11-cu128-x86_64-linux/stk/backend/triton_kernels.py
new file mode 100644
index 0000000000000000000000000000000000000000..c535309f3321249f475367164a558f94a4f8eb86
--- /dev/null
+++ b/build/torch29-cxx11-cu128-x86_64-linux/stk/backend/triton_kernels.py
@@ -0,0 +1,393 @@
+import torch
+import triton
+import triton.language as tl
+from dataclasses import dataclass
+
+@dataclass
+class TritonConfig:
+ BLOCK_M: int = 128
+ BLOCK_N: int = 128
+ BLOCK_K: int = 32
+ BLOCK_SIZE: int = 128
+ NUM_STAGES: int = 4
+ NUM_WARPS: int = 4
+
+def _validate_matmul_dims(M: int, K: int, N: int):
+ error_string = "incompatible dimensions: tensor has dim with length: {}, which must be divisible by {}"
+ assert M % TritonConfig.BLOCK_M == 0, error_string.format(M, TritonConfig.BLOCK_M)
+ assert K % TritonConfig.BLOCK_K == 0, error_string.format(K, TritonConfig.BLOCK_K)
+ assert N % TritonConfig.BLOCK_N == 0, error_string.format(N, TritonConfig.BLOCK_N)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _sdd_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+ # matrix multiplication
+ pid = tl.program_id(0)
+ pid_m = tl.load(row_indices + pid)
+ pid_n = tl.load(column_indices + pid)
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
+ rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
+ rk = tl.arange(0, BLOCK_K)
+ # pointers
+ A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
+ B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ for k in range(0, tl.cdiv(K, BLOCK_K)):
+ a = tl.load(A)
+ b = tl.load(B)
+ acc += tl.dot(a, b)
+ A += BLOCK_K * stride_ak
+ B += BLOCK_K * stride_bk
+ #Store to sparse matrix
+ acc = acc.to(C.dtype.element_ty)
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+ cm = tl.arange(0, BLOCK_M)
+ cn = tl.arange(0, BLOCK_N)
+ C = C + pid * BLOCK_ELEMENTS + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _dsd_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices, offsets,
+ block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+
+ # matrix multiplication
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+
+ num_pid_m = tl.num_programs(0)
+ num_pid_n = tl.num_programs(1)
+ pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M)
+
+ start_inx = tl.load(offsets + pid_m)
+ end_inx = tl.load(offsets + pid_m + 1)
+
+ # pointers to sparse matrix
+ rm = tl.arange(0, BLOCK_M)
+ rak = tl.arange(0, BLOCK_K)
+
+ A += (rm[:, None] * stride_am + rak[None, :] * stride_ak)
+
+ # pointers to dense matrix
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ rbk = tl.arange(0, BLOCK_K)
+ B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn)
+
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K)
+
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+ ak_sub_incr = BLOCK_K * stride_ak
+ bk_sub_incr = BLOCK_K * stride_bk
+ bk_block_incr = BLOCK_SIZE * stride_bk
+
+ for k in range(nsub_blocks * (end_inx - start_inx)):
+ sub_block_inx = k % nsub_blocks
+ block_inx = k // nsub_blocks
+
+ if trans_A:
+ ptr_A = A + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr
+ else:
+ ptr_A = A + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr
+
+ ptr_B = B + tl.load(column_indices + start_inx + block_inx) * bk_block_incr + sub_block_inx * bk_sub_incr
+
+ a = tl.load(ptr_A)
+ b = tl.load(ptr_B)
+ acc += tl.dot(a, b)
+
+ acc = acc.to(C.dtype.element_ty)
+
+ cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+
+ C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _dds_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices, offsets,
+ block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+
+ # matrix multiplication
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+
+ num_pid_m = tl.num_programs(0)
+ num_pid_n = tl.num_programs(1)
+ pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M)
+
+ start_inx = tl.load(offsets + pid_n)
+ end_inx = tl.load(offsets + pid_n + 1)
+
+ # pointers to dense matrix
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ rak = tl.arange(0, BLOCK_K)
+
+ A += (rm[:, None] * stride_am + rak[None, :] * stride_ak)
+
+ # pointers to sparse matrix
+ rn = tl.arange(0, BLOCK_N)
+ rbk = tl.arange(0, BLOCK_K)
+ B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn)
+
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K)
+
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+
+ ak_sub_incr = BLOCK_K * stride_ak
+ ak_block_incr = BLOCK_SIZE * stride_ak
+ bk_sub_incr = BLOCK_K * stride_bk
+
+ for k in range(nsub_blocks * (end_inx - start_inx)):
+ sub_block_inx = k % nsub_blocks
+ block_inx = k // nsub_blocks
+
+ if trans_B:
+ ptr_B = B + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr
+ else:
+ ptr_B = B + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr
+
+ ptr_A = A + tl.load(column_indices + start_inx + block_inx) * ak_block_incr + sub_block_inx * ak_sub_incr
+ a = tl.load(ptr_A)
+ b = tl.load(ptr_B)
+ acc += tl.dot(a, b)
+
+ acc = acc.to(C.dtype.element_ty)
+ cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+def dsd(shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs,
+ out
+ ):
+
+ device = rhs.device
+ trans_A = transpose_a
+ trans_B = False
+
+ if rhs.stride(0) > 1 and rhs.stride(1) > 1:
+ trans_B = True
+
+ # checks constraints
+ assert shape[1] == rhs.shape[0], "incompatible dimensions"
+ M, K = shape
+ _, N = rhs.shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if rhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ stride_am, stride_ak = data.stride(1), data.stride(2)
+ stride_bk, stride_bn = rhs.stride(0), rhs.stride(1)
+ a_column_indices = column_indices
+ a_offsets = offsets
+
+ # launch kernel
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N']))
+
+ if trans_A:
+ stride_am, stride_ak = data.stride(2), data.stride(1)
+ a_column_indices, a_offsets = column_indices_t, offsets_t
+
+ if trans_B:
+ stride_bk, stride_bn = rhs.stride(1), rhs.stride(0)
+
+ _dsd_kernel[grid](
+ data.data, rhs, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(0), out.stride(1),
+ row_indices, a_column_indices, a_offsets,
+ block_offsets_t, trans_A, trans_B,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+ # return out
+
+def dds(lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b,
+ out
+ ):
+
+ device = lhs.device
+ trans_B = transpose_b
+ trans_A = False
+
+ if lhs.stride(0) > 1 and lhs.stride(1) > 1:
+ trans_A = True
+
+ # checks constraints
+ assert lhs.shape[1] == shape[0], "incompatible dimensions"
+ M, K = lhs.shape
+ _, N = shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if lhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ stride_am, stride_ak = lhs.stride(0), lhs.stride(1)
+ stride_bk, stride_bn = data.stride(1), data.stride(2)
+ b_column_indices = column_indices_t
+ b_offsets = offsets_t
+
+ # launch kernel
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N']))
+
+ if trans_A:
+ stride_am, stride_ak = lhs.stride(1), lhs.stride(0)
+ if trans_B:
+ stride_bk, stride_bn = data.stride(2), data.stride(1)
+ b_column_indices, b_offsets = column_indices, offsets
+
+ _dds_kernel[grid](
+ lhs, data, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(0), out.stride(1),
+ row_indices, b_column_indices, b_offsets,
+ block_offsets_t, trans_A, trans_B,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+
+def sdd(lhs,
+ rhs,
+ shape,
+ out,
+ offsets,
+ row_indices,
+ column_indices
+ ):
+
+ device = out.device
+ trans_A = False
+ trans_B = False
+
+ if lhs.stride(0) > 1 and lhs.stride(1) > 1:
+ trans_A = True
+ if rhs.stride(0) > 1 and rhs.stride(1) > 1:
+ trans_B = True
+
+ # checks constraints
+ assert lhs.shape[1] == rhs.shape[0], "incompatible dimensions"
+ M, K = lhs.shape
+ _, N = rhs.shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if out.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ # launch kernel
+ nnz_blocks = len(row_indices)
+ grid = lambda META: (nnz_blocks,)
+
+ stride_am, stride_ak = lhs.stride(0), lhs.stride(1)
+ stride_bk, stride_bn = rhs.stride(0), rhs.stride(1)
+
+ if trans_A:
+ stride_am, stride_ak = lhs.stride(1), lhs.stride(0)
+ if trans_B:
+ stride_bk, stride_bn = rhs.stride(1), rhs.stride(0)
+
+ _sdd_kernel[grid](
+ lhs, rhs, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(1), out.stride(2),
+ row_indices, column_indices,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+
+@triton.jit
+def _row_indices_kernel(offsets, out):
+ pid = tl.program_id(0)
+ row_offset = tl.load(offsets + pid)
+ nnz_blocks = tl.load(offsets + pid + 1) - row_offset
+ for nnz_block in range(nnz_blocks):
+ tl.store(out + row_offset + nnz_block, pid)
+
+def row_indices(
+ shape, data, offsets, column_indices, out
+):
+ block_rows = len(offsets) - 1
+ _row_indices_kernel[(block_rows, )](offsets, out)
diff --git a/build/torch29-cxx11-cu128-x86_64-linux/stk/matrix.py b/build/torch29-cxx11-cu128-x86_64-linux/stk/matrix.py
new file mode 100644
index 0000000000000000000000000000000000000000..80f42263d6aada287adbfa52a61fe950162a9e28
--- /dev/null
+++ b/build/torch29-cxx11-cu128-x86_64-linux/stk/matrix.py
@@ -0,0 +1,329 @@
+import numpy as np
+import torch
+
+# 1. Add heavyweight (data) validation helper.
+# 2. Add construction helpers
+# 3. Make indentation consistent
+# 4. Replace asserts with descriptive errors.
+
+##
+### Validation helpers.
+##
+
+
+def _validate_matrix(shape, data, row_indices, column_indices, offsets):
+ # Data should be [nnz, block_size, block_size]
+ if data.dim() == 1:
+ data = torch.reshape(data, [data.numel(), 1, 1])
+
+ # Blocks should be square.
+ if data.shape[-2] != data.shape[-1]:
+ raise ValueError(
+ "Expected square blocking in data. "
+ f"Got block shape {[data.shape[-2], data.shape[-1]]}")
+
+ # Flatten batch dimensions on data - original shape preserved
+ # in shape argument.
+ block_size = data.shape[-1]
+ data = data.view([-1, block_size, block_size])
+
+ if data.dim() != 3:
+ raise ValueError(
+ "Expected 3D shape for data (nnz, block, block). "
+ f"Got shape {data.dim()}D shape.")
+
+ block_size = data.shape[1]
+ if shape[-2] % block_size != 0 or shape[-1] % block_size != 0:
+ raise ValueError(
+ "Matrix shape must be dividible by blocking. "
+ f"Got shape {shape} with "
+ f"{[block_size, block_size]} blocking.")
+
+ if np.prod(shape) < data.numel():
+ raise ValueError(
+ "Invalid matrix. Number of nonzeros exceeds matrix capacity "
+ f"({data.numel()} v. {np.prod(shape)})")
+
+ if row_indices.dim() != 1:
+ raise ValueError(
+ f"Expected 1D row_indices. Got {row_indices.dim()}D row_indices.")
+
+ if column_indices.dim() != 1:
+ raise ValueError(
+ f"Expected 1D column_indices. Got {column_indices.dim()}D column_indices.")
+
+ if offsets.dim() != 1:
+ raise ValueError(
+ f"Expected 1D offsets. Got {offsets.dim()}D offsets.")
+
+ if row_indices.numel() != data.shape[0]:
+ raise ValueError(
+ "Expected 1 index per nonzero block. "
+ f"Got {row_indices.numel()} row_indices for {data.shape[0]} blocks")
+
+ if column_indices.numel() != data.shape[0]:
+ raise ValueError(
+ "Expected 1 index per nonzero block. "
+ f"Got {column_indices.numel()} column_indices for {data.shape[0]} blocks")
+
+ block_rows = np.prod(shape[:-1]) / block_size
+ if offsets.numel() != block_rows + 1:
+ raise ValueError(
+ "Expected one offset per block row plus one. "
+ f"Got {offsets.numel()} offsets with {block_rows} block rows.")
+
+ is_cuda = (data.is_cuda and
+ row_indices.is_cuda and
+ column_indices.is_cuda and
+ offsets.is_cuda)
+ is_cpu = (not data.is_cuda and
+ not row_indices.is_cuda and
+ not column_indices.is_cuda and
+ not offsets.is_cuda)
+ if not (is_cuda or is_cpu):
+ raise ValueError(
+ "Expected data & meta-data on common device. "
+ f"Got data on {data.device}, row_indices on {row_indices.device} "
+ f"column_indices on {column_indices.device} and "
+ f"offsets on {offsets.device}.")
+
+ if data.dtype != torch.float16:
+ raise ValueError(
+ f"Expected float16 data. Got {data.dtype} data.")
+ if row_indices.dtype != torch.int16:
+ raise ValueError(
+ f"Expected int16 row_indices. Got {row_indices.dtype} row_indices.")
+ if column_indices.dtype != torch.int16:
+ raise ValueError(
+ f"Expected int16 column_indices. Got {column_indices.dtype} column_indices.")
+ if offsets.dtype != torch.int32:
+ raise ValueError(
+ f"Expected int32 offsets. Got {offsets.dtype} offsets.")
+ return data
+
+
+def _transpose(size, data, row_indices, column_indices, offsets):
+ block_columns = size[1] // data.shape[1]
+
+ # Sort row indices by column indices to get the transposed matrix's
+ # column indices.
+ gather_indices = column_indices.argsort()
+ column_indices_t = row_indices.gather(0, gather_indices)
+ block_offsets_t = gather_indices.int()
+
+ # NOTE: Histogram is not implemented for any integer type on CPU. Do
+ # the histogram in 32-bit float, which can exactly represent 16-bit
+ # integers.
+ column_indices_float = column_indices.float()
+
+ zero = torch.zeros((1,), dtype=torch.int32, device=data.device)
+ nnz_per_column = column_indices_float.histc(block_columns, 0, block_columns)
+ nnz_per_column = nnz_per_column.int()
+ offsets_t = torch.cat([zero, nnz_per_column.cumsum(0, dtype=torch.int32)])
+ return column_indices_t, offsets_t, block_offsets_t
+
+
+class Matrix(torch.nn.Module):
+ """A matrix stored in sparse format.
+
+ Underlying format is block compressed sparse row (BCSR).
+
+ TODO(tgale): Make this mirror torch.Tensor API as much as possible.
+ """
+
+ def __init__(self,
+ size,
+ data,
+ row_indices,
+ column_indices,
+ offsets,
+ column_indices_t=None,
+ offsets_t=None,
+ block_offsets_t=None):
+ super().__init__()
+ self._size = size
+ self._data = data
+ self._row_indices = row_indices
+ self._column_indices = column_indices
+ self._offsets = offsets
+
+ # Produce the transpose meta-data if it is not passed in.
+ if ((column_indices_t is None) or (offsets_t is None) or
+ (block_offsets_t is None)):
+ column_indices_t, offsets_t, block_offsets_t = _transpose(
+ size, data, row_indices, column_indices, offsets)
+ self._column_indices_t = column_indices_t
+ self._offsets_t = offsets_t
+ self._block_offsets_t = block_offsets_t
+
+ self._transposed = False
+
+ # Validate that our metadata will not overflow.
+ max_dim = np.iinfo(np.int16).max * self.blocking
+ if column_indices.dtype == torch.int16:
+ if size[0] > max_dim or size[1] > max_dim:
+ raise ValueError(
+ "Sparse matrix with shape {size} exceeds representable "
+ "size with 16-bit indices.")
+
+ def validate(self):
+ _validate_matrix(self._size,
+ self._data,
+ self._row_indices,
+ self._column_indices,
+ self._offsets)
+
+ # TODO(tgale): Add heavyweight data validation.
+
+ def to(self, device):
+ # TODO(tgale): Handle type conversions here. We
+ # need to set the appropriate meta-data type for
+ # the given floating-point type.
+ self._data = self._data.to(device)
+ self._row_indices = self._row_indices.to(device)
+ self._column_indices = self._column_indices.to(device)
+ self._offsets = self._offsets.to(device)
+ self._column_indices_t = self._column_indices_t.to(device)
+ self._offsets_t = self._offsets_t.to(device)
+ self._block_offsets_t = self._block_offsets_t.to(device)
+ return self
+
+ def cuda(self):
+ return self.to(torch.cuda.current_device())
+
+ def clone(self):
+ return Matrix(
+ self.size(),
+ self.data.clone(),
+ self.row_indices.clone(),
+ self.column_indices.clone(),
+ self.offsets.clone(),
+ self.column_indices_t.clone(),
+ self.offsets_t.clone(),
+ self.block_offsets_t.clone())
+
+ def t(self):
+ if self.dim() != 2:
+ raise ValueError(
+ "t() expects a tensor with <= 2 dimensions, "
+ f"but self is {self.dim()}D.")
+ out = Matrix(self.size(),
+ self.data,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+ out._transposed = not self._transposed
+ out._size = torch.Size((self._size[1], self._size[0]))
+ return out
+
+ def contiguous(self):
+ raise ValueError("Not yet implemented.")
+
+ def is_contiguous(self):
+ return not self._transposed
+
+ @property
+ def is_cuda(self):
+ return self._data.is_cuda
+
+ @property
+ def device(self):
+ return self._data.device
+
+ def size(self):
+ return self._size
+
+ @property
+ def shape(self):
+ return self.size()
+
+ def dim(self):
+ return len(self._size)
+
+ @property
+ def data(self):
+ return self._data
+
+ @property
+ def row_indices(self):
+ return self._row_indices
+
+ @property
+ def column_indices(self):
+ return self._column_indices
+
+ @property
+ def offsets(self):
+ return self._offsets
+
+ @property
+ def offsets_t(self):
+ return self._offsets_t
+
+ @property
+ def column_indices_t(self):
+ return self._column_indices_t
+
+ @property
+ def block_offsets_t(self):
+ return self._block_offsets_t
+
+ @property
+ def dtype(self):
+ return self.data.dtype
+
+ @property
+ def nnz(self):
+ return self.data.numel()
+
+ @property
+ def blocking(self):
+ return self.data.shape[1]
+
+ @property
+ def requires_grad(self):
+ return self.data.requires_grad
+
+ def requires_grad_(self, x):
+ self.data.requires_grad_(x)
+ return self
+
+ def view(self, *shape):
+ assert self.is_contiguous()
+ if shape[-1] != self.size()[-1]:
+ raise ValueError(
+ "Can't change view on compressed dimension. "
+ f"{self.size()[-1]} v. {shape[-1]}.")
+ if np.prod(shape) != np.prod(self.size()):
+ raise ValueError(
+ "Mismatch in numel of Matrix and new shape. "
+ f"{np.prod(self.size())} v. {np.prod(shape)}")
+ return Matrix(shape,
+ self.data,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+
+ @property
+ def grad(self):
+ # TODO(tgale): Make sure this mirrors torch.Tensor
+ # behavior in the case where we ask for the gradient
+ # of a non-contiguous tensor.
+ size = self.size()
+ if not self.is_contiguous():
+ size = torch.Size((size[1], size[0]))
+ out = Matrix(size,
+ self.data.grad,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+ return out if self.is_contiguous() else out.t()
diff --git a/build/torch29-cxx11-cu128-x86_64-linux/stk/ops/__init__.py b/build/torch29-cxx11-cu128-x86_64-linux/stk/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc873b236f4cd4036964c016a4036e3ce5ebf1ac
--- /dev/null
+++ b/build/torch29-cxx11-cu128-x86_64-linux/stk/ops/__init__.py
@@ -0,0 +1,3 @@
+from .linear_ops import dds, dsd, sdd
+from .matrix_ops import ones_like, row_indices, sum, to_dense, to_sparse
+from .eltwise_ops import mul
diff --git a/build/torch29-cxx11-cu128-x86_64-linux/stk/ops/eltwise_ops.py b/build/torch29-cxx11-cu128-x86_64-linux/stk/ops/eltwise_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba7d7332320250fd01fa60e60528f19de3e8ed03
--- /dev/null
+++ b/build/torch29-cxx11-cu128-x86_64-linux/stk/ops/eltwise_ops.py
@@ -0,0 +1,28 @@
+from ..matrix import Matrix
+
+def mul(a, b):
+ """Performs element-wise multiplication of matrices a and b.
+
+ It is the user's responsibility to make sure that a and b
+ follow the same matrix topology. This function assumes it is safe
+ to use the topoplogy of a.
+
+ Args:
+ a: stk.Matrix.
+ b: stk.Matrix with a's matrix topology.
+
+ Returns:
+ stk.Matrix where the entries correspond to torch.mul(a, b).
+ """
+ assert isinstance(a, Matrix)
+ assert isinstance(b, Matrix)
+ assert a.size() == b.size()
+
+ return Matrix(a.size(),
+ a.data * b.data,
+ a.row_indices,
+ a.column_indices,
+ a.offsets,
+ a.column_indices_t,
+ a.offsets_t,
+ a.block_offsets_t)
diff --git a/build/torch29-cxx11-cu128-x86_64-linux/stk/ops/eltwise_ops_test.py b/build/torch29-cxx11-cu128-x86_64-linux/stk/ops/eltwise_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..66bfd4f6af77042d3c5bdb1fe18d00e457478d46
--- /dev/null
+++ b/build/torch29-cxx11-cu128-x86_64-linux/stk/ops/eltwise_ops_test.py
@@ -0,0 +1,86 @@
+import unittest
+import itertools
+import torch
+from absl.testing import parameterized
+
+import stk
+from stk.ops.linear_ops_test import allclose, _dense_and_sparse
+
+_MATRIX_SIZES = (
+ (128, 128, 0.0),
+ (256, 256, 0.5),
+ (2048, 1024, 0.8),
+ (512, 128, 0.0),
+ (128, 512, 0.0),
+ (1024, 512, 0.0),
+ (1024, 512, 0.5),
+ (1024, 512, 0.75),
+ (512, 1024, 0.0),
+ (512, 1024, 0.5),
+ (512, 1024, 0.75),
+ (1024, 1024, 0.0),
+ (1024, 1024, 0.5),
+ (1024, 1024, 0.75),
+)
+
+_DTYPE = (
+ torch.float16, torch.bfloat16
+)
+
+def _generate_testcases():
+ testcases = itertools.product(_MATRIX_SIZES, _DTYPE)
+ testcases = [(*size, 128, dtype) for
+ (size, dtype) in testcases]
+ return testcases
+
+_ELTWISE_OP_TESTS = _generate_testcases()
+
+def _dense_and_sparse_like(x, std=0.1):
+ dense_data = torch.randn_like(x.data, device=x.device) * std
+ sparse = stk.Matrix(x.size(),
+ dense_data,
+ x.row_indices,
+ x.column_indices,
+ x.offsets)
+ dense = stk.ops.to_dense(sparse)
+
+ return (dense.requires_grad_(True),
+ sparse.requires_grad_(True))
+
+@parameterized.parameters(_ELTWISE_OP_TESTS)
+class EltwiseOpsTest(parameterized.TestCase):
+
+ def testEltwiseMul(self, m, n, sparsity, blocking, dtype):
+
+ a_dense, a = _dense_and_sparse(m, n, sparsity, blocking, dtype)
+ b_dense, b = _dense_and_sparse_like(a)
+
+ out = stk.ops.mul(a, b)
+ expected_out = torch.mul(a_dense, b_dense)
+
+ # Compute the gradients w.r.t. the inputs.
+ expected_out.sum().backward()
+ stk.ops.sum(out).backward()
+
+ # Validate the results.
+ out = stk.ops.to_dense(out)
+ self.assertEqual(out.dim(), 2)
+ self.assertEqual(expected_out.size(), out.size())
+ self.assertTrue(allclose(out, expected_out))
+
+ # LHS gradient.
+ grad = stk.ops.to_dense(a.grad)
+ expected_grad = a_dense.grad
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size(), grad.size())
+ self.assertTrue(allclose(grad, expected_grad))
+
+ # RHS gradient.
+ grad = stk.ops.to_dense(b.grad)
+ expected_grad = b_dense.grad
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size(), grad.size())
+ self.assertTrue(allclose(grad, expected_grad))
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch29-cxx11-cu128-x86_64-linux/stk/ops/linear_ops.py b/build/torch29-cxx11-cu128-x86_64-linux/stk/ops/linear_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d277c8c07f9e30addc31900a12175c8a1f4d7ad
--- /dev/null
+++ b/build/torch29-cxx11-cu128-x86_64-linux/stk/ops/linear_ops.py
@@ -0,0 +1,59 @@
+import torch
+
+from ..backend import sputnik
+from ..matrix import Matrix
+
+
+def dsd(a, b):
+ assert isinstance(a, Matrix)
+ assert isinstance(b, torch.Tensor)
+ return sputnik.dsd(
+ a.size(),
+ a.data, a.offsets,
+ a.row_indices,
+ a.column_indices,
+ a.offsets_t,
+ a.column_indices_t,
+ a.block_offsets_t,
+ not a.is_contiguous(),
+ b)
+
+
+def dds(a, b):
+ assert isinstance(a, torch.Tensor)
+ assert isinstance(b, Matrix)
+ return sputnik.dds(
+ a,
+ b.size(),
+ b.data, b.offsets,
+ b.row_indices,
+ b.column_indices,
+ b.offsets_t,
+ b.column_indices_t,
+ b.block_offsets_t,
+ not b.is_contiguous())
+
+
+def sdd(a, b, topo):
+ assert isinstance(a, torch.Tensor)
+ assert isinstance(b, torch.Tensor)
+ assert isinstance(topo, Matrix)
+ assert topo.is_contiguous()
+ out = sputnik.sdd(
+ a, b,
+ topo.size(),
+ topo.data,
+ topo.offsets,
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets_t,
+ topo.column_indices_t,
+ topo.block_offsets_t)
+ return Matrix(topo.size(),
+ out,
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets,
+ topo.column_indices_t,
+ topo.offsets_t,
+ topo.block_offsets_t)
diff --git a/build/torch29-cxx11-cu128-x86_64-linux/stk/ops/linear_ops_test.py b/build/torch29-cxx11-cu128-x86_64-linux/stk/ops/linear_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..ced1d782fbc9f9ca16b3449239f1588dc5ff5e00
--- /dev/null
+++ b/build/torch29-cxx11-cu128-x86_64-linux/stk/ops/linear_ops_test.py
@@ -0,0 +1,216 @@
+import unittest
+import itertools
+import numpy as np
+import torch
+from absl.testing import parameterized
+
+import stk
+
+
+def allclose(x, y, pct=0.25):
+ mask = torch.isclose(x, y, rtol=5e-2)
+ pct_diff = (mask.numel() - mask.sum()) / mask.numel() * 100
+ if pct_diff > pct:
+ print("{:.2f}% of values not close.".format(pct_diff))
+ return False
+ return True
+
+
+# An assortment of problems designed to make sure
+# the bindings are operating correctly.
+_MATRIX_SIZES = (
+ (128, 128, 128, 0.0),
+ (256, 256, 256, 0.5),
+ (2048, 1024, 512, 0.8),
+ (512, 128, 128, 0.0),
+ (128, 128, 512, 0.0),
+ (1024, 512, 512, 0.0),
+ (1024, 512, 512, 0.5),
+ (1024, 512, 512, 0.75),
+ (512, 512, 1024, 0.0),
+ (512, 512, 1024, 0.5),
+ (512, 512, 1024, 0.75),
+ (1024, 1024, 1024, 0.0),
+ (1024, 1024, 1024, 0.5),
+ (1024, 1024, 1024, 0.75),
+)
+
+_TRANSPOSE = (
+ (False, False),
+ (False, True),
+ (True, False),
+ (True, True),
+)
+
+_DTYPE = (
+ torch.float16, torch.bfloat16
+)
+
+def _generate_testcases():
+ testcases = itertools.product(_MATRIX_SIZES, _TRANSPOSE, _DTYPE)
+ testcases = [(*size, *trans, 128, dtype) for
+ (size, trans, dtype) in testcases]
+ return testcases
+
+_LINEAR_OP_TESTS = _generate_testcases()
+
+def _dense_and_sparse(rows, cols, sparsity, blocking, dtype, std=0.1):
+ mask = stk.random.dense_mask(rows, cols, sparsity, blocking)
+ dense = (torch.randn(rows, cols) * std * mask).type(dtype)
+ sparse = stk.ops.to_sparse(dense, blocking)
+ cuda_device = torch.device("cuda")
+ return (dense.to(cuda_device).requires_grad_(True),
+ sparse.to(cuda_device).requires_grad_(True))
+
+
+def _dense(rows, cols, dtype, std=0.1):
+ cuda_device = torch.device("cuda")
+ out = (torch.randn(rows, cols) * std).type(dtype)
+ return out.to(cuda_device).requires_grad_(True)
+
+
+def _dense_2x(rows, cols, dtype):
+ a = _dense(rows, cols, dtype)
+ return a, a.detach().requires_grad_(True)
+
+
+def _with_transpose(op, a, b, trans_a, trans_b):
+ a = a.t() if trans_a else a
+ b = b.t() if trans_b else b
+ return op(a, b)
+
+
+def _mmm(a, b, topo):
+ mask = stk.ops.to_dense(stk.ops.ones_like(topo))
+ return torch.mm(a, b) * mask
+
+
+def _sparse_out_with_transpose(op, a, b, topo, trans_a, trans_b):
+ a = a.t() if trans_a else a
+ b = b.t() if trans_b else b
+ return op(a, b, topo)
+
+
+def _mask(x, mask):
+ mask = stk.ops.to_dense(stk.ops.ones_like(mask))
+ return x * mask
+
+
+@parameterized.parameters(*_LINEAR_OP_TESTS)
+class LinearOpsTest(parameterized.TestCase):
+
+ def testLinearOps_Dsd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+ # Construct the operands.
+ a_shape = (k, m) if trans_a else (m, k)
+ a_dense, a = _dense_and_sparse(*a_shape, sparsity, blocking, dtype)
+ b_shape = (n, k) if trans_b else (k, n)
+ b, bcp = _dense_2x(*b_shape, dtype)
+
+ # Execute the matmul.
+ out = _with_transpose(stk.ops.dsd, a, b, trans_a, trans_b)
+ expected_out = _with_transpose(torch.mm, a_dense, bcp, trans_a, trans_b)
+
+ # Compute the gradients w.r.t. the inputs.
+ expected_out.sum().backward()
+ out.sum().backward()
+
+ # Validate the results.
+ self.assertEqual(out.dim(), 2)
+ self.assertEqual(expected_out.size()[0], out.size()[0])
+ self.assertEqual(expected_out.size()[1], out.size()[1])
+ self.assertTrue(allclose(out, expected_out))
+
+ # LHS gradient.
+ grad = stk.ops.to_dense(a.grad)
+ expected_grad = _mask(a_dense.grad, a.grad)
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
+ self.assertTrue(allclose(grad, expected_grad))
+
+ # RHS gradient.
+ grad = b.grad
+ expected_grad = bcp.grad
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
+ self.assertTrue(allclose(grad, expected_grad))
+
+ def testLinearOps_Dds(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+ # Construct the operands.
+ a_shape = (k, m) if trans_a else (m, k)
+ a, acp = _dense_2x(*a_shape, dtype)
+ b_shape = (n, k) if trans_b else (k, n)
+ b_dense, b = _dense_and_sparse(*b_shape, sparsity, blocking, dtype)
+
+ # Execute the matmul.
+ out = _with_transpose(stk.ops.dds, a, b, trans_a, trans_b)
+ expected_out = _with_transpose(torch.mm, acp, b_dense, trans_a, trans_b)
+
+ # Compute the gradients w.r.t. the inputs.
+ expected_out.sum().backward()
+ out.sum().backward()
+
+ # Validate the results.
+ self.assertEqual(out.dim(), 2)
+ self.assertEqual(expected_out.size()[0], out.size()[0])
+ self.assertEqual(expected_out.size()[1], out.size()[1])
+ self.assertTrue(allclose(out, expected_out))
+
+ # LHS gradient.
+ grad = a.grad
+ expected_grad = acp.grad
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
+ self.assertTrue(allclose(grad, expected_grad))
+
+ # RHS gradient.
+ grad = stk.ops.to_dense(b.grad)
+ expected_grad = _mask(b_dense.grad, b.grad)
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
+ self.assertTrue(allclose(grad, expected_grad))
+
+ def testLinearOps_Sdd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+ # Construct the operands.
+ a_shape = (k, m) if trans_a else (m, k)
+ a, acp = _dense_2x(*a_shape, dtype)
+ b_shape = (n, k) if trans_b else (k, n)
+ b, bcp = _dense_2x(*b_shape, dtype)
+ _, topo = _dense_and_sparse(m, n, sparsity, blocking, dtype)
+
+ # Execute the matmul.
+ out = _sparse_out_with_transpose(stk.ops.sdd, a, b, topo, trans_a, trans_b)
+ expected_out = _sparse_out_with_transpose(_mmm, acp, bcp, topo, trans_a, trans_b)
+
+ # Compute the gradients w.r.t. the inputs.
+ expected_out.sum().backward()
+ stk.ops.sum(out).backward()
+
+ # Validate the results.
+ out = stk.ops.to_dense(out)
+ self.assertEqual(out.dim(), 2)
+ self.assertEqual(expected_out.size()[0], out.size()[0])
+ self.assertEqual(expected_out.size()[1], out.size()[1])
+ self.assertTrue(allclose(out, expected_out))
+
+ # LHS gradient.
+ grad = a.grad
+ expected_grad = acp.grad
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
+ self.assertTrue(allclose(grad, expected_grad))
+
+ # RHS gradient.
+ grad = b.grad
+ expected_grad = bcp.grad
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
+ self.assertTrue(allclose(grad, expected_grad))
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch29-cxx11-cu128-x86_64-linux/stk/ops/matrix_ops.py b/build/torch29-cxx11-cu128-x86_64-linux/stk/ops/matrix_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..447c72dc73439d84f58c917676cc04e64f13e97d
--- /dev/null
+++ b/build/torch29-cxx11-cu128-x86_64-linux/stk/ops/matrix_ops.py
@@ -0,0 +1,98 @@
+from ..backend import sputnik
+from ..matrix import Matrix
+import torch
+import numpy as np
+
+
+@torch.no_grad()
+def row_indices(shape, data, offsets, column_indices):
+ return sputnik.row_indices(shape, data, offsets, column_indices)
+
+
+# TODO(tgale): Replace this helper with a custom kernel. This operation
+# is much simpler to do than how it's currently implemented.
+@torch.no_grad()
+def _expand_for_blocking(idxs, blocking):
+ # Duplicate for block column dimension.
+ idxs = torch.reshape(idxs, [idxs.size()[0], 1, 2]).repeat(1, blocking, 1)
+
+ # Update the column indices.
+ idxs[:, :, 1] *= blocking
+ idxs[:, :, 1] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking])
+
+ # Duplicate for block row dimension.
+ idxs = torch.reshape(idxs, [idxs.size()[0], 1, blocking, 2])
+ idxs = idxs.repeat(1, blocking, 1, 1)
+
+ # Update the row indices.
+ idxs[:, :, :, 0] *= blocking
+ idxs[:, :, :, 0] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking, 1])
+ idxs = torch.reshape(idxs, [-1, 2])
+ return idxs
+
+
+# TODO(tgale): Add input type checking.
+@torch.no_grad()
+def to_dense(x):
+ assert isinstance(x, Matrix)
+
+ shape = (np.prod(x.shape[:-1]), x.shape[-1])
+ row_idxs = x.row_indices.type(torch.int32)
+ col_idxs = x.column_indices.type(torch.int32)
+ indices = _expand_for_blocking(torch.stack([row_idxs, col_idxs], dim=1), x.blocking)
+ indices = (indices[:, 0] * shape[1] + indices[:, 1]).type(torch.int64)
+
+ out = torch.zeros(shape[0] * shape[1], dtype=x.dtype, device=x.device)
+ out.scatter_(0, indices, x.data.flatten())
+ return out.reshape(x.size())
+
+
+@torch.no_grad()
+def _mask(x, blocking=1):
+ assert x.dim() == 2
+ assert x.size()[0] % blocking == 0
+ assert x.size()[1] % blocking == 0
+ block_rows = x.size()[0] // blocking
+ block_cols = x.size()[1] // blocking
+ x = torch.reshape(x, [block_rows, blocking, block_cols, blocking])
+ x = torch.sum(torch.abs(x), dim=(1, 3))
+ return x != 0
+
+
+# TODO(tgale): Add input type checking.
+@torch.no_grad()
+def to_sparse(x, blocking=1):
+ m = _mask(x, blocking)
+
+ # TODO(tgale): Set to appropriate type for input matrix.
+ row_nnzs = torch.sum(m, dim=1).type(torch.int32)
+ zeros = torch.zeros((1,), dtype=row_nnzs.dtype, device=row_nnzs.device)
+ offsets = torch.cat([zeros, torch.cumsum(row_nnzs, dim=0)])
+ offsets = offsets.type(torch.int32)
+
+ indices = torch.nonzero(m).type(torch.int16)
+ row_indices = indices[:, 0]
+ column_indices = indices[:, 1]
+
+ # Nonzero indices in the dense matrix.
+ nonzero_indices = torch.nonzero(m)
+ nonzero_indices = _expand_for_blocking(nonzero_indices, blocking)
+ nonzero_indices = nonzero_indices[:, 0] * x.size()[1] + nonzero_indices[:, 1]
+
+ # Gather the data and construct the sparse matrix.
+ data = torch.gather(x.flatten(), dim=0, index=nonzero_indices)
+ data = torch.reshape(data, [-1, blocking, blocking])
+ return Matrix(x.size(), data, row_indices, column_indices, offsets)
+
+
+@torch.no_grad()
+def ones_like(x):
+ return Matrix(x.size(),
+ torch.ones_like(x.data),
+ x.row_indices,
+ x.column_indices, x.offsets)
+
+
+def sum(x):
+ assert isinstance(x, Matrix)
+ return x.data.sum()
diff --git a/build/torch29-cxx11-cu128-x86_64-linux/stk/ops/matrix_ops_test.py b/build/torch29-cxx11-cu128-x86_64-linux/stk/ops/matrix_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..3af04c0760483e578f93303dc457415948a2a34c
--- /dev/null
+++ b/build/torch29-cxx11-cu128-x86_64-linux/stk/ops/matrix_ops_test.py
@@ -0,0 +1,62 @@
+import unittest
+
+from absl.testing import parameterized
+import stk
+import torch
+
+
+@parameterized.parameters(
+ (8, 16, 0.0, 1),
+ (8, 16, 0.5, 1),
+ (8, 16, .95, 1),
+ (16, 8, 0.0, 1),
+ (16, 8, 0.5, 1),
+ (16, 8, .95, 1),
+ (8, 16, 0.0, 8),
+ (8, 16, 0.5, 8),
+ (8, 16, 1.0, 8),
+ (16, 8, 0.0, 8),
+ (16, 8, 0.5, 8),
+ (16, 8, 1.0, 8),
+ (128, 256, 0.5, 16),
+ (256, 128, 0.75, 32),
+ (512, 512, .875, 128))
+class MatrixOpsTest(parameterized.TestCase):
+
+ def testMatrixOps_FormatConversion(self, rows, cols, sparsity, blocking):
+ mask = stk.random.dense_mask(rows, cols, sparsity, blocking)
+ x = (torch.randn(rows, cols) * mask).type(torch.float16)
+
+ # Convert the matrix to sparse format.
+ sparse_x = stk.ops.to_sparse(x, blocking)
+
+ # Validate the matrix.
+ sparse_x.validate()
+
+ # Validate the shape.
+ self.assertEqual(sparse_x.dim(), 2)
+ self.assertEqual(sparse_x.size()[0], rows)
+ self.assertEqual(sparse_x.size()[1], cols)
+
+ # Validate the sparsity.
+ numblocks = rows // blocking * cols // blocking
+ nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+ self.assertEqual(sparse_x.nnz, nnz)
+
+ # Convert back to dense format.
+ dense_x = stk.ops.to_dense(sparse_x)
+
+ # Validate the shape.
+ self.assertEqual(dense_x.dim(), 2)
+ self.assertEqual(dense_x.size()[0], rows)
+ self.assertEqual(dense_x.size()[1], cols)
+
+ # Validate the sparsity
+ self.assertEqual(torch.count_nonzero(dense_x).item(), nnz)
+
+ # Validate the output.
+ self.assertTrue(torch.all(torch.eq(x, dense_x)))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch29-cxx11-cu128-x86_64-linux/stk/random/__init__.py b/build/torch29-cxx11-cu128-x86_64-linux/stk/random/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2576d1ca27283f77569a9a620c7c99fa68aaf30e
--- /dev/null
+++ b/build/torch29-cxx11-cu128-x86_64-linux/stk/random/__init__.py
@@ -0,0 +1,2 @@
+# from stk.random.random_ops import dense_mask, mask, randn
+from .random_ops import dense_mask, mask, randn
diff --git a/build/torch29-cxx11-cu128-x86_64-linux/stk/random/random_ops.py b/build/torch29-cxx11-cu128-x86_64-linux/stk/random/random_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1b36771e0eb8e7abf46bcb3b136b5fb1d29df93
--- /dev/null
+++ b/build/torch29-cxx11-cu128-x86_64-linux/stk/random/random_ops.py
@@ -0,0 +1,36 @@
+import numpy as np
+import torch
+from ..ops import matrix_ops
+
+
+@torch.no_grad()
+def dense_mask(rows, cols, sparsity, blocking=1):
+ assert sparsity >= 0.0 and sparsity <= 1.0
+ assert rows % blocking == 0 and cols % blocking == 0
+
+ block_rows, block_cols = (rows // blocking, cols // blocking)
+ nnz = round(block_rows * block_cols * (1 - sparsity))
+
+ out = np.ones(block_rows * block_cols)
+ mask = np.random.choice(out.size, out.size - nnz, replace=False)
+ out[mask] = 0.0
+
+ out = np.tile(
+ np.reshape(out, [block_rows, 1, block_cols, 1]),
+ (1, blocking, 1, blocking))
+ out = np.reshape(out, [rows, cols])
+ return torch.from_numpy(out.astype(np.float32))
+
+
+@torch.no_grad()
+def mask(m, n, sparsity, blocking=1):
+ out = dense_mask(m, n, sparsity, blocking).type(torch.float16)
+ return matrix_ops.to_sparse(out, blocking=blocking)
+
+
+@torch.no_grad()
+def randn(shape, sparsity, blocking=1):
+ shape_2d = (np.prod(shape[:-1]), shape[-1])
+ out = mask(*shape_2d, sparsity, blocking)
+ out.data.copy_(torch.randn(*out.data.shape))
+ return out.view(*shape)
diff --git a/build/torch29-cxx11-cu128-x86_64-linux/stk/random/random_ops_test.py b/build/torch29-cxx11-cu128-x86_64-linux/stk/random/random_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..587b44ec890c861879c6296b8f9028f5d99ab82f
--- /dev/null
+++ b/build/torch29-cxx11-cu128-x86_64-linux/stk/random/random_ops_test.py
@@ -0,0 +1,73 @@
+import unittest
+
+from absl.testing import parameterized
+from . import random
+import torch
+
+
+@parameterized.parameters(
+ (8, 16, 0.0, 1),
+ (8, 16, 0.5, 1),
+ (8, 16, .95, 1),
+ (16, 8, 0.0, 1),
+ (16, 8, 0.5, 1),
+ (16, 8, .95, 1),
+ (8, 16, 0.0, 8),
+ (8, 16, 0.5, 8),
+ (8, 16, 1.0, 8),
+ (16, 8, 0.0, 8),
+ (16, 8, 0.5, 8),
+ (16, 8, 1.0, 8),
+ (128, 256, 0.5, 16),
+ (256, 128, 0.75, 32),
+ (512, 512, .875, 128))
+class RandomOpsTest(parameterized.TestCase):
+
+ def testRandomOps_DenseMask(self, rows, cols, sparsity, blocking):
+ mask = random.dense_mask(
+ rows, cols, sparsity, blocking)
+
+ # Validate the shape.
+ self.assertEqual(mask.dim(), 2)
+ self.assertEqual(mask.size()[0], rows)
+ self.assertEqual(mask.size()[1], cols)
+
+ # Validate the sparsity
+ numblocks = rows // blocking * cols // blocking
+ nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+ self.assertEqual(
+ torch.count_nonzero(mask).item(),
+ nnz)
+
+ # Check values are zero or one.
+ self.assertTrue(
+ torch.all(torch.logical_or(
+ torch.eq(mask, 0),
+ torch.eq(mask, 1))))
+
+ def testRandomOps_SparseMask(self, rows, cols, sparsity, blocking):
+ mask = random.mask(
+ rows, cols, sparsity, blocking)
+
+ # Validate the matrix.
+ mask.validate()
+
+ # Validate the shape.
+ self.assertEqual(mask.dim(), 2)
+ self.assertEqual(mask.size()[0], rows)
+ self.assertEqual(mask.size()[1], cols)
+
+ # Validate the sparsity.
+ numblocks = rows // blocking * cols // blocking
+ nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+ self.assertEqual(mask.nnz, nnz)
+
+ # Check values are zero or one.
+ self.assertTrue(
+ torch.all(torch.logical_or(
+ torch.eq(mask.data, 0),
+ torch.eq(mask.data, 1))))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch29-cxx11-cu128-x86_64-linux/xpu_fused_moe.py b/build/torch29-cxx11-cu128-x86_64-linux/xpu_fused_moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..0440b484e6f4073a384dcf3bc562174d2ef71d25
--- /dev/null
+++ b/build/torch29-cxx11-cu128-x86_64-linux/xpu_fused_moe.py
@@ -0,0 +1,672 @@
+# SPDX-License-Identifier: Apache-2.0
+# MegaBlocks XPU Fused MoE Implementation
+import os
+import torch
+
+from ._ops import ops, add_op_namespace_prefix
+
+from torch.library import register_fake
+
+
+def resolve_dtensor(weight: torch.Tensor):
+ """Convert DTensor to local tensor for use with custom ops."""
+ from torch.distributed._tensor import DTensor
+ if isinstance(weight, DTensor):
+ return weight.to_local()
+ return weight
+
+
+# Register fake/meta kernels for torch.compile compatibility
+def _register_xpu_fake_kernels():
+ """Register fake kernels for XPU MoE operations to support torch.compile."""
+
+ def _register_if_available(op_name, fn):
+ if hasattr(ops, op_name):
+ register_fake(add_op_namespace_prefix(op_name))(fn)
+
+ _register_if_available(
+ "cutlass_grouped_gemm_interface",
+ lambda ptr_A, ptr_B, ptr_scales, ptr_bias, ptr_D, expert_first_token_offset, N, K, num_experts, is_B_int4, is_B_mxfp4: ptr_D,
+ )
+
+ _register_if_available(
+ "fused_moe_prologue",
+ lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, ep_rank, ep_size, num_experts_on_rank: None,
+ )
+
+ _register_if_available(
+ "moe_gather",
+ lambda output, moe_output, topk_weights, permuted_row_to_unpermuted_row, unpermuted_row_to_permuted_row, expert_first_token_offset, num_experts: None,
+ )
+
+ _register_if_available(
+ "silu_and_mul",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "mul_and_silu",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_and_mul",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_tanh_and_mul",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_fast",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_new",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_quick",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "swigluoai_and_mul",
+ lambda out, input, alpha=1.702, limit=7.0: None,
+ )
+
+
+# Register fake kernels on module load
+_register_xpu_fake_kernels()
+
+
+# default
+def cutlass_grouped_gemm(input_A, input_B, bias, output, expert_token_count, n,
+ k, num_experts):
+ # expert_token_count_ = torch.tensor(expert_token_count,
+ # dtype=torch.int64,
+ # device=input_A.device)
+ # if bias is not None:
+ # bias = bias.repeat_interleave(expert_token_count_, dim=0).float()
+
+ def exclusive_prefix_sum(arr):
+ prefix = [0]
+ for i, x in enumerate(arr):
+ prefix.append(prefix[-1] + x)
+ return prefix
+
+ expert_offset = torch.tensor(exclusive_prefix_sum(expert_token_count),
+ dtype=torch.int64,
+ device="xpu")
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=None,
+ ptr_bias=bias,
+ ptr_D=output,
+ expert_first_token_offset=expert_offset,
+ N=n,
+ K=k,
+ num_experts=num_experts,
+ is_B_int4=False,
+ is_B_mxfp4=False)
+
+
+def cutlass_grouped_gemm_xe2(input_A, input_B, scales, bias, output,
+ num_rows_per_expert, n, k, num_experts, is_B_int4,
+ is_B_mxfp4):
+ expert_first_token_offset = torch.cat([
+ torch.tensor([0],
+ dtype=num_rows_per_expert.dtype,
+ device=num_rows_per_expert.device),
+ torch.cumsum(num_rows_per_expert, dim=0)
+ ]).to(torch.int64)
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=scales,
+ ptr_bias=bias,
+ ptr_D=output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=n,
+ K=k,
+ num_experts=num_experts,
+ is_B_int4=is_B_int4,
+ is_B_mxfp4=is_B_mxfp4)
+
+
+def ceilDiv(a, b):
+ return (a + b - 1) // b
+
+
+def compute_num_tokens_per_block(num_tokens, num_experts_per_node):
+ for num_tokens_per_block in [32, 64, 128, 256, 512, 1024]:
+ num_blocks_per_seq = ceilDiv(num_tokens, num_tokens_per_block)
+ if num_blocks_per_seq * num_experts_per_node <= num_tokens_per_block:
+ return num_tokens_per_block
+ return 1024
+
+
+def _bytes_to_typed_tensor(byte_tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
+ """Reinterpret a uint8 buffer as a typed tensor by copying bytes.
+
+ This avoids `Tensor.view(dtype)` which can fail under torch.compile
+ constant folding when shape divisibility is not proven.
+ """
+ if byte_tensor.dtype != torch.uint8:
+ raise ValueError("byte_tensor must be uint8")
+ itemsize = torch.empty((), dtype=dtype).element_size()
+ numel = byte_tensor.numel() // itemsize
+ out = torch.empty((numel,), dtype=dtype, device=byte_tensor.device)
+ out.view(torch.uint8).copy_(byte_tensor.contiguous())
+ return out
+
+
+def implement_zp(qweight):
+ # change u4 to s4 to avoid zero point in gemm kernel
+ # only support default zero point now
+ assert qweight.dtype == torch.uint8, "Input tensor must be uint8"
+
+ high_u4 = (qweight >> 4) & 0x0F
+ low_u4 = qweight & 0x0F
+
+ high_s8 = high_u4.to(torch.int8)
+ low_s8 = low_u4.to(torch.int8)
+
+ high_s8 = high_s8 - 8
+ low_s8 = low_s8 - 8
+
+ def pack_compact(a, b):
+
+ def process_number(x):
+ sign = (x < 0).to(torch.uint8)
+ abs_low3 = (x.view(torch.uint8) & 0x7).to(torch.uint8)
+ return (sign << 3) | abs_low3
+
+ packed_a = process_number(a)
+ packed_b = process_number(b)
+
+ return (packed_a << 4) | packed_b
+
+ result = pack_compact(high_s8, low_s8)
+
+ return result
+
+
+def xpu_fused_moe(hidden_states,
+ w13,
+ w13_scales,
+ w13_bias,
+ w2,
+ w2_scales,
+ w2_bias,
+ topk_weights,
+ topk_ids,
+ n_experts_per_token,
+ activation,
+ num_experts,
+ ep_rank=0,
+ ep_size=1,
+ is_fp8=False,
+ is_int4=False,
+ is_mxfp4=False):
+ '''
+ hidden_states: [num_rows, hidden_size]
+ w13: [num_experts, 2*inter_size, hidden_size]
+ w13_scales:
+ None for bf16/fp16
+ or [num_experts] for fp8
+ or [num_experts, 2*inter_size, hidden_size // group_size] for 4bits
+ w13_bias: [num_experts, 2*inter_size] or None
+ w2: [num_experts, hidden_size, inter_size]
+ w2_scales:
+ None for bf16/fp16
+ or [num_experts] for fp8
+ or [num_experts, hidden_size, inter_size // group_size] for 4bits
+ w2_bias: [num_experts, hidden_size] or None
+ topk_weights: [num_rows, topk]
+ topk_ids: [num_rows, topk]
+ n_experts_per_token: int
+ activation: str
+ num_experts: int
+ is_int4: bool
+ is_mxfp4: bool
+ '''
+
+ # Resolve DTensors to local tensors before passing to custom ops
+ hidden_states = resolve_dtensor(hidden_states)
+ w13 = resolve_dtensor(w13)
+ w2 = resolve_dtensor(w2)
+ if w13_scales is not None:
+ w13_scales = resolve_dtensor(w13_scales)
+ if w13_bias is not None:
+ w13_bias = resolve_dtensor(w13_bias)
+ if w2_scales is not None:
+ w2_scales = resolve_dtensor(w2_scales)
+ if w2_bias is not None:
+ w2_bias = resolve_dtensor(w2_bias)
+ topk_weights = resolve_dtensor(topk_weights)
+ topk_ids = resolve_dtensor(topk_ids)
+
+ output = torch.empty_like(hidden_states)
+ num_rows, hidden_size = list(hidden_states.shape)
+
+ dim_last = w13.shape[-1]
+ dim_second_last = w13.shape[-2]
+
+ # w13 is combined gate+up weights, so one dimension is 2*inter_size
+ # Determine which dimension is hidden_size and which is 2*inter_size
+ if dim_second_last == hidden_size:
+ # w13 is [E, hidden_size, 2*inter_size] - standard layout
+ inter_size = dim_last // 2
+ needs_transpose = False
+ else:
+ # w13 is [E, 2*inter_size, hidden_size] - needs transpose
+ inter_size = dim_second_last // 2
+ needs_transpose = True
+
+ assert w13.is_contiguous() and w2.is_contiguous()
+
+ # 4bits support [E, N, K]
+ # other types [E, K, N]
+ if not is_int4 and not is_mxfp4:
+ if not hasattr(w13, 'xpu_fused_moe'):
+ if needs_transpose:
+ w13.data = w13.transpose(-1, -2).contiguous()
+ w2.data = w2.transpose(-1, -2).contiguous()
+ w13.xpu_fused_moe = True
+ w13.inter_size = inter_size
+ else:
+ inter_size = w13.inter_size
+
+ if is_int4 and not hasattr(w13, 'xpu_fused_moe'):
+ w13_tmp = torch.empty_like(w13)
+ w2_tmp = torch.empty_like(w2)
+ for i in range(num_experts):
+ w13_tmp[i] = implement_zp(w13[i])
+ w2_tmp[i] = implement_zp(w2[i])
+ w13_tmp = w13_tmp.contiguous()
+ w2_tmp = w2_tmp.contiguous()
+ w13.data = w13_tmp
+ w2.data = w2_tmp
+ w13.xpu_fused_moe = True
+
+ # TODO: will all integrated in Cpp func. Temporary expose before gemm fusion
+ num_experts_per_node = num_experts
+ experts_per_token = n_experts_per_token
+ num_moe_inputs = n_experts_per_token * num_rows
+ permuted_elems = num_moe_inputs * hidden_size
+ # interbuf_elems = num_moe_inputs * inter_size
+ permuted_row_to_unpermuted_row_size = num_moe_inputs * 4
+ permuted_token_selected_experts_size = num_moe_inputs * 4
+ src_to_dest_map_size = experts_per_token * num_rows * 4
+ expert_first_token_offset_size = (num_experts_per_node + 1) * 8
+ num_tokens_per_block = compute_num_tokens_per_block(
+ num_rows, num_experts_per_node)
+ num_blocks_per_seq = ceilDiv(num_rows, num_tokens_per_block)
+ blocked_expert_counts_size = num_experts_per_node * num_blocks_per_seq * 4
+ blocked_expert_counts_cumsum_size = blocked_expert_counts_size
+ blocked_row_to_unpermuted_row_size = num_experts_per_node * num_rows * 4
+ permuted_data_size = permuted_elems * hidden_states.element_size()
+ permuted_token_final_scales_size = num_moe_inputs * 4
+
+ ws_map = {}
+ map_offset = 0
+
+ def config_ws(name, size):
+ nonlocal map_offset
+ if size % 256 != 0:
+ size += 256 - size % 256
+ ws_map[name] = (size, map_offset)
+ map_offset += size
+
+ config_ws("permuted_row_to_unpermuted_row",
+ permuted_row_to_unpermuted_row_size)
+ config_ws("permuted_token_selected_experts",
+ permuted_token_selected_experts_size)
+ config_ws("unpermuted_row_to_permuted_row", src_to_dest_map_size)
+ config_ws("blocked_expert_counts", blocked_expert_counts_size)
+ config_ws("blocked_expert_counts_cumsum",
+ blocked_expert_counts_cumsum_size)
+ config_ws("blocked_row_to_unpermuted_row",
+ blocked_row_to_unpermuted_row_size)
+ config_ws("expert_first_token_offset", expert_first_token_offset_size)
+ config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
+ config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
+
+ workspace = torch.zeros(map_offset,
+ dtype=torch.uint8,
+ device=hidden_states.device)
+ if topk_ids.dtype == torch.int32:
+ topk_ids = topk_ids.to(torch.int64)
+ ops.fused_moe_prologue(
+ input=hidden_states,
+ token_selected_experts=topk_ids,
+ token_final_scales=topk_weights,
+ workspace=workspace,
+ hidden_size=hidden_size,
+ inter_size=inter_size,
+ ep_rank=ep_rank,
+ ep_size=ep_size,
+ num_experts_on_rank=num_experts_per_node)
+
+ expert_first_token_offset_bytes = workspace[
+ ws_map["expert_first_token_offset"][1]:
+ ws_map["expert_first_token_offset"][1] +
+ expert_first_token_offset_size]
+ unpermuted_row_to_permuted_row_bytes = workspace[
+ ws_map["unpermuted_row_to_permuted_row"][1]:
+ ws_map["unpermuted_row_to_permuted_row"][1] +
+ src_to_dest_map_size]
+ permuted_row_to_unpermuted_row_bytes = workspace[
+ ws_map["permuted_row_to_unpermuted_row"][1]:
+ ws_map["permuted_row_to_unpermuted_row"][1] +
+ permuted_row_to_unpermuted_row_size]
+
+ if torch.compiler.is_compiling():
+ expert_first_token_offset = _bytes_to_typed_tensor(
+ expert_first_token_offset_bytes, torch.int64
+ )
+ unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
+ unpermuted_row_to_permuted_row_bytes, torch.int32
+ )
+ permuted_row_to_unpermuted_row = _bytes_to_typed_tensor(
+ permuted_row_to_unpermuted_row_bytes, torch.int32
+ )
+ else:
+ expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
+ unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
+ permuted_row_to_unpermuted_row = permuted_row_to_unpermuted_row_bytes.view(torch.int32)
+ gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
+ ws_map["overlapped_gemm1_gemm2_inputs"][1] +
+ permuted_data_size].view(hidden_states.dtype).view(
+ num_moe_inputs, hidden_size)
+ # permuted_token_final_scales = workspace[
+ # ws_map["permuted_token_final_scales"][1]:
+ # ws_map["permuted_token_final_scales"][1] +
+ # permuted_token_final_scales_size].view(torch.float)
+ gemm1_output = torch.empty((num_moe_inputs, 2 * inter_size),
+ dtype=hidden_states.dtype,
+ device=hidden_states.device)
+
+ ########### gemm1 ##################
+ input_B = w13
+
+ if not is_fp8 and not is_int4 and not is_mxfp4:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=gemm1_input,
+ ptr_B=input_B,
+ ptr_scales=None,
+ ptr_bias=w13_bias,
+ ptr_D=gemm1_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=2 * inter_size,
+ K=hidden_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+ else:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=gemm1_input,
+ ptr_B=input_B,
+ ptr_scales=w13_scales,
+ ptr_bias=w13_bias,
+ ptr_D=gemm1_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=2 * inter_size,
+ K=hidden_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+
+ # act
+ act_output = torch.empty((num_moe_inputs, inter_size),
+ dtype=gemm1_output.dtype,
+ device=gemm1_output.device)
+ if activation == "silu":
+ ops.silu_and_mul(act_output, gemm1_output)
+ elif activation == "gelu":
+ ops.gelu_and_mul(act_output, gemm1_output)
+ elif activation == "swigluoai":
+ ops.swigluoai_and_mul(act_output, gemm1_output, 1.702, 7.0)
+ else:
+ raise ValueError(f"Unsupported FusedMoe activation: {activation}.")
+
+ ########### gemm2 ##################
+ input_A = act_output.contiguous()
+ input_B = w2
+ gemm2_output = torch.empty((num_moe_inputs, hidden_size),
+ dtype=hidden_states.dtype,
+ device=hidden_states.device)
+ if not is_fp8 and not is_int4 and not is_mxfp4:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=None,
+ ptr_bias=w2_bias,
+ ptr_D=gemm2_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=hidden_size,
+ K=inter_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+ else:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=w2_scales,
+ ptr_bias=w2_bias,
+ ptr_D=gemm2_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=hidden_size,
+ K=inter_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+
+ ops.moe_gather(output, gemm2_output, topk_weights,
+ permuted_row_to_unpermuted_row,
+ unpermuted_row_to_permuted_row,
+ expert_first_token_offset,
+ num_experts_per_node)
+ return output
+
+
+def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor:
+ """Apply jitter to the input tensor for regularization."""
+ low = 1.0 - moe_jitter_eps
+ high = 1.0 + moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return x * (low + noise * (high - low))
+
+
+def compute_top_k(scores: torch.Tensor, moe_top_k: int):
+ """Compute the top-k scores from the logits."""
+ if moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, moe_top_k, dim=-1)
+
+
+def route_tokens_xpu(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: torch.Tensor,
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ training: bool = False,
+) -> tuple:
+ """Route tokens to experts and compute expert weights and indices (XPU version)."""
+ if training and moe_jitter_eps is not None:
+ x = apply_jitter(x, moe_jitter_eps)
+
+ x_flat = x.view(-1, x.shape[-1])
+ logits = torch.nn.functional.linear(x_flat, router_weight, router_bias)
+ expert_weights, expert_indices = compute_top_k(logits, moe_top_k)
+ expert_weights = expert_weights.softmax(dim=-1)
+ if moe_normalize_expert_weights is not None:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+
+ return logits, expert_weights, expert_indices
+
+
+def _get_device_mesh(model):
+ """Extract device_mesh from child's unused pre_hook closure for EP support."""
+ try:
+ hook = next(
+ h
+ for h in model.experts._forward_pre_hooks.values()
+ if "device_mesh" in h.__code__.co_freevars
+ )
+ return hook.__closure__[
+ hook.__code__.co_freevars.index("device_mesh")
+ ].cell_contents
+ except Exception:
+ return None
+
+
+class MegaBlocksMoeMLP(torch.nn.Module):
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> tuple:
+ """
+ Forward pass through the MoE layer.
+
+ Args:
+ x: Input tensor of shape [batch_size, seq_len, hidden_size] or [tokens, hidden_size]
+
+ Returns:
+ Tuple of (output, expert_weights) where:
+ - output: Tensor of same shape as input
+ - expert_weights: Expert weights for each token [tokens, top_k]
+ """
+
+ # Get MoE parameters from the wrapped modules
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+
+ # Get EP (Expert Parallelism) parameters
+ ep_size = 1
+ ep_rank = 0
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = _get_device_mesh(self)
+ if device_mesh is not None:
+ expert_parallel_group = device_mesh.get_group()
+ if expert_parallel_group is not None:
+ import torch.distributed as dist
+ if dist.is_initialized():
+ ep_size = dist.get_world_size(expert_parallel_group)
+ ep_rank = dist.get_rank(expert_parallel_group)
+
+ # Number of experts on this rank
+ num_experts_on_rank = moe_num_experts // ep_size
+
+ # Detect activation type - check for GptOss-style swigluoai activation
+ # GptOssExperts has alpha and limit attributes for swigluoai
+ if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
+ activation = "swigluoai"
+ else:
+ activation = getattr(self.experts, "activation", "silu")
+
+ # Get weight tensors - support different naming conventions
+ if hasattr(self.experts, "gate_up_proj"):
+ w13 = self.experts.gate_up_proj
+ # NOTE: swigluoai_and_mul kernel expects interleaved layout [g0,u0,g1,u1,...]
+ # which matches GptOss's gate_up_proj layout, so no conversion needed.
+
+ elif hasattr(self.experts, "w1"):
+ # Combine w1 and w3 if stored separately
+ w1 = self.experts.w1
+ w3 = getattr(self.experts, "w3", None)
+ if w3 is not None:
+ w13 = torch.cat([w1, w3], dim=-2)
+ else:
+ w13 = w1
+ else:
+ raise AttributeError("experts module must have 'gate_up_proj' or 'w1' attribute")
+
+ if hasattr(self.experts, "down_proj"):
+ w2 = self.experts.down_proj
+ elif hasattr(self.experts, "w2"):
+ w2 = self.experts.w2
+ else:
+ raise AttributeError("experts module must have 'down_proj' or 'w2' attribute")
+
+ # Get optional bias tensors
+ w13_bias = getattr(self.experts, "gate_up_proj_bias", None)
+ w2_bias = getattr(self.experts, "down_proj_bias", None)
+
+ # Get quantization info
+ is_fp8 = getattr(self.experts, "is_fp8", False)
+ is_int4 = getattr(self.experts, "is_int4", False)
+ is_mxfp4 = getattr(self.experts, "is_mxfp4", False)
+
+ w13_scales = getattr(self.experts, "gate_up_proj_scales", None)
+ w2_scales = getattr(self.experts, "down_proj_scales", None)
+
+ # Store original shape
+ in_shape = x.size()
+
+ # Route tokens to experts
+ logits, expert_weights, expert_indices = route_tokens_xpu(
+ x,
+ self.router.weight,
+ getattr(self.router, "bias", None),
+ moe_top_k,
+ moe_num_experts,
+ moe_jitter_eps,
+ moe_normalize_expert_weights,
+ self.training,
+ )
+
+ # Reshape input for fused MoE
+ x_flat = x.view(-1, x.shape[-1])
+
+ # Call XPU fused MoE kernel
+ output = xpu_fused_moe(
+ hidden_states=x_flat,
+ w13=w13,
+ w13_scales=w13_scales,
+ w13_bias=w13_bias,
+ w2=w2,
+ w2_scales=w2_scales,
+ w2_bias=w2_bias,
+ topk_weights=expert_weights.float(),
+ topk_ids=expert_indices,
+ n_experts_per_token=moe_top_k,
+ activation=activation,
+ num_experts=num_experts_on_rank,
+ ep_rank=ep_rank,
+ ep_size=ep_size,
+ is_fp8=is_fp8,
+ is_int4=is_int4,
+ is_mxfp4=is_mxfp4,
+ )
+
+ # All-reduce across EP group to combine partial expert outputs
+ if ep_size > 1 and expert_parallel_group is not None:
+ import torch.distributed as dist
+ dist.all_reduce(output, op=dist.ReduceOp.SUM, group=expert_parallel_group)
+
+ # Restore original shape
+ output = output.view(in_shape)
+
+ return output, expert_weights
+
+
+# Export classes and functions
+__all__ = [
+ "MegaBlocksMoeMLP",
+ "xpu_fused_moe",
+ "cutlass_grouped_gemm",
+ "cutlass_grouped_gemm_xe2",
+]
\ No newline at end of file
diff --git a/build/torch29-cxx11-cu129-aarch64-linux/__init__.py b/build/torch29-cxx11-cu129-aarch64-linux/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..38075732c6d8fa0e1e6ef493145e1aca3851ae6b
--- /dev/null
+++ b/build/torch29-cxx11-cu129-aarch64-linux/__init__.py
@@ -0,0 +1,202 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+from ._ops import ops
+
+from .grouped_gemm import backend as gg_backend
+from .grouped_gemm import ops as gg_ops
+
+
+from ._layers.arguments import Arguments
+from ._layers.dmoe import ParallelDroplessMLP, dMoE
+from ._layers.glu import SparseGLU
+from ._layers.mlp import MLP, SparseMLP
+from ._layers.moe import MoE, ParallelMLP, get_load_balancing_loss
+
+from . import layers
+
+# This section contains the direct kernel exports (not inlcuded in the original code)
+def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
+ """
+ Compute exclusive cumulative sum along the specified dimension.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ result = ops.exclusive_cumsum(x, dim)
+ out.copy_(result)
+ return out
+
+
+def inclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
+ """
+ Compute inclusive cumulative sum along the specified dimension.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ result = ops.inclusive_cumsum(x, dim)
+ out.copy_(result)
+ return out
+
+
+def histogram(x: torch.Tensor, num_bins: int) -> torch.Tensor:
+ """
+ Compute histogram of input tensor values.
+
+ Args:
+ x: Input tensor
+ num_bins: Number of histogram bins
+
+ Returns:
+ Histogram tensor with counts for each bin
+ """
+ return ops.histogram(x, num_bins)
+
+
+def indices(
+ padded_bins: torch.Tensor,
+ block_size: int,
+ output_block_rows: int,
+ output_block_columns: int,
+) -> torch.Tensor:
+ """
+ Construct indices from padded bins for sparse operations.
+
+ Args:
+ padded_bins: Tensor containing bin boundaries
+ block_size: Size of each block
+ output_block_rows: Number of rows in output blocks
+ output_block_columns: Number of columns in output blocks
+
+ Returns:
+ Tensor containing constructed indices
+ """
+ return ops.indices(padded_bins, block_size, output_block_rows, output_block_columns)
+
+
+def replicate_forward(
+ x: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Forward pass of replicate operation - replicate values according to bin sizes.
+
+ Args:
+ x: Input tensor with values to replicate
+ bins: Tensor containing bin sizes
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ return ops.replicate_forward(x, bins, out)
+
+
+def replicate_backward(
+ grad: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Backward pass of replicate operation - reduce gradients back to bins.
+
+ Args:
+ grad: Gradient tensor to reduce
+ bins: Tensor containing bin sizes
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ return ops.replicate_backward(grad, bins, out)
+
+
+def sort(
+ x: torch.Tensor, end_bit: int, x_out: torch.Tensor, iota_out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Radix sort with index tracking.
+
+ Args:
+ x: Input tensor to sort
+ end_bit: Number of bits to consider in sorting
+ x_out: Output tensor for sorted values
+ iota_out: Output tensor for sorted indices
+
+ Returns:
+ The sorted values tensor
+ """
+ return ops.sort(x, end_bit, x_out, iota_out)
+
+
+# Convenience functions for common use cases
+def cumsum(x: torch.Tensor, dim: int = -1, exclusive: bool = False) -> torch.Tensor:
+ """
+ Compute cumulative sum with automatic output allocation.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum (default: last dimension)
+ exclusive: Whether to compute exclusive (True) or inclusive (False) cumsum
+
+ Returns:
+ New tensor containing the cumulative sum
+ """
+ out = torch.empty_like(x)
+ if exclusive:
+ return exclusive_cumsum(x, dim, out)
+ else:
+ return inclusive_cumsum(x, dim, out)
+
+
+def argsort(x: torch.Tensor, end_bit: int = 32) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Sort tensor and return both sorted values and indices.
+
+ Args:
+ x: Input tensor to sort
+ end_bit: Number of bits to consider in sorting
+
+ Returns:
+ Tuple of (sorted_values, sorted_indices)
+ """
+ x_out = torch.empty_like(x)
+ iota_out = torch.empty_like(x)
+ sort(x, end_bit, x_out, iota_out)
+ return x_out, iota_out
+
+
+# Export public API
+__all__ = [
+ "MyReplacementLayer",
+ # Direct kernel exports
+ "exclusive_cumsum",
+ "inclusive_cumsum",
+ "histogram",
+ "indices",
+ "replicate_forward",
+ "replicate_backward",
+ "sort",
+ "cumsum",
+ "argsort",
+ # Original exports
+ "Arguments",
+ "ParallelDroplessMLP",
+ "dMoE",
+ "SparseGLU",
+ "MLP",
+ "SparseMLP",
+ "MoE",
+ "ParallelMLP",
+ "get_load_balancing_loss",
+]
diff --git a/build/torch29-cxx11-cu129-aarch64-linux/_layers/__init__.py b/build/torch29-cxx11-cu129-aarch64-linux/_layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a720e7a2cc4e44636f6e433a2750e945dc38e8b2
--- /dev/null
+++ b/build/torch29-cxx11-cu129-aarch64-linux/_layers/__init__.py
@@ -0,0 +1,10 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# from megablocks.layers.dmoe import dMoE
+from .moe import MoE
+
+__all__ = [
+ 'MoE',
+ # 'dMoE',
+]
diff --git a/build/torch29-cxx11-cu129-aarch64-linux/_layers/activation_fn.py b/build/torch29-cxx11-cu129-aarch64-linux/_layers/activation_fn.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e1d956704840aa4daf7d1d71d24e051567feab9
--- /dev/null
+++ b/build/torch29-cxx11-cu129-aarch64-linux/_layers/activation_fn.py
@@ -0,0 +1,33 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Callable, Union
+
+import torch
+from ..stk import Matrix
+
+
+def act_fn(
+ x: Matrix,
+ function: Callable,
+ return_grad_fn: bool = False,
+ **kwargs,
+) -> Union[tuple[Matrix, Any] | Matrix]:
+ assert isinstance(x, Matrix)
+ with torch.set_grad_enabled(torch.is_grad_enabled() or return_grad_fn):
+ if return_grad_fn:
+ x.data.requires_grad = True
+ out = function(x.data, **kwargs)
+ y = Matrix(
+ x.size(),
+ out,
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
+ if return_grad_fn:
+ return y, out.backward
+ return y
diff --git a/build/torch29-cxx11-cu129-aarch64-linux/_layers/all_to_all.py b/build/torch29-cxx11-cu129-aarch64-linux/_layers/all_to_all.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ac7067bcaa34db1d82b340c43550fe3577aa7a3
--- /dev/null
+++ b/build/torch29-cxx11-cu129-aarch64-linux/_layers/all_to_all.py
@@ -0,0 +1,54 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import torch.distributed as dist
+
+
+class AllToAllOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op):
+ out = torch.empty((sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype)
+
+ ctx.input_shape = x.shape
+ ctx.output_split_sizes = output_split_sizes
+ ctx.input_split_sizes = input_split_sizes
+ ctx.group = group
+ handle = dist.all_to_all_single(
+ out,
+ x,
+ output_split_sizes=output_split_sizes,
+ input_split_sizes=input_split_sizes,
+ group=group,
+ async_op=async_op,
+ )
+ return out, handle
+
+ @staticmethod
+ def backward(ctx, grad, _):
+ if ctx.needs_input_grad[0]:
+ out = torch.empty(
+ ctx.input_shape,
+ device=grad.device,
+ dtype=grad.dtype,
+ )
+ dist.all_to_all_single(
+ out,
+ grad,
+ output_split_sizes=ctx.input_split_sizes,
+ input_split_sizes=ctx.output_split_sizes,
+ group=ctx.group,
+ )
+ return out, None, None, None, None
+ return None, None, None, None, None
+
+
+def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False):
+ return AllToAllOp.apply(
+ x,
+ output_split_sizes,
+ input_split_sizes,
+ group,
+ async_op,
+ )
diff --git a/build/torch29-cxx11-cu129-aarch64-linux/_layers/arguments.py b/build/torch29-cxx11-cu129-aarch64-linux/_layers/arguments.py
new file mode 100644
index 0000000000000000000000000000000000000000..4db9b1bd38bc2e2f421625c124f86b85f45c5ae0
--- /dev/null
+++ b/build/torch29-cxx11-cu129-aarch64-linux/_layers/arguments.py
@@ -0,0 +1,101 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import dataclasses
+from functools import partial
+from typing import Any, Callable, Optional, Union
+
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+
+# import megablocks.grouped_gemm_util as grouped_gemm
+from .. import grouped_gemm_util as grouped_gemm
+
+# Type annotation for in-place Tensor initialization function.
+InitFn = Union[Callable[[torch.Tensor], None], partial[torch.Tensor]]
+
+_ALLOWED_BITWIDTHS = (-1, 4, 8)
+
+DEFAULT_ACTIVATION_FN = partial(F.gelu, approximate='tanh')
+
+
+@dataclasses.dataclass
+class Arguments:
+ # Model arguments.
+ hidden_size: int = 1024
+ ffn_hidden_size: int = 4096
+ num_layers: int = 1
+ bias: bool = True
+ return_bias: bool = True
+ activation_fn: Optional[Callable] = DEFAULT_ACTIVATION_FN
+
+ # MoE arguments.
+ moe_num_experts: int = 1
+ moe_top_k: int = 1
+ moe_capacity_factor: int = 1
+ moe_normalize_expert_weights: Optional[Union[int, float]] = None
+ moe_loss_weight: float = 0.1
+ moe_jitter_eps: Optional[float] = None
+ moe_lbl_in_fp32: bool = False
+
+ # Parallelism arguments.
+ moe_expert_model_parallelism: bool = False
+ expert_parallel_group: Optional[dist.ProcessGroup] = None
+ pipeline_model_parallel_size: int = 1
+ num_layers_per_virtual_pipeline_stage: Optional[int] = None
+
+ # Compute arguments.
+ memory_optimized_mlp: bool = False
+ mlp_type: str = 'mlp'
+ mlp_impl: str = 'sparse'
+
+ # Initialization arguments.
+ fp16: bool = True
+ bf16: bool = False
+ device: Union[int, torch.device] = dataclasses.field(default_factory=torch.cuda.current_device)
+ init_method: InitFn = partial(torch.nn.init.normal_, mean=0.0, std=0.02)
+ output_layer_init_method: InitFn = init_method
+
+ # Benchmarking arguments.
+ uniform_expert_assignment: bool = False
+
+ # shared expert arguments
+ shared_expert: bool = False # enable using shared expert
+ fc_cls: Any = torch.nn.Linear # class of the fully connected layer in shared expert (purpose: to allow using custom FC layer eg te.Linear (for FP8))
+ fc_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict,) # kwargs for custom fc layers
+ remat_act_fn: bool = True # enable act fn to be rematerialized instead of stored
+ shared_expert_hidden_size: Optional[
+ int] = None # hidden size of the shared expert IF we want to set it to something different from hidden_size
+ shared_expert_weighted_sum: bool = False # enable using weighted sum for shared expert output (wieghted by number of experts used)
+
+ # Router Z-loss arguments
+ moe_zloss_weight: float = 0 # 1e-3 is a reasonable value
+ moe_zloss_in_fp32: bool = False
+
+ def __post_init__(self):
+ # Sparse MLP is not supported with triton >=3.2.0
+ # TODO: Remove this once sparse is supported with triton >=3.2.0
+ if self.__getattribute__('mlp_impl') == 'sparse':
+ try:
+ import triton
+ if triton.__version__ >= '3.2.0':
+ raise ValueError(
+ 'Sparse MLP is not supported with triton >=3.2.0. Please use mlp_impl="grouped" instead.',
+ )
+ except ImportError:
+ raise ImportError('Triton is required for sparse MLP implementation')
+
+ if self.__getattribute__('mlp_impl') == 'grouped':
+ grouped_gemm.assert_grouped_gemm_is_available()
+
+ if self.shared_expert_hidden_size is None:
+ self.shared_expert_hidden_size = self.ffn_hidden_size
+
+
+def from_megatron(megatron_args: Any):
+ args = Arguments()
+ for field in dataclasses.fields(args):
+ if hasattr(megatron_args, field.name):
+ setattr(args, field.name, getattr(megatron_args, field.name))
+ return args
diff --git a/build/torch29-cxx11-cu129-aarch64-linux/_layers/common.py b/build/torch29-cxx11-cu129-aarch64-linux/_layers/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d07109702963ba48a3b94ab860807954dfd79c1
--- /dev/null
+++ b/build/torch29-cxx11-cu129-aarch64-linux/_layers/common.py
@@ -0,0 +1,26 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+from .arguments import Arguments
+
+
+def dtype(args: Arguments):
+ if args.fp16:
+ return torch.float16
+ elif args.bf16:
+ return torch.bfloat16
+ return None
+
+
+def cast_if_autocast_enabled(tensor):
+ if torch.is_autocast_enabled():
+ if tensor.device.type == 'cuda':
+ dtype = torch.get_autocast_gpu_dtype()
+ elif tensor.device.type == 'cpu':
+ dtype = torch.get_autocast_cpu_dtype()
+ else:
+ raise NotImplementedError()
+ return tensor.to(dtype=dtype)
+ return tensor
diff --git a/build/torch29-cxx11-cu129-aarch64-linux/_layers/dmlp_registry.py b/build/torch29-cxx11-cu129-aarch64-linux/_layers/dmlp_registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..de2ed047042e438c7190ebb139b6f7f30009734c
--- /dev/null
+++ b/build/torch29-cxx11-cu129-aarch64-linux/_layers/dmlp_registry.py
@@ -0,0 +1,42 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Union
+
+from . import glu, mlp
+from .arguments import Arguments
+
+MlpType = Union[mlp.SparseMLP, glu.SparseGLU]
+
+_REGISTRY = {
+ 'mlp': {
+ 'grouped': mlp.GroupedMLP,
+ 'sparse': mlp.SparseMLP,
+ },
+ 'glu': {
+ 'grouped': glu.GroupedGLU,
+ 'sparse': glu.SparseGLU,
+ },
+}
+
+
+def get(args: Arguments) -> MlpType:
+ """Returns an MLP for use in a dMoE instance.
+
+ Uses the provided arguments to instantiate the appropriate
+ MLP instance. This only contains MLPs for use in dMoEs
+ (ie. only for the dropless versions of MoEs).
+
+ Args:
+ args: propagated Arguments dataclass.
+
+ Returns:
+ An instantiated MLP constructed using the input args.
+ """
+ if args.mlp_type not in _REGISTRY:
+ raise ValueError(f'Unsupported mlp type: {args.mlp_type}')
+
+ if args.mlp_impl not in _REGISTRY[args.mlp_type]:
+ raise ValueError(f'{args.mlp_type} does not support {args.mlp_impl} backend.',)
+
+ return _REGISTRY[args.mlp_type][args.mlp_impl](args)
diff --git a/build/torch29-cxx11-cu129-aarch64-linux/_layers/dmoe.py b/build/torch29-cxx11-cu129-aarch64-linux/_layers/dmoe.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d0375a4df2f27134c4127e60be04f3b45693050
--- /dev/null
+++ b/build/torch29-cxx11-cu129-aarch64-linux/_layers/dmoe.py
@@ -0,0 +1,337 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import numpy as np
+import torch
+
+# try:
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/dmoe.py is needed.',
+# )
+
+# import megablocks.ops as ops
+# # from megablocks.ops import ops
+# from megablocks.layers import common, dmlp_registry, moe, mpu
+# from megablocks.layers.arguments import Arguments
+
+from .. import stk
+from .. import ops
+from . import common, dmlp_registry, moe, mpu
+from .arguments import Arguments
+
+def promote_scalar(x):
+ return x.view(1) if not len(x.size()) else x
+
+
+class ParallelDroplessMLP(moe.ParallelMLP):
+
+ def __init__(self, args: Arguments):
+ super(ParallelDroplessMLP, self).__init__(args)
+ self.hidden_size = args.hidden_size
+ self.ffn_hidden_size = mpu.features_per_rank(args)
+ self.blocking = 128
+ self.mlp = dmlp_registry.get(args)
+
+ # Calculate the number of bits needed to represent the column indices
+ # in the intermediate sparse matrix.
+ max_column_index = ((self.ffn_hidden_size * self.num_experts) // self.blocking)
+ self.transpose_sort_end_bit = max(
+ int(np.ceil(np.log2(max_column_index))),
+ 1,
+ )
+
+ def sparse_transpose(self, size, row_indices, column_indices, offsets):
+ block_columns = size[1] // self.blocking
+
+ # Sort row indices by column indices to get the transposed matrix's
+ # column indices.
+ #
+ # NOTE: Our sort operation uses the same width indices as the input values.
+ # To avoid overflow when we have large activation matrices we cast to
+ # 32-bit before sorting.
+ _, gather_indices = ops.sort(
+ column_indices.int(),
+ self.transpose_sort_end_bit,
+ )
+
+ # There are a constant number of blocks in every row of the sparse matrix.
+ # A blocks offset is:
+ #
+ # row_index * blocks_per_row + column_index % blocks_per_row
+ #
+ # Once we have the block offsets ordered for transposition we can divide
+ # by blocks_per_row to get the transposed column indices.
+ column_indices_t = row_indices.gather(0, gather_indices.long())
+ block_offsets_t = gather_indices.int()
+
+ zero = torch.zeros((1,), dtype=torch.int32, device=row_indices.device)
+ nnz_per_column = ops.histogram(column_indices, block_columns)
+ nnz_per_column = ops.inclusive_cumsum(nnz_per_column, 0)
+ if nnz_per_column.dim() == 0:
+ # This addresses an edge case when ffn_hidden_size is equal to self.blocking.
+ nnz_per_column = nnz_per_column.unsqueeze(0)
+ offsets_t = torch.cat([zero, nnz_per_column])
+ return column_indices_t, offsets_t, block_offsets_t
+
+ def topology(self, x, padded_bins):
+ padded_tokens, _ = x.size()
+ assert padded_tokens % self.blocking == 0
+ if self.ffn_hidden_size % self.blocking != 0:
+ raise ValueError(
+ f'The ffn_hidden_size {self.ffn_hidden_size} must be divisible by ' +
+ f'the block size {self.blocking}. Please update your configuration.',
+ )
+
+ # Offsets for the sparse matrix. All rows have the
+ # same number of nonzero blocks dictated by the
+ # dimensionality of a single expert.
+ block_rows = padded_tokens // self.blocking
+ blocks_per_row = self.ffn_hidden_size // self.blocking
+ offsets = torch.arange(
+ 0,
+ block_rows * blocks_per_row + 1,
+ blocks_per_row,
+ dtype=torch.int32,
+ device=x.device,
+ )
+
+ # Indices for the sparse matrix. The indices for
+ # the intermediate matrix are dynamic depending
+ # on the mapping of tokens to experts.
+ column_indices = ops.topology(
+ padded_bins,
+ self.blocking,
+ block_rows,
+ blocks_per_row,
+ )
+
+ # TODO(tgale): This is unused. Remove the need for this in stk.
+ # For now, use meta init to save the device memory.
+ data = torch.empty(
+ column_indices.numel(),
+ self.blocking,
+ self.blocking,
+ dtype=common.dtype(self.args),
+ device='meta',
+ )
+ shape = (
+ padded_tokens,
+ self.ffn_hidden_size * mpu.experts_per_rank(self.args),
+ )
+ row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
+ column_indices_t, offsets_t, block_offsets_t = self.sparse_transpose(
+ shape,
+ row_indices,
+ column_indices,
+ offsets,
+ )
+ return stk.Matrix(
+ shape,
+ data,
+ row_indices,
+ column_indices,
+ offsets,
+ column_indices_t,
+ offsets_t,
+ block_offsets_t,
+ )
+
+ def indices_and_padded_bins(self, top_experts):
+ # Sort the expert ids to produce the scatter/gather
+ # indices for the permutation.
+ top_experts = top_experts.int()
+ bin_ids, indices = ops.sort(top_experts, self.sort_end_bit)
+
+ # Histogram the expert ids to identify the number of
+ # tokens routed to each expert.
+ tokens_per_expert = ops.histogram(top_experts, self.num_experts)
+
+ # Round the token counts up to the block size used in
+ # the matrix muliplications. Caculate the starting
+ # position of each bin.
+ padded_tokens_per_expert = ops.round_up(
+ tokens_per_expert,
+ self.blocking,
+ )
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ padded_bins = promote_scalar(padded_bins)
+
+ # Calculate the bin bounds for the sorted tokens.
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ bins = promote_scalar(bins)
+ return indices, bin_ids, bins, padded_bins, tokens_per_expert
+
+ def sparse_forward_once(self, x, expert_weights, top_experts):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, padded_bins, tokens_per_expert = (self.indices_and_padded_bins(top_experts))
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.padded_gather(
+ x,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ self.top_k,
+ )
+
+ # Create the sparse matrix topology.
+ with torch.no_grad():
+ topo = self.topology(x, padded_bins)
+
+ # Perform the expert computation.
+ x = self.mlp(x, topo)
+
+ # Un-route the data for the MoE output.
+ x = ops.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ padded_bins,
+ self.top_k,
+ )
+ return x, tokens_per_expert
+
+ # For use in the base-class parallel_forward_once.
+ def sparse_permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy, # unused
+ top_k,
+ ):
+
+ # Round the token counts up to the block size used in the matrix
+ # multiplication. Calculate the starting position of each bin.
+ padded_tokens_per_expert = ops.round_up(
+ tokens_per_expert,
+ self.blocking,
+ )
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ padded_bins = promote_scalar(padded_bins)
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
+
+ # Create the sparse matrix topology.
+ with torch.no_grad():
+ topo = self.topology(x, padded_bins)
+
+ # Perform the expert computation.
+ x = self.mlp(x, topo)
+
+ # Un-route the data for the MoE output.
+ return ops.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ def grouped_forward_once(self, x, expert_weights, top_experts):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ out = self.grouped_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ -1, # unused
+ self.args.moe_top_k,
+ )
+ return out, tokens_per_expert
+
+ def grouped_permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy, # unused
+ top_k,
+ ):
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.gather(x, indices, bin_ids, bins, top_k)
+
+ # Perform the expert computation.
+ x = self.mlp(x, tokens_per_expert)
+
+ # Un-route the data for the MoE output.
+ return ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
+
+ def forward_once(self, x, expert_weights, top_experts):
+ if self.args.mlp_impl == 'sparse':
+ return self.sparse_forward_once(x, expert_weights, top_experts)
+ else:
+ return self.grouped_forward_once(x, expert_weights, top_experts)
+
+ def permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ ):
+ if self.args.mlp_impl == 'sparse':
+ return self.sparse_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ )
+ else:
+ return self.grouped_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ )
+
+
+class dMoE(moe.MoE):
+
+ def _init_experts_mlp(self, args: Arguments):
+ return ParallelDroplessMLP(args)
diff --git a/build/torch29-cxx11-cu129-aarch64-linux/_layers/gelu.py b/build/torch29-cxx11-cu129-aarch64-linux/_layers/gelu.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4c9e6532798615b5c12c96694241a4c18ee8f7b
--- /dev/null
+++ b/build/torch29-cxx11-cu129-aarch64-linux/_layers/gelu.py
@@ -0,0 +1,52 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# try:
+# import stk
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/gelu.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+import torch.nn.functional as F
+
+
+@torch.jit.script
+def _gelu_backward_inplace(g, x):
+ tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
+ ff = (0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out))
+ return g.mul_(ff)
+
+
+def gelu_backward_(grad: stk.Matrix, x: stk.Matrix):
+ # NOTE: The two sparse matrices must have the same topology.
+ if isinstance(grad, stk.Matrix) and isinstance(x, stk.Matrix):
+ return stk.Matrix(
+ x.size(),
+ _gelu_backward_inplace(grad.data, x.data),
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
+ return _gelu_backward_inplace(grad, x)
+
+
+def gelu(x: stk.Matrix):
+ assert isinstance(x, stk.Matrix)
+ return stk.Matrix(
+ x.size(),
+ F.gelu(x.data, approximate='tanh'),
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
diff --git a/build/torch29-cxx11-cu129-aarch64-linux/_layers/glu.py b/build/torch29-cxx11-cu129-aarch64-linux/_layers/glu.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f297a41ff6a1a2a285f5b461951672364b898da
--- /dev/null
+++ b/build/torch29-cxx11-cu129-aarch64-linux/_layers/glu.py
@@ -0,0 +1,244 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# import stk.ops
+# try:
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/glu.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+
+# from megablocks import grouped_gemm_util as gg
+# from megablocks.layers import common, mpu
+# from megablocks.layers.activation_fn import act_fn
+# from megablocks.layers.arguments import Arguments
+# from megablocks.layers.mlp import (
+# SharedMLP,
+# SparseMLP,
+# create_dmoe_expert_weights,
+# resolve_dtensor,
+# )
+
+from .. import grouped_gemm_util as gg
+from . import common, mpu
+from .activation_fn import act_fn
+from .arguments import Arguments
+from .mlp import (
+ SharedMLP,
+ SparseMLP,
+ create_dmoe_expert_weights,
+ resolve_dtensor,
+)
+
+
+class SparseGLU(SparseMLP):
+
+ def __init__(self, args: Arguments):
+ super().__init__(args)
+ self.v1 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ with torch.no_grad():
+ self.v1.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ ),
+ )
+
+ mpu.set_expert_model_parallel_attributes(
+ self.v1,
+ self._should_set_parallelism_attribute,
+ )
+
+ def forward(self, x, topo):
+ if self.args.memory_optimized_mlp:
+ raise NotImplementedError(
+ 'Memory optimized implementation not yet supported with GLU with sparse kernels.',
+ )
+
+ w1, v1, w2 = self.scale_grad(self.w1), self.scale_grad(self.v1,), self.scale_grad(self.w2)
+ w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2)
+
+ # Compute the GLU.
+ x1 = stk.ops.sdd(x, w1.t(), topo)
+ x2 = stk.ops.sdd(x, v1.t(), topo)
+
+ activation_fn_out = act_fn(x1, self.args.activation_fn)
+ x1 = stk.ops.mul(activation_fn_out, x2)
+
+ return stk.ops.dsd(x1, w2)
+
+
+class MemoryOptimizedGroupedGLU(torch.autograd.Function):
+ """GroupedMLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, v1, w2, batch_sizes, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ v1 = v1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], v1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not v1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1', 'v1' and 'w2'.")
+
+ # Layer 0: x @ w1.t().
+ assert gg.backend is not None
+ sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True)
+ v1_out = gg.backend.gmm(x, v1, batch_sizes, trans_b=True)
+
+ # GeLU.
+ activation_fn_out = activation_fn(sdd_out) * v1_out
+
+ # Layer 1: x @ w2.
+ dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, v1, w2, batch_sizes, x, sdd_out, v1_out)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx, ddsd_out):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # Unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, v1, w2 = saved_tensors[:3]
+ batch_sizes = saved_tensors[3]
+ x = saved_tensors[4]
+ sdd_out, v1_out = saved_tensors[5:7]
+
+ # Rematerialize activation_fn output.
+ activation_fn = ctx.activation_fn
+ with torch.set_grad_enabled(True):
+ sdd_out.requires_grad = True
+ v1_out.requires_grad = True
+ activation_fn_out = activation_fn(sdd_out) * v1_out
+ activation_grad_fn = activation_fn_out.backward
+
+ # Compute dw2 with recomputed activation_fn output.
+ assert gg.backend is not None
+ dw2 = gg.backend.gmm(
+ activation_fn_out,
+ ddsd_out,
+ batch_sizes,
+ trans_a=True,
+ )
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ gg.backend.gmm(
+ ddsd_out,
+ w2,
+ batch_sizes,
+ trans_b=True,
+ c=dactivation_fn_out,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out)
+ dsdd_out = sdd_out.grad
+ dv1_out = v1_out.grad
+
+ # Compute dw1.
+ dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True)
+
+ # Compute dv1.
+ dv1 = gg.backend.gmm(dv1_out, x, batch_sizes, trans_a=True)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ dx = ddsd_out
+ gg.backend.gmm(dsdd_out, w1, batch_sizes, c=dx)
+ dx += gg.backend.gmm(dv1_out, v1, batch_sizes)
+ return dx, dw1, dv1, dw2, None, None
+
+
+memory_optimized_grouped_glu = MemoryOptimizedGroupedGLU.apply
+
+
+class GroupedGLU(SparseGLU):
+
+ def forward(self, x, tokens_per_expert):
+ batch_sizes = tokens_per_expert.cpu().to(torch.long)
+ w1, v1, w2 = (
+ self.scale_grad(self.w1),
+ self.scale_grad(self.v1),
+ self.scale_grad(self.w2),
+ )
+ w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2)
+
+ # Re-shape the weights for the grouped GEMMs.
+ ne = mpu.experts_per_rank(self.args)
+ w1 = w1.view(ne, -1, self.args.hidden_size)
+ v1 = v1.view(ne, -1, self.args.hidden_size)
+ w2 = w2.view(ne, -1, self.args.hidden_size)
+
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_grouped_glu(
+ x,
+ w1,
+ v1,
+ w2,
+ batch_sizes,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ assert gg.ops is not None
+ x1 = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
+ x2 = gg.ops.gmm(x, v1, batch_sizes, trans_b=True)
+ x1 = self.args.activation_fn(x1) * x2
+ return gg.ops.gmm(x1, w2, batch_sizes)
+
+
+class SharedGLU(SharedMLP):
+ """GPU for shared expert.
+
+ Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTGLU class
+ """
+
+ def __init__(self, args: Arguments):
+ super().__init__(args)
+ self.gate_proj = args.fc_cls(
+ args.hidden_size,
+ self.args.shared_expert_hidden_size,
+ **self.fc_kwargs,
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))
diff --git a/build/torch29-cxx11-cu129-aarch64-linux/_layers/memory_test.py b/build/torch29-cxx11-cu129-aarch64-linux/_layers/memory_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..74d1166931b712635131985b25a89f4ca23e576d
--- /dev/null
+++ b/build/torch29-cxx11-cu129-aarch64-linux/_layers/memory_test.py
@@ -0,0 +1,103 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import gc
+
+import torch
+import torch.distributed as dist
+
+# from megablocks.layers import arguments, dmoe
+from . import arguments, dmoe
+
+_TESTS = ((8, 2048, 4096, 4096, 32, 4),)
+
+
+def get_tensors():
+ ptrs = set()
+ out = []
+ for obj in gc.get_objects():
+ if torch.is_tensor(obj):
+ if not obj.is_contiguous() or obj.data_ptr() in ptrs:
+ continue
+ out.append(obj)
+ ptrs.add(obj.data_ptr())
+ return out
+
+
+def test_memory(
+ group,
+ batch_size,
+ sequence_length,
+ hidden_size,
+ ffn_hidden_size,
+ num_experts,
+ top_k,
+):
+ args = arguments.Arguments(
+ hidden_size=hidden_size,
+ ffn_hidden_size=ffn_hidden_size,
+ moe_num_experts=num_experts,
+ moe_top_k=top_k,
+ moe_expert_model_parallelism=True,
+ expert_parallel_group=group,
+ fp16=False,
+ bf16=True,
+ device=torch.cuda.current_device(),
+ )
+ layer = dmoe.dMoE(args).cuda()
+
+ x = torch.randn((batch_size, sequence_length, hidden_size),
+ device=torch.cuda.current_device(),
+ dtype=torch.bfloat16).requires_grad_(True)
+ torch.cuda.empty_cache()
+
+ # Run forward + backward.
+ # with torch.autograd.detect_anomaly():
+ out, _ = layer(x)
+ out.mean().backward()
+
+ # Report peak memory.
+ mem = torch.cuda.max_memory_allocated()
+ print('Max Memory Allocated = {:0.0f}MiB'.format(mem / 1e6))
+ print('Max Memory Reserved = {:0.0f}MiB'.format(torch.cuda.max_memory_reserved() / 1e6,),)
+
+ # Calculate weight and gradient memory usage.
+ weight_memory = 2 * (
+ layer.router.layer.weight.numel() + layer.experts.mlp.w1.numel() + layer.experts.mlp.w2.numel()
+ )
+
+ def grad_numel(x):
+ if x.grad is not None:
+ return x.grad.numel()
+ return 0
+
+ grad_memory = 2 * (
+ grad_numel(layer.router.layer.weight) + grad_numel(layer.experts.mlp.w1) + grad_numel(layer.experts.mlp.w2)
+ )
+ weight_memory += grad_memory
+
+ print('Weight Memory Allocated = {:0.0f}MiB'.format(weight_memory / 1e6))
+ print('Activation Memory Allocated = {:0.0f}MiB'.format((mem - weight_memory) / 1e6,),)
+
+ # Manually calculate GPU memory usage from the garbage
+ # collector.
+ gc.collect()
+ total = 0
+ tensors = get_tensors()
+ tensors = sorted(tensors, key=lambda x: -x.numel())
+ for i, t in enumerate(tensors):
+ total += t.numel()
+ print(f'{i}: {t.shape}, {t.numel() * 2}')
+ del tensors
+
+ print('Total Bytes Found = {:0.0f}MiB'.format(total * 2 / 1e6))
+
+
+if __name__ == '__main__':
+ assert dist.is_available()
+ group = dist.init_process_group(backend='nccl')
+ local_rank = dist.get_rank(group)
+ torch.cuda.set_device(local_rank)
+
+ for args in _TESTS:
+ test_memory(group, *args)
diff --git a/build/torch29-cxx11-cu129-aarch64-linux/_layers/mlp.py b/build/torch29-cxx11-cu129-aarch64-linux/_layers/mlp.py
new file mode 100644
index 0000000000000000000000000000000000000000..c99afb9904c24a8b6a83e79059cd1251dbbfd99e
--- /dev/null
+++ b/build/torch29-cxx11-cu129-aarch64-linux/_layers/mlp.py
@@ -0,0 +1,587 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# try:
+# import stk
+# import stk.backend.triton_kernels
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/mlp.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+from packaging import version
+
+# from megablocks import grouped_gemm_util as gg
+# from megablocks.layers import common, gelu, mpu
+# from megablocks.layers.activation_fn import act_fn
+# from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
+
+from .. import grouped_gemm_util as gg
+from . import common, gelu, mpu
+from .activation_fn import act_fn
+from .arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
+
+class ScaleGradient(torch.autograd.Function):
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx: Any, x: torch.Tensor, scale: float):
+ ctx.scale = scale
+ return x
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx: torch.Tensor, grad: torch.Tensor):
+ return grad * ctx.scale, None
+
+
+scale_gradient = ScaleGradient.apply
+
+
+def resolve_dtensor(weight: torch.Tensor):
+ if version.parse(torch.__version__) >= version.parse('2.0.0'):
+ from torch.distributed._tensor import DTensor
+ if isinstance(weight, DTensor):
+ return weight.to_local()
+ return weight
+
+
+def create_moe_expert_weights(
+ args: Arguments,
+ num_experts: int,
+ ffn_hidden_size: int,
+ hidden_size: int,
+ init_method: InitFn,
+):
+ # Create the entire weight matrix such that the sampled weights will
+ # not vary between data parallelism and expert model parallelism for
+ # the same random seed.
+ master_weights = torch.empty(
+ num_experts,
+ ffn_hidden_size,
+ hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ )
+ init_method(master_weights)
+
+ if not args.moe_expert_model_parallelism:
+ return master_weights
+
+ # Calculate the amount of sharding in each dimension.
+ expert_sharding_degree = mpu.expert_sharding_degree(args)
+ hidden_sharding_degree = mpu.hidden_sharding_degree(args)
+
+ # Calculate the experts per rank.
+ #
+ # NOTE: We assign ranks to be expert parallel before going
+ # tensor parallel.
+ rank = mpu.get_expert_parallel_rank(args)
+ expert_rank = rank % expert_sharding_degree
+ num_experts_per_rank = num_experts // expert_sharding_degree
+ start_expert = expert_rank * num_experts_per_rank
+ end_expert = (expert_rank + 1) * num_experts_per_rank
+
+ # Calculate the rows per rank.
+ row_rank = rank // expert_sharding_degree
+ num_rows_per_rank = ffn_hidden_size // hidden_sharding_degree
+ start_row = row_rank * num_rows_per_rank
+ end_row = (row_rank + 1) * num_rows_per_rank
+
+ # Slice the weight matrix to get the chunk for this rank.
+ with torch.no_grad():
+ weights = master_weights[start_expert:end_expert, start_row:end_row]
+ return weights
+
+
+class MLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ # expert_parallel_world_size = mpu.get_expert_parallel_world_size(args)
+ experts_per_rank = mpu.experts_per_rank(args)
+
+ self.w1 = torch.nn.Parameter(
+ torch.empty(
+ experts_per_rank,
+ args.hidden_size,
+ mpu.features_per_rank(args),
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ self.w2 = torch.nn.Parameter(
+ torch.empty(
+ experts_per_rank,
+ mpu.features_per_rank(args),
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w1,
+ args.moe_expert_model_parallelism,
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w2,
+ args.moe_expert_model_parallelism,
+ )
+
+ # Initialize the parameters for the MLP.
+ #
+ # NOTE: It is important that we create the weight tensors prior
+ # to creating the master weights and slicing our the piece for
+ # this rank. If the master weights are created first the PyTorch
+ # caching allocator appears to use the same memory block for these
+ # and the slice which causes large increases in our peak memory
+ # usage.
+ with torch.no_grad():
+ w1 = create_moe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ )
+ self.w1.copy_(w1.transpose(1, 2).contiguous())
+ self.w2.copy_(
+ create_moe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.output_layer_init_method,
+ ),
+ )
+
+ self.gradient_scale = None
+ if self.args.moe_expert_model_parallelism:
+ self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,)
+
+ def scale_grad(self, w):
+ if self.gradient_scale is None:
+ return w
+ return scale_gradient(w, self.gradient_scale)
+
+ def forward(self, x):
+ w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2)
+ w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2)
+ x = torch.bmm(x, w1)
+ x = self.args.activation_fn(x)
+ return torch.bmm(x, w2)
+
+
+def create_dmoe_expert_weights(
+ args: Arguments,
+ num_experts: int,
+ rows: int,
+ columns: int,
+ init_method: InitFn,
+):
+ weights = create_moe_expert_weights(
+ args,
+ num_experts,
+ rows,
+ columns,
+ init_method,
+ )
+ return weights.view([-1, columns])
+
+
+class MemoryOptimizedMLP(torch.autograd.Function):
+ """Sparse MLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, w2, topo, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.")
+
+ topo_tensors = (
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets,
+ topo.column_indices_t,
+ topo.offsets_t,
+ topo.block_offsets_t,
+ )
+
+ # Layer 0: x @ w1.t().
+ sdd_out = stk.ops.sdd(x, w1.t(), topo)
+
+ # GeLU.
+ activation_fn_out = act_fn(sdd_out, activation_fn)
+
+ # Layer 1: x @ w2.
+ dsd_out = stk.ops.dsd(activation_fn_out, w2)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.shape = topo.shape
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.data.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, w2, *topo_tensors, x, sdd_out.data)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx, ddsd_out):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, w2 = saved_tensors[:2]
+ topo_tensors = saved_tensors[2:8]
+ x = saved_tensors[8]
+ sdd_out_data = saved_tensors[9]
+
+ # rematerialize activation function output
+ activation_fn = ctx.activation_fn
+ sdd_out = stk.Matrix(ctx.shape, sdd_out_data, *topo_tensors)
+ activation_fn_out, activation_grad_fn = act_fn(
+ sdd_out,
+ activation_fn,
+ return_grad_fn=True,
+ )
+
+ # Compute dw2 with recomputed activation_fn output.
+ dw2 = stk.ops.dsd(activation_fn_out.t(), ddsd_out)
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ stk.backend.triton_kernels.sdd(
+ ddsd_out,
+ w2.t(),
+ dactivation_fn_out.shape,
+ dactivation_fn_out.data,
+ dactivation_fn_out.offsets,
+ dactivation_fn_out.row_indices,
+ dactivation_fn_out.column_indices,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ if activation_fn is DEFAULT_ACTIVATION_FN:
+ dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out)
+ else:
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out.data)
+ dsdd_out = stk.Matrix(ctx.shape, sdd_out.data.grad, *topo_tensors)
+
+ # Compute dw1.
+ dw1 = stk.ops.dsd(dsdd_out.t(), x)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ stk.backend.triton_kernels.dsd(
+ dsdd_out.shape,
+ dsdd_out.data,
+ dsdd_out.offsets,
+ dsdd_out.row_indices,
+ dsdd_out.column_indices,
+ dsdd_out.offsets_t,
+ dsdd_out.column_indices_t,
+ dsdd_out.block_offsets_t,
+ False,
+ w1,
+ ddsd_out,
+ )
+ dx = ddsd_out
+ return dx, dw1, dw2, None, None
+
+
+memory_optimized_mlp = MemoryOptimizedMLP.apply
+
+
+class SparseMLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ self._num_rows_per_rank = mpu.experts_per_rank(args) * mpu.features_per_rank(args)
+
+ self.w1 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ self.w2 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+
+ # Initialize the parameters for the MLP.
+ #
+ # NOTE: It is important that we create the weight tensors prior
+ # to creating the master weights and slicing our the piece for
+ # this rank. If the master weights are created first the PyTorch
+ # caching allocator appears to use the same memory block for these
+ # and the slice which causes large increases in our peak memory
+ # usage.
+ with torch.no_grad():
+ self.w1.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ ),
+ )
+ self.w2.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.output_layer_init_method,
+ ),
+ )
+
+ self._should_set_parallelism_attribute = args.moe_expert_model_parallelism
+ mpu.set_expert_model_parallel_attributes(
+ self.w1,
+ self._should_set_parallelism_attribute,
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w2,
+ self._should_set_parallelism_attribute,
+ )
+
+ self.gradient_scale = None
+ if self.args.moe_expert_model_parallelism:
+ self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,)
+
+ def scale_grad(self, w):
+ if self.gradient_scale is None:
+ return w
+ return scale_gradient(w, self.gradient_scale)
+
+ def forward(self, x, topo):
+ w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2)
+ w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2)
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_mlp(
+ x,
+ w1,
+ w2,
+ topo,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ x = stk.ops.sdd(x, w1.t(), topo)
+ activation_fn_out = act_fn(x, self.args.activation_fn)
+ return stk.ops.dsd(activation_fn_out, w2)
+
+
+class MemoryOptimizedGroupedMLP(torch.autograd.Function):
+ """GroupedMLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, w2, batch_sizes, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.")
+
+ # Layer 0: x @ w1.t().
+ assert gg.backend is not None
+ sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True)
+
+ # activation_fn
+ activation_fn_out = activation_fn(sdd_out)
+
+ # Layer 1: x @ w2.
+ dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, w2, batch_sizes, x, sdd_out)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx: Any, ddsd_out: torch.Tensor):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # Unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, w2 = saved_tensors[:2]
+ batch_sizes = saved_tensors[2]
+ x = saved_tensors[3]
+ sdd_out = saved_tensors[4]
+
+ # Rematerialize activation_fn output.
+ activation_fn = ctx.activation_fn
+ with torch.set_grad_enabled(True):
+ sdd_out.requires_grad = True
+ activation_fn_out = activation_fn(sdd_out)
+ activation_grad_fn = activation_fn_out.backward
+
+ # Compute dw2 with recomputed activation_fn output.
+ assert gg.backend is not None
+ dw2 = gg.backend.gmm(
+ activation_fn_out,
+ ddsd_out,
+ batch_sizes,
+ trans_a=True,
+ )
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ gg.backend.gmm(
+ ddsd_out,
+ w2,
+ batch_sizes,
+ trans_b=True,
+ c=dactivation_fn_out,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ if activation_fn is DEFAULT_ACTIVATION_FN:
+ dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out)
+ else:
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out)
+ dsdd_out = sdd_out.grad
+
+ # Compute dw1.
+ dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ gg.backend.gmm(dsdd_out, w1, batch_sizes, c=ddsd_out)
+ dx = ddsd_out
+ return dx, dw1, dw2, None, None
+
+
+memory_optimized_grouped_mlp = MemoryOptimizedGroupedMLP.apply
+
+
+class GroupedMLP(SparseMLP):
+
+ def forward(self, x, tokens_per_expert):
+ batch_sizes = tokens_per_expert.cpu().to(torch.long)
+ w1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.w2))
+
+ # Re-shape the weights for the grouped GEMMs.
+ ne = mpu.experts_per_rank(self.args)
+ w1 = resolve_dtensor(w1).view(ne, -1, self.args.hidden_size)
+ w2 = resolve_dtensor(w2).view(ne, -1, self.args.hidden_size)
+
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_grouped_mlp(
+ x,
+ w1,
+ w2,
+ batch_sizes,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ assert gg.ops is not None
+ x = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
+ x = self.args.activation_fn(x)
+ return gg.ops.gmm(x, w2, batch_sizes)
+
+
+class SharedMLP(torch.nn.Module):
+ """MLP for shared expert.
+
+ Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTMLP class
+ """
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ self.fc_kwargs: dict[str, Any] = {
+ 'bias': args.bias,
+ 'device': args.device,
+ }
+ self.fc_kwargs.update(args.fc_kwargs)
+
+ self.up_proj = args.fc_cls(
+ args.hidden_size,
+ args.shared_expert_hidden_size,
+ **self.fc_kwargs,
+ )
+ self.act = args.activation_fn
+ self.down_proj = args.fc_cls(
+ args.shared_expert_hidden_size,
+ args.hidden_size,
+ **self.fc_kwargs,
+ )
+ self.down_proj._is_residual = True # a flag for llm-foundry init
+
+ def add_experts_sharedexpert(
+ self,
+ shared_expert_out: torch.Tensor,
+ expert_out: torch.Tensor,
+ ) -> torch.Tensor:
+ # Helper function to add expert output to shared expert output
+ # with optional weighted sum.
+ if self.args.shared_expert_weighted_sum:
+ # enable using weighted sum for shared expert output
+ # wieghted by number of experts used
+ t_experts = self.args.moe_top_k + 1
+ sh_mlp_out = shared_expert_out / t_experts
+ return sh_mlp_out.add(
+ expert_out,
+ alpha=(self.args.moe_top_k / t_experts),
+ )
+
+ return shared_expert_out + expert_out
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.down_proj(self.act(self.up_proj(x)))
diff --git a/build/torch29-cxx11-cu129-aarch64-linux/_layers/moe.py b/build/torch29-cxx11-cu129-aarch64-linux/_layers/moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0a4aeaacc9c86fc70944e730c53f7a55644e05e
--- /dev/null
+++ b/build/torch29-cxx11-cu129-aarch64-linux/_layers/moe.py
@@ -0,0 +1,507 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Optional, Tuple
+
+import numpy as np
+import torch
+import torch.distributed as dist
+
+# import megablocks.ops as ops
+# from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry
+# from megablocks.layers.all_to_all import all_to_all
+# from megablocks.layers.arguments import Arguments
+
+from ..ops import (
+ sort,
+ histogram,
+ inclusive_cumsum,
+ exclusive_cumsum,
+ binned_gather,
+ binned_scatter,
+ gather,
+ scatter,
+ repeat,
+ replicate,
+)
+
+from . import common, mlp, mpu, router, sharedexpert_registry
+from .arguments import Arguments
+from .all_to_all import all_to_all
+
+_LOAD_BALANCING_LOSS = []
+
+
+def save_load_balancing_loss(loss):
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.append(loss)
+
+
+def get_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ return _LOAD_BALANCING_LOSS
+
+
+def clear_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.clear()
+
+
+def batched_load_balancing_loss(args: Arguments):
+ if args.moe_loss_weight == 0:
+ return 0.0
+
+ # tokens_per_expert[i].shape = (num_experts)
+ # expert_scores[i].shape = (tokens, num_experts)
+ tokens_per_expert, expert_scores = zip(*get_load_balancing_loss())
+ num_layers_per_pipeline_stage = (args.num_layers // args.pipeline_model_parallel_size)
+ if args.num_layers_per_virtual_pipeline_stage is not None:
+ num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage
+
+ if len(tokens_per_expert) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f'Expected {num_layers_per_pipeline_stage} token_per_experts '
+ f'but found {len(tokens_per_expert)}.\nnum_layers = '
+ f'{args.num_layers}\npipeline_model_parallel_size = '
+ f'{args.pipeline_model_parallel_size}\n'
+ 'num_layers_per_virtual_pipeline_stage'
+ f' = {args.num_layers_per_virtual_pipeline_stage}',
+ )
+ if len(expert_scores) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f'Expected {num_layers_per_pipeline_stage} expert_scores '
+ f'but found {len(tokens_per_expert)}.\nnum_layers = '
+ f'{args.num_layers}\npipeline_model_parallel_size = '
+ f'{args.pipeline_model_parallel_size}\n'
+ 'num_layers_per_virtual_pipeline_stage'
+ f' = {args.num_layers_per_virtual_pipeline_stage}',
+ )
+
+ # Verify the shape of the tokens_per_expert and expert_scores tensors.
+ assert all((x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert))
+
+ tokens = expert_scores[0].shape[0]
+ assert all(((x.ndim == 2 and x.shape[1] == args.moe_num_experts and x.shape[0] == tokens) for x in expert_scores))
+
+ # Concatenate the contributions of each layer and convert to
+ # the correct types and formats for the dot product.
+ expert_scores = torch.cat(expert_scores, dim=1)
+ if args.moe_lbl_in_fp32:
+ expert_scores = expert_scores.float()
+ if tokens != 0:
+ expert_scores = expert_scores.mean(dim=0)
+ else:
+ expert_scores = expert_scores.sum(dim=0)
+ tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype)
+
+ expected_values = num_layers_per_pipeline_stage * args.moe_num_experts
+ assert tokens_per_expert.numel() == expected_values
+ assert expert_scores.numel() == expected_values
+
+ # Calculate the total scale across all factors.
+ #
+ # loss_weight * num_experts / (num_layers * tokens * top_k)
+ scale_numerator = (args.moe_num_experts * args.moe_loss_weight)
+ scale_denominator = (args.num_layers * tokens * args.moe_top_k)
+ scale = scale_numerator / scale_denominator
+ return scale * torch.dot(tokens_per_expert, expert_scores)
+
+
+# NOTE: This class defines MoE expert computation, including expert model parallel
+# communication. When using FSDP on top of MegaBlocks this is the module that should
+# be wrapped s.t. the weight all-gathers can be scheduled *before* the expert model
+# parallel all2all.
+class ParallelMLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super(ParallelMLP, self).__init__()
+ self.args = args
+
+ # Calculate the number of experts in total and the number of experts
+ # owned by this rank.
+ # world_size = mpu.get_expert_parallel_world_size(args)
+ self.num_experts = args.moe_num_experts
+ self.top_k = self.args.moe_top_k
+
+ # Calculate the number of bits needed to represent the expert indices
+ # so that we can pass it to radix sort.
+ self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1)
+
+ # Expert MLP.
+ self.mlp = mlp.MLP(args)
+
+ self.bias: Optional[torch.Tensor]
+ if self.args.bias:
+ # Note that the output bias is not parallelized with expert
+ # model parallelism.
+ self.bias = torch.nn.Parameter(
+ torch.empty(
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ torch.nn.init.zeros_(self.bias)
+ else:
+ self.register_parameter('bias', None)
+
+ # Select the forward function for the operating mode.
+ self.forward_fn = (self.parallel_forward_once if args.moe_expert_model_parallelism else self.forward_once)
+
+ def expert_capacity(self, tokens: int) -> int:
+ world_size = mpu.get_expert_parallel_world_size(self.args)
+ tokens_per_expert = (self.top_k * tokens * world_size / self.num_experts)
+ return int(self.args.moe_capacity_factor * tokens_per_expert)
+
+ def load_balancing_loss(self, tokens_per_expert: torch.Tensor, expert_scores: torch.Tensor):
+ """Calculate the load balancing loss contribution."""
+ assert len(expert_scores.size()) == 2
+ tokens, num_experts = expert_scores.size()
+ assert num_experts == self.num_experts
+ assert len(tokens_per_expert.size()) == 1
+ num_experts, = tokens_per_expert.size()
+ assert num_experts == self.num_experts
+ scale = self.num_experts / (tokens * self.top_k)
+ return scale * torch.dot(
+ tokens_per_expert.to(expert_scores.dtype),
+ expert_scores.mean(dim=0),
+ )
+
+ def indices_and_bins(self,
+ top_expert: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ # Sort the expert ids to produce the scatter/gather
+ # indices for the permutation.
+ #
+ # TODO(tgale): Is it worth doing this conversion to 32-bit
+ # prior? Could we place the `torch.max` operation to return
+ # 32-bit expert indices?
+ top_expert = top_expert.int()
+ # output = ops.sort(top_expert, self.sort_end_bit)
+ output = sort(top_expert, self.sort_end_bit)
+ assert output is not None
+ bin_ids, indices = output
+
+ # Histogram the expert ids to identify the number of
+ # tokens routed to each expert.
+ #
+ # TODO(tgale): Does the sorted data produce a more favorable
+ # data distribution for histogram? Or is the op parallelism
+ # worth more?
+ # tokens_per_expert = ops.histogram(top_expert, self.num_experts)
+ tokens_per_expert = histogram(top_expert, self.num_experts)
+
+ # Calculate the bin bounds for the sorted tokens.
+ # bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ bins = inclusive_cumsum(tokens_per_expert, 0)
+ assert bins is not None
+ bins = bins.view(1) if not len(bins.size()) else bins
+
+ assert isinstance(indices, torch.Tensor)
+ assert isinstance(bin_ids, torch.Tensor)
+ assert isinstance(bins, torch.Tensor)
+ assert isinstance(tokens_per_expert, torch.Tensor)
+
+ return indices, bin_ids, bins, tokens_per_expert
+
+ def permute_and_compute(
+ self,
+ x: torch.Tensor,
+ tokens_per_expert: int, # unused
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor, # unused
+ expert_weights: torch.Tensor,
+ bins: torch.Tensor,
+ expert_capacity: int,
+ top_k: int,
+ ):
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ # output = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
+ output = binned_gather(x, indices, bins, expert_capacity, top_k)
+ assert output is not None
+ x = output
+
+ # Perform the expert computation. Note that we don't
+ # use biases for these linear operations.
+ x = self.mlp(x)
+
+ # Un-route the data for the MoE output.
+ # return ops.binned_scatter(x, indices, expert_weights, bins, top_k)
+ return binned_scatter(x, indices, expert_weights, bins, top_k)
+
+
+ def forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ # If expert_capacity is set to zero, set the number of tokens
+ # per expert to the maximum we need to avoid dropping tokens.
+ sl, bs, _ = x.size()
+ expert_capacity = self.expert_capacity(sl * bs)
+ if expert_capacity == 0:
+ expert_capacity = torch.max(tokens_per_expert).item()
+
+ x = self.permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ self.top_k,
+ )
+ return x, tokens_per_expert
+
+ def parallel_forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ # NOTE: This function implements the same computation as forward_once
+ # but with expert model parallelism.
+ #
+ # 1. Permute the tokens locally so that they are grouped by their
+ # expert assignments. This allows us to transfer all of the tokens
+ # for a remote device in one communication primitive.
+ #
+ # 2. Permute the tokens across the expert parallel devices. After
+ # this is completed each device has all of the tokens assigned to
+ # its set of experts in its local HBM.
+ #
+ # 3. Permute the tokens locally so that they are grouped by their
+ # expert assignement. After the distributed permutation the tokens
+ # are grouped by which device they came from. We re-order them
+ # locally to allow for efficient computation.
+ #
+ # After this series of permutations we compute the linear layers
+ # and then repeat these three steps in reverse to produce the final
+ # output.
+ #
+ # Compute the mapping of local tokens to experts.
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ # If we're sharding the experts along the hidden dimension
+ # multiple devices own parts of the same sets of experts.
+ # Replicate the token counts so every device gets the counts.
+ # repeated_tokens_per_expert = ops.repeat(
+ repeated_tokens_per_expert = repeat(
+ tokens_per_expert,
+ (mpu.hidden_sharding_degree(self.args),),
+ )
+
+ # Pass token count information to the device on which the
+ # target expert resides.
+ parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert,)
+ tpe_handle = dist.all_to_all_single(
+ parallel_tokens_per_expert,
+ repeated_tokens_per_expert,
+ group=self.args.expert_parallel_group,
+ async_op=True,
+ )
+
+ # Permute locally and without any padding so that tokens for each
+ # parallel device are stored contiguously.
+ #
+ # This view updates the shape of the tensor from [sl, bs, hs] to
+ # [sl * bs, hs] prior to the permutation.
+ x = x.view(-1, x.shape[-1])
+ # output = ops.gather(x, indices, bin_ids, bins, self.top_k)
+ output = gather(x, indices, bin_ids, bins, self.top_k)
+ assert output is not None
+ x = output
+
+ # Compute the number of tokens that will be received from each
+ # device and permute the input data across the devices.
+ with torch.no_grad():
+ tpe_handle.wait()
+ experts_per_rank = mpu.experts_per_rank(self.args)
+
+ # Reshape to [world_size, num_experts_per_rank].
+ world_size = mpu.get_expert_parallel_world_size(self.args)
+ repeated_tokens_per_expert = (repeated_tokens_per_expert.view(world_size, experts_per_rank))
+ parallel_tokens_per_expert = (parallel_tokens_per_expert.view(world_size, experts_per_rank))
+
+ # TODO(tgale): It might be faster to do this on the GPU and
+ # then communicate the results back to the host.
+ send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1)
+ parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
+ recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1)
+
+ # Convert the send/recv counts to lists.
+ send_counts = send_counts.tolist()
+ recv_counts = recv_counts.tolist()
+ tokens_received = sum(recv_counts)
+
+ # If we're sharding the experts along the hidden dimension
+ # multiple devices own parts of the same sets of experts.
+ # Replicate the token counts so devices that share experts
+ # get all of the tokens assigned to them.
+ #
+ # TODO(tgale): Fuse this into the prior, local permutation.
+ # x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
+ x = repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
+
+ # Start the cross-device permutation asynchronously so we can
+ # overlap communication with computation.
+ parallel_x, parallel_x_handle = all_to_all(
+ x,
+ recv_counts,
+ send_counts,
+ self.args.expert_parallel_group,
+ async_op=True,
+ )
+
+ with torch.no_grad():
+ # After we do the cross-device permutation we have the tokens on the
+ # correct device but not yet grouped by expert because we received
+ # tokens from each device as contiguous chunks. To group the tokens
+ # for expert computation we'll do one more local permutation. The
+ # rest of this torch.no_grad() scope sets up the indices and bins
+ # for this permutation.
+ # replicate_bins = ops.inclusive_cumsum(
+ replicate_bins = inclusive_cumsum(
+ parallel_tokens_per_expert.flatten(),
+ 0,
+ )
+ replicate_bins = (replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins)
+
+ # Construct the expert indices for the permuted tokens.
+ parallel_top_expert = torch.remainder(
+ torch.arange(
+ self.num_experts * mpu.hidden_sharding_degree(self.args),
+ dtype=torch.int32,
+ device=indices.device,
+ ),
+ mpu.experts_per_rank(self.args),
+ )
+ # parallel_top_expert = ops.replicate(
+ parallel_top_expert = replicate(
+ parallel_top_expert.unsqueeze(dim=0),
+ replicate_bins,
+ tokens_received,
+ ).flatten()
+
+ # TODO(tgale): The sort_end_bit here can be reduced.
+ # parallel_bin_ids, parallel_indices = ops.sort(
+ parallel_bin_ids, parallel_indices = sort(
+ parallel_top_expert,
+ self.sort_end_bit,
+ )
+
+ # Calculate the bins boundaries from the token counts.
+ parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+ # parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = (parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins)
+
+ # If expert_capacity is set to zero, set the number of tokens
+ # per expert to the maximum we need to avoid dropping tokens.
+ tokens, _ = x.size()
+ expert_capacity = self.expert_capacity(tokens)
+ if expert_capacity == 0:
+ expert_capacity = torch.max(parallel_tokens_per_expert).item()
+
+ # Locally permute the tokens and perform the expert computation.
+ # Block to make sure that the cross-device permutation is complete.
+ if self.args.mlp_impl == 'grouped':
+ # GroupedMLP requires counts on CPU. We can use the tensor already
+ # moved to CPU for the prior all_to_all, which avoids an extra
+ # device synchronization.
+ parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+ parallel_x_handle.wait()
+ parallel_x = self.permute_and_compute(
+ parallel_x,
+ parallel_tokens_per_expert,
+ parallel_indices,
+ parallel_bin_ids,
+ None, # expert_weights
+ parallel_bins,
+ expert_capacity,
+ top_k=1,
+ )
+
+ # Un-permute the tokens across the devices.
+ x, _ = all_to_all(
+ parallel_x,
+ send_counts,
+ recv_counts,
+ self.args.expert_parallel_group,
+ )
+
+ # Reduce along the hidden sharding to get the final outputs.
+ #
+ # TODO(tgale): Fuse this into the following local permutation.
+ shape = (
+ mpu.hidden_sharding_degree(self.args),
+ -1,
+ self.args.hidden_size,
+ )
+ # x = ops.sum(x.view(shape), dim=0)
+ x = x.view(shape).sum(dim=0)
+
+ # Un-permute locally to setup for the next series of operations.
+ # x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
+ x = scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
+ return x, tokens_per_expert.flatten()
+
+ def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ in_shape = x.size()
+
+ # Compute the experts.
+ x, tokens_per_expert = self.forward_fn(x, expert_weights, top_experts)
+ if self.training and self.args.moe_loss_weight > 0:
+ save_load_balancing_loss((tokens_per_expert, scores))
+ x = x.view(in_shape)
+ if self.bias is not None:
+ if self.args.return_bias:
+ return x, self.bias
+ return x + self.bias
+ return x
+
+
+class MoE(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super(MoE, self).__init__()
+
+ # Token router.
+ self.router = router.LearnedRouter(args)
+
+ # Expert computation helper.
+ self.experts = self._init_experts_mlp(args)
+
+ self.shared_expert = None
+ if args.shared_expert:
+ # SharedExpert computation helper.
+ self.shared_expert = sharedexpert_registry.get(args)
+
+ def _init_experts_mlp(self, args: Arguments):
+ return ParallelMLP(args)
+
+ def forward(self, x: torch.Tensor):
+ # NOTE: If we're going to cast the activations to lower precision
+ # do it before we permute the tokens to save bandwidth.
+ x = common.cast_if_autocast_enabled(x)
+
+ # Compute the expert scores and assignments.
+ scores, expert_weights, top_experts = self.router(x)
+
+ # Compute the experts.
+ out = self.experts(x, scores, expert_weights, top_experts)
+ if self.shared_expert is not None:
+ shared_expert_out = self.shared_expert(x)
+ out = self.shared_expert.add_experts_sharedexpert(
+ shared_expert_out,
+ out,
+ )
+ return out
diff --git a/build/torch29-cxx11-cu129-aarch64-linux/_layers/mpu.py b/build/torch29-cxx11-cu129-aarch64-linux/_layers/mpu.py
new file mode 100644
index 0000000000000000000000000000000000000000..434e143ab42bf3f83406d69e9dd1f72777716e22
--- /dev/null
+++ b/build/torch29-cxx11-cu129-aarch64-linux/_layers/mpu.py
@@ -0,0 +1,94 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Optional
+
+import torch
+import torch.distributed as dist
+
+# from megablocks.layers.arguments import Arguments
+from .arguments import Arguments
+
+
+class MoeParam(torch.Tensor):
+
+ def __init__(self):
+ super().__init__(self)
+ self.expert_model_parallel: bool
+
+
+def is_moe_param(tensor: torch.Tensor) -> bool:
+ return hasattr(tensor, 'expert_model_parallel')
+
+
+def get_expert_parallel_world_size(args: Arguments) -> int:
+ return (dist.get_world_size(args.expert_parallel_group) if args.moe_expert_model_parallelism else 1)
+
+
+def get_expert_parallel_rank(args: Arguments) -> int:
+ return (dist.get_rank(args.expert_parallel_group) if args.moe_expert_model_parallelism else 0)
+
+
+def set_expert_model_parallel_attributes(
+ tensor: torch.Tensor,
+ is_parallel: bool,
+):
+ assert not hasattr(tensor, 'expert_model_parallel')
+ setattr(tensor, 'expert_model_parallel', is_parallel)
+
+
+def param_is_expert_model_parallel(param: MoeParam) -> bool:
+ return (hasattr(param, 'expert_model_parallel') and param.expert_model_parallel)
+
+
+def copy_expert_model_parallel_attributes(
+ destination_tensor: torch.Tensor,
+ source_tensor: torch.Tensor,
+):
+ if hasattr(source_tensor, 'expert_model_parallel'):
+ setattr(
+ destination_tensor,
+ 'expert_model_parallel',
+ getattr(source_tensor, 'expert_model_parallel'),
+ )
+
+
+def synchronized_print(group: Optional[dist.ProcessGroup], *x: torch.Tensor):
+ world_size = dist.get_world_size(group)
+ rank = dist.get_rank(group)
+ for i in range(world_size):
+ dist.barrier(group)
+ if i == rank:
+ print(f'rank = {rank}', *x)
+
+
+# Helpers for expert/tensor sharding.
+def expert_sharding_degree(args: Arguments) -> int:
+ world_size = get_expert_parallel_world_size(args)
+ esd = min(world_size, args.moe_num_experts)
+
+ if (args.moe_num_experts % esd) != 0:
+ raise ValueError(f'Cannot shard {args.moe_num_experts} experts {esd} ways.',)
+ return esd
+
+
+def hidden_sharding_degree(args: Arguments) -> int:
+ world_size = get_expert_parallel_world_size(args)
+ esd = expert_sharding_degree(args)
+ hsd = world_size // esd
+
+ if (args.ffn_hidden_size % hsd) != 0:
+ raise ValueError(f'Cannot shard {args.ffn_hidden_size} features {hsd} ways.',)
+ if (esd * hsd) != world_size:
+ raise ValueError(
+ f"Invalid sharding. 'expert_sharding_degree' ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size}).",
+ )
+ return hsd
+
+
+def experts_per_rank(args: Arguments) -> int:
+ return args.moe_num_experts // expert_sharding_degree(args)
+
+
+def features_per_rank(args: Arguments) -> int:
+ return args.ffn_hidden_size // hidden_sharding_degree(args)
diff --git a/build/torch29-cxx11-cu129-aarch64-linux/_layers/router.py b/build/torch29-cxx11-cu129-aarch64-linux/_layers/router.py
new file mode 100644
index 0000000000000000000000000000000000000000..37cb2782348d62583376f1a183c7ede83601216d
--- /dev/null
+++ b/build/torch29-cxx11-cu129-aarch64-linux/_layers/router.py
@@ -0,0 +1,116 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+
+# from megablocks.layers import common
+# from megablocks.layers.arguments import Arguments
+from . import common
+from .arguments import Arguments
+
+_ROUTER_LOGITS = []
+
+
+def _save_router_logits(logits: torch.Tensor, args: Arguments):
+ if args.moe_zloss_weight == 0:
+ return
+ global _ROUTER_LOGITS
+ _ROUTER_LOGITS.append(logits)
+
+
+def clear_router_zloss():
+ global _ROUTER_LOGITS
+ _ROUTER_LOGITS.clear()
+
+
+def batched_router_zloss(args: Arguments):
+ global _ROUTER_LOGITS
+
+ if args.moe_zloss_weight == 0:
+ import warnings
+ warnings.warn('Call to batched_router_zloss, but moe_zloss_weight=0')
+ return 0
+
+ logits_per_router = _ROUTER_LOGITS
+
+ if args.moe_zloss_in_fp32:
+ logits_per_router = [logits.float() for logits in logits_per_router]
+
+ unscaled_zloss_per_router = torch.stack([
+ torch.logsumexp(logits, dim=1).square().mean() for logits in logits_per_router
+ ])
+
+ return args.moe_zloss_weight * unscaled_zloss_per_router
+
+
+# NOTE: To enable end-to-end benchmarking without convergence we
+# support a flag to force the router to assign tokens uniformly
+# across the experts. We do this with a custom autograd operation
+# so that PyTorch still executes the full set of router operation.
+class _UniformExpertAssignment(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, num_experts: int):
+ out = torch.arange(x.numel(), dtype=x.dtype, device=x.device)
+ out = torch.remainder(out, num_experts)
+ return out.view(x.shape)
+
+
+_uniform_expert_assignment = _UniformExpertAssignment.apply
+
+
+class LearnedRouter(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+
+ # Learned router parameters.
+ #
+ # NOTE: This weight matrix is not parallelized with expert model
+ # parallelism. Each device needs the entire router weight matrix
+ # so that it can route its batch of data correctly.
+ self.layer = torch.nn.Linear(
+ args.hidden_size,
+ args.moe_num_experts,
+ bias=False,
+ dtype=common.dtype(args),
+ device=args.device,
+ )
+ args.init_method(self.layer.weight)
+
+ def jitter(self, x: torch.Tensor):
+ low: float = 1.0 - self.args.moe_jitter_eps
+ high: float = 1.0 + self.args.moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return low + noise * (high - low)
+
+ def _top_k(self, scores: torch.Tensor):
+ if self.args.moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, self.args.moe_top_k, dim=-1)
+
+ def forward(self, x: torch.Tensor):
+ if self.training and self.args.moe_jitter_eps is not None:
+ x = x * self.jitter(x)
+
+ logits = self.layer(x.view(-1, x.shape[-1]))
+ _save_router_logits(logits, self.args)
+ scores = logits.softmax(dim=-1)
+ expert_weights, expert_indices = self._top_k(scores)
+ if self.args.moe_normalize_expert_weights:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=self.args.moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+
+ expert_indices = (
+ _uniform_expert_assignment(
+ expert_indices,
+ self.args.moe_num_experts,
+ ) if self.args.uniform_expert_assignment else expert_indices
+ )
+ return scores, expert_weights, expert_indices
diff --git a/build/torch29-cxx11-cu129-aarch64-linux/_layers/sharedexpert_registry.py b/build/torch29-cxx11-cu129-aarch64-linux/_layers/sharedexpert_registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..5840862f88f370ace5fd49bd0612fc98d186cc49
--- /dev/null
+++ b/build/torch29-cxx11-cu129-aarch64-linux/_layers/sharedexpert_registry.py
@@ -0,0 +1,32 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Union
+
+# from megablocks.layers import glu, mlp
+# from megablocks.layers.arguments import Arguments
+from . import glu, mlp
+from .arguments import Arguments
+
+_REGISTRY = {
+ 'mlp': mlp.SharedMLP,
+ 'glu': glu.SharedGLU,
+}
+
+
+def get(args: Arguments) -> Union[mlp.SharedMLP, glu.SharedGLU]:
+ """Returns an SharedMLP for use in a dMoE instance.
+
+ Uses the provided arguments to instantiate the appropriate
+ SharedMLP instance.
+
+ Args:
+ args: propagated Arguments dataclass.
+
+ Returns:
+ An instantiated SharedMLP constructed using the input args.
+ """
+ if args.mlp_type not in _REGISTRY:
+ raise ValueError(f'Unsupported mlp type: {args.mlp_type}')
+
+ return _REGISTRY[args.mlp_type](args)
diff --git a/build/torch29-cxx11-cu129-aarch64-linux/_megablocks_cuda_7a6bcf4.abi3.so b/build/torch29-cxx11-cu129-aarch64-linux/_megablocks_cuda_7a6bcf4.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..dab216ec9f834709a36442fd6d8727e6129e1a74
--- /dev/null
+++ b/build/torch29-cxx11-cu129-aarch64-linux/_megablocks_cuda_7a6bcf4.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d457732aa8fa3b1c8d08d6e9d48d08e6b8fc211967df7e45a82e1d88e58c9728
+size 16035488
diff --git a/build/torch29-cxx11-cu129-aarch64-linux/_ops.py b/build/torch29-cxx11-cu129-aarch64-linux/_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..abde001f8cf5f78a02794d6e9a81fd8195e65d77
--- /dev/null
+++ b/build/torch29-cxx11-cu129-aarch64-linux/_ops.py
@@ -0,0 +1,9 @@
+import torch
+from . import _megablocks_cuda_7a6bcf4
+ops = torch.ops._megablocks_cuda_7a6bcf4
+
+def add_op_namespace_prefix(op_name: str):
+ """
+ Prefix op by namespace.
+ """
+ return f"_megablocks_cuda_7a6bcf4::{op_name}"
diff --git a/build/torch29-cxx11-cu129-aarch64-linux/_version.py b/build/torch29-cxx11-cu129-aarch64-linux/_version.py
new file mode 100644
index 0000000000000000000000000000000000000000..c55783177af19bc03654c730c4892df8f8532279
--- /dev/null
+++ b/build/torch29-cxx11-cu129-aarch64-linux/_version.py
@@ -0,0 +1,6 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+"""The MegaBlocks Version."""
+
+__version__ = '0.11.0.dev0'
diff --git a/build/torch29-cxx11-cu129-aarch64-linux/backend/__init__.py b/build/torch29-cxx11-cu129-aarch64-linux/backend/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d4e43e9b7e6a9a1c3bde2df34914643ca5d8332
--- /dev/null
+++ b/build/torch29-cxx11-cu129-aarch64-linux/backend/__init__.py
@@ -0,0 +1,2 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
diff --git a/build/torch29-cxx11-cu129-aarch64-linux/backend/kernels.py b/build/torch29-cxx11-cu129-aarch64-linux/backend/kernels.py
new file mode 100644
index 0000000000000000000000000000000000000000..d20139324883338ddb312e4b05a72056d74491ac
--- /dev/null
+++ b/build/torch29-cxx11-cu129-aarch64-linux/backend/kernels.py
@@ -0,0 +1,557 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import triton
+import triton.language as tl
+
+# Stub triton autotune when testing in a env that does not have CUDA
+# this approach preserves the original code but enables testing without a GPU
+if torch.cuda.is_available() is False:
+ import warnings
+
+ warnings.warn("CUDA is not available. Triton autotuning is disabled.")
+
+ def _no_autotune(*args, **kwargs):
+ def deco(fn):
+ return fn
+ return deco
+
+ triton.autotune = _no_autotune
+
+
+def assert_is_tensor(x, ndim):
+ if x.ndim != ndim:
+ raise ValueError(f'Expected {ndim}-tensor but got {x.ndim}-tensor')
+
+
+def assert_is_matrix(x):
+ assert_is_tensor(x, 2)
+
+
+def assert_is_vector(x):
+ if x.ndim != 1:
+ raise ValueError(f'Expected 1-tensor but got {x.ndim}-tensor')
+
+
+def assert_equal(a, b):
+ if a != b:
+ raise ValueError(f'Expected dimensions to be equal but got {a} and {b}.',)
+
+
+# a: (tokens, hidden_size), real.
+# indices: (tokens * top_k), integer.
+# bin_ids: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+# padded_bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _padded_copy(
+ a,
+ b,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+ A_TO_B: tl.constexpr,
+ SCALE: tl.constexpr,
+):
+ # Our index into array 'a'.
+ index_a = tl.load(indices + tl.program_id(0))
+
+ # One threadblock per row in 'a'. Array 'b' has greater or equal
+ # number of rows since they could be padded.
+ bin_idx = tl.load(bin_ids + tl.program_id(0))
+
+ # Now we know what bin we're assigned to, but we need to know how
+ # many threadblocks were assigned to earlier bins so we can offset
+ # in our bin properly.
+ offset_in_bin = tl.program_id(0)
+ if bin_idx > 0:
+ offset_in_bin -= tl.load(bins + bin_idx - 1)
+
+ # Load the starting index of our bin in array 'b'.
+ index_b = offset_in_bin
+ if bin_idx > 0:
+ index_b += tl.load(padded_bins + bin_idx - 1)
+
+ # Offset the input and output pointers.
+ #
+ # If we're going from A to B, divide the input index to copy
+ # the same input repeatedly. If we're going from B to A we
+ # need to reduce the result. Using atomics is slow, so we
+ # do the reduce step in a second kernel.
+ offset = index_a // TOP_K if A_TO_B else index_a
+ a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS)
+ b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ # Load the scale, if requested.
+ scale = tl.load(weights + index_a) if SCALE else 1
+
+ # Swap the pointers depending on the direction.
+ iptr = a if A_TO_B else b
+ optr = b if A_TO_B else a
+
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ x = tl.load(iptr + offsets, mask=mask)
+ x = x.to(tl.float32) * scale.to(tl.float32)
+
+ tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask)
+
+ offsets += BLOCK_X
+
+
+def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+ assert_equal(bin_ids.shape[0], x.shape[0] * top_k)
+ assert_equal(bins.size(), padded_bins.size())
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ # NOTE: Because of the padding, the output size is dynamic.
+ # We load the final padded bin bound to get the output rows.
+ output_rows = padded_bins[-1].cpu().item()
+ out = torch.zeros((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ x,
+ out,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def gather(x, indices, bin_ids, weights, bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+ assert_equal(bin_ids.shape[0], x.shape[0] * top_k)
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ # NOTE: There is no padding so the output rows equals the
+ # input rows multiplied by top_k.
+ output_rows = x.shape[0] * top_k
+ out = torch.empty((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ x,
+ out,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], bin_ids.shape[0])
+ assert_equal(bins.size(), padded_bins.size())
+
+ if weights is not None:
+ assert_equal(indices.shape[0], weights.shape[0])
+
+ tokens = indices.shape[0] // top_k
+ out = torch.empty((tokens, top_k, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ out,
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=False,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+
+ # Reduce along the top-k dimension, if needed.
+ return out.sum(dim=1) if top_k > 1 else out.view(tokens, x.shape[1])
+
+
+def scatter(x, indices, bin_ids, weights, bins, top_k):
+ return padded_scatter(x, indices, bin_ids, weights, bins, bins, top_k)
+
+
+# x: (tokens, top_k, hidden_size), real
+# grad: (tokens, hidden_size), real.
+# wgrad: (tokens, top_k), real.
+# indices: (tokens * top_k), integer.
+# bin_ids: (tokens * top_k), integer.
+# bins: (num_experts), integer.
+# padded_bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _padded_copy_wgrad(
+ x,
+ grad,
+ wgrad,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+):
+ # Our index into 'tokens * top_k'.
+ index_out = tl.load(indices + tl.program_id(0))
+
+ # One threadblock per row in 'a'. Array 'b' has greater or equal
+ # number of rows since they could be padded.
+ bin_idx = tl.load(bin_ids + tl.program_id(0))
+
+ # Now we know what bin we're assigned to, but we need to know how
+ # many threadblocks were assigned to earlier bins so we can offset
+ # in our bin properly.
+ offset_in_bin = tl.program_id(0)
+ if bin_idx > 0:
+ offset_in_bin -= tl.load(bins + bin_idx - 1)
+
+ # Load the starting index of our bin in array 'x'.
+ index_x = offset_in_bin
+ if bin_idx > 0:
+ index_x += tl.load(padded_bins + bin_idx - 1)
+
+ # Offset the input and output pointers.
+ wgrad += index_out
+ grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS)
+ x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ data = tl.load(x + offsets, mask=mask).to(tl.float32)
+ scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
+ acc += data * scale
+ offsets += BLOCK_X
+
+ # Reduce to get the final result and store.
+ out = tl.sum(acc).to(wgrad.dtype.element_ty)
+ tl.store(wgrad, out)
+
+
+def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_matrix(grad)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], bin_ids.shape[0])
+ assert_equal(bins.size(), padded_bins.size())
+
+ tokens = indices.shape[0] // top_k
+ out = torch.empty((tokens * top_k), dtype=x.dtype, device=x.device)
+ _padded_copy_wgrad[(indices.shape[0],)](
+ x,
+ grad,
+ out,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ TOP_K=top_k,
+ )
+ return out
+
+
+def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k):
+ return padded_scatter_wgrad(x, grad, indices, bin_ids, bins, bins, top_k)
+
+
+# a: (tokens, hidden_size), real.
+# b: (num_experts, expert_capacity, num_columns), real.
+# indices: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _binned_copy(
+ a,
+ b,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+ A_TO_B: tl.constexpr,
+ SCALE: tl.constexpr,
+):
+ # Load our indices into the output.
+ expert_idx = tl.program_id(0)
+ entry_idx = tl.program_id(1)
+
+ # Calculate our offset into the output.
+ index_b = expert_idx * expert_capacity + entry_idx
+
+ # Load the index bounds for our bin and calculate
+ # the number of tokens assigned to our expert.
+ start = 0
+ if expert_idx > 0:
+ start = tl.load(bins + expert_idx - 1)
+ end = tl.load(bins + expert_idx)
+ num_tokens = end - start
+
+ # Calculate our offset into the input. If we don't
+ # have an input exit early.
+ if entry_idx >= num_tokens:
+ return
+ index_a = tl.load(indices + start + entry_idx)
+
+ # Offset the input and output pointers.
+ #
+ # If we're going from A to B, divide the input index to copy
+ # the same input repeatedly. If we're going from B to A we
+ # need to reduce the result. Using atomics is slow, so we
+ # do the reduce step in a second kernel.
+ offset = index_a // TOP_K if A_TO_B else index_a
+ a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS)
+ b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ # Load the scale, if requested.
+ scale = tl.load(weights + index_a) if SCALE else 1
+
+ # Swap the pointers depending on the direction.
+ #
+ # NOTE: We need to zero the output in both directions.
+ iptr = a if A_TO_B else b
+ optr = b if A_TO_B else a
+
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ x = tl.load(iptr + offsets, mask=mask)
+ x = x.to(tl.float32) * scale.to(tl.float32)
+
+ tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask)
+
+ offsets += BLOCK_X
+
+
+def binned_gather(x, indices, weights, bins, expert_capacity, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ num_experts = bins.shape[0]
+ out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
+
+ _binned_copy[(num_experts, expert_capacity)](
+ x,
+ out,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def binned_scatter(x, indices, weights, bins, top_k):
+ # Validate the input shapes.
+ assert_is_tensor(x, 3)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(bins.shape[0], x.shape[0])
+
+ if weights is not None:
+ assert_equal(indices.shape[0], weights.shape[0])
+
+ num_experts, expert_capacity, hidden_size = x.shape
+ tokens = indices.shape[0] // top_k
+ out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
+ _binned_copy[(num_experts, expert_capacity)](
+ out,
+ x,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS=hidden_size,
+ A_TO_B=False,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+
+ # Reduce along the top-k dimension, if needed.
+ return out.sum(dim=1) if top_k > 1 else out.view(tokens, hidden_size)
+
+
+# a: (tokens, hidden_size), real.
+# b: (num_experts, expert_capacity, num_columns), real.
+# indices: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _binned_copy_wgrad(
+ x,
+ grad,
+ wgrad,
+ num_experts,
+ expert_capacity,
+ indices,
+ bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+):
+ # Load our indices into the output.
+ expert_idx = tl.program_id(0)
+ entry_idx = tl.program_id(1)
+
+ # Calculate our offset into the output.
+ index_x = expert_idx * expert_capacity + entry_idx
+
+ # Load the index bounds for our bin and calculate
+ # the number of tokens assigned to our expert.
+ start = 0
+ if expert_idx > 0:
+ start = tl.load(bins + expert_idx - 1)
+ end = tl.load(bins + expert_idx)
+ num_tokens = end - start
+
+ # Calculate our offset into the input. If we don't
+ # have an input exit early.
+ if entry_idx >= num_tokens:
+ return
+ index_out = tl.load(indices + start + entry_idx)
+
+ # Offset the input and output pointers.
+ wgrad += index_out
+ grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS)
+ x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ data = tl.load(x + offsets, mask=mask).to(tl.float32)
+ scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
+ acc += data * scale
+ offsets += BLOCK_X
+
+ # Reduce to get the final result and store.
+ out = tl.sum(acc).to(wgrad.dtype.element_ty)
+ tl.store(wgrad, out)
+
+
+def binned_scatter_wgrad(x, grad, indices, bins, top_k):
+ # Validate the input shapes.
+ assert_is_tensor(x, 3)
+ assert_is_matrix(grad)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(bins.shape[0], x.shape[0])
+
+ num_experts, expert_capacity, hidden_size = x.shape
+ tokens = indices.shape[0] // top_k
+ out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
+ _binned_copy_wgrad[(num_experts, expert_capacity)](
+ x,
+ grad,
+ out,
+ num_experts,
+ expert_capacity,
+ indices,
+ bins,
+ NUM_COLUMNS=hidden_size,
+ TOP_K=top_k,
+ )
+ return out
diff --git a/build/torch29-cxx11-cu129-aarch64-linux/benchmark_util.py b/build/torch29-cxx11-cu129-aarch64-linux/benchmark_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..02612d95e3ead1175a596e2878fa34b5bf85ad6f
--- /dev/null
+++ b/build/torch29-cxx11-cu129-aarch64-linux/benchmark_util.py
@@ -0,0 +1,35 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import numpy as np
+import torch
+
+
+def log_benchmark(name, arguments, time, std):
+ print('=' * 60)
+ print(f'{name} Benchmark')
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean time = {:.3f}ms, std time = {:.3f}ms'.format(time, std))
+ print('=' * 60)
+
+
+def benchmark_function(fn, iterations=100, warmup=10):
+ # Warmup iterations.
+ for _ in range(warmup):
+ fn()
+
+ times = []
+ for i in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+
+ start.record()
+ fn()
+ end.record()
+
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ return np.mean(times), np.std(times)
diff --git a/build/torch29-cxx11-cu129-aarch64-linux/cpu_fused_moe.py b/build/torch29-cxx11-cu129-aarch64-linux/cpu_fused_moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a29e81d6e9c57f64b0a78ac1c0828e45fd9d855
--- /dev/null
+++ b/build/torch29-cxx11-cu129-aarch64-linux/cpu_fused_moe.py
@@ -0,0 +1,311 @@
+# SPDX-License-Identifier: Apache-2.0
+# MegaBlocks CPU Fused MoE Implementation
+#
+# This is a pure Python/PyTorch implementation for CPU.
+# For better performance, consider using the C++ kernel implementation.
+#
+import torch
+import torch.nn.functional as F
+
+
+def swigluoai_activation(gate: torch.Tensor, up: torch.Tensor,
+ alpha: float = 1.702, limit: float = 7.0) -> torch.Tensor:
+ """
+ SwigluOAI activation function used in GptOss models.
+
+ Formula:
+ gate = clamp(gate, max=limit)
+ up = clamp(up, -limit, limit)
+ glu = gate * sigmoid(gate * alpha)
+ output = (up + 1) * glu
+
+ Args:
+ gate: Gate tensor from gate projection
+ up: Up tensor from up projection
+ alpha: Scaling factor for sigmoid (default: 1.702)
+ limit: Clamp limit (default: 7.0)
+
+ Returns:
+ Activated tensor
+ """
+ gate = gate.clamp(max=limit)
+ up = up.clamp(min=-limit, max=limit)
+ glu = gate * torch.sigmoid(gate * alpha)
+ return (up + 1) * glu
+
+
+def silu_and_mul_activation(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor:
+ """
+ SiLU (Swish) activation with element-wise multiplication.
+
+ Formula:
+ output = silu(gate) * up
+
+ Args:
+ gate: Gate tensor
+ up: Up tensor
+
+ Returns:
+ Activated tensor
+ """
+ return F.silu(gate) * up
+
+
+def route_tokens_cpu(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: torch.Tensor | None,
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_normalize_expert_weights: int | None = None,
+) -> tuple:
+ """
+ Route tokens to experts and compute expert weights and indices (CPU version).
+
+ Args:
+ x: Input tensor [batch, seq, hidden] or [tokens, hidden]
+ router_weight: Router weight [num_experts, hidden]
+ router_bias: Router bias [num_experts] or None
+ moe_top_k: Number of experts per token
+ moe_num_experts: Total number of experts
+ moe_normalize_expert_weights: Normalization order or None
+
+ Returns:
+ Tuple of (logits, expert_weights, expert_indices)
+ """
+ x_flat = x.view(-1, x.shape[-1])
+ logits = F.linear(x_flat, router_weight, router_bias)
+
+ if moe_top_k == 1:
+ expert_weights, expert_indices = logits.max(dim=-1, keepdim=True)
+ else:
+ expert_weights, expert_indices = torch.topk(logits, moe_top_k, dim=-1)
+
+ expert_weights = expert_weights.softmax(dim=-1)
+
+ if moe_normalize_expert_weights is not None:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+
+ return logits, expert_weights, expert_indices
+
+
+def cpu_fused_moe(
+ hidden_states: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ w1_bias: torch.Tensor | None = None,
+ w2_bias: torch.Tensor | None = None,
+ activation: str = "silu",
+ alpha: float = 1.702,
+ limit: float = 7.0,
+ is_interleaved: bool = True,
+) -> torch.Tensor:
+ """
+ CPU Fused MoE using PyTorch operations.
+
+ This implementation processes all experts in parallel using batched operations
+ instead of sequential for loops, which is more efficient on CPU.
+
+ Args:
+ hidden_states: [num_tokens, hidden_size]
+ w1: [num_experts, hidden_size, 2*inter_size] - gate_up_proj weights
+ w2: [num_experts, inter_size, hidden_size] - down_proj weights
+ topk_weights: [num_tokens, topk] - routing weights
+ topk_ids: [num_tokens, topk] - expert indices
+ w1_bias: [num_experts, 2*inter_size] or None
+ w2_bias: [num_experts, hidden_size] or None
+ activation: "silu" or "swigluoai"
+ alpha: swigluoai alpha parameter
+ limit: swigluoai limit parameter
+ is_interleaved: whether gate_up is interleaved [g0,u0,g1,u1,...] (True for GptOss)
+
+ Returns:
+ output: [num_tokens, hidden_size]
+ """
+ num_tokens, hidden_size = hidden_states.shape
+ num_experts = w1.shape[0]
+ inter_size = w2.shape[1]
+ topk = topk_weights.shape[1]
+
+ # Initialize output
+ output = torch.zeros_like(hidden_states)
+
+ # Build expert mask: which tokens go to which expert
+ # expert_mask[expert_id] contains indices of (token_idx, topk_pos) pairs
+ for expert_idx in range(num_experts):
+ # Find tokens assigned to this expert
+ # mask shape: [num_tokens, topk], True where topk_ids == expert_idx
+ mask = (topk_ids == expert_idx)
+
+ if not mask.any():
+ continue
+
+ # Get token indices and topk positions
+ token_indices, topk_positions = torch.where(mask)
+
+ if len(token_indices) == 0:
+ continue
+
+ # Gather input tokens for this expert
+ # current_hidden: [num_selected_tokens, hidden_size]
+ current_hidden = hidden_states[token_indices]
+
+ # Get weights for this expert
+ # w1[expert_idx]: [hidden_size, 2*inter_size]
+ # w2[expert_idx]: [inter_size, hidden_size]
+ expert_w1 = w1[expert_idx] # [hidden_size, 2*inter_size]
+ expert_w2 = w2[expert_idx] # [inter_size, hidden_size]
+
+ # First projection: hidden @ w1 -> [num_selected, 2*inter_size]
+ gate_up = current_hidden @ expert_w1
+
+ # Add bias if present
+ if w1_bias is not None:
+ gate_up = gate_up + w1_bias[expert_idx]
+
+ # Split gate and up projections
+ if is_interleaved:
+ # GptOss uses interleaved layout: [g0, u0, g1, u1, ...]
+ gate = gate_up[..., ::2] # [num_selected, inter_size]
+ up = gate_up[..., 1::2] # [num_selected, inter_size]
+ else:
+ # Standard layout: [gate_all, up_all]
+ gate = gate_up[..., :inter_size]
+ up = gate_up[..., inter_size:]
+
+ # Apply activation
+ if activation == "swigluoai":
+ activated = swigluoai_activation(gate, up, alpha, limit)
+ else: # silu
+ activated = silu_and_mul_activation(gate, up)
+
+ # Second projection: activated @ w2 -> [num_selected, hidden_size]
+ expert_out = activated @ expert_w2
+
+ # Add bias if present
+ if w2_bias is not None:
+ expert_out = expert_out + w2_bias[expert_idx]
+
+ # Apply routing weights and accumulate
+ # weights shape: [num_selected]
+ weights = topk_weights[token_indices, topk_positions].unsqueeze(-1)
+ weighted_out = expert_out * weights
+
+ # Accumulate to output
+ output.index_add_(0, token_indices, weighted_out.to(output.dtype))
+
+ return output
+
+
+class MegaBlocksMoeMLP(torch.nn.Module):
+ """
+ CPU MoE MLP module that can be used as a drop-in replacement for
+ the transformers GptOssMLP when using @use_kernel_forward_from_hub.
+ """
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> tuple:
+ """
+ Forward pass through the MoE layer.
+
+ Args:
+ x: Input tensor of shape [batch_size, seq_len, hidden_size] or [tokens, hidden_size]
+
+ Returns:
+ Tuple of (output, expert_weights) where:
+ - output: Tensor of same shape as input
+ - expert_weights: Expert weights for each token [tokens, top_k]
+ """
+ # Get MoE parameters from the wrapped modules
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+
+ # Detect activation type
+ if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
+ activation = "swigluoai"
+ alpha = self.experts.alpha
+ limit = self.experts.limit
+ else:
+ activation = getattr(self.experts, "activation", "silu")
+ alpha = 1.702
+ limit = 7.0
+
+ # Get weight tensors
+ if hasattr(self.experts, "gate_up_proj"):
+ w1 = self.experts.gate_up_proj
+ is_interleaved = True # GptOss uses interleaved layout
+ elif hasattr(self.experts, "w1"):
+ w1 = self.experts.w1
+ w3 = getattr(self.experts, "w3", None)
+ if w3 is not None:
+ w1 = torch.cat([w1, w3], dim=-1)
+ is_interleaved = False
+ else:
+ raise AttributeError("experts module must have 'gate_up_proj' or 'w1' attribute")
+
+ if hasattr(self.experts, "down_proj"):
+ w2 = self.experts.down_proj
+ elif hasattr(self.experts, "w2"):
+ w2 = self.experts.w2
+ else:
+ raise AttributeError("experts module must have 'down_proj' or 'w2' attribute")
+
+ # Get optional bias tensors
+ w1_bias = getattr(self.experts, "gate_up_proj_bias", None)
+ w2_bias = getattr(self.experts, "down_proj_bias", None)
+
+ # Store original shape
+ in_shape = x.size()
+
+ # Route tokens to experts
+ logits, expert_weights, expert_indices = route_tokens_cpu(
+ x,
+ self.router.weight,
+ getattr(self.router, "bias", None),
+ moe_top_k,
+ moe_num_experts,
+ moe_normalize_expert_weights,
+ )
+
+ # Reshape input for fused MoE
+ x_flat = x.view(-1, x.shape[-1])
+
+ # Call CPU fused MoE
+ output = cpu_fused_moe(
+ hidden_states=x_flat,
+ w1=w1,
+ w2=w2,
+ topk_weights=expert_weights,
+ topk_ids=expert_indices,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ activation=activation,
+ alpha=alpha,
+ limit=limit,
+ is_interleaved=is_interleaved,
+ )
+
+ # Restore original shape
+ output = output.view(in_shape)
+
+ return output, expert_weights
+
+
+# Export classes and functions
+__all__ = [
+ "MegaBlocksMoeMLP",
+ "cpu_fused_moe",
+ "route_tokens_cpu",
+ "swigluoai_activation",
+ "silu_and_mul_activation",
+]
diff --git a/build/torch29-cxx11-cu129-aarch64-linux/cpu_moe_cpp.py b/build/torch29-cxx11-cu129-aarch64-linux/cpu_moe_cpp.py
new file mode 100644
index 0000000000000000000000000000000000000000..073ff66d24ce348fbb5ed19c9027fadd3f7a9c61
--- /dev/null
+++ b/build/torch29-cxx11-cu129-aarch64-linux/cpu_moe_cpp.py
@@ -0,0 +1,265 @@
+# SPDX-License-Identifier: Apache-2.0
+# MegaBlocks C++ Optimized CPU MoE
+
+"""
+C++ accelerated MoE with brgemm optimization for Intel AMX.
+Direct replacement for cpu_fused_moe.MegaBlocksMoeMLP with better performance.
+"""
+
+import torch
+from typing import Optional
+from .cpu_fused_moe import route_tokens_cpu
+from ._ops import ops
+
+
+def _to_local_tensor(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
+ """Convert DTensor to local torch.Tensor if needed for custom ops compatibility."""
+ if tensor is None:
+ return None
+ # Check if it's a DTensor by looking for the to_local() method
+ if hasattr(tensor, "to_local"):
+ return tensor.to_local()
+ return tensor
+
+
+def fused_moe_cpp(
+ hidden_states: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ inplace: bool = False,
+ use_int8_w8a8: bool = False,
+ use_fp8_w8a16: bool = False,
+ use_mxfp4: bool = False,
+ w1_scale: Optional[torch.Tensor] = None,
+ w2_scale: Optional[torch.Tensor] = None,
+ block_size: Optional[list] = None,
+ a1_scale: Optional[torch.Tensor] = None,
+ a2_scale: Optional[torch.Tensor] = None,
+ w1_bias: Optional[torch.Tensor] = None,
+ w2_bias: Optional[torch.Tensor] = None,
+ alpha: Optional[float] = None,
+ limit: Optional[float] = None,
+ is_vnni: bool = False,
+) -> torch.Tensor:
+ """
+ C++ Fused MoE with brgemm optimization (sglang compatible interface).
+
+ Uses at::native::cpublas::brgemm for efficient batch GEMM on Intel CPUs.
+ Supports both silu_and_mul (standard SwiGLU) and swigluoai (GptOss) activations.
+
+ Args:
+ hidden_states: Input tensor [M, K]
+ w1: Gate and up projections [E, 2N, K]
+ w2: Down projection [E, K, N]
+ topk_weights: Expert weights [M, topk]
+ topk_ids: Expert indices [M, topk]
+ inplace: Whether to use hidden_states as output
+ use_int8_w8a8: Use int8 quantization
+ use_fp8_w8a16: Use fp8 quantization
+ use_mxfp4: Use mxfp4 quantization
+ w1_scale, w2_scale: Quantization scales
+ block_size: Block size for fp8
+ a1_scale, a2_scale: Activation scales
+ w1_bias, w2_bias: Optional biases
+ alpha: swigluoai alpha parameter (set to enable swiglu)
+ limit: swigluoai limit parameter (set to enable swiglu)
+ is_vnni: Whether w1/w2 are already in VNNI packed format
+ """
+ # MXFP4/FP8 kernels only support bf16, convert if needed
+ orig_dtype = hidden_states.dtype
+ need_convert = ((use_mxfp4 or use_fp8_w8a16) and orig_dtype != torch.bfloat16) or orig_dtype == torch.float32
+ if need_convert:
+ hidden_states = hidden_states.to(torch.bfloat16)
+
+ # bias must match hidden_states dtype
+ if w1_bias is not None:
+ w1_bias = w1_bias.to(hidden_states.dtype)
+ if w2_bias is not None:
+ w2_bias = w2_bias.to(hidden_states.dtype)
+
+ # Convert DTensor to local tensor for custom ops compatibility (TP mode)
+ hidden_states = _to_local_tensor(hidden_states)
+ w1 = _to_local_tensor(w1)
+ w2 = _to_local_tensor(w2)
+ topk_weights = _to_local_tensor(topk_weights)
+ topk_ids = _to_local_tensor(topk_ids)
+ w1_scale = _to_local_tensor(w1_scale)
+ w2_scale = _to_local_tensor(w2_scale)
+ a1_scale = _to_local_tensor(a1_scale)
+ a2_scale = _to_local_tensor(a2_scale)
+ w1_bias = _to_local_tensor(w1_bias)
+ w2_bias = _to_local_tensor(w2_bias)
+
+ output = ops.fused_experts(
+ hidden_states, w1, w2, topk_weights, topk_ids,
+ inplace, use_int8_w8a8, use_fp8_w8a16, use_mxfp4,
+ w1_scale, w2_scale, block_size, a1_scale, a2_scale,
+ w1_bias, w2_bias, alpha, limit, is_vnni
+ )
+
+ # Convert back to original dtype if needed
+ if need_convert:
+ output = output.to(orig_dtype)
+ return output
+
+
+class CPUMegaBlocksMoeMLP(torch.nn.Module):
+ """
+ C++ optimized MoE MLP using brgemm.
+ Drop-in replacement for cpu_fused_moe.MegaBlocksMoeMLP with better performance.
+
+ Usage in transformers:
+ # Will be used via @use_kernel_forward_from_hub decorator
+ """
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> tuple:
+ """
+ Forward pass through the MoE layer using C++ kernel.
+
+ Args:
+ x: Input tensor [batch_size, seq_len, hidden_size]
+
+ Returns:
+ Tuple of (output, expert_weights)
+ """
+ # Optimization for GPT-OSS model
+ if getattr(self, "use_mxfp4", None) is None:
+ self.use_mxfp4 = False
+
+ w1_scale = None
+ w2_scale = None
+
+ if (
+ not getattr(self, "packed_scales", False)
+ and hasattr(self.experts, "gate_up_proj")
+ and getattr(self.experts, "gate_up_proj_precision_config", None) is not None
+ ):
+ # convert scales
+ data_1 = ops.convert_scale_packed(self.experts.gate_up_proj_precision_config.weight_scale.data.transpose(-1, -2).contiguous())
+ data_2 = ops.convert_scale_packed(self.experts.down_proj_precision_config.weight_scale.data.transpose(-1, -2).contiguous())
+ self.experts.gate_up_proj_precision_config.weight_scale.storage.data = data_1
+ self.experts.down_proj_precision_config.weight_scale.storage.data = data_2
+ self.packed_scales = True
+ self.use_mxfp4 = True
+
+ if not getattr(self, "packed_weight", False) and hasattr(
+ self.experts, "gate_up_proj"
+ ):
+ # convert weights
+ data_1 = self.experts.gate_up_proj.data.transpose(-1, -2).contiguous()
+ data_2 = self.experts.down_proj.data.transpose(-1, -2).contiguous()
+ if self.use_mxfp4:
+ self.experts.gate_up_proj.storage.data = ops.convert_weight_packed(data_1)
+ self.experts.down_proj.storage.data = ops.convert_weight_packed(data_2)
+ else:
+ # convert_weight_packed only supports bfloat16, float16, int8, fp8_e4m3 or uint8(mxfp4 or int4).
+ data_1 = data_1.to(torch.bfloat16) if data_1.dtype == torch.float32 else data_1
+ data_2 = data_2.to(torch.bfloat16) if data_2.dtype == torch.float32 else data_2
+ self.experts.gate_up_proj.data = ops.convert_weight_packed(data_1)
+ self.experts.down_proj.data = ops.convert_weight_packed(data_2)
+
+ # C++ kernel does not support float32.
+ dtype = torch.bfloat16 if x.dtype == torch.float32 else x.dtype
+ if getattr(self.experts, "gate_up_proj_bias", None) is not None:
+ self.experts.gate_up_proj_bias.data = self.experts.gate_up_proj_bias.data.to(dtype)
+ if getattr(self.experts, "down_proj_bias", None) is not None:
+ self.experts.down_proj_bias.data = self.experts.down_proj_bias.data.to(dtype)
+
+ self.packed_weight = True
+
+ # Get MoE parameters
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
+
+ # Detect activation type
+ if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
+ activation = "swigluoai"
+ alpha = self.experts.alpha
+ limit = self.experts.limit
+ else:
+ activation = getattr(self.experts, "activation", "silu")
+ alpha = 1.702
+ limit = 7.0
+
+ # Get weight tensors
+ if hasattr(self.experts, "gate_up_proj"):
+ w1 = self.experts.gate_up_proj
+ elif hasattr(self.experts, "w1"):
+ w1 = self.experts.w1
+ w3 = getattr(self.experts, "w3", None)
+ if w3 is not None:
+ w1 = torch.cat([w1, w3], dim=-1)
+ else:
+ raise AttributeError("experts module must have 'gate_up_proj' or 'w1' attribute")
+
+ if hasattr(self.experts, "down_proj"):
+ w2 = self.experts.down_proj
+ elif hasattr(self.experts, "w2"):
+ w2 = self.experts.w2
+ else:
+ raise AttributeError("experts module must have 'down_proj' or 'w2' attribute")
+
+ # Get optional bias tensors
+ w1_bias = getattr(self.experts, "gate_up_proj_bias", None)
+ w2_bias = getattr(self.experts, "down_proj_bias", None)
+ w1_bias = w1_bias if w1_bias is None else w1_bias.data
+ w2_bias = w2_bias if w2_bias is None else w2_bias.data
+
+ if self.use_mxfp4:
+ w1_scale = self.experts.gate_up_proj_precision_config.weight_scale.data
+ w2_scale = self.experts.down_proj_precision_config.weight_scale.data
+
+ # Store original shape
+ in_shape = x.size()
+
+ # Route tokens to experts (Python implementation is fast enough)
+ logits, expert_weights, expert_indices = route_tokens_cpu(
+ x,
+ self.router.weight,
+ getattr(self.router, "bias", None),
+ moe_top_k,
+ moe_num_experts,
+ moe_normalize_expert_weights,
+ )
+
+ # Flatten input
+ x_flat = x.view(-1, x.shape[-1])
+
+ # Determine alpha/limit for swiglu activation
+ use_alpha = alpha if activation == "swigluoai" else None
+ use_limit = limit if activation == "swigluoai" else None
+
+ # Call C++ optimized kernel
+ output = fused_moe_cpp(
+ hidden_states=x_flat,
+ w1=w1.data,
+ w2=w2.data,
+ topk_weights=expert_weights,
+ topk_ids=expert_indices.to(torch.int32),
+ inplace=False,
+ use_int8_w8a8=False,
+ use_fp8_w8a16=False,
+ use_mxfp4=self.use_mxfp4,
+ w1_scale=w1_scale,
+ w2_scale=w2_scale,
+ block_size=None,
+ a1_scale=None,
+ a2_scale=None,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ alpha=use_alpha,
+ limit=use_limit,
+ is_vnni=getattr(self, "packed_weight", False),
+ )
+
+ # Restore original shape
+ output = output.view(in_shape)
+
+ return output, expert_weights
+
+
+__all__ = ["fused_moe_cpp", "MegaBlocksMoeMLP"]
diff --git a/build/torch29-cxx11-cu129-aarch64-linux/grouped_gemm/__init__.py b/build/torch29-cxx11-cu129-aarch64-linux/grouped_gemm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b91c8308f0c24f4c4171b6e4f15b6f76dabf295a
--- /dev/null
+++ b/build/torch29-cxx11-cu129-aarch64-linux/grouped_gemm/__init__.py
@@ -0,0 +1,2 @@
+from . import ops
+from . import backend
diff --git a/build/torch29-cxx11-cu129-aarch64-linux/grouped_gemm/backend.py b/build/torch29-cxx11-cu129-aarch64-linux/grouped_gemm/backend.py
new file mode 100644
index 0000000000000000000000000000000000000000..76037d8039cbfc2f0577275c78e4bc0be762592a
--- /dev/null
+++ b/build/torch29-cxx11-cu129-aarch64-linux/grouped_gemm/backend.py
@@ -0,0 +1,33 @@
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# # TODO(tgale): Wrap this in a try-block with better
+# # error message and instructions for building the
+# # c++ operations.
+# import grouped_gemm_backend as backend
+
+# We import the backend operations from the megablocks package as
+# grouped_gemm is vendored in megablocks in this repository.
+# from ... import _ops as backend
+# from megablocks._ops import ops as backend # type: ignore
+from .._ops import ops as backend # type: ignore
+
+def _allocate_output(a, b, batch_sizes, trans_a, trans_b):
+ assert not (trans_a and trans_b)
+ assert batch_sizes.ndim == 1, "Expected 1d tensor for batch_sizes"
+ assert a.ndim == 2, "Expected 2d tensor for 'a'"
+ assert b.ndim == (2 if trans_a else 3)
+
+ shape = (
+ (batch_sizes.shape[0], a.shape[1], b.shape[1])
+ if trans_a else
+ (a.shape[0], (b.shape[1] if trans_b else b.shape[2]))
+ )
+ return torch.empty(*shape, device=a.device, dtype=a.dtype)
+
+def gmm(a, b, batch_sizes, trans_a=False, trans_b=False, c=None):
+ if c is None:
+ c = _allocate_output(a, b, batch_sizes, trans_a, trans_b)
+ backend.gmm(a, b, c, batch_sizes, trans_a, trans_b)
+ return c
diff --git a/build/torch29-cxx11-cu129-aarch64-linux/grouped_gemm/ops.py b/build/torch29-cxx11-cu129-aarch64-linux/grouped_gemm/ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b30dd14e23837ea3b12334f4e31337ed9ad2b69
--- /dev/null
+++ b/build/torch29-cxx11-cu129-aarch64-linux/grouped_gemm/ops.py
@@ -0,0 +1,33 @@
+from . import backend
+import torch
+
+
+class GroupedGemm(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, a, b, batch_sizes, trans_b):
+ ctx.save_for_backward(a, b, batch_sizes)
+ ctx.trans_b = trans_b
+ return backend.gmm(a, b, batch_sizes, trans_a=False, trans_b=trans_b)
+
+ @staticmethod
+ def backward(ctx, grad):
+ grad = grad.contiguous()
+ a, b, batch_sizes = ctx.saved_tensors
+ trans_b = ctx.trans_b
+
+ agrad = None
+ if ctx.needs_input_grad[0]:
+ agrad = backend.gmm(
+ grad, b, batch_sizes, trans_a=False, trans_b=not trans_b)
+
+ bgrad = None
+ if ctx.needs_input_grad[1]:
+ lhs, rhs = (grad, a) if trans_b else (a, grad)
+ bgrad = backend.gmm(
+ lhs, rhs, batch_sizes, trans_a=True, trans_b=False)
+ return agrad, bgrad, None, None
+
+
+def gmm(a, b, batch_sizes, trans_b=False):
+ return GroupedGemm.apply(a, b, batch_sizes, trans_b)
diff --git a/build/torch29-cxx11-cu129-aarch64-linux/grouped_gemm_util.py b/build/torch29-cxx11-cu129-aarch64-linux/grouped_gemm_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d6d49fc46b0a57ad46e4179df3cc1ac2a24f7ae
--- /dev/null
+++ b/build/torch29-cxx11-cu129-aarch64-linux/grouped_gemm_util.py
@@ -0,0 +1,31 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+import warnings
+
+_grouped_gemm_is_available: bool = False
+try:
+ # import grouped_gemm
+ pass
+ _grouped_gemm_is_available = True
+except ImportError as error:
+ warnings.warn('Grouped GEMM not available.')
+
+
+def grouped_gemm_is_available():
+ return _grouped_gemm_is_available
+
+
+def assert_grouped_gemm_is_available():
+ msg = (
+ 'Grouped GEMM not available. Please run '
+ '`pip install git+https://github.com/tgale96/grouped_gemm@main`.',
+ )
+ assert _grouped_gemm_is_available, msg
+
+
+# backend = grouped_gemm.backend if grouped_gemm_is_available() else None
+# ops = grouped_gemm.ops if grouped_gemm_is_available() else None
+
+
+from .grouped_gemm import backend as ops
+from .grouped_gemm import ops as backend
diff --git a/build/torch29-cxx11-cu129-aarch64-linux/layers.py b/build/torch29-cxx11-cu129-aarch64-linux/layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c91edcd9e2d1a4ef9eac90217ff481f08ab1886
--- /dev/null
+++ b/build/torch29-cxx11-cu129-aarch64-linux/layers.py
@@ -0,0 +1,1232 @@
+import torch
+import torch.distributed as dist
+
+from typing import Optional, Any, TYPE_CHECKING
+
+from . import _layers
+from . import ops
+
+# Conditional import for meta kernel registration
+if TYPE_CHECKING:
+
+ def register_fake(fn):
+ return lambda name: fn
+
+else:
+ try:
+ from torch.library import register_fake
+ except ImportError:
+ try:
+ from torch.library import impl_abstract as register_fake
+ except ImportError:
+ # Fallback for older PyTorch versions
+ def register_fake(op_name):
+ def decorator(fn):
+ return fn
+
+ return decorator
+
+
+# Meta kernel implementations for torch.compile compatibility
+def _install_meta_kernels():
+ """Install meta kernels for existing MegaBlocks operations"""
+
+ # Create wrapper functions that check for compilation and return meta tensors
+
+ # Patch ops.sort
+ if hasattr(ops, "sort"):
+ original_sort = ops.sort
+
+ def sort_with_meta(x, end_bit=None):
+ if torch.compiler.is_compiling():
+ # print("Using meta kernel for sort")
+ # Meta implementation - return tensors with correct shape/dtype/device
+ return torch.empty_like(x), torch.empty_like(x)
+ # print("Using original sort kernel")
+ return original_sort(x, end_bit)
+
+ ops.sort = sort_with_meta
+
+ # Patch ops.histogram
+ if hasattr(ops, "histogram"):
+ original_histogram = ops.histogram
+
+ def histogram_with_meta(x, max_val):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty((max_val,), dtype=torch.int32, device=x.device)
+ return original_histogram(x, max_val)
+
+ ops.histogram = histogram_with_meta
+
+ # Patch ops.inclusive_cumsum
+ if hasattr(ops, "inclusive_cumsum"):
+ original_inclusive_cumsum = ops.inclusive_cumsum
+
+ def inclusive_cumsum_with_meta(x, dim):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty_like(x)
+ return original_inclusive_cumsum(x, dim)
+
+ ops.inclusive_cumsum = inclusive_cumsum_with_meta
+
+ # Patch ops.binned_gather
+ if hasattr(ops, "binned_gather"):
+ original_binned_gather = ops.binned_gather
+
+ def binned_gather_with_meta(x, indices, bins, bin_size, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - output shape based on bin_size
+ if x.dim() >= 2:
+ hidden_size = x.size(-1)
+ return torch.empty(
+ (bin_size, x.size(1), hidden_size),
+ dtype=x.dtype,
+ device=x.device,
+ )
+ else:
+ return torch.empty((bin_size,), dtype=x.dtype, device=x.device)
+ return original_binned_gather(x, indices, bins, bin_size, top_k)
+
+ ops.binned_gather = binned_gather_with_meta
+
+ # Patch ops.binned_scatter
+ if hasattr(ops, "binned_scatter"):
+ original_binned_scatter = ops.binned_scatter
+
+ def binned_scatter_with_meta(x, indices, weights, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - typically reduces to 2D
+ if x.dim() >= 3:
+ return torch.empty(
+ (x.size(1), x.size(2)), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty_like(x)
+ return original_binned_scatter(x, indices, weights, bins, top_k)
+
+ ops.binned_scatter = binned_scatter_with_meta
+
+ # Patch ops.gather
+ if hasattr(ops, "gather"):
+ original_gather = ops.gather
+
+ def gather_with_meta(x, indices, bin_ids, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ if x.dim() >= 2:
+ hidden_size = x.size(-1)
+ return torch.empty(
+ (indices.numel(), hidden_size), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty(indices.shape, dtype=x.dtype, device=x.device)
+ return original_gather(x, indices, bin_ids, bins, top_k)
+
+ ops.gather = gather_with_meta
+
+ # Patch ops.scatter
+ if hasattr(ops, "scatter"):
+ original_scatter = ops.scatter
+
+ def scatter_with_meta(x, indices, bin_ids, weights, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - restore sequence shape
+ seq_len = (
+ indices.size(0) // top_k
+ if indices.numel() > 0 and top_k > 0
+ else x.size(0)
+ )
+ if x.dim() >= 2:
+ return torch.empty(
+ (seq_len, x.size(-1)), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty((seq_len,), dtype=x.dtype, device=x.device)
+ return original_scatter(x, indices, bin_ids, weights, bins, top_k)
+
+ ops.scatter = scatter_with_meta
+
+ # Patch ops.replicate
+ if hasattr(ops, "replicate"):
+ original_replicate = ops.replicate
+
+ def replicate_with_meta(x, bins, num_outputs):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty(
+ (x.shape[0], num_outputs), dtype=x.dtype, device=x.device
+ )
+ return original_replicate(x, bins, num_outputs)
+
+ ops.replicate = replicate_with_meta
+
+ # Patch ops.repeat (if it's a regular function)
+ if hasattr(ops, "repeat"):
+ original_repeat = ops.repeat
+
+ def repeat_with_meta(x, repeats):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ if isinstance(repeats, (tuple, list)):
+ new_shape = list(x.shape)
+ for i, rep in enumerate(repeats):
+ if i < len(new_shape):
+ new_shape[i] *= rep
+ return torch.empty(new_shape, dtype=x.dtype, device=x.device)
+ else:
+ new_shape = [x.size(0) * repeats] + list(x.shape[1:])
+ return torch.empty(new_shape, dtype=x.dtype, device=x.device)
+ return original_repeat(x, repeats)
+
+ ops.repeat = repeat_with_meta
+
+
+# Install meta kernels on import
+try:
+ _install_meta_kernels()
+except Exception as e:
+ # If meta kernel installation fails, continue without them
+ # torch.compile may not work but the library will still function
+ import warnings
+
+ warnings.warn(
+ f"Failed to install meta kernels for torch.compile support: {e}", UserWarning
+ )
+
+
+# Set the expert model parallel attributes on a tensor
+def set_expert_model_parallel_attributes(
+ tensor: torch.Tensor,
+ is_parallel: bool,
+):
+ assert not hasattr(tensor, "expert_model_parallel")
+ setattr(tensor, "expert_model_parallel", is_parallel)
+
+
+# Get the expert model parallel attributes from a tensor
+def expert_sharding_degree(
+ world_size: int,
+ moe_num_experts: int,
+) -> int:
+ esd = min(world_size, moe_num_experts)
+ if (moe_num_experts % esd) != 0:
+ raise ValueError(f"Cannot shard {moe_num_experts} experts {esd} ways.")
+ return esd
+
+
+# Calculate the hidden sharding degree based on world size and expert sharding degree
+def hidden_sharding_degree(
+ world_size: int,
+ moe_num_experts: int,
+ ffn_hidden_size: int,
+) -> int:
+ esd = expert_sharding_degree(world_size, moe_num_experts)
+ hsd = world_size // esd
+ if (ffn_hidden_size % hsd) != 0:
+ raise ValueError(f"Cannot shard {ffn_hidden_size} features {hsd} ways.")
+ if (esd * hsd) != world_size:
+ raise ValueError(
+ f"Invalid sharding. expert_sharding_degree ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size})."
+ )
+ return hsd
+
+
+# Calculate the number of experts per rank based on world size and expert sharding degree
+def experts_per_rank(
+ moe_num_experts: int,
+ world_size: int,
+) -> int:
+ return moe_num_experts // expert_sharding_degree(world_size, moe_num_experts)
+
+
+# Calculate the number of features per rank based on ffn hidden size and hidden sharding degree
+def features_per_rank(
+ ffn_hidden_size: int, world_size: int, moe_num_experts: int
+) -> int:
+ return ffn_hidden_size // hidden_sharding_degree(
+ world_size, moe_num_experts, ffn_hidden_size
+ )
+
+
+# Apply jitter to the input tensor
+def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor:
+ low = 1.0 - moe_jitter_eps
+ high = 1.0 + moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return x * (low + noise * (high - low))
+
+
+# Compute the top-k scores from the logits
+def compute_top_k(scores: torch.Tensor, moe_top_k: int):
+ if moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, moe_top_k, dim=-1)
+
+
+# Route tokens to experts and compute expert weights and indices
+def route_tokens(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: torch.Tensor,
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ if training and moe_jitter_eps is not None:
+ x = apply_jitter(x, moe_jitter_eps)
+
+ x_flat = x.view(-1, x.shape[-1])
+ logits = torch.nn.functional.linear(x_flat, router_weight, router_bias)
+ expert_weights, expert_indices = compute_top_k(logits, moe_top_k)
+ expert_weights = expert_weights.softmax(dim=-1)
+ if moe_normalize_expert_weights is not None:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+ if uniform_expert_assignment:
+ expert_indices = _layers.router._uniform_expert_assignment(
+ expert_indices,
+ moe_num_experts,
+ )
+
+ return logits, expert_weights, expert_indices
+
+
+# Scale the gradient of the weights
+def scale_grad(
+ w: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+) -> torch.Tensor:
+ if gradient_scale is None:
+ return w
+ return _layers.mlp.scale_gradient(w, gradient_scale)
+
+
+# Forward pass for the MLP layer
+def mlp_forward(
+ x: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ limit: float = 7.0,
+):
+ # Scale weights
+ w1 = scale_grad(w1, gradient_scale)
+ w2 = scale_grad(w2, gradient_scale)
+ w1_bias = scale_grad(w1_bias, gradient_scale)
+ w2_bias = scale_grad(w2_bias, gradient_scale)
+
+ # Resolve dtensors
+ w1 = _layers.mlp.resolve_dtensor(w1)
+ w2 = _layers.mlp.resolve_dtensor(w2)
+ w1_bias = _layers.mlp.resolve_dtensor(w1_bias)
+ w2_bias = _layers.mlp.resolve_dtensor(w2_bias)
+
+ # Forward pass
+ gate_up = torch.bmm(x, w1) + w1_bias[..., None, :]
+ gate, up = gate_up[..., ::2], gate_up[..., 1::2]
+ gate = gate.clamp(min=None, max=limit)
+ up = up.clamp(min=-limit, max=limit)
+ glu = gate * torch.sigmoid(gate * alpha)
+ next_states = torch.bmm(((up + 1) * glu), w2)
+ next_states += w2_bias[..., None, :]
+ return next_states
+
+# Shared expert MLP forward pass
+def shared_mlp_forward(
+ x: torch.Tensor,
+ up_proj_weight: torch.Tensor,
+ down_proj_weight: torch.Tensor,
+ up_proj_bias: Optional[torch.Tensor] = None,
+ down_proj_bias: Optional[torch.Tensor] = None,
+ activation_fn: Optional[Any] = None,
+ gradient_scale: Optional[float] = None,
+) -> torch.Tensor:
+ # Default activation function
+ if activation_fn is None:
+ activation_fn = torch.nn.functional.gelu
+
+ # Scale weights
+ up_proj_weight = scale_grad(up_proj_weight, gradient_scale)
+ down_proj_weight = scale_grad(down_proj_weight, gradient_scale)
+ if up_proj_bias is not None:
+ up_proj_bias = scale_grad(up_proj_bias, gradient_scale)
+ if down_proj_bias is not None:
+ down_proj_bias = scale_grad(down_proj_bias, gradient_scale)
+
+ # Resolve dtensors
+ up_proj_weight = _layers.mlp.resolve_dtensor(up_proj_weight)
+ down_proj_weight = _layers.mlp.resolve_dtensor(down_proj_weight)
+ if up_proj_bias is not None:
+ up_proj_bias = _layers.mlp.resolve_dtensor(up_proj_bias)
+ if down_proj_bias is not None:
+ down_proj_bias = _layers.mlp.resolve_dtensor(down_proj_bias)
+
+ # Up projection
+ x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias)
+
+ # Activation
+ x = activation_fn(x)
+
+ # Down projection
+ x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias)
+
+ return x
+
+
+# Combine outputs from shared expert and regular experts
+def combine_expert_shared_outputs(
+ shared_expert_out: torch.Tensor,
+ expert_out: torch.Tensor,
+ shared_expert_weighted_sum: bool = False,
+ moe_top_k: int = 1,
+) -> torch.Tensor:
+ if shared_expert_weighted_sum:
+ # Weighted sum based on number of experts used
+ total_experts = moe_top_k + 1
+ shared_weight = 1.0 / total_experts
+ expert_weight = moe_top_k / total_experts
+ return shared_expert_out * shared_weight + expert_out * expert_weight
+ else:
+ # Simple addition
+ return shared_expert_out + expert_out
+
+
+# Global variable to store load balancing loss
+_LOAD_BALANCING_LOSS = []
+
+
+def save_load_balancing_loss(loss):
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.append(loss)
+
+
+def get_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ return _LOAD_BALANCING_LOSS
+
+
+def clear_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.clear()
+
+
+def batched_load_balancing_loss(args):
+ if args.moe_loss_weight == 0:
+ return 0.0
+
+ tokens_per_expert, expert_scores = zip(*get_load_balancing_loss())
+ num_layers_per_pipeline_stage = args.num_layers // args.pipeline_model_parallel_size
+ if args.num_layers_per_virtual_pipeline_stage is not None:
+ num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage
+
+ if len(tokens_per_expert) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f"Expected {num_layers_per_pipeline_stage} token_per_experts "
+ f"but found {len(tokens_per_expert)}.\nnum_layers = "
+ f"{args.num_layers}\npipeline_model_parallel_size = "
+ f"{args.pipeline_model_parallel_size}\n"
+ "num_layers_per_virtual_pipeline_stage"
+ f" = {args.num_layers_per_virtual_pipeline_stage}",
+ )
+ if len(expert_scores) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f"Expected {num_layers_per_pipeline_stage} expert_scores "
+ f"but found {len(tokens_per_expert)}.\nnum_layers = "
+ f"{args.num_layers}\npipeline_model_parallel_size = "
+ f"{args.pipeline_model_parallel_size}\n"
+ "num_layers_per_virtual_pipeline_stage"
+ f" = {args.num_layers_per_virtual_pipeline_stage}",
+ )
+
+ # Verify the shape of the tokens_per_expert and expert_scores tensors.
+ assert all(
+ (x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert)
+ )
+
+ tokens = expert_scores[0].shape[0]
+ assert all(
+ (
+ (
+ x.ndim == 2
+ and x.shape[1] == args.moe_num_experts
+ and x.shape[0] == tokens
+ )
+ for x in expert_scores
+ )
+ )
+
+ # Concatenate the contributions of each layer and convert to
+ # the correct types and formats for the dot product.
+ expert_scores = torch.cat(expert_scores, dim=1)
+ if args.moe_lbl_in_fp32:
+ expert_scores = expert_scores.float()
+ if tokens != 0:
+ expert_scores = expert_scores.mean(dim=0)
+ else:
+ expert_scores = expert_scores.sum(dim=0)
+ tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype)
+
+ expected_values = num_layers_per_pipeline_stage * args.moe_num_experts
+ assert tokens_per_expert.numel() == expected_values
+ assert expert_scores.numel() == expected_values
+
+ # Calculate the total scale across all factors.
+ #
+ # loss_weight * num_experts / (num_layers * tokens * top_k)
+ scale_numerator = args.moe_num_experts * args.moe_loss_weight
+ scale_denominator = args.num_layers * tokens * args.moe_top_k
+ scale = scale_numerator / scale_denominator
+ return scale * torch.dot(tokens_per_expert, expert_scores)
+
+
+# Calculate the expert capacity based on tokens, top_k, number of experts,
+# expert parallel group, capacity factor, and whether expert model parallelism is used.
+def expert_capacity(
+ tokens: int,
+ top_k: int,
+ num_experts: int,
+ expert_parallel_group: int,
+ moe_capacity_factor: float,
+ moe_expert_model_parallelism: bool,
+) -> int:
+ world_size = (
+ dist.get_world_size(expert_parallel_group)
+ if moe_expert_model_parallelism
+ else 1
+ )
+
+ tokens_per_expert = top_k * tokens * world_size / num_experts
+ return int(moe_capacity_factor * tokens_per_expert)
+
+
+def load_balancing_loss(
+ tokens_per_expert: torch.Tensor,
+ expert_scores: torch.Tensor,
+ top_k: int,
+ num_experts: int,
+):
+ assert len(expert_scores.size()) == 2
+ tokens, num_experts = expert_scores.size()
+ assert num_experts == num_experts
+ assert len(tokens_per_expert.size()) == 1
+ (num_experts,) = tokens_per_expert.size()
+ assert num_experts == num_experts
+ scale = num_experts / (tokens * top_k)
+ return scale * torch.dot(
+ tokens_per_expert.to(expert_scores.dtype),
+ expert_scores.mean(dim=0),
+ )
+
+
+def indices_and_bins(
+ top_expert: torch.Tensor,
+ sort_end_bit: int,
+ num_experts: int,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ top_expert = top_expert.int()
+
+ # Ensure contiguous memory layout
+ top_expert = top_expert.contiguous()
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(top_expert.device):
+ output = ops.sort(top_expert, sort_end_bit)
+ bin_ids, indices = output
+ tokens_per_expert = ops.histogram(top_expert, num_experts)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+
+ bins = bins.view(1) if not len(bins.size()) else bins
+ return indices, bin_ids, bins, tokens_per_expert
+
+
+def expert_capacity_fn(
+ tokens: int,
+ top_k: int,
+ num_experts: int,
+ expert_parallel_group: torch.distributed.ProcessGroup,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+) -> int:
+ world_size = (
+ dist.get_world_size(expert_parallel_group)
+ if moe_expert_model_parallelism
+ else 1
+ )
+ tokens_per_expert = top_k * tokens * world_size / num_experts
+ return int(moe_capacity_factor * tokens_per_expert)
+
+
+def permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ top_k,
+ w1,
+ w2,
+ w1_bias,
+ w2_bias,
+ gradient_scale,
+ alpha,
+):
+ # Route tokens to experts
+ x = x.view(-1, x.shape[-1])
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(x.device):
+ x = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
+
+ # Expert computation
+ x = mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale, alpha)
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(x.device):
+ # Route tokens back
+ out = ops.binned_scatter(x, indices, expert_weights, bins, top_k)
+ return out
+
+
+def forward_once(
+ x: torch.Tensor,
+ expert_weights: torch.Tensor,
+ top_experts: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ top_k: int = 4,
+ num_experts: int = 128,
+ expert_parallel_group: int = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ mlp_impl: Optional[str] = None,
+):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
+ top_experts, sort_end_bit, num_experts
+ )
+
+ # Calculate expert capacity
+ sl, bs, _ = x.size()
+
+ expert_capacity = expert_capacity_fn(
+ sl * bs,
+ top_k,
+ num_experts,
+ expert_parallel_group,
+ moe_capacity_factor,
+ moe_expert_model_parallelism,
+ )
+
+ if expert_capacity == 0:
+ expert_capacity = torch.max(tokens_per_expert).item()
+
+ x = permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ top_k,
+ w1,
+ w2,
+ w1_bias,
+ w2_bias,
+ gradient_scale,
+ alpha,
+ )
+ return x, tokens_per_expert
+
+
+def parallel_forward_once(
+ x: torch.Tensor,
+ expert_weights: torch.Tensor,
+ top_experts: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ top_k: int = 4,
+ num_experts: int = 128,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = True,
+ hidden_size: int = 1152,
+ mlp_impl: Optional[str] = "grouped",
+):
+ # Flatten inputs
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+
+ # TODO: remove debugging var
+ # my_rank = dist.get_rank(expert_parallel_group) if expert_parallel_group else 0
+
+ with torch.no_grad():
+ # Step 1: Local permutation setup
+ indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
+ top_experts, sort_end_bit, num_experts
+ )
+
+ # Calculate sharding parameters
+ world_size = dist.get_world_size(expert_parallel_group)
+ hidden_sharding_deg = hidden_sharding_degree(
+ world_size, num_experts, hidden_size
+ )
+ experts_per_rank_val = experts_per_rank(num_experts, world_size)
+
+ # Replicate token counts for hidden sharding
+ repeated_tokens_per_expert = ops.repeat(
+ tokens_per_expert, (hidden_sharding_deg,)
+ )
+
+ # Exchange token counts across devices
+ parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert)
+
+ # Ensure CUB knows which device to use
+ tpe_handle = dist.all_to_all_single(
+ parallel_tokens_per_expert,
+ repeated_tokens_per_expert,
+ group=expert_parallel_group,
+ async_op=True,
+ )
+
+ # Step 2: Local permutation - group tokens by target device
+ x = x.view(-1, x.shape[-1]) # [sl * bs, hs]
+ x = ops.gather(x, indices, bin_ids, bins, top_k)
+
+ # Step 3: Compute communication counts and exchange tokens
+ with torch.no_grad():
+ tpe_handle.wait()
+
+ # Reshape for per-device calculations
+ repeated_tokens_per_expert = repeated_tokens_per_expert.view(
+ world_size, experts_per_rank_val
+ )
+ parallel_tokens_per_expert = parallel_tokens_per_expert.view(
+ world_size, experts_per_rank_val
+ )
+
+ # Calculate send/recv counts
+ send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1).tolist()
+ # recv_counts = parallel_tokens_per_expert.cpu().sum(dim=-1).tolist()
+ parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
+ recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1).tolist()
+ tokens_received = sum(recv_counts)
+
+ # Replicate for hidden sharding
+ x = ops.repeat(x, (hidden_sharding_deg, 1))
+
+ # Cross-device token exchange
+ parallel_x, parallel_x_handle = _layers.all_to_all.all_to_all(
+ x, recv_counts, send_counts, expert_parallel_group, async_op=True
+ )
+
+ with torch.no_grad():
+ # Step 4: Setup for local expert computation
+ replicate_bins = ops.inclusive_cumsum(parallel_tokens_per_expert.flatten(), 0)
+ replicate_bins = (
+ replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins
+ )
+
+ # Create expert indices for received tokens
+ parallel_top_expert = torch.remainder(
+ torch.arange(
+ num_experts * hidden_sharding_deg,
+ dtype=torch.int32,
+ device=indices.device,
+ ),
+ experts_per_rank_val,
+ )
+ parallel_top_expert = ops.replicate(
+ parallel_top_expert.unsqueeze(dim=0),
+ replicate_bins,
+ tokens_received,
+ ).flatten()
+
+ # Sort tokens by expert assignment
+ parallel_bin_ids, parallel_indices = ops.sort(
+ parallel_top_expert,
+ sort_end_bit,
+ )
+
+ # Calculate bins for local experts
+ parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
+ dim=0, dtype=torch.int
+ )
+ parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = (
+ parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins
+ )
+
+ # Calculate expert capacity
+ expert_capacity = expert_capacity_fn(
+ tokens_received,
+ top_k,
+ experts_per_rank_val,
+ expert_parallel_group,
+ moe_capacity_factor,
+ moe_expert_model_parallelism,
+ )
+ if expert_capacity == 0:
+ expert_capacity = torch.max(parallel_tokens_per_expert).item()
+
+ # Locally permute the tokens and perform the expert computation.
+ # Block to make sure that the cross-device permutation is complete.
+ if mlp_impl == "grouped":
+ # GroupedMLP requires counts on CPU. We can use the tensor already
+ # moved to CPU for the prior all_to_all, which avoids an extra
+ # device synchronization.
+ parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+
+ # Step 5: Expert computation
+ parallel_x_handle.wait()
+
+ parallel_x = permute_and_compute(
+ parallel_x,
+ parallel_tokens_per_expert,
+ parallel_indices,
+ parallel_bin_ids,
+ None, # expert_weights
+ parallel_bins,
+ expert_capacity,
+ top_k=1,
+ w1=w1,
+ w2=w2,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ )
+
+ # Step 6: Reverse communication - send results back
+ x, _ = _layers.all_to_all.all_to_all(
+ parallel_x, send_counts, recv_counts, expert_parallel_group
+ )
+
+ # Step 7: Reduce across hidden sharding dimension
+ shape = (hidden_sharding_deg, -1, hidden_size)
+ x = x.view(shape).sum(dim=0)
+
+ # Step 8: Final local unpermutation
+ x = ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
+
+ return x, tokens_per_expert.flatten()
+
+
+def moe_forward(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: Optional[torch.Tensor],
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+ w1: torch.Tensor = None,
+ w2: torch.Tensor = None,
+ w1_bias: torch.Tensor = None,
+ w2_bias: torch.Tensor = None,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ forward_fn: Any = None,
+ hidden_size: int = None,
+ mlp_impl: str = "grouped",
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+
+ # Route tokens to experts
+ logits, expert_weights, expert_indices = route_tokens(
+ x,
+ router_weight,
+ router_bias,
+ moe_top_k,
+ moe_num_experts,
+ moe_jitter_eps,
+ moe_normalize_expert_weights,
+ uniform_expert_assignment,
+ training,
+ )
+
+ # Create router scores for output
+ router_scores = (
+ torch.zeros_like(logits)
+ .scatter_(1, expert_indices, expert_weights)
+ .transpose(0, 1)
+ )
+
+ in_shape = x.size()
+
+ # Prepare forward function arguments
+ forward_args = {
+ "x": x,
+ "expert_weights": expert_weights,
+ "top_experts": expert_indices,
+ "w1": w1,
+ "w2": w2,
+ "w1_bias": w1_bias,
+ "w2_bias": w2_bias,
+ "gradient_scale": gradient_scale,
+ "alpha": alpha,
+ "sort_end_bit": sort_end_bit,
+ "top_k": moe_top_k,
+ "num_experts": moe_num_experts,
+ "expert_parallel_group": expert_parallel_group,
+ "moe_capacity_factor": moe_capacity_factor,
+ "moe_expert_model_parallelism": moe_expert_model_parallelism,
+ "mlp_impl": mlp_impl,
+ }
+
+ # Add hidden_size for parallel forward
+ if moe_expert_model_parallelism and hidden_size is not None:
+ forward_args["hidden_size"] = hidden_size
+ elif moe_expert_model_parallelism and hidden_size is None:
+ # Infer hidden_size from input shape
+ forward_args["hidden_size"] = x.shape[-1]
+
+ # Compute expert outputs
+ x, tokens_per_expert = forward_fn(**forward_args)
+
+ # Save load balancing loss if needed
+ moe_loss_weight = 0.0 # Can be made configurable
+ if training and moe_loss_weight > 0:
+ save_load_balancing_loss((tokens_per_expert, logits))
+
+ # Restore original shape
+ x = x.view(in_shape)
+
+ return x, expert_weights, router_scores
+
+
+def moe_forward_with_shared_expert(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: Optional[torch.Tensor],
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+ w1: torch.Tensor = None,
+ w2: torch.Tensor = None,
+ w1_bias: torch.Tensor = None,
+ w2_bias: torch.Tensor = None,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ forward_fn: Any = None,
+ hidden_size: int = None,
+ mlp_impl: str = "grouped",
+ # Shared expert parameters
+ shared_up_proj_weight: Optional[torch.Tensor] = None,
+ shared_down_proj_weight: Optional[torch.Tensor] = None,
+ shared_up_proj_bias: Optional[torch.Tensor] = None,
+ shared_down_proj_bias: Optional[torch.Tensor] = None,
+ shared_expert_weighted_sum: bool = False,
+ shared_activation_fn: Optional[Any] = None,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+
+ # First, compute regular MoE forward pass
+ expert_out, expert_weights, router_scores = moe_forward(
+ x=x,
+ router_weight=router_weight,
+ router_bias=router_bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=training,
+ w1=w1,
+ w2=w2,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=moe_expert_model_parallelism,
+ forward_fn=forward_fn,
+ hidden_size=hidden_size,
+ mlp_impl=mlp_impl,
+ )
+
+ # If shared expert weights provided, compute shared expert output
+ if shared_up_proj_weight is not None and shared_down_proj_weight is not None:
+ shared_expert_out = shared_mlp_forward(
+ x=x,
+ up_proj_weight=shared_up_proj_weight,
+ down_proj_weight=shared_down_proj_weight,
+ up_proj_bias=shared_up_proj_bias,
+ down_proj_bias=shared_down_proj_bias,
+ activation_fn=shared_activation_fn,
+ gradient_scale=gradient_scale,
+ )
+
+ # Combine expert outputs
+ combined_out = combine_expert_shared_outputs(
+ shared_expert_out=shared_expert_out,
+ expert_out=expert_out,
+ shared_expert_weighted_sum=shared_expert_weighted_sum,
+ moe_top_k=moe_top_k,
+ )
+
+ return combined_out, expert_weights, router_scores
+
+ # Return regular MoE output if no shared expert
+ return expert_out, expert_weights, router_scores
+
+
+def create_shared_expert_weights(
+ hidden_size: int,
+ shared_expert_hidden_size: int,
+ device: torch.device,
+ dtype: torch.dtype,
+ init_method: Any,
+ output_layer_init_method: Any = None,
+) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
+
+ if output_layer_init_method is None:
+ output_layer_init_method = init_method
+
+ # Create weight tensors
+ up_proj_weight = torch.empty(
+ shared_expert_hidden_size,
+ hidden_size,
+ device=device,
+ dtype=dtype,
+ )
+ down_proj_weight = torch.empty(
+ hidden_size,
+ shared_expert_hidden_size,
+ device=device,
+ dtype=dtype,
+ )
+
+ # Initialize weights
+ init_method(up_proj_weight)
+ output_layer_init_method(down_proj_weight)
+
+ # No bias by default
+ return up_proj_weight, down_proj_weight, None, None
+
+
+# HACK: Extract device_mesh from pre-hook closure - required for transformers integration
+# This exists because device_mesh is trapped in hook closures with no model attribute
+# Fragile - breaks if hook structure changes or Python internals change
+# TODO: Replace with a more robust solution when available
+def get_device_mesh(model):
+ # Extract device_mesh from child's unused pre_hook closure
+ try:
+ # Find the pre-hook that contains 'device_mesh' in its closure
+ hook = next(
+ h
+ for h in model.experts._forward_pre_hooks.values()
+ if "device_mesh" in h.__code__.co_freevars
+ )
+ # Extract the device_mesh from the closure
+ return hook.__closure__[
+ hook.__code__.co_freevars.index("device_mesh")
+ ].cell_contents
+ except Exception:
+ return None
+
+
+class MegaBlocksMoeMLP(torch.nn.Module):
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ gradient_scale = getattr(self.experts, "gradient_scale", None)
+ alpha = getattr(self.experts, "alpha", 1.0)
+ moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+ uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
+
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = get_device_mesh(self)
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
+
+ has_parallel = (
+ expert_parallel_group is not None
+ and dist.is_initialized()
+ and dist.get_world_size(expert_parallel_group) > 1
+ )
+ forward_fn = parallel_forward_once if has_parallel else forward_once
+
+ sort_end_bit = max(
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
+ )
+ mlp_impl = getattr(self, "mlp_impl", "grouped")
+ output, expert_weights_out, *_ = moe_forward(
+ x=x,
+ router_weight=self.router.weight,
+ router_bias=self.router.bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=self.training,
+ w1=self.experts.gate_up_proj,
+ w2=self.experts.down_proj,
+ w1_bias=self.experts.gate_up_proj_bias,
+ w2_bias=self.experts.down_proj_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=has_parallel,
+ forward_fn=forward_fn,
+ hidden_size=self.experts.hidden_size,
+ mlp_impl=mlp_impl,
+ )
+ return output, expert_weights_out
+
+
+# Export main classes
+__all__ = ["MegaBlocksMoeMLP", "MegaBlocksMoeMLPWithSharedExpert"]
+
+
+class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
+
+ def __init__(self):
+ super().__init__()
+ # Shared expert weights will be set by the user
+ self.shared_up_proj_weight = None
+ self.shared_down_proj_weight = None
+ self.shared_up_proj_bias = None
+ self.shared_down_proj_bias = None
+ self.shared_expert_weighted_sum = False
+ self.shared_activation_fn = None
+
+ def set_shared_expert_weights(
+ self,
+ up_proj_weight: torch.Tensor,
+ down_proj_weight: torch.Tensor,
+ up_proj_bias: Optional[torch.Tensor] = None,
+ down_proj_bias: Optional[torch.Tensor] = None,
+ weighted_sum: bool = False,
+ activation_fn: Optional[Any] = None,
+ ):
+ self.shared_up_proj_weight = up_proj_weight
+ self.shared_down_proj_weight = down_proj_weight
+ self.shared_up_proj_bias = up_proj_bias
+ self.shared_down_proj_bias = down_proj_bias
+ self.shared_expert_weighted_sum = weighted_sum
+ self.shared_activation_fn = activation_fn
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ gradient_scale = getattr(self.experts, "gradient_scale", None)
+ alpha = getattr(self.experts, "alpha", 1.0)
+ moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+ uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
+
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = get_device_mesh(self)
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
+
+ has_parallel = (
+ expert_parallel_group is not None
+ and dist.is_initialized()
+ and dist.get_world_size(expert_parallel_group) > 1
+ )
+ forward_fn = parallel_forward_once if has_parallel else forward_once
+
+ sort_end_bit = max(
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
+ )
+ mlp_impl = getattr(self, "mlp_impl", "grouped")
+
+ output, expert_weights_out, *_ = moe_forward_with_shared_expert(
+ x=x,
+ router_weight=self.router.weight,
+ router_bias=self.router.bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=self.training,
+ w1=self.experts.gate_up_proj,
+ w2=self.experts.down_proj,
+ w1_bias=self.experts.gate_up_proj_bias,
+ w2_bias=self.experts.down_proj_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=has_parallel,
+ forward_fn=forward_fn,
+ hidden_size=self.experts.hidden_size,
+ mlp_impl=mlp_impl,
+ # Shared expert parameters
+ shared_up_proj_weight=self.shared_up_proj_weight,
+ shared_down_proj_weight=self.shared_down_proj_weight,
+ shared_up_proj_bias=self.shared_up_proj_bias,
+ shared_down_proj_bias=self.shared_down_proj_bias,
+ shared_expert_weighted_sum=self.shared_expert_weighted_sum,
+ shared_activation_fn=self.shared_activation_fn,
+ )
+ return output, expert_weights_out
+
+
+# Patch for XPU or CPU support
+if hasattr(torch, "xpu") and torch.xpu.is_available():
+ from .xpu_fused_moe import MegaBlocksMoeMLP
+
+from .cpu_moe_cpp import CPUMegaBlocksMoeMLP
diff --git a/build/torch29-cxx11-cu129-aarch64-linux/megablocks/__init__.py b/build/torch29-cxx11-cu129-aarch64-linux/megablocks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23
--- /dev/null
+++ b/build/torch29-cxx11-cu129-aarch64-linux/megablocks/__init__.py
@@ -0,0 +1,26 @@
+import ctypes
+import importlib.util
+import sys
+from pathlib import Path
+from types import ModuleType
+
+
+def _import_from_path(file_path: Path) -> ModuleType:
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
+ # it would also be used for other imports. So, we make a module name that
+ # depends on the path for it to be unique using the hex-encoded hash of
+ # the path.
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
+ module_name = path_hash
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ if spec is None:
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
+ module = importlib.util.module_from_spec(spec)
+ if module is None:
+ raise ImportError(f"Cannot load module {module_name} from spec")
+ sys.modules[module_name] = module
+ spec.loader.exec_module(module) # type: ignore
+ return module
+
+
+globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
diff --git a/build/torch29-cxx11-cu129-aarch64-linux/metadata.json b/build/torch29-cxx11-cu129-aarch64-linux/metadata.json
new file mode 100644
index 0000000000000000000000000000000000000000..e3e4edf582b7ffb515d0ed32e9fc9c89f125c441
--- /dev/null
+++ b/build/torch29-cxx11-cu129-aarch64-linux/metadata.json
@@ -0,0 +1,21 @@
+{
+ "version": 1,
+ "license": "Apache-2.0",
+ "python-depends": [],
+ "backend": {
+ "type": "cuda",
+ "archs": [
+ "10.0",
+ "10.1",
+ "12.0",
+ "7.0",
+ "7.2",
+ "7.5",
+ "8.0",
+ "8.6",
+ "8.7",
+ "8.9",
+ "9.0"
+ ]
+ }
+}
diff --git a/build/torch29-cxx11-cu129-aarch64-linux/ops/__init__.py b/build/torch29-cxx11-cu129-aarch64-linux/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b944080df810d0b0cfc571f3009b0098a651f9b7
--- /dev/null
+++ b/build/torch29-cxx11-cu129-aarch64-linux/ops/__init__.py
@@ -0,0 +1,35 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from .binned_gather import binned_gather
+from .binned_scatter import binned_scatter
+from .cumsum import exclusive_cumsum, inclusive_cumsum
+from .gather import gather
+from .histogram import histogram
+from .padded_gather import padded_gather
+from .padded_scatter import padded_scatter
+from .repeat import repeat
+from .replicate import replicate
+from .round_up import round_up
+from .scatter import scatter
+from .sort import sort
+from .sum import sum
+from .topology import topology
+
+__all__ = [
+ 'binned_gather',
+ 'binned_scatter',
+ 'exclusive_cumsum',
+ 'inclusive_cumsum',
+ 'gather',
+ 'histogram',
+ 'padded_gather',
+ 'padded_scatter',
+ 'repeat',
+ 'replicate',
+ 'round_up',
+ 'scatter',
+ 'sort',
+ 'sum',
+ 'topology',
+]
diff --git a/build/torch29-cxx11-cu129-aarch64-linux/ops/all_to_all_benchmark.py b/build/torch29-cxx11-cu129-aarch64-linux/ops/all_to_all_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c939818edca3345f6344bbc7cef07ffe3cd0181
--- /dev/null
+++ b/build/torch29-cxx11-cu129-aarch64-linux/ops/all_to_all_benchmark.py
@@ -0,0 +1,63 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import torch.distributed as dist
+
+# from megablocks import benchmark_util
+# from megablocks.layers.all_to_all import all_to_all
+
+from .. import benchmark_util
+from .._layers.all_to_all import all_to_all
+
+_ALL_TO_ALL_BENCHMARK = (
+ (8, 1024),
+ (16, 1024),
+ (32, 1024),
+ (64, 1024),
+ (128, 1024),
+ (256, 1024),
+ (512, 1024),
+ (1024, 1024),
+ (2 * 1024, 1024),
+ (4 * 1024, 1024),
+ (8 * 1024, 1024),
+ (16 * 1024, 1024),
+ (32 * 1024, 1024),
+ (64 * 1024, 1024),
+ (128 * 1024, 1024),
+ (256 * 1024, 1024),
+ (512 * 1024, 1024),
+ (1024 * 1024, 1024),
+)
+
+
+def benchmark_all_to_all(group, sl, hs):
+ world_size = dist.get_world_size(group)
+ assert (sl % world_size) == 0
+ send_recv_sizes = [sl // world_size] * world_size
+
+ x = torch.randn((sl, hs)).cuda().half()
+
+ details = {
+ 'world_size': world_size,
+ 'message_size (B)': send_recv_sizes[0] * hs * 2, # 2B elements.
+ }
+
+ def benchmark():
+ return all_to_all(x, send_recv_sizes, send_recv_sizes, group)
+
+ time, std = benchmark_util.benchmark_function(benchmark)
+
+ if dist.get_rank(group) == 0:
+ benchmark_util.log_benchmark('All-To-All', details, time, std)
+
+
+if __name__ == '__main__':
+ assert dist.is_available()
+ group = dist.init_process_group(backend='nccl')
+ local_rank = dist.get_rank(group)
+ torch.cuda.set_device(local_rank)
+
+ for args in _ALL_TO_ALL_BENCHMARK:
+ benchmark_all_to_all(group, *args)
diff --git a/build/torch29-cxx11-cu129-aarch64-linux/ops/binned_gather.py b/build/torch29-cxx11-cu129-aarch64-linux/ops/binned_gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..189a7fa3518d660f29ea32e7a04827164af98d60
--- /dev/null
+++ b/build/torch29-cxx11-cu129-aarch64-linux/ops/binned_gather.py
@@ -0,0 +1,37 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for binned_gather kernel.
+class BinnedGatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bins: torch.Tensor,
+ bin_size: int,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bins)
+ ctx.top_k = top_k
+ return kernels.binned_gather(x, indices, None, bins, bin_size, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ indices, bins = ctx.saved_tensors
+ out = kernels.binned_scatter(grad, indices, None, bins, ctx.top_k)
+ return out, None, None, None, None
+
+
+binned_gather = BinnedGatherOp.apply
diff --git a/build/torch29-cxx11-cu129-aarch64-linux/ops/binned_scatter.py b/build/torch29-cxx11-cu129-aarch64-linux/ops/binned_scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb937c0c106662ce8108c1cb926f8f063b163d3d
--- /dev/null
+++ b/build/torch29-cxx11-cu129-aarch64-linux/ops/binned_scatter.py
@@ -0,0 +1,59 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for binned_scatter kernel.
+class BinnedScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ):
+ assert len(x.size()) == 3
+ ctx.bin_size = x.size(1)
+ ctx.top_k = top_k
+
+ # TODO(tgale): Don't save 'x' for backwards if we don't need to
+ # calculate the gradient w.r.t. 'weights'.
+ ctx.save_for_backward(x, indices, weights, bins)
+ return kernels.binned_scatter(x, indices, weights, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ x, indices, weights, bins = ctx.saved_tensors
+ out = kernels.binned_gather(
+ grad,
+ indices,
+ weights,
+ bins,
+ ctx.bin_size,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[2]:
+ wgrad = kernels.binned_scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bins,
+ ctx.top_k,
+ )
+ return out, None, wgrad, None, None
+
+
+binned_scatter = BinnedScatterOp.apply
diff --git a/build/torch29-cxx11-cu129-aarch64-linux/ops/cumsum.py b/build/torch29-cxx11-cu129-aarch64-linux/ops/cumsum.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2b7572391e20045d335cf7337246e8a9b9f57ef
--- /dev/null
+++ b/build/torch29-cxx11-cu129-aarch64-linux/ops/cumsum.py
@@ -0,0 +1,52 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ # import megablocks_ops as ops # type: ignore
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrappers for cumsum kernels.
+# NOTE: Does not support gradients.
+class ExclusiveCumsumOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, dim: int):
+ if len(x.size()) == 1:
+ x = x.view([1, -1])
+ out = torch.empty_like(x)
+ ops.exclusive_cumsum(x, 1, out)
+ return out.squeeze()
+ out = torch.empty_like(x)
+ ops.exclusive_cumsum(x, dim, out)
+ return out
+
+
+exclusive_cumsum = ExclusiveCumsumOp.apply
+
+
+class InclusiveCumsumOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, dim: int) -> torch.Tensor:
+ if len(x.size()) == 1:
+ x = x.view([1, -1])
+ out = torch.empty_like(x)
+ ops.inclusive_cumsum(x, 1, out)
+ return out.squeeze()
+ out = torch.empty_like(x)
+ ops.inclusive_cumsum(x, dim, out)
+ return out
+
+
+inclusive_cumsum = InclusiveCumsumOp.apply
diff --git a/build/torch29-cxx11-cu129-aarch64-linux/ops/gather.py b/build/torch29-cxx11-cu129-aarch64-linux/ops/gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1f87c1e7bed8d3589dd790805234976e0b05898
--- /dev/null
+++ b/build/torch29-cxx11-cu129-aarch64-linux/ops/gather.py
@@ -0,0 +1,38 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for gather kernel.
+class GatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bin_ids, bins)
+ ctx.top_k = top_k
+ return kernels.gather(x, indices, bin_ids, None, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+
+ indices, bin_ids, bins = ctx.saved_tensors
+ out = kernels.scatter(grad, indices, bin_ids, None, bins, ctx.top_k)
+ return out, None, None, None, None, None
+
+
+gather = GatherOp.apply
diff --git a/build/torch29-cxx11-cu129-aarch64-linux/ops/histogram.py b/build/torch29-cxx11-cu129-aarch64-linux/ops/histogram.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b3f058ec373cbba7555704fb5e4212c3cc75d9d
--- /dev/null
+++ b/build/torch29-cxx11-cu129-aarch64-linux/ops/histogram.py
@@ -0,0 +1,27 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for histogram kernel.
+# NOTE: Does not support gradients.
+class HistogramOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, max_val: float):
+ return ops.histogram(x, max_val)
+
+
+histogram = HistogramOp.apply
diff --git a/build/torch29-cxx11-cu129-aarch64-linux/ops/histogram_benchmark.py b/build/torch29-cxx11-cu129-aarch64-linux/ops/histogram_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0b9c6047567b87a295979498142230d1b0c9411
--- /dev/null
+++ b/build/torch29-cxx11-cu129-aarch64-linux/ops/histogram_benchmark.py
@@ -0,0 +1,78 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import numpy as np
+import torch
+# from absl.testing import parameterized
+
+from .. import ops
+
+_HISTOGRAM_TESTS = (
+ (16384, torch.int32, 2),
+ (16384, torch.int32, 4),
+ (16384, torch.int32, 8),
+ (16384, torch.int32, 16),
+ (16384, torch.int32, 32),
+ (16384, torch.int32, 64),
+ (16384, torch.int32, 128),
+ (16384, torch.int32, 256),
+)
+
+
+def benchmark_function(fn, iterations=10):
+ # Run once to get rid of startup overhead.
+ fn()
+ times = []
+ for _ in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+ start.record()
+ fn()
+ end.record()
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ times = np.array(times)
+ return times.mean(), times.std(), times.max(), times.min()
+
+
+def log_benchmark(arguments, mean_t, std_t):
+ print('=' * 60)
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t))
+ print('=' * 60)
+
+
+# class HistogramBenchmark(parameterized.TestCase):
+#
+# @parameterized.parameters(*_HISTOGRAM_TESTS)
+# def testHistogram(self, n, dtype, max_val):
+# x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
+#
+# mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.histogram(x, max_val),)
+# arguments = {
+# 'n': n,
+# 'dtype': dtype,
+# 'max_val': max_val,
+# }
+# log_benchmark(arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_HISTOGRAM_TESTS)
+# def testTorchHistogram(self, n, dtype, max_val):
+# x = torch.randint(0, 128, (n,)).cuda().to(dtype)
+#
+# mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.histc(x, max_val, 0, max_val - 1),)
+# arguments = {
+# 'n': n,
+# 'dtype': dtype,
+# 'max_val': max_val,
+# }
+# log_benchmark(arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch29-cxx11-cu129-aarch64-linux/ops/matmul_benchmark.py b/build/torch29-cxx11-cu129-aarch64-linux/ops/matmul_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c4b9b8866ffed2eb769b77f2320c82e5491ae0e
--- /dev/null
+++ b/build/torch29-cxx11-cu129-aarch64-linux/ops/matmul_benchmark.py
@@ -0,0 +1,415 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+
+# import stk
+
+# try:
+# import stk
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/ops/matmul_benchmark.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+# from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+
+# Calling tensor.t() calls tensor.transpose(0, 1) which calls
+# torch.as_strided(...). Circumvent this chain to avoid an overhead
+# this adds.
+def transpose_view(x):
+ return torch.as_strided(
+ x,
+ (x.shape[1], x.shape[0]),
+ (x.stride()[1], x.stride()[0]),
+ )
+
+
+_MATMUL_TESTS = (
+ (64 * 1024, 512, 2048, 64),
+ (32 * 1024, 768, 3072, 64),
+ (8 * 1024, 1024, 4096, 64),
+ (4 * 2048, 4096, 4 * 4096, 4),
+)
+
+
+def log_benchmark(name, arguments, time, std, flops):
+ benchmark_util.log_benchmark(name, arguments, time, std)
+ print('flops = {:.2f}B'.format(flops / 1e9))
+ print('throughput = {:.2f}T'.format(flops / 1e9 / time))
+ print('=' * 60)
+
+
+# class MatmulBenchmark(parameterized.TestCase):
+#
+# def build_sparse_matrix(self, x, padded_bins, fhs, ne):
+# blocking = 128
+# padded_tokens, _ = x.size()
+# assert padded_tokens % blocking == 0
+# assert fhs % blocking == 0
+#
+# # Offsets for the sparse matrix. All rows have the
+# # same number of nonzero blocks dictated by the
+# # dimensionality of a single expert.
+# block_rows = padded_tokens // blocking
+# blocks_per_row = fhs // blocking
+# offsets = torch.arange(
+# 0,
+# block_rows * blocks_per_row + 1,
+# blocks_per_row,
+# dtype=torch.int32,
+# device=x.device,
+# )
+#
+# # Indices for the sparse matrix. The indices for
+# # the intermediate matrix are dynamic depending
+# # on the mapping of tokens to experts.
+# column_indices = ops.topology(
+# padded_bins,
+# blocking,
+# block_rows,
+# blocks_per_row,
+# )
+# data = torch.empty(
+# column_indices.numel(),
+# blocking,
+# blocking,
+# dtype=torch.float16,
+# device=x.device,
+# )
+# shape = (padded_tokens, fhs * ne)
+# row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
+# return stk.Matrix(shape, data, row_indices, column_indices, offsets)
+#
+# def build_input_matrix(self, sl, hs, ne):
+# x = torch.randn((sl, hs)).cuda().half()
+#
+# # Assign tokens to experts uniformly.
+# top_expert = torch.arange(0, sl).cuda().int() % ne
+#
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(top_expert, ne)
+# padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+# padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+# out = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, 1)
+# return out, padded_bins
+#
+# def build_weight_matrix(self, ne, hs, fhs):
+# return torch.randn((hs, ne * fhs)).cuda().half()
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_Fwd_SDD_NT(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+# w = transpose_view(w)
+#
+# def benchmark():
+# return stk.ops.sdd(x, w, topo)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0::Fwd::SDD::NT',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_GradX_DSD_NN(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+#
+# def benchmark():
+# return stk.ops.dsd(topo, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0::GradX::DSD::NN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_GradW_DSD_TN(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+# topo = topo.t()
+#
+# def benchmark():
+# return stk.ops.dsd(topo, x)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0::GradW::DSD::TN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_Fwd_DSD_NN(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+#
+# def benchmark():
+# return stk.ops.dsd(x, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::Fwd::DSD::NN',
+# arguments,
+# mean_t,
+# std_t,
+# x.nnz * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_GradX_SDD_NT(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+# out = stk.ops.dsd(x, w)
+# w = transpose_view(w)
+#
+# def benchmark():
+# return stk.ops.sdd(out, w, x)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::GradX::SDD::NT',
+# arguments,
+# mean_t,
+# std_t,
+# x.nnz * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_GradW_DSD_TN(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+# out = stk.ops.dsd(x, w)
+# x = x.t()
+#
+# def benchmark():
+# return stk.ops.dsd(x, out)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::GradW::DSD::TN',
+# arguments,
+# mean_t,
+# std_t,
+# x.nnz * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_Fwd_DDD_NT(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, hs)).cuda().half()
+# w = torch.randn((ne, hs, fhs)).cuda().half()
+#
+# w = w.transpose(1, 2).contiguous()
+# w = w.transpose(1, 2)
+#
+# def benchmark():
+# return torch.bmm(x, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0::Fwd:DDD::NT',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_GradX_DDD_NN(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, hs)).cuda().half()
+# w = torch.randn((ne, hs, fhs)).cuda().half()
+# out = torch.bmm(x, w)
+# w = w.transpose(1, 2).contiguous()
+#
+# def benchmark():
+# return torch.bmm(out, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0:GradX:DDD::NN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_GradW_DDD_TN(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, hs)).cuda().half()
+# w = torch.randn((ne, hs, fhs)).cuda().half()
+# out = torch.bmm(x, w)
+# out = out.transpose(1, 2)
+#
+# def benchmark():
+# return torch.bmm(out, x)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0:GradW:DDD::TN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_Fwd_DDD_NN(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+# w = torch.randn((ne, fhs, hs)).cuda().half()
+#
+# def benchmark():
+# return torch.bmm(x, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::Fwd::DDD::NN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_GradX_DDD_NT(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+# w = torch.randn((ne, fhs, hs)).cuda().half()
+# out = torch.bmm(x, w)
+# w = torch.transpose(w, 1, 2)
+#
+# def benchmark():
+# return torch.bmm(out, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::GradX::DDD::NT',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_GradW_DDD_TN(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+# w = torch.randn((ne, fhs, hs)).cuda().half()
+# out = torch.bmm(x, w)
+# x = torch.transpose(x, 1, 2)
+#
+# def benchmark():
+# return torch.bmm(x, out)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::GradW::DDD::TN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * hs * 2,
+# )
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch29-cxx11-cu129-aarch64-linux/ops/padded_gather.py b/build/torch29-cxx11-cu129-aarch64-linux/ops/padded_gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1cf4047c9494394d2a3884ba8830179013db7ff
--- /dev/null
+++ b/build/torch29-cxx11-cu129-aarch64-linux/ops/padded_gather.py
@@ -0,0 +1,55 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for padded_gather kernel.
+class PaddedGatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bin_ids, bins, padded_bins)
+ ctx.top_k = top_k
+ return kernels.padded_gather(
+ x,
+ indices,
+ bin_ids,
+ None,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+
+ indices, bin_ids, bins, padded_bins = ctx.saved_tensors
+ out = kernels.padded_scatter(
+ grad,
+ indices,
+ bin_ids,
+ None,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+ return out, None, None, None, None, None
+
+
+padded_gather = PaddedGatherOp.apply
diff --git a/build/torch29-cxx11-cu129-aarch64-linux/ops/padded_scatter.py b/build/torch29-cxx11-cu129-aarch64-linux/ops/padded_scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..61e021b81497e472cda5d72bdac557a0ca92d262
--- /dev/null
+++ b/build/torch29-cxx11-cu129-aarch64-linux/ops/padded_scatter.py
@@ -0,0 +1,98 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for padded_scatter kernel.
+class PaddedScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+ ):
+ maybe_x = [x] if ctx.needs_input_grad[3] else []
+ ctx.save_for_backward(
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ *maybe_x,
+ )
+ ctx.top_k = top_k
+ ctx.x_shape = x.shape
+ return kernels.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ saved_tensors = ctx.saved_tensors
+
+ indices, bin_ids, weights, bins, padded_bins = saved_tensors[:5]
+ dgrad = None
+ if ctx.needs_input_grad[0]:
+ dgrad = kernels.padded_gather(
+ grad,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[3]: # need wgrad
+ x = saved_tensors[-1]
+ wgrad = kernels.padded_scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+ return dgrad, None, None, wgrad, None, None, None, None
+
+
+def padded_scatter(
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+):
+ return PaddedScatterOp.apply(
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
diff --git a/build/torch29-cxx11-cu129-aarch64-linux/ops/padded_scatter_benchmark.py b/build/torch29-cxx11-cu129-aarch64-linux/ops/padded_scatter_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..3cbe4735891446b46f93170c64c23fe63632bf93
--- /dev/null
+++ b/build/torch29-cxx11-cu129-aarch64-linux/ops/padded_scatter_benchmark.py
@@ -0,0 +1,66 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import torch
+# from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+_PADDED_SCATTER_BENCHMARK = (
+ # dMoE-Medium, 8-way EMP.
+ (1024 * 16, 1024, 8, 4),
+ # dMoE-Medium, post-all-to-all.
+ (1024 * 16 * 4, 1024, 8, 1),
+)
+
+
+# class PaddedScatterTest(parameterized.TestCase):
+#
+# @parameterized.parameters(*_PADDED_SCATTER_BENCHMARK)
+# def testPaddedScatter(self, sl, hs, ne, top_k):
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+#
+# # Randomly assign tokens to experts.
+# top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(top_expert, ne)
+# padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+# padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+#
+# # Sample weights for the scatter reduce.
+# weights = torch.rand((sl * top_k,)).cuda().half()
+#
+# # Gather the data to prepare for backwards.
+# x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
+#
+# def benchmark():
+# return ops.padded_scatter(
+# x,
+# indices,
+# bin_ids,
+# weights,
+# bins,
+# padded_bins,
+# top_k,
+# )
+#
+# time, std = benchmark_util.benchmark_function(benchmark)
+# benchmark_util.log_benchmark(
+# 'Padded Scatter',
+# {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# 'top_k': top_k,
+# },
+# time,
+# std,
+# )
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch29-cxx11-cu129-aarch64-linux/ops/permute_benchmark.py b/build/torch29-cxx11-cu129-aarch64-linux/ops/permute_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..697abddbb3a2082ec4ddd6d94f89f7faabb34b40
--- /dev/null
+++ b/build/torch29-cxx11-cu129-aarch64-linux/ops/permute_benchmark.py
@@ -0,0 +1,149 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import torch
+# from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+_PERMUTE_TESTS = (
+ (16384, 768, 2),
+ (16384, 768, 4),
+ (16384, 768, 8),
+ (16384, 768, 16),
+ (16384, 768, 32),
+ (16384, 768, 64),
+ (16384, 768, 128),
+ (16384 * 8, 768, 2),
+ (16384 * 8, 768, 4),
+ (16384 * 8, 768, 8),
+ (16384 * 8, 768, 16),
+ (16384 * 8, 768, 32),
+ (16384 * 8, 768, 64),
+ (16384 * 8, 768, 128),
+)
+
+
+# class PermuteBenchmark(parameterized.TestCase):
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testBinnedGather(self, sl, hs, ne):
+# # NOTE: Capacity factor == 1.
+# ec = sl // ne
+#
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+# top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(indices, ne)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+#
+# def benchmark():
+# return ops.binned_gather(x, indices, bins, ec)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('BinnedGather', arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testBinnedScatter(self, sl, hs, ne):
+# # NOTE: Capacity factor == 1.
+# ec = sl // ne
+#
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+# top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(indices, ne)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+# x = ops.binned_gather(x, indices, bins, ec)
+#
+# def benchmark():
+# return ops.binned_scatter(x, indices, bins)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('BinnedScatter', arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testPaddedGather(self, sl, hs, ne):
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+#
+# # Randomly assign tokens to experts.
+# top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(top_expert, ne)
+# padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+# padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+#
+# def benchmark():
+# return ops.padded_gather(x, indices, bin_ids, bins, padded_bins)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('PaddedGather', arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testPaddedScatter(self, sl, hs, ne):
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+#
+# # Randomly assign tokens to experts.
+# top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(top_expert, ne)
+# padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+# padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+# x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins)
+#
+# def benchmark():
+# return ops.padded_scatter(x, indices, bin_ids, bins, padded_bins)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('PaddedScatter', arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testCopy(self, sl, hs, ne):
+# # NOTE: Capacity factor == 1.
+# # ec = sl // ne
+#
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+# y = x.clone()
+#
+# def benchmark():
+# return y.copy_(x)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('Copy', arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch29-cxx11-cu129-aarch64-linux/ops/repeat.py b/build/torch29-cxx11-cu129-aarch64-linux/ops/repeat.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e9e09de5f857d51cd758ab30b2f3a846d6f9275
--- /dev/null
+++ b/build/torch29-cxx11-cu129-aarch64-linux/ops/repeat.py
@@ -0,0 +1,10 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+
+def repeat(x: torch.Tensor, tiling: torch.Size):
+ if all((t == 1 for t in tiling)):
+ return x
+ return x.repeat(*tiling)
diff --git a/build/torch29-cxx11-cu129-aarch64-linux/ops/replicate.py b/build/torch29-cxx11-cu129-aarch64-linux/ops/replicate.py
new file mode 100644
index 0000000000000000000000000000000000000000..26daf0eede330603a4b8ea7167faf1411d07ca93
--- /dev/null
+++ b/build/torch29-cxx11-cu129-aarch64-linux/ops/replicate.py
@@ -0,0 +1,36 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for replicate kernel.
+class ReplicateOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, bins: torch.Tensor, num_outputs: int):
+ ctx.save_for_backward(bins)
+ out = torch.empty((x.shape[0], num_outputs), dtype=x.dtype, device=x.device)
+ ops.replicate_forward(x, bins, out)
+ return out
+
+ @staticmethod
+ def backward(ctx: Any, grad: torch.Tensor):
+ bins, = ctx.saved_tensors
+ out = torch.empty((grad.shape[0], bins.shape[0]), dtype=grad.dtype, device=grad.device)
+ ops.replicate_backward(grad, bins, out)
+ return out, None, None
+
+
+replicate = ReplicateOp.apply
diff --git a/build/torch29-cxx11-cu129-aarch64-linux/ops/round_up.py b/build/torch29-cxx11-cu129-aarch64-linux/ops/round_up.py
new file mode 100644
index 0000000000000000000000000000000000000000..6cf6bc873c9f448c5fa9126ebcfd66e8688002af
--- /dev/null
+++ b/build/torch29-cxx11-cu129-aarch64-linux/ops/round_up.py
@@ -0,0 +1,14 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+
+def round_up(x: torch.Tensor, value: int):
+ assert isinstance(value, int)
+ assert x.dtype == torch.int32
+
+ # TODO(tgale): If this becomes and issue
+ # do this in a custom kernel. We only expect
+ # to use this on arrays of less than 1k elements.
+ return torch.div(x + (value - 1), value, rounding_mode='trunc') * value
diff --git a/build/torch29-cxx11-cu129-aarch64-linux/ops/scatter.py b/build/torch29-cxx11-cu129-aarch64-linux/ops/scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4605d9b46f387761b070352365f223dbfe69d47
--- /dev/null
+++ b/build/torch29-cxx11-cu129-aarch64-linux/ops/scatter.py
@@ -0,0 +1,72 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Optional
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for scatter kernel.
+class ScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ) -> torch.Tensor:
+ maybe_x = [x] if ctx.needs_input_grad[3] else []
+ ctx.save_for_backward(indices, bin_ids, weights, bins, *maybe_x)
+ ctx.top_k = top_k
+ ctx.x_shape = x.shape
+ return kernels.scatter(x, indices, bin_ids, weights, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ saved_tensors = ctx.saved_tensors
+
+ indices, bin_ids, weights, bins = saved_tensors[:4]
+ dgrad = None
+ if ctx.needs_input_grad[0]:
+ dgrad = kernels.gather(
+ grad,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[3]: # need wgrad
+ x = saved_tensors[-1]
+ wgrad = kernels.scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bin_ids,
+ bins,
+ ctx.top_k,
+ )
+ return dgrad, None, None, wgrad, None, None, None
+
+
+def scatter(
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+) -> Optional[torch.Tensor]:
+ return ScatterOp.apply(x, indices, bin_ids, weights, bins, top_k)
diff --git a/build/torch29-cxx11-cu129-aarch64-linux/ops/sort.py b/build/torch29-cxx11-cu129-aarch64-linux/ops/sort.py
new file mode 100644
index 0000000000000000000000000000000000000000..bda3bf64283e39533c2eae3627e76bb2d0262c9f
--- /dev/null
+++ b/build/torch29-cxx11-cu129-aarch64-linux/ops/sort.py
@@ -0,0 +1,38 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Optional, Tuple
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+_BITS_FOR_DTYPE = {
+ torch.int16: 16,
+ torch.int32: 32,
+ torch.int64: 64,
+}
+
+
+# Autograd wrapper for sort kernel.
+# NOTE: Does not support gradients.
+class SortOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, end_bit: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
+ if end_bit is None:
+ end_bit = _BITS_FOR_DTYPE[x.dtype]
+ x_out = torch.empty_like(x)
+ iota_out = torch.empty_like(x)
+ ops.sort(x, end_bit, x_out, iota_out)
+ return (x_out, iota_out)
+
+
+sort = SortOp.apply
diff --git a/build/torch29-cxx11-cu129-aarch64-linux/ops/sort_benchmark.py b/build/torch29-cxx11-cu129-aarch64-linux/ops/sort_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..11043c0824c36372585f1d9f48480c2a6ef32eb6
--- /dev/null
+++ b/build/torch29-cxx11-cu129-aarch64-linux/ops/sort_benchmark.py
@@ -0,0 +1,85 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import numpy as np
+import torch
+# from absl.testing import parameterized
+
+from .. import ops
+
+_SORT_TESTS = (
+ (16384, torch.int32, None),
+ (16384, torch.int32, 2),
+ (16384, torch.int32, 128),
+)
+
+_BASELINE_SORT_TESTS = ((16384,),)
+
+
+def numpy_dtype(dtype):
+ types = {
+ torch.int16: np.int16,
+ torch.int32: np.int32,
+ torch.int64: np.int64,
+ }
+ return types[dtype]
+
+
+def benchmark_function(fn, iterations=10):
+ # Run once to get rid of startup overhead.
+ fn()
+ times = []
+ for _ in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+ start.record()
+ fn()
+ end.record()
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ times = np.array(times)
+ return times.mean(), times.std(), times.max(), times.min()
+
+
+def log_benchmark(arguments, mean_t, std_t):
+ print('=' * 60)
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t))
+ print('=' * 60)
+
+
+# class SortBenchmark(parameterized.TestCase):
+#
+# @parameterized.parameters(*_SORT_TESTS)
+# def testSort(self, n, dtype, max_val):
+# if max_val is None:
+# max_val = np.iinfo(numpy_dtype(dtype)).max
+# end_bit = int(np.ceil(np.log2(max_val)))
+# x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
+#
+# mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.sort(x, end_bit),)
+# arguments = {
+# 'n': n,
+# 'dtype': dtype,
+# 'max_val': max_val,
+# }
+# log_benchmark(arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_BASELINE_SORT_TESTS)
+# def testTorchSort(self, n):
+# x = torch.randint(0, 128, (n,)).cuda().to(torch.int32)
+#
+# mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.sort(x))
+# arguments = {
+# 'n': n,
+# }
+# log_benchmark(arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch29-cxx11-cu129-aarch64-linux/ops/stk_autocast.py b/build/torch29-cxx11-cu129-aarch64-linux/ops/stk_autocast.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a3626e5e0eec51339c95a448bca84be14a2ca93
--- /dev/null
+++ b/build/torch29-cxx11-cu129-aarch64-linux/ops/stk_autocast.py
@@ -0,0 +1,39 @@
+# vendored from
+# https://github.com/stanford-futuredata/stk/blob/736313768ef697ce13a0594a41b2512a0fbc9884/stk/backend/autocast.py
+import functools
+import torch
+
+
+def _is_eligible(x):
+ return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64)
+
+
+def _cast(x, dtype):
+ if isinstance(x, torch.Tensor) and _is_eligible(x):
+ return x.to(dtype)
+ elif isinstance(x, map):
+ return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()}
+ elif isinstance(x, list) or isinstance(x, tuple):
+ return type(x)(map(lambda y: _cast(y, dtype), x))
+ return x
+
+
+def custom_fwd(fwd):
+ """Wrap a custom autograd function that always uses autocast dtype."""
+
+ @functools.wraps(fwd)
+ def decorate_fwd(*args, **kwargs):
+ if torch.is_autocast_enabled():
+ with torch.autocast(device_type="cuda", enabled=False):
+ dtype = torch.get_autocast_gpu_dtype()
+ return fwd(*_cast(args, dtype), **_cast(kwargs, dtype))
+ return fwd(*args, **kwargs)
+ return decorate_fwd
+
+
+def custom_bwd(bwd):
+ @functools.wraps(bwd)
+ def decorate_bwd(*args, **kwargs):
+ with torch.autocast(device_type="cuda", enabled=False):
+ return bwd(*args, **kwargs)
+ return decorate_bwd
\ No newline at end of file
diff --git a/build/torch29-cxx11-cu129-aarch64-linux/ops/sum.py b/build/torch29-cxx11-cu129-aarch64-linux/ops/sum.py
new file mode 100644
index 0000000000000000000000000000000000000000..e00c1aa68e1193f5b72f75a2edc37de8d505facc
--- /dev/null
+++ b/build/torch29-cxx11-cu129-aarch64-linux/ops/sum.py
@@ -0,0 +1,9 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+import torch
+
+
+def sum(x: torch.Tensor, dim: int = 0):
+ if x.shape[dim] == 1:
+ return x.squeeze(dim=dim)
+ return x.sum(dim=dim)
diff --git a/build/torch29-cxx11-cu129-aarch64-linux/ops/topology.py b/build/torch29-cxx11-cu129-aarch64-linux/ops/topology.py
new file mode 100644
index 0000000000000000000000000000000000000000..76a50d3164db20534b099dcb4d8487a7aef25d15
--- /dev/null
+++ b/build/torch29-cxx11-cu129-aarch64-linux/ops/topology.py
@@ -0,0 +1,45 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for topology kernel.
+# NOTE: Does not support gradients.
+class TopologyOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(
+ ctx: Any,
+ padded_bins: torch.Tensor,
+ block_size: int,
+ output_block_rows: int,
+ output_block_columns: int,
+ ):
+ out = torch.empty(
+ output_block_rows * output_block_columns,
+ dtype=torch.int16,
+ device=padded_bins.device,
+ )
+ ops.indices(
+ padded_bins,
+ block_size,
+ output_block_rows,
+ output_block_columns,
+ out,
+ )
+ return out
+
+
+topology = TopologyOp.apply
diff --git a/build/torch29-cxx11-cu129-aarch64-linux/stk/__init__.py b/build/torch29-cxx11-cu129-aarch64-linux/stk/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..73c40c267b1c3f4949e9c957a5d2c9f682dfc1a6
--- /dev/null
+++ b/build/torch29-cxx11-cu129-aarch64-linux/stk/__init__.py
@@ -0,0 +1,7 @@
+# import stk.random
+# import stk.ops
+# from stk.matrix import Matrix
+
+from . import random
+from . import ops
+from .matrix import Matrix
diff --git a/build/torch29-cxx11-cu129-aarch64-linux/stk/backend/__init__.py b/build/torch29-cxx11-cu129-aarch64-linux/stk/backend/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/build/torch29-cxx11-cu129-aarch64-linux/stk/backend/autocast.py b/build/torch29-cxx11-cu129-aarch64-linux/stk/backend/autocast.py
new file mode 100644
index 0000000000000000000000000000000000000000..97f6e919a60f3fd579ed0215031008d14111dc96
--- /dev/null
+++ b/build/torch29-cxx11-cu129-aarch64-linux/stk/backend/autocast.py
@@ -0,0 +1,37 @@
+import functools
+import torch
+
+
+def _is_eligible(x):
+ return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64)
+
+
+def _cast(x, dtype):
+ if isinstance(x, torch.Tensor) and _is_eligible(x):
+ return x.to(dtype)
+ elif isinstance(x, map):
+ return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()}
+ elif isinstance(x, list) or isinstance(x, tuple):
+ return type(x)(map(lambda y: _cast(y, dtype), x))
+ return x
+
+
+def custom_fwd(fwd):
+ """Wrap a custom autograd function that always uses autocast dtype."""
+
+ @functools.wraps(fwd)
+ def decorate_fwd(*args, **kwargs):
+ if torch.is_autocast_enabled():
+ with torch.autocast(device_type="cuda", enabled=False):
+ dtype = torch.get_autocast_gpu_dtype()
+ return fwd(*_cast(args, dtype), **_cast(kwargs, dtype))
+ return fwd(*args, **kwargs)
+ return decorate_fwd
+
+
+def custom_bwd(bwd):
+ @functools.wraps(bwd)
+ def decorate_bwd(*args, **kwargs):
+ with torch.autocast(device_type="cuda", enabled=False):
+ return bwd(*args, **kwargs)
+ return decorate_bwd
diff --git a/build/torch29-cxx11-cu129-aarch64-linux/stk/backend/sputnik.py b/build/torch29-cxx11-cu129-aarch64-linux/stk/backend/sputnik.py
new file mode 100644
index 0000000000000000000000000000000000000000..220c947bc1e932e8c77cc30f4069e9930f1aa962
--- /dev/null
+++ b/build/torch29-cxx11-cu129-aarch64-linux/stk/backend/sputnik.py
@@ -0,0 +1,316 @@
+import torch
+
+from ..backend import triton_kernels as backend
+from ..backend.autocast import custom_bwd, custom_fwd
+
+
+def _standardize_shape(x, transpose):
+ if transpose:
+ return torch.Size((x[1], x[0]))
+ return x
+
+
+def _sparse_transpose(x):
+ return (torch.Size((x[0][1], x[0][0])), ) + x[1:]
+
+
+def _transpose_helper(x, transpose):
+ if isinstance(x, torch.Tensor):
+ return x.t() if transpose else x
+ if transpose:
+ x = _sparse_transpose(x)
+ return x + (transpose,)
+
+
+def _wrap(x):
+ if isinstance(x, torch.Tensor):
+ return (x,)
+ return x
+
+
+def _is_transposed(x):
+ return (not x.is_contiguous() and
+ x.stride()[0] == 1 and
+ x.stride()[1] == x.size()[0])
+
+
+def _call_helper(op, out, a, b, trans_a, trans_b):
+ args = (_wrap(_transpose_helper(a, trans_a)) +
+ _wrap(_transpose_helper(b, trans_b)))
+ if isinstance(out, tuple):
+ args = args + out
+ return op(*args)
+
+
+def _preprocess_inputs(lhs, rhs, dy):
+ if isinstance(lhs, torch.Tensor) and _is_transposed(lhs):
+ lhs = lhs.t()
+ if isinstance(rhs, torch.Tensor) and _is_transposed(rhs):
+ rhs = rhs.t()
+ if (isinstance(dy, torch.Tensor) and
+ not dy.is_contiguous() and
+ not _is_transposed(dy)):
+ dy = dy.contiguous()
+ if isinstance(dy, tuple) and not dy[1].is_contiguous():
+ dy = (dy[0], dy[1].contiguous()) + dy[2:]
+ return lhs, rhs, dy
+
+
+def _postprocess_outputs(x, transpose, grad):
+ if isinstance(x, torch.Tensor) and transpose:
+ return grad.t()
+ return grad
+
+
+def _lhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs):
+ lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy)
+
+ a, b = (rhs, dy) if trans_lhs else (dy, rhs)
+ trans_a = trans_lhs and trans_rhs
+ trans_b = trans_lhs or not trans_rhs
+ out = _call_helper(op, lhs, a, b, trans_a, trans_b)
+ return _postprocess_outputs(lhs, trans_lhs, out)
+
+
+def _rhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs):
+ lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy)
+
+ a, b = (dy, lhs) if trans_rhs else (lhs, dy)
+ trans_a = not trans_lhs or trans_rhs
+ trans_b = trans_lhs and trans_rhs
+ out = _call_helper(op, rhs, a, b, trans_a, trans_b)
+ return _postprocess_outputs(rhs, trans_rhs, out)
+
+
+class DSD(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs):
+ ctx.save_for_backward(data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ rhs)
+ ctx.shape = _standardize_shape(shape, transpose_a)
+ ctx.transpose_a = transpose_a
+
+ out = torch.empty(
+ (shape[0], rhs.size()[1]),
+ dtype=rhs.dtype,
+ device=rhs.device)
+
+ backend.dsd(shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs,
+ out)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs = (ctx.shape,) + saved_tensors[:-1]
+ rhs = saved_tensors[-1]
+ trans_a = ctx.transpose_a
+ trans_b = _is_transposed(rhs)
+
+ ddata = None
+ if ctx.needs_input_grad[1]:
+ ddata = _lhs_gradient(sdd,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ drhs = None
+ if ctx.needs_input_grad[-1]:
+ op = dds if trans_b else dsd
+ drhs = _rhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return None, ddata, None, None, None, None, None, None, None, drhs
+
+
+dsd = DSD.apply
+
+
+class DDS(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b):
+ ctx.save_for_backward(lhs,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t)
+ ctx.shape = _standardize_shape(shape, transpose_b)
+ ctx.transpose_b = transpose_b
+ out = torch.empty((lhs.size()[0], shape[1]),
+ dtype=lhs.dtype,
+ device=lhs.device)
+ backend.dds(lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b,
+ out)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs = saved_tensors[0]
+ rhs = (ctx.shape,) + saved_tensors[1:]
+ trans_a = _is_transposed(lhs)
+ trans_b = ctx.transpose_b
+
+ dlhs = None
+ if ctx.needs_input_grad[0]:
+ op = dsd if trans_a else dds
+ dlhs = _lhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ ddata = None
+ if ctx.needs_input_grad[2]:
+ ddata = _rhs_gradient(sdd,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return dlhs, None, ddata, None, None, None, None, None, None, None
+
+
+dds = DDS.apply
+
+
+class SDD(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ lhs,
+ rhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t):
+ ctx.save_for_backward(
+ lhs,
+ rhs,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t)
+ ctx.shape = shape
+ out = torch.empty(
+ data.shape,
+ dtype=lhs.dtype,
+ device=lhs.device)
+ backend.sdd(lhs,
+ rhs,
+ shape,
+ out,
+ offsets,
+ row_indices,
+ column_indices)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs, rhs = saved_tensors[:2]
+ dy = (ctx.shape, dy) + saved_tensors[2:]
+ trans_a = _is_transposed(lhs)
+ trans_b = _is_transposed(rhs)
+
+ dlhs = None
+ if ctx.needs_input_grad[0]:
+ op = dds if trans_a else dsd
+ dlhs = _lhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ drhs = None
+ if ctx.needs_input_grad[1]:
+ op = dsd if trans_b else dds
+ drhs = _rhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return dlhs, drhs, None, None, None, None, None, None, None, None
+
+
+sdd = SDD.apply
+
+class RowIndices(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, shape, data, offsets, column_indices):
+ out = torch.empty(
+ column_indices.shape,
+ dtype=column_indices.dtype,
+ device=column_indices.device)
+ backend.row_indices(shape, data, offsets, column_indices, out)
+ return out
+
+
+row_indices = RowIndices.apply
diff --git a/build/torch29-cxx11-cu129-aarch64-linux/stk/backend/triton_kernels.py b/build/torch29-cxx11-cu129-aarch64-linux/stk/backend/triton_kernels.py
new file mode 100644
index 0000000000000000000000000000000000000000..c535309f3321249f475367164a558f94a4f8eb86
--- /dev/null
+++ b/build/torch29-cxx11-cu129-aarch64-linux/stk/backend/triton_kernels.py
@@ -0,0 +1,393 @@
+import torch
+import triton
+import triton.language as tl
+from dataclasses import dataclass
+
+@dataclass
+class TritonConfig:
+ BLOCK_M: int = 128
+ BLOCK_N: int = 128
+ BLOCK_K: int = 32
+ BLOCK_SIZE: int = 128
+ NUM_STAGES: int = 4
+ NUM_WARPS: int = 4
+
+def _validate_matmul_dims(M: int, K: int, N: int):
+ error_string = "incompatible dimensions: tensor has dim with length: {}, which must be divisible by {}"
+ assert M % TritonConfig.BLOCK_M == 0, error_string.format(M, TritonConfig.BLOCK_M)
+ assert K % TritonConfig.BLOCK_K == 0, error_string.format(K, TritonConfig.BLOCK_K)
+ assert N % TritonConfig.BLOCK_N == 0, error_string.format(N, TritonConfig.BLOCK_N)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _sdd_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+ # matrix multiplication
+ pid = tl.program_id(0)
+ pid_m = tl.load(row_indices + pid)
+ pid_n = tl.load(column_indices + pid)
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
+ rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
+ rk = tl.arange(0, BLOCK_K)
+ # pointers
+ A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
+ B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ for k in range(0, tl.cdiv(K, BLOCK_K)):
+ a = tl.load(A)
+ b = tl.load(B)
+ acc += tl.dot(a, b)
+ A += BLOCK_K * stride_ak
+ B += BLOCK_K * stride_bk
+ #Store to sparse matrix
+ acc = acc.to(C.dtype.element_ty)
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+ cm = tl.arange(0, BLOCK_M)
+ cn = tl.arange(0, BLOCK_N)
+ C = C + pid * BLOCK_ELEMENTS + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _dsd_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices, offsets,
+ block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+
+ # matrix multiplication
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+
+ num_pid_m = tl.num_programs(0)
+ num_pid_n = tl.num_programs(1)
+ pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M)
+
+ start_inx = tl.load(offsets + pid_m)
+ end_inx = tl.load(offsets + pid_m + 1)
+
+ # pointers to sparse matrix
+ rm = tl.arange(0, BLOCK_M)
+ rak = tl.arange(0, BLOCK_K)
+
+ A += (rm[:, None] * stride_am + rak[None, :] * stride_ak)
+
+ # pointers to dense matrix
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ rbk = tl.arange(0, BLOCK_K)
+ B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn)
+
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K)
+
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+ ak_sub_incr = BLOCK_K * stride_ak
+ bk_sub_incr = BLOCK_K * stride_bk
+ bk_block_incr = BLOCK_SIZE * stride_bk
+
+ for k in range(nsub_blocks * (end_inx - start_inx)):
+ sub_block_inx = k % nsub_blocks
+ block_inx = k // nsub_blocks
+
+ if trans_A:
+ ptr_A = A + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr
+ else:
+ ptr_A = A + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr
+
+ ptr_B = B + tl.load(column_indices + start_inx + block_inx) * bk_block_incr + sub_block_inx * bk_sub_incr
+
+ a = tl.load(ptr_A)
+ b = tl.load(ptr_B)
+ acc += tl.dot(a, b)
+
+ acc = acc.to(C.dtype.element_ty)
+
+ cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+
+ C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _dds_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices, offsets,
+ block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+
+ # matrix multiplication
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+
+ num_pid_m = tl.num_programs(0)
+ num_pid_n = tl.num_programs(1)
+ pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M)
+
+ start_inx = tl.load(offsets + pid_n)
+ end_inx = tl.load(offsets + pid_n + 1)
+
+ # pointers to dense matrix
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ rak = tl.arange(0, BLOCK_K)
+
+ A += (rm[:, None] * stride_am + rak[None, :] * stride_ak)
+
+ # pointers to sparse matrix
+ rn = tl.arange(0, BLOCK_N)
+ rbk = tl.arange(0, BLOCK_K)
+ B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn)
+
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K)
+
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+
+ ak_sub_incr = BLOCK_K * stride_ak
+ ak_block_incr = BLOCK_SIZE * stride_ak
+ bk_sub_incr = BLOCK_K * stride_bk
+
+ for k in range(nsub_blocks * (end_inx - start_inx)):
+ sub_block_inx = k % nsub_blocks
+ block_inx = k // nsub_blocks
+
+ if trans_B:
+ ptr_B = B + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr
+ else:
+ ptr_B = B + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr
+
+ ptr_A = A + tl.load(column_indices + start_inx + block_inx) * ak_block_incr + sub_block_inx * ak_sub_incr
+ a = tl.load(ptr_A)
+ b = tl.load(ptr_B)
+ acc += tl.dot(a, b)
+
+ acc = acc.to(C.dtype.element_ty)
+ cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+def dsd(shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs,
+ out
+ ):
+
+ device = rhs.device
+ trans_A = transpose_a
+ trans_B = False
+
+ if rhs.stride(0) > 1 and rhs.stride(1) > 1:
+ trans_B = True
+
+ # checks constraints
+ assert shape[1] == rhs.shape[0], "incompatible dimensions"
+ M, K = shape
+ _, N = rhs.shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if rhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ stride_am, stride_ak = data.stride(1), data.stride(2)
+ stride_bk, stride_bn = rhs.stride(0), rhs.stride(1)
+ a_column_indices = column_indices
+ a_offsets = offsets
+
+ # launch kernel
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N']))
+
+ if trans_A:
+ stride_am, stride_ak = data.stride(2), data.stride(1)
+ a_column_indices, a_offsets = column_indices_t, offsets_t
+
+ if trans_B:
+ stride_bk, stride_bn = rhs.stride(1), rhs.stride(0)
+
+ _dsd_kernel[grid](
+ data.data, rhs, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(0), out.stride(1),
+ row_indices, a_column_indices, a_offsets,
+ block_offsets_t, trans_A, trans_B,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+ # return out
+
+def dds(lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b,
+ out
+ ):
+
+ device = lhs.device
+ trans_B = transpose_b
+ trans_A = False
+
+ if lhs.stride(0) > 1 and lhs.stride(1) > 1:
+ trans_A = True
+
+ # checks constraints
+ assert lhs.shape[1] == shape[0], "incompatible dimensions"
+ M, K = lhs.shape
+ _, N = shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if lhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ stride_am, stride_ak = lhs.stride(0), lhs.stride(1)
+ stride_bk, stride_bn = data.stride(1), data.stride(2)
+ b_column_indices = column_indices_t
+ b_offsets = offsets_t
+
+ # launch kernel
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N']))
+
+ if trans_A:
+ stride_am, stride_ak = lhs.stride(1), lhs.stride(0)
+ if trans_B:
+ stride_bk, stride_bn = data.stride(2), data.stride(1)
+ b_column_indices, b_offsets = column_indices, offsets
+
+ _dds_kernel[grid](
+ lhs, data, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(0), out.stride(1),
+ row_indices, b_column_indices, b_offsets,
+ block_offsets_t, trans_A, trans_B,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+
+def sdd(lhs,
+ rhs,
+ shape,
+ out,
+ offsets,
+ row_indices,
+ column_indices
+ ):
+
+ device = out.device
+ trans_A = False
+ trans_B = False
+
+ if lhs.stride(0) > 1 and lhs.stride(1) > 1:
+ trans_A = True
+ if rhs.stride(0) > 1 and rhs.stride(1) > 1:
+ trans_B = True
+
+ # checks constraints
+ assert lhs.shape[1] == rhs.shape[0], "incompatible dimensions"
+ M, K = lhs.shape
+ _, N = rhs.shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if out.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ # launch kernel
+ nnz_blocks = len(row_indices)
+ grid = lambda META: (nnz_blocks,)
+
+ stride_am, stride_ak = lhs.stride(0), lhs.stride(1)
+ stride_bk, stride_bn = rhs.stride(0), rhs.stride(1)
+
+ if trans_A:
+ stride_am, stride_ak = lhs.stride(1), lhs.stride(0)
+ if trans_B:
+ stride_bk, stride_bn = rhs.stride(1), rhs.stride(0)
+
+ _sdd_kernel[grid](
+ lhs, rhs, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(1), out.stride(2),
+ row_indices, column_indices,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+
+@triton.jit
+def _row_indices_kernel(offsets, out):
+ pid = tl.program_id(0)
+ row_offset = tl.load(offsets + pid)
+ nnz_blocks = tl.load(offsets + pid + 1) - row_offset
+ for nnz_block in range(nnz_blocks):
+ tl.store(out + row_offset + nnz_block, pid)
+
+def row_indices(
+ shape, data, offsets, column_indices, out
+):
+ block_rows = len(offsets) - 1
+ _row_indices_kernel[(block_rows, )](offsets, out)
diff --git a/build/torch29-cxx11-cu129-aarch64-linux/stk/matrix.py b/build/torch29-cxx11-cu129-aarch64-linux/stk/matrix.py
new file mode 100644
index 0000000000000000000000000000000000000000..80f42263d6aada287adbfa52a61fe950162a9e28
--- /dev/null
+++ b/build/torch29-cxx11-cu129-aarch64-linux/stk/matrix.py
@@ -0,0 +1,329 @@
+import numpy as np
+import torch
+
+# 1. Add heavyweight (data) validation helper.
+# 2. Add construction helpers
+# 3. Make indentation consistent
+# 4. Replace asserts with descriptive errors.
+
+##
+### Validation helpers.
+##
+
+
+def _validate_matrix(shape, data, row_indices, column_indices, offsets):
+ # Data should be [nnz, block_size, block_size]
+ if data.dim() == 1:
+ data = torch.reshape(data, [data.numel(), 1, 1])
+
+ # Blocks should be square.
+ if data.shape[-2] != data.shape[-1]:
+ raise ValueError(
+ "Expected square blocking in data. "
+ f"Got block shape {[data.shape[-2], data.shape[-1]]}")
+
+ # Flatten batch dimensions on data - original shape preserved
+ # in shape argument.
+ block_size = data.shape[-1]
+ data = data.view([-1, block_size, block_size])
+
+ if data.dim() != 3:
+ raise ValueError(
+ "Expected 3D shape for data (nnz, block, block). "
+ f"Got shape {data.dim()}D shape.")
+
+ block_size = data.shape[1]
+ if shape[-2] % block_size != 0 or shape[-1] % block_size != 0:
+ raise ValueError(
+ "Matrix shape must be dividible by blocking. "
+ f"Got shape {shape} with "
+ f"{[block_size, block_size]} blocking.")
+
+ if np.prod(shape) < data.numel():
+ raise ValueError(
+ "Invalid matrix. Number of nonzeros exceeds matrix capacity "
+ f"({data.numel()} v. {np.prod(shape)})")
+
+ if row_indices.dim() != 1:
+ raise ValueError(
+ f"Expected 1D row_indices. Got {row_indices.dim()}D row_indices.")
+
+ if column_indices.dim() != 1:
+ raise ValueError(
+ f"Expected 1D column_indices. Got {column_indices.dim()}D column_indices.")
+
+ if offsets.dim() != 1:
+ raise ValueError(
+ f"Expected 1D offsets. Got {offsets.dim()}D offsets.")
+
+ if row_indices.numel() != data.shape[0]:
+ raise ValueError(
+ "Expected 1 index per nonzero block. "
+ f"Got {row_indices.numel()} row_indices for {data.shape[0]} blocks")
+
+ if column_indices.numel() != data.shape[0]:
+ raise ValueError(
+ "Expected 1 index per nonzero block. "
+ f"Got {column_indices.numel()} column_indices for {data.shape[0]} blocks")
+
+ block_rows = np.prod(shape[:-1]) / block_size
+ if offsets.numel() != block_rows + 1:
+ raise ValueError(
+ "Expected one offset per block row plus one. "
+ f"Got {offsets.numel()} offsets with {block_rows} block rows.")
+
+ is_cuda = (data.is_cuda and
+ row_indices.is_cuda and
+ column_indices.is_cuda and
+ offsets.is_cuda)
+ is_cpu = (not data.is_cuda and
+ not row_indices.is_cuda and
+ not column_indices.is_cuda and
+ not offsets.is_cuda)
+ if not (is_cuda or is_cpu):
+ raise ValueError(
+ "Expected data & meta-data on common device. "
+ f"Got data on {data.device}, row_indices on {row_indices.device} "
+ f"column_indices on {column_indices.device} and "
+ f"offsets on {offsets.device}.")
+
+ if data.dtype != torch.float16:
+ raise ValueError(
+ f"Expected float16 data. Got {data.dtype} data.")
+ if row_indices.dtype != torch.int16:
+ raise ValueError(
+ f"Expected int16 row_indices. Got {row_indices.dtype} row_indices.")
+ if column_indices.dtype != torch.int16:
+ raise ValueError(
+ f"Expected int16 column_indices. Got {column_indices.dtype} column_indices.")
+ if offsets.dtype != torch.int32:
+ raise ValueError(
+ f"Expected int32 offsets. Got {offsets.dtype} offsets.")
+ return data
+
+
+def _transpose(size, data, row_indices, column_indices, offsets):
+ block_columns = size[1] // data.shape[1]
+
+ # Sort row indices by column indices to get the transposed matrix's
+ # column indices.
+ gather_indices = column_indices.argsort()
+ column_indices_t = row_indices.gather(0, gather_indices)
+ block_offsets_t = gather_indices.int()
+
+ # NOTE: Histogram is not implemented for any integer type on CPU. Do
+ # the histogram in 32-bit float, which can exactly represent 16-bit
+ # integers.
+ column_indices_float = column_indices.float()
+
+ zero = torch.zeros((1,), dtype=torch.int32, device=data.device)
+ nnz_per_column = column_indices_float.histc(block_columns, 0, block_columns)
+ nnz_per_column = nnz_per_column.int()
+ offsets_t = torch.cat([zero, nnz_per_column.cumsum(0, dtype=torch.int32)])
+ return column_indices_t, offsets_t, block_offsets_t
+
+
+class Matrix(torch.nn.Module):
+ """A matrix stored in sparse format.
+
+ Underlying format is block compressed sparse row (BCSR).
+
+ TODO(tgale): Make this mirror torch.Tensor API as much as possible.
+ """
+
+ def __init__(self,
+ size,
+ data,
+ row_indices,
+ column_indices,
+ offsets,
+ column_indices_t=None,
+ offsets_t=None,
+ block_offsets_t=None):
+ super().__init__()
+ self._size = size
+ self._data = data
+ self._row_indices = row_indices
+ self._column_indices = column_indices
+ self._offsets = offsets
+
+ # Produce the transpose meta-data if it is not passed in.
+ if ((column_indices_t is None) or (offsets_t is None) or
+ (block_offsets_t is None)):
+ column_indices_t, offsets_t, block_offsets_t = _transpose(
+ size, data, row_indices, column_indices, offsets)
+ self._column_indices_t = column_indices_t
+ self._offsets_t = offsets_t
+ self._block_offsets_t = block_offsets_t
+
+ self._transposed = False
+
+ # Validate that our metadata will not overflow.
+ max_dim = np.iinfo(np.int16).max * self.blocking
+ if column_indices.dtype == torch.int16:
+ if size[0] > max_dim or size[1] > max_dim:
+ raise ValueError(
+ "Sparse matrix with shape {size} exceeds representable "
+ "size with 16-bit indices.")
+
+ def validate(self):
+ _validate_matrix(self._size,
+ self._data,
+ self._row_indices,
+ self._column_indices,
+ self._offsets)
+
+ # TODO(tgale): Add heavyweight data validation.
+
+ def to(self, device):
+ # TODO(tgale): Handle type conversions here. We
+ # need to set the appropriate meta-data type for
+ # the given floating-point type.
+ self._data = self._data.to(device)
+ self._row_indices = self._row_indices.to(device)
+ self._column_indices = self._column_indices.to(device)
+ self._offsets = self._offsets.to(device)
+ self._column_indices_t = self._column_indices_t.to(device)
+ self._offsets_t = self._offsets_t.to(device)
+ self._block_offsets_t = self._block_offsets_t.to(device)
+ return self
+
+ def cuda(self):
+ return self.to(torch.cuda.current_device())
+
+ def clone(self):
+ return Matrix(
+ self.size(),
+ self.data.clone(),
+ self.row_indices.clone(),
+ self.column_indices.clone(),
+ self.offsets.clone(),
+ self.column_indices_t.clone(),
+ self.offsets_t.clone(),
+ self.block_offsets_t.clone())
+
+ def t(self):
+ if self.dim() != 2:
+ raise ValueError(
+ "t() expects a tensor with <= 2 dimensions, "
+ f"but self is {self.dim()}D.")
+ out = Matrix(self.size(),
+ self.data,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+ out._transposed = not self._transposed
+ out._size = torch.Size((self._size[1], self._size[0]))
+ return out
+
+ def contiguous(self):
+ raise ValueError("Not yet implemented.")
+
+ def is_contiguous(self):
+ return not self._transposed
+
+ @property
+ def is_cuda(self):
+ return self._data.is_cuda
+
+ @property
+ def device(self):
+ return self._data.device
+
+ def size(self):
+ return self._size
+
+ @property
+ def shape(self):
+ return self.size()
+
+ def dim(self):
+ return len(self._size)
+
+ @property
+ def data(self):
+ return self._data
+
+ @property
+ def row_indices(self):
+ return self._row_indices
+
+ @property
+ def column_indices(self):
+ return self._column_indices
+
+ @property
+ def offsets(self):
+ return self._offsets
+
+ @property
+ def offsets_t(self):
+ return self._offsets_t
+
+ @property
+ def column_indices_t(self):
+ return self._column_indices_t
+
+ @property
+ def block_offsets_t(self):
+ return self._block_offsets_t
+
+ @property
+ def dtype(self):
+ return self.data.dtype
+
+ @property
+ def nnz(self):
+ return self.data.numel()
+
+ @property
+ def blocking(self):
+ return self.data.shape[1]
+
+ @property
+ def requires_grad(self):
+ return self.data.requires_grad
+
+ def requires_grad_(self, x):
+ self.data.requires_grad_(x)
+ return self
+
+ def view(self, *shape):
+ assert self.is_contiguous()
+ if shape[-1] != self.size()[-1]:
+ raise ValueError(
+ "Can't change view on compressed dimension. "
+ f"{self.size()[-1]} v. {shape[-1]}.")
+ if np.prod(shape) != np.prod(self.size()):
+ raise ValueError(
+ "Mismatch in numel of Matrix and new shape. "
+ f"{np.prod(self.size())} v. {np.prod(shape)}")
+ return Matrix(shape,
+ self.data,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+
+ @property
+ def grad(self):
+ # TODO(tgale): Make sure this mirrors torch.Tensor
+ # behavior in the case where we ask for the gradient
+ # of a non-contiguous tensor.
+ size = self.size()
+ if not self.is_contiguous():
+ size = torch.Size((size[1], size[0]))
+ out = Matrix(size,
+ self.data.grad,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+ return out if self.is_contiguous() else out.t()
diff --git a/build/torch29-cxx11-cu129-aarch64-linux/stk/ops/__init__.py b/build/torch29-cxx11-cu129-aarch64-linux/stk/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc873b236f4cd4036964c016a4036e3ce5ebf1ac
--- /dev/null
+++ b/build/torch29-cxx11-cu129-aarch64-linux/stk/ops/__init__.py
@@ -0,0 +1,3 @@
+from .linear_ops import dds, dsd, sdd
+from .matrix_ops import ones_like, row_indices, sum, to_dense, to_sparse
+from .eltwise_ops import mul
diff --git a/build/torch29-cxx11-cu129-aarch64-linux/stk/ops/eltwise_ops.py b/build/torch29-cxx11-cu129-aarch64-linux/stk/ops/eltwise_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba7d7332320250fd01fa60e60528f19de3e8ed03
--- /dev/null
+++ b/build/torch29-cxx11-cu129-aarch64-linux/stk/ops/eltwise_ops.py
@@ -0,0 +1,28 @@
+from ..matrix import Matrix
+
+def mul(a, b):
+ """Performs element-wise multiplication of matrices a and b.
+
+ It is the user's responsibility to make sure that a and b
+ follow the same matrix topology. This function assumes it is safe
+ to use the topoplogy of a.
+
+ Args:
+ a: stk.Matrix.
+ b: stk.Matrix with a's matrix topology.
+
+ Returns:
+ stk.Matrix where the entries correspond to torch.mul(a, b).
+ """
+ assert isinstance(a, Matrix)
+ assert isinstance(b, Matrix)
+ assert a.size() == b.size()
+
+ return Matrix(a.size(),
+ a.data * b.data,
+ a.row_indices,
+ a.column_indices,
+ a.offsets,
+ a.column_indices_t,
+ a.offsets_t,
+ a.block_offsets_t)
diff --git a/build/torch29-cxx11-cu129-aarch64-linux/stk/ops/eltwise_ops_test.py b/build/torch29-cxx11-cu129-aarch64-linux/stk/ops/eltwise_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..2939372a5c68ac92b47b11015db4f75f4fd60ffa
--- /dev/null
+++ b/build/torch29-cxx11-cu129-aarch64-linux/stk/ops/eltwise_ops_test.py
@@ -0,0 +1,86 @@
+import unittest
+import itertools
+import torch
+# from absl.testing import parameterized
+
+import stk
+from stk.ops.linear_ops_test import allclose, _dense_and_sparse
+
+_MATRIX_SIZES = (
+ (128, 128, 0.0),
+ (256, 256, 0.5),
+ (2048, 1024, 0.8),
+ (512, 128, 0.0),
+ (128, 512, 0.0),
+ (1024, 512, 0.0),
+ (1024, 512, 0.5),
+ (1024, 512, 0.75),
+ (512, 1024, 0.0),
+ (512, 1024, 0.5),
+ (512, 1024, 0.75),
+ (1024, 1024, 0.0),
+ (1024, 1024, 0.5),
+ (1024, 1024, 0.75),
+)
+
+_DTYPE = (
+ torch.float16, torch.bfloat16
+)
+
+def _generate_testcases():
+ testcases = itertools.product(_MATRIX_SIZES, _DTYPE)
+ testcases = [(*size, 128, dtype) for
+ (size, dtype) in testcases]
+ return testcases
+
+_ELTWISE_OP_TESTS = _generate_testcases()
+
+def _dense_and_sparse_like(x, std=0.1):
+ dense_data = torch.randn_like(x.data, device=x.device) * std
+ sparse = stk.Matrix(x.size(),
+ dense_data,
+ x.row_indices,
+ x.column_indices,
+ x.offsets)
+ dense = stk.ops.to_dense(sparse)
+
+ return (dense.requires_grad_(True),
+ sparse.requires_grad_(True))
+
+# @parameterized.parameters(_ELTWISE_OP_TESTS)
+# class EltwiseOpsTest(parameterized.TestCase):
+#
+# def testEltwiseMul(self, m, n, sparsity, blocking, dtype):
+#
+# a_dense, a = _dense_and_sparse(m, n, sparsity, blocking, dtype)
+# b_dense, b = _dense_and_sparse_like(a)
+#
+# out = stk.ops.mul(a, b)
+# expected_out = torch.mul(a_dense, b_dense)
+#
+# # Compute the gradients w.r.t. the inputs.
+# expected_out.sum().backward()
+# stk.ops.sum(out).backward()
+#
+# # Validate the results.
+# out = stk.ops.to_dense(out)
+# self.assertEqual(out.dim(), 2)
+# self.assertEqual(expected_out.size(), out.size())
+# self.assertTrue(allclose(out, expected_out))
+#
+# # LHS gradient.
+# grad = stk.ops.to_dense(a.grad)
+# expected_grad = a_dense.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size(), grad.size())
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# # RHS gradient.
+# grad = stk.ops.to_dense(b.grad)
+# expected_grad = b_dense.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size(), grad.size())
+# self.assertTrue(allclose(grad, expected_grad))
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch29-cxx11-cu129-aarch64-linux/stk/ops/linear_ops.py b/build/torch29-cxx11-cu129-aarch64-linux/stk/ops/linear_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d277c8c07f9e30addc31900a12175c8a1f4d7ad
--- /dev/null
+++ b/build/torch29-cxx11-cu129-aarch64-linux/stk/ops/linear_ops.py
@@ -0,0 +1,59 @@
+import torch
+
+from ..backend import sputnik
+from ..matrix import Matrix
+
+
+def dsd(a, b):
+ assert isinstance(a, Matrix)
+ assert isinstance(b, torch.Tensor)
+ return sputnik.dsd(
+ a.size(),
+ a.data, a.offsets,
+ a.row_indices,
+ a.column_indices,
+ a.offsets_t,
+ a.column_indices_t,
+ a.block_offsets_t,
+ not a.is_contiguous(),
+ b)
+
+
+def dds(a, b):
+ assert isinstance(a, torch.Tensor)
+ assert isinstance(b, Matrix)
+ return sputnik.dds(
+ a,
+ b.size(),
+ b.data, b.offsets,
+ b.row_indices,
+ b.column_indices,
+ b.offsets_t,
+ b.column_indices_t,
+ b.block_offsets_t,
+ not b.is_contiguous())
+
+
+def sdd(a, b, topo):
+ assert isinstance(a, torch.Tensor)
+ assert isinstance(b, torch.Tensor)
+ assert isinstance(topo, Matrix)
+ assert topo.is_contiguous()
+ out = sputnik.sdd(
+ a, b,
+ topo.size(),
+ topo.data,
+ topo.offsets,
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets_t,
+ topo.column_indices_t,
+ topo.block_offsets_t)
+ return Matrix(topo.size(),
+ out,
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets,
+ topo.column_indices_t,
+ topo.offsets_t,
+ topo.block_offsets_t)
diff --git a/build/torch29-cxx11-cu129-aarch64-linux/stk/ops/linear_ops_test.py b/build/torch29-cxx11-cu129-aarch64-linux/stk/ops/linear_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1c24d350df9c1b2346c7da885502cd696c88867
--- /dev/null
+++ b/build/torch29-cxx11-cu129-aarch64-linux/stk/ops/linear_ops_test.py
@@ -0,0 +1,216 @@
+import unittest
+import itertools
+import numpy as np
+import torch
+# from absl.testing import parameterized
+
+import stk
+
+
+def allclose(x, y, pct=0.25):
+ mask = torch.isclose(x, y, rtol=5e-2)
+ pct_diff = (mask.numel() - mask.sum()) / mask.numel() * 100
+ if pct_diff > pct:
+ print("{:.2f}% of values not close.".format(pct_diff))
+ return False
+ return True
+
+
+# An assortment of problems designed to make sure
+# the bindings are operating correctly.
+_MATRIX_SIZES = (
+ (128, 128, 128, 0.0),
+ (256, 256, 256, 0.5),
+ (2048, 1024, 512, 0.8),
+ (512, 128, 128, 0.0),
+ (128, 128, 512, 0.0),
+ (1024, 512, 512, 0.0),
+ (1024, 512, 512, 0.5),
+ (1024, 512, 512, 0.75),
+ (512, 512, 1024, 0.0),
+ (512, 512, 1024, 0.5),
+ (512, 512, 1024, 0.75),
+ (1024, 1024, 1024, 0.0),
+ (1024, 1024, 1024, 0.5),
+ (1024, 1024, 1024, 0.75),
+)
+
+_TRANSPOSE = (
+ (False, False),
+ (False, True),
+ (True, False),
+ (True, True),
+)
+
+_DTYPE = (
+ torch.float16, torch.bfloat16
+)
+
+def _generate_testcases():
+ testcases = itertools.product(_MATRIX_SIZES, _TRANSPOSE, _DTYPE)
+ testcases = [(*size, *trans, 128, dtype) for
+ (size, trans, dtype) in testcases]
+ return testcases
+
+_LINEAR_OP_TESTS = _generate_testcases()
+
+def _dense_and_sparse(rows, cols, sparsity, blocking, dtype, std=0.1):
+ mask = stk.random.dense_mask(rows, cols, sparsity, blocking)
+ dense = (torch.randn(rows, cols) * std * mask).type(dtype)
+ sparse = stk.ops.to_sparse(dense, blocking)
+ cuda_device = torch.device("cuda")
+ return (dense.to(cuda_device).requires_grad_(True),
+ sparse.to(cuda_device).requires_grad_(True))
+
+
+def _dense(rows, cols, dtype, std=0.1):
+ cuda_device = torch.device("cuda")
+ out = (torch.randn(rows, cols) * std).type(dtype)
+ return out.to(cuda_device).requires_grad_(True)
+
+
+def _dense_2x(rows, cols, dtype):
+ a = _dense(rows, cols, dtype)
+ return a, a.detach().requires_grad_(True)
+
+
+def _with_transpose(op, a, b, trans_a, trans_b):
+ a = a.t() if trans_a else a
+ b = b.t() if trans_b else b
+ return op(a, b)
+
+
+def _mmm(a, b, topo):
+ mask = stk.ops.to_dense(stk.ops.ones_like(topo))
+ return torch.mm(a, b) * mask
+
+
+def _sparse_out_with_transpose(op, a, b, topo, trans_a, trans_b):
+ a = a.t() if trans_a else a
+ b = b.t() if trans_b else b
+ return op(a, b, topo)
+
+
+def _mask(x, mask):
+ mask = stk.ops.to_dense(stk.ops.ones_like(mask))
+ return x * mask
+
+
+# @parameterized.parameters(*_LINEAR_OP_TESTS)
+# class LinearOpsTest(parameterized.TestCase):
+#
+# def testLinearOps_Dsd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+# # Construct the operands.
+# a_shape = (k, m) if trans_a else (m, k)
+# a_dense, a = _dense_and_sparse(*a_shape, sparsity, blocking, dtype)
+# b_shape = (n, k) if trans_b else (k, n)
+# b, bcp = _dense_2x(*b_shape, dtype)
+#
+# # Execute the matmul.
+# out = _with_transpose(stk.ops.dsd, a, b, trans_a, trans_b)
+# expected_out = _with_transpose(torch.mm, a_dense, bcp, trans_a, trans_b)
+#
+# # Compute the gradients w.r.t. the inputs.
+# expected_out.sum().backward()
+# out.sum().backward()
+#
+# # Validate the results.
+# self.assertEqual(out.dim(), 2)
+# self.assertEqual(expected_out.size()[0], out.size()[0])
+# self.assertEqual(expected_out.size()[1], out.size()[1])
+# self.assertTrue(allclose(out, expected_out))
+#
+# # LHS gradient.
+# grad = stk.ops.to_dense(a.grad)
+# expected_grad = _mask(a_dense.grad, a.grad)
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# # RHS gradient.
+# grad = b.grad
+# expected_grad = bcp.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# def testLinearOps_Dds(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+# # Construct the operands.
+# a_shape = (k, m) if trans_a else (m, k)
+# a, acp = _dense_2x(*a_shape, dtype)
+# b_shape = (n, k) if trans_b else (k, n)
+# b_dense, b = _dense_and_sparse(*b_shape, sparsity, blocking, dtype)
+#
+# # Execute the matmul.
+# out = _with_transpose(stk.ops.dds, a, b, trans_a, trans_b)
+# expected_out = _with_transpose(torch.mm, acp, b_dense, trans_a, trans_b)
+#
+# # Compute the gradients w.r.t. the inputs.
+# expected_out.sum().backward()
+# out.sum().backward()
+#
+# # Validate the results.
+# self.assertEqual(out.dim(), 2)
+# self.assertEqual(expected_out.size()[0], out.size()[0])
+# self.assertEqual(expected_out.size()[1], out.size()[1])
+# self.assertTrue(allclose(out, expected_out))
+#
+# # LHS gradient.
+# grad = a.grad
+# expected_grad = acp.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# # RHS gradient.
+# grad = stk.ops.to_dense(b.grad)
+# expected_grad = _mask(b_dense.grad, b.grad)
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# def testLinearOps_Sdd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+# # Construct the operands.
+# a_shape = (k, m) if trans_a else (m, k)
+# a, acp = _dense_2x(*a_shape, dtype)
+# b_shape = (n, k) if trans_b else (k, n)
+# b, bcp = _dense_2x(*b_shape, dtype)
+# _, topo = _dense_and_sparse(m, n, sparsity, blocking, dtype)
+#
+# # Execute the matmul.
+# out = _sparse_out_with_transpose(stk.ops.sdd, a, b, topo, trans_a, trans_b)
+# expected_out = _sparse_out_with_transpose(_mmm, acp, bcp, topo, trans_a, trans_b)
+#
+# # Compute the gradients w.r.t. the inputs.
+# expected_out.sum().backward()
+# stk.ops.sum(out).backward()
+#
+# # Validate the results.
+# out = stk.ops.to_dense(out)
+# self.assertEqual(out.dim(), 2)
+# self.assertEqual(expected_out.size()[0], out.size()[0])
+# self.assertEqual(expected_out.size()[1], out.size()[1])
+# self.assertTrue(allclose(out, expected_out))
+#
+# # LHS gradient.
+# grad = a.grad
+# expected_grad = acp.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# # RHS gradient.
+# grad = b.grad
+# expected_grad = bcp.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch29-cxx11-cu129-aarch64-linux/stk/ops/matrix_ops.py b/build/torch29-cxx11-cu129-aarch64-linux/stk/ops/matrix_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..447c72dc73439d84f58c917676cc04e64f13e97d
--- /dev/null
+++ b/build/torch29-cxx11-cu129-aarch64-linux/stk/ops/matrix_ops.py
@@ -0,0 +1,98 @@
+from ..backend import sputnik
+from ..matrix import Matrix
+import torch
+import numpy as np
+
+
+@torch.no_grad()
+def row_indices(shape, data, offsets, column_indices):
+ return sputnik.row_indices(shape, data, offsets, column_indices)
+
+
+# TODO(tgale): Replace this helper with a custom kernel. This operation
+# is much simpler to do than how it's currently implemented.
+@torch.no_grad()
+def _expand_for_blocking(idxs, blocking):
+ # Duplicate for block column dimension.
+ idxs = torch.reshape(idxs, [idxs.size()[0], 1, 2]).repeat(1, blocking, 1)
+
+ # Update the column indices.
+ idxs[:, :, 1] *= blocking
+ idxs[:, :, 1] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking])
+
+ # Duplicate for block row dimension.
+ idxs = torch.reshape(idxs, [idxs.size()[0], 1, blocking, 2])
+ idxs = idxs.repeat(1, blocking, 1, 1)
+
+ # Update the row indices.
+ idxs[:, :, :, 0] *= blocking
+ idxs[:, :, :, 0] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking, 1])
+ idxs = torch.reshape(idxs, [-1, 2])
+ return idxs
+
+
+# TODO(tgale): Add input type checking.
+@torch.no_grad()
+def to_dense(x):
+ assert isinstance(x, Matrix)
+
+ shape = (np.prod(x.shape[:-1]), x.shape[-1])
+ row_idxs = x.row_indices.type(torch.int32)
+ col_idxs = x.column_indices.type(torch.int32)
+ indices = _expand_for_blocking(torch.stack([row_idxs, col_idxs], dim=1), x.blocking)
+ indices = (indices[:, 0] * shape[1] + indices[:, 1]).type(torch.int64)
+
+ out = torch.zeros(shape[0] * shape[1], dtype=x.dtype, device=x.device)
+ out.scatter_(0, indices, x.data.flatten())
+ return out.reshape(x.size())
+
+
+@torch.no_grad()
+def _mask(x, blocking=1):
+ assert x.dim() == 2
+ assert x.size()[0] % blocking == 0
+ assert x.size()[1] % blocking == 0
+ block_rows = x.size()[0] // blocking
+ block_cols = x.size()[1] // blocking
+ x = torch.reshape(x, [block_rows, blocking, block_cols, blocking])
+ x = torch.sum(torch.abs(x), dim=(1, 3))
+ return x != 0
+
+
+# TODO(tgale): Add input type checking.
+@torch.no_grad()
+def to_sparse(x, blocking=1):
+ m = _mask(x, blocking)
+
+ # TODO(tgale): Set to appropriate type for input matrix.
+ row_nnzs = torch.sum(m, dim=1).type(torch.int32)
+ zeros = torch.zeros((1,), dtype=row_nnzs.dtype, device=row_nnzs.device)
+ offsets = torch.cat([zeros, torch.cumsum(row_nnzs, dim=0)])
+ offsets = offsets.type(torch.int32)
+
+ indices = torch.nonzero(m).type(torch.int16)
+ row_indices = indices[:, 0]
+ column_indices = indices[:, 1]
+
+ # Nonzero indices in the dense matrix.
+ nonzero_indices = torch.nonzero(m)
+ nonzero_indices = _expand_for_blocking(nonzero_indices, blocking)
+ nonzero_indices = nonzero_indices[:, 0] * x.size()[1] + nonzero_indices[:, 1]
+
+ # Gather the data and construct the sparse matrix.
+ data = torch.gather(x.flatten(), dim=0, index=nonzero_indices)
+ data = torch.reshape(data, [-1, blocking, blocking])
+ return Matrix(x.size(), data, row_indices, column_indices, offsets)
+
+
+@torch.no_grad()
+def ones_like(x):
+ return Matrix(x.size(),
+ torch.ones_like(x.data),
+ x.row_indices,
+ x.column_indices, x.offsets)
+
+
+def sum(x):
+ assert isinstance(x, Matrix)
+ return x.data.sum()
diff --git a/build/torch29-cxx11-cu129-aarch64-linux/stk/ops/matrix_ops_test.py b/build/torch29-cxx11-cu129-aarch64-linux/stk/ops/matrix_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..d172d921f6f08b0e4fb709207a458b0e1e071bd0
--- /dev/null
+++ b/build/torch29-cxx11-cu129-aarch64-linux/stk/ops/matrix_ops_test.py
@@ -0,0 +1,62 @@
+import unittest
+
+# from absl.testing import parameterized
+import stk
+import torch
+
+
+# @parameterized.parameters(
+# (8, 16, 0.0, 1),
+# (8, 16, 0.5, 1),
+# (8, 16, .95, 1),
+# (16, 8, 0.0, 1),
+# (16, 8, 0.5, 1),
+# (16, 8, .95, 1),
+# (8, 16, 0.0, 8),
+# (8, 16, 0.5, 8),
+# (8, 16, 1.0, 8),
+# (16, 8, 0.0, 8),
+# (16, 8, 0.5, 8),
+# (16, 8, 1.0, 8),
+# (128, 256, 0.5, 16),
+# (256, 128, 0.75, 32),
+# (512, 512, .875, 128))
+# class MatrixOpsTest(parameterized.TestCase):
+#
+# def testMatrixOps_FormatConversion(self, rows, cols, sparsity, blocking):
+# mask = stk.random.dense_mask(rows, cols, sparsity, blocking)
+# x = (torch.randn(rows, cols) * mask).type(torch.float16)
+#
+# # Convert the matrix to sparse format.
+# sparse_x = stk.ops.to_sparse(x, blocking)
+#
+# # Validate the matrix.
+# sparse_x.validate()
+#
+# # Validate the shape.
+# self.assertEqual(sparse_x.dim(), 2)
+# self.assertEqual(sparse_x.size()[0], rows)
+# self.assertEqual(sparse_x.size()[1], cols)
+#
+# # Validate the sparsity.
+# numblocks = rows // blocking * cols // blocking
+# nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+# self.assertEqual(sparse_x.nnz, nnz)
+#
+# # Convert back to dense format.
+# dense_x = stk.ops.to_dense(sparse_x)
+#
+# # Validate the shape.
+# self.assertEqual(dense_x.dim(), 2)
+# self.assertEqual(dense_x.size()[0], rows)
+# self.assertEqual(dense_x.size()[1], cols)
+#
+# # Validate the sparsity
+# self.assertEqual(torch.count_nonzero(dense_x).item(), nnz)
+#
+# # Validate the output.
+# self.assertTrue(torch.all(torch.eq(x, dense_x)))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch29-cxx11-cu129-aarch64-linux/stk/random/__init__.py b/build/torch29-cxx11-cu129-aarch64-linux/stk/random/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2576d1ca27283f77569a9a620c7c99fa68aaf30e
--- /dev/null
+++ b/build/torch29-cxx11-cu129-aarch64-linux/stk/random/__init__.py
@@ -0,0 +1,2 @@
+# from stk.random.random_ops import dense_mask, mask, randn
+from .random_ops import dense_mask, mask, randn
diff --git a/build/torch29-cxx11-cu129-aarch64-linux/stk/random/random_ops.py b/build/torch29-cxx11-cu129-aarch64-linux/stk/random/random_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1b36771e0eb8e7abf46bcb3b136b5fb1d29df93
--- /dev/null
+++ b/build/torch29-cxx11-cu129-aarch64-linux/stk/random/random_ops.py
@@ -0,0 +1,36 @@
+import numpy as np
+import torch
+from ..ops import matrix_ops
+
+
+@torch.no_grad()
+def dense_mask(rows, cols, sparsity, blocking=1):
+ assert sparsity >= 0.0 and sparsity <= 1.0
+ assert rows % blocking == 0 and cols % blocking == 0
+
+ block_rows, block_cols = (rows // blocking, cols // blocking)
+ nnz = round(block_rows * block_cols * (1 - sparsity))
+
+ out = np.ones(block_rows * block_cols)
+ mask = np.random.choice(out.size, out.size - nnz, replace=False)
+ out[mask] = 0.0
+
+ out = np.tile(
+ np.reshape(out, [block_rows, 1, block_cols, 1]),
+ (1, blocking, 1, blocking))
+ out = np.reshape(out, [rows, cols])
+ return torch.from_numpy(out.astype(np.float32))
+
+
+@torch.no_grad()
+def mask(m, n, sparsity, blocking=1):
+ out = dense_mask(m, n, sparsity, blocking).type(torch.float16)
+ return matrix_ops.to_sparse(out, blocking=blocking)
+
+
+@torch.no_grad()
+def randn(shape, sparsity, blocking=1):
+ shape_2d = (np.prod(shape[:-1]), shape[-1])
+ out = mask(*shape_2d, sparsity, blocking)
+ out.data.copy_(torch.randn(*out.data.shape))
+ return out.view(*shape)
diff --git a/build/torch29-cxx11-cu129-aarch64-linux/stk/random/random_ops_test.py b/build/torch29-cxx11-cu129-aarch64-linux/stk/random/random_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..71d716b78b5ec009cbf9ac2dfdf09162a0102e62
--- /dev/null
+++ b/build/torch29-cxx11-cu129-aarch64-linux/stk/random/random_ops_test.py
@@ -0,0 +1,73 @@
+import unittest
+
+# from absl.testing import parameterized
+from . import random
+import torch
+
+
+# @parameterized.parameters(
+# (8, 16, 0.0, 1),
+# (8, 16, 0.5, 1),
+# (8, 16, .95, 1),
+# (16, 8, 0.0, 1),
+# (16, 8, 0.5, 1),
+# (16, 8, .95, 1),
+# (8, 16, 0.0, 8),
+# (8, 16, 0.5, 8),
+# (8, 16, 1.0, 8),
+# (16, 8, 0.0, 8),
+# (16, 8, 0.5, 8),
+# (16, 8, 1.0, 8),
+# (128, 256, 0.5, 16),
+# (256, 128, 0.75, 32),
+# (512, 512, .875, 128))
+# class RandomOpsTest(parameterized.TestCase):
+#
+# def testRandomOps_DenseMask(self, rows, cols, sparsity, blocking):
+# mask = random.dense_mask(
+# rows, cols, sparsity, blocking)
+#
+# # Validate the shape.
+# self.assertEqual(mask.dim(), 2)
+# self.assertEqual(mask.size()[0], rows)
+# self.assertEqual(mask.size()[1], cols)
+#
+# # Validate the sparsity
+# numblocks = rows // blocking * cols // blocking
+# nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+# self.assertEqual(
+# torch.count_nonzero(mask).item(),
+# nnz)
+#
+# # Check values are zero or one.
+# self.assertTrue(
+# torch.all(torch.logical_or(
+# torch.eq(mask, 0),
+# torch.eq(mask, 1))))
+#
+# def testRandomOps_SparseMask(self, rows, cols, sparsity, blocking):
+# mask = random.mask(
+# rows, cols, sparsity, blocking)
+#
+# # Validate the matrix.
+# mask.validate()
+#
+# # Validate the shape.
+# self.assertEqual(mask.dim(), 2)
+# self.assertEqual(mask.size()[0], rows)
+# self.assertEqual(mask.size()[1], cols)
+#
+# # Validate the sparsity.
+# numblocks = rows // blocking * cols // blocking
+# nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+# self.assertEqual(mask.nnz, nnz)
+#
+# # Check values are zero or one.
+# self.assertTrue(
+# torch.all(torch.logical_or(
+# torch.eq(mask.data, 0),
+# torch.eq(mask.data, 1))))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch29-cxx11-cu129-aarch64-linux/xpu_fused_moe.py b/build/torch29-cxx11-cu129-aarch64-linux/xpu_fused_moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..0440b484e6f4073a384dcf3bc562174d2ef71d25
--- /dev/null
+++ b/build/torch29-cxx11-cu129-aarch64-linux/xpu_fused_moe.py
@@ -0,0 +1,672 @@
+# SPDX-License-Identifier: Apache-2.0
+# MegaBlocks XPU Fused MoE Implementation
+import os
+import torch
+
+from ._ops import ops, add_op_namespace_prefix
+
+from torch.library import register_fake
+
+
+def resolve_dtensor(weight: torch.Tensor):
+ """Convert DTensor to local tensor for use with custom ops."""
+ from torch.distributed._tensor import DTensor
+ if isinstance(weight, DTensor):
+ return weight.to_local()
+ return weight
+
+
+# Register fake/meta kernels for torch.compile compatibility
+def _register_xpu_fake_kernels():
+ """Register fake kernels for XPU MoE operations to support torch.compile."""
+
+ def _register_if_available(op_name, fn):
+ if hasattr(ops, op_name):
+ register_fake(add_op_namespace_prefix(op_name))(fn)
+
+ _register_if_available(
+ "cutlass_grouped_gemm_interface",
+ lambda ptr_A, ptr_B, ptr_scales, ptr_bias, ptr_D, expert_first_token_offset, N, K, num_experts, is_B_int4, is_B_mxfp4: ptr_D,
+ )
+
+ _register_if_available(
+ "fused_moe_prologue",
+ lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, ep_rank, ep_size, num_experts_on_rank: None,
+ )
+
+ _register_if_available(
+ "moe_gather",
+ lambda output, moe_output, topk_weights, permuted_row_to_unpermuted_row, unpermuted_row_to_permuted_row, expert_first_token_offset, num_experts: None,
+ )
+
+ _register_if_available(
+ "silu_and_mul",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "mul_and_silu",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_and_mul",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_tanh_and_mul",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_fast",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_new",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_quick",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "swigluoai_and_mul",
+ lambda out, input, alpha=1.702, limit=7.0: None,
+ )
+
+
+# Register fake kernels on module load
+_register_xpu_fake_kernels()
+
+
+# default
+def cutlass_grouped_gemm(input_A, input_B, bias, output, expert_token_count, n,
+ k, num_experts):
+ # expert_token_count_ = torch.tensor(expert_token_count,
+ # dtype=torch.int64,
+ # device=input_A.device)
+ # if bias is not None:
+ # bias = bias.repeat_interleave(expert_token_count_, dim=0).float()
+
+ def exclusive_prefix_sum(arr):
+ prefix = [0]
+ for i, x in enumerate(arr):
+ prefix.append(prefix[-1] + x)
+ return prefix
+
+ expert_offset = torch.tensor(exclusive_prefix_sum(expert_token_count),
+ dtype=torch.int64,
+ device="xpu")
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=None,
+ ptr_bias=bias,
+ ptr_D=output,
+ expert_first_token_offset=expert_offset,
+ N=n,
+ K=k,
+ num_experts=num_experts,
+ is_B_int4=False,
+ is_B_mxfp4=False)
+
+
+def cutlass_grouped_gemm_xe2(input_A, input_B, scales, bias, output,
+ num_rows_per_expert, n, k, num_experts, is_B_int4,
+ is_B_mxfp4):
+ expert_first_token_offset = torch.cat([
+ torch.tensor([0],
+ dtype=num_rows_per_expert.dtype,
+ device=num_rows_per_expert.device),
+ torch.cumsum(num_rows_per_expert, dim=0)
+ ]).to(torch.int64)
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=scales,
+ ptr_bias=bias,
+ ptr_D=output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=n,
+ K=k,
+ num_experts=num_experts,
+ is_B_int4=is_B_int4,
+ is_B_mxfp4=is_B_mxfp4)
+
+
+def ceilDiv(a, b):
+ return (a + b - 1) // b
+
+
+def compute_num_tokens_per_block(num_tokens, num_experts_per_node):
+ for num_tokens_per_block in [32, 64, 128, 256, 512, 1024]:
+ num_blocks_per_seq = ceilDiv(num_tokens, num_tokens_per_block)
+ if num_blocks_per_seq * num_experts_per_node <= num_tokens_per_block:
+ return num_tokens_per_block
+ return 1024
+
+
+def _bytes_to_typed_tensor(byte_tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
+ """Reinterpret a uint8 buffer as a typed tensor by copying bytes.
+
+ This avoids `Tensor.view(dtype)` which can fail under torch.compile
+ constant folding when shape divisibility is not proven.
+ """
+ if byte_tensor.dtype != torch.uint8:
+ raise ValueError("byte_tensor must be uint8")
+ itemsize = torch.empty((), dtype=dtype).element_size()
+ numel = byte_tensor.numel() // itemsize
+ out = torch.empty((numel,), dtype=dtype, device=byte_tensor.device)
+ out.view(torch.uint8).copy_(byte_tensor.contiguous())
+ return out
+
+
+def implement_zp(qweight):
+ # change u4 to s4 to avoid zero point in gemm kernel
+ # only support default zero point now
+ assert qweight.dtype == torch.uint8, "Input tensor must be uint8"
+
+ high_u4 = (qweight >> 4) & 0x0F
+ low_u4 = qweight & 0x0F
+
+ high_s8 = high_u4.to(torch.int8)
+ low_s8 = low_u4.to(torch.int8)
+
+ high_s8 = high_s8 - 8
+ low_s8 = low_s8 - 8
+
+ def pack_compact(a, b):
+
+ def process_number(x):
+ sign = (x < 0).to(torch.uint8)
+ abs_low3 = (x.view(torch.uint8) & 0x7).to(torch.uint8)
+ return (sign << 3) | abs_low3
+
+ packed_a = process_number(a)
+ packed_b = process_number(b)
+
+ return (packed_a << 4) | packed_b
+
+ result = pack_compact(high_s8, low_s8)
+
+ return result
+
+
+def xpu_fused_moe(hidden_states,
+ w13,
+ w13_scales,
+ w13_bias,
+ w2,
+ w2_scales,
+ w2_bias,
+ topk_weights,
+ topk_ids,
+ n_experts_per_token,
+ activation,
+ num_experts,
+ ep_rank=0,
+ ep_size=1,
+ is_fp8=False,
+ is_int4=False,
+ is_mxfp4=False):
+ '''
+ hidden_states: [num_rows, hidden_size]
+ w13: [num_experts, 2*inter_size, hidden_size]
+ w13_scales:
+ None for bf16/fp16
+ or [num_experts] for fp8
+ or [num_experts, 2*inter_size, hidden_size // group_size] for 4bits
+ w13_bias: [num_experts, 2*inter_size] or None
+ w2: [num_experts, hidden_size, inter_size]
+ w2_scales:
+ None for bf16/fp16
+ or [num_experts] for fp8
+ or [num_experts, hidden_size, inter_size // group_size] for 4bits
+ w2_bias: [num_experts, hidden_size] or None
+ topk_weights: [num_rows, topk]
+ topk_ids: [num_rows, topk]
+ n_experts_per_token: int
+ activation: str
+ num_experts: int
+ is_int4: bool
+ is_mxfp4: bool
+ '''
+
+ # Resolve DTensors to local tensors before passing to custom ops
+ hidden_states = resolve_dtensor(hidden_states)
+ w13 = resolve_dtensor(w13)
+ w2 = resolve_dtensor(w2)
+ if w13_scales is not None:
+ w13_scales = resolve_dtensor(w13_scales)
+ if w13_bias is not None:
+ w13_bias = resolve_dtensor(w13_bias)
+ if w2_scales is not None:
+ w2_scales = resolve_dtensor(w2_scales)
+ if w2_bias is not None:
+ w2_bias = resolve_dtensor(w2_bias)
+ topk_weights = resolve_dtensor(topk_weights)
+ topk_ids = resolve_dtensor(topk_ids)
+
+ output = torch.empty_like(hidden_states)
+ num_rows, hidden_size = list(hidden_states.shape)
+
+ dim_last = w13.shape[-1]
+ dim_second_last = w13.shape[-2]
+
+ # w13 is combined gate+up weights, so one dimension is 2*inter_size
+ # Determine which dimension is hidden_size and which is 2*inter_size
+ if dim_second_last == hidden_size:
+ # w13 is [E, hidden_size, 2*inter_size] - standard layout
+ inter_size = dim_last // 2
+ needs_transpose = False
+ else:
+ # w13 is [E, 2*inter_size, hidden_size] - needs transpose
+ inter_size = dim_second_last // 2
+ needs_transpose = True
+
+ assert w13.is_contiguous() and w2.is_contiguous()
+
+ # 4bits support [E, N, K]
+ # other types [E, K, N]
+ if not is_int4 and not is_mxfp4:
+ if not hasattr(w13, 'xpu_fused_moe'):
+ if needs_transpose:
+ w13.data = w13.transpose(-1, -2).contiguous()
+ w2.data = w2.transpose(-1, -2).contiguous()
+ w13.xpu_fused_moe = True
+ w13.inter_size = inter_size
+ else:
+ inter_size = w13.inter_size
+
+ if is_int4 and not hasattr(w13, 'xpu_fused_moe'):
+ w13_tmp = torch.empty_like(w13)
+ w2_tmp = torch.empty_like(w2)
+ for i in range(num_experts):
+ w13_tmp[i] = implement_zp(w13[i])
+ w2_tmp[i] = implement_zp(w2[i])
+ w13_tmp = w13_tmp.contiguous()
+ w2_tmp = w2_tmp.contiguous()
+ w13.data = w13_tmp
+ w2.data = w2_tmp
+ w13.xpu_fused_moe = True
+
+ # TODO: will all integrated in Cpp func. Temporary expose before gemm fusion
+ num_experts_per_node = num_experts
+ experts_per_token = n_experts_per_token
+ num_moe_inputs = n_experts_per_token * num_rows
+ permuted_elems = num_moe_inputs * hidden_size
+ # interbuf_elems = num_moe_inputs * inter_size
+ permuted_row_to_unpermuted_row_size = num_moe_inputs * 4
+ permuted_token_selected_experts_size = num_moe_inputs * 4
+ src_to_dest_map_size = experts_per_token * num_rows * 4
+ expert_first_token_offset_size = (num_experts_per_node + 1) * 8
+ num_tokens_per_block = compute_num_tokens_per_block(
+ num_rows, num_experts_per_node)
+ num_blocks_per_seq = ceilDiv(num_rows, num_tokens_per_block)
+ blocked_expert_counts_size = num_experts_per_node * num_blocks_per_seq * 4
+ blocked_expert_counts_cumsum_size = blocked_expert_counts_size
+ blocked_row_to_unpermuted_row_size = num_experts_per_node * num_rows * 4
+ permuted_data_size = permuted_elems * hidden_states.element_size()
+ permuted_token_final_scales_size = num_moe_inputs * 4
+
+ ws_map = {}
+ map_offset = 0
+
+ def config_ws(name, size):
+ nonlocal map_offset
+ if size % 256 != 0:
+ size += 256 - size % 256
+ ws_map[name] = (size, map_offset)
+ map_offset += size
+
+ config_ws("permuted_row_to_unpermuted_row",
+ permuted_row_to_unpermuted_row_size)
+ config_ws("permuted_token_selected_experts",
+ permuted_token_selected_experts_size)
+ config_ws("unpermuted_row_to_permuted_row", src_to_dest_map_size)
+ config_ws("blocked_expert_counts", blocked_expert_counts_size)
+ config_ws("blocked_expert_counts_cumsum",
+ blocked_expert_counts_cumsum_size)
+ config_ws("blocked_row_to_unpermuted_row",
+ blocked_row_to_unpermuted_row_size)
+ config_ws("expert_first_token_offset", expert_first_token_offset_size)
+ config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
+ config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
+
+ workspace = torch.zeros(map_offset,
+ dtype=torch.uint8,
+ device=hidden_states.device)
+ if topk_ids.dtype == torch.int32:
+ topk_ids = topk_ids.to(torch.int64)
+ ops.fused_moe_prologue(
+ input=hidden_states,
+ token_selected_experts=topk_ids,
+ token_final_scales=topk_weights,
+ workspace=workspace,
+ hidden_size=hidden_size,
+ inter_size=inter_size,
+ ep_rank=ep_rank,
+ ep_size=ep_size,
+ num_experts_on_rank=num_experts_per_node)
+
+ expert_first_token_offset_bytes = workspace[
+ ws_map["expert_first_token_offset"][1]:
+ ws_map["expert_first_token_offset"][1] +
+ expert_first_token_offset_size]
+ unpermuted_row_to_permuted_row_bytes = workspace[
+ ws_map["unpermuted_row_to_permuted_row"][1]:
+ ws_map["unpermuted_row_to_permuted_row"][1] +
+ src_to_dest_map_size]
+ permuted_row_to_unpermuted_row_bytes = workspace[
+ ws_map["permuted_row_to_unpermuted_row"][1]:
+ ws_map["permuted_row_to_unpermuted_row"][1] +
+ permuted_row_to_unpermuted_row_size]
+
+ if torch.compiler.is_compiling():
+ expert_first_token_offset = _bytes_to_typed_tensor(
+ expert_first_token_offset_bytes, torch.int64
+ )
+ unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
+ unpermuted_row_to_permuted_row_bytes, torch.int32
+ )
+ permuted_row_to_unpermuted_row = _bytes_to_typed_tensor(
+ permuted_row_to_unpermuted_row_bytes, torch.int32
+ )
+ else:
+ expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
+ unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
+ permuted_row_to_unpermuted_row = permuted_row_to_unpermuted_row_bytes.view(torch.int32)
+ gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
+ ws_map["overlapped_gemm1_gemm2_inputs"][1] +
+ permuted_data_size].view(hidden_states.dtype).view(
+ num_moe_inputs, hidden_size)
+ # permuted_token_final_scales = workspace[
+ # ws_map["permuted_token_final_scales"][1]:
+ # ws_map["permuted_token_final_scales"][1] +
+ # permuted_token_final_scales_size].view(torch.float)
+ gemm1_output = torch.empty((num_moe_inputs, 2 * inter_size),
+ dtype=hidden_states.dtype,
+ device=hidden_states.device)
+
+ ########### gemm1 ##################
+ input_B = w13
+
+ if not is_fp8 and not is_int4 and not is_mxfp4:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=gemm1_input,
+ ptr_B=input_B,
+ ptr_scales=None,
+ ptr_bias=w13_bias,
+ ptr_D=gemm1_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=2 * inter_size,
+ K=hidden_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+ else:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=gemm1_input,
+ ptr_B=input_B,
+ ptr_scales=w13_scales,
+ ptr_bias=w13_bias,
+ ptr_D=gemm1_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=2 * inter_size,
+ K=hidden_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+
+ # act
+ act_output = torch.empty((num_moe_inputs, inter_size),
+ dtype=gemm1_output.dtype,
+ device=gemm1_output.device)
+ if activation == "silu":
+ ops.silu_and_mul(act_output, gemm1_output)
+ elif activation == "gelu":
+ ops.gelu_and_mul(act_output, gemm1_output)
+ elif activation == "swigluoai":
+ ops.swigluoai_and_mul(act_output, gemm1_output, 1.702, 7.0)
+ else:
+ raise ValueError(f"Unsupported FusedMoe activation: {activation}.")
+
+ ########### gemm2 ##################
+ input_A = act_output.contiguous()
+ input_B = w2
+ gemm2_output = torch.empty((num_moe_inputs, hidden_size),
+ dtype=hidden_states.dtype,
+ device=hidden_states.device)
+ if not is_fp8 and not is_int4 and not is_mxfp4:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=None,
+ ptr_bias=w2_bias,
+ ptr_D=gemm2_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=hidden_size,
+ K=inter_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+ else:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=w2_scales,
+ ptr_bias=w2_bias,
+ ptr_D=gemm2_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=hidden_size,
+ K=inter_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+
+ ops.moe_gather(output, gemm2_output, topk_weights,
+ permuted_row_to_unpermuted_row,
+ unpermuted_row_to_permuted_row,
+ expert_first_token_offset,
+ num_experts_per_node)
+ return output
+
+
+def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor:
+ """Apply jitter to the input tensor for regularization."""
+ low = 1.0 - moe_jitter_eps
+ high = 1.0 + moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return x * (low + noise * (high - low))
+
+
+def compute_top_k(scores: torch.Tensor, moe_top_k: int):
+ """Compute the top-k scores from the logits."""
+ if moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, moe_top_k, dim=-1)
+
+
+def route_tokens_xpu(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: torch.Tensor,
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ training: bool = False,
+) -> tuple:
+ """Route tokens to experts and compute expert weights and indices (XPU version)."""
+ if training and moe_jitter_eps is not None:
+ x = apply_jitter(x, moe_jitter_eps)
+
+ x_flat = x.view(-1, x.shape[-1])
+ logits = torch.nn.functional.linear(x_flat, router_weight, router_bias)
+ expert_weights, expert_indices = compute_top_k(logits, moe_top_k)
+ expert_weights = expert_weights.softmax(dim=-1)
+ if moe_normalize_expert_weights is not None:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+
+ return logits, expert_weights, expert_indices
+
+
+def _get_device_mesh(model):
+ """Extract device_mesh from child's unused pre_hook closure for EP support."""
+ try:
+ hook = next(
+ h
+ for h in model.experts._forward_pre_hooks.values()
+ if "device_mesh" in h.__code__.co_freevars
+ )
+ return hook.__closure__[
+ hook.__code__.co_freevars.index("device_mesh")
+ ].cell_contents
+ except Exception:
+ return None
+
+
+class MegaBlocksMoeMLP(torch.nn.Module):
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> tuple:
+ """
+ Forward pass through the MoE layer.
+
+ Args:
+ x: Input tensor of shape [batch_size, seq_len, hidden_size] or [tokens, hidden_size]
+
+ Returns:
+ Tuple of (output, expert_weights) where:
+ - output: Tensor of same shape as input
+ - expert_weights: Expert weights for each token [tokens, top_k]
+ """
+
+ # Get MoE parameters from the wrapped modules
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+
+ # Get EP (Expert Parallelism) parameters
+ ep_size = 1
+ ep_rank = 0
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = _get_device_mesh(self)
+ if device_mesh is not None:
+ expert_parallel_group = device_mesh.get_group()
+ if expert_parallel_group is not None:
+ import torch.distributed as dist
+ if dist.is_initialized():
+ ep_size = dist.get_world_size(expert_parallel_group)
+ ep_rank = dist.get_rank(expert_parallel_group)
+
+ # Number of experts on this rank
+ num_experts_on_rank = moe_num_experts // ep_size
+
+ # Detect activation type - check for GptOss-style swigluoai activation
+ # GptOssExperts has alpha and limit attributes for swigluoai
+ if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
+ activation = "swigluoai"
+ else:
+ activation = getattr(self.experts, "activation", "silu")
+
+ # Get weight tensors - support different naming conventions
+ if hasattr(self.experts, "gate_up_proj"):
+ w13 = self.experts.gate_up_proj
+ # NOTE: swigluoai_and_mul kernel expects interleaved layout [g0,u0,g1,u1,...]
+ # which matches GptOss's gate_up_proj layout, so no conversion needed.
+
+ elif hasattr(self.experts, "w1"):
+ # Combine w1 and w3 if stored separately
+ w1 = self.experts.w1
+ w3 = getattr(self.experts, "w3", None)
+ if w3 is not None:
+ w13 = torch.cat([w1, w3], dim=-2)
+ else:
+ w13 = w1
+ else:
+ raise AttributeError("experts module must have 'gate_up_proj' or 'w1' attribute")
+
+ if hasattr(self.experts, "down_proj"):
+ w2 = self.experts.down_proj
+ elif hasattr(self.experts, "w2"):
+ w2 = self.experts.w2
+ else:
+ raise AttributeError("experts module must have 'down_proj' or 'w2' attribute")
+
+ # Get optional bias tensors
+ w13_bias = getattr(self.experts, "gate_up_proj_bias", None)
+ w2_bias = getattr(self.experts, "down_proj_bias", None)
+
+ # Get quantization info
+ is_fp8 = getattr(self.experts, "is_fp8", False)
+ is_int4 = getattr(self.experts, "is_int4", False)
+ is_mxfp4 = getattr(self.experts, "is_mxfp4", False)
+
+ w13_scales = getattr(self.experts, "gate_up_proj_scales", None)
+ w2_scales = getattr(self.experts, "down_proj_scales", None)
+
+ # Store original shape
+ in_shape = x.size()
+
+ # Route tokens to experts
+ logits, expert_weights, expert_indices = route_tokens_xpu(
+ x,
+ self.router.weight,
+ getattr(self.router, "bias", None),
+ moe_top_k,
+ moe_num_experts,
+ moe_jitter_eps,
+ moe_normalize_expert_weights,
+ self.training,
+ )
+
+ # Reshape input for fused MoE
+ x_flat = x.view(-1, x.shape[-1])
+
+ # Call XPU fused MoE kernel
+ output = xpu_fused_moe(
+ hidden_states=x_flat,
+ w13=w13,
+ w13_scales=w13_scales,
+ w13_bias=w13_bias,
+ w2=w2,
+ w2_scales=w2_scales,
+ w2_bias=w2_bias,
+ topk_weights=expert_weights.float(),
+ topk_ids=expert_indices,
+ n_experts_per_token=moe_top_k,
+ activation=activation,
+ num_experts=num_experts_on_rank,
+ ep_rank=ep_rank,
+ ep_size=ep_size,
+ is_fp8=is_fp8,
+ is_int4=is_int4,
+ is_mxfp4=is_mxfp4,
+ )
+
+ # All-reduce across EP group to combine partial expert outputs
+ if ep_size > 1 and expert_parallel_group is not None:
+ import torch.distributed as dist
+ dist.all_reduce(output, op=dist.ReduceOp.SUM, group=expert_parallel_group)
+
+ # Restore original shape
+ output = output.view(in_shape)
+
+ return output, expert_weights
+
+
+# Export classes and functions
+__all__ = [
+ "MegaBlocksMoeMLP",
+ "xpu_fused_moe",
+ "cutlass_grouped_gemm",
+ "cutlass_grouped_gemm_xe2",
+]
\ No newline at end of file
diff --git a/build/torch29-cxx11-cu129-x86_64-linux/__init__.py b/build/torch29-cxx11-cu129-x86_64-linux/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..38075732c6d8fa0e1e6ef493145e1aca3851ae6b
--- /dev/null
+++ b/build/torch29-cxx11-cu129-x86_64-linux/__init__.py
@@ -0,0 +1,202 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+from ._ops import ops
+
+from .grouped_gemm import backend as gg_backend
+from .grouped_gemm import ops as gg_ops
+
+
+from ._layers.arguments import Arguments
+from ._layers.dmoe import ParallelDroplessMLP, dMoE
+from ._layers.glu import SparseGLU
+from ._layers.mlp import MLP, SparseMLP
+from ._layers.moe import MoE, ParallelMLP, get_load_balancing_loss
+
+from . import layers
+
+# This section contains the direct kernel exports (not inlcuded in the original code)
+def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
+ """
+ Compute exclusive cumulative sum along the specified dimension.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ result = ops.exclusive_cumsum(x, dim)
+ out.copy_(result)
+ return out
+
+
+def inclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
+ """
+ Compute inclusive cumulative sum along the specified dimension.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ result = ops.inclusive_cumsum(x, dim)
+ out.copy_(result)
+ return out
+
+
+def histogram(x: torch.Tensor, num_bins: int) -> torch.Tensor:
+ """
+ Compute histogram of input tensor values.
+
+ Args:
+ x: Input tensor
+ num_bins: Number of histogram bins
+
+ Returns:
+ Histogram tensor with counts for each bin
+ """
+ return ops.histogram(x, num_bins)
+
+
+def indices(
+ padded_bins: torch.Tensor,
+ block_size: int,
+ output_block_rows: int,
+ output_block_columns: int,
+) -> torch.Tensor:
+ """
+ Construct indices from padded bins for sparse operations.
+
+ Args:
+ padded_bins: Tensor containing bin boundaries
+ block_size: Size of each block
+ output_block_rows: Number of rows in output blocks
+ output_block_columns: Number of columns in output blocks
+
+ Returns:
+ Tensor containing constructed indices
+ """
+ return ops.indices(padded_bins, block_size, output_block_rows, output_block_columns)
+
+
+def replicate_forward(
+ x: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Forward pass of replicate operation - replicate values according to bin sizes.
+
+ Args:
+ x: Input tensor with values to replicate
+ bins: Tensor containing bin sizes
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ return ops.replicate_forward(x, bins, out)
+
+
+def replicate_backward(
+ grad: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Backward pass of replicate operation - reduce gradients back to bins.
+
+ Args:
+ grad: Gradient tensor to reduce
+ bins: Tensor containing bin sizes
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ return ops.replicate_backward(grad, bins, out)
+
+
+def sort(
+ x: torch.Tensor, end_bit: int, x_out: torch.Tensor, iota_out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Radix sort with index tracking.
+
+ Args:
+ x: Input tensor to sort
+ end_bit: Number of bits to consider in sorting
+ x_out: Output tensor for sorted values
+ iota_out: Output tensor for sorted indices
+
+ Returns:
+ The sorted values tensor
+ """
+ return ops.sort(x, end_bit, x_out, iota_out)
+
+
+# Convenience functions for common use cases
+def cumsum(x: torch.Tensor, dim: int = -1, exclusive: bool = False) -> torch.Tensor:
+ """
+ Compute cumulative sum with automatic output allocation.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum (default: last dimension)
+ exclusive: Whether to compute exclusive (True) or inclusive (False) cumsum
+
+ Returns:
+ New tensor containing the cumulative sum
+ """
+ out = torch.empty_like(x)
+ if exclusive:
+ return exclusive_cumsum(x, dim, out)
+ else:
+ return inclusive_cumsum(x, dim, out)
+
+
+def argsort(x: torch.Tensor, end_bit: int = 32) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Sort tensor and return both sorted values and indices.
+
+ Args:
+ x: Input tensor to sort
+ end_bit: Number of bits to consider in sorting
+
+ Returns:
+ Tuple of (sorted_values, sorted_indices)
+ """
+ x_out = torch.empty_like(x)
+ iota_out = torch.empty_like(x)
+ sort(x, end_bit, x_out, iota_out)
+ return x_out, iota_out
+
+
+# Export public API
+__all__ = [
+ "MyReplacementLayer",
+ # Direct kernel exports
+ "exclusive_cumsum",
+ "inclusive_cumsum",
+ "histogram",
+ "indices",
+ "replicate_forward",
+ "replicate_backward",
+ "sort",
+ "cumsum",
+ "argsort",
+ # Original exports
+ "Arguments",
+ "ParallelDroplessMLP",
+ "dMoE",
+ "SparseGLU",
+ "MLP",
+ "SparseMLP",
+ "MoE",
+ "ParallelMLP",
+ "get_load_balancing_loss",
+]
diff --git a/build/torch29-cxx11-cu129-x86_64-linux/_layers/__init__.py b/build/torch29-cxx11-cu129-x86_64-linux/_layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a720e7a2cc4e44636f6e433a2750e945dc38e8b2
--- /dev/null
+++ b/build/torch29-cxx11-cu129-x86_64-linux/_layers/__init__.py
@@ -0,0 +1,10 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# from megablocks.layers.dmoe import dMoE
+from .moe import MoE
+
+__all__ = [
+ 'MoE',
+ # 'dMoE',
+]
diff --git a/build/torch29-cxx11-cu129-x86_64-linux/_layers/activation_fn.py b/build/torch29-cxx11-cu129-x86_64-linux/_layers/activation_fn.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e1d956704840aa4daf7d1d71d24e051567feab9
--- /dev/null
+++ b/build/torch29-cxx11-cu129-x86_64-linux/_layers/activation_fn.py
@@ -0,0 +1,33 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Callable, Union
+
+import torch
+from ..stk import Matrix
+
+
+def act_fn(
+ x: Matrix,
+ function: Callable,
+ return_grad_fn: bool = False,
+ **kwargs,
+) -> Union[tuple[Matrix, Any] | Matrix]:
+ assert isinstance(x, Matrix)
+ with torch.set_grad_enabled(torch.is_grad_enabled() or return_grad_fn):
+ if return_grad_fn:
+ x.data.requires_grad = True
+ out = function(x.data, **kwargs)
+ y = Matrix(
+ x.size(),
+ out,
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
+ if return_grad_fn:
+ return y, out.backward
+ return y
diff --git a/build/torch29-cxx11-cu129-x86_64-linux/_layers/all_to_all.py b/build/torch29-cxx11-cu129-x86_64-linux/_layers/all_to_all.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ac7067bcaa34db1d82b340c43550fe3577aa7a3
--- /dev/null
+++ b/build/torch29-cxx11-cu129-x86_64-linux/_layers/all_to_all.py
@@ -0,0 +1,54 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import torch.distributed as dist
+
+
+class AllToAllOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op):
+ out = torch.empty((sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype)
+
+ ctx.input_shape = x.shape
+ ctx.output_split_sizes = output_split_sizes
+ ctx.input_split_sizes = input_split_sizes
+ ctx.group = group
+ handle = dist.all_to_all_single(
+ out,
+ x,
+ output_split_sizes=output_split_sizes,
+ input_split_sizes=input_split_sizes,
+ group=group,
+ async_op=async_op,
+ )
+ return out, handle
+
+ @staticmethod
+ def backward(ctx, grad, _):
+ if ctx.needs_input_grad[0]:
+ out = torch.empty(
+ ctx.input_shape,
+ device=grad.device,
+ dtype=grad.dtype,
+ )
+ dist.all_to_all_single(
+ out,
+ grad,
+ output_split_sizes=ctx.input_split_sizes,
+ input_split_sizes=ctx.output_split_sizes,
+ group=ctx.group,
+ )
+ return out, None, None, None, None
+ return None, None, None, None, None
+
+
+def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False):
+ return AllToAllOp.apply(
+ x,
+ output_split_sizes,
+ input_split_sizes,
+ group,
+ async_op,
+ )
diff --git a/build/torch29-cxx11-cu129-x86_64-linux/_layers/arguments.py b/build/torch29-cxx11-cu129-x86_64-linux/_layers/arguments.py
new file mode 100644
index 0000000000000000000000000000000000000000..4db9b1bd38bc2e2f421625c124f86b85f45c5ae0
--- /dev/null
+++ b/build/torch29-cxx11-cu129-x86_64-linux/_layers/arguments.py
@@ -0,0 +1,101 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import dataclasses
+from functools import partial
+from typing import Any, Callable, Optional, Union
+
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+
+# import megablocks.grouped_gemm_util as grouped_gemm
+from .. import grouped_gemm_util as grouped_gemm
+
+# Type annotation for in-place Tensor initialization function.
+InitFn = Union[Callable[[torch.Tensor], None], partial[torch.Tensor]]
+
+_ALLOWED_BITWIDTHS = (-1, 4, 8)
+
+DEFAULT_ACTIVATION_FN = partial(F.gelu, approximate='tanh')
+
+
+@dataclasses.dataclass
+class Arguments:
+ # Model arguments.
+ hidden_size: int = 1024
+ ffn_hidden_size: int = 4096
+ num_layers: int = 1
+ bias: bool = True
+ return_bias: bool = True
+ activation_fn: Optional[Callable] = DEFAULT_ACTIVATION_FN
+
+ # MoE arguments.
+ moe_num_experts: int = 1
+ moe_top_k: int = 1
+ moe_capacity_factor: int = 1
+ moe_normalize_expert_weights: Optional[Union[int, float]] = None
+ moe_loss_weight: float = 0.1
+ moe_jitter_eps: Optional[float] = None
+ moe_lbl_in_fp32: bool = False
+
+ # Parallelism arguments.
+ moe_expert_model_parallelism: bool = False
+ expert_parallel_group: Optional[dist.ProcessGroup] = None
+ pipeline_model_parallel_size: int = 1
+ num_layers_per_virtual_pipeline_stage: Optional[int] = None
+
+ # Compute arguments.
+ memory_optimized_mlp: bool = False
+ mlp_type: str = 'mlp'
+ mlp_impl: str = 'sparse'
+
+ # Initialization arguments.
+ fp16: bool = True
+ bf16: bool = False
+ device: Union[int, torch.device] = dataclasses.field(default_factory=torch.cuda.current_device)
+ init_method: InitFn = partial(torch.nn.init.normal_, mean=0.0, std=0.02)
+ output_layer_init_method: InitFn = init_method
+
+ # Benchmarking arguments.
+ uniform_expert_assignment: bool = False
+
+ # shared expert arguments
+ shared_expert: bool = False # enable using shared expert
+ fc_cls: Any = torch.nn.Linear # class of the fully connected layer in shared expert (purpose: to allow using custom FC layer eg te.Linear (for FP8))
+ fc_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict,) # kwargs for custom fc layers
+ remat_act_fn: bool = True # enable act fn to be rematerialized instead of stored
+ shared_expert_hidden_size: Optional[
+ int] = None # hidden size of the shared expert IF we want to set it to something different from hidden_size
+ shared_expert_weighted_sum: bool = False # enable using weighted sum for shared expert output (wieghted by number of experts used)
+
+ # Router Z-loss arguments
+ moe_zloss_weight: float = 0 # 1e-3 is a reasonable value
+ moe_zloss_in_fp32: bool = False
+
+ def __post_init__(self):
+ # Sparse MLP is not supported with triton >=3.2.0
+ # TODO: Remove this once sparse is supported with triton >=3.2.0
+ if self.__getattribute__('mlp_impl') == 'sparse':
+ try:
+ import triton
+ if triton.__version__ >= '3.2.0':
+ raise ValueError(
+ 'Sparse MLP is not supported with triton >=3.2.0. Please use mlp_impl="grouped" instead.',
+ )
+ except ImportError:
+ raise ImportError('Triton is required for sparse MLP implementation')
+
+ if self.__getattribute__('mlp_impl') == 'grouped':
+ grouped_gemm.assert_grouped_gemm_is_available()
+
+ if self.shared_expert_hidden_size is None:
+ self.shared_expert_hidden_size = self.ffn_hidden_size
+
+
+def from_megatron(megatron_args: Any):
+ args = Arguments()
+ for field in dataclasses.fields(args):
+ if hasattr(megatron_args, field.name):
+ setattr(args, field.name, getattr(megatron_args, field.name))
+ return args
diff --git a/build/torch29-cxx11-cu129-x86_64-linux/_layers/common.py b/build/torch29-cxx11-cu129-x86_64-linux/_layers/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d07109702963ba48a3b94ab860807954dfd79c1
--- /dev/null
+++ b/build/torch29-cxx11-cu129-x86_64-linux/_layers/common.py
@@ -0,0 +1,26 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+from .arguments import Arguments
+
+
+def dtype(args: Arguments):
+ if args.fp16:
+ return torch.float16
+ elif args.bf16:
+ return torch.bfloat16
+ return None
+
+
+def cast_if_autocast_enabled(tensor):
+ if torch.is_autocast_enabled():
+ if tensor.device.type == 'cuda':
+ dtype = torch.get_autocast_gpu_dtype()
+ elif tensor.device.type == 'cpu':
+ dtype = torch.get_autocast_cpu_dtype()
+ else:
+ raise NotImplementedError()
+ return tensor.to(dtype=dtype)
+ return tensor
diff --git a/build/torch29-cxx11-cu129-x86_64-linux/_layers/dmlp_registry.py b/build/torch29-cxx11-cu129-x86_64-linux/_layers/dmlp_registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..de2ed047042e438c7190ebb139b6f7f30009734c
--- /dev/null
+++ b/build/torch29-cxx11-cu129-x86_64-linux/_layers/dmlp_registry.py
@@ -0,0 +1,42 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Union
+
+from . import glu, mlp
+from .arguments import Arguments
+
+MlpType = Union[mlp.SparseMLP, glu.SparseGLU]
+
+_REGISTRY = {
+ 'mlp': {
+ 'grouped': mlp.GroupedMLP,
+ 'sparse': mlp.SparseMLP,
+ },
+ 'glu': {
+ 'grouped': glu.GroupedGLU,
+ 'sparse': glu.SparseGLU,
+ },
+}
+
+
+def get(args: Arguments) -> MlpType:
+ """Returns an MLP for use in a dMoE instance.
+
+ Uses the provided arguments to instantiate the appropriate
+ MLP instance. This only contains MLPs for use in dMoEs
+ (ie. only for the dropless versions of MoEs).
+
+ Args:
+ args: propagated Arguments dataclass.
+
+ Returns:
+ An instantiated MLP constructed using the input args.
+ """
+ if args.mlp_type not in _REGISTRY:
+ raise ValueError(f'Unsupported mlp type: {args.mlp_type}')
+
+ if args.mlp_impl not in _REGISTRY[args.mlp_type]:
+ raise ValueError(f'{args.mlp_type} does not support {args.mlp_impl} backend.',)
+
+ return _REGISTRY[args.mlp_type][args.mlp_impl](args)
diff --git a/build/torch29-cxx11-cu129-x86_64-linux/_layers/dmoe.py b/build/torch29-cxx11-cu129-x86_64-linux/_layers/dmoe.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d0375a4df2f27134c4127e60be04f3b45693050
--- /dev/null
+++ b/build/torch29-cxx11-cu129-x86_64-linux/_layers/dmoe.py
@@ -0,0 +1,337 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import numpy as np
+import torch
+
+# try:
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/dmoe.py is needed.',
+# )
+
+# import megablocks.ops as ops
+# # from megablocks.ops import ops
+# from megablocks.layers import common, dmlp_registry, moe, mpu
+# from megablocks.layers.arguments import Arguments
+
+from .. import stk
+from .. import ops
+from . import common, dmlp_registry, moe, mpu
+from .arguments import Arguments
+
+def promote_scalar(x):
+ return x.view(1) if not len(x.size()) else x
+
+
+class ParallelDroplessMLP(moe.ParallelMLP):
+
+ def __init__(self, args: Arguments):
+ super(ParallelDroplessMLP, self).__init__(args)
+ self.hidden_size = args.hidden_size
+ self.ffn_hidden_size = mpu.features_per_rank(args)
+ self.blocking = 128
+ self.mlp = dmlp_registry.get(args)
+
+ # Calculate the number of bits needed to represent the column indices
+ # in the intermediate sparse matrix.
+ max_column_index = ((self.ffn_hidden_size * self.num_experts) // self.blocking)
+ self.transpose_sort_end_bit = max(
+ int(np.ceil(np.log2(max_column_index))),
+ 1,
+ )
+
+ def sparse_transpose(self, size, row_indices, column_indices, offsets):
+ block_columns = size[1] // self.blocking
+
+ # Sort row indices by column indices to get the transposed matrix's
+ # column indices.
+ #
+ # NOTE: Our sort operation uses the same width indices as the input values.
+ # To avoid overflow when we have large activation matrices we cast to
+ # 32-bit before sorting.
+ _, gather_indices = ops.sort(
+ column_indices.int(),
+ self.transpose_sort_end_bit,
+ )
+
+ # There are a constant number of blocks in every row of the sparse matrix.
+ # A blocks offset is:
+ #
+ # row_index * blocks_per_row + column_index % blocks_per_row
+ #
+ # Once we have the block offsets ordered for transposition we can divide
+ # by blocks_per_row to get the transposed column indices.
+ column_indices_t = row_indices.gather(0, gather_indices.long())
+ block_offsets_t = gather_indices.int()
+
+ zero = torch.zeros((1,), dtype=torch.int32, device=row_indices.device)
+ nnz_per_column = ops.histogram(column_indices, block_columns)
+ nnz_per_column = ops.inclusive_cumsum(nnz_per_column, 0)
+ if nnz_per_column.dim() == 0:
+ # This addresses an edge case when ffn_hidden_size is equal to self.blocking.
+ nnz_per_column = nnz_per_column.unsqueeze(0)
+ offsets_t = torch.cat([zero, nnz_per_column])
+ return column_indices_t, offsets_t, block_offsets_t
+
+ def topology(self, x, padded_bins):
+ padded_tokens, _ = x.size()
+ assert padded_tokens % self.blocking == 0
+ if self.ffn_hidden_size % self.blocking != 0:
+ raise ValueError(
+ f'The ffn_hidden_size {self.ffn_hidden_size} must be divisible by ' +
+ f'the block size {self.blocking}. Please update your configuration.',
+ )
+
+ # Offsets for the sparse matrix. All rows have the
+ # same number of nonzero blocks dictated by the
+ # dimensionality of a single expert.
+ block_rows = padded_tokens // self.blocking
+ blocks_per_row = self.ffn_hidden_size // self.blocking
+ offsets = torch.arange(
+ 0,
+ block_rows * blocks_per_row + 1,
+ blocks_per_row,
+ dtype=torch.int32,
+ device=x.device,
+ )
+
+ # Indices for the sparse matrix. The indices for
+ # the intermediate matrix are dynamic depending
+ # on the mapping of tokens to experts.
+ column_indices = ops.topology(
+ padded_bins,
+ self.blocking,
+ block_rows,
+ blocks_per_row,
+ )
+
+ # TODO(tgale): This is unused. Remove the need for this in stk.
+ # For now, use meta init to save the device memory.
+ data = torch.empty(
+ column_indices.numel(),
+ self.blocking,
+ self.blocking,
+ dtype=common.dtype(self.args),
+ device='meta',
+ )
+ shape = (
+ padded_tokens,
+ self.ffn_hidden_size * mpu.experts_per_rank(self.args),
+ )
+ row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
+ column_indices_t, offsets_t, block_offsets_t = self.sparse_transpose(
+ shape,
+ row_indices,
+ column_indices,
+ offsets,
+ )
+ return stk.Matrix(
+ shape,
+ data,
+ row_indices,
+ column_indices,
+ offsets,
+ column_indices_t,
+ offsets_t,
+ block_offsets_t,
+ )
+
+ def indices_and_padded_bins(self, top_experts):
+ # Sort the expert ids to produce the scatter/gather
+ # indices for the permutation.
+ top_experts = top_experts.int()
+ bin_ids, indices = ops.sort(top_experts, self.sort_end_bit)
+
+ # Histogram the expert ids to identify the number of
+ # tokens routed to each expert.
+ tokens_per_expert = ops.histogram(top_experts, self.num_experts)
+
+ # Round the token counts up to the block size used in
+ # the matrix muliplications. Caculate the starting
+ # position of each bin.
+ padded_tokens_per_expert = ops.round_up(
+ tokens_per_expert,
+ self.blocking,
+ )
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ padded_bins = promote_scalar(padded_bins)
+
+ # Calculate the bin bounds for the sorted tokens.
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ bins = promote_scalar(bins)
+ return indices, bin_ids, bins, padded_bins, tokens_per_expert
+
+ def sparse_forward_once(self, x, expert_weights, top_experts):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, padded_bins, tokens_per_expert = (self.indices_and_padded_bins(top_experts))
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.padded_gather(
+ x,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ self.top_k,
+ )
+
+ # Create the sparse matrix topology.
+ with torch.no_grad():
+ topo = self.topology(x, padded_bins)
+
+ # Perform the expert computation.
+ x = self.mlp(x, topo)
+
+ # Un-route the data for the MoE output.
+ x = ops.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ padded_bins,
+ self.top_k,
+ )
+ return x, tokens_per_expert
+
+ # For use in the base-class parallel_forward_once.
+ def sparse_permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy, # unused
+ top_k,
+ ):
+
+ # Round the token counts up to the block size used in the matrix
+ # multiplication. Calculate the starting position of each bin.
+ padded_tokens_per_expert = ops.round_up(
+ tokens_per_expert,
+ self.blocking,
+ )
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ padded_bins = promote_scalar(padded_bins)
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
+
+ # Create the sparse matrix topology.
+ with torch.no_grad():
+ topo = self.topology(x, padded_bins)
+
+ # Perform the expert computation.
+ x = self.mlp(x, topo)
+
+ # Un-route the data for the MoE output.
+ return ops.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ def grouped_forward_once(self, x, expert_weights, top_experts):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ out = self.grouped_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ -1, # unused
+ self.args.moe_top_k,
+ )
+ return out, tokens_per_expert
+
+ def grouped_permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy, # unused
+ top_k,
+ ):
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.gather(x, indices, bin_ids, bins, top_k)
+
+ # Perform the expert computation.
+ x = self.mlp(x, tokens_per_expert)
+
+ # Un-route the data for the MoE output.
+ return ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
+
+ def forward_once(self, x, expert_weights, top_experts):
+ if self.args.mlp_impl == 'sparse':
+ return self.sparse_forward_once(x, expert_weights, top_experts)
+ else:
+ return self.grouped_forward_once(x, expert_weights, top_experts)
+
+ def permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ ):
+ if self.args.mlp_impl == 'sparse':
+ return self.sparse_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ )
+ else:
+ return self.grouped_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ )
+
+
+class dMoE(moe.MoE):
+
+ def _init_experts_mlp(self, args: Arguments):
+ return ParallelDroplessMLP(args)
diff --git a/build/torch29-cxx11-cu129-x86_64-linux/_layers/gelu.py b/build/torch29-cxx11-cu129-x86_64-linux/_layers/gelu.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4c9e6532798615b5c12c96694241a4c18ee8f7b
--- /dev/null
+++ b/build/torch29-cxx11-cu129-x86_64-linux/_layers/gelu.py
@@ -0,0 +1,52 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# try:
+# import stk
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/gelu.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+import torch.nn.functional as F
+
+
+@torch.jit.script
+def _gelu_backward_inplace(g, x):
+ tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
+ ff = (0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out))
+ return g.mul_(ff)
+
+
+def gelu_backward_(grad: stk.Matrix, x: stk.Matrix):
+ # NOTE: The two sparse matrices must have the same topology.
+ if isinstance(grad, stk.Matrix) and isinstance(x, stk.Matrix):
+ return stk.Matrix(
+ x.size(),
+ _gelu_backward_inplace(grad.data, x.data),
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
+ return _gelu_backward_inplace(grad, x)
+
+
+def gelu(x: stk.Matrix):
+ assert isinstance(x, stk.Matrix)
+ return stk.Matrix(
+ x.size(),
+ F.gelu(x.data, approximate='tanh'),
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
diff --git a/build/torch29-cxx11-cu129-x86_64-linux/_layers/glu.py b/build/torch29-cxx11-cu129-x86_64-linux/_layers/glu.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f297a41ff6a1a2a285f5b461951672364b898da
--- /dev/null
+++ b/build/torch29-cxx11-cu129-x86_64-linux/_layers/glu.py
@@ -0,0 +1,244 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# import stk.ops
+# try:
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/glu.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+
+# from megablocks import grouped_gemm_util as gg
+# from megablocks.layers import common, mpu
+# from megablocks.layers.activation_fn import act_fn
+# from megablocks.layers.arguments import Arguments
+# from megablocks.layers.mlp import (
+# SharedMLP,
+# SparseMLP,
+# create_dmoe_expert_weights,
+# resolve_dtensor,
+# )
+
+from .. import grouped_gemm_util as gg
+from . import common, mpu
+from .activation_fn import act_fn
+from .arguments import Arguments
+from .mlp import (
+ SharedMLP,
+ SparseMLP,
+ create_dmoe_expert_weights,
+ resolve_dtensor,
+)
+
+
+class SparseGLU(SparseMLP):
+
+ def __init__(self, args: Arguments):
+ super().__init__(args)
+ self.v1 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ with torch.no_grad():
+ self.v1.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ ),
+ )
+
+ mpu.set_expert_model_parallel_attributes(
+ self.v1,
+ self._should_set_parallelism_attribute,
+ )
+
+ def forward(self, x, topo):
+ if self.args.memory_optimized_mlp:
+ raise NotImplementedError(
+ 'Memory optimized implementation not yet supported with GLU with sparse kernels.',
+ )
+
+ w1, v1, w2 = self.scale_grad(self.w1), self.scale_grad(self.v1,), self.scale_grad(self.w2)
+ w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2)
+
+ # Compute the GLU.
+ x1 = stk.ops.sdd(x, w1.t(), topo)
+ x2 = stk.ops.sdd(x, v1.t(), topo)
+
+ activation_fn_out = act_fn(x1, self.args.activation_fn)
+ x1 = stk.ops.mul(activation_fn_out, x2)
+
+ return stk.ops.dsd(x1, w2)
+
+
+class MemoryOptimizedGroupedGLU(torch.autograd.Function):
+ """GroupedMLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, v1, w2, batch_sizes, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ v1 = v1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], v1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not v1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1', 'v1' and 'w2'.")
+
+ # Layer 0: x @ w1.t().
+ assert gg.backend is not None
+ sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True)
+ v1_out = gg.backend.gmm(x, v1, batch_sizes, trans_b=True)
+
+ # GeLU.
+ activation_fn_out = activation_fn(sdd_out) * v1_out
+
+ # Layer 1: x @ w2.
+ dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, v1, w2, batch_sizes, x, sdd_out, v1_out)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx, ddsd_out):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # Unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, v1, w2 = saved_tensors[:3]
+ batch_sizes = saved_tensors[3]
+ x = saved_tensors[4]
+ sdd_out, v1_out = saved_tensors[5:7]
+
+ # Rematerialize activation_fn output.
+ activation_fn = ctx.activation_fn
+ with torch.set_grad_enabled(True):
+ sdd_out.requires_grad = True
+ v1_out.requires_grad = True
+ activation_fn_out = activation_fn(sdd_out) * v1_out
+ activation_grad_fn = activation_fn_out.backward
+
+ # Compute dw2 with recomputed activation_fn output.
+ assert gg.backend is not None
+ dw2 = gg.backend.gmm(
+ activation_fn_out,
+ ddsd_out,
+ batch_sizes,
+ trans_a=True,
+ )
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ gg.backend.gmm(
+ ddsd_out,
+ w2,
+ batch_sizes,
+ trans_b=True,
+ c=dactivation_fn_out,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out)
+ dsdd_out = sdd_out.grad
+ dv1_out = v1_out.grad
+
+ # Compute dw1.
+ dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True)
+
+ # Compute dv1.
+ dv1 = gg.backend.gmm(dv1_out, x, batch_sizes, trans_a=True)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ dx = ddsd_out
+ gg.backend.gmm(dsdd_out, w1, batch_sizes, c=dx)
+ dx += gg.backend.gmm(dv1_out, v1, batch_sizes)
+ return dx, dw1, dv1, dw2, None, None
+
+
+memory_optimized_grouped_glu = MemoryOptimizedGroupedGLU.apply
+
+
+class GroupedGLU(SparseGLU):
+
+ def forward(self, x, tokens_per_expert):
+ batch_sizes = tokens_per_expert.cpu().to(torch.long)
+ w1, v1, w2 = (
+ self.scale_grad(self.w1),
+ self.scale_grad(self.v1),
+ self.scale_grad(self.w2),
+ )
+ w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2)
+
+ # Re-shape the weights for the grouped GEMMs.
+ ne = mpu.experts_per_rank(self.args)
+ w1 = w1.view(ne, -1, self.args.hidden_size)
+ v1 = v1.view(ne, -1, self.args.hidden_size)
+ w2 = w2.view(ne, -1, self.args.hidden_size)
+
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_grouped_glu(
+ x,
+ w1,
+ v1,
+ w2,
+ batch_sizes,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ assert gg.ops is not None
+ x1 = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
+ x2 = gg.ops.gmm(x, v1, batch_sizes, trans_b=True)
+ x1 = self.args.activation_fn(x1) * x2
+ return gg.ops.gmm(x1, w2, batch_sizes)
+
+
+class SharedGLU(SharedMLP):
+ """GPU for shared expert.
+
+ Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTGLU class
+ """
+
+ def __init__(self, args: Arguments):
+ super().__init__(args)
+ self.gate_proj = args.fc_cls(
+ args.hidden_size,
+ self.args.shared_expert_hidden_size,
+ **self.fc_kwargs,
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))
diff --git a/build/torch29-cxx11-cu129-x86_64-linux/_layers/memory_test.py b/build/torch29-cxx11-cu129-x86_64-linux/_layers/memory_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..74d1166931b712635131985b25a89f4ca23e576d
--- /dev/null
+++ b/build/torch29-cxx11-cu129-x86_64-linux/_layers/memory_test.py
@@ -0,0 +1,103 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import gc
+
+import torch
+import torch.distributed as dist
+
+# from megablocks.layers import arguments, dmoe
+from . import arguments, dmoe
+
+_TESTS = ((8, 2048, 4096, 4096, 32, 4),)
+
+
+def get_tensors():
+ ptrs = set()
+ out = []
+ for obj in gc.get_objects():
+ if torch.is_tensor(obj):
+ if not obj.is_contiguous() or obj.data_ptr() in ptrs:
+ continue
+ out.append(obj)
+ ptrs.add(obj.data_ptr())
+ return out
+
+
+def test_memory(
+ group,
+ batch_size,
+ sequence_length,
+ hidden_size,
+ ffn_hidden_size,
+ num_experts,
+ top_k,
+):
+ args = arguments.Arguments(
+ hidden_size=hidden_size,
+ ffn_hidden_size=ffn_hidden_size,
+ moe_num_experts=num_experts,
+ moe_top_k=top_k,
+ moe_expert_model_parallelism=True,
+ expert_parallel_group=group,
+ fp16=False,
+ bf16=True,
+ device=torch.cuda.current_device(),
+ )
+ layer = dmoe.dMoE(args).cuda()
+
+ x = torch.randn((batch_size, sequence_length, hidden_size),
+ device=torch.cuda.current_device(),
+ dtype=torch.bfloat16).requires_grad_(True)
+ torch.cuda.empty_cache()
+
+ # Run forward + backward.
+ # with torch.autograd.detect_anomaly():
+ out, _ = layer(x)
+ out.mean().backward()
+
+ # Report peak memory.
+ mem = torch.cuda.max_memory_allocated()
+ print('Max Memory Allocated = {:0.0f}MiB'.format(mem / 1e6))
+ print('Max Memory Reserved = {:0.0f}MiB'.format(torch.cuda.max_memory_reserved() / 1e6,),)
+
+ # Calculate weight and gradient memory usage.
+ weight_memory = 2 * (
+ layer.router.layer.weight.numel() + layer.experts.mlp.w1.numel() + layer.experts.mlp.w2.numel()
+ )
+
+ def grad_numel(x):
+ if x.grad is not None:
+ return x.grad.numel()
+ return 0
+
+ grad_memory = 2 * (
+ grad_numel(layer.router.layer.weight) + grad_numel(layer.experts.mlp.w1) + grad_numel(layer.experts.mlp.w2)
+ )
+ weight_memory += grad_memory
+
+ print('Weight Memory Allocated = {:0.0f}MiB'.format(weight_memory / 1e6))
+ print('Activation Memory Allocated = {:0.0f}MiB'.format((mem - weight_memory) / 1e6,),)
+
+ # Manually calculate GPU memory usage from the garbage
+ # collector.
+ gc.collect()
+ total = 0
+ tensors = get_tensors()
+ tensors = sorted(tensors, key=lambda x: -x.numel())
+ for i, t in enumerate(tensors):
+ total += t.numel()
+ print(f'{i}: {t.shape}, {t.numel() * 2}')
+ del tensors
+
+ print('Total Bytes Found = {:0.0f}MiB'.format(total * 2 / 1e6))
+
+
+if __name__ == '__main__':
+ assert dist.is_available()
+ group = dist.init_process_group(backend='nccl')
+ local_rank = dist.get_rank(group)
+ torch.cuda.set_device(local_rank)
+
+ for args in _TESTS:
+ test_memory(group, *args)
diff --git a/build/torch29-cxx11-cu129-x86_64-linux/_layers/mlp.py b/build/torch29-cxx11-cu129-x86_64-linux/_layers/mlp.py
new file mode 100644
index 0000000000000000000000000000000000000000..c99afb9904c24a8b6a83e79059cd1251dbbfd99e
--- /dev/null
+++ b/build/torch29-cxx11-cu129-x86_64-linux/_layers/mlp.py
@@ -0,0 +1,587 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# try:
+# import stk
+# import stk.backend.triton_kernels
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/mlp.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+from packaging import version
+
+# from megablocks import grouped_gemm_util as gg
+# from megablocks.layers import common, gelu, mpu
+# from megablocks.layers.activation_fn import act_fn
+# from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
+
+from .. import grouped_gemm_util as gg
+from . import common, gelu, mpu
+from .activation_fn import act_fn
+from .arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
+
+class ScaleGradient(torch.autograd.Function):
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx: Any, x: torch.Tensor, scale: float):
+ ctx.scale = scale
+ return x
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx: torch.Tensor, grad: torch.Tensor):
+ return grad * ctx.scale, None
+
+
+scale_gradient = ScaleGradient.apply
+
+
+def resolve_dtensor(weight: torch.Tensor):
+ if version.parse(torch.__version__) >= version.parse('2.0.0'):
+ from torch.distributed._tensor import DTensor
+ if isinstance(weight, DTensor):
+ return weight.to_local()
+ return weight
+
+
+def create_moe_expert_weights(
+ args: Arguments,
+ num_experts: int,
+ ffn_hidden_size: int,
+ hidden_size: int,
+ init_method: InitFn,
+):
+ # Create the entire weight matrix such that the sampled weights will
+ # not vary between data parallelism and expert model parallelism for
+ # the same random seed.
+ master_weights = torch.empty(
+ num_experts,
+ ffn_hidden_size,
+ hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ )
+ init_method(master_weights)
+
+ if not args.moe_expert_model_parallelism:
+ return master_weights
+
+ # Calculate the amount of sharding in each dimension.
+ expert_sharding_degree = mpu.expert_sharding_degree(args)
+ hidden_sharding_degree = mpu.hidden_sharding_degree(args)
+
+ # Calculate the experts per rank.
+ #
+ # NOTE: We assign ranks to be expert parallel before going
+ # tensor parallel.
+ rank = mpu.get_expert_parallel_rank(args)
+ expert_rank = rank % expert_sharding_degree
+ num_experts_per_rank = num_experts // expert_sharding_degree
+ start_expert = expert_rank * num_experts_per_rank
+ end_expert = (expert_rank + 1) * num_experts_per_rank
+
+ # Calculate the rows per rank.
+ row_rank = rank // expert_sharding_degree
+ num_rows_per_rank = ffn_hidden_size // hidden_sharding_degree
+ start_row = row_rank * num_rows_per_rank
+ end_row = (row_rank + 1) * num_rows_per_rank
+
+ # Slice the weight matrix to get the chunk for this rank.
+ with torch.no_grad():
+ weights = master_weights[start_expert:end_expert, start_row:end_row]
+ return weights
+
+
+class MLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ # expert_parallel_world_size = mpu.get_expert_parallel_world_size(args)
+ experts_per_rank = mpu.experts_per_rank(args)
+
+ self.w1 = torch.nn.Parameter(
+ torch.empty(
+ experts_per_rank,
+ args.hidden_size,
+ mpu.features_per_rank(args),
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ self.w2 = torch.nn.Parameter(
+ torch.empty(
+ experts_per_rank,
+ mpu.features_per_rank(args),
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w1,
+ args.moe_expert_model_parallelism,
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w2,
+ args.moe_expert_model_parallelism,
+ )
+
+ # Initialize the parameters for the MLP.
+ #
+ # NOTE: It is important that we create the weight tensors prior
+ # to creating the master weights and slicing our the piece for
+ # this rank. If the master weights are created first the PyTorch
+ # caching allocator appears to use the same memory block for these
+ # and the slice which causes large increases in our peak memory
+ # usage.
+ with torch.no_grad():
+ w1 = create_moe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ )
+ self.w1.copy_(w1.transpose(1, 2).contiguous())
+ self.w2.copy_(
+ create_moe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.output_layer_init_method,
+ ),
+ )
+
+ self.gradient_scale = None
+ if self.args.moe_expert_model_parallelism:
+ self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,)
+
+ def scale_grad(self, w):
+ if self.gradient_scale is None:
+ return w
+ return scale_gradient(w, self.gradient_scale)
+
+ def forward(self, x):
+ w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2)
+ w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2)
+ x = torch.bmm(x, w1)
+ x = self.args.activation_fn(x)
+ return torch.bmm(x, w2)
+
+
+def create_dmoe_expert_weights(
+ args: Arguments,
+ num_experts: int,
+ rows: int,
+ columns: int,
+ init_method: InitFn,
+):
+ weights = create_moe_expert_weights(
+ args,
+ num_experts,
+ rows,
+ columns,
+ init_method,
+ )
+ return weights.view([-1, columns])
+
+
+class MemoryOptimizedMLP(torch.autograd.Function):
+ """Sparse MLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, w2, topo, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.")
+
+ topo_tensors = (
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets,
+ topo.column_indices_t,
+ topo.offsets_t,
+ topo.block_offsets_t,
+ )
+
+ # Layer 0: x @ w1.t().
+ sdd_out = stk.ops.sdd(x, w1.t(), topo)
+
+ # GeLU.
+ activation_fn_out = act_fn(sdd_out, activation_fn)
+
+ # Layer 1: x @ w2.
+ dsd_out = stk.ops.dsd(activation_fn_out, w2)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.shape = topo.shape
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.data.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, w2, *topo_tensors, x, sdd_out.data)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx, ddsd_out):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, w2 = saved_tensors[:2]
+ topo_tensors = saved_tensors[2:8]
+ x = saved_tensors[8]
+ sdd_out_data = saved_tensors[9]
+
+ # rematerialize activation function output
+ activation_fn = ctx.activation_fn
+ sdd_out = stk.Matrix(ctx.shape, sdd_out_data, *topo_tensors)
+ activation_fn_out, activation_grad_fn = act_fn(
+ sdd_out,
+ activation_fn,
+ return_grad_fn=True,
+ )
+
+ # Compute dw2 with recomputed activation_fn output.
+ dw2 = stk.ops.dsd(activation_fn_out.t(), ddsd_out)
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ stk.backend.triton_kernels.sdd(
+ ddsd_out,
+ w2.t(),
+ dactivation_fn_out.shape,
+ dactivation_fn_out.data,
+ dactivation_fn_out.offsets,
+ dactivation_fn_out.row_indices,
+ dactivation_fn_out.column_indices,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ if activation_fn is DEFAULT_ACTIVATION_FN:
+ dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out)
+ else:
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out.data)
+ dsdd_out = stk.Matrix(ctx.shape, sdd_out.data.grad, *topo_tensors)
+
+ # Compute dw1.
+ dw1 = stk.ops.dsd(dsdd_out.t(), x)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ stk.backend.triton_kernels.dsd(
+ dsdd_out.shape,
+ dsdd_out.data,
+ dsdd_out.offsets,
+ dsdd_out.row_indices,
+ dsdd_out.column_indices,
+ dsdd_out.offsets_t,
+ dsdd_out.column_indices_t,
+ dsdd_out.block_offsets_t,
+ False,
+ w1,
+ ddsd_out,
+ )
+ dx = ddsd_out
+ return dx, dw1, dw2, None, None
+
+
+memory_optimized_mlp = MemoryOptimizedMLP.apply
+
+
+class SparseMLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ self._num_rows_per_rank = mpu.experts_per_rank(args) * mpu.features_per_rank(args)
+
+ self.w1 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ self.w2 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+
+ # Initialize the parameters for the MLP.
+ #
+ # NOTE: It is important that we create the weight tensors prior
+ # to creating the master weights and slicing our the piece for
+ # this rank. If the master weights are created first the PyTorch
+ # caching allocator appears to use the same memory block for these
+ # and the slice which causes large increases in our peak memory
+ # usage.
+ with torch.no_grad():
+ self.w1.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ ),
+ )
+ self.w2.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.output_layer_init_method,
+ ),
+ )
+
+ self._should_set_parallelism_attribute = args.moe_expert_model_parallelism
+ mpu.set_expert_model_parallel_attributes(
+ self.w1,
+ self._should_set_parallelism_attribute,
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w2,
+ self._should_set_parallelism_attribute,
+ )
+
+ self.gradient_scale = None
+ if self.args.moe_expert_model_parallelism:
+ self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,)
+
+ def scale_grad(self, w):
+ if self.gradient_scale is None:
+ return w
+ return scale_gradient(w, self.gradient_scale)
+
+ def forward(self, x, topo):
+ w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2)
+ w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2)
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_mlp(
+ x,
+ w1,
+ w2,
+ topo,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ x = stk.ops.sdd(x, w1.t(), topo)
+ activation_fn_out = act_fn(x, self.args.activation_fn)
+ return stk.ops.dsd(activation_fn_out, w2)
+
+
+class MemoryOptimizedGroupedMLP(torch.autograd.Function):
+ """GroupedMLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, w2, batch_sizes, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.")
+
+ # Layer 0: x @ w1.t().
+ assert gg.backend is not None
+ sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True)
+
+ # activation_fn
+ activation_fn_out = activation_fn(sdd_out)
+
+ # Layer 1: x @ w2.
+ dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, w2, batch_sizes, x, sdd_out)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx: Any, ddsd_out: torch.Tensor):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # Unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, w2 = saved_tensors[:2]
+ batch_sizes = saved_tensors[2]
+ x = saved_tensors[3]
+ sdd_out = saved_tensors[4]
+
+ # Rematerialize activation_fn output.
+ activation_fn = ctx.activation_fn
+ with torch.set_grad_enabled(True):
+ sdd_out.requires_grad = True
+ activation_fn_out = activation_fn(sdd_out)
+ activation_grad_fn = activation_fn_out.backward
+
+ # Compute dw2 with recomputed activation_fn output.
+ assert gg.backend is not None
+ dw2 = gg.backend.gmm(
+ activation_fn_out,
+ ddsd_out,
+ batch_sizes,
+ trans_a=True,
+ )
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ gg.backend.gmm(
+ ddsd_out,
+ w2,
+ batch_sizes,
+ trans_b=True,
+ c=dactivation_fn_out,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ if activation_fn is DEFAULT_ACTIVATION_FN:
+ dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out)
+ else:
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out)
+ dsdd_out = sdd_out.grad
+
+ # Compute dw1.
+ dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ gg.backend.gmm(dsdd_out, w1, batch_sizes, c=ddsd_out)
+ dx = ddsd_out
+ return dx, dw1, dw2, None, None
+
+
+memory_optimized_grouped_mlp = MemoryOptimizedGroupedMLP.apply
+
+
+class GroupedMLP(SparseMLP):
+
+ def forward(self, x, tokens_per_expert):
+ batch_sizes = tokens_per_expert.cpu().to(torch.long)
+ w1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.w2))
+
+ # Re-shape the weights for the grouped GEMMs.
+ ne = mpu.experts_per_rank(self.args)
+ w1 = resolve_dtensor(w1).view(ne, -1, self.args.hidden_size)
+ w2 = resolve_dtensor(w2).view(ne, -1, self.args.hidden_size)
+
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_grouped_mlp(
+ x,
+ w1,
+ w2,
+ batch_sizes,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ assert gg.ops is not None
+ x = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
+ x = self.args.activation_fn(x)
+ return gg.ops.gmm(x, w2, batch_sizes)
+
+
+class SharedMLP(torch.nn.Module):
+ """MLP for shared expert.
+
+ Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTMLP class
+ """
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ self.fc_kwargs: dict[str, Any] = {
+ 'bias': args.bias,
+ 'device': args.device,
+ }
+ self.fc_kwargs.update(args.fc_kwargs)
+
+ self.up_proj = args.fc_cls(
+ args.hidden_size,
+ args.shared_expert_hidden_size,
+ **self.fc_kwargs,
+ )
+ self.act = args.activation_fn
+ self.down_proj = args.fc_cls(
+ args.shared_expert_hidden_size,
+ args.hidden_size,
+ **self.fc_kwargs,
+ )
+ self.down_proj._is_residual = True # a flag for llm-foundry init
+
+ def add_experts_sharedexpert(
+ self,
+ shared_expert_out: torch.Tensor,
+ expert_out: torch.Tensor,
+ ) -> torch.Tensor:
+ # Helper function to add expert output to shared expert output
+ # with optional weighted sum.
+ if self.args.shared_expert_weighted_sum:
+ # enable using weighted sum for shared expert output
+ # wieghted by number of experts used
+ t_experts = self.args.moe_top_k + 1
+ sh_mlp_out = shared_expert_out / t_experts
+ return sh_mlp_out.add(
+ expert_out,
+ alpha=(self.args.moe_top_k / t_experts),
+ )
+
+ return shared_expert_out + expert_out
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.down_proj(self.act(self.up_proj(x)))
diff --git a/build/torch29-cxx11-cu129-x86_64-linux/_layers/moe.py b/build/torch29-cxx11-cu129-x86_64-linux/_layers/moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0a4aeaacc9c86fc70944e730c53f7a55644e05e
--- /dev/null
+++ b/build/torch29-cxx11-cu129-x86_64-linux/_layers/moe.py
@@ -0,0 +1,507 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Optional, Tuple
+
+import numpy as np
+import torch
+import torch.distributed as dist
+
+# import megablocks.ops as ops
+# from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry
+# from megablocks.layers.all_to_all import all_to_all
+# from megablocks.layers.arguments import Arguments
+
+from ..ops import (
+ sort,
+ histogram,
+ inclusive_cumsum,
+ exclusive_cumsum,
+ binned_gather,
+ binned_scatter,
+ gather,
+ scatter,
+ repeat,
+ replicate,
+)
+
+from . import common, mlp, mpu, router, sharedexpert_registry
+from .arguments import Arguments
+from .all_to_all import all_to_all
+
+_LOAD_BALANCING_LOSS = []
+
+
+def save_load_balancing_loss(loss):
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.append(loss)
+
+
+def get_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ return _LOAD_BALANCING_LOSS
+
+
+def clear_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.clear()
+
+
+def batched_load_balancing_loss(args: Arguments):
+ if args.moe_loss_weight == 0:
+ return 0.0
+
+ # tokens_per_expert[i].shape = (num_experts)
+ # expert_scores[i].shape = (tokens, num_experts)
+ tokens_per_expert, expert_scores = zip(*get_load_balancing_loss())
+ num_layers_per_pipeline_stage = (args.num_layers // args.pipeline_model_parallel_size)
+ if args.num_layers_per_virtual_pipeline_stage is not None:
+ num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage
+
+ if len(tokens_per_expert) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f'Expected {num_layers_per_pipeline_stage} token_per_experts '
+ f'but found {len(tokens_per_expert)}.\nnum_layers = '
+ f'{args.num_layers}\npipeline_model_parallel_size = '
+ f'{args.pipeline_model_parallel_size}\n'
+ 'num_layers_per_virtual_pipeline_stage'
+ f' = {args.num_layers_per_virtual_pipeline_stage}',
+ )
+ if len(expert_scores) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f'Expected {num_layers_per_pipeline_stage} expert_scores '
+ f'but found {len(tokens_per_expert)}.\nnum_layers = '
+ f'{args.num_layers}\npipeline_model_parallel_size = '
+ f'{args.pipeline_model_parallel_size}\n'
+ 'num_layers_per_virtual_pipeline_stage'
+ f' = {args.num_layers_per_virtual_pipeline_stage}',
+ )
+
+ # Verify the shape of the tokens_per_expert and expert_scores tensors.
+ assert all((x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert))
+
+ tokens = expert_scores[0].shape[0]
+ assert all(((x.ndim == 2 and x.shape[1] == args.moe_num_experts and x.shape[0] == tokens) for x in expert_scores))
+
+ # Concatenate the contributions of each layer and convert to
+ # the correct types and formats for the dot product.
+ expert_scores = torch.cat(expert_scores, dim=1)
+ if args.moe_lbl_in_fp32:
+ expert_scores = expert_scores.float()
+ if tokens != 0:
+ expert_scores = expert_scores.mean(dim=0)
+ else:
+ expert_scores = expert_scores.sum(dim=0)
+ tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype)
+
+ expected_values = num_layers_per_pipeline_stage * args.moe_num_experts
+ assert tokens_per_expert.numel() == expected_values
+ assert expert_scores.numel() == expected_values
+
+ # Calculate the total scale across all factors.
+ #
+ # loss_weight * num_experts / (num_layers * tokens * top_k)
+ scale_numerator = (args.moe_num_experts * args.moe_loss_weight)
+ scale_denominator = (args.num_layers * tokens * args.moe_top_k)
+ scale = scale_numerator / scale_denominator
+ return scale * torch.dot(tokens_per_expert, expert_scores)
+
+
+# NOTE: This class defines MoE expert computation, including expert model parallel
+# communication. When using FSDP on top of MegaBlocks this is the module that should
+# be wrapped s.t. the weight all-gathers can be scheduled *before* the expert model
+# parallel all2all.
+class ParallelMLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super(ParallelMLP, self).__init__()
+ self.args = args
+
+ # Calculate the number of experts in total and the number of experts
+ # owned by this rank.
+ # world_size = mpu.get_expert_parallel_world_size(args)
+ self.num_experts = args.moe_num_experts
+ self.top_k = self.args.moe_top_k
+
+ # Calculate the number of bits needed to represent the expert indices
+ # so that we can pass it to radix sort.
+ self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1)
+
+ # Expert MLP.
+ self.mlp = mlp.MLP(args)
+
+ self.bias: Optional[torch.Tensor]
+ if self.args.bias:
+ # Note that the output bias is not parallelized with expert
+ # model parallelism.
+ self.bias = torch.nn.Parameter(
+ torch.empty(
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ torch.nn.init.zeros_(self.bias)
+ else:
+ self.register_parameter('bias', None)
+
+ # Select the forward function for the operating mode.
+ self.forward_fn = (self.parallel_forward_once if args.moe_expert_model_parallelism else self.forward_once)
+
+ def expert_capacity(self, tokens: int) -> int:
+ world_size = mpu.get_expert_parallel_world_size(self.args)
+ tokens_per_expert = (self.top_k * tokens * world_size / self.num_experts)
+ return int(self.args.moe_capacity_factor * tokens_per_expert)
+
+ def load_balancing_loss(self, tokens_per_expert: torch.Tensor, expert_scores: torch.Tensor):
+ """Calculate the load balancing loss contribution."""
+ assert len(expert_scores.size()) == 2
+ tokens, num_experts = expert_scores.size()
+ assert num_experts == self.num_experts
+ assert len(tokens_per_expert.size()) == 1
+ num_experts, = tokens_per_expert.size()
+ assert num_experts == self.num_experts
+ scale = self.num_experts / (tokens * self.top_k)
+ return scale * torch.dot(
+ tokens_per_expert.to(expert_scores.dtype),
+ expert_scores.mean(dim=0),
+ )
+
+ def indices_and_bins(self,
+ top_expert: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ # Sort the expert ids to produce the scatter/gather
+ # indices for the permutation.
+ #
+ # TODO(tgale): Is it worth doing this conversion to 32-bit
+ # prior? Could we place the `torch.max` operation to return
+ # 32-bit expert indices?
+ top_expert = top_expert.int()
+ # output = ops.sort(top_expert, self.sort_end_bit)
+ output = sort(top_expert, self.sort_end_bit)
+ assert output is not None
+ bin_ids, indices = output
+
+ # Histogram the expert ids to identify the number of
+ # tokens routed to each expert.
+ #
+ # TODO(tgale): Does the sorted data produce a more favorable
+ # data distribution for histogram? Or is the op parallelism
+ # worth more?
+ # tokens_per_expert = ops.histogram(top_expert, self.num_experts)
+ tokens_per_expert = histogram(top_expert, self.num_experts)
+
+ # Calculate the bin bounds for the sorted tokens.
+ # bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ bins = inclusive_cumsum(tokens_per_expert, 0)
+ assert bins is not None
+ bins = bins.view(1) if not len(bins.size()) else bins
+
+ assert isinstance(indices, torch.Tensor)
+ assert isinstance(bin_ids, torch.Tensor)
+ assert isinstance(bins, torch.Tensor)
+ assert isinstance(tokens_per_expert, torch.Tensor)
+
+ return indices, bin_ids, bins, tokens_per_expert
+
+ def permute_and_compute(
+ self,
+ x: torch.Tensor,
+ tokens_per_expert: int, # unused
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor, # unused
+ expert_weights: torch.Tensor,
+ bins: torch.Tensor,
+ expert_capacity: int,
+ top_k: int,
+ ):
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ # output = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
+ output = binned_gather(x, indices, bins, expert_capacity, top_k)
+ assert output is not None
+ x = output
+
+ # Perform the expert computation. Note that we don't
+ # use biases for these linear operations.
+ x = self.mlp(x)
+
+ # Un-route the data for the MoE output.
+ # return ops.binned_scatter(x, indices, expert_weights, bins, top_k)
+ return binned_scatter(x, indices, expert_weights, bins, top_k)
+
+
+ def forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ # If expert_capacity is set to zero, set the number of tokens
+ # per expert to the maximum we need to avoid dropping tokens.
+ sl, bs, _ = x.size()
+ expert_capacity = self.expert_capacity(sl * bs)
+ if expert_capacity == 0:
+ expert_capacity = torch.max(tokens_per_expert).item()
+
+ x = self.permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ self.top_k,
+ )
+ return x, tokens_per_expert
+
+ def parallel_forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ # NOTE: This function implements the same computation as forward_once
+ # but with expert model parallelism.
+ #
+ # 1. Permute the tokens locally so that they are grouped by their
+ # expert assignments. This allows us to transfer all of the tokens
+ # for a remote device in one communication primitive.
+ #
+ # 2. Permute the tokens across the expert parallel devices. After
+ # this is completed each device has all of the tokens assigned to
+ # its set of experts in its local HBM.
+ #
+ # 3. Permute the tokens locally so that they are grouped by their
+ # expert assignement. After the distributed permutation the tokens
+ # are grouped by which device they came from. We re-order them
+ # locally to allow for efficient computation.
+ #
+ # After this series of permutations we compute the linear layers
+ # and then repeat these three steps in reverse to produce the final
+ # output.
+ #
+ # Compute the mapping of local tokens to experts.
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ # If we're sharding the experts along the hidden dimension
+ # multiple devices own parts of the same sets of experts.
+ # Replicate the token counts so every device gets the counts.
+ # repeated_tokens_per_expert = ops.repeat(
+ repeated_tokens_per_expert = repeat(
+ tokens_per_expert,
+ (mpu.hidden_sharding_degree(self.args),),
+ )
+
+ # Pass token count information to the device on which the
+ # target expert resides.
+ parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert,)
+ tpe_handle = dist.all_to_all_single(
+ parallel_tokens_per_expert,
+ repeated_tokens_per_expert,
+ group=self.args.expert_parallel_group,
+ async_op=True,
+ )
+
+ # Permute locally and without any padding so that tokens for each
+ # parallel device are stored contiguously.
+ #
+ # This view updates the shape of the tensor from [sl, bs, hs] to
+ # [sl * bs, hs] prior to the permutation.
+ x = x.view(-1, x.shape[-1])
+ # output = ops.gather(x, indices, bin_ids, bins, self.top_k)
+ output = gather(x, indices, bin_ids, bins, self.top_k)
+ assert output is not None
+ x = output
+
+ # Compute the number of tokens that will be received from each
+ # device and permute the input data across the devices.
+ with torch.no_grad():
+ tpe_handle.wait()
+ experts_per_rank = mpu.experts_per_rank(self.args)
+
+ # Reshape to [world_size, num_experts_per_rank].
+ world_size = mpu.get_expert_parallel_world_size(self.args)
+ repeated_tokens_per_expert = (repeated_tokens_per_expert.view(world_size, experts_per_rank))
+ parallel_tokens_per_expert = (parallel_tokens_per_expert.view(world_size, experts_per_rank))
+
+ # TODO(tgale): It might be faster to do this on the GPU and
+ # then communicate the results back to the host.
+ send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1)
+ parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
+ recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1)
+
+ # Convert the send/recv counts to lists.
+ send_counts = send_counts.tolist()
+ recv_counts = recv_counts.tolist()
+ tokens_received = sum(recv_counts)
+
+ # If we're sharding the experts along the hidden dimension
+ # multiple devices own parts of the same sets of experts.
+ # Replicate the token counts so devices that share experts
+ # get all of the tokens assigned to them.
+ #
+ # TODO(tgale): Fuse this into the prior, local permutation.
+ # x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
+ x = repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
+
+ # Start the cross-device permutation asynchronously so we can
+ # overlap communication with computation.
+ parallel_x, parallel_x_handle = all_to_all(
+ x,
+ recv_counts,
+ send_counts,
+ self.args.expert_parallel_group,
+ async_op=True,
+ )
+
+ with torch.no_grad():
+ # After we do the cross-device permutation we have the tokens on the
+ # correct device but not yet grouped by expert because we received
+ # tokens from each device as contiguous chunks. To group the tokens
+ # for expert computation we'll do one more local permutation. The
+ # rest of this torch.no_grad() scope sets up the indices and bins
+ # for this permutation.
+ # replicate_bins = ops.inclusive_cumsum(
+ replicate_bins = inclusive_cumsum(
+ parallel_tokens_per_expert.flatten(),
+ 0,
+ )
+ replicate_bins = (replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins)
+
+ # Construct the expert indices for the permuted tokens.
+ parallel_top_expert = torch.remainder(
+ torch.arange(
+ self.num_experts * mpu.hidden_sharding_degree(self.args),
+ dtype=torch.int32,
+ device=indices.device,
+ ),
+ mpu.experts_per_rank(self.args),
+ )
+ # parallel_top_expert = ops.replicate(
+ parallel_top_expert = replicate(
+ parallel_top_expert.unsqueeze(dim=0),
+ replicate_bins,
+ tokens_received,
+ ).flatten()
+
+ # TODO(tgale): The sort_end_bit here can be reduced.
+ # parallel_bin_ids, parallel_indices = ops.sort(
+ parallel_bin_ids, parallel_indices = sort(
+ parallel_top_expert,
+ self.sort_end_bit,
+ )
+
+ # Calculate the bins boundaries from the token counts.
+ parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+ # parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = (parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins)
+
+ # If expert_capacity is set to zero, set the number of tokens
+ # per expert to the maximum we need to avoid dropping tokens.
+ tokens, _ = x.size()
+ expert_capacity = self.expert_capacity(tokens)
+ if expert_capacity == 0:
+ expert_capacity = torch.max(parallel_tokens_per_expert).item()
+
+ # Locally permute the tokens and perform the expert computation.
+ # Block to make sure that the cross-device permutation is complete.
+ if self.args.mlp_impl == 'grouped':
+ # GroupedMLP requires counts on CPU. We can use the tensor already
+ # moved to CPU for the prior all_to_all, which avoids an extra
+ # device synchronization.
+ parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+ parallel_x_handle.wait()
+ parallel_x = self.permute_and_compute(
+ parallel_x,
+ parallel_tokens_per_expert,
+ parallel_indices,
+ parallel_bin_ids,
+ None, # expert_weights
+ parallel_bins,
+ expert_capacity,
+ top_k=1,
+ )
+
+ # Un-permute the tokens across the devices.
+ x, _ = all_to_all(
+ parallel_x,
+ send_counts,
+ recv_counts,
+ self.args.expert_parallel_group,
+ )
+
+ # Reduce along the hidden sharding to get the final outputs.
+ #
+ # TODO(tgale): Fuse this into the following local permutation.
+ shape = (
+ mpu.hidden_sharding_degree(self.args),
+ -1,
+ self.args.hidden_size,
+ )
+ # x = ops.sum(x.view(shape), dim=0)
+ x = x.view(shape).sum(dim=0)
+
+ # Un-permute locally to setup for the next series of operations.
+ # x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
+ x = scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
+ return x, tokens_per_expert.flatten()
+
+ def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ in_shape = x.size()
+
+ # Compute the experts.
+ x, tokens_per_expert = self.forward_fn(x, expert_weights, top_experts)
+ if self.training and self.args.moe_loss_weight > 0:
+ save_load_balancing_loss((tokens_per_expert, scores))
+ x = x.view(in_shape)
+ if self.bias is not None:
+ if self.args.return_bias:
+ return x, self.bias
+ return x + self.bias
+ return x
+
+
+class MoE(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super(MoE, self).__init__()
+
+ # Token router.
+ self.router = router.LearnedRouter(args)
+
+ # Expert computation helper.
+ self.experts = self._init_experts_mlp(args)
+
+ self.shared_expert = None
+ if args.shared_expert:
+ # SharedExpert computation helper.
+ self.shared_expert = sharedexpert_registry.get(args)
+
+ def _init_experts_mlp(self, args: Arguments):
+ return ParallelMLP(args)
+
+ def forward(self, x: torch.Tensor):
+ # NOTE: If we're going to cast the activations to lower precision
+ # do it before we permute the tokens to save bandwidth.
+ x = common.cast_if_autocast_enabled(x)
+
+ # Compute the expert scores and assignments.
+ scores, expert_weights, top_experts = self.router(x)
+
+ # Compute the experts.
+ out = self.experts(x, scores, expert_weights, top_experts)
+ if self.shared_expert is not None:
+ shared_expert_out = self.shared_expert(x)
+ out = self.shared_expert.add_experts_sharedexpert(
+ shared_expert_out,
+ out,
+ )
+ return out
diff --git a/build/torch29-cxx11-cu129-x86_64-linux/_layers/mpu.py b/build/torch29-cxx11-cu129-x86_64-linux/_layers/mpu.py
new file mode 100644
index 0000000000000000000000000000000000000000..434e143ab42bf3f83406d69e9dd1f72777716e22
--- /dev/null
+++ b/build/torch29-cxx11-cu129-x86_64-linux/_layers/mpu.py
@@ -0,0 +1,94 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Optional
+
+import torch
+import torch.distributed as dist
+
+# from megablocks.layers.arguments import Arguments
+from .arguments import Arguments
+
+
+class MoeParam(torch.Tensor):
+
+ def __init__(self):
+ super().__init__(self)
+ self.expert_model_parallel: bool
+
+
+def is_moe_param(tensor: torch.Tensor) -> bool:
+ return hasattr(tensor, 'expert_model_parallel')
+
+
+def get_expert_parallel_world_size(args: Arguments) -> int:
+ return (dist.get_world_size(args.expert_parallel_group) if args.moe_expert_model_parallelism else 1)
+
+
+def get_expert_parallel_rank(args: Arguments) -> int:
+ return (dist.get_rank(args.expert_parallel_group) if args.moe_expert_model_parallelism else 0)
+
+
+def set_expert_model_parallel_attributes(
+ tensor: torch.Tensor,
+ is_parallel: bool,
+):
+ assert not hasattr(tensor, 'expert_model_parallel')
+ setattr(tensor, 'expert_model_parallel', is_parallel)
+
+
+def param_is_expert_model_parallel(param: MoeParam) -> bool:
+ return (hasattr(param, 'expert_model_parallel') and param.expert_model_parallel)
+
+
+def copy_expert_model_parallel_attributes(
+ destination_tensor: torch.Tensor,
+ source_tensor: torch.Tensor,
+):
+ if hasattr(source_tensor, 'expert_model_parallel'):
+ setattr(
+ destination_tensor,
+ 'expert_model_parallel',
+ getattr(source_tensor, 'expert_model_parallel'),
+ )
+
+
+def synchronized_print(group: Optional[dist.ProcessGroup], *x: torch.Tensor):
+ world_size = dist.get_world_size(group)
+ rank = dist.get_rank(group)
+ for i in range(world_size):
+ dist.barrier(group)
+ if i == rank:
+ print(f'rank = {rank}', *x)
+
+
+# Helpers for expert/tensor sharding.
+def expert_sharding_degree(args: Arguments) -> int:
+ world_size = get_expert_parallel_world_size(args)
+ esd = min(world_size, args.moe_num_experts)
+
+ if (args.moe_num_experts % esd) != 0:
+ raise ValueError(f'Cannot shard {args.moe_num_experts} experts {esd} ways.',)
+ return esd
+
+
+def hidden_sharding_degree(args: Arguments) -> int:
+ world_size = get_expert_parallel_world_size(args)
+ esd = expert_sharding_degree(args)
+ hsd = world_size // esd
+
+ if (args.ffn_hidden_size % hsd) != 0:
+ raise ValueError(f'Cannot shard {args.ffn_hidden_size} features {hsd} ways.',)
+ if (esd * hsd) != world_size:
+ raise ValueError(
+ f"Invalid sharding. 'expert_sharding_degree' ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size}).",
+ )
+ return hsd
+
+
+def experts_per_rank(args: Arguments) -> int:
+ return args.moe_num_experts // expert_sharding_degree(args)
+
+
+def features_per_rank(args: Arguments) -> int:
+ return args.ffn_hidden_size // hidden_sharding_degree(args)
diff --git a/build/torch29-cxx11-cu129-x86_64-linux/_layers/router.py b/build/torch29-cxx11-cu129-x86_64-linux/_layers/router.py
new file mode 100644
index 0000000000000000000000000000000000000000..37cb2782348d62583376f1a183c7ede83601216d
--- /dev/null
+++ b/build/torch29-cxx11-cu129-x86_64-linux/_layers/router.py
@@ -0,0 +1,116 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+
+# from megablocks.layers import common
+# from megablocks.layers.arguments import Arguments
+from . import common
+from .arguments import Arguments
+
+_ROUTER_LOGITS = []
+
+
+def _save_router_logits(logits: torch.Tensor, args: Arguments):
+ if args.moe_zloss_weight == 0:
+ return
+ global _ROUTER_LOGITS
+ _ROUTER_LOGITS.append(logits)
+
+
+def clear_router_zloss():
+ global _ROUTER_LOGITS
+ _ROUTER_LOGITS.clear()
+
+
+def batched_router_zloss(args: Arguments):
+ global _ROUTER_LOGITS
+
+ if args.moe_zloss_weight == 0:
+ import warnings
+ warnings.warn('Call to batched_router_zloss, but moe_zloss_weight=0')
+ return 0
+
+ logits_per_router = _ROUTER_LOGITS
+
+ if args.moe_zloss_in_fp32:
+ logits_per_router = [logits.float() for logits in logits_per_router]
+
+ unscaled_zloss_per_router = torch.stack([
+ torch.logsumexp(logits, dim=1).square().mean() for logits in logits_per_router
+ ])
+
+ return args.moe_zloss_weight * unscaled_zloss_per_router
+
+
+# NOTE: To enable end-to-end benchmarking without convergence we
+# support a flag to force the router to assign tokens uniformly
+# across the experts. We do this with a custom autograd operation
+# so that PyTorch still executes the full set of router operation.
+class _UniformExpertAssignment(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, num_experts: int):
+ out = torch.arange(x.numel(), dtype=x.dtype, device=x.device)
+ out = torch.remainder(out, num_experts)
+ return out.view(x.shape)
+
+
+_uniform_expert_assignment = _UniformExpertAssignment.apply
+
+
+class LearnedRouter(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+
+ # Learned router parameters.
+ #
+ # NOTE: This weight matrix is not parallelized with expert model
+ # parallelism. Each device needs the entire router weight matrix
+ # so that it can route its batch of data correctly.
+ self.layer = torch.nn.Linear(
+ args.hidden_size,
+ args.moe_num_experts,
+ bias=False,
+ dtype=common.dtype(args),
+ device=args.device,
+ )
+ args.init_method(self.layer.weight)
+
+ def jitter(self, x: torch.Tensor):
+ low: float = 1.0 - self.args.moe_jitter_eps
+ high: float = 1.0 + self.args.moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return low + noise * (high - low)
+
+ def _top_k(self, scores: torch.Tensor):
+ if self.args.moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, self.args.moe_top_k, dim=-1)
+
+ def forward(self, x: torch.Tensor):
+ if self.training and self.args.moe_jitter_eps is not None:
+ x = x * self.jitter(x)
+
+ logits = self.layer(x.view(-1, x.shape[-1]))
+ _save_router_logits(logits, self.args)
+ scores = logits.softmax(dim=-1)
+ expert_weights, expert_indices = self._top_k(scores)
+ if self.args.moe_normalize_expert_weights:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=self.args.moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+
+ expert_indices = (
+ _uniform_expert_assignment(
+ expert_indices,
+ self.args.moe_num_experts,
+ ) if self.args.uniform_expert_assignment else expert_indices
+ )
+ return scores, expert_weights, expert_indices
diff --git a/build/torch29-cxx11-cu129-x86_64-linux/_layers/sharedexpert_registry.py b/build/torch29-cxx11-cu129-x86_64-linux/_layers/sharedexpert_registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..5840862f88f370ace5fd49bd0612fc98d186cc49
--- /dev/null
+++ b/build/torch29-cxx11-cu129-x86_64-linux/_layers/sharedexpert_registry.py
@@ -0,0 +1,32 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Union
+
+# from megablocks.layers import glu, mlp
+# from megablocks.layers.arguments import Arguments
+from . import glu, mlp
+from .arguments import Arguments
+
+_REGISTRY = {
+ 'mlp': mlp.SharedMLP,
+ 'glu': glu.SharedGLU,
+}
+
+
+def get(args: Arguments) -> Union[mlp.SharedMLP, glu.SharedGLU]:
+ """Returns an SharedMLP for use in a dMoE instance.
+
+ Uses the provided arguments to instantiate the appropriate
+ SharedMLP instance.
+
+ Args:
+ args: propagated Arguments dataclass.
+
+ Returns:
+ An instantiated SharedMLP constructed using the input args.
+ """
+ if args.mlp_type not in _REGISTRY:
+ raise ValueError(f'Unsupported mlp type: {args.mlp_type}')
+
+ return _REGISTRY[args.mlp_type](args)
diff --git a/build/torch29-cxx11-cu129-x86_64-linux/_megablocks_cuda_7a6bcf4.abi3.so b/build/torch29-cxx11-cu129-x86_64-linux/_megablocks_cuda_7a6bcf4.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..d6d595b1f0221f9d3793e58d2f549ef17693f655
--- /dev/null
+++ b/build/torch29-cxx11-cu129-x86_64-linux/_megablocks_cuda_7a6bcf4.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:94304fb698702f77c92b943ee0a64f00b26aedca7afa944c3a470de2ca7a13e5
+size 16003376
diff --git a/build/torch29-cxx11-cu129-x86_64-linux/_ops.py b/build/torch29-cxx11-cu129-x86_64-linux/_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..abde001f8cf5f78a02794d6e9a81fd8195e65d77
--- /dev/null
+++ b/build/torch29-cxx11-cu129-x86_64-linux/_ops.py
@@ -0,0 +1,9 @@
+import torch
+from . import _megablocks_cuda_7a6bcf4
+ops = torch.ops._megablocks_cuda_7a6bcf4
+
+def add_op_namespace_prefix(op_name: str):
+ """
+ Prefix op by namespace.
+ """
+ return f"_megablocks_cuda_7a6bcf4::{op_name}"
diff --git a/build/torch29-cxx11-cu129-x86_64-linux/_version.py b/build/torch29-cxx11-cu129-x86_64-linux/_version.py
new file mode 100644
index 0000000000000000000000000000000000000000..c55783177af19bc03654c730c4892df8f8532279
--- /dev/null
+++ b/build/torch29-cxx11-cu129-x86_64-linux/_version.py
@@ -0,0 +1,6 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+"""The MegaBlocks Version."""
+
+__version__ = '0.11.0.dev0'
diff --git a/build/torch29-cxx11-cu129-x86_64-linux/backend/__init__.py b/build/torch29-cxx11-cu129-x86_64-linux/backend/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d4e43e9b7e6a9a1c3bde2df34914643ca5d8332
--- /dev/null
+++ b/build/torch29-cxx11-cu129-x86_64-linux/backend/__init__.py
@@ -0,0 +1,2 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
diff --git a/build/torch29-cxx11-cu129-x86_64-linux/backend/kernels.py b/build/torch29-cxx11-cu129-x86_64-linux/backend/kernels.py
new file mode 100644
index 0000000000000000000000000000000000000000..d20139324883338ddb312e4b05a72056d74491ac
--- /dev/null
+++ b/build/torch29-cxx11-cu129-x86_64-linux/backend/kernels.py
@@ -0,0 +1,557 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import triton
+import triton.language as tl
+
+# Stub triton autotune when testing in a env that does not have CUDA
+# this approach preserves the original code but enables testing without a GPU
+if torch.cuda.is_available() is False:
+ import warnings
+
+ warnings.warn("CUDA is not available. Triton autotuning is disabled.")
+
+ def _no_autotune(*args, **kwargs):
+ def deco(fn):
+ return fn
+ return deco
+
+ triton.autotune = _no_autotune
+
+
+def assert_is_tensor(x, ndim):
+ if x.ndim != ndim:
+ raise ValueError(f'Expected {ndim}-tensor but got {x.ndim}-tensor')
+
+
+def assert_is_matrix(x):
+ assert_is_tensor(x, 2)
+
+
+def assert_is_vector(x):
+ if x.ndim != 1:
+ raise ValueError(f'Expected 1-tensor but got {x.ndim}-tensor')
+
+
+def assert_equal(a, b):
+ if a != b:
+ raise ValueError(f'Expected dimensions to be equal but got {a} and {b}.',)
+
+
+# a: (tokens, hidden_size), real.
+# indices: (tokens * top_k), integer.
+# bin_ids: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+# padded_bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _padded_copy(
+ a,
+ b,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+ A_TO_B: tl.constexpr,
+ SCALE: tl.constexpr,
+):
+ # Our index into array 'a'.
+ index_a = tl.load(indices + tl.program_id(0))
+
+ # One threadblock per row in 'a'. Array 'b' has greater or equal
+ # number of rows since they could be padded.
+ bin_idx = tl.load(bin_ids + tl.program_id(0))
+
+ # Now we know what bin we're assigned to, but we need to know how
+ # many threadblocks were assigned to earlier bins so we can offset
+ # in our bin properly.
+ offset_in_bin = tl.program_id(0)
+ if bin_idx > 0:
+ offset_in_bin -= tl.load(bins + bin_idx - 1)
+
+ # Load the starting index of our bin in array 'b'.
+ index_b = offset_in_bin
+ if bin_idx > 0:
+ index_b += tl.load(padded_bins + bin_idx - 1)
+
+ # Offset the input and output pointers.
+ #
+ # If we're going from A to B, divide the input index to copy
+ # the same input repeatedly. If we're going from B to A we
+ # need to reduce the result. Using atomics is slow, so we
+ # do the reduce step in a second kernel.
+ offset = index_a // TOP_K if A_TO_B else index_a
+ a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS)
+ b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ # Load the scale, if requested.
+ scale = tl.load(weights + index_a) if SCALE else 1
+
+ # Swap the pointers depending on the direction.
+ iptr = a if A_TO_B else b
+ optr = b if A_TO_B else a
+
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ x = tl.load(iptr + offsets, mask=mask)
+ x = x.to(tl.float32) * scale.to(tl.float32)
+
+ tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask)
+
+ offsets += BLOCK_X
+
+
+def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+ assert_equal(bin_ids.shape[0], x.shape[0] * top_k)
+ assert_equal(bins.size(), padded_bins.size())
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ # NOTE: Because of the padding, the output size is dynamic.
+ # We load the final padded bin bound to get the output rows.
+ output_rows = padded_bins[-1].cpu().item()
+ out = torch.zeros((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ x,
+ out,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def gather(x, indices, bin_ids, weights, bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+ assert_equal(bin_ids.shape[0], x.shape[0] * top_k)
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ # NOTE: There is no padding so the output rows equals the
+ # input rows multiplied by top_k.
+ output_rows = x.shape[0] * top_k
+ out = torch.empty((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ x,
+ out,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], bin_ids.shape[0])
+ assert_equal(bins.size(), padded_bins.size())
+
+ if weights is not None:
+ assert_equal(indices.shape[0], weights.shape[0])
+
+ tokens = indices.shape[0] // top_k
+ out = torch.empty((tokens, top_k, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ out,
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=False,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+
+ # Reduce along the top-k dimension, if needed.
+ return out.sum(dim=1) if top_k > 1 else out.view(tokens, x.shape[1])
+
+
+def scatter(x, indices, bin_ids, weights, bins, top_k):
+ return padded_scatter(x, indices, bin_ids, weights, bins, bins, top_k)
+
+
+# x: (tokens, top_k, hidden_size), real
+# grad: (tokens, hidden_size), real.
+# wgrad: (tokens, top_k), real.
+# indices: (tokens * top_k), integer.
+# bin_ids: (tokens * top_k), integer.
+# bins: (num_experts), integer.
+# padded_bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _padded_copy_wgrad(
+ x,
+ grad,
+ wgrad,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+):
+ # Our index into 'tokens * top_k'.
+ index_out = tl.load(indices + tl.program_id(0))
+
+ # One threadblock per row in 'a'. Array 'b' has greater or equal
+ # number of rows since they could be padded.
+ bin_idx = tl.load(bin_ids + tl.program_id(0))
+
+ # Now we know what bin we're assigned to, but we need to know how
+ # many threadblocks were assigned to earlier bins so we can offset
+ # in our bin properly.
+ offset_in_bin = tl.program_id(0)
+ if bin_idx > 0:
+ offset_in_bin -= tl.load(bins + bin_idx - 1)
+
+ # Load the starting index of our bin in array 'x'.
+ index_x = offset_in_bin
+ if bin_idx > 0:
+ index_x += tl.load(padded_bins + bin_idx - 1)
+
+ # Offset the input and output pointers.
+ wgrad += index_out
+ grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS)
+ x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ data = tl.load(x + offsets, mask=mask).to(tl.float32)
+ scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
+ acc += data * scale
+ offsets += BLOCK_X
+
+ # Reduce to get the final result and store.
+ out = tl.sum(acc).to(wgrad.dtype.element_ty)
+ tl.store(wgrad, out)
+
+
+def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_matrix(grad)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], bin_ids.shape[0])
+ assert_equal(bins.size(), padded_bins.size())
+
+ tokens = indices.shape[0] // top_k
+ out = torch.empty((tokens * top_k), dtype=x.dtype, device=x.device)
+ _padded_copy_wgrad[(indices.shape[0],)](
+ x,
+ grad,
+ out,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ TOP_K=top_k,
+ )
+ return out
+
+
+def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k):
+ return padded_scatter_wgrad(x, grad, indices, bin_ids, bins, bins, top_k)
+
+
+# a: (tokens, hidden_size), real.
+# b: (num_experts, expert_capacity, num_columns), real.
+# indices: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _binned_copy(
+ a,
+ b,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+ A_TO_B: tl.constexpr,
+ SCALE: tl.constexpr,
+):
+ # Load our indices into the output.
+ expert_idx = tl.program_id(0)
+ entry_idx = tl.program_id(1)
+
+ # Calculate our offset into the output.
+ index_b = expert_idx * expert_capacity + entry_idx
+
+ # Load the index bounds for our bin and calculate
+ # the number of tokens assigned to our expert.
+ start = 0
+ if expert_idx > 0:
+ start = tl.load(bins + expert_idx - 1)
+ end = tl.load(bins + expert_idx)
+ num_tokens = end - start
+
+ # Calculate our offset into the input. If we don't
+ # have an input exit early.
+ if entry_idx >= num_tokens:
+ return
+ index_a = tl.load(indices + start + entry_idx)
+
+ # Offset the input and output pointers.
+ #
+ # If we're going from A to B, divide the input index to copy
+ # the same input repeatedly. If we're going from B to A we
+ # need to reduce the result. Using atomics is slow, so we
+ # do the reduce step in a second kernel.
+ offset = index_a // TOP_K if A_TO_B else index_a
+ a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS)
+ b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ # Load the scale, if requested.
+ scale = tl.load(weights + index_a) if SCALE else 1
+
+ # Swap the pointers depending on the direction.
+ #
+ # NOTE: We need to zero the output in both directions.
+ iptr = a if A_TO_B else b
+ optr = b if A_TO_B else a
+
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ x = tl.load(iptr + offsets, mask=mask)
+ x = x.to(tl.float32) * scale.to(tl.float32)
+
+ tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask)
+
+ offsets += BLOCK_X
+
+
+def binned_gather(x, indices, weights, bins, expert_capacity, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ num_experts = bins.shape[0]
+ out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
+
+ _binned_copy[(num_experts, expert_capacity)](
+ x,
+ out,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def binned_scatter(x, indices, weights, bins, top_k):
+ # Validate the input shapes.
+ assert_is_tensor(x, 3)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(bins.shape[0], x.shape[0])
+
+ if weights is not None:
+ assert_equal(indices.shape[0], weights.shape[0])
+
+ num_experts, expert_capacity, hidden_size = x.shape
+ tokens = indices.shape[0] // top_k
+ out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
+ _binned_copy[(num_experts, expert_capacity)](
+ out,
+ x,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS=hidden_size,
+ A_TO_B=False,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+
+ # Reduce along the top-k dimension, if needed.
+ return out.sum(dim=1) if top_k > 1 else out.view(tokens, hidden_size)
+
+
+# a: (tokens, hidden_size), real.
+# b: (num_experts, expert_capacity, num_columns), real.
+# indices: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _binned_copy_wgrad(
+ x,
+ grad,
+ wgrad,
+ num_experts,
+ expert_capacity,
+ indices,
+ bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+):
+ # Load our indices into the output.
+ expert_idx = tl.program_id(0)
+ entry_idx = tl.program_id(1)
+
+ # Calculate our offset into the output.
+ index_x = expert_idx * expert_capacity + entry_idx
+
+ # Load the index bounds for our bin and calculate
+ # the number of tokens assigned to our expert.
+ start = 0
+ if expert_idx > 0:
+ start = tl.load(bins + expert_idx - 1)
+ end = tl.load(bins + expert_idx)
+ num_tokens = end - start
+
+ # Calculate our offset into the input. If we don't
+ # have an input exit early.
+ if entry_idx >= num_tokens:
+ return
+ index_out = tl.load(indices + start + entry_idx)
+
+ # Offset the input and output pointers.
+ wgrad += index_out
+ grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS)
+ x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ data = tl.load(x + offsets, mask=mask).to(tl.float32)
+ scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
+ acc += data * scale
+ offsets += BLOCK_X
+
+ # Reduce to get the final result and store.
+ out = tl.sum(acc).to(wgrad.dtype.element_ty)
+ tl.store(wgrad, out)
+
+
+def binned_scatter_wgrad(x, grad, indices, bins, top_k):
+ # Validate the input shapes.
+ assert_is_tensor(x, 3)
+ assert_is_matrix(grad)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(bins.shape[0], x.shape[0])
+
+ num_experts, expert_capacity, hidden_size = x.shape
+ tokens = indices.shape[0] // top_k
+ out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
+ _binned_copy_wgrad[(num_experts, expert_capacity)](
+ x,
+ grad,
+ out,
+ num_experts,
+ expert_capacity,
+ indices,
+ bins,
+ NUM_COLUMNS=hidden_size,
+ TOP_K=top_k,
+ )
+ return out
diff --git a/build/torch29-cxx11-cu129-x86_64-linux/benchmark_util.py b/build/torch29-cxx11-cu129-x86_64-linux/benchmark_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..02612d95e3ead1175a596e2878fa34b5bf85ad6f
--- /dev/null
+++ b/build/torch29-cxx11-cu129-x86_64-linux/benchmark_util.py
@@ -0,0 +1,35 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import numpy as np
+import torch
+
+
+def log_benchmark(name, arguments, time, std):
+ print('=' * 60)
+ print(f'{name} Benchmark')
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean time = {:.3f}ms, std time = {:.3f}ms'.format(time, std))
+ print('=' * 60)
+
+
+def benchmark_function(fn, iterations=100, warmup=10):
+ # Warmup iterations.
+ for _ in range(warmup):
+ fn()
+
+ times = []
+ for i in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+
+ start.record()
+ fn()
+ end.record()
+
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ return np.mean(times), np.std(times)
diff --git a/build/torch29-cxx11-cu129-x86_64-linux/cpu_fused_moe.py b/build/torch29-cxx11-cu129-x86_64-linux/cpu_fused_moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a29e81d6e9c57f64b0a78ac1c0828e45fd9d855
--- /dev/null
+++ b/build/torch29-cxx11-cu129-x86_64-linux/cpu_fused_moe.py
@@ -0,0 +1,311 @@
+# SPDX-License-Identifier: Apache-2.0
+# MegaBlocks CPU Fused MoE Implementation
+#
+# This is a pure Python/PyTorch implementation for CPU.
+# For better performance, consider using the C++ kernel implementation.
+#
+import torch
+import torch.nn.functional as F
+
+
+def swigluoai_activation(gate: torch.Tensor, up: torch.Tensor,
+ alpha: float = 1.702, limit: float = 7.0) -> torch.Tensor:
+ """
+ SwigluOAI activation function used in GptOss models.
+
+ Formula:
+ gate = clamp(gate, max=limit)
+ up = clamp(up, -limit, limit)
+ glu = gate * sigmoid(gate * alpha)
+ output = (up + 1) * glu
+
+ Args:
+ gate: Gate tensor from gate projection
+ up: Up tensor from up projection
+ alpha: Scaling factor for sigmoid (default: 1.702)
+ limit: Clamp limit (default: 7.0)
+
+ Returns:
+ Activated tensor
+ """
+ gate = gate.clamp(max=limit)
+ up = up.clamp(min=-limit, max=limit)
+ glu = gate * torch.sigmoid(gate * alpha)
+ return (up + 1) * glu
+
+
+def silu_and_mul_activation(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor:
+ """
+ SiLU (Swish) activation with element-wise multiplication.
+
+ Formula:
+ output = silu(gate) * up
+
+ Args:
+ gate: Gate tensor
+ up: Up tensor
+
+ Returns:
+ Activated tensor
+ """
+ return F.silu(gate) * up
+
+
+def route_tokens_cpu(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: torch.Tensor | None,
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_normalize_expert_weights: int | None = None,
+) -> tuple:
+ """
+ Route tokens to experts and compute expert weights and indices (CPU version).
+
+ Args:
+ x: Input tensor [batch, seq, hidden] or [tokens, hidden]
+ router_weight: Router weight [num_experts, hidden]
+ router_bias: Router bias [num_experts] or None
+ moe_top_k: Number of experts per token
+ moe_num_experts: Total number of experts
+ moe_normalize_expert_weights: Normalization order or None
+
+ Returns:
+ Tuple of (logits, expert_weights, expert_indices)
+ """
+ x_flat = x.view(-1, x.shape[-1])
+ logits = F.linear(x_flat, router_weight, router_bias)
+
+ if moe_top_k == 1:
+ expert_weights, expert_indices = logits.max(dim=-1, keepdim=True)
+ else:
+ expert_weights, expert_indices = torch.topk(logits, moe_top_k, dim=-1)
+
+ expert_weights = expert_weights.softmax(dim=-1)
+
+ if moe_normalize_expert_weights is not None:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+
+ return logits, expert_weights, expert_indices
+
+
+def cpu_fused_moe(
+ hidden_states: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ w1_bias: torch.Tensor | None = None,
+ w2_bias: torch.Tensor | None = None,
+ activation: str = "silu",
+ alpha: float = 1.702,
+ limit: float = 7.0,
+ is_interleaved: bool = True,
+) -> torch.Tensor:
+ """
+ CPU Fused MoE using PyTorch operations.
+
+ This implementation processes all experts in parallel using batched operations
+ instead of sequential for loops, which is more efficient on CPU.
+
+ Args:
+ hidden_states: [num_tokens, hidden_size]
+ w1: [num_experts, hidden_size, 2*inter_size] - gate_up_proj weights
+ w2: [num_experts, inter_size, hidden_size] - down_proj weights
+ topk_weights: [num_tokens, topk] - routing weights
+ topk_ids: [num_tokens, topk] - expert indices
+ w1_bias: [num_experts, 2*inter_size] or None
+ w2_bias: [num_experts, hidden_size] or None
+ activation: "silu" or "swigluoai"
+ alpha: swigluoai alpha parameter
+ limit: swigluoai limit parameter
+ is_interleaved: whether gate_up is interleaved [g0,u0,g1,u1,...] (True for GptOss)
+
+ Returns:
+ output: [num_tokens, hidden_size]
+ """
+ num_tokens, hidden_size = hidden_states.shape
+ num_experts = w1.shape[0]
+ inter_size = w2.shape[1]
+ topk = topk_weights.shape[1]
+
+ # Initialize output
+ output = torch.zeros_like(hidden_states)
+
+ # Build expert mask: which tokens go to which expert
+ # expert_mask[expert_id] contains indices of (token_idx, topk_pos) pairs
+ for expert_idx in range(num_experts):
+ # Find tokens assigned to this expert
+ # mask shape: [num_tokens, topk], True where topk_ids == expert_idx
+ mask = (topk_ids == expert_idx)
+
+ if not mask.any():
+ continue
+
+ # Get token indices and topk positions
+ token_indices, topk_positions = torch.where(mask)
+
+ if len(token_indices) == 0:
+ continue
+
+ # Gather input tokens for this expert
+ # current_hidden: [num_selected_tokens, hidden_size]
+ current_hidden = hidden_states[token_indices]
+
+ # Get weights for this expert
+ # w1[expert_idx]: [hidden_size, 2*inter_size]
+ # w2[expert_idx]: [inter_size, hidden_size]
+ expert_w1 = w1[expert_idx] # [hidden_size, 2*inter_size]
+ expert_w2 = w2[expert_idx] # [inter_size, hidden_size]
+
+ # First projection: hidden @ w1 -> [num_selected, 2*inter_size]
+ gate_up = current_hidden @ expert_w1
+
+ # Add bias if present
+ if w1_bias is not None:
+ gate_up = gate_up + w1_bias[expert_idx]
+
+ # Split gate and up projections
+ if is_interleaved:
+ # GptOss uses interleaved layout: [g0, u0, g1, u1, ...]
+ gate = gate_up[..., ::2] # [num_selected, inter_size]
+ up = gate_up[..., 1::2] # [num_selected, inter_size]
+ else:
+ # Standard layout: [gate_all, up_all]
+ gate = gate_up[..., :inter_size]
+ up = gate_up[..., inter_size:]
+
+ # Apply activation
+ if activation == "swigluoai":
+ activated = swigluoai_activation(gate, up, alpha, limit)
+ else: # silu
+ activated = silu_and_mul_activation(gate, up)
+
+ # Second projection: activated @ w2 -> [num_selected, hidden_size]
+ expert_out = activated @ expert_w2
+
+ # Add bias if present
+ if w2_bias is not None:
+ expert_out = expert_out + w2_bias[expert_idx]
+
+ # Apply routing weights and accumulate
+ # weights shape: [num_selected]
+ weights = topk_weights[token_indices, topk_positions].unsqueeze(-1)
+ weighted_out = expert_out * weights
+
+ # Accumulate to output
+ output.index_add_(0, token_indices, weighted_out.to(output.dtype))
+
+ return output
+
+
+class MegaBlocksMoeMLP(torch.nn.Module):
+ """
+ CPU MoE MLP module that can be used as a drop-in replacement for
+ the transformers GptOssMLP when using @use_kernel_forward_from_hub.
+ """
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> tuple:
+ """
+ Forward pass through the MoE layer.
+
+ Args:
+ x: Input tensor of shape [batch_size, seq_len, hidden_size] or [tokens, hidden_size]
+
+ Returns:
+ Tuple of (output, expert_weights) where:
+ - output: Tensor of same shape as input
+ - expert_weights: Expert weights for each token [tokens, top_k]
+ """
+ # Get MoE parameters from the wrapped modules
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+
+ # Detect activation type
+ if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
+ activation = "swigluoai"
+ alpha = self.experts.alpha
+ limit = self.experts.limit
+ else:
+ activation = getattr(self.experts, "activation", "silu")
+ alpha = 1.702
+ limit = 7.0
+
+ # Get weight tensors
+ if hasattr(self.experts, "gate_up_proj"):
+ w1 = self.experts.gate_up_proj
+ is_interleaved = True # GptOss uses interleaved layout
+ elif hasattr(self.experts, "w1"):
+ w1 = self.experts.w1
+ w3 = getattr(self.experts, "w3", None)
+ if w3 is not None:
+ w1 = torch.cat([w1, w3], dim=-1)
+ is_interleaved = False
+ else:
+ raise AttributeError("experts module must have 'gate_up_proj' or 'w1' attribute")
+
+ if hasattr(self.experts, "down_proj"):
+ w2 = self.experts.down_proj
+ elif hasattr(self.experts, "w2"):
+ w2 = self.experts.w2
+ else:
+ raise AttributeError("experts module must have 'down_proj' or 'w2' attribute")
+
+ # Get optional bias tensors
+ w1_bias = getattr(self.experts, "gate_up_proj_bias", None)
+ w2_bias = getattr(self.experts, "down_proj_bias", None)
+
+ # Store original shape
+ in_shape = x.size()
+
+ # Route tokens to experts
+ logits, expert_weights, expert_indices = route_tokens_cpu(
+ x,
+ self.router.weight,
+ getattr(self.router, "bias", None),
+ moe_top_k,
+ moe_num_experts,
+ moe_normalize_expert_weights,
+ )
+
+ # Reshape input for fused MoE
+ x_flat = x.view(-1, x.shape[-1])
+
+ # Call CPU fused MoE
+ output = cpu_fused_moe(
+ hidden_states=x_flat,
+ w1=w1,
+ w2=w2,
+ topk_weights=expert_weights,
+ topk_ids=expert_indices,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ activation=activation,
+ alpha=alpha,
+ limit=limit,
+ is_interleaved=is_interleaved,
+ )
+
+ # Restore original shape
+ output = output.view(in_shape)
+
+ return output, expert_weights
+
+
+# Export classes and functions
+__all__ = [
+ "MegaBlocksMoeMLP",
+ "cpu_fused_moe",
+ "route_tokens_cpu",
+ "swigluoai_activation",
+ "silu_and_mul_activation",
+]
diff --git a/build/torch29-cxx11-cu129-x86_64-linux/cpu_moe_cpp.py b/build/torch29-cxx11-cu129-x86_64-linux/cpu_moe_cpp.py
new file mode 100644
index 0000000000000000000000000000000000000000..073ff66d24ce348fbb5ed19c9027fadd3f7a9c61
--- /dev/null
+++ b/build/torch29-cxx11-cu129-x86_64-linux/cpu_moe_cpp.py
@@ -0,0 +1,265 @@
+# SPDX-License-Identifier: Apache-2.0
+# MegaBlocks C++ Optimized CPU MoE
+
+"""
+C++ accelerated MoE with brgemm optimization for Intel AMX.
+Direct replacement for cpu_fused_moe.MegaBlocksMoeMLP with better performance.
+"""
+
+import torch
+from typing import Optional
+from .cpu_fused_moe import route_tokens_cpu
+from ._ops import ops
+
+
+def _to_local_tensor(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
+ """Convert DTensor to local torch.Tensor if needed for custom ops compatibility."""
+ if tensor is None:
+ return None
+ # Check if it's a DTensor by looking for the to_local() method
+ if hasattr(tensor, "to_local"):
+ return tensor.to_local()
+ return tensor
+
+
+def fused_moe_cpp(
+ hidden_states: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ inplace: bool = False,
+ use_int8_w8a8: bool = False,
+ use_fp8_w8a16: bool = False,
+ use_mxfp4: bool = False,
+ w1_scale: Optional[torch.Tensor] = None,
+ w2_scale: Optional[torch.Tensor] = None,
+ block_size: Optional[list] = None,
+ a1_scale: Optional[torch.Tensor] = None,
+ a2_scale: Optional[torch.Tensor] = None,
+ w1_bias: Optional[torch.Tensor] = None,
+ w2_bias: Optional[torch.Tensor] = None,
+ alpha: Optional[float] = None,
+ limit: Optional[float] = None,
+ is_vnni: bool = False,
+) -> torch.Tensor:
+ """
+ C++ Fused MoE with brgemm optimization (sglang compatible interface).
+
+ Uses at::native::cpublas::brgemm for efficient batch GEMM on Intel CPUs.
+ Supports both silu_and_mul (standard SwiGLU) and swigluoai (GptOss) activations.
+
+ Args:
+ hidden_states: Input tensor [M, K]
+ w1: Gate and up projections [E, 2N, K]
+ w2: Down projection [E, K, N]
+ topk_weights: Expert weights [M, topk]
+ topk_ids: Expert indices [M, topk]
+ inplace: Whether to use hidden_states as output
+ use_int8_w8a8: Use int8 quantization
+ use_fp8_w8a16: Use fp8 quantization
+ use_mxfp4: Use mxfp4 quantization
+ w1_scale, w2_scale: Quantization scales
+ block_size: Block size for fp8
+ a1_scale, a2_scale: Activation scales
+ w1_bias, w2_bias: Optional biases
+ alpha: swigluoai alpha parameter (set to enable swiglu)
+ limit: swigluoai limit parameter (set to enable swiglu)
+ is_vnni: Whether w1/w2 are already in VNNI packed format
+ """
+ # MXFP4/FP8 kernels only support bf16, convert if needed
+ orig_dtype = hidden_states.dtype
+ need_convert = ((use_mxfp4 or use_fp8_w8a16) and orig_dtype != torch.bfloat16) or orig_dtype == torch.float32
+ if need_convert:
+ hidden_states = hidden_states.to(torch.bfloat16)
+
+ # bias must match hidden_states dtype
+ if w1_bias is not None:
+ w1_bias = w1_bias.to(hidden_states.dtype)
+ if w2_bias is not None:
+ w2_bias = w2_bias.to(hidden_states.dtype)
+
+ # Convert DTensor to local tensor for custom ops compatibility (TP mode)
+ hidden_states = _to_local_tensor(hidden_states)
+ w1 = _to_local_tensor(w1)
+ w2 = _to_local_tensor(w2)
+ topk_weights = _to_local_tensor(topk_weights)
+ topk_ids = _to_local_tensor(topk_ids)
+ w1_scale = _to_local_tensor(w1_scale)
+ w2_scale = _to_local_tensor(w2_scale)
+ a1_scale = _to_local_tensor(a1_scale)
+ a2_scale = _to_local_tensor(a2_scale)
+ w1_bias = _to_local_tensor(w1_bias)
+ w2_bias = _to_local_tensor(w2_bias)
+
+ output = ops.fused_experts(
+ hidden_states, w1, w2, topk_weights, topk_ids,
+ inplace, use_int8_w8a8, use_fp8_w8a16, use_mxfp4,
+ w1_scale, w2_scale, block_size, a1_scale, a2_scale,
+ w1_bias, w2_bias, alpha, limit, is_vnni
+ )
+
+ # Convert back to original dtype if needed
+ if need_convert:
+ output = output.to(orig_dtype)
+ return output
+
+
+class CPUMegaBlocksMoeMLP(torch.nn.Module):
+ """
+ C++ optimized MoE MLP using brgemm.
+ Drop-in replacement for cpu_fused_moe.MegaBlocksMoeMLP with better performance.
+
+ Usage in transformers:
+ # Will be used via @use_kernel_forward_from_hub decorator
+ """
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> tuple:
+ """
+ Forward pass through the MoE layer using C++ kernel.
+
+ Args:
+ x: Input tensor [batch_size, seq_len, hidden_size]
+
+ Returns:
+ Tuple of (output, expert_weights)
+ """
+ # Optimization for GPT-OSS model
+ if getattr(self, "use_mxfp4", None) is None:
+ self.use_mxfp4 = False
+
+ w1_scale = None
+ w2_scale = None
+
+ if (
+ not getattr(self, "packed_scales", False)
+ and hasattr(self.experts, "gate_up_proj")
+ and getattr(self.experts, "gate_up_proj_precision_config", None) is not None
+ ):
+ # convert scales
+ data_1 = ops.convert_scale_packed(self.experts.gate_up_proj_precision_config.weight_scale.data.transpose(-1, -2).contiguous())
+ data_2 = ops.convert_scale_packed(self.experts.down_proj_precision_config.weight_scale.data.transpose(-1, -2).contiguous())
+ self.experts.gate_up_proj_precision_config.weight_scale.storage.data = data_1
+ self.experts.down_proj_precision_config.weight_scale.storage.data = data_2
+ self.packed_scales = True
+ self.use_mxfp4 = True
+
+ if not getattr(self, "packed_weight", False) and hasattr(
+ self.experts, "gate_up_proj"
+ ):
+ # convert weights
+ data_1 = self.experts.gate_up_proj.data.transpose(-1, -2).contiguous()
+ data_2 = self.experts.down_proj.data.transpose(-1, -2).contiguous()
+ if self.use_mxfp4:
+ self.experts.gate_up_proj.storage.data = ops.convert_weight_packed(data_1)
+ self.experts.down_proj.storage.data = ops.convert_weight_packed(data_2)
+ else:
+ # convert_weight_packed only supports bfloat16, float16, int8, fp8_e4m3 or uint8(mxfp4 or int4).
+ data_1 = data_1.to(torch.bfloat16) if data_1.dtype == torch.float32 else data_1
+ data_2 = data_2.to(torch.bfloat16) if data_2.dtype == torch.float32 else data_2
+ self.experts.gate_up_proj.data = ops.convert_weight_packed(data_1)
+ self.experts.down_proj.data = ops.convert_weight_packed(data_2)
+
+ # C++ kernel does not support float32.
+ dtype = torch.bfloat16 if x.dtype == torch.float32 else x.dtype
+ if getattr(self.experts, "gate_up_proj_bias", None) is not None:
+ self.experts.gate_up_proj_bias.data = self.experts.gate_up_proj_bias.data.to(dtype)
+ if getattr(self.experts, "down_proj_bias", None) is not None:
+ self.experts.down_proj_bias.data = self.experts.down_proj_bias.data.to(dtype)
+
+ self.packed_weight = True
+
+ # Get MoE parameters
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
+
+ # Detect activation type
+ if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
+ activation = "swigluoai"
+ alpha = self.experts.alpha
+ limit = self.experts.limit
+ else:
+ activation = getattr(self.experts, "activation", "silu")
+ alpha = 1.702
+ limit = 7.0
+
+ # Get weight tensors
+ if hasattr(self.experts, "gate_up_proj"):
+ w1 = self.experts.gate_up_proj
+ elif hasattr(self.experts, "w1"):
+ w1 = self.experts.w1
+ w3 = getattr(self.experts, "w3", None)
+ if w3 is not None:
+ w1 = torch.cat([w1, w3], dim=-1)
+ else:
+ raise AttributeError("experts module must have 'gate_up_proj' or 'w1' attribute")
+
+ if hasattr(self.experts, "down_proj"):
+ w2 = self.experts.down_proj
+ elif hasattr(self.experts, "w2"):
+ w2 = self.experts.w2
+ else:
+ raise AttributeError("experts module must have 'down_proj' or 'w2' attribute")
+
+ # Get optional bias tensors
+ w1_bias = getattr(self.experts, "gate_up_proj_bias", None)
+ w2_bias = getattr(self.experts, "down_proj_bias", None)
+ w1_bias = w1_bias if w1_bias is None else w1_bias.data
+ w2_bias = w2_bias if w2_bias is None else w2_bias.data
+
+ if self.use_mxfp4:
+ w1_scale = self.experts.gate_up_proj_precision_config.weight_scale.data
+ w2_scale = self.experts.down_proj_precision_config.weight_scale.data
+
+ # Store original shape
+ in_shape = x.size()
+
+ # Route tokens to experts (Python implementation is fast enough)
+ logits, expert_weights, expert_indices = route_tokens_cpu(
+ x,
+ self.router.weight,
+ getattr(self.router, "bias", None),
+ moe_top_k,
+ moe_num_experts,
+ moe_normalize_expert_weights,
+ )
+
+ # Flatten input
+ x_flat = x.view(-1, x.shape[-1])
+
+ # Determine alpha/limit for swiglu activation
+ use_alpha = alpha if activation == "swigluoai" else None
+ use_limit = limit if activation == "swigluoai" else None
+
+ # Call C++ optimized kernel
+ output = fused_moe_cpp(
+ hidden_states=x_flat,
+ w1=w1.data,
+ w2=w2.data,
+ topk_weights=expert_weights,
+ topk_ids=expert_indices.to(torch.int32),
+ inplace=False,
+ use_int8_w8a8=False,
+ use_fp8_w8a16=False,
+ use_mxfp4=self.use_mxfp4,
+ w1_scale=w1_scale,
+ w2_scale=w2_scale,
+ block_size=None,
+ a1_scale=None,
+ a2_scale=None,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ alpha=use_alpha,
+ limit=use_limit,
+ is_vnni=getattr(self, "packed_weight", False),
+ )
+
+ # Restore original shape
+ output = output.view(in_shape)
+
+ return output, expert_weights
+
+
+__all__ = ["fused_moe_cpp", "MegaBlocksMoeMLP"]
diff --git a/build/torch29-cxx11-cu129-x86_64-linux/grouped_gemm/__init__.py b/build/torch29-cxx11-cu129-x86_64-linux/grouped_gemm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b91c8308f0c24f4c4171b6e4f15b6f76dabf295a
--- /dev/null
+++ b/build/torch29-cxx11-cu129-x86_64-linux/grouped_gemm/__init__.py
@@ -0,0 +1,2 @@
+from . import ops
+from . import backend
diff --git a/build/torch29-cxx11-cu129-x86_64-linux/grouped_gemm/backend.py b/build/torch29-cxx11-cu129-x86_64-linux/grouped_gemm/backend.py
new file mode 100644
index 0000000000000000000000000000000000000000..76037d8039cbfc2f0577275c78e4bc0be762592a
--- /dev/null
+++ b/build/torch29-cxx11-cu129-x86_64-linux/grouped_gemm/backend.py
@@ -0,0 +1,33 @@
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# # TODO(tgale): Wrap this in a try-block with better
+# # error message and instructions for building the
+# # c++ operations.
+# import grouped_gemm_backend as backend
+
+# We import the backend operations from the megablocks package as
+# grouped_gemm is vendored in megablocks in this repository.
+# from ... import _ops as backend
+# from megablocks._ops import ops as backend # type: ignore
+from .._ops import ops as backend # type: ignore
+
+def _allocate_output(a, b, batch_sizes, trans_a, trans_b):
+ assert not (trans_a and trans_b)
+ assert batch_sizes.ndim == 1, "Expected 1d tensor for batch_sizes"
+ assert a.ndim == 2, "Expected 2d tensor for 'a'"
+ assert b.ndim == (2 if trans_a else 3)
+
+ shape = (
+ (batch_sizes.shape[0], a.shape[1], b.shape[1])
+ if trans_a else
+ (a.shape[0], (b.shape[1] if trans_b else b.shape[2]))
+ )
+ return torch.empty(*shape, device=a.device, dtype=a.dtype)
+
+def gmm(a, b, batch_sizes, trans_a=False, trans_b=False, c=None):
+ if c is None:
+ c = _allocate_output(a, b, batch_sizes, trans_a, trans_b)
+ backend.gmm(a, b, c, batch_sizes, trans_a, trans_b)
+ return c
diff --git a/build/torch29-cxx11-cu129-x86_64-linux/grouped_gemm/ops.py b/build/torch29-cxx11-cu129-x86_64-linux/grouped_gemm/ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b30dd14e23837ea3b12334f4e31337ed9ad2b69
--- /dev/null
+++ b/build/torch29-cxx11-cu129-x86_64-linux/grouped_gemm/ops.py
@@ -0,0 +1,33 @@
+from . import backend
+import torch
+
+
+class GroupedGemm(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, a, b, batch_sizes, trans_b):
+ ctx.save_for_backward(a, b, batch_sizes)
+ ctx.trans_b = trans_b
+ return backend.gmm(a, b, batch_sizes, trans_a=False, trans_b=trans_b)
+
+ @staticmethod
+ def backward(ctx, grad):
+ grad = grad.contiguous()
+ a, b, batch_sizes = ctx.saved_tensors
+ trans_b = ctx.trans_b
+
+ agrad = None
+ if ctx.needs_input_grad[0]:
+ agrad = backend.gmm(
+ grad, b, batch_sizes, trans_a=False, trans_b=not trans_b)
+
+ bgrad = None
+ if ctx.needs_input_grad[1]:
+ lhs, rhs = (grad, a) if trans_b else (a, grad)
+ bgrad = backend.gmm(
+ lhs, rhs, batch_sizes, trans_a=True, trans_b=False)
+ return agrad, bgrad, None, None
+
+
+def gmm(a, b, batch_sizes, trans_b=False):
+ return GroupedGemm.apply(a, b, batch_sizes, trans_b)
diff --git a/build/torch29-cxx11-cu129-x86_64-linux/grouped_gemm_util.py b/build/torch29-cxx11-cu129-x86_64-linux/grouped_gemm_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d6d49fc46b0a57ad46e4179df3cc1ac2a24f7ae
--- /dev/null
+++ b/build/torch29-cxx11-cu129-x86_64-linux/grouped_gemm_util.py
@@ -0,0 +1,31 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+import warnings
+
+_grouped_gemm_is_available: bool = False
+try:
+ # import grouped_gemm
+ pass
+ _grouped_gemm_is_available = True
+except ImportError as error:
+ warnings.warn('Grouped GEMM not available.')
+
+
+def grouped_gemm_is_available():
+ return _grouped_gemm_is_available
+
+
+def assert_grouped_gemm_is_available():
+ msg = (
+ 'Grouped GEMM not available. Please run '
+ '`pip install git+https://github.com/tgale96/grouped_gemm@main`.',
+ )
+ assert _grouped_gemm_is_available, msg
+
+
+# backend = grouped_gemm.backend if grouped_gemm_is_available() else None
+# ops = grouped_gemm.ops if grouped_gemm_is_available() else None
+
+
+from .grouped_gemm import backend as ops
+from .grouped_gemm import ops as backend
diff --git a/build/torch29-cxx11-cu129-x86_64-linux/layers.py b/build/torch29-cxx11-cu129-x86_64-linux/layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c91edcd9e2d1a4ef9eac90217ff481f08ab1886
--- /dev/null
+++ b/build/torch29-cxx11-cu129-x86_64-linux/layers.py
@@ -0,0 +1,1232 @@
+import torch
+import torch.distributed as dist
+
+from typing import Optional, Any, TYPE_CHECKING
+
+from . import _layers
+from . import ops
+
+# Conditional import for meta kernel registration
+if TYPE_CHECKING:
+
+ def register_fake(fn):
+ return lambda name: fn
+
+else:
+ try:
+ from torch.library import register_fake
+ except ImportError:
+ try:
+ from torch.library import impl_abstract as register_fake
+ except ImportError:
+ # Fallback for older PyTorch versions
+ def register_fake(op_name):
+ def decorator(fn):
+ return fn
+
+ return decorator
+
+
+# Meta kernel implementations for torch.compile compatibility
+def _install_meta_kernels():
+ """Install meta kernels for existing MegaBlocks operations"""
+
+ # Create wrapper functions that check for compilation and return meta tensors
+
+ # Patch ops.sort
+ if hasattr(ops, "sort"):
+ original_sort = ops.sort
+
+ def sort_with_meta(x, end_bit=None):
+ if torch.compiler.is_compiling():
+ # print("Using meta kernel for sort")
+ # Meta implementation - return tensors with correct shape/dtype/device
+ return torch.empty_like(x), torch.empty_like(x)
+ # print("Using original sort kernel")
+ return original_sort(x, end_bit)
+
+ ops.sort = sort_with_meta
+
+ # Patch ops.histogram
+ if hasattr(ops, "histogram"):
+ original_histogram = ops.histogram
+
+ def histogram_with_meta(x, max_val):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty((max_val,), dtype=torch.int32, device=x.device)
+ return original_histogram(x, max_val)
+
+ ops.histogram = histogram_with_meta
+
+ # Patch ops.inclusive_cumsum
+ if hasattr(ops, "inclusive_cumsum"):
+ original_inclusive_cumsum = ops.inclusive_cumsum
+
+ def inclusive_cumsum_with_meta(x, dim):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty_like(x)
+ return original_inclusive_cumsum(x, dim)
+
+ ops.inclusive_cumsum = inclusive_cumsum_with_meta
+
+ # Patch ops.binned_gather
+ if hasattr(ops, "binned_gather"):
+ original_binned_gather = ops.binned_gather
+
+ def binned_gather_with_meta(x, indices, bins, bin_size, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - output shape based on bin_size
+ if x.dim() >= 2:
+ hidden_size = x.size(-1)
+ return torch.empty(
+ (bin_size, x.size(1), hidden_size),
+ dtype=x.dtype,
+ device=x.device,
+ )
+ else:
+ return torch.empty((bin_size,), dtype=x.dtype, device=x.device)
+ return original_binned_gather(x, indices, bins, bin_size, top_k)
+
+ ops.binned_gather = binned_gather_with_meta
+
+ # Patch ops.binned_scatter
+ if hasattr(ops, "binned_scatter"):
+ original_binned_scatter = ops.binned_scatter
+
+ def binned_scatter_with_meta(x, indices, weights, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - typically reduces to 2D
+ if x.dim() >= 3:
+ return torch.empty(
+ (x.size(1), x.size(2)), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty_like(x)
+ return original_binned_scatter(x, indices, weights, bins, top_k)
+
+ ops.binned_scatter = binned_scatter_with_meta
+
+ # Patch ops.gather
+ if hasattr(ops, "gather"):
+ original_gather = ops.gather
+
+ def gather_with_meta(x, indices, bin_ids, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ if x.dim() >= 2:
+ hidden_size = x.size(-1)
+ return torch.empty(
+ (indices.numel(), hidden_size), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty(indices.shape, dtype=x.dtype, device=x.device)
+ return original_gather(x, indices, bin_ids, bins, top_k)
+
+ ops.gather = gather_with_meta
+
+ # Patch ops.scatter
+ if hasattr(ops, "scatter"):
+ original_scatter = ops.scatter
+
+ def scatter_with_meta(x, indices, bin_ids, weights, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - restore sequence shape
+ seq_len = (
+ indices.size(0) // top_k
+ if indices.numel() > 0 and top_k > 0
+ else x.size(0)
+ )
+ if x.dim() >= 2:
+ return torch.empty(
+ (seq_len, x.size(-1)), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty((seq_len,), dtype=x.dtype, device=x.device)
+ return original_scatter(x, indices, bin_ids, weights, bins, top_k)
+
+ ops.scatter = scatter_with_meta
+
+ # Patch ops.replicate
+ if hasattr(ops, "replicate"):
+ original_replicate = ops.replicate
+
+ def replicate_with_meta(x, bins, num_outputs):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty(
+ (x.shape[0], num_outputs), dtype=x.dtype, device=x.device
+ )
+ return original_replicate(x, bins, num_outputs)
+
+ ops.replicate = replicate_with_meta
+
+ # Patch ops.repeat (if it's a regular function)
+ if hasattr(ops, "repeat"):
+ original_repeat = ops.repeat
+
+ def repeat_with_meta(x, repeats):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ if isinstance(repeats, (tuple, list)):
+ new_shape = list(x.shape)
+ for i, rep in enumerate(repeats):
+ if i < len(new_shape):
+ new_shape[i] *= rep
+ return torch.empty(new_shape, dtype=x.dtype, device=x.device)
+ else:
+ new_shape = [x.size(0) * repeats] + list(x.shape[1:])
+ return torch.empty(new_shape, dtype=x.dtype, device=x.device)
+ return original_repeat(x, repeats)
+
+ ops.repeat = repeat_with_meta
+
+
+# Install meta kernels on import
+try:
+ _install_meta_kernels()
+except Exception as e:
+ # If meta kernel installation fails, continue without them
+ # torch.compile may not work but the library will still function
+ import warnings
+
+ warnings.warn(
+ f"Failed to install meta kernels for torch.compile support: {e}", UserWarning
+ )
+
+
+# Set the expert model parallel attributes on a tensor
+def set_expert_model_parallel_attributes(
+ tensor: torch.Tensor,
+ is_parallel: bool,
+):
+ assert not hasattr(tensor, "expert_model_parallel")
+ setattr(tensor, "expert_model_parallel", is_parallel)
+
+
+# Get the expert model parallel attributes from a tensor
+def expert_sharding_degree(
+ world_size: int,
+ moe_num_experts: int,
+) -> int:
+ esd = min(world_size, moe_num_experts)
+ if (moe_num_experts % esd) != 0:
+ raise ValueError(f"Cannot shard {moe_num_experts} experts {esd} ways.")
+ return esd
+
+
+# Calculate the hidden sharding degree based on world size and expert sharding degree
+def hidden_sharding_degree(
+ world_size: int,
+ moe_num_experts: int,
+ ffn_hidden_size: int,
+) -> int:
+ esd = expert_sharding_degree(world_size, moe_num_experts)
+ hsd = world_size // esd
+ if (ffn_hidden_size % hsd) != 0:
+ raise ValueError(f"Cannot shard {ffn_hidden_size} features {hsd} ways.")
+ if (esd * hsd) != world_size:
+ raise ValueError(
+ f"Invalid sharding. expert_sharding_degree ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size})."
+ )
+ return hsd
+
+
+# Calculate the number of experts per rank based on world size and expert sharding degree
+def experts_per_rank(
+ moe_num_experts: int,
+ world_size: int,
+) -> int:
+ return moe_num_experts // expert_sharding_degree(world_size, moe_num_experts)
+
+
+# Calculate the number of features per rank based on ffn hidden size and hidden sharding degree
+def features_per_rank(
+ ffn_hidden_size: int, world_size: int, moe_num_experts: int
+) -> int:
+ return ffn_hidden_size // hidden_sharding_degree(
+ world_size, moe_num_experts, ffn_hidden_size
+ )
+
+
+# Apply jitter to the input tensor
+def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor:
+ low = 1.0 - moe_jitter_eps
+ high = 1.0 + moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return x * (low + noise * (high - low))
+
+
+# Compute the top-k scores from the logits
+def compute_top_k(scores: torch.Tensor, moe_top_k: int):
+ if moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, moe_top_k, dim=-1)
+
+
+# Route tokens to experts and compute expert weights and indices
+def route_tokens(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: torch.Tensor,
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ if training and moe_jitter_eps is not None:
+ x = apply_jitter(x, moe_jitter_eps)
+
+ x_flat = x.view(-1, x.shape[-1])
+ logits = torch.nn.functional.linear(x_flat, router_weight, router_bias)
+ expert_weights, expert_indices = compute_top_k(logits, moe_top_k)
+ expert_weights = expert_weights.softmax(dim=-1)
+ if moe_normalize_expert_weights is not None:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+ if uniform_expert_assignment:
+ expert_indices = _layers.router._uniform_expert_assignment(
+ expert_indices,
+ moe_num_experts,
+ )
+
+ return logits, expert_weights, expert_indices
+
+
+# Scale the gradient of the weights
+def scale_grad(
+ w: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+) -> torch.Tensor:
+ if gradient_scale is None:
+ return w
+ return _layers.mlp.scale_gradient(w, gradient_scale)
+
+
+# Forward pass for the MLP layer
+def mlp_forward(
+ x: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ limit: float = 7.0,
+):
+ # Scale weights
+ w1 = scale_grad(w1, gradient_scale)
+ w2 = scale_grad(w2, gradient_scale)
+ w1_bias = scale_grad(w1_bias, gradient_scale)
+ w2_bias = scale_grad(w2_bias, gradient_scale)
+
+ # Resolve dtensors
+ w1 = _layers.mlp.resolve_dtensor(w1)
+ w2 = _layers.mlp.resolve_dtensor(w2)
+ w1_bias = _layers.mlp.resolve_dtensor(w1_bias)
+ w2_bias = _layers.mlp.resolve_dtensor(w2_bias)
+
+ # Forward pass
+ gate_up = torch.bmm(x, w1) + w1_bias[..., None, :]
+ gate, up = gate_up[..., ::2], gate_up[..., 1::2]
+ gate = gate.clamp(min=None, max=limit)
+ up = up.clamp(min=-limit, max=limit)
+ glu = gate * torch.sigmoid(gate * alpha)
+ next_states = torch.bmm(((up + 1) * glu), w2)
+ next_states += w2_bias[..., None, :]
+ return next_states
+
+# Shared expert MLP forward pass
+def shared_mlp_forward(
+ x: torch.Tensor,
+ up_proj_weight: torch.Tensor,
+ down_proj_weight: torch.Tensor,
+ up_proj_bias: Optional[torch.Tensor] = None,
+ down_proj_bias: Optional[torch.Tensor] = None,
+ activation_fn: Optional[Any] = None,
+ gradient_scale: Optional[float] = None,
+) -> torch.Tensor:
+ # Default activation function
+ if activation_fn is None:
+ activation_fn = torch.nn.functional.gelu
+
+ # Scale weights
+ up_proj_weight = scale_grad(up_proj_weight, gradient_scale)
+ down_proj_weight = scale_grad(down_proj_weight, gradient_scale)
+ if up_proj_bias is not None:
+ up_proj_bias = scale_grad(up_proj_bias, gradient_scale)
+ if down_proj_bias is not None:
+ down_proj_bias = scale_grad(down_proj_bias, gradient_scale)
+
+ # Resolve dtensors
+ up_proj_weight = _layers.mlp.resolve_dtensor(up_proj_weight)
+ down_proj_weight = _layers.mlp.resolve_dtensor(down_proj_weight)
+ if up_proj_bias is not None:
+ up_proj_bias = _layers.mlp.resolve_dtensor(up_proj_bias)
+ if down_proj_bias is not None:
+ down_proj_bias = _layers.mlp.resolve_dtensor(down_proj_bias)
+
+ # Up projection
+ x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias)
+
+ # Activation
+ x = activation_fn(x)
+
+ # Down projection
+ x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias)
+
+ return x
+
+
+# Combine outputs from shared expert and regular experts
+def combine_expert_shared_outputs(
+ shared_expert_out: torch.Tensor,
+ expert_out: torch.Tensor,
+ shared_expert_weighted_sum: bool = False,
+ moe_top_k: int = 1,
+) -> torch.Tensor:
+ if shared_expert_weighted_sum:
+ # Weighted sum based on number of experts used
+ total_experts = moe_top_k + 1
+ shared_weight = 1.0 / total_experts
+ expert_weight = moe_top_k / total_experts
+ return shared_expert_out * shared_weight + expert_out * expert_weight
+ else:
+ # Simple addition
+ return shared_expert_out + expert_out
+
+
+# Global variable to store load balancing loss
+_LOAD_BALANCING_LOSS = []
+
+
+def save_load_balancing_loss(loss):
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.append(loss)
+
+
+def get_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ return _LOAD_BALANCING_LOSS
+
+
+def clear_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.clear()
+
+
+def batched_load_balancing_loss(args):
+ if args.moe_loss_weight == 0:
+ return 0.0
+
+ tokens_per_expert, expert_scores = zip(*get_load_balancing_loss())
+ num_layers_per_pipeline_stage = args.num_layers // args.pipeline_model_parallel_size
+ if args.num_layers_per_virtual_pipeline_stage is not None:
+ num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage
+
+ if len(tokens_per_expert) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f"Expected {num_layers_per_pipeline_stage} token_per_experts "
+ f"but found {len(tokens_per_expert)}.\nnum_layers = "
+ f"{args.num_layers}\npipeline_model_parallel_size = "
+ f"{args.pipeline_model_parallel_size}\n"
+ "num_layers_per_virtual_pipeline_stage"
+ f" = {args.num_layers_per_virtual_pipeline_stage}",
+ )
+ if len(expert_scores) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f"Expected {num_layers_per_pipeline_stage} expert_scores "
+ f"but found {len(tokens_per_expert)}.\nnum_layers = "
+ f"{args.num_layers}\npipeline_model_parallel_size = "
+ f"{args.pipeline_model_parallel_size}\n"
+ "num_layers_per_virtual_pipeline_stage"
+ f" = {args.num_layers_per_virtual_pipeline_stage}",
+ )
+
+ # Verify the shape of the tokens_per_expert and expert_scores tensors.
+ assert all(
+ (x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert)
+ )
+
+ tokens = expert_scores[0].shape[0]
+ assert all(
+ (
+ (
+ x.ndim == 2
+ and x.shape[1] == args.moe_num_experts
+ and x.shape[0] == tokens
+ )
+ for x in expert_scores
+ )
+ )
+
+ # Concatenate the contributions of each layer and convert to
+ # the correct types and formats for the dot product.
+ expert_scores = torch.cat(expert_scores, dim=1)
+ if args.moe_lbl_in_fp32:
+ expert_scores = expert_scores.float()
+ if tokens != 0:
+ expert_scores = expert_scores.mean(dim=0)
+ else:
+ expert_scores = expert_scores.sum(dim=0)
+ tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype)
+
+ expected_values = num_layers_per_pipeline_stage * args.moe_num_experts
+ assert tokens_per_expert.numel() == expected_values
+ assert expert_scores.numel() == expected_values
+
+ # Calculate the total scale across all factors.
+ #
+ # loss_weight * num_experts / (num_layers * tokens * top_k)
+ scale_numerator = args.moe_num_experts * args.moe_loss_weight
+ scale_denominator = args.num_layers * tokens * args.moe_top_k
+ scale = scale_numerator / scale_denominator
+ return scale * torch.dot(tokens_per_expert, expert_scores)
+
+
+# Calculate the expert capacity based on tokens, top_k, number of experts,
+# expert parallel group, capacity factor, and whether expert model parallelism is used.
+def expert_capacity(
+ tokens: int,
+ top_k: int,
+ num_experts: int,
+ expert_parallel_group: int,
+ moe_capacity_factor: float,
+ moe_expert_model_parallelism: bool,
+) -> int:
+ world_size = (
+ dist.get_world_size(expert_parallel_group)
+ if moe_expert_model_parallelism
+ else 1
+ )
+
+ tokens_per_expert = top_k * tokens * world_size / num_experts
+ return int(moe_capacity_factor * tokens_per_expert)
+
+
+def load_balancing_loss(
+ tokens_per_expert: torch.Tensor,
+ expert_scores: torch.Tensor,
+ top_k: int,
+ num_experts: int,
+):
+ assert len(expert_scores.size()) == 2
+ tokens, num_experts = expert_scores.size()
+ assert num_experts == num_experts
+ assert len(tokens_per_expert.size()) == 1
+ (num_experts,) = tokens_per_expert.size()
+ assert num_experts == num_experts
+ scale = num_experts / (tokens * top_k)
+ return scale * torch.dot(
+ tokens_per_expert.to(expert_scores.dtype),
+ expert_scores.mean(dim=0),
+ )
+
+
+def indices_and_bins(
+ top_expert: torch.Tensor,
+ sort_end_bit: int,
+ num_experts: int,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ top_expert = top_expert.int()
+
+ # Ensure contiguous memory layout
+ top_expert = top_expert.contiguous()
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(top_expert.device):
+ output = ops.sort(top_expert, sort_end_bit)
+ bin_ids, indices = output
+ tokens_per_expert = ops.histogram(top_expert, num_experts)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+
+ bins = bins.view(1) if not len(bins.size()) else bins
+ return indices, bin_ids, bins, tokens_per_expert
+
+
+def expert_capacity_fn(
+ tokens: int,
+ top_k: int,
+ num_experts: int,
+ expert_parallel_group: torch.distributed.ProcessGroup,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+) -> int:
+ world_size = (
+ dist.get_world_size(expert_parallel_group)
+ if moe_expert_model_parallelism
+ else 1
+ )
+ tokens_per_expert = top_k * tokens * world_size / num_experts
+ return int(moe_capacity_factor * tokens_per_expert)
+
+
+def permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ top_k,
+ w1,
+ w2,
+ w1_bias,
+ w2_bias,
+ gradient_scale,
+ alpha,
+):
+ # Route tokens to experts
+ x = x.view(-1, x.shape[-1])
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(x.device):
+ x = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
+
+ # Expert computation
+ x = mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale, alpha)
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(x.device):
+ # Route tokens back
+ out = ops.binned_scatter(x, indices, expert_weights, bins, top_k)
+ return out
+
+
+def forward_once(
+ x: torch.Tensor,
+ expert_weights: torch.Tensor,
+ top_experts: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ top_k: int = 4,
+ num_experts: int = 128,
+ expert_parallel_group: int = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ mlp_impl: Optional[str] = None,
+):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
+ top_experts, sort_end_bit, num_experts
+ )
+
+ # Calculate expert capacity
+ sl, bs, _ = x.size()
+
+ expert_capacity = expert_capacity_fn(
+ sl * bs,
+ top_k,
+ num_experts,
+ expert_parallel_group,
+ moe_capacity_factor,
+ moe_expert_model_parallelism,
+ )
+
+ if expert_capacity == 0:
+ expert_capacity = torch.max(tokens_per_expert).item()
+
+ x = permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ top_k,
+ w1,
+ w2,
+ w1_bias,
+ w2_bias,
+ gradient_scale,
+ alpha,
+ )
+ return x, tokens_per_expert
+
+
+def parallel_forward_once(
+ x: torch.Tensor,
+ expert_weights: torch.Tensor,
+ top_experts: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ top_k: int = 4,
+ num_experts: int = 128,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = True,
+ hidden_size: int = 1152,
+ mlp_impl: Optional[str] = "grouped",
+):
+ # Flatten inputs
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+
+ # TODO: remove debugging var
+ # my_rank = dist.get_rank(expert_parallel_group) if expert_parallel_group else 0
+
+ with torch.no_grad():
+ # Step 1: Local permutation setup
+ indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
+ top_experts, sort_end_bit, num_experts
+ )
+
+ # Calculate sharding parameters
+ world_size = dist.get_world_size(expert_parallel_group)
+ hidden_sharding_deg = hidden_sharding_degree(
+ world_size, num_experts, hidden_size
+ )
+ experts_per_rank_val = experts_per_rank(num_experts, world_size)
+
+ # Replicate token counts for hidden sharding
+ repeated_tokens_per_expert = ops.repeat(
+ tokens_per_expert, (hidden_sharding_deg,)
+ )
+
+ # Exchange token counts across devices
+ parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert)
+
+ # Ensure CUB knows which device to use
+ tpe_handle = dist.all_to_all_single(
+ parallel_tokens_per_expert,
+ repeated_tokens_per_expert,
+ group=expert_parallel_group,
+ async_op=True,
+ )
+
+ # Step 2: Local permutation - group tokens by target device
+ x = x.view(-1, x.shape[-1]) # [sl * bs, hs]
+ x = ops.gather(x, indices, bin_ids, bins, top_k)
+
+ # Step 3: Compute communication counts and exchange tokens
+ with torch.no_grad():
+ tpe_handle.wait()
+
+ # Reshape for per-device calculations
+ repeated_tokens_per_expert = repeated_tokens_per_expert.view(
+ world_size, experts_per_rank_val
+ )
+ parallel_tokens_per_expert = parallel_tokens_per_expert.view(
+ world_size, experts_per_rank_val
+ )
+
+ # Calculate send/recv counts
+ send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1).tolist()
+ # recv_counts = parallel_tokens_per_expert.cpu().sum(dim=-1).tolist()
+ parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
+ recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1).tolist()
+ tokens_received = sum(recv_counts)
+
+ # Replicate for hidden sharding
+ x = ops.repeat(x, (hidden_sharding_deg, 1))
+
+ # Cross-device token exchange
+ parallel_x, parallel_x_handle = _layers.all_to_all.all_to_all(
+ x, recv_counts, send_counts, expert_parallel_group, async_op=True
+ )
+
+ with torch.no_grad():
+ # Step 4: Setup for local expert computation
+ replicate_bins = ops.inclusive_cumsum(parallel_tokens_per_expert.flatten(), 0)
+ replicate_bins = (
+ replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins
+ )
+
+ # Create expert indices for received tokens
+ parallel_top_expert = torch.remainder(
+ torch.arange(
+ num_experts * hidden_sharding_deg,
+ dtype=torch.int32,
+ device=indices.device,
+ ),
+ experts_per_rank_val,
+ )
+ parallel_top_expert = ops.replicate(
+ parallel_top_expert.unsqueeze(dim=0),
+ replicate_bins,
+ tokens_received,
+ ).flatten()
+
+ # Sort tokens by expert assignment
+ parallel_bin_ids, parallel_indices = ops.sort(
+ parallel_top_expert,
+ sort_end_bit,
+ )
+
+ # Calculate bins for local experts
+ parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
+ dim=0, dtype=torch.int
+ )
+ parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = (
+ parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins
+ )
+
+ # Calculate expert capacity
+ expert_capacity = expert_capacity_fn(
+ tokens_received,
+ top_k,
+ experts_per_rank_val,
+ expert_parallel_group,
+ moe_capacity_factor,
+ moe_expert_model_parallelism,
+ )
+ if expert_capacity == 0:
+ expert_capacity = torch.max(parallel_tokens_per_expert).item()
+
+ # Locally permute the tokens and perform the expert computation.
+ # Block to make sure that the cross-device permutation is complete.
+ if mlp_impl == "grouped":
+ # GroupedMLP requires counts on CPU. We can use the tensor already
+ # moved to CPU for the prior all_to_all, which avoids an extra
+ # device synchronization.
+ parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+
+ # Step 5: Expert computation
+ parallel_x_handle.wait()
+
+ parallel_x = permute_and_compute(
+ parallel_x,
+ parallel_tokens_per_expert,
+ parallel_indices,
+ parallel_bin_ids,
+ None, # expert_weights
+ parallel_bins,
+ expert_capacity,
+ top_k=1,
+ w1=w1,
+ w2=w2,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ )
+
+ # Step 6: Reverse communication - send results back
+ x, _ = _layers.all_to_all.all_to_all(
+ parallel_x, send_counts, recv_counts, expert_parallel_group
+ )
+
+ # Step 7: Reduce across hidden sharding dimension
+ shape = (hidden_sharding_deg, -1, hidden_size)
+ x = x.view(shape).sum(dim=0)
+
+ # Step 8: Final local unpermutation
+ x = ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
+
+ return x, tokens_per_expert.flatten()
+
+
+def moe_forward(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: Optional[torch.Tensor],
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+ w1: torch.Tensor = None,
+ w2: torch.Tensor = None,
+ w1_bias: torch.Tensor = None,
+ w2_bias: torch.Tensor = None,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ forward_fn: Any = None,
+ hidden_size: int = None,
+ mlp_impl: str = "grouped",
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+
+ # Route tokens to experts
+ logits, expert_weights, expert_indices = route_tokens(
+ x,
+ router_weight,
+ router_bias,
+ moe_top_k,
+ moe_num_experts,
+ moe_jitter_eps,
+ moe_normalize_expert_weights,
+ uniform_expert_assignment,
+ training,
+ )
+
+ # Create router scores for output
+ router_scores = (
+ torch.zeros_like(logits)
+ .scatter_(1, expert_indices, expert_weights)
+ .transpose(0, 1)
+ )
+
+ in_shape = x.size()
+
+ # Prepare forward function arguments
+ forward_args = {
+ "x": x,
+ "expert_weights": expert_weights,
+ "top_experts": expert_indices,
+ "w1": w1,
+ "w2": w2,
+ "w1_bias": w1_bias,
+ "w2_bias": w2_bias,
+ "gradient_scale": gradient_scale,
+ "alpha": alpha,
+ "sort_end_bit": sort_end_bit,
+ "top_k": moe_top_k,
+ "num_experts": moe_num_experts,
+ "expert_parallel_group": expert_parallel_group,
+ "moe_capacity_factor": moe_capacity_factor,
+ "moe_expert_model_parallelism": moe_expert_model_parallelism,
+ "mlp_impl": mlp_impl,
+ }
+
+ # Add hidden_size for parallel forward
+ if moe_expert_model_parallelism and hidden_size is not None:
+ forward_args["hidden_size"] = hidden_size
+ elif moe_expert_model_parallelism and hidden_size is None:
+ # Infer hidden_size from input shape
+ forward_args["hidden_size"] = x.shape[-1]
+
+ # Compute expert outputs
+ x, tokens_per_expert = forward_fn(**forward_args)
+
+ # Save load balancing loss if needed
+ moe_loss_weight = 0.0 # Can be made configurable
+ if training and moe_loss_weight > 0:
+ save_load_balancing_loss((tokens_per_expert, logits))
+
+ # Restore original shape
+ x = x.view(in_shape)
+
+ return x, expert_weights, router_scores
+
+
+def moe_forward_with_shared_expert(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: Optional[torch.Tensor],
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+ w1: torch.Tensor = None,
+ w2: torch.Tensor = None,
+ w1_bias: torch.Tensor = None,
+ w2_bias: torch.Tensor = None,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ forward_fn: Any = None,
+ hidden_size: int = None,
+ mlp_impl: str = "grouped",
+ # Shared expert parameters
+ shared_up_proj_weight: Optional[torch.Tensor] = None,
+ shared_down_proj_weight: Optional[torch.Tensor] = None,
+ shared_up_proj_bias: Optional[torch.Tensor] = None,
+ shared_down_proj_bias: Optional[torch.Tensor] = None,
+ shared_expert_weighted_sum: bool = False,
+ shared_activation_fn: Optional[Any] = None,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+
+ # First, compute regular MoE forward pass
+ expert_out, expert_weights, router_scores = moe_forward(
+ x=x,
+ router_weight=router_weight,
+ router_bias=router_bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=training,
+ w1=w1,
+ w2=w2,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=moe_expert_model_parallelism,
+ forward_fn=forward_fn,
+ hidden_size=hidden_size,
+ mlp_impl=mlp_impl,
+ )
+
+ # If shared expert weights provided, compute shared expert output
+ if shared_up_proj_weight is not None and shared_down_proj_weight is not None:
+ shared_expert_out = shared_mlp_forward(
+ x=x,
+ up_proj_weight=shared_up_proj_weight,
+ down_proj_weight=shared_down_proj_weight,
+ up_proj_bias=shared_up_proj_bias,
+ down_proj_bias=shared_down_proj_bias,
+ activation_fn=shared_activation_fn,
+ gradient_scale=gradient_scale,
+ )
+
+ # Combine expert outputs
+ combined_out = combine_expert_shared_outputs(
+ shared_expert_out=shared_expert_out,
+ expert_out=expert_out,
+ shared_expert_weighted_sum=shared_expert_weighted_sum,
+ moe_top_k=moe_top_k,
+ )
+
+ return combined_out, expert_weights, router_scores
+
+ # Return regular MoE output if no shared expert
+ return expert_out, expert_weights, router_scores
+
+
+def create_shared_expert_weights(
+ hidden_size: int,
+ shared_expert_hidden_size: int,
+ device: torch.device,
+ dtype: torch.dtype,
+ init_method: Any,
+ output_layer_init_method: Any = None,
+) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
+
+ if output_layer_init_method is None:
+ output_layer_init_method = init_method
+
+ # Create weight tensors
+ up_proj_weight = torch.empty(
+ shared_expert_hidden_size,
+ hidden_size,
+ device=device,
+ dtype=dtype,
+ )
+ down_proj_weight = torch.empty(
+ hidden_size,
+ shared_expert_hidden_size,
+ device=device,
+ dtype=dtype,
+ )
+
+ # Initialize weights
+ init_method(up_proj_weight)
+ output_layer_init_method(down_proj_weight)
+
+ # No bias by default
+ return up_proj_weight, down_proj_weight, None, None
+
+
+# HACK: Extract device_mesh from pre-hook closure - required for transformers integration
+# This exists because device_mesh is trapped in hook closures with no model attribute
+# Fragile - breaks if hook structure changes or Python internals change
+# TODO: Replace with a more robust solution when available
+def get_device_mesh(model):
+ # Extract device_mesh from child's unused pre_hook closure
+ try:
+ # Find the pre-hook that contains 'device_mesh' in its closure
+ hook = next(
+ h
+ for h in model.experts._forward_pre_hooks.values()
+ if "device_mesh" in h.__code__.co_freevars
+ )
+ # Extract the device_mesh from the closure
+ return hook.__closure__[
+ hook.__code__.co_freevars.index("device_mesh")
+ ].cell_contents
+ except Exception:
+ return None
+
+
+class MegaBlocksMoeMLP(torch.nn.Module):
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ gradient_scale = getattr(self.experts, "gradient_scale", None)
+ alpha = getattr(self.experts, "alpha", 1.0)
+ moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+ uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
+
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = get_device_mesh(self)
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
+
+ has_parallel = (
+ expert_parallel_group is not None
+ and dist.is_initialized()
+ and dist.get_world_size(expert_parallel_group) > 1
+ )
+ forward_fn = parallel_forward_once if has_parallel else forward_once
+
+ sort_end_bit = max(
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
+ )
+ mlp_impl = getattr(self, "mlp_impl", "grouped")
+ output, expert_weights_out, *_ = moe_forward(
+ x=x,
+ router_weight=self.router.weight,
+ router_bias=self.router.bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=self.training,
+ w1=self.experts.gate_up_proj,
+ w2=self.experts.down_proj,
+ w1_bias=self.experts.gate_up_proj_bias,
+ w2_bias=self.experts.down_proj_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=has_parallel,
+ forward_fn=forward_fn,
+ hidden_size=self.experts.hidden_size,
+ mlp_impl=mlp_impl,
+ )
+ return output, expert_weights_out
+
+
+# Export main classes
+__all__ = ["MegaBlocksMoeMLP", "MegaBlocksMoeMLPWithSharedExpert"]
+
+
+class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
+
+ def __init__(self):
+ super().__init__()
+ # Shared expert weights will be set by the user
+ self.shared_up_proj_weight = None
+ self.shared_down_proj_weight = None
+ self.shared_up_proj_bias = None
+ self.shared_down_proj_bias = None
+ self.shared_expert_weighted_sum = False
+ self.shared_activation_fn = None
+
+ def set_shared_expert_weights(
+ self,
+ up_proj_weight: torch.Tensor,
+ down_proj_weight: torch.Tensor,
+ up_proj_bias: Optional[torch.Tensor] = None,
+ down_proj_bias: Optional[torch.Tensor] = None,
+ weighted_sum: bool = False,
+ activation_fn: Optional[Any] = None,
+ ):
+ self.shared_up_proj_weight = up_proj_weight
+ self.shared_down_proj_weight = down_proj_weight
+ self.shared_up_proj_bias = up_proj_bias
+ self.shared_down_proj_bias = down_proj_bias
+ self.shared_expert_weighted_sum = weighted_sum
+ self.shared_activation_fn = activation_fn
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ gradient_scale = getattr(self.experts, "gradient_scale", None)
+ alpha = getattr(self.experts, "alpha", 1.0)
+ moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+ uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
+
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = get_device_mesh(self)
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
+
+ has_parallel = (
+ expert_parallel_group is not None
+ and dist.is_initialized()
+ and dist.get_world_size(expert_parallel_group) > 1
+ )
+ forward_fn = parallel_forward_once if has_parallel else forward_once
+
+ sort_end_bit = max(
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
+ )
+ mlp_impl = getattr(self, "mlp_impl", "grouped")
+
+ output, expert_weights_out, *_ = moe_forward_with_shared_expert(
+ x=x,
+ router_weight=self.router.weight,
+ router_bias=self.router.bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=self.training,
+ w1=self.experts.gate_up_proj,
+ w2=self.experts.down_proj,
+ w1_bias=self.experts.gate_up_proj_bias,
+ w2_bias=self.experts.down_proj_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=has_parallel,
+ forward_fn=forward_fn,
+ hidden_size=self.experts.hidden_size,
+ mlp_impl=mlp_impl,
+ # Shared expert parameters
+ shared_up_proj_weight=self.shared_up_proj_weight,
+ shared_down_proj_weight=self.shared_down_proj_weight,
+ shared_up_proj_bias=self.shared_up_proj_bias,
+ shared_down_proj_bias=self.shared_down_proj_bias,
+ shared_expert_weighted_sum=self.shared_expert_weighted_sum,
+ shared_activation_fn=self.shared_activation_fn,
+ )
+ return output, expert_weights_out
+
+
+# Patch for XPU or CPU support
+if hasattr(torch, "xpu") and torch.xpu.is_available():
+ from .xpu_fused_moe import MegaBlocksMoeMLP
+
+from .cpu_moe_cpp import CPUMegaBlocksMoeMLP
diff --git a/build/torch29-cxx11-cu129-x86_64-linux/megablocks/__init__.py b/build/torch29-cxx11-cu129-x86_64-linux/megablocks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23
--- /dev/null
+++ b/build/torch29-cxx11-cu129-x86_64-linux/megablocks/__init__.py
@@ -0,0 +1,26 @@
+import ctypes
+import importlib.util
+import sys
+from pathlib import Path
+from types import ModuleType
+
+
+def _import_from_path(file_path: Path) -> ModuleType:
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
+ # it would also be used for other imports. So, we make a module name that
+ # depends on the path for it to be unique using the hex-encoded hash of
+ # the path.
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
+ module_name = path_hash
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ if spec is None:
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
+ module = importlib.util.module_from_spec(spec)
+ if module is None:
+ raise ImportError(f"Cannot load module {module_name} from spec")
+ sys.modules[module_name] = module
+ spec.loader.exec_module(module) # type: ignore
+ return module
+
+
+globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
diff --git a/build/torch29-cxx11-cu129-x86_64-linux/metadata.json b/build/torch29-cxx11-cu129-x86_64-linux/metadata.json
new file mode 100644
index 0000000000000000000000000000000000000000..e3e4edf582b7ffb515d0ed32e9fc9c89f125c441
--- /dev/null
+++ b/build/torch29-cxx11-cu129-x86_64-linux/metadata.json
@@ -0,0 +1,21 @@
+{
+ "version": 1,
+ "license": "Apache-2.0",
+ "python-depends": [],
+ "backend": {
+ "type": "cuda",
+ "archs": [
+ "10.0",
+ "10.1",
+ "12.0",
+ "7.0",
+ "7.2",
+ "7.5",
+ "8.0",
+ "8.6",
+ "8.7",
+ "8.9",
+ "9.0"
+ ]
+ }
+}
diff --git a/build/torch29-cxx11-cu129-x86_64-linux/ops/__init__.py b/build/torch29-cxx11-cu129-x86_64-linux/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b944080df810d0b0cfc571f3009b0098a651f9b7
--- /dev/null
+++ b/build/torch29-cxx11-cu129-x86_64-linux/ops/__init__.py
@@ -0,0 +1,35 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from .binned_gather import binned_gather
+from .binned_scatter import binned_scatter
+from .cumsum import exclusive_cumsum, inclusive_cumsum
+from .gather import gather
+from .histogram import histogram
+from .padded_gather import padded_gather
+from .padded_scatter import padded_scatter
+from .repeat import repeat
+from .replicate import replicate
+from .round_up import round_up
+from .scatter import scatter
+from .sort import sort
+from .sum import sum
+from .topology import topology
+
+__all__ = [
+ 'binned_gather',
+ 'binned_scatter',
+ 'exclusive_cumsum',
+ 'inclusive_cumsum',
+ 'gather',
+ 'histogram',
+ 'padded_gather',
+ 'padded_scatter',
+ 'repeat',
+ 'replicate',
+ 'round_up',
+ 'scatter',
+ 'sort',
+ 'sum',
+ 'topology',
+]
diff --git a/build/torch29-cxx11-cu129-x86_64-linux/ops/all_to_all_benchmark.py b/build/torch29-cxx11-cu129-x86_64-linux/ops/all_to_all_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c939818edca3345f6344bbc7cef07ffe3cd0181
--- /dev/null
+++ b/build/torch29-cxx11-cu129-x86_64-linux/ops/all_to_all_benchmark.py
@@ -0,0 +1,63 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import torch.distributed as dist
+
+# from megablocks import benchmark_util
+# from megablocks.layers.all_to_all import all_to_all
+
+from .. import benchmark_util
+from .._layers.all_to_all import all_to_all
+
+_ALL_TO_ALL_BENCHMARK = (
+ (8, 1024),
+ (16, 1024),
+ (32, 1024),
+ (64, 1024),
+ (128, 1024),
+ (256, 1024),
+ (512, 1024),
+ (1024, 1024),
+ (2 * 1024, 1024),
+ (4 * 1024, 1024),
+ (8 * 1024, 1024),
+ (16 * 1024, 1024),
+ (32 * 1024, 1024),
+ (64 * 1024, 1024),
+ (128 * 1024, 1024),
+ (256 * 1024, 1024),
+ (512 * 1024, 1024),
+ (1024 * 1024, 1024),
+)
+
+
+def benchmark_all_to_all(group, sl, hs):
+ world_size = dist.get_world_size(group)
+ assert (sl % world_size) == 0
+ send_recv_sizes = [sl // world_size] * world_size
+
+ x = torch.randn((sl, hs)).cuda().half()
+
+ details = {
+ 'world_size': world_size,
+ 'message_size (B)': send_recv_sizes[0] * hs * 2, # 2B elements.
+ }
+
+ def benchmark():
+ return all_to_all(x, send_recv_sizes, send_recv_sizes, group)
+
+ time, std = benchmark_util.benchmark_function(benchmark)
+
+ if dist.get_rank(group) == 0:
+ benchmark_util.log_benchmark('All-To-All', details, time, std)
+
+
+if __name__ == '__main__':
+ assert dist.is_available()
+ group = dist.init_process_group(backend='nccl')
+ local_rank = dist.get_rank(group)
+ torch.cuda.set_device(local_rank)
+
+ for args in _ALL_TO_ALL_BENCHMARK:
+ benchmark_all_to_all(group, *args)
diff --git a/build/torch29-cxx11-cu129-x86_64-linux/ops/binned_gather.py b/build/torch29-cxx11-cu129-x86_64-linux/ops/binned_gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..189a7fa3518d660f29ea32e7a04827164af98d60
--- /dev/null
+++ b/build/torch29-cxx11-cu129-x86_64-linux/ops/binned_gather.py
@@ -0,0 +1,37 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for binned_gather kernel.
+class BinnedGatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bins: torch.Tensor,
+ bin_size: int,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bins)
+ ctx.top_k = top_k
+ return kernels.binned_gather(x, indices, None, bins, bin_size, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ indices, bins = ctx.saved_tensors
+ out = kernels.binned_scatter(grad, indices, None, bins, ctx.top_k)
+ return out, None, None, None, None
+
+
+binned_gather = BinnedGatherOp.apply
diff --git a/build/torch29-cxx11-cu129-x86_64-linux/ops/binned_scatter.py b/build/torch29-cxx11-cu129-x86_64-linux/ops/binned_scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb937c0c106662ce8108c1cb926f8f063b163d3d
--- /dev/null
+++ b/build/torch29-cxx11-cu129-x86_64-linux/ops/binned_scatter.py
@@ -0,0 +1,59 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for binned_scatter kernel.
+class BinnedScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ):
+ assert len(x.size()) == 3
+ ctx.bin_size = x.size(1)
+ ctx.top_k = top_k
+
+ # TODO(tgale): Don't save 'x' for backwards if we don't need to
+ # calculate the gradient w.r.t. 'weights'.
+ ctx.save_for_backward(x, indices, weights, bins)
+ return kernels.binned_scatter(x, indices, weights, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ x, indices, weights, bins = ctx.saved_tensors
+ out = kernels.binned_gather(
+ grad,
+ indices,
+ weights,
+ bins,
+ ctx.bin_size,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[2]:
+ wgrad = kernels.binned_scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bins,
+ ctx.top_k,
+ )
+ return out, None, wgrad, None, None
+
+
+binned_scatter = BinnedScatterOp.apply
diff --git a/build/torch29-cxx11-cu129-x86_64-linux/ops/cumsum.py b/build/torch29-cxx11-cu129-x86_64-linux/ops/cumsum.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2b7572391e20045d335cf7337246e8a9b9f57ef
--- /dev/null
+++ b/build/torch29-cxx11-cu129-x86_64-linux/ops/cumsum.py
@@ -0,0 +1,52 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ # import megablocks_ops as ops # type: ignore
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrappers for cumsum kernels.
+# NOTE: Does not support gradients.
+class ExclusiveCumsumOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, dim: int):
+ if len(x.size()) == 1:
+ x = x.view([1, -1])
+ out = torch.empty_like(x)
+ ops.exclusive_cumsum(x, 1, out)
+ return out.squeeze()
+ out = torch.empty_like(x)
+ ops.exclusive_cumsum(x, dim, out)
+ return out
+
+
+exclusive_cumsum = ExclusiveCumsumOp.apply
+
+
+class InclusiveCumsumOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, dim: int) -> torch.Tensor:
+ if len(x.size()) == 1:
+ x = x.view([1, -1])
+ out = torch.empty_like(x)
+ ops.inclusive_cumsum(x, 1, out)
+ return out.squeeze()
+ out = torch.empty_like(x)
+ ops.inclusive_cumsum(x, dim, out)
+ return out
+
+
+inclusive_cumsum = InclusiveCumsumOp.apply
diff --git a/build/torch29-cxx11-cu129-x86_64-linux/ops/gather.py b/build/torch29-cxx11-cu129-x86_64-linux/ops/gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1f87c1e7bed8d3589dd790805234976e0b05898
--- /dev/null
+++ b/build/torch29-cxx11-cu129-x86_64-linux/ops/gather.py
@@ -0,0 +1,38 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for gather kernel.
+class GatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bin_ids, bins)
+ ctx.top_k = top_k
+ return kernels.gather(x, indices, bin_ids, None, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+
+ indices, bin_ids, bins = ctx.saved_tensors
+ out = kernels.scatter(grad, indices, bin_ids, None, bins, ctx.top_k)
+ return out, None, None, None, None, None
+
+
+gather = GatherOp.apply
diff --git a/build/torch29-cxx11-cu129-x86_64-linux/ops/histogram.py b/build/torch29-cxx11-cu129-x86_64-linux/ops/histogram.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b3f058ec373cbba7555704fb5e4212c3cc75d9d
--- /dev/null
+++ b/build/torch29-cxx11-cu129-x86_64-linux/ops/histogram.py
@@ -0,0 +1,27 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for histogram kernel.
+# NOTE: Does not support gradients.
+class HistogramOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, max_val: float):
+ return ops.histogram(x, max_val)
+
+
+histogram = HistogramOp.apply
diff --git a/build/torch29-cxx11-cu129-x86_64-linux/ops/histogram_benchmark.py b/build/torch29-cxx11-cu129-x86_64-linux/ops/histogram_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0b9c6047567b87a295979498142230d1b0c9411
--- /dev/null
+++ b/build/torch29-cxx11-cu129-x86_64-linux/ops/histogram_benchmark.py
@@ -0,0 +1,78 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import numpy as np
+import torch
+# from absl.testing import parameterized
+
+from .. import ops
+
+_HISTOGRAM_TESTS = (
+ (16384, torch.int32, 2),
+ (16384, torch.int32, 4),
+ (16384, torch.int32, 8),
+ (16384, torch.int32, 16),
+ (16384, torch.int32, 32),
+ (16384, torch.int32, 64),
+ (16384, torch.int32, 128),
+ (16384, torch.int32, 256),
+)
+
+
+def benchmark_function(fn, iterations=10):
+ # Run once to get rid of startup overhead.
+ fn()
+ times = []
+ for _ in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+ start.record()
+ fn()
+ end.record()
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ times = np.array(times)
+ return times.mean(), times.std(), times.max(), times.min()
+
+
+def log_benchmark(arguments, mean_t, std_t):
+ print('=' * 60)
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t))
+ print('=' * 60)
+
+
+# class HistogramBenchmark(parameterized.TestCase):
+#
+# @parameterized.parameters(*_HISTOGRAM_TESTS)
+# def testHistogram(self, n, dtype, max_val):
+# x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
+#
+# mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.histogram(x, max_val),)
+# arguments = {
+# 'n': n,
+# 'dtype': dtype,
+# 'max_val': max_val,
+# }
+# log_benchmark(arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_HISTOGRAM_TESTS)
+# def testTorchHistogram(self, n, dtype, max_val):
+# x = torch.randint(0, 128, (n,)).cuda().to(dtype)
+#
+# mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.histc(x, max_val, 0, max_val - 1),)
+# arguments = {
+# 'n': n,
+# 'dtype': dtype,
+# 'max_val': max_val,
+# }
+# log_benchmark(arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch29-cxx11-cu129-x86_64-linux/ops/matmul_benchmark.py b/build/torch29-cxx11-cu129-x86_64-linux/ops/matmul_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c4b9b8866ffed2eb769b77f2320c82e5491ae0e
--- /dev/null
+++ b/build/torch29-cxx11-cu129-x86_64-linux/ops/matmul_benchmark.py
@@ -0,0 +1,415 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+
+# import stk
+
+# try:
+# import stk
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/ops/matmul_benchmark.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+# from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+
+# Calling tensor.t() calls tensor.transpose(0, 1) which calls
+# torch.as_strided(...). Circumvent this chain to avoid an overhead
+# this adds.
+def transpose_view(x):
+ return torch.as_strided(
+ x,
+ (x.shape[1], x.shape[0]),
+ (x.stride()[1], x.stride()[0]),
+ )
+
+
+_MATMUL_TESTS = (
+ (64 * 1024, 512, 2048, 64),
+ (32 * 1024, 768, 3072, 64),
+ (8 * 1024, 1024, 4096, 64),
+ (4 * 2048, 4096, 4 * 4096, 4),
+)
+
+
+def log_benchmark(name, arguments, time, std, flops):
+ benchmark_util.log_benchmark(name, arguments, time, std)
+ print('flops = {:.2f}B'.format(flops / 1e9))
+ print('throughput = {:.2f}T'.format(flops / 1e9 / time))
+ print('=' * 60)
+
+
+# class MatmulBenchmark(parameterized.TestCase):
+#
+# def build_sparse_matrix(self, x, padded_bins, fhs, ne):
+# blocking = 128
+# padded_tokens, _ = x.size()
+# assert padded_tokens % blocking == 0
+# assert fhs % blocking == 0
+#
+# # Offsets for the sparse matrix. All rows have the
+# # same number of nonzero blocks dictated by the
+# # dimensionality of a single expert.
+# block_rows = padded_tokens // blocking
+# blocks_per_row = fhs // blocking
+# offsets = torch.arange(
+# 0,
+# block_rows * blocks_per_row + 1,
+# blocks_per_row,
+# dtype=torch.int32,
+# device=x.device,
+# )
+#
+# # Indices for the sparse matrix. The indices for
+# # the intermediate matrix are dynamic depending
+# # on the mapping of tokens to experts.
+# column_indices = ops.topology(
+# padded_bins,
+# blocking,
+# block_rows,
+# blocks_per_row,
+# )
+# data = torch.empty(
+# column_indices.numel(),
+# blocking,
+# blocking,
+# dtype=torch.float16,
+# device=x.device,
+# )
+# shape = (padded_tokens, fhs * ne)
+# row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
+# return stk.Matrix(shape, data, row_indices, column_indices, offsets)
+#
+# def build_input_matrix(self, sl, hs, ne):
+# x = torch.randn((sl, hs)).cuda().half()
+#
+# # Assign tokens to experts uniformly.
+# top_expert = torch.arange(0, sl).cuda().int() % ne
+#
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(top_expert, ne)
+# padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+# padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+# out = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, 1)
+# return out, padded_bins
+#
+# def build_weight_matrix(self, ne, hs, fhs):
+# return torch.randn((hs, ne * fhs)).cuda().half()
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_Fwd_SDD_NT(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+# w = transpose_view(w)
+#
+# def benchmark():
+# return stk.ops.sdd(x, w, topo)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0::Fwd::SDD::NT',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_GradX_DSD_NN(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+#
+# def benchmark():
+# return stk.ops.dsd(topo, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0::GradX::DSD::NN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_GradW_DSD_TN(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+# topo = topo.t()
+#
+# def benchmark():
+# return stk.ops.dsd(topo, x)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0::GradW::DSD::TN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_Fwd_DSD_NN(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+#
+# def benchmark():
+# return stk.ops.dsd(x, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::Fwd::DSD::NN',
+# arguments,
+# mean_t,
+# std_t,
+# x.nnz * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_GradX_SDD_NT(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+# out = stk.ops.dsd(x, w)
+# w = transpose_view(w)
+#
+# def benchmark():
+# return stk.ops.sdd(out, w, x)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::GradX::SDD::NT',
+# arguments,
+# mean_t,
+# std_t,
+# x.nnz * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_GradW_DSD_TN(self, sl, hs, fhs, ne):
+# x, padded_bins = self.build_input_matrix(sl, hs, ne)
+# w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+# x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+# out = stk.ops.dsd(x, w)
+# x = x.t()
+#
+# def benchmark():
+# return stk.ops.dsd(x, out)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::GradW::DSD::TN',
+# arguments,
+# mean_t,
+# std_t,
+# x.nnz * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_Fwd_DDD_NT(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, hs)).cuda().half()
+# w = torch.randn((ne, hs, fhs)).cuda().half()
+#
+# w = w.transpose(1, 2).contiguous()
+# w = w.transpose(1, 2)
+#
+# def benchmark():
+# return torch.bmm(x, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0::Fwd:DDD::NT',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_GradX_DDD_NN(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, hs)).cuda().half()
+# w = torch.randn((ne, hs, fhs)).cuda().half()
+# out = torch.bmm(x, w)
+# w = w.transpose(1, 2).contiguous()
+#
+# def benchmark():
+# return torch.bmm(out, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0:GradX:DDD::NN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear0_GradW_DDD_TN(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, hs)).cuda().half()
+# w = torch.randn((ne, hs, fhs)).cuda().half()
+# out = torch.bmm(x, w)
+# out = out.transpose(1, 2)
+#
+# def benchmark():
+# return torch.bmm(out, x)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '0:GradW:DDD::TN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * fhs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_Fwd_DDD_NN(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+# w = torch.randn((ne, fhs, hs)).cuda().half()
+#
+# def benchmark():
+# return torch.bmm(x, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::Fwd::DDD::NN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_GradX_DDD_NT(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+# w = torch.randn((ne, fhs, hs)).cuda().half()
+# out = torch.bmm(x, w)
+# w = torch.transpose(w, 1, 2)
+#
+# def benchmark():
+# return torch.bmm(out, w)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::GradX::DDD::NT',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * hs * 2,
+# )
+#
+# @parameterized.parameters(*_MATMUL_TESTS)
+# def testFFN_Linear1_GradW_DDD_TN(self, sl, hs, fhs, ne):
+# assert (sl % ne) == 0
+# x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+# w = torch.randn((ne, fhs, hs)).cuda().half()
+# out = torch.bmm(x, w)
+# x = torch.transpose(x, 1, 2)
+#
+# def benchmark():
+# return torch.bmm(x, out)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'ffn_hidden_size': fhs,
+# 'num_experts': ne,
+# }
+# log_benchmark(
+# '1::GradW::DDD::TN',
+# arguments,
+# mean_t,
+# std_t,
+# x.numel() * hs * 2,
+# )
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch29-cxx11-cu129-x86_64-linux/ops/padded_gather.py b/build/torch29-cxx11-cu129-x86_64-linux/ops/padded_gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1cf4047c9494394d2a3884ba8830179013db7ff
--- /dev/null
+++ b/build/torch29-cxx11-cu129-x86_64-linux/ops/padded_gather.py
@@ -0,0 +1,55 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for padded_gather kernel.
+class PaddedGatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bin_ids, bins, padded_bins)
+ ctx.top_k = top_k
+ return kernels.padded_gather(
+ x,
+ indices,
+ bin_ids,
+ None,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+
+ indices, bin_ids, bins, padded_bins = ctx.saved_tensors
+ out = kernels.padded_scatter(
+ grad,
+ indices,
+ bin_ids,
+ None,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+ return out, None, None, None, None, None
+
+
+padded_gather = PaddedGatherOp.apply
diff --git a/build/torch29-cxx11-cu129-x86_64-linux/ops/padded_scatter.py b/build/torch29-cxx11-cu129-x86_64-linux/ops/padded_scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..61e021b81497e472cda5d72bdac557a0ca92d262
--- /dev/null
+++ b/build/torch29-cxx11-cu129-x86_64-linux/ops/padded_scatter.py
@@ -0,0 +1,98 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for padded_scatter kernel.
+class PaddedScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+ ):
+ maybe_x = [x] if ctx.needs_input_grad[3] else []
+ ctx.save_for_backward(
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ *maybe_x,
+ )
+ ctx.top_k = top_k
+ ctx.x_shape = x.shape
+ return kernels.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ saved_tensors = ctx.saved_tensors
+
+ indices, bin_ids, weights, bins, padded_bins = saved_tensors[:5]
+ dgrad = None
+ if ctx.needs_input_grad[0]:
+ dgrad = kernels.padded_gather(
+ grad,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[3]: # need wgrad
+ x = saved_tensors[-1]
+ wgrad = kernels.padded_scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+ return dgrad, None, None, wgrad, None, None, None, None
+
+
+def padded_scatter(
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+):
+ return PaddedScatterOp.apply(
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
diff --git a/build/torch29-cxx11-cu129-x86_64-linux/ops/padded_scatter_benchmark.py b/build/torch29-cxx11-cu129-x86_64-linux/ops/padded_scatter_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..3cbe4735891446b46f93170c64c23fe63632bf93
--- /dev/null
+++ b/build/torch29-cxx11-cu129-x86_64-linux/ops/padded_scatter_benchmark.py
@@ -0,0 +1,66 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import torch
+# from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+_PADDED_SCATTER_BENCHMARK = (
+ # dMoE-Medium, 8-way EMP.
+ (1024 * 16, 1024, 8, 4),
+ # dMoE-Medium, post-all-to-all.
+ (1024 * 16 * 4, 1024, 8, 1),
+)
+
+
+# class PaddedScatterTest(parameterized.TestCase):
+#
+# @parameterized.parameters(*_PADDED_SCATTER_BENCHMARK)
+# def testPaddedScatter(self, sl, hs, ne, top_k):
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+#
+# # Randomly assign tokens to experts.
+# top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(top_expert, ne)
+# padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+# padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+#
+# # Sample weights for the scatter reduce.
+# weights = torch.rand((sl * top_k,)).cuda().half()
+#
+# # Gather the data to prepare for backwards.
+# x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
+#
+# def benchmark():
+# return ops.padded_scatter(
+# x,
+# indices,
+# bin_ids,
+# weights,
+# bins,
+# padded_bins,
+# top_k,
+# )
+#
+# time, std = benchmark_util.benchmark_function(benchmark)
+# benchmark_util.log_benchmark(
+# 'Padded Scatter',
+# {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# 'top_k': top_k,
+# },
+# time,
+# std,
+# )
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch29-cxx11-cu129-x86_64-linux/ops/permute_benchmark.py b/build/torch29-cxx11-cu129-x86_64-linux/ops/permute_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..697abddbb3a2082ec4ddd6d94f89f7faabb34b40
--- /dev/null
+++ b/build/torch29-cxx11-cu129-x86_64-linux/ops/permute_benchmark.py
@@ -0,0 +1,149 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import torch
+# from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+_PERMUTE_TESTS = (
+ (16384, 768, 2),
+ (16384, 768, 4),
+ (16384, 768, 8),
+ (16384, 768, 16),
+ (16384, 768, 32),
+ (16384, 768, 64),
+ (16384, 768, 128),
+ (16384 * 8, 768, 2),
+ (16384 * 8, 768, 4),
+ (16384 * 8, 768, 8),
+ (16384 * 8, 768, 16),
+ (16384 * 8, 768, 32),
+ (16384 * 8, 768, 64),
+ (16384 * 8, 768, 128),
+)
+
+
+# class PermuteBenchmark(parameterized.TestCase):
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testBinnedGather(self, sl, hs, ne):
+# # NOTE: Capacity factor == 1.
+# ec = sl // ne
+#
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+# top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(indices, ne)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+#
+# def benchmark():
+# return ops.binned_gather(x, indices, bins, ec)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('BinnedGather', arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testBinnedScatter(self, sl, hs, ne):
+# # NOTE: Capacity factor == 1.
+# ec = sl // ne
+#
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+# top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(indices, ne)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+# x = ops.binned_gather(x, indices, bins, ec)
+#
+# def benchmark():
+# return ops.binned_scatter(x, indices, bins)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('BinnedScatter', arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testPaddedGather(self, sl, hs, ne):
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+#
+# # Randomly assign tokens to experts.
+# top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(top_expert, ne)
+# padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+# padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+#
+# def benchmark():
+# return ops.padded_gather(x, indices, bin_ids, bins, padded_bins)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('PaddedGather', arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testPaddedScatter(self, sl, hs, ne):
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+#
+# # Randomly assign tokens to experts.
+# top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+# bin_ids, indices = ops.sort(top_expert)
+# tokens_per_expert = ops.histogram(top_expert, ne)
+# padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+# padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+# x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins)
+#
+# def benchmark():
+# return ops.padded_scatter(x, indices, bin_ids, bins, padded_bins)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('PaddedScatter', arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_PERMUTE_TESTS)
+# def testCopy(self, sl, hs, ne):
+# # NOTE: Capacity factor == 1.
+# # ec = sl // ne
+#
+# # Create the data and indices.
+# x = torch.randn((sl, hs)).cuda().half()
+# y = x.clone()
+#
+# def benchmark():
+# return y.copy_(x)
+#
+# mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+# arguments = {
+# 'sequence_length': sl,
+# 'hidden_size': hs,
+# 'num_experts': ne,
+# }
+# benchmark_util.log_benchmark('Copy', arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch29-cxx11-cu129-x86_64-linux/ops/repeat.py b/build/torch29-cxx11-cu129-x86_64-linux/ops/repeat.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e9e09de5f857d51cd758ab30b2f3a846d6f9275
--- /dev/null
+++ b/build/torch29-cxx11-cu129-x86_64-linux/ops/repeat.py
@@ -0,0 +1,10 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+
+def repeat(x: torch.Tensor, tiling: torch.Size):
+ if all((t == 1 for t in tiling)):
+ return x
+ return x.repeat(*tiling)
diff --git a/build/torch29-cxx11-cu129-x86_64-linux/ops/replicate.py b/build/torch29-cxx11-cu129-x86_64-linux/ops/replicate.py
new file mode 100644
index 0000000000000000000000000000000000000000..26daf0eede330603a4b8ea7167faf1411d07ca93
--- /dev/null
+++ b/build/torch29-cxx11-cu129-x86_64-linux/ops/replicate.py
@@ -0,0 +1,36 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for replicate kernel.
+class ReplicateOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, bins: torch.Tensor, num_outputs: int):
+ ctx.save_for_backward(bins)
+ out = torch.empty((x.shape[0], num_outputs), dtype=x.dtype, device=x.device)
+ ops.replicate_forward(x, bins, out)
+ return out
+
+ @staticmethod
+ def backward(ctx: Any, grad: torch.Tensor):
+ bins, = ctx.saved_tensors
+ out = torch.empty((grad.shape[0], bins.shape[0]), dtype=grad.dtype, device=grad.device)
+ ops.replicate_backward(grad, bins, out)
+ return out, None, None
+
+
+replicate = ReplicateOp.apply
diff --git a/build/torch29-cxx11-cu129-x86_64-linux/ops/round_up.py b/build/torch29-cxx11-cu129-x86_64-linux/ops/round_up.py
new file mode 100644
index 0000000000000000000000000000000000000000..6cf6bc873c9f448c5fa9126ebcfd66e8688002af
--- /dev/null
+++ b/build/torch29-cxx11-cu129-x86_64-linux/ops/round_up.py
@@ -0,0 +1,14 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+
+def round_up(x: torch.Tensor, value: int):
+ assert isinstance(value, int)
+ assert x.dtype == torch.int32
+
+ # TODO(tgale): If this becomes and issue
+ # do this in a custom kernel. We only expect
+ # to use this on arrays of less than 1k elements.
+ return torch.div(x + (value - 1), value, rounding_mode='trunc') * value
diff --git a/build/torch29-cxx11-cu129-x86_64-linux/ops/scatter.py b/build/torch29-cxx11-cu129-x86_64-linux/ops/scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4605d9b46f387761b070352365f223dbfe69d47
--- /dev/null
+++ b/build/torch29-cxx11-cu129-x86_64-linux/ops/scatter.py
@@ -0,0 +1,72 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Optional
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for scatter kernel.
+class ScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ) -> torch.Tensor:
+ maybe_x = [x] if ctx.needs_input_grad[3] else []
+ ctx.save_for_backward(indices, bin_ids, weights, bins, *maybe_x)
+ ctx.top_k = top_k
+ ctx.x_shape = x.shape
+ return kernels.scatter(x, indices, bin_ids, weights, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ saved_tensors = ctx.saved_tensors
+
+ indices, bin_ids, weights, bins = saved_tensors[:4]
+ dgrad = None
+ if ctx.needs_input_grad[0]:
+ dgrad = kernels.gather(
+ grad,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[3]: # need wgrad
+ x = saved_tensors[-1]
+ wgrad = kernels.scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bin_ids,
+ bins,
+ ctx.top_k,
+ )
+ return dgrad, None, None, wgrad, None, None, None
+
+
+def scatter(
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+) -> Optional[torch.Tensor]:
+ return ScatterOp.apply(x, indices, bin_ids, weights, bins, top_k)
diff --git a/build/torch29-cxx11-cu129-x86_64-linux/ops/sort.py b/build/torch29-cxx11-cu129-x86_64-linux/ops/sort.py
new file mode 100644
index 0000000000000000000000000000000000000000..bda3bf64283e39533c2eae3627e76bb2d0262c9f
--- /dev/null
+++ b/build/torch29-cxx11-cu129-x86_64-linux/ops/sort.py
@@ -0,0 +1,38 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Optional, Tuple
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+_BITS_FOR_DTYPE = {
+ torch.int16: 16,
+ torch.int32: 32,
+ torch.int64: 64,
+}
+
+
+# Autograd wrapper for sort kernel.
+# NOTE: Does not support gradients.
+class SortOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, end_bit: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
+ if end_bit is None:
+ end_bit = _BITS_FOR_DTYPE[x.dtype]
+ x_out = torch.empty_like(x)
+ iota_out = torch.empty_like(x)
+ ops.sort(x, end_bit, x_out, iota_out)
+ return (x_out, iota_out)
+
+
+sort = SortOp.apply
diff --git a/build/torch29-cxx11-cu129-x86_64-linux/ops/sort_benchmark.py b/build/torch29-cxx11-cu129-x86_64-linux/ops/sort_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..11043c0824c36372585f1d9f48480c2a6ef32eb6
--- /dev/null
+++ b/build/torch29-cxx11-cu129-x86_64-linux/ops/sort_benchmark.py
@@ -0,0 +1,85 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import numpy as np
+import torch
+# from absl.testing import parameterized
+
+from .. import ops
+
+_SORT_TESTS = (
+ (16384, torch.int32, None),
+ (16384, torch.int32, 2),
+ (16384, torch.int32, 128),
+)
+
+_BASELINE_SORT_TESTS = ((16384,),)
+
+
+def numpy_dtype(dtype):
+ types = {
+ torch.int16: np.int16,
+ torch.int32: np.int32,
+ torch.int64: np.int64,
+ }
+ return types[dtype]
+
+
+def benchmark_function(fn, iterations=10):
+ # Run once to get rid of startup overhead.
+ fn()
+ times = []
+ for _ in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+ start.record()
+ fn()
+ end.record()
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ times = np.array(times)
+ return times.mean(), times.std(), times.max(), times.min()
+
+
+def log_benchmark(arguments, mean_t, std_t):
+ print('=' * 60)
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t))
+ print('=' * 60)
+
+
+# class SortBenchmark(parameterized.TestCase):
+#
+# @parameterized.parameters(*_SORT_TESTS)
+# def testSort(self, n, dtype, max_val):
+# if max_val is None:
+# max_val = np.iinfo(numpy_dtype(dtype)).max
+# end_bit = int(np.ceil(np.log2(max_val)))
+# x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
+#
+# mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.sort(x, end_bit),)
+# arguments = {
+# 'n': n,
+# 'dtype': dtype,
+# 'max_val': max_val,
+# }
+# log_benchmark(arguments, mean_t, std_t)
+#
+# @parameterized.parameters(*_BASELINE_SORT_TESTS)
+# def testTorchSort(self, n):
+# x = torch.randint(0, 128, (n,)).cuda().to(torch.int32)
+#
+# mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.sort(x))
+# arguments = {
+# 'n': n,
+# }
+# log_benchmark(arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch29-cxx11-cu129-x86_64-linux/ops/stk_autocast.py b/build/torch29-cxx11-cu129-x86_64-linux/ops/stk_autocast.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a3626e5e0eec51339c95a448bca84be14a2ca93
--- /dev/null
+++ b/build/torch29-cxx11-cu129-x86_64-linux/ops/stk_autocast.py
@@ -0,0 +1,39 @@
+# vendored from
+# https://github.com/stanford-futuredata/stk/blob/736313768ef697ce13a0594a41b2512a0fbc9884/stk/backend/autocast.py
+import functools
+import torch
+
+
+def _is_eligible(x):
+ return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64)
+
+
+def _cast(x, dtype):
+ if isinstance(x, torch.Tensor) and _is_eligible(x):
+ return x.to(dtype)
+ elif isinstance(x, map):
+ return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()}
+ elif isinstance(x, list) or isinstance(x, tuple):
+ return type(x)(map(lambda y: _cast(y, dtype), x))
+ return x
+
+
+def custom_fwd(fwd):
+ """Wrap a custom autograd function that always uses autocast dtype."""
+
+ @functools.wraps(fwd)
+ def decorate_fwd(*args, **kwargs):
+ if torch.is_autocast_enabled():
+ with torch.autocast(device_type="cuda", enabled=False):
+ dtype = torch.get_autocast_gpu_dtype()
+ return fwd(*_cast(args, dtype), **_cast(kwargs, dtype))
+ return fwd(*args, **kwargs)
+ return decorate_fwd
+
+
+def custom_bwd(bwd):
+ @functools.wraps(bwd)
+ def decorate_bwd(*args, **kwargs):
+ with torch.autocast(device_type="cuda", enabled=False):
+ return bwd(*args, **kwargs)
+ return decorate_bwd
\ No newline at end of file
diff --git a/build/torch29-cxx11-cu129-x86_64-linux/ops/sum.py b/build/torch29-cxx11-cu129-x86_64-linux/ops/sum.py
new file mode 100644
index 0000000000000000000000000000000000000000..e00c1aa68e1193f5b72f75a2edc37de8d505facc
--- /dev/null
+++ b/build/torch29-cxx11-cu129-x86_64-linux/ops/sum.py
@@ -0,0 +1,9 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+import torch
+
+
+def sum(x: torch.Tensor, dim: int = 0):
+ if x.shape[dim] == 1:
+ return x.squeeze(dim=dim)
+ return x.sum(dim=dim)
diff --git a/build/torch29-cxx11-cu129-x86_64-linux/ops/topology.py b/build/torch29-cxx11-cu129-x86_64-linux/ops/topology.py
new file mode 100644
index 0000000000000000000000000000000000000000..76a50d3164db20534b099dcb4d8487a7aef25d15
--- /dev/null
+++ b/build/torch29-cxx11-cu129-x86_64-linux/ops/topology.py
@@ -0,0 +1,45 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for topology kernel.
+# NOTE: Does not support gradients.
+class TopologyOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(
+ ctx: Any,
+ padded_bins: torch.Tensor,
+ block_size: int,
+ output_block_rows: int,
+ output_block_columns: int,
+ ):
+ out = torch.empty(
+ output_block_rows * output_block_columns,
+ dtype=torch.int16,
+ device=padded_bins.device,
+ )
+ ops.indices(
+ padded_bins,
+ block_size,
+ output_block_rows,
+ output_block_columns,
+ out,
+ )
+ return out
+
+
+topology = TopologyOp.apply
diff --git a/build/torch29-cxx11-cu129-x86_64-linux/stk/__init__.py b/build/torch29-cxx11-cu129-x86_64-linux/stk/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..73c40c267b1c3f4949e9c957a5d2c9f682dfc1a6
--- /dev/null
+++ b/build/torch29-cxx11-cu129-x86_64-linux/stk/__init__.py
@@ -0,0 +1,7 @@
+# import stk.random
+# import stk.ops
+# from stk.matrix import Matrix
+
+from . import random
+from . import ops
+from .matrix import Matrix
diff --git a/build/torch29-cxx11-cu129-x86_64-linux/stk/backend/__init__.py b/build/torch29-cxx11-cu129-x86_64-linux/stk/backend/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/build/torch29-cxx11-cu129-x86_64-linux/stk/backend/autocast.py b/build/torch29-cxx11-cu129-x86_64-linux/stk/backend/autocast.py
new file mode 100644
index 0000000000000000000000000000000000000000..97f6e919a60f3fd579ed0215031008d14111dc96
--- /dev/null
+++ b/build/torch29-cxx11-cu129-x86_64-linux/stk/backend/autocast.py
@@ -0,0 +1,37 @@
+import functools
+import torch
+
+
+def _is_eligible(x):
+ return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64)
+
+
+def _cast(x, dtype):
+ if isinstance(x, torch.Tensor) and _is_eligible(x):
+ return x.to(dtype)
+ elif isinstance(x, map):
+ return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()}
+ elif isinstance(x, list) or isinstance(x, tuple):
+ return type(x)(map(lambda y: _cast(y, dtype), x))
+ return x
+
+
+def custom_fwd(fwd):
+ """Wrap a custom autograd function that always uses autocast dtype."""
+
+ @functools.wraps(fwd)
+ def decorate_fwd(*args, **kwargs):
+ if torch.is_autocast_enabled():
+ with torch.autocast(device_type="cuda", enabled=False):
+ dtype = torch.get_autocast_gpu_dtype()
+ return fwd(*_cast(args, dtype), **_cast(kwargs, dtype))
+ return fwd(*args, **kwargs)
+ return decorate_fwd
+
+
+def custom_bwd(bwd):
+ @functools.wraps(bwd)
+ def decorate_bwd(*args, **kwargs):
+ with torch.autocast(device_type="cuda", enabled=False):
+ return bwd(*args, **kwargs)
+ return decorate_bwd
diff --git a/build/torch29-cxx11-cu129-x86_64-linux/stk/backend/sputnik.py b/build/torch29-cxx11-cu129-x86_64-linux/stk/backend/sputnik.py
new file mode 100644
index 0000000000000000000000000000000000000000..220c947bc1e932e8c77cc30f4069e9930f1aa962
--- /dev/null
+++ b/build/torch29-cxx11-cu129-x86_64-linux/stk/backend/sputnik.py
@@ -0,0 +1,316 @@
+import torch
+
+from ..backend import triton_kernels as backend
+from ..backend.autocast import custom_bwd, custom_fwd
+
+
+def _standardize_shape(x, transpose):
+ if transpose:
+ return torch.Size((x[1], x[0]))
+ return x
+
+
+def _sparse_transpose(x):
+ return (torch.Size((x[0][1], x[0][0])), ) + x[1:]
+
+
+def _transpose_helper(x, transpose):
+ if isinstance(x, torch.Tensor):
+ return x.t() if transpose else x
+ if transpose:
+ x = _sparse_transpose(x)
+ return x + (transpose,)
+
+
+def _wrap(x):
+ if isinstance(x, torch.Tensor):
+ return (x,)
+ return x
+
+
+def _is_transposed(x):
+ return (not x.is_contiguous() and
+ x.stride()[0] == 1 and
+ x.stride()[1] == x.size()[0])
+
+
+def _call_helper(op, out, a, b, trans_a, trans_b):
+ args = (_wrap(_transpose_helper(a, trans_a)) +
+ _wrap(_transpose_helper(b, trans_b)))
+ if isinstance(out, tuple):
+ args = args + out
+ return op(*args)
+
+
+def _preprocess_inputs(lhs, rhs, dy):
+ if isinstance(lhs, torch.Tensor) and _is_transposed(lhs):
+ lhs = lhs.t()
+ if isinstance(rhs, torch.Tensor) and _is_transposed(rhs):
+ rhs = rhs.t()
+ if (isinstance(dy, torch.Tensor) and
+ not dy.is_contiguous() and
+ not _is_transposed(dy)):
+ dy = dy.contiguous()
+ if isinstance(dy, tuple) and not dy[1].is_contiguous():
+ dy = (dy[0], dy[1].contiguous()) + dy[2:]
+ return lhs, rhs, dy
+
+
+def _postprocess_outputs(x, transpose, grad):
+ if isinstance(x, torch.Tensor) and transpose:
+ return grad.t()
+ return grad
+
+
+def _lhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs):
+ lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy)
+
+ a, b = (rhs, dy) if trans_lhs else (dy, rhs)
+ trans_a = trans_lhs and trans_rhs
+ trans_b = trans_lhs or not trans_rhs
+ out = _call_helper(op, lhs, a, b, trans_a, trans_b)
+ return _postprocess_outputs(lhs, trans_lhs, out)
+
+
+def _rhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs):
+ lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy)
+
+ a, b = (dy, lhs) if trans_rhs else (lhs, dy)
+ trans_a = not trans_lhs or trans_rhs
+ trans_b = trans_lhs and trans_rhs
+ out = _call_helper(op, rhs, a, b, trans_a, trans_b)
+ return _postprocess_outputs(rhs, trans_rhs, out)
+
+
+class DSD(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs):
+ ctx.save_for_backward(data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ rhs)
+ ctx.shape = _standardize_shape(shape, transpose_a)
+ ctx.transpose_a = transpose_a
+
+ out = torch.empty(
+ (shape[0], rhs.size()[1]),
+ dtype=rhs.dtype,
+ device=rhs.device)
+
+ backend.dsd(shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs,
+ out)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs = (ctx.shape,) + saved_tensors[:-1]
+ rhs = saved_tensors[-1]
+ trans_a = ctx.transpose_a
+ trans_b = _is_transposed(rhs)
+
+ ddata = None
+ if ctx.needs_input_grad[1]:
+ ddata = _lhs_gradient(sdd,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ drhs = None
+ if ctx.needs_input_grad[-1]:
+ op = dds if trans_b else dsd
+ drhs = _rhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return None, ddata, None, None, None, None, None, None, None, drhs
+
+
+dsd = DSD.apply
+
+
+class DDS(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b):
+ ctx.save_for_backward(lhs,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t)
+ ctx.shape = _standardize_shape(shape, transpose_b)
+ ctx.transpose_b = transpose_b
+ out = torch.empty((lhs.size()[0], shape[1]),
+ dtype=lhs.dtype,
+ device=lhs.device)
+ backend.dds(lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b,
+ out)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs = saved_tensors[0]
+ rhs = (ctx.shape,) + saved_tensors[1:]
+ trans_a = _is_transposed(lhs)
+ trans_b = ctx.transpose_b
+
+ dlhs = None
+ if ctx.needs_input_grad[0]:
+ op = dsd if trans_a else dds
+ dlhs = _lhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ ddata = None
+ if ctx.needs_input_grad[2]:
+ ddata = _rhs_gradient(sdd,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return dlhs, None, ddata, None, None, None, None, None, None, None
+
+
+dds = DDS.apply
+
+
+class SDD(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ lhs,
+ rhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t):
+ ctx.save_for_backward(
+ lhs,
+ rhs,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t)
+ ctx.shape = shape
+ out = torch.empty(
+ data.shape,
+ dtype=lhs.dtype,
+ device=lhs.device)
+ backend.sdd(lhs,
+ rhs,
+ shape,
+ out,
+ offsets,
+ row_indices,
+ column_indices)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs, rhs = saved_tensors[:2]
+ dy = (ctx.shape, dy) + saved_tensors[2:]
+ trans_a = _is_transposed(lhs)
+ trans_b = _is_transposed(rhs)
+
+ dlhs = None
+ if ctx.needs_input_grad[0]:
+ op = dds if trans_a else dsd
+ dlhs = _lhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ drhs = None
+ if ctx.needs_input_grad[1]:
+ op = dsd if trans_b else dds
+ drhs = _rhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return dlhs, drhs, None, None, None, None, None, None, None, None
+
+
+sdd = SDD.apply
+
+class RowIndices(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, shape, data, offsets, column_indices):
+ out = torch.empty(
+ column_indices.shape,
+ dtype=column_indices.dtype,
+ device=column_indices.device)
+ backend.row_indices(shape, data, offsets, column_indices, out)
+ return out
+
+
+row_indices = RowIndices.apply
diff --git a/build/torch29-cxx11-cu129-x86_64-linux/stk/backend/triton_kernels.py b/build/torch29-cxx11-cu129-x86_64-linux/stk/backend/triton_kernels.py
new file mode 100644
index 0000000000000000000000000000000000000000..c535309f3321249f475367164a558f94a4f8eb86
--- /dev/null
+++ b/build/torch29-cxx11-cu129-x86_64-linux/stk/backend/triton_kernels.py
@@ -0,0 +1,393 @@
+import torch
+import triton
+import triton.language as tl
+from dataclasses import dataclass
+
+@dataclass
+class TritonConfig:
+ BLOCK_M: int = 128
+ BLOCK_N: int = 128
+ BLOCK_K: int = 32
+ BLOCK_SIZE: int = 128
+ NUM_STAGES: int = 4
+ NUM_WARPS: int = 4
+
+def _validate_matmul_dims(M: int, K: int, N: int):
+ error_string = "incompatible dimensions: tensor has dim with length: {}, which must be divisible by {}"
+ assert M % TritonConfig.BLOCK_M == 0, error_string.format(M, TritonConfig.BLOCK_M)
+ assert K % TritonConfig.BLOCK_K == 0, error_string.format(K, TritonConfig.BLOCK_K)
+ assert N % TritonConfig.BLOCK_N == 0, error_string.format(N, TritonConfig.BLOCK_N)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _sdd_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+ # matrix multiplication
+ pid = tl.program_id(0)
+ pid_m = tl.load(row_indices + pid)
+ pid_n = tl.load(column_indices + pid)
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
+ rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
+ rk = tl.arange(0, BLOCK_K)
+ # pointers
+ A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
+ B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ for k in range(0, tl.cdiv(K, BLOCK_K)):
+ a = tl.load(A)
+ b = tl.load(B)
+ acc += tl.dot(a, b)
+ A += BLOCK_K * stride_ak
+ B += BLOCK_K * stride_bk
+ #Store to sparse matrix
+ acc = acc.to(C.dtype.element_ty)
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+ cm = tl.arange(0, BLOCK_M)
+ cn = tl.arange(0, BLOCK_N)
+ C = C + pid * BLOCK_ELEMENTS + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _dsd_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices, offsets,
+ block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+
+ # matrix multiplication
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+
+ num_pid_m = tl.num_programs(0)
+ num_pid_n = tl.num_programs(1)
+ pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M)
+
+ start_inx = tl.load(offsets + pid_m)
+ end_inx = tl.load(offsets + pid_m + 1)
+
+ # pointers to sparse matrix
+ rm = tl.arange(0, BLOCK_M)
+ rak = tl.arange(0, BLOCK_K)
+
+ A += (rm[:, None] * stride_am + rak[None, :] * stride_ak)
+
+ # pointers to dense matrix
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ rbk = tl.arange(0, BLOCK_K)
+ B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn)
+
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K)
+
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+ ak_sub_incr = BLOCK_K * stride_ak
+ bk_sub_incr = BLOCK_K * stride_bk
+ bk_block_incr = BLOCK_SIZE * stride_bk
+
+ for k in range(nsub_blocks * (end_inx - start_inx)):
+ sub_block_inx = k % nsub_blocks
+ block_inx = k // nsub_blocks
+
+ if trans_A:
+ ptr_A = A + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr
+ else:
+ ptr_A = A + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr
+
+ ptr_B = B + tl.load(column_indices + start_inx + block_inx) * bk_block_incr + sub_block_inx * bk_sub_incr
+
+ a = tl.load(ptr_A)
+ b = tl.load(ptr_B)
+ acc += tl.dot(a, b)
+
+ acc = acc.to(C.dtype.element_ty)
+
+ cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+
+ C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _dds_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices, offsets,
+ block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+
+ # matrix multiplication
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+
+ num_pid_m = tl.num_programs(0)
+ num_pid_n = tl.num_programs(1)
+ pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M)
+
+ start_inx = tl.load(offsets + pid_n)
+ end_inx = tl.load(offsets + pid_n + 1)
+
+ # pointers to dense matrix
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ rak = tl.arange(0, BLOCK_K)
+
+ A += (rm[:, None] * stride_am + rak[None, :] * stride_ak)
+
+ # pointers to sparse matrix
+ rn = tl.arange(0, BLOCK_N)
+ rbk = tl.arange(0, BLOCK_K)
+ B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn)
+
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K)
+
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+
+ ak_sub_incr = BLOCK_K * stride_ak
+ ak_block_incr = BLOCK_SIZE * stride_ak
+ bk_sub_incr = BLOCK_K * stride_bk
+
+ for k in range(nsub_blocks * (end_inx - start_inx)):
+ sub_block_inx = k % nsub_blocks
+ block_inx = k // nsub_blocks
+
+ if trans_B:
+ ptr_B = B + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr
+ else:
+ ptr_B = B + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr
+
+ ptr_A = A + tl.load(column_indices + start_inx + block_inx) * ak_block_incr + sub_block_inx * ak_sub_incr
+ a = tl.load(ptr_A)
+ b = tl.load(ptr_B)
+ acc += tl.dot(a, b)
+
+ acc = acc.to(C.dtype.element_ty)
+ cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+def dsd(shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs,
+ out
+ ):
+
+ device = rhs.device
+ trans_A = transpose_a
+ trans_B = False
+
+ if rhs.stride(0) > 1 and rhs.stride(1) > 1:
+ trans_B = True
+
+ # checks constraints
+ assert shape[1] == rhs.shape[0], "incompatible dimensions"
+ M, K = shape
+ _, N = rhs.shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if rhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ stride_am, stride_ak = data.stride(1), data.stride(2)
+ stride_bk, stride_bn = rhs.stride(0), rhs.stride(1)
+ a_column_indices = column_indices
+ a_offsets = offsets
+
+ # launch kernel
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N']))
+
+ if trans_A:
+ stride_am, stride_ak = data.stride(2), data.stride(1)
+ a_column_indices, a_offsets = column_indices_t, offsets_t
+
+ if trans_B:
+ stride_bk, stride_bn = rhs.stride(1), rhs.stride(0)
+
+ _dsd_kernel[grid](
+ data.data, rhs, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(0), out.stride(1),
+ row_indices, a_column_indices, a_offsets,
+ block_offsets_t, trans_A, trans_B,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+ # return out
+
+def dds(lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b,
+ out
+ ):
+
+ device = lhs.device
+ trans_B = transpose_b
+ trans_A = False
+
+ if lhs.stride(0) > 1 and lhs.stride(1) > 1:
+ trans_A = True
+
+ # checks constraints
+ assert lhs.shape[1] == shape[0], "incompatible dimensions"
+ M, K = lhs.shape
+ _, N = shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if lhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ stride_am, stride_ak = lhs.stride(0), lhs.stride(1)
+ stride_bk, stride_bn = data.stride(1), data.stride(2)
+ b_column_indices = column_indices_t
+ b_offsets = offsets_t
+
+ # launch kernel
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N']))
+
+ if trans_A:
+ stride_am, stride_ak = lhs.stride(1), lhs.stride(0)
+ if trans_B:
+ stride_bk, stride_bn = data.stride(2), data.stride(1)
+ b_column_indices, b_offsets = column_indices, offsets
+
+ _dds_kernel[grid](
+ lhs, data, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(0), out.stride(1),
+ row_indices, b_column_indices, b_offsets,
+ block_offsets_t, trans_A, trans_B,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+
+def sdd(lhs,
+ rhs,
+ shape,
+ out,
+ offsets,
+ row_indices,
+ column_indices
+ ):
+
+ device = out.device
+ trans_A = False
+ trans_B = False
+
+ if lhs.stride(0) > 1 and lhs.stride(1) > 1:
+ trans_A = True
+ if rhs.stride(0) > 1 and rhs.stride(1) > 1:
+ trans_B = True
+
+ # checks constraints
+ assert lhs.shape[1] == rhs.shape[0], "incompatible dimensions"
+ M, K = lhs.shape
+ _, N = rhs.shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if out.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ # launch kernel
+ nnz_blocks = len(row_indices)
+ grid = lambda META: (nnz_blocks,)
+
+ stride_am, stride_ak = lhs.stride(0), lhs.stride(1)
+ stride_bk, stride_bn = rhs.stride(0), rhs.stride(1)
+
+ if trans_A:
+ stride_am, stride_ak = lhs.stride(1), lhs.stride(0)
+ if trans_B:
+ stride_bk, stride_bn = rhs.stride(1), rhs.stride(0)
+
+ _sdd_kernel[grid](
+ lhs, rhs, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(1), out.stride(2),
+ row_indices, column_indices,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+
+@triton.jit
+def _row_indices_kernel(offsets, out):
+ pid = tl.program_id(0)
+ row_offset = tl.load(offsets + pid)
+ nnz_blocks = tl.load(offsets + pid + 1) - row_offset
+ for nnz_block in range(nnz_blocks):
+ tl.store(out + row_offset + nnz_block, pid)
+
+def row_indices(
+ shape, data, offsets, column_indices, out
+):
+ block_rows = len(offsets) - 1
+ _row_indices_kernel[(block_rows, )](offsets, out)
diff --git a/build/torch29-cxx11-cu129-x86_64-linux/stk/matrix.py b/build/torch29-cxx11-cu129-x86_64-linux/stk/matrix.py
new file mode 100644
index 0000000000000000000000000000000000000000..80f42263d6aada287adbfa52a61fe950162a9e28
--- /dev/null
+++ b/build/torch29-cxx11-cu129-x86_64-linux/stk/matrix.py
@@ -0,0 +1,329 @@
+import numpy as np
+import torch
+
+# 1. Add heavyweight (data) validation helper.
+# 2. Add construction helpers
+# 3. Make indentation consistent
+# 4. Replace asserts with descriptive errors.
+
+##
+### Validation helpers.
+##
+
+
+def _validate_matrix(shape, data, row_indices, column_indices, offsets):
+ # Data should be [nnz, block_size, block_size]
+ if data.dim() == 1:
+ data = torch.reshape(data, [data.numel(), 1, 1])
+
+ # Blocks should be square.
+ if data.shape[-2] != data.shape[-1]:
+ raise ValueError(
+ "Expected square blocking in data. "
+ f"Got block shape {[data.shape[-2], data.shape[-1]]}")
+
+ # Flatten batch dimensions on data - original shape preserved
+ # in shape argument.
+ block_size = data.shape[-1]
+ data = data.view([-1, block_size, block_size])
+
+ if data.dim() != 3:
+ raise ValueError(
+ "Expected 3D shape for data (nnz, block, block). "
+ f"Got shape {data.dim()}D shape.")
+
+ block_size = data.shape[1]
+ if shape[-2] % block_size != 0 or shape[-1] % block_size != 0:
+ raise ValueError(
+ "Matrix shape must be dividible by blocking. "
+ f"Got shape {shape} with "
+ f"{[block_size, block_size]} blocking.")
+
+ if np.prod(shape) < data.numel():
+ raise ValueError(
+ "Invalid matrix. Number of nonzeros exceeds matrix capacity "
+ f"({data.numel()} v. {np.prod(shape)})")
+
+ if row_indices.dim() != 1:
+ raise ValueError(
+ f"Expected 1D row_indices. Got {row_indices.dim()}D row_indices.")
+
+ if column_indices.dim() != 1:
+ raise ValueError(
+ f"Expected 1D column_indices. Got {column_indices.dim()}D column_indices.")
+
+ if offsets.dim() != 1:
+ raise ValueError(
+ f"Expected 1D offsets. Got {offsets.dim()}D offsets.")
+
+ if row_indices.numel() != data.shape[0]:
+ raise ValueError(
+ "Expected 1 index per nonzero block. "
+ f"Got {row_indices.numel()} row_indices for {data.shape[0]} blocks")
+
+ if column_indices.numel() != data.shape[0]:
+ raise ValueError(
+ "Expected 1 index per nonzero block. "
+ f"Got {column_indices.numel()} column_indices for {data.shape[0]} blocks")
+
+ block_rows = np.prod(shape[:-1]) / block_size
+ if offsets.numel() != block_rows + 1:
+ raise ValueError(
+ "Expected one offset per block row plus one. "
+ f"Got {offsets.numel()} offsets with {block_rows} block rows.")
+
+ is_cuda = (data.is_cuda and
+ row_indices.is_cuda and
+ column_indices.is_cuda and
+ offsets.is_cuda)
+ is_cpu = (not data.is_cuda and
+ not row_indices.is_cuda and
+ not column_indices.is_cuda and
+ not offsets.is_cuda)
+ if not (is_cuda or is_cpu):
+ raise ValueError(
+ "Expected data & meta-data on common device. "
+ f"Got data on {data.device}, row_indices on {row_indices.device} "
+ f"column_indices on {column_indices.device} and "
+ f"offsets on {offsets.device}.")
+
+ if data.dtype != torch.float16:
+ raise ValueError(
+ f"Expected float16 data. Got {data.dtype} data.")
+ if row_indices.dtype != torch.int16:
+ raise ValueError(
+ f"Expected int16 row_indices. Got {row_indices.dtype} row_indices.")
+ if column_indices.dtype != torch.int16:
+ raise ValueError(
+ f"Expected int16 column_indices. Got {column_indices.dtype} column_indices.")
+ if offsets.dtype != torch.int32:
+ raise ValueError(
+ f"Expected int32 offsets. Got {offsets.dtype} offsets.")
+ return data
+
+
+def _transpose(size, data, row_indices, column_indices, offsets):
+ block_columns = size[1] // data.shape[1]
+
+ # Sort row indices by column indices to get the transposed matrix's
+ # column indices.
+ gather_indices = column_indices.argsort()
+ column_indices_t = row_indices.gather(0, gather_indices)
+ block_offsets_t = gather_indices.int()
+
+ # NOTE: Histogram is not implemented for any integer type on CPU. Do
+ # the histogram in 32-bit float, which can exactly represent 16-bit
+ # integers.
+ column_indices_float = column_indices.float()
+
+ zero = torch.zeros((1,), dtype=torch.int32, device=data.device)
+ nnz_per_column = column_indices_float.histc(block_columns, 0, block_columns)
+ nnz_per_column = nnz_per_column.int()
+ offsets_t = torch.cat([zero, nnz_per_column.cumsum(0, dtype=torch.int32)])
+ return column_indices_t, offsets_t, block_offsets_t
+
+
+class Matrix(torch.nn.Module):
+ """A matrix stored in sparse format.
+
+ Underlying format is block compressed sparse row (BCSR).
+
+ TODO(tgale): Make this mirror torch.Tensor API as much as possible.
+ """
+
+ def __init__(self,
+ size,
+ data,
+ row_indices,
+ column_indices,
+ offsets,
+ column_indices_t=None,
+ offsets_t=None,
+ block_offsets_t=None):
+ super().__init__()
+ self._size = size
+ self._data = data
+ self._row_indices = row_indices
+ self._column_indices = column_indices
+ self._offsets = offsets
+
+ # Produce the transpose meta-data if it is not passed in.
+ if ((column_indices_t is None) or (offsets_t is None) or
+ (block_offsets_t is None)):
+ column_indices_t, offsets_t, block_offsets_t = _transpose(
+ size, data, row_indices, column_indices, offsets)
+ self._column_indices_t = column_indices_t
+ self._offsets_t = offsets_t
+ self._block_offsets_t = block_offsets_t
+
+ self._transposed = False
+
+ # Validate that our metadata will not overflow.
+ max_dim = np.iinfo(np.int16).max * self.blocking
+ if column_indices.dtype == torch.int16:
+ if size[0] > max_dim or size[1] > max_dim:
+ raise ValueError(
+ "Sparse matrix with shape {size} exceeds representable "
+ "size with 16-bit indices.")
+
+ def validate(self):
+ _validate_matrix(self._size,
+ self._data,
+ self._row_indices,
+ self._column_indices,
+ self._offsets)
+
+ # TODO(tgale): Add heavyweight data validation.
+
+ def to(self, device):
+ # TODO(tgale): Handle type conversions here. We
+ # need to set the appropriate meta-data type for
+ # the given floating-point type.
+ self._data = self._data.to(device)
+ self._row_indices = self._row_indices.to(device)
+ self._column_indices = self._column_indices.to(device)
+ self._offsets = self._offsets.to(device)
+ self._column_indices_t = self._column_indices_t.to(device)
+ self._offsets_t = self._offsets_t.to(device)
+ self._block_offsets_t = self._block_offsets_t.to(device)
+ return self
+
+ def cuda(self):
+ return self.to(torch.cuda.current_device())
+
+ def clone(self):
+ return Matrix(
+ self.size(),
+ self.data.clone(),
+ self.row_indices.clone(),
+ self.column_indices.clone(),
+ self.offsets.clone(),
+ self.column_indices_t.clone(),
+ self.offsets_t.clone(),
+ self.block_offsets_t.clone())
+
+ def t(self):
+ if self.dim() != 2:
+ raise ValueError(
+ "t() expects a tensor with <= 2 dimensions, "
+ f"but self is {self.dim()}D.")
+ out = Matrix(self.size(),
+ self.data,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+ out._transposed = not self._transposed
+ out._size = torch.Size((self._size[1], self._size[0]))
+ return out
+
+ def contiguous(self):
+ raise ValueError("Not yet implemented.")
+
+ def is_contiguous(self):
+ return not self._transposed
+
+ @property
+ def is_cuda(self):
+ return self._data.is_cuda
+
+ @property
+ def device(self):
+ return self._data.device
+
+ def size(self):
+ return self._size
+
+ @property
+ def shape(self):
+ return self.size()
+
+ def dim(self):
+ return len(self._size)
+
+ @property
+ def data(self):
+ return self._data
+
+ @property
+ def row_indices(self):
+ return self._row_indices
+
+ @property
+ def column_indices(self):
+ return self._column_indices
+
+ @property
+ def offsets(self):
+ return self._offsets
+
+ @property
+ def offsets_t(self):
+ return self._offsets_t
+
+ @property
+ def column_indices_t(self):
+ return self._column_indices_t
+
+ @property
+ def block_offsets_t(self):
+ return self._block_offsets_t
+
+ @property
+ def dtype(self):
+ return self.data.dtype
+
+ @property
+ def nnz(self):
+ return self.data.numel()
+
+ @property
+ def blocking(self):
+ return self.data.shape[1]
+
+ @property
+ def requires_grad(self):
+ return self.data.requires_grad
+
+ def requires_grad_(self, x):
+ self.data.requires_grad_(x)
+ return self
+
+ def view(self, *shape):
+ assert self.is_contiguous()
+ if shape[-1] != self.size()[-1]:
+ raise ValueError(
+ "Can't change view on compressed dimension. "
+ f"{self.size()[-1]} v. {shape[-1]}.")
+ if np.prod(shape) != np.prod(self.size()):
+ raise ValueError(
+ "Mismatch in numel of Matrix and new shape. "
+ f"{np.prod(self.size())} v. {np.prod(shape)}")
+ return Matrix(shape,
+ self.data,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+
+ @property
+ def grad(self):
+ # TODO(tgale): Make sure this mirrors torch.Tensor
+ # behavior in the case where we ask for the gradient
+ # of a non-contiguous tensor.
+ size = self.size()
+ if not self.is_contiguous():
+ size = torch.Size((size[1], size[0]))
+ out = Matrix(size,
+ self.data.grad,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+ return out if self.is_contiguous() else out.t()
diff --git a/build/torch29-cxx11-cu129-x86_64-linux/stk/ops/__init__.py b/build/torch29-cxx11-cu129-x86_64-linux/stk/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc873b236f4cd4036964c016a4036e3ce5ebf1ac
--- /dev/null
+++ b/build/torch29-cxx11-cu129-x86_64-linux/stk/ops/__init__.py
@@ -0,0 +1,3 @@
+from .linear_ops import dds, dsd, sdd
+from .matrix_ops import ones_like, row_indices, sum, to_dense, to_sparse
+from .eltwise_ops import mul
diff --git a/build/torch29-cxx11-cu129-x86_64-linux/stk/ops/eltwise_ops.py b/build/torch29-cxx11-cu129-x86_64-linux/stk/ops/eltwise_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba7d7332320250fd01fa60e60528f19de3e8ed03
--- /dev/null
+++ b/build/torch29-cxx11-cu129-x86_64-linux/stk/ops/eltwise_ops.py
@@ -0,0 +1,28 @@
+from ..matrix import Matrix
+
+def mul(a, b):
+ """Performs element-wise multiplication of matrices a and b.
+
+ It is the user's responsibility to make sure that a and b
+ follow the same matrix topology. This function assumes it is safe
+ to use the topoplogy of a.
+
+ Args:
+ a: stk.Matrix.
+ b: stk.Matrix with a's matrix topology.
+
+ Returns:
+ stk.Matrix where the entries correspond to torch.mul(a, b).
+ """
+ assert isinstance(a, Matrix)
+ assert isinstance(b, Matrix)
+ assert a.size() == b.size()
+
+ return Matrix(a.size(),
+ a.data * b.data,
+ a.row_indices,
+ a.column_indices,
+ a.offsets,
+ a.column_indices_t,
+ a.offsets_t,
+ a.block_offsets_t)
diff --git a/build/torch29-cxx11-cu129-x86_64-linux/stk/ops/eltwise_ops_test.py b/build/torch29-cxx11-cu129-x86_64-linux/stk/ops/eltwise_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..2939372a5c68ac92b47b11015db4f75f4fd60ffa
--- /dev/null
+++ b/build/torch29-cxx11-cu129-x86_64-linux/stk/ops/eltwise_ops_test.py
@@ -0,0 +1,86 @@
+import unittest
+import itertools
+import torch
+# from absl.testing import parameterized
+
+import stk
+from stk.ops.linear_ops_test import allclose, _dense_and_sparse
+
+_MATRIX_SIZES = (
+ (128, 128, 0.0),
+ (256, 256, 0.5),
+ (2048, 1024, 0.8),
+ (512, 128, 0.0),
+ (128, 512, 0.0),
+ (1024, 512, 0.0),
+ (1024, 512, 0.5),
+ (1024, 512, 0.75),
+ (512, 1024, 0.0),
+ (512, 1024, 0.5),
+ (512, 1024, 0.75),
+ (1024, 1024, 0.0),
+ (1024, 1024, 0.5),
+ (1024, 1024, 0.75),
+)
+
+_DTYPE = (
+ torch.float16, torch.bfloat16
+)
+
+def _generate_testcases():
+ testcases = itertools.product(_MATRIX_SIZES, _DTYPE)
+ testcases = [(*size, 128, dtype) for
+ (size, dtype) in testcases]
+ return testcases
+
+_ELTWISE_OP_TESTS = _generate_testcases()
+
+def _dense_and_sparse_like(x, std=0.1):
+ dense_data = torch.randn_like(x.data, device=x.device) * std
+ sparse = stk.Matrix(x.size(),
+ dense_data,
+ x.row_indices,
+ x.column_indices,
+ x.offsets)
+ dense = stk.ops.to_dense(sparse)
+
+ return (dense.requires_grad_(True),
+ sparse.requires_grad_(True))
+
+# @parameterized.parameters(_ELTWISE_OP_TESTS)
+# class EltwiseOpsTest(parameterized.TestCase):
+#
+# def testEltwiseMul(self, m, n, sparsity, blocking, dtype):
+#
+# a_dense, a = _dense_and_sparse(m, n, sparsity, blocking, dtype)
+# b_dense, b = _dense_and_sparse_like(a)
+#
+# out = stk.ops.mul(a, b)
+# expected_out = torch.mul(a_dense, b_dense)
+#
+# # Compute the gradients w.r.t. the inputs.
+# expected_out.sum().backward()
+# stk.ops.sum(out).backward()
+#
+# # Validate the results.
+# out = stk.ops.to_dense(out)
+# self.assertEqual(out.dim(), 2)
+# self.assertEqual(expected_out.size(), out.size())
+# self.assertTrue(allclose(out, expected_out))
+#
+# # LHS gradient.
+# grad = stk.ops.to_dense(a.grad)
+# expected_grad = a_dense.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size(), grad.size())
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# # RHS gradient.
+# grad = stk.ops.to_dense(b.grad)
+# expected_grad = b_dense.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size(), grad.size())
+# self.assertTrue(allclose(grad, expected_grad))
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch29-cxx11-cu129-x86_64-linux/stk/ops/linear_ops.py b/build/torch29-cxx11-cu129-x86_64-linux/stk/ops/linear_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d277c8c07f9e30addc31900a12175c8a1f4d7ad
--- /dev/null
+++ b/build/torch29-cxx11-cu129-x86_64-linux/stk/ops/linear_ops.py
@@ -0,0 +1,59 @@
+import torch
+
+from ..backend import sputnik
+from ..matrix import Matrix
+
+
+def dsd(a, b):
+ assert isinstance(a, Matrix)
+ assert isinstance(b, torch.Tensor)
+ return sputnik.dsd(
+ a.size(),
+ a.data, a.offsets,
+ a.row_indices,
+ a.column_indices,
+ a.offsets_t,
+ a.column_indices_t,
+ a.block_offsets_t,
+ not a.is_contiguous(),
+ b)
+
+
+def dds(a, b):
+ assert isinstance(a, torch.Tensor)
+ assert isinstance(b, Matrix)
+ return sputnik.dds(
+ a,
+ b.size(),
+ b.data, b.offsets,
+ b.row_indices,
+ b.column_indices,
+ b.offsets_t,
+ b.column_indices_t,
+ b.block_offsets_t,
+ not b.is_contiguous())
+
+
+def sdd(a, b, topo):
+ assert isinstance(a, torch.Tensor)
+ assert isinstance(b, torch.Tensor)
+ assert isinstance(topo, Matrix)
+ assert topo.is_contiguous()
+ out = sputnik.sdd(
+ a, b,
+ topo.size(),
+ topo.data,
+ topo.offsets,
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets_t,
+ topo.column_indices_t,
+ topo.block_offsets_t)
+ return Matrix(topo.size(),
+ out,
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets,
+ topo.column_indices_t,
+ topo.offsets_t,
+ topo.block_offsets_t)
diff --git a/build/torch29-cxx11-cu129-x86_64-linux/stk/ops/linear_ops_test.py b/build/torch29-cxx11-cu129-x86_64-linux/stk/ops/linear_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1c24d350df9c1b2346c7da885502cd696c88867
--- /dev/null
+++ b/build/torch29-cxx11-cu129-x86_64-linux/stk/ops/linear_ops_test.py
@@ -0,0 +1,216 @@
+import unittest
+import itertools
+import numpy as np
+import torch
+# from absl.testing import parameterized
+
+import stk
+
+
+def allclose(x, y, pct=0.25):
+ mask = torch.isclose(x, y, rtol=5e-2)
+ pct_diff = (mask.numel() - mask.sum()) / mask.numel() * 100
+ if pct_diff > pct:
+ print("{:.2f}% of values not close.".format(pct_diff))
+ return False
+ return True
+
+
+# An assortment of problems designed to make sure
+# the bindings are operating correctly.
+_MATRIX_SIZES = (
+ (128, 128, 128, 0.0),
+ (256, 256, 256, 0.5),
+ (2048, 1024, 512, 0.8),
+ (512, 128, 128, 0.0),
+ (128, 128, 512, 0.0),
+ (1024, 512, 512, 0.0),
+ (1024, 512, 512, 0.5),
+ (1024, 512, 512, 0.75),
+ (512, 512, 1024, 0.0),
+ (512, 512, 1024, 0.5),
+ (512, 512, 1024, 0.75),
+ (1024, 1024, 1024, 0.0),
+ (1024, 1024, 1024, 0.5),
+ (1024, 1024, 1024, 0.75),
+)
+
+_TRANSPOSE = (
+ (False, False),
+ (False, True),
+ (True, False),
+ (True, True),
+)
+
+_DTYPE = (
+ torch.float16, torch.bfloat16
+)
+
+def _generate_testcases():
+ testcases = itertools.product(_MATRIX_SIZES, _TRANSPOSE, _DTYPE)
+ testcases = [(*size, *trans, 128, dtype) for
+ (size, trans, dtype) in testcases]
+ return testcases
+
+_LINEAR_OP_TESTS = _generate_testcases()
+
+def _dense_and_sparse(rows, cols, sparsity, blocking, dtype, std=0.1):
+ mask = stk.random.dense_mask(rows, cols, sparsity, blocking)
+ dense = (torch.randn(rows, cols) * std * mask).type(dtype)
+ sparse = stk.ops.to_sparse(dense, blocking)
+ cuda_device = torch.device("cuda")
+ return (dense.to(cuda_device).requires_grad_(True),
+ sparse.to(cuda_device).requires_grad_(True))
+
+
+def _dense(rows, cols, dtype, std=0.1):
+ cuda_device = torch.device("cuda")
+ out = (torch.randn(rows, cols) * std).type(dtype)
+ return out.to(cuda_device).requires_grad_(True)
+
+
+def _dense_2x(rows, cols, dtype):
+ a = _dense(rows, cols, dtype)
+ return a, a.detach().requires_grad_(True)
+
+
+def _with_transpose(op, a, b, trans_a, trans_b):
+ a = a.t() if trans_a else a
+ b = b.t() if trans_b else b
+ return op(a, b)
+
+
+def _mmm(a, b, topo):
+ mask = stk.ops.to_dense(stk.ops.ones_like(topo))
+ return torch.mm(a, b) * mask
+
+
+def _sparse_out_with_transpose(op, a, b, topo, trans_a, trans_b):
+ a = a.t() if trans_a else a
+ b = b.t() if trans_b else b
+ return op(a, b, topo)
+
+
+def _mask(x, mask):
+ mask = stk.ops.to_dense(stk.ops.ones_like(mask))
+ return x * mask
+
+
+# @parameterized.parameters(*_LINEAR_OP_TESTS)
+# class LinearOpsTest(parameterized.TestCase):
+#
+# def testLinearOps_Dsd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+# # Construct the operands.
+# a_shape = (k, m) if trans_a else (m, k)
+# a_dense, a = _dense_and_sparse(*a_shape, sparsity, blocking, dtype)
+# b_shape = (n, k) if trans_b else (k, n)
+# b, bcp = _dense_2x(*b_shape, dtype)
+#
+# # Execute the matmul.
+# out = _with_transpose(stk.ops.dsd, a, b, trans_a, trans_b)
+# expected_out = _with_transpose(torch.mm, a_dense, bcp, trans_a, trans_b)
+#
+# # Compute the gradients w.r.t. the inputs.
+# expected_out.sum().backward()
+# out.sum().backward()
+#
+# # Validate the results.
+# self.assertEqual(out.dim(), 2)
+# self.assertEqual(expected_out.size()[0], out.size()[0])
+# self.assertEqual(expected_out.size()[1], out.size()[1])
+# self.assertTrue(allclose(out, expected_out))
+#
+# # LHS gradient.
+# grad = stk.ops.to_dense(a.grad)
+# expected_grad = _mask(a_dense.grad, a.grad)
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# # RHS gradient.
+# grad = b.grad
+# expected_grad = bcp.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# def testLinearOps_Dds(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+# # Construct the operands.
+# a_shape = (k, m) if trans_a else (m, k)
+# a, acp = _dense_2x(*a_shape, dtype)
+# b_shape = (n, k) if trans_b else (k, n)
+# b_dense, b = _dense_and_sparse(*b_shape, sparsity, blocking, dtype)
+#
+# # Execute the matmul.
+# out = _with_transpose(stk.ops.dds, a, b, trans_a, trans_b)
+# expected_out = _with_transpose(torch.mm, acp, b_dense, trans_a, trans_b)
+#
+# # Compute the gradients w.r.t. the inputs.
+# expected_out.sum().backward()
+# out.sum().backward()
+#
+# # Validate the results.
+# self.assertEqual(out.dim(), 2)
+# self.assertEqual(expected_out.size()[0], out.size()[0])
+# self.assertEqual(expected_out.size()[1], out.size()[1])
+# self.assertTrue(allclose(out, expected_out))
+#
+# # LHS gradient.
+# grad = a.grad
+# expected_grad = acp.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# # RHS gradient.
+# grad = stk.ops.to_dense(b.grad)
+# expected_grad = _mask(b_dense.grad, b.grad)
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# def testLinearOps_Sdd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+# # Construct the operands.
+# a_shape = (k, m) if trans_a else (m, k)
+# a, acp = _dense_2x(*a_shape, dtype)
+# b_shape = (n, k) if trans_b else (k, n)
+# b, bcp = _dense_2x(*b_shape, dtype)
+# _, topo = _dense_and_sparse(m, n, sparsity, blocking, dtype)
+#
+# # Execute the matmul.
+# out = _sparse_out_with_transpose(stk.ops.sdd, a, b, topo, trans_a, trans_b)
+# expected_out = _sparse_out_with_transpose(_mmm, acp, bcp, topo, trans_a, trans_b)
+#
+# # Compute the gradients w.r.t. the inputs.
+# expected_out.sum().backward()
+# stk.ops.sum(out).backward()
+#
+# # Validate the results.
+# out = stk.ops.to_dense(out)
+# self.assertEqual(out.dim(), 2)
+# self.assertEqual(expected_out.size()[0], out.size()[0])
+# self.assertEqual(expected_out.size()[1], out.size()[1])
+# self.assertTrue(allclose(out, expected_out))
+#
+# # LHS gradient.
+# grad = a.grad
+# expected_grad = acp.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+#
+# # RHS gradient.
+# grad = b.grad
+# expected_grad = bcp.grad
+# self.assertEqual(grad.dim(), 2)
+# self.assertEqual(expected_grad.size()[0], grad.size()[0])
+# self.assertEqual(expected_grad.size()[1], grad.size()[1])
+# self.assertTrue(allclose(grad, expected_grad))
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch29-cxx11-cu129-x86_64-linux/stk/ops/matrix_ops.py b/build/torch29-cxx11-cu129-x86_64-linux/stk/ops/matrix_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..447c72dc73439d84f58c917676cc04e64f13e97d
--- /dev/null
+++ b/build/torch29-cxx11-cu129-x86_64-linux/stk/ops/matrix_ops.py
@@ -0,0 +1,98 @@
+from ..backend import sputnik
+from ..matrix import Matrix
+import torch
+import numpy as np
+
+
+@torch.no_grad()
+def row_indices(shape, data, offsets, column_indices):
+ return sputnik.row_indices(shape, data, offsets, column_indices)
+
+
+# TODO(tgale): Replace this helper with a custom kernel. This operation
+# is much simpler to do than how it's currently implemented.
+@torch.no_grad()
+def _expand_for_blocking(idxs, blocking):
+ # Duplicate for block column dimension.
+ idxs = torch.reshape(idxs, [idxs.size()[0], 1, 2]).repeat(1, blocking, 1)
+
+ # Update the column indices.
+ idxs[:, :, 1] *= blocking
+ idxs[:, :, 1] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking])
+
+ # Duplicate for block row dimension.
+ idxs = torch.reshape(idxs, [idxs.size()[0], 1, blocking, 2])
+ idxs = idxs.repeat(1, blocking, 1, 1)
+
+ # Update the row indices.
+ idxs[:, :, :, 0] *= blocking
+ idxs[:, :, :, 0] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking, 1])
+ idxs = torch.reshape(idxs, [-1, 2])
+ return idxs
+
+
+# TODO(tgale): Add input type checking.
+@torch.no_grad()
+def to_dense(x):
+ assert isinstance(x, Matrix)
+
+ shape = (np.prod(x.shape[:-1]), x.shape[-1])
+ row_idxs = x.row_indices.type(torch.int32)
+ col_idxs = x.column_indices.type(torch.int32)
+ indices = _expand_for_blocking(torch.stack([row_idxs, col_idxs], dim=1), x.blocking)
+ indices = (indices[:, 0] * shape[1] + indices[:, 1]).type(torch.int64)
+
+ out = torch.zeros(shape[0] * shape[1], dtype=x.dtype, device=x.device)
+ out.scatter_(0, indices, x.data.flatten())
+ return out.reshape(x.size())
+
+
+@torch.no_grad()
+def _mask(x, blocking=1):
+ assert x.dim() == 2
+ assert x.size()[0] % blocking == 0
+ assert x.size()[1] % blocking == 0
+ block_rows = x.size()[0] // blocking
+ block_cols = x.size()[1] // blocking
+ x = torch.reshape(x, [block_rows, blocking, block_cols, blocking])
+ x = torch.sum(torch.abs(x), dim=(1, 3))
+ return x != 0
+
+
+# TODO(tgale): Add input type checking.
+@torch.no_grad()
+def to_sparse(x, blocking=1):
+ m = _mask(x, blocking)
+
+ # TODO(tgale): Set to appropriate type for input matrix.
+ row_nnzs = torch.sum(m, dim=1).type(torch.int32)
+ zeros = torch.zeros((1,), dtype=row_nnzs.dtype, device=row_nnzs.device)
+ offsets = torch.cat([zeros, torch.cumsum(row_nnzs, dim=0)])
+ offsets = offsets.type(torch.int32)
+
+ indices = torch.nonzero(m).type(torch.int16)
+ row_indices = indices[:, 0]
+ column_indices = indices[:, 1]
+
+ # Nonzero indices in the dense matrix.
+ nonzero_indices = torch.nonzero(m)
+ nonzero_indices = _expand_for_blocking(nonzero_indices, blocking)
+ nonzero_indices = nonzero_indices[:, 0] * x.size()[1] + nonzero_indices[:, 1]
+
+ # Gather the data and construct the sparse matrix.
+ data = torch.gather(x.flatten(), dim=0, index=nonzero_indices)
+ data = torch.reshape(data, [-1, blocking, blocking])
+ return Matrix(x.size(), data, row_indices, column_indices, offsets)
+
+
+@torch.no_grad()
+def ones_like(x):
+ return Matrix(x.size(),
+ torch.ones_like(x.data),
+ x.row_indices,
+ x.column_indices, x.offsets)
+
+
+def sum(x):
+ assert isinstance(x, Matrix)
+ return x.data.sum()
diff --git a/build/torch29-cxx11-cu129-x86_64-linux/stk/ops/matrix_ops_test.py b/build/torch29-cxx11-cu129-x86_64-linux/stk/ops/matrix_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..d172d921f6f08b0e4fb709207a458b0e1e071bd0
--- /dev/null
+++ b/build/torch29-cxx11-cu129-x86_64-linux/stk/ops/matrix_ops_test.py
@@ -0,0 +1,62 @@
+import unittest
+
+# from absl.testing import parameterized
+import stk
+import torch
+
+
+# @parameterized.parameters(
+# (8, 16, 0.0, 1),
+# (8, 16, 0.5, 1),
+# (8, 16, .95, 1),
+# (16, 8, 0.0, 1),
+# (16, 8, 0.5, 1),
+# (16, 8, .95, 1),
+# (8, 16, 0.0, 8),
+# (8, 16, 0.5, 8),
+# (8, 16, 1.0, 8),
+# (16, 8, 0.0, 8),
+# (16, 8, 0.5, 8),
+# (16, 8, 1.0, 8),
+# (128, 256, 0.5, 16),
+# (256, 128, 0.75, 32),
+# (512, 512, .875, 128))
+# class MatrixOpsTest(parameterized.TestCase):
+#
+# def testMatrixOps_FormatConversion(self, rows, cols, sparsity, blocking):
+# mask = stk.random.dense_mask(rows, cols, sparsity, blocking)
+# x = (torch.randn(rows, cols) * mask).type(torch.float16)
+#
+# # Convert the matrix to sparse format.
+# sparse_x = stk.ops.to_sparse(x, blocking)
+#
+# # Validate the matrix.
+# sparse_x.validate()
+#
+# # Validate the shape.
+# self.assertEqual(sparse_x.dim(), 2)
+# self.assertEqual(sparse_x.size()[0], rows)
+# self.assertEqual(sparse_x.size()[1], cols)
+#
+# # Validate the sparsity.
+# numblocks = rows // blocking * cols // blocking
+# nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+# self.assertEqual(sparse_x.nnz, nnz)
+#
+# # Convert back to dense format.
+# dense_x = stk.ops.to_dense(sparse_x)
+#
+# # Validate the shape.
+# self.assertEqual(dense_x.dim(), 2)
+# self.assertEqual(dense_x.size()[0], rows)
+# self.assertEqual(dense_x.size()[1], cols)
+#
+# # Validate the sparsity
+# self.assertEqual(torch.count_nonzero(dense_x).item(), nnz)
+#
+# # Validate the output.
+# self.assertTrue(torch.all(torch.eq(x, dense_x)))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch29-cxx11-cu129-x86_64-linux/stk/random/__init__.py b/build/torch29-cxx11-cu129-x86_64-linux/stk/random/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2576d1ca27283f77569a9a620c7c99fa68aaf30e
--- /dev/null
+++ b/build/torch29-cxx11-cu129-x86_64-linux/stk/random/__init__.py
@@ -0,0 +1,2 @@
+# from stk.random.random_ops import dense_mask, mask, randn
+from .random_ops import dense_mask, mask, randn
diff --git a/build/torch29-cxx11-cu129-x86_64-linux/stk/random/random_ops.py b/build/torch29-cxx11-cu129-x86_64-linux/stk/random/random_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1b36771e0eb8e7abf46bcb3b136b5fb1d29df93
--- /dev/null
+++ b/build/torch29-cxx11-cu129-x86_64-linux/stk/random/random_ops.py
@@ -0,0 +1,36 @@
+import numpy as np
+import torch
+from ..ops import matrix_ops
+
+
+@torch.no_grad()
+def dense_mask(rows, cols, sparsity, blocking=1):
+ assert sparsity >= 0.0 and sparsity <= 1.0
+ assert rows % blocking == 0 and cols % blocking == 0
+
+ block_rows, block_cols = (rows // blocking, cols // blocking)
+ nnz = round(block_rows * block_cols * (1 - sparsity))
+
+ out = np.ones(block_rows * block_cols)
+ mask = np.random.choice(out.size, out.size - nnz, replace=False)
+ out[mask] = 0.0
+
+ out = np.tile(
+ np.reshape(out, [block_rows, 1, block_cols, 1]),
+ (1, blocking, 1, blocking))
+ out = np.reshape(out, [rows, cols])
+ return torch.from_numpy(out.astype(np.float32))
+
+
+@torch.no_grad()
+def mask(m, n, sparsity, blocking=1):
+ out = dense_mask(m, n, sparsity, blocking).type(torch.float16)
+ return matrix_ops.to_sparse(out, blocking=blocking)
+
+
+@torch.no_grad()
+def randn(shape, sparsity, blocking=1):
+ shape_2d = (np.prod(shape[:-1]), shape[-1])
+ out = mask(*shape_2d, sparsity, blocking)
+ out.data.copy_(torch.randn(*out.data.shape))
+ return out.view(*shape)
diff --git a/build/torch29-cxx11-cu129-x86_64-linux/stk/random/random_ops_test.py b/build/torch29-cxx11-cu129-x86_64-linux/stk/random/random_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..71d716b78b5ec009cbf9ac2dfdf09162a0102e62
--- /dev/null
+++ b/build/torch29-cxx11-cu129-x86_64-linux/stk/random/random_ops_test.py
@@ -0,0 +1,73 @@
+import unittest
+
+# from absl.testing import parameterized
+from . import random
+import torch
+
+
+# @parameterized.parameters(
+# (8, 16, 0.0, 1),
+# (8, 16, 0.5, 1),
+# (8, 16, .95, 1),
+# (16, 8, 0.0, 1),
+# (16, 8, 0.5, 1),
+# (16, 8, .95, 1),
+# (8, 16, 0.0, 8),
+# (8, 16, 0.5, 8),
+# (8, 16, 1.0, 8),
+# (16, 8, 0.0, 8),
+# (16, 8, 0.5, 8),
+# (16, 8, 1.0, 8),
+# (128, 256, 0.5, 16),
+# (256, 128, 0.75, 32),
+# (512, 512, .875, 128))
+# class RandomOpsTest(parameterized.TestCase):
+#
+# def testRandomOps_DenseMask(self, rows, cols, sparsity, blocking):
+# mask = random.dense_mask(
+# rows, cols, sparsity, blocking)
+#
+# # Validate the shape.
+# self.assertEqual(mask.dim(), 2)
+# self.assertEqual(mask.size()[0], rows)
+# self.assertEqual(mask.size()[1], cols)
+#
+# # Validate the sparsity
+# numblocks = rows // blocking * cols // blocking
+# nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+# self.assertEqual(
+# torch.count_nonzero(mask).item(),
+# nnz)
+#
+# # Check values are zero or one.
+# self.assertTrue(
+# torch.all(torch.logical_or(
+# torch.eq(mask, 0),
+# torch.eq(mask, 1))))
+#
+# def testRandomOps_SparseMask(self, rows, cols, sparsity, blocking):
+# mask = random.mask(
+# rows, cols, sparsity, blocking)
+#
+# # Validate the matrix.
+# mask.validate()
+#
+# # Validate the shape.
+# self.assertEqual(mask.dim(), 2)
+# self.assertEqual(mask.size()[0], rows)
+# self.assertEqual(mask.size()[1], cols)
+#
+# # Validate the sparsity.
+# numblocks = rows // blocking * cols // blocking
+# nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+# self.assertEqual(mask.nnz, nnz)
+#
+# # Check values are zero or one.
+# self.assertTrue(
+# torch.all(torch.logical_or(
+# torch.eq(mask.data, 0),
+# torch.eq(mask.data, 1))))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch29-cxx11-cu129-x86_64-linux/xpu_fused_moe.py b/build/torch29-cxx11-cu129-x86_64-linux/xpu_fused_moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..0440b484e6f4073a384dcf3bc562174d2ef71d25
--- /dev/null
+++ b/build/torch29-cxx11-cu129-x86_64-linux/xpu_fused_moe.py
@@ -0,0 +1,672 @@
+# SPDX-License-Identifier: Apache-2.0
+# MegaBlocks XPU Fused MoE Implementation
+import os
+import torch
+
+from ._ops import ops, add_op_namespace_prefix
+
+from torch.library import register_fake
+
+
+def resolve_dtensor(weight: torch.Tensor):
+ """Convert DTensor to local tensor for use with custom ops."""
+ from torch.distributed._tensor import DTensor
+ if isinstance(weight, DTensor):
+ return weight.to_local()
+ return weight
+
+
+# Register fake/meta kernels for torch.compile compatibility
+def _register_xpu_fake_kernels():
+ """Register fake kernels for XPU MoE operations to support torch.compile."""
+
+ def _register_if_available(op_name, fn):
+ if hasattr(ops, op_name):
+ register_fake(add_op_namespace_prefix(op_name))(fn)
+
+ _register_if_available(
+ "cutlass_grouped_gemm_interface",
+ lambda ptr_A, ptr_B, ptr_scales, ptr_bias, ptr_D, expert_first_token_offset, N, K, num_experts, is_B_int4, is_B_mxfp4: ptr_D,
+ )
+
+ _register_if_available(
+ "fused_moe_prologue",
+ lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, ep_rank, ep_size, num_experts_on_rank: None,
+ )
+
+ _register_if_available(
+ "moe_gather",
+ lambda output, moe_output, topk_weights, permuted_row_to_unpermuted_row, unpermuted_row_to_permuted_row, expert_first_token_offset, num_experts: None,
+ )
+
+ _register_if_available(
+ "silu_and_mul",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "mul_and_silu",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_and_mul",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_tanh_and_mul",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_fast",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_new",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_quick",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "swigluoai_and_mul",
+ lambda out, input, alpha=1.702, limit=7.0: None,
+ )
+
+
+# Register fake kernels on module load
+_register_xpu_fake_kernels()
+
+
+# default
+def cutlass_grouped_gemm(input_A, input_B, bias, output, expert_token_count, n,
+ k, num_experts):
+ # expert_token_count_ = torch.tensor(expert_token_count,
+ # dtype=torch.int64,
+ # device=input_A.device)
+ # if bias is not None:
+ # bias = bias.repeat_interleave(expert_token_count_, dim=0).float()
+
+ def exclusive_prefix_sum(arr):
+ prefix = [0]
+ for i, x in enumerate(arr):
+ prefix.append(prefix[-1] + x)
+ return prefix
+
+ expert_offset = torch.tensor(exclusive_prefix_sum(expert_token_count),
+ dtype=torch.int64,
+ device="xpu")
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=None,
+ ptr_bias=bias,
+ ptr_D=output,
+ expert_first_token_offset=expert_offset,
+ N=n,
+ K=k,
+ num_experts=num_experts,
+ is_B_int4=False,
+ is_B_mxfp4=False)
+
+
+def cutlass_grouped_gemm_xe2(input_A, input_B, scales, bias, output,
+ num_rows_per_expert, n, k, num_experts, is_B_int4,
+ is_B_mxfp4):
+ expert_first_token_offset = torch.cat([
+ torch.tensor([0],
+ dtype=num_rows_per_expert.dtype,
+ device=num_rows_per_expert.device),
+ torch.cumsum(num_rows_per_expert, dim=0)
+ ]).to(torch.int64)
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=scales,
+ ptr_bias=bias,
+ ptr_D=output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=n,
+ K=k,
+ num_experts=num_experts,
+ is_B_int4=is_B_int4,
+ is_B_mxfp4=is_B_mxfp4)
+
+
+def ceilDiv(a, b):
+ return (a + b - 1) // b
+
+
+def compute_num_tokens_per_block(num_tokens, num_experts_per_node):
+ for num_tokens_per_block in [32, 64, 128, 256, 512, 1024]:
+ num_blocks_per_seq = ceilDiv(num_tokens, num_tokens_per_block)
+ if num_blocks_per_seq * num_experts_per_node <= num_tokens_per_block:
+ return num_tokens_per_block
+ return 1024
+
+
+def _bytes_to_typed_tensor(byte_tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
+ """Reinterpret a uint8 buffer as a typed tensor by copying bytes.
+
+ This avoids `Tensor.view(dtype)` which can fail under torch.compile
+ constant folding when shape divisibility is not proven.
+ """
+ if byte_tensor.dtype != torch.uint8:
+ raise ValueError("byte_tensor must be uint8")
+ itemsize = torch.empty((), dtype=dtype).element_size()
+ numel = byte_tensor.numel() // itemsize
+ out = torch.empty((numel,), dtype=dtype, device=byte_tensor.device)
+ out.view(torch.uint8).copy_(byte_tensor.contiguous())
+ return out
+
+
+def implement_zp(qweight):
+ # change u4 to s4 to avoid zero point in gemm kernel
+ # only support default zero point now
+ assert qweight.dtype == torch.uint8, "Input tensor must be uint8"
+
+ high_u4 = (qweight >> 4) & 0x0F
+ low_u4 = qweight & 0x0F
+
+ high_s8 = high_u4.to(torch.int8)
+ low_s8 = low_u4.to(torch.int8)
+
+ high_s8 = high_s8 - 8
+ low_s8 = low_s8 - 8
+
+ def pack_compact(a, b):
+
+ def process_number(x):
+ sign = (x < 0).to(torch.uint8)
+ abs_low3 = (x.view(torch.uint8) & 0x7).to(torch.uint8)
+ return (sign << 3) | abs_low3
+
+ packed_a = process_number(a)
+ packed_b = process_number(b)
+
+ return (packed_a << 4) | packed_b
+
+ result = pack_compact(high_s8, low_s8)
+
+ return result
+
+
+def xpu_fused_moe(hidden_states,
+ w13,
+ w13_scales,
+ w13_bias,
+ w2,
+ w2_scales,
+ w2_bias,
+ topk_weights,
+ topk_ids,
+ n_experts_per_token,
+ activation,
+ num_experts,
+ ep_rank=0,
+ ep_size=1,
+ is_fp8=False,
+ is_int4=False,
+ is_mxfp4=False):
+ '''
+ hidden_states: [num_rows, hidden_size]
+ w13: [num_experts, 2*inter_size, hidden_size]
+ w13_scales:
+ None for bf16/fp16
+ or [num_experts] for fp8
+ or [num_experts, 2*inter_size, hidden_size // group_size] for 4bits
+ w13_bias: [num_experts, 2*inter_size] or None
+ w2: [num_experts, hidden_size, inter_size]
+ w2_scales:
+ None for bf16/fp16
+ or [num_experts] for fp8
+ or [num_experts, hidden_size, inter_size // group_size] for 4bits
+ w2_bias: [num_experts, hidden_size] or None
+ topk_weights: [num_rows, topk]
+ topk_ids: [num_rows, topk]
+ n_experts_per_token: int
+ activation: str
+ num_experts: int
+ is_int4: bool
+ is_mxfp4: bool
+ '''
+
+ # Resolve DTensors to local tensors before passing to custom ops
+ hidden_states = resolve_dtensor(hidden_states)
+ w13 = resolve_dtensor(w13)
+ w2 = resolve_dtensor(w2)
+ if w13_scales is not None:
+ w13_scales = resolve_dtensor(w13_scales)
+ if w13_bias is not None:
+ w13_bias = resolve_dtensor(w13_bias)
+ if w2_scales is not None:
+ w2_scales = resolve_dtensor(w2_scales)
+ if w2_bias is not None:
+ w2_bias = resolve_dtensor(w2_bias)
+ topk_weights = resolve_dtensor(topk_weights)
+ topk_ids = resolve_dtensor(topk_ids)
+
+ output = torch.empty_like(hidden_states)
+ num_rows, hidden_size = list(hidden_states.shape)
+
+ dim_last = w13.shape[-1]
+ dim_second_last = w13.shape[-2]
+
+ # w13 is combined gate+up weights, so one dimension is 2*inter_size
+ # Determine which dimension is hidden_size and which is 2*inter_size
+ if dim_second_last == hidden_size:
+ # w13 is [E, hidden_size, 2*inter_size] - standard layout
+ inter_size = dim_last // 2
+ needs_transpose = False
+ else:
+ # w13 is [E, 2*inter_size, hidden_size] - needs transpose
+ inter_size = dim_second_last // 2
+ needs_transpose = True
+
+ assert w13.is_contiguous() and w2.is_contiguous()
+
+ # 4bits support [E, N, K]
+ # other types [E, K, N]
+ if not is_int4 and not is_mxfp4:
+ if not hasattr(w13, 'xpu_fused_moe'):
+ if needs_transpose:
+ w13.data = w13.transpose(-1, -2).contiguous()
+ w2.data = w2.transpose(-1, -2).contiguous()
+ w13.xpu_fused_moe = True
+ w13.inter_size = inter_size
+ else:
+ inter_size = w13.inter_size
+
+ if is_int4 and not hasattr(w13, 'xpu_fused_moe'):
+ w13_tmp = torch.empty_like(w13)
+ w2_tmp = torch.empty_like(w2)
+ for i in range(num_experts):
+ w13_tmp[i] = implement_zp(w13[i])
+ w2_tmp[i] = implement_zp(w2[i])
+ w13_tmp = w13_tmp.contiguous()
+ w2_tmp = w2_tmp.contiguous()
+ w13.data = w13_tmp
+ w2.data = w2_tmp
+ w13.xpu_fused_moe = True
+
+ # TODO: will all integrated in Cpp func. Temporary expose before gemm fusion
+ num_experts_per_node = num_experts
+ experts_per_token = n_experts_per_token
+ num_moe_inputs = n_experts_per_token * num_rows
+ permuted_elems = num_moe_inputs * hidden_size
+ # interbuf_elems = num_moe_inputs * inter_size
+ permuted_row_to_unpermuted_row_size = num_moe_inputs * 4
+ permuted_token_selected_experts_size = num_moe_inputs * 4
+ src_to_dest_map_size = experts_per_token * num_rows * 4
+ expert_first_token_offset_size = (num_experts_per_node + 1) * 8
+ num_tokens_per_block = compute_num_tokens_per_block(
+ num_rows, num_experts_per_node)
+ num_blocks_per_seq = ceilDiv(num_rows, num_tokens_per_block)
+ blocked_expert_counts_size = num_experts_per_node * num_blocks_per_seq * 4
+ blocked_expert_counts_cumsum_size = blocked_expert_counts_size
+ blocked_row_to_unpermuted_row_size = num_experts_per_node * num_rows * 4
+ permuted_data_size = permuted_elems * hidden_states.element_size()
+ permuted_token_final_scales_size = num_moe_inputs * 4
+
+ ws_map = {}
+ map_offset = 0
+
+ def config_ws(name, size):
+ nonlocal map_offset
+ if size % 256 != 0:
+ size += 256 - size % 256
+ ws_map[name] = (size, map_offset)
+ map_offset += size
+
+ config_ws("permuted_row_to_unpermuted_row",
+ permuted_row_to_unpermuted_row_size)
+ config_ws("permuted_token_selected_experts",
+ permuted_token_selected_experts_size)
+ config_ws("unpermuted_row_to_permuted_row", src_to_dest_map_size)
+ config_ws("blocked_expert_counts", blocked_expert_counts_size)
+ config_ws("blocked_expert_counts_cumsum",
+ blocked_expert_counts_cumsum_size)
+ config_ws("blocked_row_to_unpermuted_row",
+ blocked_row_to_unpermuted_row_size)
+ config_ws("expert_first_token_offset", expert_first_token_offset_size)
+ config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
+ config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
+
+ workspace = torch.zeros(map_offset,
+ dtype=torch.uint8,
+ device=hidden_states.device)
+ if topk_ids.dtype == torch.int32:
+ topk_ids = topk_ids.to(torch.int64)
+ ops.fused_moe_prologue(
+ input=hidden_states,
+ token_selected_experts=topk_ids,
+ token_final_scales=topk_weights,
+ workspace=workspace,
+ hidden_size=hidden_size,
+ inter_size=inter_size,
+ ep_rank=ep_rank,
+ ep_size=ep_size,
+ num_experts_on_rank=num_experts_per_node)
+
+ expert_first_token_offset_bytes = workspace[
+ ws_map["expert_first_token_offset"][1]:
+ ws_map["expert_first_token_offset"][1] +
+ expert_first_token_offset_size]
+ unpermuted_row_to_permuted_row_bytes = workspace[
+ ws_map["unpermuted_row_to_permuted_row"][1]:
+ ws_map["unpermuted_row_to_permuted_row"][1] +
+ src_to_dest_map_size]
+ permuted_row_to_unpermuted_row_bytes = workspace[
+ ws_map["permuted_row_to_unpermuted_row"][1]:
+ ws_map["permuted_row_to_unpermuted_row"][1] +
+ permuted_row_to_unpermuted_row_size]
+
+ if torch.compiler.is_compiling():
+ expert_first_token_offset = _bytes_to_typed_tensor(
+ expert_first_token_offset_bytes, torch.int64
+ )
+ unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
+ unpermuted_row_to_permuted_row_bytes, torch.int32
+ )
+ permuted_row_to_unpermuted_row = _bytes_to_typed_tensor(
+ permuted_row_to_unpermuted_row_bytes, torch.int32
+ )
+ else:
+ expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
+ unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
+ permuted_row_to_unpermuted_row = permuted_row_to_unpermuted_row_bytes.view(torch.int32)
+ gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
+ ws_map["overlapped_gemm1_gemm2_inputs"][1] +
+ permuted_data_size].view(hidden_states.dtype).view(
+ num_moe_inputs, hidden_size)
+ # permuted_token_final_scales = workspace[
+ # ws_map["permuted_token_final_scales"][1]:
+ # ws_map["permuted_token_final_scales"][1] +
+ # permuted_token_final_scales_size].view(torch.float)
+ gemm1_output = torch.empty((num_moe_inputs, 2 * inter_size),
+ dtype=hidden_states.dtype,
+ device=hidden_states.device)
+
+ ########### gemm1 ##################
+ input_B = w13
+
+ if not is_fp8 and not is_int4 and not is_mxfp4:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=gemm1_input,
+ ptr_B=input_B,
+ ptr_scales=None,
+ ptr_bias=w13_bias,
+ ptr_D=gemm1_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=2 * inter_size,
+ K=hidden_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+ else:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=gemm1_input,
+ ptr_B=input_B,
+ ptr_scales=w13_scales,
+ ptr_bias=w13_bias,
+ ptr_D=gemm1_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=2 * inter_size,
+ K=hidden_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+
+ # act
+ act_output = torch.empty((num_moe_inputs, inter_size),
+ dtype=gemm1_output.dtype,
+ device=gemm1_output.device)
+ if activation == "silu":
+ ops.silu_and_mul(act_output, gemm1_output)
+ elif activation == "gelu":
+ ops.gelu_and_mul(act_output, gemm1_output)
+ elif activation == "swigluoai":
+ ops.swigluoai_and_mul(act_output, gemm1_output, 1.702, 7.0)
+ else:
+ raise ValueError(f"Unsupported FusedMoe activation: {activation}.")
+
+ ########### gemm2 ##################
+ input_A = act_output.contiguous()
+ input_B = w2
+ gemm2_output = torch.empty((num_moe_inputs, hidden_size),
+ dtype=hidden_states.dtype,
+ device=hidden_states.device)
+ if not is_fp8 and not is_int4 and not is_mxfp4:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=None,
+ ptr_bias=w2_bias,
+ ptr_D=gemm2_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=hidden_size,
+ K=inter_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+ else:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=w2_scales,
+ ptr_bias=w2_bias,
+ ptr_D=gemm2_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=hidden_size,
+ K=inter_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+
+ ops.moe_gather(output, gemm2_output, topk_weights,
+ permuted_row_to_unpermuted_row,
+ unpermuted_row_to_permuted_row,
+ expert_first_token_offset,
+ num_experts_per_node)
+ return output
+
+
+def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor:
+ """Apply jitter to the input tensor for regularization."""
+ low = 1.0 - moe_jitter_eps
+ high = 1.0 + moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return x * (low + noise * (high - low))
+
+
+def compute_top_k(scores: torch.Tensor, moe_top_k: int):
+ """Compute the top-k scores from the logits."""
+ if moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, moe_top_k, dim=-1)
+
+
+def route_tokens_xpu(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: torch.Tensor,
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ training: bool = False,
+) -> tuple:
+ """Route tokens to experts and compute expert weights and indices (XPU version)."""
+ if training and moe_jitter_eps is not None:
+ x = apply_jitter(x, moe_jitter_eps)
+
+ x_flat = x.view(-1, x.shape[-1])
+ logits = torch.nn.functional.linear(x_flat, router_weight, router_bias)
+ expert_weights, expert_indices = compute_top_k(logits, moe_top_k)
+ expert_weights = expert_weights.softmax(dim=-1)
+ if moe_normalize_expert_weights is not None:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+
+ return logits, expert_weights, expert_indices
+
+
+def _get_device_mesh(model):
+ """Extract device_mesh from child's unused pre_hook closure for EP support."""
+ try:
+ hook = next(
+ h
+ for h in model.experts._forward_pre_hooks.values()
+ if "device_mesh" in h.__code__.co_freevars
+ )
+ return hook.__closure__[
+ hook.__code__.co_freevars.index("device_mesh")
+ ].cell_contents
+ except Exception:
+ return None
+
+
+class MegaBlocksMoeMLP(torch.nn.Module):
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> tuple:
+ """
+ Forward pass through the MoE layer.
+
+ Args:
+ x: Input tensor of shape [batch_size, seq_len, hidden_size] or [tokens, hidden_size]
+
+ Returns:
+ Tuple of (output, expert_weights) where:
+ - output: Tensor of same shape as input
+ - expert_weights: Expert weights for each token [tokens, top_k]
+ """
+
+ # Get MoE parameters from the wrapped modules
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+
+ # Get EP (Expert Parallelism) parameters
+ ep_size = 1
+ ep_rank = 0
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = _get_device_mesh(self)
+ if device_mesh is not None:
+ expert_parallel_group = device_mesh.get_group()
+ if expert_parallel_group is not None:
+ import torch.distributed as dist
+ if dist.is_initialized():
+ ep_size = dist.get_world_size(expert_parallel_group)
+ ep_rank = dist.get_rank(expert_parallel_group)
+
+ # Number of experts on this rank
+ num_experts_on_rank = moe_num_experts // ep_size
+
+ # Detect activation type - check for GptOss-style swigluoai activation
+ # GptOssExperts has alpha and limit attributes for swigluoai
+ if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
+ activation = "swigluoai"
+ else:
+ activation = getattr(self.experts, "activation", "silu")
+
+ # Get weight tensors - support different naming conventions
+ if hasattr(self.experts, "gate_up_proj"):
+ w13 = self.experts.gate_up_proj
+ # NOTE: swigluoai_and_mul kernel expects interleaved layout [g0,u0,g1,u1,...]
+ # which matches GptOss's gate_up_proj layout, so no conversion needed.
+
+ elif hasattr(self.experts, "w1"):
+ # Combine w1 and w3 if stored separately
+ w1 = self.experts.w1
+ w3 = getattr(self.experts, "w3", None)
+ if w3 is not None:
+ w13 = torch.cat([w1, w3], dim=-2)
+ else:
+ w13 = w1
+ else:
+ raise AttributeError("experts module must have 'gate_up_proj' or 'w1' attribute")
+
+ if hasattr(self.experts, "down_proj"):
+ w2 = self.experts.down_proj
+ elif hasattr(self.experts, "w2"):
+ w2 = self.experts.w2
+ else:
+ raise AttributeError("experts module must have 'down_proj' or 'w2' attribute")
+
+ # Get optional bias tensors
+ w13_bias = getattr(self.experts, "gate_up_proj_bias", None)
+ w2_bias = getattr(self.experts, "down_proj_bias", None)
+
+ # Get quantization info
+ is_fp8 = getattr(self.experts, "is_fp8", False)
+ is_int4 = getattr(self.experts, "is_int4", False)
+ is_mxfp4 = getattr(self.experts, "is_mxfp4", False)
+
+ w13_scales = getattr(self.experts, "gate_up_proj_scales", None)
+ w2_scales = getattr(self.experts, "down_proj_scales", None)
+
+ # Store original shape
+ in_shape = x.size()
+
+ # Route tokens to experts
+ logits, expert_weights, expert_indices = route_tokens_xpu(
+ x,
+ self.router.weight,
+ getattr(self.router, "bias", None),
+ moe_top_k,
+ moe_num_experts,
+ moe_jitter_eps,
+ moe_normalize_expert_weights,
+ self.training,
+ )
+
+ # Reshape input for fused MoE
+ x_flat = x.view(-1, x.shape[-1])
+
+ # Call XPU fused MoE kernel
+ output = xpu_fused_moe(
+ hidden_states=x_flat,
+ w13=w13,
+ w13_scales=w13_scales,
+ w13_bias=w13_bias,
+ w2=w2,
+ w2_scales=w2_scales,
+ w2_bias=w2_bias,
+ topk_weights=expert_weights.float(),
+ topk_ids=expert_indices,
+ n_experts_per_token=moe_top_k,
+ activation=activation,
+ num_experts=num_experts_on_rank,
+ ep_rank=ep_rank,
+ ep_size=ep_size,
+ is_fp8=is_fp8,
+ is_int4=is_int4,
+ is_mxfp4=is_mxfp4,
+ )
+
+ # All-reduce across EP group to combine partial expert outputs
+ if ep_size > 1 and expert_parallel_group is not None:
+ import torch.distributed as dist
+ dist.all_reduce(output, op=dist.ReduceOp.SUM, group=expert_parallel_group)
+
+ # Restore original shape
+ output = output.view(in_shape)
+
+ return output, expert_weights
+
+
+# Export classes and functions
+__all__ = [
+ "MegaBlocksMoeMLP",
+ "xpu_fused_moe",
+ "cutlass_grouped_gemm",
+ "cutlass_grouped_gemm_xe2",
+]
\ No newline at end of file
diff --git a/build/torch29-cxx11-cu130-aarch64-linux/__init__.py b/build/torch29-cxx11-cu130-aarch64-linux/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..38075732c6d8fa0e1e6ef493145e1aca3851ae6b
--- /dev/null
+++ b/build/torch29-cxx11-cu130-aarch64-linux/__init__.py
@@ -0,0 +1,202 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+from ._ops import ops
+
+from .grouped_gemm import backend as gg_backend
+from .grouped_gemm import ops as gg_ops
+
+
+from ._layers.arguments import Arguments
+from ._layers.dmoe import ParallelDroplessMLP, dMoE
+from ._layers.glu import SparseGLU
+from ._layers.mlp import MLP, SparseMLP
+from ._layers.moe import MoE, ParallelMLP, get_load_balancing_loss
+
+from . import layers
+
+# This section contains the direct kernel exports (not inlcuded in the original code)
+def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
+ """
+ Compute exclusive cumulative sum along the specified dimension.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ result = ops.exclusive_cumsum(x, dim)
+ out.copy_(result)
+ return out
+
+
+def inclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
+ """
+ Compute inclusive cumulative sum along the specified dimension.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ result = ops.inclusive_cumsum(x, dim)
+ out.copy_(result)
+ return out
+
+
+def histogram(x: torch.Tensor, num_bins: int) -> torch.Tensor:
+ """
+ Compute histogram of input tensor values.
+
+ Args:
+ x: Input tensor
+ num_bins: Number of histogram bins
+
+ Returns:
+ Histogram tensor with counts for each bin
+ """
+ return ops.histogram(x, num_bins)
+
+
+def indices(
+ padded_bins: torch.Tensor,
+ block_size: int,
+ output_block_rows: int,
+ output_block_columns: int,
+) -> torch.Tensor:
+ """
+ Construct indices from padded bins for sparse operations.
+
+ Args:
+ padded_bins: Tensor containing bin boundaries
+ block_size: Size of each block
+ output_block_rows: Number of rows in output blocks
+ output_block_columns: Number of columns in output blocks
+
+ Returns:
+ Tensor containing constructed indices
+ """
+ return ops.indices(padded_bins, block_size, output_block_rows, output_block_columns)
+
+
+def replicate_forward(
+ x: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Forward pass of replicate operation - replicate values according to bin sizes.
+
+ Args:
+ x: Input tensor with values to replicate
+ bins: Tensor containing bin sizes
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ return ops.replicate_forward(x, bins, out)
+
+
+def replicate_backward(
+ grad: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Backward pass of replicate operation - reduce gradients back to bins.
+
+ Args:
+ grad: Gradient tensor to reduce
+ bins: Tensor containing bin sizes
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ return ops.replicate_backward(grad, bins, out)
+
+
+def sort(
+ x: torch.Tensor, end_bit: int, x_out: torch.Tensor, iota_out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Radix sort with index tracking.
+
+ Args:
+ x: Input tensor to sort
+ end_bit: Number of bits to consider in sorting
+ x_out: Output tensor for sorted values
+ iota_out: Output tensor for sorted indices
+
+ Returns:
+ The sorted values tensor
+ """
+ return ops.sort(x, end_bit, x_out, iota_out)
+
+
+# Convenience functions for common use cases
+def cumsum(x: torch.Tensor, dim: int = -1, exclusive: bool = False) -> torch.Tensor:
+ """
+ Compute cumulative sum with automatic output allocation.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum (default: last dimension)
+ exclusive: Whether to compute exclusive (True) or inclusive (False) cumsum
+
+ Returns:
+ New tensor containing the cumulative sum
+ """
+ out = torch.empty_like(x)
+ if exclusive:
+ return exclusive_cumsum(x, dim, out)
+ else:
+ return inclusive_cumsum(x, dim, out)
+
+
+def argsort(x: torch.Tensor, end_bit: int = 32) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Sort tensor and return both sorted values and indices.
+
+ Args:
+ x: Input tensor to sort
+ end_bit: Number of bits to consider in sorting
+
+ Returns:
+ Tuple of (sorted_values, sorted_indices)
+ """
+ x_out = torch.empty_like(x)
+ iota_out = torch.empty_like(x)
+ sort(x, end_bit, x_out, iota_out)
+ return x_out, iota_out
+
+
+# Export public API
+__all__ = [
+ "MyReplacementLayer",
+ # Direct kernel exports
+ "exclusive_cumsum",
+ "inclusive_cumsum",
+ "histogram",
+ "indices",
+ "replicate_forward",
+ "replicate_backward",
+ "sort",
+ "cumsum",
+ "argsort",
+ # Original exports
+ "Arguments",
+ "ParallelDroplessMLP",
+ "dMoE",
+ "SparseGLU",
+ "MLP",
+ "SparseMLP",
+ "MoE",
+ "ParallelMLP",
+ "get_load_balancing_loss",
+]
diff --git a/build/torch29-cxx11-cu130-aarch64-linux/_layers/__init__.py b/build/torch29-cxx11-cu130-aarch64-linux/_layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a720e7a2cc4e44636f6e433a2750e945dc38e8b2
--- /dev/null
+++ b/build/torch29-cxx11-cu130-aarch64-linux/_layers/__init__.py
@@ -0,0 +1,10 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# from megablocks.layers.dmoe import dMoE
+from .moe import MoE
+
+__all__ = [
+ 'MoE',
+ # 'dMoE',
+]
diff --git a/build/torch29-cxx11-cu130-aarch64-linux/_layers/activation_fn.py b/build/torch29-cxx11-cu130-aarch64-linux/_layers/activation_fn.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e1d956704840aa4daf7d1d71d24e051567feab9
--- /dev/null
+++ b/build/torch29-cxx11-cu130-aarch64-linux/_layers/activation_fn.py
@@ -0,0 +1,33 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Callable, Union
+
+import torch
+from ..stk import Matrix
+
+
+def act_fn(
+ x: Matrix,
+ function: Callable,
+ return_grad_fn: bool = False,
+ **kwargs,
+) -> Union[tuple[Matrix, Any] | Matrix]:
+ assert isinstance(x, Matrix)
+ with torch.set_grad_enabled(torch.is_grad_enabled() or return_grad_fn):
+ if return_grad_fn:
+ x.data.requires_grad = True
+ out = function(x.data, **kwargs)
+ y = Matrix(
+ x.size(),
+ out,
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
+ if return_grad_fn:
+ return y, out.backward
+ return y
diff --git a/build/torch29-cxx11-cu130-aarch64-linux/_layers/all_to_all.py b/build/torch29-cxx11-cu130-aarch64-linux/_layers/all_to_all.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ac7067bcaa34db1d82b340c43550fe3577aa7a3
--- /dev/null
+++ b/build/torch29-cxx11-cu130-aarch64-linux/_layers/all_to_all.py
@@ -0,0 +1,54 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import torch.distributed as dist
+
+
+class AllToAllOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op):
+ out = torch.empty((sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype)
+
+ ctx.input_shape = x.shape
+ ctx.output_split_sizes = output_split_sizes
+ ctx.input_split_sizes = input_split_sizes
+ ctx.group = group
+ handle = dist.all_to_all_single(
+ out,
+ x,
+ output_split_sizes=output_split_sizes,
+ input_split_sizes=input_split_sizes,
+ group=group,
+ async_op=async_op,
+ )
+ return out, handle
+
+ @staticmethod
+ def backward(ctx, grad, _):
+ if ctx.needs_input_grad[0]:
+ out = torch.empty(
+ ctx.input_shape,
+ device=grad.device,
+ dtype=grad.dtype,
+ )
+ dist.all_to_all_single(
+ out,
+ grad,
+ output_split_sizes=ctx.input_split_sizes,
+ input_split_sizes=ctx.output_split_sizes,
+ group=ctx.group,
+ )
+ return out, None, None, None, None
+ return None, None, None, None, None
+
+
+def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False):
+ return AllToAllOp.apply(
+ x,
+ output_split_sizes,
+ input_split_sizes,
+ group,
+ async_op,
+ )
diff --git a/build/torch29-cxx11-cu130-aarch64-linux/_layers/arguments.py b/build/torch29-cxx11-cu130-aarch64-linux/_layers/arguments.py
new file mode 100644
index 0000000000000000000000000000000000000000..4db9b1bd38bc2e2f421625c124f86b85f45c5ae0
--- /dev/null
+++ b/build/torch29-cxx11-cu130-aarch64-linux/_layers/arguments.py
@@ -0,0 +1,101 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import dataclasses
+from functools import partial
+from typing import Any, Callable, Optional, Union
+
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+
+# import megablocks.grouped_gemm_util as grouped_gemm
+from .. import grouped_gemm_util as grouped_gemm
+
+# Type annotation for in-place Tensor initialization function.
+InitFn = Union[Callable[[torch.Tensor], None], partial[torch.Tensor]]
+
+_ALLOWED_BITWIDTHS = (-1, 4, 8)
+
+DEFAULT_ACTIVATION_FN = partial(F.gelu, approximate='tanh')
+
+
+@dataclasses.dataclass
+class Arguments:
+ # Model arguments.
+ hidden_size: int = 1024
+ ffn_hidden_size: int = 4096
+ num_layers: int = 1
+ bias: bool = True
+ return_bias: bool = True
+ activation_fn: Optional[Callable] = DEFAULT_ACTIVATION_FN
+
+ # MoE arguments.
+ moe_num_experts: int = 1
+ moe_top_k: int = 1
+ moe_capacity_factor: int = 1
+ moe_normalize_expert_weights: Optional[Union[int, float]] = None
+ moe_loss_weight: float = 0.1
+ moe_jitter_eps: Optional[float] = None
+ moe_lbl_in_fp32: bool = False
+
+ # Parallelism arguments.
+ moe_expert_model_parallelism: bool = False
+ expert_parallel_group: Optional[dist.ProcessGroup] = None
+ pipeline_model_parallel_size: int = 1
+ num_layers_per_virtual_pipeline_stage: Optional[int] = None
+
+ # Compute arguments.
+ memory_optimized_mlp: bool = False
+ mlp_type: str = 'mlp'
+ mlp_impl: str = 'sparse'
+
+ # Initialization arguments.
+ fp16: bool = True
+ bf16: bool = False
+ device: Union[int, torch.device] = dataclasses.field(default_factory=torch.cuda.current_device)
+ init_method: InitFn = partial(torch.nn.init.normal_, mean=0.0, std=0.02)
+ output_layer_init_method: InitFn = init_method
+
+ # Benchmarking arguments.
+ uniform_expert_assignment: bool = False
+
+ # shared expert arguments
+ shared_expert: bool = False # enable using shared expert
+ fc_cls: Any = torch.nn.Linear # class of the fully connected layer in shared expert (purpose: to allow using custom FC layer eg te.Linear (for FP8))
+ fc_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict,) # kwargs for custom fc layers
+ remat_act_fn: bool = True # enable act fn to be rematerialized instead of stored
+ shared_expert_hidden_size: Optional[
+ int] = None # hidden size of the shared expert IF we want to set it to something different from hidden_size
+ shared_expert_weighted_sum: bool = False # enable using weighted sum for shared expert output (wieghted by number of experts used)
+
+ # Router Z-loss arguments
+ moe_zloss_weight: float = 0 # 1e-3 is a reasonable value
+ moe_zloss_in_fp32: bool = False
+
+ def __post_init__(self):
+ # Sparse MLP is not supported with triton >=3.2.0
+ # TODO: Remove this once sparse is supported with triton >=3.2.0
+ if self.__getattribute__('mlp_impl') == 'sparse':
+ try:
+ import triton
+ if triton.__version__ >= '3.2.0':
+ raise ValueError(
+ 'Sparse MLP is not supported with triton >=3.2.0. Please use mlp_impl="grouped" instead.',
+ )
+ except ImportError:
+ raise ImportError('Triton is required for sparse MLP implementation')
+
+ if self.__getattribute__('mlp_impl') == 'grouped':
+ grouped_gemm.assert_grouped_gemm_is_available()
+
+ if self.shared_expert_hidden_size is None:
+ self.shared_expert_hidden_size = self.ffn_hidden_size
+
+
+def from_megatron(megatron_args: Any):
+ args = Arguments()
+ for field in dataclasses.fields(args):
+ if hasattr(megatron_args, field.name):
+ setattr(args, field.name, getattr(megatron_args, field.name))
+ return args
diff --git a/build/torch29-cxx11-cu130-aarch64-linux/_layers/common.py b/build/torch29-cxx11-cu130-aarch64-linux/_layers/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d07109702963ba48a3b94ab860807954dfd79c1
--- /dev/null
+++ b/build/torch29-cxx11-cu130-aarch64-linux/_layers/common.py
@@ -0,0 +1,26 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+from .arguments import Arguments
+
+
+def dtype(args: Arguments):
+ if args.fp16:
+ return torch.float16
+ elif args.bf16:
+ return torch.bfloat16
+ return None
+
+
+def cast_if_autocast_enabled(tensor):
+ if torch.is_autocast_enabled():
+ if tensor.device.type == 'cuda':
+ dtype = torch.get_autocast_gpu_dtype()
+ elif tensor.device.type == 'cpu':
+ dtype = torch.get_autocast_cpu_dtype()
+ else:
+ raise NotImplementedError()
+ return tensor.to(dtype=dtype)
+ return tensor
diff --git a/build/torch29-cxx11-cu130-aarch64-linux/_layers/dmlp_registry.py b/build/torch29-cxx11-cu130-aarch64-linux/_layers/dmlp_registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..de2ed047042e438c7190ebb139b6f7f30009734c
--- /dev/null
+++ b/build/torch29-cxx11-cu130-aarch64-linux/_layers/dmlp_registry.py
@@ -0,0 +1,42 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Union
+
+from . import glu, mlp
+from .arguments import Arguments
+
+MlpType = Union[mlp.SparseMLP, glu.SparseGLU]
+
+_REGISTRY = {
+ 'mlp': {
+ 'grouped': mlp.GroupedMLP,
+ 'sparse': mlp.SparseMLP,
+ },
+ 'glu': {
+ 'grouped': glu.GroupedGLU,
+ 'sparse': glu.SparseGLU,
+ },
+}
+
+
+def get(args: Arguments) -> MlpType:
+ """Returns an MLP for use in a dMoE instance.
+
+ Uses the provided arguments to instantiate the appropriate
+ MLP instance. This only contains MLPs for use in dMoEs
+ (ie. only for the dropless versions of MoEs).
+
+ Args:
+ args: propagated Arguments dataclass.
+
+ Returns:
+ An instantiated MLP constructed using the input args.
+ """
+ if args.mlp_type not in _REGISTRY:
+ raise ValueError(f'Unsupported mlp type: {args.mlp_type}')
+
+ if args.mlp_impl not in _REGISTRY[args.mlp_type]:
+ raise ValueError(f'{args.mlp_type} does not support {args.mlp_impl} backend.',)
+
+ return _REGISTRY[args.mlp_type][args.mlp_impl](args)
diff --git a/build/torch29-cxx11-cu130-aarch64-linux/_layers/dmoe.py b/build/torch29-cxx11-cu130-aarch64-linux/_layers/dmoe.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d0375a4df2f27134c4127e60be04f3b45693050
--- /dev/null
+++ b/build/torch29-cxx11-cu130-aarch64-linux/_layers/dmoe.py
@@ -0,0 +1,337 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import numpy as np
+import torch
+
+# try:
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/dmoe.py is needed.',
+# )
+
+# import megablocks.ops as ops
+# # from megablocks.ops import ops
+# from megablocks.layers import common, dmlp_registry, moe, mpu
+# from megablocks.layers.arguments import Arguments
+
+from .. import stk
+from .. import ops
+from . import common, dmlp_registry, moe, mpu
+from .arguments import Arguments
+
+def promote_scalar(x):
+ return x.view(1) if not len(x.size()) else x
+
+
+class ParallelDroplessMLP(moe.ParallelMLP):
+
+ def __init__(self, args: Arguments):
+ super(ParallelDroplessMLP, self).__init__(args)
+ self.hidden_size = args.hidden_size
+ self.ffn_hidden_size = mpu.features_per_rank(args)
+ self.blocking = 128
+ self.mlp = dmlp_registry.get(args)
+
+ # Calculate the number of bits needed to represent the column indices
+ # in the intermediate sparse matrix.
+ max_column_index = ((self.ffn_hidden_size * self.num_experts) // self.blocking)
+ self.transpose_sort_end_bit = max(
+ int(np.ceil(np.log2(max_column_index))),
+ 1,
+ )
+
+ def sparse_transpose(self, size, row_indices, column_indices, offsets):
+ block_columns = size[1] // self.blocking
+
+ # Sort row indices by column indices to get the transposed matrix's
+ # column indices.
+ #
+ # NOTE: Our sort operation uses the same width indices as the input values.
+ # To avoid overflow when we have large activation matrices we cast to
+ # 32-bit before sorting.
+ _, gather_indices = ops.sort(
+ column_indices.int(),
+ self.transpose_sort_end_bit,
+ )
+
+ # There are a constant number of blocks in every row of the sparse matrix.
+ # A blocks offset is:
+ #
+ # row_index * blocks_per_row + column_index % blocks_per_row
+ #
+ # Once we have the block offsets ordered for transposition we can divide
+ # by blocks_per_row to get the transposed column indices.
+ column_indices_t = row_indices.gather(0, gather_indices.long())
+ block_offsets_t = gather_indices.int()
+
+ zero = torch.zeros((1,), dtype=torch.int32, device=row_indices.device)
+ nnz_per_column = ops.histogram(column_indices, block_columns)
+ nnz_per_column = ops.inclusive_cumsum(nnz_per_column, 0)
+ if nnz_per_column.dim() == 0:
+ # This addresses an edge case when ffn_hidden_size is equal to self.blocking.
+ nnz_per_column = nnz_per_column.unsqueeze(0)
+ offsets_t = torch.cat([zero, nnz_per_column])
+ return column_indices_t, offsets_t, block_offsets_t
+
+ def topology(self, x, padded_bins):
+ padded_tokens, _ = x.size()
+ assert padded_tokens % self.blocking == 0
+ if self.ffn_hidden_size % self.blocking != 0:
+ raise ValueError(
+ f'The ffn_hidden_size {self.ffn_hidden_size} must be divisible by ' +
+ f'the block size {self.blocking}. Please update your configuration.',
+ )
+
+ # Offsets for the sparse matrix. All rows have the
+ # same number of nonzero blocks dictated by the
+ # dimensionality of a single expert.
+ block_rows = padded_tokens // self.blocking
+ blocks_per_row = self.ffn_hidden_size // self.blocking
+ offsets = torch.arange(
+ 0,
+ block_rows * blocks_per_row + 1,
+ blocks_per_row,
+ dtype=torch.int32,
+ device=x.device,
+ )
+
+ # Indices for the sparse matrix. The indices for
+ # the intermediate matrix are dynamic depending
+ # on the mapping of tokens to experts.
+ column_indices = ops.topology(
+ padded_bins,
+ self.blocking,
+ block_rows,
+ blocks_per_row,
+ )
+
+ # TODO(tgale): This is unused. Remove the need for this in stk.
+ # For now, use meta init to save the device memory.
+ data = torch.empty(
+ column_indices.numel(),
+ self.blocking,
+ self.blocking,
+ dtype=common.dtype(self.args),
+ device='meta',
+ )
+ shape = (
+ padded_tokens,
+ self.ffn_hidden_size * mpu.experts_per_rank(self.args),
+ )
+ row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
+ column_indices_t, offsets_t, block_offsets_t = self.sparse_transpose(
+ shape,
+ row_indices,
+ column_indices,
+ offsets,
+ )
+ return stk.Matrix(
+ shape,
+ data,
+ row_indices,
+ column_indices,
+ offsets,
+ column_indices_t,
+ offsets_t,
+ block_offsets_t,
+ )
+
+ def indices_and_padded_bins(self, top_experts):
+ # Sort the expert ids to produce the scatter/gather
+ # indices for the permutation.
+ top_experts = top_experts.int()
+ bin_ids, indices = ops.sort(top_experts, self.sort_end_bit)
+
+ # Histogram the expert ids to identify the number of
+ # tokens routed to each expert.
+ tokens_per_expert = ops.histogram(top_experts, self.num_experts)
+
+ # Round the token counts up to the block size used in
+ # the matrix muliplications. Caculate the starting
+ # position of each bin.
+ padded_tokens_per_expert = ops.round_up(
+ tokens_per_expert,
+ self.blocking,
+ )
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ padded_bins = promote_scalar(padded_bins)
+
+ # Calculate the bin bounds for the sorted tokens.
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ bins = promote_scalar(bins)
+ return indices, bin_ids, bins, padded_bins, tokens_per_expert
+
+ def sparse_forward_once(self, x, expert_weights, top_experts):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, padded_bins, tokens_per_expert = (self.indices_and_padded_bins(top_experts))
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.padded_gather(
+ x,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ self.top_k,
+ )
+
+ # Create the sparse matrix topology.
+ with torch.no_grad():
+ topo = self.topology(x, padded_bins)
+
+ # Perform the expert computation.
+ x = self.mlp(x, topo)
+
+ # Un-route the data for the MoE output.
+ x = ops.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ padded_bins,
+ self.top_k,
+ )
+ return x, tokens_per_expert
+
+ # For use in the base-class parallel_forward_once.
+ def sparse_permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy, # unused
+ top_k,
+ ):
+
+ # Round the token counts up to the block size used in the matrix
+ # multiplication. Calculate the starting position of each bin.
+ padded_tokens_per_expert = ops.round_up(
+ tokens_per_expert,
+ self.blocking,
+ )
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ padded_bins = promote_scalar(padded_bins)
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
+
+ # Create the sparse matrix topology.
+ with torch.no_grad():
+ topo = self.topology(x, padded_bins)
+
+ # Perform the expert computation.
+ x = self.mlp(x, topo)
+
+ # Un-route the data for the MoE output.
+ return ops.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ def grouped_forward_once(self, x, expert_weights, top_experts):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ out = self.grouped_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ -1, # unused
+ self.args.moe_top_k,
+ )
+ return out, tokens_per_expert
+
+ def grouped_permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy, # unused
+ top_k,
+ ):
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.gather(x, indices, bin_ids, bins, top_k)
+
+ # Perform the expert computation.
+ x = self.mlp(x, tokens_per_expert)
+
+ # Un-route the data for the MoE output.
+ return ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
+
+ def forward_once(self, x, expert_weights, top_experts):
+ if self.args.mlp_impl == 'sparse':
+ return self.sparse_forward_once(x, expert_weights, top_experts)
+ else:
+ return self.grouped_forward_once(x, expert_weights, top_experts)
+
+ def permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ ):
+ if self.args.mlp_impl == 'sparse':
+ return self.sparse_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ )
+ else:
+ return self.grouped_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ )
+
+
+class dMoE(moe.MoE):
+
+ def _init_experts_mlp(self, args: Arguments):
+ return ParallelDroplessMLP(args)
diff --git a/build/torch29-cxx11-cu130-aarch64-linux/_layers/gelu.py b/build/torch29-cxx11-cu130-aarch64-linux/_layers/gelu.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4c9e6532798615b5c12c96694241a4c18ee8f7b
--- /dev/null
+++ b/build/torch29-cxx11-cu130-aarch64-linux/_layers/gelu.py
@@ -0,0 +1,52 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# try:
+# import stk
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/gelu.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+import torch.nn.functional as F
+
+
+@torch.jit.script
+def _gelu_backward_inplace(g, x):
+ tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
+ ff = (0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out))
+ return g.mul_(ff)
+
+
+def gelu_backward_(grad: stk.Matrix, x: stk.Matrix):
+ # NOTE: The two sparse matrices must have the same topology.
+ if isinstance(grad, stk.Matrix) and isinstance(x, stk.Matrix):
+ return stk.Matrix(
+ x.size(),
+ _gelu_backward_inplace(grad.data, x.data),
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
+ return _gelu_backward_inplace(grad, x)
+
+
+def gelu(x: stk.Matrix):
+ assert isinstance(x, stk.Matrix)
+ return stk.Matrix(
+ x.size(),
+ F.gelu(x.data, approximate='tanh'),
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
diff --git a/build/torch29-cxx11-cu130-aarch64-linux/_layers/glu.py b/build/torch29-cxx11-cu130-aarch64-linux/_layers/glu.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f297a41ff6a1a2a285f5b461951672364b898da
--- /dev/null
+++ b/build/torch29-cxx11-cu130-aarch64-linux/_layers/glu.py
@@ -0,0 +1,244 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# import stk.ops
+# try:
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/glu.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+
+# from megablocks import grouped_gemm_util as gg
+# from megablocks.layers import common, mpu
+# from megablocks.layers.activation_fn import act_fn
+# from megablocks.layers.arguments import Arguments
+# from megablocks.layers.mlp import (
+# SharedMLP,
+# SparseMLP,
+# create_dmoe_expert_weights,
+# resolve_dtensor,
+# )
+
+from .. import grouped_gemm_util as gg
+from . import common, mpu
+from .activation_fn import act_fn
+from .arguments import Arguments
+from .mlp import (
+ SharedMLP,
+ SparseMLP,
+ create_dmoe_expert_weights,
+ resolve_dtensor,
+)
+
+
+class SparseGLU(SparseMLP):
+
+ def __init__(self, args: Arguments):
+ super().__init__(args)
+ self.v1 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ with torch.no_grad():
+ self.v1.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ ),
+ )
+
+ mpu.set_expert_model_parallel_attributes(
+ self.v1,
+ self._should_set_parallelism_attribute,
+ )
+
+ def forward(self, x, topo):
+ if self.args.memory_optimized_mlp:
+ raise NotImplementedError(
+ 'Memory optimized implementation not yet supported with GLU with sparse kernels.',
+ )
+
+ w1, v1, w2 = self.scale_grad(self.w1), self.scale_grad(self.v1,), self.scale_grad(self.w2)
+ w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2)
+
+ # Compute the GLU.
+ x1 = stk.ops.sdd(x, w1.t(), topo)
+ x2 = stk.ops.sdd(x, v1.t(), topo)
+
+ activation_fn_out = act_fn(x1, self.args.activation_fn)
+ x1 = stk.ops.mul(activation_fn_out, x2)
+
+ return stk.ops.dsd(x1, w2)
+
+
+class MemoryOptimizedGroupedGLU(torch.autograd.Function):
+ """GroupedMLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, v1, w2, batch_sizes, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ v1 = v1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], v1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not v1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1', 'v1' and 'w2'.")
+
+ # Layer 0: x @ w1.t().
+ assert gg.backend is not None
+ sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True)
+ v1_out = gg.backend.gmm(x, v1, batch_sizes, trans_b=True)
+
+ # GeLU.
+ activation_fn_out = activation_fn(sdd_out) * v1_out
+
+ # Layer 1: x @ w2.
+ dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, v1, w2, batch_sizes, x, sdd_out, v1_out)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx, ddsd_out):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # Unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, v1, w2 = saved_tensors[:3]
+ batch_sizes = saved_tensors[3]
+ x = saved_tensors[4]
+ sdd_out, v1_out = saved_tensors[5:7]
+
+ # Rematerialize activation_fn output.
+ activation_fn = ctx.activation_fn
+ with torch.set_grad_enabled(True):
+ sdd_out.requires_grad = True
+ v1_out.requires_grad = True
+ activation_fn_out = activation_fn(sdd_out) * v1_out
+ activation_grad_fn = activation_fn_out.backward
+
+ # Compute dw2 with recomputed activation_fn output.
+ assert gg.backend is not None
+ dw2 = gg.backend.gmm(
+ activation_fn_out,
+ ddsd_out,
+ batch_sizes,
+ trans_a=True,
+ )
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ gg.backend.gmm(
+ ddsd_out,
+ w2,
+ batch_sizes,
+ trans_b=True,
+ c=dactivation_fn_out,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out)
+ dsdd_out = sdd_out.grad
+ dv1_out = v1_out.grad
+
+ # Compute dw1.
+ dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True)
+
+ # Compute dv1.
+ dv1 = gg.backend.gmm(dv1_out, x, batch_sizes, trans_a=True)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ dx = ddsd_out
+ gg.backend.gmm(dsdd_out, w1, batch_sizes, c=dx)
+ dx += gg.backend.gmm(dv1_out, v1, batch_sizes)
+ return dx, dw1, dv1, dw2, None, None
+
+
+memory_optimized_grouped_glu = MemoryOptimizedGroupedGLU.apply
+
+
+class GroupedGLU(SparseGLU):
+
+ def forward(self, x, tokens_per_expert):
+ batch_sizes = tokens_per_expert.cpu().to(torch.long)
+ w1, v1, w2 = (
+ self.scale_grad(self.w1),
+ self.scale_grad(self.v1),
+ self.scale_grad(self.w2),
+ )
+ w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2)
+
+ # Re-shape the weights for the grouped GEMMs.
+ ne = mpu.experts_per_rank(self.args)
+ w1 = w1.view(ne, -1, self.args.hidden_size)
+ v1 = v1.view(ne, -1, self.args.hidden_size)
+ w2 = w2.view(ne, -1, self.args.hidden_size)
+
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_grouped_glu(
+ x,
+ w1,
+ v1,
+ w2,
+ batch_sizes,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ assert gg.ops is not None
+ x1 = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
+ x2 = gg.ops.gmm(x, v1, batch_sizes, trans_b=True)
+ x1 = self.args.activation_fn(x1) * x2
+ return gg.ops.gmm(x1, w2, batch_sizes)
+
+
+class SharedGLU(SharedMLP):
+ """GPU for shared expert.
+
+ Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTGLU class
+ """
+
+ def __init__(self, args: Arguments):
+ super().__init__(args)
+ self.gate_proj = args.fc_cls(
+ args.hidden_size,
+ self.args.shared_expert_hidden_size,
+ **self.fc_kwargs,
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))
diff --git a/build/torch29-cxx11-cu130-aarch64-linux/_layers/memory_test.py b/build/torch29-cxx11-cu130-aarch64-linux/_layers/memory_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..74d1166931b712635131985b25a89f4ca23e576d
--- /dev/null
+++ b/build/torch29-cxx11-cu130-aarch64-linux/_layers/memory_test.py
@@ -0,0 +1,103 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import gc
+
+import torch
+import torch.distributed as dist
+
+# from megablocks.layers import arguments, dmoe
+from . import arguments, dmoe
+
+_TESTS = ((8, 2048, 4096, 4096, 32, 4),)
+
+
+def get_tensors():
+ ptrs = set()
+ out = []
+ for obj in gc.get_objects():
+ if torch.is_tensor(obj):
+ if not obj.is_contiguous() or obj.data_ptr() in ptrs:
+ continue
+ out.append(obj)
+ ptrs.add(obj.data_ptr())
+ return out
+
+
+def test_memory(
+ group,
+ batch_size,
+ sequence_length,
+ hidden_size,
+ ffn_hidden_size,
+ num_experts,
+ top_k,
+):
+ args = arguments.Arguments(
+ hidden_size=hidden_size,
+ ffn_hidden_size=ffn_hidden_size,
+ moe_num_experts=num_experts,
+ moe_top_k=top_k,
+ moe_expert_model_parallelism=True,
+ expert_parallel_group=group,
+ fp16=False,
+ bf16=True,
+ device=torch.cuda.current_device(),
+ )
+ layer = dmoe.dMoE(args).cuda()
+
+ x = torch.randn((batch_size, sequence_length, hidden_size),
+ device=torch.cuda.current_device(),
+ dtype=torch.bfloat16).requires_grad_(True)
+ torch.cuda.empty_cache()
+
+ # Run forward + backward.
+ # with torch.autograd.detect_anomaly():
+ out, _ = layer(x)
+ out.mean().backward()
+
+ # Report peak memory.
+ mem = torch.cuda.max_memory_allocated()
+ print('Max Memory Allocated = {:0.0f}MiB'.format(mem / 1e6))
+ print('Max Memory Reserved = {:0.0f}MiB'.format(torch.cuda.max_memory_reserved() / 1e6,),)
+
+ # Calculate weight and gradient memory usage.
+ weight_memory = 2 * (
+ layer.router.layer.weight.numel() + layer.experts.mlp.w1.numel() + layer.experts.mlp.w2.numel()
+ )
+
+ def grad_numel(x):
+ if x.grad is not None:
+ return x.grad.numel()
+ return 0
+
+ grad_memory = 2 * (
+ grad_numel(layer.router.layer.weight) + grad_numel(layer.experts.mlp.w1) + grad_numel(layer.experts.mlp.w2)
+ )
+ weight_memory += grad_memory
+
+ print('Weight Memory Allocated = {:0.0f}MiB'.format(weight_memory / 1e6))
+ print('Activation Memory Allocated = {:0.0f}MiB'.format((mem - weight_memory) / 1e6,),)
+
+ # Manually calculate GPU memory usage from the garbage
+ # collector.
+ gc.collect()
+ total = 0
+ tensors = get_tensors()
+ tensors = sorted(tensors, key=lambda x: -x.numel())
+ for i, t in enumerate(tensors):
+ total += t.numel()
+ print(f'{i}: {t.shape}, {t.numel() * 2}')
+ del tensors
+
+ print('Total Bytes Found = {:0.0f}MiB'.format(total * 2 / 1e6))
+
+
+if __name__ == '__main__':
+ assert dist.is_available()
+ group = dist.init_process_group(backend='nccl')
+ local_rank = dist.get_rank(group)
+ torch.cuda.set_device(local_rank)
+
+ for args in _TESTS:
+ test_memory(group, *args)
diff --git a/build/torch29-cxx11-cu130-aarch64-linux/_layers/mlp.py b/build/torch29-cxx11-cu130-aarch64-linux/_layers/mlp.py
new file mode 100644
index 0000000000000000000000000000000000000000..c99afb9904c24a8b6a83e79059cd1251dbbfd99e
--- /dev/null
+++ b/build/torch29-cxx11-cu130-aarch64-linux/_layers/mlp.py
@@ -0,0 +1,587 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# try:
+# import stk
+# import stk.backend.triton_kernels
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/mlp.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+from packaging import version
+
+# from megablocks import grouped_gemm_util as gg
+# from megablocks.layers import common, gelu, mpu
+# from megablocks.layers.activation_fn import act_fn
+# from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
+
+from .. import grouped_gemm_util as gg
+from . import common, gelu, mpu
+from .activation_fn import act_fn
+from .arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
+
+class ScaleGradient(torch.autograd.Function):
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx: Any, x: torch.Tensor, scale: float):
+ ctx.scale = scale
+ return x
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx: torch.Tensor, grad: torch.Tensor):
+ return grad * ctx.scale, None
+
+
+scale_gradient = ScaleGradient.apply
+
+
+def resolve_dtensor(weight: torch.Tensor):
+ if version.parse(torch.__version__) >= version.parse('2.0.0'):
+ from torch.distributed._tensor import DTensor
+ if isinstance(weight, DTensor):
+ return weight.to_local()
+ return weight
+
+
+def create_moe_expert_weights(
+ args: Arguments,
+ num_experts: int,
+ ffn_hidden_size: int,
+ hidden_size: int,
+ init_method: InitFn,
+):
+ # Create the entire weight matrix such that the sampled weights will
+ # not vary between data parallelism and expert model parallelism for
+ # the same random seed.
+ master_weights = torch.empty(
+ num_experts,
+ ffn_hidden_size,
+ hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ )
+ init_method(master_weights)
+
+ if not args.moe_expert_model_parallelism:
+ return master_weights
+
+ # Calculate the amount of sharding in each dimension.
+ expert_sharding_degree = mpu.expert_sharding_degree(args)
+ hidden_sharding_degree = mpu.hidden_sharding_degree(args)
+
+ # Calculate the experts per rank.
+ #
+ # NOTE: We assign ranks to be expert parallel before going
+ # tensor parallel.
+ rank = mpu.get_expert_parallel_rank(args)
+ expert_rank = rank % expert_sharding_degree
+ num_experts_per_rank = num_experts // expert_sharding_degree
+ start_expert = expert_rank * num_experts_per_rank
+ end_expert = (expert_rank + 1) * num_experts_per_rank
+
+ # Calculate the rows per rank.
+ row_rank = rank // expert_sharding_degree
+ num_rows_per_rank = ffn_hidden_size // hidden_sharding_degree
+ start_row = row_rank * num_rows_per_rank
+ end_row = (row_rank + 1) * num_rows_per_rank
+
+ # Slice the weight matrix to get the chunk for this rank.
+ with torch.no_grad():
+ weights = master_weights[start_expert:end_expert, start_row:end_row]
+ return weights
+
+
+class MLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ # expert_parallel_world_size = mpu.get_expert_parallel_world_size(args)
+ experts_per_rank = mpu.experts_per_rank(args)
+
+ self.w1 = torch.nn.Parameter(
+ torch.empty(
+ experts_per_rank,
+ args.hidden_size,
+ mpu.features_per_rank(args),
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ self.w2 = torch.nn.Parameter(
+ torch.empty(
+ experts_per_rank,
+ mpu.features_per_rank(args),
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w1,
+ args.moe_expert_model_parallelism,
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w2,
+ args.moe_expert_model_parallelism,
+ )
+
+ # Initialize the parameters for the MLP.
+ #
+ # NOTE: It is important that we create the weight tensors prior
+ # to creating the master weights and slicing our the piece for
+ # this rank. If the master weights are created first the PyTorch
+ # caching allocator appears to use the same memory block for these
+ # and the slice which causes large increases in our peak memory
+ # usage.
+ with torch.no_grad():
+ w1 = create_moe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ )
+ self.w1.copy_(w1.transpose(1, 2).contiguous())
+ self.w2.copy_(
+ create_moe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.output_layer_init_method,
+ ),
+ )
+
+ self.gradient_scale = None
+ if self.args.moe_expert_model_parallelism:
+ self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,)
+
+ def scale_grad(self, w):
+ if self.gradient_scale is None:
+ return w
+ return scale_gradient(w, self.gradient_scale)
+
+ def forward(self, x):
+ w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2)
+ w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2)
+ x = torch.bmm(x, w1)
+ x = self.args.activation_fn(x)
+ return torch.bmm(x, w2)
+
+
+def create_dmoe_expert_weights(
+ args: Arguments,
+ num_experts: int,
+ rows: int,
+ columns: int,
+ init_method: InitFn,
+):
+ weights = create_moe_expert_weights(
+ args,
+ num_experts,
+ rows,
+ columns,
+ init_method,
+ )
+ return weights.view([-1, columns])
+
+
+class MemoryOptimizedMLP(torch.autograd.Function):
+ """Sparse MLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, w2, topo, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.")
+
+ topo_tensors = (
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets,
+ topo.column_indices_t,
+ topo.offsets_t,
+ topo.block_offsets_t,
+ )
+
+ # Layer 0: x @ w1.t().
+ sdd_out = stk.ops.sdd(x, w1.t(), topo)
+
+ # GeLU.
+ activation_fn_out = act_fn(sdd_out, activation_fn)
+
+ # Layer 1: x @ w2.
+ dsd_out = stk.ops.dsd(activation_fn_out, w2)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.shape = topo.shape
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.data.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, w2, *topo_tensors, x, sdd_out.data)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx, ddsd_out):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, w2 = saved_tensors[:2]
+ topo_tensors = saved_tensors[2:8]
+ x = saved_tensors[8]
+ sdd_out_data = saved_tensors[9]
+
+ # rematerialize activation function output
+ activation_fn = ctx.activation_fn
+ sdd_out = stk.Matrix(ctx.shape, sdd_out_data, *topo_tensors)
+ activation_fn_out, activation_grad_fn = act_fn(
+ sdd_out,
+ activation_fn,
+ return_grad_fn=True,
+ )
+
+ # Compute dw2 with recomputed activation_fn output.
+ dw2 = stk.ops.dsd(activation_fn_out.t(), ddsd_out)
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ stk.backend.triton_kernels.sdd(
+ ddsd_out,
+ w2.t(),
+ dactivation_fn_out.shape,
+ dactivation_fn_out.data,
+ dactivation_fn_out.offsets,
+ dactivation_fn_out.row_indices,
+ dactivation_fn_out.column_indices,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ if activation_fn is DEFAULT_ACTIVATION_FN:
+ dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out)
+ else:
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out.data)
+ dsdd_out = stk.Matrix(ctx.shape, sdd_out.data.grad, *topo_tensors)
+
+ # Compute dw1.
+ dw1 = stk.ops.dsd(dsdd_out.t(), x)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ stk.backend.triton_kernels.dsd(
+ dsdd_out.shape,
+ dsdd_out.data,
+ dsdd_out.offsets,
+ dsdd_out.row_indices,
+ dsdd_out.column_indices,
+ dsdd_out.offsets_t,
+ dsdd_out.column_indices_t,
+ dsdd_out.block_offsets_t,
+ False,
+ w1,
+ ddsd_out,
+ )
+ dx = ddsd_out
+ return dx, dw1, dw2, None, None
+
+
+memory_optimized_mlp = MemoryOptimizedMLP.apply
+
+
+class SparseMLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ self._num_rows_per_rank = mpu.experts_per_rank(args) * mpu.features_per_rank(args)
+
+ self.w1 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ self.w2 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+
+ # Initialize the parameters for the MLP.
+ #
+ # NOTE: It is important that we create the weight tensors prior
+ # to creating the master weights and slicing our the piece for
+ # this rank. If the master weights are created first the PyTorch
+ # caching allocator appears to use the same memory block for these
+ # and the slice which causes large increases in our peak memory
+ # usage.
+ with torch.no_grad():
+ self.w1.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ ),
+ )
+ self.w2.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.output_layer_init_method,
+ ),
+ )
+
+ self._should_set_parallelism_attribute = args.moe_expert_model_parallelism
+ mpu.set_expert_model_parallel_attributes(
+ self.w1,
+ self._should_set_parallelism_attribute,
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w2,
+ self._should_set_parallelism_attribute,
+ )
+
+ self.gradient_scale = None
+ if self.args.moe_expert_model_parallelism:
+ self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,)
+
+ def scale_grad(self, w):
+ if self.gradient_scale is None:
+ return w
+ return scale_gradient(w, self.gradient_scale)
+
+ def forward(self, x, topo):
+ w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2)
+ w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2)
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_mlp(
+ x,
+ w1,
+ w2,
+ topo,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ x = stk.ops.sdd(x, w1.t(), topo)
+ activation_fn_out = act_fn(x, self.args.activation_fn)
+ return stk.ops.dsd(activation_fn_out, w2)
+
+
+class MemoryOptimizedGroupedMLP(torch.autograd.Function):
+ """GroupedMLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, w2, batch_sizes, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.")
+
+ # Layer 0: x @ w1.t().
+ assert gg.backend is not None
+ sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True)
+
+ # activation_fn
+ activation_fn_out = activation_fn(sdd_out)
+
+ # Layer 1: x @ w2.
+ dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, w2, batch_sizes, x, sdd_out)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx: Any, ddsd_out: torch.Tensor):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # Unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, w2 = saved_tensors[:2]
+ batch_sizes = saved_tensors[2]
+ x = saved_tensors[3]
+ sdd_out = saved_tensors[4]
+
+ # Rematerialize activation_fn output.
+ activation_fn = ctx.activation_fn
+ with torch.set_grad_enabled(True):
+ sdd_out.requires_grad = True
+ activation_fn_out = activation_fn(sdd_out)
+ activation_grad_fn = activation_fn_out.backward
+
+ # Compute dw2 with recomputed activation_fn output.
+ assert gg.backend is not None
+ dw2 = gg.backend.gmm(
+ activation_fn_out,
+ ddsd_out,
+ batch_sizes,
+ trans_a=True,
+ )
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ gg.backend.gmm(
+ ddsd_out,
+ w2,
+ batch_sizes,
+ trans_b=True,
+ c=dactivation_fn_out,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ if activation_fn is DEFAULT_ACTIVATION_FN:
+ dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out)
+ else:
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out)
+ dsdd_out = sdd_out.grad
+
+ # Compute dw1.
+ dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ gg.backend.gmm(dsdd_out, w1, batch_sizes, c=ddsd_out)
+ dx = ddsd_out
+ return dx, dw1, dw2, None, None
+
+
+memory_optimized_grouped_mlp = MemoryOptimizedGroupedMLP.apply
+
+
+class GroupedMLP(SparseMLP):
+
+ def forward(self, x, tokens_per_expert):
+ batch_sizes = tokens_per_expert.cpu().to(torch.long)
+ w1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.w2))
+
+ # Re-shape the weights for the grouped GEMMs.
+ ne = mpu.experts_per_rank(self.args)
+ w1 = resolve_dtensor(w1).view(ne, -1, self.args.hidden_size)
+ w2 = resolve_dtensor(w2).view(ne, -1, self.args.hidden_size)
+
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_grouped_mlp(
+ x,
+ w1,
+ w2,
+ batch_sizes,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ assert gg.ops is not None
+ x = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
+ x = self.args.activation_fn(x)
+ return gg.ops.gmm(x, w2, batch_sizes)
+
+
+class SharedMLP(torch.nn.Module):
+ """MLP for shared expert.
+
+ Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTMLP class
+ """
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ self.fc_kwargs: dict[str, Any] = {
+ 'bias': args.bias,
+ 'device': args.device,
+ }
+ self.fc_kwargs.update(args.fc_kwargs)
+
+ self.up_proj = args.fc_cls(
+ args.hidden_size,
+ args.shared_expert_hidden_size,
+ **self.fc_kwargs,
+ )
+ self.act = args.activation_fn
+ self.down_proj = args.fc_cls(
+ args.shared_expert_hidden_size,
+ args.hidden_size,
+ **self.fc_kwargs,
+ )
+ self.down_proj._is_residual = True # a flag for llm-foundry init
+
+ def add_experts_sharedexpert(
+ self,
+ shared_expert_out: torch.Tensor,
+ expert_out: torch.Tensor,
+ ) -> torch.Tensor:
+ # Helper function to add expert output to shared expert output
+ # with optional weighted sum.
+ if self.args.shared_expert_weighted_sum:
+ # enable using weighted sum for shared expert output
+ # wieghted by number of experts used
+ t_experts = self.args.moe_top_k + 1
+ sh_mlp_out = shared_expert_out / t_experts
+ return sh_mlp_out.add(
+ expert_out,
+ alpha=(self.args.moe_top_k / t_experts),
+ )
+
+ return shared_expert_out + expert_out
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.down_proj(self.act(self.up_proj(x)))
diff --git a/build/torch29-cxx11-cu130-aarch64-linux/_layers/moe.py b/build/torch29-cxx11-cu130-aarch64-linux/_layers/moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0a4aeaacc9c86fc70944e730c53f7a55644e05e
--- /dev/null
+++ b/build/torch29-cxx11-cu130-aarch64-linux/_layers/moe.py
@@ -0,0 +1,507 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Optional, Tuple
+
+import numpy as np
+import torch
+import torch.distributed as dist
+
+# import megablocks.ops as ops
+# from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry
+# from megablocks.layers.all_to_all import all_to_all
+# from megablocks.layers.arguments import Arguments
+
+from ..ops import (
+ sort,
+ histogram,
+ inclusive_cumsum,
+ exclusive_cumsum,
+ binned_gather,
+ binned_scatter,
+ gather,
+ scatter,
+ repeat,
+ replicate,
+)
+
+from . import common, mlp, mpu, router, sharedexpert_registry
+from .arguments import Arguments
+from .all_to_all import all_to_all
+
+_LOAD_BALANCING_LOSS = []
+
+
+def save_load_balancing_loss(loss):
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.append(loss)
+
+
+def get_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ return _LOAD_BALANCING_LOSS
+
+
+def clear_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.clear()
+
+
+def batched_load_balancing_loss(args: Arguments):
+ if args.moe_loss_weight == 0:
+ return 0.0
+
+ # tokens_per_expert[i].shape = (num_experts)
+ # expert_scores[i].shape = (tokens, num_experts)
+ tokens_per_expert, expert_scores = zip(*get_load_balancing_loss())
+ num_layers_per_pipeline_stage = (args.num_layers // args.pipeline_model_parallel_size)
+ if args.num_layers_per_virtual_pipeline_stage is not None:
+ num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage
+
+ if len(tokens_per_expert) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f'Expected {num_layers_per_pipeline_stage} token_per_experts '
+ f'but found {len(tokens_per_expert)}.\nnum_layers = '
+ f'{args.num_layers}\npipeline_model_parallel_size = '
+ f'{args.pipeline_model_parallel_size}\n'
+ 'num_layers_per_virtual_pipeline_stage'
+ f' = {args.num_layers_per_virtual_pipeline_stage}',
+ )
+ if len(expert_scores) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f'Expected {num_layers_per_pipeline_stage} expert_scores '
+ f'but found {len(tokens_per_expert)}.\nnum_layers = '
+ f'{args.num_layers}\npipeline_model_parallel_size = '
+ f'{args.pipeline_model_parallel_size}\n'
+ 'num_layers_per_virtual_pipeline_stage'
+ f' = {args.num_layers_per_virtual_pipeline_stage}',
+ )
+
+ # Verify the shape of the tokens_per_expert and expert_scores tensors.
+ assert all((x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert))
+
+ tokens = expert_scores[0].shape[0]
+ assert all(((x.ndim == 2 and x.shape[1] == args.moe_num_experts and x.shape[0] == tokens) for x in expert_scores))
+
+ # Concatenate the contributions of each layer and convert to
+ # the correct types and formats for the dot product.
+ expert_scores = torch.cat(expert_scores, dim=1)
+ if args.moe_lbl_in_fp32:
+ expert_scores = expert_scores.float()
+ if tokens != 0:
+ expert_scores = expert_scores.mean(dim=0)
+ else:
+ expert_scores = expert_scores.sum(dim=0)
+ tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype)
+
+ expected_values = num_layers_per_pipeline_stage * args.moe_num_experts
+ assert tokens_per_expert.numel() == expected_values
+ assert expert_scores.numel() == expected_values
+
+ # Calculate the total scale across all factors.
+ #
+ # loss_weight * num_experts / (num_layers * tokens * top_k)
+ scale_numerator = (args.moe_num_experts * args.moe_loss_weight)
+ scale_denominator = (args.num_layers * tokens * args.moe_top_k)
+ scale = scale_numerator / scale_denominator
+ return scale * torch.dot(tokens_per_expert, expert_scores)
+
+
+# NOTE: This class defines MoE expert computation, including expert model parallel
+# communication. When using FSDP on top of MegaBlocks this is the module that should
+# be wrapped s.t. the weight all-gathers can be scheduled *before* the expert model
+# parallel all2all.
+class ParallelMLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super(ParallelMLP, self).__init__()
+ self.args = args
+
+ # Calculate the number of experts in total and the number of experts
+ # owned by this rank.
+ # world_size = mpu.get_expert_parallel_world_size(args)
+ self.num_experts = args.moe_num_experts
+ self.top_k = self.args.moe_top_k
+
+ # Calculate the number of bits needed to represent the expert indices
+ # so that we can pass it to radix sort.
+ self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1)
+
+ # Expert MLP.
+ self.mlp = mlp.MLP(args)
+
+ self.bias: Optional[torch.Tensor]
+ if self.args.bias:
+ # Note that the output bias is not parallelized with expert
+ # model parallelism.
+ self.bias = torch.nn.Parameter(
+ torch.empty(
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ torch.nn.init.zeros_(self.bias)
+ else:
+ self.register_parameter('bias', None)
+
+ # Select the forward function for the operating mode.
+ self.forward_fn = (self.parallel_forward_once if args.moe_expert_model_parallelism else self.forward_once)
+
+ def expert_capacity(self, tokens: int) -> int:
+ world_size = mpu.get_expert_parallel_world_size(self.args)
+ tokens_per_expert = (self.top_k * tokens * world_size / self.num_experts)
+ return int(self.args.moe_capacity_factor * tokens_per_expert)
+
+ def load_balancing_loss(self, tokens_per_expert: torch.Tensor, expert_scores: torch.Tensor):
+ """Calculate the load balancing loss contribution."""
+ assert len(expert_scores.size()) == 2
+ tokens, num_experts = expert_scores.size()
+ assert num_experts == self.num_experts
+ assert len(tokens_per_expert.size()) == 1
+ num_experts, = tokens_per_expert.size()
+ assert num_experts == self.num_experts
+ scale = self.num_experts / (tokens * self.top_k)
+ return scale * torch.dot(
+ tokens_per_expert.to(expert_scores.dtype),
+ expert_scores.mean(dim=0),
+ )
+
+ def indices_and_bins(self,
+ top_expert: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ # Sort the expert ids to produce the scatter/gather
+ # indices for the permutation.
+ #
+ # TODO(tgale): Is it worth doing this conversion to 32-bit
+ # prior? Could we place the `torch.max` operation to return
+ # 32-bit expert indices?
+ top_expert = top_expert.int()
+ # output = ops.sort(top_expert, self.sort_end_bit)
+ output = sort(top_expert, self.sort_end_bit)
+ assert output is not None
+ bin_ids, indices = output
+
+ # Histogram the expert ids to identify the number of
+ # tokens routed to each expert.
+ #
+ # TODO(tgale): Does the sorted data produce a more favorable
+ # data distribution for histogram? Or is the op parallelism
+ # worth more?
+ # tokens_per_expert = ops.histogram(top_expert, self.num_experts)
+ tokens_per_expert = histogram(top_expert, self.num_experts)
+
+ # Calculate the bin bounds for the sorted tokens.
+ # bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ bins = inclusive_cumsum(tokens_per_expert, 0)
+ assert bins is not None
+ bins = bins.view(1) if not len(bins.size()) else bins
+
+ assert isinstance(indices, torch.Tensor)
+ assert isinstance(bin_ids, torch.Tensor)
+ assert isinstance(bins, torch.Tensor)
+ assert isinstance(tokens_per_expert, torch.Tensor)
+
+ return indices, bin_ids, bins, tokens_per_expert
+
+ def permute_and_compute(
+ self,
+ x: torch.Tensor,
+ tokens_per_expert: int, # unused
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor, # unused
+ expert_weights: torch.Tensor,
+ bins: torch.Tensor,
+ expert_capacity: int,
+ top_k: int,
+ ):
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ # output = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
+ output = binned_gather(x, indices, bins, expert_capacity, top_k)
+ assert output is not None
+ x = output
+
+ # Perform the expert computation. Note that we don't
+ # use biases for these linear operations.
+ x = self.mlp(x)
+
+ # Un-route the data for the MoE output.
+ # return ops.binned_scatter(x, indices, expert_weights, bins, top_k)
+ return binned_scatter(x, indices, expert_weights, bins, top_k)
+
+
+ def forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ # If expert_capacity is set to zero, set the number of tokens
+ # per expert to the maximum we need to avoid dropping tokens.
+ sl, bs, _ = x.size()
+ expert_capacity = self.expert_capacity(sl * bs)
+ if expert_capacity == 0:
+ expert_capacity = torch.max(tokens_per_expert).item()
+
+ x = self.permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ self.top_k,
+ )
+ return x, tokens_per_expert
+
+ def parallel_forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ # NOTE: This function implements the same computation as forward_once
+ # but with expert model parallelism.
+ #
+ # 1. Permute the tokens locally so that they are grouped by their
+ # expert assignments. This allows us to transfer all of the tokens
+ # for a remote device in one communication primitive.
+ #
+ # 2. Permute the tokens across the expert parallel devices. After
+ # this is completed each device has all of the tokens assigned to
+ # its set of experts in its local HBM.
+ #
+ # 3. Permute the tokens locally so that they are grouped by their
+ # expert assignement. After the distributed permutation the tokens
+ # are grouped by which device they came from. We re-order them
+ # locally to allow for efficient computation.
+ #
+ # After this series of permutations we compute the linear layers
+ # and then repeat these three steps in reverse to produce the final
+ # output.
+ #
+ # Compute the mapping of local tokens to experts.
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ # If we're sharding the experts along the hidden dimension
+ # multiple devices own parts of the same sets of experts.
+ # Replicate the token counts so every device gets the counts.
+ # repeated_tokens_per_expert = ops.repeat(
+ repeated_tokens_per_expert = repeat(
+ tokens_per_expert,
+ (mpu.hidden_sharding_degree(self.args),),
+ )
+
+ # Pass token count information to the device on which the
+ # target expert resides.
+ parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert,)
+ tpe_handle = dist.all_to_all_single(
+ parallel_tokens_per_expert,
+ repeated_tokens_per_expert,
+ group=self.args.expert_parallel_group,
+ async_op=True,
+ )
+
+ # Permute locally and without any padding so that tokens for each
+ # parallel device are stored contiguously.
+ #
+ # This view updates the shape of the tensor from [sl, bs, hs] to
+ # [sl * bs, hs] prior to the permutation.
+ x = x.view(-1, x.shape[-1])
+ # output = ops.gather(x, indices, bin_ids, bins, self.top_k)
+ output = gather(x, indices, bin_ids, bins, self.top_k)
+ assert output is not None
+ x = output
+
+ # Compute the number of tokens that will be received from each
+ # device and permute the input data across the devices.
+ with torch.no_grad():
+ tpe_handle.wait()
+ experts_per_rank = mpu.experts_per_rank(self.args)
+
+ # Reshape to [world_size, num_experts_per_rank].
+ world_size = mpu.get_expert_parallel_world_size(self.args)
+ repeated_tokens_per_expert = (repeated_tokens_per_expert.view(world_size, experts_per_rank))
+ parallel_tokens_per_expert = (parallel_tokens_per_expert.view(world_size, experts_per_rank))
+
+ # TODO(tgale): It might be faster to do this on the GPU and
+ # then communicate the results back to the host.
+ send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1)
+ parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
+ recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1)
+
+ # Convert the send/recv counts to lists.
+ send_counts = send_counts.tolist()
+ recv_counts = recv_counts.tolist()
+ tokens_received = sum(recv_counts)
+
+ # If we're sharding the experts along the hidden dimension
+ # multiple devices own parts of the same sets of experts.
+ # Replicate the token counts so devices that share experts
+ # get all of the tokens assigned to them.
+ #
+ # TODO(tgale): Fuse this into the prior, local permutation.
+ # x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
+ x = repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
+
+ # Start the cross-device permutation asynchronously so we can
+ # overlap communication with computation.
+ parallel_x, parallel_x_handle = all_to_all(
+ x,
+ recv_counts,
+ send_counts,
+ self.args.expert_parallel_group,
+ async_op=True,
+ )
+
+ with torch.no_grad():
+ # After we do the cross-device permutation we have the tokens on the
+ # correct device but not yet grouped by expert because we received
+ # tokens from each device as contiguous chunks. To group the tokens
+ # for expert computation we'll do one more local permutation. The
+ # rest of this torch.no_grad() scope sets up the indices and bins
+ # for this permutation.
+ # replicate_bins = ops.inclusive_cumsum(
+ replicate_bins = inclusive_cumsum(
+ parallel_tokens_per_expert.flatten(),
+ 0,
+ )
+ replicate_bins = (replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins)
+
+ # Construct the expert indices for the permuted tokens.
+ parallel_top_expert = torch.remainder(
+ torch.arange(
+ self.num_experts * mpu.hidden_sharding_degree(self.args),
+ dtype=torch.int32,
+ device=indices.device,
+ ),
+ mpu.experts_per_rank(self.args),
+ )
+ # parallel_top_expert = ops.replicate(
+ parallel_top_expert = replicate(
+ parallel_top_expert.unsqueeze(dim=0),
+ replicate_bins,
+ tokens_received,
+ ).flatten()
+
+ # TODO(tgale): The sort_end_bit here can be reduced.
+ # parallel_bin_ids, parallel_indices = ops.sort(
+ parallel_bin_ids, parallel_indices = sort(
+ parallel_top_expert,
+ self.sort_end_bit,
+ )
+
+ # Calculate the bins boundaries from the token counts.
+ parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+ # parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = (parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins)
+
+ # If expert_capacity is set to zero, set the number of tokens
+ # per expert to the maximum we need to avoid dropping tokens.
+ tokens, _ = x.size()
+ expert_capacity = self.expert_capacity(tokens)
+ if expert_capacity == 0:
+ expert_capacity = torch.max(parallel_tokens_per_expert).item()
+
+ # Locally permute the tokens and perform the expert computation.
+ # Block to make sure that the cross-device permutation is complete.
+ if self.args.mlp_impl == 'grouped':
+ # GroupedMLP requires counts on CPU. We can use the tensor already
+ # moved to CPU for the prior all_to_all, which avoids an extra
+ # device synchronization.
+ parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+ parallel_x_handle.wait()
+ parallel_x = self.permute_and_compute(
+ parallel_x,
+ parallel_tokens_per_expert,
+ parallel_indices,
+ parallel_bin_ids,
+ None, # expert_weights
+ parallel_bins,
+ expert_capacity,
+ top_k=1,
+ )
+
+ # Un-permute the tokens across the devices.
+ x, _ = all_to_all(
+ parallel_x,
+ send_counts,
+ recv_counts,
+ self.args.expert_parallel_group,
+ )
+
+ # Reduce along the hidden sharding to get the final outputs.
+ #
+ # TODO(tgale): Fuse this into the following local permutation.
+ shape = (
+ mpu.hidden_sharding_degree(self.args),
+ -1,
+ self.args.hidden_size,
+ )
+ # x = ops.sum(x.view(shape), dim=0)
+ x = x.view(shape).sum(dim=0)
+
+ # Un-permute locally to setup for the next series of operations.
+ # x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
+ x = scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
+ return x, tokens_per_expert.flatten()
+
+ def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ in_shape = x.size()
+
+ # Compute the experts.
+ x, tokens_per_expert = self.forward_fn(x, expert_weights, top_experts)
+ if self.training and self.args.moe_loss_weight > 0:
+ save_load_balancing_loss((tokens_per_expert, scores))
+ x = x.view(in_shape)
+ if self.bias is not None:
+ if self.args.return_bias:
+ return x, self.bias
+ return x + self.bias
+ return x
+
+
+class MoE(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super(MoE, self).__init__()
+
+ # Token router.
+ self.router = router.LearnedRouter(args)
+
+ # Expert computation helper.
+ self.experts = self._init_experts_mlp(args)
+
+ self.shared_expert = None
+ if args.shared_expert:
+ # SharedExpert computation helper.
+ self.shared_expert = sharedexpert_registry.get(args)
+
+ def _init_experts_mlp(self, args: Arguments):
+ return ParallelMLP(args)
+
+ def forward(self, x: torch.Tensor):
+ # NOTE: If we're going to cast the activations to lower precision
+ # do it before we permute the tokens to save bandwidth.
+ x = common.cast_if_autocast_enabled(x)
+
+ # Compute the expert scores and assignments.
+ scores, expert_weights, top_experts = self.router(x)
+
+ # Compute the experts.
+ out = self.experts(x, scores, expert_weights, top_experts)
+ if self.shared_expert is not None:
+ shared_expert_out = self.shared_expert(x)
+ out = self.shared_expert.add_experts_sharedexpert(
+ shared_expert_out,
+ out,
+ )
+ return out
diff --git a/build/torch29-cxx11-cu130-aarch64-linux/_layers/mpu.py b/build/torch29-cxx11-cu130-aarch64-linux/_layers/mpu.py
new file mode 100644
index 0000000000000000000000000000000000000000..434e143ab42bf3f83406d69e9dd1f72777716e22
--- /dev/null
+++ b/build/torch29-cxx11-cu130-aarch64-linux/_layers/mpu.py
@@ -0,0 +1,94 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Optional
+
+import torch
+import torch.distributed as dist
+
+# from megablocks.layers.arguments import Arguments
+from .arguments import Arguments
+
+
+class MoeParam(torch.Tensor):
+
+ def __init__(self):
+ super().__init__(self)
+ self.expert_model_parallel: bool
+
+
+def is_moe_param(tensor: torch.Tensor) -> bool:
+ return hasattr(tensor, 'expert_model_parallel')
+
+
+def get_expert_parallel_world_size(args: Arguments) -> int:
+ return (dist.get_world_size(args.expert_parallel_group) if args.moe_expert_model_parallelism else 1)
+
+
+def get_expert_parallel_rank(args: Arguments) -> int:
+ return (dist.get_rank(args.expert_parallel_group) if args.moe_expert_model_parallelism else 0)
+
+
+def set_expert_model_parallel_attributes(
+ tensor: torch.Tensor,
+ is_parallel: bool,
+):
+ assert not hasattr(tensor, 'expert_model_parallel')
+ setattr(tensor, 'expert_model_parallel', is_parallel)
+
+
+def param_is_expert_model_parallel(param: MoeParam) -> bool:
+ return (hasattr(param, 'expert_model_parallel') and param.expert_model_parallel)
+
+
+def copy_expert_model_parallel_attributes(
+ destination_tensor: torch.Tensor,
+ source_tensor: torch.Tensor,
+):
+ if hasattr(source_tensor, 'expert_model_parallel'):
+ setattr(
+ destination_tensor,
+ 'expert_model_parallel',
+ getattr(source_tensor, 'expert_model_parallel'),
+ )
+
+
+def synchronized_print(group: Optional[dist.ProcessGroup], *x: torch.Tensor):
+ world_size = dist.get_world_size(group)
+ rank = dist.get_rank(group)
+ for i in range(world_size):
+ dist.barrier(group)
+ if i == rank:
+ print(f'rank = {rank}', *x)
+
+
+# Helpers for expert/tensor sharding.
+def expert_sharding_degree(args: Arguments) -> int:
+ world_size = get_expert_parallel_world_size(args)
+ esd = min(world_size, args.moe_num_experts)
+
+ if (args.moe_num_experts % esd) != 0:
+ raise ValueError(f'Cannot shard {args.moe_num_experts} experts {esd} ways.',)
+ return esd
+
+
+def hidden_sharding_degree(args: Arguments) -> int:
+ world_size = get_expert_parallel_world_size(args)
+ esd = expert_sharding_degree(args)
+ hsd = world_size // esd
+
+ if (args.ffn_hidden_size % hsd) != 0:
+ raise ValueError(f'Cannot shard {args.ffn_hidden_size} features {hsd} ways.',)
+ if (esd * hsd) != world_size:
+ raise ValueError(
+ f"Invalid sharding. 'expert_sharding_degree' ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size}).",
+ )
+ return hsd
+
+
+def experts_per_rank(args: Arguments) -> int:
+ return args.moe_num_experts // expert_sharding_degree(args)
+
+
+def features_per_rank(args: Arguments) -> int:
+ return args.ffn_hidden_size // hidden_sharding_degree(args)
diff --git a/build/torch29-cxx11-cu130-aarch64-linux/_layers/router.py b/build/torch29-cxx11-cu130-aarch64-linux/_layers/router.py
new file mode 100644
index 0000000000000000000000000000000000000000..37cb2782348d62583376f1a183c7ede83601216d
--- /dev/null
+++ b/build/torch29-cxx11-cu130-aarch64-linux/_layers/router.py
@@ -0,0 +1,116 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+
+# from megablocks.layers import common
+# from megablocks.layers.arguments import Arguments
+from . import common
+from .arguments import Arguments
+
+_ROUTER_LOGITS = []
+
+
+def _save_router_logits(logits: torch.Tensor, args: Arguments):
+ if args.moe_zloss_weight == 0:
+ return
+ global _ROUTER_LOGITS
+ _ROUTER_LOGITS.append(logits)
+
+
+def clear_router_zloss():
+ global _ROUTER_LOGITS
+ _ROUTER_LOGITS.clear()
+
+
+def batched_router_zloss(args: Arguments):
+ global _ROUTER_LOGITS
+
+ if args.moe_zloss_weight == 0:
+ import warnings
+ warnings.warn('Call to batched_router_zloss, but moe_zloss_weight=0')
+ return 0
+
+ logits_per_router = _ROUTER_LOGITS
+
+ if args.moe_zloss_in_fp32:
+ logits_per_router = [logits.float() for logits in logits_per_router]
+
+ unscaled_zloss_per_router = torch.stack([
+ torch.logsumexp(logits, dim=1).square().mean() for logits in logits_per_router
+ ])
+
+ return args.moe_zloss_weight * unscaled_zloss_per_router
+
+
+# NOTE: To enable end-to-end benchmarking without convergence we
+# support a flag to force the router to assign tokens uniformly
+# across the experts. We do this with a custom autograd operation
+# so that PyTorch still executes the full set of router operation.
+class _UniformExpertAssignment(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, num_experts: int):
+ out = torch.arange(x.numel(), dtype=x.dtype, device=x.device)
+ out = torch.remainder(out, num_experts)
+ return out.view(x.shape)
+
+
+_uniform_expert_assignment = _UniformExpertAssignment.apply
+
+
+class LearnedRouter(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+
+ # Learned router parameters.
+ #
+ # NOTE: This weight matrix is not parallelized with expert model
+ # parallelism. Each device needs the entire router weight matrix
+ # so that it can route its batch of data correctly.
+ self.layer = torch.nn.Linear(
+ args.hidden_size,
+ args.moe_num_experts,
+ bias=False,
+ dtype=common.dtype(args),
+ device=args.device,
+ )
+ args.init_method(self.layer.weight)
+
+ def jitter(self, x: torch.Tensor):
+ low: float = 1.0 - self.args.moe_jitter_eps
+ high: float = 1.0 + self.args.moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return low + noise * (high - low)
+
+ def _top_k(self, scores: torch.Tensor):
+ if self.args.moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, self.args.moe_top_k, dim=-1)
+
+ def forward(self, x: torch.Tensor):
+ if self.training and self.args.moe_jitter_eps is not None:
+ x = x * self.jitter(x)
+
+ logits = self.layer(x.view(-1, x.shape[-1]))
+ _save_router_logits(logits, self.args)
+ scores = logits.softmax(dim=-1)
+ expert_weights, expert_indices = self._top_k(scores)
+ if self.args.moe_normalize_expert_weights:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=self.args.moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+
+ expert_indices = (
+ _uniform_expert_assignment(
+ expert_indices,
+ self.args.moe_num_experts,
+ ) if self.args.uniform_expert_assignment else expert_indices
+ )
+ return scores, expert_weights, expert_indices
diff --git a/build/torch29-cxx11-cu130-aarch64-linux/_layers/sharedexpert_registry.py b/build/torch29-cxx11-cu130-aarch64-linux/_layers/sharedexpert_registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..5840862f88f370ace5fd49bd0612fc98d186cc49
--- /dev/null
+++ b/build/torch29-cxx11-cu130-aarch64-linux/_layers/sharedexpert_registry.py
@@ -0,0 +1,32 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Union
+
+# from megablocks.layers import glu, mlp
+# from megablocks.layers.arguments import Arguments
+from . import glu, mlp
+from .arguments import Arguments
+
+_REGISTRY = {
+ 'mlp': mlp.SharedMLP,
+ 'glu': glu.SharedGLU,
+}
+
+
+def get(args: Arguments) -> Union[mlp.SharedMLP, glu.SharedGLU]:
+ """Returns an SharedMLP for use in a dMoE instance.
+
+ Uses the provided arguments to instantiate the appropriate
+ SharedMLP instance.
+
+ Args:
+ args: propagated Arguments dataclass.
+
+ Returns:
+ An instantiated SharedMLP constructed using the input args.
+ """
+ if args.mlp_type not in _REGISTRY:
+ raise ValueError(f'Unsupported mlp type: {args.mlp_type}')
+
+ return _REGISTRY[args.mlp_type](args)
diff --git a/build/torch29-cxx11-cu130-aarch64-linux/_megablocks_cuda_6e04dec.abi3.so b/build/torch29-cxx11-cu130-aarch64-linux/_megablocks_cuda_6e04dec.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..51d57a635e3e06c4a428fea2de0175b62370f823
--- /dev/null
+++ b/build/torch29-cxx11-cu130-aarch64-linux/_megablocks_cuda_6e04dec.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8669b2a5cf6f36ab1d6c518040d4f4e2874d7b1c5880b4424d21f89c60e77c5f
+size 12070448
diff --git a/build/torch29-cxx11-cu130-aarch64-linux/_ops.py b/build/torch29-cxx11-cu130-aarch64-linux/_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..b2f202b8db3c3f3028303ab4308cf35f950e2c74
--- /dev/null
+++ b/build/torch29-cxx11-cu130-aarch64-linux/_ops.py
@@ -0,0 +1,9 @@
+import torch
+from . import _megablocks_cuda_6e04dec
+ops = torch.ops._megablocks_cuda_6e04dec
+
+def add_op_namespace_prefix(op_name: str):
+ """
+ Prefix op by namespace.
+ """
+ return f"_megablocks_cuda_6e04dec::{op_name}"
diff --git a/build/torch29-cxx11-cu130-aarch64-linux/_version.py b/build/torch29-cxx11-cu130-aarch64-linux/_version.py
new file mode 100644
index 0000000000000000000000000000000000000000..c55783177af19bc03654c730c4892df8f8532279
--- /dev/null
+++ b/build/torch29-cxx11-cu130-aarch64-linux/_version.py
@@ -0,0 +1,6 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+"""The MegaBlocks Version."""
+
+__version__ = '0.11.0.dev0'
diff --git a/build/torch29-cxx11-cu130-aarch64-linux/backend/__init__.py b/build/torch29-cxx11-cu130-aarch64-linux/backend/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d4e43e9b7e6a9a1c3bde2df34914643ca5d8332
--- /dev/null
+++ b/build/torch29-cxx11-cu130-aarch64-linux/backend/__init__.py
@@ -0,0 +1,2 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
diff --git a/build/torch29-cxx11-cu130-aarch64-linux/backend/kernels.py b/build/torch29-cxx11-cu130-aarch64-linux/backend/kernels.py
new file mode 100644
index 0000000000000000000000000000000000000000..d20139324883338ddb312e4b05a72056d74491ac
--- /dev/null
+++ b/build/torch29-cxx11-cu130-aarch64-linux/backend/kernels.py
@@ -0,0 +1,557 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import triton
+import triton.language as tl
+
+# Stub triton autotune when testing in a env that does not have CUDA
+# this approach preserves the original code but enables testing without a GPU
+if torch.cuda.is_available() is False:
+ import warnings
+
+ warnings.warn("CUDA is not available. Triton autotuning is disabled.")
+
+ def _no_autotune(*args, **kwargs):
+ def deco(fn):
+ return fn
+ return deco
+
+ triton.autotune = _no_autotune
+
+
+def assert_is_tensor(x, ndim):
+ if x.ndim != ndim:
+ raise ValueError(f'Expected {ndim}-tensor but got {x.ndim}-tensor')
+
+
+def assert_is_matrix(x):
+ assert_is_tensor(x, 2)
+
+
+def assert_is_vector(x):
+ if x.ndim != 1:
+ raise ValueError(f'Expected 1-tensor but got {x.ndim}-tensor')
+
+
+def assert_equal(a, b):
+ if a != b:
+ raise ValueError(f'Expected dimensions to be equal but got {a} and {b}.',)
+
+
+# a: (tokens, hidden_size), real.
+# indices: (tokens * top_k), integer.
+# bin_ids: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+# padded_bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _padded_copy(
+ a,
+ b,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+ A_TO_B: tl.constexpr,
+ SCALE: tl.constexpr,
+):
+ # Our index into array 'a'.
+ index_a = tl.load(indices + tl.program_id(0))
+
+ # One threadblock per row in 'a'. Array 'b' has greater or equal
+ # number of rows since they could be padded.
+ bin_idx = tl.load(bin_ids + tl.program_id(0))
+
+ # Now we know what bin we're assigned to, but we need to know how
+ # many threadblocks were assigned to earlier bins so we can offset
+ # in our bin properly.
+ offset_in_bin = tl.program_id(0)
+ if bin_idx > 0:
+ offset_in_bin -= tl.load(bins + bin_idx - 1)
+
+ # Load the starting index of our bin in array 'b'.
+ index_b = offset_in_bin
+ if bin_idx > 0:
+ index_b += tl.load(padded_bins + bin_idx - 1)
+
+ # Offset the input and output pointers.
+ #
+ # If we're going from A to B, divide the input index to copy
+ # the same input repeatedly. If we're going from B to A we
+ # need to reduce the result. Using atomics is slow, so we
+ # do the reduce step in a second kernel.
+ offset = index_a // TOP_K if A_TO_B else index_a
+ a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS)
+ b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ # Load the scale, if requested.
+ scale = tl.load(weights + index_a) if SCALE else 1
+
+ # Swap the pointers depending on the direction.
+ iptr = a if A_TO_B else b
+ optr = b if A_TO_B else a
+
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ x = tl.load(iptr + offsets, mask=mask)
+ x = x.to(tl.float32) * scale.to(tl.float32)
+
+ tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask)
+
+ offsets += BLOCK_X
+
+
+def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+ assert_equal(bin_ids.shape[0], x.shape[0] * top_k)
+ assert_equal(bins.size(), padded_bins.size())
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ # NOTE: Because of the padding, the output size is dynamic.
+ # We load the final padded bin bound to get the output rows.
+ output_rows = padded_bins[-1].cpu().item()
+ out = torch.zeros((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ x,
+ out,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def gather(x, indices, bin_ids, weights, bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+ assert_equal(bin_ids.shape[0], x.shape[0] * top_k)
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ # NOTE: There is no padding so the output rows equals the
+ # input rows multiplied by top_k.
+ output_rows = x.shape[0] * top_k
+ out = torch.empty((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ x,
+ out,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], bin_ids.shape[0])
+ assert_equal(bins.size(), padded_bins.size())
+
+ if weights is not None:
+ assert_equal(indices.shape[0], weights.shape[0])
+
+ tokens = indices.shape[0] // top_k
+ out = torch.empty((tokens, top_k, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ out,
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=False,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+
+ # Reduce along the top-k dimension, if needed.
+ return out.sum(dim=1) if top_k > 1 else out.view(tokens, x.shape[1])
+
+
+def scatter(x, indices, bin_ids, weights, bins, top_k):
+ return padded_scatter(x, indices, bin_ids, weights, bins, bins, top_k)
+
+
+# x: (tokens, top_k, hidden_size), real
+# grad: (tokens, hidden_size), real.
+# wgrad: (tokens, top_k), real.
+# indices: (tokens * top_k), integer.
+# bin_ids: (tokens * top_k), integer.
+# bins: (num_experts), integer.
+# padded_bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _padded_copy_wgrad(
+ x,
+ grad,
+ wgrad,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+):
+ # Our index into 'tokens * top_k'.
+ index_out = tl.load(indices + tl.program_id(0))
+
+ # One threadblock per row in 'a'. Array 'b' has greater or equal
+ # number of rows since they could be padded.
+ bin_idx = tl.load(bin_ids + tl.program_id(0))
+
+ # Now we know what bin we're assigned to, but we need to know how
+ # many threadblocks were assigned to earlier bins so we can offset
+ # in our bin properly.
+ offset_in_bin = tl.program_id(0)
+ if bin_idx > 0:
+ offset_in_bin -= tl.load(bins + bin_idx - 1)
+
+ # Load the starting index of our bin in array 'x'.
+ index_x = offset_in_bin
+ if bin_idx > 0:
+ index_x += tl.load(padded_bins + bin_idx - 1)
+
+ # Offset the input and output pointers.
+ wgrad += index_out
+ grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS)
+ x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ data = tl.load(x + offsets, mask=mask).to(tl.float32)
+ scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
+ acc += data * scale
+ offsets += BLOCK_X
+
+ # Reduce to get the final result and store.
+ out = tl.sum(acc).to(wgrad.dtype.element_ty)
+ tl.store(wgrad, out)
+
+
+def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_matrix(grad)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], bin_ids.shape[0])
+ assert_equal(bins.size(), padded_bins.size())
+
+ tokens = indices.shape[0] // top_k
+ out = torch.empty((tokens * top_k), dtype=x.dtype, device=x.device)
+ _padded_copy_wgrad[(indices.shape[0],)](
+ x,
+ grad,
+ out,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ TOP_K=top_k,
+ )
+ return out
+
+
+def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k):
+ return padded_scatter_wgrad(x, grad, indices, bin_ids, bins, bins, top_k)
+
+
+# a: (tokens, hidden_size), real.
+# b: (num_experts, expert_capacity, num_columns), real.
+# indices: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _binned_copy(
+ a,
+ b,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+ A_TO_B: tl.constexpr,
+ SCALE: tl.constexpr,
+):
+ # Load our indices into the output.
+ expert_idx = tl.program_id(0)
+ entry_idx = tl.program_id(1)
+
+ # Calculate our offset into the output.
+ index_b = expert_idx * expert_capacity + entry_idx
+
+ # Load the index bounds for our bin and calculate
+ # the number of tokens assigned to our expert.
+ start = 0
+ if expert_idx > 0:
+ start = tl.load(bins + expert_idx - 1)
+ end = tl.load(bins + expert_idx)
+ num_tokens = end - start
+
+ # Calculate our offset into the input. If we don't
+ # have an input exit early.
+ if entry_idx >= num_tokens:
+ return
+ index_a = tl.load(indices + start + entry_idx)
+
+ # Offset the input and output pointers.
+ #
+ # If we're going from A to B, divide the input index to copy
+ # the same input repeatedly. If we're going from B to A we
+ # need to reduce the result. Using atomics is slow, so we
+ # do the reduce step in a second kernel.
+ offset = index_a // TOP_K if A_TO_B else index_a
+ a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS)
+ b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ # Load the scale, if requested.
+ scale = tl.load(weights + index_a) if SCALE else 1
+
+ # Swap the pointers depending on the direction.
+ #
+ # NOTE: We need to zero the output in both directions.
+ iptr = a if A_TO_B else b
+ optr = b if A_TO_B else a
+
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ x = tl.load(iptr + offsets, mask=mask)
+ x = x.to(tl.float32) * scale.to(tl.float32)
+
+ tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask)
+
+ offsets += BLOCK_X
+
+
+def binned_gather(x, indices, weights, bins, expert_capacity, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ num_experts = bins.shape[0]
+ out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
+
+ _binned_copy[(num_experts, expert_capacity)](
+ x,
+ out,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def binned_scatter(x, indices, weights, bins, top_k):
+ # Validate the input shapes.
+ assert_is_tensor(x, 3)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(bins.shape[0], x.shape[0])
+
+ if weights is not None:
+ assert_equal(indices.shape[0], weights.shape[0])
+
+ num_experts, expert_capacity, hidden_size = x.shape
+ tokens = indices.shape[0] // top_k
+ out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
+ _binned_copy[(num_experts, expert_capacity)](
+ out,
+ x,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS=hidden_size,
+ A_TO_B=False,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+
+ # Reduce along the top-k dimension, if needed.
+ return out.sum(dim=1) if top_k > 1 else out.view(tokens, hidden_size)
+
+
+# a: (tokens, hidden_size), real.
+# b: (num_experts, expert_capacity, num_columns), real.
+# indices: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _binned_copy_wgrad(
+ x,
+ grad,
+ wgrad,
+ num_experts,
+ expert_capacity,
+ indices,
+ bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+):
+ # Load our indices into the output.
+ expert_idx = tl.program_id(0)
+ entry_idx = tl.program_id(1)
+
+ # Calculate our offset into the output.
+ index_x = expert_idx * expert_capacity + entry_idx
+
+ # Load the index bounds for our bin and calculate
+ # the number of tokens assigned to our expert.
+ start = 0
+ if expert_idx > 0:
+ start = tl.load(bins + expert_idx - 1)
+ end = tl.load(bins + expert_idx)
+ num_tokens = end - start
+
+ # Calculate our offset into the input. If we don't
+ # have an input exit early.
+ if entry_idx >= num_tokens:
+ return
+ index_out = tl.load(indices + start + entry_idx)
+
+ # Offset the input and output pointers.
+ wgrad += index_out
+ grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS)
+ x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ data = tl.load(x + offsets, mask=mask).to(tl.float32)
+ scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
+ acc += data * scale
+ offsets += BLOCK_X
+
+ # Reduce to get the final result and store.
+ out = tl.sum(acc).to(wgrad.dtype.element_ty)
+ tl.store(wgrad, out)
+
+
+def binned_scatter_wgrad(x, grad, indices, bins, top_k):
+ # Validate the input shapes.
+ assert_is_tensor(x, 3)
+ assert_is_matrix(grad)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(bins.shape[0], x.shape[0])
+
+ num_experts, expert_capacity, hidden_size = x.shape
+ tokens = indices.shape[0] // top_k
+ out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
+ _binned_copy_wgrad[(num_experts, expert_capacity)](
+ x,
+ grad,
+ out,
+ num_experts,
+ expert_capacity,
+ indices,
+ bins,
+ NUM_COLUMNS=hidden_size,
+ TOP_K=top_k,
+ )
+ return out
diff --git a/build/torch29-cxx11-cu130-aarch64-linux/benchmark_util.py b/build/torch29-cxx11-cu130-aarch64-linux/benchmark_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..02612d95e3ead1175a596e2878fa34b5bf85ad6f
--- /dev/null
+++ b/build/torch29-cxx11-cu130-aarch64-linux/benchmark_util.py
@@ -0,0 +1,35 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import numpy as np
+import torch
+
+
+def log_benchmark(name, arguments, time, std):
+ print('=' * 60)
+ print(f'{name} Benchmark')
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean time = {:.3f}ms, std time = {:.3f}ms'.format(time, std))
+ print('=' * 60)
+
+
+def benchmark_function(fn, iterations=100, warmup=10):
+ # Warmup iterations.
+ for _ in range(warmup):
+ fn()
+
+ times = []
+ for i in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+
+ start.record()
+ fn()
+ end.record()
+
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ return np.mean(times), np.std(times)
diff --git a/build/torch29-cxx11-cu130-aarch64-linux/cpu_fused_moe.py b/build/torch29-cxx11-cu130-aarch64-linux/cpu_fused_moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a29e81d6e9c57f64b0a78ac1c0828e45fd9d855
--- /dev/null
+++ b/build/torch29-cxx11-cu130-aarch64-linux/cpu_fused_moe.py
@@ -0,0 +1,311 @@
+# SPDX-License-Identifier: Apache-2.0
+# MegaBlocks CPU Fused MoE Implementation
+#
+# This is a pure Python/PyTorch implementation for CPU.
+# For better performance, consider using the C++ kernel implementation.
+#
+import torch
+import torch.nn.functional as F
+
+
+def swigluoai_activation(gate: torch.Tensor, up: torch.Tensor,
+ alpha: float = 1.702, limit: float = 7.0) -> torch.Tensor:
+ """
+ SwigluOAI activation function used in GptOss models.
+
+ Formula:
+ gate = clamp(gate, max=limit)
+ up = clamp(up, -limit, limit)
+ glu = gate * sigmoid(gate * alpha)
+ output = (up + 1) * glu
+
+ Args:
+ gate: Gate tensor from gate projection
+ up: Up tensor from up projection
+ alpha: Scaling factor for sigmoid (default: 1.702)
+ limit: Clamp limit (default: 7.0)
+
+ Returns:
+ Activated tensor
+ """
+ gate = gate.clamp(max=limit)
+ up = up.clamp(min=-limit, max=limit)
+ glu = gate * torch.sigmoid(gate * alpha)
+ return (up + 1) * glu
+
+
+def silu_and_mul_activation(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor:
+ """
+ SiLU (Swish) activation with element-wise multiplication.
+
+ Formula:
+ output = silu(gate) * up
+
+ Args:
+ gate: Gate tensor
+ up: Up tensor
+
+ Returns:
+ Activated tensor
+ """
+ return F.silu(gate) * up
+
+
+def route_tokens_cpu(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: torch.Tensor | None,
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_normalize_expert_weights: int | None = None,
+) -> tuple:
+ """
+ Route tokens to experts and compute expert weights and indices (CPU version).
+
+ Args:
+ x: Input tensor [batch, seq, hidden] or [tokens, hidden]
+ router_weight: Router weight [num_experts, hidden]
+ router_bias: Router bias [num_experts] or None
+ moe_top_k: Number of experts per token
+ moe_num_experts: Total number of experts
+ moe_normalize_expert_weights: Normalization order or None
+
+ Returns:
+ Tuple of (logits, expert_weights, expert_indices)
+ """
+ x_flat = x.view(-1, x.shape[-1])
+ logits = F.linear(x_flat, router_weight, router_bias)
+
+ if moe_top_k == 1:
+ expert_weights, expert_indices = logits.max(dim=-1, keepdim=True)
+ else:
+ expert_weights, expert_indices = torch.topk(logits, moe_top_k, dim=-1)
+
+ expert_weights = expert_weights.softmax(dim=-1)
+
+ if moe_normalize_expert_weights is not None:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+
+ return logits, expert_weights, expert_indices
+
+
+def cpu_fused_moe(
+ hidden_states: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ w1_bias: torch.Tensor | None = None,
+ w2_bias: torch.Tensor | None = None,
+ activation: str = "silu",
+ alpha: float = 1.702,
+ limit: float = 7.0,
+ is_interleaved: bool = True,
+) -> torch.Tensor:
+ """
+ CPU Fused MoE using PyTorch operations.
+
+ This implementation processes all experts in parallel using batched operations
+ instead of sequential for loops, which is more efficient on CPU.
+
+ Args:
+ hidden_states: [num_tokens, hidden_size]
+ w1: [num_experts, hidden_size, 2*inter_size] - gate_up_proj weights
+ w2: [num_experts, inter_size, hidden_size] - down_proj weights
+ topk_weights: [num_tokens, topk] - routing weights
+ topk_ids: [num_tokens, topk] - expert indices
+ w1_bias: [num_experts, 2*inter_size] or None
+ w2_bias: [num_experts, hidden_size] or None
+ activation: "silu" or "swigluoai"
+ alpha: swigluoai alpha parameter
+ limit: swigluoai limit parameter
+ is_interleaved: whether gate_up is interleaved [g0,u0,g1,u1,...] (True for GptOss)
+
+ Returns:
+ output: [num_tokens, hidden_size]
+ """
+ num_tokens, hidden_size = hidden_states.shape
+ num_experts = w1.shape[0]
+ inter_size = w2.shape[1]
+ topk = topk_weights.shape[1]
+
+ # Initialize output
+ output = torch.zeros_like(hidden_states)
+
+ # Build expert mask: which tokens go to which expert
+ # expert_mask[expert_id] contains indices of (token_idx, topk_pos) pairs
+ for expert_idx in range(num_experts):
+ # Find tokens assigned to this expert
+ # mask shape: [num_tokens, topk], True where topk_ids == expert_idx
+ mask = (topk_ids == expert_idx)
+
+ if not mask.any():
+ continue
+
+ # Get token indices and topk positions
+ token_indices, topk_positions = torch.where(mask)
+
+ if len(token_indices) == 0:
+ continue
+
+ # Gather input tokens for this expert
+ # current_hidden: [num_selected_tokens, hidden_size]
+ current_hidden = hidden_states[token_indices]
+
+ # Get weights for this expert
+ # w1[expert_idx]: [hidden_size, 2*inter_size]
+ # w2[expert_idx]: [inter_size, hidden_size]
+ expert_w1 = w1[expert_idx] # [hidden_size, 2*inter_size]
+ expert_w2 = w2[expert_idx] # [inter_size, hidden_size]
+
+ # First projection: hidden @ w1 -> [num_selected, 2*inter_size]
+ gate_up = current_hidden @ expert_w1
+
+ # Add bias if present
+ if w1_bias is not None:
+ gate_up = gate_up + w1_bias[expert_idx]
+
+ # Split gate and up projections
+ if is_interleaved:
+ # GptOss uses interleaved layout: [g0, u0, g1, u1, ...]
+ gate = gate_up[..., ::2] # [num_selected, inter_size]
+ up = gate_up[..., 1::2] # [num_selected, inter_size]
+ else:
+ # Standard layout: [gate_all, up_all]
+ gate = gate_up[..., :inter_size]
+ up = gate_up[..., inter_size:]
+
+ # Apply activation
+ if activation == "swigluoai":
+ activated = swigluoai_activation(gate, up, alpha, limit)
+ else: # silu
+ activated = silu_and_mul_activation(gate, up)
+
+ # Second projection: activated @ w2 -> [num_selected, hidden_size]
+ expert_out = activated @ expert_w2
+
+ # Add bias if present
+ if w2_bias is not None:
+ expert_out = expert_out + w2_bias[expert_idx]
+
+ # Apply routing weights and accumulate
+ # weights shape: [num_selected]
+ weights = topk_weights[token_indices, topk_positions].unsqueeze(-1)
+ weighted_out = expert_out * weights
+
+ # Accumulate to output
+ output.index_add_(0, token_indices, weighted_out.to(output.dtype))
+
+ return output
+
+
+class MegaBlocksMoeMLP(torch.nn.Module):
+ """
+ CPU MoE MLP module that can be used as a drop-in replacement for
+ the transformers GptOssMLP when using @use_kernel_forward_from_hub.
+ """
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> tuple:
+ """
+ Forward pass through the MoE layer.
+
+ Args:
+ x: Input tensor of shape [batch_size, seq_len, hidden_size] or [tokens, hidden_size]
+
+ Returns:
+ Tuple of (output, expert_weights) where:
+ - output: Tensor of same shape as input
+ - expert_weights: Expert weights for each token [tokens, top_k]
+ """
+ # Get MoE parameters from the wrapped modules
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+
+ # Detect activation type
+ if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
+ activation = "swigluoai"
+ alpha = self.experts.alpha
+ limit = self.experts.limit
+ else:
+ activation = getattr(self.experts, "activation", "silu")
+ alpha = 1.702
+ limit = 7.0
+
+ # Get weight tensors
+ if hasattr(self.experts, "gate_up_proj"):
+ w1 = self.experts.gate_up_proj
+ is_interleaved = True # GptOss uses interleaved layout
+ elif hasattr(self.experts, "w1"):
+ w1 = self.experts.w1
+ w3 = getattr(self.experts, "w3", None)
+ if w3 is not None:
+ w1 = torch.cat([w1, w3], dim=-1)
+ is_interleaved = False
+ else:
+ raise AttributeError("experts module must have 'gate_up_proj' or 'w1' attribute")
+
+ if hasattr(self.experts, "down_proj"):
+ w2 = self.experts.down_proj
+ elif hasattr(self.experts, "w2"):
+ w2 = self.experts.w2
+ else:
+ raise AttributeError("experts module must have 'down_proj' or 'w2' attribute")
+
+ # Get optional bias tensors
+ w1_bias = getattr(self.experts, "gate_up_proj_bias", None)
+ w2_bias = getattr(self.experts, "down_proj_bias", None)
+
+ # Store original shape
+ in_shape = x.size()
+
+ # Route tokens to experts
+ logits, expert_weights, expert_indices = route_tokens_cpu(
+ x,
+ self.router.weight,
+ getattr(self.router, "bias", None),
+ moe_top_k,
+ moe_num_experts,
+ moe_normalize_expert_weights,
+ )
+
+ # Reshape input for fused MoE
+ x_flat = x.view(-1, x.shape[-1])
+
+ # Call CPU fused MoE
+ output = cpu_fused_moe(
+ hidden_states=x_flat,
+ w1=w1,
+ w2=w2,
+ topk_weights=expert_weights,
+ topk_ids=expert_indices,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ activation=activation,
+ alpha=alpha,
+ limit=limit,
+ is_interleaved=is_interleaved,
+ )
+
+ # Restore original shape
+ output = output.view(in_shape)
+
+ return output, expert_weights
+
+
+# Export classes and functions
+__all__ = [
+ "MegaBlocksMoeMLP",
+ "cpu_fused_moe",
+ "route_tokens_cpu",
+ "swigluoai_activation",
+ "silu_and_mul_activation",
+]
diff --git a/build/torch29-cxx11-cu130-aarch64-linux/cpu_moe_cpp.py b/build/torch29-cxx11-cu130-aarch64-linux/cpu_moe_cpp.py
new file mode 100644
index 0000000000000000000000000000000000000000..073ff66d24ce348fbb5ed19c9027fadd3f7a9c61
--- /dev/null
+++ b/build/torch29-cxx11-cu130-aarch64-linux/cpu_moe_cpp.py
@@ -0,0 +1,265 @@
+# SPDX-License-Identifier: Apache-2.0
+# MegaBlocks C++ Optimized CPU MoE
+
+"""
+C++ accelerated MoE with brgemm optimization for Intel AMX.
+Direct replacement for cpu_fused_moe.MegaBlocksMoeMLP with better performance.
+"""
+
+import torch
+from typing import Optional
+from .cpu_fused_moe import route_tokens_cpu
+from ._ops import ops
+
+
+def _to_local_tensor(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
+ """Convert DTensor to local torch.Tensor if needed for custom ops compatibility."""
+ if tensor is None:
+ return None
+ # Check if it's a DTensor by looking for the to_local() method
+ if hasattr(tensor, "to_local"):
+ return tensor.to_local()
+ return tensor
+
+
+def fused_moe_cpp(
+ hidden_states: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ inplace: bool = False,
+ use_int8_w8a8: bool = False,
+ use_fp8_w8a16: bool = False,
+ use_mxfp4: bool = False,
+ w1_scale: Optional[torch.Tensor] = None,
+ w2_scale: Optional[torch.Tensor] = None,
+ block_size: Optional[list] = None,
+ a1_scale: Optional[torch.Tensor] = None,
+ a2_scale: Optional[torch.Tensor] = None,
+ w1_bias: Optional[torch.Tensor] = None,
+ w2_bias: Optional[torch.Tensor] = None,
+ alpha: Optional[float] = None,
+ limit: Optional[float] = None,
+ is_vnni: bool = False,
+) -> torch.Tensor:
+ """
+ C++ Fused MoE with brgemm optimization (sglang compatible interface).
+
+ Uses at::native::cpublas::brgemm for efficient batch GEMM on Intel CPUs.
+ Supports both silu_and_mul (standard SwiGLU) and swigluoai (GptOss) activations.
+
+ Args:
+ hidden_states: Input tensor [M, K]
+ w1: Gate and up projections [E, 2N, K]
+ w2: Down projection [E, K, N]
+ topk_weights: Expert weights [M, topk]
+ topk_ids: Expert indices [M, topk]
+ inplace: Whether to use hidden_states as output
+ use_int8_w8a8: Use int8 quantization
+ use_fp8_w8a16: Use fp8 quantization
+ use_mxfp4: Use mxfp4 quantization
+ w1_scale, w2_scale: Quantization scales
+ block_size: Block size for fp8
+ a1_scale, a2_scale: Activation scales
+ w1_bias, w2_bias: Optional biases
+ alpha: swigluoai alpha parameter (set to enable swiglu)
+ limit: swigluoai limit parameter (set to enable swiglu)
+ is_vnni: Whether w1/w2 are already in VNNI packed format
+ """
+ # MXFP4/FP8 kernels only support bf16, convert if needed
+ orig_dtype = hidden_states.dtype
+ need_convert = ((use_mxfp4 or use_fp8_w8a16) and orig_dtype != torch.bfloat16) or orig_dtype == torch.float32
+ if need_convert:
+ hidden_states = hidden_states.to(torch.bfloat16)
+
+ # bias must match hidden_states dtype
+ if w1_bias is not None:
+ w1_bias = w1_bias.to(hidden_states.dtype)
+ if w2_bias is not None:
+ w2_bias = w2_bias.to(hidden_states.dtype)
+
+ # Convert DTensor to local tensor for custom ops compatibility (TP mode)
+ hidden_states = _to_local_tensor(hidden_states)
+ w1 = _to_local_tensor(w1)
+ w2 = _to_local_tensor(w2)
+ topk_weights = _to_local_tensor(topk_weights)
+ topk_ids = _to_local_tensor(topk_ids)
+ w1_scale = _to_local_tensor(w1_scale)
+ w2_scale = _to_local_tensor(w2_scale)
+ a1_scale = _to_local_tensor(a1_scale)
+ a2_scale = _to_local_tensor(a2_scale)
+ w1_bias = _to_local_tensor(w1_bias)
+ w2_bias = _to_local_tensor(w2_bias)
+
+ output = ops.fused_experts(
+ hidden_states, w1, w2, topk_weights, topk_ids,
+ inplace, use_int8_w8a8, use_fp8_w8a16, use_mxfp4,
+ w1_scale, w2_scale, block_size, a1_scale, a2_scale,
+ w1_bias, w2_bias, alpha, limit, is_vnni
+ )
+
+ # Convert back to original dtype if needed
+ if need_convert:
+ output = output.to(orig_dtype)
+ return output
+
+
+class CPUMegaBlocksMoeMLP(torch.nn.Module):
+ """
+ C++ optimized MoE MLP using brgemm.
+ Drop-in replacement for cpu_fused_moe.MegaBlocksMoeMLP with better performance.
+
+ Usage in transformers:
+ # Will be used via @use_kernel_forward_from_hub decorator
+ """
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> tuple:
+ """
+ Forward pass through the MoE layer using C++ kernel.
+
+ Args:
+ x: Input tensor [batch_size, seq_len, hidden_size]
+
+ Returns:
+ Tuple of (output, expert_weights)
+ """
+ # Optimization for GPT-OSS model
+ if getattr(self, "use_mxfp4", None) is None:
+ self.use_mxfp4 = False
+
+ w1_scale = None
+ w2_scale = None
+
+ if (
+ not getattr(self, "packed_scales", False)
+ and hasattr(self.experts, "gate_up_proj")
+ and getattr(self.experts, "gate_up_proj_precision_config", None) is not None
+ ):
+ # convert scales
+ data_1 = ops.convert_scale_packed(self.experts.gate_up_proj_precision_config.weight_scale.data.transpose(-1, -2).contiguous())
+ data_2 = ops.convert_scale_packed(self.experts.down_proj_precision_config.weight_scale.data.transpose(-1, -2).contiguous())
+ self.experts.gate_up_proj_precision_config.weight_scale.storage.data = data_1
+ self.experts.down_proj_precision_config.weight_scale.storage.data = data_2
+ self.packed_scales = True
+ self.use_mxfp4 = True
+
+ if not getattr(self, "packed_weight", False) and hasattr(
+ self.experts, "gate_up_proj"
+ ):
+ # convert weights
+ data_1 = self.experts.gate_up_proj.data.transpose(-1, -2).contiguous()
+ data_2 = self.experts.down_proj.data.transpose(-1, -2).contiguous()
+ if self.use_mxfp4:
+ self.experts.gate_up_proj.storage.data = ops.convert_weight_packed(data_1)
+ self.experts.down_proj.storage.data = ops.convert_weight_packed(data_2)
+ else:
+ # convert_weight_packed only supports bfloat16, float16, int8, fp8_e4m3 or uint8(mxfp4 or int4).
+ data_1 = data_1.to(torch.bfloat16) if data_1.dtype == torch.float32 else data_1
+ data_2 = data_2.to(torch.bfloat16) if data_2.dtype == torch.float32 else data_2
+ self.experts.gate_up_proj.data = ops.convert_weight_packed(data_1)
+ self.experts.down_proj.data = ops.convert_weight_packed(data_2)
+
+ # C++ kernel does not support float32.
+ dtype = torch.bfloat16 if x.dtype == torch.float32 else x.dtype
+ if getattr(self.experts, "gate_up_proj_bias", None) is not None:
+ self.experts.gate_up_proj_bias.data = self.experts.gate_up_proj_bias.data.to(dtype)
+ if getattr(self.experts, "down_proj_bias", None) is not None:
+ self.experts.down_proj_bias.data = self.experts.down_proj_bias.data.to(dtype)
+
+ self.packed_weight = True
+
+ # Get MoE parameters
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
+
+ # Detect activation type
+ if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
+ activation = "swigluoai"
+ alpha = self.experts.alpha
+ limit = self.experts.limit
+ else:
+ activation = getattr(self.experts, "activation", "silu")
+ alpha = 1.702
+ limit = 7.0
+
+ # Get weight tensors
+ if hasattr(self.experts, "gate_up_proj"):
+ w1 = self.experts.gate_up_proj
+ elif hasattr(self.experts, "w1"):
+ w1 = self.experts.w1
+ w3 = getattr(self.experts, "w3", None)
+ if w3 is not None:
+ w1 = torch.cat([w1, w3], dim=-1)
+ else:
+ raise AttributeError("experts module must have 'gate_up_proj' or 'w1' attribute")
+
+ if hasattr(self.experts, "down_proj"):
+ w2 = self.experts.down_proj
+ elif hasattr(self.experts, "w2"):
+ w2 = self.experts.w2
+ else:
+ raise AttributeError("experts module must have 'down_proj' or 'w2' attribute")
+
+ # Get optional bias tensors
+ w1_bias = getattr(self.experts, "gate_up_proj_bias", None)
+ w2_bias = getattr(self.experts, "down_proj_bias", None)
+ w1_bias = w1_bias if w1_bias is None else w1_bias.data
+ w2_bias = w2_bias if w2_bias is None else w2_bias.data
+
+ if self.use_mxfp4:
+ w1_scale = self.experts.gate_up_proj_precision_config.weight_scale.data
+ w2_scale = self.experts.down_proj_precision_config.weight_scale.data
+
+ # Store original shape
+ in_shape = x.size()
+
+ # Route tokens to experts (Python implementation is fast enough)
+ logits, expert_weights, expert_indices = route_tokens_cpu(
+ x,
+ self.router.weight,
+ getattr(self.router, "bias", None),
+ moe_top_k,
+ moe_num_experts,
+ moe_normalize_expert_weights,
+ )
+
+ # Flatten input
+ x_flat = x.view(-1, x.shape[-1])
+
+ # Determine alpha/limit for swiglu activation
+ use_alpha = alpha if activation == "swigluoai" else None
+ use_limit = limit if activation == "swigluoai" else None
+
+ # Call C++ optimized kernel
+ output = fused_moe_cpp(
+ hidden_states=x_flat,
+ w1=w1.data,
+ w2=w2.data,
+ topk_weights=expert_weights,
+ topk_ids=expert_indices.to(torch.int32),
+ inplace=False,
+ use_int8_w8a8=False,
+ use_fp8_w8a16=False,
+ use_mxfp4=self.use_mxfp4,
+ w1_scale=w1_scale,
+ w2_scale=w2_scale,
+ block_size=None,
+ a1_scale=None,
+ a2_scale=None,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ alpha=use_alpha,
+ limit=use_limit,
+ is_vnni=getattr(self, "packed_weight", False),
+ )
+
+ # Restore original shape
+ output = output.view(in_shape)
+
+ return output, expert_weights
+
+
+__all__ = ["fused_moe_cpp", "MegaBlocksMoeMLP"]
diff --git a/build/torch29-cxx11-cu130-aarch64-linux/grouped_gemm/__init__.py b/build/torch29-cxx11-cu130-aarch64-linux/grouped_gemm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b91c8308f0c24f4c4171b6e4f15b6f76dabf295a
--- /dev/null
+++ b/build/torch29-cxx11-cu130-aarch64-linux/grouped_gemm/__init__.py
@@ -0,0 +1,2 @@
+from . import ops
+from . import backend
diff --git a/build/torch29-cxx11-cu130-aarch64-linux/grouped_gemm/backend.py b/build/torch29-cxx11-cu130-aarch64-linux/grouped_gemm/backend.py
new file mode 100644
index 0000000000000000000000000000000000000000..76037d8039cbfc2f0577275c78e4bc0be762592a
--- /dev/null
+++ b/build/torch29-cxx11-cu130-aarch64-linux/grouped_gemm/backend.py
@@ -0,0 +1,33 @@
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# # TODO(tgale): Wrap this in a try-block with better
+# # error message and instructions for building the
+# # c++ operations.
+# import grouped_gemm_backend as backend
+
+# We import the backend operations from the megablocks package as
+# grouped_gemm is vendored in megablocks in this repository.
+# from ... import _ops as backend
+# from megablocks._ops import ops as backend # type: ignore
+from .._ops import ops as backend # type: ignore
+
+def _allocate_output(a, b, batch_sizes, trans_a, trans_b):
+ assert not (trans_a and trans_b)
+ assert batch_sizes.ndim == 1, "Expected 1d tensor for batch_sizes"
+ assert a.ndim == 2, "Expected 2d tensor for 'a'"
+ assert b.ndim == (2 if trans_a else 3)
+
+ shape = (
+ (batch_sizes.shape[0], a.shape[1], b.shape[1])
+ if trans_a else
+ (a.shape[0], (b.shape[1] if trans_b else b.shape[2]))
+ )
+ return torch.empty(*shape, device=a.device, dtype=a.dtype)
+
+def gmm(a, b, batch_sizes, trans_a=False, trans_b=False, c=None):
+ if c is None:
+ c = _allocate_output(a, b, batch_sizes, trans_a, trans_b)
+ backend.gmm(a, b, c, batch_sizes, trans_a, trans_b)
+ return c
diff --git a/build/torch29-cxx11-cu130-aarch64-linux/grouped_gemm/ops.py b/build/torch29-cxx11-cu130-aarch64-linux/grouped_gemm/ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b30dd14e23837ea3b12334f4e31337ed9ad2b69
--- /dev/null
+++ b/build/torch29-cxx11-cu130-aarch64-linux/grouped_gemm/ops.py
@@ -0,0 +1,33 @@
+from . import backend
+import torch
+
+
+class GroupedGemm(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, a, b, batch_sizes, trans_b):
+ ctx.save_for_backward(a, b, batch_sizes)
+ ctx.trans_b = trans_b
+ return backend.gmm(a, b, batch_sizes, trans_a=False, trans_b=trans_b)
+
+ @staticmethod
+ def backward(ctx, grad):
+ grad = grad.contiguous()
+ a, b, batch_sizes = ctx.saved_tensors
+ trans_b = ctx.trans_b
+
+ agrad = None
+ if ctx.needs_input_grad[0]:
+ agrad = backend.gmm(
+ grad, b, batch_sizes, trans_a=False, trans_b=not trans_b)
+
+ bgrad = None
+ if ctx.needs_input_grad[1]:
+ lhs, rhs = (grad, a) if trans_b else (a, grad)
+ bgrad = backend.gmm(
+ lhs, rhs, batch_sizes, trans_a=True, trans_b=False)
+ return agrad, bgrad, None, None
+
+
+def gmm(a, b, batch_sizes, trans_b=False):
+ return GroupedGemm.apply(a, b, batch_sizes, trans_b)
diff --git a/build/torch29-cxx11-cu130-aarch64-linux/grouped_gemm_util.py b/build/torch29-cxx11-cu130-aarch64-linux/grouped_gemm_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d6d49fc46b0a57ad46e4179df3cc1ac2a24f7ae
--- /dev/null
+++ b/build/torch29-cxx11-cu130-aarch64-linux/grouped_gemm_util.py
@@ -0,0 +1,31 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+import warnings
+
+_grouped_gemm_is_available: bool = False
+try:
+ # import grouped_gemm
+ pass
+ _grouped_gemm_is_available = True
+except ImportError as error:
+ warnings.warn('Grouped GEMM not available.')
+
+
+def grouped_gemm_is_available():
+ return _grouped_gemm_is_available
+
+
+def assert_grouped_gemm_is_available():
+ msg = (
+ 'Grouped GEMM not available. Please run '
+ '`pip install git+https://github.com/tgale96/grouped_gemm@main`.',
+ )
+ assert _grouped_gemm_is_available, msg
+
+
+# backend = grouped_gemm.backend if grouped_gemm_is_available() else None
+# ops = grouped_gemm.ops if grouped_gemm_is_available() else None
+
+
+from .grouped_gemm import backend as ops
+from .grouped_gemm import ops as backend
diff --git a/build/torch29-cxx11-cu130-aarch64-linux/layers.py b/build/torch29-cxx11-cu130-aarch64-linux/layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c91edcd9e2d1a4ef9eac90217ff481f08ab1886
--- /dev/null
+++ b/build/torch29-cxx11-cu130-aarch64-linux/layers.py
@@ -0,0 +1,1232 @@
+import torch
+import torch.distributed as dist
+
+from typing import Optional, Any, TYPE_CHECKING
+
+from . import _layers
+from . import ops
+
+# Conditional import for meta kernel registration
+if TYPE_CHECKING:
+
+ def register_fake(fn):
+ return lambda name: fn
+
+else:
+ try:
+ from torch.library import register_fake
+ except ImportError:
+ try:
+ from torch.library import impl_abstract as register_fake
+ except ImportError:
+ # Fallback for older PyTorch versions
+ def register_fake(op_name):
+ def decorator(fn):
+ return fn
+
+ return decorator
+
+
+# Meta kernel implementations for torch.compile compatibility
+def _install_meta_kernels():
+ """Install meta kernels for existing MegaBlocks operations"""
+
+ # Create wrapper functions that check for compilation and return meta tensors
+
+ # Patch ops.sort
+ if hasattr(ops, "sort"):
+ original_sort = ops.sort
+
+ def sort_with_meta(x, end_bit=None):
+ if torch.compiler.is_compiling():
+ # print("Using meta kernel for sort")
+ # Meta implementation - return tensors with correct shape/dtype/device
+ return torch.empty_like(x), torch.empty_like(x)
+ # print("Using original sort kernel")
+ return original_sort(x, end_bit)
+
+ ops.sort = sort_with_meta
+
+ # Patch ops.histogram
+ if hasattr(ops, "histogram"):
+ original_histogram = ops.histogram
+
+ def histogram_with_meta(x, max_val):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty((max_val,), dtype=torch.int32, device=x.device)
+ return original_histogram(x, max_val)
+
+ ops.histogram = histogram_with_meta
+
+ # Patch ops.inclusive_cumsum
+ if hasattr(ops, "inclusive_cumsum"):
+ original_inclusive_cumsum = ops.inclusive_cumsum
+
+ def inclusive_cumsum_with_meta(x, dim):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty_like(x)
+ return original_inclusive_cumsum(x, dim)
+
+ ops.inclusive_cumsum = inclusive_cumsum_with_meta
+
+ # Patch ops.binned_gather
+ if hasattr(ops, "binned_gather"):
+ original_binned_gather = ops.binned_gather
+
+ def binned_gather_with_meta(x, indices, bins, bin_size, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - output shape based on bin_size
+ if x.dim() >= 2:
+ hidden_size = x.size(-1)
+ return torch.empty(
+ (bin_size, x.size(1), hidden_size),
+ dtype=x.dtype,
+ device=x.device,
+ )
+ else:
+ return torch.empty((bin_size,), dtype=x.dtype, device=x.device)
+ return original_binned_gather(x, indices, bins, bin_size, top_k)
+
+ ops.binned_gather = binned_gather_with_meta
+
+ # Patch ops.binned_scatter
+ if hasattr(ops, "binned_scatter"):
+ original_binned_scatter = ops.binned_scatter
+
+ def binned_scatter_with_meta(x, indices, weights, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - typically reduces to 2D
+ if x.dim() >= 3:
+ return torch.empty(
+ (x.size(1), x.size(2)), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty_like(x)
+ return original_binned_scatter(x, indices, weights, bins, top_k)
+
+ ops.binned_scatter = binned_scatter_with_meta
+
+ # Patch ops.gather
+ if hasattr(ops, "gather"):
+ original_gather = ops.gather
+
+ def gather_with_meta(x, indices, bin_ids, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ if x.dim() >= 2:
+ hidden_size = x.size(-1)
+ return torch.empty(
+ (indices.numel(), hidden_size), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty(indices.shape, dtype=x.dtype, device=x.device)
+ return original_gather(x, indices, bin_ids, bins, top_k)
+
+ ops.gather = gather_with_meta
+
+ # Patch ops.scatter
+ if hasattr(ops, "scatter"):
+ original_scatter = ops.scatter
+
+ def scatter_with_meta(x, indices, bin_ids, weights, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - restore sequence shape
+ seq_len = (
+ indices.size(0) // top_k
+ if indices.numel() > 0 and top_k > 0
+ else x.size(0)
+ )
+ if x.dim() >= 2:
+ return torch.empty(
+ (seq_len, x.size(-1)), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty((seq_len,), dtype=x.dtype, device=x.device)
+ return original_scatter(x, indices, bin_ids, weights, bins, top_k)
+
+ ops.scatter = scatter_with_meta
+
+ # Patch ops.replicate
+ if hasattr(ops, "replicate"):
+ original_replicate = ops.replicate
+
+ def replicate_with_meta(x, bins, num_outputs):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty(
+ (x.shape[0], num_outputs), dtype=x.dtype, device=x.device
+ )
+ return original_replicate(x, bins, num_outputs)
+
+ ops.replicate = replicate_with_meta
+
+ # Patch ops.repeat (if it's a regular function)
+ if hasattr(ops, "repeat"):
+ original_repeat = ops.repeat
+
+ def repeat_with_meta(x, repeats):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ if isinstance(repeats, (tuple, list)):
+ new_shape = list(x.shape)
+ for i, rep in enumerate(repeats):
+ if i < len(new_shape):
+ new_shape[i] *= rep
+ return torch.empty(new_shape, dtype=x.dtype, device=x.device)
+ else:
+ new_shape = [x.size(0) * repeats] + list(x.shape[1:])
+ return torch.empty(new_shape, dtype=x.dtype, device=x.device)
+ return original_repeat(x, repeats)
+
+ ops.repeat = repeat_with_meta
+
+
+# Install meta kernels on import
+try:
+ _install_meta_kernels()
+except Exception as e:
+ # If meta kernel installation fails, continue without them
+ # torch.compile may not work but the library will still function
+ import warnings
+
+ warnings.warn(
+ f"Failed to install meta kernels for torch.compile support: {e}", UserWarning
+ )
+
+
+# Set the expert model parallel attributes on a tensor
+def set_expert_model_parallel_attributes(
+ tensor: torch.Tensor,
+ is_parallel: bool,
+):
+ assert not hasattr(tensor, "expert_model_parallel")
+ setattr(tensor, "expert_model_parallel", is_parallel)
+
+
+# Get the expert model parallel attributes from a tensor
+def expert_sharding_degree(
+ world_size: int,
+ moe_num_experts: int,
+) -> int:
+ esd = min(world_size, moe_num_experts)
+ if (moe_num_experts % esd) != 0:
+ raise ValueError(f"Cannot shard {moe_num_experts} experts {esd} ways.")
+ return esd
+
+
+# Calculate the hidden sharding degree based on world size and expert sharding degree
+def hidden_sharding_degree(
+ world_size: int,
+ moe_num_experts: int,
+ ffn_hidden_size: int,
+) -> int:
+ esd = expert_sharding_degree(world_size, moe_num_experts)
+ hsd = world_size // esd
+ if (ffn_hidden_size % hsd) != 0:
+ raise ValueError(f"Cannot shard {ffn_hidden_size} features {hsd} ways.")
+ if (esd * hsd) != world_size:
+ raise ValueError(
+ f"Invalid sharding. expert_sharding_degree ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size})."
+ )
+ return hsd
+
+
+# Calculate the number of experts per rank based on world size and expert sharding degree
+def experts_per_rank(
+ moe_num_experts: int,
+ world_size: int,
+) -> int:
+ return moe_num_experts // expert_sharding_degree(world_size, moe_num_experts)
+
+
+# Calculate the number of features per rank based on ffn hidden size and hidden sharding degree
+def features_per_rank(
+ ffn_hidden_size: int, world_size: int, moe_num_experts: int
+) -> int:
+ return ffn_hidden_size // hidden_sharding_degree(
+ world_size, moe_num_experts, ffn_hidden_size
+ )
+
+
+# Apply jitter to the input tensor
+def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor:
+ low = 1.0 - moe_jitter_eps
+ high = 1.0 + moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return x * (low + noise * (high - low))
+
+
+# Compute the top-k scores from the logits
+def compute_top_k(scores: torch.Tensor, moe_top_k: int):
+ if moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, moe_top_k, dim=-1)
+
+
+# Route tokens to experts and compute expert weights and indices
+def route_tokens(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: torch.Tensor,
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ if training and moe_jitter_eps is not None:
+ x = apply_jitter(x, moe_jitter_eps)
+
+ x_flat = x.view(-1, x.shape[-1])
+ logits = torch.nn.functional.linear(x_flat, router_weight, router_bias)
+ expert_weights, expert_indices = compute_top_k(logits, moe_top_k)
+ expert_weights = expert_weights.softmax(dim=-1)
+ if moe_normalize_expert_weights is not None:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+ if uniform_expert_assignment:
+ expert_indices = _layers.router._uniform_expert_assignment(
+ expert_indices,
+ moe_num_experts,
+ )
+
+ return logits, expert_weights, expert_indices
+
+
+# Scale the gradient of the weights
+def scale_grad(
+ w: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+) -> torch.Tensor:
+ if gradient_scale is None:
+ return w
+ return _layers.mlp.scale_gradient(w, gradient_scale)
+
+
+# Forward pass for the MLP layer
+def mlp_forward(
+ x: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ limit: float = 7.0,
+):
+ # Scale weights
+ w1 = scale_grad(w1, gradient_scale)
+ w2 = scale_grad(w2, gradient_scale)
+ w1_bias = scale_grad(w1_bias, gradient_scale)
+ w2_bias = scale_grad(w2_bias, gradient_scale)
+
+ # Resolve dtensors
+ w1 = _layers.mlp.resolve_dtensor(w1)
+ w2 = _layers.mlp.resolve_dtensor(w2)
+ w1_bias = _layers.mlp.resolve_dtensor(w1_bias)
+ w2_bias = _layers.mlp.resolve_dtensor(w2_bias)
+
+ # Forward pass
+ gate_up = torch.bmm(x, w1) + w1_bias[..., None, :]
+ gate, up = gate_up[..., ::2], gate_up[..., 1::2]
+ gate = gate.clamp(min=None, max=limit)
+ up = up.clamp(min=-limit, max=limit)
+ glu = gate * torch.sigmoid(gate * alpha)
+ next_states = torch.bmm(((up + 1) * glu), w2)
+ next_states += w2_bias[..., None, :]
+ return next_states
+
+# Shared expert MLP forward pass
+def shared_mlp_forward(
+ x: torch.Tensor,
+ up_proj_weight: torch.Tensor,
+ down_proj_weight: torch.Tensor,
+ up_proj_bias: Optional[torch.Tensor] = None,
+ down_proj_bias: Optional[torch.Tensor] = None,
+ activation_fn: Optional[Any] = None,
+ gradient_scale: Optional[float] = None,
+) -> torch.Tensor:
+ # Default activation function
+ if activation_fn is None:
+ activation_fn = torch.nn.functional.gelu
+
+ # Scale weights
+ up_proj_weight = scale_grad(up_proj_weight, gradient_scale)
+ down_proj_weight = scale_grad(down_proj_weight, gradient_scale)
+ if up_proj_bias is not None:
+ up_proj_bias = scale_grad(up_proj_bias, gradient_scale)
+ if down_proj_bias is not None:
+ down_proj_bias = scale_grad(down_proj_bias, gradient_scale)
+
+ # Resolve dtensors
+ up_proj_weight = _layers.mlp.resolve_dtensor(up_proj_weight)
+ down_proj_weight = _layers.mlp.resolve_dtensor(down_proj_weight)
+ if up_proj_bias is not None:
+ up_proj_bias = _layers.mlp.resolve_dtensor(up_proj_bias)
+ if down_proj_bias is not None:
+ down_proj_bias = _layers.mlp.resolve_dtensor(down_proj_bias)
+
+ # Up projection
+ x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias)
+
+ # Activation
+ x = activation_fn(x)
+
+ # Down projection
+ x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias)
+
+ return x
+
+
+# Combine outputs from shared expert and regular experts
+def combine_expert_shared_outputs(
+ shared_expert_out: torch.Tensor,
+ expert_out: torch.Tensor,
+ shared_expert_weighted_sum: bool = False,
+ moe_top_k: int = 1,
+) -> torch.Tensor:
+ if shared_expert_weighted_sum:
+ # Weighted sum based on number of experts used
+ total_experts = moe_top_k + 1
+ shared_weight = 1.0 / total_experts
+ expert_weight = moe_top_k / total_experts
+ return shared_expert_out * shared_weight + expert_out * expert_weight
+ else:
+ # Simple addition
+ return shared_expert_out + expert_out
+
+
+# Global variable to store load balancing loss
+_LOAD_BALANCING_LOSS = []
+
+
+def save_load_balancing_loss(loss):
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.append(loss)
+
+
+def get_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ return _LOAD_BALANCING_LOSS
+
+
+def clear_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.clear()
+
+
+def batched_load_balancing_loss(args):
+ if args.moe_loss_weight == 0:
+ return 0.0
+
+ tokens_per_expert, expert_scores = zip(*get_load_balancing_loss())
+ num_layers_per_pipeline_stage = args.num_layers // args.pipeline_model_parallel_size
+ if args.num_layers_per_virtual_pipeline_stage is not None:
+ num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage
+
+ if len(tokens_per_expert) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f"Expected {num_layers_per_pipeline_stage} token_per_experts "
+ f"but found {len(tokens_per_expert)}.\nnum_layers = "
+ f"{args.num_layers}\npipeline_model_parallel_size = "
+ f"{args.pipeline_model_parallel_size}\n"
+ "num_layers_per_virtual_pipeline_stage"
+ f" = {args.num_layers_per_virtual_pipeline_stage}",
+ )
+ if len(expert_scores) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f"Expected {num_layers_per_pipeline_stage} expert_scores "
+ f"but found {len(tokens_per_expert)}.\nnum_layers = "
+ f"{args.num_layers}\npipeline_model_parallel_size = "
+ f"{args.pipeline_model_parallel_size}\n"
+ "num_layers_per_virtual_pipeline_stage"
+ f" = {args.num_layers_per_virtual_pipeline_stage}",
+ )
+
+ # Verify the shape of the tokens_per_expert and expert_scores tensors.
+ assert all(
+ (x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert)
+ )
+
+ tokens = expert_scores[0].shape[0]
+ assert all(
+ (
+ (
+ x.ndim == 2
+ and x.shape[1] == args.moe_num_experts
+ and x.shape[0] == tokens
+ )
+ for x in expert_scores
+ )
+ )
+
+ # Concatenate the contributions of each layer and convert to
+ # the correct types and formats for the dot product.
+ expert_scores = torch.cat(expert_scores, dim=1)
+ if args.moe_lbl_in_fp32:
+ expert_scores = expert_scores.float()
+ if tokens != 0:
+ expert_scores = expert_scores.mean(dim=0)
+ else:
+ expert_scores = expert_scores.sum(dim=0)
+ tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype)
+
+ expected_values = num_layers_per_pipeline_stage * args.moe_num_experts
+ assert tokens_per_expert.numel() == expected_values
+ assert expert_scores.numel() == expected_values
+
+ # Calculate the total scale across all factors.
+ #
+ # loss_weight * num_experts / (num_layers * tokens * top_k)
+ scale_numerator = args.moe_num_experts * args.moe_loss_weight
+ scale_denominator = args.num_layers * tokens * args.moe_top_k
+ scale = scale_numerator / scale_denominator
+ return scale * torch.dot(tokens_per_expert, expert_scores)
+
+
+# Calculate the expert capacity based on tokens, top_k, number of experts,
+# expert parallel group, capacity factor, and whether expert model parallelism is used.
+def expert_capacity(
+ tokens: int,
+ top_k: int,
+ num_experts: int,
+ expert_parallel_group: int,
+ moe_capacity_factor: float,
+ moe_expert_model_parallelism: bool,
+) -> int:
+ world_size = (
+ dist.get_world_size(expert_parallel_group)
+ if moe_expert_model_parallelism
+ else 1
+ )
+
+ tokens_per_expert = top_k * tokens * world_size / num_experts
+ return int(moe_capacity_factor * tokens_per_expert)
+
+
+def load_balancing_loss(
+ tokens_per_expert: torch.Tensor,
+ expert_scores: torch.Tensor,
+ top_k: int,
+ num_experts: int,
+):
+ assert len(expert_scores.size()) == 2
+ tokens, num_experts = expert_scores.size()
+ assert num_experts == num_experts
+ assert len(tokens_per_expert.size()) == 1
+ (num_experts,) = tokens_per_expert.size()
+ assert num_experts == num_experts
+ scale = num_experts / (tokens * top_k)
+ return scale * torch.dot(
+ tokens_per_expert.to(expert_scores.dtype),
+ expert_scores.mean(dim=0),
+ )
+
+
+def indices_and_bins(
+ top_expert: torch.Tensor,
+ sort_end_bit: int,
+ num_experts: int,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ top_expert = top_expert.int()
+
+ # Ensure contiguous memory layout
+ top_expert = top_expert.contiguous()
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(top_expert.device):
+ output = ops.sort(top_expert, sort_end_bit)
+ bin_ids, indices = output
+ tokens_per_expert = ops.histogram(top_expert, num_experts)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+
+ bins = bins.view(1) if not len(bins.size()) else bins
+ return indices, bin_ids, bins, tokens_per_expert
+
+
+def expert_capacity_fn(
+ tokens: int,
+ top_k: int,
+ num_experts: int,
+ expert_parallel_group: torch.distributed.ProcessGroup,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+) -> int:
+ world_size = (
+ dist.get_world_size(expert_parallel_group)
+ if moe_expert_model_parallelism
+ else 1
+ )
+ tokens_per_expert = top_k * tokens * world_size / num_experts
+ return int(moe_capacity_factor * tokens_per_expert)
+
+
+def permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ top_k,
+ w1,
+ w2,
+ w1_bias,
+ w2_bias,
+ gradient_scale,
+ alpha,
+):
+ # Route tokens to experts
+ x = x.view(-1, x.shape[-1])
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(x.device):
+ x = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
+
+ # Expert computation
+ x = mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale, alpha)
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(x.device):
+ # Route tokens back
+ out = ops.binned_scatter(x, indices, expert_weights, bins, top_k)
+ return out
+
+
+def forward_once(
+ x: torch.Tensor,
+ expert_weights: torch.Tensor,
+ top_experts: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ top_k: int = 4,
+ num_experts: int = 128,
+ expert_parallel_group: int = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ mlp_impl: Optional[str] = None,
+):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
+ top_experts, sort_end_bit, num_experts
+ )
+
+ # Calculate expert capacity
+ sl, bs, _ = x.size()
+
+ expert_capacity = expert_capacity_fn(
+ sl * bs,
+ top_k,
+ num_experts,
+ expert_parallel_group,
+ moe_capacity_factor,
+ moe_expert_model_parallelism,
+ )
+
+ if expert_capacity == 0:
+ expert_capacity = torch.max(tokens_per_expert).item()
+
+ x = permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ top_k,
+ w1,
+ w2,
+ w1_bias,
+ w2_bias,
+ gradient_scale,
+ alpha,
+ )
+ return x, tokens_per_expert
+
+
+def parallel_forward_once(
+ x: torch.Tensor,
+ expert_weights: torch.Tensor,
+ top_experts: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ top_k: int = 4,
+ num_experts: int = 128,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = True,
+ hidden_size: int = 1152,
+ mlp_impl: Optional[str] = "grouped",
+):
+ # Flatten inputs
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+
+ # TODO: remove debugging var
+ # my_rank = dist.get_rank(expert_parallel_group) if expert_parallel_group else 0
+
+ with torch.no_grad():
+ # Step 1: Local permutation setup
+ indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
+ top_experts, sort_end_bit, num_experts
+ )
+
+ # Calculate sharding parameters
+ world_size = dist.get_world_size(expert_parallel_group)
+ hidden_sharding_deg = hidden_sharding_degree(
+ world_size, num_experts, hidden_size
+ )
+ experts_per_rank_val = experts_per_rank(num_experts, world_size)
+
+ # Replicate token counts for hidden sharding
+ repeated_tokens_per_expert = ops.repeat(
+ tokens_per_expert, (hidden_sharding_deg,)
+ )
+
+ # Exchange token counts across devices
+ parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert)
+
+ # Ensure CUB knows which device to use
+ tpe_handle = dist.all_to_all_single(
+ parallel_tokens_per_expert,
+ repeated_tokens_per_expert,
+ group=expert_parallel_group,
+ async_op=True,
+ )
+
+ # Step 2: Local permutation - group tokens by target device
+ x = x.view(-1, x.shape[-1]) # [sl * bs, hs]
+ x = ops.gather(x, indices, bin_ids, bins, top_k)
+
+ # Step 3: Compute communication counts and exchange tokens
+ with torch.no_grad():
+ tpe_handle.wait()
+
+ # Reshape for per-device calculations
+ repeated_tokens_per_expert = repeated_tokens_per_expert.view(
+ world_size, experts_per_rank_val
+ )
+ parallel_tokens_per_expert = parallel_tokens_per_expert.view(
+ world_size, experts_per_rank_val
+ )
+
+ # Calculate send/recv counts
+ send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1).tolist()
+ # recv_counts = parallel_tokens_per_expert.cpu().sum(dim=-1).tolist()
+ parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
+ recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1).tolist()
+ tokens_received = sum(recv_counts)
+
+ # Replicate for hidden sharding
+ x = ops.repeat(x, (hidden_sharding_deg, 1))
+
+ # Cross-device token exchange
+ parallel_x, parallel_x_handle = _layers.all_to_all.all_to_all(
+ x, recv_counts, send_counts, expert_parallel_group, async_op=True
+ )
+
+ with torch.no_grad():
+ # Step 4: Setup for local expert computation
+ replicate_bins = ops.inclusive_cumsum(parallel_tokens_per_expert.flatten(), 0)
+ replicate_bins = (
+ replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins
+ )
+
+ # Create expert indices for received tokens
+ parallel_top_expert = torch.remainder(
+ torch.arange(
+ num_experts * hidden_sharding_deg,
+ dtype=torch.int32,
+ device=indices.device,
+ ),
+ experts_per_rank_val,
+ )
+ parallel_top_expert = ops.replicate(
+ parallel_top_expert.unsqueeze(dim=0),
+ replicate_bins,
+ tokens_received,
+ ).flatten()
+
+ # Sort tokens by expert assignment
+ parallel_bin_ids, parallel_indices = ops.sort(
+ parallel_top_expert,
+ sort_end_bit,
+ )
+
+ # Calculate bins for local experts
+ parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
+ dim=0, dtype=torch.int
+ )
+ parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = (
+ parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins
+ )
+
+ # Calculate expert capacity
+ expert_capacity = expert_capacity_fn(
+ tokens_received,
+ top_k,
+ experts_per_rank_val,
+ expert_parallel_group,
+ moe_capacity_factor,
+ moe_expert_model_parallelism,
+ )
+ if expert_capacity == 0:
+ expert_capacity = torch.max(parallel_tokens_per_expert).item()
+
+ # Locally permute the tokens and perform the expert computation.
+ # Block to make sure that the cross-device permutation is complete.
+ if mlp_impl == "grouped":
+ # GroupedMLP requires counts on CPU. We can use the tensor already
+ # moved to CPU for the prior all_to_all, which avoids an extra
+ # device synchronization.
+ parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+
+ # Step 5: Expert computation
+ parallel_x_handle.wait()
+
+ parallel_x = permute_and_compute(
+ parallel_x,
+ parallel_tokens_per_expert,
+ parallel_indices,
+ parallel_bin_ids,
+ None, # expert_weights
+ parallel_bins,
+ expert_capacity,
+ top_k=1,
+ w1=w1,
+ w2=w2,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ )
+
+ # Step 6: Reverse communication - send results back
+ x, _ = _layers.all_to_all.all_to_all(
+ parallel_x, send_counts, recv_counts, expert_parallel_group
+ )
+
+ # Step 7: Reduce across hidden sharding dimension
+ shape = (hidden_sharding_deg, -1, hidden_size)
+ x = x.view(shape).sum(dim=0)
+
+ # Step 8: Final local unpermutation
+ x = ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
+
+ return x, tokens_per_expert.flatten()
+
+
+def moe_forward(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: Optional[torch.Tensor],
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+ w1: torch.Tensor = None,
+ w2: torch.Tensor = None,
+ w1_bias: torch.Tensor = None,
+ w2_bias: torch.Tensor = None,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ forward_fn: Any = None,
+ hidden_size: int = None,
+ mlp_impl: str = "grouped",
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+
+ # Route tokens to experts
+ logits, expert_weights, expert_indices = route_tokens(
+ x,
+ router_weight,
+ router_bias,
+ moe_top_k,
+ moe_num_experts,
+ moe_jitter_eps,
+ moe_normalize_expert_weights,
+ uniform_expert_assignment,
+ training,
+ )
+
+ # Create router scores for output
+ router_scores = (
+ torch.zeros_like(logits)
+ .scatter_(1, expert_indices, expert_weights)
+ .transpose(0, 1)
+ )
+
+ in_shape = x.size()
+
+ # Prepare forward function arguments
+ forward_args = {
+ "x": x,
+ "expert_weights": expert_weights,
+ "top_experts": expert_indices,
+ "w1": w1,
+ "w2": w2,
+ "w1_bias": w1_bias,
+ "w2_bias": w2_bias,
+ "gradient_scale": gradient_scale,
+ "alpha": alpha,
+ "sort_end_bit": sort_end_bit,
+ "top_k": moe_top_k,
+ "num_experts": moe_num_experts,
+ "expert_parallel_group": expert_parallel_group,
+ "moe_capacity_factor": moe_capacity_factor,
+ "moe_expert_model_parallelism": moe_expert_model_parallelism,
+ "mlp_impl": mlp_impl,
+ }
+
+ # Add hidden_size for parallel forward
+ if moe_expert_model_parallelism and hidden_size is not None:
+ forward_args["hidden_size"] = hidden_size
+ elif moe_expert_model_parallelism and hidden_size is None:
+ # Infer hidden_size from input shape
+ forward_args["hidden_size"] = x.shape[-1]
+
+ # Compute expert outputs
+ x, tokens_per_expert = forward_fn(**forward_args)
+
+ # Save load balancing loss if needed
+ moe_loss_weight = 0.0 # Can be made configurable
+ if training and moe_loss_weight > 0:
+ save_load_balancing_loss((tokens_per_expert, logits))
+
+ # Restore original shape
+ x = x.view(in_shape)
+
+ return x, expert_weights, router_scores
+
+
+def moe_forward_with_shared_expert(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: Optional[torch.Tensor],
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+ w1: torch.Tensor = None,
+ w2: torch.Tensor = None,
+ w1_bias: torch.Tensor = None,
+ w2_bias: torch.Tensor = None,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ forward_fn: Any = None,
+ hidden_size: int = None,
+ mlp_impl: str = "grouped",
+ # Shared expert parameters
+ shared_up_proj_weight: Optional[torch.Tensor] = None,
+ shared_down_proj_weight: Optional[torch.Tensor] = None,
+ shared_up_proj_bias: Optional[torch.Tensor] = None,
+ shared_down_proj_bias: Optional[torch.Tensor] = None,
+ shared_expert_weighted_sum: bool = False,
+ shared_activation_fn: Optional[Any] = None,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+
+ # First, compute regular MoE forward pass
+ expert_out, expert_weights, router_scores = moe_forward(
+ x=x,
+ router_weight=router_weight,
+ router_bias=router_bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=training,
+ w1=w1,
+ w2=w2,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=moe_expert_model_parallelism,
+ forward_fn=forward_fn,
+ hidden_size=hidden_size,
+ mlp_impl=mlp_impl,
+ )
+
+ # If shared expert weights provided, compute shared expert output
+ if shared_up_proj_weight is not None and shared_down_proj_weight is not None:
+ shared_expert_out = shared_mlp_forward(
+ x=x,
+ up_proj_weight=shared_up_proj_weight,
+ down_proj_weight=shared_down_proj_weight,
+ up_proj_bias=shared_up_proj_bias,
+ down_proj_bias=shared_down_proj_bias,
+ activation_fn=shared_activation_fn,
+ gradient_scale=gradient_scale,
+ )
+
+ # Combine expert outputs
+ combined_out = combine_expert_shared_outputs(
+ shared_expert_out=shared_expert_out,
+ expert_out=expert_out,
+ shared_expert_weighted_sum=shared_expert_weighted_sum,
+ moe_top_k=moe_top_k,
+ )
+
+ return combined_out, expert_weights, router_scores
+
+ # Return regular MoE output if no shared expert
+ return expert_out, expert_weights, router_scores
+
+
+def create_shared_expert_weights(
+ hidden_size: int,
+ shared_expert_hidden_size: int,
+ device: torch.device,
+ dtype: torch.dtype,
+ init_method: Any,
+ output_layer_init_method: Any = None,
+) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
+
+ if output_layer_init_method is None:
+ output_layer_init_method = init_method
+
+ # Create weight tensors
+ up_proj_weight = torch.empty(
+ shared_expert_hidden_size,
+ hidden_size,
+ device=device,
+ dtype=dtype,
+ )
+ down_proj_weight = torch.empty(
+ hidden_size,
+ shared_expert_hidden_size,
+ device=device,
+ dtype=dtype,
+ )
+
+ # Initialize weights
+ init_method(up_proj_weight)
+ output_layer_init_method(down_proj_weight)
+
+ # No bias by default
+ return up_proj_weight, down_proj_weight, None, None
+
+
+# HACK: Extract device_mesh from pre-hook closure - required for transformers integration
+# This exists because device_mesh is trapped in hook closures with no model attribute
+# Fragile - breaks if hook structure changes or Python internals change
+# TODO: Replace with a more robust solution when available
+def get_device_mesh(model):
+ # Extract device_mesh from child's unused pre_hook closure
+ try:
+ # Find the pre-hook that contains 'device_mesh' in its closure
+ hook = next(
+ h
+ for h in model.experts._forward_pre_hooks.values()
+ if "device_mesh" in h.__code__.co_freevars
+ )
+ # Extract the device_mesh from the closure
+ return hook.__closure__[
+ hook.__code__.co_freevars.index("device_mesh")
+ ].cell_contents
+ except Exception:
+ return None
+
+
+class MegaBlocksMoeMLP(torch.nn.Module):
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ gradient_scale = getattr(self.experts, "gradient_scale", None)
+ alpha = getattr(self.experts, "alpha", 1.0)
+ moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+ uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
+
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = get_device_mesh(self)
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
+
+ has_parallel = (
+ expert_parallel_group is not None
+ and dist.is_initialized()
+ and dist.get_world_size(expert_parallel_group) > 1
+ )
+ forward_fn = parallel_forward_once if has_parallel else forward_once
+
+ sort_end_bit = max(
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
+ )
+ mlp_impl = getattr(self, "mlp_impl", "grouped")
+ output, expert_weights_out, *_ = moe_forward(
+ x=x,
+ router_weight=self.router.weight,
+ router_bias=self.router.bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=self.training,
+ w1=self.experts.gate_up_proj,
+ w2=self.experts.down_proj,
+ w1_bias=self.experts.gate_up_proj_bias,
+ w2_bias=self.experts.down_proj_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=has_parallel,
+ forward_fn=forward_fn,
+ hidden_size=self.experts.hidden_size,
+ mlp_impl=mlp_impl,
+ )
+ return output, expert_weights_out
+
+
+# Export main classes
+__all__ = ["MegaBlocksMoeMLP", "MegaBlocksMoeMLPWithSharedExpert"]
+
+
+class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
+
+ def __init__(self):
+ super().__init__()
+ # Shared expert weights will be set by the user
+ self.shared_up_proj_weight = None
+ self.shared_down_proj_weight = None
+ self.shared_up_proj_bias = None
+ self.shared_down_proj_bias = None
+ self.shared_expert_weighted_sum = False
+ self.shared_activation_fn = None
+
+ def set_shared_expert_weights(
+ self,
+ up_proj_weight: torch.Tensor,
+ down_proj_weight: torch.Tensor,
+ up_proj_bias: Optional[torch.Tensor] = None,
+ down_proj_bias: Optional[torch.Tensor] = None,
+ weighted_sum: bool = False,
+ activation_fn: Optional[Any] = None,
+ ):
+ self.shared_up_proj_weight = up_proj_weight
+ self.shared_down_proj_weight = down_proj_weight
+ self.shared_up_proj_bias = up_proj_bias
+ self.shared_down_proj_bias = down_proj_bias
+ self.shared_expert_weighted_sum = weighted_sum
+ self.shared_activation_fn = activation_fn
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ gradient_scale = getattr(self.experts, "gradient_scale", None)
+ alpha = getattr(self.experts, "alpha", 1.0)
+ moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+ uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
+
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = get_device_mesh(self)
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
+
+ has_parallel = (
+ expert_parallel_group is not None
+ and dist.is_initialized()
+ and dist.get_world_size(expert_parallel_group) > 1
+ )
+ forward_fn = parallel_forward_once if has_parallel else forward_once
+
+ sort_end_bit = max(
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
+ )
+ mlp_impl = getattr(self, "mlp_impl", "grouped")
+
+ output, expert_weights_out, *_ = moe_forward_with_shared_expert(
+ x=x,
+ router_weight=self.router.weight,
+ router_bias=self.router.bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=self.training,
+ w1=self.experts.gate_up_proj,
+ w2=self.experts.down_proj,
+ w1_bias=self.experts.gate_up_proj_bias,
+ w2_bias=self.experts.down_proj_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=has_parallel,
+ forward_fn=forward_fn,
+ hidden_size=self.experts.hidden_size,
+ mlp_impl=mlp_impl,
+ # Shared expert parameters
+ shared_up_proj_weight=self.shared_up_proj_weight,
+ shared_down_proj_weight=self.shared_down_proj_weight,
+ shared_up_proj_bias=self.shared_up_proj_bias,
+ shared_down_proj_bias=self.shared_down_proj_bias,
+ shared_expert_weighted_sum=self.shared_expert_weighted_sum,
+ shared_activation_fn=self.shared_activation_fn,
+ )
+ return output, expert_weights_out
+
+
+# Patch for XPU or CPU support
+if hasattr(torch, "xpu") and torch.xpu.is_available():
+ from .xpu_fused_moe import MegaBlocksMoeMLP
+
+from .cpu_moe_cpp import CPUMegaBlocksMoeMLP
diff --git a/build/torch29-cxx11-cu130-aarch64-linux/megablocks/__init__.py b/build/torch29-cxx11-cu130-aarch64-linux/megablocks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..03dbc1afe1cf156661a2b1b22003cd5f599a0309
--- /dev/null
+++ b/build/torch29-cxx11-cu130-aarch64-linux/megablocks/__init__.py
@@ -0,0 +1,26 @@
+import ctypes
+import sys
+
+import importlib
+from pathlib import Path
+from types import ModuleType
+
+def _import_from_path(file_path: Path) -> ModuleType:
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
+ # it would also be used for other imports. So, we make a module name that
+ # depends on the path for it to be unique using the hex-encoded hash of
+ # the path.
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
+ module_name = path_hash
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ if spec is None:
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
+ module = importlib.util.module_from_spec(spec)
+ if module is None:
+ raise ImportError(f"Cannot load module {module_name} from spec")
+ sys.modules[module_name] = module
+ spec.loader.exec_module(module) # type: ignore
+ return module
+
+
+globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
diff --git a/build/torch29-cxx11-cu130-aarch64-linux/metadata.json b/build/torch29-cxx11-cu130-aarch64-linux/metadata.json
new file mode 100644
index 0000000000000000000000000000000000000000..a9813b81c6c98110d265c184f2016d728202289b
--- /dev/null
+++ b/build/torch29-cxx11-cu130-aarch64-linux/metadata.json
@@ -0,0 +1,18 @@
+{
+ "version": 1,
+ "license": "Apache-2.0",
+ "python-depends": [],
+ "backend": {
+ "type": "cuda",
+ "archs": [
+ "10.0",
+ "12.0",
+ "7.5",
+ "8.0",
+ "8.6",
+ "8.7",
+ "8.9",
+ "9.0"
+ ]
+ }
+}
diff --git a/build/torch29-cxx11-cu130-aarch64-linux/ops/__init__.py b/build/torch29-cxx11-cu130-aarch64-linux/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b944080df810d0b0cfc571f3009b0098a651f9b7
--- /dev/null
+++ b/build/torch29-cxx11-cu130-aarch64-linux/ops/__init__.py
@@ -0,0 +1,35 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from .binned_gather import binned_gather
+from .binned_scatter import binned_scatter
+from .cumsum import exclusive_cumsum, inclusive_cumsum
+from .gather import gather
+from .histogram import histogram
+from .padded_gather import padded_gather
+from .padded_scatter import padded_scatter
+from .repeat import repeat
+from .replicate import replicate
+from .round_up import round_up
+from .scatter import scatter
+from .sort import sort
+from .sum import sum
+from .topology import topology
+
+__all__ = [
+ 'binned_gather',
+ 'binned_scatter',
+ 'exclusive_cumsum',
+ 'inclusive_cumsum',
+ 'gather',
+ 'histogram',
+ 'padded_gather',
+ 'padded_scatter',
+ 'repeat',
+ 'replicate',
+ 'round_up',
+ 'scatter',
+ 'sort',
+ 'sum',
+ 'topology',
+]
diff --git a/build/torch29-cxx11-cu130-aarch64-linux/ops/all_to_all_benchmark.py b/build/torch29-cxx11-cu130-aarch64-linux/ops/all_to_all_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c939818edca3345f6344bbc7cef07ffe3cd0181
--- /dev/null
+++ b/build/torch29-cxx11-cu130-aarch64-linux/ops/all_to_all_benchmark.py
@@ -0,0 +1,63 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import torch.distributed as dist
+
+# from megablocks import benchmark_util
+# from megablocks.layers.all_to_all import all_to_all
+
+from .. import benchmark_util
+from .._layers.all_to_all import all_to_all
+
+_ALL_TO_ALL_BENCHMARK = (
+ (8, 1024),
+ (16, 1024),
+ (32, 1024),
+ (64, 1024),
+ (128, 1024),
+ (256, 1024),
+ (512, 1024),
+ (1024, 1024),
+ (2 * 1024, 1024),
+ (4 * 1024, 1024),
+ (8 * 1024, 1024),
+ (16 * 1024, 1024),
+ (32 * 1024, 1024),
+ (64 * 1024, 1024),
+ (128 * 1024, 1024),
+ (256 * 1024, 1024),
+ (512 * 1024, 1024),
+ (1024 * 1024, 1024),
+)
+
+
+def benchmark_all_to_all(group, sl, hs):
+ world_size = dist.get_world_size(group)
+ assert (sl % world_size) == 0
+ send_recv_sizes = [sl // world_size] * world_size
+
+ x = torch.randn((sl, hs)).cuda().half()
+
+ details = {
+ 'world_size': world_size,
+ 'message_size (B)': send_recv_sizes[0] * hs * 2, # 2B elements.
+ }
+
+ def benchmark():
+ return all_to_all(x, send_recv_sizes, send_recv_sizes, group)
+
+ time, std = benchmark_util.benchmark_function(benchmark)
+
+ if dist.get_rank(group) == 0:
+ benchmark_util.log_benchmark('All-To-All', details, time, std)
+
+
+if __name__ == '__main__':
+ assert dist.is_available()
+ group = dist.init_process_group(backend='nccl')
+ local_rank = dist.get_rank(group)
+ torch.cuda.set_device(local_rank)
+
+ for args in _ALL_TO_ALL_BENCHMARK:
+ benchmark_all_to_all(group, *args)
diff --git a/build/torch29-cxx11-cu130-aarch64-linux/ops/binned_gather.py b/build/torch29-cxx11-cu130-aarch64-linux/ops/binned_gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..189a7fa3518d660f29ea32e7a04827164af98d60
--- /dev/null
+++ b/build/torch29-cxx11-cu130-aarch64-linux/ops/binned_gather.py
@@ -0,0 +1,37 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for binned_gather kernel.
+class BinnedGatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bins: torch.Tensor,
+ bin_size: int,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bins)
+ ctx.top_k = top_k
+ return kernels.binned_gather(x, indices, None, bins, bin_size, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ indices, bins = ctx.saved_tensors
+ out = kernels.binned_scatter(grad, indices, None, bins, ctx.top_k)
+ return out, None, None, None, None
+
+
+binned_gather = BinnedGatherOp.apply
diff --git a/build/torch29-cxx11-cu130-aarch64-linux/ops/binned_scatter.py b/build/torch29-cxx11-cu130-aarch64-linux/ops/binned_scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb937c0c106662ce8108c1cb926f8f063b163d3d
--- /dev/null
+++ b/build/torch29-cxx11-cu130-aarch64-linux/ops/binned_scatter.py
@@ -0,0 +1,59 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for binned_scatter kernel.
+class BinnedScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ):
+ assert len(x.size()) == 3
+ ctx.bin_size = x.size(1)
+ ctx.top_k = top_k
+
+ # TODO(tgale): Don't save 'x' for backwards if we don't need to
+ # calculate the gradient w.r.t. 'weights'.
+ ctx.save_for_backward(x, indices, weights, bins)
+ return kernels.binned_scatter(x, indices, weights, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ x, indices, weights, bins = ctx.saved_tensors
+ out = kernels.binned_gather(
+ grad,
+ indices,
+ weights,
+ bins,
+ ctx.bin_size,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[2]:
+ wgrad = kernels.binned_scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bins,
+ ctx.top_k,
+ )
+ return out, None, wgrad, None, None
+
+
+binned_scatter = BinnedScatterOp.apply
diff --git a/build/torch29-cxx11-cu130-aarch64-linux/ops/cumsum.py b/build/torch29-cxx11-cu130-aarch64-linux/ops/cumsum.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2b7572391e20045d335cf7337246e8a9b9f57ef
--- /dev/null
+++ b/build/torch29-cxx11-cu130-aarch64-linux/ops/cumsum.py
@@ -0,0 +1,52 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ # import megablocks_ops as ops # type: ignore
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrappers for cumsum kernels.
+# NOTE: Does not support gradients.
+class ExclusiveCumsumOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, dim: int):
+ if len(x.size()) == 1:
+ x = x.view([1, -1])
+ out = torch.empty_like(x)
+ ops.exclusive_cumsum(x, 1, out)
+ return out.squeeze()
+ out = torch.empty_like(x)
+ ops.exclusive_cumsum(x, dim, out)
+ return out
+
+
+exclusive_cumsum = ExclusiveCumsumOp.apply
+
+
+class InclusiveCumsumOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, dim: int) -> torch.Tensor:
+ if len(x.size()) == 1:
+ x = x.view([1, -1])
+ out = torch.empty_like(x)
+ ops.inclusive_cumsum(x, 1, out)
+ return out.squeeze()
+ out = torch.empty_like(x)
+ ops.inclusive_cumsum(x, dim, out)
+ return out
+
+
+inclusive_cumsum = InclusiveCumsumOp.apply
diff --git a/build/torch29-cxx11-cu130-aarch64-linux/ops/gather.py b/build/torch29-cxx11-cu130-aarch64-linux/ops/gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1f87c1e7bed8d3589dd790805234976e0b05898
--- /dev/null
+++ b/build/torch29-cxx11-cu130-aarch64-linux/ops/gather.py
@@ -0,0 +1,38 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for gather kernel.
+class GatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bin_ids, bins)
+ ctx.top_k = top_k
+ return kernels.gather(x, indices, bin_ids, None, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+
+ indices, bin_ids, bins = ctx.saved_tensors
+ out = kernels.scatter(grad, indices, bin_ids, None, bins, ctx.top_k)
+ return out, None, None, None, None, None
+
+
+gather = GatherOp.apply
diff --git a/build/torch29-cxx11-cu130-aarch64-linux/ops/histogram.py b/build/torch29-cxx11-cu130-aarch64-linux/ops/histogram.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b3f058ec373cbba7555704fb5e4212c3cc75d9d
--- /dev/null
+++ b/build/torch29-cxx11-cu130-aarch64-linux/ops/histogram.py
@@ -0,0 +1,27 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for histogram kernel.
+# NOTE: Does not support gradients.
+class HistogramOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, max_val: float):
+ return ops.histogram(x, max_val)
+
+
+histogram = HistogramOp.apply
diff --git a/build/torch29-cxx11-cu130-aarch64-linux/ops/histogram_benchmark.py b/build/torch29-cxx11-cu130-aarch64-linux/ops/histogram_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..c57b7bf8228e01237236748147368b09ffdf8072
--- /dev/null
+++ b/build/torch29-cxx11-cu130-aarch64-linux/ops/histogram_benchmark.py
@@ -0,0 +1,78 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import numpy as np
+import torch
+from absl.testing import parameterized
+
+from .. import ops
+
+_HISTOGRAM_TESTS = (
+ (16384, torch.int32, 2),
+ (16384, torch.int32, 4),
+ (16384, torch.int32, 8),
+ (16384, torch.int32, 16),
+ (16384, torch.int32, 32),
+ (16384, torch.int32, 64),
+ (16384, torch.int32, 128),
+ (16384, torch.int32, 256),
+)
+
+
+def benchmark_function(fn, iterations=10):
+ # Run once to get rid of startup overhead.
+ fn()
+ times = []
+ for _ in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+ start.record()
+ fn()
+ end.record()
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ times = np.array(times)
+ return times.mean(), times.std(), times.max(), times.min()
+
+
+def log_benchmark(arguments, mean_t, std_t):
+ print('=' * 60)
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t))
+ print('=' * 60)
+
+
+class HistogramBenchmark(parameterized.TestCase):
+
+ @parameterized.parameters(*_HISTOGRAM_TESTS)
+ def testHistogram(self, n, dtype, max_val):
+ x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
+
+ mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.histogram(x, max_val),)
+ arguments = {
+ 'n': n,
+ 'dtype': dtype,
+ 'max_val': max_val,
+ }
+ log_benchmark(arguments, mean_t, std_t)
+
+ @parameterized.parameters(*_HISTOGRAM_TESTS)
+ def testTorchHistogram(self, n, dtype, max_val):
+ x = torch.randint(0, 128, (n,)).cuda().to(dtype)
+
+ mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.histc(x, max_val, 0, max_val - 1),)
+ arguments = {
+ 'n': n,
+ 'dtype': dtype,
+ 'max_val': max_val,
+ }
+ log_benchmark(arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch29-cxx11-cu130-aarch64-linux/ops/matmul_benchmark.py b/build/torch29-cxx11-cu130-aarch64-linux/ops/matmul_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ccc5dcec5e9a663794fad944c45285869c4d1c1
--- /dev/null
+++ b/build/torch29-cxx11-cu130-aarch64-linux/ops/matmul_benchmark.py
@@ -0,0 +1,415 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+
+# import stk
+
+# try:
+# import stk
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/ops/matmul_benchmark.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+
+# Calling tensor.t() calls tensor.transpose(0, 1) which calls
+# torch.as_strided(...). Circumvent this chain to avoid an overhead
+# this adds.
+def transpose_view(x):
+ return torch.as_strided(
+ x,
+ (x.shape[1], x.shape[0]),
+ (x.stride()[1], x.stride()[0]),
+ )
+
+
+_MATMUL_TESTS = (
+ (64 * 1024, 512, 2048, 64),
+ (32 * 1024, 768, 3072, 64),
+ (8 * 1024, 1024, 4096, 64),
+ (4 * 2048, 4096, 4 * 4096, 4),
+)
+
+
+def log_benchmark(name, arguments, time, std, flops):
+ benchmark_util.log_benchmark(name, arguments, time, std)
+ print('flops = {:.2f}B'.format(flops / 1e9))
+ print('throughput = {:.2f}T'.format(flops / 1e9 / time))
+ print('=' * 60)
+
+
+class MatmulBenchmark(parameterized.TestCase):
+
+ def build_sparse_matrix(self, x, padded_bins, fhs, ne):
+ blocking = 128
+ padded_tokens, _ = x.size()
+ assert padded_tokens % blocking == 0
+ assert fhs % blocking == 0
+
+ # Offsets for the sparse matrix. All rows have the
+ # same number of nonzero blocks dictated by the
+ # dimensionality of a single expert.
+ block_rows = padded_tokens // blocking
+ blocks_per_row = fhs // blocking
+ offsets = torch.arange(
+ 0,
+ block_rows * blocks_per_row + 1,
+ blocks_per_row,
+ dtype=torch.int32,
+ device=x.device,
+ )
+
+ # Indices for the sparse matrix. The indices for
+ # the intermediate matrix are dynamic depending
+ # on the mapping of tokens to experts.
+ column_indices = ops.topology(
+ padded_bins,
+ blocking,
+ block_rows,
+ blocks_per_row,
+ )
+ data = torch.empty(
+ column_indices.numel(),
+ blocking,
+ blocking,
+ dtype=torch.float16,
+ device=x.device,
+ )
+ shape = (padded_tokens, fhs * ne)
+ row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
+ return stk.Matrix(shape, data, row_indices, column_indices, offsets)
+
+ def build_input_matrix(self, sl, hs, ne):
+ x = torch.randn((sl, hs)).cuda().half()
+
+ # Assign tokens to experts uniformly.
+ top_expert = torch.arange(0, sl).cuda().int() % ne
+
+ bin_ids, indices = ops.sort(top_expert)
+ tokens_per_expert = ops.histogram(top_expert, ne)
+ padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ out = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, 1)
+ return out, padded_bins
+
+ def build_weight_matrix(self, ne, hs, fhs):
+ return torch.randn((hs, ne * fhs)).cuda().half()
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear0_Fwd_SDD_NT(self, sl, hs, fhs, ne):
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
+ w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+ topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+ w = transpose_view(w)
+
+ def benchmark():
+ return stk.ops.sdd(x, w, topo)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '0::Fwd::SDD::NT',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * fhs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear0_GradX_DSD_NN(self, sl, hs, fhs, ne):
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
+ w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+ topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+
+ def benchmark():
+ return stk.ops.dsd(topo, w)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '0::GradX::DSD::NN',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * fhs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear0_GradW_DSD_TN(self, sl, hs, fhs, ne):
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
+ topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+ topo = topo.t()
+
+ def benchmark():
+ return stk.ops.dsd(topo, x)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '0::GradW::DSD::TN',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * fhs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear1_Fwd_DSD_NN(self, sl, hs, fhs, ne):
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
+ w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+ x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+
+ def benchmark():
+ return stk.ops.dsd(x, w)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '1::Fwd::DSD::NN',
+ arguments,
+ mean_t,
+ std_t,
+ x.nnz * hs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear1_GradX_SDD_NT(self, sl, hs, fhs, ne):
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
+ w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+ x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+ out = stk.ops.dsd(x, w)
+ w = transpose_view(w)
+
+ def benchmark():
+ return stk.ops.sdd(out, w, x)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '1::GradX::SDD::NT',
+ arguments,
+ mean_t,
+ std_t,
+ x.nnz * hs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear1_GradW_DSD_TN(self, sl, hs, fhs, ne):
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
+ w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+ x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+ out = stk.ops.dsd(x, w)
+ x = x.t()
+
+ def benchmark():
+ return stk.ops.dsd(x, out)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '1::GradW::DSD::TN',
+ arguments,
+ mean_t,
+ std_t,
+ x.nnz * hs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear0_Fwd_DDD_NT(self, sl, hs, fhs, ne):
+ assert (sl % ne) == 0
+ x = torch.randn((ne, sl // ne, hs)).cuda().half()
+ w = torch.randn((ne, hs, fhs)).cuda().half()
+
+ w = w.transpose(1, 2).contiguous()
+ w = w.transpose(1, 2)
+
+ def benchmark():
+ return torch.bmm(x, w)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '0::Fwd:DDD::NT',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * fhs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear0_GradX_DDD_NN(self, sl, hs, fhs, ne):
+ assert (sl % ne) == 0
+ x = torch.randn((ne, sl // ne, hs)).cuda().half()
+ w = torch.randn((ne, hs, fhs)).cuda().half()
+ out = torch.bmm(x, w)
+ w = w.transpose(1, 2).contiguous()
+
+ def benchmark():
+ return torch.bmm(out, w)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '0:GradX:DDD::NN',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * fhs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear0_GradW_DDD_TN(self, sl, hs, fhs, ne):
+ assert (sl % ne) == 0
+ x = torch.randn((ne, sl // ne, hs)).cuda().half()
+ w = torch.randn((ne, hs, fhs)).cuda().half()
+ out = torch.bmm(x, w)
+ out = out.transpose(1, 2)
+
+ def benchmark():
+ return torch.bmm(out, x)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '0:GradW:DDD::TN',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * fhs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear1_Fwd_DDD_NN(self, sl, hs, fhs, ne):
+ assert (sl % ne) == 0
+ x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+ w = torch.randn((ne, fhs, hs)).cuda().half()
+
+ def benchmark():
+ return torch.bmm(x, w)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '1::Fwd::DDD::NN',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * hs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear1_GradX_DDD_NT(self, sl, hs, fhs, ne):
+ assert (sl % ne) == 0
+ x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+ w = torch.randn((ne, fhs, hs)).cuda().half()
+ out = torch.bmm(x, w)
+ w = torch.transpose(w, 1, 2)
+
+ def benchmark():
+ return torch.bmm(out, w)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '1::GradX::DDD::NT',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * hs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear1_GradW_DDD_TN(self, sl, hs, fhs, ne):
+ assert (sl % ne) == 0
+ x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+ w = torch.randn((ne, fhs, hs)).cuda().half()
+ out = torch.bmm(x, w)
+ x = torch.transpose(x, 1, 2)
+
+ def benchmark():
+ return torch.bmm(x, out)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '1::GradW::DDD::TN',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * hs * 2,
+ )
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch29-cxx11-cu130-aarch64-linux/ops/padded_gather.py b/build/torch29-cxx11-cu130-aarch64-linux/ops/padded_gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1cf4047c9494394d2a3884ba8830179013db7ff
--- /dev/null
+++ b/build/torch29-cxx11-cu130-aarch64-linux/ops/padded_gather.py
@@ -0,0 +1,55 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for padded_gather kernel.
+class PaddedGatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bin_ids, bins, padded_bins)
+ ctx.top_k = top_k
+ return kernels.padded_gather(
+ x,
+ indices,
+ bin_ids,
+ None,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+
+ indices, bin_ids, bins, padded_bins = ctx.saved_tensors
+ out = kernels.padded_scatter(
+ grad,
+ indices,
+ bin_ids,
+ None,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+ return out, None, None, None, None, None
+
+
+padded_gather = PaddedGatherOp.apply
diff --git a/build/torch29-cxx11-cu130-aarch64-linux/ops/padded_scatter.py b/build/torch29-cxx11-cu130-aarch64-linux/ops/padded_scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..61e021b81497e472cda5d72bdac557a0ca92d262
--- /dev/null
+++ b/build/torch29-cxx11-cu130-aarch64-linux/ops/padded_scatter.py
@@ -0,0 +1,98 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for padded_scatter kernel.
+class PaddedScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+ ):
+ maybe_x = [x] if ctx.needs_input_grad[3] else []
+ ctx.save_for_backward(
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ *maybe_x,
+ )
+ ctx.top_k = top_k
+ ctx.x_shape = x.shape
+ return kernels.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ saved_tensors = ctx.saved_tensors
+
+ indices, bin_ids, weights, bins, padded_bins = saved_tensors[:5]
+ dgrad = None
+ if ctx.needs_input_grad[0]:
+ dgrad = kernels.padded_gather(
+ grad,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[3]: # need wgrad
+ x = saved_tensors[-1]
+ wgrad = kernels.padded_scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+ return dgrad, None, None, wgrad, None, None, None, None
+
+
+def padded_scatter(
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+):
+ return PaddedScatterOp.apply(
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
diff --git a/build/torch29-cxx11-cu130-aarch64-linux/ops/padded_scatter_benchmark.py b/build/torch29-cxx11-cu130-aarch64-linux/ops/padded_scatter_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..c575cfe7487d346ba9ec18bbb7ef17f2eb77ec51
--- /dev/null
+++ b/build/torch29-cxx11-cu130-aarch64-linux/ops/padded_scatter_benchmark.py
@@ -0,0 +1,66 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import torch
+from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+_PADDED_SCATTER_BENCHMARK = (
+ # dMoE-Medium, 8-way EMP.
+ (1024 * 16, 1024, 8, 4),
+ # dMoE-Medium, post-all-to-all.
+ (1024 * 16 * 4, 1024, 8, 1),
+)
+
+
+class PaddedScatterTest(parameterized.TestCase):
+
+ @parameterized.parameters(*_PADDED_SCATTER_BENCHMARK)
+ def testPaddedScatter(self, sl, hs, ne, top_k):
+ # Create the data and indices.
+ x = torch.randn((sl, hs)).cuda().half()
+
+ # Randomly assign tokens to experts.
+ top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int()
+ bin_ids, indices = ops.sort(top_expert)
+ tokens_per_expert = ops.histogram(top_expert, ne)
+ padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+
+ # Sample weights for the scatter reduce.
+ weights = torch.rand((sl * top_k,)).cuda().half()
+
+ # Gather the data to prepare for backwards.
+ x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
+
+ def benchmark():
+ return ops.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ time, std = benchmark_util.benchmark_function(benchmark)
+ benchmark_util.log_benchmark(
+ 'Padded Scatter',
+ {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'num_experts': ne,
+ 'top_k': top_k,
+ },
+ time,
+ std,
+ )
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch29-cxx11-cu130-aarch64-linux/ops/permute_benchmark.py b/build/torch29-cxx11-cu130-aarch64-linux/ops/permute_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..6536eeeae402659a087e5c51ef9840627af56501
--- /dev/null
+++ b/build/torch29-cxx11-cu130-aarch64-linux/ops/permute_benchmark.py
@@ -0,0 +1,149 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import torch
+from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+_PERMUTE_TESTS = (
+ (16384, 768, 2),
+ (16384, 768, 4),
+ (16384, 768, 8),
+ (16384, 768, 16),
+ (16384, 768, 32),
+ (16384, 768, 64),
+ (16384, 768, 128),
+ (16384 * 8, 768, 2),
+ (16384 * 8, 768, 4),
+ (16384 * 8, 768, 8),
+ (16384 * 8, 768, 16),
+ (16384 * 8, 768, 32),
+ (16384 * 8, 768, 64),
+ (16384 * 8, 768, 128),
+)
+
+
+class PermuteBenchmark(parameterized.TestCase):
+
+ @parameterized.parameters(*_PERMUTE_TESTS)
+ def testBinnedGather(self, sl, hs, ne):
+ # NOTE: Capacity factor == 1.
+ ec = sl // ne
+
+ # Create the data and indices.
+ x = torch.randn((sl, hs)).cuda().half()
+ top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+ bin_ids, indices = ops.sort(top_expert)
+ tokens_per_expert = ops.histogram(indices, ne)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+
+ def benchmark():
+ return ops.binned_gather(x, indices, bins, ec)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'num_experts': ne,
+ }
+ benchmark_util.log_benchmark('BinnedGather', arguments, mean_t, std_t)
+
+ @parameterized.parameters(*_PERMUTE_TESTS)
+ def testBinnedScatter(self, sl, hs, ne):
+ # NOTE: Capacity factor == 1.
+ ec = sl // ne
+
+ # Create the data and indices.
+ x = torch.randn((sl, hs)).cuda().half()
+ top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+ bin_ids, indices = ops.sort(top_expert)
+ tokens_per_expert = ops.histogram(indices, ne)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ x = ops.binned_gather(x, indices, bins, ec)
+
+ def benchmark():
+ return ops.binned_scatter(x, indices, bins)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'num_experts': ne,
+ }
+ benchmark_util.log_benchmark('BinnedScatter', arguments, mean_t, std_t)
+
+ @parameterized.parameters(*_PERMUTE_TESTS)
+ def testPaddedGather(self, sl, hs, ne):
+ # Create the data and indices.
+ x = torch.randn((sl, hs)).cuda().half()
+
+ # Randomly assign tokens to experts.
+ top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+ bin_ids, indices = ops.sort(top_expert)
+ tokens_per_expert = ops.histogram(top_expert, ne)
+ padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+
+ def benchmark():
+ return ops.padded_gather(x, indices, bin_ids, bins, padded_bins)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'num_experts': ne,
+ }
+ benchmark_util.log_benchmark('PaddedGather', arguments, mean_t, std_t)
+
+ @parameterized.parameters(*_PERMUTE_TESTS)
+ def testPaddedScatter(self, sl, hs, ne):
+ # Create the data and indices.
+ x = torch.randn((sl, hs)).cuda().half()
+
+ # Randomly assign tokens to experts.
+ top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+ bin_ids, indices = ops.sort(top_expert)
+ tokens_per_expert = ops.histogram(top_expert, ne)
+ padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins)
+
+ def benchmark():
+ return ops.padded_scatter(x, indices, bin_ids, bins, padded_bins)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'num_experts': ne,
+ }
+ benchmark_util.log_benchmark('PaddedScatter', arguments, mean_t, std_t)
+
+ @parameterized.parameters(*_PERMUTE_TESTS)
+ def testCopy(self, sl, hs, ne):
+ # NOTE: Capacity factor == 1.
+ # ec = sl // ne
+
+ # Create the data and indices.
+ x = torch.randn((sl, hs)).cuda().half()
+ y = x.clone()
+
+ def benchmark():
+ return y.copy_(x)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'num_experts': ne,
+ }
+ benchmark_util.log_benchmark('Copy', arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch29-cxx11-cu130-aarch64-linux/ops/repeat.py b/build/torch29-cxx11-cu130-aarch64-linux/ops/repeat.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e9e09de5f857d51cd758ab30b2f3a846d6f9275
--- /dev/null
+++ b/build/torch29-cxx11-cu130-aarch64-linux/ops/repeat.py
@@ -0,0 +1,10 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+
+def repeat(x: torch.Tensor, tiling: torch.Size):
+ if all((t == 1 for t in tiling)):
+ return x
+ return x.repeat(*tiling)
diff --git a/build/torch29-cxx11-cu130-aarch64-linux/ops/replicate.py b/build/torch29-cxx11-cu130-aarch64-linux/ops/replicate.py
new file mode 100644
index 0000000000000000000000000000000000000000..26daf0eede330603a4b8ea7167faf1411d07ca93
--- /dev/null
+++ b/build/torch29-cxx11-cu130-aarch64-linux/ops/replicate.py
@@ -0,0 +1,36 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for replicate kernel.
+class ReplicateOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, bins: torch.Tensor, num_outputs: int):
+ ctx.save_for_backward(bins)
+ out = torch.empty((x.shape[0], num_outputs), dtype=x.dtype, device=x.device)
+ ops.replicate_forward(x, bins, out)
+ return out
+
+ @staticmethod
+ def backward(ctx: Any, grad: torch.Tensor):
+ bins, = ctx.saved_tensors
+ out = torch.empty((grad.shape[0], bins.shape[0]), dtype=grad.dtype, device=grad.device)
+ ops.replicate_backward(grad, bins, out)
+ return out, None, None
+
+
+replicate = ReplicateOp.apply
diff --git a/build/torch29-cxx11-cu130-aarch64-linux/ops/round_up.py b/build/torch29-cxx11-cu130-aarch64-linux/ops/round_up.py
new file mode 100644
index 0000000000000000000000000000000000000000..6cf6bc873c9f448c5fa9126ebcfd66e8688002af
--- /dev/null
+++ b/build/torch29-cxx11-cu130-aarch64-linux/ops/round_up.py
@@ -0,0 +1,14 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+
+def round_up(x: torch.Tensor, value: int):
+ assert isinstance(value, int)
+ assert x.dtype == torch.int32
+
+ # TODO(tgale): If this becomes and issue
+ # do this in a custom kernel. We only expect
+ # to use this on arrays of less than 1k elements.
+ return torch.div(x + (value - 1), value, rounding_mode='trunc') * value
diff --git a/build/torch29-cxx11-cu130-aarch64-linux/ops/scatter.py b/build/torch29-cxx11-cu130-aarch64-linux/ops/scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4605d9b46f387761b070352365f223dbfe69d47
--- /dev/null
+++ b/build/torch29-cxx11-cu130-aarch64-linux/ops/scatter.py
@@ -0,0 +1,72 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Optional
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for scatter kernel.
+class ScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ) -> torch.Tensor:
+ maybe_x = [x] if ctx.needs_input_grad[3] else []
+ ctx.save_for_backward(indices, bin_ids, weights, bins, *maybe_x)
+ ctx.top_k = top_k
+ ctx.x_shape = x.shape
+ return kernels.scatter(x, indices, bin_ids, weights, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ saved_tensors = ctx.saved_tensors
+
+ indices, bin_ids, weights, bins = saved_tensors[:4]
+ dgrad = None
+ if ctx.needs_input_grad[0]:
+ dgrad = kernels.gather(
+ grad,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[3]: # need wgrad
+ x = saved_tensors[-1]
+ wgrad = kernels.scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bin_ids,
+ bins,
+ ctx.top_k,
+ )
+ return dgrad, None, None, wgrad, None, None, None
+
+
+def scatter(
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+) -> Optional[torch.Tensor]:
+ return ScatterOp.apply(x, indices, bin_ids, weights, bins, top_k)
diff --git a/build/torch29-cxx11-cu130-aarch64-linux/ops/sort.py b/build/torch29-cxx11-cu130-aarch64-linux/ops/sort.py
new file mode 100644
index 0000000000000000000000000000000000000000..bda3bf64283e39533c2eae3627e76bb2d0262c9f
--- /dev/null
+++ b/build/torch29-cxx11-cu130-aarch64-linux/ops/sort.py
@@ -0,0 +1,38 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Optional, Tuple
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+_BITS_FOR_DTYPE = {
+ torch.int16: 16,
+ torch.int32: 32,
+ torch.int64: 64,
+}
+
+
+# Autograd wrapper for sort kernel.
+# NOTE: Does not support gradients.
+class SortOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, end_bit: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
+ if end_bit is None:
+ end_bit = _BITS_FOR_DTYPE[x.dtype]
+ x_out = torch.empty_like(x)
+ iota_out = torch.empty_like(x)
+ ops.sort(x, end_bit, x_out, iota_out)
+ return (x_out, iota_out)
+
+
+sort = SortOp.apply
diff --git a/build/torch29-cxx11-cu130-aarch64-linux/ops/sort_benchmark.py b/build/torch29-cxx11-cu130-aarch64-linux/ops/sort_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..a92ff957d4c552c6e61d9279a7989795472af7b7
--- /dev/null
+++ b/build/torch29-cxx11-cu130-aarch64-linux/ops/sort_benchmark.py
@@ -0,0 +1,85 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import numpy as np
+import torch
+from absl.testing import parameterized
+
+from .. import ops
+
+_SORT_TESTS = (
+ (16384, torch.int32, None),
+ (16384, torch.int32, 2),
+ (16384, torch.int32, 128),
+)
+
+_BASELINE_SORT_TESTS = ((16384,),)
+
+
+def numpy_dtype(dtype):
+ types = {
+ torch.int16: np.int16,
+ torch.int32: np.int32,
+ torch.int64: np.int64,
+ }
+ return types[dtype]
+
+
+def benchmark_function(fn, iterations=10):
+ # Run once to get rid of startup overhead.
+ fn()
+ times = []
+ for _ in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+ start.record()
+ fn()
+ end.record()
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ times = np.array(times)
+ return times.mean(), times.std(), times.max(), times.min()
+
+
+def log_benchmark(arguments, mean_t, std_t):
+ print('=' * 60)
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t))
+ print('=' * 60)
+
+
+class SortBenchmark(parameterized.TestCase):
+
+ @parameterized.parameters(*_SORT_TESTS)
+ def testSort(self, n, dtype, max_val):
+ if max_val is None:
+ max_val = np.iinfo(numpy_dtype(dtype)).max
+ end_bit = int(np.ceil(np.log2(max_val)))
+ x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
+
+ mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.sort(x, end_bit),)
+ arguments = {
+ 'n': n,
+ 'dtype': dtype,
+ 'max_val': max_val,
+ }
+ log_benchmark(arguments, mean_t, std_t)
+
+ @parameterized.parameters(*_BASELINE_SORT_TESTS)
+ def testTorchSort(self, n):
+ x = torch.randint(0, 128, (n,)).cuda().to(torch.int32)
+
+ mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.sort(x))
+ arguments = {
+ 'n': n,
+ }
+ log_benchmark(arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch29-cxx11-cu130-aarch64-linux/ops/stk_autocast.py b/build/torch29-cxx11-cu130-aarch64-linux/ops/stk_autocast.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a3626e5e0eec51339c95a448bca84be14a2ca93
--- /dev/null
+++ b/build/torch29-cxx11-cu130-aarch64-linux/ops/stk_autocast.py
@@ -0,0 +1,39 @@
+# vendored from
+# https://github.com/stanford-futuredata/stk/blob/736313768ef697ce13a0594a41b2512a0fbc9884/stk/backend/autocast.py
+import functools
+import torch
+
+
+def _is_eligible(x):
+ return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64)
+
+
+def _cast(x, dtype):
+ if isinstance(x, torch.Tensor) and _is_eligible(x):
+ return x.to(dtype)
+ elif isinstance(x, map):
+ return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()}
+ elif isinstance(x, list) or isinstance(x, tuple):
+ return type(x)(map(lambda y: _cast(y, dtype), x))
+ return x
+
+
+def custom_fwd(fwd):
+ """Wrap a custom autograd function that always uses autocast dtype."""
+
+ @functools.wraps(fwd)
+ def decorate_fwd(*args, **kwargs):
+ if torch.is_autocast_enabled():
+ with torch.autocast(device_type="cuda", enabled=False):
+ dtype = torch.get_autocast_gpu_dtype()
+ return fwd(*_cast(args, dtype), **_cast(kwargs, dtype))
+ return fwd(*args, **kwargs)
+ return decorate_fwd
+
+
+def custom_bwd(bwd):
+ @functools.wraps(bwd)
+ def decorate_bwd(*args, **kwargs):
+ with torch.autocast(device_type="cuda", enabled=False):
+ return bwd(*args, **kwargs)
+ return decorate_bwd
\ No newline at end of file
diff --git a/build/torch29-cxx11-cu130-aarch64-linux/ops/sum.py b/build/torch29-cxx11-cu130-aarch64-linux/ops/sum.py
new file mode 100644
index 0000000000000000000000000000000000000000..e00c1aa68e1193f5b72f75a2edc37de8d505facc
--- /dev/null
+++ b/build/torch29-cxx11-cu130-aarch64-linux/ops/sum.py
@@ -0,0 +1,9 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+import torch
+
+
+def sum(x: torch.Tensor, dim: int = 0):
+ if x.shape[dim] == 1:
+ return x.squeeze(dim=dim)
+ return x.sum(dim=dim)
diff --git a/build/torch29-cxx11-cu130-aarch64-linux/ops/topology.py b/build/torch29-cxx11-cu130-aarch64-linux/ops/topology.py
new file mode 100644
index 0000000000000000000000000000000000000000..76a50d3164db20534b099dcb4d8487a7aef25d15
--- /dev/null
+++ b/build/torch29-cxx11-cu130-aarch64-linux/ops/topology.py
@@ -0,0 +1,45 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for topology kernel.
+# NOTE: Does not support gradients.
+class TopologyOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(
+ ctx: Any,
+ padded_bins: torch.Tensor,
+ block_size: int,
+ output_block_rows: int,
+ output_block_columns: int,
+ ):
+ out = torch.empty(
+ output_block_rows * output_block_columns,
+ dtype=torch.int16,
+ device=padded_bins.device,
+ )
+ ops.indices(
+ padded_bins,
+ block_size,
+ output_block_rows,
+ output_block_columns,
+ out,
+ )
+ return out
+
+
+topology = TopologyOp.apply
diff --git a/build/torch29-cxx11-cu130-aarch64-linux/stk/__init__.py b/build/torch29-cxx11-cu130-aarch64-linux/stk/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..73c40c267b1c3f4949e9c957a5d2c9f682dfc1a6
--- /dev/null
+++ b/build/torch29-cxx11-cu130-aarch64-linux/stk/__init__.py
@@ -0,0 +1,7 @@
+# import stk.random
+# import stk.ops
+# from stk.matrix import Matrix
+
+from . import random
+from . import ops
+from .matrix import Matrix
diff --git a/build/torch29-cxx11-cu130-aarch64-linux/stk/backend/__init__.py b/build/torch29-cxx11-cu130-aarch64-linux/stk/backend/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/build/torch29-cxx11-cu130-aarch64-linux/stk/backend/autocast.py b/build/torch29-cxx11-cu130-aarch64-linux/stk/backend/autocast.py
new file mode 100644
index 0000000000000000000000000000000000000000..97f6e919a60f3fd579ed0215031008d14111dc96
--- /dev/null
+++ b/build/torch29-cxx11-cu130-aarch64-linux/stk/backend/autocast.py
@@ -0,0 +1,37 @@
+import functools
+import torch
+
+
+def _is_eligible(x):
+ return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64)
+
+
+def _cast(x, dtype):
+ if isinstance(x, torch.Tensor) and _is_eligible(x):
+ return x.to(dtype)
+ elif isinstance(x, map):
+ return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()}
+ elif isinstance(x, list) or isinstance(x, tuple):
+ return type(x)(map(lambda y: _cast(y, dtype), x))
+ return x
+
+
+def custom_fwd(fwd):
+ """Wrap a custom autograd function that always uses autocast dtype."""
+
+ @functools.wraps(fwd)
+ def decorate_fwd(*args, **kwargs):
+ if torch.is_autocast_enabled():
+ with torch.autocast(device_type="cuda", enabled=False):
+ dtype = torch.get_autocast_gpu_dtype()
+ return fwd(*_cast(args, dtype), **_cast(kwargs, dtype))
+ return fwd(*args, **kwargs)
+ return decorate_fwd
+
+
+def custom_bwd(bwd):
+ @functools.wraps(bwd)
+ def decorate_bwd(*args, **kwargs):
+ with torch.autocast(device_type="cuda", enabled=False):
+ return bwd(*args, **kwargs)
+ return decorate_bwd
diff --git a/build/torch29-cxx11-cu130-aarch64-linux/stk/backend/sputnik.py b/build/torch29-cxx11-cu130-aarch64-linux/stk/backend/sputnik.py
new file mode 100644
index 0000000000000000000000000000000000000000..220c947bc1e932e8c77cc30f4069e9930f1aa962
--- /dev/null
+++ b/build/torch29-cxx11-cu130-aarch64-linux/stk/backend/sputnik.py
@@ -0,0 +1,316 @@
+import torch
+
+from ..backend import triton_kernels as backend
+from ..backend.autocast import custom_bwd, custom_fwd
+
+
+def _standardize_shape(x, transpose):
+ if transpose:
+ return torch.Size((x[1], x[0]))
+ return x
+
+
+def _sparse_transpose(x):
+ return (torch.Size((x[0][1], x[0][0])), ) + x[1:]
+
+
+def _transpose_helper(x, transpose):
+ if isinstance(x, torch.Tensor):
+ return x.t() if transpose else x
+ if transpose:
+ x = _sparse_transpose(x)
+ return x + (transpose,)
+
+
+def _wrap(x):
+ if isinstance(x, torch.Tensor):
+ return (x,)
+ return x
+
+
+def _is_transposed(x):
+ return (not x.is_contiguous() and
+ x.stride()[0] == 1 and
+ x.stride()[1] == x.size()[0])
+
+
+def _call_helper(op, out, a, b, trans_a, trans_b):
+ args = (_wrap(_transpose_helper(a, trans_a)) +
+ _wrap(_transpose_helper(b, trans_b)))
+ if isinstance(out, tuple):
+ args = args + out
+ return op(*args)
+
+
+def _preprocess_inputs(lhs, rhs, dy):
+ if isinstance(lhs, torch.Tensor) and _is_transposed(lhs):
+ lhs = lhs.t()
+ if isinstance(rhs, torch.Tensor) and _is_transposed(rhs):
+ rhs = rhs.t()
+ if (isinstance(dy, torch.Tensor) and
+ not dy.is_contiguous() and
+ not _is_transposed(dy)):
+ dy = dy.contiguous()
+ if isinstance(dy, tuple) and not dy[1].is_contiguous():
+ dy = (dy[0], dy[1].contiguous()) + dy[2:]
+ return lhs, rhs, dy
+
+
+def _postprocess_outputs(x, transpose, grad):
+ if isinstance(x, torch.Tensor) and transpose:
+ return grad.t()
+ return grad
+
+
+def _lhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs):
+ lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy)
+
+ a, b = (rhs, dy) if trans_lhs else (dy, rhs)
+ trans_a = trans_lhs and trans_rhs
+ trans_b = trans_lhs or not trans_rhs
+ out = _call_helper(op, lhs, a, b, trans_a, trans_b)
+ return _postprocess_outputs(lhs, trans_lhs, out)
+
+
+def _rhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs):
+ lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy)
+
+ a, b = (dy, lhs) if trans_rhs else (lhs, dy)
+ trans_a = not trans_lhs or trans_rhs
+ trans_b = trans_lhs and trans_rhs
+ out = _call_helper(op, rhs, a, b, trans_a, trans_b)
+ return _postprocess_outputs(rhs, trans_rhs, out)
+
+
+class DSD(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs):
+ ctx.save_for_backward(data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ rhs)
+ ctx.shape = _standardize_shape(shape, transpose_a)
+ ctx.transpose_a = transpose_a
+
+ out = torch.empty(
+ (shape[0], rhs.size()[1]),
+ dtype=rhs.dtype,
+ device=rhs.device)
+
+ backend.dsd(shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs,
+ out)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs = (ctx.shape,) + saved_tensors[:-1]
+ rhs = saved_tensors[-1]
+ trans_a = ctx.transpose_a
+ trans_b = _is_transposed(rhs)
+
+ ddata = None
+ if ctx.needs_input_grad[1]:
+ ddata = _lhs_gradient(sdd,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ drhs = None
+ if ctx.needs_input_grad[-1]:
+ op = dds if trans_b else dsd
+ drhs = _rhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return None, ddata, None, None, None, None, None, None, None, drhs
+
+
+dsd = DSD.apply
+
+
+class DDS(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b):
+ ctx.save_for_backward(lhs,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t)
+ ctx.shape = _standardize_shape(shape, transpose_b)
+ ctx.transpose_b = transpose_b
+ out = torch.empty((lhs.size()[0], shape[1]),
+ dtype=lhs.dtype,
+ device=lhs.device)
+ backend.dds(lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b,
+ out)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs = saved_tensors[0]
+ rhs = (ctx.shape,) + saved_tensors[1:]
+ trans_a = _is_transposed(lhs)
+ trans_b = ctx.transpose_b
+
+ dlhs = None
+ if ctx.needs_input_grad[0]:
+ op = dsd if trans_a else dds
+ dlhs = _lhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ ddata = None
+ if ctx.needs_input_grad[2]:
+ ddata = _rhs_gradient(sdd,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return dlhs, None, ddata, None, None, None, None, None, None, None
+
+
+dds = DDS.apply
+
+
+class SDD(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ lhs,
+ rhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t):
+ ctx.save_for_backward(
+ lhs,
+ rhs,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t)
+ ctx.shape = shape
+ out = torch.empty(
+ data.shape,
+ dtype=lhs.dtype,
+ device=lhs.device)
+ backend.sdd(lhs,
+ rhs,
+ shape,
+ out,
+ offsets,
+ row_indices,
+ column_indices)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs, rhs = saved_tensors[:2]
+ dy = (ctx.shape, dy) + saved_tensors[2:]
+ trans_a = _is_transposed(lhs)
+ trans_b = _is_transposed(rhs)
+
+ dlhs = None
+ if ctx.needs_input_grad[0]:
+ op = dds if trans_a else dsd
+ dlhs = _lhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ drhs = None
+ if ctx.needs_input_grad[1]:
+ op = dsd if trans_b else dds
+ drhs = _rhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return dlhs, drhs, None, None, None, None, None, None, None, None
+
+
+sdd = SDD.apply
+
+class RowIndices(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, shape, data, offsets, column_indices):
+ out = torch.empty(
+ column_indices.shape,
+ dtype=column_indices.dtype,
+ device=column_indices.device)
+ backend.row_indices(shape, data, offsets, column_indices, out)
+ return out
+
+
+row_indices = RowIndices.apply
diff --git a/build/torch29-cxx11-cu130-aarch64-linux/stk/backend/triton_kernels.py b/build/torch29-cxx11-cu130-aarch64-linux/stk/backend/triton_kernels.py
new file mode 100644
index 0000000000000000000000000000000000000000..c535309f3321249f475367164a558f94a4f8eb86
--- /dev/null
+++ b/build/torch29-cxx11-cu130-aarch64-linux/stk/backend/triton_kernels.py
@@ -0,0 +1,393 @@
+import torch
+import triton
+import triton.language as tl
+from dataclasses import dataclass
+
+@dataclass
+class TritonConfig:
+ BLOCK_M: int = 128
+ BLOCK_N: int = 128
+ BLOCK_K: int = 32
+ BLOCK_SIZE: int = 128
+ NUM_STAGES: int = 4
+ NUM_WARPS: int = 4
+
+def _validate_matmul_dims(M: int, K: int, N: int):
+ error_string = "incompatible dimensions: tensor has dim with length: {}, which must be divisible by {}"
+ assert M % TritonConfig.BLOCK_M == 0, error_string.format(M, TritonConfig.BLOCK_M)
+ assert K % TritonConfig.BLOCK_K == 0, error_string.format(K, TritonConfig.BLOCK_K)
+ assert N % TritonConfig.BLOCK_N == 0, error_string.format(N, TritonConfig.BLOCK_N)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _sdd_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+ # matrix multiplication
+ pid = tl.program_id(0)
+ pid_m = tl.load(row_indices + pid)
+ pid_n = tl.load(column_indices + pid)
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
+ rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
+ rk = tl.arange(0, BLOCK_K)
+ # pointers
+ A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
+ B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ for k in range(0, tl.cdiv(K, BLOCK_K)):
+ a = tl.load(A)
+ b = tl.load(B)
+ acc += tl.dot(a, b)
+ A += BLOCK_K * stride_ak
+ B += BLOCK_K * stride_bk
+ #Store to sparse matrix
+ acc = acc.to(C.dtype.element_ty)
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+ cm = tl.arange(0, BLOCK_M)
+ cn = tl.arange(0, BLOCK_N)
+ C = C + pid * BLOCK_ELEMENTS + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _dsd_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices, offsets,
+ block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+
+ # matrix multiplication
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+
+ num_pid_m = tl.num_programs(0)
+ num_pid_n = tl.num_programs(1)
+ pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M)
+
+ start_inx = tl.load(offsets + pid_m)
+ end_inx = tl.load(offsets + pid_m + 1)
+
+ # pointers to sparse matrix
+ rm = tl.arange(0, BLOCK_M)
+ rak = tl.arange(0, BLOCK_K)
+
+ A += (rm[:, None] * stride_am + rak[None, :] * stride_ak)
+
+ # pointers to dense matrix
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ rbk = tl.arange(0, BLOCK_K)
+ B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn)
+
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K)
+
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+ ak_sub_incr = BLOCK_K * stride_ak
+ bk_sub_incr = BLOCK_K * stride_bk
+ bk_block_incr = BLOCK_SIZE * stride_bk
+
+ for k in range(nsub_blocks * (end_inx - start_inx)):
+ sub_block_inx = k % nsub_blocks
+ block_inx = k // nsub_blocks
+
+ if trans_A:
+ ptr_A = A + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr
+ else:
+ ptr_A = A + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr
+
+ ptr_B = B + tl.load(column_indices + start_inx + block_inx) * bk_block_incr + sub_block_inx * bk_sub_incr
+
+ a = tl.load(ptr_A)
+ b = tl.load(ptr_B)
+ acc += tl.dot(a, b)
+
+ acc = acc.to(C.dtype.element_ty)
+
+ cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+
+ C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _dds_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices, offsets,
+ block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+
+ # matrix multiplication
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+
+ num_pid_m = tl.num_programs(0)
+ num_pid_n = tl.num_programs(1)
+ pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M)
+
+ start_inx = tl.load(offsets + pid_n)
+ end_inx = tl.load(offsets + pid_n + 1)
+
+ # pointers to dense matrix
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ rak = tl.arange(0, BLOCK_K)
+
+ A += (rm[:, None] * stride_am + rak[None, :] * stride_ak)
+
+ # pointers to sparse matrix
+ rn = tl.arange(0, BLOCK_N)
+ rbk = tl.arange(0, BLOCK_K)
+ B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn)
+
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K)
+
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+
+ ak_sub_incr = BLOCK_K * stride_ak
+ ak_block_incr = BLOCK_SIZE * stride_ak
+ bk_sub_incr = BLOCK_K * stride_bk
+
+ for k in range(nsub_blocks * (end_inx - start_inx)):
+ sub_block_inx = k % nsub_blocks
+ block_inx = k // nsub_blocks
+
+ if trans_B:
+ ptr_B = B + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr
+ else:
+ ptr_B = B + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr
+
+ ptr_A = A + tl.load(column_indices + start_inx + block_inx) * ak_block_incr + sub_block_inx * ak_sub_incr
+ a = tl.load(ptr_A)
+ b = tl.load(ptr_B)
+ acc += tl.dot(a, b)
+
+ acc = acc.to(C.dtype.element_ty)
+ cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+def dsd(shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs,
+ out
+ ):
+
+ device = rhs.device
+ trans_A = transpose_a
+ trans_B = False
+
+ if rhs.stride(0) > 1 and rhs.stride(1) > 1:
+ trans_B = True
+
+ # checks constraints
+ assert shape[1] == rhs.shape[0], "incompatible dimensions"
+ M, K = shape
+ _, N = rhs.shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if rhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ stride_am, stride_ak = data.stride(1), data.stride(2)
+ stride_bk, stride_bn = rhs.stride(0), rhs.stride(1)
+ a_column_indices = column_indices
+ a_offsets = offsets
+
+ # launch kernel
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N']))
+
+ if trans_A:
+ stride_am, stride_ak = data.stride(2), data.stride(1)
+ a_column_indices, a_offsets = column_indices_t, offsets_t
+
+ if trans_B:
+ stride_bk, stride_bn = rhs.stride(1), rhs.stride(0)
+
+ _dsd_kernel[grid](
+ data.data, rhs, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(0), out.stride(1),
+ row_indices, a_column_indices, a_offsets,
+ block_offsets_t, trans_A, trans_B,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+ # return out
+
+def dds(lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b,
+ out
+ ):
+
+ device = lhs.device
+ trans_B = transpose_b
+ trans_A = False
+
+ if lhs.stride(0) > 1 and lhs.stride(1) > 1:
+ trans_A = True
+
+ # checks constraints
+ assert lhs.shape[1] == shape[0], "incompatible dimensions"
+ M, K = lhs.shape
+ _, N = shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if lhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ stride_am, stride_ak = lhs.stride(0), lhs.stride(1)
+ stride_bk, stride_bn = data.stride(1), data.stride(2)
+ b_column_indices = column_indices_t
+ b_offsets = offsets_t
+
+ # launch kernel
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N']))
+
+ if trans_A:
+ stride_am, stride_ak = lhs.stride(1), lhs.stride(0)
+ if trans_B:
+ stride_bk, stride_bn = data.stride(2), data.stride(1)
+ b_column_indices, b_offsets = column_indices, offsets
+
+ _dds_kernel[grid](
+ lhs, data, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(0), out.stride(1),
+ row_indices, b_column_indices, b_offsets,
+ block_offsets_t, trans_A, trans_B,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+
+def sdd(lhs,
+ rhs,
+ shape,
+ out,
+ offsets,
+ row_indices,
+ column_indices
+ ):
+
+ device = out.device
+ trans_A = False
+ trans_B = False
+
+ if lhs.stride(0) > 1 and lhs.stride(1) > 1:
+ trans_A = True
+ if rhs.stride(0) > 1 and rhs.stride(1) > 1:
+ trans_B = True
+
+ # checks constraints
+ assert lhs.shape[1] == rhs.shape[0], "incompatible dimensions"
+ M, K = lhs.shape
+ _, N = rhs.shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if out.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ # launch kernel
+ nnz_blocks = len(row_indices)
+ grid = lambda META: (nnz_blocks,)
+
+ stride_am, stride_ak = lhs.stride(0), lhs.stride(1)
+ stride_bk, stride_bn = rhs.stride(0), rhs.stride(1)
+
+ if trans_A:
+ stride_am, stride_ak = lhs.stride(1), lhs.stride(0)
+ if trans_B:
+ stride_bk, stride_bn = rhs.stride(1), rhs.stride(0)
+
+ _sdd_kernel[grid](
+ lhs, rhs, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(1), out.stride(2),
+ row_indices, column_indices,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+
+@triton.jit
+def _row_indices_kernel(offsets, out):
+ pid = tl.program_id(0)
+ row_offset = tl.load(offsets + pid)
+ nnz_blocks = tl.load(offsets + pid + 1) - row_offset
+ for nnz_block in range(nnz_blocks):
+ tl.store(out + row_offset + nnz_block, pid)
+
+def row_indices(
+ shape, data, offsets, column_indices, out
+):
+ block_rows = len(offsets) - 1
+ _row_indices_kernel[(block_rows, )](offsets, out)
diff --git a/build/torch29-cxx11-cu130-aarch64-linux/stk/matrix.py b/build/torch29-cxx11-cu130-aarch64-linux/stk/matrix.py
new file mode 100644
index 0000000000000000000000000000000000000000..80f42263d6aada287adbfa52a61fe950162a9e28
--- /dev/null
+++ b/build/torch29-cxx11-cu130-aarch64-linux/stk/matrix.py
@@ -0,0 +1,329 @@
+import numpy as np
+import torch
+
+# 1. Add heavyweight (data) validation helper.
+# 2. Add construction helpers
+# 3. Make indentation consistent
+# 4. Replace asserts with descriptive errors.
+
+##
+### Validation helpers.
+##
+
+
+def _validate_matrix(shape, data, row_indices, column_indices, offsets):
+ # Data should be [nnz, block_size, block_size]
+ if data.dim() == 1:
+ data = torch.reshape(data, [data.numel(), 1, 1])
+
+ # Blocks should be square.
+ if data.shape[-2] != data.shape[-1]:
+ raise ValueError(
+ "Expected square blocking in data. "
+ f"Got block shape {[data.shape[-2], data.shape[-1]]}")
+
+ # Flatten batch dimensions on data - original shape preserved
+ # in shape argument.
+ block_size = data.shape[-1]
+ data = data.view([-1, block_size, block_size])
+
+ if data.dim() != 3:
+ raise ValueError(
+ "Expected 3D shape for data (nnz, block, block). "
+ f"Got shape {data.dim()}D shape.")
+
+ block_size = data.shape[1]
+ if shape[-2] % block_size != 0 or shape[-1] % block_size != 0:
+ raise ValueError(
+ "Matrix shape must be dividible by blocking. "
+ f"Got shape {shape} with "
+ f"{[block_size, block_size]} blocking.")
+
+ if np.prod(shape) < data.numel():
+ raise ValueError(
+ "Invalid matrix. Number of nonzeros exceeds matrix capacity "
+ f"({data.numel()} v. {np.prod(shape)})")
+
+ if row_indices.dim() != 1:
+ raise ValueError(
+ f"Expected 1D row_indices. Got {row_indices.dim()}D row_indices.")
+
+ if column_indices.dim() != 1:
+ raise ValueError(
+ f"Expected 1D column_indices. Got {column_indices.dim()}D column_indices.")
+
+ if offsets.dim() != 1:
+ raise ValueError(
+ f"Expected 1D offsets. Got {offsets.dim()}D offsets.")
+
+ if row_indices.numel() != data.shape[0]:
+ raise ValueError(
+ "Expected 1 index per nonzero block. "
+ f"Got {row_indices.numel()} row_indices for {data.shape[0]} blocks")
+
+ if column_indices.numel() != data.shape[0]:
+ raise ValueError(
+ "Expected 1 index per nonzero block. "
+ f"Got {column_indices.numel()} column_indices for {data.shape[0]} blocks")
+
+ block_rows = np.prod(shape[:-1]) / block_size
+ if offsets.numel() != block_rows + 1:
+ raise ValueError(
+ "Expected one offset per block row plus one. "
+ f"Got {offsets.numel()} offsets with {block_rows} block rows.")
+
+ is_cuda = (data.is_cuda and
+ row_indices.is_cuda and
+ column_indices.is_cuda and
+ offsets.is_cuda)
+ is_cpu = (not data.is_cuda and
+ not row_indices.is_cuda and
+ not column_indices.is_cuda and
+ not offsets.is_cuda)
+ if not (is_cuda or is_cpu):
+ raise ValueError(
+ "Expected data & meta-data on common device. "
+ f"Got data on {data.device}, row_indices on {row_indices.device} "
+ f"column_indices on {column_indices.device} and "
+ f"offsets on {offsets.device}.")
+
+ if data.dtype != torch.float16:
+ raise ValueError(
+ f"Expected float16 data. Got {data.dtype} data.")
+ if row_indices.dtype != torch.int16:
+ raise ValueError(
+ f"Expected int16 row_indices. Got {row_indices.dtype} row_indices.")
+ if column_indices.dtype != torch.int16:
+ raise ValueError(
+ f"Expected int16 column_indices. Got {column_indices.dtype} column_indices.")
+ if offsets.dtype != torch.int32:
+ raise ValueError(
+ f"Expected int32 offsets. Got {offsets.dtype} offsets.")
+ return data
+
+
+def _transpose(size, data, row_indices, column_indices, offsets):
+ block_columns = size[1] // data.shape[1]
+
+ # Sort row indices by column indices to get the transposed matrix's
+ # column indices.
+ gather_indices = column_indices.argsort()
+ column_indices_t = row_indices.gather(0, gather_indices)
+ block_offsets_t = gather_indices.int()
+
+ # NOTE: Histogram is not implemented for any integer type on CPU. Do
+ # the histogram in 32-bit float, which can exactly represent 16-bit
+ # integers.
+ column_indices_float = column_indices.float()
+
+ zero = torch.zeros((1,), dtype=torch.int32, device=data.device)
+ nnz_per_column = column_indices_float.histc(block_columns, 0, block_columns)
+ nnz_per_column = nnz_per_column.int()
+ offsets_t = torch.cat([zero, nnz_per_column.cumsum(0, dtype=torch.int32)])
+ return column_indices_t, offsets_t, block_offsets_t
+
+
+class Matrix(torch.nn.Module):
+ """A matrix stored in sparse format.
+
+ Underlying format is block compressed sparse row (BCSR).
+
+ TODO(tgale): Make this mirror torch.Tensor API as much as possible.
+ """
+
+ def __init__(self,
+ size,
+ data,
+ row_indices,
+ column_indices,
+ offsets,
+ column_indices_t=None,
+ offsets_t=None,
+ block_offsets_t=None):
+ super().__init__()
+ self._size = size
+ self._data = data
+ self._row_indices = row_indices
+ self._column_indices = column_indices
+ self._offsets = offsets
+
+ # Produce the transpose meta-data if it is not passed in.
+ if ((column_indices_t is None) or (offsets_t is None) or
+ (block_offsets_t is None)):
+ column_indices_t, offsets_t, block_offsets_t = _transpose(
+ size, data, row_indices, column_indices, offsets)
+ self._column_indices_t = column_indices_t
+ self._offsets_t = offsets_t
+ self._block_offsets_t = block_offsets_t
+
+ self._transposed = False
+
+ # Validate that our metadata will not overflow.
+ max_dim = np.iinfo(np.int16).max * self.blocking
+ if column_indices.dtype == torch.int16:
+ if size[0] > max_dim or size[1] > max_dim:
+ raise ValueError(
+ "Sparse matrix with shape {size} exceeds representable "
+ "size with 16-bit indices.")
+
+ def validate(self):
+ _validate_matrix(self._size,
+ self._data,
+ self._row_indices,
+ self._column_indices,
+ self._offsets)
+
+ # TODO(tgale): Add heavyweight data validation.
+
+ def to(self, device):
+ # TODO(tgale): Handle type conversions here. We
+ # need to set the appropriate meta-data type for
+ # the given floating-point type.
+ self._data = self._data.to(device)
+ self._row_indices = self._row_indices.to(device)
+ self._column_indices = self._column_indices.to(device)
+ self._offsets = self._offsets.to(device)
+ self._column_indices_t = self._column_indices_t.to(device)
+ self._offsets_t = self._offsets_t.to(device)
+ self._block_offsets_t = self._block_offsets_t.to(device)
+ return self
+
+ def cuda(self):
+ return self.to(torch.cuda.current_device())
+
+ def clone(self):
+ return Matrix(
+ self.size(),
+ self.data.clone(),
+ self.row_indices.clone(),
+ self.column_indices.clone(),
+ self.offsets.clone(),
+ self.column_indices_t.clone(),
+ self.offsets_t.clone(),
+ self.block_offsets_t.clone())
+
+ def t(self):
+ if self.dim() != 2:
+ raise ValueError(
+ "t() expects a tensor with <= 2 dimensions, "
+ f"but self is {self.dim()}D.")
+ out = Matrix(self.size(),
+ self.data,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+ out._transposed = not self._transposed
+ out._size = torch.Size((self._size[1], self._size[0]))
+ return out
+
+ def contiguous(self):
+ raise ValueError("Not yet implemented.")
+
+ def is_contiguous(self):
+ return not self._transposed
+
+ @property
+ def is_cuda(self):
+ return self._data.is_cuda
+
+ @property
+ def device(self):
+ return self._data.device
+
+ def size(self):
+ return self._size
+
+ @property
+ def shape(self):
+ return self.size()
+
+ def dim(self):
+ return len(self._size)
+
+ @property
+ def data(self):
+ return self._data
+
+ @property
+ def row_indices(self):
+ return self._row_indices
+
+ @property
+ def column_indices(self):
+ return self._column_indices
+
+ @property
+ def offsets(self):
+ return self._offsets
+
+ @property
+ def offsets_t(self):
+ return self._offsets_t
+
+ @property
+ def column_indices_t(self):
+ return self._column_indices_t
+
+ @property
+ def block_offsets_t(self):
+ return self._block_offsets_t
+
+ @property
+ def dtype(self):
+ return self.data.dtype
+
+ @property
+ def nnz(self):
+ return self.data.numel()
+
+ @property
+ def blocking(self):
+ return self.data.shape[1]
+
+ @property
+ def requires_grad(self):
+ return self.data.requires_grad
+
+ def requires_grad_(self, x):
+ self.data.requires_grad_(x)
+ return self
+
+ def view(self, *shape):
+ assert self.is_contiguous()
+ if shape[-1] != self.size()[-1]:
+ raise ValueError(
+ "Can't change view on compressed dimension. "
+ f"{self.size()[-1]} v. {shape[-1]}.")
+ if np.prod(shape) != np.prod(self.size()):
+ raise ValueError(
+ "Mismatch in numel of Matrix and new shape. "
+ f"{np.prod(self.size())} v. {np.prod(shape)}")
+ return Matrix(shape,
+ self.data,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+
+ @property
+ def grad(self):
+ # TODO(tgale): Make sure this mirrors torch.Tensor
+ # behavior in the case where we ask for the gradient
+ # of a non-contiguous tensor.
+ size = self.size()
+ if not self.is_contiguous():
+ size = torch.Size((size[1], size[0]))
+ out = Matrix(size,
+ self.data.grad,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+ return out if self.is_contiguous() else out.t()
diff --git a/build/torch29-cxx11-cu130-aarch64-linux/stk/ops/__init__.py b/build/torch29-cxx11-cu130-aarch64-linux/stk/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc873b236f4cd4036964c016a4036e3ce5ebf1ac
--- /dev/null
+++ b/build/torch29-cxx11-cu130-aarch64-linux/stk/ops/__init__.py
@@ -0,0 +1,3 @@
+from .linear_ops import dds, dsd, sdd
+from .matrix_ops import ones_like, row_indices, sum, to_dense, to_sparse
+from .eltwise_ops import mul
diff --git a/build/torch29-cxx11-cu130-aarch64-linux/stk/ops/eltwise_ops.py b/build/torch29-cxx11-cu130-aarch64-linux/stk/ops/eltwise_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba7d7332320250fd01fa60e60528f19de3e8ed03
--- /dev/null
+++ b/build/torch29-cxx11-cu130-aarch64-linux/stk/ops/eltwise_ops.py
@@ -0,0 +1,28 @@
+from ..matrix import Matrix
+
+def mul(a, b):
+ """Performs element-wise multiplication of matrices a and b.
+
+ It is the user's responsibility to make sure that a and b
+ follow the same matrix topology. This function assumes it is safe
+ to use the topoplogy of a.
+
+ Args:
+ a: stk.Matrix.
+ b: stk.Matrix with a's matrix topology.
+
+ Returns:
+ stk.Matrix where the entries correspond to torch.mul(a, b).
+ """
+ assert isinstance(a, Matrix)
+ assert isinstance(b, Matrix)
+ assert a.size() == b.size()
+
+ return Matrix(a.size(),
+ a.data * b.data,
+ a.row_indices,
+ a.column_indices,
+ a.offsets,
+ a.column_indices_t,
+ a.offsets_t,
+ a.block_offsets_t)
diff --git a/build/torch29-cxx11-cu130-aarch64-linux/stk/ops/eltwise_ops_test.py b/build/torch29-cxx11-cu130-aarch64-linux/stk/ops/eltwise_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..66bfd4f6af77042d3c5bdb1fe18d00e457478d46
--- /dev/null
+++ b/build/torch29-cxx11-cu130-aarch64-linux/stk/ops/eltwise_ops_test.py
@@ -0,0 +1,86 @@
+import unittest
+import itertools
+import torch
+from absl.testing import parameterized
+
+import stk
+from stk.ops.linear_ops_test import allclose, _dense_and_sparse
+
+_MATRIX_SIZES = (
+ (128, 128, 0.0),
+ (256, 256, 0.5),
+ (2048, 1024, 0.8),
+ (512, 128, 0.0),
+ (128, 512, 0.0),
+ (1024, 512, 0.0),
+ (1024, 512, 0.5),
+ (1024, 512, 0.75),
+ (512, 1024, 0.0),
+ (512, 1024, 0.5),
+ (512, 1024, 0.75),
+ (1024, 1024, 0.0),
+ (1024, 1024, 0.5),
+ (1024, 1024, 0.75),
+)
+
+_DTYPE = (
+ torch.float16, torch.bfloat16
+)
+
+def _generate_testcases():
+ testcases = itertools.product(_MATRIX_SIZES, _DTYPE)
+ testcases = [(*size, 128, dtype) for
+ (size, dtype) in testcases]
+ return testcases
+
+_ELTWISE_OP_TESTS = _generate_testcases()
+
+def _dense_and_sparse_like(x, std=0.1):
+ dense_data = torch.randn_like(x.data, device=x.device) * std
+ sparse = stk.Matrix(x.size(),
+ dense_data,
+ x.row_indices,
+ x.column_indices,
+ x.offsets)
+ dense = stk.ops.to_dense(sparse)
+
+ return (dense.requires_grad_(True),
+ sparse.requires_grad_(True))
+
+@parameterized.parameters(_ELTWISE_OP_TESTS)
+class EltwiseOpsTest(parameterized.TestCase):
+
+ def testEltwiseMul(self, m, n, sparsity, blocking, dtype):
+
+ a_dense, a = _dense_and_sparse(m, n, sparsity, blocking, dtype)
+ b_dense, b = _dense_and_sparse_like(a)
+
+ out = stk.ops.mul(a, b)
+ expected_out = torch.mul(a_dense, b_dense)
+
+ # Compute the gradients w.r.t. the inputs.
+ expected_out.sum().backward()
+ stk.ops.sum(out).backward()
+
+ # Validate the results.
+ out = stk.ops.to_dense(out)
+ self.assertEqual(out.dim(), 2)
+ self.assertEqual(expected_out.size(), out.size())
+ self.assertTrue(allclose(out, expected_out))
+
+ # LHS gradient.
+ grad = stk.ops.to_dense(a.grad)
+ expected_grad = a_dense.grad
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size(), grad.size())
+ self.assertTrue(allclose(grad, expected_grad))
+
+ # RHS gradient.
+ grad = stk.ops.to_dense(b.grad)
+ expected_grad = b_dense.grad
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size(), grad.size())
+ self.assertTrue(allclose(grad, expected_grad))
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch29-cxx11-cu130-aarch64-linux/stk/ops/linear_ops.py b/build/torch29-cxx11-cu130-aarch64-linux/stk/ops/linear_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d277c8c07f9e30addc31900a12175c8a1f4d7ad
--- /dev/null
+++ b/build/torch29-cxx11-cu130-aarch64-linux/stk/ops/linear_ops.py
@@ -0,0 +1,59 @@
+import torch
+
+from ..backend import sputnik
+from ..matrix import Matrix
+
+
+def dsd(a, b):
+ assert isinstance(a, Matrix)
+ assert isinstance(b, torch.Tensor)
+ return sputnik.dsd(
+ a.size(),
+ a.data, a.offsets,
+ a.row_indices,
+ a.column_indices,
+ a.offsets_t,
+ a.column_indices_t,
+ a.block_offsets_t,
+ not a.is_contiguous(),
+ b)
+
+
+def dds(a, b):
+ assert isinstance(a, torch.Tensor)
+ assert isinstance(b, Matrix)
+ return sputnik.dds(
+ a,
+ b.size(),
+ b.data, b.offsets,
+ b.row_indices,
+ b.column_indices,
+ b.offsets_t,
+ b.column_indices_t,
+ b.block_offsets_t,
+ not b.is_contiguous())
+
+
+def sdd(a, b, topo):
+ assert isinstance(a, torch.Tensor)
+ assert isinstance(b, torch.Tensor)
+ assert isinstance(topo, Matrix)
+ assert topo.is_contiguous()
+ out = sputnik.sdd(
+ a, b,
+ topo.size(),
+ topo.data,
+ topo.offsets,
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets_t,
+ topo.column_indices_t,
+ topo.block_offsets_t)
+ return Matrix(topo.size(),
+ out,
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets,
+ topo.column_indices_t,
+ topo.offsets_t,
+ topo.block_offsets_t)
diff --git a/build/torch29-cxx11-cu130-aarch64-linux/stk/ops/linear_ops_test.py b/build/torch29-cxx11-cu130-aarch64-linux/stk/ops/linear_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..ced1d782fbc9f9ca16b3449239f1588dc5ff5e00
--- /dev/null
+++ b/build/torch29-cxx11-cu130-aarch64-linux/stk/ops/linear_ops_test.py
@@ -0,0 +1,216 @@
+import unittest
+import itertools
+import numpy as np
+import torch
+from absl.testing import parameterized
+
+import stk
+
+
+def allclose(x, y, pct=0.25):
+ mask = torch.isclose(x, y, rtol=5e-2)
+ pct_diff = (mask.numel() - mask.sum()) / mask.numel() * 100
+ if pct_diff > pct:
+ print("{:.2f}% of values not close.".format(pct_diff))
+ return False
+ return True
+
+
+# An assortment of problems designed to make sure
+# the bindings are operating correctly.
+_MATRIX_SIZES = (
+ (128, 128, 128, 0.0),
+ (256, 256, 256, 0.5),
+ (2048, 1024, 512, 0.8),
+ (512, 128, 128, 0.0),
+ (128, 128, 512, 0.0),
+ (1024, 512, 512, 0.0),
+ (1024, 512, 512, 0.5),
+ (1024, 512, 512, 0.75),
+ (512, 512, 1024, 0.0),
+ (512, 512, 1024, 0.5),
+ (512, 512, 1024, 0.75),
+ (1024, 1024, 1024, 0.0),
+ (1024, 1024, 1024, 0.5),
+ (1024, 1024, 1024, 0.75),
+)
+
+_TRANSPOSE = (
+ (False, False),
+ (False, True),
+ (True, False),
+ (True, True),
+)
+
+_DTYPE = (
+ torch.float16, torch.bfloat16
+)
+
+def _generate_testcases():
+ testcases = itertools.product(_MATRIX_SIZES, _TRANSPOSE, _DTYPE)
+ testcases = [(*size, *trans, 128, dtype) for
+ (size, trans, dtype) in testcases]
+ return testcases
+
+_LINEAR_OP_TESTS = _generate_testcases()
+
+def _dense_and_sparse(rows, cols, sparsity, blocking, dtype, std=0.1):
+ mask = stk.random.dense_mask(rows, cols, sparsity, blocking)
+ dense = (torch.randn(rows, cols) * std * mask).type(dtype)
+ sparse = stk.ops.to_sparse(dense, blocking)
+ cuda_device = torch.device("cuda")
+ return (dense.to(cuda_device).requires_grad_(True),
+ sparse.to(cuda_device).requires_grad_(True))
+
+
+def _dense(rows, cols, dtype, std=0.1):
+ cuda_device = torch.device("cuda")
+ out = (torch.randn(rows, cols) * std).type(dtype)
+ return out.to(cuda_device).requires_grad_(True)
+
+
+def _dense_2x(rows, cols, dtype):
+ a = _dense(rows, cols, dtype)
+ return a, a.detach().requires_grad_(True)
+
+
+def _with_transpose(op, a, b, trans_a, trans_b):
+ a = a.t() if trans_a else a
+ b = b.t() if trans_b else b
+ return op(a, b)
+
+
+def _mmm(a, b, topo):
+ mask = stk.ops.to_dense(stk.ops.ones_like(topo))
+ return torch.mm(a, b) * mask
+
+
+def _sparse_out_with_transpose(op, a, b, topo, trans_a, trans_b):
+ a = a.t() if trans_a else a
+ b = b.t() if trans_b else b
+ return op(a, b, topo)
+
+
+def _mask(x, mask):
+ mask = stk.ops.to_dense(stk.ops.ones_like(mask))
+ return x * mask
+
+
+@parameterized.parameters(*_LINEAR_OP_TESTS)
+class LinearOpsTest(parameterized.TestCase):
+
+ def testLinearOps_Dsd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+ # Construct the operands.
+ a_shape = (k, m) if trans_a else (m, k)
+ a_dense, a = _dense_and_sparse(*a_shape, sparsity, blocking, dtype)
+ b_shape = (n, k) if trans_b else (k, n)
+ b, bcp = _dense_2x(*b_shape, dtype)
+
+ # Execute the matmul.
+ out = _with_transpose(stk.ops.dsd, a, b, trans_a, trans_b)
+ expected_out = _with_transpose(torch.mm, a_dense, bcp, trans_a, trans_b)
+
+ # Compute the gradients w.r.t. the inputs.
+ expected_out.sum().backward()
+ out.sum().backward()
+
+ # Validate the results.
+ self.assertEqual(out.dim(), 2)
+ self.assertEqual(expected_out.size()[0], out.size()[0])
+ self.assertEqual(expected_out.size()[1], out.size()[1])
+ self.assertTrue(allclose(out, expected_out))
+
+ # LHS gradient.
+ grad = stk.ops.to_dense(a.grad)
+ expected_grad = _mask(a_dense.grad, a.grad)
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
+ self.assertTrue(allclose(grad, expected_grad))
+
+ # RHS gradient.
+ grad = b.grad
+ expected_grad = bcp.grad
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
+ self.assertTrue(allclose(grad, expected_grad))
+
+ def testLinearOps_Dds(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+ # Construct the operands.
+ a_shape = (k, m) if trans_a else (m, k)
+ a, acp = _dense_2x(*a_shape, dtype)
+ b_shape = (n, k) if trans_b else (k, n)
+ b_dense, b = _dense_and_sparse(*b_shape, sparsity, blocking, dtype)
+
+ # Execute the matmul.
+ out = _with_transpose(stk.ops.dds, a, b, trans_a, trans_b)
+ expected_out = _with_transpose(torch.mm, acp, b_dense, trans_a, trans_b)
+
+ # Compute the gradients w.r.t. the inputs.
+ expected_out.sum().backward()
+ out.sum().backward()
+
+ # Validate the results.
+ self.assertEqual(out.dim(), 2)
+ self.assertEqual(expected_out.size()[0], out.size()[0])
+ self.assertEqual(expected_out.size()[1], out.size()[1])
+ self.assertTrue(allclose(out, expected_out))
+
+ # LHS gradient.
+ grad = a.grad
+ expected_grad = acp.grad
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
+ self.assertTrue(allclose(grad, expected_grad))
+
+ # RHS gradient.
+ grad = stk.ops.to_dense(b.grad)
+ expected_grad = _mask(b_dense.grad, b.grad)
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
+ self.assertTrue(allclose(grad, expected_grad))
+
+ def testLinearOps_Sdd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+ # Construct the operands.
+ a_shape = (k, m) if trans_a else (m, k)
+ a, acp = _dense_2x(*a_shape, dtype)
+ b_shape = (n, k) if trans_b else (k, n)
+ b, bcp = _dense_2x(*b_shape, dtype)
+ _, topo = _dense_and_sparse(m, n, sparsity, blocking, dtype)
+
+ # Execute the matmul.
+ out = _sparse_out_with_transpose(stk.ops.sdd, a, b, topo, trans_a, trans_b)
+ expected_out = _sparse_out_with_transpose(_mmm, acp, bcp, topo, trans_a, trans_b)
+
+ # Compute the gradients w.r.t. the inputs.
+ expected_out.sum().backward()
+ stk.ops.sum(out).backward()
+
+ # Validate the results.
+ out = stk.ops.to_dense(out)
+ self.assertEqual(out.dim(), 2)
+ self.assertEqual(expected_out.size()[0], out.size()[0])
+ self.assertEqual(expected_out.size()[1], out.size()[1])
+ self.assertTrue(allclose(out, expected_out))
+
+ # LHS gradient.
+ grad = a.grad
+ expected_grad = acp.grad
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
+ self.assertTrue(allclose(grad, expected_grad))
+
+ # RHS gradient.
+ grad = b.grad
+ expected_grad = bcp.grad
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
+ self.assertTrue(allclose(grad, expected_grad))
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch29-cxx11-cu130-aarch64-linux/stk/ops/matrix_ops.py b/build/torch29-cxx11-cu130-aarch64-linux/stk/ops/matrix_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..447c72dc73439d84f58c917676cc04e64f13e97d
--- /dev/null
+++ b/build/torch29-cxx11-cu130-aarch64-linux/stk/ops/matrix_ops.py
@@ -0,0 +1,98 @@
+from ..backend import sputnik
+from ..matrix import Matrix
+import torch
+import numpy as np
+
+
+@torch.no_grad()
+def row_indices(shape, data, offsets, column_indices):
+ return sputnik.row_indices(shape, data, offsets, column_indices)
+
+
+# TODO(tgale): Replace this helper with a custom kernel. This operation
+# is much simpler to do than how it's currently implemented.
+@torch.no_grad()
+def _expand_for_blocking(idxs, blocking):
+ # Duplicate for block column dimension.
+ idxs = torch.reshape(idxs, [idxs.size()[0], 1, 2]).repeat(1, blocking, 1)
+
+ # Update the column indices.
+ idxs[:, :, 1] *= blocking
+ idxs[:, :, 1] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking])
+
+ # Duplicate for block row dimension.
+ idxs = torch.reshape(idxs, [idxs.size()[0], 1, blocking, 2])
+ idxs = idxs.repeat(1, blocking, 1, 1)
+
+ # Update the row indices.
+ idxs[:, :, :, 0] *= blocking
+ idxs[:, :, :, 0] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking, 1])
+ idxs = torch.reshape(idxs, [-1, 2])
+ return idxs
+
+
+# TODO(tgale): Add input type checking.
+@torch.no_grad()
+def to_dense(x):
+ assert isinstance(x, Matrix)
+
+ shape = (np.prod(x.shape[:-1]), x.shape[-1])
+ row_idxs = x.row_indices.type(torch.int32)
+ col_idxs = x.column_indices.type(torch.int32)
+ indices = _expand_for_blocking(torch.stack([row_idxs, col_idxs], dim=1), x.blocking)
+ indices = (indices[:, 0] * shape[1] + indices[:, 1]).type(torch.int64)
+
+ out = torch.zeros(shape[0] * shape[1], dtype=x.dtype, device=x.device)
+ out.scatter_(0, indices, x.data.flatten())
+ return out.reshape(x.size())
+
+
+@torch.no_grad()
+def _mask(x, blocking=1):
+ assert x.dim() == 2
+ assert x.size()[0] % blocking == 0
+ assert x.size()[1] % blocking == 0
+ block_rows = x.size()[0] // blocking
+ block_cols = x.size()[1] // blocking
+ x = torch.reshape(x, [block_rows, blocking, block_cols, blocking])
+ x = torch.sum(torch.abs(x), dim=(1, 3))
+ return x != 0
+
+
+# TODO(tgale): Add input type checking.
+@torch.no_grad()
+def to_sparse(x, blocking=1):
+ m = _mask(x, blocking)
+
+ # TODO(tgale): Set to appropriate type for input matrix.
+ row_nnzs = torch.sum(m, dim=1).type(torch.int32)
+ zeros = torch.zeros((1,), dtype=row_nnzs.dtype, device=row_nnzs.device)
+ offsets = torch.cat([zeros, torch.cumsum(row_nnzs, dim=0)])
+ offsets = offsets.type(torch.int32)
+
+ indices = torch.nonzero(m).type(torch.int16)
+ row_indices = indices[:, 0]
+ column_indices = indices[:, 1]
+
+ # Nonzero indices in the dense matrix.
+ nonzero_indices = torch.nonzero(m)
+ nonzero_indices = _expand_for_blocking(nonzero_indices, blocking)
+ nonzero_indices = nonzero_indices[:, 0] * x.size()[1] + nonzero_indices[:, 1]
+
+ # Gather the data and construct the sparse matrix.
+ data = torch.gather(x.flatten(), dim=0, index=nonzero_indices)
+ data = torch.reshape(data, [-1, blocking, blocking])
+ return Matrix(x.size(), data, row_indices, column_indices, offsets)
+
+
+@torch.no_grad()
+def ones_like(x):
+ return Matrix(x.size(),
+ torch.ones_like(x.data),
+ x.row_indices,
+ x.column_indices, x.offsets)
+
+
+def sum(x):
+ assert isinstance(x, Matrix)
+ return x.data.sum()
diff --git a/build/torch29-cxx11-cu130-aarch64-linux/stk/ops/matrix_ops_test.py b/build/torch29-cxx11-cu130-aarch64-linux/stk/ops/matrix_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..3af04c0760483e578f93303dc457415948a2a34c
--- /dev/null
+++ b/build/torch29-cxx11-cu130-aarch64-linux/stk/ops/matrix_ops_test.py
@@ -0,0 +1,62 @@
+import unittest
+
+from absl.testing import parameterized
+import stk
+import torch
+
+
+@parameterized.parameters(
+ (8, 16, 0.0, 1),
+ (8, 16, 0.5, 1),
+ (8, 16, .95, 1),
+ (16, 8, 0.0, 1),
+ (16, 8, 0.5, 1),
+ (16, 8, .95, 1),
+ (8, 16, 0.0, 8),
+ (8, 16, 0.5, 8),
+ (8, 16, 1.0, 8),
+ (16, 8, 0.0, 8),
+ (16, 8, 0.5, 8),
+ (16, 8, 1.0, 8),
+ (128, 256, 0.5, 16),
+ (256, 128, 0.75, 32),
+ (512, 512, .875, 128))
+class MatrixOpsTest(parameterized.TestCase):
+
+ def testMatrixOps_FormatConversion(self, rows, cols, sparsity, blocking):
+ mask = stk.random.dense_mask(rows, cols, sparsity, blocking)
+ x = (torch.randn(rows, cols) * mask).type(torch.float16)
+
+ # Convert the matrix to sparse format.
+ sparse_x = stk.ops.to_sparse(x, blocking)
+
+ # Validate the matrix.
+ sparse_x.validate()
+
+ # Validate the shape.
+ self.assertEqual(sparse_x.dim(), 2)
+ self.assertEqual(sparse_x.size()[0], rows)
+ self.assertEqual(sparse_x.size()[1], cols)
+
+ # Validate the sparsity.
+ numblocks = rows // blocking * cols // blocking
+ nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+ self.assertEqual(sparse_x.nnz, nnz)
+
+ # Convert back to dense format.
+ dense_x = stk.ops.to_dense(sparse_x)
+
+ # Validate the shape.
+ self.assertEqual(dense_x.dim(), 2)
+ self.assertEqual(dense_x.size()[0], rows)
+ self.assertEqual(dense_x.size()[1], cols)
+
+ # Validate the sparsity
+ self.assertEqual(torch.count_nonzero(dense_x).item(), nnz)
+
+ # Validate the output.
+ self.assertTrue(torch.all(torch.eq(x, dense_x)))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch29-cxx11-cu130-aarch64-linux/stk/random/__init__.py b/build/torch29-cxx11-cu130-aarch64-linux/stk/random/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2576d1ca27283f77569a9a620c7c99fa68aaf30e
--- /dev/null
+++ b/build/torch29-cxx11-cu130-aarch64-linux/stk/random/__init__.py
@@ -0,0 +1,2 @@
+# from stk.random.random_ops import dense_mask, mask, randn
+from .random_ops import dense_mask, mask, randn
diff --git a/build/torch29-cxx11-cu130-aarch64-linux/stk/random/random_ops.py b/build/torch29-cxx11-cu130-aarch64-linux/stk/random/random_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1b36771e0eb8e7abf46bcb3b136b5fb1d29df93
--- /dev/null
+++ b/build/torch29-cxx11-cu130-aarch64-linux/stk/random/random_ops.py
@@ -0,0 +1,36 @@
+import numpy as np
+import torch
+from ..ops import matrix_ops
+
+
+@torch.no_grad()
+def dense_mask(rows, cols, sparsity, blocking=1):
+ assert sparsity >= 0.0 and sparsity <= 1.0
+ assert rows % blocking == 0 and cols % blocking == 0
+
+ block_rows, block_cols = (rows // blocking, cols // blocking)
+ nnz = round(block_rows * block_cols * (1 - sparsity))
+
+ out = np.ones(block_rows * block_cols)
+ mask = np.random.choice(out.size, out.size - nnz, replace=False)
+ out[mask] = 0.0
+
+ out = np.tile(
+ np.reshape(out, [block_rows, 1, block_cols, 1]),
+ (1, blocking, 1, blocking))
+ out = np.reshape(out, [rows, cols])
+ return torch.from_numpy(out.astype(np.float32))
+
+
+@torch.no_grad()
+def mask(m, n, sparsity, blocking=1):
+ out = dense_mask(m, n, sparsity, blocking).type(torch.float16)
+ return matrix_ops.to_sparse(out, blocking=blocking)
+
+
+@torch.no_grad()
+def randn(shape, sparsity, blocking=1):
+ shape_2d = (np.prod(shape[:-1]), shape[-1])
+ out = mask(*shape_2d, sparsity, blocking)
+ out.data.copy_(torch.randn(*out.data.shape))
+ return out.view(*shape)
diff --git a/build/torch29-cxx11-cu130-aarch64-linux/stk/random/random_ops_test.py b/build/torch29-cxx11-cu130-aarch64-linux/stk/random/random_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..587b44ec890c861879c6296b8f9028f5d99ab82f
--- /dev/null
+++ b/build/torch29-cxx11-cu130-aarch64-linux/stk/random/random_ops_test.py
@@ -0,0 +1,73 @@
+import unittest
+
+from absl.testing import parameterized
+from . import random
+import torch
+
+
+@parameterized.parameters(
+ (8, 16, 0.0, 1),
+ (8, 16, 0.5, 1),
+ (8, 16, .95, 1),
+ (16, 8, 0.0, 1),
+ (16, 8, 0.5, 1),
+ (16, 8, .95, 1),
+ (8, 16, 0.0, 8),
+ (8, 16, 0.5, 8),
+ (8, 16, 1.0, 8),
+ (16, 8, 0.0, 8),
+ (16, 8, 0.5, 8),
+ (16, 8, 1.0, 8),
+ (128, 256, 0.5, 16),
+ (256, 128, 0.75, 32),
+ (512, 512, .875, 128))
+class RandomOpsTest(parameterized.TestCase):
+
+ def testRandomOps_DenseMask(self, rows, cols, sparsity, blocking):
+ mask = random.dense_mask(
+ rows, cols, sparsity, blocking)
+
+ # Validate the shape.
+ self.assertEqual(mask.dim(), 2)
+ self.assertEqual(mask.size()[0], rows)
+ self.assertEqual(mask.size()[1], cols)
+
+ # Validate the sparsity
+ numblocks = rows // blocking * cols // blocking
+ nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+ self.assertEqual(
+ torch.count_nonzero(mask).item(),
+ nnz)
+
+ # Check values are zero or one.
+ self.assertTrue(
+ torch.all(torch.logical_or(
+ torch.eq(mask, 0),
+ torch.eq(mask, 1))))
+
+ def testRandomOps_SparseMask(self, rows, cols, sparsity, blocking):
+ mask = random.mask(
+ rows, cols, sparsity, blocking)
+
+ # Validate the matrix.
+ mask.validate()
+
+ # Validate the shape.
+ self.assertEqual(mask.dim(), 2)
+ self.assertEqual(mask.size()[0], rows)
+ self.assertEqual(mask.size()[1], cols)
+
+ # Validate the sparsity.
+ numblocks = rows // blocking * cols // blocking
+ nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+ self.assertEqual(mask.nnz, nnz)
+
+ # Check values are zero or one.
+ self.assertTrue(
+ torch.all(torch.logical_or(
+ torch.eq(mask.data, 0),
+ torch.eq(mask.data, 1))))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch29-cxx11-cu130-aarch64-linux/xpu_fused_moe.py b/build/torch29-cxx11-cu130-aarch64-linux/xpu_fused_moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..0440b484e6f4073a384dcf3bc562174d2ef71d25
--- /dev/null
+++ b/build/torch29-cxx11-cu130-aarch64-linux/xpu_fused_moe.py
@@ -0,0 +1,672 @@
+# SPDX-License-Identifier: Apache-2.0
+# MegaBlocks XPU Fused MoE Implementation
+import os
+import torch
+
+from ._ops import ops, add_op_namespace_prefix
+
+from torch.library import register_fake
+
+
+def resolve_dtensor(weight: torch.Tensor):
+ """Convert DTensor to local tensor for use with custom ops."""
+ from torch.distributed._tensor import DTensor
+ if isinstance(weight, DTensor):
+ return weight.to_local()
+ return weight
+
+
+# Register fake/meta kernels for torch.compile compatibility
+def _register_xpu_fake_kernels():
+ """Register fake kernels for XPU MoE operations to support torch.compile."""
+
+ def _register_if_available(op_name, fn):
+ if hasattr(ops, op_name):
+ register_fake(add_op_namespace_prefix(op_name))(fn)
+
+ _register_if_available(
+ "cutlass_grouped_gemm_interface",
+ lambda ptr_A, ptr_B, ptr_scales, ptr_bias, ptr_D, expert_first_token_offset, N, K, num_experts, is_B_int4, is_B_mxfp4: ptr_D,
+ )
+
+ _register_if_available(
+ "fused_moe_prologue",
+ lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, ep_rank, ep_size, num_experts_on_rank: None,
+ )
+
+ _register_if_available(
+ "moe_gather",
+ lambda output, moe_output, topk_weights, permuted_row_to_unpermuted_row, unpermuted_row_to_permuted_row, expert_first_token_offset, num_experts: None,
+ )
+
+ _register_if_available(
+ "silu_and_mul",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "mul_and_silu",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_and_mul",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_tanh_and_mul",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_fast",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_new",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_quick",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "swigluoai_and_mul",
+ lambda out, input, alpha=1.702, limit=7.0: None,
+ )
+
+
+# Register fake kernels on module load
+_register_xpu_fake_kernels()
+
+
+# default
+def cutlass_grouped_gemm(input_A, input_B, bias, output, expert_token_count, n,
+ k, num_experts):
+ # expert_token_count_ = torch.tensor(expert_token_count,
+ # dtype=torch.int64,
+ # device=input_A.device)
+ # if bias is not None:
+ # bias = bias.repeat_interleave(expert_token_count_, dim=0).float()
+
+ def exclusive_prefix_sum(arr):
+ prefix = [0]
+ for i, x in enumerate(arr):
+ prefix.append(prefix[-1] + x)
+ return prefix
+
+ expert_offset = torch.tensor(exclusive_prefix_sum(expert_token_count),
+ dtype=torch.int64,
+ device="xpu")
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=None,
+ ptr_bias=bias,
+ ptr_D=output,
+ expert_first_token_offset=expert_offset,
+ N=n,
+ K=k,
+ num_experts=num_experts,
+ is_B_int4=False,
+ is_B_mxfp4=False)
+
+
+def cutlass_grouped_gemm_xe2(input_A, input_B, scales, bias, output,
+ num_rows_per_expert, n, k, num_experts, is_B_int4,
+ is_B_mxfp4):
+ expert_first_token_offset = torch.cat([
+ torch.tensor([0],
+ dtype=num_rows_per_expert.dtype,
+ device=num_rows_per_expert.device),
+ torch.cumsum(num_rows_per_expert, dim=0)
+ ]).to(torch.int64)
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=scales,
+ ptr_bias=bias,
+ ptr_D=output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=n,
+ K=k,
+ num_experts=num_experts,
+ is_B_int4=is_B_int4,
+ is_B_mxfp4=is_B_mxfp4)
+
+
+def ceilDiv(a, b):
+ return (a + b - 1) // b
+
+
+def compute_num_tokens_per_block(num_tokens, num_experts_per_node):
+ for num_tokens_per_block in [32, 64, 128, 256, 512, 1024]:
+ num_blocks_per_seq = ceilDiv(num_tokens, num_tokens_per_block)
+ if num_blocks_per_seq * num_experts_per_node <= num_tokens_per_block:
+ return num_tokens_per_block
+ return 1024
+
+
+def _bytes_to_typed_tensor(byte_tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
+ """Reinterpret a uint8 buffer as a typed tensor by copying bytes.
+
+ This avoids `Tensor.view(dtype)` which can fail under torch.compile
+ constant folding when shape divisibility is not proven.
+ """
+ if byte_tensor.dtype != torch.uint8:
+ raise ValueError("byte_tensor must be uint8")
+ itemsize = torch.empty((), dtype=dtype).element_size()
+ numel = byte_tensor.numel() // itemsize
+ out = torch.empty((numel,), dtype=dtype, device=byte_tensor.device)
+ out.view(torch.uint8).copy_(byte_tensor.contiguous())
+ return out
+
+
+def implement_zp(qweight):
+ # change u4 to s4 to avoid zero point in gemm kernel
+ # only support default zero point now
+ assert qweight.dtype == torch.uint8, "Input tensor must be uint8"
+
+ high_u4 = (qweight >> 4) & 0x0F
+ low_u4 = qweight & 0x0F
+
+ high_s8 = high_u4.to(torch.int8)
+ low_s8 = low_u4.to(torch.int8)
+
+ high_s8 = high_s8 - 8
+ low_s8 = low_s8 - 8
+
+ def pack_compact(a, b):
+
+ def process_number(x):
+ sign = (x < 0).to(torch.uint8)
+ abs_low3 = (x.view(torch.uint8) & 0x7).to(torch.uint8)
+ return (sign << 3) | abs_low3
+
+ packed_a = process_number(a)
+ packed_b = process_number(b)
+
+ return (packed_a << 4) | packed_b
+
+ result = pack_compact(high_s8, low_s8)
+
+ return result
+
+
+def xpu_fused_moe(hidden_states,
+ w13,
+ w13_scales,
+ w13_bias,
+ w2,
+ w2_scales,
+ w2_bias,
+ topk_weights,
+ topk_ids,
+ n_experts_per_token,
+ activation,
+ num_experts,
+ ep_rank=0,
+ ep_size=1,
+ is_fp8=False,
+ is_int4=False,
+ is_mxfp4=False):
+ '''
+ hidden_states: [num_rows, hidden_size]
+ w13: [num_experts, 2*inter_size, hidden_size]
+ w13_scales:
+ None for bf16/fp16
+ or [num_experts] for fp8
+ or [num_experts, 2*inter_size, hidden_size // group_size] for 4bits
+ w13_bias: [num_experts, 2*inter_size] or None
+ w2: [num_experts, hidden_size, inter_size]
+ w2_scales:
+ None for bf16/fp16
+ or [num_experts] for fp8
+ or [num_experts, hidden_size, inter_size // group_size] for 4bits
+ w2_bias: [num_experts, hidden_size] or None
+ topk_weights: [num_rows, topk]
+ topk_ids: [num_rows, topk]
+ n_experts_per_token: int
+ activation: str
+ num_experts: int
+ is_int4: bool
+ is_mxfp4: bool
+ '''
+
+ # Resolve DTensors to local tensors before passing to custom ops
+ hidden_states = resolve_dtensor(hidden_states)
+ w13 = resolve_dtensor(w13)
+ w2 = resolve_dtensor(w2)
+ if w13_scales is not None:
+ w13_scales = resolve_dtensor(w13_scales)
+ if w13_bias is not None:
+ w13_bias = resolve_dtensor(w13_bias)
+ if w2_scales is not None:
+ w2_scales = resolve_dtensor(w2_scales)
+ if w2_bias is not None:
+ w2_bias = resolve_dtensor(w2_bias)
+ topk_weights = resolve_dtensor(topk_weights)
+ topk_ids = resolve_dtensor(topk_ids)
+
+ output = torch.empty_like(hidden_states)
+ num_rows, hidden_size = list(hidden_states.shape)
+
+ dim_last = w13.shape[-1]
+ dim_second_last = w13.shape[-2]
+
+ # w13 is combined gate+up weights, so one dimension is 2*inter_size
+ # Determine which dimension is hidden_size and which is 2*inter_size
+ if dim_second_last == hidden_size:
+ # w13 is [E, hidden_size, 2*inter_size] - standard layout
+ inter_size = dim_last // 2
+ needs_transpose = False
+ else:
+ # w13 is [E, 2*inter_size, hidden_size] - needs transpose
+ inter_size = dim_second_last // 2
+ needs_transpose = True
+
+ assert w13.is_contiguous() and w2.is_contiguous()
+
+ # 4bits support [E, N, K]
+ # other types [E, K, N]
+ if not is_int4 and not is_mxfp4:
+ if not hasattr(w13, 'xpu_fused_moe'):
+ if needs_transpose:
+ w13.data = w13.transpose(-1, -2).contiguous()
+ w2.data = w2.transpose(-1, -2).contiguous()
+ w13.xpu_fused_moe = True
+ w13.inter_size = inter_size
+ else:
+ inter_size = w13.inter_size
+
+ if is_int4 and not hasattr(w13, 'xpu_fused_moe'):
+ w13_tmp = torch.empty_like(w13)
+ w2_tmp = torch.empty_like(w2)
+ for i in range(num_experts):
+ w13_tmp[i] = implement_zp(w13[i])
+ w2_tmp[i] = implement_zp(w2[i])
+ w13_tmp = w13_tmp.contiguous()
+ w2_tmp = w2_tmp.contiguous()
+ w13.data = w13_tmp
+ w2.data = w2_tmp
+ w13.xpu_fused_moe = True
+
+ # TODO: will all integrated in Cpp func. Temporary expose before gemm fusion
+ num_experts_per_node = num_experts
+ experts_per_token = n_experts_per_token
+ num_moe_inputs = n_experts_per_token * num_rows
+ permuted_elems = num_moe_inputs * hidden_size
+ # interbuf_elems = num_moe_inputs * inter_size
+ permuted_row_to_unpermuted_row_size = num_moe_inputs * 4
+ permuted_token_selected_experts_size = num_moe_inputs * 4
+ src_to_dest_map_size = experts_per_token * num_rows * 4
+ expert_first_token_offset_size = (num_experts_per_node + 1) * 8
+ num_tokens_per_block = compute_num_tokens_per_block(
+ num_rows, num_experts_per_node)
+ num_blocks_per_seq = ceilDiv(num_rows, num_tokens_per_block)
+ blocked_expert_counts_size = num_experts_per_node * num_blocks_per_seq * 4
+ blocked_expert_counts_cumsum_size = blocked_expert_counts_size
+ blocked_row_to_unpermuted_row_size = num_experts_per_node * num_rows * 4
+ permuted_data_size = permuted_elems * hidden_states.element_size()
+ permuted_token_final_scales_size = num_moe_inputs * 4
+
+ ws_map = {}
+ map_offset = 0
+
+ def config_ws(name, size):
+ nonlocal map_offset
+ if size % 256 != 0:
+ size += 256 - size % 256
+ ws_map[name] = (size, map_offset)
+ map_offset += size
+
+ config_ws("permuted_row_to_unpermuted_row",
+ permuted_row_to_unpermuted_row_size)
+ config_ws("permuted_token_selected_experts",
+ permuted_token_selected_experts_size)
+ config_ws("unpermuted_row_to_permuted_row", src_to_dest_map_size)
+ config_ws("blocked_expert_counts", blocked_expert_counts_size)
+ config_ws("blocked_expert_counts_cumsum",
+ blocked_expert_counts_cumsum_size)
+ config_ws("blocked_row_to_unpermuted_row",
+ blocked_row_to_unpermuted_row_size)
+ config_ws("expert_first_token_offset", expert_first_token_offset_size)
+ config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
+ config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
+
+ workspace = torch.zeros(map_offset,
+ dtype=torch.uint8,
+ device=hidden_states.device)
+ if topk_ids.dtype == torch.int32:
+ topk_ids = topk_ids.to(torch.int64)
+ ops.fused_moe_prologue(
+ input=hidden_states,
+ token_selected_experts=topk_ids,
+ token_final_scales=topk_weights,
+ workspace=workspace,
+ hidden_size=hidden_size,
+ inter_size=inter_size,
+ ep_rank=ep_rank,
+ ep_size=ep_size,
+ num_experts_on_rank=num_experts_per_node)
+
+ expert_first_token_offset_bytes = workspace[
+ ws_map["expert_first_token_offset"][1]:
+ ws_map["expert_first_token_offset"][1] +
+ expert_first_token_offset_size]
+ unpermuted_row_to_permuted_row_bytes = workspace[
+ ws_map["unpermuted_row_to_permuted_row"][1]:
+ ws_map["unpermuted_row_to_permuted_row"][1] +
+ src_to_dest_map_size]
+ permuted_row_to_unpermuted_row_bytes = workspace[
+ ws_map["permuted_row_to_unpermuted_row"][1]:
+ ws_map["permuted_row_to_unpermuted_row"][1] +
+ permuted_row_to_unpermuted_row_size]
+
+ if torch.compiler.is_compiling():
+ expert_first_token_offset = _bytes_to_typed_tensor(
+ expert_first_token_offset_bytes, torch.int64
+ )
+ unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
+ unpermuted_row_to_permuted_row_bytes, torch.int32
+ )
+ permuted_row_to_unpermuted_row = _bytes_to_typed_tensor(
+ permuted_row_to_unpermuted_row_bytes, torch.int32
+ )
+ else:
+ expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
+ unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
+ permuted_row_to_unpermuted_row = permuted_row_to_unpermuted_row_bytes.view(torch.int32)
+ gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
+ ws_map["overlapped_gemm1_gemm2_inputs"][1] +
+ permuted_data_size].view(hidden_states.dtype).view(
+ num_moe_inputs, hidden_size)
+ # permuted_token_final_scales = workspace[
+ # ws_map["permuted_token_final_scales"][1]:
+ # ws_map["permuted_token_final_scales"][1] +
+ # permuted_token_final_scales_size].view(torch.float)
+ gemm1_output = torch.empty((num_moe_inputs, 2 * inter_size),
+ dtype=hidden_states.dtype,
+ device=hidden_states.device)
+
+ ########### gemm1 ##################
+ input_B = w13
+
+ if not is_fp8 and not is_int4 and not is_mxfp4:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=gemm1_input,
+ ptr_B=input_B,
+ ptr_scales=None,
+ ptr_bias=w13_bias,
+ ptr_D=gemm1_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=2 * inter_size,
+ K=hidden_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+ else:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=gemm1_input,
+ ptr_B=input_B,
+ ptr_scales=w13_scales,
+ ptr_bias=w13_bias,
+ ptr_D=gemm1_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=2 * inter_size,
+ K=hidden_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+
+ # act
+ act_output = torch.empty((num_moe_inputs, inter_size),
+ dtype=gemm1_output.dtype,
+ device=gemm1_output.device)
+ if activation == "silu":
+ ops.silu_and_mul(act_output, gemm1_output)
+ elif activation == "gelu":
+ ops.gelu_and_mul(act_output, gemm1_output)
+ elif activation == "swigluoai":
+ ops.swigluoai_and_mul(act_output, gemm1_output, 1.702, 7.0)
+ else:
+ raise ValueError(f"Unsupported FusedMoe activation: {activation}.")
+
+ ########### gemm2 ##################
+ input_A = act_output.contiguous()
+ input_B = w2
+ gemm2_output = torch.empty((num_moe_inputs, hidden_size),
+ dtype=hidden_states.dtype,
+ device=hidden_states.device)
+ if not is_fp8 and not is_int4 and not is_mxfp4:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=None,
+ ptr_bias=w2_bias,
+ ptr_D=gemm2_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=hidden_size,
+ K=inter_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+ else:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=w2_scales,
+ ptr_bias=w2_bias,
+ ptr_D=gemm2_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=hidden_size,
+ K=inter_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+
+ ops.moe_gather(output, gemm2_output, topk_weights,
+ permuted_row_to_unpermuted_row,
+ unpermuted_row_to_permuted_row,
+ expert_first_token_offset,
+ num_experts_per_node)
+ return output
+
+
+def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor:
+ """Apply jitter to the input tensor for regularization."""
+ low = 1.0 - moe_jitter_eps
+ high = 1.0 + moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return x * (low + noise * (high - low))
+
+
+def compute_top_k(scores: torch.Tensor, moe_top_k: int):
+ """Compute the top-k scores from the logits."""
+ if moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, moe_top_k, dim=-1)
+
+
+def route_tokens_xpu(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: torch.Tensor,
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ training: bool = False,
+) -> tuple:
+ """Route tokens to experts and compute expert weights and indices (XPU version)."""
+ if training and moe_jitter_eps is not None:
+ x = apply_jitter(x, moe_jitter_eps)
+
+ x_flat = x.view(-1, x.shape[-1])
+ logits = torch.nn.functional.linear(x_flat, router_weight, router_bias)
+ expert_weights, expert_indices = compute_top_k(logits, moe_top_k)
+ expert_weights = expert_weights.softmax(dim=-1)
+ if moe_normalize_expert_weights is not None:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+
+ return logits, expert_weights, expert_indices
+
+
+def _get_device_mesh(model):
+ """Extract device_mesh from child's unused pre_hook closure for EP support."""
+ try:
+ hook = next(
+ h
+ for h in model.experts._forward_pre_hooks.values()
+ if "device_mesh" in h.__code__.co_freevars
+ )
+ return hook.__closure__[
+ hook.__code__.co_freevars.index("device_mesh")
+ ].cell_contents
+ except Exception:
+ return None
+
+
+class MegaBlocksMoeMLP(torch.nn.Module):
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> tuple:
+ """
+ Forward pass through the MoE layer.
+
+ Args:
+ x: Input tensor of shape [batch_size, seq_len, hidden_size] or [tokens, hidden_size]
+
+ Returns:
+ Tuple of (output, expert_weights) where:
+ - output: Tensor of same shape as input
+ - expert_weights: Expert weights for each token [tokens, top_k]
+ """
+
+ # Get MoE parameters from the wrapped modules
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+
+ # Get EP (Expert Parallelism) parameters
+ ep_size = 1
+ ep_rank = 0
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = _get_device_mesh(self)
+ if device_mesh is not None:
+ expert_parallel_group = device_mesh.get_group()
+ if expert_parallel_group is not None:
+ import torch.distributed as dist
+ if dist.is_initialized():
+ ep_size = dist.get_world_size(expert_parallel_group)
+ ep_rank = dist.get_rank(expert_parallel_group)
+
+ # Number of experts on this rank
+ num_experts_on_rank = moe_num_experts // ep_size
+
+ # Detect activation type - check for GptOss-style swigluoai activation
+ # GptOssExperts has alpha and limit attributes for swigluoai
+ if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
+ activation = "swigluoai"
+ else:
+ activation = getattr(self.experts, "activation", "silu")
+
+ # Get weight tensors - support different naming conventions
+ if hasattr(self.experts, "gate_up_proj"):
+ w13 = self.experts.gate_up_proj
+ # NOTE: swigluoai_and_mul kernel expects interleaved layout [g0,u0,g1,u1,...]
+ # which matches GptOss's gate_up_proj layout, so no conversion needed.
+
+ elif hasattr(self.experts, "w1"):
+ # Combine w1 and w3 if stored separately
+ w1 = self.experts.w1
+ w3 = getattr(self.experts, "w3", None)
+ if w3 is not None:
+ w13 = torch.cat([w1, w3], dim=-2)
+ else:
+ w13 = w1
+ else:
+ raise AttributeError("experts module must have 'gate_up_proj' or 'w1' attribute")
+
+ if hasattr(self.experts, "down_proj"):
+ w2 = self.experts.down_proj
+ elif hasattr(self.experts, "w2"):
+ w2 = self.experts.w2
+ else:
+ raise AttributeError("experts module must have 'down_proj' or 'w2' attribute")
+
+ # Get optional bias tensors
+ w13_bias = getattr(self.experts, "gate_up_proj_bias", None)
+ w2_bias = getattr(self.experts, "down_proj_bias", None)
+
+ # Get quantization info
+ is_fp8 = getattr(self.experts, "is_fp8", False)
+ is_int4 = getattr(self.experts, "is_int4", False)
+ is_mxfp4 = getattr(self.experts, "is_mxfp4", False)
+
+ w13_scales = getattr(self.experts, "gate_up_proj_scales", None)
+ w2_scales = getattr(self.experts, "down_proj_scales", None)
+
+ # Store original shape
+ in_shape = x.size()
+
+ # Route tokens to experts
+ logits, expert_weights, expert_indices = route_tokens_xpu(
+ x,
+ self.router.weight,
+ getattr(self.router, "bias", None),
+ moe_top_k,
+ moe_num_experts,
+ moe_jitter_eps,
+ moe_normalize_expert_weights,
+ self.training,
+ )
+
+ # Reshape input for fused MoE
+ x_flat = x.view(-1, x.shape[-1])
+
+ # Call XPU fused MoE kernel
+ output = xpu_fused_moe(
+ hidden_states=x_flat,
+ w13=w13,
+ w13_scales=w13_scales,
+ w13_bias=w13_bias,
+ w2=w2,
+ w2_scales=w2_scales,
+ w2_bias=w2_bias,
+ topk_weights=expert_weights.float(),
+ topk_ids=expert_indices,
+ n_experts_per_token=moe_top_k,
+ activation=activation,
+ num_experts=num_experts_on_rank,
+ ep_rank=ep_rank,
+ ep_size=ep_size,
+ is_fp8=is_fp8,
+ is_int4=is_int4,
+ is_mxfp4=is_mxfp4,
+ )
+
+ # All-reduce across EP group to combine partial expert outputs
+ if ep_size > 1 and expert_parallel_group is not None:
+ import torch.distributed as dist
+ dist.all_reduce(output, op=dist.ReduceOp.SUM, group=expert_parallel_group)
+
+ # Restore original shape
+ output = output.view(in_shape)
+
+ return output, expert_weights
+
+
+# Export classes and functions
+__all__ = [
+ "MegaBlocksMoeMLP",
+ "xpu_fused_moe",
+ "cutlass_grouped_gemm",
+ "cutlass_grouped_gemm_xe2",
+]
\ No newline at end of file
diff --git a/build/torch29-cxx11-cu130-x86_64-linux/__init__.py b/build/torch29-cxx11-cu130-x86_64-linux/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..38075732c6d8fa0e1e6ef493145e1aca3851ae6b
--- /dev/null
+++ b/build/torch29-cxx11-cu130-x86_64-linux/__init__.py
@@ -0,0 +1,202 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+from ._ops import ops
+
+from .grouped_gemm import backend as gg_backend
+from .grouped_gemm import ops as gg_ops
+
+
+from ._layers.arguments import Arguments
+from ._layers.dmoe import ParallelDroplessMLP, dMoE
+from ._layers.glu import SparseGLU
+from ._layers.mlp import MLP, SparseMLP
+from ._layers.moe import MoE, ParallelMLP, get_load_balancing_loss
+
+from . import layers
+
+# This section contains the direct kernel exports (not inlcuded in the original code)
+def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
+ """
+ Compute exclusive cumulative sum along the specified dimension.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ result = ops.exclusive_cumsum(x, dim)
+ out.copy_(result)
+ return out
+
+
+def inclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
+ """
+ Compute inclusive cumulative sum along the specified dimension.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ result = ops.inclusive_cumsum(x, dim)
+ out.copy_(result)
+ return out
+
+
+def histogram(x: torch.Tensor, num_bins: int) -> torch.Tensor:
+ """
+ Compute histogram of input tensor values.
+
+ Args:
+ x: Input tensor
+ num_bins: Number of histogram bins
+
+ Returns:
+ Histogram tensor with counts for each bin
+ """
+ return ops.histogram(x, num_bins)
+
+
+def indices(
+ padded_bins: torch.Tensor,
+ block_size: int,
+ output_block_rows: int,
+ output_block_columns: int,
+) -> torch.Tensor:
+ """
+ Construct indices from padded bins for sparse operations.
+
+ Args:
+ padded_bins: Tensor containing bin boundaries
+ block_size: Size of each block
+ output_block_rows: Number of rows in output blocks
+ output_block_columns: Number of columns in output blocks
+
+ Returns:
+ Tensor containing constructed indices
+ """
+ return ops.indices(padded_bins, block_size, output_block_rows, output_block_columns)
+
+
+def replicate_forward(
+ x: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Forward pass of replicate operation - replicate values according to bin sizes.
+
+ Args:
+ x: Input tensor with values to replicate
+ bins: Tensor containing bin sizes
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ return ops.replicate_forward(x, bins, out)
+
+
+def replicate_backward(
+ grad: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Backward pass of replicate operation - reduce gradients back to bins.
+
+ Args:
+ grad: Gradient tensor to reduce
+ bins: Tensor containing bin sizes
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ return ops.replicate_backward(grad, bins, out)
+
+
+def sort(
+ x: torch.Tensor, end_bit: int, x_out: torch.Tensor, iota_out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Radix sort with index tracking.
+
+ Args:
+ x: Input tensor to sort
+ end_bit: Number of bits to consider in sorting
+ x_out: Output tensor for sorted values
+ iota_out: Output tensor for sorted indices
+
+ Returns:
+ The sorted values tensor
+ """
+ return ops.sort(x, end_bit, x_out, iota_out)
+
+
+# Convenience functions for common use cases
+def cumsum(x: torch.Tensor, dim: int = -1, exclusive: bool = False) -> torch.Tensor:
+ """
+ Compute cumulative sum with automatic output allocation.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum (default: last dimension)
+ exclusive: Whether to compute exclusive (True) or inclusive (False) cumsum
+
+ Returns:
+ New tensor containing the cumulative sum
+ """
+ out = torch.empty_like(x)
+ if exclusive:
+ return exclusive_cumsum(x, dim, out)
+ else:
+ return inclusive_cumsum(x, dim, out)
+
+
+def argsort(x: torch.Tensor, end_bit: int = 32) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Sort tensor and return both sorted values and indices.
+
+ Args:
+ x: Input tensor to sort
+ end_bit: Number of bits to consider in sorting
+
+ Returns:
+ Tuple of (sorted_values, sorted_indices)
+ """
+ x_out = torch.empty_like(x)
+ iota_out = torch.empty_like(x)
+ sort(x, end_bit, x_out, iota_out)
+ return x_out, iota_out
+
+
+# Export public API
+__all__ = [
+ "MyReplacementLayer",
+ # Direct kernel exports
+ "exclusive_cumsum",
+ "inclusive_cumsum",
+ "histogram",
+ "indices",
+ "replicate_forward",
+ "replicate_backward",
+ "sort",
+ "cumsum",
+ "argsort",
+ # Original exports
+ "Arguments",
+ "ParallelDroplessMLP",
+ "dMoE",
+ "SparseGLU",
+ "MLP",
+ "SparseMLP",
+ "MoE",
+ "ParallelMLP",
+ "get_load_balancing_loss",
+]
diff --git a/build/torch29-cxx11-cu130-x86_64-linux/_layers/__init__.py b/build/torch29-cxx11-cu130-x86_64-linux/_layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a720e7a2cc4e44636f6e433a2750e945dc38e8b2
--- /dev/null
+++ b/build/torch29-cxx11-cu130-x86_64-linux/_layers/__init__.py
@@ -0,0 +1,10 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# from megablocks.layers.dmoe import dMoE
+from .moe import MoE
+
+__all__ = [
+ 'MoE',
+ # 'dMoE',
+]
diff --git a/build/torch29-cxx11-cu130-x86_64-linux/_layers/activation_fn.py b/build/torch29-cxx11-cu130-x86_64-linux/_layers/activation_fn.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e1d956704840aa4daf7d1d71d24e051567feab9
--- /dev/null
+++ b/build/torch29-cxx11-cu130-x86_64-linux/_layers/activation_fn.py
@@ -0,0 +1,33 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Callable, Union
+
+import torch
+from ..stk import Matrix
+
+
+def act_fn(
+ x: Matrix,
+ function: Callable,
+ return_grad_fn: bool = False,
+ **kwargs,
+) -> Union[tuple[Matrix, Any] | Matrix]:
+ assert isinstance(x, Matrix)
+ with torch.set_grad_enabled(torch.is_grad_enabled() or return_grad_fn):
+ if return_grad_fn:
+ x.data.requires_grad = True
+ out = function(x.data, **kwargs)
+ y = Matrix(
+ x.size(),
+ out,
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
+ if return_grad_fn:
+ return y, out.backward
+ return y
diff --git a/build/torch29-cxx11-cu130-x86_64-linux/_layers/all_to_all.py b/build/torch29-cxx11-cu130-x86_64-linux/_layers/all_to_all.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ac7067bcaa34db1d82b340c43550fe3577aa7a3
--- /dev/null
+++ b/build/torch29-cxx11-cu130-x86_64-linux/_layers/all_to_all.py
@@ -0,0 +1,54 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import torch.distributed as dist
+
+
+class AllToAllOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op):
+ out = torch.empty((sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype)
+
+ ctx.input_shape = x.shape
+ ctx.output_split_sizes = output_split_sizes
+ ctx.input_split_sizes = input_split_sizes
+ ctx.group = group
+ handle = dist.all_to_all_single(
+ out,
+ x,
+ output_split_sizes=output_split_sizes,
+ input_split_sizes=input_split_sizes,
+ group=group,
+ async_op=async_op,
+ )
+ return out, handle
+
+ @staticmethod
+ def backward(ctx, grad, _):
+ if ctx.needs_input_grad[0]:
+ out = torch.empty(
+ ctx.input_shape,
+ device=grad.device,
+ dtype=grad.dtype,
+ )
+ dist.all_to_all_single(
+ out,
+ grad,
+ output_split_sizes=ctx.input_split_sizes,
+ input_split_sizes=ctx.output_split_sizes,
+ group=ctx.group,
+ )
+ return out, None, None, None, None
+ return None, None, None, None, None
+
+
+def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False):
+ return AllToAllOp.apply(
+ x,
+ output_split_sizes,
+ input_split_sizes,
+ group,
+ async_op,
+ )
diff --git a/build/torch29-cxx11-cu130-x86_64-linux/_layers/arguments.py b/build/torch29-cxx11-cu130-x86_64-linux/_layers/arguments.py
new file mode 100644
index 0000000000000000000000000000000000000000..4db9b1bd38bc2e2f421625c124f86b85f45c5ae0
--- /dev/null
+++ b/build/torch29-cxx11-cu130-x86_64-linux/_layers/arguments.py
@@ -0,0 +1,101 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import dataclasses
+from functools import partial
+from typing import Any, Callable, Optional, Union
+
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+
+# import megablocks.grouped_gemm_util as grouped_gemm
+from .. import grouped_gemm_util as grouped_gemm
+
+# Type annotation for in-place Tensor initialization function.
+InitFn = Union[Callable[[torch.Tensor], None], partial[torch.Tensor]]
+
+_ALLOWED_BITWIDTHS = (-1, 4, 8)
+
+DEFAULT_ACTIVATION_FN = partial(F.gelu, approximate='tanh')
+
+
+@dataclasses.dataclass
+class Arguments:
+ # Model arguments.
+ hidden_size: int = 1024
+ ffn_hidden_size: int = 4096
+ num_layers: int = 1
+ bias: bool = True
+ return_bias: bool = True
+ activation_fn: Optional[Callable] = DEFAULT_ACTIVATION_FN
+
+ # MoE arguments.
+ moe_num_experts: int = 1
+ moe_top_k: int = 1
+ moe_capacity_factor: int = 1
+ moe_normalize_expert_weights: Optional[Union[int, float]] = None
+ moe_loss_weight: float = 0.1
+ moe_jitter_eps: Optional[float] = None
+ moe_lbl_in_fp32: bool = False
+
+ # Parallelism arguments.
+ moe_expert_model_parallelism: bool = False
+ expert_parallel_group: Optional[dist.ProcessGroup] = None
+ pipeline_model_parallel_size: int = 1
+ num_layers_per_virtual_pipeline_stage: Optional[int] = None
+
+ # Compute arguments.
+ memory_optimized_mlp: bool = False
+ mlp_type: str = 'mlp'
+ mlp_impl: str = 'sparse'
+
+ # Initialization arguments.
+ fp16: bool = True
+ bf16: bool = False
+ device: Union[int, torch.device] = dataclasses.field(default_factory=torch.cuda.current_device)
+ init_method: InitFn = partial(torch.nn.init.normal_, mean=0.0, std=0.02)
+ output_layer_init_method: InitFn = init_method
+
+ # Benchmarking arguments.
+ uniform_expert_assignment: bool = False
+
+ # shared expert arguments
+ shared_expert: bool = False # enable using shared expert
+ fc_cls: Any = torch.nn.Linear # class of the fully connected layer in shared expert (purpose: to allow using custom FC layer eg te.Linear (for FP8))
+ fc_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict,) # kwargs for custom fc layers
+ remat_act_fn: bool = True # enable act fn to be rematerialized instead of stored
+ shared_expert_hidden_size: Optional[
+ int] = None # hidden size of the shared expert IF we want to set it to something different from hidden_size
+ shared_expert_weighted_sum: bool = False # enable using weighted sum for shared expert output (wieghted by number of experts used)
+
+ # Router Z-loss arguments
+ moe_zloss_weight: float = 0 # 1e-3 is a reasonable value
+ moe_zloss_in_fp32: bool = False
+
+ def __post_init__(self):
+ # Sparse MLP is not supported with triton >=3.2.0
+ # TODO: Remove this once sparse is supported with triton >=3.2.0
+ if self.__getattribute__('mlp_impl') == 'sparse':
+ try:
+ import triton
+ if triton.__version__ >= '3.2.0':
+ raise ValueError(
+ 'Sparse MLP is not supported with triton >=3.2.0. Please use mlp_impl="grouped" instead.',
+ )
+ except ImportError:
+ raise ImportError('Triton is required for sparse MLP implementation')
+
+ if self.__getattribute__('mlp_impl') == 'grouped':
+ grouped_gemm.assert_grouped_gemm_is_available()
+
+ if self.shared_expert_hidden_size is None:
+ self.shared_expert_hidden_size = self.ffn_hidden_size
+
+
+def from_megatron(megatron_args: Any):
+ args = Arguments()
+ for field in dataclasses.fields(args):
+ if hasattr(megatron_args, field.name):
+ setattr(args, field.name, getattr(megatron_args, field.name))
+ return args
diff --git a/build/torch29-cxx11-cu130-x86_64-linux/_layers/common.py b/build/torch29-cxx11-cu130-x86_64-linux/_layers/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d07109702963ba48a3b94ab860807954dfd79c1
--- /dev/null
+++ b/build/torch29-cxx11-cu130-x86_64-linux/_layers/common.py
@@ -0,0 +1,26 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+from .arguments import Arguments
+
+
+def dtype(args: Arguments):
+ if args.fp16:
+ return torch.float16
+ elif args.bf16:
+ return torch.bfloat16
+ return None
+
+
+def cast_if_autocast_enabled(tensor):
+ if torch.is_autocast_enabled():
+ if tensor.device.type == 'cuda':
+ dtype = torch.get_autocast_gpu_dtype()
+ elif tensor.device.type == 'cpu':
+ dtype = torch.get_autocast_cpu_dtype()
+ else:
+ raise NotImplementedError()
+ return tensor.to(dtype=dtype)
+ return tensor
diff --git a/build/torch29-cxx11-cu130-x86_64-linux/_layers/dmlp_registry.py b/build/torch29-cxx11-cu130-x86_64-linux/_layers/dmlp_registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..de2ed047042e438c7190ebb139b6f7f30009734c
--- /dev/null
+++ b/build/torch29-cxx11-cu130-x86_64-linux/_layers/dmlp_registry.py
@@ -0,0 +1,42 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Union
+
+from . import glu, mlp
+from .arguments import Arguments
+
+MlpType = Union[mlp.SparseMLP, glu.SparseGLU]
+
+_REGISTRY = {
+ 'mlp': {
+ 'grouped': mlp.GroupedMLP,
+ 'sparse': mlp.SparseMLP,
+ },
+ 'glu': {
+ 'grouped': glu.GroupedGLU,
+ 'sparse': glu.SparseGLU,
+ },
+}
+
+
+def get(args: Arguments) -> MlpType:
+ """Returns an MLP for use in a dMoE instance.
+
+ Uses the provided arguments to instantiate the appropriate
+ MLP instance. This only contains MLPs for use in dMoEs
+ (ie. only for the dropless versions of MoEs).
+
+ Args:
+ args: propagated Arguments dataclass.
+
+ Returns:
+ An instantiated MLP constructed using the input args.
+ """
+ if args.mlp_type not in _REGISTRY:
+ raise ValueError(f'Unsupported mlp type: {args.mlp_type}')
+
+ if args.mlp_impl not in _REGISTRY[args.mlp_type]:
+ raise ValueError(f'{args.mlp_type} does not support {args.mlp_impl} backend.',)
+
+ return _REGISTRY[args.mlp_type][args.mlp_impl](args)
diff --git a/build/torch29-cxx11-cu130-x86_64-linux/_layers/dmoe.py b/build/torch29-cxx11-cu130-x86_64-linux/_layers/dmoe.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d0375a4df2f27134c4127e60be04f3b45693050
--- /dev/null
+++ b/build/torch29-cxx11-cu130-x86_64-linux/_layers/dmoe.py
@@ -0,0 +1,337 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import numpy as np
+import torch
+
+# try:
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/dmoe.py is needed.',
+# )
+
+# import megablocks.ops as ops
+# # from megablocks.ops import ops
+# from megablocks.layers import common, dmlp_registry, moe, mpu
+# from megablocks.layers.arguments import Arguments
+
+from .. import stk
+from .. import ops
+from . import common, dmlp_registry, moe, mpu
+from .arguments import Arguments
+
+def promote_scalar(x):
+ return x.view(1) if not len(x.size()) else x
+
+
+class ParallelDroplessMLP(moe.ParallelMLP):
+
+ def __init__(self, args: Arguments):
+ super(ParallelDroplessMLP, self).__init__(args)
+ self.hidden_size = args.hidden_size
+ self.ffn_hidden_size = mpu.features_per_rank(args)
+ self.blocking = 128
+ self.mlp = dmlp_registry.get(args)
+
+ # Calculate the number of bits needed to represent the column indices
+ # in the intermediate sparse matrix.
+ max_column_index = ((self.ffn_hidden_size * self.num_experts) // self.blocking)
+ self.transpose_sort_end_bit = max(
+ int(np.ceil(np.log2(max_column_index))),
+ 1,
+ )
+
+ def sparse_transpose(self, size, row_indices, column_indices, offsets):
+ block_columns = size[1] // self.blocking
+
+ # Sort row indices by column indices to get the transposed matrix's
+ # column indices.
+ #
+ # NOTE: Our sort operation uses the same width indices as the input values.
+ # To avoid overflow when we have large activation matrices we cast to
+ # 32-bit before sorting.
+ _, gather_indices = ops.sort(
+ column_indices.int(),
+ self.transpose_sort_end_bit,
+ )
+
+ # There are a constant number of blocks in every row of the sparse matrix.
+ # A blocks offset is:
+ #
+ # row_index * blocks_per_row + column_index % blocks_per_row
+ #
+ # Once we have the block offsets ordered for transposition we can divide
+ # by blocks_per_row to get the transposed column indices.
+ column_indices_t = row_indices.gather(0, gather_indices.long())
+ block_offsets_t = gather_indices.int()
+
+ zero = torch.zeros((1,), dtype=torch.int32, device=row_indices.device)
+ nnz_per_column = ops.histogram(column_indices, block_columns)
+ nnz_per_column = ops.inclusive_cumsum(nnz_per_column, 0)
+ if nnz_per_column.dim() == 0:
+ # This addresses an edge case when ffn_hidden_size is equal to self.blocking.
+ nnz_per_column = nnz_per_column.unsqueeze(0)
+ offsets_t = torch.cat([zero, nnz_per_column])
+ return column_indices_t, offsets_t, block_offsets_t
+
+ def topology(self, x, padded_bins):
+ padded_tokens, _ = x.size()
+ assert padded_tokens % self.blocking == 0
+ if self.ffn_hidden_size % self.blocking != 0:
+ raise ValueError(
+ f'The ffn_hidden_size {self.ffn_hidden_size} must be divisible by ' +
+ f'the block size {self.blocking}. Please update your configuration.',
+ )
+
+ # Offsets for the sparse matrix. All rows have the
+ # same number of nonzero blocks dictated by the
+ # dimensionality of a single expert.
+ block_rows = padded_tokens // self.blocking
+ blocks_per_row = self.ffn_hidden_size // self.blocking
+ offsets = torch.arange(
+ 0,
+ block_rows * blocks_per_row + 1,
+ blocks_per_row,
+ dtype=torch.int32,
+ device=x.device,
+ )
+
+ # Indices for the sparse matrix. The indices for
+ # the intermediate matrix are dynamic depending
+ # on the mapping of tokens to experts.
+ column_indices = ops.topology(
+ padded_bins,
+ self.blocking,
+ block_rows,
+ blocks_per_row,
+ )
+
+ # TODO(tgale): This is unused. Remove the need for this in stk.
+ # For now, use meta init to save the device memory.
+ data = torch.empty(
+ column_indices.numel(),
+ self.blocking,
+ self.blocking,
+ dtype=common.dtype(self.args),
+ device='meta',
+ )
+ shape = (
+ padded_tokens,
+ self.ffn_hidden_size * mpu.experts_per_rank(self.args),
+ )
+ row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
+ column_indices_t, offsets_t, block_offsets_t = self.sparse_transpose(
+ shape,
+ row_indices,
+ column_indices,
+ offsets,
+ )
+ return stk.Matrix(
+ shape,
+ data,
+ row_indices,
+ column_indices,
+ offsets,
+ column_indices_t,
+ offsets_t,
+ block_offsets_t,
+ )
+
+ def indices_and_padded_bins(self, top_experts):
+ # Sort the expert ids to produce the scatter/gather
+ # indices for the permutation.
+ top_experts = top_experts.int()
+ bin_ids, indices = ops.sort(top_experts, self.sort_end_bit)
+
+ # Histogram the expert ids to identify the number of
+ # tokens routed to each expert.
+ tokens_per_expert = ops.histogram(top_experts, self.num_experts)
+
+ # Round the token counts up to the block size used in
+ # the matrix muliplications. Caculate the starting
+ # position of each bin.
+ padded_tokens_per_expert = ops.round_up(
+ tokens_per_expert,
+ self.blocking,
+ )
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ padded_bins = promote_scalar(padded_bins)
+
+ # Calculate the bin bounds for the sorted tokens.
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ bins = promote_scalar(bins)
+ return indices, bin_ids, bins, padded_bins, tokens_per_expert
+
+ def sparse_forward_once(self, x, expert_weights, top_experts):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, padded_bins, tokens_per_expert = (self.indices_and_padded_bins(top_experts))
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.padded_gather(
+ x,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ self.top_k,
+ )
+
+ # Create the sparse matrix topology.
+ with torch.no_grad():
+ topo = self.topology(x, padded_bins)
+
+ # Perform the expert computation.
+ x = self.mlp(x, topo)
+
+ # Un-route the data for the MoE output.
+ x = ops.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ padded_bins,
+ self.top_k,
+ )
+ return x, tokens_per_expert
+
+ # For use in the base-class parallel_forward_once.
+ def sparse_permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy, # unused
+ top_k,
+ ):
+
+ # Round the token counts up to the block size used in the matrix
+ # multiplication. Calculate the starting position of each bin.
+ padded_tokens_per_expert = ops.round_up(
+ tokens_per_expert,
+ self.blocking,
+ )
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ padded_bins = promote_scalar(padded_bins)
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
+
+ # Create the sparse matrix topology.
+ with torch.no_grad():
+ topo = self.topology(x, padded_bins)
+
+ # Perform the expert computation.
+ x = self.mlp(x, topo)
+
+ # Un-route the data for the MoE output.
+ return ops.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ def grouped_forward_once(self, x, expert_weights, top_experts):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ out = self.grouped_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ -1, # unused
+ self.args.moe_top_k,
+ )
+ return out, tokens_per_expert
+
+ def grouped_permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy, # unused
+ top_k,
+ ):
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.gather(x, indices, bin_ids, bins, top_k)
+
+ # Perform the expert computation.
+ x = self.mlp(x, tokens_per_expert)
+
+ # Un-route the data for the MoE output.
+ return ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
+
+ def forward_once(self, x, expert_weights, top_experts):
+ if self.args.mlp_impl == 'sparse':
+ return self.sparse_forward_once(x, expert_weights, top_experts)
+ else:
+ return self.grouped_forward_once(x, expert_weights, top_experts)
+
+ def permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ ):
+ if self.args.mlp_impl == 'sparse':
+ return self.sparse_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ )
+ else:
+ return self.grouped_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ )
+
+
+class dMoE(moe.MoE):
+
+ def _init_experts_mlp(self, args: Arguments):
+ return ParallelDroplessMLP(args)
diff --git a/build/torch29-cxx11-cu130-x86_64-linux/_layers/gelu.py b/build/torch29-cxx11-cu130-x86_64-linux/_layers/gelu.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4c9e6532798615b5c12c96694241a4c18ee8f7b
--- /dev/null
+++ b/build/torch29-cxx11-cu130-x86_64-linux/_layers/gelu.py
@@ -0,0 +1,52 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# try:
+# import stk
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/gelu.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+import torch.nn.functional as F
+
+
+@torch.jit.script
+def _gelu_backward_inplace(g, x):
+ tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
+ ff = (0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out))
+ return g.mul_(ff)
+
+
+def gelu_backward_(grad: stk.Matrix, x: stk.Matrix):
+ # NOTE: The two sparse matrices must have the same topology.
+ if isinstance(grad, stk.Matrix) and isinstance(x, stk.Matrix):
+ return stk.Matrix(
+ x.size(),
+ _gelu_backward_inplace(grad.data, x.data),
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
+ return _gelu_backward_inplace(grad, x)
+
+
+def gelu(x: stk.Matrix):
+ assert isinstance(x, stk.Matrix)
+ return stk.Matrix(
+ x.size(),
+ F.gelu(x.data, approximate='tanh'),
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
diff --git a/build/torch29-cxx11-cu130-x86_64-linux/_layers/glu.py b/build/torch29-cxx11-cu130-x86_64-linux/_layers/glu.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f297a41ff6a1a2a285f5b461951672364b898da
--- /dev/null
+++ b/build/torch29-cxx11-cu130-x86_64-linux/_layers/glu.py
@@ -0,0 +1,244 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# import stk.ops
+# try:
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/glu.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+
+# from megablocks import grouped_gemm_util as gg
+# from megablocks.layers import common, mpu
+# from megablocks.layers.activation_fn import act_fn
+# from megablocks.layers.arguments import Arguments
+# from megablocks.layers.mlp import (
+# SharedMLP,
+# SparseMLP,
+# create_dmoe_expert_weights,
+# resolve_dtensor,
+# )
+
+from .. import grouped_gemm_util as gg
+from . import common, mpu
+from .activation_fn import act_fn
+from .arguments import Arguments
+from .mlp import (
+ SharedMLP,
+ SparseMLP,
+ create_dmoe_expert_weights,
+ resolve_dtensor,
+)
+
+
+class SparseGLU(SparseMLP):
+
+ def __init__(self, args: Arguments):
+ super().__init__(args)
+ self.v1 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ with torch.no_grad():
+ self.v1.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ ),
+ )
+
+ mpu.set_expert_model_parallel_attributes(
+ self.v1,
+ self._should_set_parallelism_attribute,
+ )
+
+ def forward(self, x, topo):
+ if self.args.memory_optimized_mlp:
+ raise NotImplementedError(
+ 'Memory optimized implementation not yet supported with GLU with sparse kernels.',
+ )
+
+ w1, v1, w2 = self.scale_grad(self.w1), self.scale_grad(self.v1,), self.scale_grad(self.w2)
+ w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2)
+
+ # Compute the GLU.
+ x1 = stk.ops.sdd(x, w1.t(), topo)
+ x2 = stk.ops.sdd(x, v1.t(), topo)
+
+ activation_fn_out = act_fn(x1, self.args.activation_fn)
+ x1 = stk.ops.mul(activation_fn_out, x2)
+
+ return stk.ops.dsd(x1, w2)
+
+
+class MemoryOptimizedGroupedGLU(torch.autograd.Function):
+ """GroupedMLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, v1, w2, batch_sizes, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ v1 = v1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], v1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not v1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1', 'v1' and 'w2'.")
+
+ # Layer 0: x @ w1.t().
+ assert gg.backend is not None
+ sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True)
+ v1_out = gg.backend.gmm(x, v1, batch_sizes, trans_b=True)
+
+ # GeLU.
+ activation_fn_out = activation_fn(sdd_out) * v1_out
+
+ # Layer 1: x @ w2.
+ dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, v1, w2, batch_sizes, x, sdd_out, v1_out)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx, ddsd_out):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # Unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, v1, w2 = saved_tensors[:3]
+ batch_sizes = saved_tensors[3]
+ x = saved_tensors[4]
+ sdd_out, v1_out = saved_tensors[5:7]
+
+ # Rematerialize activation_fn output.
+ activation_fn = ctx.activation_fn
+ with torch.set_grad_enabled(True):
+ sdd_out.requires_grad = True
+ v1_out.requires_grad = True
+ activation_fn_out = activation_fn(sdd_out) * v1_out
+ activation_grad_fn = activation_fn_out.backward
+
+ # Compute dw2 with recomputed activation_fn output.
+ assert gg.backend is not None
+ dw2 = gg.backend.gmm(
+ activation_fn_out,
+ ddsd_out,
+ batch_sizes,
+ trans_a=True,
+ )
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ gg.backend.gmm(
+ ddsd_out,
+ w2,
+ batch_sizes,
+ trans_b=True,
+ c=dactivation_fn_out,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out)
+ dsdd_out = sdd_out.grad
+ dv1_out = v1_out.grad
+
+ # Compute dw1.
+ dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True)
+
+ # Compute dv1.
+ dv1 = gg.backend.gmm(dv1_out, x, batch_sizes, trans_a=True)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ dx = ddsd_out
+ gg.backend.gmm(dsdd_out, w1, batch_sizes, c=dx)
+ dx += gg.backend.gmm(dv1_out, v1, batch_sizes)
+ return dx, dw1, dv1, dw2, None, None
+
+
+memory_optimized_grouped_glu = MemoryOptimizedGroupedGLU.apply
+
+
+class GroupedGLU(SparseGLU):
+
+ def forward(self, x, tokens_per_expert):
+ batch_sizes = tokens_per_expert.cpu().to(torch.long)
+ w1, v1, w2 = (
+ self.scale_grad(self.w1),
+ self.scale_grad(self.v1),
+ self.scale_grad(self.w2),
+ )
+ w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2)
+
+ # Re-shape the weights for the grouped GEMMs.
+ ne = mpu.experts_per_rank(self.args)
+ w1 = w1.view(ne, -1, self.args.hidden_size)
+ v1 = v1.view(ne, -1, self.args.hidden_size)
+ w2 = w2.view(ne, -1, self.args.hidden_size)
+
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_grouped_glu(
+ x,
+ w1,
+ v1,
+ w2,
+ batch_sizes,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ assert gg.ops is not None
+ x1 = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
+ x2 = gg.ops.gmm(x, v1, batch_sizes, trans_b=True)
+ x1 = self.args.activation_fn(x1) * x2
+ return gg.ops.gmm(x1, w2, batch_sizes)
+
+
+class SharedGLU(SharedMLP):
+ """GPU for shared expert.
+
+ Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTGLU class
+ """
+
+ def __init__(self, args: Arguments):
+ super().__init__(args)
+ self.gate_proj = args.fc_cls(
+ args.hidden_size,
+ self.args.shared_expert_hidden_size,
+ **self.fc_kwargs,
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))
diff --git a/build/torch29-cxx11-cu130-x86_64-linux/_layers/memory_test.py b/build/torch29-cxx11-cu130-x86_64-linux/_layers/memory_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..74d1166931b712635131985b25a89f4ca23e576d
--- /dev/null
+++ b/build/torch29-cxx11-cu130-x86_64-linux/_layers/memory_test.py
@@ -0,0 +1,103 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import gc
+
+import torch
+import torch.distributed as dist
+
+# from megablocks.layers import arguments, dmoe
+from . import arguments, dmoe
+
+_TESTS = ((8, 2048, 4096, 4096, 32, 4),)
+
+
+def get_tensors():
+ ptrs = set()
+ out = []
+ for obj in gc.get_objects():
+ if torch.is_tensor(obj):
+ if not obj.is_contiguous() or obj.data_ptr() in ptrs:
+ continue
+ out.append(obj)
+ ptrs.add(obj.data_ptr())
+ return out
+
+
+def test_memory(
+ group,
+ batch_size,
+ sequence_length,
+ hidden_size,
+ ffn_hidden_size,
+ num_experts,
+ top_k,
+):
+ args = arguments.Arguments(
+ hidden_size=hidden_size,
+ ffn_hidden_size=ffn_hidden_size,
+ moe_num_experts=num_experts,
+ moe_top_k=top_k,
+ moe_expert_model_parallelism=True,
+ expert_parallel_group=group,
+ fp16=False,
+ bf16=True,
+ device=torch.cuda.current_device(),
+ )
+ layer = dmoe.dMoE(args).cuda()
+
+ x = torch.randn((batch_size, sequence_length, hidden_size),
+ device=torch.cuda.current_device(),
+ dtype=torch.bfloat16).requires_grad_(True)
+ torch.cuda.empty_cache()
+
+ # Run forward + backward.
+ # with torch.autograd.detect_anomaly():
+ out, _ = layer(x)
+ out.mean().backward()
+
+ # Report peak memory.
+ mem = torch.cuda.max_memory_allocated()
+ print('Max Memory Allocated = {:0.0f}MiB'.format(mem / 1e6))
+ print('Max Memory Reserved = {:0.0f}MiB'.format(torch.cuda.max_memory_reserved() / 1e6,),)
+
+ # Calculate weight and gradient memory usage.
+ weight_memory = 2 * (
+ layer.router.layer.weight.numel() + layer.experts.mlp.w1.numel() + layer.experts.mlp.w2.numel()
+ )
+
+ def grad_numel(x):
+ if x.grad is not None:
+ return x.grad.numel()
+ return 0
+
+ grad_memory = 2 * (
+ grad_numel(layer.router.layer.weight) + grad_numel(layer.experts.mlp.w1) + grad_numel(layer.experts.mlp.w2)
+ )
+ weight_memory += grad_memory
+
+ print('Weight Memory Allocated = {:0.0f}MiB'.format(weight_memory / 1e6))
+ print('Activation Memory Allocated = {:0.0f}MiB'.format((mem - weight_memory) / 1e6,),)
+
+ # Manually calculate GPU memory usage from the garbage
+ # collector.
+ gc.collect()
+ total = 0
+ tensors = get_tensors()
+ tensors = sorted(tensors, key=lambda x: -x.numel())
+ for i, t in enumerate(tensors):
+ total += t.numel()
+ print(f'{i}: {t.shape}, {t.numel() * 2}')
+ del tensors
+
+ print('Total Bytes Found = {:0.0f}MiB'.format(total * 2 / 1e6))
+
+
+if __name__ == '__main__':
+ assert dist.is_available()
+ group = dist.init_process_group(backend='nccl')
+ local_rank = dist.get_rank(group)
+ torch.cuda.set_device(local_rank)
+
+ for args in _TESTS:
+ test_memory(group, *args)
diff --git a/build/torch29-cxx11-cu130-x86_64-linux/_layers/mlp.py b/build/torch29-cxx11-cu130-x86_64-linux/_layers/mlp.py
new file mode 100644
index 0000000000000000000000000000000000000000..c99afb9904c24a8b6a83e79059cd1251dbbfd99e
--- /dev/null
+++ b/build/torch29-cxx11-cu130-x86_64-linux/_layers/mlp.py
@@ -0,0 +1,587 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# try:
+# import stk
+# import stk.backend.triton_kernels
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/mlp.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+from packaging import version
+
+# from megablocks import grouped_gemm_util as gg
+# from megablocks.layers import common, gelu, mpu
+# from megablocks.layers.activation_fn import act_fn
+# from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
+
+from .. import grouped_gemm_util as gg
+from . import common, gelu, mpu
+from .activation_fn import act_fn
+from .arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
+
+class ScaleGradient(torch.autograd.Function):
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx: Any, x: torch.Tensor, scale: float):
+ ctx.scale = scale
+ return x
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx: torch.Tensor, grad: torch.Tensor):
+ return grad * ctx.scale, None
+
+
+scale_gradient = ScaleGradient.apply
+
+
+def resolve_dtensor(weight: torch.Tensor):
+ if version.parse(torch.__version__) >= version.parse('2.0.0'):
+ from torch.distributed._tensor import DTensor
+ if isinstance(weight, DTensor):
+ return weight.to_local()
+ return weight
+
+
+def create_moe_expert_weights(
+ args: Arguments,
+ num_experts: int,
+ ffn_hidden_size: int,
+ hidden_size: int,
+ init_method: InitFn,
+):
+ # Create the entire weight matrix such that the sampled weights will
+ # not vary between data parallelism and expert model parallelism for
+ # the same random seed.
+ master_weights = torch.empty(
+ num_experts,
+ ffn_hidden_size,
+ hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ )
+ init_method(master_weights)
+
+ if not args.moe_expert_model_parallelism:
+ return master_weights
+
+ # Calculate the amount of sharding in each dimension.
+ expert_sharding_degree = mpu.expert_sharding_degree(args)
+ hidden_sharding_degree = mpu.hidden_sharding_degree(args)
+
+ # Calculate the experts per rank.
+ #
+ # NOTE: We assign ranks to be expert parallel before going
+ # tensor parallel.
+ rank = mpu.get_expert_parallel_rank(args)
+ expert_rank = rank % expert_sharding_degree
+ num_experts_per_rank = num_experts // expert_sharding_degree
+ start_expert = expert_rank * num_experts_per_rank
+ end_expert = (expert_rank + 1) * num_experts_per_rank
+
+ # Calculate the rows per rank.
+ row_rank = rank // expert_sharding_degree
+ num_rows_per_rank = ffn_hidden_size // hidden_sharding_degree
+ start_row = row_rank * num_rows_per_rank
+ end_row = (row_rank + 1) * num_rows_per_rank
+
+ # Slice the weight matrix to get the chunk for this rank.
+ with torch.no_grad():
+ weights = master_weights[start_expert:end_expert, start_row:end_row]
+ return weights
+
+
+class MLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ # expert_parallel_world_size = mpu.get_expert_parallel_world_size(args)
+ experts_per_rank = mpu.experts_per_rank(args)
+
+ self.w1 = torch.nn.Parameter(
+ torch.empty(
+ experts_per_rank,
+ args.hidden_size,
+ mpu.features_per_rank(args),
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ self.w2 = torch.nn.Parameter(
+ torch.empty(
+ experts_per_rank,
+ mpu.features_per_rank(args),
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w1,
+ args.moe_expert_model_parallelism,
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w2,
+ args.moe_expert_model_parallelism,
+ )
+
+ # Initialize the parameters for the MLP.
+ #
+ # NOTE: It is important that we create the weight tensors prior
+ # to creating the master weights and slicing our the piece for
+ # this rank. If the master weights are created first the PyTorch
+ # caching allocator appears to use the same memory block for these
+ # and the slice which causes large increases in our peak memory
+ # usage.
+ with torch.no_grad():
+ w1 = create_moe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ )
+ self.w1.copy_(w1.transpose(1, 2).contiguous())
+ self.w2.copy_(
+ create_moe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.output_layer_init_method,
+ ),
+ )
+
+ self.gradient_scale = None
+ if self.args.moe_expert_model_parallelism:
+ self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,)
+
+ def scale_grad(self, w):
+ if self.gradient_scale is None:
+ return w
+ return scale_gradient(w, self.gradient_scale)
+
+ def forward(self, x):
+ w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2)
+ w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2)
+ x = torch.bmm(x, w1)
+ x = self.args.activation_fn(x)
+ return torch.bmm(x, w2)
+
+
+def create_dmoe_expert_weights(
+ args: Arguments,
+ num_experts: int,
+ rows: int,
+ columns: int,
+ init_method: InitFn,
+):
+ weights = create_moe_expert_weights(
+ args,
+ num_experts,
+ rows,
+ columns,
+ init_method,
+ )
+ return weights.view([-1, columns])
+
+
+class MemoryOptimizedMLP(torch.autograd.Function):
+ """Sparse MLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, w2, topo, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.")
+
+ topo_tensors = (
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets,
+ topo.column_indices_t,
+ topo.offsets_t,
+ topo.block_offsets_t,
+ )
+
+ # Layer 0: x @ w1.t().
+ sdd_out = stk.ops.sdd(x, w1.t(), topo)
+
+ # GeLU.
+ activation_fn_out = act_fn(sdd_out, activation_fn)
+
+ # Layer 1: x @ w2.
+ dsd_out = stk.ops.dsd(activation_fn_out, w2)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.shape = topo.shape
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.data.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, w2, *topo_tensors, x, sdd_out.data)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx, ddsd_out):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, w2 = saved_tensors[:2]
+ topo_tensors = saved_tensors[2:8]
+ x = saved_tensors[8]
+ sdd_out_data = saved_tensors[9]
+
+ # rematerialize activation function output
+ activation_fn = ctx.activation_fn
+ sdd_out = stk.Matrix(ctx.shape, sdd_out_data, *topo_tensors)
+ activation_fn_out, activation_grad_fn = act_fn(
+ sdd_out,
+ activation_fn,
+ return_grad_fn=True,
+ )
+
+ # Compute dw2 with recomputed activation_fn output.
+ dw2 = stk.ops.dsd(activation_fn_out.t(), ddsd_out)
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ stk.backend.triton_kernels.sdd(
+ ddsd_out,
+ w2.t(),
+ dactivation_fn_out.shape,
+ dactivation_fn_out.data,
+ dactivation_fn_out.offsets,
+ dactivation_fn_out.row_indices,
+ dactivation_fn_out.column_indices,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ if activation_fn is DEFAULT_ACTIVATION_FN:
+ dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out)
+ else:
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out.data)
+ dsdd_out = stk.Matrix(ctx.shape, sdd_out.data.grad, *topo_tensors)
+
+ # Compute dw1.
+ dw1 = stk.ops.dsd(dsdd_out.t(), x)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ stk.backend.triton_kernels.dsd(
+ dsdd_out.shape,
+ dsdd_out.data,
+ dsdd_out.offsets,
+ dsdd_out.row_indices,
+ dsdd_out.column_indices,
+ dsdd_out.offsets_t,
+ dsdd_out.column_indices_t,
+ dsdd_out.block_offsets_t,
+ False,
+ w1,
+ ddsd_out,
+ )
+ dx = ddsd_out
+ return dx, dw1, dw2, None, None
+
+
+memory_optimized_mlp = MemoryOptimizedMLP.apply
+
+
+class SparseMLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ self._num_rows_per_rank = mpu.experts_per_rank(args) * mpu.features_per_rank(args)
+
+ self.w1 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ self.w2 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+
+ # Initialize the parameters for the MLP.
+ #
+ # NOTE: It is important that we create the weight tensors prior
+ # to creating the master weights and slicing our the piece for
+ # this rank. If the master weights are created first the PyTorch
+ # caching allocator appears to use the same memory block for these
+ # and the slice which causes large increases in our peak memory
+ # usage.
+ with torch.no_grad():
+ self.w1.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ ),
+ )
+ self.w2.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.output_layer_init_method,
+ ),
+ )
+
+ self._should_set_parallelism_attribute = args.moe_expert_model_parallelism
+ mpu.set_expert_model_parallel_attributes(
+ self.w1,
+ self._should_set_parallelism_attribute,
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w2,
+ self._should_set_parallelism_attribute,
+ )
+
+ self.gradient_scale = None
+ if self.args.moe_expert_model_parallelism:
+ self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,)
+
+ def scale_grad(self, w):
+ if self.gradient_scale is None:
+ return w
+ return scale_gradient(w, self.gradient_scale)
+
+ def forward(self, x, topo):
+ w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2)
+ w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2)
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_mlp(
+ x,
+ w1,
+ w2,
+ topo,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ x = stk.ops.sdd(x, w1.t(), topo)
+ activation_fn_out = act_fn(x, self.args.activation_fn)
+ return stk.ops.dsd(activation_fn_out, w2)
+
+
+class MemoryOptimizedGroupedMLP(torch.autograd.Function):
+ """GroupedMLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, w2, batch_sizes, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.")
+
+ # Layer 0: x @ w1.t().
+ assert gg.backend is not None
+ sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True)
+
+ # activation_fn
+ activation_fn_out = activation_fn(sdd_out)
+
+ # Layer 1: x @ w2.
+ dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, w2, batch_sizes, x, sdd_out)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx: Any, ddsd_out: torch.Tensor):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # Unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, w2 = saved_tensors[:2]
+ batch_sizes = saved_tensors[2]
+ x = saved_tensors[3]
+ sdd_out = saved_tensors[4]
+
+ # Rematerialize activation_fn output.
+ activation_fn = ctx.activation_fn
+ with torch.set_grad_enabled(True):
+ sdd_out.requires_grad = True
+ activation_fn_out = activation_fn(sdd_out)
+ activation_grad_fn = activation_fn_out.backward
+
+ # Compute dw2 with recomputed activation_fn output.
+ assert gg.backend is not None
+ dw2 = gg.backend.gmm(
+ activation_fn_out,
+ ddsd_out,
+ batch_sizes,
+ trans_a=True,
+ )
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ gg.backend.gmm(
+ ddsd_out,
+ w2,
+ batch_sizes,
+ trans_b=True,
+ c=dactivation_fn_out,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ if activation_fn is DEFAULT_ACTIVATION_FN:
+ dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out)
+ else:
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out)
+ dsdd_out = sdd_out.grad
+
+ # Compute dw1.
+ dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ gg.backend.gmm(dsdd_out, w1, batch_sizes, c=ddsd_out)
+ dx = ddsd_out
+ return dx, dw1, dw2, None, None
+
+
+memory_optimized_grouped_mlp = MemoryOptimizedGroupedMLP.apply
+
+
+class GroupedMLP(SparseMLP):
+
+ def forward(self, x, tokens_per_expert):
+ batch_sizes = tokens_per_expert.cpu().to(torch.long)
+ w1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.w2))
+
+ # Re-shape the weights for the grouped GEMMs.
+ ne = mpu.experts_per_rank(self.args)
+ w1 = resolve_dtensor(w1).view(ne, -1, self.args.hidden_size)
+ w2 = resolve_dtensor(w2).view(ne, -1, self.args.hidden_size)
+
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_grouped_mlp(
+ x,
+ w1,
+ w2,
+ batch_sizes,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ assert gg.ops is not None
+ x = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
+ x = self.args.activation_fn(x)
+ return gg.ops.gmm(x, w2, batch_sizes)
+
+
+class SharedMLP(torch.nn.Module):
+ """MLP for shared expert.
+
+ Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTMLP class
+ """
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ self.fc_kwargs: dict[str, Any] = {
+ 'bias': args.bias,
+ 'device': args.device,
+ }
+ self.fc_kwargs.update(args.fc_kwargs)
+
+ self.up_proj = args.fc_cls(
+ args.hidden_size,
+ args.shared_expert_hidden_size,
+ **self.fc_kwargs,
+ )
+ self.act = args.activation_fn
+ self.down_proj = args.fc_cls(
+ args.shared_expert_hidden_size,
+ args.hidden_size,
+ **self.fc_kwargs,
+ )
+ self.down_proj._is_residual = True # a flag for llm-foundry init
+
+ def add_experts_sharedexpert(
+ self,
+ shared_expert_out: torch.Tensor,
+ expert_out: torch.Tensor,
+ ) -> torch.Tensor:
+ # Helper function to add expert output to shared expert output
+ # with optional weighted sum.
+ if self.args.shared_expert_weighted_sum:
+ # enable using weighted sum for shared expert output
+ # wieghted by number of experts used
+ t_experts = self.args.moe_top_k + 1
+ sh_mlp_out = shared_expert_out / t_experts
+ return sh_mlp_out.add(
+ expert_out,
+ alpha=(self.args.moe_top_k / t_experts),
+ )
+
+ return shared_expert_out + expert_out
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.down_proj(self.act(self.up_proj(x)))
diff --git a/build/torch29-cxx11-cu130-x86_64-linux/_layers/moe.py b/build/torch29-cxx11-cu130-x86_64-linux/_layers/moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0a4aeaacc9c86fc70944e730c53f7a55644e05e
--- /dev/null
+++ b/build/torch29-cxx11-cu130-x86_64-linux/_layers/moe.py
@@ -0,0 +1,507 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Optional, Tuple
+
+import numpy as np
+import torch
+import torch.distributed as dist
+
+# import megablocks.ops as ops
+# from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry
+# from megablocks.layers.all_to_all import all_to_all
+# from megablocks.layers.arguments import Arguments
+
+from ..ops import (
+ sort,
+ histogram,
+ inclusive_cumsum,
+ exclusive_cumsum,
+ binned_gather,
+ binned_scatter,
+ gather,
+ scatter,
+ repeat,
+ replicate,
+)
+
+from . import common, mlp, mpu, router, sharedexpert_registry
+from .arguments import Arguments
+from .all_to_all import all_to_all
+
+_LOAD_BALANCING_LOSS = []
+
+
+def save_load_balancing_loss(loss):
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.append(loss)
+
+
+def get_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ return _LOAD_BALANCING_LOSS
+
+
+def clear_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.clear()
+
+
+def batched_load_balancing_loss(args: Arguments):
+ if args.moe_loss_weight == 0:
+ return 0.0
+
+ # tokens_per_expert[i].shape = (num_experts)
+ # expert_scores[i].shape = (tokens, num_experts)
+ tokens_per_expert, expert_scores = zip(*get_load_balancing_loss())
+ num_layers_per_pipeline_stage = (args.num_layers // args.pipeline_model_parallel_size)
+ if args.num_layers_per_virtual_pipeline_stage is not None:
+ num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage
+
+ if len(tokens_per_expert) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f'Expected {num_layers_per_pipeline_stage} token_per_experts '
+ f'but found {len(tokens_per_expert)}.\nnum_layers = '
+ f'{args.num_layers}\npipeline_model_parallel_size = '
+ f'{args.pipeline_model_parallel_size}\n'
+ 'num_layers_per_virtual_pipeline_stage'
+ f' = {args.num_layers_per_virtual_pipeline_stage}',
+ )
+ if len(expert_scores) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f'Expected {num_layers_per_pipeline_stage} expert_scores '
+ f'but found {len(tokens_per_expert)}.\nnum_layers = '
+ f'{args.num_layers}\npipeline_model_parallel_size = '
+ f'{args.pipeline_model_parallel_size}\n'
+ 'num_layers_per_virtual_pipeline_stage'
+ f' = {args.num_layers_per_virtual_pipeline_stage}',
+ )
+
+ # Verify the shape of the tokens_per_expert and expert_scores tensors.
+ assert all((x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert))
+
+ tokens = expert_scores[0].shape[0]
+ assert all(((x.ndim == 2 and x.shape[1] == args.moe_num_experts and x.shape[0] == tokens) for x in expert_scores))
+
+ # Concatenate the contributions of each layer and convert to
+ # the correct types and formats for the dot product.
+ expert_scores = torch.cat(expert_scores, dim=1)
+ if args.moe_lbl_in_fp32:
+ expert_scores = expert_scores.float()
+ if tokens != 0:
+ expert_scores = expert_scores.mean(dim=0)
+ else:
+ expert_scores = expert_scores.sum(dim=0)
+ tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype)
+
+ expected_values = num_layers_per_pipeline_stage * args.moe_num_experts
+ assert tokens_per_expert.numel() == expected_values
+ assert expert_scores.numel() == expected_values
+
+ # Calculate the total scale across all factors.
+ #
+ # loss_weight * num_experts / (num_layers * tokens * top_k)
+ scale_numerator = (args.moe_num_experts * args.moe_loss_weight)
+ scale_denominator = (args.num_layers * tokens * args.moe_top_k)
+ scale = scale_numerator / scale_denominator
+ return scale * torch.dot(tokens_per_expert, expert_scores)
+
+
+# NOTE: This class defines MoE expert computation, including expert model parallel
+# communication. When using FSDP on top of MegaBlocks this is the module that should
+# be wrapped s.t. the weight all-gathers can be scheduled *before* the expert model
+# parallel all2all.
+class ParallelMLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super(ParallelMLP, self).__init__()
+ self.args = args
+
+ # Calculate the number of experts in total and the number of experts
+ # owned by this rank.
+ # world_size = mpu.get_expert_parallel_world_size(args)
+ self.num_experts = args.moe_num_experts
+ self.top_k = self.args.moe_top_k
+
+ # Calculate the number of bits needed to represent the expert indices
+ # so that we can pass it to radix sort.
+ self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1)
+
+ # Expert MLP.
+ self.mlp = mlp.MLP(args)
+
+ self.bias: Optional[torch.Tensor]
+ if self.args.bias:
+ # Note that the output bias is not parallelized with expert
+ # model parallelism.
+ self.bias = torch.nn.Parameter(
+ torch.empty(
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ torch.nn.init.zeros_(self.bias)
+ else:
+ self.register_parameter('bias', None)
+
+ # Select the forward function for the operating mode.
+ self.forward_fn = (self.parallel_forward_once if args.moe_expert_model_parallelism else self.forward_once)
+
+ def expert_capacity(self, tokens: int) -> int:
+ world_size = mpu.get_expert_parallel_world_size(self.args)
+ tokens_per_expert = (self.top_k * tokens * world_size / self.num_experts)
+ return int(self.args.moe_capacity_factor * tokens_per_expert)
+
+ def load_balancing_loss(self, tokens_per_expert: torch.Tensor, expert_scores: torch.Tensor):
+ """Calculate the load balancing loss contribution."""
+ assert len(expert_scores.size()) == 2
+ tokens, num_experts = expert_scores.size()
+ assert num_experts == self.num_experts
+ assert len(tokens_per_expert.size()) == 1
+ num_experts, = tokens_per_expert.size()
+ assert num_experts == self.num_experts
+ scale = self.num_experts / (tokens * self.top_k)
+ return scale * torch.dot(
+ tokens_per_expert.to(expert_scores.dtype),
+ expert_scores.mean(dim=0),
+ )
+
+ def indices_and_bins(self,
+ top_expert: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ # Sort the expert ids to produce the scatter/gather
+ # indices for the permutation.
+ #
+ # TODO(tgale): Is it worth doing this conversion to 32-bit
+ # prior? Could we place the `torch.max` operation to return
+ # 32-bit expert indices?
+ top_expert = top_expert.int()
+ # output = ops.sort(top_expert, self.sort_end_bit)
+ output = sort(top_expert, self.sort_end_bit)
+ assert output is not None
+ bin_ids, indices = output
+
+ # Histogram the expert ids to identify the number of
+ # tokens routed to each expert.
+ #
+ # TODO(tgale): Does the sorted data produce a more favorable
+ # data distribution for histogram? Or is the op parallelism
+ # worth more?
+ # tokens_per_expert = ops.histogram(top_expert, self.num_experts)
+ tokens_per_expert = histogram(top_expert, self.num_experts)
+
+ # Calculate the bin bounds for the sorted tokens.
+ # bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ bins = inclusive_cumsum(tokens_per_expert, 0)
+ assert bins is not None
+ bins = bins.view(1) if not len(bins.size()) else bins
+
+ assert isinstance(indices, torch.Tensor)
+ assert isinstance(bin_ids, torch.Tensor)
+ assert isinstance(bins, torch.Tensor)
+ assert isinstance(tokens_per_expert, torch.Tensor)
+
+ return indices, bin_ids, bins, tokens_per_expert
+
+ def permute_and_compute(
+ self,
+ x: torch.Tensor,
+ tokens_per_expert: int, # unused
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor, # unused
+ expert_weights: torch.Tensor,
+ bins: torch.Tensor,
+ expert_capacity: int,
+ top_k: int,
+ ):
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ # output = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
+ output = binned_gather(x, indices, bins, expert_capacity, top_k)
+ assert output is not None
+ x = output
+
+ # Perform the expert computation. Note that we don't
+ # use biases for these linear operations.
+ x = self.mlp(x)
+
+ # Un-route the data for the MoE output.
+ # return ops.binned_scatter(x, indices, expert_weights, bins, top_k)
+ return binned_scatter(x, indices, expert_weights, bins, top_k)
+
+
+ def forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ # If expert_capacity is set to zero, set the number of tokens
+ # per expert to the maximum we need to avoid dropping tokens.
+ sl, bs, _ = x.size()
+ expert_capacity = self.expert_capacity(sl * bs)
+ if expert_capacity == 0:
+ expert_capacity = torch.max(tokens_per_expert).item()
+
+ x = self.permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ self.top_k,
+ )
+ return x, tokens_per_expert
+
+ def parallel_forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ # NOTE: This function implements the same computation as forward_once
+ # but with expert model parallelism.
+ #
+ # 1. Permute the tokens locally so that they are grouped by their
+ # expert assignments. This allows us to transfer all of the tokens
+ # for a remote device in one communication primitive.
+ #
+ # 2. Permute the tokens across the expert parallel devices. After
+ # this is completed each device has all of the tokens assigned to
+ # its set of experts in its local HBM.
+ #
+ # 3. Permute the tokens locally so that they are grouped by their
+ # expert assignement. After the distributed permutation the tokens
+ # are grouped by which device they came from. We re-order them
+ # locally to allow for efficient computation.
+ #
+ # After this series of permutations we compute the linear layers
+ # and then repeat these three steps in reverse to produce the final
+ # output.
+ #
+ # Compute the mapping of local tokens to experts.
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ # If we're sharding the experts along the hidden dimension
+ # multiple devices own parts of the same sets of experts.
+ # Replicate the token counts so every device gets the counts.
+ # repeated_tokens_per_expert = ops.repeat(
+ repeated_tokens_per_expert = repeat(
+ tokens_per_expert,
+ (mpu.hidden_sharding_degree(self.args),),
+ )
+
+ # Pass token count information to the device on which the
+ # target expert resides.
+ parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert,)
+ tpe_handle = dist.all_to_all_single(
+ parallel_tokens_per_expert,
+ repeated_tokens_per_expert,
+ group=self.args.expert_parallel_group,
+ async_op=True,
+ )
+
+ # Permute locally and without any padding so that tokens for each
+ # parallel device are stored contiguously.
+ #
+ # This view updates the shape of the tensor from [sl, bs, hs] to
+ # [sl * bs, hs] prior to the permutation.
+ x = x.view(-1, x.shape[-1])
+ # output = ops.gather(x, indices, bin_ids, bins, self.top_k)
+ output = gather(x, indices, bin_ids, bins, self.top_k)
+ assert output is not None
+ x = output
+
+ # Compute the number of tokens that will be received from each
+ # device and permute the input data across the devices.
+ with torch.no_grad():
+ tpe_handle.wait()
+ experts_per_rank = mpu.experts_per_rank(self.args)
+
+ # Reshape to [world_size, num_experts_per_rank].
+ world_size = mpu.get_expert_parallel_world_size(self.args)
+ repeated_tokens_per_expert = (repeated_tokens_per_expert.view(world_size, experts_per_rank))
+ parallel_tokens_per_expert = (parallel_tokens_per_expert.view(world_size, experts_per_rank))
+
+ # TODO(tgale): It might be faster to do this on the GPU and
+ # then communicate the results back to the host.
+ send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1)
+ parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
+ recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1)
+
+ # Convert the send/recv counts to lists.
+ send_counts = send_counts.tolist()
+ recv_counts = recv_counts.tolist()
+ tokens_received = sum(recv_counts)
+
+ # If we're sharding the experts along the hidden dimension
+ # multiple devices own parts of the same sets of experts.
+ # Replicate the token counts so devices that share experts
+ # get all of the tokens assigned to them.
+ #
+ # TODO(tgale): Fuse this into the prior, local permutation.
+ # x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
+ x = repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
+
+ # Start the cross-device permutation asynchronously so we can
+ # overlap communication with computation.
+ parallel_x, parallel_x_handle = all_to_all(
+ x,
+ recv_counts,
+ send_counts,
+ self.args.expert_parallel_group,
+ async_op=True,
+ )
+
+ with torch.no_grad():
+ # After we do the cross-device permutation we have the tokens on the
+ # correct device but not yet grouped by expert because we received
+ # tokens from each device as contiguous chunks. To group the tokens
+ # for expert computation we'll do one more local permutation. The
+ # rest of this torch.no_grad() scope sets up the indices and bins
+ # for this permutation.
+ # replicate_bins = ops.inclusive_cumsum(
+ replicate_bins = inclusive_cumsum(
+ parallel_tokens_per_expert.flatten(),
+ 0,
+ )
+ replicate_bins = (replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins)
+
+ # Construct the expert indices for the permuted tokens.
+ parallel_top_expert = torch.remainder(
+ torch.arange(
+ self.num_experts * mpu.hidden_sharding_degree(self.args),
+ dtype=torch.int32,
+ device=indices.device,
+ ),
+ mpu.experts_per_rank(self.args),
+ )
+ # parallel_top_expert = ops.replicate(
+ parallel_top_expert = replicate(
+ parallel_top_expert.unsqueeze(dim=0),
+ replicate_bins,
+ tokens_received,
+ ).flatten()
+
+ # TODO(tgale): The sort_end_bit here can be reduced.
+ # parallel_bin_ids, parallel_indices = ops.sort(
+ parallel_bin_ids, parallel_indices = sort(
+ parallel_top_expert,
+ self.sort_end_bit,
+ )
+
+ # Calculate the bins boundaries from the token counts.
+ parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+ # parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = (parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins)
+
+ # If expert_capacity is set to zero, set the number of tokens
+ # per expert to the maximum we need to avoid dropping tokens.
+ tokens, _ = x.size()
+ expert_capacity = self.expert_capacity(tokens)
+ if expert_capacity == 0:
+ expert_capacity = torch.max(parallel_tokens_per_expert).item()
+
+ # Locally permute the tokens and perform the expert computation.
+ # Block to make sure that the cross-device permutation is complete.
+ if self.args.mlp_impl == 'grouped':
+ # GroupedMLP requires counts on CPU. We can use the tensor already
+ # moved to CPU for the prior all_to_all, which avoids an extra
+ # device synchronization.
+ parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+ parallel_x_handle.wait()
+ parallel_x = self.permute_and_compute(
+ parallel_x,
+ parallel_tokens_per_expert,
+ parallel_indices,
+ parallel_bin_ids,
+ None, # expert_weights
+ parallel_bins,
+ expert_capacity,
+ top_k=1,
+ )
+
+ # Un-permute the tokens across the devices.
+ x, _ = all_to_all(
+ parallel_x,
+ send_counts,
+ recv_counts,
+ self.args.expert_parallel_group,
+ )
+
+ # Reduce along the hidden sharding to get the final outputs.
+ #
+ # TODO(tgale): Fuse this into the following local permutation.
+ shape = (
+ mpu.hidden_sharding_degree(self.args),
+ -1,
+ self.args.hidden_size,
+ )
+ # x = ops.sum(x.view(shape), dim=0)
+ x = x.view(shape).sum(dim=0)
+
+ # Un-permute locally to setup for the next series of operations.
+ # x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
+ x = scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
+ return x, tokens_per_expert.flatten()
+
+ def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ in_shape = x.size()
+
+ # Compute the experts.
+ x, tokens_per_expert = self.forward_fn(x, expert_weights, top_experts)
+ if self.training and self.args.moe_loss_weight > 0:
+ save_load_balancing_loss((tokens_per_expert, scores))
+ x = x.view(in_shape)
+ if self.bias is not None:
+ if self.args.return_bias:
+ return x, self.bias
+ return x + self.bias
+ return x
+
+
+class MoE(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super(MoE, self).__init__()
+
+ # Token router.
+ self.router = router.LearnedRouter(args)
+
+ # Expert computation helper.
+ self.experts = self._init_experts_mlp(args)
+
+ self.shared_expert = None
+ if args.shared_expert:
+ # SharedExpert computation helper.
+ self.shared_expert = sharedexpert_registry.get(args)
+
+ def _init_experts_mlp(self, args: Arguments):
+ return ParallelMLP(args)
+
+ def forward(self, x: torch.Tensor):
+ # NOTE: If we're going to cast the activations to lower precision
+ # do it before we permute the tokens to save bandwidth.
+ x = common.cast_if_autocast_enabled(x)
+
+ # Compute the expert scores and assignments.
+ scores, expert_weights, top_experts = self.router(x)
+
+ # Compute the experts.
+ out = self.experts(x, scores, expert_weights, top_experts)
+ if self.shared_expert is not None:
+ shared_expert_out = self.shared_expert(x)
+ out = self.shared_expert.add_experts_sharedexpert(
+ shared_expert_out,
+ out,
+ )
+ return out
diff --git a/build/torch29-cxx11-cu130-x86_64-linux/_layers/mpu.py b/build/torch29-cxx11-cu130-x86_64-linux/_layers/mpu.py
new file mode 100644
index 0000000000000000000000000000000000000000..434e143ab42bf3f83406d69e9dd1f72777716e22
--- /dev/null
+++ b/build/torch29-cxx11-cu130-x86_64-linux/_layers/mpu.py
@@ -0,0 +1,94 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Optional
+
+import torch
+import torch.distributed as dist
+
+# from megablocks.layers.arguments import Arguments
+from .arguments import Arguments
+
+
+class MoeParam(torch.Tensor):
+
+ def __init__(self):
+ super().__init__(self)
+ self.expert_model_parallel: bool
+
+
+def is_moe_param(tensor: torch.Tensor) -> bool:
+ return hasattr(tensor, 'expert_model_parallel')
+
+
+def get_expert_parallel_world_size(args: Arguments) -> int:
+ return (dist.get_world_size(args.expert_parallel_group) if args.moe_expert_model_parallelism else 1)
+
+
+def get_expert_parallel_rank(args: Arguments) -> int:
+ return (dist.get_rank(args.expert_parallel_group) if args.moe_expert_model_parallelism else 0)
+
+
+def set_expert_model_parallel_attributes(
+ tensor: torch.Tensor,
+ is_parallel: bool,
+):
+ assert not hasattr(tensor, 'expert_model_parallel')
+ setattr(tensor, 'expert_model_parallel', is_parallel)
+
+
+def param_is_expert_model_parallel(param: MoeParam) -> bool:
+ return (hasattr(param, 'expert_model_parallel') and param.expert_model_parallel)
+
+
+def copy_expert_model_parallel_attributes(
+ destination_tensor: torch.Tensor,
+ source_tensor: torch.Tensor,
+):
+ if hasattr(source_tensor, 'expert_model_parallel'):
+ setattr(
+ destination_tensor,
+ 'expert_model_parallel',
+ getattr(source_tensor, 'expert_model_parallel'),
+ )
+
+
+def synchronized_print(group: Optional[dist.ProcessGroup], *x: torch.Tensor):
+ world_size = dist.get_world_size(group)
+ rank = dist.get_rank(group)
+ for i in range(world_size):
+ dist.barrier(group)
+ if i == rank:
+ print(f'rank = {rank}', *x)
+
+
+# Helpers for expert/tensor sharding.
+def expert_sharding_degree(args: Arguments) -> int:
+ world_size = get_expert_parallel_world_size(args)
+ esd = min(world_size, args.moe_num_experts)
+
+ if (args.moe_num_experts % esd) != 0:
+ raise ValueError(f'Cannot shard {args.moe_num_experts} experts {esd} ways.',)
+ return esd
+
+
+def hidden_sharding_degree(args: Arguments) -> int:
+ world_size = get_expert_parallel_world_size(args)
+ esd = expert_sharding_degree(args)
+ hsd = world_size // esd
+
+ if (args.ffn_hidden_size % hsd) != 0:
+ raise ValueError(f'Cannot shard {args.ffn_hidden_size} features {hsd} ways.',)
+ if (esd * hsd) != world_size:
+ raise ValueError(
+ f"Invalid sharding. 'expert_sharding_degree' ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size}).",
+ )
+ return hsd
+
+
+def experts_per_rank(args: Arguments) -> int:
+ return args.moe_num_experts // expert_sharding_degree(args)
+
+
+def features_per_rank(args: Arguments) -> int:
+ return args.ffn_hidden_size // hidden_sharding_degree(args)
diff --git a/build/torch29-cxx11-cu130-x86_64-linux/_layers/router.py b/build/torch29-cxx11-cu130-x86_64-linux/_layers/router.py
new file mode 100644
index 0000000000000000000000000000000000000000..37cb2782348d62583376f1a183c7ede83601216d
--- /dev/null
+++ b/build/torch29-cxx11-cu130-x86_64-linux/_layers/router.py
@@ -0,0 +1,116 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+
+# from megablocks.layers import common
+# from megablocks.layers.arguments import Arguments
+from . import common
+from .arguments import Arguments
+
+_ROUTER_LOGITS = []
+
+
+def _save_router_logits(logits: torch.Tensor, args: Arguments):
+ if args.moe_zloss_weight == 0:
+ return
+ global _ROUTER_LOGITS
+ _ROUTER_LOGITS.append(logits)
+
+
+def clear_router_zloss():
+ global _ROUTER_LOGITS
+ _ROUTER_LOGITS.clear()
+
+
+def batched_router_zloss(args: Arguments):
+ global _ROUTER_LOGITS
+
+ if args.moe_zloss_weight == 0:
+ import warnings
+ warnings.warn('Call to batched_router_zloss, but moe_zloss_weight=0')
+ return 0
+
+ logits_per_router = _ROUTER_LOGITS
+
+ if args.moe_zloss_in_fp32:
+ logits_per_router = [logits.float() for logits in logits_per_router]
+
+ unscaled_zloss_per_router = torch.stack([
+ torch.logsumexp(logits, dim=1).square().mean() for logits in logits_per_router
+ ])
+
+ return args.moe_zloss_weight * unscaled_zloss_per_router
+
+
+# NOTE: To enable end-to-end benchmarking without convergence we
+# support a flag to force the router to assign tokens uniformly
+# across the experts. We do this with a custom autograd operation
+# so that PyTorch still executes the full set of router operation.
+class _UniformExpertAssignment(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, num_experts: int):
+ out = torch.arange(x.numel(), dtype=x.dtype, device=x.device)
+ out = torch.remainder(out, num_experts)
+ return out.view(x.shape)
+
+
+_uniform_expert_assignment = _UniformExpertAssignment.apply
+
+
+class LearnedRouter(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+
+ # Learned router parameters.
+ #
+ # NOTE: This weight matrix is not parallelized with expert model
+ # parallelism. Each device needs the entire router weight matrix
+ # so that it can route its batch of data correctly.
+ self.layer = torch.nn.Linear(
+ args.hidden_size,
+ args.moe_num_experts,
+ bias=False,
+ dtype=common.dtype(args),
+ device=args.device,
+ )
+ args.init_method(self.layer.weight)
+
+ def jitter(self, x: torch.Tensor):
+ low: float = 1.0 - self.args.moe_jitter_eps
+ high: float = 1.0 + self.args.moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return low + noise * (high - low)
+
+ def _top_k(self, scores: torch.Tensor):
+ if self.args.moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, self.args.moe_top_k, dim=-1)
+
+ def forward(self, x: torch.Tensor):
+ if self.training and self.args.moe_jitter_eps is not None:
+ x = x * self.jitter(x)
+
+ logits = self.layer(x.view(-1, x.shape[-1]))
+ _save_router_logits(logits, self.args)
+ scores = logits.softmax(dim=-1)
+ expert_weights, expert_indices = self._top_k(scores)
+ if self.args.moe_normalize_expert_weights:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=self.args.moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+
+ expert_indices = (
+ _uniform_expert_assignment(
+ expert_indices,
+ self.args.moe_num_experts,
+ ) if self.args.uniform_expert_assignment else expert_indices
+ )
+ return scores, expert_weights, expert_indices
diff --git a/build/torch29-cxx11-cu130-x86_64-linux/_layers/sharedexpert_registry.py b/build/torch29-cxx11-cu130-x86_64-linux/_layers/sharedexpert_registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..5840862f88f370ace5fd49bd0612fc98d186cc49
--- /dev/null
+++ b/build/torch29-cxx11-cu130-x86_64-linux/_layers/sharedexpert_registry.py
@@ -0,0 +1,32 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Union
+
+# from megablocks.layers import glu, mlp
+# from megablocks.layers.arguments import Arguments
+from . import glu, mlp
+from .arguments import Arguments
+
+_REGISTRY = {
+ 'mlp': mlp.SharedMLP,
+ 'glu': glu.SharedGLU,
+}
+
+
+def get(args: Arguments) -> Union[mlp.SharedMLP, glu.SharedGLU]:
+ """Returns an SharedMLP for use in a dMoE instance.
+
+ Uses the provided arguments to instantiate the appropriate
+ SharedMLP instance.
+
+ Args:
+ args: propagated Arguments dataclass.
+
+ Returns:
+ An instantiated SharedMLP constructed using the input args.
+ """
+ if args.mlp_type not in _REGISTRY:
+ raise ValueError(f'Unsupported mlp type: {args.mlp_type}')
+
+ return _REGISTRY[args.mlp_type](args)
diff --git a/build/torch29-cxx11-cu130-x86_64-linux/_megablocks_cuda_6e04dec.abi3.so b/build/torch29-cxx11-cu130-x86_64-linux/_megablocks_cuda_6e04dec.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..3548f7fb815188fc523c90fa2111a7b14bf82e95
--- /dev/null
+++ b/build/torch29-cxx11-cu130-x86_64-linux/_megablocks_cuda_6e04dec.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1e1383adbf7afa208f0769d84a826fcd43de9ee9ce39d676ebce97698759c526
+size 12031416
diff --git a/build/torch29-cxx11-cu130-x86_64-linux/_ops.py b/build/torch29-cxx11-cu130-x86_64-linux/_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..b2f202b8db3c3f3028303ab4308cf35f950e2c74
--- /dev/null
+++ b/build/torch29-cxx11-cu130-x86_64-linux/_ops.py
@@ -0,0 +1,9 @@
+import torch
+from . import _megablocks_cuda_6e04dec
+ops = torch.ops._megablocks_cuda_6e04dec
+
+def add_op_namespace_prefix(op_name: str):
+ """
+ Prefix op by namespace.
+ """
+ return f"_megablocks_cuda_6e04dec::{op_name}"
diff --git a/build/torch29-cxx11-cu130-x86_64-linux/_version.py b/build/torch29-cxx11-cu130-x86_64-linux/_version.py
new file mode 100644
index 0000000000000000000000000000000000000000..c55783177af19bc03654c730c4892df8f8532279
--- /dev/null
+++ b/build/torch29-cxx11-cu130-x86_64-linux/_version.py
@@ -0,0 +1,6 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+"""The MegaBlocks Version."""
+
+__version__ = '0.11.0.dev0'
diff --git a/build/torch29-cxx11-cu130-x86_64-linux/backend/__init__.py b/build/torch29-cxx11-cu130-x86_64-linux/backend/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d4e43e9b7e6a9a1c3bde2df34914643ca5d8332
--- /dev/null
+++ b/build/torch29-cxx11-cu130-x86_64-linux/backend/__init__.py
@@ -0,0 +1,2 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
diff --git a/build/torch29-cxx11-cu130-x86_64-linux/backend/kernels.py b/build/torch29-cxx11-cu130-x86_64-linux/backend/kernels.py
new file mode 100644
index 0000000000000000000000000000000000000000..d20139324883338ddb312e4b05a72056d74491ac
--- /dev/null
+++ b/build/torch29-cxx11-cu130-x86_64-linux/backend/kernels.py
@@ -0,0 +1,557 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import triton
+import triton.language as tl
+
+# Stub triton autotune when testing in a env that does not have CUDA
+# this approach preserves the original code but enables testing without a GPU
+if torch.cuda.is_available() is False:
+ import warnings
+
+ warnings.warn("CUDA is not available. Triton autotuning is disabled.")
+
+ def _no_autotune(*args, **kwargs):
+ def deco(fn):
+ return fn
+ return deco
+
+ triton.autotune = _no_autotune
+
+
+def assert_is_tensor(x, ndim):
+ if x.ndim != ndim:
+ raise ValueError(f'Expected {ndim}-tensor but got {x.ndim}-tensor')
+
+
+def assert_is_matrix(x):
+ assert_is_tensor(x, 2)
+
+
+def assert_is_vector(x):
+ if x.ndim != 1:
+ raise ValueError(f'Expected 1-tensor but got {x.ndim}-tensor')
+
+
+def assert_equal(a, b):
+ if a != b:
+ raise ValueError(f'Expected dimensions to be equal but got {a} and {b}.',)
+
+
+# a: (tokens, hidden_size), real.
+# indices: (tokens * top_k), integer.
+# bin_ids: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+# padded_bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _padded_copy(
+ a,
+ b,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+ A_TO_B: tl.constexpr,
+ SCALE: tl.constexpr,
+):
+ # Our index into array 'a'.
+ index_a = tl.load(indices + tl.program_id(0))
+
+ # One threadblock per row in 'a'. Array 'b' has greater or equal
+ # number of rows since they could be padded.
+ bin_idx = tl.load(bin_ids + tl.program_id(0))
+
+ # Now we know what bin we're assigned to, but we need to know how
+ # many threadblocks were assigned to earlier bins so we can offset
+ # in our bin properly.
+ offset_in_bin = tl.program_id(0)
+ if bin_idx > 0:
+ offset_in_bin -= tl.load(bins + bin_idx - 1)
+
+ # Load the starting index of our bin in array 'b'.
+ index_b = offset_in_bin
+ if bin_idx > 0:
+ index_b += tl.load(padded_bins + bin_idx - 1)
+
+ # Offset the input and output pointers.
+ #
+ # If we're going from A to B, divide the input index to copy
+ # the same input repeatedly. If we're going from B to A we
+ # need to reduce the result. Using atomics is slow, so we
+ # do the reduce step in a second kernel.
+ offset = index_a // TOP_K if A_TO_B else index_a
+ a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS)
+ b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ # Load the scale, if requested.
+ scale = tl.load(weights + index_a) if SCALE else 1
+
+ # Swap the pointers depending on the direction.
+ iptr = a if A_TO_B else b
+ optr = b if A_TO_B else a
+
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ x = tl.load(iptr + offsets, mask=mask)
+ x = x.to(tl.float32) * scale.to(tl.float32)
+
+ tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask)
+
+ offsets += BLOCK_X
+
+
+def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+ assert_equal(bin_ids.shape[0], x.shape[0] * top_k)
+ assert_equal(bins.size(), padded_bins.size())
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ # NOTE: Because of the padding, the output size is dynamic.
+ # We load the final padded bin bound to get the output rows.
+ output_rows = padded_bins[-1].cpu().item()
+ out = torch.zeros((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ x,
+ out,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def gather(x, indices, bin_ids, weights, bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+ assert_equal(bin_ids.shape[0], x.shape[0] * top_k)
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ # NOTE: There is no padding so the output rows equals the
+ # input rows multiplied by top_k.
+ output_rows = x.shape[0] * top_k
+ out = torch.empty((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ x,
+ out,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], bin_ids.shape[0])
+ assert_equal(bins.size(), padded_bins.size())
+
+ if weights is not None:
+ assert_equal(indices.shape[0], weights.shape[0])
+
+ tokens = indices.shape[0] // top_k
+ out = torch.empty((tokens, top_k, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ out,
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=False,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+
+ # Reduce along the top-k dimension, if needed.
+ return out.sum(dim=1) if top_k > 1 else out.view(tokens, x.shape[1])
+
+
+def scatter(x, indices, bin_ids, weights, bins, top_k):
+ return padded_scatter(x, indices, bin_ids, weights, bins, bins, top_k)
+
+
+# x: (tokens, top_k, hidden_size), real
+# grad: (tokens, hidden_size), real.
+# wgrad: (tokens, top_k), real.
+# indices: (tokens * top_k), integer.
+# bin_ids: (tokens * top_k), integer.
+# bins: (num_experts), integer.
+# padded_bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _padded_copy_wgrad(
+ x,
+ grad,
+ wgrad,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+):
+ # Our index into 'tokens * top_k'.
+ index_out = tl.load(indices + tl.program_id(0))
+
+ # One threadblock per row in 'a'. Array 'b' has greater or equal
+ # number of rows since they could be padded.
+ bin_idx = tl.load(bin_ids + tl.program_id(0))
+
+ # Now we know what bin we're assigned to, but we need to know how
+ # many threadblocks were assigned to earlier bins so we can offset
+ # in our bin properly.
+ offset_in_bin = tl.program_id(0)
+ if bin_idx > 0:
+ offset_in_bin -= tl.load(bins + bin_idx - 1)
+
+ # Load the starting index of our bin in array 'x'.
+ index_x = offset_in_bin
+ if bin_idx > 0:
+ index_x += tl.load(padded_bins + bin_idx - 1)
+
+ # Offset the input and output pointers.
+ wgrad += index_out
+ grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS)
+ x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ data = tl.load(x + offsets, mask=mask).to(tl.float32)
+ scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
+ acc += data * scale
+ offsets += BLOCK_X
+
+ # Reduce to get the final result and store.
+ out = tl.sum(acc).to(wgrad.dtype.element_ty)
+ tl.store(wgrad, out)
+
+
+def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_matrix(grad)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], bin_ids.shape[0])
+ assert_equal(bins.size(), padded_bins.size())
+
+ tokens = indices.shape[0] // top_k
+ out = torch.empty((tokens * top_k), dtype=x.dtype, device=x.device)
+ _padded_copy_wgrad[(indices.shape[0],)](
+ x,
+ grad,
+ out,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ TOP_K=top_k,
+ )
+ return out
+
+
+def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k):
+ return padded_scatter_wgrad(x, grad, indices, bin_ids, bins, bins, top_k)
+
+
+# a: (tokens, hidden_size), real.
+# b: (num_experts, expert_capacity, num_columns), real.
+# indices: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _binned_copy(
+ a,
+ b,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+ A_TO_B: tl.constexpr,
+ SCALE: tl.constexpr,
+):
+ # Load our indices into the output.
+ expert_idx = tl.program_id(0)
+ entry_idx = tl.program_id(1)
+
+ # Calculate our offset into the output.
+ index_b = expert_idx * expert_capacity + entry_idx
+
+ # Load the index bounds for our bin and calculate
+ # the number of tokens assigned to our expert.
+ start = 0
+ if expert_idx > 0:
+ start = tl.load(bins + expert_idx - 1)
+ end = tl.load(bins + expert_idx)
+ num_tokens = end - start
+
+ # Calculate our offset into the input. If we don't
+ # have an input exit early.
+ if entry_idx >= num_tokens:
+ return
+ index_a = tl.load(indices + start + entry_idx)
+
+ # Offset the input and output pointers.
+ #
+ # If we're going from A to B, divide the input index to copy
+ # the same input repeatedly. If we're going from B to A we
+ # need to reduce the result. Using atomics is slow, so we
+ # do the reduce step in a second kernel.
+ offset = index_a // TOP_K if A_TO_B else index_a
+ a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS)
+ b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ # Load the scale, if requested.
+ scale = tl.load(weights + index_a) if SCALE else 1
+
+ # Swap the pointers depending on the direction.
+ #
+ # NOTE: We need to zero the output in both directions.
+ iptr = a if A_TO_B else b
+ optr = b if A_TO_B else a
+
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ x = tl.load(iptr + offsets, mask=mask)
+ x = x.to(tl.float32) * scale.to(tl.float32)
+
+ tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask)
+
+ offsets += BLOCK_X
+
+
+def binned_gather(x, indices, weights, bins, expert_capacity, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ num_experts = bins.shape[0]
+ out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
+
+ _binned_copy[(num_experts, expert_capacity)](
+ x,
+ out,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def binned_scatter(x, indices, weights, bins, top_k):
+ # Validate the input shapes.
+ assert_is_tensor(x, 3)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(bins.shape[0], x.shape[0])
+
+ if weights is not None:
+ assert_equal(indices.shape[0], weights.shape[0])
+
+ num_experts, expert_capacity, hidden_size = x.shape
+ tokens = indices.shape[0] // top_k
+ out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
+ _binned_copy[(num_experts, expert_capacity)](
+ out,
+ x,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS=hidden_size,
+ A_TO_B=False,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+
+ # Reduce along the top-k dimension, if needed.
+ return out.sum(dim=1) if top_k > 1 else out.view(tokens, hidden_size)
+
+
+# a: (tokens, hidden_size), real.
+# b: (num_experts, expert_capacity, num_columns), real.
+# indices: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _binned_copy_wgrad(
+ x,
+ grad,
+ wgrad,
+ num_experts,
+ expert_capacity,
+ indices,
+ bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+):
+ # Load our indices into the output.
+ expert_idx = tl.program_id(0)
+ entry_idx = tl.program_id(1)
+
+ # Calculate our offset into the output.
+ index_x = expert_idx * expert_capacity + entry_idx
+
+ # Load the index bounds for our bin and calculate
+ # the number of tokens assigned to our expert.
+ start = 0
+ if expert_idx > 0:
+ start = tl.load(bins + expert_idx - 1)
+ end = tl.load(bins + expert_idx)
+ num_tokens = end - start
+
+ # Calculate our offset into the input. If we don't
+ # have an input exit early.
+ if entry_idx >= num_tokens:
+ return
+ index_out = tl.load(indices + start + entry_idx)
+
+ # Offset the input and output pointers.
+ wgrad += index_out
+ grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS)
+ x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ data = tl.load(x + offsets, mask=mask).to(tl.float32)
+ scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
+ acc += data * scale
+ offsets += BLOCK_X
+
+ # Reduce to get the final result and store.
+ out = tl.sum(acc).to(wgrad.dtype.element_ty)
+ tl.store(wgrad, out)
+
+
+def binned_scatter_wgrad(x, grad, indices, bins, top_k):
+ # Validate the input shapes.
+ assert_is_tensor(x, 3)
+ assert_is_matrix(grad)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(bins.shape[0], x.shape[0])
+
+ num_experts, expert_capacity, hidden_size = x.shape
+ tokens = indices.shape[0] // top_k
+ out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
+ _binned_copy_wgrad[(num_experts, expert_capacity)](
+ x,
+ grad,
+ out,
+ num_experts,
+ expert_capacity,
+ indices,
+ bins,
+ NUM_COLUMNS=hidden_size,
+ TOP_K=top_k,
+ )
+ return out
diff --git a/build/torch29-cxx11-cu130-x86_64-linux/benchmark_util.py b/build/torch29-cxx11-cu130-x86_64-linux/benchmark_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..02612d95e3ead1175a596e2878fa34b5bf85ad6f
--- /dev/null
+++ b/build/torch29-cxx11-cu130-x86_64-linux/benchmark_util.py
@@ -0,0 +1,35 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import numpy as np
+import torch
+
+
+def log_benchmark(name, arguments, time, std):
+ print('=' * 60)
+ print(f'{name} Benchmark')
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean time = {:.3f}ms, std time = {:.3f}ms'.format(time, std))
+ print('=' * 60)
+
+
+def benchmark_function(fn, iterations=100, warmup=10):
+ # Warmup iterations.
+ for _ in range(warmup):
+ fn()
+
+ times = []
+ for i in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+
+ start.record()
+ fn()
+ end.record()
+
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ return np.mean(times), np.std(times)
diff --git a/build/torch29-cxx11-cu130-x86_64-linux/cpu_fused_moe.py b/build/torch29-cxx11-cu130-x86_64-linux/cpu_fused_moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a29e81d6e9c57f64b0a78ac1c0828e45fd9d855
--- /dev/null
+++ b/build/torch29-cxx11-cu130-x86_64-linux/cpu_fused_moe.py
@@ -0,0 +1,311 @@
+# SPDX-License-Identifier: Apache-2.0
+# MegaBlocks CPU Fused MoE Implementation
+#
+# This is a pure Python/PyTorch implementation for CPU.
+# For better performance, consider using the C++ kernel implementation.
+#
+import torch
+import torch.nn.functional as F
+
+
+def swigluoai_activation(gate: torch.Tensor, up: torch.Tensor,
+ alpha: float = 1.702, limit: float = 7.0) -> torch.Tensor:
+ """
+ SwigluOAI activation function used in GptOss models.
+
+ Formula:
+ gate = clamp(gate, max=limit)
+ up = clamp(up, -limit, limit)
+ glu = gate * sigmoid(gate * alpha)
+ output = (up + 1) * glu
+
+ Args:
+ gate: Gate tensor from gate projection
+ up: Up tensor from up projection
+ alpha: Scaling factor for sigmoid (default: 1.702)
+ limit: Clamp limit (default: 7.0)
+
+ Returns:
+ Activated tensor
+ """
+ gate = gate.clamp(max=limit)
+ up = up.clamp(min=-limit, max=limit)
+ glu = gate * torch.sigmoid(gate * alpha)
+ return (up + 1) * glu
+
+
+def silu_and_mul_activation(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor:
+ """
+ SiLU (Swish) activation with element-wise multiplication.
+
+ Formula:
+ output = silu(gate) * up
+
+ Args:
+ gate: Gate tensor
+ up: Up tensor
+
+ Returns:
+ Activated tensor
+ """
+ return F.silu(gate) * up
+
+
+def route_tokens_cpu(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: torch.Tensor | None,
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_normalize_expert_weights: int | None = None,
+) -> tuple:
+ """
+ Route tokens to experts and compute expert weights and indices (CPU version).
+
+ Args:
+ x: Input tensor [batch, seq, hidden] or [tokens, hidden]
+ router_weight: Router weight [num_experts, hidden]
+ router_bias: Router bias [num_experts] or None
+ moe_top_k: Number of experts per token
+ moe_num_experts: Total number of experts
+ moe_normalize_expert_weights: Normalization order or None
+
+ Returns:
+ Tuple of (logits, expert_weights, expert_indices)
+ """
+ x_flat = x.view(-1, x.shape[-1])
+ logits = F.linear(x_flat, router_weight, router_bias)
+
+ if moe_top_k == 1:
+ expert_weights, expert_indices = logits.max(dim=-1, keepdim=True)
+ else:
+ expert_weights, expert_indices = torch.topk(logits, moe_top_k, dim=-1)
+
+ expert_weights = expert_weights.softmax(dim=-1)
+
+ if moe_normalize_expert_weights is not None:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+
+ return logits, expert_weights, expert_indices
+
+
+def cpu_fused_moe(
+ hidden_states: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ w1_bias: torch.Tensor | None = None,
+ w2_bias: torch.Tensor | None = None,
+ activation: str = "silu",
+ alpha: float = 1.702,
+ limit: float = 7.0,
+ is_interleaved: bool = True,
+) -> torch.Tensor:
+ """
+ CPU Fused MoE using PyTorch operations.
+
+ This implementation processes all experts in parallel using batched operations
+ instead of sequential for loops, which is more efficient on CPU.
+
+ Args:
+ hidden_states: [num_tokens, hidden_size]
+ w1: [num_experts, hidden_size, 2*inter_size] - gate_up_proj weights
+ w2: [num_experts, inter_size, hidden_size] - down_proj weights
+ topk_weights: [num_tokens, topk] - routing weights
+ topk_ids: [num_tokens, topk] - expert indices
+ w1_bias: [num_experts, 2*inter_size] or None
+ w2_bias: [num_experts, hidden_size] or None
+ activation: "silu" or "swigluoai"
+ alpha: swigluoai alpha parameter
+ limit: swigluoai limit parameter
+ is_interleaved: whether gate_up is interleaved [g0,u0,g1,u1,...] (True for GptOss)
+
+ Returns:
+ output: [num_tokens, hidden_size]
+ """
+ num_tokens, hidden_size = hidden_states.shape
+ num_experts = w1.shape[0]
+ inter_size = w2.shape[1]
+ topk = topk_weights.shape[1]
+
+ # Initialize output
+ output = torch.zeros_like(hidden_states)
+
+ # Build expert mask: which tokens go to which expert
+ # expert_mask[expert_id] contains indices of (token_idx, topk_pos) pairs
+ for expert_idx in range(num_experts):
+ # Find tokens assigned to this expert
+ # mask shape: [num_tokens, topk], True where topk_ids == expert_idx
+ mask = (topk_ids == expert_idx)
+
+ if not mask.any():
+ continue
+
+ # Get token indices and topk positions
+ token_indices, topk_positions = torch.where(mask)
+
+ if len(token_indices) == 0:
+ continue
+
+ # Gather input tokens for this expert
+ # current_hidden: [num_selected_tokens, hidden_size]
+ current_hidden = hidden_states[token_indices]
+
+ # Get weights for this expert
+ # w1[expert_idx]: [hidden_size, 2*inter_size]
+ # w2[expert_idx]: [inter_size, hidden_size]
+ expert_w1 = w1[expert_idx] # [hidden_size, 2*inter_size]
+ expert_w2 = w2[expert_idx] # [inter_size, hidden_size]
+
+ # First projection: hidden @ w1 -> [num_selected, 2*inter_size]
+ gate_up = current_hidden @ expert_w1
+
+ # Add bias if present
+ if w1_bias is not None:
+ gate_up = gate_up + w1_bias[expert_idx]
+
+ # Split gate and up projections
+ if is_interleaved:
+ # GptOss uses interleaved layout: [g0, u0, g1, u1, ...]
+ gate = gate_up[..., ::2] # [num_selected, inter_size]
+ up = gate_up[..., 1::2] # [num_selected, inter_size]
+ else:
+ # Standard layout: [gate_all, up_all]
+ gate = gate_up[..., :inter_size]
+ up = gate_up[..., inter_size:]
+
+ # Apply activation
+ if activation == "swigluoai":
+ activated = swigluoai_activation(gate, up, alpha, limit)
+ else: # silu
+ activated = silu_and_mul_activation(gate, up)
+
+ # Second projection: activated @ w2 -> [num_selected, hidden_size]
+ expert_out = activated @ expert_w2
+
+ # Add bias if present
+ if w2_bias is not None:
+ expert_out = expert_out + w2_bias[expert_idx]
+
+ # Apply routing weights and accumulate
+ # weights shape: [num_selected]
+ weights = topk_weights[token_indices, topk_positions].unsqueeze(-1)
+ weighted_out = expert_out * weights
+
+ # Accumulate to output
+ output.index_add_(0, token_indices, weighted_out.to(output.dtype))
+
+ return output
+
+
+class MegaBlocksMoeMLP(torch.nn.Module):
+ """
+ CPU MoE MLP module that can be used as a drop-in replacement for
+ the transformers GptOssMLP when using @use_kernel_forward_from_hub.
+ """
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> tuple:
+ """
+ Forward pass through the MoE layer.
+
+ Args:
+ x: Input tensor of shape [batch_size, seq_len, hidden_size] or [tokens, hidden_size]
+
+ Returns:
+ Tuple of (output, expert_weights) where:
+ - output: Tensor of same shape as input
+ - expert_weights: Expert weights for each token [tokens, top_k]
+ """
+ # Get MoE parameters from the wrapped modules
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+
+ # Detect activation type
+ if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
+ activation = "swigluoai"
+ alpha = self.experts.alpha
+ limit = self.experts.limit
+ else:
+ activation = getattr(self.experts, "activation", "silu")
+ alpha = 1.702
+ limit = 7.0
+
+ # Get weight tensors
+ if hasattr(self.experts, "gate_up_proj"):
+ w1 = self.experts.gate_up_proj
+ is_interleaved = True # GptOss uses interleaved layout
+ elif hasattr(self.experts, "w1"):
+ w1 = self.experts.w1
+ w3 = getattr(self.experts, "w3", None)
+ if w3 is not None:
+ w1 = torch.cat([w1, w3], dim=-1)
+ is_interleaved = False
+ else:
+ raise AttributeError("experts module must have 'gate_up_proj' or 'w1' attribute")
+
+ if hasattr(self.experts, "down_proj"):
+ w2 = self.experts.down_proj
+ elif hasattr(self.experts, "w2"):
+ w2 = self.experts.w2
+ else:
+ raise AttributeError("experts module must have 'down_proj' or 'w2' attribute")
+
+ # Get optional bias tensors
+ w1_bias = getattr(self.experts, "gate_up_proj_bias", None)
+ w2_bias = getattr(self.experts, "down_proj_bias", None)
+
+ # Store original shape
+ in_shape = x.size()
+
+ # Route tokens to experts
+ logits, expert_weights, expert_indices = route_tokens_cpu(
+ x,
+ self.router.weight,
+ getattr(self.router, "bias", None),
+ moe_top_k,
+ moe_num_experts,
+ moe_normalize_expert_weights,
+ )
+
+ # Reshape input for fused MoE
+ x_flat = x.view(-1, x.shape[-1])
+
+ # Call CPU fused MoE
+ output = cpu_fused_moe(
+ hidden_states=x_flat,
+ w1=w1,
+ w2=w2,
+ topk_weights=expert_weights,
+ topk_ids=expert_indices,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ activation=activation,
+ alpha=alpha,
+ limit=limit,
+ is_interleaved=is_interleaved,
+ )
+
+ # Restore original shape
+ output = output.view(in_shape)
+
+ return output, expert_weights
+
+
+# Export classes and functions
+__all__ = [
+ "MegaBlocksMoeMLP",
+ "cpu_fused_moe",
+ "route_tokens_cpu",
+ "swigluoai_activation",
+ "silu_and_mul_activation",
+]
diff --git a/build/torch29-cxx11-cu130-x86_64-linux/cpu_moe_cpp.py b/build/torch29-cxx11-cu130-x86_64-linux/cpu_moe_cpp.py
new file mode 100644
index 0000000000000000000000000000000000000000..073ff66d24ce348fbb5ed19c9027fadd3f7a9c61
--- /dev/null
+++ b/build/torch29-cxx11-cu130-x86_64-linux/cpu_moe_cpp.py
@@ -0,0 +1,265 @@
+# SPDX-License-Identifier: Apache-2.0
+# MegaBlocks C++ Optimized CPU MoE
+
+"""
+C++ accelerated MoE with brgemm optimization for Intel AMX.
+Direct replacement for cpu_fused_moe.MegaBlocksMoeMLP with better performance.
+"""
+
+import torch
+from typing import Optional
+from .cpu_fused_moe import route_tokens_cpu
+from ._ops import ops
+
+
+def _to_local_tensor(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
+ """Convert DTensor to local torch.Tensor if needed for custom ops compatibility."""
+ if tensor is None:
+ return None
+ # Check if it's a DTensor by looking for the to_local() method
+ if hasattr(tensor, "to_local"):
+ return tensor.to_local()
+ return tensor
+
+
+def fused_moe_cpp(
+ hidden_states: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ inplace: bool = False,
+ use_int8_w8a8: bool = False,
+ use_fp8_w8a16: bool = False,
+ use_mxfp4: bool = False,
+ w1_scale: Optional[torch.Tensor] = None,
+ w2_scale: Optional[torch.Tensor] = None,
+ block_size: Optional[list] = None,
+ a1_scale: Optional[torch.Tensor] = None,
+ a2_scale: Optional[torch.Tensor] = None,
+ w1_bias: Optional[torch.Tensor] = None,
+ w2_bias: Optional[torch.Tensor] = None,
+ alpha: Optional[float] = None,
+ limit: Optional[float] = None,
+ is_vnni: bool = False,
+) -> torch.Tensor:
+ """
+ C++ Fused MoE with brgemm optimization (sglang compatible interface).
+
+ Uses at::native::cpublas::brgemm for efficient batch GEMM on Intel CPUs.
+ Supports both silu_and_mul (standard SwiGLU) and swigluoai (GptOss) activations.
+
+ Args:
+ hidden_states: Input tensor [M, K]
+ w1: Gate and up projections [E, 2N, K]
+ w2: Down projection [E, K, N]
+ topk_weights: Expert weights [M, topk]
+ topk_ids: Expert indices [M, topk]
+ inplace: Whether to use hidden_states as output
+ use_int8_w8a8: Use int8 quantization
+ use_fp8_w8a16: Use fp8 quantization
+ use_mxfp4: Use mxfp4 quantization
+ w1_scale, w2_scale: Quantization scales
+ block_size: Block size for fp8
+ a1_scale, a2_scale: Activation scales
+ w1_bias, w2_bias: Optional biases
+ alpha: swigluoai alpha parameter (set to enable swiglu)
+ limit: swigluoai limit parameter (set to enable swiglu)
+ is_vnni: Whether w1/w2 are already in VNNI packed format
+ """
+ # MXFP4/FP8 kernels only support bf16, convert if needed
+ orig_dtype = hidden_states.dtype
+ need_convert = ((use_mxfp4 or use_fp8_w8a16) and orig_dtype != torch.bfloat16) or orig_dtype == torch.float32
+ if need_convert:
+ hidden_states = hidden_states.to(torch.bfloat16)
+
+ # bias must match hidden_states dtype
+ if w1_bias is not None:
+ w1_bias = w1_bias.to(hidden_states.dtype)
+ if w2_bias is not None:
+ w2_bias = w2_bias.to(hidden_states.dtype)
+
+ # Convert DTensor to local tensor for custom ops compatibility (TP mode)
+ hidden_states = _to_local_tensor(hidden_states)
+ w1 = _to_local_tensor(w1)
+ w2 = _to_local_tensor(w2)
+ topk_weights = _to_local_tensor(topk_weights)
+ topk_ids = _to_local_tensor(topk_ids)
+ w1_scale = _to_local_tensor(w1_scale)
+ w2_scale = _to_local_tensor(w2_scale)
+ a1_scale = _to_local_tensor(a1_scale)
+ a2_scale = _to_local_tensor(a2_scale)
+ w1_bias = _to_local_tensor(w1_bias)
+ w2_bias = _to_local_tensor(w2_bias)
+
+ output = ops.fused_experts(
+ hidden_states, w1, w2, topk_weights, topk_ids,
+ inplace, use_int8_w8a8, use_fp8_w8a16, use_mxfp4,
+ w1_scale, w2_scale, block_size, a1_scale, a2_scale,
+ w1_bias, w2_bias, alpha, limit, is_vnni
+ )
+
+ # Convert back to original dtype if needed
+ if need_convert:
+ output = output.to(orig_dtype)
+ return output
+
+
+class CPUMegaBlocksMoeMLP(torch.nn.Module):
+ """
+ C++ optimized MoE MLP using brgemm.
+ Drop-in replacement for cpu_fused_moe.MegaBlocksMoeMLP with better performance.
+
+ Usage in transformers:
+ # Will be used via @use_kernel_forward_from_hub decorator
+ """
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> tuple:
+ """
+ Forward pass through the MoE layer using C++ kernel.
+
+ Args:
+ x: Input tensor [batch_size, seq_len, hidden_size]
+
+ Returns:
+ Tuple of (output, expert_weights)
+ """
+ # Optimization for GPT-OSS model
+ if getattr(self, "use_mxfp4", None) is None:
+ self.use_mxfp4 = False
+
+ w1_scale = None
+ w2_scale = None
+
+ if (
+ not getattr(self, "packed_scales", False)
+ and hasattr(self.experts, "gate_up_proj")
+ and getattr(self.experts, "gate_up_proj_precision_config", None) is not None
+ ):
+ # convert scales
+ data_1 = ops.convert_scale_packed(self.experts.gate_up_proj_precision_config.weight_scale.data.transpose(-1, -2).contiguous())
+ data_2 = ops.convert_scale_packed(self.experts.down_proj_precision_config.weight_scale.data.transpose(-1, -2).contiguous())
+ self.experts.gate_up_proj_precision_config.weight_scale.storage.data = data_1
+ self.experts.down_proj_precision_config.weight_scale.storage.data = data_2
+ self.packed_scales = True
+ self.use_mxfp4 = True
+
+ if not getattr(self, "packed_weight", False) and hasattr(
+ self.experts, "gate_up_proj"
+ ):
+ # convert weights
+ data_1 = self.experts.gate_up_proj.data.transpose(-1, -2).contiguous()
+ data_2 = self.experts.down_proj.data.transpose(-1, -2).contiguous()
+ if self.use_mxfp4:
+ self.experts.gate_up_proj.storage.data = ops.convert_weight_packed(data_1)
+ self.experts.down_proj.storage.data = ops.convert_weight_packed(data_2)
+ else:
+ # convert_weight_packed only supports bfloat16, float16, int8, fp8_e4m3 or uint8(mxfp4 or int4).
+ data_1 = data_1.to(torch.bfloat16) if data_1.dtype == torch.float32 else data_1
+ data_2 = data_2.to(torch.bfloat16) if data_2.dtype == torch.float32 else data_2
+ self.experts.gate_up_proj.data = ops.convert_weight_packed(data_1)
+ self.experts.down_proj.data = ops.convert_weight_packed(data_2)
+
+ # C++ kernel does not support float32.
+ dtype = torch.bfloat16 if x.dtype == torch.float32 else x.dtype
+ if getattr(self.experts, "gate_up_proj_bias", None) is not None:
+ self.experts.gate_up_proj_bias.data = self.experts.gate_up_proj_bias.data.to(dtype)
+ if getattr(self.experts, "down_proj_bias", None) is not None:
+ self.experts.down_proj_bias.data = self.experts.down_proj_bias.data.to(dtype)
+
+ self.packed_weight = True
+
+ # Get MoE parameters
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
+
+ # Detect activation type
+ if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
+ activation = "swigluoai"
+ alpha = self.experts.alpha
+ limit = self.experts.limit
+ else:
+ activation = getattr(self.experts, "activation", "silu")
+ alpha = 1.702
+ limit = 7.0
+
+ # Get weight tensors
+ if hasattr(self.experts, "gate_up_proj"):
+ w1 = self.experts.gate_up_proj
+ elif hasattr(self.experts, "w1"):
+ w1 = self.experts.w1
+ w3 = getattr(self.experts, "w3", None)
+ if w3 is not None:
+ w1 = torch.cat([w1, w3], dim=-1)
+ else:
+ raise AttributeError("experts module must have 'gate_up_proj' or 'w1' attribute")
+
+ if hasattr(self.experts, "down_proj"):
+ w2 = self.experts.down_proj
+ elif hasattr(self.experts, "w2"):
+ w2 = self.experts.w2
+ else:
+ raise AttributeError("experts module must have 'down_proj' or 'w2' attribute")
+
+ # Get optional bias tensors
+ w1_bias = getattr(self.experts, "gate_up_proj_bias", None)
+ w2_bias = getattr(self.experts, "down_proj_bias", None)
+ w1_bias = w1_bias if w1_bias is None else w1_bias.data
+ w2_bias = w2_bias if w2_bias is None else w2_bias.data
+
+ if self.use_mxfp4:
+ w1_scale = self.experts.gate_up_proj_precision_config.weight_scale.data
+ w2_scale = self.experts.down_proj_precision_config.weight_scale.data
+
+ # Store original shape
+ in_shape = x.size()
+
+ # Route tokens to experts (Python implementation is fast enough)
+ logits, expert_weights, expert_indices = route_tokens_cpu(
+ x,
+ self.router.weight,
+ getattr(self.router, "bias", None),
+ moe_top_k,
+ moe_num_experts,
+ moe_normalize_expert_weights,
+ )
+
+ # Flatten input
+ x_flat = x.view(-1, x.shape[-1])
+
+ # Determine alpha/limit for swiglu activation
+ use_alpha = alpha if activation == "swigluoai" else None
+ use_limit = limit if activation == "swigluoai" else None
+
+ # Call C++ optimized kernel
+ output = fused_moe_cpp(
+ hidden_states=x_flat,
+ w1=w1.data,
+ w2=w2.data,
+ topk_weights=expert_weights,
+ topk_ids=expert_indices.to(torch.int32),
+ inplace=False,
+ use_int8_w8a8=False,
+ use_fp8_w8a16=False,
+ use_mxfp4=self.use_mxfp4,
+ w1_scale=w1_scale,
+ w2_scale=w2_scale,
+ block_size=None,
+ a1_scale=None,
+ a2_scale=None,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ alpha=use_alpha,
+ limit=use_limit,
+ is_vnni=getattr(self, "packed_weight", False),
+ )
+
+ # Restore original shape
+ output = output.view(in_shape)
+
+ return output, expert_weights
+
+
+__all__ = ["fused_moe_cpp", "MegaBlocksMoeMLP"]
diff --git a/build/torch29-cxx11-cu130-x86_64-linux/grouped_gemm/__init__.py b/build/torch29-cxx11-cu130-x86_64-linux/grouped_gemm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b91c8308f0c24f4c4171b6e4f15b6f76dabf295a
--- /dev/null
+++ b/build/torch29-cxx11-cu130-x86_64-linux/grouped_gemm/__init__.py
@@ -0,0 +1,2 @@
+from . import ops
+from . import backend
diff --git a/build/torch29-cxx11-cu130-x86_64-linux/grouped_gemm/backend.py b/build/torch29-cxx11-cu130-x86_64-linux/grouped_gemm/backend.py
new file mode 100644
index 0000000000000000000000000000000000000000..76037d8039cbfc2f0577275c78e4bc0be762592a
--- /dev/null
+++ b/build/torch29-cxx11-cu130-x86_64-linux/grouped_gemm/backend.py
@@ -0,0 +1,33 @@
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# # TODO(tgale): Wrap this in a try-block with better
+# # error message and instructions for building the
+# # c++ operations.
+# import grouped_gemm_backend as backend
+
+# We import the backend operations from the megablocks package as
+# grouped_gemm is vendored in megablocks in this repository.
+# from ... import _ops as backend
+# from megablocks._ops import ops as backend # type: ignore
+from .._ops import ops as backend # type: ignore
+
+def _allocate_output(a, b, batch_sizes, trans_a, trans_b):
+ assert not (trans_a and trans_b)
+ assert batch_sizes.ndim == 1, "Expected 1d tensor for batch_sizes"
+ assert a.ndim == 2, "Expected 2d tensor for 'a'"
+ assert b.ndim == (2 if trans_a else 3)
+
+ shape = (
+ (batch_sizes.shape[0], a.shape[1], b.shape[1])
+ if trans_a else
+ (a.shape[0], (b.shape[1] if trans_b else b.shape[2]))
+ )
+ return torch.empty(*shape, device=a.device, dtype=a.dtype)
+
+def gmm(a, b, batch_sizes, trans_a=False, trans_b=False, c=None):
+ if c is None:
+ c = _allocate_output(a, b, batch_sizes, trans_a, trans_b)
+ backend.gmm(a, b, c, batch_sizes, trans_a, trans_b)
+ return c
diff --git a/build/torch29-cxx11-cu130-x86_64-linux/grouped_gemm/ops.py b/build/torch29-cxx11-cu130-x86_64-linux/grouped_gemm/ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b30dd14e23837ea3b12334f4e31337ed9ad2b69
--- /dev/null
+++ b/build/torch29-cxx11-cu130-x86_64-linux/grouped_gemm/ops.py
@@ -0,0 +1,33 @@
+from . import backend
+import torch
+
+
+class GroupedGemm(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, a, b, batch_sizes, trans_b):
+ ctx.save_for_backward(a, b, batch_sizes)
+ ctx.trans_b = trans_b
+ return backend.gmm(a, b, batch_sizes, trans_a=False, trans_b=trans_b)
+
+ @staticmethod
+ def backward(ctx, grad):
+ grad = grad.contiguous()
+ a, b, batch_sizes = ctx.saved_tensors
+ trans_b = ctx.trans_b
+
+ agrad = None
+ if ctx.needs_input_grad[0]:
+ agrad = backend.gmm(
+ grad, b, batch_sizes, trans_a=False, trans_b=not trans_b)
+
+ bgrad = None
+ if ctx.needs_input_grad[1]:
+ lhs, rhs = (grad, a) if trans_b else (a, grad)
+ bgrad = backend.gmm(
+ lhs, rhs, batch_sizes, trans_a=True, trans_b=False)
+ return agrad, bgrad, None, None
+
+
+def gmm(a, b, batch_sizes, trans_b=False):
+ return GroupedGemm.apply(a, b, batch_sizes, trans_b)
diff --git a/build/torch29-cxx11-cu130-x86_64-linux/grouped_gemm_util.py b/build/torch29-cxx11-cu130-x86_64-linux/grouped_gemm_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d6d49fc46b0a57ad46e4179df3cc1ac2a24f7ae
--- /dev/null
+++ b/build/torch29-cxx11-cu130-x86_64-linux/grouped_gemm_util.py
@@ -0,0 +1,31 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+import warnings
+
+_grouped_gemm_is_available: bool = False
+try:
+ # import grouped_gemm
+ pass
+ _grouped_gemm_is_available = True
+except ImportError as error:
+ warnings.warn('Grouped GEMM not available.')
+
+
+def grouped_gemm_is_available():
+ return _grouped_gemm_is_available
+
+
+def assert_grouped_gemm_is_available():
+ msg = (
+ 'Grouped GEMM not available. Please run '
+ '`pip install git+https://github.com/tgale96/grouped_gemm@main`.',
+ )
+ assert _grouped_gemm_is_available, msg
+
+
+# backend = grouped_gemm.backend if grouped_gemm_is_available() else None
+# ops = grouped_gemm.ops if grouped_gemm_is_available() else None
+
+
+from .grouped_gemm import backend as ops
+from .grouped_gemm import ops as backend
diff --git a/build/torch29-cxx11-cu130-x86_64-linux/layers.py b/build/torch29-cxx11-cu130-x86_64-linux/layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c91edcd9e2d1a4ef9eac90217ff481f08ab1886
--- /dev/null
+++ b/build/torch29-cxx11-cu130-x86_64-linux/layers.py
@@ -0,0 +1,1232 @@
+import torch
+import torch.distributed as dist
+
+from typing import Optional, Any, TYPE_CHECKING
+
+from . import _layers
+from . import ops
+
+# Conditional import for meta kernel registration
+if TYPE_CHECKING:
+
+ def register_fake(fn):
+ return lambda name: fn
+
+else:
+ try:
+ from torch.library import register_fake
+ except ImportError:
+ try:
+ from torch.library import impl_abstract as register_fake
+ except ImportError:
+ # Fallback for older PyTorch versions
+ def register_fake(op_name):
+ def decorator(fn):
+ return fn
+
+ return decorator
+
+
+# Meta kernel implementations for torch.compile compatibility
+def _install_meta_kernels():
+ """Install meta kernels for existing MegaBlocks operations"""
+
+ # Create wrapper functions that check for compilation and return meta tensors
+
+ # Patch ops.sort
+ if hasattr(ops, "sort"):
+ original_sort = ops.sort
+
+ def sort_with_meta(x, end_bit=None):
+ if torch.compiler.is_compiling():
+ # print("Using meta kernel for sort")
+ # Meta implementation - return tensors with correct shape/dtype/device
+ return torch.empty_like(x), torch.empty_like(x)
+ # print("Using original sort kernel")
+ return original_sort(x, end_bit)
+
+ ops.sort = sort_with_meta
+
+ # Patch ops.histogram
+ if hasattr(ops, "histogram"):
+ original_histogram = ops.histogram
+
+ def histogram_with_meta(x, max_val):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty((max_val,), dtype=torch.int32, device=x.device)
+ return original_histogram(x, max_val)
+
+ ops.histogram = histogram_with_meta
+
+ # Patch ops.inclusive_cumsum
+ if hasattr(ops, "inclusive_cumsum"):
+ original_inclusive_cumsum = ops.inclusive_cumsum
+
+ def inclusive_cumsum_with_meta(x, dim):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty_like(x)
+ return original_inclusive_cumsum(x, dim)
+
+ ops.inclusive_cumsum = inclusive_cumsum_with_meta
+
+ # Patch ops.binned_gather
+ if hasattr(ops, "binned_gather"):
+ original_binned_gather = ops.binned_gather
+
+ def binned_gather_with_meta(x, indices, bins, bin_size, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - output shape based on bin_size
+ if x.dim() >= 2:
+ hidden_size = x.size(-1)
+ return torch.empty(
+ (bin_size, x.size(1), hidden_size),
+ dtype=x.dtype,
+ device=x.device,
+ )
+ else:
+ return torch.empty((bin_size,), dtype=x.dtype, device=x.device)
+ return original_binned_gather(x, indices, bins, bin_size, top_k)
+
+ ops.binned_gather = binned_gather_with_meta
+
+ # Patch ops.binned_scatter
+ if hasattr(ops, "binned_scatter"):
+ original_binned_scatter = ops.binned_scatter
+
+ def binned_scatter_with_meta(x, indices, weights, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - typically reduces to 2D
+ if x.dim() >= 3:
+ return torch.empty(
+ (x.size(1), x.size(2)), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty_like(x)
+ return original_binned_scatter(x, indices, weights, bins, top_k)
+
+ ops.binned_scatter = binned_scatter_with_meta
+
+ # Patch ops.gather
+ if hasattr(ops, "gather"):
+ original_gather = ops.gather
+
+ def gather_with_meta(x, indices, bin_ids, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ if x.dim() >= 2:
+ hidden_size = x.size(-1)
+ return torch.empty(
+ (indices.numel(), hidden_size), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty(indices.shape, dtype=x.dtype, device=x.device)
+ return original_gather(x, indices, bin_ids, bins, top_k)
+
+ ops.gather = gather_with_meta
+
+ # Patch ops.scatter
+ if hasattr(ops, "scatter"):
+ original_scatter = ops.scatter
+
+ def scatter_with_meta(x, indices, bin_ids, weights, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - restore sequence shape
+ seq_len = (
+ indices.size(0) // top_k
+ if indices.numel() > 0 and top_k > 0
+ else x.size(0)
+ )
+ if x.dim() >= 2:
+ return torch.empty(
+ (seq_len, x.size(-1)), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty((seq_len,), dtype=x.dtype, device=x.device)
+ return original_scatter(x, indices, bin_ids, weights, bins, top_k)
+
+ ops.scatter = scatter_with_meta
+
+ # Patch ops.replicate
+ if hasattr(ops, "replicate"):
+ original_replicate = ops.replicate
+
+ def replicate_with_meta(x, bins, num_outputs):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty(
+ (x.shape[0], num_outputs), dtype=x.dtype, device=x.device
+ )
+ return original_replicate(x, bins, num_outputs)
+
+ ops.replicate = replicate_with_meta
+
+ # Patch ops.repeat (if it's a regular function)
+ if hasattr(ops, "repeat"):
+ original_repeat = ops.repeat
+
+ def repeat_with_meta(x, repeats):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ if isinstance(repeats, (tuple, list)):
+ new_shape = list(x.shape)
+ for i, rep in enumerate(repeats):
+ if i < len(new_shape):
+ new_shape[i] *= rep
+ return torch.empty(new_shape, dtype=x.dtype, device=x.device)
+ else:
+ new_shape = [x.size(0) * repeats] + list(x.shape[1:])
+ return torch.empty(new_shape, dtype=x.dtype, device=x.device)
+ return original_repeat(x, repeats)
+
+ ops.repeat = repeat_with_meta
+
+
+# Install meta kernels on import
+try:
+ _install_meta_kernels()
+except Exception as e:
+ # If meta kernel installation fails, continue without them
+ # torch.compile may not work but the library will still function
+ import warnings
+
+ warnings.warn(
+ f"Failed to install meta kernels for torch.compile support: {e}", UserWarning
+ )
+
+
+# Set the expert model parallel attributes on a tensor
+def set_expert_model_parallel_attributes(
+ tensor: torch.Tensor,
+ is_parallel: bool,
+):
+ assert not hasattr(tensor, "expert_model_parallel")
+ setattr(tensor, "expert_model_parallel", is_parallel)
+
+
+# Get the expert model parallel attributes from a tensor
+def expert_sharding_degree(
+ world_size: int,
+ moe_num_experts: int,
+) -> int:
+ esd = min(world_size, moe_num_experts)
+ if (moe_num_experts % esd) != 0:
+ raise ValueError(f"Cannot shard {moe_num_experts} experts {esd} ways.")
+ return esd
+
+
+# Calculate the hidden sharding degree based on world size and expert sharding degree
+def hidden_sharding_degree(
+ world_size: int,
+ moe_num_experts: int,
+ ffn_hidden_size: int,
+) -> int:
+ esd = expert_sharding_degree(world_size, moe_num_experts)
+ hsd = world_size // esd
+ if (ffn_hidden_size % hsd) != 0:
+ raise ValueError(f"Cannot shard {ffn_hidden_size} features {hsd} ways.")
+ if (esd * hsd) != world_size:
+ raise ValueError(
+ f"Invalid sharding. expert_sharding_degree ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size})."
+ )
+ return hsd
+
+
+# Calculate the number of experts per rank based on world size and expert sharding degree
+def experts_per_rank(
+ moe_num_experts: int,
+ world_size: int,
+) -> int:
+ return moe_num_experts // expert_sharding_degree(world_size, moe_num_experts)
+
+
+# Calculate the number of features per rank based on ffn hidden size and hidden sharding degree
+def features_per_rank(
+ ffn_hidden_size: int, world_size: int, moe_num_experts: int
+) -> int:
+ return ffn_hidden_size // hidden_sharding_degree(
+ world_size, moe_num_experts, ffn_hidden_size
+ )
+
+
+# Apply jitter to the input tensor
+def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor:
+ low = 1.0 - moe_jitter_eps
+ high = 1.0 + moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return x * (low + noise * (high - low))
+
+
+# Compute the top-k scores from the logits
+def compute_top_k(scores: torch.Tensor, moe_top_k: int):
+ if moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, moe_top_k, dim=-1)
+
+
+# Route tokens to experts and compute expert weights and indices
+def route_tokens(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: torch.Tensor,
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ if training and moe_jitter_eps is not None:
+ x = apply_jitter(x, moe_jitter_eps)
+
+ x_flat = x.view(-1, x.shape[-1])
+ logits = torch.nn.functional.linear(x_flat, router_weight, router_bias)
+ expert_weights, expert_indices = compute_top_k(logits, moe_top_k)
+ expert_weights = expert_weights.softmax(dim=-1)
+ if moe_normalize_expert_weights is not None:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+ if uniform_expert_assignment:
+ expert_indices = _layers.router._uniform_expert_assignment(
+ expert_indices,
+ moe_num_experts,
+ )
+
+ return logits, expert_weights, expert_indices
+
+
+# Scale the gradient of the weights
+def scale_grad(
+ w: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+) -> torch.Tensor:
+ if gradient_scale is None:
+ return w
+ return _layers.mlp.scale_gradient(w, gradient_scale)
+
+
+# Forward pass for the MLP layer
+def mlp_forward(
+ x: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ limit: float = 7.0,
+):
+ # Scale weights
+ w1 = scale_grad(w1, gradient_scale)
+ w2 = scale_grad(w2, gradient_scale)
+ w1_bias = scale_grad(w1_bias, gradient_scale)
+ w2_bias = scale_grad(w2_bias, gradient_scale)
+
+ # Resolve dtensors
+ w1 = _layers.mlp.resolve_dtensor(w1)
+ w2 = _layers.mlp.resolve_dtensor(w2)
+ w1_bias = _layers.mlp.resolve_dtensor(w1_bias)
+ w2_bias = _layers.mlp.resolve_dtensor(w2_bias)
+
+ # Forward pass
+ gate_up = torch.bmm(x, w1) + w1_bias[..., None, :]
+ gate, up = gate_up[..., ::2], gate_up[..., 1::2]
+ gate = gate.clamp(min=None, max=limit)
+ up = up.clamp(min=-limit, max=limit)
+ glu = gate * torch.sigmoid(gate * alpha)
+ next_states = torch.bmm(((up + 1) * glu), w2)
+ next_states += w2_bias[..., None, :]
+ return next_states
+
+# Shared expert MLP forward pass
+def shared_mlp_forward(
+ x: torch.Tensor,
+ up_proj_weight: torch.Tensor,
+ down_proj_weight: torch.Tensor,
+ up_proj_bias: Optional[torch.Tensor] = None,
+ down_proj_bias: Optional[torch.Tensor] = None,
+ activation_fn: Optional[Any] = None,
+ gradient_scale: Optional[float] = None,
+) -> torch.Tensor:
+ # Default activation function
+ if activation_fn is None:
+ activation_fn = torch.nn.functional.gelu
+
+ # Scale weights
+ up_proj_weight = scale_grad(up_proj_weight, gradient_scale)
+ down_proj_weight = scale_grad(down_proj_weight, gradient_scale)
+ if up_proj_bias is not None:
+ up_proj_bias = scale_grad(up_proj_bias, gradient_scale)
+ if down_proj_bias is not None:
+ down_proj_bias = scale_grad(down_proj_bias, gradient_scale)
+
+ # Resolve dtensors
+ up_proj_weight = _layers.mlp.resolve_dtensor(up_proj_weight)
+ down_proj_weight = _layers.mlp.resolve_dtensor(down_proj_weight)
+ if up_proj_bias is not None:
+ up_proj_bias = _layers.mlp.resolve_dtensor(up_proj_bias)
+ if down_proj_bias is not None:
+ down_proj_bias = _layers.mlp.resolve_dtensor(down_proj_bias)
+
+ # Up projection
+ x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias)
+
+ # Activation
+ x = activation_fn(x)
+
+ # Down projection
+ x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias)
+
+ return x
+
+
+# Combine outputs from shared expert and regular experts
+def combine_expert_shared_outputs(
+ shared_expert_out: torch.Tensor,
+ expert_out: torch.Tensor,
+ shared_expert_weighted_sum: bool = False,
+ moe_top_k: int = 1,
+) -> torch.Tensor:
+ if shared_expert_weighted_sum:
+ # Weighted sum based on number of experts used
+ total_experts = moe_top_k + 1
+ shared_weight = 1.0 / total_experts
+ expert_weight = moe_top_k / total_experts
+ return shared_expert_out * shared_weight + expert_out * expert_weight
+ else:
+ # Simple addition
+ return shared_expert_out + expert_out
+
+
+# Global variable to store load balancing loss
+_LOAD_BALANCING_LOSS = []
+
+
+def save_load_balancing_loss(loss):
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.append(loss)
+
+
+def get_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ return _LOAD_BALANCING_LOSS
+
+
+def clear_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.clear()
+
+
+def batched_load_balancing_loss(args):
+ if args.moe_loss_weight == 0:
+ return 0.0
+
+ tokens_per_expert, expert_scores = zip(*get_load_balancing_loss())
+ num_layers_per_pipeline_stage = args.num_layers // args.pipeline_model_parallel_size
+ if args.num_layers_per_virtual_pipeline_stage is not None:
+ num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage
+
+ if len(tokens_per_expert) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f"Expected {num_layers_per_pipeline_stage} token_per_experts "
+ f"but found {len(tokens_per_expert)}.\nnum_layers = "
+ f"{args.num_layers}\npipeline_model_parallel_size = "
+ f"{args.pipeline_model_parallel_size}\n"
+ "num_layers_per_virtual_pipeline_stage"
+ f" = {args.num_layers_per_virtual_pipeline_stage}",
+ )
+ if len(expert_scores) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f"Expected {num_layers_per_pipeline_stage} expert_scores "
+ f"but found {len(tokens_per_expert)}.\nnum_layers = "
+ f"{args.num_layers}\npipeline_model_parallel_size = "
+ f"{args.pipeline_model_parallel_size}\n"
+ "num_layers_per_virtual_pipeline_stage"
+ f" = {args.num_layers_per_virtual_pipeline_stage}",
+ )
+
+ # Verify the shape of the tokens_per_expert and expert_scores tensors.
+ assert all(
+ (x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert)
+ )
+
+ tokens = expert_scores[0].shape[0]
+ assert all(
+ (
+ (
+ x.ndim == 2
+ and x.shape[1] == args.moe_num_experts
+ and x.shape[0] == tokens
+ )
+ for x in expert_scores
+ )
+ )
+
+ # Concatenate the contributions of each layer and convert to
+ # the correct types and formats for the dot product.
+ expert_scores = torch.cat(expert_scores, dim=1)
+ if args.moe_lbl_in_fp32:
+ expert_scores = expert_scores.float()
+ if tokens != 0:
+ expert_scores = expert_scores.mean(dim=0)
+ else:
+ expert_scores = expert_scores.sum(dim=0)
+ tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype)
+
+ expected_values = num_layers_per_pipeline_stage * args.moe_num_experts
+ assert tokens_per_expert.numel() == expected_values
+ assert expert_scores.numel() == expected_values
+
+ # Calculate the total scale across all factors.
+ #
+ # loss_weight * num_experts / (num_layers * tokens * top_k)
+ scale_numerator = args.moe_num_experts * args.moe_loss_weight
+ scale_denominator = args.num_layers * tokens * args.moe_top_k
+ scale = scale_numerator / scale_denominator
+ return scale * torch.dot(tokens_per_expert, expert_scores)
+
+
+# Calculate the expert capacity based on tokens, top_k, number of experts,
+# expert parallel group, capacity factor, and whether expert model parallelism is used.
+def expert_capacity(
+ tokens: int,
+ top_k: int,
+ num_experts: int,
+ expert_parallel_group: int,
+ moe_capacity_factor: float,
+ moe_expert_model_parallelism: bool,
+) -> int:
+ world_size = (
+ dist.get_world_size(expert_parallel_group)
+ if moe_expert_model_parallelism
+ else 1
+ )
+
+ tokens_per_expert = top_k * tokens * world_size / num_experts
+ return int(moe_capacity_factor * tokens_per_expert)
+
+
+def load_balancing_loss(
+ tokens_per_expert: torch.Tensor,
+ expert_scores: torch.Tensor,
+ top_k: int,
+ num_experts: int,
+):
+ assert len(expert_scores.size()) == 2
+ tokens, num_experts = expert_scores.size()
+ assert num_experts == num_experts
+ assert len(tokens_per_expert.size()) == 1
+ (num_experts,) = tokens_per_expert.size()
+ assert num_experts == num_experts
+ scale = num_experts / (tokens * top_k)
+ return scale * torch.dot(
+ tokens_per_expert.to(expert_scores.dtype),
+ expert_scores.mean(dim=0),
+ )
+
+
+def indices_and_bins(
+ top_expert: torch.Tensor,
+ sort_end_bit: int,
+ num_experts: int,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ top_expert = top_expert.int()
+
+ # Ensure contiguous memory layout
+ top_expert = top_expert.contiguous()
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(top_expert.device):
+ output = ops.sort(top_expert, sort_end_bit)
+ bin_ids, indices = output
+ tokens_per_expert = ops.histogram(top_expert, num_experts)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+
+ bins = bins.view(1) if not len(bins.size()) else bins
+ return indices, bin_ids, bins, tokens_per_expert
+
+
+def expert_capacity_fn(
+ tokens: int,
+ top_k: int,
+ num_experts: int,
+ expert_parallel_group: torch.distributed.ProcessGroup,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+) -> int:
+ world_size = (
+ dist.get_world_size(expert_parallel_group)
+ if moe_expert_model_parallelism
+ else 1
+ )
+ tokens_per_expert = top_k * tokens * world_size / num_experts
+ return int(moe_capacity_factor * tokens_per_expert)
+
+
+def permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ top_k,
+ w1,
+ w2,
+ w1_bias,
+ w2_bias,
+ gradient_scale,
+ alpha,
+):
+ # Route tokens to experts
+ x = x.view(-1, x.shape[-1])
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(x.device):
+ x = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
+
+ # Expert computation
+ x = mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale, alpha)
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(x.device):
+ # Route tokens back
+ out = ops.binned_scatter(x, indices, expert_weights, bins, top_k)
+ return out
+
+
+def forward_once(
+ x: torch.Tensor,
+ expert_weights: torch.Tensor,
+ top_experts: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ top_k: int = 4,
+ num_experts: int = 128,
+ expert_parallel_group: int = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ mlp_impl: Optional[str] = None,
+):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
+ top_experts, sort_end_bit, num_experts
+ )
+
+ # Calculate expert capacity
+ sl, bs, _ = x.size()
+
+ expert_capacity = expert_capacity_fn(
+ sl * bs,
+ top_k,
+ num_experts,
+ expert_parallel_group,
+ moe_capacity_factor,
+ moe_expert_model_parallelism,
+ )
+
+ if expert_capacity == 0:
+ expert_capacity = torch.max(tokens_per_expert).item()
+
+ x = permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ top_k,
+ w1,
+ w2,
+ w1_bias,
+ w2_bias,
+ gradient_scale,
+ alpha,
+ )
+ return x, tokens_per_expert
+
+
+def parallel_forward_once(
+ x: torch.Tensor,
+ expert_weights: torch.Tensor,
+ top_experts: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ top_k: int = 4,
+ num_experts: int = 128,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = True,
+ hidden_size: int = 1152,
+ mlp_impl: Optional[str] = "grouped",
+):
+ # Flatten inputs
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+
+ # TODO: remove debugging var
+ # my_rank = dist.get_rank(expert_parallel_group) if expert_parallel_group else 0
+
+ with torch.no_grad():
+ # Step 1: Local permutation setup
+ indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
+ top_experts, sort_end_bit, num_experts
+ )
+
+ # Calculate sharding parameters
+ world_size = dist.get_world_size(expert_parallel_group)
+ hidden_sharding_deg = hidden_sharding_degree(
+ world_size, num_experts, hidden_size
+ )
+ experts_per_rank_val = experts_per_rank(num_experts, world_size)
+
+ # Replicate token counts for hidden sharding
+ repeated_tokens_per_expert = ops.repeat(
+ tokens_per_expert, (hidden_sharding_deg,)
+ )
+
+ # Exchange token counts across devices
+ parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert)
+
+ # Ensure CUB knows which device to use
+ tpe_handle = dist.all_to_all_single(
+ parallel_tokens_per_expert,
+ repeated_tokens_per_expert,
+ group=expert_parallel_group,
+ async_op=True,
+ )
+
+ # Step 2: Local permutation - group tokens by target device
+ x = x.view(-1, x.shape[-1]) # [sl * bs, hs]
+ x = ops.gather(x, indices, bin_ids, bins, top_k)
+
+ # Step 3: Compute communication counts and exchange tokens
+ with torch.no_grad():
+ tpe_handle.wait()
+
+ # Reshape for per-device calculations
+ repeated_tokens_per_expert = repeated_tokens_per_expert.view(
+ world_size, experts_per_rank_val
+ )
+ parallel_tokens_per_expert = parallel_tokens_per_expert.view(
+ world_size, experts_per_rank_val
+ )
+
+ # Calculate send/recv counts
+ send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1).tolist()
+ # recv_counts = parallel_tokens_per_expert.cpu().sum(dim=-1).tolist()
+ parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
+ recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1).tolist()
+ tokens_received = sum(recv_counts)
+
+ # Replicate for hidden sharding
+ x = ops.repeat(x, (hidden_sharding_deg, 1))
+
+ # Cross-device token exchange
+ parallel_x, parallel_x_handle = _layers.all_to_all.all_to_all(
+ x, recv_counts, send_counts, expert_parallel_group, async_op=True
+ )
+
+ with torch.no_grad():
+ # Step 4: Setup for local expert computation
+ replicate_bins = ops.inclusive_cumsum(parallel_tokens_per_expert.flatten(), 0)
+ replicate_bins = (
+ replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins
+ )
+
+ # Create expert indices for received tokens
+ parallel_top_expert = torch.remainder(
+ torch.arange(
+ num_experts * hidden_sharding_deg,
+ dtype=torch.int32,
+ device=indices.device,
+ ),
+ experts_per_rank_val,
+ )
+ parallel_top_expert = ops.replicate(
+ parallel_top_expert.unsqueeze(dim=0),
+ replicate_bins,
+ tokens_received,
+ ).flatten()
+
+ # Sort tokens by expert assignment
+ parallel_bin_ids, parallel_indices = ops.sort(
+ parallel_top_expert,
+ sort_end_bit,
+ )
+
+ # Calculate bins for local experts
+ parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
+ dim=0, dtype=torch.int
+ )
+ parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = (
+ parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins
+ )
+
+ # Calculate expert capacity
+ expert_capacity = expert_capacity_fn(
+ tokens_received,
+ top_k,
+ experts_per_rank_val,
+ expert_parallel_group,
+ moe_capacity_factor,
+ moe_expert_model_parallelism,
+ )
+ if expert_capacity == 0:
+ expert_capacity = torch.max(parallel_tokens_per_expert).item()
+
+ # Locally permute the tokens and perform the expert computation.
+ # Block to make sure that the cross-device permutation is complete.
+ if mlp_impl == "grouped":
+ # GroupedMLP requires counts on CPU. We can use the tensor already
+ # moved to CPU for the prior all_to_all, which avoids an extra
+ # device synchronization.
+ parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+
+ # Step 5: Expert computation
+ parallel_x_handle.wait()
+
+ parallel_x = permute_and_compute(
+ parallel_x,
+ parallel_tokens_per_expert,
+ parallel_indices,
+ parallel_bin_ids,
+ None, # expert_weights
+ parallel_bins,
+ expert_capacity,
+ top_k=1,
+ w1=w1,
+ w2=w2,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ )
+
+ # Step 6: Reverse communication - send results back
+ x, _ = _layers.all_to_all.all_to_all(
+ parallel_x, send_counts, recv_counts, expert_parallel_group
+ )
+
+ # Step 7: Reduce across hidden sharding dimension
+ shape = (hidden_sharding_deg, -1, hidden_size)
+ x = x.view(shape).sum(dim=0)
+
+ # Step 8: Final local unpermutation
+ x = ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
+
+ return x, tokens_per_expert.flatten()
+
+
+def moe_forward(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: Optional[torch.Tensor],
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+ w1: torch.Tensor = None,
+ w2: torch.Tensor = None,
+ w1_bias: torch.Tensor = None,
+ w2_bias: torch.Tensor = None,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ forward_fn: Any = None,
+ hidden_size: int = None,
+ mlp_impl: str = "grouped",
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+
+ # Route tokens to experts
+ logits, expert_weights, expert_indices = route_tokens(
+ x,
+ router_weight,
+ router_bias,
+ moe_top_k,
+ moe_num_experts,
+ moe_jitter_eps,
+ moe_normalize_expert_weights,
+ uniform_expert_assignment,
+ training,
+ )
+
+ # Create router scores for output
+ router_scores = (
+ torch.zeros_like(logits)
+ .scatter_(1, expert_indices, expert_weights)
+ .transpose(0, 1)
+ )
+
+ in_shape = x.size()
+
+ # Prepare forward function arguments
+ forward_args = {
+ "x": x,
+ "expert_weights": expert_weights,
+ "top_experts": expert_indices,
+ "w1": w1,
+ "w2": w2,
+ "w1_bias": w1_bias,
+ "w2_bias": w2_bias,
+ "gradient_scale": gradient_scale,
+ "alpha": alpha,
+ "sort_end_bit": sort_end_bit,
+ "top_k": moe_top_k,
+ "num_experts": moe_num_experts,
+ "expert_parallel_group": expert_parallel_group,
+ "moe_capacity_factor": moe_capacity_factor,
+ "moe_expert_model_parallelism": moe_expert_model_parallelism,
+ "mlp_impl": mlp_impl,
+ }
+
+ # Add hidden_size for parallel forward
+ if moe_expert_model_parallelism and hidden_size is not None:
+ forward_args["hidden_size"] = hidden_size
+ elif moe_expert_model_parallelism and hidden_size is None:
+ # Infer hidden_size from input shape
+ forward_args["hidden_size"] = x.shape[-1]
+
+ # Compute expert outputs
+ x, tokens_per_expert = forward_fn(**forward_args)
+
+ # Save load balancing loss if needed
+ moe_loss_weight = 0.0 # Can be made configurable
+ if training and moe_loss_weight > 0:
+ save_load_balancing_loss((tokens_per_expert, logits))
+
+ # Restore original shape
+ x = x.view(in_shape)
+
+ return x, expert_weights, router_scores
+
+
+def moe_forward_with_shared_expert(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: Optional[torch.Tensor],
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+ w1: torch.Tensor = None,
+ w2: torch.Tensor = None,
+ w1_bias: torch.Tensor = None,
+ w2_bias: torch.Tensor = None,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ forward_fn: Any = None,
+ hidden_size: int = None,
+ mlp_impl: str = "grouped",
+ # Shared expert parameters
+ shared_up_proj_weight: Optional[torch.Tensor] = None,
+ shared_down_proj_weight: Optional[torch.Tensor] = None,
+ shared_up_proj_bias: Optional[torch.Tensor] = None,
+ shared_down_proj_bias: Optional[torch.Tensor] = None,
+ shared_expert_weighted_sum: bool = False,
+ shared_activation_fn: Optional[Any] = None,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+
+ # First, compute regular MoE forward pass
+ expert_out, expert_weights, router_scores = moe_forward(
+ x=x,
+ router_weight=router_weight,
+ router_bias=router_bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=training,
+ w1=w1,
+ w2=w2,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=moe_expert_model_parallelism,
+ forward_fn=forward_fn,
+ hidden_size=hidden_size,
+ mlp_impl=mlp_impl,
+ )
+
+ # If shared expert weights provided, compute shared expert output
+ if shared_up_proj_weight is not None and shared_down_proj_weight is not None:
+ shared_expert_out = shared_mlp_forward(
+ x=x,
+ up_proj_weight=shared_up_proj_weight,
+ down_proj_weight=shared_down_proj_weight,
+ up_proj_bias=shared_up_proj_bias,
+ down_proj_bias=shared_down_proj_bias,
+ activation_fn=shared_activation_fn,
+ gradient_scale=gradient_scale,
+ )
+
+ # Combine expert outputs
+ combined_out = combine_expert_shared_outputs(
+ shared_expert_out=shared_expert_out,
+ expert_out=expert_out,
+ shared_expert_weighted_sum=shared_expert_weighted_sum,
+ moe_top_k=moe_top_k,
+ )
+
+ return combined_out, expert_weights, router_scores
+
+ # Return regular MoE output if no shared expert
+ return expert_out, expert_weights, router_scores
+
+
+def create_shared_expert_weights(
+ hidden_size: int,
+ shared_expert_hidden_size: int,
+ device: torch.device,
+ dtype: torch.dtype,
+ init_method: Any,
+ output_layer_init_method: Any = None,
+) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
+
+ if output_layer_init_method is None:
+ output_layer_init_method = init_method
+
+ # Create weight tensors
+ up_proj_weight = torch.empty(
+ shared_expert_hidden_size,
+ hidden_size,
+ device=device,
+ dtype=dtype,
+ )
+ down_proj_weight = torch.empty(
+ hidden_size,
+ shared_expert_hidden_size,
+ device=device,
+ dtype=dtype,
+ )
+
+ # Initialize weights
+ init_method(up_proj_weight)
+ output_layer_init_method(down_proj_weight)
+
+ # No bias by default
+ return up_proj_weight, down_proj_weight, None, None
+
+
+# HACK: Extract device_mesh from pre-hook closure - required for transformers integration
+# This exists because device_mesh is trapped in hook closures with no model attribute
+# Fragile - breaks if hook structure changes or Python internals change
+# TODO: Replace with a more robust solution when available
+def get_device_mesh(model):
+ # Extract device_mesh from child's unused pre_hook closure
+ try:
+ # Find the pre-hook that contains 'device_mesh' in its closure
+ hook = next(
+ h
+ for h in model.experts._forward_pre_hooks.values()
+ if "device_mesh" in h.__code__.co_freevars
+ )
+ # Extract the device_mesh from the closure
+ return hook.__closure__[
+ hook.__code__.co_freevars.index("device_mesh")
+ ].cell_contents
+ except Exception:
+ return None
+
+
+class MegaBlocksMoeMLP(torch.nn.Module):
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ gradient_scale = getattr(self.experts, "gradient_scale", None)
+ alpha = getattr(self.experts, "alpha", 1.0)
+ moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+ uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
+
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = get_device_mesh(self)
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
+
+ has_parallel = (
+ expert_parallel_group is not None
+ and dist.is_initialized()
+ and dist.get_world_size(expert_parallel_group) > 1
+ )
+ forward_fn = parallel_forward_once if has_parallel else forward_once
+
+ sort_end_bit = max(
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
+ )
+ mlp_impl = getattr(self, "mlp_impl", "grouped")
+ output, expert_weights_out, *_ = moe_forward(
+ x=x,
+ router_weight=self.router.weight,
+ router_bias=self.router.bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=self.training,
+ w1=self.experts.gate_up_proj,
+ w2=self.experts.down_proj,
+ w1_bias=self.experts.gate_up_proj_bias,
+ w2_bias=self.experts.down_proj_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=has_parallel,
+ forward_fn=forward_fn,
+ hidden_size=self.experts.hidden_size,
+ mlp_impl=mlp_impl,
+ )
+ return output, expert_weights_out
+
+
+# Export main classes
+__all__ = ["MegaBlocksMoeMLP", "MegaBlocksMoeMLPWithSharedExpert"]
+
+
+class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
+
+ def __init__(self):
+ super().__init__()
+ # Shared expert weights will be set by the user
+ self.shared_up_proj_weight = None
+ self.shared_down_proj_weight = None
+ self.shared_up_proj_bias = None
+ self.shared_down_proj_bias = None
+ self.shared_expert_weighted_sum = False
+ self.shared_activation_fn = None
+
+ def set_shared_expert_weights(
+ self,
+ up_proj_weight: torch.Tensor,
+ down_proj_weight: torch.Tensor,
+ up_proj_bias: Optional[torch.Tensor] = None,
+ down_proj_bias: Optional[torch.Tensor] = None,
+ weighted_sum: bool = False,
+ activation_fn: Optional[Any] = None,
+ ):
+ self.shared_up_proj_weight = up_proj_weight
+ self.shared_down_proj_weight = down_proj_weight
+ self.shared_up_proj_bias = up_proj_bias
+ self.shared_down_proj_bias = down_proj_bias
+ self.shared_expert_weighted_sum = weighted_sum
+ self.shared_activation_fn = activation_fn
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ gradient_scale = getattr(self.experts, "gradient_scale", None)
+ alpha = getattr(self.experts, "alpha", 1.0)
+ moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+ uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
+
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = get_device_mesh(self)
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
+
+ has_parallel = (
+ expert_parallel_group is not None
+ and dist.is_initialized()
+ and dist.get_world_size(expert_parallel_group) > 1
+ )
+ forward_fn = parallel_forward_once if has_parallel else forward_once
+
+ sort_end_bit = max(
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
+ )
+ mlp_impl = getattr(self, "mlp_impl", "grouped")
+
+ output, expert_weights_out, *_ = moe_forward_with_shared_expert(
+ x=x,
+ router_weight=self.router.weight,
+ router_bias=self.router.bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=self.training,
+ w1=self.experts.gate_up_proj,
+ w2=self.experts.down_proj,
+ w1_bias=self.experts.gate_up_proj_bias,
+ w2_bias=self.experts.down_proj_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=has_parallel,
+ forward_fn=forward_fn,
+ hidden_size=self.experts.hidden_size,
+ mlp_impl=mlp_impl,
+ # Shared expert parameters
+ shared_up_proj_weight=self.shared_up_proj_weight,
+ shared_down_proj_weight=self.shared_down_proj_weight,
+ shared_up_proj_bias=self.shared_up_proj_bias,
+ shared_down_proj_bias=self.shared_down_proj_bias,
+ shared_expert_weighted_sum=self.shared_expert_weighted_sum,
+ shared_activation_fn=self.shared_activation_fn,
+ )
+ return output, expert_weights_out
+
+
+# Patch for XPU or CPU support
+if hasattr(torch, "xpu") and torch.xpu.is_available():
+ from .xpu_fused_moe import MegaBlocksMoeMLP
+
+from .cpu_moe_cpp import CPUMegaBlocksMoeMLP
diff --git a/build/torch29-cxx11-cu130-x86_64-linux/megablocks/__init__.py b/build/torch29-cxx11-cu130-x86_64-linux/megablocks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..03dbc1afe1cf156661a2b1b22003cd5f599a0309
--- /dev/null
+++ b/build/torch29-cxx11-cu130-x86_64-linux/megablocks/__init__.py
@@ -0,0 +1,26 @@
+import ctypes
+import sys
+
+import importlib
+from pathlib import Path
+from types import ModuleType
+
+def _import_from_path(file_path: Path) -> ModuleType:
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
+ # it would also be used for other imports. So, we make a module name that
+ # depends on the path for it to be unique using the hex-encoded hash of
+ # the path.
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
+ module_name = path_hash
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ if spec is None:
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
+ module = importlib.util.module_from_spec(spec)
+ if module is None:
+ raise ImportError(f"Cannot load module {module_name} from spec")
+ sys.modules[module_name] = module
+ spec.loader.exec_module(module) # type: ignore
+ return module
+
+
+globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
diff --git a/build/torch29-cxx11-cu130-x86_64-linux/metadata.json b/build/torch29-cxx11-cu130-x86_64-linux/metadata.json
new file mode 100644
index 0000000000000000000000000000000000000000..a9813b81c6c98110d265c184f2016d728202289b
--- /dev/null
+++ b/build/torch29-cxx11-cu130-x86_64-linux/metadata.json
@@ -0,0 +1,18 @@
+{
+ "version": 1,
+ "license": "Apache-2.0",
+ "python-depends": [],
+ "backend": {
+ "type": "cuda",
+ "archs": [
+ "10.0",
+ "12.0",
+ "7.5",
+ "8.0",
+ "8.6",
+ "8.7",
+ "8.9",
+ "9.0"
+ ]
+ }
+}
diff --git a/build/torch29-cxx11-cu130-x86_64-linux/ops/__init__.py b/build/torch29-cxx11-cu130-x86_64-linux/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b944080df810d0b0cfc571f3009b0098a651f9b7
--- /dev/null
+++ b/build/torch29-cxx11-cu130-x86_64-linux/ops/__init__.py
@@ -0,0 +1,35 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from .binned_gather import binned_gather
+from .binned_scatter import binned_scatter
+from .cumsum import exclusive_cumsum, inclusive_cumsum
+from .gather import gather
+from .histogram import histogram
+from .padded_gather import padded_gather
+from .padded_scatter import padded_scatter
+from .repeat import repeat
+from .replicate import replicate
+from .round_up import round_up
+from .scatter import scatter
+from .sort import sort
+from .sum import sum
+from .topology import topology
+
+__all__ = [
+ 'binned_gather',
+ 'binned_scatter',
+ 'exclusive_cumsum',
+ 'inclusive_cumsum',
+ 'gather',
+ 'histogram',
+ 'padded_gather',
+ 'padded_scatter',
+ 'repeat',
+ 'replicate',
+ 'round_up',
+ 'scatter',
+ 'sort',
+ 'sum',
+ 'topology',
+]
diff --git a/build/torch29-cxx11-cu130-x86_64-linux/ops/all_to_all_benchmark.py b/build/torch29-cxx11-cu130-x86_64-linux/ops/all_to_all_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c939818edca3345f6344bbc7cef07ffe3cd0181
--- /dev/null
+++ b/build/torch29-cxx11-cu130-x86_64-linux/ops/all_to_all_benchmark.py
@@ -0,0 +1,63 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import torch.distributed as dist
+
+# from megablocks import benchmark_util
+# from megablocks.layers.all_to_all import all_to_all
+
+from .. import benchmark_util
+from .._layers.all_to_all import all_to_all
+
+_ALL_TO_ALL_BENCHMARK = (
+ (8, 1024),
+ (16, 1024),
+ (32, 1024),
+ (64, 1024),
+ (128, 1024),
+ (256, 1024),
+ (512, 1024),
+ (1024, 1024),
+ (2 * 1024, 1024),
+ (4 * 1024, 1024),
+ (8 * 1024, 1024),
+ (16 * 1024, 1024),
+ (32 * 1024, 1024),
+ (64 * 1024, 1024),
+ (128 * 1024, 1024),
+ (256 * 1024, 1024),
+ (512 * 1024, 1024),
+ (1024 * 1024, 1024),
+)
+
+
+def benchmark_all_to_all(group, sl, hs):
+ world_size = dist.get_world_size(group)
+ assert (sl % world_size) == 0
+ send_recv_sizes = [sl // world_size] * world_size
+
+ x = torch.randn((sl, hs)).cuda().half()
+
+ details = {
+ 'world_size': world_size,
+ 'message_size (B)': send_recv_sizes[0] * hs * 2, # 2B elements.
+ }
+
+ def benchmark():
+ return all_to_all(x, send_recv_sizes, send_recv_sizes, group)
+
+ time, std = benchmark_util.benchmark_function(benchmark)
+
+ if dist.get_rank(group) == 0:
+ benchmark_util.log_benchmark('All-To-All', details, time, std)
+
+
+if __name__ == '__main__':
+ assert dist.is_available()
+ group = dist.init_process_group(backend='nccl')
+ local_rank = dist.get_rank(group)
+ torch.cuda.set_device(local_rank)
+
+ for args in _ALL_TO_ALL_BENCHMARK:
+ benchmark_all_to_all(group, *args)
diff --git a/build/torch29-cxx11-cu130-x86_64-linux/ops/binned_gather.py b/build/torch29-cxx11-cu130-x86_64-linux/ops/binned_gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..189a7fa3518d660f29ea32e7a04827164af98d60
--- /dev/null
+++ b/build/torch29-cxx11-cu130-x86_64-linux/ops/binned_gather.py
@@ -0,0 +1,37 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for binned_gather kernel.
+class BinnedGatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bins: torch.Tensor,
+ bin_size: int,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bins)
+ ctx.top_k = top_k
+ return kernels.binned_gather(x, indices, None, bins, bin_size, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ indices, bins = ctx.saved_tensors
+ out = kernels.binned_scatter(grad, indices, None, bins, ctx.top_k)
+ return out, None, None, None, None
+
+
+binned_gather = BinnedGatherOp.apply
diff --git a/build/torch29-cxx11-cu130-x86_64-linux/ops/binned_scatter.py b/build/torch29-cxx11-cu130-x86_64-linux/ops/binned_scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb937c0c106662ce8108c1cb926f8f063b163d3d
--- /dev/null
+++ b/build/torch29-cxx11-cu130-x86_64-linux/ops/binned_scatter.py
@@ -0,0 +1,59 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for binned_scatter kernel.
+class BinnedScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ):
+ assert len(x.size()) == 3
+ ctx.bin_size = x.size(1)
+ ctx.top_k = top_k
+
+ # TODO(tgale): Don't save 'x' for backwards if we don't need to
+ # calculate the gradient w.r.t. 'weights'.
+ ctx.save_for_backward(x, indices, weights, bins)
+ return kernels.binned_scatter(x, indices, weights, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ x, indices, weights, bins = ctx.saved_tensors
+ out = kernels.binned_gather(
+ grad,
+ indices,
+ weights,
+ bins,
+ ctx.bin_size,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[2]:
+ wgrad = kernels.binned_scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bins,
+ ctx.top_k,
+ )
+ return out, None, wgrad, None, None
+
+
+binned_scatter = BinnedScatterOp.apply
diff --git a/build/torch29-cxx11-cu130-x86_64-linux/ops/cumsum.py b/build/torch29-cxx11-cu130-x86_64-linux/ops/cumsum.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2b7572391e20045d335cf7337246e8a9b9f57ef
--- /dev/null
+++ b/build/torch29-cxx11-cu130-x86_64-linux/ops/cumsum.py
@@ -0,0 +1,52 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ # import megablocks_ops as ops # type: ignore
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrappers for cumsum kernels.
+# NOTE: Does not support gradients.
+class ExclusiveCumsumOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, dim: int):
+ if len(x.size()) == 1:
+ x = x.view([1, -1])
+ out = torch.empty_like(x)
+ ops.exclusive_cumsum(x, 1, out)
+ return out.squeeze()
+ out = torch.empty_like(x)
+ ops.exclusive_cumsum(x, dim, out)
+ return out
+
+
+exclusive_cumsum = ExclusiveCumsumOp.apply
+
+
+class InclusiveCumsumOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, dim: int) -> torch.Tensor:
+ if len(x.size()) == 1:
+ x = x.view([1, -1])
+ out = torch.empty_like(x)
+ ops.inclusive_cumsum(x, 1, out)
+ return out.squeeze()
+ out = torch.empty_like(x)
+ ops.inclusive_cumsum(x, dim, out)
+ return out
+
+
+inclusive_cumsum = InclusiveCumsumOp.apply
diff --git a/build/torch29-cxx11-cu130-x86_64-linux/ops/gather.py b/build/torch29-cxx11-cu130-x86_64-linux/ops/gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1f87c1e7bed8d3589dd790805234976e0b05898
--- /dev/null
+++ b/build/torch29-cxx11-cu130-x86_64-linux/ops/gather.py
@@ -0,0 +1,38 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for gather kernel.
+class GatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bin_ids, bins)
+ ctx.top_k = top_k
+ return kernels.gather(x, indices, bin_ids, None, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+
+ indices, bin_ids, bins = ctx.saved_tensors
+ out = kernels.scatter(grad, indices, bin_ids, None, bins, ctx.top_k)
+ return out, None, None, None, None, None
+
+
+gather = GatherOp.apply
diff --git a/build/torch29-cxx11-cu130-x86_64-linux/ops/histogram.py b/build/torch29-cxx11-cu130-x86_64-linux/ops/histogram.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b3f058ec373cbba7555704fb5e4212c3cc75d9d
--- /dev/null
+++ b/build/torch29-cxx11-cu130-x86_64-linux/ops/histogram.py
@@ -0,0 +1,27 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for histogram kernel.
+# NOTE: Does not support gradients.
+class HistogramOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, max_val: float):
+ return ops.histogram(x, max_val)
+
+
+histogram = HistogramOp.apply
diff --git a/build/torch29-cxx11-cu130-x86_64-linux/ops/histogram_benchmark.py b/build/torch29-cxx11-cu130-x86_64-linux/ops/histogram_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..c57b7bf8228e01237236748147368b09ffdf8072
--- /dev/null
+++ b/build/torch29-cxx11-cu130-x86_64-linux/ops/histogram_benchmark.py
@@ -0,0 +1,78 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import numpy as np
+import torch
+from absl.testing import parameterized
+
+from .. import ops
+
+_HISTOGRAM_TESTS = (
+ (16384, torch.int32, 2),
+ (16384, torch.int32, 4),
+ (16384, torch.int32, 8),
+ (16384, torch.int32, 16),
+ (16384, torch.int32, 32),
+ (16384, torch.int32, 64),
+ (16384, torch.int32, 128),
+ (16384, torch.int32, 256),
+)
+
+
+def benchmark_function(fn, iterations=10):
+ # Run once to get rid of startup overhead.
+ fn()
+ times = []
+ for _ in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+ start.record()
+ fn()
+ end.record()
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ times = np.array(times)
+ return times.mean(), times.std(), times.max(), times.min()
+
+
+def log_benchmark(arguments, mean_t, std_t):
+ print('=' * 60)
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t))
+ print('=' * 60)
+
+
+class HistogramBenchmark(parameterized.TestCase):
+
+ @parameterized.parameters(*_HISTOGRAM_TESTS)
+ def testHistogram(self, n, dtype, max_val):
+ x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
+
+ mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.histogram(x, max_val),)
+ arguments = {
+ 'n': n,
+ 'dtype': dtype,
+ 'max_val': max_val,
+ }
+ log_benchmark(arguments, mean_t, std_t)
+
+ @parameterized.parameters(*_HISTOGRAM_TESTS)
+ def testTorchHistogram(self, n, dtype, max_val):
+ x = torch.randint(0, 128, (n,)).cuda().to(dtype)
+
+ mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.histc(x, max_val, 0, max_val - 1),)
+ arguments = {
+ 'n': n,
+ 'dtype': dtype,
+ 'max_val': max_val,
+ }
+ log_benchmark(arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch29-cxx11-cu130-x86_64-linux/ops/matmul_benchmark.py b/build/torch29-cxx11-cu130-x86_64-linux/ops/matmul_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ccc5dcec5e9a663794fad944c45285869c4d1c1
--- /dev/null
+++ b/build/torch29-cxx11-cu130-x86_64-linux/ops/matmul_benchmark.py
@@ -0,0 +1,415 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+
+# import stk
+
+# try:
+# import stk
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/ops/matmul_benchmark.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+
+# Calling tensor.t() calls tensor.transpose(0, 1) which calls
+# torch.as_strided(...). Circumvent this chain to avoid an overhead
+# this adds.
+def transpose_view(x):
+ return torch.as_strided(
+ x,
+ (x.shape[1], x.shape[0]),
+ (x.stride()[1], x.stride()[0]),
+ )
+
+
+_MATMUL_TESTS = (
+ (64 * 1024, 512, 2048, 64),
+ (32 * 1024, 768, 3072, 64),
+ (8 * 1024, 1024, 4096, 64),
+ (4 * 2048, 4096, 4 * 4096, 4),
+)
+
+
+def log_benchmark(name, arguments, time, std, flops):
+ benchmark_util.log_benchmark(name, arguments, time, std)
+ print('flops = {:.2f}B'.format(flops / 1e9))
+ print('throughput = {:.2f}T'.format(flops / 1e9 / time))
+ print('=' * 60)
+
+
+class MatmulBenchmark(parameterized.TestCase):
+
+ def build_sparse_matrix(self, x, padded_bins, fhs, ne):
+ blocking = 128
+ padded_tokens, _ = x.size()
+ assert padded_tokens % blocking == 0
+ assert fhs % blocking == 0
+
+ # Offsets for the sparse matrix. All rows have the
+ # same number of nonzero blocks dictated by the
+ # dimensionality of a single expert.
+ block_rows = padded_tokens // blocking
+ blocks_per_row = fhs // blocking
+ offsets = torch.arange(
+ 0,
+ block_rows * blocks_per_row + 1,
+ blocks_per_row,
+ dtype=torch.int32,
+ device=x.device,
+ )
+
+ # Indices for the sparse matrix. The indices for
+ # the intermediate matrix are dynamic depending
+ # on the mapping of tokens to experts.
+ column_indices = ops.topology(
+ padded_bins,
+ blocking,
+ block_rows,
+ blocks_per_row,
+ )
+ data = torch.empty(
+ column_indices.numel(),
+ blocking,
+ blocking,
+ dtype=torch.float16,
+ device=x.device,
+ )
+ shape = (padded_tokens, fhs * ne)
+ row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
+ return stk.Matrix(shape, data, row_indices, column_indices, offsets)
+
+ def build_input_matrix(self, sl, hs, ne):
+ x = torch.randn((sl, hs)).cuda().half()
+
+ # Assign tokens to experts uniformly.
+ top_expert = torch.arange(0, sl).cuda().int() % ne
+
+ bin_ids, indices = ops.sort(top_expert)
+ tokens_per_expert = ops.histogram(top_expert, ne)
+ padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ out = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, 1)
+ return out, padded_bins
+
+ def build_weight_matrix(self, ne, hs, fhs):
+ return torch.randn((hs, ne * fhs)).cuda().half()
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear0_Fwd_SDD_NT(self, sl, hs, fhs, ne):
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
+ w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+ topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+ w = transpose_view(w)
+
+ def benchmark():
+ return stk.ops.sdd(x, w, topo)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '0::Fwd::SDD::NT',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * fhs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear0_GradX_DSD_NN(self, sl, hs, fhs, ne):
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
+ w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+ topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+
+ def benchmark():
+ return stk.ops.dsd(topo, w)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '0::GradX::DSD::NN',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * fhs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear0_GradW_DSD_TN(self, sl, hs, fhs, ne):
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
+ topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+ topo = topo.t()
+
+ def benchmark():
+ return stk.ops.dsd(topo, x)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '0::GradW::DSD::TN',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * fhs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear1_Fwd_DSD_NN(self, sl, hs, fhs, ne):
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
+ w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+ x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+
+ def benchmark():
+ return stk.ops.dsd(x, w)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '1::Fwd::DSD::NN',
+ arguments,
+ mean_t,
+ std_t,
+ x.nnz * hs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear1_GradX_SDD_NT(self, sl, hs, fhs, ne):
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
+ w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+ x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+ out = stk.ops.dsd(x, w)
+ w = transpose_view(w)
+
+ def benchmark():
+ return stk.ops.sdd(out, w, x)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '1::GradX::SDD::NT',
+ arguments,
+ mean_t,
+ std_t,
+ x.nnz * hs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear1_GradW_DSD_TN(self, sl, hs, fhs, ne):
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
+ w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
+ x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
+ out = stk.ops.dsd(x, w)
+ x = x.t()
+
+ def benchmark():
+ return stk.ops.dsd(x, out)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '1::GradW::DSD::TN',
+ arguments,
+ mean_t,
+ std_t,
+ x.nnz * hs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear0_Fwd_DDD_NT(self, sl, hs, fhs, ne):
+ assert (sl % ne) == 0
+ x = torch.randn((ne, sl // ne, hs)).cuda().half()
+ w = torch.randn((ne, hs, fhs)).cuda().half()
+
+ w = w.transpose(1, 2).contiguous()
+ w = w.transpose(1, 2)
+
+ def benchmark():
+ return torch.bmm(x, w)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '0::Fwd:DDD::NT',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * fhs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear0_GradX_DDD_NN(self, sl, hs, fhs, ne):
+ assert (sl % ne) == 0
+ x = torch.randn((ne, sl // ne, hs)).cuda().half()
+ w = torch.randn((ne, hs, fhs)).cuda().half()
+ out = torch.bmm(x, w)
+ w = w.transpose(1, 2).contiguous()
+
+ def benchmark():
+ return torch.bmm(out, w)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '0:GradX:DDD::NN',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * fhs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear0_GradW_DDD_TN(self, sl, hs, fhs, ne):
+ assert (sl % ne) == 0
+ x = torch.randn((ne, sl // ne, hs)).cuda().half()
+ w = torch.randn((ne, hs, fhs)).cuda().half()
+ out = torch.bmm(x, w)
+ out = out.transpose(1, 2)
+
+ def benchmark():
+ return torch.bmm(out, x)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '0:GradW:DDD::TN',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * fhs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear1_Fwd_DDD_NN(self, sl, hs, fhs, ne):
+ assert (sl % ne) == 0
+ x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+ w = torch.randn((ne, fhs, hs)).cuda().half()
+
+ def benchmark():
+ return torch.bmm(x, w)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '1::Fwd::DDD::NN',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * hs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear1_GradX_DDD_NT(self, sl, hs, fhs, ne):
+ assert (sl % ne) == 0
+ x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+ w = torch.randn((ne, fhs, hs)).cuda().half()
+ out = torch.bmm(x, w)
+ w = torch.transpose(w, 1, 2)
+
+ def benchmark():
+ return torch.bmm(out, w)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '1::GradX::DDD::NT',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * hs * 2,
+ )
+
+ @parameterized.parameters(*_MATMUL_TESTS)
+ def testFFN_Linear1_GradW_DDD_TN(self, sl, hs, fhs, ne):
+ assert (sl % ne) == 0
+ x = torch.randn((ne, sl // ne, fhs)).cuda().half()
+ w = torch.randn((ne, fhs, hs)).cuda().half()
+ out = torch.bmm(x, w)
+ x = torch.transpose(x, 1, 2)
+
+ def benchmark():
+ return torch.bmm(x, out)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'ffn_hidden_size': fhs,
+ 'num_experts': ne,
+ }
+ log_benchmark(
+ '1::GradW::DDD::TN',
+ arguments,
+ mean_t,
+ std_t,
+ x.numel() * hs * 2,
+ )
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch29-cxx11-cu130-x86_64-linux/ops/padded_gather.py b/build/torch29-cxx11-cu130-x86_64-linux/ops/padded_gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1cf4047c9494394d2a3884ba8830179013db7ff
--- /dev/null
+++ b/build/torch29-cxx11-cu130-x86_64-linux/ops/padded_gather.py
@@ -0,0 +1,55 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for padded_gather kernel.
+class PaddedGatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bin_ids, bins, padded_bins)
+ ctx.top_k = top_k
+ return kernels.padded_gather(
+ x,
+ indices,
+ bin_ids,
+ None,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+
+ indices, bin_ids, bins, padded_bins = ctx.saved_tensors
+ out = kernels.padded_scatter(
+ grad,
+ indices,
+ bin_ids,
+ None,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+ return out, None, None, None, None, None
+
+
+padded_gather = PaddedGatherOp.apply
diff --git a/build/torch29-cxx11-cu130-x86_64-linux/ops/padded_scatter.py b/build/torch29-cxx11-cu130-x86_64-linux/ops/padded_scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..61e021b81497e472cda5d72bdac557a0ca92d262
--- /dev/null
+++ b/build/torch29-cxx11-cu130-x86_64-linux/ops/padded_scatter.py
@@ -0,0 +1,98 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for padded_scatter kernel.
+class PaddedScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+ ):
+ maybe_x = [x] if ctx.needs_input_grad[3] else []
+ ctx.save_for_backward(
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ *maybe_x,
+ )
+ ctx.top_k = top_k
+ ctx.x_shape = x.shape
+ return kernels.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ saved_tensors = ctx.saved_tensors
+
+ indices, bin_ids, weights, bins, padded_bins = saved_tensors[:5]
+ dgrad = None
+ if ctx.needs_input_grad[0]:
+ dgrad = kernels.padded_gather(
+ grad,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[3]: # need wgrad
+ x = saved_tensors[-1]
+ wgrad = kernels.padded_scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ ctx.top_k,
+ )
+ return dgrad, None, None, wgrad, None, None, None, None
+
+
+def padded_scatter(
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ padded_bins: torch.Tensor,
+ top_k: int,
+):
+ return PaddedScatterOp.apply(
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
diff --git a/build/torch29-cxx11-cu130-x86_64-linux/ops/padded_scatter_benchmark.py b/build/torch29-cxx11-cu130-x86_64-linux/ops/padded_scatter_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..c575cfe7487d346ba9ec18bbb7ef17f2eb77ec51
--- /dev/null
+++ b/build/torch29-cxx11-cu130-x86_64-linux/ops/padded_scatter_benchmark.py
@@ -0,0 +1,66 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import torch
+from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+_PADDED_SCATTER_BENCHMARK = (
+ # dMoE-Medium, 8-way EMP.
+ (1024 * 16, 1024, 8, 4),
+ # dMoE-Medium, post-all-to-all.
+ (1024 * 16 * 4, 1024, 8, 1),
+)
+
+
+class PaddedScatterTest(parameterized.TestCase):
+
+ @parameterized.parameters(*_PADDED_SCATTER_BENCHMARK)
+ def testPaddedScatter(self, sl, hs, ne, top_k):
+ # Create the data and indices.
+ x = torch.randn((sl, hs)).cuda().half()
+
+ # Randomly assign tokens to experts.
+ top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int()
+ bin_ids, indices = ops.sort(top_expert)
+ tokens_per_expert = ops.histogram(top_expert, ne)
+ padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+
+ # Sample weights for the scatter reduce.
+ weights = torch.rand((sl * top_k,)).cuda().half()
+
+ # Gather the data to prepare for backwards.
+ x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
+
+ def benchmark():
+ return ops.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ time, std = benchmark_util.benchmark_function(benchmark)
+ benchmark_util.log_benchmark(
+ 'Padded Scatter',
+ {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'num_experts': ne,
+ 'top_k': top_k,
+ },
+ time,
+ std,
+ )
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch29-cxx11-cu130-x86_64-linux/ops/permute_benchmark.py b/build/torch29-cxx11-cu130-x86_64-linux/ops/permute_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..6536eeeae402659a087e5c51ef9840627af56501
--- /dev/null
+++ b/build/torch29-cxx11-cu130-x86_64-linux/ops/permute_benchmark.py
@@ -0,0 +1,149 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import torch
+from absl.testing import parameterized
+
+from .. import benchmark_util, ops
+
+_PERMUTE_TESTS = (
+ (16384, 768, 2),
+ (16384, 768, 4),
+ (16384, 768, 8),
+ (16384, 768, 16),
+ (16384, 768, 32),
+ (16384, 768, 64),
+ (16384, 768, 128),
+ (16384 * 8, 768, 2),
+ (16384 * 8, 768, 4),
+ (16384 * 8, 768, 8),
+ (16384 * 8, 768, 16),
+ (16384 * 8, 768, 32),
+ (16384 * 8, 768, 64),
+ (16384 * 8, 768, 128),
+)
+
+
+class PermuteBenchmark(parameterized.TestCase):
+
+ @parameterized.parameters(*_PERMUTE_TESTS)
+ def testBinnedGather(self, sl, hs, ne):
+ # NOTE: Capacity factor == 1.
+ ec = sl // ne
+
+ # Create the data and indices.
+ x = torch.randn((sl, hs)).cuda().half()
+ top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+ bin_ids, indices = ops.sort(top_expert)
+ tokens_per_expert = ops.histogram(indices, ne)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+
+ def benchmark():
+ return ops.binned_gather(x, indices, bins, ec)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'num_experts': ne,
+ }
+ benchmark_util.log_benchmark('BinnedGather', arguments, mean_t, std_t)
+
+ @parameterized.parameters(*_PERMUTE_TESTS)
+ def testBinnedScatter(self, sl, hs, ne):
+ # NOTE: Capacity factor == 1.
+ ec = sl // ne
+
+ # Create the data and indices.
+ x = torch.randn((sl, hs)).cuda().half()
+ top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+ bin_ids, indices = ops.sort(top_expert)
+ tokens_per_expert = ops.histogram(indices, ne)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ x = ops.binned_gather(x, indices, bins, ec)
+
+ def benchmark():
+ return ops.binned_scatter(x, indices, bins)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'num_experts': ne,
+ }
+ benchmark_util.log_benchmark('BinnedScatter', arguments, mean_t, std_t)
+
+ @parameterized.parameters(*_PERMUTE_TESTS)
+ def testPaddedGather(self, sl, hs, ne):
+ # Create the data and indices.
+ x = torch.randn((sl, hs)).cuda().half()
+
+ # Randomly assign tokens to experts.
+ top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+ bin_ids, indices = ops.sort(top_expert)
+ tokens_per_expert = ops.histogram(top_expert, ne)
+ padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+
+ def benchmark():
+ return ops.padded_gather(x, indices, bin_ids, bins, padded_bins)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'num_experts': ne,
+ }
+ benchmark_util.log_benchmark('PaddedGather', arguments, mean_t, std_t)
+
+ @parameterized.parameters(*_PERMUTE_TESTS)
+ def testPaddedScatter(self, sl, hs, ne):
+ # Create the data and indices.
+ x = torch.randn((sl, hs)).cuda().half()
+
+ # Randomly assign tokens to experts.
+ top_expert = torch.randint(0, ne, (sl,)).cuda().int()
+ bin_ids, indices = ops.sort(top_expert)
+ tokens_per_expert = ops.histogram(top_expert, ne)
+ padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins)
+
+ def benchmark():
+ return ops.padded_scatter(x, indices, bin_ids, bins, padded_bins)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'num_experts': ne,
+ }
+ benchmark_util.log_benchmark('PaddedScatter', arguments, mean_t, std_t)
+
+ @parameterized.parameters(*_PERMUTE_TESTS)
+ def testCopy(self, sl, hs, ne):
+ # NOTE: Capacity factor == 1.
+ # ec = sl // ne
+
+ # Create the data and indices.
+ x = torch.randn((sl, hs)).cuda().half()
+ y = x.clone()
+
+ def benchmark():
+ return y.copy_(x)
+
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
+ arguments = {
+ 'sequence_length': sl,
+ 'hidden_size': hs,
+ 'num_experts': ne,
+ }
+ benchmark_util.log_benchmark('Copy', arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch29-cxx11-cu130-x86_64-linux/ops/repeat.py b/build/torch29-cxx11-cu130-x86_64-linux/ops/repeat.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e9e09de5f857d51cd758ab30b2f3a846d6f9275
--- /dev/null
+++ b/build/torch29-cxx11-cu130-x86_64-linux/ops/repeat.py
@@ -0,0 +1,10 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+
+def repeat(x: torch.Tensor, tiling: torch.Size):
+ if all((t == 1 for t in tiling)):
+ return x
+ return x.repeat(*tiling)
diff --git a/build/torch29-cxx11-cu130-x86_64-linux/ops/replicate.py b/build/torch29-cxx11-cu130-x86_64-linux/ops/replicate.py
new file mode 100644
index 0000000000000000000000000000000000000000..26daf0eede330603a4b8ea7167faf1411d07ca93
--- /dev/null
+++ b/build/torch29-cxx11-cu130-x86_64-linux/ops/replicate.py
@@ -0,0 +1,36 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for replicate kernel.
+class ReplicateOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, bins: torch.Tensor, num_outputs: int):
+ ctx.save_for_backward(bins)
+ out = torch.empty((x.shape[0], num_outputs), dtype=x.dtype, device=x.device)
+ ops.replicate_forward(x, bins, out)
+ return out
+
+ @staticmethod
+ def backward(ctx: Any, grad: torch.Tensor):
+ bins, = ctx.saved_tensors
+ out = torch.empty((grad.shape[0], bins.shape[0]), dtype=grad.dtype, device=grad.device)
+ ops.replicate_backward(grad, bins, out)
+ return out, None, None
+
+
+replicate = ReplicateOp.apply
diff --git a/build/torch29-cxx11-cu130-x86_64-linux/ops/round_up.py b/build/torch29-cxx11-cu130-x86_64-linux/ops/round_up.py
new file mode 100644
index 0000000000000000000000000000000000000000..6cf6bc873c9f448c5fa9126ebcfd66e8688002af
--- /dev/null
+++ b/build/torch29-cxx11-cu130-x86_64-linux/ops/round_up.py
@@ -0,0 +1,14 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+
+def round_up(x: torch.Tensor, value: int):
+ assert isinstance(value, int)
+ assert x.dtype == torch.int32
+
+ # TODO(tgale): If this becomes and issue
+ # do this in a custom kernel. We only expect
+ # to use this on arrays of less than 1k elements.
+ return torch.div(x + (value - 1), value, rounding_mode='trunc') * value
diff --git a/build/torch29-cxx11-cu130-x86_64-linux/ops/scatter.py b/build/torch29-cxx11-cu130-x86_64-linux/ops/scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4605d9b46f387761b070352365f223dbfe69d47
--- /dev/null
+++ b/build/torch29-cxx11-cu130-x86_64-linux/ops/scatter.py
@@ -0,0 +1,72 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Optional
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for scatter kernel.
+class ScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+ ) -> torch.Tensor:
+ maybe_x = [x] if ctx.needs_input_grad[3] else []
+ ctx.save_for_backward(indices, bin_ids, weights, bins, *maybe_x)
+ ctx.top_k = top_k
+ ctx.x_shape = x.shape
+ return kernels.scatter(x, indices, bin_ids, weights, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ saved_tensors = ctx.saved_tensors
+
+ indices, bin_ids, weights, bins = saved_tensors[:4]
+ dgrad = None
+ if ctx.needs_input_grad[0]:
+ dgrad = kernels.gather(
+ grad,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[3]: # need wgrad
+ x = saved_tensors[-1]
+ wgrad = kernels.scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bin_ids,
+ bins,
+ ctx.top_k,
+ )
+ return dgrad, None, None, wgrad, None, None, None
+
+
+def scatter(
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+) -> Optional[torch.Tensor]:
+ return ScatterOp.apply(x, indices, bin_ids, weights, bins, top_k)
diff --git a/build/torch29-cxx11-cu130-x86_64-linux/ops/sort.py b/build/torch29-cxx11-cu130-x86_64-linux/ops/sort.py
new file mode 100644
index 0000000000000000000000000000000000000000..bda3bf64283e39533c2eae3627e76bb2d0262c9f
--- /dev/null
+++ b/build/torch29-cxx11-cu130-x86_64-linux/ops/sort.py
@@ -0,0 +1,38 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Optional, Tuple
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+_BITS_FOR_DTYPE = {
+ torch.int16: 16,
+ torch.int32: 32,
+ torch.int64: 64,
+}
+
+
+# Autograd wrapper for sort kernel.
+# NOTE: Does not support gradients.
+class SortOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, end_bit: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
+ if end_bit is None:
+ end_bit = _BITS_FOR_DTYPE[x.dtype]
+ x_out = torch.empty_like(x)
+ iota_out = torch.empty_like(x)
+ ops.sort(x, end_bit, x_out, iota_out)
+ return (x_out, iota_out)
+
+
+sort = SortOp.apply
diff --git a/build/torch29-cxx11-cu130-x86_64-linux/ops/sort_benchmark.py b/build/torch29-cxx11-cu130-x86_64-linux/ops/sort_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..a92ff957d4c552c6e61d9279a7989795472af7b7
--- /dev/null
+++ b/build/torch29-cxx11-cu130-x86_64-linux/ops/sort_benchmark.py
@@ -0,0 +1,85 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import unittest
+
+import numpy as np
+import torch
+from absl.testing import parameterized
+
+from .. import ops
+
+_SORT_TESTS = (
+ (16384, torch.int32, None),
+ (16384, torch.int32, 2),
+ (16384, torch.int32, 128),
+)
+
+_BASELINE_SORT_TESTS = ((16384,),)
+
+
+def numpy_dtype(dtype):
+ types = {
+ torch.int16: np.int16,
+ torch.int32: np.int32,
+ torch.int64: np.int64,
+ }
+ return types[dtype]
+
+
+def benchmark_function(fn, iterations=10):
+ # Run once to get rid of startup overhead.
+ fn()
+ times = []
+ for _ in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+ start.record()
+ fn()
+ end.record()
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ times = np.array(times)
+ return times.mean(), times.std(), times.max(), times.min()
+
+
+def log_benchmark(arguments, mean_t, std_t):
+ print('=' * 60)
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t))
+ print('=' * 60)
+
+
+class SortBenchmark(parameterized.TestCase):
+
+ @parameterized.parameters(*_SORT_TESTS)
+ def testSort(self, n, dtype, max_val):
+ if max_val is None:
+ max_val = np.iinfo(numpy_dtype(dtype)).max
+ end_bit = int(np.ceil(np.log2(max_val)))
+ x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
+
+ mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.sort(x, end_bit),)
+ arguments = {
+ 'n': n,
+ 'dtype': dtype,
+ 'max_val': max_val,
+ }
+ log_benchmark(arguments, mean_t, std_t)
+
+ @parameterized.parameters(*_BASELINE_SORT_TESTS)
+ def testTorchSort(self, n):
+ x = torch.randint(0, 128, (n,)).cuda().to(torch.int32)
+
+ mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.sort(x))
+ arguments = {
+ 'n': n,
+ }
+ log_benchmark(arguments, mean_t, std_t)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch29-cxx11-cu130-x86_64-linux/ops/stk_autocast.py b/build/torch29-cxx11-cu130-x86_64-linux/ops/stk_autocast.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a3626e5e0eec51339c95a448bca84be14a2ca93
--- /dev/null
+++ b/build/torch29-cxx11-cu130-x86_64-linux/ops/stk_autocast.py
@@ -0,0 +1,39 @@
+# vendored from
+# https://github.com/stanford-futuredata/stk/blob/736313768ef697ce13a0594a41b2512a0fbc9884/stk/backend/autocast.py
+import functools
+import torch
+
+
+def _is_eligible(x):
+ return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64)
+
+
+def _cast(x, dtype):
+ if isinstance(x, torch.Tensor) and _is_eligible(x):
+ return x.to(dtype)
+ elif isinstance(x, map):
+ return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()}
+ elif isinstance(x, list) or isinstance(x, tuple):
+ return type(x)(map(lambda y: _cast(y, dtype), x))
+ return x
+
+
+def custom_fwd(fwd):
+ """Wrap a custom autograd function that always uses autocast dtype."""
+
+ @functools.wraps(fwd)
+ def decorate_fwd(*args, **kwargs):
+ if torch.is_autocast_enabled():
+ with torch.autocast(device_type="cuda", enabled=False):
+ dtype = torch.get_autocast_gpu_dtype()
+ return fwd(*_cast(args, dtype), **_cast(kwargs, dtype))
+ return fwd(*args, **kwargs)
+ return decorate_fwd
+
+
+def custom_bwd(bwd):
+ @functools.wraps(bwd)
+ def decorate_bwd(*args, **kwargs):
+ with torch.autocast(device_type="cuda", enabled=False):
+ return bwd(*args, **kwargs)
+ return decorate_bwd
\ No newline at end of file
diff --git a/build/torch29-cxx11-cu130-x86_64-linux/ops/sum.py b/build/torch29-cxx11-cu130-x86_64-linux/ops/sum.py
new file mode 100644
index 0000000000000000000000000000000000000000..e00c1aa68e1193f5b72f75a2edc37de8d505facc
--- /dev/null
+++ b/build/torch29-cxx11-cu130-x86_64-linux/ops/sum.py
@@ -0,0 +1,9 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+import torch
+
+
+def sum(x: torch.Tensor, dim: int = 0):
+ if x.shape[dim] == 1:
+ return x.squeeze(dim=dim)
+ return x.sum(dim=dim)
diff --git a/build/torch29-cxx11-cu130-x86_64-linux/ops/topology.py b/build/torch29-cxx11-cu130-x86_64-linux/ops/topology.py
new file mode 100644
index 0000000000000000000000000000000000000000..76a50d3164db20534b099dcb4d8487a7aef25d15
--- /dev/null
+++ b/build/torch29-cxx11-cu130-x86_64-linux/ops/topology.py
@@ -0,0 +1,45 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# Wrap this in a try-block with better error message and
+# instructions for building the c++ operations.
+try:
+ from .._ops import ops # type: ignore
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
+
+
+# Autograd wrapper for topology kernel.
+# NOTE: Does not support gradients.
+class TopologyOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(
+ ctx: Any,
+ padded_bins: torch.Tensor,
+ block_size: int,
+ output_block_rows: int,
+ output_block_columns: int,
+ ):
+ out = torch.empty(
+ output_block_rows * output_block_columns,
+ dtype=torch.int16,
+ device=padded_bins.device,
+ )
+ ops.indices(
+ padded_bins,
+ block_size,
+ output_block_rows,
+ output_block_columns,
+ out,
+ )
+ return out
+
+
+topology = TopologyOp.apply
diff --git a/build/torch29-cxx11-cu130-x86_64-linux/stk/__init__.py b/build/torch29-cxx11-cu130-x86_64-linux/stk/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..73c40c267b1c3f4949e9c957a5d2c9f682dfc1a6
--- /dev/null
+++ b/build/torch29-cxx11-cu130-x86_64-linux/stk/__init__.py
@@ -0,0 +1,7 @@
+# import stk.random
+# import stk.ops
+# from stk.matrix import Matrix
+
+from . import random
+from . import ops
+from .matrix import Matrix
diff --git a/build/torch29-cxx11-cu130-x86_64-linux/stk/backend/__init__.py b/build/torch29-cxx11-cu130-x86_64-linux/stk/backend/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/build/torch29-cxx11-cu130-x86_64-linux/stk/backend/autocast.py b/build/torch29-cxx11-cu130-x86_64-linux/stk/backend/autocast.py
new file mode 100644
index 0000000000000000000000000000000000000000..97f6e919a60f3fd579ed0215031008d14111dc96
--- /dev/null
+++ b/build/torch29-cxx11-cu130-x86_64-linux/stk/backend/autocast.py
@@ -0,0 +1,37 @@
+import functools
+import torch
+
+
+def _is_eligible(x):
+ return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64)
+
+
+def _cast(x, dtype):
+ if isinstance(x, torch.Tensor) and _is_eligible(x):
+ return x.to(dtype)
+ elif isinstance(x, map):
+ return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()}
+ elif isinstance(x, list) or isinstance(x, tuple):
+ return type(x)(map(lambda y: _cast(y, dtype), x))
+ return x
+
+
+def custom_fwd(fwd):
+ """Wrap a custom autograd function that always uses autocast dtype."""
+
+ @functools.wraps(fwd)
+ def decorate_fwd(*args, **kwargs):
+ if torch.is_autocast_enabled():
+ with torch.autocast(device_type="cuda", enabled=False):
+ dtype = torch.get_autocast_gpu_dtype()
+ return fwd(*_cast(args, dtype), **_cast(kwargs, dtype))
+ return fwd(*args, **kwargs)
+ return decorate_fwd
+
+
+def custom_bwd(bwd):
+ @functools.wraps(bwd)
+ def decorate_bwd(*args, **kwargs):
+ with torch.autocast(device_type="cuda", enabled=False):
+ return bwd(*args, **kwargs)
+ return decorate_bwd
diff --git a/build/torch29-cxx11-cu130-x86_64-linux/stk/backend/sputnik.py b/build/torch29-cxx11-cu130-x86_64-linux/stk/backend/sputnik.py
new file mode 100644
index 0000000000000000000000000000000000000000..220c947bc1e932e8c77cc30f4069e9930f1aa962
--- /dev/null
+++ b/build/torch29-cxx11-cu130-x86_64-linux/stk/backend/sputnik.py
@@ -0,0 +1,316 @@
+import torch
+
+from ..backend import triton_kernels as backend
+from ..backend.autocast import custom_bwd, custom_fwd
+
+
+def _standardize_shape(x, transpose):
+ if transpose:
+ return torch.Size((x[1], x[0]))
+ return x
+
+
+def _sparse_transpose(x):
+ return (torch.Size((x[0][1], x[0][0])), ) + x[1:]
+
+
+def _transpose_helper(x, transpose):
+ if isinstance(x, torch.Tensor):
+ return x.t() if transpose else x
+ if transpose:
+ x = _sparse_transpose(x)
+ return x + (transpose,)
+
+
+def _wrap(x):
+ if isinstance(x, torch.Tensor):
+ return (x,)
+ return x
+
+
+def _is_transposed(x):
+ return (not x.is_contiguous() and
+ x.stride()[0] == 1 and
+ x.stride()[1] == x.size()[0])
+
+
+def _call_helper(op, out, a, b, trans_a, trans_b):
+ args = (_wrap(_transpose_helper(a, trans_a)) +
+ _wrap(_transpose_helper(b, trans_b)))
+ if isinstance(out, tuple):
+ args = args + out
+ return op(*args)
+
+
+def _preprocess_inputs(lhs, rhs, dy):
+ if isinstance(lhs, torch.Tensor) and _is_transposed(lhs):
+ lhs = lhs.t()
+ if isinstance(rhs, torch.Tensor) and _is_transposed(rhs):
+ rhs = rhs.t()
+ if (isinstance(dy, torch.Tensor) and
+ not dy.is_contiguous() and
+ not _is_transposed(dy)):
+ dy = dy.contiguous()
+ if isinstance(dy, tuple) and not dy[1].is_contiguous():
+ dy = (dy[0], dy[1].contiguous()) + dy[2:]
+ return lhs, rhs, dy
+
+
+def _postprocess_outputs(x, transpose, grad):
+ if isinstance(x, torch.Tensor) and transpose:
+ return grad.t()
+ return grad
+
+
+def _lhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs):
+ lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy)
+
+ a, b = (rhs, dy) if trans_lhs else (dy, rhs)
+ trans_a = trans_lhs and trans_rhs
+ trans_b = trans_lhs or not trans_rhs
+ out = _call_helper(op, lhs, a, b, trans_a, trans_b)
+ return _postprocess_outputs(lhs, trans_lhs, out)
+
+
+def _rhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs):
+ lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy)
+
+ a, b = (dy, lhs) if trans_rhs else (lhs, dy)
+ trans_a = not trans_lhs or trans_rhs
+ trans_b = trans_lhs and trans_rhs
+ out = _call_helper(op, rhs, a, b, trans_a, trans_b)
+ return _postprocess_outputs(rhs, trans_rhs, out)
+
+
+class DSD(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs):
+ ctx.save_for_backward(data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ rhs)
+ ctx.shape = _standardize_shape(shape, transpose_a)
+ ctx.transpose_a = transpose_a
+
+ out = torch.empty(
+ (shape[0], rhs.size()[1]),
+ dtype=rhs.dtype,
+ device=rhs.device)
+
+ backend.dsd(shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs,
+ out)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs = (ctx.shape,) + saved_tensors[:-1]
+ rhs = saved_tensors[-1]
+ trans_a = ctx.transpose_a
+ trans_b = _is_transposed(rhs)
+
+ ddata = None
+ if ctx.needs_input_grad[1]:
+ ddata = _lhs_gradient(sdd,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ drhs = None
+ if ctx.needs_input_grad[-1]:
+ op = dds if trans_b else dsd
+ drhs = _rhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return None, ddata, None, None, None, None, None, None, None, drhs
+
+
+dsd = DSD.apply
+
+
+class DDS(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b):
+ ctx.save_for_backward(lhs,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t)
+ ctx.shape = _standardize_shape(shape, transpose_b)
+ ctx.transpose_b = transpose_b
+ out = torch.empty((lhs.size()[0], shape[1]),
+ dtype=lhs.dtype,
+ device=lhs.device)
+ backend.dds(lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b,
+ out)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs = saved_tensors[0]
+ rhs = (ctx.shape,) + saved_tensors[1:]
+ trans_a = _is_transposed(lhs)
+ trans_b = ctx.transpose_b
+
+ dlhs = None
+ if ctx.needs_input_grad[0]:
+ op = dsd if trans_a else dds
+ dlhs = _lhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ ddata = None
+ if ctx.needs_input_grad[2]:
+ ddata = _rhs_gradient(sdd,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return dlhs, None, ddata, None, None, None, None, None, None, None
+
+
+dds = DDS.apply
+
+
+class SDD(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx,
+ lhs,
+ rhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t):
+ ctx.save_for_backward(
+ lhs,
+ rhs,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t)
+ ctx.shape = shape
+ out = torch.empty(
+ data.shape,
+ dtype=lhs.dtype,
+ device=lhs.device)
+ backend.sdd(lhs,
+ rhs,
+ shape,
+ out,
+ offsets,
+ row_indices,
+ column_indices)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy):
+ saved_tensors = ctx.saved_tensors
+ lhs, rhs = saved_tensors[:2]
+ dy = (ctx.shape, dy) + saved_tensors[2:]
+ trans_a = _is_transposed(lhs)
+ trans_b = _is_transposed(rhs)
+
+ dlhs = None
+ if ctx.needs_input_grad[0]:
+ op = dds if trans_a else dsd
+ dlhs = _lhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ drhs = None
+ if ctx.needs_input_grad[1]:
+ op = dsd if trans_b else dds
+ drhs = _rhs_gradient(op,
+ lhs,
+ rhs,
+ dy,
+ trans_a,
+ trans_b)
+ return dlhs, drhs, None, None, None, None, None, None, None, None
+
+
+sdd = SDD.apply
+
+class RowIndices(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, shape, data, offsets, column_indices):
+ out = torch.empty(
+ column_indices.shape,
+ dtype=column_indices.dtype,
+ device=column_indices.device)
+ backend.row_indices(shape, data, offsets, column_indices, out)
+ return out
+
+
+row_indices = RowIndices.apply
diff --git a/build/torch29-cxx11-cu130-x86_64-linux/stk/backend/triton_kernels.py b/build/torch29-cxx11-cu130-x86_64-linux/stk/backend/triton_kernels.py
new file mode 100644
index 0000000000000000000000000000000000000000..c535309f3321249f475367164a558f94a4f8eb86
--- /dev/null
+++ b/build/torch29-cxx11-cu130-x86_64-linux/stk/backend/triton_kernels.py
@@ -0,0 +1,393 @@
+import torch
+import triton
+import triton.language as tl
+from dataclasses import dataclass
+
+@dataclass
+class TritonConfig:
+ BLOCK_M: int = 128
+ BLOCK_N: int = 128
+ BLOCK_K: int = 32
+ BLOCK_SIZE: int = 128
+ NUM_STAGES: int = 4
+ NUM_WARPS: int = 4
+
+def _validate_matmul_dims(M: int, K: int, N: int):
+ error_string = "incompatible dimensions: tensor has dim with length: {}, which must be divisible by {}"
+ assert M % TritonConfig.BLOCK_M == 0, error_string.format(M, TritonConfig.BLOCK_M)
+ assert K % TritonConfig.BLOCK_K == 0, error_string.format(K, TritonConfig.BLOCK_K)
+ assert N % TritonConfig.BLOCK_N == 0, error_string.format(N, TritonConfig.BLOCK_N)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _sdd_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+ # matrix multiplication
+ pid = tl.program_id(0)
+ pid_m = tl.load(row_indices + pid)
+ pid_n = tl.load(column_indices + pid)
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
+ rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
+ rk = tl.arange(0, BLOCK_K)
+ # pointers
+ A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
+ B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ for k in range(0, tl.cdiv(K, BLOCK_K)):
+ a = tl.load(A)
+ b = tl.load(B)
+ acc += tl.dot(a, b)
+ A += BLOCK_K * stride_ak
+ B += BLOCK_K * stride_bk
+ #Store to sparse matrix
+ acc = acc.to(C.dtype.element_ty)
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+ cm = tl.arange(0, BLOCK_M)
+ cn = tl.arange(0, BLOCK_N)
+ C = C + pid * BLOCK_ELEMENTS + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _dsd_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices, offsets,
+ block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+
+ # matrix multiplication
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+
+ num_pid_m = tl.num_programs(0)
+ num_pid_n = tl.num_programs(1)
+ pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M)
+
+ start_inx = tl.load(offsets + pid_m)
+ end_inx = tl.load(offsets + pid_m + 1)
+
+ # pointers to sparse matrix
+ rm = tl.arange(0, BLOCK_M)
+ rak = tl.arange(0, BLOCK_K)
+
+ A += (rm[:, None] * stride_am + rak[None, :] * stride_ak)
+
+ # pointers to dense matrix
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ rbk = tl.arange(0, BLOCK_K)
+ B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn)
+
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K)
+
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+ ak_sub_incr = BLOCK_K * stride_ak
+ bk_sub_incr = BLOCK_K * stride_bk
+ bk_block_incr = BLOCK_SIZE * stride_bk
+
+ for k in range(nsub_blocks * (end_inx - start_inx)):
+ sub_block_inx = k % nsub_blocks
+ block_inx = k // nsub_blocks
+
+ if trans_A:
+ ptr_A = A + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr
+ else:
+ ptr_A = A + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr
+
+ ptr_B = B + tl.load(column_indices + start_inx + block_inx) * bk_block_incr + sub_block_inx * bk_sub_incr
+
+ a = tl.load(ptr_A)
+ b = tl.load(ptr_B)
+ acc += tl.dot(a, b)
+
+ acc = acc.to(C.dtype.element_ty)
+
+ cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+
+ C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': TritonConfig.BLOCK_M,
+ 'BLOCK_N': TritonConfig.BLOCK_N,
+ 'BLOCK_K': TritonConfig.BLOCK_K,
+ 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
+ }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
+ ],
+ key=['M', 'N', 'K'],
+)
+@triton.jit
+def _dds_kernel(A, B, C, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ row_indices, column_indices, offsets,
+ block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
+ ):
+
+ # matrix multiplication
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+
+ num_pid_m = tl.num_programs(0)
+ num_pid_n = tl.num_programs(1)
+ pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M)
+
+ start_inx = tl.load(offsets + pid_n)
+ end_inx = tl.load(offsets + pid_n + 1)
+
+ # pointers to dense matrix
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ rak = tl.arange(0, BLOCK_K)
+
+ A += (rm[:, None] * stride_am + rak[None, :] * stride_ak)
+
+ # pointers to sparse matrix
+ rn = tl.arange(0, BLOCK_N)
+ rbk = tl.arange(0, BLOCK_K)
+ B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn)
+
+ # do matrix multiplication
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K)
+
+ BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
+
+ ak_sub_incr = BLOCK_K * stride_ak
+ ak_block_incr = BLOCK_SIZE * stride_ak
+ bk_sub_incr = BLOCK_K * stride_bk
+
+ for k in range(nsub_blocks * (end_inx - start_inx)):
+ sub_block_inx = k % nsub_blocks
+ block_inx = k // nsub_blocks
+
+ if trans_B:
+ ptr_B = B + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr
+ else:
+ ptr_B = B + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr
+
+ ptr_A = A + tl.load(column_indices + start_inx + block_inx) * ak_block_incr + sub_block_inx * ak_sub_incr
+ a = tl.load(ptr_A)
+ b = tl.load(ptr_B)
+ acc += tl.dot(a, b)
+
+ acc = acc.to(C.dtype.element_ty)
+ cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
+ tl.store(C, acc, mask=True)
+
+def dsd(shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_a,
+ rhs,
+ out
+ ):
+
+ device = rhs.device
+ trans_A = transpose_a
+ trans_B = False
+
+ if rhs.stride(0) > 1 and rhs.stride(1) > 1:
+ trans_B = True
+
+ # checks constraints
+ assert shape[1] == rhs.shape[0], "incompatible dimensions"
+ M, K = shape
+ _, N = rhs.shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if rhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ stride_am, stride_ak = data.stride(1), data.stride(2)
+ stride_bk, stride_bn = rhs.stride(0), rhs.stride(1)
+ a_column_indices = column_indices
+ a_offsets = offsets
+
+ # launch kernel
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N']))
+
+ if trans_A:
+ stride_am, stride_ak = data.stride(2), data.stride(1)
+ a_column_indices, a_offsets = column_indices_t, offsets_t
+
+ if trans_B:
+ stride_bk, stride_bn = rhs.stride(1), rhs.stride(0)
+
+ _dsd_kernel[grid](
+ data.data, rhs, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(0), out.stride(1),
+ row_indices, a_column_indices, a_offsets,
+ block_offsets_t, trans_A, trans_B,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+ # return out
+
+def dds(lhs,
+ shape,
+ data,
+ offsets,
+ row_indices,
+ column_indices,
+ offsets_t,
+ column_indices_t,
+ block_offsets_t,
+ transpose_b,
+ out
+ ):
+
+ device = lhs.device
+ trans_B = transpose_b
+ trans_A = False
+
+ if lhs.stride(0) > 1 and lhs.stride(1) > 1:
+ trans_A = True
+
+ # checks constraints
+ assert lhs.shape[1] == shape[0], "incompatible dimensions"
+ M, K = lhs.shape
+ _, N = shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if lhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ stride_am, stride_ak = lhs.stride(0), lhs.stride(1)
+ stride_bk, stride_bn = data.stride(1), data.stride(2)
+ b_column_indices = column_indices_t
+ b_offsets = offsets_t
+
+ # launch kernel
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N']))
+
+ if trans_A:
+ stride_am, stride_ak = lhs.stride(1), lhs.stride(0)
+ if trans_B:
+ stride_bk, stride_bn = data.stride(2), data.stride(1)
+ b_column_indices, b_offsets = column_indices, offsets
+
+ _dds_kernel[grid](
+ lhs, data, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(0), out.stride(1),
+ row_indices, b_column_indices, b_offsets,
+ block_offsets_t, trans_A, trans_B,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+
+def sdd(lhs,
+ rhs,
+ shape,
+ out,
+ offsets,
+ row_indices,
+ column_indices
+ ):
+
+ device = out.device
+ trans_A = False
+ trans_B = False
+
+ if lhs.stride(0) > 1 and lhs.stride(1) > 1:
+ trans_A = True
+ if rhs.stride(0) > 1 and rhs.stride(1) > 1:
+ trans_B = True
+
+ # checks constraints
+ assert lhs.shape[1] == rhs.shape[0], "incompatible dimensions"
+ M, K = lhs.shape
+ _, N = rhs.shape
+
+ _validate_matmul_dims(M, K, N)
+
+ # accumulator types
+ ACC_TYPE = tl.float32 if out.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+
+ # launch kernel
+ nnz_blocks = len(row_indices)
+ grid = lambda META: (nnz_blocks,)
+
+ stride_am, stride_ak = lhs.stride(0), lhs.stride(1)
+ stride_bk, stride_bn = rhs.stride(0), rhs.stride(1)
+
+ if trans_A:
+ stride_am, stride_ak = lhs.stride(1), lhs.stride(0)
+ if trans_B:
+ stride_bk, stride_bn = rhs.stride(1), rhs.stride(0)
+
+ _sdd_kernel[grid](
+ lhs, rhs, out, M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ out.stride(1), out.stride(2),
+ row_indices, column_indices,
+ GROUP_M=128, ACC_TYPE=ACC_TYPE
+ )
+
+@triton.jit
+def _row_indices_kernel(offsets, out):
+ pid = tl.program_id(0)
+ row_offset = tl.load(offsets + pid)
+ nnz_blocks = tl.load(offsets + pid + 1) - row_offset
+ for nnz_block in range(nnz_blocks):
+ tl.store(out + row_offset + nnz_block, pid)
+
+def row_indices(
+ shape, data, offsets, column_indices, out
+):
+ block_rows = len(offsets) - 1
+ _row_indices_kernel[(block_rows, )](offsets, out)
diff --git a/build/torch29-cxx11-cu130-x86_64-linux/stk/matrix.py b/build/torch29-cxx11-cu130-x86_64-linux/stk/matrix.py
new file mode 100644
index 0000000000000000000000000000000000000000..80f42263d6aada287adbfa52a61fe950162a9e28
--- /dev/null
+++ b/build/torch29-cxx11-cu130-x86_64-linux/stk/matrix.py
@@ -0,0 +1,329 @@
+import numpy as np
+import torch
+
+# 1. Add heavyweight (data) validation helper.
+# 2. Add construction helpers
+# 3. Make indentation consistent
+# 4. Replace asserts with descriptive errors.
+
+##
+### Validation helpers.
+##
+
+
+def _validate_matrix(shape, data, row_indices, column_indices, offsets):
+ # Data should be [nnz, block_size, block_size]
+ if data.dim() == 1:
+ data = torch.reshape(data, [data.numel(), 1, 1])
+
+ # Blocks should be square.
+ if data.shape[-2] != data.shape[-1]:
+ raise ValueError(
+ "Expected square blocking in data. "
+ f"Got block shape {[data.shape[-2], data.shape[-1]]}")
+
+ # Flatten batch dimensions on data - original shape preserved
+ # in shape argument.
+ block_size = data.shape[-1]
+ data = data.view([-1, block_size, block_size])
+
+ if data.dim() != 3:
+ raise ValueError(
+ "Expected 3D shape for data (nnz, block, block). "
+ f"Got shape {data.dim()}D shape.")
+
+ block_size = data.shape[1]
+ if shape[-2] % block_size != 0 or shape[-1] % block_size != 0:
+ raise ValueError(
+ "Matrix shape must be dividible by blocking. "
+ f"Got shape {shape} with "
+ f"{[block_size, block_size]} blocking.")
+
+ if np.prod(shape) < data.numel():
+ raise ValueError(
+ "Invalid matrix. Number of nonzeros exceeds matrix capacity "
+ f"({data.numel()} v. {np.prod(shape)})")
+
+ if row_indices.dim() != 1:
+ raise ValueError(
+ f"Expected 1D row_indices. Got {row_indices.dim()}D row_indices.")
+
+ if column_indices.dim() != 1:
+ raise ValueError(
+ f"Expected 1D column_indices. Got {column_indices.dim()}D column_indices.")
+
+ if offsets.dim() != 1:
+ raise ValueError(
+ f"Expected 1D offsets. Got {offsets.dim()}D offsets.")
+
+ if row_indices.numel() != data.shape[0]:
+ raise ValueError(
+ "Expected 1 index per nonzero block. "
+ f"Got {row_indices.numel()} row_indices for {data.shape[0]} blocks")
+
+ if column_indices.numel() != data.shape[0]:
+ raise ValueError(
+ "Expected 1 index per nonzero block. "
+ f"Got {column_indices.numel()} column_indices for {data.shape[0]} blocks")
+
+ block_rows = np.prod(shape[:-1]) / block_size
+ if offsets.numel() != block_rows + 1:
+ raise ValueError(
+ "Expected one offset per block row plus one. "
+ f"Got {offsets.numel()} offsets with {block_rows} block rows.")
+
+ is_cuda = (data.is_cuda and
+ row_indices.is_cuda and
+ column_indices.is_cuda and
+ offsets.is_cuda)
+ is_cpu = (not data.is_cuda and
+ not row_indices.is_cuda and
+ not column_indices.is_cuda and
+ not offsets.is_cuda)
+ if not (is_cuda or is_cpu):
+ raise ValueError(
+ "Expected data & meta-data on common device. "
+ f"Got data on {data.device}, row_indices on {row_indices.device} "
+ f"column_indices on {column_indices.device} and "
+ f"offsets on {offsets.device}.")
+
+ if data.dtype != torch.float16:
+ raise ValueError(
+ f"Expected float16 data. Got {data.dtype} data.")
+ if row_indices.dtype != torch.int16:
+ raise ValueError(
+ f"Expected int16 row_indices. Got {row_indices.dtype} row_indices.")
+ if column_indices.dtype != torch.int16:
+ raise ValueError(
+ f"Expected int16 column_indices. Got {column_indices.dtype} column_indices.")
+ if offsets.dtype != torch.int32:
+ raise ValueError(
+ f"Expected int32 offsets. Got {offsets.dtype} offsets.")
+ return data
+
+
+def _transpose(size, data, row_indices, column_indices, offsets):
+ block_columns = size[1] // data.shape[1]
+
+ # Sort row indices by column indices to get the transposed matrix's
+ # column indices.
+ gather_indices = column_indices.argsort()
+ column_indices_t = row_indices.gather(0, gather_indices)
+ block_offsets_t = gather_indices.int()
+
+ # NOTE: Histogram is not implemented for any integer type on CPU. Do
+ # the histogram in 32-bit float, which can exactly represent 16-bit
+ # integers.
+ column_indices_float = column_indices.float()
+
+ zero = torch.zeros((1,), dtype=torch.int32, device=data.device)
+ nnz_per_column = column_indices_float.histc(block_columns, 0, block_columns)
+ nnz_per_column = nnz_per_column.int()
+ offsets_t = torch.cat([zero, nnz_per_column.cumsum(0, dtype=torch.int32)])
+ return column_indices_t, offsets_t, block_offsets_t
+
+
+class Matrix(torch.nn.Module):
+ """A matrix stored in sparse format.
+
+ Underlying format is block compressed sparse row (BCSR).
+
+ TODO(tgale): Make this mirror torch.Tensor API as much as possible.
+ """
+
+ def __init__(self,
+ size,
+ data,
+ row_indices,
+ column_indices,
+ offsets,
+ column_indices_t=None,
+ offsets_t=None,
+ block_offsets_t=None):
+ super().__init__()
+ self._size = size
+ self._data = data
+ self._row_indices = row_indices
+ self._column_indices = column_indices
+ self._offsets = offsets
+
+ # Produce the transpose meta-data if it is not passed in.
+ if ((column_indices_t is None) or (offsets_t is None) or
+ (block_offsets_t is None)):
+ column_indices_t, offsets_t, block_offsets_t = _transpose(
+ size, data, row_indices, column_indices, offsets)
+ self._column_indices_t = column_indices_t
+ self._offsets_t = offsets_t
+ self._block_offsets_t = block_offsets_t
+
+ self._transposed = False
+
+ # Validate that our metadata will not overflow.
+ max_dim = np.iinfo(np.int16).max * self.blocking
+ if column_indices.dtype == torch.int16:
+ if size[0] > max_dim or size[1] > max_dim:
+ raise ValueError(
+ "Sparse matrix with shape {size} exceeds representable "
+ "size with 16-bit indices.")
+
+ def validate(self):
+ _validate_matrix(self._size,
+ self._data,
+ self._row_indices,
+ self._column_indices,
+ self._offsets)
+
+ # TODO(tgale): Add heavyweight data validation.
+
+ def to(self, device):
+ # TODO(tgale): Handle type conversions here. We
+ # need to set the appropriate meta-data type for
+ # the given floating-point type.
+ self._data = self._data.to(device)
+ self._row_indices = self._row_indices.to(device)
+ self._column_indices = self._column_indices.to(device)
+ self._offsets = self._offsets.to(device)
+ self._column_indices_t = self._column_indices_t.to(device)
+ self._offsets_t = self._offsets_t.to(device)
+ self._block_offsets_t = self._block_offsets_t.to(device)
+ return self
+
+ def cuda(self):
+ return self.to(torch.cuda.current_device())
+
+ def clone(self):
+ return Matrix(
+ self.size(),
+ self.data.clone(),
+ self.row_indices.clone(),
+ self.column_indices.clone(),
+ self.offsets.clone(),
+ self.column_indices_t.clone(),
+ self.offsets_t.clone(),
+ self.block_offsets_t.clone())
+
+ def t(self):
+ if self.dim() != 2:
+ raise ValueError(
+ "t() expects a tensor with <= 2 dimensions, "
+ f"but self is {self.dim()}D.")
+ out = Matrix(self.size(),
+ self.data,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+ out._transposed = not self._transposed
+ out._size = torch.Size((self._size[1], self._size[0]))
+ return out
+
+ def contiguous(self):
+ raise ValueError("Not yet implemented.")
+
+ def is_contiguous(self):
+ return not self._transposed
+
+ @property
+ def is_cuda(self):
+ return self._data.is_cuda
+
+ @property
+ def device(self):
+ return self._data.device
+
+ def size(self):
+ return self._size
+
+ @property
+ def shape(self):
+ return self.size()
+
+ def dim(self):
+ return len(self._size)
+
+ @property
+ def data(self):
+ return self._data
+
+ @property
+ def row_indices(self):
+ return self._row_indices
+
+ @property
+ def column_indices(self):
+ return self._column_indices
+
+ @property
+ def offsets(self):
+ return self._offsets
+
+ @property
+ def offsets_t(self):
+ return self._offsets_t
+
+ @property
+ def column_indices_t(self):
+ return self._column_indices_t
+
+ @property
+ def block_offsets_t(self):
+ return self._block_offsets_t
+
+ @property
+ def dtype(self):
+ return self.data.dtype
+
+ @property
+ def nnz(self):
+ return self.data.numel()
+
+ @property
+ def blocking(self):
+ return self.data.shape[1]
+
+ @property
+ def requires_grad(self):
+ return self.data.requires_grad
+
+ def requires_grad_(self, x):
+ self.data.requires_grad_(x)
+ return self
+
+ def view(self, *shape):
+ assert self.is_contiguous()
+ if shape[-1] != self.size()[-1]:
+ raise ValueError(
+ "Can't change view on compressed dimension. "
+ f"{self.size()[-1]} v. {shape[-1]}.")
+ if np.prod(shape) != np.prod(self.size()):
+ raise ValueError(
+ "Mismatch in numel of Matrix and new shape. "
+ f"{np.prod(self.size())} v. {np.prod(shape)}")
+ return Matrix(shape,
+ self.data,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+
+ @property
+ def grad(self):
+ # TODO(tgale): Make sure this mirrors torch.Tensor
+ # behavior in the case where we ask for the gradient
+ # of a non-contiguous tensor.
+ size = self.size()
+ if not self.is_contiguous():
+ size = torch.Size((size[1], size[0]))
+ out = Matrix(size,
+ self.data.grad,
+ self.row_indices,
+ self.column_indices,
+ self.offsets,
+ self.column_indices_t,
+ self.offsets_t,
+ self.block_offsets_t)
+ return out if self.is_contiguous() else out.t()
diff --git a/build/torch29-cxx11-cu130-x86_64-linux/stk/ops/__init__.py b/build/torch29-cxx11-cu130-x86_64-linux/stk/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc873b236f4cd4036964c016a4036e3ce5ebf1ac
--- /dev/null
+++ b/build/torch29-cxx11-cu130-x86_64-linux/stk/ops/__init__.py
@@ -0,0 +1,3 @@
+from .linear_ops import dds, dsd, sdd
+from .matrix_ops import ones_like, row_indices, sum, to_dense, to_sparse
+from .eltwise_ops import mul
diff --git a/build/torch29-cxx11-cu130-x86_64-linux/stk/ops/eltwise_ops.py b/build/torch29-cxx11-cu130-x86_64-linux/stk/ops/eltwise_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba7d7332320250fd01fa60e60528f19de3e8ed03
--- /dev/null
+++ b/build/torch29-cxx11-cu130-x86_64-linux/stk/ops/eltwise_ops.py
@@ -0,0 +1,28 @@
+from ..matrix import Matrix
+
+def mul(a, b):
+ """Performs element-wise multiplication of matrices a and b.
+
+ It is the user's responsibility to make sure that a and b
+ follow the same matrix topology. This function assumes it is safe
+ to use the topoplogy of a.
+
+ Args:
+ a: stk.Matrix.
+ b: stk.Matrix with a's matrix topology.
+
+ Returns:
+ stk.Matrix where the entries correspond to torch.mul(a, b).
+ """
+ assert isinstance(a, Matrix)
+ assert isinstance(b, Matrix)
+ assert a.size() == b.size()
+
+ return Matrix(a.size(),
+ a.data * b.data,
+ a.row_indices,
+ a.column_indices,
+ a.offsets,
+ a.column_indices_t,
+ a.offsets_t,
+ a.block_offsets_t)
diff --git a/build/torch29-cxx11-cu130-x86_64-linux/stk/ops/eltwise_ops_test.py b/build/torch29-cxx11-cu130-x86_64-linux/stk/ops/eltwise_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..66bfd4f6af77042d3c5bdb1fe18d00e457478d46
--- /dev/null
+++ b/build/torch29-cxx11-cu130-x86_64-linux/stk/ops/eltwise_ops_test.py
@@ -0,0 +1,86 @@
+import unittest
+import itertools
+import torch
+from absl.testing import parameterized
+
+import stk
+from stk.ops.linear_ops_test import allclose, _dense_and_sparse
+
+_MATRIX_SIZES = (
+ (128, 128, 0.0),
+ (256, 256, 0.5),
+ (2048, 1024, 0.8),
+ (512, 128, 0.0),
+ (128, 512, 0.0),
+ (1024, 512, 0.0),
+ (1024, 512, 0.5),
+ (1024, 512, 0.75),
+ (512, 1024, 0.0),
+ (512, 1024, 0.5),
+ (512, 1024, 0.75),
+ (1024, 1024, 0.0),
+ (1024, 1024, 0.5),
+ (1024, 1024, 0.75),
+)
+
+_DTYPE = (
+ torch.float16, torch.bfloat16
+)
+
+def _generate_testcases():
+ testcases = itertools.product(_MATRIX_SIZES, _DTYPE)
+ testcases = [(*size, 128, dtype) for
+ (size, dtype) in testcases]
+ return testcases
+
+_ELTWISE_OP_TESTS = _generate_testcases()
+
+def _dense_and_sparse_like(x, std=0.1):
+ dense_data = torch.randn_like(x.data, device=x.device) * std
+ sparse = stk.Matrix(x.size(),
+ dense_data,
+ x.row_indices,
+ x.column_indices,
+ x.offsets)
+ dense = stk.ops.to_dense(sparse)
+
+ return (dense.requires_grad_(True),
+ sparse.requires_grad_(True))
+
+@parameterized.parameters(_ELTWISE_OP_TESTS)
+class EltwiseOpsTest(parameterized.TestCase):
+
+ def testEltwiseMul(self, m, n, sparsity, blocking, dtype):
+
+ a_dense, a = _dense_and_sparse(m, n, sparsity, blocking, dtype)
+ b_dense, b = _dense_and_sparse_like(a)
+
+ out = stk.ops.mul(a, b)
+ expected_out = torch.mul(a_dense, b_dense)
+
+ # Compute the gradients w.r.t. the inputs.
+ expected_out.sum().backward()
+ stk.ops.sum(out).backward()
+
+ # Validate the results.
+ out = stk.ops.to_dense(out)
+ self.assertEqual(out.dim(), 2)
+ self.assertEqual(expected_out.size(), out.size())
+ self.assertTrue(allclose(out, expected_out))
+
+ # LHS gradient.
+ grad = stk.ops.to_dense(a.grad)
+ expected_grad = a_dense.grad
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size(), grad.size())
+ self.assertTrue(allclose(grad, expected_grad))
+
+ # RHS gradient.
+ grad = stk.ops.to_dense(b.grad)
+ expected_grad = b_dense.grad
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size(), grad.size())
+ self.assertTrue(allclose(grad, expected_grad))
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch29-cxx11-cu130-x86_64-linux/stk/ops/linear_ops.py b/build/torch29-cxx11-cu130-x86_64-linux/stk/ops/linear_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d277c8c07f9e30addc31900a12175c8a1f4d7ad
--- /dev/null
+++ b/build/torch29-cxx11-cu130-x86_64-linux/stk/ops/linear_ops.py
@@ -0,0 +1,59 @@
+import torch
+
+from ..backend import sputnik
+from ..matrix import Matrix
+
+
+def dsd(a, b):
+ assert isinstance(a, Matrix)
+ assert isinstance(b, torch.Tensor)
+ return sputnik.dsd(
+ a.size(),
+ a.data, a.offsets,
+ a.row_indices,
+ a.column_indices,
+ a.offsets_t,
+ a.column_indices_t,
+ a.block_offsets_t,
+ not a.is_contiguous(),
+ b)
+
+
+def dds(a, b):
+ assert isinstance(a, torch.Tensor)
+ assert isinstance(b, Matrix)
+ return sputnik.dds(
+ a,
+ b.size(),
+ b.data, b.offsets,
+ b.row_indices,
+ b.column_indices,
+ b.offsets_t,
+ b.column_indices_t,
+ b.block_offsets_t,
+ not b.is_contiguous())
+
+
+def sdd(a, b, topo):
+ assert isinstance(a, torch.Tensor)
+ assert isinstance(b, torch.Tensor)
+ assert isinstance(topo, Matrix)
+ assert topo.is_contiguous()
+ out = sputnik.sdd(
+ a, b,
+ topo.size(),
+ topo.data,
+ topo.offsets,
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets_t,
+ topo.column_indices_t,
+ topo.block_offsets_t)
+ return Matrix(topo.size(),
+ out,
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets,
+ topo.column_indices_t,
+ topo.offsets_t,
+ topo.block_offsets_t)
diff --git a/build/torch29-cxx11-cu130-x86_64-linux/stk/ops/linear_ops_test.py b/build/torch29-cxx11-cu130-x86_64-linux/stk/ops/linear_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..ced1d782fbc9f9ca16b3449239f1588dc5ff5e00
--- /dev/null
+++ b/build/torch29-cxx11-cu130-x86_64-linux/stk/ops/linear_ops_test.py
@@ -0,0 +1,216 @@
+import unittest
+import itertools
+import numpy as np
+import torch
+from absl.testing import parameterized
+
+import stk
+
+
+def allclose(x, y, pct=0.25):
+ mask = torch.isclose(x, y, rtol=5e-2)
+ pct_diff = (mask.numel() - mask.sum()) / mask.numel() * 100
+ if pct_diff > pct:
+ print("{:.2f}% of values not close.".format(pct_diff))
+ return False
+ return True
+
+
+# An assortment of problems designed to make sure
+# the bindings are operating correctly.
+_MATRIX_SIZES = (
+ (128, 128, 128, 0.0),
+ (256, 256, 256, 0.5),
+ (2048, 1024, 512, 0.8),
+ (512, 128, 128, 0.0),
+ (128, 128, 512, 0.0),
+ (1024, 512, 512, 0.0),
+ (1024, 512, 512, 0.5),
+ (1024, 512, 512, 0.75),
+ (512, 512, 1024, 0.0),
+ (512, 512, 1024, 0.5),
+ (512, 512, 1024, 0.75),
+ (1024, 1024, 1024, 0.0),
+ (1024, 1024, 1024, 0.5),
+ (1024, 1024, 1024, 0.75),
+)
+
+_TRANSPOSE = (
+ (False, False),
+ (False, True),
+ (True, False),
+ (True, True),
+)
+
+_DTYPE = (
+ torch.float16, torch.bfloat16
+)
+
+def _generate_testcases():
+ testcases = itertools.product(_MATRIX_SIZES, _TRANSPOSE, _DTYPE)
+ testcases = [(*size, *trans, 128, dtype) for
+ (size, trans, dtype) in testcases]
+ return testcases
+
+_LINEAR_OP_TESTS = _generate_testcases()
+
+def _dense_and_sparse(rows, cols, sparsity, blocking, dtype, std=0.1):
+ mask = stk.random.dense_mask(rows, cols, sparsity, blocking)
+ dense = (torch.randn(rows, cols) * std * mask).type(dtype)
+ sparse = stk.ops.to_sparse(dense, blocking)
+ cuda_device = torch.device("cuda")
+ return (dense.to(cuda_device).requires_grad_(True),
+ sparse.to(cuda_device).requires_grad_(True))
+
+
+def _dense(rows, cols, dtype, std=0.1):
+ cuda_device = torch.device("cuda")
+ out = (torch.randn(rows, cols) * std).type(dtype)
+ return out.to(cuda_device).requires_grad_(True)
+
+
+def _dense_2x(rows, cols, dtype):
+ a = _dense(rows, cols, dtype)
+ return a, a.detach().requires_grad_(True)
+
+
+def _with_transpose(op, a, b, trans_a, trans_b):
+ a = a.t() if trans_a else a
+ b = b.t() if trans_b else b
+ return op(a, b)
+
+
+def _mmm(a, b, topo):
+ mask = stk.ops.to_dense(stk.ops.ones_like(topo))
+ return torch.mm(a, b) * mask
+
+
+def _sparse_out_with_transpose(op, a, b, topo, trans_a, trans_b):
+ a = a.t() if trans_a else a
+ b = b.t() if trans_b else b
+ return op(a, b, topo)
+
+
+def _mask(x, mask):
+ mask = stk.ops.to_dense(stk.ops.ones_like(mask))
+ return x * mask
+
+
+@parameterized.parameters(*_LINEAR_OP_TESTS)
+class LinearOpsTest(parameterized.TestCase):
+
+ def testLinearOps_Dsd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+ # Construct the operands.
+ a_shape = (k, m) if trans_a else (m, k)
+ a_dense, a = _dense_and_sparse(*a_shape, sparsity, blocking, dtype)
+ b_shape = (n, k) if trans_b else (k, n)
+ b, bcp = _dense_2x(*b_shape, dtype)
+
+ # Execute the matmul.
+ out = _with_transpose(stk.ops.dsd, a, b, trans_a, trans_b)
+ expected_out = _with_transpose(torch.mm, a_dense, bcp, trans_a, trans_b)
+
+ # Compute the gradients w.r.t. the inputs.
+ expected_out.sum().backward()
+ out.sum().backward()
+
+ # Validate the results.
+ self.assertEqual(out.dim(), 2)
+ self.assertEqual(expected_out.size()[0], out.size()[0])
+ self.assertEqual(expected_out.size()[1], out.size()[1])
+ self.assertTrue(allclose(out, expected_out))
+
+ # LHS gradient.
+ grad = stk.ops.to_dense(a.grad)
+ expected_grad = _mask(a_dense.grad, a.grad)
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
+ self.assertTrue(allclose(grad, expected_grad))
+
+ # RHS gradient.
+ grad = b.grad
+ expected_grad = bcp.grad
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
+ self.assertTrue(allclose(grad, expected_grad))
+
+ def testLinearOps_Dds(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+ # Construct the operands.
+ a_shape = (k, m) if trans_a else (m, k)
+ a, acp = _dense_2x(*a_shape, dtype)
+ b_shape = (n, k) if trans_b else (k, n)
+ b_dense, b = _dense_and_sparse(*b_shape, sparsity, blocking, dtype)
+
+ # Execute the matmul.
+ out = _with_transpose(stk.ops.dds, a, b, trans_a, trans_b)
+ expected_out = _with_transpose(torch.mm, acp, b_dense, trans_a, trans_b)
+
+ # Compute the gradients w.r.t. the inputs.
+ expected_out.sum().backward()
+ out.sum().backward()
+
+ # Validate the results.
+ self.assertEqual(out.dim(), 2)
+ self.assertEqual(expected_out.size()[0], out.size()[0])
+ self.assertEqual(expected_out.size()[1], out.size()[1])
+ self.assertTrue(allclose(out, expected_out))
+
+ # LHS gradient.
+ grad = a.grad
+ expected_grad = acp.grad
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
+ self.assertTrue(allclose(grad, expected_grad))
+
+ # RHS gradient.
+ grad = stk.ops.to_dense(b.grad)
+ expected_grad = _mask(b_dense.grad, b.grad)
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
+ self.assertTrue(allclose(grad, expected_grad))
+
+ def testLinearOps_Sdd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
+ # Construct the operands.
+ a_shape = (k, m) if trans_a else (m, k)
+ a, acp = _dense_2x(*a_shape, dtype)
+ b_shape = (n, k) if trans_b else (k, n)
+ b, bcp = _dense_2x(*b_shape, dtype)
+ _, topo = _dense_and_sparse(m, n, sparsity, blocking, dtype)
+
+ # Execute the matmul.
+ out = _sparse_out_with_transpose(stk.ops.sdd, a, b, topo, trans_a, trans_b)
+ expected_out = _sparse_out_with_transpose(_mmm, acp, bcp, topo, trans_a, trans_b)
+
+ # Compute the gradients w.r.t. the inputs.
+ expected_out.sum().backward()
+ stk.ops.sum(out).backward()
+
+ # Validate the results.
+ out = stk.ops.to_dense(out)
+ self.assertEqual(out.dim(), 2)
+ self.assertEqual(expected_out.size()[0], out.size()[0])
+ self.assertEqual(expected_out.size()[1], out.size()[1])
+ self.assertTrue(allclose(out, expected_out))
+
+ # LHS gradient.
+ grad = a.grad
+ expected_grad = acp.grad
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
+ self.assertTrue(allclose(grad, expected_grad))
+
+ # RHS gradient.
+ grad = b.grad
+ expected_grad = bcp.grad
+ self.assertEqual(grad.dim(), 2)
+ self.assertEqual(expected_grad.size()[0], grad.size()[0])
+ self.assertEqual(expected_grad.size()[1], grad.size()[1])
+ self.assertTrue(allclose(grad, expected_grad))
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch29-cxx11-cu130-x86_64-linux/stk/ops/matrix_ops.py b/build/torch29-cxx11-cu130-x86_64-linux/stk/ops/matrix_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..447c72dc73439d84f58c917676cc04e64f13e97d
--- /dev/null
+++ b/build/torch29-cxx11-cu130-x86_64-linux/stk/ops/matrix_ops.py
@@ -0,0 +1,98 @@
+from ..backend import sputnik
+from ..matrix import Matrix
+import torch
+import numpy as np
+
+
+@torch.no_grad()
+def row_indices(shape, data, offsets, column_indices):
+ return sputnik.row_indices(shape, data, offsets, column_indices)
+
+
+# TODO(tgale): Replace this helper with a custom kernel. This operation
+# is much simpler to do than how it's currently implemented.
+@torch.no_grad()
+def _expand_for_blocking(idxs, blocking):
+ # Duplicate for block column dimension.
+ idxs = torch.reshape(idxs, [idxs.size()[0], 1, 2]).repeat(1, blocking, 1)
+
+ # Update the column indices.
+ idxs[:, :, 1] *= blocking
+ idxs[:, :, 1] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking])
+
+ # Duplicate for block row dimension.
+ idxs = torch.reshape(idxs, [idxs.size()[0], 1, blocking, 2])
+ idxs = idxs.repeat(1, blocking, 1, 1)
+
+ # Update the row indices.
+ idxs[:, :, :, 0] *= blocking
+ idxs[:, :, :, 0] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking, 1])
+ idxs = torch.reshape(idxs, [-1, 2])
+ return idxs
+
+
+# TODO(tgale): Add input type checking.
+@torch.no_grad()
+def to_dense(x):
+ assert isinstance(x, Matrix)
+
+ shape = (np.prod(x.shape[:-1]), x.shape[-1])
+ row_idxs = x.row_indices.type(torch.int32)
+ col_idxs = x.column_indices.type(torch.int32)
+ indices = _expand_for_blocking(torch.stack([row_idxs, col_idxs], dim=1), x.blocking)
+ indices = (indices[:, 0] * shape[1] + indices[:, 1]).type(torch.int64)
+
+ out = torch.zeros(shape[0] * shape[1], dtype=x.dtype, device=x.device)
+ out.scatter_(0, indices, x.data.flatten())
+ return out.reshape(x.size())
+
+
+@torch.no_grad()
+def _mask(x, blocking=1):
+ assert x.dim() == 2
+ assert x.size()[0] % blocking == 0
+ assert x.size()[1] % blocking == 0
+ block_rows = x.size()[0] // blocking
+ block_cols = x.size()[1] // blocking
+ x = torch.reshape(x, [block_rows, blocking, block_cols, blocking])
+ x = torch.sum(torch.abs(x), dim=(1, 3))
+ return x != 0
+
+
+# TODO(tgale): Add input type checking.
+@torch.no_grad()
+def to_sparse(x, blocking=1):
+ m = _mask(x, blocking)
+
+ # TODO(tgale): Set to appropriate type for input matrix.
+ row_nnzs = torch.sum(m, dim=1).type(torch.int32)
+ zeros = torch.zeros((1,), dtype=row_nnzs.dtype, device=row_nnzs.device)
+ offsets = torch.cat([zeros, torch.cumsum(row_nnzs, dim=0)])
+ offsets = offsets.type(torch.int32)
+
+ indices = torch.nonzero(m).type(torch.int16)
+ row_indices = indices[:, 0]
+ column_indices = indices[:, 1]
+
+ # Nonzero indices in the dense matrix.
+ nonzero_indices = torch.nonzero(m)
+ nonzero_indices = _expand_for_blocking(nonzero_indices, blocking)
+ nonzero_indices = nonzero_indices[:, 0] * x.size()[1] + nonzero_indices[:, 1]
+
+ # Gather the data and construct the sparse matrix.
+ data = torch.gather(x.flatten(), dim=0, index=nonzero_indices)
+ data = torch.reshape(data, [-1, blocking, blocking])
+ return Matrix(x.size(), data, row_indices, column_indices, offsets)
+
+
+@torch.no_grad()
+def ones_like(x):
+ return Matrix(x.size(),
+ torch.ones_like(x.data),
+ x.row_indices,
+ x.column_indices, x.offsets)
+
+
+def sum(x):
+ assert isinstance(x, Matrix)
+ return x.data.sum()
diff --git a/build/torch29-cxx11-cu130-x86_64-linux/stk/ops/matrix_ops_test.py b/build/torch29-cxx11-cu130-x86_64-linux/stk/ops/matrix_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..3af04c0760483e578f93303dc457415948a2a34c
--- /dev/null
+++ b/build/torch29-cxx11-cu130-x86_64-linux/stk/ops/matrix_ops_test.py
@@ -0,0 +1,62 @@
+import unittest
+
+from absl.testing import parameterized
+import stk
+import torch
+
+
+@parameterized.parameters(
+ (8, 16, 0.0, 1),
+ (8, 16, 0.5, 1),
+ (8, 16, .95, 1),
+ (16, 8, 0.0, 1),
+ (16, 8, 0.5, 1),
+ (16, 8, .95, 1),
+ (8, 16, 0.0, 8),
+ (8, 16, 0.5, 8),
+ (8, 16, 1.0, 8),
+ (16, 8, 0.0, 8),
+ (16, 8, 0.5, 8),
+ (16, 8, 1.0, 8),
+ (128, 256, 0.5, 16),
+ (256, 128, 0.75, 32),
+ (512, 512, .875, 128))
+class MatrixOpsTest(parameterized.TestCase):
+
+ def testMatrixOps_FormatConversion(self, rows, cols, sparsity, blocking):
+ mask = stk.random.dense_mask(rows, cols, sparsity, blocking)
+ x = (torch.randn(rows, cols) * mask).type(torch.float16)
+
+ # Convert the matrix to sparse format.
+ sparse_x = stk.ops.to_sparse(x, blocking)
+
+ # Validate the matrix.
+ sparse_x.validate()
+
+ # Validate the shape.
+ self.assertEqual(sparse_x.dim(), 2)
+ self.assertEqual(sparse_x.size()[0], rows)
+ self.assertEqual(sparse_x.size()[1], cols)
+
+ # Validate the sparsity.
+ numblocks = rows // blocking * cols // blocking
+ nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+ self.assertEqual(sparse_x.nnz, nnz)
+
+ # Convert back to dense format.
+ dense_x = stk.ops.to_dense(sparse_x)
+
+ # Validate the shape.
+ self.assertEqual(dense_x.dim(), 2)
+ self.assertEqual(dense_x.size()[0], rows)
+ self.assertEqual(dense_x.size()[1], cols)
+
+ # Validate the sparsity
+ self.assertEqual(torch.count_nonzero(dense_x).item(), nnz)
+
+ # Validate the output.
+ self.assertTrue(torch.all(torch.eq(x, dense_x)))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch29-cxx11-cu130-x86_64-linux/stk/random/__init__.py b/build/torch29-cxx11-cu130-x86_64-linux/stk/random/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2576d1ca27283f77569a9a620c7c99fa68aaf30e
--- /dev/null
+++ b/build/torch29-cxx11-cu130-x86_64-linux/stk/random/__init__.py
@@ -0,0 +1,2 @@
+# from stk.random.random_ops import dense_mask, mask, randn
+from .random_ops import dense_mask, mask, randn
diff --git a/build/torch29-cxx11-cu130-x86_64-linux/stk/random/random_ops.py b/build/torch29-cxx11-cu130-x86_64-linux/stk/random/random_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1b36771e0eb8e7abf46bcb3b136b5fb1d29df93
--- /dev/null
+++ b/build/torch29-cxx11-cu130-x86_64-linux/stk/random/random_ops.py
@@ -0,0 +1,36 @@
+import numpy as np
+import torch
+from ..ops import matrix_ops
+
+
+@torch.no_grad()
+def dense_mask(rows, cols, sparsity, blocking=1):
+ assert sparsity >= 0.0 and sparsity <= 1.0
+ assert rows % blocking == 0 and cols % blocking == 0
+
+ block_rows, block_cols = (rows // blocking, cols // blocking)
+ nnz = round(block_rows * block_cols * (1 - sparsity))
+
+ out = np.ones(block_rows * block_cols)
+ mask = np.random.choice(out.size, out.size - nnz, replace=False)
+ out[mask] = 0.0
+
+ out = np.tile(
+ np.reshape(out, [block_rows, 1, block_cols, 1]),
+ (1, blocking, 1, blocking))
+ out = np.reshape(out, [rows, cols])
+ return torch.from_numpy(out.astype(np.float32))
+
+
+@torch.no_grad()
+def mask(m, n, sparsity, blocking=1):
+ out = dense_mask(m, n, sparsity, blocking).type(torch.float16)
+ return matrix_ops.to_sparse(out, blocking=blocking)
+
+
+@torch.no_grad()
+def randn(shape, sparsity, blocking=1):
+ shape_2d = (np.prod(shape[:-1]), shape[-1])
+ out = mask(*shape_2d, sparsity, blocking)
+ out.data.copy_(torch.randn(*out.data.shape))
+ return out.view(*shape)
diff --git a/build/torch29-cxx11-cu130-x86_64-linux/stk/random/random_ops_test.py b/build/torch29-cxx11-cu130-x86_64-linux/stk/random/random_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..587b44ec890c861879c6296b8f9028f5d99ab82f
--- /dev/null
+++ b/build/torch29-cxx11-cu130-x86_64-linux/stk/random/random_ops_test.py
@@ -0,0 +1,73 @@
+import unittest
+
+from absl.testing import parameterized
+from . import random
+import torch
+
+
+@parameterized.parameters(
+ (8, 16, 0.0, 1),
+ (8, 16, 0.5, 1),
+ (8, 16, .95, 1),
+ (16, 8, 0.0, 1),
+ (16, 8, 0.5, 1),
+ (16, 8, .95, 1),
+ (8, 16, 0.0, 8),
+ (8, 16, 0.5, 8),
+ (8, 16, 1.0, 8),
+ (16, 8, 0.0, 8),
+ (16, 8, 0.5, 8),
+ (16, 8, 1.0, 8),
+ (128, 256, 0.5, 16),
+ (256, 128, 0.75, 32),
+ (512, 512, .875, 128))
+class RandomOpsTest(parameterized.TestCase):
+
+ def testRandomOps_DenseMask(self, rows, cols, sparsity, blocking):
+ mask = random.dense_mask(
+ rows, cols, sparsity, blocking)
+
+ # Validate the shape.
+ self.assertEqual(mask.dim(), 2)
+ self.assertEqual(mask.size()[0], rows)
+ self.assertEqual(mask.size()[1], cols)
+
+ # Validate the sparsity
+ numblocks = rows // blocking * cols // blocking
+ nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+ self.assertEqual(
+ torch.count_nonzero(mask).item(),
+ nnz)
+
+ # Check values are zero or one.
+ self.assertTrue(
+ torch.all(torch.logical_or(
+ torch.eq(mask, 0),
+ torch.eq(mask, 1))))
+
+ def testRandomOps_SparseMask(self, rows, cols, sparsity, blocking):
+ mask = random.mask(
+ rows, cols, sparsity, blocking)
+
+ # Validate the matrix.
+ mask.validate()
+
+ # Validate the shape.
+ self.assertEqual(mask.dim(), 2)
+ self.assertEqual(mask.size()[0], rows)
+ self.assertEqual(mask.size()[1], cols)
+
+ # Validate the sparsity.
+ numblocks = rows // blocking * cols // blocking
+ nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
+ self.assertEqual(mask.nnz, nnz)
+
+ # Check values are zero or one.
+ self.assertTrue(
+ torch.all(torch.logical_or(
+ torch.eq(mask.data, 0),
+ torch.eq(mask.data, 1))))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/build/torch29-cxx11-cu130-x86_64-linux/xpu_fused_moe.py b/build/torch29-cxx11-cu130-x86_64-linux/xpu_fused_moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..0440b484e6f4073a384dcf3bc562174d2ef71d25
--- /dev/null
+++ b/build/torch29-cxx11-cu130-x86_64-linux/xpu_fused_moe.py
@@ -0,0 +1,672 @@
+# SPDX-License-Identifier: Apache-2.0
+# MegaBlocks XPU Fused MoE Implementation
+import os
+import torch
+
+from ._ops import ops, add_op_namespace_prefix
+
+from torch.library import register_fake
+
+
+def resolve_dtensor(weight: torch.Tensor):
+ """Convert DTensor to local tensor for use with custom ops."""
+ from torch.distributed._tensor import DTensor
+ if isinstance(weight, DTensor):
+ return weight.to_local()
+ return weight
+
+
+# Register fake/meta kernels for torch.compile compatibility
+def _register_xpu_fake_kernels():
+ """Register fake kernels for XPU MoE operations to support torch.compile."""
+
+ def _register_if_available(op_name, fn):
+ if hasattr(ops, op_name):
+ register_fake(add_op_namespace_prefix(op_name))(fn)
+
+ _register_if_available(
+ "cutlass_grouped_gemm_interface",
+ lambda ptr_A, ptr_B, ptr_scales, ptr_bias, ptr_D, expert_first_token_offset, N, K, num_experts, is_B_int4, is_B_mxfp4: ptr_D,
+ )
+
+ _register_if_available(
+ "fused_moe_prologue",
+ lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, ep_rank, ep_size, num_experts_on_rank: None,
+ )
+
+ _register_if_available(
+ "moe_gather",
+ lambda output, moe_output, topk_weights, permuted_row_to_unpermuted_row, unpermuted_row_to_permuted_row, expert_first_token_offset, num_experts: None,
+ )
+
+ _register_if_available(
+ "silu_and_mul",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "mul_and_silu",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_and_mul",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_tanh_and_mul",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_fast",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_new",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "gelu_quick",
+ lambda out, input: None,
+ )
+ _register_if_available(
+ "swigluoai_and_mul",
+ lambda out, input, alpha=1.702, limit=7.0: None,
+ )
+
+
+# Register fake kernels on module load
+_register_xpu_fake_kernels()
+
+
+# default
+def cutlass_grouped_gemm(input_A, input_B, bias, output, expert_token_count, n,
+ k, num_experts):
+ # expert_token_count_ = torch.tensor(expert_token_count,
+ # dtype=torch.int64,
+ # device=input_A.device)
+ # if bias is not None:
+ # bias = bias.repeat_interleave(expert_token_count_, dim=0).float()
+
+ def exclusive_prefix_sum(arr):
+ prefix = [0]
+ for i, x in enumerate(arr):
+ prefix.append(prefix[-1] + x)
+ return prefix
+
+ expert_offset = torch.tensor(exclusive_prefix_sum(expert_token_count),
+ dtype=torch.int64,
+ device="xpu")
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=None,
+ ptr_bias=bias,
+ ptr_D=output,
+ expert_first_token_offset=expert_offset,
+ N=n,
+ K=k,
+ num_experts=num_experts,
+ is_B_int4=False,
+ is_B_mxfp4=False)
+
+
+def cutlass_grouped_gemm_xe2(input_A, input_B, scales, bias, output,
+ num_rows_per_expert, n, k, num_experts, is_B_int4,
+ is_B_mxfp4):
+ expert_first_token_offset = torch.cat([
+ torch.tensor([0],
+ dtype=num_rows_per_expert.dtype,
+ device=num_rows_per_expert.device),
+ torch.cumsum(num_rows_per_expert, dim=0)
+ ]).to(torch.int64)
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=scales,
+ ptr_bias=bias,
+ ptr_D=output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=n,
+ K=k,
+ num_experts=num_experts,
+ is_B_int4=is_B_int4,
+ is_B_mxfp4=is_B_mxfp4)
+
+
+def ceilDiv(a, b):
+ return (a + b - 1) // b
+
+
+def compute_num_tokens_per_block(num_tokens, num_experts_per_node):
+ for num_tokens_per_block in [32, 64, 128, 256, 512, 1024]:
+ num_blocks_per_seq = ceilDiv(num_tokens, num_tokens_per_block)
+ if num_blocks_per_seq * num_experts_per_node <= num_tokens_per_block:
+ return num_tokens_per_block
+ return 1024
+
+
+def _bytes_to_typed_tensor(byte_tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
+ """Reinterpret a uint8 buffer as a typed tensor by copying bytes.
+
+ This avoids `Tensor.view(dtype)` which can fail under torch.compile
+ constant folding when shape divisibility is not proven.
+ """
+ if byte_tensor.dtype != torch.uint8:
+ raise ValueError("byte_tensor must be uint8")
+ itemsize = torch.empty((), dtype=dtype).element_size()
+ numel = byte_tensor.numel() // itemsize
+ out = torch.empty((numel,), dtype=dtype, device=byte_tensor.device)
+ out.view(torch.uint8).copy_(byte_tensor.contiguous())
+ return out
+
+
+def implement_zp(qweight):
+ # change u4 to s4 to avoid zero point in gemm kernel
+ # only support default zero point now
+ assert qweight.dtype == torch.uint8, "Input tensor must be uint8"
+
+ high_u4 = (qweight >> 4) & 0x0F
+ low_u4 = qweight & 0x0F
+
+ high_s8 = high_u4.to(torch.int8)
+ low_s8 = low_u4.to(torch.int8)
+
+ high_s8 = high_s8 - 8
+ low_s8 = low_s8 - 8
+
+ def pack_compact(a, b):
+
+ def process_number(x):
+ sign = (x < 0).to(torch.uint8)
+ abs_low3 = (x.view(torch.uint8) & 0x7).to(torch.uint8)
+ return (sign << 3) | abs_low3
+
+ packed_a = process_number(a)
+ packed_b = process_number(b)
+
+ return (packed_a << 4) | packed_b
+
+ result = pack_compact(high_s8, low_s8)
+
+ return result
+
+
+def xpu_fused_moe(hidden_states,
+ w13,
+ w13_scales,
+ w13_bias,
+ w2,
+ w2_scales,
+ w2_bias,
+ topk_weights,
+ topk_ids,
+ n_experts_per_token,
+ activation,
+ num_experts,
+ ep_rank=0,
+ ep_size=1,
+ is_fp8=False,
+ is_int4=False,
+ is_mxfp4=False):
+ '''
+ hidden_states: [num_rows, hidden_size]
+ w13: [num_experts, 2*inter_size, hidden_size]
+ w13_scales:
+ None for bf16/fp16
+ or [num_experts] for fp8
+ or [num_experts, 2*inter_size, hidden_size // group_size] for 4bits
+ w13_bias: [num_experts, 2*inter_size] or None
+ w2: [num_experts, hidden_size, inter_size]
+ w2_scales:
+ None for bf16/fp16
+ or [num_experts] for fp8
+ or [num_experts, hidden_size, inter_size // group_size] for 4bits
+ w2_bias: [num_experts, hidden_size] or None
+ topk_weights: [num_rows, topk]
+ topk_ids: [num_rows, topk]
+ n_experts_per_token: int
+ activation: str
+ num_experts: int
+ is_int4: bool
+ is_mxfp4: bool
+ '''
+
+ # Resolve DTensors to local tensors before passing to custom ops
+ hidden_states = resolve_dtensor(hidden_states)
+ w13 = resolve_dtensor(w13)
+ w2 = resolve_dtensor(w2)
+ if w13_scales is not None:
+ w13_scales = resolve_dtensor(w13_scales)
+ if w13_bias is not None:
+ w13_bias = resolve_dtensor(w13_bias)
+ if w2_scales is not None:
+ w2_scales = resolve_dtensor(w2_scales)
+ if w2_bias is not None:
+ w2_bias = resolve_dtensor(w2_bias)
+ topk_weights = resolve_dtensor(topk_weights)
+ topk_ids = resolve_dtensor(topk_ids)
+
+ output = torch.empty_like(hidden_states)
+ num_rows, hidden_size = list(hidden_states.shape)
+
+ dim_last = w13.shape[-1]
+ dim_second_last = w13.shape[-2]
+
+ # w13 is combined gate+up weights, so one dimension is 2*inter_size
+ # Determine which dimension is hidden_size and which is 2*inter_size
+ if dim_second_last == hidden_size:
+ # w13 is [E, hidden_size, 2*inter_size] - standard layout
+ inter_size = dim_last // 2
+ needs_transpose = False
+ else:
+ # w13 is [E, 2*inter_size, hidden_size] - needs transpose
+ inter_size = dim_second_last // 2
+ needs_transpose = True
+
+ assert w13.is_contiguous() and w2.is_contiguous()
+
+ # 4bits support [E, N, K]
+ # other types [E, K, N]
+ if not is_int4 and not is_mxfp4:
+ if not hasattr(w13, 'xpu_fused_moe'):
+ if needs_transpose:
+ w13.data = w13.transpose(-1, -2).contiguous()
+ w2.data = w2.transpose(-1, -2).contiguous()
+ w13.xpu_fused_moe = True
+ w13.inter_size = inter_size
+ else:
+ inter_size = w13.inter_size
+
+ if is_int4 and not hasattr(w13, 'xpu_fused_moe'):
+ w13_tmp = torch.empty_like(w13)
+ w2_tmp = torch.empty_like(w2)
+ for i in range(num_experts):
+ w13_tmp[i] = implement_zp(w13[i])
+ w2_tmp[i] = implement_zp(w2[i])
+ w13_tmp = w13_tmp.contiguous()
+ w2_tmp = w2_tmp.contiguous()
+ w13.data = w13_tmp
+ w2.data = w2_tmp
+ w13.xpu_fused_moe = True
+
+ # TODO: will all integrated in Cpp func. Temporary expose before gemm fusion
+ num_experts_per_node = num_experts
+ experts_per_token = n_experts_per_token
+ num_moe_inputs = n_experts_per_token * num_rows
+ permuted_elems = num_moe_inputs * hidden_size
+ # interbuf_elems = num_moe_inputs * inter_size
+ permuted_row_to_unpermuted_row_size = num_moe_inputs * 4
+ permuted_token_selected_experts_size = num_moe_inputs * 4
+ src_to_dest_map_size = experts_per_token * num_rows * 4
+ expert_first_token_offset_size = (num_experts_per_node + 1) * 8
+ num_tokens_per_block = compute_num_tokens_per_block(
+ num_rows, num_experts_per_node)
+ num_blocks_per_seq = ceilDiv(num_rows, num_tokens_per_block)
+ blocked_expert_counts_size = num_experts_per_node * num_blocks_per_seq * 4
+ blocked_expert_counts_cumsum_size = blocked_expert_counts_size
+ blocked_row_to_unpermuted_row_size = num_experts_per_node * num_rows * 4
+ permuted_data_size = permuted_elems * hidden_states.element_size()
+ permuted_token_final_scales_size = num_moe_inputs * 4
+
+ ws_map = {}
+ map_offset = 0
+
+ def config_ws(name, size):
+ nonlocal map_offset
+ if size % 256 != 0:
+ size += 256 - size % 256
+ ws_map[name] = (size, map_offset)
+ map_offset += size
+
+ config_ws("permuted_row_to_unpermuted_row",
+ permuted_row_to_unpermuted_row_size)
+ config_ws("permuted_token_selected_experts",
+ permuted_token_selected_experts_size)
+ config_ws("unpermuted_row_to_permuted_row", src_to_dest_map_size)
+ config_ws("blocked_expert_counts", blocked_expert_counts_size)
+ config_ws("blocked_expert_counts_cumsum",
+ blocked_expert_counts_cumsum_size)
+ config_ws("blocked_row_to_unpermuted_row",
+ blocked_row_to_unpermuted_row_size)
+ config_ws("expert_first_token_offset", expert_first_token_offset_size)
+ config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
+ config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
+
+ workspace = torch.zeros(map_offset,
+ dtype=torch.uint8,
+ device=hidden_states.device)
+ if topk_ids.dtype == torch.int32:
+ topk_ids = topk_ids.to(torch.int64)
+ ops.fused_moe_prologue(
+ input=hidden_states,
+ token_selected_experts=topk_ids,
+ token_final_scales=topk_weights,
+ workspace=workspace,
+ hidden_size=hidden_size,
+ inter_size=inter_size,
+ ep_rank=ep_rank,
+ ep_size=ep_size,
+ num_experts_on_rank=num_experts_per_node)
+
+ expert_first_token_offset_bytes = workspace[
+ ws_map["expert_first_token_offset"][1]:
+ ws_map["expert_first_token_offset"][1] +
+ expert_first_token_offset_size]
+ unpermuted_row_to_permuted_row_bytes = workspace[
+ ws_map["unpermuted_row_to_permuted_row"][1]:
+ ws_map["unpermuted_row_to_permuted_row"][1] +
+ src_to_dest_map_size]
+ permuted_row_to_unpermuted_row_bytes = workspace[
+ ws_map["permuted_row_to_unpermuted_row"][1]:
+ ws_map["permuted_row_to_unpermuted_row"][1] +
+ permuted_row_to_unpermuted_row_size]
+
+ if torch.compiler.is_compiling():
+ expert_first_token_offset = _bytes_to_typed_tensor(
+ expert_first_token_offset_bytes, torch.int64
+ )
+ unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
+ unpermuted_row_to_permuted_row_bytes, torch.int32
+ )
+ permuted_row_to_unpermuted_row = _bytes_to_typed_tensor(
+ permuted_row_to_unpermuted_row_bytes, torch.int32
+ )
+ else:
+ expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
+ unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
+ permuted_row_to_unpermuted_row = permuted_row_to_unpermuted_row_bytes.view(torch.int32)
+ gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
+ ws_map["overlapped_gemm1_gemm2_inputs"][1] +
+ permuted_data_size].view(hidden_states.dtype).view(
+ num_moe_inputs, hidden_size)
+ # permuted_token_final_scales = workspace[
+ # ws_map["permuted_token_final_scales"][1]:
+ # ws_map["permuted_token_final_scales"][1] +
+ # permuted_token_final_scales_size].view(torch.float)
+ gemm1_output = torch.empty((num_moe_inputs, 2 * inter_size),
+ dtype=hidden_states.dtype,
+ device=hidden_states.device)
+
+ ########### gemm1 ##################
+ input_B = w13
+
+ if not is_fp8 and not is_int4 and not is_mxfp4:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=gemm1_input,
+ ptr_B=input_B,
+ ptr_scales=None,
+ ptr_bias=w13_bias,
+ ptr_D=gemm1_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=2 * inter_size,
+ K=hidden_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+ else:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=gemm1_input,
+ ptr_B=input_B,
+ ptr_scales=w13_scales,
+ ptr_bias=w13_bias,
+ ptr_D=gemm1_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=2 * inter_size,
+ K=hidden_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+
+ # act
+ act_output = torch.empty((num_moe_inputs, inter_size),
+ dtype=gemm1_output.dtype,
+ device=gemm1_output.device)
+ if activation == "silu":
+ ops.silu_and_mul(act_output, gemm1_output)
+ elif activation == "gelu":
+ ops.gelu_and_mul(act_output, gemm1_output)
+ elif activation == "swigluoai":
+ ops.swigluoai_and_mul(act_output, gemm1_output, 1.702, 7.0)
+ else:
+ raise ValueError(f"Unsupported FusedMoe activation: {activation}.")
+
+ ########### gemm2 ##################
+ input_A = act_output.contiguous()
+ input_B = w2
+ gemm2_output = torch.empty((num_moe_inputs, hidden_size),
+ dtype=hidden_states.dtype,
+ device=hidden_states.device)
+ if not is_fp8 and not is_int4 and not is_mxfp4:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=None,
+ ptr_bias=w2_bias,
+ ptr_D=gemm2_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=hidden_size,
+ K=inter_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+ else:
+ ops.cutlass_grouped_gemm_interface(
+ ptr_A=input_A,
+ ptr_B=input_B,
+ ptr_scales=w2_scales,
+ ptr_bias=w2_bias,
+ ptr_D=gemm2_output,
+ expert_first_token_offset=expert_first_token_offset,
+ N=hidden_size,
+ K=inter_size,
+ num_experts=num_experts_per_node,
+ is_B_int4=is_int4,
+ is_B_mxfp4=is_mxfp4)
+
+ ops.moe_gather(output, gemm2_output, topk_weights,
+ permuted_row_to_unpermuted_row,
+ unpermuted_row_to_permuted_row,
+ expert_first_token_offset,
+ num_experts_per_node)
+ return output
+
+
+def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor:
+ """Apply jitter to the input tensor for regularization."""
+ low = 1.0 - moe_jitter_eps
+ high = 1.0 + moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return x * (low + noise * (high - low))
+
+
+def compute_top_k(scores: torch.Tensor, moe_top_k: int):
+ """Compute the top-k scores from the logits."""
+ if moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, moe_top_k, dim=-1)
+
+
+def route_tokens_xpu(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: torch.Tensor,
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ training: bool = False,
+) -> tuple:
+ """Route tokens to experts and compute expert weights and indices (XPU version)."""
+ if training and moe_jitter_eps is not None:
+ x = apply_jitter(x, moe_jitter_eps)
+
+ x_flat = x.view(-1, x.shape[-1])
+ logits = torch.nn.functional.linear(x_flat, router_weight, router_bias)
+ expert_weights, expert_indices = compute_top_k(logits, moe_top_k)
+ expert_weights = expert_weights.softmax(dim=-1)
+ if moe_normalize_expert_weights is not None:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+
+ return logits, expert_weights, expert_indices
+
+
+def _get_device_mesh(model):
+ """Extract device_mesh from child's unused pre_hook closure for EP support."""
+ try:
+ hook = next(
+ h
+ for h in model.experts._forward_pre_hooks.values()
+ if "device_mesh" in h.__code__.co_freevars
+ )
+ return hook.__closure__[
+ hook.__code__.co_freevars.index("device_mesh")
+ ].cell_contents
+ except Exception:
+ return None
+
+
+class MegaBlocksMoeMLP(torch.nn.Module):
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> tuple:
+ """
+ Forward pass through the MoE layer.
+
+ Args:
+ x: Input tensor of shape [batch_size, seq_len, hidden_size] or [tokens, hidden_size]
+
+ Returns:
+ Tuple of (output, expert_weights) where:
+ - output: Tensor of same shape as input
+ - expert_weights: Expert weights for each token [tokens, top_k]
+ """
+
+ # Get MoE parameters from the wrapped modules
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+
+ # Get EP (Expert Parallelism) parameters
+ ep_size = 1
+ ep_rank = 0
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = _get_device_mesh(self)
+ if device_mesh is not None:
+ expert_parallel_group = device_mesh.get_group()
+ if expert_parallel_group is not None:
+ import torch.distributed as dist
+ if dist.is_initialized():
+ ep_size = dist.get_world_size(expert_parallel_group)
+ ep_rank = dist.get_rank(expert_parallel_group)
+
+ # Number of experts on this rank
+ num_experts_on_rank = moe_num_experts // ep_size
+
+ # Detect activation type - check for GptOss-style swigluoai activation
+ # GptOssExperts has alpha and limit attributes for swigluoai
+ if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
+ activation = "swigluoai"
+ else:
+ activation = getattr(self.experts, "activation", "silu")
+
+ # Get weight tensors - support different naming conventions
+ if hasattr(self.experts, "gate_up_proj"):
+ w13 = self.experts.gate_up_proj
+ # NOTE: swigluoai_and_mul kernel expects interleaved layout [g0,u0,g1,u1,...]
+ # which matches GptOss's gate_up_proj layout, so no conversion needed.
+
+ elif hasattr(self.experts, "w1"):
+ # Combine w1 and w3 if stored separately
+ w1 = self.experts.w1
+ w3 = getattr(self.experts, "w3", None)
+ if w3 is not None:
+ w13 = torch.cat([w1, w3], dim=-2)
+ else:
+ w13 = w1
+ else:
+ raise AttributeError("experts module must have 'gate_up_proj' or 'w1' attribute")
+
+ if hasattr(self.experts, "down_proj"):
+ w2 = self.experts.down_proj
+ elif hasattr(self.experts, "w2"):
+ w2 = self.experts.w2
+ else:
+ raise AttributeError("experts module must have 'down_proj' or 'w2' attribute")
+
+ # Get optional bias tensors
+ w13_bias = getattr(self.experts, "gate_up_proj_bias", None)
+ w2_bias = getattr(self.experts, "down_proj_bias", None)
+
+ # Get quantization info
+ is_fp8 = getattr(self.experts, "is_fp8", False)
+ is_int4 = getattr(self.experts, "is_int4", False)
+ is_mxfp4 = getattr(self.experts, "is_mxfp4", False)
+
+ w13_scales = getattr(self.experts, "gate_up_proj_scales", None)
+ w2_scales = getattr(self.experts, "down_proj_scales", None)
+
+ # Store original shape
+ in_shape = x.size()
+
+ # Route tokens to experts
+ logits, expert_weights, expert_indices = route_tokens_xpu(
+ x,
+ self.router.weight,
+ getattr(self.router, "bias", None),
+ moe_top_k,
+ moe_num_experts,
+ moe_jitter_eps,
+ moe_normalize_expert_weights,
+ self.training,
+ )
+
+ # Reshape input for fused MoE
+ x_flat = x.view(-1, x.shape[-1])
+
+ # Call XPU fused MoE kernel
+ output = xpu_fused_moe(
+ hidden_states=x_flat,
+ w13=w13,
+ w13_scales=w13_scales,
+ w13_bias=w13_bias,
+ w2=w2,
+ w2_scales=w2_scales,
+ w2_bias=w2_bias,
+ topk_weights=expert_weights.float(),
+ topk_ids=expert_indices,
+ n_experts_per_token=moe_top_k,
+ activation=activation,
+ num_experts=num_experts_on_rank,
+ ep_rank=ep_rank,
+ ep_size=ep_size,
+ is_fp8=is_fp8,
+ is_int4=is_int4,
+ is_mxfp4=is_mxfp4,
+ )
+
+ # All-reduce across EP group to combine partial expert outputs
+ if ep_size > 1 and expert_parallel_group is not None:
+ import torch.distributed as dist
+ dist.all_reduce(output, op=dist.ReduceOp.SUM, group=expert_parallel_group)
+
+ # Restore original shape
+ output = output.view(in_shape)
+
+ return output, expert_weights
+
+
+# Export classes and functions
+__all__ = [
+ "MegaBlocksMoeMLP",
+ "xpu_fused_moe",
+ "cutlass_grouped_gemm",
+ "cutlass_grouped_gemm_xe2",
+]
\ No newline at end of file
diff --git a/build/torch29-cxx11-xpu20252-x86_64-linux/__init__.py b/build/torch29-cxx11-xpu20252-x86_64-linux/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..38075732c6d8fa0e1e6ef493145e1aca3851ae6b
--- /dev/null
+++ b/build/torch29-cxx11-xpu20252-x86_64-linux/__init__.py
@@ -0,0 +1,202 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+from ._ops import ops
+
+from .grouped_gemm import backend as gg_backend
+from .grouped_gemm import ops as gg_ops
+
+
+from ._layers.arguments import Arguments
+from ._layers.dmoe import ParallelDroplessMLP, dMoE
+from ._layers.glu import SparseGLU
+from ._layers.mlp import MLP, SparseMLP
+from ._layers.moe import MoE, ParallelMLP, get_load_balancing_loss
+
+from . import layers
+
+# This section contains the direct kernel exports (not inlcuded in the original code)
+def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
+ """
+ Compute exclusive cumulative sum along the specified dimension.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ result = ops.exclusive_cumsum(x, dim)
+ out.copy_(result)
+ return out
+
+
+def inclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
+ """
+ Compute inclusive cumulative sum along the specified dimension.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ result = ops.inclusive_cumsum(x, dim)
+ out.copy_(result)
+ return out
+
+
+def histogram(x: torch.Tensor, num_bins: int) -> torch.Tensor:
+ """
+ Compute histogram of input tensor values.
+
+ Args:
+ x: Input tensor
+ num_bins: Number of histogram bins
+
+ Returns:
+ Histogram tensor with counts for each bin
+ """
+ return ops.histogram(x, num_bins)
+
+
+def indices(
+ padded_bins: torch.Tensor,
+ block_size: int,
+ output_block_rows: int,
+ output_block_columns: int,
+) -> torch.Tensor:
+ """
+ Construct indices from padded bins for sparse operations.
+
+ Args:
+ padded_bins: Tensor containing bin boundaries
+ block_size: Size of each block
+ output_block_rows: Number of rows in output blocks
+ output_block_columns: Number of columns in output blocks
+
+ Returns:
+ Tensor containing constructed indices
+ """
+ return ops.indices(padded_bins, block_size, output_block_rows, output_block_columns)
+
+
+def replicate_forward(
+ x: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Forward pass of replicate operation - replicate values according to bin sizes.
+
+ Args:
+ x: Input tensor with values to replicate
+ bins: Tensor containing bin sizes
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ return ops.replicate_forward(x, bins, out)
+
+
+def replicate_backward(
+ grad: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Backward pass of replicate operation - reduce gradients back to bins.
+
+ Args:
+ grad: Gradient tensor to reduce
+ bins: Tensor containing bin sizes
+ out: Output tensor (modified in-place)
+
+ Returns:
+ The output tensor
+ """
+ return ops.replicate_backward(grad, bins, out)
+
+
+def sort(
+ x: torch.Tensor, end_bit: int, x_out: torch.Tensor, iota_out: torch.Tensor
+) -> torch.Tensor:
+ """
+ Radix sort with index tracking.
+
+ Args:
+ x: Input tensor to sort
+ end_bit: Number of bits to consider in sorting
+ x_out: Output tensor for sorted values
+ iota_out: Output tensor for sorted indices
+
+ Returns:
+ The sorted values tensor
+ """
+ return ops.sort(x, end_bit, x_out, iota_out)
+
+
+# Convenience functions for common use cases
+def cumsum(x: torch.Tensor, dim: int = -1, exclusive: bool = False) -> torch.Tensor:
+ """
+ Compute cumulative sum with automatic output allocation.
+
+ Args:
+ x: Input tensor
+ dim: Dimension along which to compute cumsum (default: last dimension)
+ exclusive: Whether to compute exclusive (True) or inclusive (False) cumsum
+
+ Returns:
+ New tensor containing the cumulative sum
+ """
+ out = torch.empty_like(x)
+ if exclusive:
+ return exclusive_cumsum(x, dim, out)
+ else:
+ return inclusive_cumsum(x, dim, out)
+
+
+def argsort(x: torch.Tensor, end_bit: int = 32) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Sort tensor and return both sorted values and indices.
+
+ Args:
+ x: Input tensor to sort
+ end_bit: Number of bits to consider in sorting
+
+ Returns:
+ Tuple of (sorted_values, sorted_indices)
+ """
+ x_out = torch.empty_like(x)
+ iota_out = torch.empty_like(x)
+ sort(x, end_bit, x_out, iota_out)
+ return x_out, iota_out
+
+
+# Export public API
+__all__ = [
+ "MyReplacementLayer",
+ # Direct kernel exports
+ "exclusive_cumsum",
+ "inclusive_cumsum",
+ "histogram",
+ "indices",
+ "replicate_forward",
+ "replicate_backward",
+ "sort",
+ "cumsum",
+ "argsort",
+ # Original exports
+ "Arguments",
+ "ParallelDroplessMLP",
+ "dMoE",
+ "SparseGLU",
+ "MLP",
+ "SparseMLP",
+ "MoE",
+ "ParallelMLP",
+ "get_load_balancing_loss",
+]
diff --git a/build/torch29-cxx11-xpu20252-x86_64-linux/_layers/__init__.py b/build/torch29-cxx11-xpu20252-x86_64-linux/_layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a720e7a2cc4e44636f6e433a2750e945dc38e8b2
--- /dev/null
+++ b/build/torch29-cxx11-xpu20252-x86_64-linux/_layers/__init__.py
@@ -0,0 +1,10 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# from megablocks.layers.dmoe import dMoE
+from .moe import MoE
+
+__all__ = [
+ 'MoE',
+ # 'dMoE',
+]
diff --git a/build/torch29-cxx11-xpu20252-x86_64-linux/_layers/activation_fn.py b/build/torch29-cxx11-xpu20252-x86_64-linux/_layers/activation_fn.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e1d956704840aa4daf7d1d71d24e051567feab9
--- /dev/null
+++ b/build/torch29-cxx11-xpu20252-x86_64-linux/_layers/activation_fn.py
@@ -0,0 +1,33 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Callable, Union
+
+import torch
+from ..stk import Matrix
+
+
+def act_fn(
+ x: Matrix,
+ function: Callable,
+ return_grad_fn: bool = False,
+ **kwargs,
+) -> Union[tuple[Matrix, Any] | Matrix]:
+ assert isinstance(x, Matrix)
+ with torch.set_grad_enabled(torch.is_grad_enabled() or return_grad_fn):
+ if return_grad_fn:
+ x.data.requires_grad = True
+ out = function(x.data, **kwargs)
+ y = Matrix(
+ x.size(),
+ out,
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
+ if return_grad_fn:
+ return y, out.backward
+ return y
diff --git a/build/torch29-cxx11-xpu20252-x86_64-linux/_layers/all_to_all.py b/build/torch29-cxx11-xpu20252-x86_64-linux/_layers/all_to_all.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ac7067bcaa34db1d82b340c43550fe3577aa7a3
--- /dev/null
+++ b/build/torch29-cxx11-xpu20252-x86_64-linux/_layers/all_to_all.py
@@ -0,0 +1,54 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import torch.distributed as dist
+
+
+class AllToAllOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op):
+ out = torch.empty((sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype)
+
+ ctx.input_shape = x.shape
+ ctx.output_split_sizes = output_split_sizes
+ ctx.input_split_sizes = input_split_sizes
+ ctx.group = group
+ handle = dist.all_to_all_single(
+ out,
+ x,
+ output_split_sizes=output_split_sizes,
+ input_split_sizes=input_split_sizes,
+ group=group,
+ async_op=async_op,
+ )
+ return out, handle
+
+ @staticmethod
+ def backward(ctx, grad, _):
+ if ctx.needs_input_grad[0]:
+ out = torch.empty(
+ ctx.input_shape,
+ device=grad.device,
+ dtype=grad.dtype,
+ )
+ dist.all_to_all_single(
+ out,
+ grad,
+ output_split_sizes=ctx.input_split_sizes,
+ input_split_sizes=ctx.output_split_sizes,
+ group=ctx.group,
+ )
+ return out, None, None, None, None
+ return None, None, None, None, None
+
+
+def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False):
+ return AllToAllOp.apply(
+ x,
+ output_split_sizes,
+ input_split_sizes,
+ group,
+ async_op,
+ )
diff --git a/build/torch29-cxx11-xpu20252-x86_64-linux/_layers/arguments.py b/build/torch29-cxx11-xpu20252-x86_64-linux/_layers/arguments.py
new file mode 100644
index 0000000000000000000000000000000000000000..4db9b1bd38bc2e2f421625c124f86b85f45c5ae0
--- /dev/null
+++ b/build/torch29-cxx11-xpu20252-x86_64-linux/_layers/arguments.py
@@ -0,0 +1,101 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import dataclasses
+from functools import partial
+from typing import Any, Callable, Optional, Union
+
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+
+# import megablocks.grouped_gemm_util as grouped_gemm
+from .. import grouped_gemm_util as grouped_gemm
+
+# Type annotation for in-place Tensor initialization function.
+InitFn = Union[Callable[[torch.Tensor], None], partial[torch.Tensor]]
+
+_ALLOWED_BITWIDTHS = (-1, 4, 8)
+
+DEFAULT_ACTIVATION_FN = partial(F.gelu, approximate='tanh')
+
+
+@dataclasses.dataclass
+class Arguments:
+ # Model arguments.
+ hidden_size: int = 1024
+ ffn_hidden_size: int = 4096
+ num_layers: int = 1
+ bias: bool = True
+ return_bias: bool = True
+ activation_fn: Optional[Callable] = DEFAULT_ACTIVATION_FN
+
+ # MoE arguments.
+ moe_num_experts: int = 1
+ moe_top_k: int = 1
+ moe_capacity_factor: int = 1
+ moe_normalize_expert_weights: Optional[Union[int, float]] = None
+ moe_loss_weight: float = 0.1
+ moe_jitter_eps: Optional[float] = None
+ moe_lbl_in_fp32: bool = False
+
+ # Parallelism arguments.
+ moe_expert_model_parallelism: bool = False
+ expert_parallel_group: Optional[dist.ProcessGroup] = None
+ pipeline_model_parallel_size: int = 1
+ num_layers_per_virtual_pipeline_stage: Optional[int] = None
+
+ # Compute arguments.
+ memory_optimized_mlp: bool = False
+ mlp_type: str = 'mlp'
+ mlp_impl: str = 'sparse'
+
+ # Initialization arguments.
+ fp16: bool = True
+ bf16: bool = False
+ device: Union[int, torch.device] = dataclasses.field(default_factory=torch.cuda.current_device)
+ init_method: InitFn = partial(torch.nn.init.normal_, mean=0.0, std=0.02)
+ output_layer_init_method: InitFn = init_method
+
+ # Benchmarking arguments.
+ uniform_expert_assignment: bool = False
+
+ # shared expert arguments
+ shared_expert: bool = False # enable using shared expert
+ fc_cls: Any = torch.nn.Linear # class of the fully connected layer in shared expert (purpose: to allow using custom FC layer eg te.Linear (for FP8))
+ fc_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict,) # kwargs for custom fc layers
+ remat_act_fn: bool = True # enable act fn to be rematerialized instead of stored
+ shared_expert_hidden_size: Optional[
+ int] = None # hidden size of the shared expert IF we want to set it to something different from hidden_size
+ shared_expert_weighted_sum: bool = False # enable using weighted sum for shared expert output (wieghted by number of experts used)
+
+ # Router Z-loss arguments
+ moe_zloss_weight: float = 0 # 1e-3 is a reasonable value
+ moe_zloss_in_fp32: bool = False
+
+ def __post_init__(self):
+ # Sparse MLP is not supported with triton >=3.2.0
+ # TODO: Remove this once sparse is supported with triton >=3.2.0
+ if self.__getattribute__('mlp_impl') == 'sparse':
+ try:
+ import triton
+ if triton.__version__ >= '3.2.0':
+ raise ValueError(
+ 'Sparse MLP is not supported with triton >=3.2.0. Please use mlp_impl="grouped" instead.',
+ )
+ except ImportError:
+ raise ImportError('Triton is required for sparse MLP implementation')
+
+ if self.__getattribute__('mlp_impl') == 'grouped':
+ grouped_gemm.assert_grouped_gemm_is_available()
+
+ if self.shared_expert_hidden_size is None:
+ self.shared_expert_hidden_size = self.ffn_hidden_size
+
+
+def from_megatron(megatron_args: Any):
+ args = Arguments()
+ for field in dataclasses.fields(args):
+ if hasattr(megatron_args, field.name):
+ setattr(args, field.name, getattr(megatron_args, field.name))
+ return args
diff --git a/build/torch29-cxx11-xpu20252-x86_64-linux/_layers/common.py b/build/torch29-cxx11-xpu20252-x86_64-linux/_layers/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d07109702963ba48a3b94ab860807954dfd79c1
--- /dev/null
+++ b/build/torch29-cxx11-xpu20252-x86_64-linux/_layers/common.py
@@ -0,0 +1,26 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+from .arguments import Arguments
+
+
+def dtype(args: Arguments):
+ if args.fp16:
+ return torch.float16
+ elif args.bf16:
+ return torch.bfloat16
+ return None
+
+
+def cast_if_autocast_enabled(tensor):
+ if torch.is_autocast_enabled():
+ if tensor.device.type == 'cuda':
+ dtype = torch.get_autocast_gpu_dtype()
+ elif tensor.device.type == 'cpu':
+ dtype = torch.get_autocast_cpu_dtype()
+ else:
+ raise NotImplementedError()
+ return tensor.to(dtype=dtype)
+ return tensor
diff --git a/build/torch29-cxx11-xpu20252-x86_64-linux/_layers/dmlp_registry.py b/build/torch29-cxx11-xpu20252-x86_64-linux/_layers/dmlp_registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..de2ed047042e438c7190ebb139b6f7f30009734c
--- /dev/null
+++ b/build/torch29-cxx11-xpu20252-x86_64-linux/_layers/dmlp_registry.py
@@ -0,0 +1,42 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Union
+
+from . import glu, mlp
+from .arguments import Arguments
+
+MlpType = Union[mlp.SparseMLP, glu.SparseGLU]
+
+_REGISTRY = {
+ 'mlp': {
+ 'grouped': mlp.GroupedMLP,
+ 'sparse': mlp.SparseMLP,
+ },
+ 'glu': {
+ 'grouped': glu.GroupedGLU,
+ 'sparse': glu.SparseGLU,
+ },
+}
+
+
+def get(args: Arguments) -> MlpType:
+ """Returns an MLP for use in a dMoE instance.
+
+ Uses the provided arguments to instantiate the appropriate
+ MLP instance. This only contains MLPs for use in dMoEs
+ (ie. only for the dropless versions of MoEs).
+
+ Args:
+ args: propagated Arguments dataclass.
+
+ Returns:
+ An instantiated MLP constructed using the input args.
+ """
+ if args.mlp_type not in _REGISTRY:
+ raise ValueError(f'Unsupported mlp type: {args.mlp_type}')
+
+ if args.mlp_impl not in _REGISTRY[args.mlp_type]:
+ raise ValueError(f'{args.mlp_type} does not support {args.mlp_impl} backend.',)
+
+ return _REGISTRY[args.mlp_type][args.mlp_impl](args)
diff --git a/build/torch29-cxx11-xpu20252-x86_64-linux/_layers/dmoe.py b/build/torch29-cxx11-xpu20252-x86_64-linux/_layers/dmoe.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d0375a4df2f27134c4127e60be04f3b45693050
--- /dev/null
+++ b/build/torch29-cxx11-xpu20252-x86_64-linux/_layers/dmoe.py
@@ -0,0 +1,337 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import numpy as np
+import torch
+
+# try:
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/dmoe.py is needed.',
+# )
+
+# import megablocks.ops as ops
+# # from megablocks.ops import ops
+# from megablocks.layers import common, dmlp_registry, moe, mpu
+# from megablocks.layers.arguments import Arguments
+
+from .. import stk
+from .. import ops
+from . import common, dmlp_registry, moe, mpu
+from .arguments import Arguments
+
+def promote_scalar(x):
+ return x.view(1) if not len(x.size()) else x
+
+
+class ParallelDroplessMLP(moe.ParallelMLP):
+
+ def __init__(self, args: Arguments):
+ super(ParallelDroplessMLP, self).__init__(args)
+ self.hidden_size = args.hidden_size
+ self.ffn_hidden_size = mpu.features_per_rank(args)
+ self.blocking = 128
+ self.mlp = dmlp_registry.get(args)
+
+ # Calculate the number of bits needed to represent the column indices
+ # in the intermediate sparse matrix.
+ max_column_index = ((self.ffn_hidden_size * self.num_experts) // self.blocking)
+ self.transpose_sort_end_bit = max(
+ int(np.ceil(np.log2(max_column_index))),
+ 1,
+ )
+
+ def sparse_transpose(self, size, row_indices, column_indices, offsets):
+ block_columns = size[1] // self.blocking
+
+ # Sort row indices by column indices to get the transposed matrix's
+ # column indices.
+ #
+ # NOTE: Our sort operation uses the same width indices as the input values.
+ # To avoid overflow when we have large activation matrices we cast to
+ # 32-bit before sorting.
+ _, gather_indices = ops.sort(
+ column_indices.int(),
+ self.transpose_sort_end_bit,
+ )
+
+ # There are a constant number of blocks in every row of the sparse matrix.
+ # A blocks offset is:
+ #
+ # row_index * blocks_per_row + column_index % blocks_per_row
+ #
+ # Once we have the block offsets ordered for transposition we can divide
+ # by blocks_per_row to get the transposed column indices.
+ column_indices_t = row_indices.gather(0, gather_indices.long())
+ block_offsets_t = gather_indices.int()
+
+ zero = torch.zeros((1,), dtype=torch.int32, device=row_indices.device)
+ nnz_per_column = ops.histogram(column_indices, block_columns)
+ nnz_per_column = ops.inclusive_cumsum(nnz_per_column, 0)
+ if nnz_per_column.dim() == 0:
+ # This addresses an edge case when ffn_hidden_size is equal to self.blocking.
+ nnz_per_column = nnz_per_column.unsqueeze(0)
+ offsets_t = torch.cat([zero, nnz_per_column])
+ return column_indices_t, offsets_t, block_offsets_t
+
+ def topology(self, x, padded_bins):
+ padded_tokens, _ = x.size()
+ assert padded_tokens % self.blocking == 0
+ if self.ffn_hidden_size % self.blocking != 0:
+ raise ValueError(
+ f'The ffn_hidden_size {self.ffn_hidden_size} must be divisible by ' +
+ f'the block size {self.blocking}. Please update your configuration.',
+ )
+
+ # Offsets for the sparse matrix. All rows have the
+ # same number of nonzero blocks dictated by the
+ # dimensionality of a single expert.
+ block_rows = padded_tokens // self.blocking
+ blocks_per_row = self.ffn_hidden_size // self.blocking
+ offsets = torch.arange(
+ 0,
+ block_rows * blocks_per_row + 1,
+ blocks_per_row,
+ dtype=torch.int32,
+ device=x.device,
+ )
+
+ # Indices for the sparse matrix. The indices for
+ # the intermediate matrix are dynamic depending
+ # on the mapping of tokens to experts.
+ column_indices = ops.topology(
+ padded_bins,
+ self.blocking,
+ block_rows,
+ blocks_per_row,
+ )
+
+ # TODO(tgale): This is unused. Remove the need for this in stk.
+ # For now, use meta init to save the device memory.
+ data = torch.empty(
+ column_indices.numel(),
+ self.blocking,
+ self.blocking,
+ dtype=common.dtype(self.args),
+ device='meta',
+ )
+ shape = (
+ padded_tokens,
+ self.ffn_hidden_size * mpu.experts_per_rank(self.args),
+ )
+ row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
+ column_indices_t, offsets_t, block_offsets_t = self.sparse_transpose(
+ shape,
+ row_indices,
+ column_indices,
+ offsets,
+ )
+ return stk.Matrix(
+ shape,
+ data,
+ row_indices,
+ column_indices,
+ offsets,
+ column_indices_t,
+ offsets_t,
+ block_offsets_t,
+ )
+
+ def indices_and_padded_bins(self, top_experts):
+ # Sort the expert ids to produce the scatter/gather
+ # indices for the permutation.
+ top_experts = top_experts.int()
+ bin_ids, indices = ops.sort(top_experts, self.sort_end_bit)
+
+ # Histogram the expert ids to identify the number of
+ # tokens routed to each expert.
+ tokens_per_expert = ops.histogram(top_experts, self.num_experts)
+
+ # Round the token counts up to the block size used in
+ # the matrix muliplications. Caculate the starting
+ # position of each bin.
+ padded_tokens_per_expert = ops.round_up(
+ tokens_per_expert,
+ self.blocking,
+ )
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ padded_bins = promote_scalar(padded_bins)
+
+ # Calculate the bin bounds for the sorted tokens.
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ bins = promote_scalar(bins)
+ return indices, bin_ids, bins, padded_bins, tokens_per_expert
+
+ def sparse_forward_once(self, x, expert_weights, top_experts):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, padded_bins, tokens_per_expert = (self.indices_and_padded_bins(top_experts))
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.padded_gather(
+ x,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ self.top_k,
+ )
+
+ # Create the sparse matrix topology.
+ with torch.no_grad():
+ topo = self.topology(x, padded_bins)
+
+ # Perform the expert computation.
+ x = self.mlp(x, topo)
+
+ # Un-route the data for the MoE output.
+ x = ops.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ padded_bins,
+ self.top_k,
+ )
+ return x, tokens_per_expert
+
+ # For use in the base-class parallel_forward_once.
+ def sparse_permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy, # unused
+ top_k,
+ ):
+
+ # Round the token counts up to the block size used in the matrix
+ # multiplication. Calculate the starting position of each bin.
+ padded_tokens_per_expert = ops.round_up(
+ tokens_per_expert,
+ self.blocking,
+ )
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
+ padded_bins = promote_scalar(padded_bins)
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
+
+ # Create the sparse matrix topology.
+ with torch.no_grad():
+ topo = self.topology(x, padded_bins)
+
+ # Perform the expert computation.
+ x = self.mlp(x, topo)
+
+ # Un-route the data for the MoE output.
+ return ops.padded_scatter(
+ x,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ padded_bins,
+ top_k,
+ )
+
+ def grouped_forward_once(self, x, expert_weights, top_experts):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ out = self.grouped_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ -1, # unused
+ self.args.moe_top_k,
+ )
+ return out, tokens_per_expert
+
+ def grouped_permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy, # unused
+ top_k,
+ ):
+
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ x = ops.gather(x, indices, bin_ids, bins, top_k)
+
+ # Perform the expert computation.
+ x = self.mlp(x, tokens_per_expert)
+
+ # Un-route the data for the MoE output.
+ return ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
+
+ def forward_once(self, x, expert_weights, top_experts):
+ if self.args.mlp_impl == 'sparse':
+ return self.sparse_forward_once(x, expert_weights, top_experts)
+ else:
+ return self.grouped_forward_once(x, expert_weights, top_experts)
+
+ def permute_and_compute(
+ self,
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ ):
+ if self.args.mlp_impl == 'sparse':
+ return self.sparse_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ )
+ else:
+ return self.grouped_permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capactiy,
+ top_k,
+ )
+
+
+class dMoE(moe.MoE):
+
+ def _init_experts_mlp(self, args: Arguments):
+ return ParallelDroplessMLP(args)
diff --git a/build/torch29-cxx11-xpu20252-x86_64-linux/_layers/gelu.py b/build/torch29-cxx11-xpu20252-x86_64-linux/_layers/gelu.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4c9e6532798615b5c12c96694241a4c18ee8f7b
--- /dev/null
+++ b/build/torch29-cxx11-xpu20252-x86_64-linux/_layers/gelu.py
@@ -0,0 +1,52 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# try:
+# import stk
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/gelu.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+import torch.nn.functional as F
+
+
+@torch.jit.script
+def _gelu_backward_inplace(g, x):
+ tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
+ ff = (0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out))
+ return g.mul_(ff)
+
+
+def gelu_backward_(grad: stk.Matrix, x: stk.Matrix):
+ # NOTE: The two sparse matrices must have the same topology.
+ if isinstance(grad, stk.Matrix) and isinstance(x, stk.Matrix):
+ return stk.Matrix(
+ x.size(),
+ _gelu_backward_inplace(grad.data, x.data),
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
+ return _gelu_backward_inplace(grad, x)
+
+
+def gelu(x: stk.Matrix):
+ assert isinstance(x, stk.Matrix)
+ return stk.Matrix(
+ x.size(),
+ F.gelu(x.data, approximate='tanh'),
+ x.row_indices,
+ x.column_indices,
+ x.offsets,
+ x.column_indices_t,
+ x.offsets_t,
+ x.block_offsets_t,
+ )
diff --git a/build/torch29-cxx11-xpu20252-x86_64-linux/_layers/glu.py b/build/torch29-cxx11-xpu20252-x86_64-linux/_layers/glu.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f297a41ff6a1a2a285f5b461951672364b898da
--- /dev/null
+++ b/build/torch29-cxx11-xpu20252-x86_64-linux/_layers/glu.py
@@ -0,0 +1,244 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# import stk.ops
+# try:
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/glu.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+
+# from megablocks import grouped_gemm_util as gg
+# from megablocks.layers import common, mpu
+# from megablocks.layers.activation_fn import act_fn
+# from megablocks.layers.arguments import Arguments
+# from megablocks.layers.mlp import (
+# SharedMLP,
+# SparseMLP,
+# create_dmoe_expert_weights,
+# resolve_dtensor,
+# )
+
+from .. import grouped_gemm_util as gg
+from . import common, mpu
+from .activation_fn import act_fn
+from .arguments import Arguments
+from .mlp import (
+ SharedMLP,
+ SparseMLP,
+ create_dmoe_expert_weights,
+ resolve_dtensor,
+)
+
+
+class SparseGLU(SparseMLP):
+
+ def __init__(self, args: Arguments):
+ super().__init__(args)
+ self.v1 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ with torch.no_grad():
+ self.v1.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ ),
+ )
+
+ mpu.set_expert_model_parallel_attributes(
+ self.v1,
+ self._should_set_parallelism_attribute,
+ )
+
+ def forward(self, x, topo):
+ if self.args.memory_optimized_mlp:
+ raise NotImplementedError(
+ 'Memory optimized implementation not yet supported with GLU with sparse kernels.',
+ )
+
+ w1, v1, w2 = self.scale_grad(self.w1), self.scale_grad(self.v1,), self.scale_grad(self.w2)
+ w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2)
+
+ # Compute the GLU.
+ x1 = stk.ops.sdd(x, w1.t(), topo)
+ x2 = stk.ops.sdd(x, v1.t(), topo)
+
+ activation_fn_out = act_fn(x1, self.args.activation_fn)
+ x1 = stk.ops.mul(activation_fn_out, x2)
+
+ return stk.ops.dsd(x1, w2)
+
+
+class MemoryOptimizedGroupedGLU(torch.autograd.Function):
+ """GroupedMLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, v1, w2, batch_sizes, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ v1 = v1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], v1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not v1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1', 'v1' and 'w2'.")
+
+ # Layer 0: x @ w1.t().
+ assert gg.backend is not None
+ sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True)
+ v1_out = gg.backend.gmm(x, v1, batch_sizes, trans_b=True)
+
+ # GeLU.
+ activation_fn_out = activation_fn(sdd_out) * v1_out
+
+ # Layer 1: x @ w2.
+ dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, v1, w2, batch_sizes, x, sdd_out, v1_out)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx, ddsd_out):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # Unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, v1, w2 = saved_tensors[:3]
+ batch_sizes = saved_tensors[3]
+ x = saved_tensors[4]
+ sdd_out, v1_out = saved_tensors[5:7]
+
+ # Rematerialize activation_fn output.
+ activation_fn = ctx.activation_fn
+ with torch.set_grad_enabled(True):
+ sdd_out.requires_grad = True
+ v1_out.requires_grad = True
+ activation_fn_out = activation_fn(sdd_out) * v1_out
+ activation_grad_fn = activation_fn_out.backward
+
+ # Compute dw2 with recomputed activation_fn output.
+ assert gg.backend is not None
+ dw2 = gg.backend.gmm(
+ activation_fn_out,
+ ddsd_out,
+ batch_sizes,
+ trans_a=True,
+ )
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ gg.backend.gmm(
+ ddsd_out,
+ w2,
+ batch_sizes,
+ trans_b=True,
+ c=dactivation_fn_out,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out)
+ dsdd_out = sdd_out.grad
+ dv1_out = v1_out.grad
+
+ # Compute dw1.
+ dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True)
+
+ # Compute dv1.
+ dv1 = gg.backend.gmm(dv1_out, x, batch_sizes, trans_a=True)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ dx = ddsd_out
+ gg.backend.gmm(dsdd_out, w1, batch_sizes, c=dx)
+ dx += gg.backend.gmm(dv1_out, v1, batch_sizes)
+ return dx, dw1, dv1, dw2, None, None
+
+
+memory_optimized_grouped_glu = MemoryOptimizedGroupedGLU.apply
+
+
+class GroupedGLU(SparseGLU):
+
+ def forward(self, x, tokens_per_expert):
+ batch_sizes = tokens_per_expert.cpu().to(torch.long)
+ w1, v1, w2 = (
+ self.scale_grad(self.w1),
+ self.scale_grad(self.v1),
+ self.scale_grad(self.w2),
+ )
+ w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2)
+
+ # Re-shape the weights for the grouped GEMMs.
+ ne = mpu.experts_per_rank(self.args)
+ w1 = w1.view(ne, -1, self.args.hidden_size)
+ v1 = v1.view(ne, -1, self.args.hidden_size)
+ w2 = w2.view(ne, -1, self.args.hidden_size)
+
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_grouped_glu(
+ x,
+ w1,
+ v1,
+ w2,
+ batch_sizes,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ assert gg.ops is not None
+ x1 = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
+ x2 = gg.ops.gmm(x, v1, batch_sizes, trans_b=True)
+ x1 = self.args.activation_fn(x1) * x2
+ return gg.ops.gmm(x1, w2, batch_sizes)
+
+
+class SharedGLU(SharedMLP):
+ """GPU for shared expert.
+
+ Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTGLU class
+ """
+
+ def __init__(self, args: Arguments):
+ super().__init__(args)
+ self.gate_proj = args.fc_cls(
+ args.hidden_size,
+ self.args.shared_expert_hidden_size,
+ **self.fc_kwargs,
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))
diff --git a/build/torch29-cxx11-xpu20252-x86_64-linux/_layers/memory_test.py b/build/torch29-cxx11-xpu20252-x86_64-linux/_layers/memory_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..74d1166931b712635131985b25a89f4ca23e576d
--- /dev/null
+++ b/build/torch29-cxx11-xpu20252-x86_64-linux/_layers/memory_test.py
@@ -0,0 +1,103 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import gc
+
+import torch
+import torch.distributed as dist
+
+# from megablocks.layers import arguments, dmoe
+from . import arguments, dmoe
+
+_TESTS = ((8, 2048, 4096, 4096, 32, 4),)
+
+
+def get_tensors():
+ ptrs = set()
+ out = []
+ for obj in gc.get_objects():
+ if torch.is_tensor(obj):
+ if not obj.is_contiguous() or obj.data_ptr() in ptrs:
+ continue
+ out.append(obj)
+ ptrs.add(obj.data_ptr())
+ return out
+
+
+def test_memory(
+ group,
+ batch_size,
+ sequence_length,
+ hidden_size,
+ ffn_hidden_size,
+ num_experts,
+ top_k,
+):
+ args = arguments.Arguments(
+ hidden_size=hidden_size,
+ ffn_hidden_size=ffn_hidden_size,
+ moe_num_experts=num_experts,
+ moe_top_k=top_k,
+ moe_expert_model_parallelism=True,
+ expert_parallel_group=group,
+ fp16=False,
+ bf16=True,
+ device=torch.cuda.current_device(),
+ )
+ layer = dmoe.dMoE(args).cuda()
+
+ x = torch.randn((batch_size, sequence_length, hidden_size),
+ device=torch.cuda.current_device(),
+ dtype=torch.bfloat16).requires_grad_(True)
+ torch.cuda.empty_cache()
+
+ # Run forward + backward.
+ # with torch.autograd.detect_anomaly():
+ out, _ = layer(x)
+ out.mean().backward()
+
+ # Report peak memory.
+ mem = torch.cuda.max_memory_allocated()
+ print('Max Memory Allocated = {:0.0f}MiB'.format(mem / 1e6))
+ print('Max Memory Reserved = {:0.0f}MiB'.format(torch.cuda.max_memory_reserved() / 1e6,),)
+
+ # Calculate weight and gradient memory usage.
+ weight_memory = 2 * (
+ layer.router.layer.weight.numel() + layer.experts.mlp.w1.numel() + layer.experts.mlp.w2.numel()
+ )
+
+ def grad_numel(x):
+ if x.grad is not None:
+ return x.grad.numel()
+ return 0
+
+ grad_memory = 2 * (
+ grad_numel(layer.router.layer.weight) + grad_numel(layer.experts.mlp.w1) + grad_numel(layer.experts.mlp.w2)
+ )
+ weight_memory += grad_memory
+
+ print('Weight Memory Allocated = {:0.0f}MiB'.format(weight_memory / 1e6))
+ print('Activation Memory Allocated = {:0.0f}MiB'.format((mem - weight_memory) / 1e6,),)
+
+ # Manually calculate GPU memory usage from the garbage
+ # collector.
+ gc.collect()
+ total = 0
+ tensors = get_tensors()
+ tensors = sorted(tensors, key=lambda x: -x.numel())
+ for i, t in enumerate(tensors):
+ total += t.numel()
+ print(f'{i}: {t.shape}, {t.numel() * 2}')
+ del tensors
+
+ print('Total Bytes Found = {:0.0f}MiB'.format(total * 2 / 1e6))
+
+
+if __name__ == '__main__':
+ assert dist.is_available()
+ group = dist.init_process_group(backend='nccl')
+ local_rank = dist.get_rank(group)
+ torch.cuda.set_device(local_rank)
+
+ for args in _TESTS:
+ test_memory(group, *args)
diff --git a/build/torch29-cxx11-xpu20252-x86_64-linux/_layers/mlp.py b/build/torch29-cxx11-xpu20252-x86_64-linux/_layers/mlp.py
new file mode 100644
index 0000000000000000000000000000000000000000..c99afb9904c24a8b6a83e79059cd1251dbbfd99e
--- /dev/null
+++ b/build/torch29-cxx11-xpu20252-x86_64-linux/_layers/mlp.py
@@ -0,0 +1,587 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+# try:
+# import stk
+# import stk.backend.triton_kernels
+# import stk.ops
+# except ImportError:
+# import warnings
+# warnings.warn(
+# 'Please add `stanford-stk` if megablocks/_layers/mlp.py is needed.',
+# )
+
+from .. import stk
+
+import torch
+from packaging import version
+
+# from megablocks import grouped_gemm_util as gg
+# from megablocks.layers import common, gelu, mpu
+# from megablocks.layers.activation_fn import act_fn
+# from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
+
+from .. import grouped_gemm_util as gg
+from . import common, gelu, mpu
+from .activation_fn import act_fn
+from .arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
+
+class ScaleGradient(torch.autograd.Function):
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx: Any, x: torch.Tensor, scale: float):
+ ctx.scale = scale
+ return x
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx: torch.Tensor, grad: torch.Tensor):
+ return grad * ctx.scale, None
+
+
+scale_gradient = ScaleGradient.apply
+
+
+def resolve_dtensor(weight: torch.Tensor):
+ if version.parse(torch.__version__) >= version.parse('2.0.0'):
+ from torch.distributed._tensor import DTensor
+ if isinstance(weight, DTensor):
+ return weight.to_local()
+ return weight
+
+
+def create_moe_expert_weights(
+ args: Arguments,
+ num_experts: int,
+ ffn_hidden_size: int,
+ hidden_size: int,
+ init_method: InitFn,
+):
+ # Create the entire weight matrix such that the sampled weights will
+ # not vary between data parallelism and expert model parallelism for
+ # the same random seed.
+ master_weights = torch.empty(
+ num_experts,
+ ffn_hidden_size,
+ hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ )
+ init_method(master_weights)
+
+ if not args.moe_expert_model_parallelism:
+ return master_weights
+
+ # Calculate the amount of sharding in each dimension.
+ expert_sharding_degree = mpu.expert_sharding_degree(args)
+ hidden_sharding_degree = mpu.hidden_sharding_degree(args)
+
+ # Calculate the experts per rank.
+ #
+ # NOTE: We assign ranks to be expert parallel before going
+ # tensor parallel.
+ rank = mpu.get_expert_parallel_rank(args)
+ expert_rank = rank % expert_sharding_degree
+ num_experts_per_rank = num_experts // expert_sharding_degree
+ start_expert = expert_rank * num_experts_per_rank
+ end_expert = (expert_rank + 1) * num_experts_per_rank
+
+ # Calculate the rows per rank.
+ row_rank = rank // expert_sharding_degree
+ num_rows_per_rank = ffn_hidden_size // hidden_sharding_degree
+ start_row = row_rank * num_rows_per_rank
+ end_row = (row_rank + 1) * num_rows_per_rank
+
+ # Slice the weight matrix to get the chunk for this rank.
+ with torch.no_grad():
+ weights = master_weights[start_expert:end_expert, start_row:end_row]
+ return weights
+
+
+class MLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ # expert_parallel_world_size = mpu.get_expert_parallel_world_size(args)
+ experts_per_rank = mpu.experts_per_rank(args)
+
+ self.w1 = torch.nn.Parameter(
+ torch.empty(
+ experts_per_rank,
+ args.hidden_size,
+ mpu.features_per_rank(args),
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ self.w2 = torch.nn.Parameter(
+ torch.empty(
+ experts_per_rank,
+ mpu.features_per_rank(args),
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w1,
+ args.moe_expert_model_parallelism,
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w2,
+ args.moe_expert_model_parallelism,
+ )
+
+ # Initialize the parameters for the MLP.
+ #
+ # NOTE: It is important that we create the weight tensors prior
+ # to creating the master weights and slicing our the piece for
+ # this rank. If the master weights are created first the PyTorch
+ # caching allocator appears to use the same memory block for these
+ # and the slice which causes large increases in our peak memory
+ # usage.
+ with torch.no_grad():
+ w1 = create_moe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ )
+ self.w1.copy_(w1.transpose(1, 2).contiguous())
+ self.w2.copy_(
+ create_moe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.output_layer_init_method,
+ ),
+ )
+
+ self.gradient_scale = None
+ if self.args.moe_expert_model_parallelism:
+ self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,)
+
+ def scale_grad(self, w):
+ if self.gradient_scale is None:
+ return w
+ return scale_gradient(w, self.gradient_scale)
+
+ def forward(self, x):
+ w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2)
+ w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2)
+ x = torch.bmm(x, w1)
+ x = self.args.activation_fn(x)
+ return torch.bmm(x, w2)
+
+
+def create_dmoe_expert_weights(
+ args: Arguments,
+ num_experts: int,
+ rows: int,
+ columns: int,
+ init_method: InitFn,
+):
+ weights = create_moe_expert_weights(
+ args,
+ num_experts,
+ rows,
+ columns,
+ init_method,
+ )
+ return weights.view([-1, columns])
+
+
+class MemoryOptimizedMLP(torch.autograd.Function):
+ """Sparse MLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, w2, topo, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.")
+
+ topo_tensors = (
+ topo.row_indices,
+ topo.column_indices,
+ topo.offsets,
+ topo.column_indices_t,
+ topo.offsets_t,
+ topo.block_offsets_t,
+ )
+
+ # Layer 0: x @ w1.t().
+ sdd_out = stk.ops.sdd(x, w1.t(), topo)
+
+ # GeLU.
+ activation_fn_out = act_fn(sdd_out, activation_fn)
+
+ # Layer 1: x @ w2.
+ dsd_out = stk.ops.dsd(activation_fn_out, w2)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.shape = topo.shape
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.data.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, w2, *topo_tensors, x, sdd_out.data)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx, ddsd_out):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, w2 = saved_tensors[:2]
+ topo_tensors = saved_tensors[2:8]
+ x = saved_tensors[8]
+ sdd_out_data = saved_tensors[9]
+
+ # rematerialize activation function output
+ activation_fn = ctx.activation_fn
+ sdd_out = stk.Matrix(ctx.shape, sdd_out_data, *topo_tensors)
+ activation_fn_out, activation_grad_fn = act_fn(
+ sdd_out,
+ activation_fn,
+ return_grad_fn=True,
+ )
+
+ # Compute dw2 with recomputed activation_fn output.
+ dw2 = stk.ops.dsd(activation_fn_out.t(), ddsd_out)
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ stk.backend.triton_kernels.sdd(
+ ddsd_out,
+ w2.t(),
+ dactivation_fn_out.shape,
+ dactivation_fn_out.data,
+ dactivation_fn_out.offsets,
+ dactivation_fn_out.row_indices,
+ dactivation_fn_out.column_indices,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ if activation_fn is DEFAULT_ACTIVATION_FN:
+ dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out)
+ else:
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out.data)
+ dsdd_out = stk.Matrix(ctx.shape, sdd_out.data.grad, *topo_tensors)
+
+ # Compute dw1.
+ dw1 = stk.ops.dsd(dsdd_out.t(), x)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ stk.backend.triton_kernels.dsd(
+ dsdd_out.shape,
+ dsdd_out.data,
+ dsdd_out.offsets,
+ dsdd_out.row_indices,
+ dsdd_out.column_indices,
+ dsdd_out.offsets_t,
+ dsdd_out.column_indices_t,
+ dsdd_out.block_offsets_t,
+ False,
+ w1,
+ ddsd_out,
+ )
+ dx = ddsd_out
+ return dx, dw1, dw2, None, None
+
+
+memory_optimized_mlp = MemoryOptimizedMLP.apply
+
+
+class SparseMLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ self._num_rows_per_rank = mpu.experts_per_rank(args) * mpu.features_per_rank(args)
+
+ self.w1 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ self.w2 = torch.nn.Parameter(
+ torch.empty(
+ self._num_rows_per_rank,
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+
+ # Initialize the parameters for the MLP.
+ #
+ # NOTE: It is important that we create the weight tensors prior
+ # to creating the master weights and slicing our the piece for
+ # this rank. If the master weights are created first the PyTorch
+ # caching allocator appears to use the same memory block for these
+ # and the slice which causes large increases in our peak memory
+ # usage.
+ with torch.no_grad():
+ self.w1.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.init_method,
+ ),
+ )
+ self.w2.copy_(
+ create_dmoe_expert_weights(
+ args,
+ args.moe_num_experts,
+ args.ffn_hidden_size,
+ args.hidden_size,
+ args.output_layer_init_method,
+ ),
+ )
+
+ self._should_set_parallelism_attribute = args.moe_expert_model_parallelism
+ mpu.set_expert_model_parallel_attributes(
+ self.w1,
+ self._should_set_parallelism_attribute,
+ )
+ mpu.set_expert_model_parallel_attributes(
+ self.w2,
+ self._should_set_parallelism_attribute,
+ )
+
+ self.gradient_scale = None
+ if self.args.moe_expert_model_parallelism:
+ self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,)
+
+ def scale_grad(self, w):
+ if self.gradient_scale is None:
+ return w
+ return scale_gradient(w, self.gradient_scale)
+
+ def forward(self, x, topo):
+ w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2)
+ w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2)
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_mlp(
+ x,
+ w1,
+ w2,
+ topo,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ x = stk.ops.sdd(x, w1.t(), topo)
+ activation_fn_out = act_fn(x, self.args.activation_fn)
+ return stk.ops.dsd(activation_fn_out, w2)
+
+
+class MemoryOptimizedGroupedMLP(torch.autograd.Function):
+ """GroupedMLP with manually scheduled memory reuse."""
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
+ def forward(ctx, x, w1, w2, batch_sizes, activation_fn):
+ # Cast inputs using ctx dtype from AMP
+ if ctx._fwd_used_autocast:
+ x = x.to(ctx._dtype)
+ w1 = w1.to(ctx._dtype)
+ w2 = w2.to(ctx._dtype)
+ # x: [m, k], w1: [n, k], w2: [n, k]
+ if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()):
+ raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.")
+
+ # Layer 0: x @ w1.t().
+ assert gg.backend is not None
+ sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True)
+
+ # activation_fn
+ activation_fn_out = activation_fn(sdd_out)
+
+ # Layer 1: x @ w2.
+ dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes)
+
+ # NOTE: Save the input to the layer and the activation_fn input for
+ # gradient computation. We'll re-compute the activation_fn forward
+ # pass in the backward pass to avoid materializing another
+ # intermediate.
+ ctx.x_shape = x.shape
+ ctx.sdd_out_shape = sdd_out.shape
+ ctx.dtype = x.dtype
+ ctx.activation_fn = activation_fn
+ ctx.save_for_backward(w1, w2, batch_sizes, x, sdd_out)
+ return dsd_out
+
+ @staticmethod
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
+ def backward(ctx: Any, ddsd_out: torch.Tensor):
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
+ raise ValueError('Expected all MLP inputs to need grad.')
+
+ # Unpack saved tensors
+ # dtype = ctx.dtype
+ saved_tensors = ctx.saved_tensors
+ w1, w2 = saved_tensors[:2]
+ batch_sizes = saved_tensors[2]
+ x = saved_tensors[3]
+ sdd_out = saved_tensors[4]
+
+ # Rematerialize activation_fn output.
+ activation_fn = ctx.activation_fn
+ with torch.set_grad_enabled(True):
+ sdd_out.requires_grad = True
+ activation_fn_out = activation_fn(sdd_out)
+ activation_grad_fn = activation_fn_out.backward
+
+ # Compute dw2 with recomputed activation_fn output.
+ assert gg.backend is not None
+ dw2 = gg.backend.gmm(
+ activation_fn_out,
+ ddsd_out,
+ batch_sizes,
+ trans_a=True,
+ )
+
+ # Compute dactivation_fn_out.
+ #
+ # NOTE: We reuse the activation_fn_out allocation.
+ dactivation_fn_out = activation_fn_out
+ gg.backend.gmm(
+ ddsd_out,
+ w2,
+ batch_sizes,
+ trans_b=True,
+ c=dactivation_fn_out,
+ )
+
+ # Compute dsdd_out.
+ #
+ # NOTE: This reuses the dactivation_fn_out allocation.
+ if activation_fn is DEFAULT_ACTIVATION_FN:
+ dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out)
+ else:
+ assert activation_grad_fn is not None
+ activation_grad_fn(dactivation_fn_out)
+ dsdd_out = sdd_out.grad
+
+ # Compute dw1.
+ dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True)
+
+ # Compute dx.
+ #
+ # NOTE: This reuses the ddsd_out allocation.
+ gg.backend.gmm(dsdd_out, w1, batch_sizes, c=ddsd_out)
+ dx = ddsd_out
+ return dx, dw1, dw2, None, None
+
+
+memory_optimized_grouped_mlp = MemoryOptimizedGroupedMLP.apply
+
+
+class GroupedMLP(SparseMLP):
+
+ def forward(self, x, tokens_per_expert):
+ batch_sizes = tokens_per_expert.cpu().to(torch.long)
+ w1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.w2))
+
+ # Re-shape the weights for the grouped GEMMs.
+ ne = mpu.experts_per_rank(self.args)
+ w1 = resolve_dtensor(w1).view(ne, -1, self.args.hidden_size)
+ w2 = resolve_dtensor(w2).view(ne, -1, self.args.hidden_size)
+
+ if self.args.memory_optimized_mlp:
+ return memory_optimized_grouped_mlp(
+ x,
+ w1,
+ w2,
+ batch_sizes,
+ self.args.activation_fn,
+ )
+
+ # Compute the MLP.
+ assert gg.ops is not None
+ x = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
+ x = self.args.activation_fn(x)
+ return gg.ops.gmm(x, w2, batch_sizes)
+
+
+class SharedMLP(torch.nn.Module):
+ """MLP for shared expert.
+
+ Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTMLP class
+ """
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+ self.fc_kwargs: dict[str, Any] = {
+ 'bias': args.bias,
+ 'device': args.device,
+ }
+ self.fc_kwargs.update(args.fc_kwargs)
+
+ self.up_proj = args.fc_cls(
+ args.hidden_size,
+ args.shared_expert_hidden_size,
+ **self.fc_kwargs,
+ )
+ self.act = args.activation_fn
+ self.down_proj = args.fc_cls(
+ args.shared_expert_hidden_size,
+ args.hidden_size,
+ **self.fc_kwargs,
+ )
+ self.down_proj._is_residual = True # a flag for llm-foundry init
+
+ def add_experts_sharedexpert(
+ self,
+ shared_expert_out: torch.Tensor,
+ expert_out: torch.Tensor,
+ ) -> torch.Tensor:
+ # Helper function to add expert output to shared expert output
+ # with optional weighted sum.
+ if self.args.shared_expert_weighted_sum:
+ # enable using weighted sum for shared expert output
+ # wieghted by number of experts used
+ t_experts = self.args.moe_top_k + 1
+ sh_mlp_out = shared_expert_out / t_experts
+ return sh_mlp_out.add(
+ expert_out,
+ alpha=(self.args.moe_top_k / t_experts),
+ )
+
+ return shared_expert_out + expert_out
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.down_proj(self.act(self.up_proj(x)))
diff --git a/build/torch29-cxx11-xpu20252-x86_64-linux/_layers/moe.py b/build/torch29-cxx11-xpu20252-x86_64-linux/_layers/moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0a4aeaacc9c86fc70944e730c53f7a55644e05e
--- /dev/null
+++ b/build/torch29-cxx11-xpu20252-x86_64-linux/_layers/moe.py
@@ -0,0 +1,507 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Optional, Tuple
+
+import numpy as np
+import torch
+import torch.distributed as dist
+
+# import megablocks.ops as ops
+# from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry
+# from megablocks.layers.all_to_all import all_to_all
+# from megablocks.layers.arguments import Arguments
+
+from ..ops import (
+ sort,
+ histogram,
+ inclusive_cumsum,
+ exclusive_cumsum,
+ binned_gather,
+ binned_scatter,
+ gather,
+ scatter,
+ repeat,
+ replicate,
+)
+
+from . import common, mlp, mpu, router, sharedexpert_registry
+from .arguments import Arguments
+from .all_to_all import all_to_all
+
+_LOAD_BALANCING_LOSS = []
+
+
+def save_load_balancing_loss(loss):
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.append(loss)
+
+
+def get_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ return _LOAD_BALANCING_LOSS
+
+
+def clear_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.clear()
+
+
+def batched_load_balancing_loss(args: Arguments):
+ if args.moe_loss_weight == 0:
+ return 0.0
+
+ # tokens_per_expert[i].shape = (num_experts)
+ # expert_scores[i].shape = (tokens, num_experts)
+ tokens_per_expert, expert_scores = zip(*get_load_balancing_loss())
+ num_layers_per_pipeline_stage = (args.num_layers // args.pipeline_model_parallel_size)
+ if args.num_layers_per_virtual_pipeline_stage is not None:
+ num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage
+
+ if len(tokens_per_expert) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f'Expected {num_layers_per_pipeline_stage} token_per_experts '
+ f'but found {len(tokens_per_expert)}.\nnum_layers = '
+ f'{args.num_layers}\npipeline_model_parallel_size = '
+ f'{args.pipeline_model_parallel_size}\n'
+ 'num_layers_per_virtual_pipeline_stage'
+ f' = {args.num_layers_per_virtual_pipeline_stage}',
+ )
+ if len(expert_scores) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f'Expected {num_layers_per_pipeline_stage} expert_scores '
+ f'but found {len(tokens_per_expert)}.\nnum_layers = '
+ f'{args.num_layers}\npipeline_model_parallel_size = '
+ f'{args.pipeline_model_parallel_size}\n'
+ 'num_layers_per_virtual_pipeline_stage'
+ f' = {args.num_layers_per_virtual_pipeline_stage}',
+ )
+
+ # Verify the shape of the tokens_per_expert and expert_scores tensors.
+ assert all((x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert))
+
+ tokens = expert_scores[0].shape[0]
+ assert all(((x.ndim == 2 and x.shape[1] == args.moe_num_experts and x.shape[0] == tokens) for x in expert_scores))
+
+ # Concatenate the contributions of each layer and convert to
+ # the correct types and formats for the dot product.
+ expert_scores = torch.cat(expert_scores, dim=1)
+ if args.moe_lbl_in_fp32:
+ expert_scores = expert_scores.float()
+ if tokens != 0:
+ expert_scores = expert_scores.mean(dim=0)
+ else:
+ expert_scores = expert_scores.sum(dim=0)
+ tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype)
+
+ expected_values = num_layers_per_pipeline_stage * args.moe_num_experts
+ assert tokens_per_expert.numel() == expected_values
+ assert expert_scores.numel() == expected_values
+
+ # Calculate the total scale across all factors.
+ #
+ # loss_weight * num_experts / (num_layers * tokens * top_k)
+ scale_numerator = (args.moe_num_experts * args.moe_loss_weight)
+ scale_denominator = (args.num_layers * tokens * args.moe_top_k)
+ scale = scale_numerator / scale_denominator
+ return scale * torch.dot(tokens_per_expert, expert_scores)
+
+
+# NOTE: This class defines MoE expert computation, including expert model parallel
+# communication. When using FSDP on top of MegaBlocks this is the module that should
+# be wrapped s.t. the weight all-gathers can be scheduled *before* the expert model
+# parallel all2all.
+class ParallelMLP(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super(ParallelMLP, self).__init__()
+ self.args = args
+
+ # Calculate the number of experts in total and the number of experts
+ # owned by this rank.
+ # world_size = mpu.get_expert_parallel_world_size(args)
+ self.num_experts = args.moe_num_experts
+ self.top_k = self.args.moe_top_k
+
+ # Calculate the number of bits needed to represent the expert indices
+ # so that we can pass it to radix sort.
+ self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1)
+
+ # Expert MLP.
+ self.mlp = mlp.MLP(args)
+
+ self.bias: Optional[torch.Tensor]
+ if self.args.bias:
+ # Note that the output bias is not parallelized with expert
+ # model parallelism.
+ self.bias = torch.nn.Parameter(
+ torch.empty(
+ args.hidden_size,
+ device=args.device,
+ dtype=common.dtype(args),
+ ),
+ )
+ torch.nn.init.zeros_(self.bias)
+ else:
+ self.register_parameter('bias', None)
+
+ # Select the forward function for the operating mode.
+ self.forward_fn = (self.parallel_forward_once if args.moe_expert_model_parallelism else self.forward_once)
+
+ def expert_capacity(self, tokens: int) -> int:
+ world_size = mpu.get_expert_parallel_world_size(self.args)
+ tokens_per_expert = (self.top_k * tokens * world_size / self.num_experts)
+ return int(self.args.moe_capacity_factor * tokens_per_expert)
+
+ def load_balancing_loss(self, tokens_per_expert: torch.Tensor, expert_scores: torch.Tensor):
+ """Calculate the load balancing loss contribution."""
+ assert len(expert_scores.size()) == 2
+ tokens, num_experts = expert_scores.size()
+ assert num_experts == self.num_experts
+ assert len(tokens_per_expert.size()) == 1
+ num_experts, = tokens_per_expert.size()
+ assert num_experts == self.num_experts
+ scale = self.num_experts / (tokens * self.top_k)
+ return scale * torch.dot(
+ tokens_per_expert.to(expert_scores.dtype),
+ expert_scores.mean(dim=0),
+ )
+
+ def indices_and_bins(self,
+ top_expert: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ # Sort the expert ids to produce the scatter/gather
+ # indices for the permutation.
+ #
+ # TODO(tgale): Is it worth doing this conversion to 32-bit
+ # prior? Could we place the `torch.max` operation to return
+ # 32-bit expert indices?
+ top_expert = top_expert.int()
+ # output = ops.sort(top_expert, self.sort_end_bit)
+ output = sort(top_expert, self.sort_end_bit)
+ assert output is not None
+ bin_ids, indices = output
+
+ # Histogram the expert ids to identify the number of
+ # tokens routed to each expert.
+ #
+ # TODO(tgale): Does the sorted data produce a more favorable
+ # data distribution for histogram? Or is the op parallelism
+ # worth more?
+ # tokens_per_expert = ops.histogram(top_expert, self.num_experts)
+ tokens_per_expert = histogram(top_expert, self.num_experts)
+
+ # Calculate the bin bounds for the sorted tokens.
+ # bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+ bins = inclusive_cumsum(tokens_per_expert, 0)
+ assert bins is not None
+ bins = bins.view(1) if not len(bins.size()) else bins
+
+ assert isinstance(indices, torch.Tensor)
+ assert isinstance(bin_ids, torch.Tensor)
+ assert isinstance(bins, torch.Tensor)
+ assert isinstance(tokens_per_expert, torch.Tensor)
+
+ return indices, bin_ids, bins, tokens_per_expert
+
+ def permute_and_compute(
+ self,
+ x: torch.Tensor,
+ tokens_per_expert: int, # unused
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor, # unused
+ expert_weights: torch.Tensor,
+ bins: torch.Tensor,
+ expert_capacity: int,
+ top_k: int,
+ ):
+ # Route the tokens for MoE computation.
+ x = x.view(-1, x.shape[-1])
+ # output = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
+ output = binned_gather(x, indices, bins, expert_capacity, top_k)
+ assert output is not None
+ x = output
+
+ # Perform the expert computation. Note that we don't
+ # use biases for these linear operations.
+ x = self.mlp(x)
+
+ # Un-route the data for the MoE output.
+ # return ops.binned_scatter(x, indices, expert_weights, bins, top_k)
+ return binned_scatter(x, indices, expert_weights, bins, top_k)
+
+
+ def forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ # If expert_capacity is set to zero, set the number of tokens
+ # per expert to the maximum we need to avoid dropping tokens.
+ sl, bs, _ = x.size()
+ expert_capacity = self.expert_capacity(sl * bs)
+ if expert_capacity == 0:
+ expert_capacity = torch.max(tokens_per_expert).item()
+
+ x = self.permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ self.top_k,
+ )
+ return x, tokens_per_expert
+
+ def parallel_forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ # NOTE: This function implements the same computation as forward_once
+ # but with expert model parallelism.
+ #
+ # 1. Permute the tokens locally so that they are grouped by their
+ # expert assignments. This allows us to transfer all of the tokens
+ # for a remote device in one communication primitive.
+ #
+ # 2. Permute the tokens across the expert parallel devices. After
+ # this is completed each device has all of the tokens assigned to
+ # its set of experts in its local HBM.
+ #
+ # 3. Permute the tokens locally so that they are grouped by their
+ # expert assignement. After the distributed permutation the tokens
+ # are grouped by which device they came from. We re-order them
+ # locally to allow for efficient computation.
+ #
+ # After this series of permutations we compute the linear layers
+ # and then repeat these three steps in reverse to produce the final
+ # output.
+ #
+ # Compute the mapping of local tokens to experts.
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
+
+ # If we're sharding the experts along the hidden dimension
+ # multiple devices own parts of the same sets of experts.
+ # Replicate the token counts so every device gets the counts.
+ # repeated_tokens_per_expert = ops.repeat(
+ repeated_tokens_per_expert = repeat(
+ tokens_per_expert,
+ (mpu.hidden_sharding_degree(self.args),),
+ )
+
+ # Pass token count information to the device on which the
+ # target expert resides.
+ parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert,)
+ tpe_handle = dist.all_to_all_single(
+ parallel_tokens_per_expert,
+ repeated_tokens_per_expert,
+ group=self.args.expert_parallel_group,
+ async_op=True,
+ )
+
+ # Permute locally and without any padding so that tokens for each
+ # parallel device are stored contiguously.
+ #
+ # This view updates the shape of the tensor from [sl, bs, hs] to
+ # [sl * bs, hs] prior to the permutation.
+ x = x.view(-1, x.shape[-1])
+ # output = ops.gather(x, indices, bin_ids, bins, self.top_k)
+ output = gather(x, indices, bin_ids, bins, self.top_k)
+ assert output is not None
+ x = output
+
+ # Compute the number of tokens that will be received from each
+ # device and permute the input data across the devices.
+ with torch.no_grad():
+ tpe_handle.wait()
+ experts_per_rank = mpu.experts_per_rank(self.args)
+
+ # Reshape to [world_size, num_experts_per_rank].
+ world_size = mpu.get_expert_parallel_world_size(self.args)
+ repeated_tokens_per_expert = (repeated_tokens_per_expert.view(world_size, experts_per_rank))
+ parallel_tokens_per_expert = (parallel_tokens_per_expert.view(world_size, experts_per_rank))
+
+ # TODO(tgale): It might be faster to do this on the GPU and
+ # then communicate the results back to the host.
+ send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1)
+ parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
+ recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1)
+
+ # Convert the send/recv counts to lists.
+ send_counts = send_counts.tolist()
+ recv_counts = recv_counts.tolist()
+ tokens_received = sum(recv_counts)
+
+ # If we're sharding the experts along the hidden dimension
+ # multiple devices own parts of the same sets of experts.
+ # Replicate the token counts so devices that share experts
+ # get all of the tokens assigned to them.
+ #
+ # TODO(tgale): Fuse this into the prior, local permutation.
+ # x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
+ x = repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
+
+ # Start the cross-device permutation asynchronously so we can
+ # overlap communication with computation.
+ parallel_x, parallel_x_handle = all_to_all(
+ x,
+ recv_counts,
+ send_counts,
+ self.args.expert_parallel_group,
+ async_op=True,
+ )
+
+ with torch.no_grad():
+ # After we do the cross-device permutation we have the tokens on the
+ # correct device but not yet grouped by expert because we received
+ # tokens from each device as contiguous chunks. To group the tokens
+ # for expert computation we'll do one more local permutation. The
+ # rest of this torch.no_grad() scope sets up the indices and bins
+ # for this permutation.
+ # replicate_bins = ops.inclusive_cumsum(
+ replicate_bins = inclusive_cumsum(
+ parallel_tokens_per_expert.flatten(),
+ 0,
+ )
+ replicate_bins = (replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins)
+
+ # Construct the expert indices for the permuted tokens.
+ parallel_top_expert = torch.remainder(
+ torch.arange(
+ self.num_experts * mpu.hidden_sharding_degree(self.args),
+ dtype=torch.int32,
+ device=indices.device,
+ ),
+ mpu.experts_per_rank(self.args),
+ )
+ # parallel_top_expert = ops.replicate(
+ parallel_top_expert = replicate(
+ parallel_top_expert.unsqueeze(dim=0),
+ replicate_bins,
+ tokens_received,
+ ).flatten()
+
+ # TODO(tgale): The sort_end_bit here can be reduced.
+ # parallel_bin_ids, parallel_indices = ops.sort(
+ parallel_bin_ids, parallel_indices = sort(
+ parallel_top_expert,
+ self.sort_end_bit,
+ )
+
+ # Calculate the bins boundaries from the token counts.
+ parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+ # parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = (parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins)
+
+ # If expert_capacity is set to zero, set the number of tokens
+ # per expert to the maximum we need to avoid dropping tokens.
+ tokens, _ = x.size()
+ expert_capacity = self.expert_capacity(tokens)
+ if expert_capacity == 0:
+ expert_capacity = torch.max(parallel_tokens_per_expert).item()
+
+ # Locally permute the tokens and perform the expert computation.
+ # Block to make sure that the cross-device permutation is complete.
+ if self.args.mlp_impl == 'grouped':
+ # GroupedMLP requires counts on CPU. We can use the tensor already
+ # moved to CPU for the prior all_to_all, which avoids an extra
+ # device synchronization.
+ parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+ parallel_x_handle.wait()
+ parallel_x = self.permute_and_compute(
+ parallel_x,
+ parallel_tokens_per_expert,
+ parallel_indices,
+ parallel_bin_ids,
+ None, # expert_weights
+ parallel_bins,
+ expert_capacity,
+ top_k=1,
+ )
+
+ # Un-permute the tokens across the devices.
+ x, _ = all_to_all(
+ parallel_x,
+ send_counts,
+ recv_counts,
+ self.args.expert_parallel_group,
+ )
+
+ # Reduce along the hidden sharding to get the final outputs.
+ #
+ # TODO(tgale): Fuse this into the following local permutation.
+ shape = (
+ mpu.hidden_sharding_degree(self.args),
+ -1,
+ self.args.hidden_size,
+ )
+ # x = ops.sum(x.view(shape), dim=0)
+ x = x.view(shape).sum(dim=0)
+
+ # Un-permute locally to setup for the next series of operations.
+ # x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
+ x = scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
+ return x, tokens_per_expert.flatten()
+
+ def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
+ in_shape = x.size()
+
+ # Compute the experts.
+ x, tokens_per_expert = self.forward_fn(x, expert_weights, top_experts)
+ if self.training and self.args.moe_loss_weight > 0:
+ save_load_balancing_loss((tokens_per_expert, scores))
+ x = x.view(in_shape)
+ if self.bias is not None:
+ if self.args.return_bias:
+ return x, self.bias
+ return x + self.bias
+ return x
+
+
+class MoE(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super(MoE, self).__init__()
+
+ # Token router.
+ self.router = router.LearnedRouter(args)
+
+ # Expert computation helper.
+ self.experts = self._init_experts_mlp(args)
+
+ self.shared_expert = None
+ if args.shared_expert:
+ # SharedExpert computation helper.
+ self.shared_expert = sharedexpert_registry.get(args)
+
+ def _init_experts_mlp(self, args: Arguments):
+ return ParallelMLP(args)
+
+ def forward(self, x: torch.Tensor):
+ # NOTE: If we're going to cast the activations to lower precision
+ # do it before we permute the tokens to save bandwidth.
+ x = common.cast_if_autocast_enabled(x)
+
+ # Compute the expert scores and assignments.
+ scores, expert_weights, top_experts = self.router(x)
+
+ # Compute the experts.
+ out = self.experts(x, scores, expert_weights, top_experts)
+ if self.shared_expert is not None:
+ shared_expert_out = self.shared_expert(x)
+ out = self.shared_expert.add_experts_sharedexpert(
+ shared_expert_out,
+ out,
+ )
+ return out
diff --git a/build/torch29-cxx11-xpu20252-x86_64-linux/_layers/mpu.py b/build/torch29-cxx11-xpu20252-x86_64-linux/_layers/mpu.py
new file mode 100644
index 0000000000000000000000000000000000000000..434e143ab42bf3f83406d69e9dd1f72777716e22
--- /dev/null
+++ b/build/torch29-cxx11-xpu20252-x86_64-linux/_layers/mpu.py
@@ -0,0 +1,94 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Optional
+
+import torch
+import torch.distributed as dist
+
+# from megablocks.layers.arguments import Arguments
+from .arguments import Arguments
+
+
+class MoeParam(torch.Tensor):
+
+ def __init__(self):
+ super().__init__(self)
+ self.expert_model_parallel: bool
+
+
+def is_moe_param(tensor: torch.Tensor) -> bool:
+ return hasattr(tensor, 'expert_model_parallel')
+
+
+def get_expert_parallel_world_size(args: Arguments) -> int:
+ return (dist.get_world_size(args.expert_parallel_group) if args.moe_expert_model_parallelism else 1)
+
+
+def get_expert_parallel_rank(args: Arguments) -> int:
+ return (dist.get_rank(args.expert_parallel_group) if args.moe_expert_model_parallelism else 0)
+
+
+def set_expert_model_parallel_attributes(
+ tensor: torch.Tensor,
+ is_parallel: bool,
+):
+ assert not hasattr(tensor, 'expert_model_parallel')
+ setattr(tensor, 'expert_model_parallel', is_parallel)
+
+
+def param_is_expert_model_parallel(param: MoeParam) -> bool:
+ return (hasattr(param, 'expert_model_parallel') and param.expert_model_parallel)
+
+
+def copy_expert_model_parallel_attributes(
+ destination_tensor: torch.Tensor,
+ source_tensor: torch.Tensor,
+):
+ if hasattr(source_tensor, 'expert_model_parallel'):
+ setattr(
+ destination_tensor,
+ 'expert_model_parallel',
+ getattr(source_tensor, 'expert_model_parallel'),
+ )
+
+
+def synchronized_print(group: Optional[dist.ProcessGroup], *x: torch.Tensor):
+ world_size = dist.get_world_size(group)
+ rank = dist.get_rank(group)
+ for i in range(world_size):
+ dist.barrier(group)
+ if i == rank:
+ print(f'rank = {rank}', *x)
+
+
+# Helpers for expert/tensor sharding.
+def expert_sharding_degree(args: Arguments) -> int:
+ world_size = get_expert_parallel_world_size(args)
+ esd = min(world_size, args.moe_num_experts)
+
+ if (args.moe_num_experts % esd) != 0:
+ raise ValueError(f'Cannot shard {args.moe_num_experts} experts {esd} ways.',)
+ return esd
+
+
+def hidden_sharding_degree(args: Arguments) -> int:
+ world_size = get_expert_parallel_world_size(args)
+ esd = expert_sharding_degree(args)
+ hsd = world_size // esd
+
+ if (args.ffn_hidden_size % hsd) != 0:
+ raise ValueError(f'Cannot shard {args.ffn_hidden_size} features {hsd} ways.',)
+ if (esd * hsd) != world_size:
+ raise ValueError(
+ f"Invalid sharding. 'expert_sharding_degree' ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size}).",
+ )
+ return hsd
+
+
+def experts_per_rank(args: Arguments) -> int:
+ return args.moe_num_experts // expert_sharding_degree(args)
+
+
+def features_per_rank(args: Arguments) -> int:
+ return args.ffn_hidden_size // hidden_sharding_degree(args)
diff --git a/build/torch29-cxx11-xpu20252-x86_64-linux/_layers/router.py b/build/torch29-cxx11-xpu20252-x86_64-linux/_layers/router.py
new file mode 100644
index 0000000000000000000000000000000000000000..37cb2782348d62583376f1a183c7ede83601216d
--- /dev/null
+++ b/build/torch29-cxx11-xpu20252-x86_64-linux/_layers/router.py
@@ -0,0 +1,116 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+
+# from megablocks.layers import common
+# from megablocks.layers.arguments import Arguments
+from . import common
+from .arguments import Arguments
+
+_ROUTER_LOGITS = []
+
+
+def _save_router_logits(logits: torch.Tensor, args: Arguments):
+ if args.moe_zloss_weight == 0:
+ return
+ global _ROUTER_LOGITS
+ _ROUTER_LOGITS.append(logits)
+
+
+def clear_router_zloss():
+ global _ROUTER_LOGITS
+ _ROUTER_LOGITS.clear()
+
+
+def batched_router_zloss(args: Arguments):
+ global _ROUTER_LOGITS
+
+ if args.moe_zloss_weight == 0:
+ import warnings
+ warnings.warn('Call to batched_router_zloss, but moe_zloss_weight=0')
+ return 0
+
+ logits_per_router = _ROUTER_LOGITS
+
+ if args.moe_zloss_in_fp32:
+ logits_per_router = [logits.float() for logits in logits_per_router]
+
+ unscaled_zloss_per_router = torch.stack([
+ torch.logsumexp(logits, dim=1).square().mean() for logits in logits_per_router
+ ])
+
+ return args.moe_zloss_weight * unscaled_zloss_per_router
+
+
+# NOTE: To enable end-to-end benchmarking without convergence we
+# support a flag to force the router to assign tokens uniformly
+# across the experts. We do this with a custom autograd operation
+# so that PyTorch still executes the full set of router operation.
+class _UniformExpertAssignment(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, x: torch.Tensor, num_experts: int):
+ out = torch.arange(x.numel(), dtype=x.dtype, device=x.device)
+ out = torch.remainder(out, num_experts)
+ return out.view(x.shape)
+
+
+_uniform_expert_assignment = _UniformExpertAssignment.apply
+
+
+class LearnedRouter(torch.nn.Module):
+
+ def __init__(self, args: Arguments):
+ super().__init__()
+ self.args = args
+
+ # Learned router parameters.
+ #
+ # NOTE: This weight matrix is not parallelized with expert model
+ # parallelism. Each device needs the entire router weight matrix
+ # so that it can route its batch of data correctly.
+ self.layer = torch.nn.Linear(
+ args.hidden_size,
+ args.moe_num_experts,
+ bias=False,
+ dtype=common.dtype(args),
+ device=args.device,
+ )
+ args.init_method(self.layer.weight)
+
+ def jitter(self, x: torch.Tensor):
+ low: float = 1.0 - self.args.moe_jitter_eps
+ high: float = 1.0 + self.args.moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return low + noise * (high - low)
+
+ def _top_k(self, scores: torch.Tensor):
+ if self.args.moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, self.args.moe_top_k, dim=-1)
+
+ def forward(self, x: torch.Tensor):
+ if self.training and self.args.moe_jitter_eps is not None:
+ x = x * self.jitter(x)
+
+ logits = self.layer(x.view(-1, x.shape[-1]))
+ _save_router_logits(logits, self.args)
+ scores = logits.softmax(dim=-1)
+ expert_weights, expert_indices = self._top_k(scores)
+ if self.args.moe_normalize_expert_weights:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=self.args.moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+
+ expert_indices = (
+ _uniform_expert_assignment(
+ expert_indices,
+ self.args.moe_num_experts,
+ ) if self.args.uniform_expert_assignment else expert_indices
+ )
+ return scores, expert_weights, expert_indices
diff --git a/build/torch29-cxx11-xpu20252-x86_64-linux/_layers/sharedexpert_registry.py b/build/torch29-cxx11-xpu20252-x86_64-linux/_layers/sharedexpert_registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..5840862f88f370ace5fd49bd0612fc98d186cc49
--- /dev/null
+++ b/build/torch29-cxx11-xpu20252-x86_64-linux/_layers/sharedexpert_registry.py
@@ -0,0 +1,32 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Union
+
+# from megablocks.layers import glu, mlp
+# from megablocks.layers.arguments import Arguments
+from . import glu, mlp
+from .arguments import Arguments
+
+_REGISTRY = {
+ 'mlp': mlp.SharedMLP,
+ 'glu': glu.SharedGLU,
+}
+
+
+def get(args: Arguments) -> Union[mlp.SharedMLP, glu.SharedGLU]:
+ """Returns an SharedMLP for use in a dMoE instance.
+
+ Uses the provided arguments to instantiate the appropriate
+ SharedMLP instance.
+
+ Args:
+ args: propagated Arguments dataclass.
+
+ Returns:
+ An instantiated SharedMLP constructed using the input args.
+ """
+ if args.mlp_type not in _REGISTRY:
+ raise ValueError(f'Unsupported mlp type: {args.mlp_type}')
+
+ return _REGISTRY[args.mlp_type](args)
diff --git a/build/torch29-cxx11-xpu20252-x86_64-linux/_megablocks_xpu_a45325d.abi3.so b/build/torch29-cxx11-xpu20252-x86_64-linux/_megablocks_xpu_a45325d.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..690eead6de2b9eba259e73b756f7a280bdd33c63
--- /dev/null
+++ b/build/torch29-cxx11-xpu20252-x86_64-linux/_megablocks_xpu_a45325d.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b9da3abdc02eb695338490793988e2c315411f3bf732e8839af05f41eb3aec66
+size 5197008
diff --git a/build/torch29-cxx11-xpu20252-x86_64-linux/_ops.py b/build/torch29-cxx11-xpu20252-x86_64-linux/_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..9be29157b8f34992dd924071221b419a35a9145f
--- /dev/null
+++ b/build/torch29-cxx11-xpu20252-x86_64-linux/_ops.py
@@ -0,0 +1,9 @@
+import torch
+from . import _megablocks_xpu_a45325d
+ops = torch.ops._megablocks_xpu_a45325d
+
+def add_op_namespace_prefix(op_name: str):
+ """
+ Prefix op by namespace.
+ """
+ return f"_megablocks_xpu_a45325d::{op_name}"
diff --git a/build/torch29-cxx11-xpu20252-x86_64-linux/_version.py b/build/torch29-cxx11-xpu20252-x86_64-linux/_version.py
new file mode 100644
index 0000000000000000000000000000000000000000..c55783177af19bc03654c730c4892df8f8532279
--- /dev/null
+++ b/build/torch29-cxx11-xpu20252-x86_64-linux/_version.py
@@ -0,0 +1,6 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+"""The MegaBlocks Version."""
+
+__version__ = '0.11.0.dev0'
diff --git a/build/torch29-cxx11-xpu20252-x86_64-linux/backend/__init__.py b/build/torch29-cxx11-xpu20252-x86_64-linux/backend/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d4e43e9b7e6a9a1c3bde2df34914643ca5d8332
--- /dev/null
+++ b/build/torch29-cxx11-xpu20252-x86_64-linux/backend/__init__.py
@@ -0,0 +1,2 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
diff --git a/build/torch29-cxx11-xpu20252-x86_64-linux/backend/kernels.py b/build/torch29-cxx11-xpu20252-x86_64-linux/backend/kernels.py
new file mode 100644
index 0000000000000000000000000000000000000000..d20139324883338ddb312e4b05a72056d74491ac
--- /dev/null
+++ b/build/torch29-cxx11-xpu20252-x86_64-linux/backend/kernels.py
@@ -0,0 +1,557 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import triton
+import triton.language as tl
+
+# Stub triton autotune when testing in a env that does not have CUDA
+# this approach preserves the original code but enables testing without a GPU
+if torch.cuda.is_available() is False:
+ import warnings
+
+ warnings.warn("CUDA is not available. Triton autotuning is disabled.")
+
+ def _no_autotune(*args, **kwargs):
+ def deco(fn):
+ return fn
+ return deco
+
+ triton.autotune = _no_autotune
+
+
+def assert_is_tensor(x, ndim):
+ if x.ndim != ndim:
+ raise ValueError(f'Expected {ndim}-tensor but got {x.ndim}-tensor')
+
+
+def assert_is_matrix(x):
+ assert_is_tensor(x, 2)
+
+
+def assert_is_vector(x):
+ if x.ndim != 1:
+ raise ValueError(f'Expected 1-tensor but got {x.ndim}-tensor')
+
+
+def assert_equal(a, b):
+ if a != b:
+ raise ValueError(f'Expected dimensions to be equal but got {a} and {b}.',)
+
+
+# a: (tokens, hidden_size), real.
+# indices: (tokens * top_k), integer.
+# bin_ids: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+# padded_bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _padded_copy(
+ a,
+ b,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+ A_TO_B: tl.constexpr,
+ SCALE: tl.constexpr,
+):
+ # Our index into array 'a'.
+ index_a = tl.load(indices + tl.program_id(0))
+
+ # One threadblock per row in 'a'. Array 'b' has greater or equal
+ # number of rows since they could be padded.
+ bin_idx = tl.load(bin_ids + tl.program_id(0))
+
+ # Now we know what bin we're assigned to, but we need to know how
+ # many threadblocks were assigned to earlier bins so we can offset
+ # in our bin properly.
+ offset_in_bin = tl.program_id(0)
+ if bin_idx > 0:
+ offset_in_bin -= tl.load(bins + bin_idx - 1)
+
+ # Load the starting index of our bin in array 'b'.
+ index_b = offset_in_bin
+ if bin_idx > 0:
+ index_b += tl.load(padded_bins + bin_idx - 1)
+
+ # Offset the input and output pointers.
+ #
+ # If we're going from A to B, divide the input index to copy
+ # the same input repeatedly. If we're going from B to A we
+ # need to reduce the result. Using atomics is slow, so we
+ # do the reduce step in a second kernel.
+ offset = index_a // TOP_K if A_TO_B else index_a
+ a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS)
+ b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ # Load the scale, if requested.
+ scale = tl.load(weights + index_a) if SCALE else 1
+
+ # Swap the pointers depending on the direction.
+ iptr = a if A_TO_B else b
+ optr = b if A_TO_B else a
+
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ x = tl.load(iptr + offsets, mask=mask)
+ x = x.to(tl.float32) * scale.to(tl.float32)
+
+ tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask)
+
+ offsets += BLOCK_X
+
+
+def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+ assert_equal(bin_ids.shape[0], x.shape[0] * top_k)
+ assert_equal(bins.size(), padded_bins.size())
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ # NOTE: Because of the padding, the output size is dynamic.
+ # We load the final padded bin bound to get the output rows.
+ output_rows = padded_bins[-1].cpu().item()
+ out = torch.zeros((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ x,
+ out,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def gather(x, indices, bin_ids, weights, bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+ assert_equal(bin_ids.shape[0], x.shape[0] * top_k)
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ # NOTE: There is no padding so the output rows equals the
+ # input rows multiplied by top_k.
+ output_rows = x.shape[0] * top_k
+ out = torch.empty((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ x,
+ out,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], bin_ids.shape[0])
+ assert_equal(bins.size(), padded_bins.size())
+
+ if weights is not None:
+ assert_equal(indices.shape[0], weights.shape[0])
+
+ tokens = indices.shape[0] // top_k
+ out = torch.empty((tokens, top_k, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ out,
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=False,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+
+ # Reduce along the top-k dimension, if needed.
+ return out.sum(dim=1) if top_k > 1 else out.view(tokens, x.shape[1])
+
+
+def scatter(x, indices, bin_ids, weights, bins, top_k):
+ return padded_scatter(x, indices, bin_ids, weights, bins, bins, top_k)
+
+
+# x: (tokens, top_k, hidden_size), real
+# grad: (tokens, hidden_size), real.
+# wgrad: (tokens, top_k), real.
+# indices: (tokens * top_k), integer.
+# bin_ids: (tokens * top_k), integer.
+# bins: (num_experts), integer.
+# padded_bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _padded_copy_wgrad(
+ x,
+ grad,
+ wgrad,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+):
+ # Our index into 'tokens * top_k'.
+ index_out = tl.load(indices + tl.program_id(0))
+
+ # One threadblock per row in 'a'. Array 'b' has greater or equal
+ # number of rows since they could be padded.
+ bin_idx = tl.load(bin_ids + tl.program_id(0))
+
+ # Now we know what bin we're assigned to, but we need to know how
+ # many threadblocks were assigned to earlier bins so we can offset
+ # in our bin properly.
+ offset_in_bin = tl.program_id(0)
+ if bin_idx > 0:
+ offset_in_bin -= tl.load(bins + bin_idx - 1)
+
+ # Load the starting index of our bin in array 'x'.
+ index_x = offset_in_bin
+ if bin_idx > 0:
+ index_x += tl.load(padded_bins + bin_idx - 1)
+
+ # Offset the input and output pointers.
+ wgrad += index_out
+ grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS)
+ x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ data = tl.load(x + offsets, mask=mask).to(tl.float32)
+ scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
+ acc += data * scale
+ offsets += BLOCK_X
+
+ # Reduce to get the final result and store.
+ out = tl.sum(acc).to(wgrad.dtype.element_ty)
+ tl.store(wgrad, out)
+
+
+def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_matrix(grad)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], bin_ids.shape[0])
+ assert_equal(bins.size(), padded_bins.size())
+
+ tokens = indices.shape[0] // top_k
+ out = torch.empty((tokens * top_k), dtype=x.dtype, device=x.device)
+ _padded_copy_wgrad[(indices.shape[0],)](
+ x,
+ grad,
+ out,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ TOP_K=top_k,
+ )
+ return out
+
+
+def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k):
+ return padded_scatter_wgrad(x, grad, indices, bin_ids, bins, bins, top_k)
+
+
+# a: (tokens, hidden_size), real.
+# b: (num_experts, expert_capacity, num_columns), real.
+# indices: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _binned_copy(
+ a,
+ b,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+ A_TO_B: tl.constexpr,
+ SCALE: tl.constexpr,
+):
+ # Load our indices into the output.
+ expert_idx = tl.program_id(0)
+ entry_idx = tl.program_id(1)
+
+ # Calculate our offset into the output.
+ index_b = expert_idx * expert_capacity + entry_idx
+
+ # Load the index bounds for our bin and calculate
+ # the number of tokens assigned to our expert.
+ start = 0
+ if expert_idx > 0:
+ start = tl.load(bins + expert_idx - 1)
+ end = tl.load(bins + expert_idx)
+ num_tokens = end - start
+
+ # Calculate our offset into the input. If we don't
+ # have an input exit early.
+ if entry_idx >= num_tokens:
+ return
+ index_a = tl.load(indices + start + entry_idx)
+
+ # Offset the input and output pointers.
+ #
+ # If we're going from A to B, divide the input index to copy
+ # the same input repeatedly. If we're going from B to A we
+ # need to reduce the result. Using atomics is slow, so we
+ # do the reduce step in a second kernel.
+ offset = index_a // TOP_K if A_TO_B else index_a
+ a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS)
+ b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ # Load the scale, if requested.
+ scale = tl.load(weights + index_a) if SCALE else 1
+
+ # Swap the pointers depending on the direction.
+ #
+ # NOTE: We need to zero the output in both directions.
+ iptr = a if A_TO_B else b
+ optr = b if A_TO_B else a
+
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ x = tl.load(iptr + offsets, mask=mask)
+ x = x.to(tl.float32) * scale.to(tl.float32)
+
+ tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask)
+
+ offsets += BLOCK_X
+
+
+def binned_gather(x, indices, weights, bins, expert_capacity, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ num_experts = bins.shape[0]
+ out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
+
+ _binned_copy[(num_experts, expert_capacity)](
+ x,
+ out,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def binned_scatter(x, indices, weights, bins, top_k):
+ # Validate the input shapes.
+ assert_is_tensor(x, 3)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(bins.shape[0], x.shape[0])
+
+ if weights is not None:
+ assert_equal(indices.shape[0], weights.shape[0])
+
+ num_experts, expert_capacity, hidden_size = x.shape
+ tokens = indices.shape[0] // top_k
+ out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
+ _binned_copy[(num_experts, expert_capacity)](
+ out,
+ x,
+ num_experts,
+ expert_capacity,
+ indices,
+ weights,
+ bins,
+ NUM_COLUMNS=hidden_size,
+ A_TO_B=False,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+
+ # Reduce along the top-k dimension, if needed.
+ return out.sum(dim=1) if top_k > 1 else out.view(tokens, hidden_size)
+
+
+# a: (tokens, hidden_size), real.
+# b: (num_experts, expert_capacity, num_columns), real.
+# indices: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
+ ],
+ key=['NUM_COLUMNS'],
+)
+@triton.jit
+def _binned_copy_wgrad(
+ x,
+ grad,
+ wgrad,
+ num_experts,
+ expert_capacity,
+ indices,
+ bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+):
+ # Load our indices into the output.
+ expert_idx = tl.program_id(0)
+ entry_idx = tl.program_id(1)
+
+ # Calculate our offset into the output.
+ index_x = expert_idx * expert_capacity + entry_idx
+
+ # Load the index bounds for our bin and calculate
+ # the number of tokens assigned to our expert.
+ start = 0
+ if expert_idx > 0:
+ start = tl.load(bins + expert_idx - 1)
+ end = tl.load(bins + expert_idx)
+ num_tokens = end - start
+
+ # Calculate our offset into the input. If we don't
+ # have an input exit early.
+ if entry_idx >= num_tokens:
+ return
+ index_out = tl.load(indices + start + entry_idx)
+
+ # Offset the input and output pointers.
+ wgrad += index_out
+ grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS)
+ x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for _ in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ data = tl.load(x + offsets, mask=mask).to(tl.float32)
+ scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
+ acc += data * scale
+ offsets += BLOCK_X
+
+ # Reduce to get the final result and store.
+ out = tl.sum(acc).to(wgrad.dtype.element_ty)
+ tl.store(wgrad, out)
+
+
+def binned_scatter_wgrad(x, grad, indices, bins, top_k):
+ # Validate the input shapes.
+ assert_is_tensor(x, 3)
+ assert_is_matrix(grad)
+ assert_is_vector(indices)
+ assert_is_vector(bins)
+ assert_equal(bins.shape[0], x.shape[0])
+
+ num_experts, expert_capacity, hidden_size = x.shape
+ tokens = indices.shape[0] // top_k
+ out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
+ _binned_copy_wgrad[(num_experts, expert_capacity)](
+ x,
+ grad,
+ out,
+ num_experts,
+ expert_capacity,
+ indices,
+ bins,
+ NUM_COLUMNS=hidden_size,
+ TOP_K=top_k,
+ )
+ return out
diff --git a/build/torch29-cxx11-xpu20252-x86_64-linux/benchmark_util.py b/build/torch29-cxx11-xpu20252-x86_64-linux/benchmark_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..02612d95e3ead1175a596e2878fa34b5bf85ad6f
--- /dev/null
+++ b/build/torch29-cxx11-xpu20252-x86_64-linux/benchmark_util.py
@@ -0,0 +1,35 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import numpy as np
+import torch
+
+
+def log_benchmark(name, arguments, time, std):
+ print('=' * 60)
+ print(f'{name} Benchmark')
+ print('Benchmark Parameters:')
+ for (key, value) in arguments.items():
+ print(f'{key} = {value}')
+ print('Results:')
+ print('mean time = {:.3f}ms, std time = {:.3f}ms'.format(time, std))
+ print('=' * 60)
+
+
+def benchmark_function(fn, iterations=100, warmup=10):
+ # Warmup iterations.
+ for _ in range(warmup):
+ fn()
+
+ times = []
+ for i in range(iterations):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+
+ start.record()
+ fn()
+ end.record()
+
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ return np.mean(times), np.std(times)
diff --git a/build/torch29-cxx11-xpu20252-x86_64-linux/cpu_fused_moe.py b/build/torch29-cxx11-xpu20252-x86_64-linux/cpu_fused_moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a29e81d6e9c57f64b0a78ac1c0828e45fd9d855
--- /dev/null
+++ b/build/torch29-cxx11-xpu20252-x86_64-linux/cpu_fused_moe.py
@@ -0,0 +1,311 @@
+# SPDX-License-Identifier: Apache-2.0
+# MegaBlocks CPU Fused MoE Implementation
+#
+# This is a pure Python/PyTorch implementation for CPU.
+# For better performance, consider using the C++ kernel implementation.
+#
+import torch
+import torch.nn.functional as F
+
+
+def swigluoai_activation(gate: torch.Tensor, up: torch.Tensor,
+ alpha: float = 1.702, limit: float = 7.0) -> torch.Tensor:
+ """
+ SwigluOAI activation function used in GptOss models.
+
+ Formula:
+ gate = clamp(gate, max=limit)
+ up = clamp(up, -limit, limit)
+ glu = gate * sigmoid(gate * alpha)
+ output = (up + 1) * glu
+
+ Args:
+ gate: Gate tensor from gate projection
+ up: Up tensor from up projection
+ alpha: Scaling factor for sigmoid (default: 1.702)
+ limit: Clamp limit (default: 7.0)
+
+ Returns:
+ Activated tensor
+ """
+ gate = gate.clamp(max=limit)
+ up = up.clamp(min=-limit, max=limit)
+ glu = gate * torch.sigmoid(gate * alpha)
+ return (up + 1) * glu
+
+
+def silu_and_mul_activation(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor:
+ """
+ SiLU (Swish) activation with element-wise multiplication.
+
+ Formula:
+ output = silu(gate) * up
+
+ Args:
+ gate: Gate tensor
+ up: Up tensor
+
+ Returns:
+ Activated tensor
+ """
+ return F.silu(gate) * up
+
+
+def route_tokens_cpu(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: torch.Tensor | None,
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_normalize_expert_weights: int | None = None,
+) -> tuple:
+ """
+ Route tokens to experts and compute expert weights and indices (CPU version).
+
+ Args:
+ x: Input tensor [batch, seq, hidden] or [tokens, hidden]
+ router_weight: Router weight [num_experts, hidden]
+ router_bias: Router bias [num_experts] or None
+ moe_top_k: Number of experts per token
+ moe_num_experts: Total number of experts
+ moe_normalize_expert_weights: Normalization order or None
+
+ Returns:
+ Tuple of (logits, expert_weights, expert_indices)
+ """
+ x_flat = x.view(-1, x.shape[-1])
+ logits = F.linear(x_flat, router_weight, router_bias)
+
+ if moe_top_k == 1:
+ expert_weights, expert_indices = logits.max(dim=-1, keepdim=True)
+ else:
+ expert_weights, expert_indices = torch.topk(logits, moe_top_k, dim=-1)
+
+ expert_weights = expert_weights.softmax(dim=-1)
+
+ if moe_normalize_expert_weights is not None:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+
+ return logits, expert_weights, expert_indices
+
+
+def cpu_fused_moe(
+ hidden_states: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ w1_bias: torch.Tensor | None = None,
+ w2_bias: torch.Tensor | None = None,
+ activation: str = "silu",
+ alpha: float = 1.702,
+ limit: float = 7.0,
+ is_interleaved: bool = True,
+) -> torch.Tensor:
+ """
+ CPU Fused MoE using PyTorch operations.
+
+ This implementation processes all experts in parallel using batched operations
+ instead of sequential for loops, which is more efficient on CPU.
+
+ Args:
+ hidden_states: [num_tokens, hidden_size]
+ w1: [num_experts, hidden_size, 2*inter_size] - gate_up_proj weights
+ w2: [num_experts, inter_size, hidden_size] - down_proj weights
+ topk_weights: [num_tokens, topk] - routing weights
+ topk_ids: [num_tokens, topk] - expert indices
+ w1_bias: [num_experts, 2*inter_size] or None
+ w2_bias: [num_experts, hidden_size] or None
+ activation: "silu" or "swigluoai"
+ alpha: swigluoai alpha parameter
+ limit: swigluoai limit parameter
+ is_interleaved: whether gate_up is interleaved [g0,u0,g1,u1,...] (True for GptOss)
+
+ Returns:
+ output: [num_tokens, hidden_size]
+ """
+ num_tokens, hidden_size = hidden_states.shape
+ num_experts = w1.shape[0]
+ inter_size = w2.shape[1]
+ topk = topk_weights.shape[1]
+
+ # Initialize output
+ output = torch.zeros_like(hidden_states)
+
+ # Build expert mask: which tokens go to which expert
+ # expert_mask[expert_id] contains indices of (token_idx, topk_pos) pairs
+ for expert_idx in range(num_experts):
+ # Find tokens assigned to this expert
+ # mask shape: [num_tokens, topk], True where topk_ids == expert_idx
+ mask = (topk_ids == expert_idx)
+
+ if not mask.any():
+ continue
+
+ # Get token indices and topk positions
+ token_indices, topk_positions = torch.where(mask)
+
+ if len(token_indices) == 0:
+ continue
+
+ # Gather input tokens for this expert
+ # current_hidden: [num_selected_tokens, hidden_size]
+ current_hidden = hidden_states[token_indices]
+
+ # Get weights for this expert
+ # w1[expert_idx]: [hidden_size, 2*inter_size]
+ # w2[expert_idx]: [inter_size, hidden_size]
+ expert_w1 = w1[expert_idx] # [hidden_size, 2*inter_size]
+ expert_w2 = w2[expert_idx] # [inter_size, hidden_size]
+
+ # First projection: hidden @ w1 -> [num_selected, 2*inter_size]
+ gate_up = current_hidden @ expert_w1
+
+ # Add bias if present
+ if w1_bias is not None:
+ gate_up = gate_up + w1_bias[expert_idx]
+
+ # Split gate and up projections
+ if is_interleaved:
+ # GptOss uses interleaved layout: [g0, u0, g1, u1, ...]
+ gate = gate_up[..., ::2] # [num_selected, inter_size]
+ up = gate_up[..., 1::2] # [num_selected, inter_size]
+ else:
+ # Standard layout: [gate_all, up_all]
+ gate = gate_up[..., :inter_size]
+ up = gate_up[..., inter_size:]
+
+ # Apply activation
+ if activation == "swigluoai":
+ activated = swigluoai_activation(gate, up, alpha, limit)
+ else: # silu
+ activated = silu_and_mul_activation(gate, up)
+
+ # Second projection: activated @ w2 -> [num_selected, hidden_size]
+ expert_out = activated @ expert_w2
+
+ # Add bias if present
+ if w2_bias is not None:
+ expert_out = expert_out + w2_bias[expert_idx]
+
+ # Apply routing weights and accumulate
+ # weights shape: [num_selected]
+ weights = topk_weights[token_indices, topk_positions].unsqueeze(-1)
+ weighted_out = expert_out * weights
+
+ # Accumulate to output
+ output.index_add_(0, token_indices, weighted_out.to(output.dtype))
+
+ return output
+
+
+class MegaBlocksMoeMLP(torch.nn.Module):
+ """
+ CPU MoE MLP module that can be used as a drop-in replacement for
+ the transformers GptOssMLP when using @use_kernel_forward_from_hub.
+ """
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> tuple:
+ """
+ Forward pass through the MoE layer.
+
+ Args:
+ x: Input tensor of shape [batch_size, seq_len, hidden_size] or [tokens, hidden_size]
+
+ Returns:
+ Tuple of (output, expert_weights) where:
+ - output: Tensor of same shape as input
+ - expert_weights: Expert weights for each token [tokens, top_k]
+ """
+ # Get MoE parameters from the wrapped modules
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+
+ # Detect activation type
+ if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
+ activation = "swigluoai"
+ alpha = self.experts.alpha
+ limit = self.experts.limit
+ else:
+ activation = getattr(self.experts, "activation", "silu")
+ alpha = 1.702
+ limit = 7.0
+
+ # Get weight tensors
+ if hasattr(self.experts, "gate_up_proj"):
+ w1 = self.experts.gate_up_proj
+ is_interleaved = True # GptOss uses interleaved layout
+ elif hasattr(self.experts, "w1"):
+ w1 = self.experts.w1
+ w3 = getattr(self.experts, "w3", None)
+ if w3 is not None:
+ w1 = torch.cat([w1, w3], dim=-1)
+ is_interleaved = False
+ else:
+ raise AttributeError("experts module must have 'gate_up_proj' or 'w1' attribute")
+
+ if hasattr(self.experts, "down_proj"):
+ w2 = self.experts.down_proj
+ elif hasattr(self.experts, "w2"):
+ w2 = self.experts.w2
+ else:
+ raise AttributeError("experts module must have 'down_proj' or 'w2' attribute")
+
+ # Get optional bias tensors
+ w1_bias = getattr(self.experts, "gate_up_proj_bias", None)
+ w2_bias = getattr(self.experts, "down_proj_bias", None)
+
+ # Store original shape
+ in_shape = x.size()
+
+ # Route tokens to experts
+ logits, expert_weights, expert_indices = route_tokens_cpu(
+ x,
+ self.router.weight,
+ getattr(self.router, "bias", None),
+ moe_top_k,
+ moe_num_experts,
+ moe_normalize_expert_weights,
+ )
+
+ # Reshape input for fused MoE
+ x_flat = x.view(-1, x.shape[-1])
+
+ # Call CPU fused MoE
+ output = cpu_fused_moe(
+ hidden_states=x_flat,
+ w1=w1,
+ w2=w2,
+ topk_weights=expert_weights,
+ topk_ids=expert_indices,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ activation=activation,
+ alpha=alpha,
+ limit=limit,
+ is_interleaved=is_interleaved,
+ )
+
+ # Restore original shape
+ output = output.view(in_shape)
+
+ return output, expert_weights
+
+
+# Export classes and functions
+__all__ = [
+ "MegaBlocksMoeMLP",
+ "cpu_fused_moe",
+ "route_tokens_cpu",
+ "swigluoai_activation",
+ "silu_and_mul_activation",
+]
diff --git a/build/torch29-cxx11-xpu20252-x86_64-linux/cpu_moe_cpp.py b/build/torch29-cxx11-xpu20252-x86_64-linux/cpu_moe_cpp.py
new file mode 100644
index 0000000000000000000000000000000000000000..073ff66d24ce348fbb5ed19c9027fadd3f7a9c61
--- /dev/null
+++ b/build/torch29-cxx11-xpu20252-x86_64-linux/cpu_moe_cpp.py
@@ -0,0 +1,265 @@
+# SPDX-License-Identifier: Apache-2.0
+# MegaBlocks C++ Optimized CPU MoE
+
+"""
+C++ accelerated MoE with brgemm optimization for Intel AMX.
+Direct replacement for cpu_fused_moe.MegaBlocksMoeMLP with better performance.
+"""
+
+import torch
+from typing import Optional
+from .cpu_fused_moe import route_tokens_cpu
+from ._ops import ops
+
+
+def _to_local_tensor(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
+ """Convert DTensor to local torch.Tensor if needed for custom ops compatibility."""
+ if tensor is None:
+ return None
+ # Check if it's a DTensor by looking for the to_local() method
+ if hasattr(tensor, "to_local"):
+ return tensor.to_local()
+ return tensor
+
+
+def fused_moe_cpp(
+ hidden_states: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ inplace: bool = False,
+ use_int8_w8a8: bool = False,
+ use_fp8_w8a16: bool = False,
+ use_mxfp4: bool = False,
+ w1_scale: Optional[torch.Tensor] = None,
+ w2_scale: Optional[torch.Tensor] = None,
+ block_size: Optional[list] = None,
+ a1_scale: Optional[torch.Tensor] = None,
+ a2_scale: Optional[torch.Tensor] = None,
+ w1_bias: Optional[torch.Tensor] = None,
+ w2_bias: Optional[torch.Tensor] = None,
+ alpha: Optional[float] = None,
+ limit: Optional[float] = None,
+ is_vnni: bool = False,
+) -> torch.Tensor:
+ """
+ C++ Fused MoE with brgemm optimization (sglang compatible interface).
+
+ Uses at::native::cpublas::brgemm for efficient batch GEMM on Intel CPUs.
+ Supports both silu_and_mul (standard SwiGLU) and swigluoai (GptOss) activations.
+
+ Args:
+ hidden_states: Input tensor [M, K]
+ w1: Gate and up projections [E, 2N, K]
+ w2: Down projection [E, K, N]
+ topk_weights: Expert weights [M, topk]
+ topk_ids: Expert indices [M, topk]
+ inplace: Whether to use hidden_states as output
+ use_int8_w8a8: Use int8 quantization
+ use_fp8_w8a16: Use fp8 quantization
+ use_mxfp4: Use mxfp4 quantization
+ w1_scale, w2_scale: Quantization scales
+ block_size: Block size for fp8
+ a1_scale, a2_scale: Activation scales
+ w1_bias, w2_bias: Optional biases
+ alpha: swigluoai alpha parameter (set to enable swiglu)
+ limit: swigluoai limit parameter (set to enable swiglu)
+ is_vnni: Whether w1/w2 are already in VNNI packed format
+ """
+ # MXFP4/FP8 kernels only support bf16, convert if needed
+ orig_dtype = hidden_states.dtype
+ need_convert = ((use_mxfp4 or use_fp8_w8a16) and orig_dtype != torch.bfloat16) or orig_dtype == torch.float32
+ if need_convert:
+ hidden_states = hidden_states.to(torch.bfloat16)
+
+ # bias must match hidden_states dtype
+ if w1_bias is not None:
+ w1_bias = w1_bias.to(hidden_states.dtype)
+ if w2_bias is not None:
+ w2_bias = w2_bias.to(hidden_states.dtype)
+
+ # Convert DTensor to local tensor for custom ops compatibility (TP mode)
+ hidden_states = _to_local_tensor(hidden_states)
+ w1 = _to_local_tensor(w1)
+ w2 = _to_local_tensor(w2)
+ topk_weights = _to_local_tensor(topk_weights)
+ topk_ids = _to_local_tensor(topk_ids)
+ w1_scale = _to_local_tensor(w1_scale)
+ w2_scale = _to_local_tensor(w2_scale)
+ a1_scale = _to_local_tensor(a1_scale)
+ a2_scale = _to_local_tensor(a2_scale)
+ w1_bias = _to_local_tensor(w1_bias)
+ w2_bias = _to_local_tensor(w2_bias)
+
+ output = ops.fused_experts(
+ hidden_states, w1, w2, topk_weights, topk_ids,
+ inplace, use_int8_w8a8, use_fp8_w8a16, use_mxfp4,
+ w1_scale, w2_scale, block_size, a1_scale, a2_scale,
+ w1_bias, w2_bias, alpha, limit, is_vnni
+ )
+
+ # Convert back to original dtype if needed
+ if need_convert:
+ output = output.to(orig_dtype)
+ return output
+
+
+class CPUMegaBlocksMoeMLP(torch.nn.Module):
+ """
+ C++ optimized MoE MLP using brgemm.
+ Drop-in replacement for cpu_fused_moe.MegaBlocksMoeMLP with better performance.
+
+ Usage in transformers:
+ # Will be used via @use_kernel_forward_from_hub decorator
+ """
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> tuple:
+ """
+ Forward pass through the MoE layer using C++ kernel.
+
+ Args:
+ x: Input tensor [batch_size, seq_len, hidden_size]
+
+ Returns:
+ Tuple of (output, expert_weights)
+ """
+ # Optimization for GPT-OSS model
+ if getattr(self, "use_mxfp4", None) is None:
+ self.use_mxfp4 = False
+
+ w1_scale = None
+ w2_scale = None
+
+ if (
+ not getattr(self, "packed_scales", False)
+ and hasattr(self.experts, "gate_up_proj")
+ and getattr(self.experts, "gate_up_proj_precision_config", None) is not None
+ ):
+ # convert scales
+ data_1 = ops.convert_scale_packed(self.experts.gate_up_proj_precision_config.weight_scale.data.transpose(-1, -2).contiguous())
+ data_2 = ops.convert_scale_packed(self.experts.down_proj_precision_config.weight_scale.data.transpose(-1, -2).contiguous())
+ self.experts.gate_up_proj_precision_config.weight_scale.storage.data = data_1
+ self.experts.down_proj_precision_config.weight_scale.storage.data = data_2
+ self.packed_scales = True
+ self.use_mxfp4 = True
+
+ if not getattr(self, "packed_weight", False) and hasattr(
+ self.experts, "gate_up_proj"
+ ):
+ # convert weights
+ data_1 = self.experts.gate_up_proj.data.transpose(-1, -2).contiguous()
+ data_2 = self.experts.down_proj.data.transpose(-1, -2).contiguous()
+ if self.use_mxfp4:
+ self.experts.gate_up_proj.storage.data = ops.convert_weight_packed(data_1)
+ self.experts.down_proj.storage.data = ops.convert_weight_packed(data_2)
+ else:
+ # convert_weight_packed only supports bfloat16, float16, int8, fp8_e4m3 or uint8(mxfp4 or int4).
+ data_1 = data_1.to(torch.bfloat16) if data_1.dtype == torch.float32 else data_1
+ data_2 = data_2.to(torch.bfloat16) if data_2.dtype == torch.float32 else data_2
+ self.experts.gate_up_proj.data = ops.convert_weight_packed(data_1)
+ self.experts.down_proj.data = ops.convert_weight_packed(data_2)
+
+ # C++ kernel does not support float32.
+ dtype = torch.bfloat16 if x.dtype == torch.float32 else x.dtype
+ if getattr(self.experts, "gate_up_proj_bias", None) is not None:
+ self.experts.gate_up_proj_bias.data = self.experts.gate_up_proj_bias.data.to(dtype)
+ if getattr(self.experts, "down_proj_bias", None) is not None:
+ self.experts.down_proj_bias.data = self.experts.down_proj_bias.data.to(dtype)
+
+ self.packed_weight = True
+
+ # Get MoE parameters
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
+
+ # Detect activation type
+ if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
+ activation = "swigluoai"
+ alpha = self.experts.alpha
+ limit = self.experts.limit
+ else:
+ activation = getattr(self.experts, "activation", "silu")
+ alpha = 1.702
+ limit = 7.0
+
+ # Get weight tensors
+ if hasattr(self.experts, "gate_up_proj"):
+ w1 = self.experts.gate_up_proj
+ elif hasattr(self.experts, "w1"):
+ w1 = self.experts.w1
+ w3 = getattr(self.experts, "w3", None)
+ if w3 is not None:
+ w1 = torch.cat([w1, w3], dim=-1)
+ else:
+ raise AttributeError("experts module must have 'gate_up_proj' or 'w1' attribute")
+
+ if hasattr(self.experts, "down_proj"):
+ w2 = self.experts.down_proj
+ elif hasattr(self.experts, "w2"):
+ w2 = self.experts.w2
+ else:
+ raise AttributeError("experts module must have 'down_proj' or 'w2' attribute")
+
+ # Get optional bias tensors
+ w1_bias = getattr(self.experts, "gate_up_proj_bias", None)
+ w2_bias = getattr(self.experts, "down_proj_bias", None)
+ w1_bias = w1_bias if w1_bias is None else w1_bias.data
+ w2_bias = w2_bias if w2_bias is None else w2_bias.data
+
+ if self.use_mxfp4:
+ w1_scale = self.experts.gate_up_proj_precision_config.weight_scale.data
+ w2_scale = self.experts.down_proj_precision_config.weight_scale.data
+
+ # Store original shape
+ in_shape = x.size()
+
+ # Route tokens to experts (Python implementation is fast enough)
+ logits, expert_weights, expert_indices = route_tokens_cpu(
+ x,
+ self.router.weight,
+ getattr(self.router, "bias", None),
+ moe_top_k,
+ moe_num_experts,
+ moe_normalize_expert_weights,
+ )
+
+ # Flatten input
+ x_flat = x.view(-1, x.shape[-1])
+
+ # Determine alpha/limit for swiglu activation
+ use_alpha = alpha if activation == "swigluoai" else None
+ use_limit = limit if activation == "swigluoai" else None
+
+ # Call C++ optimized kernel
+ output = fused_moe_cpp(
+ hidden_states=x_flat,
+ w1=w1.data,
+ w2=w2.data,
+ topk_weights=expert_weights,
+ topk_ids=expert_indices.to(torch.int32),
+ inplace=False,
+ use_int8_w8a8=False,
+ use_fp8_w8a16=False,
+ use_mxfp4=self.use_mxfp4,
+ w1_scale=w1_scale,
+ w2_scale=w2_scale,
+ block_size=None,
+ a1_scale=None,
+ a2_scale=None,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ alpha=use_alpha,
+ limit=use_limit,
+ is_vnni=getattr(self, "packed_weight", False),
+ )
+
+ # Restore original shape
+ output = output.view(in_shape)
+
+ return output, expert_weights
+
+
+__all__ = ["fused_moe_cpp", "MegaBlocksMoeMLP"]
diff --git a/build/torch29-cxx11-xpu20252-x86_64-linux/grouped_gemm/__init__.py b/build/torch29-cxx11-xpu20252-x86_64-linux/grouped_gemm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b91c8308f0c24f4c4171b6e4f15b6f76dabf295a
--- /dev/null
+++ b/build/torch29-cxx11-xpu20252-x86_64-linux/grouped_gemm/__init__.py
@@ -0,0 +1,2 @@
+from . import ops
+from . import backend
diff --git a/build/torch29-cxx11-xpu20252-x86_64-linux/grouped_gemm/backend.py b/build/torch29-cxx11-xpu20252-x86_64-linux/grouped_gemm/backend.py
new file mode 100644
index 0000000000000000000000000000000000000000..76037d8039cbfc2f0577275c78e4bc0be762592a
--- /dev/null
+++ b/build/torch29-cxx11-xpu20252-x86_64-linux/grouped_gemm/backend.py
@@ -0,0 +1,33 @@
+# NOTE: Torch needs to be imported before the custom
+# extensions. Otherwise libc10.so cannot be found.
+import torch
+
+# # TODO(tgale): Wrap this in a try-block with better
+# # error message and instructions for building the
+# # c++ operations.
+# import grouped_gemm_backend as backend
+
+# We import the backend operations from the megablocks package as
+# grouped_gemm is vendored in megablocks in this repository.
+# from ... import _ops as backend
+# from megablocks._ops import ops as backend # type: ignore
+from .._ops import ops as backend # type: ignore
+
+def _allocate_output(a, b, batch_sizes, trans_a, trans_b):
+ assert not (trans_a and trans_b)
+ assert batch_sizes.ndim == 1, "Expected 1d tensor for batch_sizes"
+ assert a.ndim == 2, "Expected 2d tensor for 'a'"
+ assert b.ndim == (2 if trans_a else 3)
+
+ shape = (
+ (batch_sizes.shape[0], a.shape[1], b.shape[1])
+ if trans_a else
+ (a.shape[0], (b.shape[1] if trans_b else b.shape[2]))
+ )
+ return torch.empty(*shape, device=a.device, dtype=a.dtype)
+
+def gmm(a, b, batch_sizes, trans_a=False, trans_b=False, c=None):
+ if c is None:
+ c = _allocate_output(a, b, batch_sizes, trans_a, trans_b)
+ backend.gmm(a, b, c, batch_sizes, trans_a, trans_b)
+ return c
diff --git a/build/torch29-cxx11-xpu20252-x86_64-linux/grouped_gemm/ops.py b/build/torch29-cxx11-xpu20252-x86_64-linux/grouped_gemm/ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b30dd14e23837ea3b12334f4e31337ed9ad2b69
--- /dev/null
+++ b/build/torch29-cxx11-xpu20252-x86_64-linux/grouped_gemm/ops.py
@@ -0,0 +1,33 @@
+from . import backend
+import torch
+
+
+class GroupedGemm(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, a, b, batch_sizes, trans_b):
+ ctx.save_for_backward(a, b, batch_sizes)
+ ctx.trans_b = trans_b
+ return backend.gmm(a, b, batch_sizes, trans_a=False, trans_b=trans_b)
+
+ @staticmethod
+ def backward(ctx, grad):
+ grad = grad.contiguous()
+ a, b, batch_sizes = ctx.saved_tensors
+ trans_b = ctx.trans_b
+
+ agrad = None
+ if ctx.needs_input_grad[0]:
+ agrad = backend.gmm(
+ grad, b, batch_sizes, trans_a=False, trans_b=not trans_b)
+
+ bgrad = None
+ if ctx.needs_input_grad[1]:
+ lhs, rhs = (grad, a) if trans_b else (a, grad)
+ bgrad = backend.gmm(
+ lhs, rhs, batch_sizes, trans_a=True, trans_b=False)
+ return agrad, bgrad, None, None
+
+
+def gmm(a, b, batch_sizes, trans_b=False):
+ return GroupedGemm.apply(a, b, batch_sizes, trans_b)
diff --git a/build/torch29-cxx11-xpu20252-x86_64-linux/grouped_gemm_util.py b/build/torch29-cxx11-xpu20252-x86_64-linux/grouped_gemm_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d6d49fc46b0a57ad46e4179df3cc1ac2a24f7ae
--- /dev/null
+++ b/build/torch29-cxx11-xpu20252-x86_64-linux/grouped_gemm_util.py
@@ -0,0 +1,31 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+import warnings
+
+_grouped_gemm_is_available: bool = False
+try:
+ # import grouped_gemm
+ pass
+ _grouped_gemm_is_available = True
+except ImportError as error:
+ warnings.warn('Grouped GEMM not available.')
+
+
+def grouped_gemm_is_available():
+ return _grouped_gemm_is_available
+
+
+def assert_grouped_gemm_is_available():
+ msg = (
+ 'Grouped GEMM not available. Please run '
+ '`pip install git+https://github.com/tgale96/grouped_gemm@main`.',
+ )
+ assert _grouped_gemm_is_available, msg
+
+
+# backend = grouped_gemm.backend if grouped_gemm_is_available() else None
+# ops = grouped_gemm.ops if grouped_gemm_is_available() else None
+
+
+from .grouped_gemm import backend as ops
+from .grouped_gemm import ops as backend
diff --git a/build/torch29-cxx11-xpu20252-x86_64-linux/layers.py b/build/torch29-cxx11-xpu20252-x86_64-linux/layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c91edcd9e2d1a4ef9eac90217ff481f08ab1886
--- /dev/null
+++ b/build/torch29-cxx11-xpu20252-x86_64-linux/layers.py
@@ -0,0 +1,1232 @@
+import torch
+import torch.distributed as dist
+
+from typing import Optional, Any, TYPE_CHECKING
+
+from . import _layers
+from . import ops
+
+# Conditional import for meta kernel registration
+if TYPE_CHECKING:
+
+ def register_fake(fn):
+ return lambda name: fn
+
+else:
+ try:
+ from torch.library import register_fake
+ except ImportError:
+ try:
+ from torch.library import impl_abstract as register_fake
+ except ImportError:
+ # Fallback for older PyTorch versions
+ def register_fake(op_name):
+ def decorator(fn):
+ return fn
+
+ return decorator
+
+
+# Meta kernel implementations for torch.compile compatibility
+def _install_meta_kernels():
+ """Install meta kernels for existing MegaBlocks operations"""
+
+ # Create wrapper functions that check for compilation and return meta tensors
+
+ # Patch ops.sort
+ if hasattr(ops, "sort"):
+ original_sort = ops.sort
+
+ def sort_with_meta(x, end_bit=None):
+ if torch.compiler.is_compiling():
+ # print("Using meta kernel for sort")
+ # Meta implementation - return tensors with correct shape/dtype/device
+ return torch.empty_like(x), torch.empty_like(x)
+ # print("Using original sort kernel")
+ return original_sort(x, end_bit)
+
+ ops.sort = sort_with_meta
+
+ # Patch ops.histogram
+ if hasattr(ops, "histogram"):
+ original_histogram = ops.histogram
+
+ def histogram_with_meta(x, max_val):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty((max_val,), dtype=torch.int32, device=x.device)
+ return original_histogram(x, max_val)
+
+ ops.histogram = histogram_with_meta
+
+ # Patch ops.inclusive_cumsum
+ if hasattr(ops, "inclusive_cumsum"):
+ original_inclusive_cumsum = ops.inclusive_cumsum
+
+ def inclusive_cumsum_with_meta(x, dim):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty_like(x)
+ return original_inclusive_cumsum(x, dim)
+
+ ops.inclusive_cumsum = inclusive_cumsum_with_meta
+
+ # Patch ops.binned_gather
+ if hasattr(ops, "binned_gather"):
+ original_binned_gather = ops.binned_gather
+
+ def binned_gather_with_meta(x, indices, bins, bin_size, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - output shape based on bin_size
+ if x.dim() >= 2:
+ hidden_size = x.size(-1)
+ return torch.empty(
+ (bin_size, x.size(1), hidden_size),
+ dtype=x.dtype,
+ device=x.device,
+ )
+ else:
+ return torch.empty((bin_size,), dtype=x.dtype, device=x.device)
+ return original_binned_gather(x, indices, bins, bin_size, top_k)
+
+ ops.binned_gather = binned_gather_with_meta
+
+ # Patch ops.binned_scatter
+ if hasattr(ops, "binned_scatter"):
+ original_binned_scatter = ops.binned_scatter
+
+ def binned_scatter_with_meta(x, indices, weights, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - typically reduces to 2D
+ if x.dim() >= 3:
+ return torch.empty(
+ (x.size(1), x.size(2)), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty_like(x)
+ return original_binned_scatter(x, indices, weights, bins, top_k)
+
+ ops.binned_scatter = binned_scatter_with_meta
+
+ # Patch ops.gather
+ if hasattr(ops, "gather"):
+ original_gather = ops.gather
+
+ def gather_with_meta(x, indices, bin_ids, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ if x.dim() >= 2:
+ hidden_size = x.size(-1)
+ return torch.empty(
+ (indices.numel(), hidden_size), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty(indices.shape, dtype=x.dtype, device=x.device)
+ return original_gather(x, indices, bin_ids, bins, top_k)
+
+ ops.gather = gather_with_meta
+
+ # Patch ops.scatter
+ if hasattr(ops, "scatter"):
+ original_scatter = ops.scatter
+
+ def scatter_with_meta(x, indices, bin_ids, weights, bins, top_k):
+ if torch.compiler.is_compiling():
+ # Meta implementation - restore sequence shape
+ seq_len = (
+ indices.size(0) // top_k
+ if indices.numel() > 0 and top_k > 0
+ else x.size(0)
+ )
+ if x.dim() >= 2:
+ return torch.empty(
+ (seq_len, x.size(-1)), dtype=x.dtype, device=x.device
+ )
+ else:
+ return torch.empty((seq_len,), dtype=x.dtype, device=x.device)
+ return original_scatter(x, indices, bin_ids, weights, bins, top_k)
+
+ ops.scatter = scatter_with_meta
+
+ # Patch ops.replicate
+ if hasattr(ops, "replicate"):
+ original_replicate = ops.replicate
+
+ def replicate_with_meta(x, bins, num_outputs):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ return torch.empty(
+ (x.shape[0], num_outputs), dtype=x.dtype, device=x.device
+ )
+ return original_replicate(x, bins, num_outputs)
+
+ ops.replicate = replicate_with_meta
+
+ # Patch ops.repeat (if it's a regular function)
+ if hasattr(ops, "repeat"):
+ original_repeat = ops.repeat
+
+ def repeat_with_meta(x, repeats):
+ if torch.compiler.is_compiling():
+ # Meta implementation
+ if isinstance(repeats, (tuple, list)):
+ new_shape = list(x.shape)
+ for i, rep in enumerate(repeats):
+ if i < len(new_shape):
+ new_shape[i] *= rep
+ return torch.empty(new_shape, dtype=x.dtype, device=x.device)
+ else:
+ new_shape = [x.size(0) * repeats] + list(x.shape[1:])
+ return torch.empty(new_shape, dtype=x.dtype, device=x.device)
+ return original_repeat(x, repeats)
+
+ ops.repeat = repeat_with_meta
+
+
+# Install meta kernels on import
+try:
+ _install_meta_kernels()
+except Exception as e:
+ # If meta kernel installation fails, continue without them
+ # torch.compile may not work but the library will still function
+ import warnings
+
+ warnings.warn(
+ f"Failed to install meta kernels for torch.compile support: {e}", UserWarning
+ )
+
+
+# Set the expert model parallel attributes on a tensor
+def set_expert_model_parallel_attributes(
+ tensor: torch.Tensor,
+ is_parallel: bool,
+):
+ assert not hasattr(tensor, "expert_model_parallel")
+ setattr(tensor, "expert_model_parallel", is_parallel)
+
+
+# Get the expert model parallel attributes from a tensor
+def expert_sharding_degree(
+ world_size: int,
+ moe_num_experts: int,
+) -> int:
+ esd = min(world_size, moe_num_experts)
+ if (moe_num_experts % esd) != 0:
+ raise ValueError(f"Cannot shard {moe_num_experts} experts {esd} ways.")
+ return esd
+
+
+# Calculate the hidden sharding degree based on world size and expert sharding degree
+def hidden_sharding_degree(
+ world_size: int,
+ moe_num_experts: int,
+ ffn_hidden_size: int,
+) -> int:
+ esd = expert_sharding_degree(world_size, moe_num_experts)
+ hsd = world_size // esd
+ if (ffn_hidden_size % hsd) != 0:
+ raise ValueError(f"Cannot shard {ffn_hidden_size} features {hsd} ways.")
+ if (esd * hsd) != world_size:
+ raise ValueError(
+ f"Invalid sharding. expert_sharding_degree ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size})."
+ )
+ return hsd
+
+
+# Calculate the number of experts per rank based on world size and expert sharding degree
+def experts_per_rank(
+ moe_num_experts: int,
+ world_size: int,
+) -> int:
+ return moe_num_experts // expert_sharding_degree(world_size, moe_num_experts)
+
+
+# Calculate the number of features per rank based on ffn hidden size and hidden sharding degree
+def features_per_rank(
+ ffn_hidden_size: int, world_size: int, moe_num_experts: int
+) -> int:
+ return ffn_hidden_size // hidden_sharding_degree(
+ world_size, moe_num_experts, ffn_hidden_size
+ )
+
+
+# Apply jitter to the input tensor
+def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor:
+ low = 1.0 - moe_jitter_eps
+ high = 1.0 + moe_jitter_eps
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
+ return x * (low + noise * (high - low))
+
+
+# Compute the top-k scores from the logits
+def compute_top_k(scores: torch.Tensor, moe_top_k: int):
+ if moe_top_k == 1:
+ return scores.max(dim=-1, keepdim=True)
+ return torch.topk(scores, moe_top_k, dim=-1)
+
+
+# Route tokens to experts and compute expert weights and indices
+def route_tokens(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: torch.Tensor,
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ if training and moe_jitter_eps is not None:
+ x = apply_jitter(x, moe_jitter_eps)
+
+ x_flat = x.view(-1, x.shape[-1])
+ logits = torch.nn.functional.linear(x_flat, router_weight, router_bias)
+ expert_weights, expert_indices = compute_top_k(logits, moe_top_k)
+ expert_weights = expert_weights.softmax(dim=-1)
+ if moe_normalize_expert_weights is not None:
+ expert_weights = expert_weights / torch.norm(
+ expert_weights,
+ p=moe_normalize_expert_weights,
+ dim=-1,
+ keepdim=True,
+ )
+ if uniform_expert_assignment:
+ expert_indices = _layers.router._uniform_expert_assignment(
+ expert_indices,
+ moe_num_experts,
+ )
+
+ return logits, expert_weights, expert_indices
+
+
+# Scale the gradient of the weights
+def scale_grad(
+ w: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+) -> torch.Tensor:
+ if gradient_scale is None:
+ return w
+ return _layers.mlp.scale_gradient(w, gradient_scale)
+
+
+# Forward pass for the MLP layer
+def mlp_forward(
+ x: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ limit: float = 7.0,
+):
+ # Scale weights
+ w1 = scale_grad(w1, gradient_scale)
+ w2 = scale_grad(w2, gradient_scale)
+ w1_bias = scale_grad(w1_bias, gradient_scale)
+ w2_bias = scale_grad(w2_bias, gradient_scale)
+
+ # Resolve dtensors
+ w1 = _layers.mlp.resolve_dtensor(w1)
+ w2 = _layers.mlp.resolve_dtensor(w2)
+ w1_bias = _layers.mlp.resolve_dtensor(w1_bias)
+ w2_bias = _layers.mlp.resolve_dtensor(w2_bias)
+
+ # Forward pass
+ gate_up = torch.bmm(x, w1) + w1_bias[..., None, :]
+ gate, up = gate_up[..., ::2], gate_up[..., 1::2]
+ gate = gate.clamp(min=None, max=limit)
+ up = up.clamp(min=-limit, max=limit)
+ glu = gate * torch.sigmoid(gate * alpha)
+ next_states = torch.bmm(((up + 1) * glu), w2)
+ next_states += w2_bias[..., None, :]
+ return next_states
+
+# Shared expert MLP forward pass
+def shared_mlp_forward(
+ x: torch.Tensor,
+ up_proj_weight: torch.Tensor,
+ down_proj_weight: torch.Tensor,
+ up_proj_bias: Optional[torch.Tensor] = None,
+ down_proj_bias: Optional[torch.Tensor] = None,
+ activation_fn: Optional[Any] = None,
+ gradient_scale: Optional[float] = None,
+) -> torch.Tensor:
+ # Default activation function
+ if activation_fn is None:
+ activation_fn = torch.nn.functional.gelu
+
+ # Scale weights
+ up_proj_weight = scale_grad(up_proj_weight, gradient_scale)
+ down_proj_weight = scale_grad(down_proj_weight, gradient_scale)
+ if up_proj_bias is not None:
+ up_proj_bias = scale_grad(up_proj_bias, gradient_scale)
+ if down_proj_bias is not None:
+ down_proj_bias = scale_grad(down_proj_bias, gradient_scale)
+
+ # Resolve dtensors
+ up_proj_weight = _layers.mlp.resolve_dtensor(up_proj_weight)
+ down_proj_weight = _layers.mlp.resolve_dtensor(down_proj_weight)
+ if up_proj_bias is not None:
+ up_proj_bias = _layers.mlp.resolve_dtensor(up_proj_bias)
+ if down_proj_bias is not None:
+ down_proj_bias = _layers.mlp.resolve_dtensor(down_proj_bias)
+
+ # Up projection
+ x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias)
+
+ # Activation
+ x = activation_fn(x)
+
+ # Down projection
+ x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias)
+
+ return x
+
+
+# Combine outputs from shared expert and regular experts
+def combine_expert_shared_outputs(
+ shared_expert_out: torch.Tensor,
+ expert_out: torch.Tensor,
+ shared_expert_weighted_sum: bool = False,
+ moe_top_k: int = 1,
+) -> torch.Tensor:
+ if shared_expert_weighted_sum:
+ # Weighted sum based on number of experts used
+ total_experts = moe_top_k + 1
+ shared_weight = 1.0 / total_experts
+ expert_weight = moe_top_k / total_experts
+ return shared_expert_out * shared_weight + expert_out * expert_weight
+ else:
+ # Simple addition
+ return shared_expert_out + expert_out
+
+
+# Global variable to store load balancing loss
+_LOAD_BALANCING_LOSS = []
+
+
+def save_load_balancing_loss(loss):
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.append(loss)
+
+
+def get_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ return _LOAD_BALANCING_LOSS
+
+
+def clear_load_balancing_loss():
+ global _LOAD_BALANCING_LOSS
+ _LOAD_BALANCING_LOSS.clear()
+
+
+def batched_load_balancing_loss(args):
+ if args.moe_loss_weight == 0:
+ return 0.0
+
+ tokens_per_expert, expert_scores = zip(*get_load_balancing_loss())
+ num_layers_per_pipeline_stage = args.num_layers // args.pipeline_model_parallel_size
+ if args.num_layers_per_virtual_pipeline_stage is not None:
+ num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage
+
+ if len(tokens_per_expert) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f"Expected {num_layers_per_pipeline_stage} token_per_experts "
+ f"but found {len(tokens_per_expert)}.\nnum_layers = "
+ f"{args.num_layers}\npipeline_model_parallel_size = "
+ f"{args.pipeline_model_parallel_size}\n"
+ "num_layers_per_virtual_pipeline_stage"
+ f" = {args.num_layers_per_virtual_pipeline_stage}",
+ )
+ if len(expert_scores) != num_layers_per_pipeline_stage:
+ raise ValueError(
+ f"Expected {num_layers_per_pipeline_stage} expert_scores "
+ f"but found {len(tokens_per_expert)}.\nnum_layers = "
+ f"{args.num_layers}\npipeline_model_parallel_size = "
+ f"{args.pipeline_model_parallel_size}\n"
+ "num_layers_per_virtual_pipeline_stage"
+ f" = {args.num_layers_per_virtual_pipeline_stage}",
+ )
+
+ # Verify the shape of the tokens_per_expert and expert_scores tensors.
+ assert all(
+ (x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert)
+ )
+
+ tokens = expert_scores[0].shape[0]
+ assert all(
+ (
+ (
+ x.ndim == 2
+ and x.shape[1] == args.moe_num_experts
+ and x.shape[0] == tokens
+ )
+ for x in expert_scores
+ )
+ )
+
+ # Concatenate the contributions of each layer and convert to
+ # the correct types and formats for the dot product.
+ expert_scores = torch.cat(expert_scores, dim=1)
+ if args.moe_lbl_in_fp32:
+ expert_scores = expert_scores.float()
+ if tokens != 0:
+ expert_scores = expert_scores.mean(dim=0)
+ else:
+ expert_scores = expert_scores.sum(dim=0)
+ tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype)
+
+ expected_values = num_layers_per_pipeline_stage * args.moe_num_experts
+ assert tokens_per_expert.numel() == expected_values
+ assert expert_scores.numel() == expected_values
+
+ # Calculate the total scale across all factors.
+ #
+ # loss_weight * num_experts / (num_layers * tokens * top_k)
+ scale_numerator = args.moe_num_experts * args.moe_loss_weight
+ scale_denominator = args.num_layers * tokens * args.moe_top_k
+ scale = scale_numerator / scale_denominator
+ return scale * torch.dot(tokens_per_expert, expert_scores)
+
+
+# Calculate the expert capacity based on tokens, top_k, number of experts,
+# expert parallel group, capacity factor, and whether expert model parallelism is used.
+def expert_capacity(
+ tokens: int,
+ top_k: int,
+ num_experts: int,
+ expert_parallel_group: int,
+ moe_capacity_factor: float,
+ moe_expert_model_parallelism: bool,
+) -> int:
+ world_size = (
+ dist.get_world_size(expert_parallel_group)
+ if moe_expert_model_parallelism
+ else 1
+ )
+
+ tokens_per_expert = top_k * tokens * world_size / num_experts
+ return int(moe_capacity_factor * tokens_per_expert)
+
+
+def load_balancing_loss(
+ tokens_per_expert: torch.Tensor,
+ expert_scores: torch.Tensor,
+ top_k: int,
+ num_experts: int,
+):
+ assert len(expert_scores.size()) == 2
+ tokens, num_experts = expert_scores.size()
+ assert num_experts == num_experts
+ assert len(tokens_per_expert.size()) == 1
+ (num_experts,) = tokens_per_expert.size()
+ assert num_experts == num_experts
+ scale = num_experts / (tokens * top_k)
+ return scale * torch.dot(
+ tokens_per_expert.to(expert_scores.dtype),
+ expert_scores.mean(dim=0),
+ )
+
+
+def indices_and_bins(
+ top_expert: torch.Tensor,
+ sort_end_bit: int,
+ num_experts: int,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ top_expert = top_expert.int()
+
+ # Ensure contiguous memory layout
+ top_expert = top_expert.contiguous()
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(top_expert.device):
+ output = ops.sort(top_expert, sort_end_bit)
+ bin_ids, indices = output
+ tokens_per_expert = ops.histogram(top_expert, num_experts)
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
+
+ bins = bins.view(1) if not len(bins.size()) else bins
+ return indices, bin_ids, bins, tokens_per_expert
+
+
+def expert_capacity_fn(
+ tokens: int,
+ top_k: int,
+ num_experts: int,
+ expert_parallel_group: torch.distributed.ProcessGroup,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+) -> int:
+ world_size = (
+ dist.get_world_size(expert_parallel_group)
+ if moe_expert_model_parallelism
+ else 1
+ )
+ tokens_per_expert = top_k * tokens * world_size / num_experts
+ return int(moe_capacity_factor * tokens_per_expert)
+
+
+def permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ top_k,
+ w1,
+ w2,
+ w1_bias,
+ w2_bias,
+ gradient_scale,
+ alpha,
+):
+ # Route tokens to experts
+ x = x.view(-1, x.shape[-1])
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(x.device):
+ x = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
+
+ # Expert computation
+ x = mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale, alpha)
+
+ # Ensure CUB knows which device to use
+ with torch.cuda.device(x.device):
+ # Route tokens back
+ out = ops.binned_scatter(x, indices, expert_weights, bins, top_k)
+ return out
+
+
+def forward_once(
+ x: torch.Tensor,
+ expert_weights: torch.Tensor,
+ top_experts: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ top_k: int = 4,
+ num_experts: int = 128,
+ expert_parallel_group: int = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ mlp_impl: Optional[str] = None,
+):
+ # x: [sl, bs, hs]
+ # expert_weights: [sl * bs, top-k]
+ # top_experts: [sl * bs, top-k]
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+
+ with torch.no_grad():
+ indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
+ top_experts, sort_end_bit, num_experts
+ )
+
+ # Calculate expert capacity
+ sl, bs, _ = x.size()
+
+ expert_capacity = expert_capacity_fn(
+ sl * bs,
+ top_k,
+ num_experts,
+ expert_parallel_group,
+ moe_capacity_factor,
+ moe_expert_model_parallelism,
+ )
+
+ if expert_capacity == 0:
+ expert_capacity = torch.max(tokens_per_expert).item()
+
+ x = permute_and_compute(
+ x,
+ tokens_per_expert,
+ indices,
+ bin_ids,
+ expert_weights,
+ bins,
+ expert_capacity,
+ top_k,
+ w1,
+ w2,
+ w1_bias,
+ w2_bias,
+ gradient_scale,
+ alpha,
+ )
+ return x, tokens_per_expert
+
+
+def parallel_forward_once(
+ x: torch.Tensor,
+ expert_weights: torch.Tensor,
+ top_experts: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ top_k: int = 4,
+ num_experts: int = 128,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = True,
+ hidden_size: int = 1152,
+ mlp_impl: Optional[str] = "grouped",
+):
+ # Flatten inputs
+ expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+
+ # TODO: remove debugging var
+ # my_rank = dist.get_rank(expert_parallel_group) if expert_parallel_group else 0
+
+ with torch.no_grad():
+ # Step 1: Local permutation setup
+ indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
+ top_experts, sort_end_bit, num_experts
+ )
+
+ # Calculate sharding parameters
+ world_size = dist.get_world_size(expert_parallel_group)
+ hidden_sharding_deg = hidden_sharding_degree(
+ world_size, num_experts, hidden_size
+ )
+ experts_per_rank_val = experts_per_rank(num_experts, world_size)
+
+ # Replicate token counts for hidden sharding
+ repeated_tokens_per_expert = ops.repeat(
+ tokens_per_expert, (hidden_sharding_deg,)
+ )
+
+ # Exchange token counts across devices
+ parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert)
+
+ # Ensure CUB knows which device to use
+ tpe_handle = dist.all_to_all_single(
+ parallel_tokens_per_expert,
+ repeated_tokens_per_expert,
+ group=expert_parallel_group,
+ async_op=True,
+ )
+
+ # Step 2: Local permutation - group tokens by target device
+ x = x.view(-1, x.shape[-1]) # [sl * bs, hs]
+ x = ops.gather(x, indices, bin_ids, bins, top_k)
+
+ # Step 3: Compute communication counts and exchange tokens
+ with torch.no_grad():
+ tpe_handle.wait()
+
+ # Reshape for per-device calculations
+ repeated_tokens_per_expert = repeated_tokens_per_expert.view(
+ world_size, experts_per_rank_val
+ )
+ parallel_tokens_per_expert = parallel_tokens_per_expert.view(
+ world_size, experts_per_rank_val
+ )
+
+ # Calculate send/recv counts
+ send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1).tolist()
+ # recv_counts = parallel_tokens_per_expert.cpu().sum(dim=-1).tolist()
+ parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
+ recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1).tolist()
+ tokens_received = sum(recv_counts)
+
+ # Replicate for hidden sharding
+ x = ops.repeat(x, (hidden_sharding_deg, 1))
+
+ # Cross-device token exchange
+ parallel_x, parallel_x_handle = _layers.all_to_all.all_to_all(
+ x, recv_counts, send_counts, expert_parallel_group, async_op=True
+ )
+
+ with torch.no_grad():
+ # Step 4: Setup for local expert computation
+ replicate_bins = ops.inclusive_cumsum(parallel_tokens_per_expert.flatten(), 0)
+ replicate_bins = (
+ replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins
+ )
+
+ # Create expert indices for received tokens
+ parallel_top_expert = torch.remainder(
+ torch.arange(
+ num_experts * hidden_sharding_deg,
+ dtype=torch.int32,
+ device=indices.device,
+ ),
+ experts_per_rank_val,
+ )
+ parallel_top_expert = ops.replicate(
+ parallel_top_expert.unsqueeze(dim=0),
+ replicate_bins,
+ tokens_received,
+ ).flatten()
+
+ # Sort tokens by expert assignment
+ parallel_bin_ids, parallel_indices = ops.sort(
+ parallel_top_expert,
+ sort_end_bit,
+ )
+
+ # Calculate bins for local experts
+ parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
+ dim=0, dtype=torch.int
+ )
+ parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
+ parallel_bins = (
+ parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins
+ )
+
+ # Calculate expert capacity
+ expert_capacity = expert_capacity_fn(
+ tokens_received,
+ top_k,
+ experts_per_rank_val,
+ expert_parallel_group,
+ moe_capacity_factor,
+ moe_expert_model_parallelism,
+ )
+ if expert_capacity == 0:
+ expert_capacity = torch.max(parallel_tokens_per_expert).item()
+
+ # Locally permute the tokens and perform the expert computation.
+ # Block to make sure that the cross-device permutation is complete.
+ if mlp_impl == "grouped":
+ # GroupedMLP requires counts on CPU. We can use the tensor already
+ # moved to CPU for the prior all_to_all, which avoids an extra
+ # device synchronization.
+ parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
+ dim=0,
+ dtype=torch.int,
+ )
+
+ # Step 5: Expert computation
+ parallel_x_handle.wait()
+
+ parallel_x = permute_and_compute(
+ parallel_x,
+ parallel_tokens_per_expert,
+ parallel_indices,
+ parallel_bin_ids,
+ None, # expert_weights
+ parallel_bins,
+ expert_capacity,
+ top_k=1,
+ w1=w1,
+ w2=w2,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ )
+
+ # Step 6: Reverse communication - send results back
+ x, _ = _layers.all_to_all.all_to_all(
+ parallel_x, send_counts, recv_counts, expert_parallel_group
+ )
+
+ # Step 7: Reduce across hidden sharding dimension
+ shape = (hidden_sharding_deg, -1, hidden_size)
+ x = x.view(shape).sum(dim=0)
+
+ # Step 8: Final local unpermutation
+ x = ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
+
+ return x, tokens_per_expert.flatten()
+
+
+def moe_forward(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: Optional[torch.Tensor],
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+ w1: torch.Tensor = None,
+ w2: torch.Tensor = None,
+ w1_bias: torch.Tensor = None,
+ w2_bias: torch.Tensor = None,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ forward_fn: Any = None,
+ hidden_size: int = None,
+ mlp_impl: str = "grouped",
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+
+ # Route tokens to experts
+ logits, expert_weights, expert_indices = route_tokens(
+ x,
+ router_weight,
+ router_bias,
+ moe_top_k,
+ moe_num_experts,
+ moe_jitter_eps,
+ moe_normalize_expert_weights,
+ uniform_expert_assignment,
+ training,
+ )
+
+ # Create router scores for output
+ router_scores = (
+ torch.zeros_like(logits)
+ .scatter_(1, expert_indices, expert_weights)
+ .transpose(0, 1)
+ )
+
+ in_shape = x.size()
+
+ # Prepare forward function arguments
+ forward_args = {
+ "x": x,
+ "expert_weights": expert_weights,
+ "top_experts": expert_indices,
+ "w1": w1,
+ "w2": w2,
+ "w1_bias": w1_bias,
+ "w2_bias": w2_bias,
+ "gradient_scale": gradient_scale,
+ "alpha": alpha,
+ "sort_end_bit": sort_end_bit,
+ "top_k": moe_top_k,
+ "num_experts": moe_num_experts,
+ "expert_parallel_group": expert_parallel_group,
+ "moe_capacity_factor": moe_capacity_factor,
+ "moe_expert_model_parallelism": moe_expert_model_parallelism,
+ "mlp_impl": mlp_impl,
+ }
+
+ # Add hidden_size for parallel forward
+ if moe_expert_model_parallelism and hidden_size is not None:
+ forward_args["hidden_size"] = hidden_size
+ elif moe_expert_model_parallelism and hidden_size is None:
+ # Infer hidden_size from input shape
+ forward_args["hidden_size"] = x.shape[-1]
+
+ # Compute expert outputs
+ x, tokens_per_expert = forward_fn(**forward_args)
+
+ # Save load balancing loss if needed
+ moe_loss_weight = 0.0 # Can be made configurable
+ if training and moe_loss_weight > 0:
+ save_load_balancing_loss((tokens_per_expert, logits))
+
+ # Restore original shape
+ x = x.view(in_shape)
+
+ return x, expert_weights, router_scores
+
+
+def moe_forward_with_shared_expert(
+ x: torch.Tensor,
+ router_weight: torch.Tensor,
+ router_bias: Optional[torch.Tensor],
+ moe_top_k: int,
+ moe_num_experts: int,
+ moe_jitter_eps: float = None,
+ moe_normalize_expert_weights: int = None,
+ uniform_expert_assignment: bool = False,
+ training: bool = False,
+ w1: torch.Tensor = None,
+ w2: torch.Tensor = None,
+ w1_bias: torch.Tensor = None,
+ w2_bias: torch.Tensor = None,
+ gradient_scale: Optional[float] = None,
+ alpha: float = 1.702,
+ sort_end_bit: int = 0,
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
+ moe_capacity_factor: float = 1.0,
+ moe_expert_model_parallelism: bool = False,
+ forward_fn: Any = None,
+ hidden_size: int = None,
+ mlp_impl: str = "grouped",
+ # Shared expert parameters
+ shared_up_proj_weight: Optional[torch.Tensor] = None,
+ shared_down_proj_weight: Optional[torch.Tensor] = None,
+ shared_up_proj_bias: Optional[torch.Tensor] = None,
+ shared_down_proj_bias: Optional[torch.Tensor] = None,
+ shared_expert_weighted_sum: bool = False,
+ shared_activation_fn: Optional[Any] = None,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+
+ # First, compute regular MoE forward pass
+ expert_out, expert_weights, router_scores = moe_forward(
+ x=x,
+ router_weight=router_weight,
+ router_bias=router_bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=training,
+ w1=w1,
+ w2=w2,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=moe_expert_model_parallelism,
+ forward_fn=forward_fn,
+ hidden_size=hidden_size,
+ mlp_impl=mlp_impl,
+ )
+
+ # If shared expert weights provided, compute shared expert output
+ if shared_up_proj_weight is not None and shared_down_proj_weight is not None:
+ shared_expert_out = shared_mlp_forward(
+ x=x,
+ up_proj_weight=shared_up_proj_weight,
+ down_proj_weight=shared_down_proj_weight,
+ up_proj_bias=shared_up_proj_bias,
+ down_proj_bias=shared_down_proj_bias,
+ activation_fn=shared_activation_fn,
+ gradient_scale=gradient_scale,
+ )
+
+ # Combine expert outputs
+ combined_out = combine_expert_shared_outputs(
+ shared_expert_out=shared_expert_out,
+ expert_out=expert_out,
+ shared_expert_weighted_sum=shared_expert_weighted_sum,
+ moe_top_k=moe_top_k,
+ )
+
+ return combined_out, expert_weights, router_scores
+
+ # Return regular MoE output if no shared expert
+ return expert_out, expert_weights, router_scores
+
+
+def create_shared_expert_weights(
+ hidden_size: int,
+ shared_expert_hidden_size: int,
+ device: torch.device,
+ dtype: torch.dtype,
+ init_method: Any,
+ output_layer_init_method: Any = None,
+) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
+
+ if output_layer_init_method is None:
+ output_layer_init_method = init_method
+
+ # Create weight tensors
+ up_proj_weight = torch.empty(
+ shared_expert_hidden_size,
+ hidden_size,
+ device=device,
+ dtype=dtype,
+ )
+ down_proj_weight = torch.empty(
+ hidden_size,
+ shared_expert_hidden_size,
+ device=device,
+ dtype=dtype,
+ )
+
+ # Initialize weights
+ init_method(up_proj_weight)
+ output_layer_init_method(down_proj_weight)
+
+ # No bias by default
+ return up_proj_weight, down_proj_weight, None, None
+
+
+# HACK: Extract device_mesh from pre-hook closure - required for transformers integration
+# This exists because device_mesh is trapped in hook closures with no model attribute
+# Fragile - breaks if hook structure changes or Python internals change
+# TODO: Replace with a more robust solution when available
+def get_device_mesh(model):
+ # Extract device_mesh from child's unused pre_hook closure
+ try:
+ # Find the pre-hook that contains 'device_mesh' in its closure
+ hook = next(
+ h
+ for h in model.experts._forward_pre_hooks.values()
+ if "device_mesh" in h.__code__.co_freevars
+ )
+ # Extract the device_mesh from the closure
+ return hook.__closure__[
+ hook.__code__.co_freevars.index("device_mesh")
+ ].cell_contents
+ except Exception:
+ return None
+
+
+class MegaBlocksMoeMLP(torch.nn.Module):
+ can_torch_compile: bool = True
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ gradient_scale = getattr(self.experts, "gradient_scale", None)
+ alpha = getattr(self.experts, "alpha", 1.0)
+ moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+ uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
+
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = get_device_mesh(self)
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
+
+ has_parallel = (
+ expert_parallel_group is not None
+ and dist.is_initialized()
+ and dist.get_world_size(expert_parallel_group) > 1
+ )
+ forward_fn = parallel_forward_once if has_parallel else forward_once
+
+ sort_end_bit = max(
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
+ )
+ mlp_impl = getattr(self, "mlp_impl", "grouped")
+ output, expert_weights_out, *_ = moe_forward(
+ x=x,
+ router_weight=self.router.weight,
+ router_bias=self.router.bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=self.training,
+ w1=self.experts.gate_up_proj,
+ w2=self.experts.down_proj,
+ w1_bias=self.experts.gate_up_proj_bias,
+ w2_bias=self.experts.down_proj_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=has_parallel,
+ forward_fn=forward_fn,
+ hidden_size=self.experts.hidden_size,
+ mlp_impl=mlp_impl,
+ )
+ return output, expert_weights_out
+
+
+# Export main classes
+__all__ = ["MegaBlocksMoeMLP", "MegaBlocksMoeMLPWithSharedExpert"]
+
+
+class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
+
+ def __init__(self):
+ super().__init__()
+ # Shared expert weights will be set by the user
+ self.shared_up_proj_weight = None
+ self.shared_down_proj_weight = None
+ self.shared_up_proj_bias = None
+ self.shared_down_proj_bias = None
+ self.shared_expert_weighted_sum = False
+ self.shared_activation_fn = None
+
+ def set_shared_expert_weights(
+ self,
+ up_proj_weight: torch.Tensor,
+ down_proj_weight: torch.Tensor,
+ up_proj_bias: Optional[torch.Tensor] = None,
+ down_proj_bias: Optional[torch.Tensor] = None,
+ weighted_sum: bool = False,
+ activation_fn: Optional[Any] = None,
+ ):
+ self.shared_up_proj_weight = up_proj_weight
+ self.shared_down_proj_weight = down_proj_weight
+ self.shared_up_proj_bias = up_proj_bias
+ self.shared_down_proj_bias = down_proj_bias
+ self.shared_expert_weighted_sum = weighted_sum
+ self.shared_activation_fn = activation_fn
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ moe_top_k = getattr(self.router, "top_k", 4)
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
+ gradient_scale = getattr(self.experts, "gradient_scale", None)
+ alpha = getattr(self.experts, "alpha", 1.0)
+ moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
+ moe_normalize_expert_weights = getattr(
+ self.experts, "normalize_expert_weights", None
+ )
+ uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
+
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
+ if expert_parallel_group is None:
+ device_mesh = get_device_mesh(self)
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
+
+ has_parallel = (
+ expert_parallel_group is not None
+ and dist.is_initialized()
+ and dist.get_world_size(expert_parallel_group) > 1
+ )
+ forward_fn = parallel_forward_once if has_parallel else forward_once
+
+ sort_end_bit = max(
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
+ )
+ mlp_impl = getattr(self, "mlp_impl", "grouped")
+
+ output, expert_weights_out, *_ = moe_forward_with_shared_expert(
+ x=x,
+ router_weight=self.router.weight,
+ router_bias=self.router.bias,
+ moe_top_k=moe_top_k,
+ moe_num_experts=moe_num_experts,
+ moe_jitter_eps=moe_jitter_eps,
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
+ uniform_expert_assignment=uniform_expert_assignment,
+ training=self.training,
+ w1=self.experts.gate_up_proj,
+ w2=self.experts.down_proj,
+ w1_bias=self.experts.gate_up_proj_bias,
+ w2_bias=self.experts.down_proj_bias,
+ gradient_scale=gradient_scale,
+ alpha=alpha,
+ sort_end_bit=sort_end_bit,
+ expert_parallel_group=expert_parallel_group,
+ moe_capacity_factor=moe_capacity_factor,
+ moe_expert_model_parallelism=has_parallel,
+ forward_fn=forward_fn,
+ hidden_size=self.experts.hidden_size,
+ mlp_impl=mlp_impl,
+ # Shared expert parameters
+ shared_up_proj_weight=self.shared_up_proj_weight,
+ shared_down_proj_weight=self.shared_down_proj_weight,
+ shared_up_proj_bias=self.shared_up_proj_bias,
+ shared_down_proj_bias=self.shared_down_proj_bias,
+ shared_expert_weighted_sum=self.shared_expert_weighted_sum,
+ shared_activation_fn=self.shared_activation_fn,
+ )
+ return output, expert_weights_out
+
+
+# Patch for XPU or CPU support
+if hasattr(torch, "xpu") and torch.xpu.is_available():
+ from .xpu_fused_moe import MegaBlocksMoeMLP
+
+from .cpu_moe_cpp import CPUMegaBlocksMoeMLP
diff --git a/build/torch29-cxx11-xpu20252-x86_64-linux/megablocks/__init__.py b/build/torch29-cxx11-xpu20252-x86_64-linux/megablocks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23
--- /dev/null
+++ b/build/torch29-cxx11-xpu20252-x86_64-linux/megablocks/__init__.py
@@ -0,0 +1,26 @@
+import ctypes
+import importlib.util
+import sys
+from pathlib import Path
+from types import ModuleType
+
+
+def _import_from_path(file_path: Path) -> ModuleType:
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
+ # it would also be used for other imports. So, we make a module name that
+ # depends on the path for it to be unique using the hex-encoded hash of
+ # the path.
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
+ module_name = path_hash
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ if spec is None:
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
+ module = importlib.util.module_from_spec(spec)
+ if module is None:
+ raise ImportError(f"Cannot load module {module_name} from spec")
+ sys.modules[module_name] = module
+ spec.loader.exec_module(module) # type: ignore
+ return module
+
+
+globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
diff --git a/build/torch29-cxx11-xpu20252-x86_64-linux/metadata.json b/build/torch29-cxx11-xpu20252-x86_64-linux/metadata.json
new file mode 100644
index 0000000000000000000000000000000000000000..b911d0a2549a35a1c65ab7e77d32e5aac23cd6ac
--- /dev/null
+++ b/build/torch29-cxx11-xpu20252-x86_64-linux/metadata.json
@@ -0,0 +1,8 @@
+{
+ "version": 1,
+ "license": "Apache-2.0",
+ "python-depends": [],
+ "backend": {
+ "type": "xpu"
+ }
+}
\ No newline at end of file
diff --git a/build/torch29-cxx11-xpu20252-x86_64-linux/ops/__init__.py b/build/torch29-cxx11-xpu20252-x86_64-linux/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b944080df810d0b0cfc571f3009b0098a651f9b7
--- /dev/null
+++ b/build/torch29-cxx11-xpu20252-x86_64-linux/ops/__init__.py
@@ -0,0 +1,35 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+from .binned_gather import binned_gather
+from .binned_scatter import binned_scatter
+from .cumsum import exclusive_cumsum, inclusive_cumsum
+from .gather import gather
+from .histogram import histogram
+from .padded_gather import padded_gather
+from .padded_scatter import padded_scatter
+from .repeat import repeat
+from .replicate import replicate
+from .round_up import round_up
+from .scatter import scatter
+from .sort import sort
+from .sum import sum
+from .topology import topology
+
+__all__ = [
+ 'binned_gather',
+ 'binned_scatter',
+ 'exclusive_cumsum',
+ 'inclusive_cumsum',
+ 'gather',
+ 'histogram',
+ 'padded_gather',
+ 'padded_scatter',
+ 'repeat',
+ 'replicate',
+ 'round_up',
+ 'scatter',
+ 'sort',
+ 'sum',
+ 'topology',
+]
diff --git a/build/torch29-cxx11-xpu20252-x86_64-linux/ops/all_to_all_benchmark.py b/build/torch29-cxx11-xpu20252-x86_64-linux/ops/all_to_all_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c939818edca3345f6344bbc7cef07ffe3cd0181
--- /dev/null
+++ b/build/torch29-cxx11-xpu20252-x86_64-linux/ops/all_to_all_benchmark.py
@@ -0,0 +1,63 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import torch.distributed as dist
+
+# from megablocks import benchmark_util
+# from megablocks.layers.all_to_all import all_to_all
+
+from .. import benchmark_util
+from .._layers.all_to_all import all_to_all
+
+_ALL_TO_ALL_BENCHMARK = (
+ (8, 1024),
+ (16, 1024),
+ (32, 1024),
+ (64, 1024),
+ (128, 1024),
+ (256, 1024),
+ (512, 1024),
+ (1024, 1024),
+ (2 * 1024, 1024),
+ (4 * 1024, 1024),
+ (8 * 1024, 1024),
+ (16 * 1024, 1024),
+ (32 * 1024, 1024),
+ (64 * 1024, 1024),
+ (128 * 1024, 1024),
+ (256 * 1024, 1024),
+ (512 * 1024, 1024),
+ (1024 * 1024, 1024),
+)
+
+
+def benchmark_all_to_all(group, sl, hs):
+ world_size = dist.get_world_size(group)
+ assert (sl % world_size) == 0
+ send_recv_sizes = [sl // world_size] * world_size
+
+ x = torch.randn((sl, hs)).cuda().half()
+
+ details = {
+ 'world_size': world_size,
+ 'message_size (B)': send_recv_sizes[0] * hs * 2, # 2B elements.
+ }
+
+ def benchmark():
+ return all_to_all(x, send_recv_sizes, send_recv_sizes, group)
+
+ time, std = benchmark_util.benchmark_function(benchmark)
+
+ if dist.get_rank(group) == 0:
+ benchmark_util.log_benchmark('All-To-All', details, time, std)
+
+
+if __name__ == '__main__':
+ assert dist.is_available()
+ group = dist.init_process_group(backend='nccl')
+ local_rank = dist.get_rank(group)
+ torch.cuda.set_device(local_rank)
+
+ for args in _ALL_TO_ALL_BENCHMARK:
+ benchmark_all_to_all(group, *args)
diff --git a/build/torch29-cxx11-xpu20252-x86_64-linux/ops/binned_gather.py b/build/torch29-cxx11-xpu20252-x86_64-linux/ops/binned_gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..189a7fa3518d660f29ea32e7a04827164af98d60
--- /dev/null
+++ b/build/torch29-cxx11-xpu20252-x86_64-linux/ops/binned_gather.py
@@ -0,0 +1,37 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for binned_gather kernel.
+class BinnedGatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bins: torch.Tensor,
+ bin_size: int,
+ top_k: int,
+ ):
+ ctx.save_for_backward(indices, bins)
+ ctx.top_k = top_k
+ return kernels.binned_gather(x, indices, None, bins, bin_size, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, grad: torch.Tensor):
+ grad = grad.contiguous()
+ indices, bins = ctx.saved_tensors
+ out = kernels.binned_scatter(grad, indices, None, bins, ctx.top_k)
+ return out, None, None, None, None
+
+
+binned_gather = BinnedGatherOp.apply
diff --git a/build/torch29-cxx11-xpu20252-x86_64-linux/ops/binned_scatter.py b/build/torch29-cxx11-xpu20252-x86_64-linux/ops/binned_scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb937c0c106662ce8108c1cb926f8f063b163d3d
--- /dev/null
+++ b/build/torch29-cxx11-xpu20252-x86_64-linux/ops/binned_scatter.py
@@ -0,0 +1,59 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any
+
+import torch
+from .stk_autocast import custom_bwd, custom_fwd
+
+from ..backend import kernels
+
+
+# Autograd wrapper for binned_scatter kernel.
+class BinnedScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ weights: t